diff --git a/.github/ISSUE_TEMPLATE/question.md b/.github/ISSUE_TEMPLATE/question.md index e822461059..60eb17a584 100644 --- a/.github/ISSUE_TEMPLATE/question.md +++ b/.github/ISSUE_TEMPLATE/question.md @@ -1,7 +1,7 @@ --- -name: Question -about: Question relating to MONAI -title: '' +name: Question (please use the Discussion tab) +about: https://github.com/Project-MONAI/MONAI/discussions +title: 'Please use MONAI Discussion tab for questions' labels: '' assignees: '' --- diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index f7024f1a08..e1eeb92c6b 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -12,6 +12,6 @@ A few sentences describing the changes proposed in this pull request. - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. -- [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests`. +- [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. diff --git a/.github/workflows/blossom-ci.yml b/.github/workflows/blossom-ci.yml index 52bd4fddae..44ffdb12a8 100644 --- a/.github/workflows/blossom-ci.yml +++ b/.github/workflows/blossom-ci.yml @@ -30,7 +30,7 @@ jobs: # This job only runs for pull request comments if: | - contains( 'madil90,Nic-Ma,wyli,', format('{0},', github.actor)) && + contains( 'Nic-Ma,wyli,pxLi,', format('{0},', github.actor)) && github.event.comment.body == '/build' steps: - name: Check if comment is issued by authorized person diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml new file mode 100644 index 0000000000..4e72a7dbcf --- /dev/null +++ b/.github/workflows/codeql-analysis.yml @@ -0,0 +1,71 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +# +# ******** NOTE ******** +# We have attempted to detect the languages in your repository. Please check +# the `language` matrix defined below to confirm you have the correct set of +# supported CodeQL languages. +# +name: "CodeQL" + +on: + push: + branches: [ dev, main ] + pull_request: + # The branches below must be a subset of the branches above + branches: [ dev ] + schedule: + - cron: '18 1 * * 0' + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ 'cpp', 'python' ] + # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] + # Learn more about CodeQL language support at https://git.io/codeql-language-support + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v1 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + # queries: ./path/to/local/query, your-org/your-repo/queries@main + + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). + # If this step fails, then you should remove it and run the build manually (see below) + # - name: Autobuild + # uses: github/codeql-action/autobuild@v1 + + # ℹī¸ Command-line programs to run using the OS shell. + # 📚 https://git.io/JvXDl + + # ✏ī¸ If the Autobuild fails above, remove it and uncomment the following three lines + # and modify them (or add more) to build your code if your project + # uses a compiled language + + - name: Build + run: | + python -m pip install -r requirements-dev.txt + BUILD_MONAI=1 ./runtests.sh --build + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v1 diff --git a/.github/workflows/conda.yml b/.github/workflows/conda.yml new file mode 100644 index 0000000000..28349a5de5 --- /dev/null +++ b/.github/workflows/conda.yml @@ -0,0 +1,57 @@ +name: conda + +on: + schedule: + - cron: "0 3 * * *" # at 03:00 UTC + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +concurrency: + # automatically cancel the previously triggered workflows when there's a newer version + group: conda-tests-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + cron-conda: + if: github.repository == 'Project-MONAI/MONAI' + strategy: + fail-fast: false + matrix: + os: [windows-latest, macOS-latest, ubuntu-latest] + python-version: ["3.7"] + runs-on: ${{ matrix.os }} + env: + QUICKTEST: True + steps: + - if: runner.os == 'windows' + name: Config pagefile (Windows only) + uses: al-cheb/configure-pagefile-action@v1.2 + with: + minimum-size: 8 + maximum-size: 16 + disk-root: "D:" + - uses: actions/checkout@v2 + - uses: conda-incubator/setup-miniconda@v2 + with: + auto-update-conda: true + python-version: ${{ matrix.python-version }} + - name: Install env (CPU ${{ runner.os }}) + shell: bash -l {0} + run: | + conda info + conda list + conda env create --file environment-dev.yml + - if: runner.os == 'windows' + name: Windows only install + shell: bash -l {0} + run: | + conda activate monai + # this `cpuonly` and -c conda-forge is needed to reduce the paging file size on a github instance + conda install pytorch torchvision torchaudio cpuonly -c pytorch -c conda-forge + conda deactivate + - name: Test env(CPU ${{ runner.os }}) + shell: bash -l {0} + run: | + conda activate monai + $(pwd)/runtests.sh --build --unittests + conda deactivate diff --git a/.github/workflows/cron-mmar.yml b/.github/workflows/cron-mmar.yml new file mode 100644 index 0000000000..f61ba59368 --- /dev/null +++ b/.github/workflows/cron-mmar.yml @@ -0,0 +1,42 @@ +name: cron-mmar + +on: + schedule: + - cron: "0 2 * * *" # at 02:00 UTC + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +concurrency: + # automatically cancel the previously triggered workflows when there's a newer version + group: mmar-tests-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + cron-load: + if: github.repository == 'Project-MONAI/MONAI' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: cache weekly timestamp + id: pip-cache + run: echo "::set-output name=datew::$(date '+%Y-%V')" + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip wheel + python -m pip install -r requirements-dev.txt + - name: Loading MMARs + run: | + # clean up temporary files + $(pwd)/runtests.sh --build --clean + # run tests + python -m tests.ngc_mmar_loading diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml index a36cfbcdb9..b6ed2274ee 100644 --- a/.github/workflows/cron.yml +++ b/.github/workflows/cron.yml @@ -10,12 +10,12 @@ jobs: cron-gpu: if: github.repository == 'Project-MONAI/MONAI' container: - image: nvcr.io/nvidia/pytorch:20.03-py3 # CUDA 10.2 + image: nvcr.io/nvidia/pytorch:21.06-py3 # CUDA 11.3 options: "--gpus all" runs-on: [self-hosted, linux, x64, common] strategy: matrix: - pytorch-version: [1.5.1, 1.6.0, 1.7.1, 1.8.1, latest] + pytorch-version: [1.6.0, 1.7.1, 1.8.1, 1.9.1, latest] steps: - uses: actions/checkout@v2 - name: Install the dependencies @@ -25,14 +25,14 @@ jobs: python -m pip uninstall -y torch torchvision if [ ${{ matrix.pytorch-version }} == "latest" ]; then python -m pip install torch torchvision - elif [ ${{ matrix.pytorch-version }} == "1.5.1" ]; then - python -m pip install torch==1.5.1 torchvision==0.6.1 elif [ ${{ matrix.pytorch-version }} == "1.6.0" ]; then python -m pip install torch==1.6.0 torchvision==0.7.0 elif [ ${{ matrix.pytorch-version }} == "1.7.1" ]; then python -m pip install torch==1.7.1 torchvision==0.8.2 elif [ ${{ matrix.pytorch-version }} == "1.8.1" ]; then python -m pip install torch==1.8.1 torchvision==0.9.1 + elif [ ${{ matrix.pytorch-version }} == "1.9.1" ]; then + python -m pip install torch==1.9.1 torchvision==0.10.1 fi python -m pip install -r requirements-dev.txt python -m pip list @@ -48,8 +48,8 @@ jobs: python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null & python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))" python -c 'import torch; print(torch.rand(5, 3, device=torch.device("cuda:0")))' - BUILD_MONAI=1 ./runtests.sh --coverage --unittests # unit tests with coverage report - BUILD_MONAI=1 ./runtests.sh --coverage --net # integration tests with coverage report + BUILD_MONAI=1 ./runtests.sh --build --coverage --unittests --disttests # unit tests with coverage report + BUILD_MONAI=1 ./runtests.sh --build --coverage --net # integration tests with coverage report coverage xml if pgrep python; then pkill python; fi - name: Upload coverage @@ -62,7 +62,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:21.02", "pytorch:21.08"] # 21.02 for backward comp. + container: ["pytorch:21.02", "pytorch:22.02"] # 21.02 for backward comp. container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -91,8 +91,8 @@ jobs: python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null & python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))" python -c 'import torch; print(torch.rand(5, 3, device=torch.device("cuda:0")))' - BUILD_MONAI=1 ./runtests.sh --coverage --unittests # unit tests with coverage report - BUILD_MONAI=1 ./runtests.sh --coverage --net # integration tests with coverage report + BUILD_MONAI=1 ./runtests.sh --build --coverage --unittests --disttests # unit tests with coverage report + BUILD_MONAI=1 ./runtests.sh --build --coverage --net # integration tests with coverage report coverage xml if pgrep python; then pkill python; fi - name: Upload coverage @@ -106,7 +106,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:21.02", "pytorch:21.08"] # 21.02 for backward comp. + container: ["pytorch:21.02", "pytorch:22.02"] # 21.02 for backward comp. container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -173,7 +173,7 @@ jobs: cron-docker: if: github.repository == 'Project-MONAI/MONAI' container: - image: localhost:5000/local_monai:dockerhub # use currently latest, locally available dockerhub image + image: docker://projectmonai/monai:latest # this might be slow and has the pull count limitations options: "--gpus all" runs-on: [self-hosted, linux, x64, common] steps: @@ -190,8 +190,8 @@ jobs: python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))" python -c 'import torch; print(torch.rand(5,3, device=torch.device("cuda:0")))' ngc --version - BUILD_MONAI=1 ./runtests.sh --coverage --pytype --unittests # unit tests with pytype checks, coverage report - BUILD_MONAI=1 ./runtests.sh --coverage --net # integration tests with coverage report + BUILD_MONAI=1 ./runtests.sh --build --coverage --pytype --unittests --disttests # unit tests with pytype checks, coverage report + BUILD_MONAI=1 ./runtests.sh --build --coverage --net # integration tests with coverage report coverage xml if pgrep python; then pkill python; fi - name: Upload coverage @@ -204,7 +204,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' needs: cron-gpu # so that monai itself is verified first container: - image: nvcr.io/nvidia/pytorch:21.08-py3 # testing with the latest pytorch base image + image: nvcr.io/nvidia/pytorch:21.09-py3 # testing with the latest pytorch base image options: "--gpus all --ipc=host" runs-on: [self-hosted, linux, x64, common] steps: @@ -215,7 +215,7 @@ jobs: which python python -m pip install --upgrade pip wheel python -m pip install -r requirements-dev.txt - BUILD_MONAI=0 python setup.py develop # install monai + BUILD_MONAI=1 python setup.py develop # install monai nvidia-smi export CUDA_VISIBLE_DEVICES=$(python -m tests.utils) echo $CUDA_VISIBLE_DEVICES @@ -234,5 +234,7 @@ jobs: trap 'if pgrep python; then pkill python; fi;' ERR python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null & cd /opt/tutorials + python -c 'import monai; monai.config.print_debug_info()' $(pwd)/runner.sh + python -c 'import monai; monai.config.print_debug_info()' if pgrep python; then pkill python; fi diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 3104224e2b..7140cd7dd8 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -13,20 +13,21 @@ on: workflow_dispatch: jobs: - versioning: + versioning_dev: # compute versioning file from python setup.py # upload as artifact - # (also used in release.yml) if: github.repository == 'Project-MONAI/MONAI' - container: - image: localhost:5000/local_monai:latest - runs-on: [self-hosted, linux, x64, build_only] + runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 # full history so that we can git describe with: ref: dev fetch-depth: 0 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 - shell: bash run: | git describe @@ -43,13 +44,11 @@ jobs: ls -al rm -rf {*,.[^.]*} - local_docker: - # builds two versions: local_monai:latest and local_monai:dockerhub - # latest: used for local tests - # dockerhub: release, no flake package + docker_build_dev: + # builds projectmonai/monai:latest if: github.repository == 'Project-MONAI/MONAI' - needs: versioning - runs-on: [self-hosted, linux, x64, build_only] + needs: versioning_dev + runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 with: @@ -58,67 +57,47 @@ jobs: uses: actions/download-artifact@v2 with: name: _version.py + - name: Install Latest Docker + run: | + curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add - + sudo add-apt-repository "deb [arch=amd64] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" + sudo apt-get update + sudo apt-get install docker-ce - name: docker_build shell: bash run: | # get tag info for versioning cat _version.py mv _version.py monai/ - # build and run original docker image for local registry - docker build -t localhost:5000/local_monai:latest -f Dockerfile . - docker push localhost:5000/local_monai:latest - # build once more w/ tag "latest": remove flake package as it is not needed on hub.docker.com + + # build "latest": remove flake package as it is not needed on hub.docker.com sed -i '/flake/d' requirements-dev.txt docker build -t projectmonai/monai:latest -f Dockerfile . - # also push as tag "dockerhub" to local registry - docker image tag projectmonai/monai:latest localhost:5000/local_monai:dockerhub - docker push localhost:5000/local_monai:dockerhub + # distribute as always w/ tag "latest" to hub.docker.com echo "${{ secrets.DOCKER_PW }}" | docker login -u projectmonai --password-stdin + docker push projectmonai/monai:latest docker logout docker image prune -f - docker_test_latest: - if: github.repository == 'Project-MONAI/MONAI' - needs: local_docker - container: - image: localhost:5000/local_monai:latest - runs-on: [self-hosted, linux, x64, common] - steps: - - name: Import - run: | - export CUDA_VISIBLE_DEVICES=$(python -m tests.utils) - echo $CUDA_VISIBLE_DEVICES - trap 'if pgrep python; then pkill python; fi;' ERR - python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null & - python -c 'import monai; monai.config.print_config()' - cd /opt/monai - ls -al - ngc --version - python -m tests.min_tests - if pgrep python; then pkill python; fi - env: - QUICKTEST: True - docker_test_dockerhub: if: github.repository == 'Project-MONAI/MONAI' - needs: local_docker + needs: docker_build_dev container: - image: localhost:5000/local_monai:dockerhub - runs-on: [self-hosted, linux, x64, common] + image: docker://projectmonai/monai:latest + options: "--shm-size=4g --ipc=host" + runs-on: ubuntu-latest steps: - name: Import run: | export CUDA_VISIBLE_DEVICES=$(python -m tests.utils) echo $CUDA_VISIBLE_DEVICES - trap 'if pgrep python; then pkill python; fi;' ERR - python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null & - python -c 'import monai; monai.config.print_config()' + python -c 'import monai; monai.config.print_debug_info()' cd /opt/monai ls -al ngc --version - python -m tests.min_tests - if pgrep python; then pkill python; fi + ./runtests.sh --min + shell: bash env: QUICKTEST: True diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index ed025e98fe..8ed9790dca 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -7,7 +7,7 @@ on: jobs: integration-py3: container: - image: nvcr.io/nvidia/pytorch:20.12-py3 # CUDA 11.1 + image: nvcr.io/nvidia/pytorch:21.12-py3 # CUDA 11.5 options: --gpus all runs-on: [self-hosted, linux, x64, common] steps: @@ -34,7 +34,7 @@ jobs: which python python -m pip install --upgrade pip wheel python -m pip uninstall -y torch torchvision - python -m pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html + python -m pip install torch==1.11.0+cu115 torchvision==0.12.0+cu115 -f https://download.pytorch.org/whl/torch_stable.html python -m pip install -r requirements-dev.txt - name: Run integration tests run: | @@ -46,8 +46,8 @@ jobs: python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null & python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))" python -c 'import torch; print(torch.rand(5,3, device=torch.device("cuda:0")))' - BUILD_MONAI=1 ./runtests.sh --net - BUILD_MONAI=1 ./runtests.sh --unittests + BUILD_MONAI=1 ./runtests.sh --build --net + BUILD_MONAI=1 ./runtests.sh --build --unittests --disttests if pgrep python; then pkill python; fi shell: bash - name: Add reaction diff --git a/.github/workflows/pythonapp-gpu.yml b/.github/workflows/pythonapp-gpu.yml index 999567ae16..8ec6bdd3c7 100644 --- a/.github/workflows/pythonapp-gpu.yml +++ b/.github/workflows/pythonapp-gpu.yml @@ -19,35 +19,41 @@ jobs: strategy: matrix: environment: - - "PT16+CUDA110" + - "PT19+CUDA114" - "PT17+CUDA102" - - "PT17+CUDA110" - "PT18+CUDA102" - - "PT19+CUDA114" - - "PT19+CUDA102" + - "PT18+CUDA112" + - "PT110+CUDA116" + - "PT110+CUDA102" + - "PT111+CUDA102" include: - - environment: PT16+CUDA110 - # we explicitly set pytorch to -h to avoid pip install error - pytorch: "-h" - base: "nvcr.io/nvidia/pytorch:20.07-py3" + # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes - environment: PT17+CUDA102 pytorch: "torch==1.7.1 torchvision==0.8.2" base: "nvcr.io/nvidia/cuda:10.2-devel-ubuntu18.04" - - environment: PT17+CUDA110 - # we explicitly set pytorch to -h to avoid pip install error - pytorch: "-h" - base: "nvcr.io/nvidia/pytorch:20.09-py3" - environment: PT18+CUDA102 pytorch: "torch==1.8.1 torchvision==0.9.1" base: "nvcr.io/nvidia/cuda:10.2-devel-ubuntu18.04" + - environment: PT18+CUDA112 + # we explicitly set pytorch to -h to avoid pip install error + # 21.03: 1.9.0a0+df837d0 + pytorch: "-h" + base: "nvcr.io/nvidia/pytorch:21.03-py3" - environment: PT19+CUDA114 # we explicitly set pytorch to -h to avoid pip install error - # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes - # 21.08: 1.10.0a0+3fd9dcf + # 21.10: 1.10.0a0+0aef44c pytorch: "-h" - base: "nvcr.io/nvidia/pytorch:21.08-py3" - - environment: PT19+CUDA102 - pytorch: "torch==1.9.0 torchvision==0.10.0" + base: "nvcr.io/nvidia/pytorch:21.10-py3" + - environment: PT110+CUDA116 + # we explicitly set pytorch to -h to avoid pip install error + # 22.02: 1.11.0a0+17540c5 + pytorch: "-h" + base: "nvcr.io/nvidia/pytorch:22.02-py3" + - environment: PT110+CUDA102 + pytorch: "torch==1.10.1 torchvision==0.11.2" + base: "nvcr.io/nvidia/cuda:10.2-devel-ubuntu18.04" + - environment: PT111+CUDA102 + pytorch: "torch==1.11.0 torchvision==0.12.0" base: "nvcr.io/nvidia/cuda:10.2-devel-ubuntu18.04" container: image: ${{ matrix.base }} @@ -59,9 +65,10 @@ jobs: run: | if [ ${{ matrix.environment }} = "PT17+CUDA102" ] || \ [ ${{ matrix.environment }} = "PT18+CUDA102" ] || \ - [ ${{ matrix.environment }} = "PT19+CUDA102" ] + [ ${{ matrix.environment }} = "PT110+CUDA102" ] || \ + [ ${{ matrix.environment }} = "PT111+CUDA102" ] then - PYVER=3.6 PYSFX=3 DISTUTILS=python3-distutils && \ + PYVER=3.7 PYSFX=3 DISTUTILS=python3-distutils && \ apt-get update && apt-get install -y --no-install-recommends \ curl \ pkg-config \ @@ -100,6 +107,8 @@ jobs: run: | which python python -m pip install --upgrade pip wheel + # fixes preinstalled ruamel_yaml error from the docker image + rm -rf $(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")/ruamel* python -m pip install ${{ matrix.pytorch }} python -m pip install -r requirements-dev.txt python -m pip list @@ -120,8 +129,8 @@ jobs: python -c 'import torch; print(torch.rand(5, 3, device=torch.device("cuda:0")))' python -c "import monai; monai.config.print_config()" # build for the current self-hosted CI Tesla V100 - BUILD_MONAI=1 TORCH_CUDA_ARCH_LIST="7.0" ./runtests.sh --quick --unittests - if [ ${{ matrix.environment }} = "PT19+CUDA102" ]; then + BUILD_MONAI=1 TORCH_CUDA_ARCH_LIST="7.0" ./runtests.sh --build --quick --unittests --disttests + if [ ${{ matrix.environment }} = "PT110+CUDA102" ]; then # test the clang-format tool downloading once coverage run -m tests.clang_format_utils fi diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml new file mode 100644 index 0000000000..c3294c2b2a --- /dev/null +++ b/.github/workflows/pythonapp-min.yml @@ -0,0 +1,172 @@ +name: build-min + +on: + # quick tests for pull requests and the releasing branches + push: + branches: + - dev + - main + - releasing/* + pull_request: + +concurrency: + # automatically cancel the previously triggered workflows when there's a newer version + group: build-min-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + # caching of these jobs: + # - docker-py3-pip- (shared) + # - ubuntu py37 pip- + # - os-latest-pip- (shared) + min-dep-os: # min dependencies installed tests for different OS + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [windows-latest, macOS-latest, ubuntu-latest] + timeout-minutes: 40 + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Prepare pip wheel + run: | + which python + python -m pip install --upgrade pip wheel + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + echo "::set-output name=dir::$(pip cache dir)" + shell: bash + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ matrix.os }}-latest-pip-${{ steps.pip-cache.outputs.datew }} + - if: runner.os == 'windows' + name: Install torch cpu from pytorch.org (Windows only) + run: | + python -m pip install torch==1.11.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + - name: Install the dependencies + run: | + # min. requirements + python -m pip install torch==1.11.0 + python -m pip install -r requirements-min.txt + python -m pip list + BUILD_MONAI=0 python setup.py develop # no compile of extensions + shell: bash + - name: Run quick tests (CPU ${{ runner.os }}) + run: | + python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' + python -c "import monai; monai.config.print_config()" + ./runtests.sh --min + shell: bash + env: + QUICKTEST: True + + min-dep-py3: # min dependencies installed tests for different python + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [3.7, 3.8, 3.9] + timeout-minutes: 40 + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Prepare pip wheel + run: | + which python + python -m pip install --user --upgrade pip setuptools wheel + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + echo "::set-output name=dir::$(pip cache dir)" + shell: bash + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }} + - name: Install the dependencies + run: | + # min. requirements + python -m pip install torch==1.11.0 + python -m pip install -r requirements-min.txt + python -m pip list + BUILD_MONAI=0 python setup.py develop # no compile of extensions + shell: bash + - name: Run quick tests (CPU ${{ runner.os }}) + run: | + python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' + python -c "import monai; monai.config.print_config()" + ./runtests.sh --min + env: + QUICKTEST: True + + min-dep-pytorch: # min dependencies installed tests for different pytorch + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + pytorch-version: [1.6.0, 1.7.1, 1.8.1, 1.9.1, 1.10.1, latest] + timeout-minutes: 40 + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Prepare pip wheel + run: | + which python + python -m pip install --user --upgrade pip setuptools wheel + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + echo "::set-output name=dir::$(pip cache dir)" + shell: bash + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }} + - name: Install the dependencies + run: | + # min. requirements + if [ ${{ matrix.pytorch-version }} == "latest" ]; then + python -m pip install torch + elif [ ${{ matrix.pytorch-version }} == "1.6.0" ]; then + python -m pip install torch==1.6.0 + elif [ ${{ matrix.pytorch-version }} == "1.7.1" ]; then + python -m pip install torch==1.7.1 + elif [ ${{ matrix.pytorch-version }} == "1.8.1" ]; then + python -m pip install torch==1.8.1 + elif [ ${{ matrix.pytorch-version }} == "1.9.1" ]; then + python -m pip install torch==1.9.1 + elif [ ${{ matrix.pytorch-version }} == "1.10.1" ]; then + python -m pip install torch==1.10.1 + fi + python -m pip install -r requirements-min.txt + python -m pip list + BUILD_MONAI=0 python setup.py develop # no compile of extensions + shell: bash + - name: Run quick tests (pytorch ${{ matrix.pytorch-version }}) + run: | + python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' + python -c "import monai; monai.config.print_config()" + ./runtests.sh --min + env: + QUICKTEST: True diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 3f18263e9e..cf251c2293 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -44,9 +44,9 @@ jobs: - name: Lint and type check run: | # clean up temporary files - $(pwd)/runtests.sh --clean + $(pwd)/runtests.sh --build --clean # Git hub actions have 2 cores, so parallize pytype - $(pwd)/runtests.sh --codeformat -j 2 + $(pwd)/runtests.sh --build --codeformat -j 2 quick-py3: # full dependencies installed tests for different OS runs-on: ${{ matrix.os }} @@ -87,10 +87,10 @@ jobs: - if: runner.os == 'windows' name: Install torch cpu from pytorch.org (Windows only) run: | - python -m pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + python -m pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html - name: Install the dependencies run: | - python -m pip install torch==1.9.0 torchvision==0.10.0 + python -m pip install torch==1.11.0 torchvision==0.12.0 cat "requirements-dev.txt" python -m pip install -r requirements-dev.txt python -m pip list @@ -106,100 +106,6 @@ jobs: env: QUICKTEST: True - min-dep-os: # min dependencies installed tests for different OS - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [windows-latest, macOS-latest, ubuntu-latest] - timeout-minutes: 40 - steps: - - uses: actions/checkout@v2 - - name: Set up Python 3.8 - uses: actions/setup-python@v2 - with: - python-version: 3.8 - - name: Prepare pip wheel - run: | - which python - python -m pip install --upgrade pip wheel - - name: cache weekly timestamp - id: pip-cache - run: | - echo "::set-output name=datew::$(date '+%Y-%V')" - echo "::set-output name=dir::$(pip cache dir)" - shell: bash - - name: cache for pip - uses: actions/cache@v2 - id: cache - with: - path: ${{ steps.pip-cache.outputs.dir }} - key: ${{ matrix.os }}-latest-pip-${{ steps.pip-cache.outputs.datew }} - - if: runner.os == 'windows' - name: Install torch cpu from pytorch.org (Windows only) - run: | - python -m pip install torch==1.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html - - name: Install the dependencies - run: | - # min. requirements - python -m pip install torch==1.9.0 - python -m pip install -r requirements-min.txt - python -m pip list - BUILD_MONAI=0 python setup.py develop # no compile of extensions - shell: bash - - name: Run quick tests (CPU ${{ runner.os }}) - run: | - python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' - python -c "import monai; monai.config.print_config()" - ./runtests.sh --min - env: - QUICKTEST: True - - min-dep-py3: # min dependencies installed tests for different python - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: [3.6, 3.7, 3.8, 3.9] - timeout-minutes: 40 - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Prepare pip wheel - run: | - which python - python -m pip install --user --upgrade pip setuptools wheel - - name: cache weekly timestamp - id: pip-cache - run: | - echo "::set-output name=datew::$(date '+%Y-%V')" - echo "::set-output name=dir::$(pip cache dir)" - shell: bash - - name: cache for pip - uses: actions/cache@v2 - id: cache - with: - path: ${{ steps.pip-cache.outputs.dir }} - key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }} - - name: Install the dependencies - run: | - # min. requirements - python -m pip install torch==1.9.0 - python -m pip install -r requirements-min.txt - python -m pip list - BUILD_MONAI=0 python setup.py develop # no compile of extensions - shell: bash - - name: Run quick tests (CPU ${{ runner.os }}) - run: | - python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' - python -c "import monai; monai.config.print_config()" - ./runtests.sh --min - env: - QUICKTEST: True - packaging: runs-on: ubuntu-latest env: @@ -231,7 +137,7 @@ jobs: # install the latest pytorch for testing # however, "pip install monai*.tar.gz" will build cpp/cuda with an isolated # fresh torch installation according to pyproject.toml - python -m pip install torch>=1.5 torchvision + python -m pip install torch>=1.6 torchvision - name: Check packages run: | pip uninstall monai diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index bfdc639788..9cef1ff090 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.6, 3.7, 3.8, 3.9] + python-version: [3.7, 3.8, 3.9] steps: - uses: actions/checkout@v2 with: @@ -87,18 +87,18 @@ jobs: versioning: # compute versioning file from python setup.py # upload as artifact - # (also used in docker.yml) if: github.repository == 'Project-MONAI/MONAI' needs: packaging - container: - image: localhost:5000/local_monai:latest - runs-on: [self-hosted, linux, x64, build_only] + runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 # full history so that we can git describe with: - ref: main fetch-depth: 0 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 - shell: bash run: | git describe @@ -118,11 +118,9 @@ jobs: release_tag_docker: if: github.repository == 'Project-MONAI/MONAI' needs: versioning - runs-on: [self-hosted, linux, x64, build_only] + runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - with: - ref: main - name: Download version uses: actions/download-artifact@v2 with: @@ -136,6 +134,13 @@ jobs: run: | echo "$RELEASE_VERSION" cat _version.py + - if: startsWith(github.ref, 'refs/tags/') + name: Install latest docker + run: | + curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add - + sudo add-apt-repository "deb [arch=amd64] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" + sudo apt-get update + sudo apt-get install docker-ce - if: startsWith(github.ref, 'refs/tags/') name: build with the tag env: @@ -144,6 +149,17 @@ jobs: run: | # get tag info for versioning mv _version.py monai/ + # version checks + target=" \"version\": \"$RELEASE_VERSION\"" + local=`grep "\"version\"" monai/_version.py` + echo "$target" + echo "$local" + if [[ "$local" == "$target" ]]; then + echo "matched version string" + else + echo "unmatched version string, please check the tagging branch." + exit 1 + fi # remove flake package as it is not needed on hub.docker.com sed -i '/flake/d' requirements-dev.txt docker build -t projectmonai/monai:"$RELEASE_VERSION" -f Dockerfile . diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml index d0dc3a9f10..02911ea51f 100644 --- a/.github/workflows/setupapp.yml +++ b/.github/workflows/setupapp.yml @@ -21,7 +21,7 @@ jobs: coverage-py3: if: github.repository == 'Project-MONAI/MONAI' container: - image: nvcr.io/nvidia/pytorch:20.03-py3 # CUDA 10.2 + image: nvcr.io/nvidia/pytorch:21.06-py3 # CUDA 11.3 options: --gpus all runs-on: [self-hosted, linux, x64, common] steps: @@ -43,7 +43,8 @@ jobs: which python python -m pip install --upgrade pip wheel python -m pip uninstall -y torch torchvision - python -m pip install torch==1.9.0 torchvision==0.10.0 + rm -rf $(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")/ruamel* + python -m pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html python -m pip install -r requirements-dev.txt - name: Run unit tests report coverage run: | @@ -58,8 +59,8 @@ jobs: python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null & python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))" python -c 'import torch; print(torch.rand(5, 3, device=torch.device("cuda:0")))' - BUILD_MONAI=1 ./runtests.sh --coverage --unittests # unit tests with coverage report - BUILD_MONAI=1 ./runtests.sh --coverage --net # integration tests with coverage report + BUILD_MONAI=1 ./runtests.sh --build --coverage --unittests --disttests # unit tests with coverage report + BUILD_MONAI=1 ./runtests.sh --build --coverage --net # integration tests with coverage report coverage xml if pgrep python; then pkill python; fi shell: bash @@ -73,7 +74,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.6, 3.7, 3.8, 3.9] + python-version: [3.7, 3.8, 3.9] steps: - uses: actions/checkout@v2 with: @@ -97,13 +98,13 @@ jobs: - name: Install the dependencies run: | python -m pip install --upgrade pip wheel - python -m pip install torch==1.9.0 torchvision==0.10.0 + python -m pip install torch==1.11.0 torchvision==0.12.0 python -m pip install -r requirements-dev.txt - name: Run quick tests CPU ubuntu run: | python -m pip list python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' - BUILD_MONAI=1 ./runtests.sh --quick --unittests + BUILD_MONAI=1 ./runtests.sh --build --quick --unittests --disttests coverage xml - name: Upload coverage uses: codecov/codecov-action@v1 diff --git a/.github/workflows/weekly-preview.yml b/.github/workflows/weekly-preview.yml index df0b5dd759..0d899b1f30 100644 --- a/.github/workflows/weekly-preview.yml +++ b/.github/workflows/weekly-preview.yml @@ -33,7 +33,7 @@ jobs: export YEAR_WEEK=$(date +'%y%U') echo "Year week for tag is ${YEAR_WEEK}" if ! [[ $YEAR_WEEK =~ ^[0-9]{4}$ ]] ; then echo "Wrong 'year week' format. Should be 4 digits."; exit 1 ; fi - git tag "0.7.dev${YEAR_WEEK}" + git tag "0.9.dev${YEAR_WEEK}" git log -1 git tag --list python setup.py sdist bdist_wheel diff --git a/.gitignore b/.gitignore index 7444d7f2f9..542e08e3b6 100644 --- a/.gitignore +++ b/.gitignore @@ -69,6 +69,7 @@ instance/ # Sphinx documentation docs/_build/ +_build/ # PyBuilder target/ @@ -129,9 +130,11 @@ temp/ tests/testing_data/MedNIST* tests/testing_data/*Hippocampus* tests/testing_data/*.tiff +tests/testing_data/schema.json # clang format tool .clang-format-bin/ # VSCode .vscode/ +*.zip diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c36c96186c..3fc762db46 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.1.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace @@ -22,18 +22,35 @@ repos: args: ['--maxkb=1024'] - id: detect-private-key - #- repo: https://github.com/asottile/pyupgrade - # rev: v2.23.2 - # hooks: - # - id: pyupgrade - # args: [--py36-plus] - # name: Upgrade code + - repo: https://github.com/asottile/pyupgrade + rev: v2.31.0 + hooks: + - id: pyupgrade + args: [--py37-plus] + name: Upgrade code + exclude: | + (?x)^( + versioneer.py| + monai/_version.py + )$ - #- repo: https://github.com/asottile/yesqa - # rev: v1.2.3 - # hooks: - # - id: yesqa - # name: Unused noqa + - repo: https://github.com/asottile/yesqa + rev: v1.3.0 + hooks: + - id: yesqa + name: Unused noqa + additional_dependencies: + - flake8>=3.8.1 + - flake8-bugbear + - flake8-comprehensions + - flake8-executable + - flake8-pyi + - pep8-naming + exclude: | + (?x)^( + monai/__init__.py| + docs/source/conf.py + )$ #- repo: https://github.com/PyCQA/isort # rev: 5.9.3 diff --git a/CHANGELOG.md b/CHANGELOG.md index bdbd23e7dd..3f55ded72f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,146 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). ## [Unreleased] -* renamed model's `n_classes` to `num_classes` + +## [0.8.1] - 2022-02-16 +### Added +* Support of `matshow3d` with given `channel_dim` +* Support of spatial 2D for `ViTAutoEnc` +* Support of `dataframe` object input in `CSVDataset` +* Support of tensor backend for `Orientation` +* Support of configurable delimiter for CSV writers +* A base workflow API +* `DataFunc` API for dataset-level preprocessing +* `write_scalar` API for logging with additional `engine` parameter in `TensorBoardHandler` +* Enhancements for NVTX Range transform logging +* Enhancements for `set_determinism` +* Performance enhancements in the cache-based datasets +* Configurable metadata keys for `monai.data.DatasetSummary` +* Flexible `kwargs` for `WSIReader` +* Logging for the learning rate schedule handler +* `GridPatchDataset` as subclass of `monai.data.IterableDataset` +* `is_onehot` option in `KeepLargestConnectedComponent` +* `channel_dim` in the image readers and support of stacking images with channels +* Skipping workflow `run` if epoch length is 0 +* Enhanced `CacheDataset` to avoid duplicated cache items +* `save_state` utility function + +### Changed +* Optionally depend on PyTorch-Ignite v0.4.8 instead of v0.4.6 +* `monai.apps.mmars.load_from_mmar` defaults to the latest version + +### Fixed +* Issue when caching large items with `pickle` +* Issue of hard-coded activation functions in `ResBlock` +* Issue of `create_file_name` assuming local disk file creation +* Issue of `WSIReader` when the backend is `TiffFile` +* Issue of `deprecated_args` when the function signature contains kwargs +* Issue of `channel_wise` computations for the intensity-based transforms +* Issue of inverting `OneOf` +* Issue of removing temporary caching file for the persistent dataset +* Error messages when reader backend is not available +* Output type casting issue in `ScaleIntensityRangePercentiles` +* Various docstring typos and broken URLs +* `mode` in the evaluator engine +* Ordering of `Orientation` and `Spacing` in `monai.apps.deepgrow.dataset` + +### Removed +* Additional deep supervision modules in `DynUnet` +* Deprecated `reduction` argument for `ContrastiveLoss` +* Decollate warning in `Workflow` +* Unique label exception in `ROCAUCMetric` +* Logger configuration logic in the event handlers + +## [0.8.0] - 2021-11-25 +### Added +* Overview of [new features in v0.8](docs/source/whatsnew_0_8.md) +* Network modules for differentiable neural network topology search (DiNTS) +* Multiple Instance Learning transforms and models for digital pathology WSI analysis +* Vision transformers for self-supervised representation learning +* Contrastive loss for self-supervised learning +* Finalized major improvements of 200+ components in `monai.transforms` to support input and backend in PyTorch and NumPy +* Initial registration module benchmarking with `GlobalMutualInformationLoss` as an example +* `monai.transforms` documentation with visual examples and the utility functions +* Event handler for `MLfLow` integration +* Enhanced data visualization functions including `blend_images` and `matshow3d` +* `RandGridDistortion` and `SmoothField` in `monai.transforms` +* Support of randomized shuffle buffer in iterable datasets +* Performance review and enhancements for data type casting +* Cumulative averaging API with distributed environment support +* Module utility functions including `require_pkg` and `pytorch_after` +* Various usability enhancements such as `allow_smaller` when sampling ROI and `wrap_sequence` when casting object types +* `tifffile` support in `WSIReader` +* Regression tests for the fast training workflows +* Various tutorials and demos including educational contents at [MONAI Bootcamp 2021](https://github.com/Project-MONAI/MONAIBootcamp2021) +### Changed +* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:21.10-py3` from `nvcr.io/nvidia/pytorch:21.08-py3` +* Decoupled `TraceKeys` and `TraceableTransform` APIs from `InvertibleTransform` +* Skipping affine-based resampling when `resample=False` in `NiftiSaver` +* Deprecated `threshold_values: bool` and `num_classes: int` in `AsDiscrete` +* Enhanced `apply_filter` for spatially 1D, 2D and 3D inputs with non-separable kernels +* Logging with `logging` in downloading and model archives in `monai.apps` +* API documentation site now defaults to `stable` instead of `latest` +* `skip-magic-trailing-comma` in coding style enforcements +* Pre-merge CI pipelines now include unit tests with Nvidia Ampere architecture +### Removed +* Support for PyTorch 1.5 +* The deprecated `DynUnetV1` and the related network blocks +* GitHub self-hosted CI/CD pipelines for package releases +### Fixed +* Support of path-like objects as file path inputs in most modules +* Issue of `decollate_batch` for dictionary of empty lists +* Typos in documentation and code examples in various modules +* Issue of no available keys when `allow_missing_keys=True` for the `MapTransform` +* Issue of redundant computation when normalization factors are 0.0 and 1.0 in `ScaleIntensity` +* Incorrect reports of registered readers in `ImageReader` +* Wrong numbering of iterations in `StatsHandler` +* Naming conflicts in network modules and aliases +* Incorrect output shape when `reduction="none"` in `FocalLoss` +* Various usability issues reported by users + +## [0.7.0] - 2021-09-24 +### Added +* Overview of [new features in v0.7](docs/source/whatsnew_0_7.md) +* Initial phase of major usability improvements in `monai.transforms` to support input and backend in PyTorch and NumPy +* Performance enhancements, with [profiling and tuning guides](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_model_training_guide.md) for typical use cases +* Reproducing [training modules and workflows](https://github.com/Project-MONAI/tutorials/tree/master/kaggle/RANZCR/4th_place_solution) of state-of-the-art Kaggle competition solutions +* 24 new transforms, including + * `OneOf` meta transform + * DeepEdit guidance signal transforms for interactive segmentation + * Transforms for self-supervised pre-training + * Integration of [NVIDIA Tools Extension](https://developer.nvidia.com/blog/nvidia-tools-extension-api-nvtx-annotation-tool-for-profiling-code-in-python-and-c-c/) (NVTX) + * Integration of [cuCIM](https://github.com/rapidsai/cucim) + * Stain normalization and contextual grid for digital pathology +* `Transchex` network for vision-language transformers for chest X-ray analysis +* `DatasetSummary` utility in `monai.data` +* `WarmupCosineSchedule` +* Deprecation warnings and documentation support for better backwards compatibility +* Padding with additional `kwargs` and different backend API +* Additional options such as `dropout` and `norm` in various networks and their submodules + +### Changed +* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:21.08-py3` from `nvcr.io/nvidia/pytorch:21.06-py3` +* Deprecated input argument `n_classes`, in favor of `num_classes` +* Deprecated input argument `dimensions` and `ndims`, in favor of `spatial_dims` +* Updated the Sphinx-based documentation theme for better readability +* `NdarrayTensor` type is replaced by `NdarrayOrTensor` for simpler annotations +* Self-attention-based network blocks now support both 2D and 3D inputs + +### Removed +* The deprecated `TransformInverter`, in favor of `monai.transforms.InvertD` +* GitHub self-hosted CI/CD pipelines for nightly and post-merge tests +* `monai.handlers.utils.evenly_divisible_all_gather` +* `monai.handlers.utils.string_list_all_gather` + +### Fixed +* A Multi-thread cache writing issue in `LMDBDataset` +* Output shape convention inconsistencies of the image readers +* Output directory and file name flexibility issue for `NiftiSaver`, `PNGSaver` +* Requirement of the `label` field in test-time augmentation +* Input argument flexibility issues for `ThreadDataLoader` +* Decoupled `Dice` and `CrossEntropy` intermediate results in `DiceCELoss` +* Improved documentation, code examples, and warning messages in various modules +* Various usability issues reported by users ## [0.6.0] - 2021-07-08 ### Added @@ -25,6 +164,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. * Fully compatible with PyTorch 1.9 * `--disttests` and `--min` options for `runtests.sh` * Initial support of pre-merge tests with Nvidia Blossom system + ### Changed * Base Docker image upgraded to `nvcr.io/nvidia/pytorch:21.06-py3` from `nvcr.io/nvidia/pytorch:21.04-py3` @@ -34,11 +174,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. * Unified the terms: `post_transform` is renamed to `postprocessing`, `pre_transform` is renamed to `preprocessing` * Unified the postprocessing transforms and event handlers to accept the "channel-first" data format * `evenly_divisible_all_gather` and `string_list_all_gather` moved to `monai.utils.dist` + ### Removed * Support of 'batched' input for postprocessing transforms and event handlers * `TorchVisionFullyConvModel` * `set_visible_devices` utility function * `SegmentationSaver` and `TransformsInverter` handlers + ### Fixed * Issue of handling big-endian image headers * Multi-thread issue for non-random transforms in the cache-based datasets @@ -269,9 +411,11 @@ the postprocessing steps should be used before calling the metrics methods * Optionally depend on PyTorch-Ignite v0.4.2 instead of v0.3.0 * Optionally depend on torchvision, ITK * Enhanced CI tests with 8 new testing environments + ### Removed * `MONAI/examples` folder (relocated into [`Project-MONAI/tutorials`](https://github.com/Project-MONAI/tutorials)) * `MONAI/research` folder (relocated to [`Project-MONAI/research-contributions`](https://github.com/Project-MONAI/research-contributions)) + ### Fixed * `dense_patch_slices` incorrect indexing * Data type issue in `GeneralizedWassersteinDiceLoss` @@ -302,6 +446,7 @@ the postprocessing steps should be used before calling the metrics methods * Cross-platform CI tests supporting multiple Python versions * Optional import mechanism * Experimental features for third-party transforms integration + ### Changed > For more details please visit [the project wiki](https://github.com/Project-MONAI/MONAI/wiki/Notable-changes-between-0.1.0-and-0.2.0) * Core modules now require numpy >= 1.17 @@ -311,9 +456,11 @@ the postprocessing steps should be used before calling the metrics methods * Base Docker image upgraded to `nvcr.io/nvidia/pytorch:20.03-py3` from `nvcr.io/nvidia/pytorch:19.10-py3` * Enhanced local testing tools * Documentation website domain changed to https://docs.monai.io + ### Removed * Support of Python < 3.6 * Automatic installation of optional dependencies including pytorch-ignite, nibabel, tensorboard, pillow, scipy, scikit-image + ### Fixed * Various issues in type and argument names consistency * Various issues in docstring and documentation site @@ -336,7 +483,10 @@ the postprocessing steps should be used before calling the metrics methods [highlights]: https://github.com/Project-MONAI/MONAI/blob/master/docs/source/highlights.md -[Unreleased]: https://github.com/Project-MONAI/MONAI/compare/0.6.0...HEAD +[Unreleased]: https://github.com/Project-MONAI/MONAI/compare/0.8.1...HEAD +[0.8.1]: https://github.com/Project-MONAI/MONAI/compare/0.8.0...0.8.1 +[0.8.0]: https://github.com/Project-MONAI/MONAI/compare/0.7.0...0.8.0 +[0.7.0]: https://github.com/Project-MONAI/MONAI/compare/0.6.0...0.7.0 [0.6.0]: https://github.com/Project-MONAI/MONAI/compare/0.5.3...0.6.0 [0.5.3]: https://github.com/Project-MONAI/MONAI/compare/0.5.0...0.5.3 [0.5.0]: https://github.com/Project-MONAI/MONAI/compare/0.4.0...0.5.0 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0dce26582a..129a839fd7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -43,6 +43,7 @@ This section highlights all the necessary preparation steps required before send To collaborate efficiently, please read through this section and follow them. * [Checking the coding style](#checking-the-coding-style) +* [Licensing information](#licensing-information) * [Unit testing](#unit-testing) * [Building documentation](#building-the-documentation) * [Signing your work](#signing-your-work) @@ -63,9 +64,11 @@ python -m pip install -U -r requirements-dev.txt ./runtests.sh --autofix ``` -License information: all source code files should start with this paragraph: +#### Licensing information +All source code files should start with this paragraph: + ``` -# Copyright MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -113,12 +116,18 @@ or (for new features that would not break existing functionality): ``` It is recommended that the new test `test_[module_name].py` is constructed by using only -python 3.6+ build-in functions, `torch`, `numpy`, `coverage` (for reporting code coverages) and `parameterized` (for organising test cases) packages. +python 3.7+ build-in functions, `torch`, `numpy`, `coverage` (for reporting code coverages) and `parameterized` (for organising test cases) packages. If it requires any other external packages, please make sure: - the packages are listed in [`requirements-dev.txt`](requirements-dev.txt) - the new test `test_[module_name].py` is added to the `exclude_cases` in [`./tests/min_tests.py`](./tests/min_tests.py) so that the minimal CI runner will not execute it. +##### Testing data +Testing data such as images and binary files should not be placed in the source code repository. +Please deploy them to a reliable file sharing location (the current preferred one is [https://github.com/Project-MONAI/MONAI-extra-test-data/releases](https://github.com/Project-MONAI/MONAI-extra-test-data/releases)). +At test time, the URLs within `tests/testing_data/data_config.json` are accessible +via the APIs provided in `tests.utils`: `tests.utils.testing_data_config` and `tests.utils.download_url_or_skip_test`. + _If it's not tested, it's broken_ All new functionality should be accompanied by an appropriate set of tests. @@ -228,7 +237,7 @@ Notably, for example, ``import monai.transforms.Spacing`` is the equivalent of ``monai.transforms.spatial.array.Spacing`` if ``class Spacing`` defined in file `monai/transforms/spatial/array.py` is decorated with ``@export("monai.transforms")``. -For string definition, [f-string](https://www.python.org/dev/peps/pep-0498/) is recommended to use over `%-print` and `format-print` from python 3.6. So please try to use `f-string` if you need to define any string object. +For string definition, [f-string](https://www.python.org/dev/peps/pep-0498/) is recommended to use over `%-print` and `format-print`. So please try to use `f-string` if you need to define any string object. #### Backwards compatibility MONAI is currently under active development, and with major version zero (following the [Semantic Versioning](https://semver.org/)). @@ -289,9 +298,9 @@ When major features are ready for a milestone, to prepare for a new release: repository's artifacts (e.g. the file at https://github.com/Project-MONAI/MONAI/actions/runs/66570977). - Check the release test at [TestPyPI](https://test.pypi.org/project/monai/), download the artifacts when the CI finishes. - Optionally run [the cron testing jobs](https://github.com/Project-MONAI/MONAI/blob/dev/.github/workflows/cron.yml) on `releasing/[version number]`. +- Rebase `releasing/[version number]` to `main`, make sure all the test pipelines succeed. - Once the release candidate is verified, tag and push a milestone, for example, `git push origin 0.1.0`. The tag must be with the latest commit of `releasing/[version number]`. -- Rebase `releasing/[version number]` to `main`, make sure all the test pipelines succeed. - Upload the packages to [PyPI](https://pypi.org/project/monai/). This could be done manually by ``twine upload dist/*``, given the artifacts are unzipped to the folder ``dist/``. - Merge `releasing/[version number]` to `dev`, this step must make sure that the tagging commit unchanged on `dev`. diff --git a/Dockerfile b/Dockerfile index 77fe1f828f..4171309c70 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,7 +11,7 @@ # To build with a different base image # please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag. -ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:21.08-py3 +ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:22.02-py3 FROM ${PYTORCH_IMAGE} LABEL maintainer="monai.contact@gmail.com" diff --git a/README.md b/README.md index e9facef64d..bb217b7247 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ [![codecov](https://codecov.io/gh/Project-MONAI/MONAI/branch/dev/graph/badge.svg)](https://codecov.io/gh/Project-MONAI/MONAI) [![PyPI version](https://badge.fury.io/py/monai.svg)](https://badge.fury.io/py/monai) -MONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](LICENSE) framework for deep learning in healthcare imaging, part of [PyTorch Ecosystem](https://pytorch.org/ecosystem/). +MONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](https://github.com/Project-MONAI/MONAI/blob/dev/LICENSE) framework for deep learning in healthcare imaging, part of [PyTorch Ecosystem](https://pytorch.org/ecosystem/). Its ambitions are: - developing a community of academic, industrial and clinical researchers collaborating on a common foundation; - creating state-of-the-art, end-to-end training workflows for healthcare imaging; @@ -19,7 +19,7 @@ Its ambitions are: ## Features > _The codebase is currently under active development._ -> _Please see [the technical highlights](https://docs.monai.io/en/latest/highlights.html) and [What's New in 0.6](https://docs.monai.io/en/latest/whatsnew_0_6.html) of the current milestone release._ +> _Please see [the technical highlights](https://docs.monai.io/en/latest/highlights.html) and [What's New](https://docs.monai.io/en/latest/whatsnew.html) of the current milestone release._ - flexible pre-processing for multi-dimensional medical imaging data; - compositional & portable APIs for ease of integration in existing workflows; @@ -47,7 +47,7 @@ Examples and notebook tutorials are located at [Project-MONAI/tutorials](https:/ Technical documentation is available at [docs.monai.io](https://docs.monai.io). ## Contributing -For guidance on making a contribution to MONAI, see the [contributing guidelines](CONTRIBUTING.md). +For guidance on making a contribution to MONAI, see the [contributing guidelines](https://github.com/Project-MONAI/MONAI/blob/dev/CONTRIBUTING.md). ## Community Join the conversation on Twitter [@ProjectMONAI](https://twitter.com/ProjectMONAI) or join our [Slack channel](https://forms.gle/QTxJq3hFictp31UM9). diff --git a/docs/images/blend.png b/docs/images/blend.png new file mode 100644 index 0000000000..fdf8a21385 Binary files /dev/null and b/docs/images/blend.png differ diff --git a/docs/images/brats_distributed.png b/docs/images/brats_distributed.png new file mode 100644 index 0000000000..90877336ee Binary files /dev/null and b/docs/images/brats_distributed.png differ diff --git a/docs/images/dints-overview.png b/docs/images/dints-overview.png new file mode 100644 index 0000000000..e5f592a8e5 Binary files /dev/null and b/docs/images/dints-overview.png differ diff --git a/docs/images/distributed_training.png b/docs/images/distributed_training.png deleted file mode 100644 index 378855aab3..0000000000 Binary files a/docs/images/distributed_training.png and /dev/null differ diff --git a/docs/images/fast_training.png b/docs/images/fast_training.png index d0584b9dac..34e47bcb21 100644 Binary files a/docs/images/fast_training.png and b/docs/images/fast_training.png differ diff --git a/docs/images/matshow3d.png b/docs/images/matshow3d.png new file mode 100644 index 0000000000..f71e69a99f Binary files /dev/null and b/docs/images/matshow3d.png differ diff --git a/docs/images/mil-patches.jpg b/docs/images/mil-patches.jpg new file mode 100644 index 0000000000..fd904943be Binary files /dev/null and b/docs/images/mil-patches.jpg differ diff --git a/docs/images/nsight_comparison.png b/docs/images/nsight_comparison.png new file mode 100644 index 0000000000..9b91826513 Binary files /dev/null and b/docs/images/nsight_comparison.png differ diff --git a/docs/images/rand_gaussian_noise.png b/docs/images/rand_gaussian_noise.png new file mode 100644 index 0000000000..a824ea8cc6 Binary files /dev/null and b/docs/images/rand_gaussian_noise.png differ diff --git a/docs/images/ssl_overview.png b/docs/images/ssl_overview.png new file mode 100644 index 0000000000..68fa1af576 Binary files /dev/null and b/docs/images/ssl_overview.png differ diff --git a/docs/images/threaddataloader.png b/docs/images/threaddataloader.png new file mode 100644 index 0000000000..565df8d0d4 Binary files /dev/null and b/docs/images/threaddataloader.png differ diff --git a/docs/requirements.txt b/docs/requirements.txt index 00dd4d2c1e..f9749e9e36 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,6 @@ -f https://download.pytorch.org/whl/cpu/torch-1.6.0%2Bcpu-cp37-cp37m-linux_x86_64.whl -torch>=1.5 -pytorch-ignite==0.4.5 +torch>=1.6 +pytorch-ignite==0.4.8 numpy>=1.17 itk>=5.2 nibabel @@ -20,3 +20,11 @@ sphinxcontrib-serializinghtml sphinx-autodoc-typehints==1.11.1 pandas einops +transformers +mlflow +tensorboardX +imagecodecs; platform_system == "Linux" +tifffile; platform_system == "Linux" +pyyaml +fire +jsonschema diff --git a/docs/source/api.rst b/docs/source/api.rst index 0596a25514..c2b19adeb2 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -7,6 +7,7 @@ API Reference :maxdepth: 1 apps + bundle transforms losses networks diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 11d60767ec..239ae9eb17 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -19,8 +19,8 @@ Applications :members: -Clara MMARs ------------ +`Clara MMARs` +------------- .. autofunction:: download_mmar .. autofunction:: load_from_mmar @@ -114,7 +114,11 @@ Clara MMARs .. automodule:: monai.apps.pathology.transforms.spatial.array .. autoclass:: SplitOnGrid :members: +.. autoclass:: TileOnGrid + :members: .. automodule:: monai.apps.pathology.transforms.spatial.dictionary .. autoclass:: SplitOnGridd :members: +.. autoclass:: TileOnGridd + :members: diff --git a/docs/source/bundle.rst b/docs/source/bundle.rst new file mode 100644 index 0000000000..87c4bf36d2 --- /dev/null +++ b/docs/source/bundle.rst @@ -0,0 +1,40 @@ +:github_url: https://github.com/Project-MONAI/MONAI + +.. _bundle: + +Model Bundle +============ +.. currentmodule:: monai.bundle + +`Config Item` +------------- +.. autoclass:: Instantiable + :members: + +.. autoclass:: ComponentLocator + :members: + +.. autoclass:: ConfigComponent + :members: + +.. autoclass:: ConfigExpression + :members: + +.. autoclass:: ConfigItem + :members: + +`Reference Resolver` +-------------------- +.. autoclass:: ReferenceResolver + :members: + +`Config Parser` +--------------- +.. autoclass:: ConfigParser + :members: + +`Scripts` +--------- +.. autofunction:: run +.. autofunction:: verify_metadata +.. autofunction:: verify_net_in_out diff --git a/docs/source/conf.py b/docs/source/conf.py index 324be8a0fd..db0ca11be3 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -22,7 +22,7 @@ # -- Project information ----------------------------------------------------- project = "MONAI" -copyright = "2020 - 2021 MONAI Consortium" +copyright = "MONAI Consortium" author = "MONAI Contributors" # The full version, including alpha/beta/rc tags @@ -40,6 +40,7 @@ "engines", "data", "apps", + "bundle", "config", "handlers", "losses", @@ -86,7 +87,7 @@ def generate_apidocs(*args): "sphinx_autodoc_typehints", ] -autoclass_content = "both" +autoclass_content = "class" add_module_names = True source_encoding = "utf-8" autosectionlabel_prefix_document = True @@ -107,16 +108,8 @@ def generate_apidocs(*args): html_theme_options = { "external_links": [{"url": "https://github.com/Project-MONAI/tutorials", "name": "Tutorials"}], "icon_links": [ - { - "name": "GitHub", - "url": "https://github.com/project-monai/monai", - "icon": "fab fa-github-square", - }, - { - "name": "Twitter", - "url": "https://twitter.com/projectmonai", - "icon": "fab fa-twitter-square", - }, + {"name": "GitHub", "url": "https://github.com/project-monai/monai", "icon": "fab fa-github-square"}, + {"name": "Twitter", "url": "https://twitter.com/projectmonai", "icon": "fab fa-twitter-square"}, ], "collapse_navigation": True, "navigation_depth": 3, diff --git a/docs/source/data.rst b/docs/source/data.rst index 6e7f5e2773..2bdf401c7f 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -21,6 +21,18 @@ Generic Interfaces :members: :special-members: __next__ +`DatasetFunc` +~~~~~~~~~~~~~ +.. autoclass:: DatasetFunc + :members: + :special-members: __next__ + +`ShuffleBuffer` +~~~~~~~~~~~~~~~ +.. autoclass:: ShuffleBuffer + :members: + :special-members: __next__ + `CSVIterableDataset` ~~~~~~~~~~~~~~~~~~~~ .. autoclass:: CSVIterableDataset @@ -138,6 +150,37 @@ WSIReader .. autoclass:: WSIReader :members: +Image writer +------------ + +resolve_writer +~~~~~~~~~~~~~~ +.. autofunction:: resolve_writer + +register_writer +~~~~~~~~~~~~~~~ +.. autofunction:: register_writer + +ImageWriter +~~~~~~~~~~~ +.. autoclass:: ImageWriter + :members: + +ITKWriter +~~~~~~~~~ +.. autoclass:: ITKWriter + :members: + +NibabelWriter +~~~~~~~~~~~~~ +.. autoclass:: NibabelWriter + :members: + +PILWriter +~~~~~~~~~ +.. autoclass:: PILWriter + :members: + Nifti format handling --------------------- @@ -166,6 +209,12 @@ Synthetic :members: +Ouput folder layout +------------------- +.. automodule:: monai.data.folder_layout + :members: + + Utilities --------- .. automodule:: monai.data.utils @@ -194,6 +243,9 @@ DatasetSummary Decathlon Datalist ~~~~~~~~~~~~~~~~~~ .. autofunction:: monai.data.load_decathlon_datalist +.. autofunction:: monai.data.load_decathlon_properties +.. autofunction:: monai.data.check_missing_files +.. autofunction:: monai.data.create_cross_validation_datalist DataLoader @@ -205,6 +257,9 @@ ThreadBuffer ~~~~~~~~~~~~ .. autoclass:: monai.data.ThreadBuffer +ThreadDataLoader +~~~~~~~~~~~~~~~~ +.. autoclass:: monai.data.ThreadDataLoader TestTimeAugmentation ~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/engines.rst b/docs/source/engines.rst index cc0ec3c659..0cd40afb78 100644 --- a/docs/source/engines.rst +++ b/docs/source/engines.rst @@ -11,20 +11,21 @@ Multi-GPU data parallel .. automodule:: monai.engines.multi_gpu_supervised_trainer :members: - Workflows --------- -.. automodule:: monai.engines.workflow -.. currentmodule:: monai.engines.workflow +.. currentmodule:: monai.engines + +`BaseWorkflow` +~~~~~~~~~~~~~~ +.. autoclass:: BaseWorkflow + :members: `Workflow` ~~~~~~~~~~ .. autoclass:: Workflow :members: -.. currentmodule:: monai.engines - `Trainer` ~~~~~~~~~ .. autoclass:: Trainer @@ -54,3 +55,8 @@ Workflows ~~~~~~~~~~~~~~~~~~~ .. autoclass:: EnsembleEvaluator :members: + +Utilities +--------- +.. automodule:: monai.engines.utils + :members: diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 5caccc6b4b..d32b6d88e3 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -150,11 +150,6 @@ GarbageCollector handler .. autoclass:: GarbageCollector :members: -Transform inverter ------------------- -.. autoclass:: TransformInverter - :members: - Post processing --------------- .. autoclass:: PostProcessing @@ -165,6 +160,11 @@ Decollate batch .. autoclass:: DecollateBatch :members: +MLFlow handler +-------------- +.. autoclass:: MLFlowHandler + :members: + NVTX Handlers ------------- .. automodule:: monai.handlers.nvtx_handlers diff --git a/docs/source/highlights.md b/docs/source/highlights.md index 141c0846d1..cc91cfcd86 100644 --- a/docs/source/highlights.md +++ b/docs/source/highlights.md @@ -16,7 +16,7 @@ The overall architecture and modules are shown in the following figure: The rest of this page provides more details for each module. * [Data I/O, processing and augmentation](#medical-image-data-i-o-processing-and-augmentation) -* [Datasets](#datasets) +* [Datasets and DataLoader](#datasets-and-dataloader) * [Loss functions](#losses) * [Optimizers](#optimizers) * [Network architectures](#network-architectures) @@ -25,7 +25,7 @@ The rest of this page provides more details for each module. * [Result writing](#result-writing) * [Workflows](#workflows) * [Research](#research) -* [GPU acceleration](#gpu-acceleration) +* [Performance optimization and GPU acceleration](#performance-optimization-and-gpu-acceleration) * [Applications](#applications) ## Medical image data I/O, processing and augmentation @@ -56,8 +56,15 @@ transformations. These currently include, for example: [2D transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/transforms_demo_2d.ipynb) shows the detailed usage of several MONAI medical image specific transforms. ![2d transform examples](../images/medical_transforms.png) -### 3. Fused spatial transforms and GPU acceleration -As medical image volumes are usually large (in multi-dimensional arrays), pre-processing performance affects the overall pipeline speed. MONAI provides affine transforms to execute fused spatial operations, supports GPU acceleration via native PyTorch for high performance. + +### 3. Transforms support both NumPy array and PyTorch Tensor (CPU or GPU accelerated) +From MONAI v0.7 we introduced PyTorch `Tensor` based computation in transforms, many transforms already support both `NumPy array` and `Tensor` as input types and computational backends. To get the supported backends of every transform, please execute: `python monai/transforms/utils.py`. + +To accelerate the transforms, a common approach is to leverage GPU parallel-computation. Users can first convert input data into GPU Tensor by `ToTensor` or `EnsureType` transform, then the following transforms can execute on GPU based on PyTorch `Tensor` APIs. +GPU transform tutorial is available at [Spleen fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb). + +### 4. Fused spatial transforms +As medical image volumes are usually large (in multi-dimensional arrays), pre-processing performance affects the overall pipeline speed. MONAI provides affine transforms to execute fused spatial operations. For example: ```py @@ -67,20 +74,21 @@ affine = Affine( scale_params=(1.2, 1.2), translate_params=(200, 40), padding_mode='zeros', - device=torch.device('cuda:0') ) # convert the image using bilinear interpolation new_img = affine(image, spatial_size=(300, 400), mode='bilinear') ``` Experiments and test results are available at [Fused transforms test](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/transform_speed.ipynb). -Currently, all the geometric image transforms (Spacing, Zoom, Rotate, Resize, etc.) are designed based on the PyTorch native interfaces. [Geometric transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/3d_image_transforms.ipynb) indicates the usage of affine transforms with 3D medical images. +Currently, all the geometric image transforms (Spacing, Zoom, Rotate, Resize, etc.) are designed based on the PyTorch native interfaces. So all of them support GPU acceleration via `GPU Tensor` operations for high performance. + +[Geometric transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/3d_image_transforms.ipynb) indicates the usage of affine transforms with 3D medical images. ![3d transform examples](../images/affine.png) -### 4. Randomly crop out batch images based on positive/negative ratio +### 5. Randomly crop out batch images based on positive/negative ratio Medical image data volume may be too large to fit into GPU memory. A widely-used approach is to randomly draw small size data samples during training and run a “sliding window” routine for inference. MONAI currently provides general random sampling strategies including class-balanced fixed ratio sampling which may help stabilize the patch-based training process. A typical example is in [Spleen 3D segmentation tutorial](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d.ipynb), which achieves the class-balanced sampling with `RandCropByPosNegLabel` transform. -### 5. Deterministic training for reproducibility +### 6. Deterministic training for reproducibility Deterministic training support is necessary and important for deep learning research, especially in the medical field. Users can easily set the random seed to all the random transforms in MONAI locally and will not affect other non-deterministic modules in the user's program. For example: @@ -99,16 +107,16 @@ Users can also enable/disable deterministic at the beginning of training program monai.utils.set_determinism(seed=0, additional_settings=None) ``` -### 6. Multiple transform chains +### 7. Multiple transform chains To apply different transforms on the same data and concatenate the results, MONAI provides `CopyItems` transform to make copies of specified items in the data dictionary and `ConcatItems` transform to combine specified items on the expected dimension, and also provides `DeleteItems` transform to delete unnecessary items to save memory. Typical usage is to scale the intensity of the same image into different ranges and concatenate the results together. ![multiple transform chains](../images/multi_transform_chains.png) -### 7. Debug transforms with DataStats +### 8. Debug transforms with DataStats When transforms are combined with the "compose" function, it's not easy to track the output of a specific transform. To help debug errors in the composed transforms, MONAI provides utility transforms such as `DataStats` to print out intermediate data properties such as `data shape`, `value range`, `data value`, `Additional information`, etc. It's a self-contained transform and can be integrated into any transform chain. -### 8. Post-processing transforms for model output +### 9. Post-processing transforms for model output MONAI also provides post-processing transforms for handling the model outputs. Currently, the transforms include: - Adding an activation layer (Sigmoid, Softmax, etc.). - Converting to discrete values (Argmax, One-Hot, Threshold value, etc), as below figure (b). @@ -119,12 +127,19 @@ MONAI also provides post-processing transforms for handling the model outputs. C After decollating the batch data of model output and applying the post-processing transforms, it's easier to compute metrics, save model output into files or visualize data in the TensorBoard. [Postprocessing transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/postprocessing_transforms.ipynb) shows an example with several main transforms for post-processing. ![post-processing transforms](../images/postprocessing_transforms.png) -### 9. Integrate third-party transforms +### 10. Integrate third-party transforms The design of MONAI transforms emphasis code readability and usability. It works for array data or dictionary-based data. MONAI also provides `Adaptor` tools to accommodate different data format for 3rd party transforms. To convert the data shapes or types, utility transforms such as `ToTensor`, `ToNumpy`, `SqueezeDim` are also provided. So it's easy to enhance the transform chain by seamlessly integrating transforms from external packages, including: `ITK`, `BatchGenerator`, `TorchIO` and `Rising`. For more details, please check out the tutorial: [integrate 3rd party transforms into MONAI program](https://github.com/Project-MONAI/tutorials/blob/master/modules/integrate_3rd_party_transforms.ipynb). -### 10. IO factory for medical image formats +In digital pathology training, due to the immense burden of loading images, the CPU is preoccupied by loading images and cannot catch up with preparing the data. This causes the pipeline to become IO bound and results in under-utilization of GPU. To overcome this bottleneck, [cuCIM](https://github.com/rapidsai/cucim) has implemented an optimized version of several common transforms that we are using in digital pathology pipeline. These transforms are natively being run on GPU and act on CuPy arrays. MONAI provides `CuCIM` and `RandCuCIM` adapters to integrate the `cuCIM` library. For instance: +```py +RandCuCIM(name="color_jitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04) +CuCIM(name="scale_intensity_range", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0) +``` +It has shown a significant speed up in pathology training metastasis detection model. + +### 11. IO factory for medical image formats Many popular image formats exist in the medical domain, and they are quite different with rich metadata information. To easily handle different medical image formats in the same pipeline, [MONAI provides `LoadImage` transform](https://github.com/Project-MONAI/tutorials/blob/master/modules/load_medical_images.ipynb), which can automatically choose image readers based on the supported suffixes and in the following priority order: - User-specified reader at runtime when calling this loader. - Registered readers from the latest to the first in the list. @@ -134,13 +149,13 @@ The `ImageReader` API is quite straightforward, users can easily extend it for t With these pre-defined image readers, MONAI can load images in formats: `NIfTI`, `DICOM`, `PNG`, `JPG`, `BMP`, `NPY/NPZ`, etc. -### 11. Save transform data into NIfTI or PNG files +### 12. Save transform data into NIfTI or PNG files To convert images into files or debug the transform chain, MONAI provides `SaveImage` transform. Users can inject this transform into the transform chain to save the results. -### 12. Automatically ensure `channel-first` data shape +### 13. Automatically ensure `channel-first` data shape Medical images have different shape formats. They can be `channel-last`, `channel-first` or even `no-channel`. We may, for example, want to load several `no-channel` images and stack them as `channel-first` data. To improve the user experience, MONAI provided an `EnsureChannelFirst` transform to automatically detect data shape according to the meta information and convert it to the `channel-first` format consistently. -### 13. Invert spatial transforms and test-time augmentations +### 14. Invert spatial transforms and test-time augmentations It is often desirable to invert the previously applied spatial transforms (resize, flip, rotate, zoom, crop, pad, etc.) within the deep learning workflows, for example, to resume to the original imaging space after processing the image data in a normalized data space. Many spatial transforms are enhanced with an `inverse` operation since in v0.5. The [model inference tutorial](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_inference_dict.py) shows a basic example. If the pipeline includes random transformations, users may want to observe the effect that these transformations have on the output. The typical approach is that we pass the same input through the transforms multiple times with different random realizations. Then use the inverse transforms to move all the results to a common space, and calculate the metrics. MONAI provided `TestTimeAugmentation` for this feature, which by default will calculate the `mode`, `mean`, `standard deviation` and `volume variation coefficient`. @@ -148,20 +163,31 @@ If the pipeline includes random transformations, users may want to observe the e [Invert transforms and TTA tutorials](https://github.com/Project-MONAI/tutorials/blob/master/modules/inverse_transforms_and_test_time_augmentations.ipynb) introduce details about the API with usage examples. (1) The last column is the inverted data of model output: + ![invert transform](../images/invert_transforms.png) (2) The TTA results of `mode`, `mean` and `standard deviation`: + ![test time augmentation](../images/tta.png) -## Datasets +### 15. Visualization of transform examples +To help clearly introduce the transform functionalities, MONAI provides visualization examples in the [API document](https://docs.monai.io/en/latest/transforms.html) for almost every transform, including spatial transforms, intensity transforms, crop / pad transforms, etc. + +For example: + +![rand gaussian noise](../images/rand_gaussian_noise.png) + +## Datasets and DataLoader ### 1. Cache IO and transforms data to accelerate training Users often need to train the model with many (potentially thousands of) epochs over the data to achieve the desired model quality. A native PyTorch implementation may repeatedly load data and run the same preprocessing steps for every epoch during training, which can be time-consuming and unnecessary, especially when the medical image volumes are large. MONAI provides a multi-thread `CacheDataset` and `LMDBDataset` to accelerate these transformation steps during training by storing the intermediate outcomes before the first randomized transform in the transform chain. Enabling this feature could potentially give 10x training speedups in the [Datasets experiment](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/dataset_type_performance.ipynb). + ![digital pathology](../images/cache_dataset.png) ### 2. Cache intermediate outcomes into persistent storage The `PersistentDataset` is similar to the CacheDataset, where the intermediate cache values are persisted to disk storage or LMDB for rapid retrieval between experimental runs (as is the case when tuning hyperparameters), or when the entire data set size exceeds available memory. The `PersistentDataset` could achieve similar performance when comparing to `CacheDataset` in [Datasets experiment](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/dataset_type_performance.ipynb). + ![cachedataset speed](../images/datasets_speed.png) ### 3. SmartCache mechanism for big datasets @@ -212,6 +238,7 @@ To quickly get started with popular training data in the medical domain, MONAI p MONAI always welcome new contributions of public datasets, please refer to existing Datasets and leverage the download and extracting APIs, etc. [Public datasets tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/public_datasets.ipynb) indicates how to quickly set up training workflows with `MedNISTDataset` and `DecathlonDataset` and how to create a new `Dataset` for public data. The common workflow of predefined datasets: + ![pre-defined dataset](../images/dataset_progress.png) ### 7. Partition dataset for cross validation @@ -221,6 +248,13 @@ The `partition_dataset` utility in MONAI can perform different types of partitio CSV tables are often used in additional to image data to incorporate adjunct information, such as patient demographics, lab results, image acquisition parameters and other non-image data, MONAI provides `CSVDataset` to load CSV files and `CSVIterableDataset` to load large CSV files with scalable data access. In addition to the regular preprocessing transform while loading, it also supports multiple CSV files loading, joining tables, rows and columns selection and grouping. [CSVDatasets tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/csv_datasets.ipynb) shows detailed usage examples. +### 9. `ThreadDataLoader` vs. `DataLoader` +If the transforms are light-weighted, especially when we cache all the data in RAM, the multiprocessing of PyTorch `DataLoader` may cause unnecessary IPC time and cause the drop of GPU utilization after every epoch. MONAI provides `ThreadDataLoader` which executes the transforms in a separate thread: + +![threaddataloader](../images/threaddataloader.png) + +a `ThreadDataLoader` example is available at [Spleen fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb). + ## Losses There are domain-specific loss functions in the medical imaging research which are not typically used in generic computer vision tasks. As an important module of MONAI, these loss functions are implemented in PyTorch, such as `DiceLoss`, `GeneralizedDiceLoss`, `MaskedDiceLoss`, `TverskyLoss`, `FocalLoss`, `DiceCELoss`, and `DiceFocalLoss`, etc. @@ -228,6 +262,7 @@ There are domain-specific loss functions in the medical imaging research which a MONAI provides several advanced features in optimizers to help accelerate the training or fine-tuning progress. For example, `Novograd` optimizer can be used to converge faster than the traditional optimizers. And users can easily define different learning rates for the model layers based [on the `generate_param_groups` utility API](https://github.com/Project-MONAI/tutorials/blob/master/modules/layer_wise_learning_rate.ipynb). Another important feature is `LearningRateFinder`. The learning rate range test increases the learning rate in a pre-training run between two boundaries in a linear or exponential manner. It provides valuable information on how well the network can be trained over a range of learning rates and what the optimal learning rates are. [LearningRateFinder tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/learning_rate.ipynb) indicates the API usage examples. + ![learning rate finder plot](../images/lr_finder.png) ## Network architectures @@ -249,7 +284,7 @@ add_module('conv1', conv_type(in_channels, out_channels, kernel_size=1, bias=Fal ``` ### 2. Implementation of generic 2D/3D networks -And there are several 1D/2D/3D-compatible implementations of intermediate blocks and generic networks, such as UNet, DynUNet, DenseNet, GAN, AHNet, VNet, SENet(and SEResNet, SEResNeXt), SegResNet, EfficientNet, Attention-based networks. All the networks can support PyTorch serialization pipeline based on `torch.jit.script`. +And there are several 1D/2D/3D-compatible implementations of intermediate blocks and generic networks, such as UNet, DynUNet, DenseNet, GAN, AHNet, VNet, SENet(and SEResNet, SEResNeXt), SegResNet, EfficientNet, Attention-based transformer networks, Multi-instance learning networks, DiNTS for AutoML, etc. All the networks can support PyTorch serialization pipeline based on `torch.jit.script`. ### 3. Network adapter to finetune final layers Instead of training from scratch, we often leverage the existing models, and finetune the final layers of a network for new learning tasks. MONAI provides a `NetAdapter` to easily replace the last layer of a model by a convolutional layer or a fully-connected layer. A typical usage example is to adapt [Torchvision models trained with ImageNet](https://pytorch.org/vision/stable/models.html) for other learning tasks. @@ -282,10 +317,22 @@ With a `Cumulative` base class, intermediate metric outcomes can be automaticall ### 3. Metrics report generation During evaluation, users usually save the metrics of every input image, then analyze the bad cases to improve the deep learning pipeline. To save detailed information of metrics, MONAI provided a handler `MetricsSaver`, which can save the final metric values, raw metric of every model output channel of every input image, metrics summary report of operations: `mean`, `median`, `max`, `min`, `percentile`, `std`, etc. The `MeanDice` reports of validation with prostate dataset are as below: + ![metrics report example](../images/metrics_report.png) ## Visualization -Beyond the simple point and curve plotting, MONAI provides intuitive interfaces to visualize multidimensional data as GIF animations in TensorBoard. This could provide a quick qualitative assessment of the model by visualizing, for example, the volumetric inputs, segmentation maps, and intermediate feature maps. A runnable example with visualization is available at [UNet training example](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_training_dict.py). +Beyond the simple point and curve plotting, MONAI provides intuitive interfaces to visualize multidimensional data as GIF animations in TensorBoard. This could provide a quick qualitative assessment of the model by visualizing, for example, the volumetric inputs, segmentation maps, and intermediate feature maps. A runnable example with visualization is available at [UNet training example](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_training_dict.py). To work with ignite program, MONAI also provides several ignite handlers to visualize training curve and metrics with `TensorBoard` or `MLFlow`, more details is available in [TensorBoard and MLFlow handlers example](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb). + +To easily visualize a 3D image as frames of 2D images, MONAI provides the utility `matshow3d` based on `matplotlib` library. It can plot frames of image for the specified dimension, showing a spleen 3D image as example: +`matshow3d(volume=image, figsize=(100, 100), every_n=10, frame_dim=-1 show=True, cmap="gray")` + +![matshow3d example](../images/matshow3d.png) + +MONAI also provides the `blend_images` utility to blend the `image` and `label` to a RGB color image to better visualize the segmentation regions with the specified `cmap` mode and weights, etc. Showing a spleen segmentation `image` and the corresponding `label` as example: + +![blend example](../images/blend.png) + +For more details of `TensorBoard utility`, `matshow3d` and `blend_images`, please check the [visualziation tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/transform_visualization.ipynb). And to visualize the class activation mapping for a trained classification model, MONAI provides CAM, GradCAM, GradCAM++ APIs for both 2D and 3D models: @@ -318,7 +365,8 @@ Models ensemble is a popular strategy in machine learning and deep learning area 4. Compute the average values with weights or vote the most common value as the final result. ![model ensemble](../images/models_ensemble.png) -More details of practice is at [Model ensemble tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/models_ensemble.ipynb). +More details of practice is at [Cross validation and model ensemble tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/cross_validation_models_ensemble.ipynb). + ### 3. Transfer learning for different input / output classes `Transfer-learning` is a common and efficient training approach, especially in the medical-specific domain where obtaining large datasets for training can be difficult. So transfer learning from a pre-trained checkpoint can significantly improve the model metrics and shorten training time. @@ -364,12 +412,33 @@ G. Wang, X. Liu, C. Li, Z. Xu, J. Ruan, H. Zhu, T. Meng, K. Li, N. Huang, S. Zha [A reimplementation](https://monai.io/research/lamp-automated-model-parallelism) of the LAMP system originally proposed by: Wentao Zhu, Can Zhao, Wenqi Li, Holger Roth, Ziyue Xu, and Daguang Xu (2020) "LAMP: Large Deep Nets with Automated Model Parallelism for Image Segmentation." MICCAI 2020 (Early Accept, paper link: https://arxiv.org/abs/2006.12575) + ![LAMP UNet](../images/unet-pipe.png) -## GPU acceleration +### 3. DiNTS: Differentiable Neural Network Topology Search for 3D Medical Image Segmentation +MONAI integrated the `DiNTS` module to support more flexible topologies and joint two-level search. It provides a topology guaranteed discretization algorithm and a discretization aware topology loss for the search stage to minimize the discretization gap, and a cost usage aware search method which can search 3D networks with different GPU memory requirements. For more details, please check the [DiNTS tutorial](https://monai.io/research/dints.html). + +![DiNTS](../images/dints-overview.png) + +### 4. Accounting for Dependencies in Deep Learning Based Multiple Instance Learning for Whole Slide Imaging +For [classification of digital pathology whole slide images (WSI)](https://arxiv.org/abs/2111.01556), MONAI introduces new transforms and network modules for multiple instance learning. These include self-attention transformer blocks for explicitly accounting of the dependencies between instances (image patches) during training. For more details, please check out the [multiple instance learning tutorial](https://github.com/Project-MONAI/tutorials/tree/master/pathology/multiple_instance_learning). ![multi-instance](../images/mil-patches.jpg) + +### 5. Self-supervised representation learning +MONAI starts to explore self-supervised representation learning in this milestone release. The Vision Transformer has been extended to learn from self-supervised reconstruction tasks with various data augmentation and a regularized contrastive loss. The weights of the pre-trained backbone could be used to enhance the performance of the novel downstream deep learning tasks. + +The [tutorial](https://github.com/Project-MONAI/tutorials/tree/master/self_supervised_pretraining) shows how to generate a good set of pre-trained weights using unlabeled data with self-supervised tasks, then use the pre-trained weights to perform fine-tuning on a fully supervised volumetric segmentation task using a transformer based `UNETR`. + +![self-supervised](../images/ssl_overview.png) + +## Performance optimization and GPU acceleration +Typically, model training is a time-consuming step during deep learning development, especially in medical imaging applications. Volumetric medical images are usually large (as multi-dimensional arrays) and the model training process can be complex. Even with powerful hardware (e.g. CPU/GPU with large RAM), it is not easy to fully leverage them to achieve high performance. MONAI provides [a fast training guide to achieve the best performance](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_model_training_guide.md). + NVIDIA GPUs have been widely applied in many areas of deep learning training and evaluation, and the CUDA parallel computation shows obvious acceleration when comparing to traditional computation methods. To fully leverage GPU features, many popular mechanisms raised, like automatic mixed precision (AMP), distributed data parallel, etc. MONAI can support these features and provides rich examples. -### 1. Auto mixed precision(AMP) +### 1. Profiling the pipelines +First of all, MONAI provides several methods based on `DLProf`, `Nsight`, `NVTX` and `NVML` for users to analyze their programs to identify the performance bottleneck. The analyses include operation-based GPU activity and overall GPU activity during model training. They will greatly help users manage computing bottlenecks and provide insights for the area to be improved for better computing efficiency. The detailed example is shown in the [performance profiling tutorial](https://github.com/Project-MONAI/tutorials/blob/master/performance_profiling/radiology/profiling_train_base_nvtx.md). + +### 2. Auto mixed precision(AMP) In 2017, NVIDIA researchers developed a methodology for mixed-precision training, which combined single-precision (FP32) with half-precision (e.g. FP16) format when training a network, and it achieved the same accuracy as FP32 training using the same hyperparameters. For the PyTorch 1.6 release, developers at NVIDIA and Facebook moved mixed precision functionality into PyTorch core as the AMP package, `torch.cuda.amp`. @@ -379,16 +448,16 @@ MONAI workflows can easily set `amp=True/False` in `SupervisedTrainer` or `Super We also executed the same test program on NVIDIA A100 GPU with the same software environment, obtained faster results: ![amp a100 results](../images/amp_training_a100.png) More details is available at [AMP training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/automatic_mixed_precision.ipynb). -We also tried to combine AMP with `CacheDataset` and `Novograd` optimizer to achieve the fast training in MONAI, able to obtain approximately 12x speedup compared with a Pytorch native implementation when the training converges at a validation mean dice of 0.93. Benchmark for reference: +We also tried to combine `AMP` with `CacheDataset`, `GPU cache`, `GPU transforms`, `ThreadDataLoader`, `DiceCE` loss function and `Novograd` optimizer to achieve the fast training in MONAI, able to obtain approximately `200x` speedup compared with a Pytorch native implementation when the training converges at a validation mean dice of `0.95`. Benchmark for reference: ![fast training results](../images/fast_training.png) More details is available at [Fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb). -### 2. Distributed data parallel -Distributed data parallel is an important feature of PyTorch to connect multiple GPU devices on single or multiple nodes to train or evaluate models. The distributed data parallel APIs of MONAI are compatible with native PyTorch distributed module, pytorch-ignite distributed module, Horovod, XLA, and the SLURM platform. MONAI provides demos for reference: train/evaluate with PyTorch DDP, train/evaluate with Horovod, train/evaluate with Ignite DDP, partition dataset and train with SmartCacheDataset, as well as a real world training example based on Decathlon challenge Task01 - Brain Tumor segmentation. The demo contains distributed caching, training, and validation. We obtained performance benchmarks for reference (based on PyTorch 1.6, CUDA 11, NVIDIA V100 GPUs): +### 3. Distributed data parallel +Distributed data parallel is an important feature of PyTorch to connect multiple GPU devices on single or multiple nodes to train or evaluate models. The distributed data parallel APIs of MONAI are compatible with native PyTorch distributed module, pytorch-ignite distributed module, Horovod, XLA, and the SLURM platform. MONAI provides demos for reference: train/evaluate with PyTorch DDP, train/evaluate with Horovod, train/evaluate with Ignite DDP, partition dataset and train with SmartCacheDataset, as well as a real world training example based on Decathlon challenge Task01 - Brain Tumor segmentation. The [tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/distributed_training/brats_training_ddp.py) contains distributed caching, training, and validation. We obtained performance benchmarks for reference (based on PyTorch 1.9.1, CUDA 11.4, NVIDIA V100 GPUs. The `optimization` means that with more GPU resources, we can split the data and cache into GPU memory and execute GPU transforms directly): -![distributed training results](../images/distributed_training.png) +![distributed training results](../images/brats_distributed.png) -### 3. C++/CUDA optimized modules +### 4. C++/CUDA optimized modules To further accelerate the domain-specific routines in the workflows, MONAI C++/CUDA implementation are introduced as extensions of the PyTorch native implementations. MONAI provides the modules using [the two ways of building C++ extensions from PyTorch](https://pytorch.org/tutorials/advanced/cpp_extension.html#custom-c-and-cuda-extensions): - via `setuptools`, for modules including `Resampler`, `Conditional random field (CRF)`, `Fast bilateral filtering using the permutohedral lattice`. @@ -396,6 +465,26 @@ MONAI provides the modules using [the two ways of building C++ extensions from P The following figure shows results of MONAI's Gaussian mixture models applied to tissue and surgical tools segmentation: ![Gaussian mixture models as a postprocessing step](../images/gmm_feature_set_comparison_s.png) +### 5. Cache IO and transforms data to GPU memory +Even with `CacheDataset`, we usually need to copy the same data to GPU memory for GPU random transforms or network computation in every epoch. An efficient approach is to cache the data to GPU memory directly, then every epoch can start from GPU computation immediately. + +For example: +```py +train_transforms = [ + LoadImaged(...), + AddChanneld(...), + Orientationd(...), + Spacingd(...), + ScaleIntensityRanged(...), + EnsureTyped(..., data_type="tensor"), + ToDeviced(..., device="cuda:0"), + RandCropByPosNegLabeld(...), +] +dataset = CacheDataset(..., transform=train_trans) +``` +Here we convert to PyTorch `Tensor` with `EnsureTyped` transform and move data to GPU with `ToDeviced` transform. `CacheDataset` caches the transform results until `ToDeviced`, so it is in GPU memory. Then in every epoch, the program fetches cached data from GPU memory and only executes the random transform `RandCropByPosNegLabeld` on GPU directly. +GPU caching example is available at [Spleen fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb). + ## Applications The research area of medical image deep learning is expanding fast. To apply the latest achievements into applications, MONAI contains many application components to build end-to-end solutions or prototypes for other similar use cases. @@ -417,3 +506,8 @@ Starting from v0.5.0, MONAI provides experimental features for building learning The following figure shows the registration of CT images acquired at different time points for a single patient using MONAI: ![3d registration](../images/3d_paired.png) + +### 4. Reproducing the state-of-the-art Kaggle competition solutions +[A reimplementation](https://github.com/Project-MONAI/tutorials/tree/master/kaggle/RANZCR/4th_place_solution) of the 4th place solution of RANZCR CLiP - Catheter and Line Position Challenge in Kaggle: https://www.kaggle.com/c/ranzcr-clip-catheter-line-classification + +The original solution is produced by Team Watercooled, and the authors are Dieter (https://www.kaggle.com/christofhenkel) and Psi (https://www.kaggle.com/philippsinger). diff --git a/docs/source/index.rst b/docs/source/index.rst index 76ba003c8d..1a4263db0d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -66,6 +66,11 @@ Technical documentation is available at `docs.monai.io `_ contrib +.. toctree:: + :maxdepth: 1 + :caption: Specifications + + mb_specification Links ----- diff --git a/docs/source/inferers.rst b/docs/source/inferers.rst index e358e603bd..ac638eb38d 100644 --- a/docs/source/inferers.rst +++ b/docs/source/inferers.rst @@ -36,3 +36,9 @@ Inferers .. autoclass:: SaliencyInferer :members: :special-members: __call__ + +`SliceInferer` +~~~~~~~~~~~~~~ +.. autoclass:: SliceInferer + :members: + :special-members: __call__ diff --git a/docs/source/installation.md b/docs/source/installation.md index 08ab109142..12bf544cba 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -4,6 +4,7 @@ 1. [From PyPI](#from-pypi) 1. [Milestone release](#milestone-release) 2. [Weekly preview release](#weekly-preview-release) +1. [From conda-forge](#from-conda-forge) 2. [From GitHub](#from-github) 1. [System-wide](#milestone-release) 2. [Editable](#weekly-preview-release) @@ -14,7 +15,7 @@ --- -MONAI's core functionality is written in Python 3 (>= 3.6) and only requires [Numpy](https://numpy.org/) and [Pytorch](https://pytorch.org/). +MONAI's core functionality is written in Python 3 (>= 3.7) and only requires [Numpy](https://numpy.org/) and [Pytorch](https://pytorch.org/). The package is currently distributed via Github as the primary source code repository, and the Python package index (PyPI). The pre-built Docker images are made available on DockerHub. @@ -47,6 +48,13 @@ To report any issues on the weekly preview, please include the version and commi python -c "import monai; print(monai.__version__); print(monai.__commit_id__)" ``` +## From conda-forge + +To install the [current milestone release](https://pypi.org/project/monai/): +```bash +conda install -c conda-forge monai +``` + ## From GitHub (_If you have installed the PyPI release version using ``pip install monai``, please run ``pip uninstall @@ -163,20 +171,28 @@ cd MONAI/ pip install -e '.[all]' ``` -To install all optional dependencies for MONAI development: +To install all optional dependencies with `pip` based on MONAI development environment settings: ```bash git clone https://github.com/Project-MONAI/MONAI.git cd MONAI/ pip install -r requirements-dev.txt ``` +To install all optional dependencies with `conda` based on MONAI development environment settings: +```bash +git clone https://github.com/Project-MONAI/MONAI.git +cd MONAI/ +conda create -n python= # eg 3.9 +conda env update -n -f environment-dev.yml +``` + Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is available via PyPI. - The options are ``` -[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops] +[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema] ``` which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas` and `einops`, respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/docs/source/losses.rst b/docs/source/losses.rst index fc7c302ea3..dfd8ce2ddb 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -63,6 +63,11 @@ Segmentation Losses .. autoclass:: TverskyLoss :members: +`ContrastiveLoss` +~~~~~~~~~~~~~~~~~ +.. autoclass:: ContrastiveLoss + :members: + Registration Losses ------------------- diff --git a/docs/source/mb_specification.rst b/docs/source/mb_specification.rst new file mode 100644 index 0000000000..d383dd7d8e --- /dev/null +++ b/docs/source/mb_specification.rst @@ -0,0 +1,142 @@ + +========================== +MONAI Bundle Specification +========================== + +Overview +======== + +This is the specification for the MONAI Bundle (MB) format of portable described deep learning models. The objective of a MB is to define a packaged network or model which includes the critical information necessary to allow users and programs to understand how the model is used and for what purpose. A bundle includes the stored weights of a model as a pickled state dictionary and/or a Torchscript object. Additional JSON files are included to store metadata about the model, information for constructing training, inference, and post-processing transform sequences, plain-text description, legal information, and other data the model creator wishes to include. + +This specification defines the directory structure a bundle must have and the necessary files it must contain. Additional files may be included and the directory packaged into a zip file or included as extra files directly in a Torchscript file. + +Directory Structure +=================== + +A MONAI Bundle is defined primarily as a directory with a set of specifically named subdirectories containing the model and metadata files. The root directory should be named for the model, given as "ModelName" in this exmaple, and should contain the following structure: + +:: + + ModelName + â”Ŗ━ configs + ┃ ┗━ metadata.json + â”Ŗ━ models + ┃ â”Ŗ━ model.pt + ┃ ┗━ model.ts + ┗━ docs + â”Ŗ━ README.md + ┗━ license.txt + + +These files mostly are required to be present with the given names for the directory to define a valid bundle: + +* **metadata.json**: metadata information in JSON format relating to the type of model, definition of input and output tensors, versions of the model and used software, and other information described below. +* **model.pt**: the state dictionary of a saved model, the information to instantiate the model must be found in the metadata file. +* **model.ts**: the Torchscript saved model if the model is compatible with being saved correctly in this format. +* **README.md**: plain-language information on the model, how to use it, author information, etc. in Markdown format. +* **license.txt**: software license attached to the model, can be left blank if no license needed. + +Archive Format +============== + +The bundle directory and its contents can be compressed into a zip file to constitute a single file package. When unzipped into a directory this file will reproduce the above directory structure, and should itself also be named after the model it contains. + +The Torchscript file format is also just a zip file with a specific structure. When creating such an archive with `save_net_with_metadata` a MB-compliant Torchscript file can be created by including the contents of `metadata.json` as the `meta_values` argument of the function, and other files included as `more_extra_files` entries. These will be stored in a `extras` directory in the zip file and can be retrieved with `load_net_with_metadata` or with any other library/tool that can read zip data. In this format the `model.*` files are obviously not needed by `README.md` and `license.txt` can be added as more extra files. + +metadata.json File +================== + +This file contains the metadata information relating to the model, including what the shape and format of inputs and outputs are, what the meaning of the outputs are, what type of model is present, and other information. The JSON structure is a dictionary containing a defined set of keys with additional user-specified keys. The mandatory keys are as follows: + +* **version**: version of the stored model, this allows multiple versions of the same model to be differentiated. +* **monai_version**: version of MONAI the bundle was generated on, later versions expected to work. +* **pytorch_version**: version of Pytorch the bundle was generated on, later versions expected to work. +* **numpy_version**: version of Numpy the bundle was generated on, later versions expected to work. +* **optional_packages_version**: dictionary relating optional package names to their versions, these packages are not needed but are recommended to be installed with this stated minimum version. +* **task**: plain-language description of what the model is meant to do. +* **description**: longer form plain-language description of what the model is, what it does, etc. +* **authorship**: state author(s) of the model. +* **copyright**: state model copyright. +* **network_data_format**: defines the format, shape, and meaning of inputs and outputs to the model, contains keys "inputs" and "outputs" relating named inputs/outputs to their format specifiers (defined below). + +Tensor format specifiers are used to define input and output tensors and their meanings, and must be a dictionary containing at least these keys: + +* **type**: what sort of data the tensor represents: "image", "label", etc. +* **format**: what format of information is stored: "magnitude", "hounsfield", "kspace", "segmentation", "multiclass", etc. +* **num_channels**: number of channels the tensor has, assumed channel dimension first +* **spatial_shape**: shape of the spatial dimensions of the form "[H]", "[H, W]", or "[H, W, D]", see below for possible values of H, W, and D +* **dtype**: data type of tensor, eg. "float32", "int32" +* **value_range**: minimum and maximum values the input data is expected to have of the form "[MIN, MAX]" or "[]" if not known +* **is_patch_data**: "true" if the data is a patch of an input/output tensor or the entirely of the tensor, "false" otherwise +* **channel_def**: dictionary relating channel indices to plain-language description of what the channel contains + +Optional keys: + +* **changelog**: dictionary relating previous version names to strings describing the version. +* **intended_use**: what the model is to be used for, ie. what task it accomplishes. +* **data_source**: description of where training/validation can be sourced. +* **data_type**: type of source data used for training/validation. +* **references**: list of published referenced relating to the model. + +Spatial shape definition can be complex for models accepting inputs of varying shapes, especially if there are specific conditions on what those shapes can be. Shapes are specified as lists of either positive integers for fixed sizes or strings containing expressions defining the condition a size depends on. This can be "*" to mean any size, or use an expression with Python mathematical operators and one character variables to represent dependence on an unknown quantity. For example, "2**n" represents a size which must be a power of 2, "2**n*m" must be a multiple of a power of 2. Variables are shared between dimension expressions, so a spatial shape of `["2**n", "2**n"]` states that the dimensions must be the same powers of 2 given by `n`. + +A JSON schema for this file can be found at https://github.com/Project-MONAI/MONAI/blob/3049e280f2424962bb2a69261389fcc0b98e0036/monai/apps/mmars/schema/metadata.json + +An example JSON metadata file: + +:: + + { + "version": "0.1.0", + "changelog": { + "0.1.0": "complete the model package", + "0.0.1": "initialize the model package structure" + }, + "monai_version": "0.8.0", + "pytorch_version": "1.10.0", + "numpy_version": "1.21.2", + "optional_packages_version": {"nibabel": "3.2.1"}, + "task": "Decathlon spleen segmentation", + "description": "A pre-trained model for volumetric (3D) segmentation of the spleen from CT image", + "authorship": "MONAI team", + "copyright": "Copyright (c) MONAI Consortium", + "data_source": "Task09_Spleen.tar from http://medicaldecathlon.com/", + "data_type": "dicom", + "image_classes": "single channel data, intensity scaled to [0, 1]", + "label_classes": "single channel data, 1 is spleen, 0 is everything else", + "pred_classes": "2 channels OneHot data, channel 1 is spleen, channel 0 is background", + "eval_metrics": { + "mean_dice": 0.96 + }, + "intended_use": "This is an example, not to be used for diagnostic purposes", + "references": [ + "Xia, Yingda, et al. '3D Semi-Supervised Learning with Uncertainty-Aware Multi-View Co-Training.' arXiv preprint arXiv:1811.12506 (2018). https://arxiv.org/abs/1811.12506.", + "Kerfoot E., Clough J., Oksuz I., Lee J., King A.P., Schnabel J.A. (2019) Left-Ventricle Quantification Using Residual U-Net. In: Pop M. et al. (eds) Statistical Atlases and Computational Models of the Heart. Atrial Segmentation and LV Quantification Challenges. STACOM 2018. Lecture Notes in Computer Science, vol 11395. Springer, Cham. https://doi.org/10.1007/978-3-030-12029-0_40" + ], + "network_data_format":{ + "inputs": { + "image": { + "type": "image", + "format": "magnitude", + "num_channels": 1, + "spatial_shape": [160, 160, 160], + "dtype": "float32", + "value_range": [0, 1], + "is_patch_data": false, + "channel_def": {0: "image"} + } + }, + "outputs":{ + "pred": { + "type": "image", + "format": "segmentation", + "num_channels": 2, + "spatial_shape": [160, 160, 160], + "dtype": "float32", + "value_range": [0, 1], + "is_patch_data": false, + "channel_def": {0: "background", 1: "spleen"} + } + } + } + } diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index c1ed25831d..332571d345 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -8,6 +8,8 @@ Metrics `FROC` ------ +.. autofunction:: compute_fp_tp_probs +.. autofunction:: compute_froc_curve_data .. autofunction:: compute_froc_score `Metric` @@ -47,6 +49,7 @@ Metrics `Confusion matrix` ------------------ .. autofunction:: get_confusion_matrix +.. autofunction:: compute_confusion_matrix_metric .. autoclass:: ConfusionMatrixMetric :members: @@ -54,6 +57,7 @@ Metrics `Hausdorff distance` -------------------- .. autofunction:: compute_hausdorff_distance +.. autofunction:: compute_percent_hausdorff_distance .. autoclass:: HausdorffDistanceMetric :members: @@ -84,3 +88,13 @@ Metrics ---------------------------- .. autoclass:: PSNRMetric :members: + +`Cumulative average` +-------------------- +.. autoclass:: CumulativeAverage + :members: + +Utilities +--------- +.. automodule:: monai.metrics.utils + :members: diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 54c2756535..7607cd2701 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -21,7 +21,7 @@ Blocks :members: `CRF` -~~~~~~~~~~~~~ +~~~~~ .. autoclass:: CRF :members: @@ -73,6 +73,8 @@ Blocks :members: .. autoclass:: UnetUpBlock :members: +.. autoclass:: UnetOutBlock + :members: `SegResnet Block` ~~~~~~~~~~~~~~~~~ @@ -188,6 +190,26 @@ Blocks .. autoclass:: PatchEmbeddingBlock :members: +`FactorizedIncreaseBlock` +~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: FactorizedIncreaseBlock + :members: + +`FactorizedReduceBlock` +~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: FactorizedReduceBlock + :members: + +`P3DActiConvNormBlock` +~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: P3DActiConvNormBlock + :members: + +`ActiConvNormBlock` +~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ActiConvNormBlock + :members: + `Warp` ~~~~~~ .. autoclass:: Warp @@ -234,6 +256,11 @@ Layers .. automodule:: monai.networks.layers.Conv :members: +`Pad` +~~~~~ +.. automodule:: monai.networks.layers.Pad + :members: + `Pool` ~~~~~~ .. automodule:: monai.networks.layers.Pool @@ -256,6 +283,19 @@ Layers .. autoclass:: Flatten :members: +`Reshape` +~~~~~~~~~ +.. autoclass:: Reshape + :members: + +`separable_filtering` +~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: separable_filtering + +`apply_filter` +~~~~~~~~~~~~~~ +.. autofunction:: apply_filter + `GaussianFilter` ~~~~~~~~~~~~~~~~ .. autoclass:: GaussianFilter @@ -267,7 +307,7 @@ Layers :members: `PHLFilter` -~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~ .. autoclass:: PHLFilter `GaussianMixtureModel` @@ -353,6 +393,11 @@ Nets .. autoclass:: EfficientNet :members: +`BlockArgs` +~~~~~~~~~~~ +.. autoclass:: BlockArgs + :members: + `EfficientNetBN` ~~~~~~~~~~~~~~~~ .. autoclass:: EfficientNetBN @@ -373,6 +418,11 @@ Nets .. autoclass:: SegResNetVAE :members: +`ResNet` +~~~~~~~~ +.. autoclass:: ResNet + :members: + `SENet` ~~~~~~~ .. autoclass:: SENet @@ -423,6 +473,11 @@ Nets .. autoclass:: Unet .. autoclass:: unet +`AttentionUnet` +~~~~~~~~~~~~~~~ +.. autoclass:: AttentionUnet + :members: + `UNETR` ~~~~~~~ .. autoclass:: UNETR @@ -470,11 +525,21 @@ Nets .. autoclass:: ViT :members: +`ViTAutoEnc` +~~~~~~~~~~~~ +.. autoclass:: ViTAutoEnc + :members: + `FullyConnectedNet` ~~~~~~~~~~~~~~~~~~~ .. autoclass:: FullyConnectedNet :members: +`VarFullyConnectedNet` +~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: VarFullyConnectedNet + :members: + `Generator` ~~~~~~~~~~~ .. autoclass:: Generator @@ -500,6 +565,11 @@ Nets .. autoclass:: Critic :members: +`Transchex` +~~~~~~~~~~~~~~~~ +.. autoclass:: Transchex + :members: + `NetAdapter` ~~~~~~~~~~~~ .. autoclass:: NetAdapter @@ -515,6 +585,31 @@ Nets .. autoclass:: TorchVisionFullyConvModel :members: +`MILModel` +~~~~~~~~~~ +.. autoclass:: MILModel + :members: + +`DiNTS` +~~~~~~~ +.. autoclass:: DiNTS + :members: + +`TopologyConstruction for DiNTS` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: TopologyConstruction + :members: + +`TopologyInstance for DiNTS` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: TopologyInstance + :members: + +`TopologySearch for DiNTS` +~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: TopologySearch + :members: + Utilities --------- .. automodule:: monai.networks.utils diff --git a/docs/source/optimizers.rst b/docs/source/optimizers.rst index c766ac3cf9..67cbdc0951 100644 --- a/docs/source/optimizers.rst +++ b/docs/source/optimizers.rst @@ -6,6 +6,11 @@ Optimizers ========== .. currentmodule:: monai.optimizers +`LearningRateFinder` +-------------------- +.. autoclass:: LearningRateFinder + :members: + `Novograd` ---------- .. autoclass:: Novograd @@ -14,3 +19,18 @@ Optimizers `Generate parameter groups` --------------------------- .. autofunction:: generate_param_groups + +`ExponentialLR` +--------------- +.. autoclass:: ExponentialLR + :members: + +`LinearLR` +---------- +.. autoclass:: LinearLR + :members: + +`WarmupCosineSchedule` +---------------------- +.. autoclass:: WarmupCosineSchedule + :members: diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index b8f57e0dbe..8fc832a253 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -43,6 +43,11 @@ Generic Interfaces .. autoclass:: InvertibleTransform :members: +`TraceableTransform` +^^^^^^^^^^^^^^^^^^^^ +.. autoclass:: TraceableTransform + :members: + `BatchInverseTransform` ^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: BatchInverseTransform @@ -53,80 +58,121 @@ Generic Interfaces .. autoclass:: Decollated :members: +`OneOf` +^^^^^^^ +.. autoclass:: OneOf + :members: + Vanilla Transforms ------------------ Crop and Pad ^^^^^^^^^^^^ +`PadListDataCollate` +"""""""""""""""""""" +.. autoclass:: PadListDataCollate + :members: + :special-members: __call__ + +`Pad` +""""" +.. autoclass:: Pad + :members: + :special-members: __call__ + `SpatialPad` """""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/SpatialPad.png + :alt: example of SpatialPad .. autoclass:: SpatialPad :members: :special-members: __call__ `BorderPad` """"""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/BorderPad.png + :alt: example of BorderPad .. autoclass:: BorderPad :members: :special-members: __call__ `DivisiblePad` """""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/DivisiblePad.png + :alt: example of DivisiblePad .. autoclass:: DivisiblePad :members: :special-members: __call__ `SpatialCrop` """"""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/SpatialCrop.png + :alt: example of SpatialCrop .. autoclass:: SpatialCrop :members: :special-members: __call__ `CenterSpatialCrop` """"""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/CenterSpatialCrop.png + :alt: example of CenterSpatialCrop .. autoclass:: CenterSpatialCrop :members: :special-members: __call__ `RandSpatialCrop` """"""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSpatialCrop.png + :alt: example of RandSpatialCrop .. autoclass:: RandSpatialCrop :members: :special-members: __call__ `RandSpatialCropSamples` """""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSpatialCropSamples.png + :alt: example of RandSpatialCropSamples .. autoclass:: RandSpatialCropSamples :members: :special-members: __call__ `CropForeground` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/CropForeground.png + :alt: example of CropForeground .. autoclass:: CropForeground :members: :special-members: __call__ `RandWeightedCrop` """""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandWeightedCrop.png + :alt: example of RandWeightedCrop .. autoclass:: RandWeightedCrop :members: :special-members: __call__ `RandCropByPosNegLabel` """"""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandCropByPosNegLabel.png + :alt: example of RandCropByPosNegLabel .. autoclass:: RandCropByPosNegLabel :members: :special-members: __call__ `RandCropByLabelClasses` """""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandCropByLabelClasses.png + :alt: example of RandCropByLabelClasses .. autoclass:: RandCropByLabelClasses :members: :special-members: __call__ `ResizeWithPadOrCrop` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ResizeWithPadOrCrop.png + :alt: example of ResizeWithPadOrCrop .. autoclass:: ResizeWithPadOrCrop :members: :special-members: __call__ @@ -139,12 +185,16 @@ Crop and Pad `RandScaleCrop` """"""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandScaleCrop.png + :alt: example of RandScaleCrop .. autoclass:: RandScaleCrop :members: :special-members: __call__ `CenterScaleCrop` """"""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/CenterScaleCrop.png + :alt: example of CenterScaleCrop .. autoclass:: CenterScaleCrop :members: :special-members: __call__ @@ -154,126 +204,168 @@ Intensity `RandGaussianNoise` """"""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGaussianNoise.png + :alt: example of RandGaussianNoise .. autoclass:: RandGaussianNoise :members: :special-members: __call__ `ShiftIntensity` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ShiftIntensity.png + :alt: example of ShiftIntensity .. autoclass:: ShiftIntensity :members: :special-members: __call__ `RandShiftIntensity` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandShiftIntensity.png + :alt: example of RandShiftIntensity .. autoclass:: RandShiftIntensity :members: :special-members: __call__ `StdShiftIntensity` """"""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/StdShiftIntensity.png + :alt: example of StdShiftIntensity .. autoclass:: StdShiftIntensity :members: :special-members: __call__ `RandStdShiftIntensity` """"""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandStdShiftIntensity.png + :alt: example of RandStdShiftIntensity .. autoclass:: RandStdShiftIntensity :members: :special-members: __call__ `RandBiasField` """"""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandBiasField.png + :alt: example of RandBiasField .. autoclass:: RandBiasField :members: :special-members: __call__ `ScaleIntensity` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ScaleIntensity.png + :alt: example of ScaleIntensity .. autoclass:: ScaleIntensity :members: :special-members: __call__ `RandScaleIntensity` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandScaleIntensity.png + :alt: example of RandScaleIntensity .. autoclass:: RandScaleIntensity :members: :special-members: __call__ `NormalizeIntensity` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/NormalizeIntensity.png + :alt: example of NormalizeIntensity .. autoclass:: NormalizeIntensity :members: :special-members: __call__ `ThresholdIntensity` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ThresholdIntensity.png + :alt: example of ThresholdIntensity .. autoclass:: ThresholdIntensity :members: :special-members: __call__ `ScaleIntensityRange` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ScaleIntensityRange.png + :alt: example of ScaleIntensityRange .. autoclass:: ScaleIntensityRange :members: :special-members: __call__ `ScaleIntensityRangePercentiles` """""""""""""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ScaleIntensityRangePercentiles.png + :alt: example of ScaleIntensityRangePercentiles .. autoclass:: ScaleIntensityRangePercentiles :members: :special-members: __call__ `AdjustContrast` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/AdjustContrast.png + :alt: example of AdjustContrast .. autoclass:: AdjustContrast :members: :special-members: __call__ `RandAdjustContrast` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandAdjustContrast.png + :alt: example of RandAdjustContrast .. autoclass:: RandAdjustContrast :members: :special-members: __call__ `MaskIntensity` """"""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/MaskIntensity.png + :alt: example of MaskIntensity .. autoclass:: MaskIntensity :members: :special-members: __call__ `SavitzkyGolaySmooth` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/SavitzkyGolaySmooth.png + :alt: example of SavitzkyGolaySmooth .. autoclass:: SavitzkyGolaySmooth :members: :special-members: __call__ `GaussianSmooth` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/GaussianSmooth.png + :alt: example of GaussianSmooth .. autoclass:: GaussianSmooth :members: :special-members: __call__ `RandGaussianSmooth` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGaussianSmooth.png + :alt: example of RandGaussianSmooth .. autoclass:: RandGaussianSmooth :members: :special-members: __call__ `GaussianSharpen` """"""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/GaussianSharpen.png + :alt: example of GaussianSharpen .. autoclass:: GaussianSharpen :members: :special-members: __call__ `RandGaussianSharpen` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGaussianSharpen.png + :alt: example of RandGaussianSharpen .. autoclass:: RandGaussianSharpen :members: :special-members: __call__ `RandHistogramShift` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandHistogramShift.png + :alt: example of RandHistogramShift .. autoclass:: RandHistogramShift :members: :special-members: __call__ @@ -286,43 +378,71 @@ Intensity `GibbsNoise` """""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/GibbsNoise.png + :alt: example of GibbsNoise .. autoclass:: GibbsNoise :members: :special-members: __call__ `RandGibbsNoise` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGibbsNoise.png + :alt: example of RandGibbsNoise .. autoclass:: RandGibbsNoise :members: :special-members: __call__ `KSpaceSpikeNoise` """""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/KSpaceSpikeNoise.png + :alt: example of KSpaceSpikeNoise .. autoclass:: KSpaceSpikeNoise :members: :special-members: __call__ `RandKSpaceSpikeNoise` """""""""""""""""""""" - .. autoclass:: RandKSpaceSpikeNoise - :members: - :special-members: __call__ +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandKSpaceSpikeNoise.png + :alt: example of RandKSpaceSpikeNoise +.. autoclass:: RandKSpaceSpikeNoise + :members: + :special-members: __call__ + +`RandRicianNoise` +""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandRicianNoise.png + :alt: example of RandRicianNoise +.. autoclass:: RandRicianNoise + :members: + :special-members: __call__ + +`RandCoarseTransform` +""""""""""""""""""""" +.. autoclass:: RandCoarseTransform + :members: + :special-members: __call__ `RandCoarseDropout` """"""""""""""""""" - .. autoclass:: RandCoarseDropout - :members: - :special-members: __call__ +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandCoarseDropout.png + :alt: example of RandCoarseDropout +.. autoclass:: RandCoarseDropout + :members: + :special-members: __call__ + +`RandCoarseShuffle` +""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandCoarseShuffle.png + :alt: example of RandCoarseShuffle +.. autoclass:: RandCoarseShuffle + :members: + :special-members: __call__ `HistogramNormalize` """""""""""""""""""" - .. autoclass:: HistogramNormalize - :members: - :special-members: __call__ - -`LocalPatchShuffling` -""""""""""""""""""""" -.. autoclass:: LocalPatchShuffling +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/HistogramNormalize.png + :alt: example of HistogramNormalize +.. autoclass:: HistogramNormalize :members: :special-members: __call__ @@ -381,18 +501,24 @@ Post-processing `AsDiscrete` """""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/AsDiscrete.png + :alt: example of AsDiscrete .. autoclass:: AsDiscrete :members: :special-members: __call__ `KeepLargestConnectedComponent` """"""""""""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/KeepLargestConnectedComponent.png + :alt: example of KeepLargestConnectedComponent .. autoclass:: KeepLargestConnectedComponent :members: :special-members: __call__ `LabelFilter` """"""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/LabelFilter.png + :alt: example of LabelFilter .. autoclass:: LabelFilter :members: :special-members: __call__ @@ -405,6 +531,8 @@ Post-processing `LabelToContour` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/LabelToContour.png + :alt: example of LabelToContour .. autoclass:: LabelToContour :members: :special-members: __call__ @@ -415,8 +543,8 @@ Post-processing :members: :special-members: __call__ -`Prob NMS` -"""""""""" +`ProbNMS` +""""""""" .. autoclass:: ProbNMS :members: @@ -429,44 +557,70 @@ Post-processing Spatial ^^^^^^^ +`SpatialResample` +""""""""""""""""" +.. autoclass:: SpatialResample + :members: + :special-members: __call__ + +`ResampleToMatch` +""""""""""""""""" +.. autoclass:: ResampleToMatch + :members: + :special-members: __call__ + `Spacing` """"""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Spacing.png + :alt: example of Spacing .. autoclass:: Spacing :members: :special-members: __call__ `Orientation` """"""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Orientation.png + :alt: example of Orientation .. autoclass:: Orientation :members: :special-members: __call__ `RandRotate` """""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandRotate.png + :alt: example of RandRotate .. autoclass:: RandRotate :members: :special-members: __call__ `RandFlip` """""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandFlip.png + :alt: example of RandFlip .. autoclass:: RandFlip :members: :special-members: __call__ `RandAxisFlip` """""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandAxisFlip.png + :alt: example of RandAxisFlip .. autoclass:: RandAxisFlip :members: :special-members: __call__ `RandZoom` """""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandZoom.png + :alt: example of RandZoom .. autoclass:: RandZoom :members: :special-members: __call__ `Affine` """""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Affine.png + :alt: example of Affine .. autoclass:: Affine :members: :special-members: __call__ @@ -479,6 +633,8 @@ Spatial `RandAffine` """""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandAffine.png + :alt: example of RandAffine .. autoclass:: RandAffine :members: :special-members: __call__ @@ -501,57 +657,110 @@ Spatial :members: :special-members: __call__ +`GridDistortion` +"""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/GridDistortion.png + :alt: example of GridDistortion +.. autoclass:: GridDistortion + :members: + :special-members: __call__ + +`RandGridDistortion` +"""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGridDistortion.png + :alt: example of RandGridDistortion +.. autoclass:: RandGridDistortion + :members: + :special-members: __call__ + `Rand2DElastic` """"""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Rand2DElastic.png + :alt: example of Rand2DElastic .. autoclass:: Rand2DElastic :members: :special-members: __call__ `Rand3DElastic` """"""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Rand3DElastic.png + :alt: example of Rand3DElastic .. autoclass:: Rand3DElastic :members: :special-members: __call__ `Rotate90` """""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Rotate90.png + :alt: example of Rotate90 .. autoclass:: Rotate90 :members: :special-members: __call__ `RandRotate90` """""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandRotate90.png + :alt: example of RandRotate90 .. autoclass:: RandRotate90 :members: :special-members: __call__ `Flip` """""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Flip.png + :alt: example of Flip .. autoclass:: Flip :members: :special-members: __call__ `Resize` """""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Resize.png + :alt: example of Resize .. autoclass:: Resize :members: :special-members: __call__ `Rotate` """""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Rotate.png + :alt: example of Rotate .. autoclass:: Rotate :members: :special-members: __call__ `Zoom` """""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Zoom.png + :alt: example of Zoom .. autoclass:: Zoom :members: :special-members: __call__ -`AddCoordinateChannels` -""""""""""""""""""""""" -.. autoclass:: AddCoordinateChannels +Smooth Field +^^^^^^^^^^^^ + +`RandSmoothFieldAdjustContrast` +""""""""""""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSmoothFieldAdjustContrast.png + :alt: example of RandSmoothFieldAdjustContrast +.. autoclass:: RandSmoothFieldAdjustContrast + :members: + :special-members: __call__ + +`RandSmoothFieldAdjustIntensity` +"""""""""""""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSmoothFieldAdjustIntensity.png + :alt: example of RandSmoothFieldAdjustIntensity +.. autoclass:: RandSmoothFieldAdjustIntensity + :members: + :special-members: __call__ + +`RandSmoothDeform` +"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSmoothDeform.png + :alt: example of RandSmoothDeform +.. autoclass:: RandSmoothDeform :members: :special-members: __call__ @@ -624,7 +833,6 @@ Utility :members: :special-members: __call__ - `Transpose` """"""""""" .. autoclass:: Transpose @@ -649,6 +857,7 @@ Utility :members: :special-members: __call__ + `Lambda` """""""" .. autoclass:: Lambda @@ -661,6 +870,12 @@ Utility :members: :special-members: __call__ +`RemoveRepeatedChannel` +""""""""""""""""""""""" +.. autoclass:: RemoveRepeatedChannel + :members: + :special-members: __call__ + `LabelToMask` """"""""""""" .. autoclass:: LabelToMask @@ -711,16 +926,34 @@ Utility `IntensityStats` """""""""""""""" - .. autoclass:: IntensityStats - :members: - :special-members: __call__ +.. autoclass:: IntensityStats + :members: + :special-members: __call__ `ToDevice` """""""""" - .. autoclass:: ToDevice +.. autoclass:: ToDevice :members: :special-members: __call__ +`CuCIM` +""""""" +.. autoclass:: CuCIM + :members: + :special-members: __call__ + +`RandCuCIM` +""""""""""" +.. autoclass:: RandCuCIM + :members: + :special-members: __call__ + +`AddCoordinateChannels` +""""""""""""""""""""""" +.. autoclass:: AddCoordinateChannels + :members: + :special-members: __call__ + Dictionary Transforms --------------------- @@ -730,72 +963,96 @@ Crop and Pad (Dict) `SpatialPadd` """"""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/SpatialPadd.png + :alt: example of SpatialPadd .. autoclass:: SpatialPadd :members: :special-members: __call__ `BorderPadd` """""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/BorderPadd.png + :alt: example of BorderPadd .. autoclass:: BorderPadd :members: :special-members: __call__ `DivisiblePadd` """"""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/DivisiblePadd.png + :alt: example of DivisiblePadd .. autoclass:: DivisiblePadd :members: :special-members: __call__ `SpatialCropd` """""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/SpatialCropd.png + :alt: example of SpatialCropd .. autoclass:: SpatialCropd :members: :special-members: __call__ `CenterSpatialCropd` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/CenterSpatialCropd.png + :alt: example of CenterSpatialCropd .. autoclass:: CenterSpatialCropd :members: :special-members: __call__ `RandSpatialCropd` """""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSpatialCropd.png + :alt: example of RandSpatialCropd .. autoclass:: RandSpatialCropd :members: :special-members: __call__ `RandSpatialCropSamplesd` """"""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSpatialCropSamplesd.png + :alt: example of RandSpatialCropSamplesd .. autoclass:: RandSpatialCropSamplesd :members: :special-members: __call__ `CropForegroundd` """"""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/CropForegroundd.png + :alt: example of CropForegroundd .. autoclass:: CropForegroundd :members: :special-members: __call__ `RandWeightedCropd` """"""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandWeightedCropd.png + :alt: example of RandWeightedCropd .. autoclass:: RandWeightedCropd :members: :special-members: __call__ `RandCropByPosNegLabeld` """""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandCropByPosNegLabeld.png + :alt: example of RandCropByPosNegLabeld .. autoclass:: RandCropByPosNegLabeld :members: :special-members: __call__ `RandCropByLabelClassesd` """"""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandCropByLabelClassesd.png + :alt: example of RandCropByLabelClassesd .. autoclass:: RandCropByLabelClassesd :members: :special-members: __call__ `ResizeWithPadOrCropd` """""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ResizeWithPadOrCropd.png + :alt: example of ResizeWithPadOrCropd .. autoclass:: ResizeWithPadOrCropd :members: :special-members: __call__ @@ -808,12 +1065,16 @@ Crop and Pad (Dict) `RandScaleCropd` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandScaleCropd.png + :alt: example of RandScaleCropd .. autoclass:: RandScaleCropd :members: :special-members: __call__ `CenterScaleCropd` """""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/CenterScaleCropd.png + :alt: example of CenterScaleCropd .. autoclass:: CenterScaleCropd :members: :special-members: __call__ @@ -823,159 +1084,235 @@ Intensity (Dict) `RandGaussianNoised` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGaussianNoised.png + :alt: example of RandGaussianNoised .. autoclass:: RandGaussianNoised :members: :special-members: __call__ `ShiftIntensityd` """"""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ShiftIntensityd.png + :alt: example of ShiftIntensityd .. autoclass:: ShiftIntensityd :members: :special-members: __call__ `RandShiftIntensityd` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandShiftIntensityd.png + :alt: example of RandShiftIntensityd .. autoclass:: RandShiftIntensityd :members: :special-members: __call__ `StdShiftIntensityd` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/StdShiftIntensityd.png + :alt: example of StdShiftIntensityd .. autoclass:: StdShiftIntensityd :members: :special-members: __call__ `RandStdShiftIntensityd` """""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandStdShiftIntensityd.png + :alt: example of RandStdShiftIntensityd .. autoclass:: RandStdShiftIntensityd :members: :special-members: __call__ `RandBiasFieldd` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandBiasFieldd.png + :alt: example of RandBiasFieldd .. autoclass:: RandBiasFieldd :members: :special-members: __call__ `ScaleIntensityd` """"""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ScaleIntensityd.png + :alt: example of ScaleIntensityd .. autoclass:: ScaleIntensityd :members: :special-members: __call__ `RandScaleIntensityd` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandScaleIntensityd.png + :alt: example of RandScaleIntensityd .. autoclass:: RandScaleIntensityd :members: :special-members: __call__ `NormalizeIntensityd` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/NormalizeIntensityd.png + :alt: example of NormalizeIntensityd .. autoclass:: NormalizeIntensityd :members: :special-members: __call__ `ThresholdIntensityd` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ThresholdIntensityd.png + :alt: example of ThresholdIntensityd .. autoclass:: ThresholdIntensityd :members: :special-members: __call__ `ScaleIntensityRanged` """""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ScaleIntensityRanged.png + :alt: example of ScaleIntensityRanged .. autoclass:: ScaleIntensityRanged :members: :special-members: __call__ `GibbsNoised` """""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/GibbsNoised.png + :alt: example of GibbsNoised .. autoclass:: GibbsNoised :members: :special-members: __call__ `RandGibbsNoised` """""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGibbsNoised.png + :alt: example of RandGibbsNoised .. autoclass:: RandGibbsNoised :members: :special-members: __call__ `KSpaceSpikeNoised` """""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/KSpaceSpikeNoised.png + :alt: example of KSpaceSpikeNoised .. autoclass:: KSpaceSpikeNoised :members: :special-members: __call__ `RandKSpaceSpikeNoised` """"""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandKSpaceSpikeNoised.png + :alt: example of RandKSpaceSpikeNoised .. autoclass:: RandKSpaceSpikeNoised :members: :special-members: __call__ +`RandRicianNoised` +"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandRicianNoised.png + :alt: example of RandRicianNoised +.. autoclass:: RandRicianNoised + :members: + :special-members: __call__ + `ScaleIntensityRangePercentilesd` """"""""""""""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ScaleIntensityRangePercentilesd.png + :alt: example of ScaleIntensityRangePercentilesd .. autoclass:: ScaleIntensityRangePercentilesd :members: :special-members: __call__ `AdjustContrastd` """"""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/AdjustContrastd.png + :alt: example of AdjustContrastd .. autoclass:: AdjustContrastd :members: :special-members: __call__ `RandAdjustContrastd` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandAdjustContrastd.png + :alt: example of RandAdjustContrastd .. autoclass:: RandAdjustContrastd :members: :special-members: __call__ `MaskIntensityd` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/MaskIntensityd.png + :alt: example of MaskIntensityd .. autoclass:: MaskIntensityd :members: :special-members: __call__ +`SavitzkyGolaySmoothd` +"""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/SavitzkyGolaySmoothd.png + :alt: example of SavitzkyGolaySmoothd +.. autoclass:: SavitzkyGolaySmoothd + :members: + :special-members: __call__ + `GaussianSmoothd` """"""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/GaussianSmoothd.png + :alt: example of GaussianSmoothd .. autoclass:: GaussianSmoothd :members: :special-members: __call__ `RandGaussianSmoothd` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGaussianSmoothd.png + :alt: example of RandGaussianSmoothd .. autoclass:: RandGaussianSmoothd :members: :special-members: __call__ `GaussianSharpend` """""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/GaussianSharpend.png + :alt: example of GaussianSharpend .. autoclass:: GaussianSharpend :members: :special-members: __call__ `RandGaussianSharpend` """""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGaussianSharpend.png + :alt: example of RandGaussianSharpend .. autoclass:: RandGaussianSharpend :members: :special-members: __call__ `RandHistogramShiftd` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandHistogramShiftd.png + :alt: example of RandHistogramShiftd .. autoclass:: RandHistogramShiftd :members: :special-members: __call__ `RandCoarseDropoutd` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandCoarseDropoutd.png + :alt: example of RandCoarseDropoutd .. autoclass:: RandCoarseDropoutd :members: :special-members: __call__ +`RandCoarseShuffled` +"""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandCoarseShuffled.png + :alt: example of RandCoarseShuffled +.. autoclass:: RandCoarseShuffled + :members: + :special-members: __call__ + `HistogramNormalized` """"""""""""""""""""" - .. autoclass:: HistogramNormalized - :members: - :special-members: __call__ +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/HistogramNormalized.png + :alt: example of HistogramNormalized +.. autoclass:: HistogramNormalized + :members: + :special-members: __call__ IO (Dict) @@ -1004,18 +1341,24 @@ Post-processing (Dict) `AsDiscreted` """"""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/AsDiscreted.png + :alt: example of AsDiscreted .. autoclass:: AsDiscreted :members: :special-members: __call__ `KeepLargestConnectedComponentd` """""""""""""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/KeepLargestConnectedComponentd.png + :alt: example of KeepLargestConnectedComponentd .. autoclass:: KeepLargestConnectedComponentd :members: :special-members: __call__ `LabelFilterd` """""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/LabelFilterd.png + :alt: example of LabelFilterd .. autoclass:: LabelFilterd :members: :special-members: __call__ @@ -1028,6 +1371,8 @@ Post-processing (Dict) `LabelToContourd` """"""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/LabelToContourd.png + :alt: example of LabelToContourd .. autoclass:: LabelToContourd :members: :special-members: __call__ @@ -1062,108 +1407,195 @@ Post-processing (Dict) :members: :special-members: __call__ +`ProbNMSd` +"""""""""" +.. autoclass:: ProbNMSd + :members: + :special-members: __call__ + Spatial (Dict) ^^^^^^^^^^^^^^ +`SpatialResampled` +"""""""""""""""""" +.. autoclass:: SpatialResampled + :members: + :special-members: __call__ + +`ResampleToMatchd` +"""""""""""""""""" +.. autoclass:: ResampleToMatchd + :members: + :special-members: __call__ + `Spacingd` """""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Spacingd.png + :alt: example of Spacingd .. autoclass:: Spacingd :members: :special-members: __call__ `Orientationd` """""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Orientationd.png + :alt: example of Orientationd .. autoclass:: Orientationd :members: :special-members: __call__ `Flipd` """"""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Flipd.png + :alt: example of Flipd .. autoclass:: Flipd :members: :special-members: __call__ `RandFlipd` """"""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandFlipd.png + :alt: example of RandFlipd .. autoclass:: RandFlipd :members: :special-members: __call__ `RandAxisFlipd` """"""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandAxisFlipd.png + :alt: example of RandAxisFlipd .. autoclass:: RandAxisFlipd :members: :special-members: __call__ `Rotated` """"""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Rotated.png + :alt: example of Rotated .. autoclass:: Rotated :members: :special-members: __call__ `RandRotated` """"""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandRotated.png + :alt: example of RandRotated .. autoclass:: RandRotated :members: :special-members: __call__ `Zoomd` """"""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Zoomd.png + :alt: example of Zoomd .. autoclass:: Zoomd :members: :special-members: __call__ `RandZoomd` """"""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandZoomd.png + :alt: example of RandZoomd .. autoclass:: RandZoomd :members: :special-members: __call__ `RandRotate90d` """"""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandRotate90d.png + :alt: example of RandRotate90d .. autoclass:: RandRotate90d :members: :special-members: __call__ `Rotate90d` """"""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Rotate90d.png + :alt: example of Rotate90d .. autoclass:: Rotate90d :members: :special-members: __call__ `Resized` """"""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Resized.png + :alt: example of Resized .. autoclass:: Resized :members: :special-members: __call__ `Affined` """"""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Affined.png + :alt: example of Affined .. autoclass:: Affined :members: :special-members: __call__ `RandAffined` """"""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandAffined.png + :alt: example of RandAffined .. autoclass:: RandAffined :members: :special-members: __call__ `Rand2DElasticd` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Rand2DElasticd.png + :alt: example of Rand2DElasticd .. autoclass:: Rand2DElasticd :members: :special-members: __call__ `Rand3DElasticd` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Rand3DElasticd.png + :alt: example of Rand3DElasticd .. autoclass:: Rand3DElasticd :members: :special-members: __call__ -`AddCoordinateChannelsd` -"""""""""""""""""""""""" -.. autoclass:: AddCoordinateChannelsd +`GridDistortiond` +""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/GridDistortiond.png + :alt: example of GridDistortiond +.. autoclass:: GridDistortiond + :members: + :special-members: __call__ + +`RandGridDistortiond` +""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGridDistortiond.png + :alt: example of RandGridDistortiond +.. autoclass:: RandGridDistortiond + :members: + :special-members: __call__ + +Smooth Field (Dict) +^^^^^^^^^^^^^^^^^^^ + +`RandSmoothFieldAdjustContrastd` +"""""""""""""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSmoothFieldAdjustContrastd.png + :alt: example of RandSmoothFieldAdjustContrastd +.. autoclass:: RandSmoothFieldAdjustContrastd + :members: + :special-members: __call__ + +`RandSmoothFieldAdjustIntensityd` +""""""""""""""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSmoothFieldAdjustIntensityd.png + :alt: example of RandSmoothFieldAdjustIntensityd +.. autoclass:: RandSmoothFieldAdjustIntensityd + :members: + :special-members: __call__ + +`RandSmoothDeformd` +""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSmoothDeformd.png + :alt: example of RandSmoothDeformd +.. autoclass:: RandSmoothDeformd :members: :special-members: __call__ @@ -1230,12 +1662,24 @@ Utility (Dict) :members: :special-members: __call__ +`ToPIL` +""""""" +.. autoclass:: ToPIL + :members: + :special-members: __call__ + `ToCupyd` """"""""" .. autoclass:: ToCupyd :members: :special-members: __call__ +`ToPILd` +"""""""" +.. autoclass:: ToPILd + :members: + :special-members: __call__ + `DeleteItemsd` """""""""""""" .. autoclass:: DeleteItemsd @@ -1248,6 +1692,12 @@ Utility (Dict) :members: :special-members: __call__ +`Transposed` +"""""""""""" +.. autoclass:: Transposed + :members: + :special-members: __call__ + `SqueezeDimd` """"""""""""" .. autoclass:: SqueezeDimd @@ -1290,6 +1740,12 @@ Utility (Dict) :members: :special-members: __call__ +`RemoveRepeatedChanneld` +"""""""""""""""""""""""" +.. autoclass:: RemoveRepeatedChanneld + :members: + :special-members: __call__ + `LabelToMaskd` """""""""""""" .. autoclass:: LabelToMaskd @@ -1352,15 +1808,37 @@ Utility (Dict) `ToDeviced` """"""""""" - .. autoclass:: ToDeviced - :members: - :special-members: __call__ +.. autoclass:: ToDeviced + :members: + :special-members: __call__ + +`CuCIMd` +"""""""" +.. autoclass:: CuCIMd + :members: + :special-members: __call__ + +`RandCuCIMd` +"""""""""""" +.. autoclass:: RandCuCIMd + :members: + :special-members: __call__ +`AddCoordinateChannelsd` +"""""""""""""""""""""""" +.. autoclass:: AddCoordinateChannelsd + :members: + :special-members: __call__ Transform Adaptors ------------------ .. automodule:: monai.transforms.adaptors +`FunctionSignature` +^^^^^^^^^^^^^^^^^^^ +.. autoclass:: FunctionSignature + :members: + `adaptor` ^^^^^^^^^ .. autofunction:: monai.transforms.adaptors.adaptor @@ -1377,3 +1855,6 @@ Utilities --------- .. automodule:: monai.transforms.utils :members: + +.. automodule:: monai.transforms.utils_pytorch_numpy_unification + :members: diff --git a/docs/source/utils.rst b/docs/source/utils.rst index a9aea7932b..881519936b 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -43,7 +43,7 @@ Profiling Deprecated ---------- -.. automodule:: monai.utils.deprecated +.. automodule:: monai.utils.deprecate_utils :members: @@ -51,3 +51,28 @@ Type conversion --------------- .. automodule:: monai.utils.type_conversion :members: + +Decorators +---------- +.. automodule:: monai.utils.decorators + :members: + +Distributed Data Parallel +------------------------- +.. automodule:: monai.utils.dist + :members: + +Enums +----- +.. automodule:: monai.utils.enums + :members: + +Jupyter Utilities +----------------- +.. automodule:: monai.utils.jupyter_utils + :members: + +State Cacher +------------ +.. automodule:: monai.utils.state_cacher + :members: diff --git a/docs/source/visualize.rst b/docs/source/visualize.rst index 850fd51770..3779feec88 100644 --- a/docs/source/visualize.rst +++ b/docs/source/visualize.rst @@ -24,3 +24,8 @@ Occlusion sensitivity .. automodule:: monai.visualize.occlusion_sensitivity :members: + +Utilities +--------- +.. automodule:: monai.visualize.utils + :members: diff --git a/docs/source/whatsnew.rst b/docs/source/whatsnew.rst index daed871e14..1f27d651db 100644 --- a/docs/source/whatsnew.rst +++ b/docs/source/whatsnew.rst @@ -6,5 +6,7 @@ What's New .. toctree:: :maxdepth: 1 + whatsnew_0_8.md + whatsnew_0_7.md whatsnew_0_6.md whatsnew_0_5.md diff --git a/docs/source/whatsnew_0_6.md b/docs/source/whatsnew_0_6.md index bdc419df37..8df0503142 100644 --- a/docs/source/whatsnew_0_6.md +++ b/docs/source/whatsnew_0_6.md @@ -1,4 +1,4 @@ -# What's new in 0.6 🎉🎉 +# What's new in 0.6 - Decollating mini-batches as an essential post-processing step - Pythonic APIs to load the pretrained models from Clara Train MMARs diff --git a/docs/source/whatsnew_0_7.md b/docs/source/whatsnew_0_7.md new file mode 100644 index 0000000000..6df64948b0 --- /dev/null +++ b/docs/source/whatsnew_0_7.md @@ -0,0 +1,63 @@ +# What's new in 0.7 + +- Performance enhancements with profiling and tuning guides +- Major usability improvements in `monai.transforms` +- Reimplementing state-of-the-art Kaggle solutions +- Vision-language multimodal transformer architectures + +## Performance enhancements with profiling and tuning guides + +Model training is often a time-consuming step during deep learning development, +especially for medical imaging applications. Even with powerful hardware (e.g. +CPU/GPU with large RAM), the workflows often require careful profiling and +tuning to achieve high performance. MONAI has been focusing on performance +enhancements, and in this version, a fast model training guide is provided +to help build highly performant workflows, with a comprehensive overview of +the profiling tools and practical strategies: +https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_model_training_guide.md. + +The following figure shows the use of [Nvidia Nsightâ„ĸ Systems](https://developer.nvidia.com/nsight-systems) for system-wide +performance analysis during a performance enhancement study. +![nsight_vis](../images/nsight_comparison.png) + +With the performance profiling and enhancements, several typical use cases were studied to +improve the training efficiency. The following figure shows that fast +training using MONAI can be `200` times faster than a regular baseline ([learn +more](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb)), and it's `20` times faster than the MONAI v0.6 fast training solution. +![fast_training](../images/fast_training.png) + +## Major usability improvements in `monai.transforms` for NumPy/PyTorch inputs and backends + + MONAI starts to roll out major usability enhancements for the + `monai.transforms` module. Many transforms are now supporting both NumPy and + PyTorch, as input types and computational backends. To get the supported backends of every transform, please execute: `python monai/transforms/utils.py`. + +One benefit of these enhancements is that the users can now better leverage the +GPUs for preprocessing. By transferring the input data onto GPU using +`ToTensor` or `EnsureType`, and applying the GPU-based transforms to the data, +[the tutorial of spleen +segmentation](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb) +shows the great potential of using the flexible modules for fast and efficient +training. + +## Reimplementing state-of-the-art Kaggle solutions + +With this release, we actively evaluate and enhance the quality and flexibility +of the MONAI core modules, using the public Kaggle challenge as a testbed. [A +reimplementation](https://github.com/Project-MONAI/tutorials/tree/master/kaggle/RANZCR/4th_place_solution) +of a state-of-the-art solution at [Kaggle RANZCR CLiP - Catheter and Line +Position +Challenge](https://www.kaggle.com/c/ranzcr-clip-catheter-line-classification) +is made available in this version. + +## Vision-language multimodal transformers + +In this release, MONAI adds support for training multimodal (vision + language) +transformers that can handle both image and textual data. MONAI introduces the +`TransCheX` model which consists of vision, language, and mixed-modality +transformer layers for processing chest X-ray and their corresponding +radiological reports within a unified framework. In addition to `TransCheX`, +users have the flexibility to alter the architecture by varying the number of +vision, language and mixed-modality layers and customizing the classification +head. In addition, the model can be initialized from pre-trained BERT language +models for fine-tuning. diff --git a/docs/source/whatsnew_0_8.md b/docs/source/whatsnew_0_8.md new file mode 100644 index 0000000000..bbaf01f5de --- /dev/null +++ b/docs/source/whatsnew_0_8.md @@ -0,0 +1,56 @@ +# What's new in 0.8 🎉🎉 + +- Differentiable neural network topology search +- Multiple instance learning for digital pathology WSI analysis +- Self-supervised representation learning +- Major usability improvements in `monai.transforms` + +## Differentiable neural network topology search +MONAI integrates `DiNTS`: [Differentiable Neural Network Topology Search for 3D +Medical Image Segmentation](https://arxiv.org/abs/2103.15954). The neural +architecture search module supports flexible multi-path topology search with +high search efficiency and budgeted memory usage. + +It provides a topology guaranteed discretization algorithm and a +discretization-aware topology loss for the search stage to minimize the +discretization gap. The module is memory usage aware and is able to search 3D +networks with different GPU memory requirements. For more details, please check out the +[DiNTS tutorial](https://monai.io/research/dints.html). + +![DiNTS](../images/dints-overview.png) + +## Multiple instance learning for digital pathology WSI analysis +For [classification of digital pathology whole slide images +(WSI)](https://arxiv.org/abs/2111.01556), MONAI introduces new transforms and +network modules for multiple instance learning. These include self-attention +transformer blocks for explicitly accounting of the dependencies between instances +(image patches) during training. For more details, +please check out the [multiple instance learning tutorial](https://github.com/Project-MONAI/tutorials/tree/master/pathology/multiple_instance_learning). + +![multi-instance](../images/mil-patches.jpg) + +## Self-supervised representation learning +MONAI starts to explore self-supervised representation learning in this +milestone release. The Vision Transformer has been extended to learn from self-supervised +reconstruction tasks with various data augmentation and a regularized +contrastive loss. The weights of the pre-trained backbone could be used to +enhance the performance of the novel downstream deep learning tasks. + +The [tutorial](https://github.com/Project-MONAI/tutorials/tree/master/self_supervised_pretraining) +shows how to generate a good set of pre-trained weights using unlabeled data +with self-supervised tasks, then use the pre-trained weights to perform +fine-tuning on a fully supervised volumetric segmentation task using a transformer based `UNETR`. + +![self-supervised](../images/ssl_overview.png) + +## Major usability improvements in `monai.transforms` +`monai.transforms` are now more flexible and easy to use in version 0.8. +- Input type handling and backend APIs are improved to support both + NumPy and PyTorch where possible. +- Visual examples are added to the documentation to illustrate the effects of + various image processing. +- New visualization utilities are provided and enhanced for quick qualitative + assessments of the model by visualizing, for example, the volumetric image + inputs, segmentation maps, and intermediate feature maps. + The visualization tutorial is available for + [TensorBoard utility, `matshow3d` and `blend_images`](https://github.com/Project-MONAI/tutorials/blob/master/modules/transform_visualization.ipynb). diff --git a/environment-dev.yml b/environment-dev.yml new file mode 100644 index 0000000000..a361262930 --- /dev/null +++ b/environment-dev.yml @@ -0,0 +1,63 @@ +name: monai +channels: + - pytorch + - defaults + - conda-forge +dependencies: + - numpy>=1.17 + - pytorch>=1.6 + - coverage>=5.5 + - parameterized + - setuptools>=50.3.0,!=60.0.0 + - ignite==0.4.8 + - gdown>=3.6.4 + - scipy + - nibabel + - pillow!=8.3.0 # https://github.com/python-pillow/Pillow/issues/5571 + - tensorboard + - scikit-image>=0.14.2 + - tqdm>=4.47.0 + - python-lmdb + - flake8>=3.8.1 + - flake8-bugbear + - flake8-comprehensions + - flake8-pyi + - pylint + - mccabe + - pep8-naming + - pycodestyle + - pyflakes + - isort + - types-pkg_resources + - ninja + - torchvision + - psutil + - Sphinx==3.5.3 + - recommonmark==0.6.0 + - sphinx-autodoc-typehints==1.11.1 + - sphinx_rtd_theme==0.5.2 + - pandas + - requests + - einops + - transformers + - mlflow + - tensorboardX + - pyyaml + - fire + - jsonschema + - pip + - pip: + # pip for itk as conda-forge version only up to v5.1 + - itk>=5.2 + # black currently at v19 on conda vs v21 on pip + - black + # conda mypy v. slow + - mypy>=0.790 + # OS-specific needs to be done via pip: + # https://github.com/conda/conda/issues/8089 + - pytype>=2020.6.1; platform_system != "Windows" + - openslide-python==1.1.2 + - cucim>=21.8.2; platform_system == "Linux" + - imagecodecs; platform_system == "Linux" + - tifffile; platform_system == "Linux" + - matplotlib!=3.5.0 diff --git a/monai/README.md b/monai/README.md index a224996f38..2c30531bf3 100644 --- a/monai/README.md +++ b/monai/README.md @@ -2,6 +2,8 @@ * **apps**: high level medical domain specific deep learning applications. +* **bundle**: components to build the portable self-descriptive model bundle. + * **config**: for system configuration and diagnostic output. * **csrc**: for C++/CUDA extensions. diff --git a/monai/__init__.py b/monai/__init__.py index 2c7c920162..e56a2f3444 100644 --- a/monai/__init__.py +++ b/monai/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,29 +15,31 @@ from ._version import get_versions PY_REQUIRED_MAJOR = 3 -PY_REQUIRED_MINOR = 6 +PY_REQUIRED_MINOR = 7 version_dict = get_versions() __version__: str = version_dict.get("version", "0+unknown") __revision_id__: str = version_dict.get("full-revisionid") del get_versions, version_dict -__copyright__ = "(c) 2020 - 2021 MONAI Consortium" +__copyright__ = "(c) MONAI Consortium" __basedir__ = os.path.dirname(__file__) -if not (sys.version_info.major == PY_REQUIRED_MAJOR and sys.version_info.minor >= PY_REQUIRED_MINOR): - raise RuntimeError( - "MONAI requires Python {}.{} or higher. But the current Python is: {}".format( - PY_REQUIRED_MAJOR, PY_REQUIRED_MINOR, sys.version - ), +if sys.version_info.major != PY_REQUIRED_MAJOR or sys.version_info.minor < PY_REQUIRED_MINOR: + import warnings + + warnings.warn( + f"MONAI requires Python {PY_REQUIRED_MAJOR}.{PY_REQUIRED_MINOR} or higher. " + f"But the current Python is: {sys.version}", + category=RuntimeWarning, ) from .utils.module import load_submodules # noqa: E402 # handlers_* have some external decorators the users may not have installed # *.so files and folder "_C" may not exist when the cpp extensions are not compiled -excludes = "(^(monai.handlers))|((\\.so)$)|(^(monai._C))" +excludes = "(^(monai.handlers))|(^(monai.bundle))|((\\.so)$)|(^(monai._C))" # load directory modules only, skip loading individual files load_submodules(sys.modules[__name__], False, exclude_pattern=excludes) @@ -47,6 +49,7 @@ __all__ = [ "apps", + "bundle", "config", "data", "engines", diff --git a/monai/_extensions/__init__.py b/monai/_extensions/__init__.py index 3718894b7c..fd32d71840 100644 --- a/monai/_extensions/__init__.py +++ b/monai/_extensions/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/_extensions/gmm/gmm.cpp b/monai/_extensions/gmm/gmm.cpp index ecb85e252a..686fddb721 100644 --- a/monai/_extensions/gmm/gmm.cpp +++ b/monai/_extensions/gmm/gmm.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/_extensions/gmm/gmm.h b/monai/_extensions/gmm/gmm.h index 9a43351eb9..6317baa41a 100644 --- a/monai/_extensions/gmm/gmm.h +++ b/monai/_extensions/gmm/gmm.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/_extensions/gmm/gmm_cpu.cpp b/monai/_extensions/gmm/gmm_cpu.cpp index 144e66806c..c9b55490eb 100644 --- a/monai/_extensions/gmm/gmm_cpu.cpp +++ b/monai/_extensions/gmm/gmm_cpu.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -22,5 +22,5 @@ void learn_cpu(const float* input, const int* labels, float* gmm, float* scratch void apply_cpu(const float* gmm, const float* input, float* output, unsigned int batch_count, unsigned int element_count) { - throw std::invalid_argument("GMM recieved a cpu tensor but is not yet implemented for the cpu"); + throw std::invalid_argument("GMM received a cpu tensor but is not yet implemented for the cpu"); } diff --git a/monai/_extensions/gmm/gmm_cuda.cu b/monai/_extensions/gmm/gmm_cuda.cu index 36af48b06c..765ffe5b1c 100644 --- a/monai/_extensions/gmm/gmm_cuda.cu +++ b/monai/_extensions/gmm/gmm_cuda.cu @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/_extensions/gmm/gmm_cuda_linalg.cuh b/monai/_extensions/gmm/gmm_cuda_linalg.cuh index 49e68c8442..9d54d80d3b 100644 --- a/monai/_extensions/gmm/gmm_cuda_linalg.cuh +++ b/monai/_extensions/gmm/gmm_cuda_linalg.cuh @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/_extensions/loader.py b/monai/_extensions/loader.py index 5f77480ecc..2b57302fb0 100644 --- a/monai/_extensions/loader.py +++ b/monai/_extensions/loader.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,7 +34,7 @@ def timeout(time, message): except KeyboardInterrupt as e: if timer is not None and timer.is_alive(): raise e # interrupt from user? - raise TimeoutError(message) + raise TimeoutError(message) from e finally: if timer is not None: try: @@ -84,11 +84,7 @@ def load_module( # This will either run the build or return the existing .so object. name = module_name + platform_str.replace(".", "_") module = load( - name=name, - sources=source, - extra_cflags=define_args, - extra_cuda_cflags=define_args, - verbose=verbose_build, + name=name, sources=source, extra_cflags=define_args, extra_cuda_cflags=define_args, verbose=verbose_build ) return module diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index ef4352cabd..893f7877d2 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,4 +11,4 @@ from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar -from .utils import check_hash, download_and_extract, download_url, extractall +from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index c766914026..1bfb97abd9 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,13 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import sys +from pathlib import Path from typing import Callable, Dict, List, Optional, Sequence, Union import numpy as np from monai.apps.utils import download_and_extract +from monai.config.type_definitions import PathLike from monai.data import ( CacheDataset, load_decathlon_datalist, @@ -49,7 +50,16 @@ class MedNISTDataset(Randomizable, CacheDataset): cache_rate: percentage of cached data in total, default is 1.0 (cache all). will take the minimum of (cache_num, data_length x cache_rate, data_length). num_workers: the number of worker threads to use. - if 0 a single thread will be used. Default is 0. + If num_workers is None then the number returned by os.cpu_count() is used. + If a value less than 1 is speficied, 1 will be used instead. + progress: whether to display a progress bar when downloading dataset and computing the transform cache content. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cached content + (for example, randomly crop from the cached image and deepcopy the crop region) + or if every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. + as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous. + it may help improve the performance of following logic. Raises: ValueError: When ``root_dir`` is not a directory. @@ -57,14 +67,14 @@ class MedNISTDataset(Randomizable, CacheDataset): """ - resource = "https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE" + resource = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz" md5 = "0bc7306e7427e00ad1c5526a6677552d" compressed_file_name = "MedNIST.tar.gz" dataset_folder_name = "MedNIST" def __init__( self, - root_dir: str, + root_dir: PathLike, section: str, transform: Union[Sequence[Callable], Callable] = (), download: bool = False, @@ -73,21 +83,32 @@ def __init__( test_frac: float = 0.1, cache_num: int = sys.maxsize, cache_rate: float = 1.0, - num_workers: int = 0, + num_workers: Optional[int] = 1, + progress: bool = True, + copy_cache: bool = True, + as_contiguous: bool = True, ) -> None: - if not os.path.isdir(root_dir): + root_dir = Path(root_dir) + if not root_dir.is_dir(): raise ValueError("Root directory root_dir must be a directory.") self.section = section self.val_frac = val_frac self.test_frac = test_frac self.set_random_state(seed=seed) - tarfile_name = os.path.join(root_dir, self.compressed_file_name) - dataset_dir = os.path.join(root_dir, self.dataset_folder_name) + tarfile_name = root_dir / self.compressed_file_name + dataset_dir = root_dir / self.dataset_folder_name self.num_class = 0 if download: - download_and_extract(self.resource, tarfile_name, root_dir, self.md5) + download_and_extract( + url=self.resource, + filepath=tarfile_name, + output_dir=root_dir, + hash_val=self.md5, + hash_type="md5", + progress=progress, + ) - if not os.path.exists(dataset_dir): + if not dataset_dir.is_dir(): raise RuntimeError( f"Cannot find dataset directory: {dataset_dir}, please use download=True to download it." ) @@ -95,31 +116,34 @@ def __init__( if transform == (): transform = LoadImaged("image") CacheDataset.__init__( - self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers + self, + data=data, + transform=transform, + cache_num=cache_num, + cache_rate=cache_rate, + num_workers=num_workers, + progress=progress, + copy_cache=copy_cache, + as_contiguous=as_contiguous, ) - def randomize(self, data: List[int]) -> None: + def randomize(self, data: np.ndarray) -> None: self.R.shuffle(data) def get_num_classes(self) -> int: """Get number of classes.""" return self.num_class - def _generate_data_list(self, dataset_dir: str) -> List[Dict]: + def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]: """ Raises: ValueError: When ``section`` is not one of ["training", "validation", "test"]. """ - class_names = sorted((x for x in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, x)))) + dataset_dir = Path(dataset_dir) + class_names = sorted(f"{x.name}" for x in dataset_dir.iterdir() if x.is_dir()) # folder name as the class name self.num_class = len(class_names) - image_files = [ - [ - os.path.join(dataset_dir, class_names[i], x) - for x in os.listdir(os.path.join(dataset_dir, class_names[i])) - ] - for i in range(self.num_class) - ] + image_files = [[f"{x}" for x in (dataset_dir / class_names[i]).iterdir()] for i in range(self.num_class)] num_each = [len(image_files[i]) for i in range(self.num_class)] image_files_list = [] image_class = [] @@ -145,13 +169,9 @@ def _generate_data_list(self, dataset_dir: str) -> List[Dict]: raise ValueError( f'Unsupported section: {self.section}, available options are ["training", "validation", "test"].' ) - + # the types of label and class name should be compatible with the pytorch dataloader return [ - { - "image": image_files_list[i], - "label": image_class[i], - "class_name": class_name[i], - } + {"image": image_files_list[i], "label": image_class[i], "class_name": class_name[i]} for i in section_indices ] @@ -183,7 +203,16 @@ class DecathlonDataset(Randomizable, CacheDataset): cache_rate: percentage of cached data in total, default is 1.0 (cache all). will take the minimum of (cache_num, data_length x cache_rate, data_length). num_workers: the number of worker threads to use. - if 0 a single thread will be used. Default is 0. + If num_workers is None then the number returned by os.cpu_count() is used. + If a value less than 1 is speficied, 1 will be used instead. + progress: whether to display a progress bar when downloading dataset and computing the transform cache content. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cached content + (for example, randomly crop from the cached image and deepcopy the crop region) + or if every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. + as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous. + it may help improve the performance of following logic. Raises: ValueError: When ``root_dir`` is not a directory. @@ -238,7 +267,7 @@ class DecathlonDataset(Randomizable, CacheDataset): def __init__( self, - root_dir: str, + root_dir: PathLike, task: str, section: str, transform: Union[Sequence[Callable], Callable] = (), @@ -247,21 +276,32 @@ def __init__( val_frac: float = 0.2, cache_num: int = sys.maxsize, cache_rate: float = 1.0, - num_workers: int = 0, + num_workers: int = 1, + progress: bool = True, + copy_cache: bool = True, + as_contiguous: bool = True, ) -> None: - if not os.path.isdir(root_dir): + root_dir = Path(root_dir) + if not root_dir.is_dir(): raise ValueError("Root directory root_dir must be a directory.") self.section = section self.val_frac = val_frac self.set_random_state(seed=seed) if task not in self.resource: raise ValueError(f"Unsupported task: {task}, available options are: {list(self.resource.keys())}.") - dataset_dir = os.path.join(root_dir, task) + dataset_dir = root_dir / task tarfile_name = f"{dataset_dir}.tar" if download: - download_and_extract(self.resource[task], tarfile_name, root_dir, self.md5[task]) + download_and_extract( + url=self.resource[task], + filepath=tarfile_name, + output_dir=root_dir, + hash_val=self.md5[task], + hash_type="md5", + progress=progress, + ) - if not os.path.exists(dataset_dir): + if not dataset_dir.exists(): raise RuntimeError( f"Cannot find dataset directory: {dataset_dir}, please use download=True to download it." ) @@ -279,11 +319,19 @@ def __init__( "numTraining", "numTest", ] - self._properties = load_decathlon_properties(os.path.join(dataset_dir, "dataset.json"), property_keys) + self._properties = load_decathlon_properties(dataset_dir / "dataset.json", property_keys) if transform == (): transform = LoadImaged(["image", "label"]) CacheDataset.__init__( - self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers + self, + data=data, + transform=transform, + cache_num=cache_num, + cache_rate=cache_rate, + num_workers=num_workers, + progress=progress, + copy_cache=copy_cache, + as_contiguous=as_contiguous, ) def get_indices(self) -> np.ndarray: @@ -293,7 +341,7 @@ def get_indices(self) -> np.ndarray: """ return self.indices - def randomize(self, data: List[int]) -> None: + def randomize(self, data: np.ndarray) -> None: self.R.shuffle(data) def get_properties(self, keys: Optional[Union[Sequence[str], str]] = None): @@ -308,9 +356,11 @@ def get_properties(self, keys: Optional[Union[Sequence[str], str]] = None): return {key: self._properties[key] for key in ensure_tuple(keys)} return {} - def _generate_data_list(self, dataset_dir: str) -> List[Dict]: + def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]: + # the types of the item in data list should be compatible with the dataloader + dataset_dir = Path(dataset_dir) section = "training" if self.section in ["training", "validation"] else "test" - datalist = load_decathlon_datalist(os.path.join(dataset_dir, "dataset.json"), True, section) + datalist = load_decathlon_datalist(dataset_dir / "dataset.json", True, section) return self._split_datalist(datalist) def _split_datalist(self, datalist: List[Dict]) -> List[Dict]: @@ -349,14 +399,15 @@ class CrossValidation: root_dir="./", task="Task09_Spleen", section="training", + transform=train_transform, download=True, ) dataset_fold0_train = cvdataset.get_dataset(folds=[1, 2, 3, 4]) - dataset_fold0_val = cvdataset.get_dataset(folds=0) + dataset_fold0_val = cvdataset.get_dataset(folds=0, transform=val_transform, download=False) # execute training for fold 0 ... - dataset_fold1_train = cvdataset.get_dataset(folds=[1]) - dataset_fold1_val = cvdataset.get_dataset(folds=[0, 2, 3, 4]) + dataset_fold1_train = cvdataset.get_dataset(folds=[0, 2, 3, 4]) + dataset_fold1_val = cvdataset.get_dataset(folds=1, transform=val_transform, download=False) # execute training for fold 1 ... ... @@ -366,13 +417,7 @@ class CrossValidation: """ - def __init__( - self, - dataset_cls, - nfolds: int = 5, - seed: int = 0, - **dataset_params, - ) -> None: + def __init__(self, dataset_cls, nfolds: int = 5, seed: int = 0, **dataset_params) -> None: if not hasattr(dataset_cls, "_split_datalist"): raise ValueError("dataset class must have _split_datalist API.") self.dataset_cls = dataset_cls @@ -380,20 +425,24 @@ def __init__( self.seed = seed self.dataset_params = dataset_params - def get_dataset(self, folds: Union[Sequence[int], int]): + def get_dataset(self, folds: Union[Sequence[int], int], **dataset_params): """ Generate dataset based on the specified fold indice in the cross validation group. Args: folds: index of folds for training or validation, if a list of values, concatenate the data. + dataset_params: other additional parameters for the dataset_cls base class, will override + the same parameters in `self.dataset_params`. """ nfolds = self.nfolds seed = self.seed + dataset_params_ = dict(self.dataset_params) + dataset_params_.update(dataset_params) class _NsplitsDataset(self.dataset_cls): # type: ignore def _split_datalist(self, datalist: List[Dict]) -> List[Dict]: data = partition_dataset(data=datalist, num_partitions=nfolds, shuffle=True, seed=seed) return select_cross_validation_folds(partitions=data, folds=folds) - return _NsplitsDataset(**self.dataset_params) + return _NsplitsDataset(**dataset_params_) diff --git a/monai/apps/deepedit/__init__.py b/monai/apps/deepedit/__init__.py index 14ae193634..1e97f89407 100644 --- a/monai/apps/deepedit/__init__.py +++ b/monai/apps/deepedit/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/apps/deepedit/transforms.py b/monai/apps/deepedit/transforms.py index 845e7bd1d0..0fcf4c2286 100644 --- a/monai/apps/deepedit/transforms.py +++ b/monai/apps/deepedit/transforms.py @@ -1,3 +1,14 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import logging from typing import Dict, Hashable, Mapping, Tuple @@ -7,6 +18,7 @@ from monai.config import KeysCollection from monai.transforms.transform import MapTransform, Randomizable, Transform from monai.utils import optional_import +from monai.utils.enums import PostFix logger = logging.getLogger(__name__) @@ -15,12 +27,7 @@ class DiscardAddGuidanced(MapTransform): - def __init__( - self, - keys: KeysCollection, - probability: float = 1.0, - allow_missing_keys: bool = False, - ): + def __init__(self, keys: KeysCollection, probability: float = 1.0, allow_missing_keys: bool = False): """ Discard positive and negative points randomly or Add the two channels for inference time @@ -54,11 +61,7 @@ class ResizeGuidanceCustomd(Transform): Resize the guidance based on cropped vs resized image. """ - def __init__( - self, - guidance: str, - ref_image: str, - ) -> None: + def __init__(self, guidance: str, ref_image: str) -> None: self.guidance = guidance self.ref_image = ref_image @@ -66,11 +69,11 @@ def __call__(self, data): d = dict(data) current_shape = d[self.ref_image].shape[1:] - factor = np.divide(current_shape, d["image_meta_dict"]["dim"][1:4]) + factor = np.divide(current_shape, d[PostFix.meta("image")]["dim"][1:4]) pos_clicks, neg_clicks = d["foreground"], d["background"] - pos = np.multiply(pos_clicks, factor).astype(int).tolist() if len(pos_clicks) else [] - neg = np.multiply(neg_clicks, factor).astype(int).tolist() if len(neg_clicks) else [] + pos = np.multiply(pos_clicks, factor).astype(int, copy=False).tolist() if len(pos_clicks) else [] + neg = np.multiply(neg_clicks, factor).astype(int, copy=False).tolist() if len(neg_clicks) else [] d[self.guidance] = [pos, neg] return d @@ -156,7 +159,7 @@ def _apply(self, guidance, discrepancy): guidance[0].append([-1] * len(neg)) guidance[1].append(neg) - return json.dumps(np.asarray(guidance).astype(int).tolist()) + return json.dumps(np.asarray(guidance, dtype=int).tolist()) def __call__(self, data): d = dict(data) diff --git a/monai/apps/deepgrow/__init__.py b/monai/apps/deepgrow/__init__.py index 14ae193634..1e97f89407 100644 --- a/monai/apps/deepgrow/__init__.py +++ b/monai/apps/deepgrow/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/apps/deepgrow/dataset.py b/monai/apps/deepgrow/dataset.py index acaeba0bc3..721781196b 100644 --- a/monai/apps/deepgrow/dataset.py +++ b/monai/apps/deepgrow/dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -97,7 +97,7 @@ def create_dataset( image = os.path.abspath(image) label = os.path.abspath(label) if label else None - logging.info("Image: {}; Label: {}".format(image, label if label else None)) + logging.info(f"Image: {image}; Label: {label if label else None}") data = transforms({image_key: image, label_key: label}) if dimension == 2: data = _save_data_2d( @@ -126,8 +126,8 @@ def _default_transforms(image_key, label_key, pixdim): [ LoadImaged(keys=keys), AsChannelFirstd(keys=keys), - Spacingd(keys=keys, pixdim=pixdim, mode=mode), Orientationd(keys=keys, axcodes="RAS"), + Spacingd(keys=keys, pixdim=pixdim, mode=mode), ] ) @@ -154,7 +154,7 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): if vol_label is not None and np.sum(label) == 0: continue - image_file_prefix = "vol_idx_{:0>4d}_slice_{:0>3d}".format(vol_idx, sid) + image_file_prefix = f"vol_idx_{vol_idx:0>4d}_slice_{sid:0>3d}" image_file = os.path.join(dataset_dir, "images", image_file_prefix) image_file += ".npy" @@ -165,9 +165,7 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): # Test Data if vol_label is None: data_list.append( - { - "image": image_file.replace(dataset_dir + os.pathsep, "") if relative_path else image_file, - } + {"image": image_file.replace(dataset_dir + os.pathsep, "") if relative_path else image_file} ) continue @@ -177,7 +175,7 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): unique_labels_count = max(unique_labels_count, len(unique_labels)) for idx in unique_labels: - label_file_prefix = "{}_region_{:0>2d}".format(image_file_prefix, int(idx)) + label_file_prefix = f"{image_file_prefix}_region_{int(idx):0>2d}" label_file = os.path.join(dataset_dir, "labels", label_file_prefix) label_file += ".npy" @@ -226,7 +224,7 @@ def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): label_count = 0 unique_labels_count = 0 - image_file_prefix = "vol_idx_{:0>4d}".format(vol_idx) + image_file_prefix = f"vol_idx_{vol_idx:0>4d}" image_file = os.path.join(dataset_dir, "images", image_file_prefix) image_file += ".npy" @@ -236,11 +234,7 @@ def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): # Test Data if vol_label is None: - data_list.append( - { - "image": image_file.replace(dataset_dir + os.pathsep, "") if relative_path else image_file, - } - ) + data_list.append({"image": image_file.replace(dataset_dir + os.pathsep, "") if relative_path else image_file}) else: # For all Labels unique_labels = np.unique(vol_label.flatten()) @@ -248,7 +242,7 @@ def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): unique_labels_count = max(unique_labels_count, len(unique_labels)) for idx in unique_labels: - label_file_prefix = "{}_region_{:0>2d}".format(image_file_prefix, int(idx)) + label_file_prefix = f"{image_file_prefix}_region_{int(idx):0>2d}" label_file = os.path.join(dataset_dir, "labels", label_file_prefix) label_file += ".npy" diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py index 81e82c958d..73bc8e7e0b 100644 --- a/monai/apps/deepgrow/interaction.py +++ b/monai/apps/deepgrow/interaction.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,6 +22,7 @@ class Interaction: """ Ignite process_function used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation. + For more details please refer to: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. This implementation is based on: Sakinis et al., Interactive segmentation of medical images through diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index db450792b0..6da614f46c 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -from typing import Callable, Dict, Optional, Sequence, Union +from typing import Callable, Dict, Hashable, List, Optional, Sequence, Union import numpy as np import torch @@ -18,12 +18,15 @@ from monai.networks.layers import GaussianFilter from monai.transforms import Resize, SpatialCrop from monai.transforms.transform import MapTransform, Randomizable, Transform -from monai.transforms.utils import generate_spatial_bounding_box -from monai.utils import InterpolateMode, ensure_tuple, ensure_tuple_rep, min_version, optional_import +from monai.transforms.utils import generate_spatial_bounding_box, is_positive +from monai.utils import InterpolateMode, deprecated_arg, ensure_tuple, ensure_tuple_rep, min_version, optional_import +from monai.utils.enums import PostFix measure, _ = optional_import("skimage.measure", "0.14.2", min_version) distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") +DEFAULT_POST_FIX = PostFix.meta() + # Transforms to support Training for Deepgrow models class FindAllValidSlicesd(Transform): @@ -51,10 +54,10 @@ def __call__(self, data): d: Dict = dict(data) label = d[self.label] if label.shape[0] != 1: - raise ValueError("Only supports single channel labels!") + raise ValueError(f"Only supports single channel labels, got label shape {label.shape}!") if len(label.shape) != 4: # only for 3D - raise ValueError("Only supports label with shape CDHW!") + raise ValueError(f"Only supports label with shape CDHW, got label shape {label.shape}!") sids = self._apply(label) if sids is not None and len(sids): @@ -145,7 +148,7 @@ def _apply(self, label, sid): def __call__(self, data): d = dict(data) self.randomize(data) - d[self.guidance] = json.dumps(self._apply(d[self.label], self.sid).astype(int).tolist()) + d[self.guidance] = json.dumps(self._apply(d[self.label], self.sid).astype(int, copy=False).tolist()) return d @@ -163,13 +166,7 @@ class AddGuidanceSignald(Transform): """ - def __init__( - self, - image: str = "image", - guidance: str = "guidance", - sigma: int = 2, - number_intensity_ch: int = 1, - ): + def __init__(self, image: str = "image", guidance: str = "guidance", sigma: int = 2, number_intensity_ch: int = 1): self.image = image self.guidance = guidance self.sigma = sigma @@ -276,12 +273,7 @@ class AddRandomGuidanced(Randomizable, Transform): """ - def __init__( - self, - guidance: str = "guidance", - discrepancy: str = "discrepancy", - probability: str = "probability", - ): + def __init__(self, guidance: str = "guidance", discrepancy: str = "discrepancy", probability: str = "probability"): self.guidance = guidance self.discrepancy = discrepancy self.probability = probability @@ -334,7 +326,7 @@ def _apply(self, guidance, discrepancy): guidance[0].append([-1] * len(neg)) guidance[1].append(neg) - return json.dumps(np.asarray(guidance).astype(int).tolist()) + return json.dumps(np.asarray(guidance, dtype=int).tolist()) def __call__(self, data): d = dict(data) @@ -377,12 +369,15 @@ class SpatialCropForegroundd(MapTransform): channel_indices: if defined, select foreground only on the specified channels of image. if None, select foreground on the whole image. margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. + allow_smaller: when computing box size with `margin`, whether allow the image size to be smaller + than box size, default to `True`. if the margined size is bigger than image size, will pad with + specified `mode`. meta_keys: explicitly indicate the key of the corresponding meta data dictionary. for example, for data with key `image`, the metadata by default is in `image_meta_dict`. the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `{key}_{meta_key_postfix}` to to fetch/store the meta data according + meta_key_postfix: if meta_keys is None, use `{key}_{meta_key_postfix}` to fetch/store the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. @@ -398,11 +393,12 @@ def __init__( keys: KeysCollection, source_key: str, spatial_size: Union[Sequence[int], np.ndarray], - select_fn: Callable = lambda x: x > 0, + select_fn: Callable = is_positive, channel_indices: Optional[IndexSelection] = None, margin: int = 0, + allow_smaller: bool = True, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix="meta_dict", + meta_key_postfix=DEFAULT_POST_FIX, start_coord_key: str = "foreground_start_coord", end_coord_key: str = "foreground_end_coord", original_shape_key: str = "foreground_original_shape", @@ -416,6 +412,7 @@ def __init__( self.select_fn = select_fn self.channel_indices = channel_indices self.margin = margin + self.allow_smaller = allow_smaller self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) if len(self.keys) != len(self.meta_keys): raise ValueError("meta_keys should have the same length as keys.") @@ -428,11 +425,11 @@ def __init__( def __call__(self, data): d = dict(data) box_start, box_end = generate_spatial_bounding_box( - d[self.source_key], self.select_fn, self.channel_indices, self.margin + d[self.source_key], self.select_fn, self.channel_indices, self.margin, self.allow_smaller ) - center = list(np.mean([box_start, box_end], axis=0).astype(int)) - current_size = list(np.subtract(box_end, box_start).astype(int)) + center = list(np.mean([box_start, box_end], axis=0).astype(int, copy=False)) + current_size = list(np.subtract(box_end, box_start).astype(int, copy=False)) if np.all(np.less(current_size, self.spatial_size)): cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size) @@ -476,18 +473,23 @@ class AddGuidanceFromPointsd(Transform): background: key that represents user background (-ve) clicks. axis: axis that represents slices in 3D volume. (axis to Depth) depth_first: if depth (slices) is positioned at first dimension. - dimensions: dimensions based on model used for deepgrow (2D vs 3D). + spatial_dims: dimensions based on model used for deepgrow (2D vs 3D). slice_key: key that represents applicable slice to add guidance. meta_keys: explicitly indicate the key of the meta data dictionary of `ref_image`. for example, for data with key `image`, the metadata by default is in `image_meta_dict`. the meta data is a dictionary object which contains: filename, original_shape, etc. if None, will try to construct meta_keys by `{ref_image}_{meta_key_postfix}`. - meta_key_postfix: if meta_key is None, use `{ref_image}_{meta_key_postfix}` to to fetch the meta data according + meta_key_postfix: if meta_key is None, use `{ref_image}_{meta_key_postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. + + .. deprecated:: 0.6.0 + ``dimensions`` is deprecated, use ``spatial_dims`` instead. + """ + @deprecated_arg(name="dimensions", since="0.6", msg_suffix="Please use `spatial_dims` instead.") def __init__( self, ref_image, @@ -496,10 +498,11 @@ def __init__( background: str = "background", axis: int = 0, depth_first: bool = True, - dimensions: int = 2, + spatial_dims: int = 2, slice_key: str = "slice", meta_keys: Optional[str] = None, - meta_key_postfix: str = "meta_dict", + meta_key_postfix: str = DEFAULT_POST_FIX, + dimensions: Optional[int] = None, ): self.ref_image = ref_image self.guidance = guidance @@ -507,7 +510,7 @@ def __init__( self.background = background self.axis = axis self.depth_first = depth_first - self.dimensions = dimensions + self.dimensions = spatial_dims if dimensions is None else dimensions self.slice = slice_key self.meta_keys = meta_keys self.meta_key_postfix = meta_key_postfix @@ -533,9 +536,9 @@ def _apply(self, pos_clicks, neg_clicks, factor, slice_num): guidance = [pos, neg, slice_idx] else: if len(pos_clicks): - pos = np.multiply(pos_clicks, factor).astype(int).tolist() + pos = np.multiply(pos_clicks, factor).astype(int, copy=False).tolist() if len(neg_clicks): - neg = np.multiply(neg_clicks, factor).astype(int).tolist() + neg = np.multiply(neg_clicks, factor).astype(int, copy=False).tolist() guidance = [pos, neg] return guidance @@ -560,7 +563,7 @@ def __call__(self, data): fg_bg_clicks = [] for key in [self.foreground, self.background]: clicks = d[key] - clicks = list(np.array(clicks).astype(int)) + clicks = list(np.array(clicks, dtype=int)) if self.depth_first: for i in range(len(clicks)): clicks[i] = list(np.roll(clicks[i], 1)) @@ -591,7 +594,7 @@ class SpatialCropGuidanced(MapTransform): the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. @@ -609,7 +612,7 @@ def __init__( spatial_size, margin=20, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix="meta_dict", + meta_key_postfix=DEFAULT_POST_FIX, start_coord_key: str = "foreground_start_coord", end_coord_key: str = "foreground_end_coord", original_shape_key: str = "foreground_original_shape", @@ -649,13 +652,17 @@ def bounding_box(self, points, img_shape): def __call__(self, data): d: Dict = dict(data) + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + guidance = d[self.guidance] - original_spatial_shape = d[self.keys[0]].shape[1:] + original_spatial_shape = d[first_key].shape[1:] box_start, box_end = self.bounding_box(np.array(guidance[0] + guidance[1]), original_spatial_shape) - center = list(np.mean([box_start, box_end], axis=0).astype(int)) + center = list(np.mean([box_start, box_end], axis=0).astype(int, copy=False)) spatial_size = self.spatial_size - box_size = list(np.subtract(box_end, box_start).astype(int)) + box_size = list(np.subtract(box_end, box_start).astype(int, copy=False)) spatial_size = spatial_size[-len(box_size) :] if len(spatial_size) < len(box_size): @@ -721,7 +728,7 @@ def __init__( guidance: str, ref_image: str, meta_keys: Optional[str] = None, - meta_key_postfix: str = "meta_dict", + meta_key_postfix: str = DEFAULT_POST_FIX, cropped_shape_key: str = "foreground_cropped_shape", ) -> None: self.guidance = guidance @@ -739,8 +746,8 @@ def __call__(self, data): factor = np.divide(current_shape, cropped_shape) pos_clicks, neg_clicks = guidance[0], guidance[1] - pos = np.multiply(pos_clicks, factor).astype(int).tolist() if len(pos_clicks) else [] - neg = np.multiply(neg_clicks, factor).astype(int).tolist() if len(neg_clicks) else [] + pos = np.multiply(pos_clicks, factor).astype(int, copy=False).tolist() if len(pos_clicks) else [] + neg = np.multiply(neg_clicks, factor).astype(int, copy=False).tolist() if len(neg_clicks) else [] d[self.guidance] = [pos, neg] return d @@ -778,14 +785,14 @@ class RestoreLabeld(MapTransform): One of the listed string values or a user supplied function for padding. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html align_corners: Geometrically, we consider the pixels of the input as squares rather than points. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of bool, each element corresponds to a key in ``keys``. meta_keys: explicitly indicate the key of the corresponding meta data dictionary. for example, for data with key `image`, the metadata by default is in `image_meta_dict`. the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_key is None, use `key_{meta_key_postfix} to to fetch the meta data according + meta_key_postfix: if meta_key is None, use `key_{meta_key_postfix} to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. @@ -804,7 +811,7 @@ def __init__( mode: Union[Sequence[Union[InterpolateMode, str]], InterpolateMode, str] = InterpolateMode.NEAREST, align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, meta_keys: Optional[str] = None, - meta_key_postfix: str = "meta_dict", + meta_key_postfix: str = DEFAULT_POST_FIX, start_coord_key: str = "foreground_start_coord", end_coord_key: str = "foreground_end_coord", original_shape_key: str = "foreground_original_shape", @@ -895,7 +902,7 @@ class Fetch2DSliced(MapTransform): the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: use `key_{meta_key_postfix}` to to fetch the meta data according to the key data, + meta_key_postfix: use `key_{meta_key_postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. @@ -908,7 +915,7 @@ def __init__( guidance="guidance", axis: int = 0, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", + meta_key_postfix: str = DEFAULT_POST_FIX, allow_missing_keys: bool = False, ): super().__init__(keys, allow_missing_keys) diff --git a/monai/apps/mmars/__init__.py b/monai/apps/mmars/__init__.py index 396be2e87d..8f1448bb06 100644 --- a/monai/apps/mmars/__init__.py +++ b/monai/apps/mmars/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/apps/mmars/mmars.py b/monai/apps/mmars/mmars.py index e7ff28ce44..f28ec0c65e 100644 --- a/monai/apps/mmars/mmars.py +++ b/monai/apps/mmars/mmars.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,12 +19,14 @@ import json import os import warnings -from typing import Mapping, Union +from pathlib import Path +from typing import Mapping, Optional, Union import torch import monai.networks.nets as monai_nets -from monai.apps.utils import download_and_extract +from monai.apps.utils import download_and_extract, logger +from monai.config.type_definitions import PathLike from monai.utils.module import optional_import from .model_desc import MODEL_DESC @@ -40,10 +42,9 @@ def get_model_spec(idx: Union[int, str]): if isinstance(idx, str): key = idx.strip().lower() for cand in MODEL_DESC: - if str(cand[Keys.ID]).strip().lower() == key: + if str(cand.get(Keys.ID)).strip().lower() == key: return cand - print(f"Available specs are: {MODEL_DESC}.") - raise ValueError(f"Unknown MODEL_DESC request: {idx}") + return idx def _get_all_ngc_models(pattern, page_index=0, page_size=50): @@ -76,6 +77,7 @@ def _get_all_ngc_models(pattern, page_index=0, page_size=50): requests_get, has_requests = optional_import("requests", name="get") if has_requests: resp = requests_get(full_url) + resp.raise_for_status() else: raise ValueError("NGC API requires requests package. Please install it.") model_list = json.loads(resp.text) @@ -98,7 +100,9 @@ def _get_ngc_doc_url(model_name: str, model_prefix=""): return f"https://ngc.nvidia.com/catalog/models/{model_prefix}{model_name}" -def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False, version: int = -1): +def download_mmar( + item, mmar_dir: Optional[PathLike] = None, progress: bool = True, api: bool = True, version: int = -1 +): """ Download and extract Medical Model Archive (MMAR) from Nvidia Clara Train. @@ -128,22 +132,22 @@ def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False, if not mmar_dir: get_dir, has_home = optional_import("torch.hub", name="get_dir") if has_home: - mmar_dir = os.path.join(get_dir(), "mmars") + mmar_dir = Path(get_dir()) / "mmars" else: raise ValueError("mmar_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?") - + mmar_dir = Path(mmar_dir) if api: - model_dict = _get_all_ngc_models(item) + model_dict = _get_all_ngc_models(item.get(Keys.NAME, f"{item}") if isinstance(item, Mapping) else f"{item}") if len(model_dict) == 0: raise ValueError(f"api query returns no item for pattern {item}. Please change or shorten it.") model_dir_list = [] for k, v in model_dict.items(): ver = v["latest"] if version == -1 else str(version) download_url = _get_ngc_url(k, ver) - model_dir = os.path.join(mmar_dir, v["name"]) + model_dir = mmar_dir / v["name"] download_and_extract( url=download_url, - filepath=os.path.join(mmar_dir, f'{v["name"]}_{ver}.zip'), + filepath=mmar_dir / f'{v["name"]}_{ver}.zip', output_dir=model_dir, hash_val=None, hash_type="md5", @@ -152,20 +156,21 @@ def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False, progress=progress, ) model_dir_list.append(model_dir) - return model_dir_list + if not model_dir_list: + raise ValueError(f"api query download no item for pattern {item}. Please change or shorten it.") + return model_dir_list[0] if not isinstance(item, Mapping): item = get_model_spec(item) - ver = item.get(Keys.VERSION, 1) if version > 0: ver = str(version) model_fullname = f"{item[Keys.NAME]}_{ver}" - model_dir = os.path.join(mmar_dir, model_fullname) + model_dir = mmar_dir / model_fullname model_url = item.get(Keys.URL) or _get_ngc_url(item[Keys.NAME], version=ver, model_prefix="nvidia/med/") download_and_extract( url=model_url, - filepath=os.path.join(mmar_dir, f"{model_fullname}.{item[Keys.FILE_TYPE]}"), + filepath=mmar_dir / f"{model_fullname}.{item[Keys.FILE_TYPE]}", output_dir=model_dir, hash_val=item[Keys.HASH_VAL], hash_type=item[Keys.HASH_TYPE], @@ -178,13 +183,15 @@ def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False, def load_from_mmar( item, - mmar_dir=None, + mmar_dir: Optional[PathLike] = None, progress: bool = True, version: int = -1, map_location=None, pretrained=True, weights_only=False, model_key: str = "model", + api: bool = True, + model_file=None, ): """ Download and extract Medical Model Archive (MMAR) model weights from Nvidia Clara Train. @@ -200,6 +207,8 @@ def load_from_mmar( model_key: a key to search in the model file or config file for the model dictionary. Currently this function assumes that the model dictionary has `{"[name|path]": "test.module", "args": {'kw': 'test'}}`. + api: whether to query NGC API to get model infomation. + model_file: the relative path to the model file within an MMAR. Examples:: >>> from monai.apps import load_from_mmar @@ -209,14 +218,18 @@ def load_from_mmar( See Also: https://docs.nvidia.com/clara/ """ + if api: + item = {Keys.NAME: get_model_spec(item)[Keys.NAME] if isinstance(item, int) else f"{item}"} if not isinstance(item, Mapping): item = get_model_spec(item) - model_dir = download_mmar(item=item, mmar_dir=mmar_dir, progress=progress, version=version) - model_file = os.path.join(model_dir, item[Keys.MODEL_FILE]) - print(f'\n*** "{item[Keys.ID]}" available at {model_dir}.') + model_dir = download_mmar(item=item, mmar_dir=mmar_dir, progress=progress, version=version, api=api) + if model_file is None: + model_file = os.path.join("models", "model.pt") + model_file = model_dir / item.get(Keys.MODEL_FILE, model_file) + logger.info(f'\n*** "{item.get(Keys.NAME)}" available at {model_dir}.') # loading with `torch.jit.load` - if f"{model_file}".endswith(".ts"): + if model_file.name.endswith(".ts"): if not pretrained: warnings.warn("Loading a ScriptModule, 'pretrained' option ignored.") if weights_only: @@ -232,7 +245,7 @@ def load_from_mmar( model_config = _get_val(dict(model_dict).get("train_conf", {}), key=model_key, default={}) if not model_config: # 2. search json CONFIG_FILE for model config spec. - json_path = os.path.join(model_dir, item.get(Keys.CONFIG_FILE, "config_train.json")) + json_path = model_dir / item.get(Keys.CONFIG_FILE, os.path.join("config", "config_train.json")) with open(json_path) as f: conf_dict = json.load(f) conf_dict = dict(conf_dict) @@ -264,18 +277,18 @@ def load_from_mmar( else: raise ValueError(f"Could not load model config {model_config}.") - print(f"*** Model: {model_cls}") + logger.info(f"*** Model: {model_cls}") model_kwargs = model_config.get("args", None) if model_kwargs: model_inst = model_cls(**model_kwargs) - print(f"*** Model params: {model_kwargs}") + logger.info(f"*** Model params: {model_kwargs}") else: model_inst = model_cls() if pretrained: model_inst.load_state_dict(model_dict.get(model_key, model_dict)) - print("\n---") + logger.info("\n---") doc_url = item.get(Keys.DOC) or _get_ngc_doc_url(item[Keys.NAME], model_prefix="nvidia:med:") - print(f"For more information, please visit {doc_url}\n") + logger.info(f"For more information, please visit {doc_url}\n") return model_inst diff --git a/monai/apps/mmars/model_desc.py b/monai/apps/mmars/model_desc.py index fca6f60da5..ae0e9cae30 100644 --- a/monai/apps/mmars/model_desc.py +++ b/monai/apps/mmars/model_desc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/apps/pathology/__init__.py b/monai/apps/pathology/__init__.py index 80f32403ea..81742caf65 100644 --- a/monai/apps/pathology/__init__.py +++ b/monai/apps/pathology/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/apps/pathology/data/__init__.py b/monai/apps/pathology/data/__init__.py index 64556b6f6e..e1b2ef7bd2 100644 --- a/monai/apps/pathology/data/__init__.py +++ b/monai/apps/pathology/data/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/apps/pathology/data/datasets.py b/monai/apps/pathology/data/datasets.py index 3694ca4144..77e3bb34c4 100644 --- a/monai/apps/pathology/data/datasets.py +++ b/monai/apps/pathology/data/datasets.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -64,7 +64,7 @@ def __init__( self.patch_size = ensure_tuple_rep(patch_size, 2) self.image_path_list = list({x["image"] for x in self.data}) - self.image_reader_name = image_reader_name + self.image_reader_name = image_reader_name.lower() self.image_reader = WSIReader(image_reader_name) self.wsi_object_dict = None if self.image_reader_name != "openslide": @@ -119,9 +119,17 @@ class SmartCachePatchWSIDataset(SmartCacheDataset): will take the minimum of (cache_num, data_length x cache_rate, data_length). num_init_workers: the number of worker threads to initialize the cache for first epoch. If num_init_workers is None then the number returned by os.cpu_count() is used. + If a value less than 1 is speficied, 1 will be used instead. num_replace_workers: the number of worker threads to prepare the replacement cache for every epoch. If num_replace_workers is None then the number returned by os.cpu_count() is used. + If a value less than 1 is speficied, 1 will be used instead. progress: whether to display a progress bar when caching for the first epoch. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cache content + or every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. + as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous. + it may help improve the performance of following logic. """ @@ -136,9 +144,11 @@ def __init__( replace_rate: float = 0.5, cache_num: int = sys.maxsize, cache_rate: float = 1.0, - num_init_workers: Optional[int] = None, - num_replace_workers: Optional[int] = None, + num_init_workers: Optional[int] = 1, + num_replace_workers: Optional[int] = 1, progress: bool = True, + copy_cache: bool = True, + as_contiguous: bool = True, ): patch_wsi_dataset = PatchWSIDataset( data=data, @@ -157,6 +167,8 @@ def __init__( num_replace_workers=num_replace_workers, progress=progress, shuffle=False, + copy_cache=copy_cache, + as_contiguous=as_contiguous, ) @@ -190,7 +202,7 @@ def __init__( self.patch_size = ensure_tuple_rep(patch_size, 2) # set up whole slide image reader - self.image_reader_name = image_reader_name + self.image_reader_name = image_reader_name.lower() self.image_reader = WSIReader(image_reader_name) # process data and create a list of dictionaries containing all required data and metadata @@ -293,11 +305,7 @@ def _load_a_patch(self, index): location_on_image = sample["image_locations"][patch_num] location_on_mask = sample["mask_locations"][patch_num] - image, _ = self.image_reader.get_data( - img=sample["image"], - location=location_on_image, - size=self.patch_size, - ) + image, _ = self.image_reader.get_data(img=sample["image"], location=location_on_image, size=self.patch_size) processed_sample = {"image": image, "name": sample["name"], "mask_location": location_on_mask} return processed_sample diff --git a/monai/apps/pathology/handlers/__init__.py b/monai/apps/pathology/handlers/__init__.py index 3a788ffa26..0638950bd8 100644 --- a/monai/apps/pathology/handlers/__init__.py +++ b/monai/apps/pathology/handlers/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/apps/pathology/handlers/prob_map_producer.py b/monai/apps/pathology/handlers/prob_map_producer.py index 7ac4a0e45b..0615b44b8f 100644 --- a/monai/apps/pathology/handlers/prob_map_producer.py +++ b/monai/apps/pathology/handlers/prob_map_producer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -62,9 +62,10 @@ def attach(self, engine: Engine) -> None: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - self.num_images = len(engine.data_loader.dataset.data) + data_loader = engine.data_loader # type: ignore + self.num_images = len(data_loader.dataset.data) - for sample in engine.data_loader.dataset.data: + for sample in data_loader.dataset.data: name = sample["name"] self.prob_map[name] = np.zeros(sample["mask_shape"], dtype=self.dtype) self.counter[name] = len(sample["mask_locations"]) @@ -84,6 +85,8 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ + if not isinstance(engine.state.batch, dict) or not isinstance(engine.state.output, dict): + raise ValueError("engine.state.batch and engine.state.output must be dictionaries.") names = engine.state.batch["name"] locs = engine.state.batch["mask_location"] pred = engine.state.output["pred"] diff --git a/monai/apps/pathology/metrics/__init__.py b/monai/apps/pathology/metrics/__init__.py index ad62df524a..f19811dcaf 100644 --- a/monai/apps/pathology/metrics/__init__.py +++ b/monai/apps/pathology/metrics/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/apps/pathology/metrics/lesion_froc.py b/monai/apps/pathology/metrics/lesion_froc.py index 2140de0080..6073bd0cda 100644 --- a/monai/apps/pathology/metrics/lesion_froc.py +++ b/monai/apps/pathology/metrics/lesion_froc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -78,11 +78,7 @@ def __init__( self.itc_diameter = itc_diameter self.eval_thresholds = eval_thresholds self.image_reader = WSIReader(image_reader_name) - self.nms = PathologyProbNMS( - sigma=nms_sigma, - prob_threshold=nms_prob_threshold, - box_size=nms_box_size, - ) + self.nms = PathologyProbNMS(sigma=nms_sigma, prob_threshold=nms_prob_threshold, box_size=nms_box_size) def prepare_inference_result(self, sample: Dict): """ @@ -151,12 +147,7 @@ def compute_fp_tp(self): total_tp_probs.extend(tp_probs) total_num_targets += num_targets - return ( - np.array(total_fp_probs), - np.array(total_tp_probs), - total_num_targets, - num_images, - ) + return np.array(total_fp_probs), np.array(total_tp_probs), total_num_targets, num_images def evaluate(self): """ @@ -168,17 +159,12 @@ def evaluate(self): # compute FROC curve given the evaluation of all images fps_per_image, total_sensitivity = compute_froc_curve_data( - fp_probs=fp_probs, - tp_probs=tp_probs, - num_targets=num_targets, - num_images=num_images, + fp_probs=fp_probs, tp_probs=tp_probs, num_targets=num_targets, num_images=num_images ) # compute FROC score give specific evaluation threshold froc_score = compute_froc_score( - fps_per_image=fps_per_image, - total_sensitivity=total_sensitivity, - eval_thresholds=self.eval_thresholds, + fps_per_image=fps_per_image, total_sensitivity=total_sensitivity, eval_thresholds=self.eval_thresholds ) return froc_score diff --git a/monai/apps/pathology/transforms/__init__.py b/monai/apps/pathology/transforms/__init__.py index 1be96b8e34..290c0ba6a8 100644 --- a/monai/apps/pathology/transforms/__init__.py +++ b/monai/apps/pathology/transforms/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,8 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .spatial.array import SplitOnGrid -from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict +from .spatial.array import SplitOnGrid, TileOnGrid +from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict from .stain.array import ExtractHEStains, NormalizeHEStains from .stain.dictionary import ( ExtractHEStainsd, diff --git a/monai/apps/pathology/transforms/spatial/__init__.py b/monai/apps/pathology/transforms/spatial/__init__.py index 07ba222ab0..eed111d2b6 100644 --- a/monai/apps/pathology/transforms/spatial/__init__.py +++ b/monai/apps/pathology/transforms/spatial/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,5 +9,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .array import SplitOnGrid -from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict +from .array import SplitOnGrid, TileOnGrid +from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict diff --git a/monai/apps/pathology/transforms/spatial/array.py b/monai/apps/pathology/transforms/spatial/array.py index 53e0c63715..a44dce1e3f 100644 --- a/monai/apps/pathology/transforms/spatial/array.py +++ b/monai/apps/pathology/transforms/spatial/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,13 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Optional, Sequence, Tuple, Union +import numpy as np import torch +from numpy.lib.stride_tricks import as_strided -from monai.transforms.transform import Transform +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.transform import Randomizable, Transform +from monai.utils import convert_data_type, convert_to_dst_type +from monai.utils.enums import TransformBackends -__all__ = ["SplitOnGrid"] +__all__ = ["SplitOnGrid", "TileOnGrid"] class SplitOnGrid(Transform): @@ -24,19 +29,19 @@ class SplitOnGrid(Transform): This transform works only with torch.Tensor inputs. Args: - grid_shape: a tuple or an integer define the shape of the grid upon which to extract patches. + grid_size: a tuple or an integer define the shape of the grid upon which to extract patches. If it's an integer, the value will be repeated for each dimension. Default is 2x2 patch_size: a tuple or an integer that defines the output patch sizes. If it's an integer, the value will be repeated for each dimension. - The default is (0, 0), where the patch size will be infered from the grid shape. + The default is (0, 0), where the patch size will be inferred from the grid shape. - Note: the shape of the input image is infered based on the first image used. + Note: the shape of the input image is inferred based on the first image used. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( - self, - grid_size: Union[int, Tuple[int, int]] = (2, 2), - patch_size: Optional[Union[int, Tuple[int, int]]] = None, + self, grid_size: Union[int, Tuple[int, int]] = (2, 2), patch_size: Optional[Union[int, Tuple[int, int]]] = None ): # Grid size if isinstance(grid_size, int): @@ -50,17 +55,42 @@ def __init__( else: self.patch_size = patch_size - def __call__(self, image: torch.Tensor) -> torch.Tensor: + def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: if self.grid_size == (1, 1) and self.patch_size is None: - return torch.stack([image]) + if isinstance(image, torch.Tensor): + return torch.stack([image]) + elif isinstance(image, np.ndarray): + return np.stack([image]) # type: ignore + else: + raise ValueError(f"Input type [{type(image)}] is not supported.") + patch_size, steps = self.get_params(image.shape[1:]) - patches = ( - image.unfold(1, patch_size[0], steps[0]) - .unfold(2, patch_size[1], steps[1]) - .flatten(1, 2) - .transpose(0, 1) - .contiguous() - ) + patches: NdarrayOrTensor + if isinstance(image, torch.Tensor): + patches = ( + image.unfold(1, patch_size[0], steps[0]) + .unfold(2, patch_size[1], steps[1]) + .flatten(1, 2) + .transpose(0, 1) + .contiguous() + ) + elif isinstance(image, np.ndarray): + x_step, y_step = steps + c_stride, x_stride, y_stride = image.strides + n_channels = image.shape[0] + patches = as_strided( + image, + shape=(*self.grid_size, n_channels, patch_size[0], patch_size[1]), + strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), + writeable=False, + ) + # flatten the first two dimensions + patches = patches.reshape(np.prod(patches.shape[:2]), *patches.shape[2:]) + # make it a contiguous array + patches = np.ascontiguousarray(patches) + else: + raise ValueError(f"Input type [{type(image)}] is not supported.") + return patches def get_params(self, image_size): @@ -75,3 +105,158 @@ def get_params(self, image_size): ) return patch_size, steps + + +class TileOnGrid(Randomizable, Transform): + """ + Tile the 2D image into patches on a grid and maintain a subset of it. + This transform works only with np.ndarray inputs for 2D images. + + Args: + tile_count: number of tiles to extract, if None extracts all non-background tiles + Defaults to ``None``. + tile_size: size of the square tile + Defaults to ``256``. + step: step size + Defaults to ``None`` (same as tile_size) + random_offset: Randomize position of the grid, instead of starting from the top-left corner + Defaults to ``False``. + pad_full: pad image to the size evenly divisible by tile_size + Defaults to ``False``. + background_val: the background constant (e.g. 255 for white background) + Defaults to ``255``. + filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is more than tile_size, + then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for random) subset + Defaults to ``min`` (which assumes background is high value) + + """ + + backend = [TransformBackends.NUMPY] + + def __init__( + self, + tile_count: Optional[int] = None, + tile_size: int = 256, + step: Optional[int] = None, + random_offset: bool = False, + pad_full: bool = False, + background_val: int = 255, + filter_mode: str = "min", + ): + self.tile_count = tile_count + self.tile_size = tile_size + self.random_offset = random_offset + self.pad_full = pad_full + self.background_val = background_val + self.filter_mode = filter_mode + + if step is None: + # non-overlapping grid + self.step = self.tile_size + else: + self.step = step + + self.offset = (0, 0) + self.random_idxs = np.array((0,)) + + if self.filter_mode not in ["min", "max", "random"]: + raise ValueError("Unsupported filter_mode, must be [min, max or random]: " + str(self.filter_mode)) + + def randomize(self, img_size: Sequence[int]) -> None: + + c, h, w = img_size + + self.offset = (0, 0) + if self.random_offset: + pad_h = h % self.tile_size + pad_w = w % self.tile_size + self.offset = (self.R.randint(pad_h) if pad_h > 0 else 0, self.R.randint(pad_w) if pad_w > 0 else 0) + h = h - self.offset[0] + w = w - self.offset[1] + + if self.pad_full: + pad_h = (self.tile_size - h % self.tile_size) % self.tile_size + pad_w = (self.tile_size - w % self.tile_size) % self.tile_size + h = h + pad_h + w = w + pad_w + + h_n = (h - self.tile_size + self.step) // self.step + w_n = (w - self.tile_size + self.step) // self.step + tile_total = h_n * w_n + + if self.tile_count is not None and tile_total > self.tile_count: + self.random_idxs = self.R.choice(range(tile_total), self.tile_count, replace=False) + else: + self.random_idxs = np.array((0,)) + + def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: + img_np, *_ = convert_data_type(image, np.ndarray) + + # add random offset + self.randomize(img_size=img_np.shape) + + if self.random_offset and (self.offset[0] > 0 or self.offset[1] > 0): + img_np = img_np[:, self.offset[0] :, self.offset[1] :] + + # pad to full size, divisible by tile_size + if self.pad_full: + c, h, w = img_np.shape + pad_h = (self.tile_size - h % self.tile_size) % self.tile_size + pad_w = (self.tile_size - w % self.tile_size) % self.tile_size + img_np = np.pad( # type: ignore + img_np, + [[0, 0], [pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2]], + constant_values=self.background_val, + ) + + # extact tiles + x_step, y_step = self.step, self.step + h_tile, w_tile = self.tile_size, self.tile_size + c_image, h_image, w_image = img_np.shape + c_stride, x_stride, y_stride = img_np.strides + llw = as_strided( + img_np, + shape=((h_image - h_tile) // x_step + 1, (w_image - w_tile) // y_step + 1, c_image, h_tile, w_tile), + strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), + writeable=False, + ) + img_np = llw.reshape(-1, c_image, h_tile, w_tile) # type: ignore + + # if keeping all patches + if self.tile_count is None: + # retain only patches with significant foreground content to speed up inference + # FYI, this returns a variable number of tiles, so the batch_size must be 1 (per gpu), e.g during inference + thresh = 0.999 * 3 * self.background_val * self.tile_size * self.tile_size + if self.filter_mode == "min": + # default, keep non-background tiles (small values) + idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) < thresh) + img_np = img_np[idxs.reshape(-1)] + elif self.filter_mode == "max": + idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) >= thresh) + img_np = img_np[idxs.reshape(-1)] + + else: + if len(img_np) > self.tile_count: + + if self.filter_mode == "min": + # default, keep non-background tiles (smallest values) + idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[: self.tile_count] + img_np = img_np[idxs] + elif self.filter_mode == "max": + idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[-self.tile_count :] + img_np = img_np[idxs] + else: + # random subset (more appropriate for WSIs without distinct background) + if self.random_idxs is not None: + img_np = img_np[self.random_idxs] + + elif len(img_np) < self.tile_count: + img_np = np.pad( # type: ignore + img_np, + [[0, self.tile_count - len(img_np)], [0, 0], [0, 0], [0, 0]], + constant_values=self.background_val, + ) + + image, *_ = convert_to_dst_type(src=img_np, dst=image, dtype=image.dtype) + + return image diff --git a/monai/apps/pathology/transforms/spatial/dictionary.py b/monai/apps/pathology/transforms/spatial/dictionary.py index 10b01a39de..d5c34a0840 100644 --- a/monai/apps/pathology/transforms/spatial/dictionary.py +++ b/monai/apps/pathology/transforms/spatial/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,16 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Hashable, Mapping, Optional, Tuple, Union - -import torch +import copy +from typing import Any, Dict, Hashable, List, Mapping, Optional, Tuple, Union from monai.config import KeysCollection -from monai.transforms.transform import MapTransform +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.transform import MapTransform, Randomizable -from .array import SplitOnGrid +from .array import SplitOnGrid, TileOnGrid -__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict"] +__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict", "TileOnGridd", "TileOnGridD", "TileOnGridDict"] class SplitOnGridd(MapTransform): @@ -27,15 +27,17 @@ class SplitOnGridd(MapTransform): This transform works only with torch.Tensor inputs. Args: - grid_shape: a tuple or an integer define the shape of the grid upon which to extract patches. + grid_size: a tuple or an integer define the shape of the grid upon which to extract patches. If it's an integer, the value will be repeated for each dimension. Default is 2x2 patch_size: a tuple or an integer that defines the output patch sizes. If it's an integer, the value will be repeated for each dimension. - The default is (0, 0), where the patch size will be infered from the grid shape. + The default is (0, 0), where the patch size will be inferred from the grid shape. - Note: the shape of the input image is infered based on the first image used. + Note: the shape of the input image is inferred based on the first image used. """ + backend = SplitOnGrid.backend + def __init__( self, keys: KeysCollection, @@ -46,11 +48,89 @@ def __init__( super().__init__(keys, allow_missing_keys) self.splitter = SplitOnGrid(grid_size=grid_size, patch_size=patch_size) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.splitter(d[key]) return d +class TileOnGridd(Randomizable, MapTransform): + """ + Tile the 2D image into patches on a grid and maintain a subset of it. + This transform works only with np.ndarray inputs for 2D images. + + Args: + tile_count: number of tiles to extract, if None extracts all non-background tiles + Defaults to ``None``. + tile_size: size of the square tile + Defaults to ``256``. + step: step size + Defaults to ``None`` (same as tile_size) + random_offset: Randomize position of the grid, instead of starting from the top-left corner + Defaults to ``False``. + pad_full: pad image to the size evenly divisible by tile_size + Defaults to ``False``. + background_val: the background constant (e.g. 255 for white background) + Defaults to ``255``. + filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is more than tile_size, + then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for random) subset + Defaults to ``min`` (which assumes background is high value) + + """ + + backend = SplitOnGrid.backend + + def __init__( + self, + keys: KeysCollection, + tile_count: Optional[int] = None, + tile_size: int = 256, + step: Optional[int] = None, + random_offset: bool = False, + pad_full: bool = False, + background_val: int = 255, + filter_mode: str = "min", + allow_missing_keys: bool = False, + return_list_of_dicts: bool = False, + ): + super().__init__(keys, allow_missing_keys) + + self.return_list_of_dicts = return_list_of_dicts + self.seed = None + + self.splitter = TileOnGrid( + tile_count=tile_count, + tile_size=tile_size, + step=step, + random_offset=random_offset, + pad_full=pad_full, + background_val=background_val, + filter_mode=filter_mode, + ) + + def randomize(self, data: Any = None) -> None: + self.seed = self.R.randint(10000) # type: ignore + + def __call__( + self, data: Mapping[Hashable, NdarrayOrTensor] + ) -> Union[Dict[Hashable, NdarrayOrTensor], List[Dict[Hashable, NdarrayOrTensor]]]: + + self.randomize() + + d = dict(data) + for key in self.key_iterator(d): + self.splitter.set_random_state(seed=self.seed) # same random seed for all keys + d[key] = self.splitter(d[key]) + + if self.return_list_of_dicts: + d_list = [] + for i in range(len(d[self.keys[0]])): + d_list.append({k: d[k][i] if k in self.keys else copy.deepcopy(d[k]) for k in d.keys()}) + d = d_list # type: ignore + + return d + + SplitOnGridDict = SplitOnGridD = SplitOnGridd +TileOnGridDict = TileOnGridD = TileOnGridd diff --git a/monai/apps/pathology/transforms/stain/__init__.py b/monai/apps/pathology/transforms/stain/__init__.py index 824f40a579..dfa235de55 100644 --- a/monai/apps/pathology/transforms/stain/__init__.py +++ b/monai/apps/pathology/transforms/stain/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/apps/pathology/transforms/stain/array.py b/monai/apps/pathology/transforms/stain/array.py index ccddc6b243..3b3a293451 100644 --- a/monai/apps/pathology/transforms/stain/array.py +++ b/monai/apps/pathology/transforms/stain/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -52,12 +52,12 @@ def _deconvolution_extract_stain(self, image: np.ndarray) -> np.ndarray: """Perform Stain Deconvolution and return stain matrix for the image. Args: - img: uint8 RGB image to perform stain deconvolution on + image: uint8 RGB image to perform stain deconvolution on Return: he: H&E absorbance matrix for the image (first column is H, second column is E, rows are RGB values) """ - # check image type and vlues + # check image type and values if not isinstance(image, np.ndarray): raise TypeError("Image must be of type numpy.ndarray.") if image.min() < 0: @@ -67,7 +67,7 @@ def _deconvolution_extract_stain(self, image: np.ndarray) -> np.ndarray: # reshape image and calculate absorbance image = image.reshape((-1, 3)) - image = image.astype(np.float32) + 1.0 + image = image.astype(np.float32, copy=False) + 1.0 absorbance = -np.log(image.clip(max=self.tli) / self.tli) # remove transparent pixels @@ -76,7 +76,7 @@ def _deconvolution_extract_stain(self, image: np.ndarray) -> np.ndarray: raise ValueError("All pixels of the input image are below the absorbance threshold.") # compute eigenvectors - _, eigvecs = np.linalg.eigh(np.cov(absorbance_hat.T).astype(np.float32)) + _, eigvecs = np.linalg.eigh(np.cov(absorbance_hat.T).astype(np.float32, copy=False)) # project on the plane spanned by the eigenvectors corresponding to the two largest eigenvalues t_hat = absorbance_hat.dot(eigvecs[:, 1:3]) @@ -162,7 +162,7 @@ def __call__(self, image: np.ndarray) -> np.ndarray: Return: image_norm: stain normalized image/patch """ - # check image type and vlues + # check image type and values if not isinstance(image, np.ndarray): raise TypeError("Image must be of type numpy.ndarray.") if image.min() < 0: @@ -186,7 +186,7 @@ def __call__(self, image: np.ndarray) -> np.ndarray: conc = np.linalg.lstsq(he, y, rcond=None)[0] # normalize stain concentrations - max_conc = np.array([np.percentile(conc[0, :], 99), np.percentile(conc[1, :], 99)], dtype=np.float32) + max_conc = np.asarray([np.percentile(conc[0, :], 99), np.percentile(conc[1, :], 99)], dtype=np.float32) tmp = np.divide(max_conc, self.max_cref, dtype=np.float32) image_c = np.divide(conc, tmp[:, np.newaxis], dtype=np.float32) diff --git a/monai/apps/pathology/transforms/stain/dictionary.py b/monai/apps/pathology/transforms/stain/dictionary.py index 976af1e7c7..eb8eba43f8 100644 --- a/monai/apps/pathology/transforms/stain/dictionary.py +++ b/monai/apps/pathology/transforms/stain/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/apps/pathology/utils.py b/monai/apps/pathology/utils.py index 54d49f5717..5a57364a11 100644 --- a/monai/apps/pathology/utils.py +++ b/monai/apps/pathology/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -62,11 +62,7 @@ class PathologyProbNMS(ProbNMS): Pathology. """ - def __call__( - self, - probs_map: Union[np.ndarray, torch.Tensor], - resolution_level: int = 0, - ): + def __call__(self, probs_map: Union[np.ndarray, torch.Tensor], resolution_level: int = 0): """ probs_map: the input probabilities map, it must have shape (H[, W, ...]). resolution_level: the level at which the probabilities map is made. diff --git a/monai/apps/utils.py b/monai/apps/utils.py index 36fac955fe..209dc796cf 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,17 +10,22 @@ # limitations under the License. import hashlib +import logging import os import shutil +import sys import tarfile import tempfile import warnings import zipfile +from pathlib import Path from typing import TYPE_CHECKING, Optional from urllib.error import ContentTooShortError, HTTPError, URLError +from urllib.parse import urlparse from urllib.request import urlretrieve -from monai.utils import min_version, optional_import +from monai.config.type_definitions import PathLike +from monai.utils import look_up_option, min_version, optional_import gdown, has_gdown = optional_import("gdown", "3.6") @@ -31,18 +36,47 @@ else: tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm") -__all__ = [ - "check_hash", - "download_url", - "extractall", - "download_and_extract", -] +__all__ = ["check_hash", "download_url", "extractall", "download_and_extract", "get_logger", "SUPPORTED_HASH_TYPES"] +DEFAULT_FMT = "%(asctime)s - %(levelname)s - %(message)s" +SUPPORTED_HASH_TYPES = {"md5": hashlib.md5, "sha1": hashlib.sha1, "sha256": hashlib.sha256, "sha512": hashlib.sha512} -def _basename(p): + +def get_logger( + module_name: str = "monai.apps", + fmt: str = DEFAULT_FMT, + datefmt: Optional[str] = None, + logger_handler: Optional[logging.Handler] = None, +): + """ + Get a `module_name` logger with the specified format and date format. + By default, the logger will print to `stdout` at the INFO level. + If `module_name` is `None`, return the root logger. + `fmt` and `datafmt` are passed to a `logging.Formatter` object + (https://docs.python.org/3/library/logging.html#formatter-objects). + `logger_handler` can be used to add an additional handler. + """ + logger = logging.getLogger(module_name) + logger.propagate = False + logger.setLevel(logging.INFO) + handler = logging.StreamHandler(sys.stdout) + formatter = logging.Formatter(fmt=fmt, datefmt=datefmt) + handler.setFormatter(formatter) + logger.addHandler(handler) + if logger_handler is not None: + logger.addHandler(logger_handler) + return logger + + +# apps module-level default logger +logger = get_logger("monai.apps") +__all__.append("logger") + + +def _basename(p: PathLike) -> str: """get the last part of the path (removing the trailing slash if it exists)""" sep = os.path.sep + (os.path.altsep or "") + "/ " - return os.path.basename(p.rstrip(sep)) + return Path(f"{p}".rstrip(sep)).name def _download_with_progress(url, filepath, progress: bool = True): @@ -69,66 +103,59 @@ def update_to(self, b: int = 1, bsize: int = 1, tsize: Optional[int] = None): self.total = tsize self.update(b * bsize - self.n) # will also set self.n = b * bsize - with TqdmUpTo( - unit="B", - unit_scale=True, - unit_divisor=1024, - miniters=1, - desc=_basename(filepath), - ) as t: + with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=_basename(filepath)) as t: urlretrieve(url, filepath, reporthook=t.update_to) else: if not has_tqdm and progress: warnings.warn("tqdm is not installed, will not show the downloading progress bar.") urlretrieve(url, filepath) - except (URLError, HTTPError, ContentTooShortError, IOError) as e: - print(f"Download failed from {url} to {filepath}.") + except (URLError, HTTPError, ContentTooShortError, OSError) as e: + logger.error(f"Download failed from {url} to {filepath}.") raise e -def check_hash(filepath: str, val: Optional[str] = None, hash_type: str = "md5") -> bool: +def check_hash(filepath: PathLike, val: Optional[str] = None, hash_type: str = "md5") -> bool: """ Verify hash signature of specified file. Args: filepath: path of source file to verify hash value. val: expected hash value of the file. - hash_type: 'md5' or 'sha1', defaults to 'md5'. + hash_type: type of hash algorithm to use, default is `"md5"`. + The supported hash types are `"md5"`, `"sha1"`, `"sha256"`, `"sha512"`. + See also: :py:data:`monai.apps.utils.SUPPORTED_HASH_TYPES`. """ if val is None: - print(f"Expected {hash_type} is None, skip {hash_type} check for file {filepath}.") + logger.info(f"Expected {hash_type} is None, skip {hash_type} check for file {filepath}.") return True - if hash_type.lower() == "md5": - actual_hash = hashlib.md5() - elif hash_type.lower() == "sha1": - actual_hash = hashlib.sha1() - else: - raise NotImplementedError(f"Unknown 'hash_type' {hash_type}.") + actual_hash_func = look_up_option(hash_type.lower(), SUPPORTED_HASH_TYPES) + actual_hash = actual_hash_func() try: with open(filepath, "rb") as f: for chunk in iter(lambda: f.read(1024 * 1024), b""): actual_hash.update(chunk) except Exception as e: - print(f"Exception in check_hash: {e}") + logger.error(f"Exception in check_hash: {e}") return False if val != actual_hash.hexdigest(): - print(f"check_hash failed {actual_hash.hexdigest()}.") + logger.error(f"check_hash failed {actual_hash.hexdigest()}.") return False - print(f"Verified '{_basename(filepath)}', {hash_type}: {val}.") + logger.info(f"Verified '{_basename(filepath)}', {hash_type}: {val}.") return True def download_url( - url: str, filepath: str = "", hash_val: Optional[str] = None, hash_type: str = "md5", progress: bool = True + url: str, filepath: PathLike = "", hash_val: Optional[str] = None, hash_type: str = "md5", progress: bool = True ) -> None: """ Download file from specified URL link, support process bar and hash check. Args: url: source URL link to download file. - filepath: target filepath to save the downloaded file. If undefined, `os.path.basename(url)` will be used. + filepath: target filepath to save the downloaded file (including the filename). + If undefined, `os.path.basename(url)` will be used. hash_val: expected hash value to validate the downloaded file. if None, skip hash validation. hash_type: 'md5' or 'sha1', defaults to 'md5'. @@ -146,33 +173,36 @@ def download_url( """ if not filepath: - filepath = os.path.abspath(os.path.join(".", _basename(url))) - print(f"Default downloading to '{filepath}'") - if os.path.exists(filepath): + filepath = Path(".", _basename(url)).resolve() + logger.info(f"Default downloading to '{filepath}'") + filepath = Path(filepath) + if filepath.exists(): if not check_hash(filepath, hash_val, hash_type): raise RuntimeError( f"{hash_type} check of existing file failed: filepath={filepath}, expected {hash_type}={hash_val}." ) - print(f"File exists: {filepath}, skipped downloading.") + logger.info(f"File exists: {filepath}, skipped downloading.") return - - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_name = os.path.join(tmp_dir, f"{_basename(filepath)}") - if url.startswith("https://drive.google.com"): - if not has_gdown: - raise RuntimeError("To download files from Google Drive, please install the gdown dependency.") - gdown.download(url, tmp_name, quiet=not progress) - else: - _download_with_progress(url, tmp_name, progress=progress) - if not os.path.exists(tmp_name): - raise RuntimeError( - f"Download of file from {url} to {filepath} failed due to network issue or denied permission." - ) - file_dir = os.path.dirname(filepath) - if file_dir: - os.makedirs(file_dir, exist_ok=True) - shutil.move(tmp_name, filepath) # copy the downloaded to a user-specified cache. - print(f"Downloaded: {filepath}") + try: + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_name = Path(tmp_dir, _basename(filepath)) + if urlparse(url).netloc == "drive.google.com": + if not has_gdown: + raise RuntimeError("To download files from Google Drive, please install the gdown dependency.") + gdown.download(url, f"{tmp_name}", quiet=not progress) + else: + _download_with_progress(url, tmp_name, progress=progress) + if not tmp_name.exists(): + raise RuntimeError( + f"Download of file from {url} to {filepath} failed due to network issue or denied permission." + ) + file_dir = filepath.parent + if file_dir: + os.makedirs(file_dir, exist_ok=True) + shutil.move(f"{tmp_name}", f"{filepath}") # copy the downloaded to a user-specified cache. + except (PermissionError, NotADirectoryError): # project-monai/monai issue #3613 #3757 for windows + pass + logger.info(f"Downloaded: {filepath}") if not check_hash(filepath, hash_val, hash_type): raise RuntimeError( f"{hash_type} check of downloaded file failed: URL={url}, " @@ -181,8 +211,8 @@ def download_url( def extractall( - filepath: str, - output_dir: str = ".", + filepath: PathLike, + output_dir: PathLike = ".", hash_val: Optional[str] = None, hash_type: str = "md5", file_type: str = "", @@ -211,24 +241,25 @@ def extractall( """ if has_base: # the extracted files will be in this folder - cache_dir = os.path.join(output_dir, _basename(filepath).split(".")[0]) + cache_dir = Path(output_dir, _basename(filepath).split(".")[0]) else: - cache_dir = output_dir - if os.path.exists(cache_dir) and len(os.listdir(cache_dir)) > 0: - print(f"Non-empty folder exists in {cache_dir}, skipped extracting.") + cache_dir = Path(output_dir) + if cache_dir.exists() and next(cache_dir.iterdir(), None) is not None: + logger.info(f"Non-empty folder exists in {cache_dir}, skipped extracting.") return + filepath = Path(filepath) if hash_val and not check_hash(filepath, hash_val, hash_type): raise RuntimeError( f"{hash_type} check of compressed file failed: " f"filepath={filepath}, expected {hash_type}={hash_val}." ) - print(f"Writing into directory: {output_dir}.") + logger.info(f"Writing into directory: {output_dir}.") _file_type = file_type.lower().strip() - if filepath.endswith("zip") or _file_type == "zip": + if filepath.name.endswith("zip") or _file_type == "zip": zip_file = zipfile.ZipFile(filepath) zip_file.extractall(output_dir) zip_file.close() return - if filepath.endswith("tar") or filepath.endswith("tar.gz") or "tar" in _file_type: + if filepath.name.endswith("tar") or filepath.name.endswith("tar.gz") or "tar" in _file_type: tar_file = tarfile.open(filepath) tar_file.extractall(output_dir) tar_file.close() @@ -240,8 +271,8 @@ def extractall( def download_and_extract( url: str, - filepath: str = "", - output_dir: str = ".", + filepath: PathLike = "", + output_dir: PathLike = ".", hash_val: Optional[str] = None, hash_type: str = "md5", file_type: str = "", @@ -268,6 +299,6 @@ def download_and_extract( progress: whether to display progress bar. """ with tempfile.TemporaryDirectory() as tmp_dir: - filename = filepath or os.path.join(tmp_dir, f"{_basename(url)}") + filename = filepath or Path(tmp_dir, _basename(url)).resolve() download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress) extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base) diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py new file mode 100644 index 0000000000..72c8805e9f --- /dev/null +++ b/monai/bundle/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable +from .config_parser import ConfigParser +from .reference_resolver import ReferenceResolver +from .scripts import run, verify_metadata, verify_net_in_out +from .utils import EXPR_KEY, ID_REF_KEY, ID_SEP_KEY, MACRO_KEY diff --git a/monai/bundle/__main__.py b/monai/bundle/__main__.py new file mode 100644 index 0000000000..0ff0a476ef --- /dev/null +++ b/monai/bundle/__main__.py @@ -0,0 +1,19 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from monai.bundle.scripts import run, verify_metadata, verify_net_in_out + +if __name__ == "__main__": + from monai.utils import optional_import + + fire, _ = optional_import("fire") + fire.Fire() diff --git a/monai/bundle/config_item.py b/monai/bundle/config_item.py new file mode 100644 index 0000000000..0531c6f14e --- /dev/null +++ b/monai/bundle/config_item.py @@ -0,0 +1,385 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import inspect +import os +import sys +import warnings +from abc import ABC, abstractmethod +from importlib import import_module +from typing import Any, Dict, List, Mapping, Optional, Sequence, Union + +from monai.bundle.utils import EXPR_KEY +from monai.utils import ensure_tuple, first, instantiate, optional_import + +__all__ = ["ComponentLocator", "ConfigItem", "ConfigExpression", "ConfigComponent"] + + +class Instantiable(ABC): + """ + Base class for an instantiable object. + """ + + @abstractmethod + def is_disabled(self, *args: Any, **kwargs: Any) -> bool: + """ + Return a boolean flag to indicate whether the object should be instantiated. + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def instantiate(self, *args: Any, **kwargs: Any) -> object: + """ + Instantiate the target component and return the instance. + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + +class ComponentLocator: + """ + Scan all the available classes and functions in the MONAI package and map them with the module paths in a table. + It's used to locate the module path for provided component name. + + Args: + excludes: if any string of the `excludes` exists in the full module name, don't import this module. + + """ + + MOD_START = "monai" + + def __init__(self, excludes: Optional[Union[Sequence[str], str]] = None): + self.excludes = [] if excludes is None else ensure_tuple(excludes) + self._components_table: Optional[Dict[str, List]] = None + + def _find_module_names(self) -> List[str]: + """ + Find all the modules start with MOD_START and don't contain any of `excludes`. + + """ + return [ + m for m in sys.modules.keys() if m.startswith(self.MOD_START) and all(s not in m for s in self.excludes) + ] + + def _find_classes_or_functions(self, modnames: Union[Sequence[str], str]) -> Dict[str, List]: + """ + Find all the classes and functions in the modules with specified `modnames`. + + Args: + modnames: names of the target modules to find all the classes and functions. + + """ + table: Dict[str, List] = {} + # all the MONAI modules are already loaded by `load_submodules` + for modname in ensure_tuple(modnames): + try: + # scan all the classes and functions in the module + module = import_module(modname) + for name, obj in inspect.getmembers(module): + if (inspect.isclass(obj) or inspect.isfunction(obj)) and obj.__module__ == modname: + if name not in table: + table[name] = [] + table[name].append(modname) + except ModuleNotFoundError: + pass + return table + + def get_component_module_name(self, name: str) -> Optional[Union[List[str], str]]: + """ + Get the full module name of the class or function with specified ``name``. + If target component name exists in multiple packages or modules, return a list of full module names. + + Args: + name: name of the expected class or function. + + """ + if not isinstance(name, str): + raise ValueError(f"`name` must be a valid string, but got: {name}.") + if self._components_table is None: + # init component and module mapping table + self._components_table = self._find_classes_or_functions(self._find_module_names()) + + mods: Optional[Union[List[str], str]] = self._components_table.get(name, None) + if isinstance(mods, list) and len(mods) == 1: + mods = mods[0] + return mods + + +class ConfigItem: + """ + Basic data structure to represent a configuration item. + + A `ConfigItem` instance can optionally have a string id, so that other items can refer to it. + It has a build-in `config` property to store the configuration object. + + Args: + config: content of a config item, can be objects of any types, + a configuration resolver may interpret the content to generate a configuration object. + id: name of the current config item, defaults to empty string. + + """ + + def __init__(self, config: Any, id: str = "") -> None: + self.config = config + self.id = id + + def get_id(self) -> str: + """ + Get the ID name of current config item, useful to identify config items during parsing. + + """ + return self.id + + def update_config(self, config: Any): + """ + Replace the content of `self.config` with new `config`. + A typical usage is to modify the initial config content at runtime. + + Args: + config: content of a `ConfigItem`. + + """ + self.config = config + + def get_config(self): + """ + Get the config content of current config item. + + """ + return self.config + + def __repr__(self) -> str: + return str(self.config) + + +class ConfigComponent(ConfigItem, Instantiable): + """ + Subclass of :py:class:`monai.bundle.ConfigItem`, this class uses a dictionary with string keys to + represent a component of `class` or `function` and supports instantiation. + + Currently, three special keys (strings surrounded by ``_``) are defined and interpreted beyond the regular literals: + + - class or function identifier of the python module, specified by ``"_target_"``, + indicating a build-in python class or function such as ``"LoadImageDict"``, + or a full module name, such as ``"monai.transforms.LoadImageDict"``. + - ``"_requires_"`` (optional): specifies reference IDs (string starts with ``"@"``) or ``ConfigExpression`` + of the dependencies for this ``ConfigComponent`` object. These dependencies will be + evaluated/instantiated before this object is instantiated. It is useful when the + component doesn't explicitly depends on the other `ConfigItems` via its arguments, + but requires the dependencies to be instantiated/evaluated beforehand. + - ``"_disabled_"`` (optional): a flag to indicate whether to skip the instantiation. + + Other fields in the config content are input arguments to the python module. + + .. code-block:: python + + from monai.bundle import ComponentLocator, ConfigComponent + + locator = ComponentLocator(excludes=["modules_to_exclude"]) + config = { + "_target_": "LoadImaged", + "keys": ["image", "label"] + } + + configer = ConfigComponent(config, id="test", locator=locator) + image_loader = configer.instantiate() + print(image_loader) # + + Args: + config: content of a config item. + id: name of the current config item, defaults to empty string. + locator: a ``ComponentLocator`` to convert a module name string into the actual python module. + if `None`, a ``ComponentLocator(excludes=excludes)`` will be used. + excludes: if ``locator`` is None, create a new ``ComponentLocator`` with ``excludes``. + See also: :py:class:`monai.bundle.ComponentLocator`. + + """ + + non_arg_keys = {"_target_", "_disabled_", "_requires_"} + + def __init__( + self, + config: Any, + id: str = "", + locator: Optional[ComponentLocator] = None, + excludes: Optional[Union[Sequence[str], str]] = None, + ) -> None: + super().__init__(config=config, id=id) + self.locator = ComponentLocator(excludes=excludes) if locator is None else locator + + @staticmethod + def is_instantiable(config: Any) -> bool: + """ + Check whether this config represents a `class` or `function` that is to be instantiated. + + Args: + config: input config content to check. + + """ + return isinstance(config, Mapping) and "_target_" in config + + def resolve_module_name(self): + """ + Resolve the target module name from current config content. + The config content must have ``"_target_"`` key. + + """ + config = dict(self.get_config()) + target = config.get("_target_") + if not isinstance(target, str): + raise ValueError("must provide a string for the `_target_` of component to instantiate.") + + module = self.locator.get_component_module_name(target) + if module is None: + # target is the full module name, no need to parse + return target + + if isinstance(module, list): + warnings.warn( + f"there are more than 1 component have name `{target}`: {module}, use the first one `{module[0]}." + f" if want to use others, please set its full module path in `_target_` directly." + ) + module = module[0] + return f"{module}.{target}" + + def resolve_args(self): + """ + Utility function used in `instantiate()` to resolve the arguments from current config content. + + """ + return {k: v for k, v in self.get_config().items() if k not in self.non_arg_keys} + + def is_disabled(self) -> bool: # type: ignore + """ + Utility function used in `instantiate()` to check whether to skip the instantiation. + + """ + _is_disabled = self.get_config().get("_disabled_", False) + return _is_disabled.lower().strip() == "true" if isinstance(_is_disabled, str) else bool(_is_disabled) + + def instantiate(self, **kwargs) -> object: # type: ignore + """ + Instantiate component based on ``self.config`` content. + The target component must be a `class` or a `function`, otherwise, return `None`. + + Args: + kwargs: args to override / add the config args when instantiation. + + """ + if not self.is_instantiable(self.get_config()) or self.is_disabled(): + # if not a class or function or marked as `disabled`, skip parsing and return `None` + return None + + modname = self.resolve_module_name() + args = self.resolve_args() + args.update(kwargs) + return instantiate(modname, **args) + + +class ConfigExpression(ConfigItem): + """ + Subclass of :py:class:`monai.bundle.ConfigItem`, the `ConfigItem` represents an executable expression + (execute based on ``eval()``, or import the module to the `globals` if it's an import statement). + + See also: + + - https://docs.python.org/3/library/functions.html#eval. + + For example: + + .. code-block:: python + + import monai + from monai.bundle import ConfigExpression + + config = "$monai.__version__" + expression = ConfigExpression(config, id="test", globals={"monai": monai}) + print(expression.execute()) + + Args: + config: content of a config item. + id: name of current config item, defaults to empty string. + globals: additional global context to evaluate the string. + + """ + + prefix = EXPR_KEY + run_eval = False if os.environ.get("MONAI_EVAL_EXPR", "1") == "0" else True + + def __init__(self, config: Any, id: str = "", globals: Optional[Dict] = None) -> None: + super().__init__(config=config, id=id) + self.globals = globals if globals is not None else {} + + def _parse_import_string(self, import_string: str): + """parse single import statement such as "from monai.transforms import Resize""" + node = first(ast.iter_child_nodes(ast.parse(import_string))) + if not isinstance(node, (ast.Import, ast.ImportFrom)): + return None + if len(node.names) < 1: + return None + if len(node.names) > 1: + warnings.warn(f"ignoring multiple import alias '{import_string}'.") + name, asname = f"{node.names[0].name}", node.names[0].asname + asname = name if asname is None else f"{asname}" + if isinstance(node, ast.ImportFrom): + self.globals[asname], _ = optional_import(f"{node.module}", name=f"{name}") + return self.globals[asname] + if isinstance(node, ast.Import): + self.globals[asname], _ = optional_import(f"{name}") + return self.globals[asname] + return None + + def evaluate(self, locals: Optional[Dict] = None): + """ + Execute the current config content and return the result if it is expression, based on Python `eval()`. + For more details: https://docs.python.org/3/library/functions.html#eval. + + Args: + locals: besides ``globals``, may also have some local symbols used in the expression at runtime. + + """ + value = self.get_config() + if not ConfigExpression.is_expression(value): + return None + optional_module = self._parse_import_string(value[len(self.prefix) :]) + if optional_module is not None: + return optional_module + if not self.run_eval: + return f"{value[len(self.prefix) :]}" + return eval(value[len(self.prefix) :], self.globals, locals) + + @classmethod + def is_expression(cls, config: Union[Dict, List, str]) -> bool: + """ + Check whether the config is an executable expression string. + Currently, a string starts with ``"$"`` character is interpreted as an expression. + + Args: + config: input config content to check. + + """ + return isinstance(config, str) and config.startswith(cls.prefix) + + @classmethod + def is_import_statement(cls, config: Union[Dict, List, str]) -> bool: + """ + Check whether the config is an import statement (a special case of expression). + + Args: + config: input config content to check. + """ + if not cls.is_expression(config): + return False + if "import" not in config: + return False + return isinstance( + first(ast.iter_child_nodes(ast.parse(f"{config[len(cls.prefix) :]}"))), (ast.Import, ast.ImportFrom) + ) diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py new file mode 100644 index 0000000000..800e18ade0 --- /dev/null +++ b/monai/bundle/config_parser.py @@ -0,0 +1,426 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem +from monai.bundle.reference_resolver import ReferenceResolver +from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY +from monai.config import PathLike +from monai.utils import ensure_tuple, look_up_option, optional_import + +yaml, _ = optional_import("yaml") + +__all__ = ["ConfigParser"] + +_default_globals = {"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"} + + +class ConfigParser: + """ + The primary configuration parser. It traverses a structured config (in the form of nested Python dict or list), + creates ``ConfigItem``, and assign unique IDs according to the structures. + + This class provides convenient access to the set of ``ConfigItem`` of the config by ID. + A typical workflow of config parsing is as follows: + + - Initialize ``ConfigParser`` with the ``config`` source. + - Call ``get_parsed_content()`` to get expected component with `id`. + + .. code-block:: python + + from monai.bundle import ConfigParser + + config = { + "my_dims": 2, + "dims_1": "$@my_dims + 1", + "my_xform": {"_target_": "LoadImage"}, + "my_net": {"_target_": "BasicUNet", "spatial_dims": "@dims_1", "in_channels": 1, "out_channels": 4}, + "trainer": {"_target_": "SupervisedTrainer", "network": "@my_net", "preprocessing": "@my_xform"} + } + # in the example $@my_dims + 1 is an expression, which adds 1 to the value of @my_dims + parser = ConfigParser(config) + + # get/set configuration content, the set method should happen before calling parse() + print(parser["my_net"]["in_channels"]) # original input channels 1 + parser["my_net"]["in_channels"] = 4 # change input channels to 4 + print(parser["my_net"]["in_channels"]) + + # instantiate the network component + parser.parse(True) + net = parser.get_parsed_content("my_net", instantiate=True) + print(net) + + # also support to get the configuration content of parsed `ConfigItem` + trainer = parser.get_parsed_content("trainer", instantiate=False) + print(trainer) + + Args: + config: input config source to parse. + excludes: when importing modules to instantiate components, + excluding components from modules specified in ``excludes``. + globals: pre-import packages as global variables to ``ConfigExpression``, + so that expressions, for example, ``"$monai.data.list_data_collate"`` can use ``monai`` modules. + The current supported globals and alias names are + ``{"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"}``. + These are MONAI's minimal dependencies. Additional packages could be included with `globals={"itk": "itk"}`. + + See also: + + - :py:class:`monai.bundle.ConfigItem` + - :py:class:`monai.bundle.scripts.run` + + """ + + suffixes = ("json", "yaml", "yml") + suffix_match = rf".*\.({'|'.join(suffixes)})" + path_match = rf"({suffix_match}$)" + # match relative id names, e.g. "@#data", "@##transform#1" + relative_id_prefix = re.compile(rf"(?:{ID_REF_KEY}|{MACRO_KEY}){ID_SEP_KEY}+") + meta_key = "_meta_" # field key to save metadata + + def __init__( + self, + config: Any = None, + excludes: Optional[Union[Sequence[str], str]] = None, + globals: Optional[Dict[str, Any]] = None, + ): + self.config = None + self.globals: Dict[str, Any] = {} + _globals = _default_globals.copy() + if isinstance(_globals, dict) and globals is not None: + _globals.update(globals) + if _globals is not None: + for k, v in _globals.items(): + self.globals[k] = optional_import(v)[0] if isinstance(v, str) else v + + self.locator = ComponentLocator(excludes=excludes) + self.ref_resolver = ReferenceResolver() + if config is None: + config = {self.meta_key: {}} + self.set(config=config) + + def __repr__(self): + return f"{self.config}" + + def __getitem__(self, id: Union[str, int]): + """ + Get the config by id. + + Args: + id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to + go one level further into the nested structures. + Use digits indexing from "0" for list or other strings for dict. + For example: ``"xform#5"``, ``"net#channels"``. ``""`` indicates the entire ``self.config``. + + """ + if id == "": + return self.config + config = self.config + for k in str(id).split(ID_SEP_KEY): + if not isinstance(config, (dict, list)): + raise ValueError(f"config must be dict or list for key `{k}`, but got {type(config)}: {config}.") + indexing = k if isinstance(config, dict) else int(k) + config = config[indexing] + return config + + def __setitem__(self, id: Union[str, int], config: Any): + """ + Set config by ``id``. Note that this method should be used before ``parse()`` or ``get_parsed_content()`` + to ensure the updates are included in the parsed content. + + Args: + id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to + go one level further into the nested structures. + Use digits indexing from "0" for list or other strings for dict. + For example: ``"xform#5"``, ``"net#channels"``. ``""`` indicates the entire ``self.config``. + config: config to set at location ``id``. + + """ + if id == "": + self.config = config + self.ref_resolver.reset() + return + keys = str(id).split(ID_SEP_KEY) + # get the last parent level config item and replace it + last_id = ID_SEP_KEY.join(keys[:-1]) + conf_ = self[last_id] + indexing = keys[-1] if isinstance(conf_, dict) else int(keys[-1]) + conf_[indexing] = config + self.ref_resolver.reset() + return + + def get(self, id: str = "", default: Optional[Any] = None): + """ + Get the config by id. + + Args: + id: id to specify the expected position. See also :py:meth:`__getitem__`. + default: default value to return if the specified ``id`` is invalid. + + """ + try: + return self[id] + except KeyError: + return default + + def set(self, config: Any, id: str = ""): + """ + Set config by ``id``. See also :py:meth:`__setitem__`. + + """ + self[id] = config + + def parse(self, reset: bool = True): + """ + Recursively resolve `self.config` to replace the macro tokens with target content. + Then recursively parse the config source, add every item as ``ConfigItem`` to the reference resolver. + + Args: + reset: whether to reset the ``reference_resolver`` before parsing. Defaults to `True`. + + """ + if reset: + self.ref_resolver.reset() + self.resolve_macro_and_relative_ids() + self._do_parse(config=self.get()) + + def get_parsed_content(self, id: str = "", **kwargs): + """ + Get the parsed result of ``ConfigItem`` with the specified ``id``. + + - If the item is ``ConfigComponent`` and ``instantiate=True``, the result is the instance. + - If the item is ``ConfigExpression`` and ``eval_expr=True``, the result is the evaluated output. + - Else, the result is the configuration content of `ConfigItem`. + + Args: + id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to + go one level further into the nested structures. + Use digits indexing from "0" for list or other strings for dict. + For example: ``"xform#5"``, ``"net#channels"``. ``""`` indicates the entire ``self.config``. + kwargs: additional keyword arguments to be passed to ``_resolve_one_item``. + Currently support ``reset`` (for parse), ``instantiate`` and ``eval_expr``. All defaulting to True. + + """ + if not self.ref_resolver.is_resolved(): + # not parsed the config source yet, parse it + self.parse(kwargs.get("reset", True)) + return self.ref_resolver.get_resolved_content(id=id, **kwargs) + + def read_meta(self, f: Union[PathLike, Sequence[PathLike], Dict], **kwargs): + """ + Read the metadata from specified JSON or YAML file. + The metadata as a dictionary will be stored at ``self.config["_meta_"]``. + + Args: + f: filepath of the metadata file, the content must be a dictionary, + if providing a list of files, wil merge the content of them. + if providing a dictionary directly, use it as metadata. + kwargs: other arguments for ``json.load`` or ``yaml.safe_load``, depends on the file format. + + """ + self.set(self.load_config_files(f, **kwargs), self.meta_key) + + def read_config(self, f: Union[PathLike, Sequence[PathLike], Dict], **kwargs): + """ + Read the config from specified JSON or YAML file. + The config content in the `self.config` dictionary. + + Args: + f: filepath of the config file, the content must be a dictionary, + if providing a list of files, wil merge the content of them. + if providing a dictionary directly, use it as config. + kwargs: other arguments for ``json.load`` or ``yaml.safe_load``, depends on the file format. + + """ + content = {self.meta_key: self.get(self.meta_key, {})} + content.update(self.load_config_files(f, **kwargs)) + self.set(config=content) + + def _do_resolve(self, config: Any, id: str = ""): + """ + Recursively resolve `self.config` to replace the relative ids with absolute ids, for example, + `@##A` means `A` in the upper level. and replace the macro tokens with target content, + The macro tokens start with "%", can be from another structured file, like: + ``"%default_net"``, ``"%/data/config.json#net"``. + + Args: + config: input config file to resolve. + id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to + go one level further into the nested structures. + Use digits indexing from "0" for list or other strings for dict. + For example: ``"xform#5"``, ``"net#channels"``. ``""`` indicates the entire ``self.config``. + + """ + if isinstance(config, (dict, list)): + for k, v in enumerate(config) if isinstance(config, list) else config.items(): + sub_id = f"{id}{ID_SEP_KEY}{k}" if id != "" else k + config[k] = self._do_resolve(v, sub_id) + if isinstance(config, str): + config = self.resolve_relative_ids(id, config) + if config.startswith(MACRO_KEY): + path, ids = ConfigParser.split_path_id(config[len(MACRO_KEY) :]) + parser = ConfigParser(config=self.get() if not path else ConfigParser.load_config_file(path)) + return self._do_resolve(config=deepcopy(parser[ids])) + return config + + def resolve_macro_and_relative_ids(self): + """ + Recursively resolve `self.config` to replace the relative ids with absolute ids, for example, + `@##A` means `A` in the upper level. and replace the macro tokens with target content, + The macro tokens are marked as starting with "%", can be from another structured file, like: + ``"%default_net"``, ``"%/data/config.json#net"``. + + """ + self.set(self._do_resolve(config=deepcopy(self.get()))) + + def _do_parse(self, config, id: str = ""): + """ + Recursively parse the nested data in config source, add every item as `ConfigItem` to the resolver. + + Args: + config: config source to parse. + id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to + go one level further into the nested structures. + Use digits indexing from "0" for list or other strings for dict. + For example: ``"xform#5"``, ``"net#channels"``. ``""`` indicates the entire ``self.config``. + + """ + if isinstance(config, (dict, list)): + for k, v in enumerate(config) if isinstance(config, list) else config.items(): + sub_id = f"{id}{ID_SEP_KEY}{k}" if id != "" else k + self._do_parse(config=v, id=sub_id) + + # copy every config item to make them independent and add them to the resolver + item_conf = deepcopy(config) + if ConfigComponent.is_instantiable(item_conf): + self.ref_resolver.add_item(ConfigComponent(config=item_conf, id=id, locator=self.locator)) + elif ConfigExpression.is_expression(item_conf): + self.ref_resolver.add_item(ConfigExpression(config=item_conf, id=id, globals=self.globals)) + else: + self.ref_resolver.add_item(ConfigItem(config=item_conf, id=id)) + + @classmethod + def load_config_file(cls, filepath: PathLike, **kwargs): + """ + Load config file with specified file path (currently support JSON and YAML files). + + Args: + filepath: path of target file to load, supported postfixes: `.json`, `.yml`, `.yaml`. + kwargs: other arguments for ``json.load`` or ```yaml.safe_load``, depends on the file format. + + """ + _filepath: str = str(Path(filepath)) + if not re.compile(cls.path_match, re.IGNORECASE).findall(_filepath): + raise ValueError(f'unknown file input: "{filepath}"') + with open(_filepath) as f: + if _filepath.lower().endswith(cls.suffixes[0]): + return json.load(f, **kwargs) + if _filepath.lower().endswith(cls.suffixes[1:]): + return yaml.safe_load(f, **kwargs) + raise ValueError(f"only support JSON or YAML config file so far, got name {_filepath}.") + + @classmethod + def load_config_files(cls, files: Union[PathLike, Sequence[PathLike], dict], **kwargs) -> dict: + """ + Load config files into a single config dict. + + Args: + files: path of target files to load, supported postfixes: `.json`, `.yml`, `.yaml`. + kwargs: other arguments for ``json.load`` or ```yaml.safe_load``, depends on the file format. + """ + if isinstance(files, dict): # already a config dict + return files + content = {} + for i in ensure_tuple(files): + content.update(cls.load_config_file(i, **kwargs)) + return content + + @classmethod + def export_config_file(cls, config: Dict, filepath: PathLike, fmt="json", **kwargs): + """ + Export the config content to the specified file path (currently support JSON and YAML files). + + Args: + config: source config content to export. + filepath: target file path to save. + fmt: format of config content, currently support ``"json"`` and ``"yaml"``. + kwargs: other arguments for ``json.dump`` or ``yaml.safe_dump``, depends on the file format. + + """ + _filepath: str = str(Path(filepath)) + writer = look_up_option(fmt.lower(), {"json", "yaml"}) + with open(_filepath, "w") as f: + if writer == "json": + return json.dump(config, f, **kwargs) + if writer == "yaml": + return yaml.safe_dump(config, f, **kwargs) + raise ValueError(f"only support JSON or YAML config file so far, got {writer}.") + + @classmethod + def split_path_id(cls, src: str) -> Tuple[str, str]: + """ + Split `src` string into two parts: a config file path and component id. + The file path should end with `(json|yaml|yml)`. The component id should be separated by `#` if it exists. + If no path or no id, return "". + + Args: + src: source string to split. + + """ + result = re.compile(rf"({cls.suffix_match}(?=(?:{ID_SEP_KEY}.*)|$))", re.IGNORECASE).findall(src) + if not result: + return "", src # the src is a pure id + path_name = result[0][0] # at most one path_name + _, ids = src.rsplit(path_name, 1) + return path_name, ids[len(ID_SEP_KEY) :] if ids.startswith(ID_SEP_KEY) else "" + + @classmethod + def resolve_relative_ids(cls, id: str, value: str) -> str: + """ + To simplify the reference or macro tokens ID in the nested config content, it's available to use + relative ID name which starts with the `ID_SEP_KEY`, for example, "@#A" means `A` in the same level, + `@##A` means `A` in the upper level. + It resolves the relative ids to absolute ids. For example, if the input data is: + + .. code-block:: python + + { + "A": 1, + "B": {"key": "@##A", "value1": 2, "value2": "%#value1", "value3": [3, 4, "@#1"]}, + } + + It will resolve `B` to `{"key": "@A", "value1": 2, "value2": "%B#value1", "value3": [3, 4, "@B#value3#1"]}`. + + Args: + id: id name for current config item to compute relative id. + value: input value to resolve relative ids. + + """ + # get the prefixes like: "@####", "%###", "@#" + prefixes = sorted(set().union(cls.relative_id_prefix.findall(value)), reverse=True) + current_id = id.split(ID_SEP_KEY) + + for p in prefixes: + sym = ID_REF_KEY if ID_REF_KEY in p else MACRO_KEY + length = p[len(sym) :].count(ID_SEP_KEY) + if length > len(current_id): + raise ValueError(f"the relative id in `{value}` is out of the range of config content.") + if length == len(current_id): + new = "" # root id is `""` + else: + new = ID_SEP_KEY.join(current_id[:-length]) + ID_SEP_KEY + value = value.replace(p, sym + new) + return value diff --git a/monai/bundle/reference_resolver.py b/monai/bundle/reference_resolver.py new file mode 100644 index 0000000000..f9f73c9c71 --- /dev/null +++ b/monai/bundle/reference_resolver.py @@ -0,0 +1,276 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any, Dict, Optional, Sequence, Set + +from monai.bundle.config_item import ConfigComponent, ConfigExpression, ConfigItem +from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY +from monai.utils import look_up_option + +__all__ = ["ReferenceResolver"] + + +class ReferenceResolver: + """ + Utility class to manage a set of ``ConfigItem`` and resolve the references between them. + + This class maintains a set of ``ConfigItem`` objects and their associated IDs. + The IDs must be unique within this set. A string in ``ConfigItem`` + starting with ``@`` will be treated as a reference to other ``ConfigItem`` objects by ID. + Since ``ConfigItem`` may have a nested dictionary or list structure, + the reference string may also contain a ``#`` character to refer to a substructure by + key indexing for a dictionary or integer indexing for a list. + + In this class, resolving references is essentially substitution of the reference strings with the + corresponding python objects. A typical workflow of resolving references is as follows: + + - Add multiple ``ConfigItem`` objects to the ``ReferenceResolver`` by ``add_item()``. + - Call ``get_resolved_content()`` to automatically resolve the references. This is done (recursively) by: + - Convert the items to objects, for those do not have references to other items. + - If it is instantiable, instantiate it and cache the class instance in ``resolved_content``. + - If it is an expression, evaluate it and save the value in ``resolved_content``. + - Substitute the reference strings with the corresponding objects. + + Args: + items: ``ConfigItem``s to resolve, this could be added later with ``add_item()``. + + """ + + _vars = "__local_refs" + sep = ID_SEP_KEY # separator for key indexing + ref = ID_REF_KEY # reference prefix + # match a reference string, e.g. "@id#key", "@id#key#0", "@_target_#key" + id_matcher = re.compile(rf"{ref}(?:\w*)(?:{sep}\w*)*") + + def __init__(self, items: Optional[Sequence[ConfigItem]] = None): + # save the items in a dictionary with the `ConfigItem.id` as key + self.items: Dict[str, Any] = {} if items is None else {i.get_id(): i for i in items} + self.resolved_content: Dict[str, Any] = {} + + def reset(self): + """ + Clear all the added `ConfigItem` and all the resolved content. + + """ + self.items = {} + self.resolved_content = {} + + def is_resolved(self) -> bool: + return bool(self.resolved_content) + + def add_item(self, item: ConfigItem): + """ + Add a ``ConfigItem`` to the resolver. + + Args: + item: a ``ConfigItem``. + + """ + id = item.get_id() + if id in self.items: + return + self.items[id] = item + + def get_item(self, id: str, resolve: bool = False, **kwargs): + """ + Get the ``ConfigItem`` by id. + + If ``resolve=True``, the returned item will be resolved, that is, + all the reference strings are substituted by the corresponding ``ConfigItem`` objects. + + Args: + id: id of the expected config item. + resolve: whether to resolve the item if it is not resolved, default to False. + kwargs: keyword arguments to pass to ``_resolve_one_item()``. + Currently support ``instantiate`` and ``eval_expr``. Both are defaulting to True. + """ + if resolve and id not in self.resolved_content: + self._resolve_one_item(id=id, **kwargs) + return self.items.get(id) + + def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, **kwargs): + """ + Resolve and return one ``ConfigItem`` of ``id``, cache the resolved result in ``resolved_content``. + If it has unresolved references, recursively resolve the referring items first. + + Args: + id: id name of ``ConfigItem`` to be resolved. + waiting_list: set of ids pending to be resolved. + It's used to detect circular references such as: + `{"name": "A", "dep": "@B"}` and `{"name": "B", "dep": "@A"}`. + kwargs: keyword arguments to pass to ``_resolve_one_item()``. + Currently support ``instantiate`` and ``eval_expr``. Both are defaulting to True. + + """ + if id in self.resolved_content: + return self.resolved_content[id] + try: + item = look_up_option(id, self.items, print_all_options=False) + except ValueError as err: + raise KeyError(f"id='{id}' is not found in the config resolver.") from err + item_config = item.get_config() + + if waiting_list is None: + waiting_list = set() + waiting_list.add(id) + + for t, v in self.items.items(): + if ( + t not in self.resolved_content + and isinstance(v, ConfigExpression) + and v.is_import_statement(v.get_config()) + ): + self.resolved_content[t] = v.evaluate() if kwargs.get("eval_expr", True) else v + for d in self.find_refs_in_config(config=item_config, id=id): + # if current item has reference already in the waiting list, that's circular references + if d in waiting_list: + raise ValueError(f"detected circular references '{d}' for id='{id}' in the config content.") + # check whether the component has any unresolved references + if d not in self.resolved_content: + # this referring item is not resolved + try: + look_up_option(d, self.items, print_all_options=False) + except ValueError as err: + raise ValueError(f"the referring item `@{d}` is not defined in the config content.") from err + # recursively resolve the reference first + self._resolve_one_item(id=d, waiting_list=waiting_list, **kwargs) + waiting_list.discard(d) + + # all references are resolved, then try to resolve current config item + new_config = self.update_config_with_refs(config=item_config, id=id, refs=self.resolved_content) + item.update_config(config=new_config) + # save the resolved result into `resolved_content` to recursively resolve others + if isinstance(item, ConfigComponent): + self.resolved_content[id] = item.instantiate() if kwargs.get("instantiate", True) else item + elif isinstance(item, ConfigExpression): + run_eval = kwargs.get("eval_expr", True) + self.resolved_content[id] = ( + item.evaluate(locals={f"{self._vars}": self.resolved_content}) if run_eval else item + ) + else: + self.resolved_content[id] = new_config + return self.resolved_content[id] + + def get_resolved_content(self, id: str, **kwargs): + """ + Get the resolved ``ConfigItem`` by id. + + Args: + id: id name of the expected item. + kwargs: additional keyword arguments to be passed to ``_resolve_one_item``. + Currently support ``instantiate`` and ``eval_expr``. Both are defaulting to True. + + """ + return self._resolve_one_item(id=id, **kwargs) + + @classmethod + def match_refs_pattern(cls, value: str) -> Set[str]: + """ + Match regular expression for the input string to find the references. + The reference string starts with ``"@"``, like: ``"@XXX#YYY#ZZZ"``. + + Args: + value: input value to match regular expression. + + """ + refs: Set[str] = set() + # regular expression pattern to match "@XXX" or "@XXX#YYY" + result = cls.id_matcher.findall(value) + value_is_expr = ConfigExpression.is_expression(value) + for item in result: + if value_is_expr or value == item: + # only check when string starts with "$" or the whole content is "@XXX" + refs.add(item[len(cls.ref) :]) + return refs + + @classmethod + def update_refs_pattern(cls, value: str, refs: Dict) -> str: + """ + Match regular expression for the input string to update content with the references. + The reference part starts with ``"@"``, like: ``"@XXX#YYY#ZZZ"``. + References dictionary must contain the referring IDs as keys. + + Args: + value: input value to match regular expression. + refs: all the referring components with ids as keys, default to `None`. + + """ + # regular expression pattern to match "@XXX" or "@XXX#YYY" + result = cls.id_matcher.findall(value) + value_is_expr = ConfigExpression.is_expression(value) + for item in result: + ref_id = item[len(cls.ref) :] # remove the ref prefix "@" + if ref_id not in refs: + raise KeyError(f"can not find expected ID '{ref_id}' in the references.") + if value_is_expr: + # replace with local code, will be used in the `evaluate` logic with `locals={"refs": ...}` + value = value.replace(item, f"{cls._vars}['{ref_id}']") + elif value == item: + # the whole content is "@XXX", it will avoid the case that regular string contains "@" + value = refs[ref_id] + return value + + @classmethod + def find_refs_in_config(cls, config, id: str, refs: Optional[Set[str]] = None) -> Set[str]: + """ + Recursively search all the content of input config item to get the ids of references. + References mean: the IDs of other config items (``"@XXX"`` in this config item), or the + sub-item in the config is `instantiable`, or the sub-item in the config is `expression`. + For `dict` and `list`, recursively check the sub-items. + + Args: + config: input config content to search. + id: ID name for the input config item. + refs: list of the ID name of found references, default to `None`. + + """ + refs_: Set[str] = refs or set() + if isinstance(config, str): + return refs_.union(cls.match_refs_pattern(value=config)) + if not isinstance(config, (list, dict)): + return refs_ + for k, v in config.items() if isinstance(config, dict) else enumerate(config): + sub_id = f"{id}{cls.sep}{k}" if id != "" else f"{k}" + if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v): + refs_.add(sub_id) + refs_ = cls.find_refs_in_config(v, sub_id, refs_) + return refs_ + + @classmethod + def update_config_with_refs(cls, config, id: str, refs: Optional[Dict] = None): + """ + With all the references in ``refs``, update the input config content with references + and return the new config. + + Args: + config: input config content to update. + id: ID name for the input config. + refs: all the referring content with ids, default to `None`. + + """ + refs_: Dict = refs or {} + if isinstance(config, str): + return cls.update_refs_pattern(config, refs_) + if not isinstance(config, (list, dict)): + return config + ret = type(config)() + for idx, v in config.items() if isinstance(config, dict) else enumerate(config): + sub_id = f"{id}{cls.sep}{idx}" if id != "" else f"{idx}" + if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v): + updated = refs_[sub_id] + if ConfigComponent.is_instantiable(v) and updated is None: + # the component is disabled + continue + else: + updated = cls.update_config_with_refs(v, sub_id, refs_) + ret.update({idx: updated}) if isinstance(ret, dict) else ret.append(updated) + return ret diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py new file mode 100644 index 0000000000..5bbde5fd62 --- /dev/null +++ b/monai/bundle/scripts.py @@ -0,0 +1,323 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import pprint +import re +from typing import Dict, Optional, Sequence, Tuple, Union + +import torch + +from monai.apps.utils import download_url, get_logger +from monai.bundle.config_parser import ConfigParser +from monai.config import PathLike +from monai.utils import check_parent_dir, get_equivalent_dtype, optional_import + +validate, _ = optional_import("jsonschema", name="validate") +ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError") + +logger = get_logger(module_name=__name__) + + +def _update_args(args: Optional[Union[str, Dict]] = None, ignore_none: bool = True, **kwargs) -> Dict: + """ + Update the `args` with the input `kwargs`. + For dict data, recursively update the content based on the keys. + + Args: + args: source args to update. + ignore_none: whether to ignore input args with None value, default to `True`. + kwargs: destination args to update. + + """ + args_: Dict = args if isinstance(args, dict) else {} # type: ignore + if isinstance(args, str): + # args are defined in a structured file + args_ = ConfigParser.load_config_file(args) + + # recursively update the default args with new args + for k, v in kwargs.items(): + if ignore_none and v is None: + continue + if isinstance(v, dict) and isinstance(args_.get(k), dict): + args_[k] = _update_args(args_[k], ignore_none, **v) + else: + args_[k] = v + return args_ + + +def _log_input_summary(tag, args: Dict): + logger.info(f"--- input summary of monai.bundle.scripts.{tag} ---") + for name, val in args.items(): + logger.info(f"> {name}: {pprint.pformat(val)}") + logger.info("---\n\n") + + +def _get_var_names(expr: str): + """ + Parse the expression and discover what variables are present in it based on ast module. + + Args: + expr: source expression to parse. + + """ + tree = ast.parse(expr) + return [m.id for m in ast.walk(tree) if isinstance(m, ast.Name)] + + +def _get_fake_spatial_shape(shape: Sequence[Union[str, int]], p: int = 1, n: int = 1, any: int = 1) -> Tuple: + """ + Get spatial shape for fake data according to the specified shape pattern. + It supports `int` number and `string` with formats like: "32", "32 * n", "32 ** p", "32 ** p *n". + + Args: + shape: specified pattern for the spatial shape. + p: power factor to generate fake data shape if dim of expected shape is "x**p", default to 1. + p: multiply factor to generate fake data shape if dim of expected shape is "x*n", default to 1. + any: specified size to generate fake data shape if dim of expected shape is "*", default to 1. + + """ + ret = [] + for i in shape: + if isinstance(i, int): + ret.append(i) + elif isinstance(i, str): + if i == "*": + ret.append(any) + else: + for c in _get_var_names(i): + if c not in ["p", "n"]: + raise ValueError(f"only support variables 'p' and 'n' so far, but got: {c}.") + ret.append(eval(i, {"p": p, "n": n})) + else: + raise ValueError(f"spatial shape items must be int or string, but got: {type(i)} {i}.") + return tuple(ret) + + +def run( + runner_id: Optional[str] = None, + meta_file: Optional[Union[str, Sequence[str]]] = None, + config_file: Optional[Union[str, Sequence[str]]] = None, + args_file: Optional[str] = None, + **override, +): + """ + Specify `meta_file` and `config_file` to run monai bundle components and workflows. + + Typical usage examples: + + .. code-block:: bash + + # Execute this module as a CLI entry: + python -m monai.bundle run trainer --meta_file --config_file + + # Override config values at runtime by specifying the component id and its new value: + python -m monai.bundle run trainer --net#input_chns 1 ... + + # Override config values with another config file `/path/to/another.json`: + python -m monai.bundle run evaluator --net %/path/to/another.json ... + + # Override config values with part content of another config file: + python -m monai.bundle run trainer --net %/data/other.json#net_arg ... + + # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line. + # Other args still can override the default args at runtime: + python -m monai.bundle run --args_file "/workspace/data/args.json" --config_file + + Args: + runner_id: ID name of the runner component or workflow, it must have a `run` method. Defaults to ``""``. + meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged. + config_file: filepath of the config file, if `None`, must be provided in `args_file`. + if it is a list of file paths, the content of them will be merged. + args_file: a JSON or YAML file to provide default values for `runner_id`, `meta_file`, + `config_file`, and override pairs. so that the command line inputs can be simplified. + override: id-value pairs to override or add the corresponding config content. + e.g. ``--net#input_chns 42``. + + """ + + _args = _update_args(args=args_file, runner_id=runner_id, meta_file=meta_file, config_file=config_file, **override) + if "config_file" not in _args: + raise ValueError(f"`config_file` is required for 'monai.bundle run'.\n{run.__doc__}") + _log_input_summary(tag="run", args=_args) + + parser = ConfigParser() + parser.read_config(f=_args.pop("config_file")) + if "meta_file" in _args: + parser.read_meta(f=_args.pop("meta_file")) + id = _args.pop("runner_id", "") + + # the rest key-values in the _args are to override config content + for k, v in _args.items(): + parser[k] = v + + workflow = parser.get_parsed_content(id=id) + if not hasattr(workflow, "run"): + raise ValueError(f"The parsed workflow {type(workflow)} (id={id}) does not have a `run` method.\n{run.__doc__}") + return workflow.run() + + +def verify_metadata( + meta_file: Optional[Union[str, Sequence[str]]] = None, + filepath: Optional[PathLike] = None, + create_dir: Optional[bool] = None, + hash_val: Optional[str] = None, + hash_type: Optional[str] = None, + args_file: Optional[str] = None, + **kwargs, +): + """ + Verify the provided `metadata` file based on the predefined `schema`. + `metadata` content must contain the `schema` field for the URL of shcema file to download. + The schema standard follows: http://json-schema.org/. + + Args: + meta_file: filepath of the metadata file to verify, if `None`, must be provided in `args_file`. + if it is a list of file paths, the content of them will be merged. + filepath: file path to store the downloaded schema. + create_dir: whether to create directories if not existing, default to `True`. + hash_val: if not None, define the hash value to verify the downloaded schema file. + hash_type: if not None, define the hash type to verify the downloaded schema file. Defaults to "md5". + args_file: a JSON or YAML file to provide default values for all the args in this function. + so that the command line inputs can be simplified. + kwargs: other arguments for `jsonschema.validate()`. for more details: + https://python-jsonschema.readthedocs.io/en/stable/validate/#jsonschema.validate. + + """ + + _args = _update_args( + args=args_file, + meta_file=meta_file, + filepath=filepath, + create_dir=create_dir, + hash_val=hash_val, + hash_type=hash_type, + **kwargs, + ) + _log_input_summary(tag="verify_metadata", args=_args) + + filepath_ = _args.pop("filepath") + create_dir_ = _args.pop("create_dir", True) + check_parent_dir(path=filepath_, create_dir=create_dir_) + + metadata = ConfigParser.load_config_files(files=_args.pop("meta_file")) + url = metadata.get("schema") + if url is None: + raise ValueError("must provide the `schema` field in the metadata for the URL of schema file.") + download_url( + url=url, + filepath=filepath_, + hash_val=_args.pop("hash_val", None), + hash_type=_args.pop("hash_type", "md5"), + progress=True, + ) + schema = ConfigParser.load_config_file(filepath=filepath_) + + try: + # the rest key-values in the _args are for `validate` API + validate(instance=metadata, schema=schema, **_args) + except ValidationError as e: + # as the error message is very long, only extract the key information + logger.info(re.compile(r".*Failed validating", re.S).findall(str(e))[0] + f" against schema `{url}`.") + return + logger.info("metadata is verified with no error.") + + +def verify_net_in_out( + net_id: Optional[str] = None, + meta_file: Optional[Union[str, Sequence[str]]] = None, + config_file: Optional[Union[str, Sequence[str]]] = None, + device: Optional[str] = None, + p: Optional[int] = None, + n: Optional[int] = None, + any: Optional[int] = None, + args_file: Optional[str] = None, + **override, +): + """ + Verify the input and output data shape and data type of network defined in the metadata. + Will test with fake Tensor data according to the required data shape in `metadata`. + + Typical usage examples: + + .. code-block:: bash + + python -m monai.bundle verify_net_in_out network --meta_file --config_file + + Args: + net_id: ID name of the network component to verify, it must be `torch.nn.Module`. + meta_file: filepath of the metadata file to get network args, if `None`, must be provided in `args_file`. + if it is a list of file paths, the content of them will be merged. + config_file: filepath of the config file to get network definition, if `None`, must be provided in `args_file`. + if it is a list of file paths, the content of them will be merged. + device: target device to run the network forward computation, if None, prefer to "cuda" if existing. + p: power factor to generate fake data shape if dim of expected shape is "x**p", default to 1. + n: multiply factor to generate fake data shape if dim of expected shape is "x*n", default to 1. + any: specified size to generate fake data shape if dim of expected shape is "*", default to 1. + args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`, + `net_id` and override pairs. so that the command line inputs can be simplified. + override: id-value pairs to override or add the corresponding config content. + e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``. + + """ + + _args = _update_args( + args=args_file, + net_id=net_id, + meta_file=meta_file, + config_file=config_file, + device=device, + p=p, + n=n, + any=any, + **override, + ) + _log_input_summary(tag="verify_net_in_out", args=_args) + + parser = ConfigParser() + parser.read_config(f=_args.pop("config_file")) + parser.read_meta(f=_args.pop("meta_file")) + id = _args.pop("net_id", "") + device_ = torch.device(_args.pop("device", "cuda:0" if torch.cuda.is_available() else "cpu")) + p = _args.pop("p", 1) + n = _args.pop("n", 1) + any = _args.pop("any", 1) + + # the rest key-values in the _args are to override config content + for k, v in _args.items(): + parser[k] = v + + try: + key: str = id # mark the full id when KeyError + net = parser.get_parsed_content(key).to(device_) + key = "_meta_#network_data_format#inputs#image#num_channels" + input_channels = parser[key] + key = "_meta_#network_data_format#inputs#image#spatial_shape" + input_spatial_shape = tuple(parser[key]) + key = "_meta_#network_data_format#inputs#image#dtype" + input_dtype = get_equivalent_dtype(parser[key], torch.Tensor) + key = "_meta_#network_data_format#outputs#pred#num_channels" + output_channels = parser[key] + key = "_meta_#network_data_format#outputs#pred#dtype" + output_dtype = get_equivalent_dtype(parser[key], torch.Tensor) + except KeyError as e: + raise KeyError(f"Failed to verify due to missing expected key in the config: {key}.") from e + + net.eval() + with torch.no_grad(): + spatial_shape = _get_fake_spatial_shape(input_spatial_shape, p=p, n=n, any=any) # type: ignore + test_data = torch.rand(*(1, input_channels, *spatial_shape), dtype=input_dtype, device=device_) + output = net(test_data) + if output.shape[1] != output_channels: + raise ValueError(f"output channel number `{output.shape[1]}` doesn't match: `{output_channels}`.") + if output.dtype != output_dtype: + raise ValueError(f"dtype of output data `{output.dtype}` doesn't match: {output_dtype}.") + logger.info("data shape of network is verified with no error.") diff --git a/monai/bundle/utils.py b/monai/bundle/utils.py new file mode 100644 index 0000000000..ba5c2729e7 --- /dev/null +++ b/monai/bundle/utils.py @@ -0,0 +1,18 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["ID_REF_KEY", "ID_SEP_KEY", "EXPR_KEY", "MACRO_KEY"] + + +ID_REF_KEY = "@" # start of a reference to a ConfigItem +ID_SEP_KEY = "#" # separator for the ID of a ConfigItem +EXPR_KEY = "$" # start of a ConfigExpression +MACRO_KEY = "%" # start of a macro of a config diff --git a/monai/config/__init__.py b/monai/config/__init__.py index c929cb2362..bf1b66fe92 100644 --- a/monai/config/__init__.py +++ b/monai/config/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,11 +12,21 @@ from .deviceconfig import ( USE_COMPILED, IgniteInfo, + get_config_values, get_gpu_info, + get_optional_config_values, get_system_info, print_config, print_debug_info, print_gpu_info, print_system_info, ) -from .type_definitions import DtypeLike, IndexSelection, KeysCollection, NdarrayOrTensor, NdarrayTensor, TensorOrList +from .type_definitions import ( + DtypeLike, + IndexSelection, + KeysCollection, + NdarrayOrTensor, + NdarrayTensor, + PathLike, + TensorOrList, +) diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index 273431fc72..fd7ca572e6 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -73,6 +73,8 @@ def get_optional_config_values(): output["psutil"] = psutil_version output["pandas"] = get_package_version("pandas") output["einops"] = get_package_version("einops") + output["transformers"] = get_package_version("transformers") + output["mlflow"] = get_package_version("mlflow") return output @@ -88,6 +90,7 @@ def print_config(file=sys.stdout): print(f"{k} version: {v}", file=file, flush=True) print(f"MONAI flags: HAS_EXT = {HAS_EXT}, USE_COMPILED = {USE_COMPILED}") print(f"MONAI rev id: {monai.__revision_id__}") + print(f"MONAI __file__: {monai.__file__}") print("\nOptional dependencies:", file=file, flush=True) for k, v in get_optional_config_values().items(): @@ -121,7 +124,7 @@ def get_system_info() -> OrderedDict: elif output["System"] == "Darwin": _dict_append(output, "Mac version", lambda: platform.mac_ver()[0]) else: - with open("/etc/os-release", "r") as rel_f: + with open("/etc/os-release") as rel_f: linux_ver = re.search(r'PRETTY_NAME="(.*)"', rel_f.read()) if linux_ver: _dict_append(output, "Linux version", lambda: linux_ver.group(1)) @@ -158,9 +161,9 @@ def get_system_info() -> OrderedDict: ), ) mem = psutil.virtual_memory() - _dict_append(output, "Total physical memory (GB)", lambda: round(mem.total / 1024 ** 3, 1)) - _dict_append(output, "Available memory (GB)", lambda: round(mem.available / 1024 ** 3, 1)) - _dict_append(output, "Used memory (GB)", lambda: round(mem.used / 1024 ** 3, 1)) + _dict_append(output, "Total physical memory (GB)", lambda: round(mem.total / 1024**3, 1)) + _dict_append(output, "Available memory (GB)", lambda: round(mem.available / 1024**3, 1)) + _dict_append(output, "Used memory (GB)", lambda: round(mem.used / 1024**3, 1)) return output @@ -198,8 +201,7 @@ def get_gpu_info() -> OrderedDict: if num_gpus > 0: _dict_append(output, "Current device", torch.cuda.current_device) - if hasattr(torch.cuda, "get_arch_list"): # get_arch_list is new in torch 1.7.1 - _dict_append(output, "Library compiled for CUDA architectures", torch.cuda.get_arch_list) + _dict_append(output, "Library compiled for CUDA architectures", torch.cuda.get_arch_list) for gpu in range(num_gpus): gpu_info = torch.cuda.get_device_properties(gpu) @@ -207,7 +209,7 @@ def get_gpu_info() -> OrderedDict: _dict_append(output, f"GPU {gpu} Is integrated", lambda: bool(gpu_info.is_integrated)) _dict_append(output, f"GPU {gpu} Is multi GPU board", lambda: bool(gpu_info.is_multi_gpu_board)) _dict_append(output, f"GPU {gpu} Multi processor count", lambda: gpu_info.multi_processor_count) - _dict_append(output, f"GPU {gpu} Total memory (GB)", lambda: round(gpu_info.total_memory / 1024 ** 3, 1)) + _dict_append(output, f"GPU {gpu} Total memory (GB)", lambda: round(gpu_info.total_memory / 1024**3, 1)) _dict_append(output, f"GPU {gpu} CUDA capability (maj.min)", lambda: f"{gpu_info.major}.{gpu_info.minor}") return output diff --git a/monai/config/type_definitions.py b/monai/config/type_definitions.py index 91ac74961b..16919c2ec4 100644 --- a/monai/config/type_definitions.py +++ b/monai/config/type_definitions.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Collection, Hashable, Iterable, Sequence, TypeVar, Union import numpy as np @@ -29,7 +30,15 @@ # may be implemented). Consistent use of the concept and recorded documentation of # the rationale and convention behind it lowers the learning curve for new # developers. For readability, short names are preferred. -__all__ = ["KeysCollection", "IndexSelection", "DtypeLike", "NdarrayTensor", "NdarrayOrTensor", "TensorOrList"] +__all__ = [ + "KeysCollection", + "IndexSelection", + "DtypeLike", + "NdarrayTensor", + "NdarrayOrTensor", + "TensorOrList", + "PathLike", +] #: KeysCollection @@ -51,18 +60,20 @@ # container must be iterable. IndexSelection = Union[Iterable[int], int] -#: Type of datatypes: Adapted from https://github.com/numpy/numpy/blob/master/numpy/typing/_dtype_like.py -DtypeLike = Union[np.dtype, type, None] +#: Type of datatypes: Adapted from https://github.com/numpy/numpy/blob/v1.21.4/numpy/typing/_dtype_like.py#L121 +DtypeLike = Union[np.dtype, type, str, None] + +#: NdarrayOrTensor: Union of numpy.ndarray and torch.Tensor to be used for typing +NdarrayOrTensor = Union[np.ndarray, torch.Tensor] #: NdarrayTensor # # Generic type which can represent either a numpy.ndarray or a torch.Tensor # Unlike Union can create a dependence between parameter(s) / return(s) -NdarrayTensor = TypeVar("NdarrayTensor", np.ndarray, torch.Tensor) - - -#: NdarrayOrTensor: Union of numpy.ndarray and torch.Tensor to be used for typing -NdarrayOrTensor = Union[np.ndarray, torch.Tensor] +NdarrayTensor = TypeVar("NdarrayTensor", bound=NdarrayOrTensor) #: TensorOrList: The TensorOrList type is used for defining `batch-first Tensor` or `list of channel-first Tensor`. TensorOrList = Union[torch.Tensor, Sequence[torch.Tensor]] + +#: PathLike: The PathLike type is used for defining a file path. +PathLike = Union[str, os.PathLike] diff --git a/monai/csrc/ext.cpp b/monai/csrc/ext.cpp index b4bb0f2c04..ac43e6fd3e 100644 --- a/monai/csrc/ext.cpp +++ b/monai/csrc/ext.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -31,6 +31,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::enum_(m, "BoundType") .value("replicate", monai::BoundType::Replicate, "a a a | a b c d | d d d") .value("nearest", monai::BoundType::Replicate, "a a a | a b c d | d d d") + .value("border", monai::BoundType::Replicate, "a a a | a b c d | d d d") .value("dct1", monai::BoundType::DCT1, "d c b | a b c d | c b a") .value("mirror", monai::BoundType::DCT1, "d c b | a b c d | c b a") .value("dct2", monai::BoundType::DCT2, "c b a | a b c d | d c b") @@ -43,6 +44,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .value("wrap", monai::BoundType::DFT, "b c d | a b c d | a b c") // .value("sliding", monai::BoundType::Sliding) .value("zero", monai::BoundType::Zero, "0 0 0 | a b c d | 0 0 0") + .value("zeros", monai::BoundType::Zero, "0 0 0 | a b c d | 0 0 0") .export_values(); // resample interpolation mode diff --git a/monai/csrc/filtering/bilateral/bilateral.cpp b/monai/csrc/filtering/bilateral/bilateral.cpp index 2720d312e2..183e13bb23 100644 --- a/monai/csrc/filtering/bilateral/bilateral.cpp +++ b/monai/csrc/filtering/bilateral/bilateral.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/filtering/bilateral/bilateral.h b/monai/csrc/filtering/bilateral/bilateral.h index c7a68d7457..288684666a 100644 --- a/monai/csrc/filtering/bilateral/bilateral.h +++ b/monai/csrc/filtering/bilateral/bilateral.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp b/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp index 2e6c7dbe20..51573ebbc0 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp b/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp index 847a452396..ec28c05520 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -57,7 +57,7 @@ void BilateralFilterPHLCpu( int coord = offsetRemainder / desc.strides[d]; offsetRemainder -= coord * desc.strides[d]; - features[i * featureChannels + desc.channelCount + d] = invSpatialSigma * coord; + features[i * featureChannels + desc.channelCount + d] = (scalar_t)invSpatialSigma * coord; } } diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu index f73ae19ac9..a24e6ed092 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu index 719d1643d3..20df9419fa 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/filtering/filtering.h b/monai/csrc/filtering/filtering.h index 3dcdfc473b..3e680010ed 100644 --- a/monai/csrc/filtering/filtering.h +++ b/monai/csrc/filtering/filtering.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/filtering/permutohedral/hash_table.cuh b/monai/csrc/filtering/permutohedral/hash_table.cuh index 1acff5f276..5507034dea 100644 --- a/monai/csrc/filtering/permutohedral/hash_table.cuh +++ b/monai/csrc/filtering/permutohedral/hash_table.cuh @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/filtering/permutohedral/permutohedral.cpp b/monai/csrc/filtering/permutohedral/permutohedral.cpp index d8fd3eaaeb..6ffe966f2c 100644 --- a/monai/csrc/filtering/permutohedral/permutohedral.cpp +++ b/monai/csrc/filtering/permutohedral/permutohedral.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/filtering/permutohedral/permutohedral.h b/monai/csrc/filtering/permutohedral/permutohedral.h index 1c9d1a031e..cdc0f693d0 100644 --- a/monai/csrc/filtering/permutohedral/permutohedral.h +++ b/monai/csrc/filtering/permutohedral/permutohedral.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp b/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp index 8c0dc8e546..7561320d2f 100644 --- a/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp +++ b/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -464,14 +464,14 @@ class PermutohedralLattice { // depending where we ended up, we may have to copy data if (oldValue != hashTableBase) { memcpy(hashTableBase, oldValue, hashTable.size() * vd * sizeof(scalar_t)); - delete oldValue; + delete[] oldValue; } else { - delete newValue; + delete[] newValue; } - delete zero; - delete neighbor1; - delete neighbor2; + delete[] zero; + delete[] neighbor1; + delete[] neighbor2; } private: diff --git a/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu b/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu index d1d78eb940..fa06590fa9 100644 --- a/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu +++ b/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/lltm/lltm.h b/monai/csrc/lltm/lltm.h index 33e17416f8..398ea92bb9 100644 --- a/monai/csrc/lltm/lltm.h +++ b/monai/csrc/lltm/lltm.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/lltm/lltm_cpu.cpp b/monai/csrc/lltm/lltm_cpu.cpp index 295c592d00..7cd2251f9f 100644 --- a/monai/csrc/lltm/lltm_cpu.cpp +++ b/monai/csrc/lltm/lltm_cpu.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/lltm/lltm_cuda.cu b/monai/csrc/lltm/lltm_cuda.cu index 4633348477..1293bec7e9 100644 --- a/monai/csrc/lltm/lltm_cuda.cu +++ b/monai/csrc/lltm/lltm_cuda.cu @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/resample/bounds_common.h b/monai/csrc/resample/bounds_common.h index 4997c7d968..2a7778b76b 100644 --- a/monai/csrc/resample/bounds_common.h +++ b/monai/csrc/resample/bounds_common.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/resample/interpolation_common.h b/monai/csrc/resample/interpolation_common.h index 35899298bf..8e2edbc8b7 100644 --- a/monai/csrc/resample/interpolation_common.h +++ b/monai/csrc/resample/interpolation_common.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/resample/pushpull.h b/monai/csrc/resample/pushpull.h index 1c20cc0114..b056bb77c2 100644 --- a/monai/csrc/resample/pushpull.h +++ b/monai/csrc/resample/pushpull.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/resample/pushpull_cpu.cpp b/monai/csrc/resample/pushpull_cpu.cpp index dd10dd76ee..c638958a47 100644 --- a/monai/csrc/resample/pushpull_cpu.cpp +++ b/monai/csrc/resample/pushpull_cpu.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -1527,10 +1527,10 @@ MONAI_NAMESPACE_DEVICE { // cpu iy0 = bound::index(bound1, iy0, src_Y); iz0 = bound::index(bound2, iz0, src_Z); - // Offsets into source volume offset_t o000, o100, o010, o001, o110, o011, o101, o111; if (do_pull || do_grad || do_sgrad) { + // Offsets into source volume o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; @@ -1539,18 +1539,12 @@ MONAI_NAMESPACE_DEVICE { // cpu o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; + } else if (!(do_push || do_count)) { + o000 = o100 = o010 = o001 = o110 = o011 = o101 = o111 = 0; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { - o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o001 = ix0 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o110 = ix1 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; - o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; scalar_t gx = static_cast(0); scalar_t gy = static_cast(0); scalar_t gz = static_cast(0); @@ -1657,16 +1651,19 @@ MONAI_NAMESPACE_DEVICE { // cpu grad_ptr_NXYZ[grad_sC] = gy; grad_ptr_NXYZ[grad_sC * 2] = gz; } + if (do_push || do_count) { + // Offsets into 'push' volume + o000 = ix0 * out_sX + iy0 * out_sY + iz0 * out_sZ; + o100 = ix1 * out_sX + iy0 * out_sY + iz0 * out_sZ; + o010 = ix0 * out_sX + iy1 * out_sY + iz0 * out_sZ; + o001 = ix0 * out_sX + iy0 * out_sY + iz1 * out_sZ; + o110 = ix1 * out_sX + iy1 * out_sY + iz0 * out_sZ; + o011 = ix0 * out_sX + iy1 * out_sY + iz1 * out_sZ; + o101 = ix1 * out_sX + iy0 * out_sY + iz1 * out_sZ; + o111 = ix1 * out_sX + iy1 * out_sY + iz1 * out_sZ; + } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_pull) { - o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o001 = ix0 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o110 = ix1 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; - o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ; scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC, src_ptr_NC += src_sC) { @@ -1678,14 +1675,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~ else if (do_sgrad) { - o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o001 = ix0 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o110 = ix1 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; - o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -1758,16 +1747,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_count) { - // Offsets into 'push' volume - o000 = ix0 * out_sX + iy0 * out_sY + iz0 * out_sZ; - o100 = ix1 * out_sX + iy0 * out_sY + iz0 * out_sZ; - o010 = ix0 * out_sX + iy1 * out_sY + iz0 * out_sZ; - o001 = ix0 * out_sX + iy0 * out_sY + iz1 * out_sZ; - o110 = ix1 * out_sX + iy1 * out_sY + iz0 * out_sZ; - o011 = ix0 * out_sX + iy1 * out_sY + iz1 * out_sZ; - o101 = ix1 * out_sX + iy0 * out_sY + iz1 * out_sZ; - o111 = ix1 * out_sX + iy1 * out_sY + iz1 * out_sZ; - scalar_t* out_ptr_N = out_ptr + n * out_sN; bound::add(out_ptr_N, o000, w000, s000); bound::add(out_ptr_N, o100, w100, s100); @@ -1822,21 +1801,19 @@ MONAI_NAMESPACE_DEVICE { // cpu ix0 = bound::index(bound0, ix0, src_X); iy0 = bound::index(bound1, iy0, src_Y); - // Offsets into source volume offset_t o00, o10, o01, o11; if (do_pull || do_grad || do_sgrad) { + // Offsets into source volume o00 = ix0 * src_sX + iy0 * src_sY; o10 = ix1 * src_sX + iy0 * src_sY; o01 = ix0 * src_sX + iy1 * src_sY; o11 = ix1 * src_sX + iy1 * src_sY; + } else if (!(do_push || do_count)) { + o00 = o10 = o01 = o11 = 0; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { - o00 = ix0 * src_sX + iy0 * src_sY; - o10 = ix1 * src_sX + iy0 * src_sY; - o01 = ix0 * src_sX + iy1 * src_sY; - o11 = ix1 * src_sX + iy1 * src_sY; scalar_t gx = static_cast(0); scalar_t gy = static_cast(0); scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY; @@ -1893,12 +1870,15 @@ MONAI_NAMESPACE_DEVICE { // cpu (*grad_ptr_NXY) = gx; grad_ptr_NXY[grad_sC] = gy; } + if (do_push || do_count) { + // Offsets into 'push' volume + o00 = ix0 * out_sX + iy0 * out_sY; + o10 = ix1 * out_sX + iy0 * out_sY; + o01 = ix0 * out_sX + iy1 * out_sY; + o11 = ix1 * out_sX + iy1 * out_sY; + } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_pull) { - o00 = ix0 * src_sX + iy0 * src_sY; - o10 = ix1 * src_sX + iy0 * src_sY; - o01 = ix0 * src_sX + iy1 * src_sY; - o11 = ix1 * src_sX + iy1 * src_sY; scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY; scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC, src_ptr_NC += src_sC) { @@ -1908,10 +1888,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_sgrad) { - o00 = ix0 * src_sX + iy0 * src_sY; - o10 = ix1 * src_sX + iy0 * src_sY; - o01 = ix0 * src_sX + iy1 * src_sY; - o11 = ix1 * src_sX + iy1 * src_sY; scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -1926,11 +1902,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_push) { - // Offsets into 'push' volume - o00 = ix0 * out_sX + iy0 * out_sY; - o10 = ix1 * out_sX + iy0 * out_sY; - o01 = ix0 * out_sX + iy1 * out_sY; - o11 = ix1 * out_sX + iy1 * out_sY; scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY; scalar_t* out_ptr_NC = out_ptr + n * out_sN; if (trgt_K == 0) { @@ -1960,12 +1931,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_count) { - // Offsets into 'push' volume - o00 = ix0 * out_sX + iy0 * out_sY; - o10 = ix1 * out_sX + iy0 * out_sY; - o01 = ix0 * out_sX + iy1 * out_sY; - o11 = ix1 * out_sX + iy1 * out_sY; - scalar_t* out_ptr_N = out_ptr + n * out_sN; bound::add(out_ptr_N, o00, w00, s00); bound::add(out_ptr_N, o10, w10, s10); @@ -1996,20 +1961,19 @@ MONAI_NAMESPACE_DEVICE { // cpu ix1 = bound::index(bound0, ix0 + 1, src_X); ix0 = bound::index(bound0, ix0, src_X); - // Offsets into source volume offset_t o0, o1; if (do_pull || do_grad || do_sgrad) { + // Offsets into source volume o0 = ix0 * src_sX; o1 = ix1 * src_sX; + } else if (!(do_push || do_count)) { + o0 = o1 = 0; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { if (trgt_K == 0) { // backward w.r.t. push/pull - - o0 = ix0 * src_sX; - o1 = ix1 * src_sX; scalar_t gx = static_cast(0); scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -2035,10 +1999,13 @@ MONAI_NAMESPACE_DEVICE { // cpu // -> zero (make sure this is done at initialization) } } + if (do_push || do_count) { + // Offsets into 'push' volume + o0 = ix0 * out_sX; + o1 = ix1 * out_sX; + } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_pull) { - o0 = ix0 * src_sX; - o1 = ix1 * src_sX; scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX; scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) { @@ -2047,8 +2014,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_sgrad) { - o0 = ix0 * src_sX; - o1 = ix1 * src_sX; scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -2058,9 +2023,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_push) { - // Offsets into 'push' volume - o0 = ix0 * out_sX; - o1 = ix1 * out_sX; scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX; scalar_t* out_ptr_NC = out_ptr + n * out_sN; if (trgt_K == 0) { @@ -2081,10 +2043,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_count) { - // Offsets into 'push' volume - o0 = ix0 * out_sX; - o1 = ix1 * out_sX; - scalar_t* out_ptr_N = out_ptr + n * out_sN; bound::add(out_ptr_N, o0, w0, s0); bound::add(out_ptr_N, o1, w1, s1); diff --git a/monai/csrc/resample/pushpull_cuda.cu b/monai/csrc/resample/pushpull_cuda.cu index 38d34ffe98..461962cb80 100644 --- a/monai/csrc/resample/pushpull_cuda.cu +++ b/monai/csrc/resample/pushpull_cuda.cu @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -1491,10 +1491,10 @@ MONAI_NAMESPACE_DEVICE { // cuda iy0 = bound::index(bound1, iy0, src_Y); iz0 = bound::index(bound2, iz0, src_Z); - // Offsets into source volume offset_t o000, o100, o010, o001, o110, o011, o101, o111; if (do_pull || do_grad || do_sgrad) { + // Offsets into source volume o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; @@ -1503,18 +1503,12 @@ MONAI_NAMESPACE_DEVICE { // cuda o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; + } else if (!(do_push || do_count)) { + o000 = o100 = o010 = o001 = o110 = o011 = o101 = o111 = 0; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { - o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o001 = ix0 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o110 = ix1 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; - o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; scalar_t gx = static_cast(0); scalar_t gy = static_cast(0); scalar_t gz = static_cast(0); @@ -1621,16 +1615,19 @@ MONAI_NAMESPACE_DEVICE { // cuda grad_ptr_NXYZ[grad_sC] = gy; grad_ptr_NXYZ[grad_sC * 2] = gz; } + if (do_push || do_count) { + // Offsets into 'push' volume + o000 = ix0 * out_sX + iy0 * out_sY + iz0 * out_sZ; + o100 = ix1 * out_sX + iy0 * out_sY + iz0 * out_sZ; + o010 = ix0 * out_sX + iy1 * out_sY + iz0 * out_sZ; + o001 = ix0 * out_sX + iy0 * out_sY + iz1 * out_sZ; + o110 = ix1 * out_sX + iy1 * out_sY + iz0 * out_sZ; + o011 = ix0 * out_sX + iy1 * out_sY + iz1 * out_sZ; + o101 = ix1 * out_sX + iy0 * out_sY + iz1 * out_sZ; + o111 = ix1 * out_sX + iy1 * out_sY + iz1 * out_sZ; + } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_pull) { - o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o001 = ix0 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o110 = ix1 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; - o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ; scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC, src_ptr_NC += src_sC) { @@ -1642,14 +1639,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~ else if (do_sgrad) { - o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o001 = ix0 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o110 = ix1 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; - o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -1672,15 +1661,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_push) { - // Offsets into 'push' volume - o000 = ix0 * out_sX + iy0 * out_sY + iz0 * out_sZ; - o100 = ix1 * out_sX + iy0 * out_sY + iz0 * out_sZ; - o010 = ix0 * out_sX + iy1 * out_sY + iz0 * out_sZ; - o001 = ix0 * out_sX + iy0 * out_sY + iz1 * out_sZ; - o110 = ix1 * out_sX + iy1 * out_sY + iz0 * out_sZ; - o011 = ix0 * out_sX + iy1 * out_sY + iz1 * out_sZ; - o101 = ix1 * out_sX + iy0 * out_sY + iz1 * out_sZ; - o111 = ix1 * out_sX + iy1 * out_sY + iz1 * out_sZ; scalar_t* trgt_ptr_NCXYZ = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY + d * trgt_sZ; scalar_t* out_ptr_NC = out_ptr + n * out_sN; if (trgt_K == 0) { @@ -1722,16 +1702,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_count) { - // Offsets into 'push' volume - o000 = ix0 * out_sX + iy0 * out_sY + iz0 * out_sZ; - o100 = ix1 * out_sX + iy0 * out_sY + iz0 * out_sZ; - o010 = ix0 * out_sX + iy1 * out_sY + iz0 * out_sZ; - o001 = ix0 * out_sX + iy0 * out_sY + iz1 * out_sZ; - o110 = ix1 * out_sX + iy1 * out_sY + iz0 * out_sZ; - o011 = ix0 * out_sX + iy1 * out_sY + iz1 * out_sZ; - o101 = ix1 * out_sX + iy0 * out_sY + iz1 * out_sZ; - o111 = ix1 * out_sX + iy1 * out_sY + iz1 * out_sZ; - scalar_t* out_ptr_N = out_ptr + n * out_sN; bound::add(out_ptr_N, o000, w000, s000); bound::add(out_ptr_N, o100, w100, s100); @@ -1786,21 +1756,19 @@ MONAI_NAMESPACE_DEVICE { // cuda ix0 = bound::index(bound0, ix0, src_X); iy0 = bound::index(bound1, iy0, src_Y); - // Offsets into source volume offset_t o00, o10, o01, o11; if (do_pull || do_grad || do_sgrad) { + // Offsets into source volume o00 = ix0 * src_sX + iy0 * src_sY; o10 = ix1 * src_sX + iy0 * src_sY; o01 = ix0 * src_sX + iy1 * src_sY; o11 = ix1 * src_sX + iy1 * src_sY; + } else if (!(do_push || do_count)) { + o00 = o10 = o01 = o11 = 0; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { - o00 = ix0 * src_sX + iy0 * src_sY; - o10 = ix1 * src_sX + iy0 * src_sY; - o01 = ix0 * src_sX + iy1 * src_sY; - o11 = ix1 * src_sX + iy1 * src_sY; scalar_t gx = static_cast(0); scalar_t gy = static_cast(0); scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY; @@ -1857,12 +1825,15 @@ MONAI_NAMESPACE_DEVICE { // cuda (*grad_ptr_NXY) = gx; grad_ptr_NXY[grad_sC] = gy; } + if (do_push || do_count) { + // Offsets into 'push' volume + o00 = ix0 * out_sX + iy0 * out_sY; + o10 = ix1 * out_sX + iy0 * out_sY; + o01 = ix0 * out_sX + iy1 * out_sY; + o11 = ix1 * out_sX + iy1 * out_sY; + } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_pull) { - o00 = ix0 * src_sX + iy0 * src_sY; - o10 = ix1 * src_sX + iy0 * src_sY; - o01 = ix0 * src_sX + iy1 * src_sY; - o11 = ix1 * src_sX + iy1 * src_sY; scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY; scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC, src_ptr_NC += src_sC) { @@ -1872,10 +1843,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_sgrad) { - o00 = ix0 * src_sX + iy0 * src_sY; - o10 = ix1 * src_sX + iy0 * src_sY; - o01 = ix0 * src_sX + iy1 * src_sY; - o11 = ix1 * src_sX + iy1 * src_sY; scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -1890,11 +1857,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_push) { - // Offsets into 'push' volume - o00 = ix0 * out_sX + iy0 * out_sY; - o10 = ix1 * out_sX + iy0 * out_sY; - o01 = ix0 * out_sX + iy1 * out_sY; - o11 = ix1 * out_sX + iy1 * out_sY; scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY; scalar_t* out_ptr_NC = out_ptr + n * out_sN; if (trgt_K == 0) { @@ -1924,12 +1886,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_count) { - // Offsets into 'push' volume - o00 = ix0 * out_sX + iy0 * out_sY; - o10 = ix1 * out_sX + iy0 * out_sY; - o01 = ix0 * out_sX + iy1 * out_sY; - o11 = ix1 * out_sX + iy1 * out_sY; - scalar_t* out_ptr_N = out_ptr + n * out_sN; bound::add(out_ptr_N, o00, w00, s00); bound::add(out_ptr_N, o10, w10, s10); @@ -1965,15 +1921,14 @@ MONAI_NAMESPACE_DEVICE { // cuda if (do_pull || do_grad || do_sgrad) { o0 = ix0 * src_sX; o1 = ix1 * src_sX; + } else if (!(do_push || do_count)) { + o0 = o1 = 0; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { if (trgt_K == 0) { // backward w.r.t. push/pull - - o0 = ix0 * src_sX; - o1 = ix1 * src_sX; scalar_t gx = static_cast(0); scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -1999,10 +1954,13 @@ MONAI_NAMESPACE_DEVICE { // cuda // -> zero (make sure this is done at initialization) } } + if (do_push || do_count) { + // Offsets into 'push' volume + o0 = ix0 * out_sX; + o1 = ix1 * out_sX; + } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_pull) { - o0 = ix0 * src_sX; - o1 = ix1 * src_sX; scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX; scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) { @@ -2011,8 +1969,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_sgrad) { - o0 = ix0 * src_sX; - o1 = ix1 * src_sX; scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -2022,9 +1978,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_push) { - // Offsets into 'push' volume - o0 = ix0 * out_sX; - o1 = ix1 * out_sX; scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX; scalar_t* out_ptr_NC = out_ptr + n * out_sN; if (trgt_K == 0) { @@ -2045,10 +1998,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_count) { - // Offsets into 'push' volume - o0 = ix0 * out_sX; - o1 = ix1 * out_sX; - scalar_t* out_ptr_N = out_ptr + n * out_sN; bound::add(out_ptr_N, o0, w0, s0); bound::add(out_ptr_N, o1, w1, s1); diff --git a/monai/csrc/utils/common_utils.h b/monai/csrc/utils/common_utils.h index 4d09377e65..aabea3c99f 100644 --- a/monai/csrc/utils/common_utils.h +++ b/monai/csrc/utils/common_utils.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/utils/meta_macros.h b/monai/csrc/utils/meta_macros.h index 980b253bbe..0f15b623e3 100644 --- a/monai/csrc/utils/meta_macros.h +++ b/monai/csrc/utils/meta_macros.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/utils/resample_utils.h b/monai/csrc/utils/resample_utils.h index bbdf258b4c..77df65e924 100644 --- a/monai/csrc/utils/resample_utils.h +++ b/monai/csrc/utils/resample_utils.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/utils/tensor_description.h b/monai/csrc/utils/tensor_description.h index dadd26c5f5..c604003b9d 100644 --- a/monai/csrc/utils/tensor_description.h +++ b/monai/csrc/utils/tensor_description.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/data/__init__.py b/monai/data/__init__.py index fca170335b..bed194d2f4 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,6 +17,7 @@ CacheNTransDataset, CSVDataset, Dataset, + DatasetFunc, LMDBDataset, NPZDictItemDataset, PersistentDataset, @@ -24,11 +25,27 @@ ZipDataset, ) from .dataset_summary import DatasetSummary -from .decathlon_datalist import load_decathlon_datalist, load_decathlon_properties +from .decathlon_datalist import ( + check_missing_files, + create_cross_validation_datalist, + load_decathlon_datalist, + load_decathlon_properties, +) +from .folder_layout import FolderLayout from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter from .image_dataset import ImageDataset from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader -from .iterable_dataset import CSVIterableDataset, IterableDataset +from .image_writer import ( + SUPPORTED_WRITERS, + ImageWriter, + ITKWriter, + NibabelWriter, + PILWriter, + logger, + register_writer, + resolve_writer, +) +from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti from .png_saver import PNGSaver @@ -37,6 +54,7 @@ from .synthetic import create_test_image_2d, create_test_image_3d from .test_time_augmentation import TestTimeAugmentation from .thread_buffer import ThreadBuffer, ThreadDataLoader +from .torchscript_utils import load_net_with_metadata, save_net_with_metadata from .utils import ( compute_importance_map, compute_shape_offset, @@ -52,12 +70,14 @@ iter_patch_slices, json_hashing, list_data_collate, + orientation_ras_lps, pad_list_data_collate, partition_dataset, partition_dataset_classes, pickle_hashing, rectify_header_sform_qform, - rep_scalar_to_batch, + reorient_spatial_axes, + resample_datalist, select_cross_validation_folds, set_rnd, sorted_dict, diff --git a/monai/data/csv_saver.py b/monai/data/csv_saver.py index 62f407bfd5..e938bdabf8 100644 --- a/monai/data/csv_saver.py +++ b/monai/data/csv_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,11 +12,13 @@ import os import warnings from collections import OrderedDict +from pathlib import Path from typing import Dict, Optional, Union import numpy as np import torch +from monai.config.type_definitions import PathLike from monai.utils import ImageMetaKey as Key @@ -33,10 +35,11 @@ class CSVSaver: def __init__( self, - output_dir: str = "./", + output_dir: PathLike = "./", filename: str = "predictions.csv", overwrite: bool = True, flush: bool = False, + delimiter: str = ",", ) -> None: """ Args: @@ -46,17 +49,20 @@ def __init__( otherwise, will append new content to the CSV file. flush: whether to write the cache data to CSV file immediately when `save_batch` and clear the cache. default to False. + delimiter: the delimiter character in the saved file, default to "," as the default output type is `csv`. + to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter. """ - self.output_dir = output_dir + self.output_dir = Path(output_dir) self._cache_dict: OrderedDict = OrderedDict() if not (isinstance(filename, str) and filename[-4:] == ".csv"): warnings.warn("CSV filename is not a string ends with '.csv'.") - self._filepath = os.path.join(output_dir, filename) - if os.path.exists(self._filepath) and overwrite: + self._filepath = self.output_dir / filename + if self._filepath.exists() and overwrite: os.remove(self._filepath) self.flush = flush + self.delimiter = delimiter self._data_index = 0 def finalize(self) -> None: @@ -64,13 +70,13 @@ def finalize(self) -> None: Writes the cached dict to a csv """ - if not os.path.exists(self.output_dir): - os.makedirs(self.output_dir) + if not self.output_dir.exists(): + self.output_dir.mkdir(parents=True, exist_ok=True) with open(self._filepath, "a") as f: for k, v in self._cache_dict.items(): f.write(k) for result in v.flatten(): - f.write("," + str(result)) + f.write(self.delimiter + str(result)) f.write("\n") # clear cache content after writing self.reset_cache() diff --git a/monai/data/dataloader.py b/monai/data/dataloader.py index 2c9174e9f4..d1f5bd4fe1 100644 --- a/monai/data/dataloader.py +++ b/monai/data/dataloader.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -81,8 +81,4 @@ def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs) -> None: if "worker_init_fn" not in kwargs: kwargs.update({"worker_init_fn": worker_init_fn}) - super().__init__( # type: ignore[call-overload] - dataset=dataset, - num_workers=num_workers, - **kwargs, - ) + super().__init__(dataset=dataset, num_workers=num_workers, **kwargs) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index c970e83d0d..3c1fc0abed 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -26,12 +26,14 @@ import numpy as np import torch +from torch.serialization import DEFAULT_PROTOCOL from torch.utils.data import Dataset as _TorchDataset from torch.utils.data import Subset -from monai.data.utils import convert_tables_to_dicts, first, pickle_hashing -from monai.transforms import Compose, Randomizable, ThreadUnsafe, Transform, apply_transform -from monai.utils import MAX_SEED, ensure_tuple, get_seed, min_version, optional_import +from monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing +from monai.transforms import Compose, Randomizable, ThreadUnsafe, Transform, apply_transform, convert_to_contiguous +from monai.utils import MAX_SEED, deprecated_arg, get_seed, look_up_option, min_version, optional_import +from monai.utils.misc import first if TYPE_CHECKING: from tqdm import tqdm @@ -95,6 +97,56 @@ def __getitem__(self, index: Union[int, slice, Sequence[int]]): return self._transform(index) +class DatasetFunc(Dataset): + """ + Execute function on the input dataset and leverage the output to act as a new Dataset. + It can be used to load / fetch the basic dataset items, like the list of `image, label` paths. + Or chain together to execute more complicated logic, like `partition_dataset`, `resample_datalist`, etc. + The `data` arg of `Dataset` will be applied to the first arg of callable `func`. + Usage example:: + + data_list = DatasetFunc( + data="path to file", + func=monai.data.load_decathlon_datalist, + data_list_key="validation", + base_dir="path to base dir", + ) + # partition dataset for every rank + data_partition = DatasetFunc( + data=data_list, + func=lambda **kwargs: monai.data.partition_dataset(**kwargs)[torch.distributed.get_rank()], + num_partitions=torch.distributed.get_world_size(), + ) + dataset = Dataset(data=data_partition, transform=transforms) + + Args: + data: input data for the func to process, will apply to `func` as the first arg. + func: callable function to generate dataset items. + kwargs: other arguments for the `func` except for the first arg. + + """ + + def __init__(self, data: Any, func: Callable, **kwargs) -> None: + super().__init__(data=None, transform=None) # type:ignore + self.src = data + self.func = func + self.kwargs = kwargs + self.reset() + + def reset(self, data: Optional[Any] = None, func: Optional[Callable] = None, **kwargs): + """ + Reset the dataset items with specified `func`. + + Args: + data: if not None, execute `func` on it, default to `self.src`. + func: if not None, execute the `func` with specified `kwargs`, default to `self.func`. + kwargs: other arguments for the `func` except for the first arg. + + """ + src = self.src if data is None else data + self.data = self.func(src, **self.kwargs) if func is None else func(src, **kwargs) + + class PersistentDataset(Dataset): """ Persistent storage of pre-computed values to efficiently manage larger than memory dictionary format data, @@ -151,6 +203,8 @@ def __init__( transform: Union[Sequence[Callable], Callable], cache_dir: Optional[Union[Path, str]], hash_func: Callable[..., bytes] = pickle_hashing, + pickle_module: str = "pickle", + pickle_protocol: int = DEFAULT_PROTOCOL, ) -> None: """ Args: @@ -166,6 +220,18 @@ def __init__( If `cache_dir` is `None`, there is effectively no caching. hash_func: a callable to compute hash from data items to be cached. defaults to `monai.data.utils.pickle_hashing`. + pickle_module: string representing the module used for pickling metadata and objects, + default to `"pickle"`. due to the pickle limitation in multi-processing of Dataloader, + we can't use `pickle` as arg directly, so here we use a string name instead. + if want to use other pickle module at runtime, just register like: + >>> from monai.data import utils + >>> utils.SUPPORTED_PICKLE_MOD["test"] = other_pickle + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, + and ``monai.data.utils.SUPPORTED_PICKLE_MOD``. + pickle_protocol: can be specified to override the default protocol, default to `2`. + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. """ if not isinstance(transform, Compose): @@ -173,6 +239,8 @@ def __init__( super().__init__(data=data, transform=transform) self.cache_dir = Path(cache_dir) if cache_dir is not None else None self.hash_func = hash_func + self.pickle_module = pickle_module + self.pickle_protocol = pickle_protocol if self.cache_dir is not None: if not self.cache_dir.exists(): self.cache_dir.mkdir(parents=True, exist_ok=True) @@ -201,7 +269,7 @@ def _pre_transform(self, item_transformed): random transform object """ - for _transform in self.transform.transforms: # type:ignore + for _transform in self.transform.transforms: # execute all the deterministic transforms if isinstance(_transform, Randomizable) or not isinstance(_transform, Transform): break @@ -267,13 +335,20 @@ def _cachecheck(self, item_transformed): raise e _item_transformed = self._pre_transform(deepcopy(item_transformed)) # keep the original hashed - if hashfile is not None: + if hashfile is None: + return _item_transformed + try: # NOTE: Writing to a temporary directory and then using a nearly atomic rename operation # to make the cache more robust to manual killing of parent process # which may leave partially written cache files in an incomplete state with tempfile.TemporaryDirectory() as tmpdirname: temp_hash_file = Path(tmpdirname) / hashfile.name - torch.save(_item_transformed, temp_hash_file) + torch.save( + obj=_item_transformed, + f=temp_hash_file, + pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), + pickle_protocol=self.pickle_protocol, + ) if temp_hash_file.is_file() and not hashfile.is_file(): # On Unix, if target exists and is a file, it will be replaced silently if the user has permission. # for more details: https://docs.python.org/3/library/shutil.html#shutil.move. @@ -281,6 +356,8 @@ def _cachecheck(self, item_transformed): shutil.move(temp_hash_file, hashfile) except FileExistsError: pass + except PermissionError: # project-monai/monai issue #3613 + pass return _item_transformed def _transform(self, index: int): @@ -301,6 +378,8 @@ def __init__( cache_n_trans: int, cache_dir: Optional[Union[Path, str]], hash_func: Callable[..., bytes] = pickle_hashing, + pickle_module: str = "pickle", + pickle_protocol: int = DEFAULT_PROTOCOL, ) -> None: """ Args: @@ -317,9 +396,28 @@ def __init__( If `cache_dir` is `None`, there is effectively no caching. hash_func: a callable to compute hash from data items to be cached. defaults to `monai.data.utils.pickle_hashing`. - - """ - super().__init__(data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func) + pickle_module: string representing the module used for pickling metadata and objects, + default to `"pickle"`. due to the pickle limitation in multi-processing of Dataloader, + we can't use `pickle` as arg directly, so here we use a string name instead. + if want to use other pickle module at runtime, just register like: + >>> from monai.data import utils + >>> utils.SUPPORTED_PICKLE_MOD["test"] = other_pickle + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, + and ``monai.data.utils.SUPPORTED_PICKLE_MOD``. + pickle_protocol: can be specified to override the default protocol, default to `2`. + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. + + """ + super().__init__( + data=data, + transform=transform, + cache_dir=cache_dir, + hash_func=hash_func, + pickle_module=pickle_module, + pickle_protocol=pickle_protocol, + ) self.cache_n_trans = cache_n_trans def _pre_transform(self, item_transformed): @@ -406,15 +504,16 @@ def __init__( lmdb_kwargs: additional keyword arguments to the lmdb environment. for more details please visit: https://lmdb.readthedocs.io/en/release/#environment-class """ - super().__init__(data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func) + super().__init__( + data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func, pickle_protocol=pickle_protocol + ) self.progress = progress if not self.cache_dir: raise ValueError("cache_dir must be specified.") self.db_file = self.cache_dir / f"{db_name}.lmdb" - self.pickle_protocol = pickle_protocol self.lmdb_kwargs = lmdb_kwargs or {} if not self.lmdb_kwargs.get("map_size", 0): - self.lmdb_kwargs["map_size"] = 1024 ** 4 # default map_size + self.lmdb_kwargs["map_size"] = 1024**4 # default map_size # lmdb is single-writer multi-reader by default # the cache is created without multi-threading self._read_env = None @@ -573,8 +672,12 @@ def __init__( transform: Union[Sequence[Callable], Callable], cache_num: int = sys.maxsize, cache_rate: float = 1.0, - num_workers: Optional[int] = None, + num_workers: Optional[int] = 1, progress: bool = True, + copy_cache: bool = True, + as_contiguous: bool = True, + hash_as_key: bool = False, + hash_func: Callable[..., bytes] = pickle_hashing, ) -> None: """ Args: @@ -584,19 +687,40 @@ def __init__( will take the minimum of (cache_num, data_length x cache_rate, data_length). cache_rate: percentage of cached data in total, default is 1.0 (cache all). will take the minimum of (cache_num, data_length x cache_rate, data_length). - num_workers: the number of worker processes to use. + num_workers: the number of worker threads to use. If num_workers is None then the number returned by os.cpu_count() is used. + If a value less than 1 is speficied, 1 will be used instead. progress: whether to display a progress bar. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cached content + (for example, randomly crop from the cached image and deepcopy the crop region) + or if every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. + as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous. + it may help improve the performance of following logic. + hash_as_key: whether to compute hash value of input data as the key to save cache, + if key exists, avoid saving duplicated content. it can help save memory when + the dataset has duplicated items or augmented dataset. + hash_func: if `hash_as_key`, a callable to compute hash from data items to be cached. + defaults to `monai.data.utils.pickle_hashing`. + """ if not isinstance(transform, Compose): transform = Compose(transform) super().__init__(data=data, transform=transform) + self.set_num = cache_num # tracking the user-provided `cache_num` option + self.set_rate = cache_rate # tracking the user-provided `cache_rate` option self.progress = progress - self.cache_num = min(int(cache_num), int(len(data) * cache_rate), len(data)) + self.copy_cache = copy_cache + self.as_contiguous = as_contiguous + self.hash_as_key = hash_as_key + self.hash_func = hash_func self.num_workers = num_workers if self.num_workers is not None: self.num_workers = max(int(self.num_workers), 1) - self._cache: List = self._fill_cache() + self.cache_num = 0 + self._cache: Union[List, Dict] = [] + self.set_data(data) def set_data(self, data: Sequence): """ @@ -607,8 +731,21 @@ def set_data(self, data: Sequence): generated cache content. """ - self.data = data - self._cache = self._fill_cache() + + def _compute_cache(): + self.cache_num = min(int(self.set_num), int(len(self.data) * self.set_rate), len(self.data)) + return self._fill_cache() + + if self.hash_as_key: + # only compute cache for the unique items of dataset + mapping = {self.hash_func(v): v for v in data} + self.data = list(mapping.values()) + cache_ = _compute_cache() + self._cache = dict(zip(list(mapping)[: self.cache_num], cache_)) + self.data = data + else: + self.data = data + self._cache = _compute_cache() def _fill_cache(self) -> List: if self.cache_num <= 0: @@ -638,17 +775,26 @@ def _load_cache_item(self, idx: int): break _xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform item = apply_transform(_xform, item) + if self.as_contiguous: + item = convert_to_contiguous(item, memory_format=torch.contiguous_format) return item def _transform(self, index: int): - if index % len(self) >= self.cache_num: # support negative index + index_: Any = index + if self.hash_as_key: + key = self.hash_func(self.data[index]) + if key in self._cache: + # if existing in cache, get the index + index_ = key # if using hash as cache keys, set the key + + if isinstance(index_, int) and index_ % len(self) >= self.cache_num: # support negative index # no cache for this index, execute all the transforms directly - return super()._transform(index) + return super()._transform(index_) # load data from cache and execute from the first random transform start_run = False if self._cache is None: self._cache = self._fill_cache() - data = self._cache[index] + data = self._cache[index_] if not isinstance(self.transform, Compose): raise ValueError("transform must be an instance of monai.transforms.Compose.") for _transform in self.transform.transforms: @@ -656,7 +802,8 @@ def _transform(self, index: int): # only need to deep copy data on first non-deterministic transform if not start_run: start_run = True - data = deepcopy(data) + if self.copy_cache: + data = deepcopy(data) data = apply_transform(_transform, data) return data @@ -716,12 +863,21 @@ class SmartCacheDataset(Randomizable, CacheDataset): will take the minimum of (cache_num, data_length x cache_rate, data_length). num_init_workers: the number of worker threads to initialize the cache for first epoch. If num_init_workers is None then the number returned by os.cpu_count() is used. + If a value less than 1 is speficied, 1 will be used instead. num_replace_workers: the number of worker threads to prepare the replacement cache for every epoch. If num_replace_workers is None then the number returned by os.cpu_count() is used. + If a value less than 1 is speficied, 1 will be used instead. progress: whether to display a progress bar when caching for the first epoch. shuffle: whether to shuffle the whole data list before preparing the cache content for first epoch. it will not modify the original input data sequence in-place. seed: random seed if shuffle is `True`, default to `0`. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cache content + or every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. + as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous. + it may help improve the performance of following logic. + """ def __init__( @@ -731,19 +887,25 @@ def __init__( replace_rate: float, cache_num: int = sys.maxsize, cache_rate: float = 1.0, - num_init_workers: Optional[int] = None, - num_replace_workers: Optional[int] = None, + num_init_workers: Optional[int] = 1, + num_replace_workers: Optional[int] = 1, progress: bool = True, shuffle: bool = True, seed: int = 0, + copy_cache: bool = True, + as_contiguous: bool = True, ) -> None: if shuffle: self.set_random_state(seed=seed) - data = copy(data) - self.randomize(data) self.shuffle = shuffle - super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress) + self._start_pos: int = 0 + self._update_lock: threading.Lock = threading.Lock() + self._round: int = 1 + self._replace_done: bool = False + self._replace_mgr: Optional[threading.Thread] = None + + super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress, copy_cache, as_contiguous) if self._cache is None: self._cache = self._fill_cache() if self.cache_num >= len(data): @@ -761,13 +923,6 @@ def __init__( self._replace_num: int = min(math.ceil(self.cache_num * replace_rate), len(data) - self.cache_num) self._replacements: List[Any] = [None for _ in range(self._replace_num)] self._replace_data_idx: List[int] = list(range(self._replace_num)) - - self._start_pos: int = 0 - self._update_lock: threading.Lock = threading.Lock() - self._round: int = 1 - self._replace_done: bool = False - self._replace_mgr: Optional[threading.Thread] = None - self._compute_data_idx() def set_data(self, data: Sequence): @@ -977,7 +1132,7 @@ def __init__(self, datasets: Sequence, transform: Optional[Callable] = None) -> super().__init__(list(datasets), transform=transform) def __len__(self) -> int: - return min((len(dataset) for dataset in self.data)) + return min(len(dataset) for dataset in self.data) def _transform(self, index: int): def to_list(x): @@ -1164,8 +1319,9 @@ class CSVDataset(Dataset): ] Args: - filename: the filename of expected CSV file to load. if providing a list - of filenames, it will load all the files and join tables. + src: if provided the filename of CSV file, it can be a str, URL, path object or file-like object to load. + also support to provide pandas `DataFrame` directly, will skip loading from filename. + if provided a list of filenames or pandas `DataFrame`, it will join the tables. row_indices: indices of the expected rows to load. it should be a list, every item can be a int number or a range `[start, end)` for the indices. for example: `row_indices=[[0, 100], 200, 201, 202, 300]`. if None, @@ -1189,28 +1345,40 @@ class CSVDataset(Dataset): be the new column name, the `value` is the names of columns to combine. for example: `col_groups={"ehr": [f"ehr_{i}" for i in range(10)], "meta": ["meta_1", "meta_2"]}` transform: transform to apply on the loaded items of a dictionary data. + kwargs_read_csv: dictionary args to pass to pandas `read_csv` function. kwargs: additional arguments for `pandas.merge()` API to join tables. + .. deprecated:: 0.8.0 + ``filename`` is deprecated, use ``src`` instead. + """ + @deprecated_arg(name="filename", new_name="src", since="0.8", msg_suffix="please use `src` instead.") def __init__( self, - filename: Union[str, Sequence[str]], + src: Optional[Union[str, Sequence[str]]] = None, # also can be `DataFrame` or sequense of `DataFrame` row_indices: Optional[Sequence[Union[int, str]]] = None, col_names: Optional[Sequence[str]] = None, col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None, col_groups: Optional[Dict[str, Sequence[str]]] = None, transform: Optional[Callable] = None, + kwargs_read_csv: Optional[Dict] = None, **kwargs, ): - files = ensure_tuple(filename) - dfs = [pd.read_csv(f) for f in files] + srcs = (src,) if not isinstance(src, (tuple, list)) else src + dfs: List = [] + for i in srcs: + if isinstance(i, str): + dfs.append(pd.read_csv(i, **kwargs_read_csv) if kwargs_read_csv else pd.read_csv(i)) + elif isinstance(i, pd.DataFrame): + dfs.append(i) + else: + raise ValueError("`src` must be file path or pandas `DataFrame`.") + + # in case treating deprecated arg `filename` as kwargs, remove it from `kwargs` + kwargs.pop("filename", None) + data = convert_tables_to_dicts( - dfs=dfs, - row_indices=row_indices, - col_names=col_names, - col_types=col_types, - col_groups=col_groups, - **kwargs, + dfs=dfs, row_indices=row_indices, col_names=col_names, col_types=col_types, col_groups=col_groups, **kwargs ) super().__init__(data=data, transform=transform) diff --git a/monai/data/dataset_summary.py b/monai/data/dataset_summary.py index a8598eb6c8..b447585d3e 100644 --- a/monai/data/dataset_summary.py +++ b/monai/data/dataset_summary.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,8 +15,14 @@ import numpy as np import torch +from monai.config import KeysCollection from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset +from monai.transforms import concatenate +from monai.utils import convert_data_type +from monai.utils.enums import PostFix + +DEFAULT_POST_FIX = PostFix.meta() class DatasetSummary: @@ -38,7 +44,8 @@ def __init__( dataset: Dataset, image_key: Optional[str] = "image", label_key: Optional[str] = "label", - meta_key_postfix: str = "meta_dict", + meta_key: Optional[KeysCollection] = None, + meta_key_postfix: str = DEFAULT_POST_FIX, num_workers: int = 0, **kwargs, ): @@ -47,11 +54,16 @@ def __init__( dataset: dataset from which to load the data. image_key: key name of images (default: ``image``). label_key: key name of labels (default: ``label``). + meta_key: explicitly indicate the key of the corresponding meta data dictionary. + for example, for data with key `image`, the metadata by default is in `image_meta_dict`. + the meta data is a dictionary object which contains: filename, affine, original_shape, etc. + if None, will try to construct meta_keys by `{image_key}_{meta_key_postfix}`. meta_key_postfix: use `{image_key}_{meta_key_postfix}` to fetch the meta data from dict, the meta data is a dictionary object (default: ``meta_dict``). num_workers: how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process (default: ``0``). - kwargs: other parameters (except batch_size) for DataLoader (this class forces to use ``batch_size=1``). + kwargs: other parameters (except `batch_size` and `num_workers`) for DataLoader, + this class forces to use ``batch_size=1``. """ @@ -59,18 +71,17 @@ def __init__( self.image_key = image_key self.label_key = label_key - if image_key: - self.meta_key = "{}_{}".format(image_key, meta_key_postfix) + self.meta_key = meta_key or f"{image_key}_{meta_key_postfix}" self.all_meta_data: List = [] def collect_meta_data(self): """ This function is used to collect the meta data for all images of the dataset. """ - if not self.meta_key: - raise ValueError("To collect meta data for the dataset, `meta_key` should exist.") for data in self.data_loader: + if self.meta_key not in data: + raise ValueError(f"To collect meta data for the dataset, key `{self.meta_key}` must exist in `data`.") self.all_meta_data.append(data[self.meta_key]) def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: int = 3, percentile: float = 10.0): @@ -78,8 +89,8 @@ def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: Calculate the target spacing according to all spacings. If the target spacing is very anisotropic, decrease the spacing value of the maximum axis according to percentile. - So far, this function only supports NIFTI images which store spacings in headers with key "pixdim". After loading - with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`. + So far, this function only supports NIFTI images which store spacings in headers with key "pixdim". + After loading with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`. Args: spacing_key: key of spacing in meta data (default: ``pixdim``). @@ -92,8 +103,8 @@ def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: self.collect_meta_data() if spacing_key not in self.all_meta_data[0]: raise ValueError("The provided spacing_key is not in self.all_meta_data.") - - all_spacings = torch.cat([data[spacing_key][:, 1:4] for data in self.all_meta_data], dim=0).numpy() + all_spacings = concatenate(to_cat=[data[spacing_key][:, 1:4] for data in self.all_meta_data], axis=0) + all_spacings, *_ = convert_data_type(data=all_spacings, output_type=np.ndarray, wrap_sequence=True) target_spacing = np.median(all_spacings, axis=0) if max(target_spacing) / min(target_spacing) >= anisotropic_threshold: @@ -126,18 +137,20 @@ def calculate_statistics(self, foreground_threshold: int = 0): image, label = data[self.image_key], data[self.label_key] else: image, label = data - - voxel_max.append(image.max().item()) - voxel_min.append(image.min().item()) + image, *_ = convert_data_type(data=image, output_type=torch.Tensor) + label, *_ = convert_data_type(data=label, output_type=torch.Tensor) image_foreground = image[torch.where(label > foreground_threshold)] + + voxel_max.append(image_foreground.max().item()) + voxel_min.append(image_foreground.min().item()) voxel_ct += len(image_foreground) voxel_sum += image_foreground.sum() voxel_square_sum += torch.square(image_foreground).sum() self.data_max, self.data_min = max(voxel_max), min(voxel_min) self.data_mean = (voxel_sum / voxel_ct).item() - self.data_std = (torch.sqrt(voxel_square_sum / voxel_ct - self.data_mean ** 2)).item() + self.data_std = (torch.sqrt(voxel_square_sum / voxel_ct - self.data_mean**2)).item() def calculate_percentiles( self, @@ -169,6 +182,8 @@ def calculate_percentiles( image, label = data[self.image_key], data[self.label_key] else: image, label = data + image, *_ = convert_data_type(data=image, output_type=torch.Tensor) + label, *_ = convert_data_type(data=label, output_type=torch.Tensor) intensities = image[torch.where(label > foreground_threshold)].tolist() if sampling_flag: diff --git a/monai/data/decathlon_datalist.py b/monai/data/decathlon_datalist.py index 663b68a08e..d2a9c3d220 100644 --- a/monai/data/decathlon_datalist.py +++ b/monai/data/decathlon_datalist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,18 +11,22 @@ import json import os +import warnings +from pathlib import Path from typing import Dict, List, Optional, Sequence, Union, overload +from monai.config import KeysCollection, PathLike +from monai.data.utils import partition_dataset, select_cross_validation_folds from monai.utils import ensure_tuple @overload -def _compute_path(base_dir: str, element: str, check_path: bool = False) -> str: +def _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = False) -> str: ... @overload -def _compute_path(base_dir: str, element: List[str], check_path: bool = False) -> List[str]: +def _compute_path(base_dir: PathLike, element: List[PathLike], check_path: bool = False) -> List[str]: ... @@ -39,24 +43,24 @@ def _compute_path(base_dir, element, check_path=False): """ - def _join_path(base_dir: str, item: str): + def _join_path(base_dir: PathLike, item: PathLike): result = os.path.normpath(os.path.join(base_dir, item)) if check_path and not os.path.exists(result): # if not an existing path, don't join with base dir - return item - return result + return f"{item}" + return f"{result}" - if isinstance(element, str): + if isinstance(element, (str, os.PathLike)): return _join_path(base_dir, element) if isinstance(element, list): for e in element: - if not isinstance(e, str): + if not isinstance(e, (str, os.PathLike)): return element return [_join_path(base_dir, e) for e in element] return element -def _append_paths(base_dir: str, is_segmentation: bool, items: List[Dict]) -> List[Dict]: +def _append_paths(base_dir: PathLike, is_segmentation: bool, items: List[Dict]) -> List[Dict]: """ Args: base_dir: the base directory of the dataset. @@ -80,10 +84,10 @@ def _append_paths(base_dir: str, is_segmentation: bool, items: List[Dict]) -> Li def load_decathlon_datalist( - data_list_file_path: str, + data_list_file_path: PathLike, is_segmentation: bool = True, data_list_key: str = "training", - base_dir: Optional[str] = None, + base_dir: Optional[PathLike] = None, ) -> List[Dict]: """Load image/label paths of decathlon challenge from JSON file @@ -110,7 +114,8 @@ def load_decathlon_datalist( ] """ - if not os.path.isfile(data_list_file_path): + data_list_file_path = Path(data_list_file_path) + if not data_list_file_path.is_file(): raise ValueError(f"Data list file {data_list_file_path} does not exist.") with open(data_list_file_path) as json_file: json_data = json.load(json_file) @@ -121,15 +126,12 @@ def load_decathlon_datalist( expected_data = [{"image": i} for i in expected_data] if base_dir is None: - base_dir = os.path.dirname(data_list_file_path) + base_dir = data_list_file_path.parent return _append_paths(base_dir, is_segmentation, expected_data) -def load_decathlon_properties( - data_property_file_path: str, - property_keys: Union[Sequence[str], str], -) -> Dict: +def load_decathlon_properties(data_property_file_path: PathLike, property_keys: Union[Sequence[str], str]) -> Dict: """Load the properties from the JSON file contains data property with specified `property_keys`. Args: @@ -140,7 +142,8 @@ def load_decathlon_properties( `modality`, `labels`, `numTraining`, `numTest`, etc. """ - if not os.path.isfile(data_property_file_path): + data_property_file_path = Path(data_property_file_path) + if not data_property_file_path.is_file(): raise ValueError(f"Data property file {data_property_file_path} does not exist.") with open(data_property_file_path) as json_file: json_data = json.load(json_file) @@ -151,3 +154,97 @@ def load_decathlon_properties( raise KeyError(f"key {key} is not in the data property file.") properties[key] = json_data[key] return properties + + +def check_missing_files( + datalist: List[Dict], keys: KeysCollection, root_dir: Optional[PathLike] = None, allow_missing_keys: bool = False +): + """Checks whether some files in the Decathlon datalist are missing. + It would be helpful to check missing files before a heavy training run. + + Args: + datalist: a list of data items, every item is a dictionary. + usually generated by `load_decathlon_datalist` API. + keys: expected keys to check in the datalist. + root_dir: if not None, provides the root dir for the relative file paths in `datalist`. + allow_missing_keys: whether allow missing keys in the datalist items. + if False, raise exception if missing. default to False. + + Returns: + A list of missing filenames. + + """ + missing_files = [] + for item in datalist: + for k in ensure_tuple(keys): + if k not in item: + if not allow_missing_keys: + raise ValueError(f"key `{k}` is missing in the datalist item: {item}") + continue + + for f in ensure_tuple(item[k]): + if not isinstance(f, (str, os.PathLike)): + raise ValueError(f"filepath of key `{k}` must be a string or a list of strings, but got: {f}.") + f = Path(f) + if isinstance(root_dir, (str, os.PathLike)): + f = Path(root_dir).joinpath(f) + if not f.exists(): + missing_files.append(f) + + return missing_files + + +def create_cross_validation_datalist( + datalist: List[Dict], + nfolds: int, + train_folds: Union[Sequence[int], int], + val_folds: Union[Sequence[int], int], + train_key: str = "training", + val_key: str = "validation", + filename: Optional[Union[Path, str]] = None, + shuffle: bool = True, + seed: int = 0, + check_missing: bool = False, + keys: Optional[KeysCollection] = None, + root_dir: Optional[str] = None, + allow_missing_keys: bool = False, + raise_error: bool = True, +): + """ + Utility to create new Decathlon style datalist based on cross validation partition. + + Args: + datalist: loaded list of dictionaries for all the items to partition. + nfolds: number of the kfold split. + train_folds: indices of folds for training part. + val_folds: indices of folds for validation part. + train_key: the key of train part in the new datalist, defaults to "training". + val_key: the key of validation part in the new datalist, defaults to "validation". + filename: if not None and ends with ".json", save the new datalist into JSON file. + shuffle: whether to shuffle the datalist before partition, defaults to `True`. + seed: if `shuffle` is True, set the random seed, defaults to `0`. + check_missing: whether to check all the files specified by `keys` are existing. + keys: if not None and check_missing_files is True, the expected keys to check in the datalist. + root_dir: if not None, provides the root dir for the relative file paths in `datalist`. + allow_missing_keys: if check_missing_files is `True`, whether allow missing keys in the datalist items. + if False, raise exception if missing. default to False. + raise_error: when found missing files, if `True`, raise exception and stop, if `False`, print warning. + + """ + if check_missing and keys is not None: + files = check_missing_files(datalist, keys, root_dir, allow_missing_keys) + if files: + msg = f"some files of the datalist are missing: {files}" + if raise_error: + raise ValueError(msg) + warnings.warn(msg) + + data = partition_dataset(data=datalist, num_partitions=nfolds, shuffle=shuffle, seed=seed) + train_list = select_cross_validation_folds(partitions=data, folds=train_folds) + val_list = select_cross_validation_folds(partitions=data, folds=val_folds) + ret = {train_key: train_list, val_key: val_list} + if isinstance(filename, (str, Path)): + with open(filename, "w") as f: + json.dump(ret, f, indent=4) + + return ret diff --git a/monai/data/folder_layout.py b/monai/data/folder_layout.py new file mode 100644 index 0000000000..b2f41b0651 --- /dev/null +++ b/monai/data/folder_layout.py @@ -0,0 +1,100 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from monai.config import PathLike +from monai.data.utils import create_file_basename + +__all__ = ["FolderLayout"] + + +class FolderLayout: + """ + A utility class to create organized filenames within ``output_dir``. The + ``filename`` method could be used to create a filename following the folder structure. + + Example: + + .. code-block:: python + + from monai.data import FolderLayout + + layout = FolderLayout( + output_dir="/test_run_1/", + postfix="seg", + extension="nii", + makedirs=False) + layout.filename(subject="Sub-A", idx="00", modality="T1") + # return value: "/test_run_1/Sub-A_seg_00_modality-T1.nii" + + The output filename is a string starting with a ``subject`` ID, and + includes additional information about a customized index and image + modality. This utility class doesn't alter the underlying image data, but + provides a convenient way to create filenames. + """ + + def __init__( + self, + output_dir: PathLike, + postfix: str = "", + extension: str = "", + parent: bool = False, + makedirs: bool = False, + data_root_dir: PathLike = "", + ): + """ + Args: + output_dir: output directory. + postfix: a postfix string for output file name appended to ``subject``. + extension: output file extension to be appended to the end of an output filename. + parent: whether to add a level of parent folder to contain each image to the output filename. + makedirs: whether to create the output parent directories if they do not exist. + data_root_dir: an optional `PathLike` object to preserve the folder structure of the input `subject`. + Please see :py:func:`monai.data.utils.create_file_basename` for more details. + """ + self.output_dir = output_dir + self.postfix = postfix + self.ext = extension + self.parent = parent + self.makedirs = makedirs + self.data_root_dir = data_root_dir + + def filename(self, subject: PathLike = "subject", idx=None, **kwargs): + """ + Create a filename based on the input ``subject`` and ``idx``. + + The output filename is formed as: + + ``output_dir/[subject/]subject[_postfix][_idx][_key-value][ext]`` + + Args: + subject: subject name, used as the primary id of the output filename. + When a `PathLike` object is provided, the base filename will be used as the subject name, + the extension name of `subject` will be ignored, in favor of ``extension`` + from this class's constructor. + idx: additional index name of the image. + kwargs: additional keyword arguments to be used to form the output filename. + The key-value pairs will be appended to the output filename as ``f"_{k}-{v}"``. + """ + full_name = create_file_basename( + postfix=self.postfix, + input_file_name=subject, + folder_path=self.output_dir, + data_root_dir=self.data_root_dir, + separate_folder=self.parent, + patch_index=idx, + makedirs=self.makedirs, + ) + for k, v in kwargs.items(): + full_name += f"_{k}-{v}" + if self.ext is not None: + ext = f"{self.ext}" + full_name += f".{ext}" if ext and not ext.startswith(".") else f"{ext}" + return full_name diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 5b2a4d7abd..9eb84a58c9 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,16 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Optional, Sequence, Union - -import numpy as np -import torch -from torch.utils.data import IterableDataset +from typing import Callable, Dict, Iterable, Optional, Sequence, Union from monai.data.dataset import Dataset +from monai.data.iterable_dataset import IterableDataset from monai.data.utils import iter_patch from monai.transforms import apply_transform -from monai.utils import NumpyPadMode, ensure_tuple, look_up_option +from monai.utils import NumpyPadMode, deprecated_arg, ensure_tuple, look_up_option __all__ = ["PatchDataset", "GridPatchDataset", "PatchIter"] @@ -96,7 +93,7 @@ class GridPatchDataset(IterableDataset): patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0) # construct the dataset - ds = GridPatchDataset(dataset=images, + ds = GridPatchDataset(data=images, patch_iter=patch_iter, transform=patch_intensity) # use the grid patch dataset @@ -108,49 +105,34 @@ class GridPatchDataset(IterableDataset): # coordinates: tensor([[[0, 1], [0, 2], [0, 2]], # [[0, 1], [2, 4], [0, 2]]]) + Args: + data: the data source to read image data from. + patch_iter: converts an input image (item from dataset) into a iterable of image patches. + `patch_iter(dataset[idx])` must yield a tuple: (patches, coordinates). + see also: :py:class:`monai.data.PatchIter`. + transform: a callable data transform operates on the patches. + with_coordinates: whether to yield the coordinates of each patch, default to `True`. + + .. deprecated:: 0.8.0 + ``dataset`` is deprecated, use ``data`` instead. + """ + @deprecated_arg(name="dataset", new_name="data", since="0.8", msg_suffix="please use `data` instead.") def __init__( self, - dataset: Sequence, + data: Union[Iterable, Sequence], patch_iter: Callable, transform: Optional[Callable] = None, with_coordinates: bool = True, ) -> None: - """ - Initializes this dataset in terms of the image dataset, patch generator, and an optional transform. - - Args: - dataset: the dataset to read image data from. - patch_iter: converts an input image (item from dataset) into a iterable of image patches. - `patch_iter(dataset[idx])` must yield a tuple: (patches, coordinates). - see also: :py:class:`monai.data.PatchIter`. - transform: a callable data transform operates on the patches. - with_coordinates: whether to yield the coordinates of each patch, default to `True`. - - """ - - self.dataset = dataset + super().__init__(data=data, transform=None) self.patch_iter = patch_iter self.transform = transform self.with_coordinates = with_coordinates def __iter__(self): - worker_info = torch.utils.data.get_worker_info() - iter_start, iter_end = 0, 1 - try: - iter_end = len(self.dataset) # TODO: support iterable self.dataset - except TypeError: - raise NotImplementedError("image dataset must implement `len()`.") - - if worker_info is not None: - # split workload - per_worker = int(np.ceil((iter_end - iter_start) / float(worker_info.num_workers))) - iter_start += worker_info.id * per_worker - iter_end = min(iter_start + per_worker, iter_end) - - for index in range(iter_start, iter_end): - image = self.dataset[index] + for image in super().__iter__(): if not self.with_coordinates: for patch, *_ in self.patch_iter(image): # patch_iter to yield at least 1 item: patch out_patch = ( @@ -204,20 +186,24 @@ class PatchDataset(Dataset): >>> torch.Size([2, 1, 3, 3]) + .. deprecated:: 0.8.0 + ``dataset`` is deprecated, use ``data`` instead. + """ + @deprecated_arg(name="dataset", new_name="data", since="0.8", msg_suffix="please use `data` instead.") def __init__( - self, dataset: Sequence, patch_func: Callable, samples_per_image: int = 1, transform: Optional[Callable] = None + self, data: Sequence, patch_func: Callable, samples_per_image: int = 1, transform: Optional[Callable] = None ) -> None: """ Args: - dataset: an image dataset to extract patches from. + data: an image dataset to extract patches from. patch_func: converts an input image (item from dataset) into a sequence of image patches. patch_func(dataset[idx]) must return a sequence of patches (length `samples_per_image`). samples_per_image: `patch_func` should return a sequence of `samples_per_image` elements. transform: transform applied to each patch. """ - super().__init__(data=dataset, transform=transform) + super().__init__(data=data, transform=transform) self.patch_func = patch_func if samples_per_image <= 0: diff --git a/monai/data/image_dataset.py b/monai/data/image_dataset.py index 874b9dc004..51f4e04959 100644 --- a/monai/data/image_dataset.py +++ b/monai/data/image_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -37,6 +37,7 @@ def __init__( labels: Optional[Sequence[float]] = None, transform: Optional[Callable] = None, seg_transform: Optional[Callable] = None, + label_transform: Optional[Callable] = None, image_only: bool = True, transform_with_metadata: bool = False, dtype: DtypeLike = np.float32, @@ -49,19 +50,20 @@ def __init__( to the images and `seg_transform` to the segmentations. Args: - image_files: list of image filenames - seg_files: if in segmentation task, list of segmentation filenames - labels: if in classification task, list of classification labels - transform: transform to apply to image arrays - seg_transform: transform to apply to segmentation arrays - image_only: if True return only the image volume, otherwise, return image volume and the metadata + image_files: list of image filenames. + seg_files: if in segmentation task, list of segmentation filenames. + labels: if in classification task, list of classification labels. + transform: transform to apply to image arrays. + seg_transform: transform to apply to segmentation arrays. + label_transform: transform to apply to the label data. + image_only: if True return only the image volume, otherwise, return image volume and the metadata. transform_with_metadata: if True, the metadata will be passed to the transforms whenever possible. - dtype: if not None convert the loaded image to this data type + dtype: if not None convert the loaded image to this data type. reader: register reader to load image file and meta data, if None, will use the default readers. If a string of reader name provided, will construct a reader object with the `*args` and `**kwargs` parameters, supported reader name: "NibabelReader", "PILReader", "ITKReader", "NumpyReader" - args: additional parameters for reader if providing a reader name - kwargs: additional parameters for reader if providing a reader name + args: additional parameters for reader if providing a reader name. + kwargs: additional parameters for reader if providing a reader name. Raises: ValueError: When ``seg_files`` length differs from ``image_files`` @@ -79,6 +81,7 @@ def __init__( self.labels = labels self.transform = transform self.seg_transform = seg_transform + self.label_transform = label_transform if image_only and transform_with_metadata: raise ValueError("transform_with_metadata=True requires image_only=False.") self.image_only = image_only @@ -117,7 +120,7 @@ def __getitem__(self, index: int): else: img = apply_transform(self.transform, img, map_items=False) - if self.seg_transform is not None: + if self.seg_files is not None and self.seg_transform is not None: if isinstance(self.seg_transform, Randomizable): self.seg_transform.set_random_state(seed=self._seed) @@ -130,6 +133,8 @@ def __getitem__(self, index: int): if self.labels is not None: label = self.labels[index] + if self.label_transform is not None: + label = apply_transform(self.label_transform, label, map_items=False) # type: ignore # construct outputs data = [img] diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index cd1486d6d3..502c6fb93b 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,37 +9,35 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import warnings from abc import ABC, abstractmethod +from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern -from monai.config import DtypeLike, KeysCollection -from monai.data.utils import correct_nifti_header_if_necessary +from monai.config import DtypeLike, KeysCollection, PathLike +from monai.data.utils import correct_nifti_header_if_necessary, is_supported_format, orientation_ras_lps from monai.transforms.utility.array import EnsureChannelFirst -from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import - -from .utils import is_supported_format +from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import, require_pkg if TYPE_CHECKING: - import cucim - import itk # type: ignore + import itk import nibabel as nib - import openslide from nibabel.nifti1 import Nifti1Image from PIL import Image as PILImage - has_itk = has_nib = has_pil = has_cim = has_osl = True + has_itk = has_nib = has_pil = True else: itk, has_itk = optional_import("itk", allow_namespace_pkg=True) nib, has_nib = optional_import("nibabel") Nifti1Image, _ = optional_import("nibabel.nifti1", name="Nifti1Image") PILImage, has_pil = optional_import("PIL.Image") - cucim, has_cim = optional_import("cucim") - openslide, has_osl = optional_import("openslide") + +OpenSlide, _ = optional_import("openslide", name="OpenSlide") +CuImage, _ = optional_import("cucim", name="CuImage") +TiffFile, _ = optional_import("tifffile", name="TiffFile") __all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "WSIReader"] @@ -64,7 +62,7 @@ class ImageReader(ABC): """ @abstractmethod - def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ Verify whether the specified `filename` is supported by the current reader. This method should return True if the reader is able to read the format suggested by the @@ -78,7 +76,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") @abstractmethod - def read(self, data: Union[Sequence[str], str], **kwargs) -> Union[Sequence[Any], Any]: + def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs) -> Union[Sequence[Any], Any]: """ Read image data from specified file or files. Note that it returns a data object or a sequence of data objects. @@ -131,11 +129,14 @@ def _stack_images(image_list: List, meta_dict: Dict): if len(image_list) <= 1: return image_list[0] if meta_dict.get("original_channel_dim", None) not in ("no_channel", None): - raise RuntimeError("can not read a list of images which already have channel dimension.") + channel_dim = int(meta_dict["original_channel_dim"]) + return np.concatenate(image_list, axis=channel_dim) + # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified meta_dict["original_channel_dim"] = 0 return np.stack(image_list, axis=0) +@require_pkg(pkg_name="itk") class ITKReader(ImageReader): """ Load medical images based on ITK library. @@ -147,25 +148,44 @@ class ITKReader(ImageReader): Args: channel_dim: the channel dimension of the input image, default is None. This is used to set original_channel_dim in the meta data, EnsureChannelFirstD reads this field. - If None, original_channel_dim will be either `no_channel` or `-1`. + If None, `original_channel_dim` will be either `no_channel` or `-1`. - Nifti file is usually "channel last", so there is no need to specify this argument. - PNG file usually has `GetNumberOfComponentsPerPixel()==3`, so there is no need to specify this argument. series_name: the name of the DICOM series if there are multiple ones. used when loading DICOM series. + reverse_indexing: whether to use a reversed spatial indexing convention for the returned data array. + If ``False``, the spatial indexing follows the numpy convention; + otherwise, the spatial indexing convention is reversed to be compatible with ITK. Default is ``False``. + This option does not affect the metadata. + series_meta: whether to load the metadata of the DICOM series (using the metadata from the first slice). + This flag is checked only when loading DICOM series. Default is ``False``. + affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``. + Set to ``True`` to be consistent with ``NibabelReader``, otherwise the affine matrix remains in the ITK convention. kwargs: additional args for `itk.imread` API. more details about available args: https://github.com/InsightSoftwareConsortium/ITK/blob/master/Wrapping/Generators/Python/itk/support/extras.py """ - def __init__(self, channel_dim: Optional[int] = None, series_name: str = "", **kwargs): + def __init__( + self, + channel_dim: Optional[int] = None, + series_name: str = "", + reverse_indexing: bool = False, + series_meta: bool = False, + affine_lps_to_ras: bool = True, + **kwargs, + ): super().__init__() self.kwargs = kwargs self.channel_dim = channel_dim self.series_name = series_name + self.reverse_indexing = reverse_indexing + self.series_meta = series_meta + self.affine_lps_to_ras = affine_lps_to_ras - def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ Verify whether the specified file or files format is supported by ITK reader. @@ -176,10 +196,10 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: """ return has_itk - def read(self, data: Union[Sequence[str], str], **kwargs): + def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): """ - Read image data from specified file or files, it can read a list of `no-channel` images - and stack them together as multi-channels data in `get_data()`. + Read image data from specified file or files, it can read a list of images + and stack them together as multi-channel data in `get_data()`. If passing directory path instead of file path, will treat it as DICOM images series and read. Note that the returned object is ITK image object or list of ITK image objects. @@ -192,11 +212,12 @@ def read(self, data: Union[Sequence[str], str], **kwargs): """ img_ = [] - filenames: Sequence[str] = ensure_tuple(data) + filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: - if os.path.isdir(name): + name = f"{name}" + if Path(name).is_dir(): # read DICOM series # https://itk.org/ITKExamples/src/IO/GDCM/ReadDICOMSeriesAndWrite3DImage names_generator = itk.GDCMSeriesFileNames.New() @@ -212,7 +233,17 @@ def read(self, data: Union[Sequence[str], str], **kwargs): series_identifier = series_uid[0] if not self.series_name else self.series_name name = names_generator.GetFileNames(series_identifier) - img_.append(itk.imread(name, **kwargs_)) + _obj = itk.imread(name, **kwargs_) + if self.series_meta: + _reader = itk.ImageSeriesReader.New(FileNames=name) + _reader.Update() + _meta = _reader.GetMetaDataDictionaryArray() + if len(_meta) > 0: + # TODO: using the first slice's meta. this could be improved to filter unnecessary tags. + _obj.SetMetaDataDictionary(_meta[0]) + img_.append(_obj) + else: + img_.append(itk.imread(name, **kwargs_)) return img_ if len(filenames) > 1 else img_[0] def get_data(self, img): @@ -234,7 +265,7 @@ def get_data(self, img): data = self._get_array_data(i) img_array.append(data) header = self._get_meta_dict(i) - header["original_affine"] = self._get_affine(i) + header["original_affine"] = self._get_affine(i, self.affine_lps_to_ras) header["affine"] = header["original_affine"].copy() header["spatial_shape"] = self._get_spatial_shape(i) if self.channel_dim is None: # default to "no_channel" or -1 @@ -259,13 +290,14 @@ def _get_meta_dict(self, img) -> Dict: meta_dict["spacing"] = np.asarray(img.GetSpacing()) return meta_dict - def _get_affine(self, img): + def _get_affine(self, img, lps_to_ras: bool = True): """ Get or construct the affine matrix of the image, it can be used to correct spacing, orientation or execute spatial transforms. Args: img: an ITK image object loaded from an image file. + lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to True. """ direction = itk.array_from_matrix(img.GetDirection()) @@ -277,8 +309,8 @@ def _get_affine(self, img): affine: np.ndarray = np.eye(sr + 1) affine[:sr, :sr] = direction[:sr, :sr] @ np.diag(spacing[:sr]) affine[:sr, -1] = origin[:sr] - flip_diag = [[-1, 1], [-1, -1, 1], [-1, -1, 1, 1]][sr - 1] # itk to nibabel affine - affine = np.diag(flip_diag) @ affine + if lps_to_ras: + affine = orientation_ras_lps(affine) return affine def _get_spatial_shape(self, img): @@ -289,15 +321,11 @@ def _get_spatial_shape(self, img): img: an ITK image object loaded from an image file. """ - # the img data should have no channel dim - sr = itk.array_from_matrix(img.GetDirection()).shape[0] sr = max(min(sr, 3), 1) _size = list(itk.size(img)) if self.channel_dim is not None: - # channel_dim is given in the numpy convention, which is different from ITK - # size is reversed - _size.pop(-self.channel_dim) + _size.pop(self.channel_dim) return np.asarray(_size[:sr]) def _get_array_data(self, img): @@ -306,39 +334,57 @@ def _get_array_data(self, img): Following PyTorch conventions, the returned array data has contiguous channels, e.g. for an RGB image, all red channel image pixels are contiguous in memory. - The first axis of the returned array is the channel axis. + The last axis of the returned array is the channel axis. + + See also: + + - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Modules/Bridge/NumPy/wrapping/PyBuffer.i.in Args: img: an ITK image object loaded from an image file. """ - channels = img.GetNumberOfComponentsPerPixel() - np_data = itk.array_view_from_image(img).T - if channels == 1: - return np_data - if channels != np_data.shape[0]: - warnings.warn("itk_img.GetNumberOfComponentsPerPixel != numpy data channels") - return np.moveaxis(np_data, 0, -1) # channel last is compatible with `write_nifti` + np_img = itk.array_view_from_image(img, keep_axes=False) + if img.GetNumberOfComponentsPerPixel() == 1: # handling spatial images + return np_img if self.reverse_indexing else np_img.T + # handling multi-channel images + return np_img if self.reverse_indexing else np.moveaxis(np_img.T, 0, -1) +@require_pkg(pkg_name="nibabel") class NibabelReader(ImageReader): """ Load NIfTI format images based on Nibabel library. Args: as_closest_canonical: if True, load the image as closest to canonical axis format. + squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3) + channel_dim: the channel dimension of the input image, default is None. + this is used to set original_channel_dim in the meta data, EnsureChannelFirstD reads this field. + if None, `original_channel_dim` will be either `no_channel` or `-1`. + most Nifti files are usually "channel last", no need to specify this argument for them. + dtype: dtype of the output data array when loading with Nibabel library. kwargs: additional args for `nibabel.load` API. more details about available args: https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py """ - def __init__(self, as_closest_canonical: bool = False, dtype: DtypeLike = np.float32, **kwargs): + def __init__( + self, + channel_dim: Optional[int] = None, + as_closest_canonical: bool = False, + squeeze_non_spatial_dims: bool = False, + dtype: DtypeLike = np.float32, + **kwargs, + ): super().__init__() + self.channel_dim = channel_dim self.as_closest_canonical = as_closest_canonical + self.squeeze_non_spatial_dims = squeeze_non_spatial_dims self.dtype = dtype self.kwargs = kwargs - def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ Verify whether the specified file or files format is supported by Nibabel reader. @@ -350,10 +396,10 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: suffixes: Sequence[str] = ["nii", "nii.gz"] return has_nib and is_supported_format(filename, suffixes) - def read(self, data: Union[Sequence[str], str], **kwargs): + def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): """ - Read image data from specified file or files, it can read a list of `no-channel` images - and stack them together as multi-channels data in `get_data()`. + Read image data from specified file or files, it can read a list of images + and stack them together as multi-channel data in `get_data()`. Note that the returned object is Nibabel image object or list of Nibabel image objects. Args: @@ -365,7 +411,7 @@ def read(self, data: Union[Sequence[str], str], **kwargs): """ img_: List[Nifti1Image] = [] - filenames: Sequence[str] = ensure_tuple(data) + filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -399,8 +445,15 @@ def get_data(self, img): header["affine"] = self._get_affine(i) header["spatial_shape"] = self._get_spatial_shape(i) data = self._get_array_data(i) + if self.squeeze_non_spatial_dims: + for d in range(len(data.shape), len(header["spatial_shape"]), -1): + if data.shape[d - 1] == 1: + data = data.squeeze(axis=d - 1) img_array.append(data) - header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else -1 + if self.channel_dim is None: # default to "no_channel" or -1 + header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else -1 + else: + header["original_channel_dim"] = self.channel_dim _copy_compatible_dict(header, compatible_meta) return _stack_images(img_array, compatible_meta), compatible_meta @@ -449,9 +502,11 @@ def _get_spatial_shape(self, img): dim = header.get("dims") # mgh format? dim = np.insert(dim, 0, 3) ndim = dim[0] - spatial_rank = min(ndim, 3) - # the img data should have no channel dim or the last dim is channel - return np.asarray(dim[1 : spatial_rank + 1]) + size = list(dim[1:]) + if self.channel_dim is not None: + size.pop(self.channel_dim) + spatial_rank = max(min(ndim, 3), 1) + return np.asarray(size[:spatial_rank]) def _get_array_data(self, img): """ @@ -475,19 +530,21 @@ class NumpyReader(ImageReader): Args: npz_keys: if loading npz file, only load the specified keys, if None, load all the items. stack the loaded items together to construct a new first dimension. + channel_dim: if not None, explicitly specify the channel dim, otherwise, treat the array as no channel. kwargs: additional args for `numpy.load` API except `allow_pickle`. more details about available args: https://numpy.org/doc/stable/reference/generated/numpy.load.html """ - def __init__(self, npz_keys: Optional[KeysCollection] = None, **kwargs): + def __init__(self, npz_keys: Optional[KeysCollection] = None, channel_dim: Optional[int] = None, **kwargs): super().__init__() if npz_keys is not None: npz_keys = ensure_tuple(npz_keys) self.npz_keys = npz_keys + self.channel_dim = channel_dim self.kwargs = kwargs - def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ Verify whether the specified file or files format is supported by Numpy reader. @@ -498,10 +555,10 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: suffixes: Sequence[str] = ["npz", "npy"] return is_supported_format(filename, suffixes) - def read(self, data: Union[Sequence[str], str], **kwargs): + def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): """ - Read image data from specified file or files, it can read a list of `no-channel` data files - and stack them together as multi-channels data in `get_data()`. + Read image data from specified file or files, it can read a list of data files + and stack them together as multi-channel data in `get_data()`. Note that the returned object is Numpy array or list of Numpy arrays. Args: @@ -513,12 +570,12 @@ def read(self, data: Union[Sequence[str], str], **kwargs): """ img_: List[Nifti1Image] = [] - filenames: Sequence[str] = ensure_tuple(data) + filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: img = np.load(name, allow_pickle=True, **kwargs_) - if name.endswith(".npz"): + if Path(name).name.endswith(".npz"): # load expected items from NPZ file npz_keys = [f"arr_{i}" for i in range(len(img))] if self.npz_keys is None else self.npz_keys for k in npz_keys: @@ -548,14 +605,19 @@ def get_data(self, img): for i in ensure_tuple(img): header = {} if isinstance(i, np.ndarray): - # can not detect the channel dim of numpy array, use all the dims as spatial_shape - header["spatial_shape"] = i.shape + # if `channel_dim` is None, can not detect the channel dim, use all the dims as spatial_shape + spatial_shape = np.asarray(i.shape) + if isinstance(self.channel_dim, int): + spatial_shape = np.delete(spatial_shape, self.channel_dim) + header["spatial_shape"] = spatial_shape img_array.append(i) + header["original_channel_dim"] = self.channel_dim if isinstance(self.channel_dim, int) else "no_channel" _copy_compatible_dict(header, compatible_meta) return _stack_images(img_array, compatible_meta), compatible_meta +@require_pkg(pkg_name="PIL") class PILReader(ImageReader): """ Load common 2D image format (supports PNG, JPG, BMP) file or files from provided path. @@ -572,7 +634,7 @@ def __init__(self, converter: Optional[Callable] = None, **kwargs): self.converter = converter self.kwargs = kwargs - def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ Verify whether the specified file or files format is supported by PIL reader. @@ -583,10 +645,10 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: suffixes: Sequence[str] = ["png", "jpg", "jpeg", "bmp"] return has_pil and is_supported_format(filename, suffixes) - def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs): + def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): """ - Read image data from specified file or files, it can read a list of `no-channel` images - and stack them together as multi-channels data in `get_data()`. + Read image data from specified file or files, it can read a list of images + and stack them together as multi-channel data in `get_data()`. Note that the returned object is PIL image or list of PIL image. Args: @@ -598,7 +660,7 @@ def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs): """ img_: List[PILImage.Image] = [] - filenames: Sequence[str] = ensure_tuple(data) + filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -616,6 +678,8 @@ def get_data(self, img): It computes `spatial_shape` and stores it in meta dict. When loading a list of files, they are stacked together at a new dimension as the first dimension, and the meta data of the first image is used to represent the output meta data. + Note that it will swap axis 0 and 1 after loading the array because the `HW` definition in PIL + is different from other common medical packages. Args: img: a PIL Image object loaded from a file or a list of PIL Image objects. @@ -641,12 +705,7 @@ def _get_meta_dict(self, img) -> Dict: img: a PIL Image object loaded from an image file. """ - return { - "format": img.format, - "mode": img.mode, - "width": img.width, - "height": img.height, - } + return {"format": img.format, "mode": img.mode, "width": img.width, "height": img.height} def _get_spatial_shape(self, img): """ @@ -659,26 +718,44 @@ def _get_spatial_shape(self, img): class WSIReader(ImageReader): """ - Read whole slide imaging and extract patches. + Read whole slide images and extract patches. Args: - reader_lib: backend library to load the images, available options: "OpenSlide" or "cuCIM". + backend: backend library to load the images, available options: "cuCIM", "OpenSlide" and "TiffFile". + level: the whole slide image level at which the image is extracted. (default=0) + This is overridden if the level argument is provided in `get_data`. + kwargs: additional args for backend reading API in `read()`, more details in `cuCIM`, `TiffFile`, `OpenSlide`: + https://github.com/rapidsai/cucim/blob/v21.12.00/cpp/include/cucim/cuimage.h#L100. + https://github.com/cgohlke/tifffile. + https://openslide.org/api/python/#openslide.OpenSlide. + + Note: + While "cuCIM" and "OpenSlide" backends both can load patches from large whole slide images + without loading the entire image into memory, "TiffFile" backend needs to load the entire image into memory + before extracting any patch; thus, memory consideration is needed when using "TiffFile" backend for + patch extraction. """ - def __init__(self, reader_lib: str = "OpenSlide"): + def __init__(self, backend: str = "OpenSlide", level: int = 0, **kwargs): super().__init__() - self.reader_lib = reader_lib.lower() - if self.reader_lib == "openslide": - if has_osl: - self.wsi_reader = openslide.OpenSlide - elif self.reader_lib == "cucim": - if has_cim: - self.wsi_reader = cucim.CuImage - else: - raise ValueError('`reader_lib` should be either "cuCIM" or "OpenSlide"') + self.backend = backend.lower() + func = require_pkg(self.backend)(self._set_reader) + self.wsi_reader = func(self.backend) + self.level = level + self.kwargs = kwargs + + @staticmethod + def _set_reader(backend: str): + if backend == "openslide": + return OpenSlide + if backend == "cucim": + return CuImage + if backend == "tifffile": + return TiffFile + raise ValueError("`backend` should be 'cuCIM', 'OpenSlide' or 'TiffFile'.") - def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ Verify whether the specified file or files format is supported by WSI reader. @@ -688,26 +765,30 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: """ return is_supported_format(filename, ["tif", "tiff"]) - def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs): + def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): """ - Read image data from specified file or files. - Note that the returned object is CuImage or list of CuImage objects. + Read image data from given file or list of files. Args: data: file name or a list of file names to read. + kwargs: additional args for backend reading API in `read()`, will override `self.kwargs` for existing keys. + more details in `cuCIM`, `TiffFile`, `OpenSlide`: + https://github.com/rapidsai/cucim/blob/v21.12.00/cpp/include/cucim/cuimage.h#L100. + https://github.com/cgohlke/tifffile. + https://openslide.org/api/python/#openslide.OpenSlide. - """ - if (self.reader_lib == "openslide") and (not has_osl): - raise ImportError("No module named 'openslide'") - if (self.reader_lib == "cucim") and (not has_cim): - raise ImportError("No module named 'cucim'") + Returns: + image object or list of image objects + """ img_: List = [] - filenames: Sequence[str] = ensure_tuple(data) + filenames: Sequence[PathLike] = ensure_tuple(data) + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) for name in filenames: - img = self.wsi_reader(name) - if self.reader_lib == "openslide": + img = self.wsi_reader(name, **kwargs_) + if self.backend == "openslide": img.shape = (img.dimensions[1], img.dimensions[0], 3) img_.append(img) @@ -718,7 +799,7 @@ def get_data( img, location: Tuple[int, int] = (0, 0), size: Optional[Tuple[int, int]] = None, - level: int = 0, + level: Optional[int] = None, dtype: DtypeLike = np.uint8, grid_shape: Tuple[int, int] = (1, 1), patch_size: Optional[Union[int, Tuple[int, int]]] = None, @@ -737,33 +818,71 @@ def get_data( grid_shape: (row, columns) tuple define a grid to extract patches on that patch_size: (height, width) the size of extracted patches at the given level """ + # Verify inputs + if level is None: + level = self._check_level(img, level) - if self.reader_lib == "openslide" and size is None: - # the maximum size is set to WxH - size = ( - img.shape[0] // (2 ** level) - location[0], - img.shape[1] // (2 ** level) - location[1], - ) - + # Extract a region or the entire image region = self._extract_region(img, location=location, size=size, level=level, dtype=dtype) + # Add necessary metadata metadata: Dict = {} - metadata["spatial_shape"] = size + metadata["spatial_shape"] = np.asarray(region.shape[:-1]) metadata["original_channel_dim"] = -1 + + # Make it channel first region = EnsureChannelFirst()(region, metadata) + + # Split into patches if patch_size is None: patches = region else: tuple_patch_size = ensure_tuple_rep(patch_size, 2) patches = self._extract_patches( - region, - patch_size=tuple_patch_size, # type: ignore - grid_shape=grid_shape, - dtype=dtype, + region, patch_size=tuple_patch_size, grid_shape=grid_shape, dtype=dtype # type: ignore ) return patches, metadata + def _check_level(self, img, level): + level = self.level + + level_count = 0 + if self.backend == "openslide": + level_count = img.level_count + elif self.backend == "cucim": + level_count = img.resolutions["level_count"] + elif self.backend == "tifffile": + level_count = len(img.pages) + + if level > level_count - 1: + raise ValueError(f"The maximum level of this image is {level_count - 1} while level={level} is requested)!") + + return level + + def _get_image_size(self, img, size, level, location): + """ + Calculate the maximum region size for the given level and starting location (if size is None). + Note that region size in OpenSlide and cuCIM are WxH (but the final image output would be HxW) + """ + if size is not None: + return size[::-1] + + max_size = [] + downsampling_factor = [] + if self.backend == "openslide": + downsampling_factor = img.level_downsamples[level] + max_size = img.level_dimensions[level] + elif self.backend == "cucim": + downsampling_factor = img.resolutions["level_downsamples"][level] + max_size = img.resolutions["level_dimensions"][level] + + # subtract the top left corner of the patch (at given level) from maximum size + location_at_level = (round(location[1] / downsampling_factor), round(location[0] / downsampling_factor)) + size = [max_size[i] - location_at_level[i] for i in range(len(max_size))] + + return size + def _extract_region( self, img_obj, @@ -772,35 +891,56 @@ def _extract_region( level: int = 0, dtype: DtypeLike = np.uint8, ): - # reverse the order of dimensions for size and location to be compatible with image shape - location = location[::-1] - if size is None: - region = img_obj.read_region(location=location, level=level) + if self.backend == "tifffile": + # Read the entire image + if size is not None: + raise ValueError( + f"TiffFile backend reads the entire image only, so size '{size}'' should not be provided!", + "For more flexibility or extracting regions, please use cuCIM or OpenSlide backend.", + ) + if location != (0, 0): + raise ValueError( + f"TiffFile backend reads the entire image only, so location '{location}' should not be provided!", + "For more flexibility and extracting regions, please use cuCIM or OpenSlide backend.", + ) + region = img_obj.asarray(level=level) else: - size = size[::-1] - region = img_obj.read_region(location=location, size=size, level=level) + # Get region size to be extracted + region_size = self._get_image_size(img_obj, size, level, location) + # reverse the order of location's dimensions to become WxH (for cuCIM and OpenSlide) + region_location = location[::-1] + # Extract a region (or the entire image) + region = img_obj.read_region(location=region_location, size=region_size, level=level) region = self.convert_to_rgb_array(region, dtype) return region - def convert_to_rgb_array( - self, - raw_region, - dtype: DtypeLike = np.uint8, - ): + def convert_to_rgb_array(self, raw_region, dtype: DtypeLike = np.uint8): """Convert to RGB mode and numpy array""" - if self.reader_lib == "openslide": + if self.backend == "openslide": # convert to RGB raw_region = raw_region.convert("RGB") - # convert to numpy - raw_region = np.asarray(raw_region, dtype=dtype) - else: - num_channels = len(raw_region.channel_names) - # convert to numpy - raw_region = np.asarray(raw_region, dtype=dtype) - # remove alpha channel if exist (RGBA) - if num_channels > 3: - raw_region = raw_region[:, :, :3] + + # convert to numpy (if not already in numpy) + raw_region = np.asarray(raw_region, dtype=dtype) + + # check if the image has three dimensions (2D + color) + if raw_region.ndim != 3: + raise ValueError( + f"The input image dimension should be 3 but {raw_region.ndim} is given. " + "`WSIReader` is designed to work only with 2D colored images." + ) + + # check if the color channel is 3 (RGB) or 4 (RGBA) + if raw_region.shape[-1] not in [3, 4]: + raise ValueError( + f"There should be three or four color channels but {raw_region.shape[-1]} is given. " + "`WSIReader` is designed to work only with 2D colored images." + ) + + # remove alpha channel if exist (RGBA) + if raw_region.shape[-1] > 3: + raw_region = raw_region[..., :3] return raw_region diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py new file mode 100644 index 0000000000..cf9ef90e8c --- /dev/null +++ b/monai/data/image_writer.py @@ -0,0 +1,819 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict, Mapping, Optional, Sequence, Union + +import numpy as np + +from monai.apps.utils import get_logger +from monai.config import DtypeLike, NdarrayOrTensor, PathLike +from monai.data.utils import affine_to_spacing, ensure_tuple, ensure_tuple_rep, orientation_ras_lps, to_affine_nd +from monai.transforms.spatial.array import Resize, SpatialResample +from monai.transforms.utils_pytorch_numpy_unification import ascontiguousarray, moveaxis +from monai.utils import ( + GridSampleMode, + GridSamplePadMode, + InterpolateMode, + OptionalImportError, + convert_data_type, + look_up_option, + optional_import, + require_pkg, +) + +DEFAULT_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)d - %(message)s" +EXT_WILDCARD = "*" +logger = get_logger(module_name=__name__, fmt=DEFAULT_FMT) + +if TYPE_CHECKING: + import itk + import nibabel as nib + from PIL import Image as PILImage +else: + itk, _ = optional_import("itk", allow_namespace_pkg=True) + nib, _ = optional_import("nibabel") + PILImage, _ = optional_import("PIL.Image") + + +__all__ = [ + "ImageWriter", + "ITKWriter", + "NibabelWriter", + "PILWriter", + "SUPPORTED_WRITERS", + "register_writer", + "resolve_writer", + "logger", +] + +SUPPORTED_WRITERS: Dict = {} + + +def register_writer(ext_name, *im_writers): + """ + Register ``ImageWriter``, so that writing a file with filename extension ``ext_name`` + could be resolved to a tuple of potentially appropriate ``ImageWriter``. + The customised writers could be registered by: + + .. code-block:: python + + from monai.data import register_writer + # `MyWriter` must implement `ImageWriter` interface + register_writer("nii", MyWriter) + + Args: + ext_name: the filename extension of the image. + As an indexing key, it will be converted to a lower case string. + im_writers: one or multiple ImageWriter classes with high priority ones first. + """ + fmt = f"{ext_name}".lower() + if fmt.startswith("."): + fmt = fmt[1:] + existing = look_up_option(fmt, SUPPORTED_WRITERS, default=()) + all_writers = im_writers + existing + SUPPORTED_WRITERS[fmt] = all_writers + + +def resolve_writer(ext_name, error_if_not_found=True) -> Sequence: + """ + Resolves to a tuple of available ``ImageWriter`` in ``SUPPORTED_WRITERS`` + according to the filename extension key ``ext_name``. + + Args: + ext_name: the filename extension of the image. + As an indexing key it will be converted to a lower case string. + error_if_not_found: whether to raise an error if no suitable image writer is found. + if True , raise an ``OptionalImportError``, otherwise return an empty tuple. Default is ``True``. + """ + if not SUPPORTED_WRITERS: + init() + fmt = f"{ext_name}".lower() + if fmt.startswith("."): + fmt = fmt[1:] + avail_writers = [] + default_writers = SUPPORTED_WRITERS.get(EXT_WILDCARD, ()) + for _writer in look_up_option(fmt, SUPPORTED_WRITERS, default=default_writers): + try: + _writer() # this triggers `monai.utils.module.require_pkg` to check the system availability + avail_writers.append(_writer) + except OptionalImportError: + continue + except Exception: # other writer init errors indicating it exists + avail_writers.append(_writer) + if not avail_writers and error_if_not_found: + raise OptionalImportError(f"No ImageWriter backend found for {fmt}.") + writer_tuple = ensure_tuple(avail_writers) + SUPPORTED_WRITERS[fmt] = writer_tuple + return writer_tuple + + +class ImageWriter: + """ + The class is a collection of utilities to write images to disk. + + Main aspects to be considered are: + + - dimensionality of the data array, arrangements of spatial dimensions and channel/time dimensions + - ``convert_to_channel_last()`` + - metadata of the current affine and output affine, the data array should be converted accordingly + - ``get_meta_info()`` + - ``resample_if_needed()`` + - data type handling of the output image (as part of ``resample_if_needed()``) + + Subclasses of this class should implement the backend-specific functions: + + - ``set_data_array()`` to set the data array (input must be numpy array or torch tensor) + - this method sets the backend object's data part + - ``set_metadata()`` to set the metadata and output affine + - this method sets the metadata including affine handling and image resampling + - backend-specific data object ``create_backend_obj()`` + - backend-specific writing function ``write()`` + + The primary usage of subclasses of ``ImageWriter`` is: + + .. code-block:: python + + writer = MyWriter() # subclass of ImageWriter + writer.set_data_array(data_array) + writer.set_metadata(meta_dict) + writer.write(filename) + + This creates an image writer object based on ``data_array`` and ``meta_dict`` and write to ``filename``. + + It supports up to three spatial dimensions (with the resampling step supports for both 2D and 3D). + When saving multiple time steps or multiple channels `data_array`, time + and/or modality axes should be the at the `channel_dim`. For example, + the shape of a 2D eight-class and ``channel_dim=0``, the segmentation + probabilities to be saved could be `(8, 64, 64)`; in this case + ``data_array`` will be converted to `(64, 64, 1, 8)` (the third + dimension is reserved as a spatial dimension). + + The ``metadata`` could optionally have the following keys: + + - ``'original_affine'``: for data original affine, it will be the + affine of the output object, defaulting to an identity matrix. + - ``'affine'``: it should specify the current data affine, defaulting to an identity matrix. + - ``'spatial_shape'``: for data output spatial shape. + + When ``metadata`` is specified, the saver will may resample data from the space defined by + `"affine"` to the space defined by `"original_affine"`, for more details, please refer to the + ``resample_if_needed`` method. + """ + + def __init__(self, **kwargs): + """ + The constructor supports adding new instance members. + The current member in the base class is ``self.data_obj``, the subclasses can add more members, + so that necessary meta information can be stored in the object and shared among the class methods. + """ + self.data_obj = None + for k, v in kwargs.items(): + setattr(self, k, v) + + def set_data_array(self, data_array, **kwargs): + raise NotImplementedError(f"Subclasses of {self.__class__.__name__} must implement this method.") + + def set_metadata(self, meta_dict: Optional[Mapping], **options): + raise NotImplementedError(f"Subclasses of {self.__class__.__name__} must implement this method.") + + def write(self, filename: PathLike, verbose: bool = True, **kwargs): + """subclass should implement this method to call the backend-specific writing APIs.""" + if verbose: + logger.info(f"writing: {filename}") + + @classmethod + def create_backend_obj(cls, data_array: NdarrayOrTensor, **kwargs) -> np.ndarray: + """ + Subclass should implement this method to return a backend-specific data representation object. + This method is used by ``cls.write`` and the input ``data_array`` is assumed 'channel-last'. + """ + return convert_data_type(data_array, np.ndarray)[0] + + @classmethod + def resample_if_needed( + cls, + data_array: NdarrayOrTensor, + affine: Optional[NdarrayOrTensor] = None, + target_affine: Optional[NdarrayOrTensor] = None, + output_spatial_shape: Union[Sequence[int], int, None] = None, + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + align_corners: bool = False, + dtype: DtypeLike = np.float64, + ): + """ + Convert the ``data_array`` into the coordinate system specified by + ``target_affine``, from the current coordinate definition of ``affine``. + + If the transform between ``affine`` and ``target_affine`` could be + achieved by simply transposing and flipping ``data_array``, no resampling + will happen. Otherwise, this function resamples ``data_array`` using the + transformation computed from ``affine`` and ``target_affine``. + + This function assumes the NIfTI dimension notations. Spatially it + supports up to three dimensions, that is, H, HW, HWD for 1D, 2D, 3D + respectively. When saving multiple time steps or multiple channels, + time and/or modality axes should be appended after the first three + dimensions. For example, shape of 2D eight-class segmentation + probabilities to be saved could be `(64, 64, 1, 8)`. Also, data in + shape `(64, 64, 8)` or `(64, 64, 8, 1)` will be considered as a + single-channel 3D image. The ``convert_to_channel_last`` method can be + used to convert the data to the format described here. + + Note that the shape of the resampled ``data_array`` may subject to some + rounding errors. For example, resampling a 20x20 pixel image from pixel + size (1.5, 1.5)-mm to (3.0, 3.0)-mm space will return a 10x10-pixel + image. However, resampling a 20x20-pixel image from pixel size (2.0, + 2.0)-mm to (3.0, 3.0)-mm space will output a 14x14-pixel image, where + the image shape is rounded from 13.333x13.333 pixels. In this case + ``output_spatial_shape`` could be specified so that this function + writes image data to a designated shape. + + Args: + data_array: input data array to be converted. + affine: the current affine of ``data_array``. Defaults to identity + target_affine: the designated affine of ``data_array``. + The actual output affine might be different from this value due to precision changes. + output_spatial_shape: spatial shape of the output image. + This option is used when resampling is needed. + mode: available options are {``"bilinear"``, ``"nearest"``, ``"bicubic"``}. + This option is used when resampling is needed. + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + padding_mode: available options are {``"zeros"``, ``"border"``, ``"reflection"``}. + This option is used when resampling is needed. + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + align_corners: boolean option of ``grid_sample`` to handle the corner convention. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + dtype: data type for resampling computation. Defaults to + ``np.float64`` for best precision. If ``None``, use the data type of input data. + The output data type of this method is always ``np.float32``. + """ + resampler = SpatialResample(mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype) + output_array, target_affine = resampler( + data_array[None], src_affine=affine, dst_affine=target_affine, spatial_size=output_spatial_shape + ) + return output_array[0], target_affine + + @classmethod + def convert_to_channel_last( + cls, + data: NdarrayOrTensor, + channel_dim: Union[None, int, Sequence[int]] = 0, + squeeze_end_dims: bool = True, + spatial_ndim: Optional[int] = 3, + contiguous: bool = False, + ): + """ + Rearrange the data array axes to make the `channel_dim`-th dim the last + dimension and ensure there are ``spatial_ndim`` number of spatial + dimensions. + + When ``squeeze_end_dims`` is ``True``, a postprocessing step will be + applied to remove any trailing singleton dimensions. + + Args: + data: input data to be converted to "channel-last" format. + channel_dim: specifies the channel axes of the data array to move to the last. + ``None`` indicates no channel dimension, a new axis will be appended as the channel dimension. + a sequence of integers indicates multiple non-spatial dimensions. + squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed (after the channel + has been moved to the end). So if input is `(H,W,D,C)` and C==1, then it will be saved as `(H,W,D)`. + If D is also 1, it will be saved as `(H,W)`. If ``False``, image will always be saved as `(H,W,D,C)`. + spatial_ndim: modifying the spatial dims if needed, so that output to have at least + this number of spatial dims. If ``None``, the output will have the same number of + spatial dimensions as the input. + contiguous: if ``True``, the output will be contiguous. + """ + # change data to "channel last" format + if channel_dim is not None: + _chns = ensure_tuple(channel_dim) + data = moveaxis(data, _chns, tuple(range(-len(_chns), 0))) + else: # adds a channel dimension + data = data[..., None] + # To ensure at least ``spatial_ndim`` number of spatial dims + if spatial_ndim: + while len(data.shape) < spatial_ndim + 1: # assuming the data has spatial + channel dims + data = data[..., None, :] + while len(data.shape) > spatial_ndim + 1: + data = data[..., 0, :] + # if desired, remove trailing singleton dimensions + while squeeze_end_dims and data.shape[-1] == 1: + data = np.squeeze(data, -1) + if contiguous: + data = ascontiguousarray(data) + return data + + @classmethod + def get_meta_info(cls, metadata: Optional[Mapping] = None): + """ + Extracts relevant meta information from the metadata object (using ``.get``). + Optional keys are ``"spatial_shape"``, ``"affine"``, ``"original_affine"``. + """ + if not metadata: + metadata = {"original_affine": None, "affine": None, "spatial_shape": None} + original_affine = metadata.get("original_affine") + affine = metadata.get("affine") + spatial_shape = metadata.get("spatial_shape") + return original_affine, affine, spatial_shape + + +@require_pkg(pkg_name="itk") +class ITKWriter(ImageWriter): + """ + Write data and metadata into files on disk using ITK-python. + + .. code-block:: python + + import numpy as np + from monai.data import ITKWriter + + np_data = np.arange(48).reshape(3, 4, 4) + + # write as 3d spatial image no channel + writer = ITKWriter(output_dtype=np.float32) + writer.set_data_array(np_data, channel_dim=None) + # optionally set metadata affine + writer.set_metadata({"affine": np.eye(4), "original_affine": -1 * np.eye(4)}) + writer.write("test1.nii.gz") + + # write as 2d image, channel-first + writer = ITKWriter(output_dtype=np.uint8) + writer.set_data_array(np_data, channel_dim=0) + writer.set_metadata({"spatial_shape": (5, 5)}) + writer.write("test1.png") + + """ + + def __init__(self, output_dtype: DtypeLike = np.float32, affine_lps_to_ras: bool = True, **kwargs): + """ + Args: + output_dtype: output data type. + affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``. + Set to ``True`` to be consistent with ``NibabelWriter``, + otherwise the affine matrix is assumed already in the ITK convention. + kwargs: keyword arguments passed to ``ImageWriter``. + + The constructor will create ``self.output_dtype`` internally. + ``affine`` and ``channel_dim`` are initialized as instance members (default ``None``, ``0``): + + - user-specified ``affine`` should be set in ``set_metadata``, + - user-specified ``channel_dim`` should be set in ``set_data_array``. + """ + super().__init__( + output_dtype=output_dtype, affine_lps_to_ras=affine_lps_to_ras, affine=None, channel_dim=0, **kwargs + ) + + def set_data_array( + self, data_array: NdarrayOrTensor, channel_dim: Optional[int] = 0, squeeze_end_dims: bool = True, **kwargs + ): + """ + Convert ``data_array`` into 'channel-last' numpy ndarray. + + Args: + data_array: input data array with the channel dimension specified by ``channel_dim``. + channel_dim: channel dimension of the data array. Defaults to 0. + ``None`` indicates data without any channel dimension. + squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed. + kwargs: keyword arguments passed to ``self.convert_to_channel_last``, + currently support ``spatial_ndim`` and ``contiguous``, defauting to ``3`` and ``False`` respectively. + """ + _r = len(data_array.shape) + self.data_obj = self.convert_to_channel_last( + data=data_array, + channel_dim=channel_dim, + squeeze_end_dims=squeeze_end_dims, + spatial_ndim=kwargs.pop("spatial_ndim", 3), + contiguous=kwargs.pop("contiguous", True), + ) + self.channel_dim = channel_dim if len(self.data_obj.shape) >= _r else None # channel dim is at the end + + def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = True, **options): + """ + Resample ``self.dataobj`` if needed. This method assumes ``self.data_obj`` is a 'channel-last' ndarray. + + Args: + meta_dict: a metadata dictionary for affine, original affine and spatial shape information. + Optional keys are ``"spatial_shape"``, ``"affine"``, ``"original_affine"``. + resample: if ``True``, the data will be resampled to the original affine (specified in ``meta_dict``). + options: keyword arguments passed to ``self.resample_if_needed``, + currently support ``mode``, ``padding_mode``, ``align_corners``, and ``dtype``, + defaulting to ``bilinear``, ``border``, ``False``, and ``np.float64`` respectively. + """ + original_affine, affine, spatial_shape = self.get_meta_info(meta_dict) + self.data_obj, self.affine = self.resample_if_needed( + data_array=self.data_obj, + affine=affine, + target_affine=original_affine if resample else None, + output_spatial_shape=spatial_shape if resample else None, + mode=options.pop("mode", GridSampleMode.BILINEAR), + padding_mode=options.pop("padding_mode", GridSamplePadMode.BORDER), + align_corners=options.pop("align_corners", False), + dtype=options.pop("dtype", np.float64), + ) + + def write(self, filename: PathLike, verbose: bool = False, **kwargs): + """ + Create an ITK object from ``self.create_backend_obj(self.obj, ...)`` and call ``itk.imwrite``. + + Args: + filename: filename or PathLike object. + verbose: if ``True``, log the progress. + kwargs: keyword arguments passed to ``itk.imwrite``, + currently support ``compression`` and ``imageio``. + + See also: + + - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L809 + """ + super().write(filename, verbose=verbose) + self.data_obj = self.create_backend_obj( + self.data_obj, + channel_dim=self.channel_dim, + affine=self.affine, + dtype=self.output_dtype, # type: ignore + affine_lps_to_ras=self.affine_lps_to_ras, # type: ignore + **kwargs, + ) + itk.imwrite( + self.data_obj, filename, compression=kwargs.pop("compression", False), imageio=kwargs.pop("imageio", None) + ) + + @classmethod + def create_backend_obj( + cls, + data_array: NdarrayOrTensor, + channel_dim: Optional[int] = 0, + affine: Optional[NdarrayOrTensor] = None, + dtype: DtypeLike = np.float32, + affine_lps_to_ras: bool = True, + **kwargs, + ): + """ + Create an ITK object from ``data_array``. This method assumes a 'channel-last' ``data_array``. + + Args: + data_array: input data array. + channel_dim: channel dimension of the data array. This is used to create a Vector Image if it is not ``None``. + affine: affine matrix of the data array. This is used to compute `spacing`, `direction` and `origin`. + dtype: output data type. + affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``. + Set to ``True`` to be consistent with ``NibabelWriter``, + otherwise the affine matrix is assumed already in the ITK convention. + kwargs: keyword arguments. Current `itk.GetImageFromArray` will read ``ttype`` from this dictionary. + + see also: + + - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L389 + """ + data_array = super().create_backend_obj(data_array) + _is_vec = channel_dim is not None + if _is_vec: + data_array = np.moveaxis(data_array, -1, 0) # from channel last to channel first + data_array = data_array.T.astype(dtype, copy=True, order="C") + itk_obj = itk.GetImageFromArray(data_array, is_vector=_is_vec, ttype=kwargs.pop("ttype", None)) + + d = len(itk.size(itk_obj)) + if affine is None: + affine = np.eye(d + 1, dtype=np.float64) + _affine = convert_data_type(affine, np.ndarray)[0] + if affine_lps_to_ras: + _affine = orientation_ras_lps(to_affine_nd(d, _affine)) + spacing = affine_to_spacing(_affine, r=d) + _direction: np.ndarray = np.diag(1 / spacing) + _direction = _affine[:d, :d] @ _direction + itk_obj.SetSpacing(spacing.tolist()) + itk_obj.SetOrigin(_affine[:d, -1].tolist()) + itk_obj.SetDirection(itk.GetMatrixFromArray(_direction)) + return itk_obj + + +@require_pkg(pkg_name="nibabel") +class NibabelWriter(ImageWriter): + """ + Write data and metadata into files on disk using Nibabel. + + .. code-block:: python + + import numpy as np + from monai.data import NibabelWriter + + np_data = np.arange(48).reshape(3, 4, 4) + writer = NibabelWriter() + writer.set_data_array(np_data, channel_dim=None) + writer.set_metadata({"affine": np.eye(4), "original_affine": np.eye(4)}) + writer.write("test1.nii.gz", verbose=True) + + """ + + def __init__(self, output_dtype: DtypeLike = np.float32, **kwargs): + """ + Args: + output_dtype: output data type. + kwargs: keyword arguments passed to ``ImageWriter``. + + The constructor will create ``self.output_dtype`` internally. + ``affine`` is initialized as instance members (default ``None``), + user-specified ``affine`` should be set in ``set_metadata``. + """ + super().__init__(output_dtype=output_dtype, affine=None, **kwargs) + + def set_data_array( + self, data_array: NdarrayOrTensor, channel_dim: Optional[int] = 0, squeeze_end_dims: bool = True, **kwargs + ): + """ + Convert ``data_array`` into 'channel-last' numpy ndarray. + + Args: + data_array: input data array with the channel dimension specified by ``channel_dim``. + channel_dim: channel dimension of the data array. Defaults to 0. + ``None`` indicates data without any channel dimension. + squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed. + kwargs: keyword arguments passed to ``self.convert_to_channel_last``, + currently support ``spatial_ndim``, defauting to ``3``. + """ + self.data_obj = self.convert_to_channel_last( + data=data_array, + channel_dim=channel_dim, + squeeze_end_dims=squeeze_end_dims, + spatial_ndim=kwargs.pop("spatial_ndim", 3), + ) + + def set_metadata(self, meta_dict: Optional[Mapping], resample: bool = True, **options): + """ + Resample ``self.dataobj`` if needed. This method assumes ``self.data_obj`` is a 'channel-last' ndarray. + + Args: + meta_dict: a metadata dictionary for affine, original affine and spatial shape information. + Optional keys are ``"spatial_shape"``, ``"affine"``, ``"original_affine"``. + resample: if ``True``, the data will be resampled to the original affine (specified in ``meta_dict``). + options: keyword arguments passed to ``self.resample_if_needed``, + currently support ``mode``, ``padding_mode``, ``align_corners``, and ``dtype``, + defaulting to ``bilinear``, ``border``, ``False``, and ``np.float64`` respectively. + """ + original_affine, affine, spatial_shape = self.get_meta_info(meta_dict) + self.data_obj, self.affine = self.resample_if_needed( + data_array=self.data_obj, + affine=affine, + target_affine=original_affine if resample else None, + output_spatial_shape=spatial_shape if resample else None, + mode=options.pop("mode", GridSampleMode.BILINEAR), + padding_mode=options.pop("padding_mode", GridSamplePadMode.BORDER), + align_corners=options.pop("align_corners", False), + dtype=options.pop("dtype", np.float64), + ) + + def write(self, filename: PathLike, verbose: bool = False, **obj_kwargs): + """ + Create a Nibabel object from ``self.create_backend_obj(self.obj, ...)`` and call ``nib.save``. + + Args: + filename: filename or PathLike object. + verbose: if ``True``, log the progress. + obj_kwargs: keyword arguments passed to ``self.create_backend_obj``, + + See also: + + - https://nipy.org/nibabel/reference/nibabel.nifti1.html#nibabel.nifti1.save + """ + super().write(filename, verbose=verbose) + self.data_obj = self.create_backend_obj( + self.data_obj, affine=self.affine, dtype=self.output_dtype, **obj_kwargs # type: ignore + ) + nib.save(self.data_obj, filename) + + @classmethod + def create_backend_obj( + cls, data_array: NdarrayOrTensor, affine: Optional[NdarrayOrTensor] = None, dtype: DtypeLike = None, **kwargs + ): + """ + Create an Nifti1Image object from ``data_array``. This method assumes a 'channel-last' ``data_array``. + + Args: + data_array: input data array. + affine: affine matrix of the data array. + dtype: output data type. + kwargs: keyword arguments. Current ``nib.nifti1.Nifti1Image`` will read + ``header``, ``extra``, ``file_map`` from this dictionary. + + See also: + + - https://nipy.org/nibabel/reference/nibabel.nifti1.html#nibabel.nifti1.Nifti1Image + """ + data_array = super().create_backend_obj(data_array) + if dtype is not None: + data_array = data_array.astype(dtype, copy=False) + affine = convert_data_type(affine, np.ndarray)[0] + affine = to_affine_nd(r=3, affine=affine) + return nib.nifti1.Nifti1Image( + data_array, + affine, + header=kwargs.pop("header", None), + extra=kwargs.pop("extra", None), + file_map=kwargs.pop("file_map", None), + ) + + +@require_pkg(pkg_name="PIL") +class PILWriter(ImageWriter): + """ + Write image data into files on disk using pillow. + + It's based on the Image module in PIL library: + https://pillow.readthedocs.io/en/stable/reference/Image.html + + .. code-block:: python + + import numpy as np + from monai.data import PILWriter + + np_data = np.arange(48).reshape(3, 4, 4) + writer = PILWriter(np.uint8) + writer.set_data_array(np_data, channel_dim=0) + writer.write("test1.png", verbose=True) + """ + + def __init__( + self, output_dtype: DtypeLike = np.float32, channel_dim: Optional[int] = 0, scale: Optional[int] = 255, **kwargs + ): + """ + Args: + output_dtype: output data type. + channel_dim: channel dimension of the data array. Defaults to 0. + ``None`` indicates data without any channel dimension. + scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling + [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. + kwargs: keyword arguments passed to ``ImageWriter``. + """ + super().__init__(output_dtype=output_dtype, channel_dim=channel_dim, scale=scale, **kwargs) + + def set_data_array( + self, + data_array: NdarrayOrTensor, + channel_dim: Optional[int] = 0, + squeeze_end_dims: bool = True, + contiguous: bool = False, + **kwargs, + ): + """ + Convert ``data_array`` into 'channel-last' numpy ndarray. + + Args: + data_array: input data array with the channel dimension specified by ``channel_dim``. + channel_dim: channel dimension of the data array. Defaults to 0. + ``None`` indicates data without any channel dimension. + squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed. + contiguous: if ``True``, the data array will be converted to a contiguous array. Default is ``False``. + kwargs: keyword arguments passed to ``self.convert_to_channel_last``, + currently support ``spatial_ndim``, defauting to ``2``. + """ + self.data_obj = self.convert_to_channel_last( + data=data_array, + channel_dim=channel_dim, + squeeze_end_dims=squeeze_end_dims, + spatial_ndim=kwargs.pop("spatial_ndim", 2), + contiguous=contiguous, + ) + + def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = True, **options): + """ + Resample ``self.dataobj`` if needed. This method assumes ``self.data_obj`` is a 'channel-last' ndarray. + + Args: + meta_dict: a metadata dictionary for affine, original affine and spatial shape information. + Optional key is ``"spatial_shape"``. + resample: if ``True``, the data will be resampled to the spatial shape specified in ``meta_dict``. + options: keyword arguments passed to ``self.resample_if_needed``, + currently support ``mode``, defaulting to ``bicubic``. + """ + spatial_shape = self.get_meta_info(meta_dict) + self.data_obj = self.resample_and_clip( + data_array=self.data_obj, + output_spatial_shape=spatial_shape if resample else None, + mode=options.pop("mode", InterpolateMode.BICUBIC), + ) + + def write(self, filename: PathLike, verbose: bool = False, **kwargs): + """ + Create a PIL image object from ``self.create_backend_obj(self.obj, ...)`` and call ``save``. + + Args: + filename: filename or PathLike object. + verbose: if ``True``, log the progress. + kwargs: optional keyword arguments passed to ``self.create_backend_obj`` + currently support ``reverse_indexing``, ``image_mode``, defaulting to ``True``, ``None`` respectively. + + See also: + + - https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.save + """ + super().write(filename, verbose=verbose) + self.data_obj = self.create_backend_obj( + data_array=self.data_obj, + dtype=self.output_dtype, # type: ignore + reverse_indexing=kwargs.pop("reverse_indexing", True), + image_mode=kwargs.pop("image_mode", None), + scale=self.scale, # type: ignore + **kwargs, + ) + self.data_obj.save(filename, **kwargs) + + @classmethod + def get_meta_info(cls, metadata: Optional[Mapping] = None): + return None if not metadata else metadata.get("spatial_shape") + + @classmethod + def resample_and_clip( + cls, + data_array: NdarrayOrTensor, + output_spatial_shape: Optional[Sequence[int]] = None, + mode: Union[InterpolateMode, str] = InterpolateMode.BICUBIC, + ): + """ + Resample ``data_array`` to ``output_spatial_shape`` if needed. + Args: + data_array: input data array. This method assumes the 'channel-last' format. + output_spatial_shape: output spatial shape. + mode: interpolation mode, defautl is ``InterpolateMode.BICUBIC``. + """ + + data: np.ndarray = convert_data_type(data_array, np.ndarray)[0] + if output_spatial_shape is not None: + output_spatial_shape_ = ensure_tuple_rep(output_spatial_shape, 2) + mode = look_up_option(mode, InterpolateMode) + align_corners = None if mode in (InterpolateMode.NEAREST, InterpolateMode.AREA) else False + xform = Resize(spatial_size=output_spatial_shape_, mode=mode, align_corners=align_corners) + _min, _max = np.min(data), np.max(data) + if len(data.shape) == 3: + data = np.moveaxis(data, -1, 0) # to channel first + data = xform(data) # type: ignore + data = np.moveaxis(data, 0, -1) + else: # (H, W) + data = np.expand_dims(data, 0) # make a channel + data = xform(data)[0] # type: ignore + if mode != InterpolateMode.NEAREST: + data = np.clip(data, _min, _max) + return data + + @classmethod + def create_backend_obj( + cls, + data_array: NdarrayOrTensor, + dtype: DtypeLike = None, + scale: Optional[int] = 255, + reverse_indexing: bool = True, + **kwargs, + ): + """ + Create a PIL object from ``data_array``. + + Args: + data_array: input data array. + dtype: output data type. + scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling + [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. + reverse_indexing: if ``True``, the data array's first two dimensions will be swapped. + kwargs: keyword arguments. Currently ``PILImage.fromarray`` will read + ``image_mode`` from this dictionary, defaults to ``None``. + + See also: + + - https://pillow.readthedocs.io/en/stable/reference/Image.html + """ + data: np.ndarray = super().create_backend_obj(data_array) + if scale: + # scale the data to be in an integer range + data = np.clip(data, 0.0, 1.0) # type: ignore # png writer only can scale data in range [0, 1] + if scale == np.iinfo(np.uint8).max: + data = (scale * data).astype(np.uint8, copy=False) + elif scale == np.iinfo(np.uint16).max: + data = (scale * data).astype(np.uint16, copy=False) + else: + raise ValueError(f"Unsupported scale: {scale}, available options are [255, 65535].") + if dtype is not None: + data = data.astype(dtype, copy=False) + if reverse_indexing: + data = np.moveaxis(data, 0, 1) + + return PILImage.fromarray(data, mode=kwargs.pop("image_mode", None)) + + +def init(): + """ + Initialize the image writer modules according to the filename extension. + """ + for ext in ("png", "jpg", "jpeg", "bmp", "tiff", "tif"): + register_writer(ext, PILWriter) # TODO: test 16-bit + for ext in ("nii.gz", "nii"): + register_writer(ext, NibabelWriter, ITKWriter) + register_writer("nrrd", ITKWriter, NibabelWriter) + register_writer(EXT_WILDCARD, ITKWriter, NibabelWriter, ITKWriter) diff --git a/monai/data/iterable_dataset.py b/monai/data/iterable_dataset.py index c4fc252586..f292bf1593 100644 --- a/monai/data/iterable_dataset.py +++ b/monai/data/iterable_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,15 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union +import numpy as np from torch.utils.data import IterableDataset as _TorchIterableDataset from torch.utils.data import get_worker_info from monai.data.utils import convert_tables_to_dicts from monai.transforms import apply_transform -from monai.utils import ensure_tuple, optional_import +from monai.transforms.transform import Randomizable +from monai.utils import deprecated_arg, optional_import pd, _ = optional_import("pandas") @@ -25,11 +26,15 @@ class IterableDataset(_TorchIterableDataset): """ A generic dataset for iterable data source and an optional callable data transform - when fetching a data sample. + when fetching a data sample. Inherit from PyTorch IterableDataset: + https://pytorch.org/docs/stable/data.html?highlight=iterabledataset#torch.utils.data.IterableDataset. For example, typical input data can be web data stream which can support multi-process access. - Note that when used with `DataLoader` and `num_workers > 0`, each worker process will have a - different copy of the dataset object, need to guarantee process-safe from data source or DataLoader. + To accelerate the loading process, it can support multi-processing based on PyTorch DataLoader workers, + every process executes transforms on part of every loaded data. + Note that the order of output data may not match data source in multi-processing mode. + And each worker process will have a different copy of the dataset object, need to guarantee + process-safe from data source or DataLoader. """ @@ -44,19 +49,91 @@ def __init__(self, data: Iterable, transform: Optional[Callable] = None) -> None self.source = None def __iter__(self): + info = get_worker_info() + num_workers = info.num_workers if info is not None else 1 + id = info.id if info is not None else 0 + self.source = iter(self.data) - for data in self.source: - if self.transform is not None: - data = apply_transform(self.transform, data) - yield data + for i, item in enumerate(self.source): + if i % num_workers == id: + if self.transform is not None: + item = apply_transform(self.transform, item) + yield item + + +class ShuffleBuffer(Randomizable, IterableDataset): + """ + Extend the IterableDataset with a buffer and randomly pop items. + + Args: + data: input data source to load and transform to generate dataset for model. + transform: a callable data transform on input data. + buffer_size: size of the buffer to store items and randomly pop, default to 512. + seed: random seed to initialize the random state of all workers, set `seed += 1` in + every iter() call, refer to the PyTorch idea: + https://github.com/pytorch/pytorch/blob/v1.10.0/torch/utils/data/distributed.py#L98. + + """ + + def __init__(self, data, transform=None, buffer_size: int = 512, seed: int = 0) -> None: + super().__init__(data=data, transform=transform) + self.size = buffer_size + self.seed = seed + self._idx = 0 + + def __iter__(self): + """ + Fetch data from the source, if buffer is not full, fill into buffer, otherwise, + randomly pop items from the buffer. + After loading all the data from source, randomly pop items from the buffer. + + """ + self.seed += 1 + super().set_random_state(seed=self.seed) # make all workers in sync + buffer = [] + source = self.data + + def _pop_item(): + self.randomize(len(buffer)) + # switch random index data and the last index data + ret, buffer[self._idx] = buffer[self._idx], buffer[-1] + buffer.pop() + return ret + + def _get_item(): + for item in source: + if len(buffer) >= self.size: + yield _pop_item() + buffer.append(item) + + while buffer: + yield _pop_item() + + self.data = _get_item() + return super().__iter__() + + def randomize(self, size: int) -> None: + self._idx = self.R.randint(size) + + def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None): + raise NotImplementedError(f"`set_random_state` is not available in {self.__class__.__name__}.") class CSVIterableDataset(IterableDataset): """ Iterable dataset to load CSV files and generate dictionary data. - It can be helpful when loading extremely big CSV files that can't read into memory directly. + It is particularly useful when data come from a stream, inherits from PyTorch IterableDataset: + https://pytorch.org/docs/stable/data.html?highlight=iterabledataset#torch.utils.data.IterableDataset. + + It also can be helpful when loading extremely big CSV files that can't read into memory directly, + just treat the big CSV file as stream input, call `reset()` of `CSVIterableDataset` for every epoch. + Note that as a stream input, it can't get the length of dataset. + + To effectively shuffle the data in the big dataset, users can set a big buffer to continuously store + the loaded data, then randomly pick data from the buffer for following tasks. + To accelerate the loading process, it can support multi-processing based on PyTorch DataLoader workers, - every process executes transforms on part of every loaded chunk. + every process executes transforms on part of every loaded data. Note: the order of output data may not match data source in multi-processing mode. It can load data from multiple CSV files and join the tables with additional `kwargs` arg. @@ -70,10 +147,12 @@ class CSVIterableDataset(IterableDataset): ] Args: - filename: the filename of expected CSV file to load. if providing a list - of filenames, it will load all the files and join tables. + src: if provided the filename of CSV file, it can be a str, URL, path object or file-like object to load. + also support to provide iter for stream input directly, will skip loading from filename. + if provided a list of filenames or iters, it will join the tables. chunksize: rows of a chunk when loading iterable data from CSV files, default to 1000. more details: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html. + buffer_size: size of the buffer to store the loaded chunks, if None, set to `2 x chunksize`. col_names: names of the expected columns to load. if None, load all the columns. col_types: `type` and `default value` to convert the loaded columns, if None, use original data. it should be a dictionary, every item maps to an expected column, the `key` is the column @@ -93,50 +172,101 @@ class CSVIterableDataset(IterableDataset): be the new column name, the `value` is the names of columns to combine. for example: `col_groups={"ehr": [f"ehr_{i}" for i in range(10)], "meta": ["meta_1", "meta_2"]}` transform: transform to apply on the loaded items of a dictionary data. + shuffle: whether to shuffle all the data in the buffer every time a new chunk loaded. + seed: random seed to initialize the random state for all the workers if `shuffle` is True, + set `seed += 1` in every iter() call, refer to the PyTorch idea: + https://github.com/pytorch/pytorch/blob/v1.10.0/torch/utils/data/distributed.py#L98. + kwargs_read_csv: dictionary args to pass to pandas `read_csv` function. Default to ``{"chunksize": chunksize}``. kwargs: additional arguments for `pandas.merge()` API to join tables. + .. deprecated:: 0.8.0 + ``filename`` is deprecated, use ``src`` instead. + """ + @deprecated_arg(name="filename", new_name="src", since="0.8", msg_suffix="please use `src` instead.") def __init__( self, - filename: Union[str, Sequence[str]], + src: Union[Union[str, Sequence[str]], Union[Iterable, Sequence[Iterable]]], chunksize: int = 1000, + buffer_size: Optional[int] = None, col_names: Optional[Sequence[str]] = None, col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None, col_groups: Optional[Dict[str, Sequence[str]]] = None, transform: Optional[Callable] = None, + shuffle: bool = False, + seed: int = 0, + kwargs_read_csv: Optional[Dict] = None, **kwargs, ): - self.files = ensure_tuple(filename) + self.src = src self.chunksize = chunksize - self.iters = self.reset() + self.buffer_size = 2 * chunksize if buffer_size is None else buffer_size self.col_names = col_names self.col_types = col_types self.col_groups = col_groups + self.shuffle = shuffle + self.seed = seed + self.kwargs_read_csv = kwargs_read_csv or {"chunksize": chunksize} + # in case treating deprecated arg `filename` as kwargs, remove it from `kwargs` + kwargs.pop("filename", None) self.kwargs = kwargs + + self.iters: List[Iterable] = self.reset() super().__init__(data=None, transform=transform) # type: ignore - def reset(self, filename: Optional[Union[str, Sequence[str]]] = None): - if filename is not None: - # update files if necessary - self.files = ensure_tuple(filename) - self.iters = [pd.read_csv(f, chunksize=self.chunksize) for f in self.files] + @deprecated_arg(name="filename", new_name="src", since="0.8", msg_suffix="please use `src` instead.") + def reset(self, src: Optional[Union[Union[str, Sequence[str]], Union[Iterable, Sequence[Iterable]]]] = None): + """ + Reset the pandas `TextFileReader` iterable object to read data. For more details, please check: + https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html?#iteration. + + Args: + src: if not None and provided the filename of CSV file, it can be a str, URL, path object + or file-like object to load. also support to provide iter for stream input directly, + will skip loading from filename. if provided a list of filenames or iters, it will join the tables. + default to `self.src`. + + """ + src = self.src if src is None else src + srcs = (src,) if not isinstance(src, (tuple, list)) else src + self.iters = [] + for i in srcs: + if isinstance(i, str): + self.iters.append(pd.read_csv(i, **self.kwargs_read_csv)) + elif isinstance(i, Iterable): + self.iters.append(i) + else: + raise ValueError("`src` must be file path or iterable object.") return self.iters - def __iter__(self): + def close(self): + """ + Close the pandas `TextFileReader` iterable objects. + If the input src is file path, TextFileReader was created internally, need to close it. + If the input src is iterable object, depends on users requirements whether to close it in this function. + For more details, please check: + https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html?#iteration. + + """ + for i in self.iters: + i.close() + + def _flattened(self): for chunks in zip(*self.iters): - self.data = convert_tables_to_dicts( + yield from convert_tables_to_dicts( dfs=chunks, col_names=self.col_names, col_types=self.col_types, col_groups=self.col_groups, **self.kwargs, ) - info = get_worker_info() - if info is not None: - length = len(self.data) - per_worker = int(math.ceil(length / float(info.num_workers))) - start = info.id * per_worker - self.data = self.data[start : min(start + per_worker, length)] - - return super().__iter__() + + def __iter__(self): + if self.shuffle: + self.seed += 1 + buffer = ShuffleBuffer( + data=self._flattened(), transform=self.transform, buffer_size=self.buffer_size, seed=self.seed + ) + yield from buffer + yield from IterableDataset(data=self._flattened(), transform=self.transform) diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index b7067def73..3fdc0aa3e8 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,19 +9,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pathlib import Path from typing import Dict, Optional, Union import numpy as np import torch -from monai.config import DtypeLike +from monai.config import DtypeLike, PathLike from monai.data.nifti_writer import write_nifti from monai.data.utils import create_file_basename from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key +from monai.utils import deprecated +@deprecated(since="0.8", msg_suffix="use monai.transforms.SaveImage instead.") class NiftiSaver: """ Save the data as NIfTI file, it can support single data content or a batch of data. @@ -33,11 +34,14 @@ class NiftiSaver: Note: image should include channel dimension: [B],C,H,W,[D]. + .. deprecated:: 0.8 + Use :py:class:`monai.transforms.SaveImage` instead. + """ def __init__( self, - output_dir: Union[Path, str] = "./", + output_dir: PathLike = "./", output_postfix: str = "seg", output_ext: str = ".nii.gz", resample: bool = True, @@ -47,7 +51,7 @@ def __init__( dtype: DtypeLike = np.float64, output_dtype: DtypeLike = np.float32, squeeze_end_dims: bool = True, - data_root_dir: str = "", + data_root_dir: PathLike = "", separate_folder: bool = True, print_log: bool = True, ) -> None: @@ -56,17 +60,18 @@ def __init__( output_dir: output image directory. output_postfix: a string appended to all output file names. output_ext: output file extension name. - resample: whether to resample before saving the data array. + resample: whether to convert the data array to it's original coordinate system + based on `original_affine` in the `meta_data`. mode: {``"bilinear"``, ``"nearest"``} This option is used when ``resample = True``. Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} This option is used when ``resample = True``. Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html align_corners: Geometrically, we consider the pixels of the input as squares rather than points. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. output_dtype: data type for saving data. Defaults to ``np.float32``. @@ -107,7 +112,7 @@ def __init__( def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: """ - Save data into a Nifti file. + Save data into a NIfTI file. The meta_data could optionally have the following keys: - ``'filename_or_obj'`` -- for output file name creation, corresponding to filename or object. @@ -116,7 +121,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] - ``'spatial_shape'`` -- for data output shape. - ``'patch_index'`` -- if the data is a patch of big image, append the patch index to filename. - When meta_data is specified, the saver will try to resample batch data from the space + When meta_data is specified and `resample=True`, the saver will try to resample batch data from the space defined by "affine" to the space defined by "original_affine". If meta_data is None, use the default index (starting from 0) as the filename. @@ -131,7 +136,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] """ filename = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) self._data_index += 1 - original_affine = meta_data.get("original_affine", None) if meta_data else None + original_affine = meta_data.get("original_affine", None) if meta_data and self.resample else None affine = meta_data.get("affine", None) if meta_data else None spatial_shape = meta_data.get("spatial_shape", None) if meta_data else None patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None @@ -151,7 +156,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] # change data shape to be (channel, h, w, d) while len(data.shape) < 4: data = np.expand_dims(data, -1) - # change data to "channel last" format and write to nifti format file + # change data to "channel last" format and write to NIfTI format file data = np.moveaxis(np.asarray(data), 0, -1) # if desired, remove trailing singleton dimensions @@ -164,7 +169,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] file_name=path, affine=affine, target_affine=original_affine, - resample=self.resample, + resample=True, output_spatial_shape=spatial_shape, mode=self.mode, padding_mode=self.padding_mode, @@ -178,7 +183,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: """ - Save a batch of data into Nifti format files. + Save a batch of data into NIfTI format files. Spatially it supports up to three dimensions, that is, H, HW, HWD for 1D, 2D, 3D respectively (with resampling supports for 2D and 3D only). diff --git a/monai/data/nifti_writer.py b/monai/data/nifti_writer.py index c56d4c1e8d..8a6172955f 100644 --- a/monai/data/nifti_writer.py +++ b/monai/data/nifti_writer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,17 +15,21 @@ import torch from monai.config import DtypeLike +from monai.config.type_definitions import NdarrayOrTensor from monai.data.utils import compute_shape_offset, to_affine_nd from monai.networks.layers import AffineTransform -from monai.utils import GridSampleMode, GridSamplePadMode, optional_import +from monai.transforms.utils_pytorch_numpy_unification import allclose +from monai.utils import GridSampleMode, GridSamplePadMode, deprecated, optional_import +from monai.utils.type_conversion import convert_data_type nib, _ = optional_import("nibabel") +@deprecated(since="0.8", msg_suffix="use monai.data.NibabelWriter instead.") def write_nifti( - data: np.ndarray, + data: NdarrayOrTensor, file_name: str, - affine: Optional[np.ndarray] = None, + affine: Optional[NdarrayOrTensor] = None, target_affine: Optional[np.ndarray] = None, resample: bool = True, output_spatial_shape: Union[Sequence[int], np.ndarray, None] = None, @@ -85,19 +89,25 @@ def write_nifti( mode: {``"bilinear"``, ``"nearest"``} This option is used when ``resample = True``. Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} This option is used when ``resample = True``. Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html align_corners: Geometrically, we consider the pixels of the input as squares rather than points. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. output_dtype: data type for saving data. Defaults to ``np.float32``. + + .. deprecated:: 0.8 + Use :py:meth:`monai.data.NibabelWriter` instead. + """ + data, *_ = convert_data_type(data, np.ndarray) + affine, *_ = convert_data_type(affine, np.ndarray) if not isinstance(data, np.ndarray): - raise AssertionError("input data must be numpy array.") + raise AssertionError("input data must be numpy array or torch tensor.") dtype = dtype or data.dtype sr = min(data.ndim, 3) if affine is None: @@ -106,11 +116,11 @@ def write_nifti( if target_affine is None: target_affine = affine - target_affine = to_affine_nd(sr, target_affine) + target_affine, *_ = convert_data_type(to_affine_nd(sr, target_affine), np.ndarray) - if np.allclose(affine, target_affine, atol=1e-3): + if allclose(affine, target_affine, atol=1e-3): # no affine changes, save (data, affine) - results_img = nib.Nifti1Image(data.astype(output_dtype), to_affine_nd(3, target_affine)) + results_img = nib.Nifti1Image(data.astype(output_dtype, copy=False), to_affine_nd(3, target_affine)) nib.save(results_img, file_name) return @@ -121,8 +131,8 @@ def write_nifti( data_shape = data.shape data = nib.orientations.apply_orientation(data, ornt_transform) _affine = affine @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape) - if np.allclose(_affine, target_affine, atol=1e-3) or not resample: - results_img = nib.Nifti1Image(data.astype(output_dtype), to_affine_nd(3, _affine)) + if allclose(_affine, target_affine, atol=1e-3) or not resample: + results_img = nib.Nifti1Image(data.astype(output_dtype, copy=False), to_affine_nd(3, _affine)) # type: ignore nib.save(results_img, file_name) return @@ -138,11 +148,11 @@ def write_nifti( while len(output_spatial_shape_) < 3: output_spatial_shape_ = output_spatial_shape_ + [1] spatial_shape, channel_shape = data.shape[:3], data.shape[3:] - data_np = data.reshape(list(spatial_shape) + [-1]) + data_np: np.ndarray = data.reshape(list(spatial_shape) + [-1]) # type: ignore data_np = np.moveaxis(data_np, -1, 0) # channel first for pytorch data_torch = affine_xform( - torch.as_tensor(np.ascontiguousarray(data_np).astype(dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(transform).astype(dtype)), + torch.as_tensor(np.ascontiguousarray(data_np, dtype=dtype)).unsqueeze(0), + torch.as_tensor(np.ascontiguousarray(transform, dtype=dtype)), spatial_size=output_spatial_shape_[:3], ) data_np = data_torch.squeeze(0).detach().cpu().numpy() @@ -152,12 +162,12 @@ def write_nifti( while len(output_spatial_shape_) < len(data.shape): output_spatial_shape_ = output_spatial_shape_ + [1] data_torch = affine_xform( - torch.as_tensor(np.ascontiguousarray(data).astype(dtype)[None, None]), - torch.as_tensor(np.ascontiguousarray(transform).astype(dtype)), + torch.as_tensor(np.ascontiguousarray(data, dtype=dtype)[None, None]), + torch.as_tensor(np.ascontiguousarray(transform, dtype=dtype)), spatial_size=output_spatial_shape_[: len(data.shape)], ) data_np = data_torch.squeeze(0).squeeze(0).detach().cpu().numpy() - results_img = nib.Nifti1Image(data_np.astype(output_dtype), to_affine_nd(3, target_affine)) + results_img = nib.Nifti1Image(data_np.astype(output_dtype, copy=False), to_affine_nd(3, target_affine)) nib.save(results_img, file_name) return diff --git a/monai/data/png_saver.py b/monai/data/png_saver.py index e6fb641cca..9a1ade0efa 100644 --- a/monai/data/png_saver.py +++ b/monai/data/png_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,18 +9,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pathlib import Path from typing import Dict, Optional, Union import numpy as np import torch +from monai.config.type_definitions import PathLike from monai.data.png_writer import write_png from monai.data.utils import create_file_basename from monai.utils import ImageMetaKey as Key -from monai.utils import InterpolateMode, look_up_option +from monai.utils import InterpolateMode, deprecated, look_up_option +@deprecated(since="0.8", msg_suffix="use monai.transforms.SaveImage instead.") class PNGSaver: """ Save the data as png file, it can support single data content or a batch of data. @@ -30,17 +31,20 @@ class PNGSaver: where the input image name is extracted from the provided meta data dictionary. If no meta data provided, use index from 0 as the filename prefix. + .. deprecated:: 0.8 + Use :py:class:`monai.transforms.SaveImage` instead. + """ def __init__( self, - output_dir: Union[Path, str] = "./", + output_dir: PathLike = "./", output_postfix: str = "seg", output_ext: str = ".png", resample: bool = True, mode: Union[InterpolateMode, str] = InterpolateMode.NEAREST, scale: Optional[int] = None, - data_root_dir: str = "", + data_root_dir: PathLike = "", separate_folder: bool = True, print_log: bool = True, ) -> None: @@ -52,7 +56,7 @@ def __init__( resample: whether to resample and resize if providing spatial_shape in the metadata. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"nearest"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. data_root_dir: if not empty, it specifies the beginning parts of the input file's @@ -134,11 +138,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] raise ValueError(f"Unsupported number of channels: {data.shape[0]}, available options are [1, 3, 4]") write_png( - np.asarray(data), - file_name=path, - output_spatial_shape=spatial_shape, - mode=self.mode, - scale=self.scale, + np.asarray(data), file_name=path, output_spatial_shape=spatial_shape, mode=self.mode, scale=self.scale ) if self.print_log: diff --git a/monai/data/png_writer.py b/monai/data/png_writer.py index 2baec3b872..5d05536923 100644 --- a/monai/data/png_writer.py +++ b/monai/data/png_writer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,11 +14,12 @@ import numpy as np from monai.transforms.spatial.array import Resize -from monai.utils import InterpolateMode, ensure_tuple_rep, look_up_option, optional_import +from monai.utils import InterpolateMode, deprecated, ensure_tuple_rep, look_up_option, optional_import Image, _ = optional_import("PIL", name="Image") +@deprecated(since="0.8", msg_suffix="use monai.data.PILWriter instead.") def write_png( data: np.ndarray, file_name: str, @@ -39,16 +40,19 @@ def write_png( output_spatial_shape: spatial shape of the output image. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"bicubic"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling to [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. Raises: ValueError: When ``scale`` is not one of [255, 65535]. + .. deprecated:: 0.8 + Use :py:meth:`monai.data.PILWriter` instead. + """ if not isinstance(data, np.ndarray): - raise AssertionError("input data must be numpy array.") + raise ValueError("input data must be numpy array.") if len(data.shape) == 3 and data.shape[2] == 1: # PIL Image can't save image with 1 channel data = data.squeeze(2) if output_spatial_shape is not None: @@ -59,26 +63,26 @@ def write_png( _min, _max = np.min(data), np.max(data) if len(data.shape) == 3: data = np.moveaxis(data, -1, 0) # to channel first - data = xform(data) + data = xform(data) # type: ignore data = np.moveaxis(data, 0, -1) else: # (H, W) data = np.expand_dims(data, 0) # make a channel - data = xform(data)[0] # first channel + data = xform(data)[0] # type: ignore if mode != InterpolateMode.NEAREST: - data = np.clip(data, _min, _max) # type: ignore + data = np.clip(data, _min, _max) if scale is not None: - data = np.clip(data, 0.0, 1.0) # type: ignore # png writer only can scale data in range [0, 1] + data = np.clip(data, 0.0, 1.0) # png writer only can scale data in range [0, 1] if scale == np.iinfo(np.uint8).max: - data = (scale * data).astype(np.uint8) + data = (scale * data).astype(np.uint8, copy=False) elif scale == np.iinfo(np.uint16).max: - data = (scale * data).astype(np.uint16) + data = (scale * data).astype(np.uint16, copy=False) else: raise ValueError(f"Unsupported scale: {scale}, available options are [255, 65535]") # PNG data must be int number - if data.dtype not in (np.uint8, np.uint16): # type: ignore - data = data.astype(np.uint8) + if data.dtype not in (np.uint8, np.uint16): + data = data.astype(np.uint8, copy=False) data = np.moveaxis(data, 0, 1) img = Image.fromarray(data) diff --git a/monai/data/samplers.py b/monai/data/samplers.py index f69c6091ca..40eed03187 100644 --- a/monai/data/samplers.py +++ b/monai/data/samplers.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/data/synthetic.py b/monai/data/synthetic.py index 6eec9fd277..46d555cf11 100644 --- a/monai/data/synthetic.py +++ b/monai/data/synthetic.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -73,7 +73,7 @@ def create_test_image_2d( else: image[circle] = rs.random() * 0.5 + 0.5 - labels = np.ceil(image).astype(np.int32) + labels = np.ceil(image).astype(np.int32, copy=False) norm = rs.uniform(0, num_seg_classes * noise_max, size=image.shape) noisyimage: np.ndarray = rescale_array(np.maximum(image, norm)) # type: ignore @@ -148,7 +148,7 @@ def create_test_image_3d( else: image[circle] = rs.random() * 0.5 + 0.5 - labels = np.ceil(image).astype(np.int32) + labels = np.ceil(image).astype(np.int32, copy=False) norm = rs.uniform(0, num_seg_classes * noise_max, size=image.shape) noisyimage: np.ndarray = rescale_array(np.maximum(image, norm)) # type: ignore diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 33239ea924..0b97c9febf 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,22 +9,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from copy import deepcopy from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch +from monai.config.type_definitions import NdarrayOrTensor from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset -from monai.data.utils import list_data_collate, pad_list_data_collate +from monai.data.utils import decollate_batch, pad_list_data_collate from monai.transforms.compose import Compose +from monai.transforms.croppad.batch import PadListDataCollate from monai.transforms.inverse import InvertibleTransform -from monai.transforms.inverse_batch_transform import BatchInverseTransform +from monai.transforms.post.dictionary import Invertd from monai.transforms.transform import Randomizable -from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode -from monai.utils.enums import CommonKeys, InverseKeys -from monai.utils.module import optional_import +from monai.transforms.utils_pytorch_numpy_unification import mode, stack +from monai.utils import CommonKeys, PostFix, optional_import if TYPE_CHECKING: from tqdm import tqdm @@ -35,6 +37,12 @@ __all__ = ["TestTimeAugmentation"] +DEFAULT_POST_FIX = PostFix.meta() + + +def _identity(x): + return x + class TestTimeAugmentation: """ @@ -67,24 +75,29 @@ class TestTimeAugmentation: orig_meta_keys: the key of the meta data of original input data, will get the `affine`, `data_shape`, etc. the meta data is a dictionary object which contains: filename, original_shape, etc. if None, will try to construct meta_keys by `{orig_key}_{meta_key_postfix}`. - meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data, + meta_key_postfix: use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. this arg only works when `meta_keys=None`. - return_full_data: normally, metrics are returned (mode, mean, std, vvc). Setting this flag to `True` will return the - full data. Dimensions will be same size as when passing a single image through `inferrer_fn`, with a dimension appended - equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`. + to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`. + output_device: if converted the inverted data to Tensor, move the inverted results to target device + before `post_func`, default to "cpu". + post_func: post processing for the inverted data, should be a callable function. + return_full_data: normally, metrics are returned (mode, mean, std, vvc). Setting this flag to `True` + will return the full data. Dimensions will be same size as when passing a single image through + `inferrer_fn`, with a dimension appended equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`. progress: whether to display a progress bar. Example: .. code-block:: python - transform = RandAffined(keys, ...) - post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) + model = UNet(...).to(device) + transform = Compose([RandAffined(keys, ...), ...]) + transform.set_random_state(seed=123) # ensure deterministic evaluation tt_aug = TestTimeAugmentation( - transform, batch_size=5, num_workers=0, inferrer_fn=lambda x: post_trans(model(x)), device=device + transform, batch_size=5, num_workers=0, inferrer_fn=model, device=device ) mode, mean, std, vvc = tt_aug(test_data) """ @@ -93,14 +106,17 @@ def __init__( self, transform: InvertibleTransform, batch_size: int, - num_workers: int, - inferrer_fn: Callable, + num_workers: int = 0, + inferrer_fn: Callable = _identity, device: Union[str, torch.device] = "cpu", image_key=CommonKeys.IMAGE, orig_key=CommonKeys.LABEL, nearest_interp: bool = True, orig_meta_keys: Optional[str] = None, - meta_key_postfix="meta_dict", + meta_key_postfix=DEFAULT_POST_FIX, + to_tensor: bool = True, + output_device: Union[str, torch.device] = "cpu", + post_func: Callable = _identity, return_full_data: bool = False, progress: bool = True, ) -> None: @@ -110,12 +126,20 @@ def __init__( self.inferrer_fn = inferrer_fn self.device = device self.image_key = image_key - self.orig_key = orig_key - self.nearest_interp = nearest_interp - self.orig_meta_keys = orig_meta_keys - self.meta_key_postfix = meta_key_postfix self.return_full_data = return_full_data self.progress = progress + self._pred_key = CommonKeys.PRED + self.inverter = Invertd( + keys=self._pred_key, + transform=transform, + orig_keys=orig_key, + orig_meta_keys=orig_meta_keys, + meta_key_postfix=meta_key_postfix, + nearest_interp=nearest_interp, + to_tensor=to_tensor, + device=output_device, + post_func=post_func, + ) # check that the transform has at least one random component, and that all random transforms are invertible self._check_transforms() @@ -127,30 +151,31 @@ def _check_transforms(self): invertibles = np.array([isinstance(t, InvertibleTransform) for t in ts]) # check at least 1 random if sum(randoms) == 0: - raise RuntimeError( - "Requires a `Randomizable` transform or a `Compose` containing at least one `Randomizable` transform." + warnings.warn( + "TTA usually has at least a `Randomizable` transform or `Compose` contains `Randomizable` transforms." ) # check that whenever randoms is True, invertibles is also true for r, i in zip(randoms, invertibles): if r and not i: - raise RuntimeError( - f"All applied random transform(s) must be invertible. Problematic transform: {type(r).__name__}" + warnings.warn( + f"Not all applied random transform(s) are invertible. Problematic transform: {type(r).__name__}" ) def __call__( self, data: Dict[str, Any], num_examples: int = 10 - ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, float], np.ndarray]: + ) -> Union[Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor, float], NdarrayOrTensor]: """ Args: data: dictionary data to be processed. num_examples: number of realisations to be processed and results combined. Returns: - - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are calculated across - `num_examples` outputs at each voxel. The volume variation coefficient (VVC) is `std/mean` across the whole output, - including `num_examples`. See original paper for clarification. - - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then concatenating across - the first dimension containing `num_examples`. This allows the user to perform their own analysis if desired. + - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are + calculated across `num_examples` outputs at each voxel. The volume variation coefficient (VVC) + is `std/mean` across the whole output, including `num_examples`. See original paper for clarification. + - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then + concatenating across the first dimension containing `num_examples`. This allows the user to perform + their own analysis if desired. """ d = dict(data) @@ -159,59 +184,26 @@ def __call__( raise ValueError("num_examples should be multiple of batch size.") # generate batch of data of size == batch_size, dataset and dataloader - data_in = [d] * num_examples + data_in = [deepcopy(d) for _ in range(num_examples)] ds = Dataset(data_in, self.transform) - dl = DataLoader(ds, self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate) - - transform_key = self.orig_key + InverseKeys.KEY_SUFFIX - - # create inverter - inverter = BatchInverseTransform(self.transform, dl, collate_fn=list_data_collate) - - outputs: List[np.ndarray] = [] + dl = DataLoader(ds, num_workers=self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate) - for batch_data in tqdm(dl) if has_tqdm and self.progress else dl: - - batch_images = batch_data[self.image_key].to(self.device) + outs: List = [] + for b in tqdm(dl) if has_tqdm and self.progress else dl: # do model forward pass - batch_output = self.inferrer_fn(batch_images) - if isinstance(batch_output, torch.Tensor): - batch_output = batch_output.detach().cpu() - if isinstance(batch_output, np.ndarray): - batch_output = torch.Tensor(batch_output) - - transform_info = batch_data[transform_key] - if self.nearest_interp: - transform_info = convert_inverse_interp_mode( - trans_info=deepcopy(transform_info), - mode="nearest", - align_corners=None, - ) - - # create a dictionary containing the inferred batch and their transforms - inferred_dict = {self.orig_key: batch_output, transform_key: transform_info} - # if meta dict is present, add that too (required for some inverse transforms) - meta_dict_key = self.orig_meta_keys or f"{self.orig_key}_{self.meta_key_postfix}" - if meta_dict_key in batch_data: - inferred_dict[meta_dict_key] = batch_data[meta_dict_key] + b[self._pred_key] = self.inferrer_fn(b[self.image_key].to(self.device)) + outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)]) - # do inverse transformation (allow missing keys as only inverting the orig_key) - with allow_missing_keys_mode(self.transform): # type: ignore - inv_batch = inverter(inferred_dict) - - # append - outputs.append(inv_batch[self.orig_key]) - - # output - output: np.ndarray = np.concatenate(outputs) + output: NdarrayOrTensor = stack(outs, 0) if self.return_full_data: return output # calculate metrics - mode = np.array(torch.mode(torch.Tensor(output.astype(np.int64)), dim=0).values) - mean: np.ndarray = np.mean(output, axis=0) # type: ignore - std: np.ndarray = np.std(output, axis=0) # type: ignore - vvc: float = (np.std(output) / np.mean(output)).item() - return mode, mean, std, vvc + _mode = mode(output, dim=0) + mean = output.mean(0) + std = output.std(0) + vvc = (output.std() / output.mean()).item() + + return _mode, mean, std, vvc diff --git a/monai/data/thread_buffer.py b/monai/data/thread_buffer.py index da5847465e..e21af69813 100644 --- a/monai/data/thread_buffer.py +++ b/monai/data/thread_buffer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -83,26 +83,44 @@ class ThreadDataLoader(DataLoader): iterate over data from the loader as expected however the data is generated on a separate thread. Use this class where a `DataLoader` instance is required and not just an iterable object. + The default behaviour with `repeats` set to 1 is to yield each batch as it is generated, however with a higher + value the generated batch is yielded that many times while underlying dataset asynchronously generates the next. + Typically not all relevant information is learned from a batch in a single iteration so training multiple times + on the same batch will still produce good training with minimal short-term overfitting while allowing a slow batch + generation process more time to produce a result. + + Another typical usage is to accelerate light-weight preprocessing (usually cached all the deterministic transforms + and no IO operations), because it leverages the separate thread to execute preprocessing to avoid unnecessary IPC + between multiple workers of DataLoader. And as CUDA may not work well with the multi-processing of DataLoader, + `ThreadDataLoader` can be useful for GPU transforms. For more details: + https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_model_training_guide.md. + + See: + * Fischetti et al. "Faster SGD training by minibatch persistency." ArXiv (2018) https://arxiv.org/abs/1806.07353 + * Dami et al., "Faster Neural Network Training with Data Echoing" ArXiv (2020) https://arxiv.org/abs/1907.05550 + * Ramezani et al. "GCN meets GPU: Decoupling "When to Sample" from "How to Sample"." NeurIPS (2020). + https://proceedings.neurips.cc/paper/2020/file/d714d2c5a796d5814c565d78dd16188d-Paper.pdf + Args: dataset: input dataset. buffer_size: number of items to buffer from the data source. buffer_timeout: time to wait for an item from the buffer, or to wait while the buffer is full when adding items. - num_workers: number of the multi-prcessing workers in PyTorch DataLoader. + repeats: number of times to yield the same batch. + kwargs: other arguments for `DataLoader` except for `dataset`. """ def __init__( - self, - dataset: Dataset, - buffer_size: int = 1, - buffer_timeout: float = 0.01, - num_workers: int = 0, - **kwargs, + self, dataset: Dataset, buffer_size: int = 1, buffer_timeout: float = 0.01, repeats: int = 1, **kwargs ): - super().__init__(dataset, num_workers, **kwargs) + super().__init__(dataset, **kwargs) self.buffer_size = buffer_size self.buffer_timeout = buffer_timeout + self.repeats = repeats def __iter__(self): buffer = ThreadBuffer(src=super().__iter__(), buffer_size=self.buffer_size, timeout=self.buffer_timeout) - yield from buffer + + for batch in buffer: + for _ in range(self.repeats): + yield batch diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py new file mode 100644 index 0000000000..585db14712 --- /dev/null +++ b/monai/data/torchscript_utils.py @@ -0,0 +1,149 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import json +import os +from typing import IO, Any, Mapping, Optional, Sequence, Tuple, Union + +import torch + +from monai.config import get_config_values +from monai.utils import JITMetadataKeys +from monai.utils.module import pytorch_after + +METADATA_FILENAME = "metadata.json" + + +def save_net_with_metadata( + jit_obj: torch.nn.Module, + filename_prefix_or_stream: Union[str, IO[Any]], + include_config_vals: bool = True, + append_timestamp: bool = False, + meta_values: Optional[Mapping[str, Any]] = None, + more_extra_files: Optional[Mapping[str, bytes]] = None, +) -> None: + """ + Save the JIT object (script or trace produced object) `jit_obj` to the given file or stream with metadata + included as a JSON file. The Torchscript format is a zip file which can contain extra file data which is used + here as a mechanism for storing metadata about the network being saved. The data in `meta_values` should be + compatible with conversion to JSON using the standard library function `dumps`. The intent is this metadata will + include information about the network applicable to some use case, such as describing the input and output format, + a network name and version, a plain language description of what the network does, and other relevant scientific + information. Clients can use this information to determine automatically how to use the network, and users can + read what the network does and keep track of versions. + + Examples:: + + net = torch.jit.script(monai.networks.nets.UNet(2, 1, 1, [8, 16], [2])) + + meta = { + "name": "Test UNet", + "used_for": "demonstration purposes", + "input_dims": 2, + "output_dims": 2 + } + + # save the Torchscript bundle with the above dictionary stored as an extra file + save_net_with_metadata(m, "test", meta_values=meta) + + # load the network back, `loaded_meta` has same data as `meta` plus version information + loaded_net, loaded_meta, _ = load_net_with_metadata("test.pt") + + + Args: + jit_obj: object to save, should be generated by `script` or `trace`. + filename_prefix_or_stream: filename or file-like stream object, if filename has no extension it becomes `.pt`. + include_config_vals: if True, MONAI, Pytorch, and Numpy versions are included in metadata. + append_timestamp: if True, a timestamp for "now" is appended to the file's name before the extension. + meta_values: metadata values to store with the object, not limited just to keys in `JITMetadataKeys`. + more_extra_files: other extra file data items to include in bundle, see `_extra_files` of `torch.jit.save`. + """ + + now = datetime.datetime.now() + metadict = {} + + if include_config_vals: + metadict.update(get_config_values()) + metadict[JITMetadataKeys.TIMESTAMP.value] = now.astimezone().isoformat() + + if meta_values is not None: + metadict.update(meta_values) + + json_data = json.dumps(metadict) + + # Pytorch>1.6 can use dictionaries directly, otherwise need to use special map object + if pytorch_after(1, 7): + extra_files = {METADATA_FILENAME: json_data.encode()} + + if more_extra_files is not None: + extra_files.update(more_extra_files) + else: + extra_files = torch._C.ExtraFilesMap() # type:ignore[attr-defined] + extra_files[METADATA_FILENAME] = json_data.encode() + + if more_extra_files is not None: + for k, v in more_extra_files.items(): + extra_files[k] = v + + if isinstance(filename_prefix_or_stream, str): + filename_no_ext, ext = os.path.splitext(filename_prefix_or_stream) + if ext == "": + ext = ".pt" + + if append_timestamp: + filename_prefix_or_stream = now.strftime(f"{filename_no_ext}_%Y%m%d%H%M%S{ext}") + else: + filename_prefix_or_stream = filename_no_ext + ext + + torch.jit.save(jit_obj, filename_prefix_or_stream, extra_files) + + +def load_net_with_metadata( + filename_prefix_or_stream: Union[str, IO[Any]], + map_location: Optional[torch.device] = None, + more_extra_files: Sequence[str] = (), +) -> Tuple[torch.nn.Module, dict, dict]: + """ + Load the module object from the given Torchscript filename or stream, and convert the stored JSON metadata + back to a dict object. This will produce an empty dict if the metadata file is not present. + + Args: + filename_prefix_or_stream: filename or file-like stream object. + map_location: network map location as in `torch.jit.load`. + more_extra_files: other extra file data names to load from bundle, see `_extra_files` of `torch.jit.load`. + Returns: + Triple containing loaded object, metadata dict, and extra files dict containing other file data if present + """ + # Pytorch>1.6 can use dictionaries directly, otherwise need to use special map object + if pytorch_after(1, 7): + extra_files = {f: "" for f in more_extra_files} + extra_files[METADATA_FILENAME] = "" + else: + extra_files = torch._C.ExtraFilesMap() # type:ignore[attr-defined] + extra_files[METADATA_FILENAME] = "" + + for f in more_extra_files: + extra_files[f] = "" + + jit_obj = torch.jit.load(filename_prefix_or_stream, map_location, extra_files) + + extra_files = dict(extra_files.items()) # compatibility with ExtraFilesMap + + if METADATA_FILENAME in extra_files: + json_data = extra_files[METADATA_FILENAME] + del extra_files[METADATA_FILENAME] + else: + json_data = "{}" + + json_data_dict = json.loads(json_data) + + return jit_obj, json_data_dict, extra_files diff --git a/monai/data/utils.py b/monai/data/utils.py index aab23217dc..495daf15e2 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,26 +11,31 @@ import hashlib import json +import logging import math import os import pickle import warnings -from collections import defaultdict +from collections import abc, defaultdict from copy import deepcopy from functools import reduce -from itertools import product, starmap -from pathlib import Path, PurePath +from itertools import product, starmap, zip_longest +from pathlib import PurePath from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch from torch.utils.data._utils.collate import default_collate +from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike from monai.networks.layers.simplelayers import GaussianFilter from monai.utils import ( MAX_SEED, BlendMode, + Method, NumpyPadMode, + convert_data_type, + convert_to_dst_type, ensure_tuple, ensure_tuple_rep, ensure_tuple_size, @@ -40,7 +45,6 @@ look_up_option, optional_import, ) -from monai.utils.enums import Method pd, _ = optional_import("pandas") DataFrame, _ = optional_import("pandas", name="DataFrame") @@ -48,35 +52,46 @@ __all__ = [ - "get_random_patch", - "iter_patch_slices", - "dense_patch_slices", - "iter_patch", - "get_valid_patch_size", - "list_data_collate", - "worker_init_fn", - "set_rnd", - "correct_nifti_header_if_necessary", - "rectify_header_sform_qform", - "zoom_affine", + "AFFINE_TOL", + "SUPPORTED_PICKLE_MOD", + "affine_to_spacing", + "compute_importance_map", "compute_shape_offset", - "to_affine_nd", + "convert_tables_to_dicts", + "correct_nifti_header_if_necessary", "create_file_basename", - "compute_importance_map", + "decollate_batch", + "dense_patch_slices", + "get_random_patch", + "get_valid_patch_size", "is_supported_format", + "iter_patch", + "iter_patch_slices", + "json_hashing", + "list_data_collate", + "no_collation", + "orientation_ras_lps", + "pad_list_data_collate", "partition_dataset", "partition_dataset_classes", - "select_cross_validation_folds", - "json_hashing", "pickle_hashing", + "rectify_header_sform_qform", + "reorient_spatial_axes", + "resample_datalist", + "select_cross_validation_folds", + "set_rnd", "sorted_dict", - "decollate_batch", - "rep_scalar_to_batch", - "pad_list_data_collate", - "no_collation", - "convert_tables_to_dicts", + "to_affine_nd", + "worker_init_fn", + "zoom_affine", ] +# module to be used by `torch.save` +SUPPORTED_PICKLE_MOD = {"pickle": pickle} + +# tolerance for affine matrix computation +AFFINE_TOL = 1e-3 + def get_random_patch( dims: Sequence[int], patch_size: Sequence[int], rand_state: Optional[np.random.RandomState] = None @@ -134,9 +149,7 @@ def iter_patch_slices( def dense_patch_slices( - image_size: Sequence[int], - patch_size: Sequence[int], - scan_interval: Sequence[int], + image_size: Sequence[int], patch_size: Sequence[int], scan_interval: Sequence[int] ) -> List[Tuple[slice, ...]]: """ Enumerate all slices defining ND patches of size `patch_size` from an `image_size` input image. @@ -251,6 +264,70 @@ def get_valid_patch_size(image_size: Sequence[int], patch_size: Union[Sequence[i return tuple(min(ms, ps or ms) for ms, ps in zip(image_size, patch_size_)) +def dev_collate(batch, level: int = 1, logger_name: str = "dev_collate"): + """ + Recursively run collate logic and provide detailed loggings for debugging purposes. + It reports results at the 'critical' level, is therefore suitable in the context of exception handling. + + Args: + batch: batch input to collate + level: current level of recursion for logging purposes + logger_name: name of logger to use for logging + + See also: https://pytorch.org/docs/stable/data.html#working-with-collate-fn + """ + elem = batch[0] + elem_type = type(elem) + l_str = ">" * level + batch_str = f"{batch[:10]}{' ... ' if len(batch) > 10 else ''}" + if isinstance(elem, torch.Tensor): + try: + logging.getLogger(logger_name).critical(f"{l_str} collate/stack a list of tensors") + return torch.stack(batch, 0) + except TypeError as e: + logging.getLogger(logger_name).critical( + f"{l_str} E: {e}, type {[type(elem).__name__ for elem in batch]} in collate({batch_str})" + ) + return + except RuntimeError as e: + logging.getLogger(logger_name).critical( + f"{l_str} E: {e}, shape {[elem.shape for elem in batch]} in collate({batch_str})" + ) + return + elif elem_type.__module__ == "numpy" and elem_type.__name__ != "str_" and elem_type.__name__ != "string_": + if elem_type.__name__ in ["ndarray", "memmap"]: + logging.getLogger(logger_name).critical(f"{l_str} collate/stack a list of numpy arrays") + return dev_collate([torch.as_tensor(b) for b in batch], level=level, logger_name=logger_name) + elif elem.shape == (): # scalars + return batch + elif isinstance(elem, (float, int, str, bytes)): + return batch + elif isinstance(elem, abc.Mapping): + out = {} + for key in elem: + logging.getLogger(logger_name).critical(f'{l_str} collate dict key "{key}" out of {len(elem)} keys') + out[key] = dev_collate([d[key] for d in batch], level=level + 1, logger_name=logger_name) + return out + elif isinstance(elem, abc.Sequence): + it = iter(batch) + els = list(it) + try: + sizes = [len(elem) for elem in els] # may not have `len` + except TypeError: + types = [type(elem).__name__ for elem in els] + logging.getLogger(logger_name).critical(f"{l_str} E: type {types} in collate({batch_str})") + return + logging.getLogger(logger_name).critical(f"{l_str} collate list of sizes: {sizes}.") + if any(s != sizes[0] for s in sizes): + logging.getLogger(logger_name).critical( + f"{l_str} collate list inconsistent sizes, got size: {sizes}, in collate({batch_str})" + ) + transposed = zip(*batch) + return [dev_collate(samples, level=level + 1, logger_name=logger_name) for samples in transposed] + logging.getLogger(logger_name).critical(f"{l_str} E: unsupported type in collate {batch_str}.") + return + + def list_data_collate(batch: Sequence): """ Enhancement for PyTorch DataLoader default collate. @@ -265,12 +342,11 @@ def list_data_collate(batch: Sequence): data = [i for k in batch for i in k] if isinstance(elem, list) else batch key = None try: - elem = batch[0] if isinstance(elem, Mapping): ret = {} for k in elem: key = k - ret[k] = default_collate([d[k] for d in data]) + ret[key] = default_collate([d[key] for d in data]) return ret return default_collate(data) except RuntimeError as re: @@ -283,7 +359,8 @@ def list_data_collate(batch: Sequence): + "`DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its " + "documentation)." ) - raise RuntimeError(re_str) + _ = dev_collate(data) + raise RuntimeError(re_str) from re except TypeError as re: re_str = str(re) if "numpy" in re_str and "Tensor" in re_str: @@ -294,10 +371,36 @@ def list_data_collate(batch: Sequence): + "creating your `DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem " + "(check its documentation)." ) - raise TypeError(re_str) + _ = dev_collate(data) + raise TypeError(re_str) from re -def decollate_batch(batch, detach: bool = True): +def _non_zipping_check(batch_data, detach, pad, fill_value): + """ + Utility function based on `decollate_batch`, to identify the largest batch size from the collated data. + returns batch_size, the list of non-iterable items, and the dictionary or list with their items decollated. + + See `decollate_batch` for more details. + """ + if isinstance(batch_data, Mapping): + _deco = {key: decollate_batch(batch_data[key], detach, pad=pad, fill_value=fill_value) for key in batch_data} + elif isinstance(batch_data, Iterable): + _deco = [decollate_batch(b, detach, pad=pad, fill_value=fill_value) for b in batch_data] + else: + raise NotImplementedError(f"Unable to de-collate: {batch_data}, type: {type(batch_data)}.") + batch_size, non_iterable = 0, [] + for k, v in _deco.items() if isinstance(_deco, Mapping) else enumerate(_deco): + if not isinstance(v, Iterable) or isinstance(v, (str, bytes)) or (isinstance(v, torch.Tensor) and v.ndim == 0): + # Not running the usual list decollate here: + # don't decollate ['test', 'test'] into [['t', 't'], ['e', 'e'], ['s', 's'], ['t', 't']] + # torch.tensor(0) is iterable but iter(torch.tensor(0)) raises TypeError: iteration over a 0-d tensor + non_iterable.append(k) + elif hasattr(v, "__len__"): + batch_size = max(batch_size, len(v)) + return batch_size, non_iterable, _deco + + +def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): """De-collate a batch of data (for example, as produced by a `DataLoader`). Returns a list of structures with the original tensor's 0-th dimension sliced into elements using `torch.unbind`. @@ -316,14 +419,14 @@ def decollate_batch(batch, detach: bool = True): batch_data = { "image": torch.rand((2,1,10,10)), - "image_meta_dict": {"scl_slope": torch.Tensor([0.0, 0.0])} + DictPostFix.meta("image"): {"scl_slope": torch.Tensor([0.0, 0.0])} } out = decollate_batch(batch_data) print(len(out)) >>> 2 print(out[0]) - >>> {'image': tensor([[[4.3549e-01...43e-01]]]), 'image_meta_dict': {'scl_slope': 0.0}} + >>> {'image': tensor([[[4.3549e-01...43e-01]]]), DictPostFix.meta("image"): {'scl_slope': 0.0}} batch_data = [torch.rand((2,1,10,10)), torch.rand((2,3,5,5))] out = decollate_batch(batch_data) @@ -335,10 +438,23 @@ def decollate_batch(batch, detach: bool = True): print(out[0]) >>> tensor([[[4.3549e-01...43e-01]]]) + batch_data = { + "image": [1, 2, 3], "meta": [4, 5], # undetermined batch size + } + out = decollate_batch(batch_data, pad=True, fill_value=0) + print(out) + >>> [{'image': 1, 'meta': 4}, {'image': 2, 'meta': 5}, {'image': 3, 'meta': 0}] + out = decollate_batch(batch_data, pad=False) + print(out) + >>> [{'image': 1, 'meta': 4}, {'image': 2, 'meta': 5}] + Args: batch: data to be de-collated. detach: whether to detach the tensors. Scalars tensors will be detached into number types instead of torch tensors. + pad: when the items in a batch indicate different batch size, whether to pad all the sequences to the longest. + If False, the batch size will be the length of the shortest sequence. + fill_value: when `pad` is True, the `fillvalue` to use when padding, defaults to `None`. """ if batch is None: return batch @@ -353,68 +469,20 @@ def decollate_batch(batch, detach: bool = True): if out_list[0].ndim == 0 and detach: return [t.item() for t in out_list] return list(out_list) - if isinstance(batch, Mapping): - _dict_list = {key: decollate_batch(batch[key], detach) for key in batch} - return [dict(zip(_dict_list, item)) for item in zip(*_dict_list.values())] - if isinstance(batch, Iterable): - item_0 = first(batch) - if ( - not isinstance(item_0, Iterable) - or isinstance(item_0, (str, bytes)) - or (isinstance(item_0, torch.Tensor) and item_0.ndim == 0) - ): - # Not running the usual list decollate here: - # don't decollate ['test', 'test'] into [['t', 't'], ['e', 'e'], ['s', 's'], ['t', 't']] - # torch.tensor(0) is iterable but iter(torch.tensor(0)) raises TypeError: iteration over a 0-d tensor - return [decollate_batch(b, detach) for b in batch] - return [list(item) for item in zip(*(decollate_batch(b, detach) for b in batch))] - raise NotImplementedError(f"Unable to de-collate: {batch}, type: {type(batch)}.") - - -def rep_scalar_to_batch(batch_data: Union[List, Dict]) -> Union[List, Dict]: - """ - Utility tp replicate the scalar items of a list or dictionary to ensure all the items have batch dimension. - It leverages `decollate_batch(detach=False)` to filter out the scalar items. - """ - - def _detect_batch_size(batch_data: Sequence): - """ - Detect the batch size from a list of data, some items in the list have batch dim, some not. - - """ - for v in batch_data: - if isinstance(v, torch.Tensor) and v.ndim > 0: - return v.shape[0] - for v in batch_data: - if issequenceiterable(v): - warnings.warn("batch_data doesn't contain batched Tensor data, use the length of first sequence data.") - return len(v) - raise RuntimeError("failed to automatically detect the batch size.") - - if isinstance(batch_data, dict): - batch_size = _detect_batch_size(list(batch_data.values())) - dict_batch = {} - for k, v in batch_data.items(): - if decollate_batch(v, detach=False) == v and not isinstance(v, list): - # if decollating a list, the result may be the same list, so should skip this case - dict_batch[k] = [deepcopy(decollate_batch(v, detach=True)) for _ in range(batch_size)] - else: - dict_batch[k] = v - - return dict_batch - if isinstance(batch_data, list): - batch_size = _detect_batch_size(batch_data) - list_batch = [] - for b in batch_data: - if decollate_batch(b, detach=False) == b and not isinstance(b, list): - list_batch.append([deepcopy(decollate_batch(b, detach=True)) for _ in range(batch_size)]) - else: - list_batch.append(b) - - return list_batch - # if not dict or list, just return the original data - return batch_data + b, non_iterable, deco = _non_zipping_check(batch, detach, pad, fill_value) + if b <= 0: # all non-iterable, single item "batch"? {"image": 1, "label": 1} + return deco + if pad: # duplicate non-iterable items to the longest batch + for k in non_iterable: + deco[k] = [deepcopy(deco[k]) for _ in range(b)] + if isinstance(deco, Mapping): + _gen = zip_longest(*deco.values(), fillvalue=fill_value) if pad else zip(*deco.values()) + return [dict(zip(deco, item)) for item in _gen] + if isinstance(deco, Iterable): + _gen = zip_longest(*deco, fillvalue=fill_value) if pad else zip(*deco) + return [list(item) for item in _gen] + raise NotImplementedError(f"Unable to de-collate: {batch}, type: {type(batch)}.") def pad_list_data_collate( @@ -430,10 +498,10 @@ def pad_list_data_collate( tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of different sizes. - This can be used on both list and dictionary data. In the case of the dictionary data, this transform will be added - to the list of invertible transforms. - - The inverse can be called using the static method: `monai.transforms.croppad.batch.PadListDataCollate.inverse`. + This can be used on both list and dictionary data. + Note that in the case of the dictionary data, this decollate function may add the transform information of + `PadListDataCollate` to the list of invertible transforms if input batch have different spatial shape, so need to + call static method: `monai.transforms.croppad.batch.PadListDataCollate.inverse` before inverting other transforms. Args: batch: batch of data to pad-collate @@ -467,9 +535,10 @@ def worker_init_fn(worker_id: int) -> None: def set_rnd(obj, seed: int) -> int: """ - Set seed or random state for all randomisable properties of obj. + Set seed or random state for all randomizable properties of obj. Args: + obj: object to set seed or random state for. seed: set the random state with an integer seed. """ if not hasattr(obj, "__dict__"): @@ -484,6 +553,30 @@ def set_rnd(obj, seed: int) -> int: return seed +def affine_to_spacing(affine: NdarrayTensor, r: int = 3, dtype=float, suppress_zeros: bool = True) -> NdarrayTensor: + """ + Computing the current spacing from the affine matrix. + + Args: + affine: a d x d affine matrix. + r: indexing based on the spatial rank, spacing is computed from `affine[:r, :r]`. + dtype: data type of the output. + suppress_zeros: whether to surpress the zeros with ones. + + Returns: + an `r` dimensional vector of spacing. + """ + _affine, *_ = convert_to_dst_type(affine[:r, :r], dst=affine, dtype=dtype) + if isinstance(_affine, torch.Tensor): + spacing = torch.sqrt(torch.sum(_affine * _affine, dim=0)) + else: + spacing = np.sqrt(np.sum(_affine * _affine, axis=0)) + if suppress_zeros: + spacing[spacing == 0] = 1.0 + spacing_, *_ = convert_to_dst_type(spacing, dst=affine, dtype=dtype) + return spacing_ + + def correct_nifti_header_if_necessary(img_nii): """ Check nifti object header's format, update the header if needed. @@ -499,7 +592,7 @@ def correct_nifti_header_if_necessary(img_nii): return img_nii # do nothing for high-dimensional array # check that affine matches zooms pixdim = np.asarray(img_nii.header.get_zooms())[:dim] - norm_affine = np.sqrt(np.sum(np.square(img_nii.affine[:dim, :dim]), 0)) + norm_affine = affine_to_spacing(img_nii.affine, r=dim) if np.allclose(pixdim, norm_affine): return img_nii if hasattr(img_nii, "get_sform"): @@ -520,8 +613,8 @@ def rectify_header_sform_qform(img_nii): d = img_nii.header["dim"][0] pixdim = np.asarray(img_nii.header.get_zooms())[:d] sform, qform = img_nii.get_sform(), img_nii.get_qform() - norm_sform = np.sqrt(np.sum(np.square(sform[:d, :d]), 0)) - norm_qform = np.sqrt(np.sum(np.square(qform[:d, :d]), 0)) + norm_sform = affine_to_spacing(sform, r=d) + norm_qform = affine_to_spacing(qform, r=d) sform_mismatch = not np.allclose(norm_sform, pixdim) qform_mismatch = not np.allclose(norm_qform, pixdim) @@ -538,14 +631,14 @@ def rectify_header_sform_qform(img_nii): img_nii.set_qform(img_nii.get_sform()) return img_nii - norm = np.sqrt(np.sum(np.square(img_nii.affine[:d, :d]), 0)) + norm = affine_to_spacing(img_nii.affine, r=d) warnings.warn(f"Modifying image pixdim from {pixdim} to {norm}") img_nii.header.set_zooms(norm) return img_nii -def zoom_affine(affine: np.ndarray, scale: Sequence[float], diagonal: bool = True): +def zoom_affine(affine: np.ndarray, scale: Union[np.ndarray, Sequence[float]], diagonal: bool = True): """ To make column norm of `affine` the same as `scale`. If diagonal is False, returns an affine that combines orthogonal rotation and the new scale. @@ -578,7 +671,7 @@ def zoom_affine(affine: np.ndarray, scale: Sequence[float], diagonal: bool = Tru d = len(affine) - 1 # compute original pixdim - norm = np.sqrt(np.sum(np.square(affine), 0))[:-1] + norm = affine_to_spacing(affine, r=d) if len(scale_np) < d: # defaults based on affine scale_np = np.append(scale_np, norm[len(scale_np) :]) scale_np = scale_np[:d] @@ -598,7 +691,7 @@ def zoom_affine(affine: np.ndarray, scale: Sequence[float], diagonal: bool = Tru def compute_shape_offset( - spatial_shape: Union[np.ndarray, Sequence[int]], in_affine: np.ndarray, out_affine: np.ndarray + spatial_shape: Union[np.ndarray, Sequence[int]], in_affine: NdarrayOrTensor, out_affine: NdarrayOrTensor ) -> Tuple[np.ndarray, np.ndarray]: """ Given input and output affine, compute appropriate shapes @@ -613,90 +706,129 @@ def compute_shape_offset( """ shape = np.array(spatial_shape, copy=True, dtype=float) sr = len(shape) - in_affine = to_affine_nd(sr, in_affine) - out_affine = to_affine_nd(sr, out_affine) + in_affine_ = convert_data_type(to_affine_nd(sr, in_affine), np.ndarray)[0] + out_affine_ = convert_data_type(to_affine_nd(sr, out_affine), np.ndarray)[0] in_coords = [(0.0, dim - 1.0) for dim in shape] - corners = np.asarray(np.meshgrid(*in_coords, indexing="ij")).reshape((len(shape), -1)) + corners: np.ndarray = np.asarray(np.meshgrid(*in_coords, indexing="ij")).reshape((len(shape), -1)) corners = np.concatenate((corners, np.ones_like(corners[:1]))) - corners = in_affine @ corners - corners_out = np.linalg.inv(out_affine) @ corners + corners = in_affine_ @ corners + try: + inv_mat = np.linalg.inv(out_affine_) + except np.linalg.LinAlgError as e: + raise ValueError(f"Affine {out_affine_} is not invertible") from e + corners_out = inv_mat @ corners corners_out = corners_out[:-1] / corners_out[-1] out_shape = np.round(corners_out.ptp(axis=1) + 1.0) - if np.allclose(nib.io_orientation(in_affine), nib.io_orientation(out_affine)): - # same orientation, get translate from the origin - offset = in_affine @ ([0] * sr + [1]) - offset = offset[:-1] / offset[-1] - else: - # different orientation, the min is the origin - corners = corners[:-1] / corners[-1] - offset = np.min(corners, 1) - return out_shape.astype(int), offset - - -def to_affine_nd(r: Union[np.ndarray, int], affine: np.ndarray) -> np.ndarray: + mat = inv_mat[:-1, :-1] + k = 0 + for i in range(corners.shape[1]): + min_corner = np.min(mat @ corners[:-1, :] - mat @ corners[:-1, i : i + 1], 1) + if np.allclose(min_corner, 0.0, rtol=AFFINE_TOL): + k = i + break + offset = corners[:-1, k] + return out_shape.astype(int, copy=False), offset + + +def to_affine_nd(r: Union[np.ndarray, int], affine: NdarrayTensor, dtype=np.float64) -> NdarrayTensor: """ Using elements from affine, to create a new affine matrix by assigning the rotation/zoom/scaling matrix and the translation vector. - when ``r`` is an integer, output is an (r+1)x(r+1) matrix, + When ``r`` is an integer, output is an (r+1)x(r+1) matrix, where the top left kxk elements are copied from ``affine``, the last column of the output affine is copied from ``affine``'s last column. `k` is determined by `min(r, len(affine) - 1)`. - when ``r`` is an affine matrix, the output has the same as ``r``, - the top left kxk elements are copied from ``affine``, + When ``r`` is an affine matrix, the output has the same shape as ``r``, + and the top left kxk elements are copied from ``affine``, the last column of the output affine is copied from ``affine``'s last column. `k` is determined by `min(len(r) - 1, len(affine) - 1)`. Args: r (int or matrix): number of spatial dimensions or an output affine to be filled. affine (matrix): 2D affine matrix + dtype: data type of the output array. Raises: ValueError: When ``affine`` dimensions is not 2. ValueError: When ``r`` is nonpositive. Returns: - an (r+1) x (r+1) matrix + an (r+1) x (r+1) matrix (tensor or ndarray depends on the input ``affine`` data type) """ - affine_np = np.array(affine, dtype=np.float64) + affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True)[0] + affine_np = affine_np.copy() if affine_np.ndim != 2: raise ValueError(f"affine must have 2 dimensions, got {affine_np.ndim}.") - new_affine = np.array(r, dtype=np.float64, copy=True) + new_affine = np.array(r, dtype=dtype, copy=True) if new_affine.ndim == 0: sr: int = int(new_affine.astype(np.uint)) if not np.isfinite(sr) or sr < 0: raise ValueError(f"r must be positive, got {sr}.") - new_affine = np.eye(sr + 1, dtype=np.float64) + new_affine = np.eye(sr + 1, dtype=dtype) d = max(min(len(new_affine) - 1, len(affine_np) - 1), 1) new_affine[:d, :d] = affine_np[:d, :d] if d > 1: new_affine[:d, -1] = affine_np[:d, -1] - return new_affine + output, *_ = convert_to_dst_type(new_affine, affine, dtype=dtype) + return output + + +def reorient_spatial_axes( + data_shape: Sequence[int], init_affine: NdarrayOrTensor, target_affine: NdarrayOrTensor +) -> Tuple[np.ndarray, NdarrayOrTensor]: + """ + Given the input ``init_affine``, compute the orientation transform between + it and ``target_affine`` by rearranging/flipping the axes. + + Returns the orientation transform and the updated affine (tensor or ndarray + depends on the input ``affine`` data type). + Note that this function requires external module ``nibabel.orientations``. + """ + init_affine_, *_ = convert_data_type(init_affine, np.ndarray) + target_affine_, *_ = convert_data_type(target_affine, np.ndarray) + start_ornt = nib.orientations.io_orientation(init_affine_) + target_ornt = nib.orientations.io_orientation(target_affine_) + try: + ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt) + except ValueError as e: + raise ValueError(f"The input affine {init_affine} and target affine {target_affine} are not compatible.") from e + new_affine = init_affine_ @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape) + new_affine, *_ = convert_to_dst_type(new_affine, init_affine) + return ornt_transform, new_affine def create_file_basename( postfix: str, - input_file_name: str, - folder_path: Union[Path, str], - data_root_dir: str = "", + input_file_name: PathLike, + folder_path: PathLike, + data_root_dir: PathLike = "", separate_folder: bool = True, - patch_index: Optional[int] = None, + patch_index=None, + makedirs: bool = True, ) -> str: """ Utility function to create the path to the output file based on the input filename (file name extension is not added by this function). - When `data_root_dir` is not specified, the output file name is: + When ``data_root_dir`` is not specified, the output file name is: - `folder_path/input_file_name (no ext.) /input_file_name (no ext.)[_postfix]` + `folder_path/input_file_name (no ext.) /input_file_name (no ext.)[_postfix][_patch_index]` - otherwise the relative path with respect to `data_root_dir` will be inserted, for example: - input_file_name: /foo/bar/test1/image.png, - postfix: seg - folder_path: /output, - data_root_dir: /foo/bar, - output will be: /output/test1/image/image_seg + otherwise the relative path with respect to ``data_root_dir`` will be inserted, for example: + + .. code-block:: python + + from monai.data import create_file_basename + create_file_basename( + postfix="seg", + input_file_name="/foo/bar/test1/image.png", + folder_path="/output", + data_root_dir="/foo/bar", + separate_folder=True, + makedirs=False) + # output: /output/test1/image/image_seg Args: postfix: output name's postfix @@ -710,6 +842,7 @@ def create_file_basename( `image.nii`, postfix is `seg` and folder_path is `output`, if `True`, save as: `output/image/image_seg.nii`, if `False`, save as `output/image_seg.nii`. default to `True`. patch_index: if not None, append the patch index to filename. + makedirs: whether to create the folder if it does not exist. """ # get the filename and directory @@ -728,16 +861,18 @@ def create_file_basename( if separate_folder: output = os.path.join(output, filename) - # create target folder if no existing - os.makedirs(output, exist_ok=True) + + if makedirs: + # create target folder if no existing + os.makedirs(output, exist_ok=True) # add the sub-folder plus the postfix name to become the file basename in the output path - output = os.path.join(output, (filename + "_" + postfix) if len(postfix) > 0 else filename) + output = os.path.join(output, filename + "_" + postfix if postfix != "" else filename) if patch_index is not None: output += f"_{patch_index}" - return os.path.abspath(output) + return os.path.normpath(output) def compute_importance_map( @@ -768,7 +903,7 @@ def compute_importance_map( """ mode = look_up_option(mode, BlendMode) - device = torch.device(device) # type: ignore[arg-type] + device = torch.device(device) if mode == BlendMode.CONSTANT: importance_map = torch.ones(patch_size, device=device).float() elif mode == BlendMode.GAUSSIAN: @@ -795,7 +930,7 @@ def compute_importance_map( return importance_map -def is_supported_format(filename: Union[Sequence[str], str], suffixes: Sequence[str]) -> bool: +def is_supported_format(filename: Union[Sequence[PathLike], PathLike], suffixes: Sequence[str]) -> bool: """ Verify whether the specified file or files format match supported suffixes. If supported suffixes is None, skip the verification and return True. @@ -806,7 +941,7 @@ def is_supported_format(filename: Union[Sequence[str], str], suffixes: Sequence[ suffixes: all the supported image suffixes of current reader, must be a list of lower case suffixes. """ - filenames: Sequence[str] = ensure_tuple(filename) + filenames: Sequence[PathLike] = ensure_tuple(filename) for name in filenames: tokens: Sequence[str] = PurePath(name).suffixes if len(tokens) == 0 or all("." + s.lower() not in "".join(tokens) for s in suffixes): @@ -993,6 +1128,31 @@ def partition_dataset_classes( return datasets +def resample_datalist(data: Sequence, factor: float, random_pick: bool = False, seed: int = 0): + """ + Utility function to resample the loaded datalist for training, for example: + If factor < 1.0, randomly pick part of the datalist and set to Dataset, useful to quickly test the program. + If factor > 1.0, repeat the datalist to enhance the Dataset. + + Args: + data: original datalist to scale. + factor: scale factor for the datalist, for example, factor=4.5, repeat the datalist 4 times and plus + 50% of the original datalist. + random_pick: whether to randomly pick data if scale factor has decimal part. + seed: random seed to randomly pick data. + + """ + scale, repeats = math.modf(factor) + ret: List = list() + + for _ in range(int(repeats)): + ret.extend(list(deepcopy(data))) + if scale > 1e-6: + ret.extend(partition_dataset(data=data, ratios=[scale, 1 - scale], shuffle=random_pick, seed=seed)[0]) + + return ret + + def select_cross_validation_folds(partitions: Sequence[Iterable], folds: Union[Sequence[int], int]) -> List: """ Select cross validation data based on data partitions and specified fold index. @@ -1029,7 +1189,7 @@ def json_hashing(item) -> bytes: """ # TODO: Find way to hash transforms content as part of the cache cache_key = hashlib.md5(json.dumps(item, sort_keys=True).encode("utf-8")).hexdigest() - return f"{cache_key}".encode("utf-8") + return f"{cache_key}".encode() def pickle_hashing(item, protocol=pickle.HIGHEST_PROTOCOL) -> bytes: @@ -1044,7 +1204,7 @@ def pickle_hashing(item, protocol=pickle.HIGHEST_PROTOCOL) -> bytes: """ cache_key = hashlib.md5(pickle.dumps(sorted_dict(item), protocol=protocol)).hexdigest() - return f"{cache_key}".encode("utf-8") + return f"{cache_key}".encode() def sorted_dict(item, key=None, reverse=False): @@ -1117,7 +1277,7 @@ def convert_tables_to_dicts( # convert data types types = {k: v["type"] for k, v in col_types.items() if v is not None and "type" in v} if types: - data_ = data_.astype(dtype=types) + data_ = data_.astype(dtype=types, copy=False) data: List[Dict] = data_.to_dict(orient="records") # group columns to generate new column @@ -1129,3 +1289,19 @@ def convert_tables_to_dicts( data = [dict(d, **{k: v[i] for k, v in groups.items()}) for i, d in enumerate(data)] return data + + +def orientation_ras_lps(affine: NdarrayTensor) -> NdarrayTensor: + """ + Convert the ``affine`` between the `RAS` and `LPS` orientation + by flipping the first two spatial dimensions. + + Args: + affine: a 2D affine matrix. + """ + sr = max(affine.shape[0] - 1, 1) # spatial rank is at least 1 + flip_d = [[-1, 1], [-1, -1, 1], [-1, -1, 1, 1]] + flip_diag = flip_d[min(sr - 1, 2)] + [1] * (sr - 3) + if isinstance(affine, torch.Tensor): + return torch.diag(torch.as_tensor(flip_diag).to(affine)) @ affine # type: ignore + return np.diag(flip_diag).astype(affine.dtype) @ affine # type: ignore diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index 89ebc8b47c..88f094c732 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,9 +15,13 @@ from .utils import ( GanKeys, IterationEvents, + PrepareBatch, + PrepareBatchDefault, + PrepareBatchExtraInput, default_make_latent, default_metric_cmp_fn, default_prepare_batch, engine_apply_transform, get_devices_spec, ) +from .workflow import BaseWorkflow, Workflow diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 1c37da71d4..c3e8c456b7 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,12 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import torch from torch.utils.data import DataLoader -from monai.config import IgniteInfo +from monai.config import IgniteInfo, KeysCollection from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer @@ -45,9 +45,13 @@ class Evaluator(Workflow): epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. - prepare_batch: function to parse image and label for current iteration. + prepare_batch: function to parse expected data (usually `image`, `label` and other network args) + from `engine.state.batch` for every iteration, for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. iteration_update: the callable function for every iteration, expect to accept `engine` - and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. key_val_metric: compute metric when every iteration completed, and save average value to @@ -80,7 +84,7 @@ def __init__( epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable] = None, + iteration_update: Optional[Callable[[Engine, Any], Any]] = None, postprocessing: Optional[Transform] = None, key_val_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, @@ -110,7 +114,7 @@ def __init__( event_to_attr=event_to_attr, decollate=decollate, ) - self.mode = look_up_option(mode, ForwardMode) + mode = look_up_option(mode, ForwardMode) if mode == ForwardMode.EVAL: self.mode = eval_mode elif mode == ForwardMode.TRAIN: @@ -147,9 +151,13 @@ class SupervisedEvaluator(Evaluator): epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. - prepare_batch: function to parse image and label for current iteration. + prepare_batch: function to parse expected data (usually `image`, `label` and other network args) + from `engine.state.batch` for every iteration, for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. iteration_update: the callable function for every iteration, expect to accept `engine` - and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. inferer: inference method that execute model forward on input data, like: SlidingWindow, etc. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. @@ -184,7 +192,7 @@ def __init__( epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable] = None, + iteration_update: Optional[Callable[[Engine, Any], Any]] = None, inferer: Optional[Inferer] = None, postprocessing: Optional[Transform] = None, key_val_metric: Optional[Dict[str, Metric]] = None, @@ -219,7 +227,7 @@ def __init__( self.network = network self.inferer = SimpleInferer() if inferer is None else inferer - def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: @@ -237,7 +245,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) + batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore if len(batch) == 2: inputs, targets = batch args: Tuple = () @@ -246,15 +254,15 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict inputs, targets, args, kwargs = batch # put iteration outputs into engine.state - engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore # execute forward computation with self.mode(self.network): if self.amp: with torch.cuda.amp.autocast(): - engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) + engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) # type: ignore else: - engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) + engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) # type: ignore engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) @@ -270,14 +278,19 @@ class EnsembleEvaluator(Evaluator): device: an object representing the device on which to run. val_data_loader: Ignite engine use data_loader to run, must be Iterable, typically be torch.DataLoader. epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`. - network: networks to evaluate in order in the evaluator, should be regular PyTorch `torch.nn.Module`. + networks: networks to evaluate in order in the evaluator, should be regular PyTorch `torch.nn.Module`. pred_keys: the keys to store every prediction data. the length must exactly match the number of networks. + if None, use "pred_{index}" as key corresponding to N networks, index from `0` to `N-1`. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. - prepare_batch: function to parse image and label for current iteration. + prepare_batch: function to parse expected data (usually `image`, `label` and other network args) + from `engine.state.batch` for every iteration, for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. iteration_update: the callable function for every iteration, expect to accept `engine` - and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. inferer: inference method that execute model forward on input data, like: SlidingWindow, etc. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. @@ -309,11 +322,11 @@ def __init__( device: torch.device, val_data_loader: Union[Iterable, DataLoader], networks: Sequence[torch.nn.Module], - pred_keys: Sequence[str], + pred_keys: Optional[KeysCollection] = None, epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable] = None, + iteration_update: Optional[Callable[[Engine, Any], Any]] = None, inferer: Optional[Inferer] = None, postprocessing: Optional[Transform] = None, key_val_metric: Optional[Dict[str, Metric]] = None, @@ -346,10 +359,14 @@ def __init__( ) self.networks = ensure_tuple(networks) - self.pred_keys = ensure_tuple(pred_keys) + self.pred_keys = ( + [f"{Keys.PRED}_{i}" for i in range(len(self.networks))] if pred_keys is None else ensure_tuple(pred_keys) + ) + if len(self.pred_keys) != len(self.networks): + raise ValueError("length of `pred_keys` must be same as the length of `networks`.") self.inferer = SimpleInferer() if inferer is None else inferer - def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: @@ -370,7 +387,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) + batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore if len(batch) == 2: inputs, targets = batch args: Tuple = () @@ -379,17 +396,21 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict inputs, targets, args, kwargs = batch # put iteration outputs into engine.state - engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore for idx, network in enumerate(self.networks): with self.mode(network): if self.amp: with torch.cuda.amp.autocast(): + if isinstance(engine.state.output, dict): + engine.state.output.update( + {self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)} + ) + else: + if isinstance(engine.state.output, dict): engine.state.output.update( {self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)} ) - else: - engine.state.output.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}) engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) diff --git a/monai/engines/multi_gpu_supervised_trainer.py b/monai/engines/multi_gpu_supervised_trainer.py index 3671dbcfd1..0433617649 100644 --- a/monai/engines/multi_gpu_supervised_trainer.py +++ b/monai/engines/multi_gpu_supervised_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,10 +34,7 @@ Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric") -__all__ = [ - "create_multigpu_supervised_trainer", - "create_multigpu_supervised_evaluator", -] +__all__ = ["create_multigpu_supervised_trainer", "create_multigpu_supervised_evaluator"] def _default_transform(_x: torch.Tensor, _y: torch.Tensor, _y_pred: torch.Tensor, loss: torch.Tensor) -> float: @@ -59,7 +56,7 @@ def create_multigpu_supervised_trainer( prepare_batch: Callable = _prepare_batch, output_transform: Callable = _default_transform, distributed: bool = False, -) -> Engine: +): """ Derived from `create_supervised_trainer` in Ignite. @@ -77,8 +74,8 @@ def create_multigpu_supervised_trainer( tuple of tensors `(batch_x, batch_y)`. output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`. - distributed: whether convert model to `DistributedDataParallel`, if have multiple devices, use - the first device as output device. + distributed: whether convert model to `DistributedDataParallel`, if `True`, `devices` must contain + only 1 GPU or CPU for current distributed rank. Returns: Engine: a trainer engine with supervised update function. @@ -90,6 +87,8 @@ def create_multigpu_supervised_trainer( devices_ = get_devices_spec(devices) if distributed: + if len(devices_) > 1: + raise ValueError(f"for distributed training, `devices` must contain only 1 GPU or CPU, but got {devices_}.") net = DistributedDataParallel(net, device_ids=devices_) elif len(devices_) > 1: net = DataParallel(net) @@ -107,7 +106,7 @@ def create_multigpu_supervised_evaluator( prepare_batch: Callable = _prepare_batch, output_transform: Callable = _default_eval_transform, distributed: bool = False, -) -> Engine: +): """ Derived from `create_supervised_evaluator` in Ignite. @@ -125,8 +124,8 @@ def create_multigpu_supervised_evaluator( output_transform: function that receives 'x', 'y', 'y_pred' and returns value to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits output expected by metrics. If you change it you should use `output_transform` in metrics. - distributed: whether convert model to `DistributedDataParallel`, if have multiple devices, use - the first device as output device. + distributed: whether convert model to `DistributedDataParallel`, if `True`, `devices` must contain + only 1 GPU or CPU for current distributed rank. Note: `engine.state.output` for this engine is defined by `output_transform` parameter and is @@ -140,6 +139,10 @@ def create_multigpu_supervised_evaluator( if distributed: net = DistributedDataParallel(net, device_ids=devices_) + if len(devices_) > 1: + raise ValueError( + f"for distributed evaluation, `devices` must contain only 1 GPU or CPU, but got {devices_}." + ) elif len(devices_) > 1: net = DataParallel(net) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 44e265be1f..774e535e7f 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import torch from torch.optim.optimizer import Optimizer @@ -26,7 +26,7 @@ from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform -from monai.utils import PT_BEFORE_1_7, min_version, optional_import +from monai.utils import min_version, optional_import, pytorch_after from monai.utils.enums import CommonKeys as Keys if TYPE_CHECKING: @@ -75,9 +75,13 @@ class SupervisedTrainer(Trainer): epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. - prepare_batch: function to parse image and label for current iteration. + prepare_batch: function to parse expected data (usually `image`, `label` and other network args) + from `engine.state.batch` for every iteration, for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. iteration_update: the callable function for every iteration, expect to accept `engine` - and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. inferer: inference method that execute model forward on input data, like: SlidingWindow, etc. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. @@ -115,7 +119,7 @@ def __init__( epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable] = None, + iteration_update: Optional[Callable[[Engine, Any], Any]] = None, inferer: Optional[Inferer] = None, postprocessing: Optional[Transform] = None, key_train_metric: Optional[Dict[str, Metric]] = None, @@ -172,7 +176,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) + batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore if len(batch) == 2: inputs, targets = batch args: Tuple = () @@ -180,7 +184,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): else: inputs, targets, args, kwargs = batch # put iteration outputs into engine.state - engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore def _compute_pred_loss(): engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) @@ -190,7 +194,7 @@ def _compute_pred_loss(): self.network.train() # `set_to_none` only work from PyTorch 1.7.0 - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): self.optimizer.zero_grad() else: self.optimizer.zero_grad(set_to_none=self.optim_set_to_none) @@ -198,13 +202,13 @@ def _compute_pred_loss(): if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): _compute_pred_loss() - self.scaler.scale(engine.state.output[Keys.LOSS]).backward() + self.scaler.scale(engine.state.output[Keys.LOSS]).backward() # type: ignore engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.scaler.step(self.optimizer) self.scaler.update() else: _compute_pred_loss() - engine.state.output[Keys.LOSS].backward() + engine.state.output[Keys.LOSS].backward() # type: ignore engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.optimizer.step() engine.fire_event(IterationEvents.MODEL_COMPLETED) @@ -241,12 +245,16 @@ class GanTrainer(Trainer): non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. d_prepare_batch: callback function to prepare batchdata for D inferer. - Defaults to return ``GanKeys.REALS`` in batchdata dict. + Defaults to return ``GanKeys.REALS`` in batchdata dict. for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. g_prepare_batch: callback function to create batch of latent input for G inferer. - Defaults to return random latents. + Defaults to return random latents. for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. g_update_latents: Calculate G loss with new latent codes. Defaults to ``True``. iteration_update: the callable function for every iteration, expect to accept `engine` - and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. key_train_metric: compute metric when every iteration completed, and save average value to @@ -286,7 +294,7 @@ def __init__( d_prepare_batch: Callable = default_prepare_batch, g_prepare_batch: Callable = default_make_latent, g_update_latents: bool = True, - iteration_update: Optional[Callable] = None, + iteration_update: Optional[Callable[[Engine, Any], Any]] = None, postprocessing: Optional[Transform] = None, key_train_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, @@ -345,18 +353,21 @@ def _iteration( if batchdata is None: raise ValueError("must provide batch data for current iteration.") - d_input = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) + d_input = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore batch_size = self.data_loader.batch_size # type: ignore - g_input = self.g_prepare_batch(batch_size, self.latent_shape, engine.state.device, engine.non_blocking) + g_input = self.g_prepare_batch( + num_latents=batch_size, + latent_size=self.latent_shape, + device=engine.state.device, # type: ignore + non_blocking=engine.non_blocking, # type: ignore + ) g_output = self.g_inferer(g_input, self.g_network) # Train Discriminator - d_total_loss = torch.zeros( - 1, - ) + d_total_loss = torch.zeros(1) for _ in range(self.d_train_steps): # `set_to_none` only work from PyTorch 1.7.0 - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): self.d_optimizer.zero_grad() else: self.d_optimizer.zero_grad(set_to_none=self.optim_set_to_none) @@ -367,9 +378,14 @@ def _iteration( # Train Generator if self.g_update_latents: - g_input = self.g_prepare_batch(batch_size, self.latent_shape, engine.state.device, engine.non_blocking) + g_input = self.g_prepare_batch( + num_latents=batch_size, + latent_size=self.latent_shape, + device=engine.state.device, # type: ignore + non_blocking=engine.non_blocking, # type: ignore + ) g_output = self.g_inferer(g_input, self.g_network) - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): self.g_optimizer.zero_grad() else: self.g_optimizer.zero_grad(set_to_none=self.optim_set_to_none) diff --git a/monai/engines/utils.py b/monai/engines/utils.py index c94cc16916..726dfc8e98 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,13 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch from monai.config import IgniteInfo from monai.transforms import apply_transform -from monai.utils import min_version, optional_import +from monai.utils import ensure_tuple, min_version, optional_import from monai.utils.enums import CommonKeys if TYPE_CHECKING: @@ -28,6 +29,9 @@ "GanKeys", "get_devices_spec", "default_prepare_batch", + "PrepareBatch", + "PrepareBatchDefault", + "PrepareBatchExtraInput", "default_make_latent", "engine_apply_transform", "default_metric_cmp_fn", @@ -100,9 +104,7 @@ def get_devices_spec(devices: Optional[Sequence[torch.device]] = None) -> List[t def default_prepare_batch( - batchdata: Dict[str, torch.Tensor], - device: Optional[Union[str, torch.device]] = None, - non_blocking: bool = False, + batchdata: Dict[str, torch.Tensor], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: """ Default function to prepare the data for current iteration. @@ -125,11 +127,81 @@ def default_prepare_batch( return batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking), None +class PrepareBatch(ABC): + """ + Interface of customized prepare_batch in the trainer or evaluator workflows. + It takes the data of current batch, target device and non_blocking flag as input. + + """ + + @abstractmethod + def __call__( + self, + batchdata: Dict[str, torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + non_blocking: bool = False, + ): + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + +class PrepareBatchDefault(PrepareBatch): + """ + Default prepare batch method to return `image` and `label` only, + it's to be consistent with `default_prepare_batch` API. + """ + + def __call__( + self, + batchdata: Dict[str, torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + non_blocking: bool = False, + ): + return default_prepare_batch(batchdata, device, non_blocking) + + +class PrepareBatchExtraInput(PrepareBatch): + """ + Customized prepare_batch for trainer or evaluator that support extra input data for network. + Extra items are specified by the `extra_keys` parameter. + + Args: + extra_keys: if a string or list provided, every item is the key of extra data in current batch, + and will pass the extra data to the `network(*args)` in order. + If a dictionary is provided, every `{k, v}` pair is the key of extra data in current batch, + `k` is the param name in network, `v` is the key of extra data in current batch, + and will pass the `{k1: batch[v1], k2: batch[v2], ...}` as kwargs to the network. + + """ + + def __init__(self, extra_keys: Union[str, Sequence[str], Dict[str, str]]) -> None: + self.extra_keys = extra_keys + + def __call__( + self, + batchdata: Dict[str, torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + non_blocking: bool = False, + ): + image, label = default_prepare_batch(batchdata, device, non_blocking) + args = list() + kwargs = dict() + + def _get_data(key: str): + data = batchdata[key] + return data.to(device=device, non_blocking=non_blocking) if isinstance(data, torch.Tensor) else data + + if isinstance(self.extra_keys, (str, list, tuple)): + for k in ensure_tuple(self.extra_keys): + args.append(_get_data(k)) + elif isinstance(self.extra_keys, dict): + for k, v in self.extra_keys.items(): + kwargs.update({k: _get_data(v)}) + + return image, label, tuple(args), kwargs + + def default_make_latent( - num_latents: int, - latent_size: int, - device: Optional[Union[str, torch.device]] = None, - non_blocking: bool = False, + num_latents: int, latent_size: int, device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False ) -> torch.Tensor: return torch.randn(num_latents, latent_size).to(device=device, non_blocking=non_blocking) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index ffb8ce05b3..245571c6db 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,7 +10,8 @@ # limitations under the License. import warnings -from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Union +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Union import torch import torch.distributed as dist @@ -19,8 +20,8 @@ from monai.config import IgniteInfo from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch -from monai.transforms import Decollated, Transform -from monai.utils import ensure_tuple, min_version, optional_import +from monai.transforms import Decollated +from monai.utils import ensure_tuple, is_scalar, min_version, optional_import from .utils import engine_apply_transform @@ -37,6 +38,18 @@ EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum") +class BaseWorkflow(ABC): + """ + Base class for any MONAI style workflow. + `run()` is designed to execute the train, evaluation or inference logic. + + """ + + @abstractmethod + def run(self, *args, **kwargs): + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optional_import """ Workflow defines the core work process inheriting from Ignite engine. @@ -54,9 +67,13 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona epoch_length: number of iterations for one epoch, default to `len(data_loader)`. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. - prepare_batch: function to parse image and label for every iteration. + prepare_batch: function to parse expected data (usually `image`, `label` and other network args) + from `engine.state.batch` for every iteration, for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. iteration_update: the callable function for every iteration, expect to accept `engine` - and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. key_metric: compute metric when every iteration completed, and save average value to @@ -94,7 +111,7 @@ def __init__( epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable] = None, + iteration_update: Optional[Callable[[Engine, Any], Any]] = None, postprocessing: Optional[Callable] = None, key_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, @@ -152,15 +169,15 @@ def set_sampler_epoch(engine: Engine): self.scaler: Optional[torch.cuda.amp.GradScaler] = None if event_names is None: - event_names = [IterationEvents] + event_names = [IterationEvents] # type: ignore else: if not isinstance(event_names, list): raise ValueError("event_names must be a list or string or EventEnum.") - event_names += [IterationEvents] + event_names += [IterationEvents] # type: ignore for name in event_names: if isinstance(name, str): self.register_events(name, event_to_attr=event_to_attr) - elif issubclass(name, EventEnum): + elif issubclass(name, EventEnum): # type: ignore self.register_events(*name, event_to_attr=event_to_attr) else: raise ValueError("event_names must be a list or string or EventEnum.") @@ -169,8 +186,8 @@ def set_sampler_epoch(engine: Engine): self._register_decollate() if postprocessing is not None: - if not decollate and isinstance(postprocessing, Transform): - warnings.warn("MONAI transforms expect `channel-first` data, `decollate=False` may not work here.") + # tips: if `decollate=False` and `postprocessing` is MONAI transforms, it may not work well + # because all the MONAI transforms expect `channel-first` data self._register_postprocessing(postprocessing) if key_metric is not None: self._register_metrics(key_metric, additional_metrics) @@ -187,8 +204,10 @@ def _register_decollate(self): def _decollate_data(engine: Engine) -> None: # replicate the scalar values to make sure all the items have batch dimension, then decollate transform = Decollated(keys=None, detach=True) - engine.state.batch = transform(engine.state.batch) - engine.state.output = transform(engine.state.output) + if isinstance(engine.state.batch, (list, dict)): + engine.state.batch = transform(engine.state.batch) + if isinstance(engine.state.output, (list, dict)): + engine.state.output = transform(engine.state.output) def _register_postprocessing(self, posttrans: Callable): """ @@ -200,9 +219,7 @@ def _register_postprocessing(self, posttrans: Callable): def _run_postprocessing(engine: Engine) -> None: if not isinstance(engine.state.batch, list) or not isinstance(engine.state.output, list): engine.state.batch, engine.state.output = engine_apply_transform( - batch=engine.state.batch, - output=engine.state.output, - transform=posttrans, + batch=engine.state.batch, output=engine.state.output, transform=posttrans ) else: for i, (b, o) in enumerate(zip(engine.state.batch, engine.state.output)): @@ -216,7 +233,7 @@ def _register_metrics(self, k_metric: Dict, add_metrics: Optional[Dict] = None): if not isinstance(k_metric, dict): raise TypeError(f"key_metric must be None or a dict but is {type(k_metric).__name__}.") self.state.key_metric_name = list(k_metric.keys())[0] - metrics = k_metric + metrics = dict(k_metric) if add_metrics is not None and len(add_metrics) > 0: if not isinstance(add_metrics, dict): raise TypeError(f"additional metrics must be None or a dict but is {type(add_metrics).__name__}.") @@ -226,12 +243,20 @@ def _register_metrics(self, k_metric: Dict, add_metrics: Optional[Dict] = None): @self.on(Events.EPOCH_COMPLETED) def _compare_metrics(engine: Engine) -> None: - if engine.state.key_metric_name is not None: - current_val_metric = engine.state.metrics[engine.state.key_metric_name] - if self.metric_cmp_fn(current_val_metric, engine.state.best_metric): - self.logger.info(f"Got new best metric of {engine.state.key_metric_name}: {current_val_metric}") - engine.state.best_metric = current_val_metric - engine.state.best_metric_epoch = engine.state.epoch + key_metric_name = engine.state.key_metric_name # type: ignore + if key_metric_name is not None: + current_val_metric = engine.state.metrics[key_metric_name] + if not is_scalar(current_val_metric): + warnings.warn( + "key metric is not a scalar value, skip the metric comparison with the current best metric." + "please set other metrics as the key metric, or change the `reduction` mode to 'mean'." + ) + return + + if self.metric_cmp_fn(current_val_metric, engine.state.best_metric): # type: ignore + self.logger.info(f"Got new best metric of {key_metric_name}: {current_val_metric}") + engine.state.best_metric = current_val_metric # type: ignore + engine.state.best_metric_epoch = engine.state.epoch # type: ignore def _register_handlers(self, handlers: Sequence): """ @@ -247,6 +272,13 @@ def run(self) -> None: Execute training, validation or evaluation based on Ignite Engine. """ + if self.state.epoch_length == 0: + warnings.warn( + "`dataloader` is emply or the specified `epoch_length` is 0, skip the `run`." + " if running distributed training, the program may hang in `all-gather`, `all-reduce`, etc." + " because not all the ranks run the same computation logic." + ) + return super().run(data=self.data_loader, max_epochs=self.state.max_epochs) def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index c9eecc6d46..649cc3cae6 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,6 +22,7 @@ from .mean_dice import MeanDice from .metric_logger import MetricLogger, MetricLoggerKeys from .metrics_saver import MetricsSaver +from .mlflow_handler import MLFlowHandler from .nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler from .parameter_scheduler import ParamSchedulerHandler from .postprocessing import PostProcessing @@ -32,13 +33,5 @@ from .stats_handler import StatsHandler from .surface_distance import SurfaceDistance from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler -from .transform_inverter import TransformInverter -from .utils import ( - evenly_divisible_all_gather, - from_engine, - stopping_fn_from_loss, - stopping_fn_from_metric, - string_list_all_gather, - write_metrics_reports, -) +from .utils import from_engine, ignore_data, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports from .validation_handler import ValidationHandler diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index f1f60abf63..91cfca354a 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -126,7 +126,7 @@ def __call__(self, engine: Engine) -> None: # save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint prior_max_epochs = engine.state.max_epochs Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint, strict=self.strict) - if engine.state.epoch > prior_max_epochs: + if prior_max_epochs is not None and engine.state.epoch > prior_max_epochs: raise ValueError( f"Epoch count ({engine.state.epoch}) in checkpoint is larger than " f"the `engine.state.max_epochs` ({prior_max_epochs}) of engine. To further train from checkpoint, " diff --git a/monai/handlers/checkpoint_saver.py b/monai/handlers/checkpoint_saver.py index f365ff73c4..f7abca4aa0 100644 --- a/monai/handlers/checkpoint_saver.py +++ b/monai/handlers/checkpoint_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,10 +11,10 @@ import logging import warnings -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, Mapping, Optional from monai.config import IgniteInfo -from monai.utils import min_version, optional_import +from monai.utils import is_scalar, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") Checkpoint, _ = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint") @@ -126,7 +126,7 @@ def __init__(self, dirname: str, filename: Optional[str] = None): super().__init__(dirname=dirname, require_empty=False, atomic=False) self.filename = filename - def __call__(self, checkpoint: Dict, filename: str, metadata: Optional[Dict] = None) -> None: + def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mapping] = None) -> None: if self.filename is not None: filename = self.filename super().__call__(checkpoint=checkpoint, filename=filename, metadata=metadata) @@ -154,14 +154,20 @@ def _final_func(engine: Engine): def _score_func(engine: Engine): if isinstance(key_metric_name, str): metric_name = key_metric_name - elif hasattr(engine.state, "key_metric_name") and isinstance(engine.state.key_metric_name, str): - metric_name = engine.state.key_metric_name + elif hasattr(engine.state, "key_metric_name"): + metric_name = engine.state.key_metric_name # type: ignore else: raise ValueError( f"Incompatible values: save_key_metric=True and key_metric_name={key_metric_name}." ) - - return (-1 if key_metric_negative_sign else 1) * engine.state.metrics[metric_name] + metric = engine.state.metrics[metric_name] + if not is_scalar(metric): + warnings.warn( + "key metric is not a scalar value, skip metric comparison and don't save a model." + "please use other metrics as key metric, or change the `reduction` mode to 'mean'." + ) + return -1 + return (-1 if key_metric_negative_sign else 1) * metric if key_metric_filename is not None and key_metric_n_saved > 1: raise ValueError("if using fixed filename to save the best metric model, we should only save 1 model.") diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 815be87754..75fb394177 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -39,6 +39,7 @@ def __init__( self, output_dir: str = "./", filename: str = "predictions.csv", + delimiter: str = ",", overwrite: bool = True, batch_transform: Callable = lambda x: x, output_transform: Callable = lambda x: x, @@ -50,14 +51,22 @@ def __init__( Args: output_dir: if `saver=None`, output CSV file directory. filename: if `saver=None`, name of the saved CSV file name. + delimiter: the delimiter character in the saved file, default to "," as the default output type is `csv`. + to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter. overwrite: if `saver=None`, whether to overwriting existing file content, if True, will clear the file before saving. otherwise, will append new content to the file. batch_transform: a callable that is used to extract the `meta_data` dictionary of the input images from `ignite.engine.state.batch`. the purpose is to get the input filenames from the `meta_data` and store with classification results together. + `engine.state` and `batch_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. output_transform: a callable that is used to extract the model prediction data from `ignite.engine.state.output`. the first dimension of its output will be treated as the batch dimension. each item in the batch will be saved individually. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. name: identifier of logging.logger to use, defaulting to `engine.logger`. save_rank: only the handler on specified rank will save to CSV file in multi-gpus validation, default to 0. @@ -68,6 +77,7 @@ def __init__( self.save_rank = save_rank self.output_dir = output_dir self.filename = filename + self.delimiter = delimiter self.overwrite = overwrite self.batch_transform = batch_transform self.output_transform = output_transform @@ -92,7 +102,14 @@ def attach(self, engine: Engine) -> None: if not engine.has_event_handler(self._finalize, Events.EPOCH_COMPLETED): engine.add_event_handler(Events.EPOCH_COMPLETED, self._finalize) - def _started(self, engine: Engine) -> None: + def _started(self, _engine: Engine) -> None: + """ + Initialize internal buffers. + + Args: + _engine: Ignite Engine, unused argument. + + """ self._outputs = [] self._filenames = [] @@ -114,12 +131,12 @@ def __call__(self, engine: Engine) -> None: o = o.detach() self._outputs.append(o) - def _finalize(self, engine: Engine) -> None: + def _finalize(self, _engine: Engine) -> None: """ All gather classification results from ranks and save to CSV file. Args: - engine: Ignite Engine, it can be a trainer, validator or evaluator. + _engine: Ignite Engine, unused argument. """ ws = idist.get_world_size() if self.save_rank >= ws: @@ -140,6 +157,8 @@ def _finalize(self, engine: Engine) -> None: # save to CSV file only in the expected rank if idist.get_rank() == self.save_rank: - saver = self.saver or CSVSaver(self.output_dir, self.filename, self.overwrite) + saver = self.saver or CSVSaver( + output_dir=self.output_dir, filename=self.filename, overwrite=self.overwrite, delimiter=self.delimiter + ) saver.save_batch(outputs, meta_dict) saver.finalize() diff --git a/monai/handlers/confusion_matrix.py b/monai/handlers/confusion_matrix.py index 368aacc6cb..e3fc4bfbf1 100644 --- a/monai/handlers/confusion_matrix.py +++ b/monai/handlers/confusion_matrix.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,11 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable +from typing import Callable, Union from monai.handlers.ignite_metric import IgniteMetric from monai.metrics import ConfusionMatrixMetric -from monai.metrics.utils import MetricReduction +from monai.utils.enums import MetricReduction class ConfusionMatrix(IgniteMetric): @@ -25,6 +25,8 @@ def __init__( self, include_background: bool = True, metric_name: str = "hit_rate", + compute_sample: bool = False, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: @@ -40,11 +42,17 @@ def __init__( ``"informedness"``, ``"markedness"``] Some of the metrics have multiple aliases (as shown in the wikipedia page aforementioned), and you can also input those names instead. + compute_sample: when reducing, if ``True``, each sample's metric will be computed based on each confusion matrix first. + if ``False``, compute reduction on the confusion matrices first, defaults to ``False``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. - for example: if `ignite.engine.state.output` is `{"pred": xxx, "label": xxx, "other": xxx}`, - output_transform can be `lambda x: (x["pred"], x["label"])`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. save_details: whether to save metric computation details per image, for example: TP/TN/FP/FN of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. @@ -54,12 +62,8 @@ def __init__( metric_fn = ConfusionMatrixMetric( include_background=include_background, metric_name=metric_name, - compute_sample=False, - reduction=MetricReduction.MEAN, + compute_sample=compute_sample, + reduction=reduction, ) self.metric_name = metric_name - super().__init__( - metric_fn=metric_fn, - output_transform=output_transform, - save_details=save_details, - ) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) diff --git a/monai/handlers/decollate_batch.py b/monai/handlers/decollate_batch.py index 4e99fc6f04..a0d0ef3ad2 100644 --- a/monai/handlers/decollate_batch.py +++ b/monai/handlers/decollate_batch.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -88,7 +88,7 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - if self.batch_transform is not None: + if self.batch_transform is not None and isinstance(engine.state.batch, (list, dict)): engine.state.batch = self.batch_transform(engine.state.batch) - if self.output_transform is not None: + if self.output_transform is not None and isinstance(engine.state.output, (list, dict)): engine.state.output = self.output_transform(engine.state.output) diff --git a/monai/handlers/earlystop_handler.py b/monai/handlers/earlystop_handler.py index e194b50d59..8d57526676 100644 --- a/monai/handlers/earlystop_handler.py +++ b/monai/handlers/earlystop_handler.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/handlers/garbage_collector.py b/monai/handlers/garbage_collector.py index fffca2a740..74ccac8a72 100644 --- a/monai/handlers/garbage_collector.py +++ b/monai/handlers/garbage_collector.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -42,6 +42,7 @@ class GarbageCollector: """ def __init__(self, trigger_event: str = "epoch", log_level: int = 10): + self.trigger_event: Events if isinstance(trigger_event, Events): self.trigger_event = trigger_event elif trigger_event.lower() == "epoch": diff --git a/monai/handlers/hausdorff_distance.py b/monai/handlers/hausdorff_distance.py index a25ef04383..739c9e9935 100644 --- a/monai/handlers/hausdorff_distance.py +++ b/monai/handlers/hausdorff_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional +from typing import Callable, Optional, Union from monai.handlers.ignite_metric import IgniteMetric from monai.metrics import HausdorffDistanceMetric @@ -27,6 +27,7 @@ def __init__( distance_metric: str = "euclidean", percentile: Optional[float] = None, directed: bool = False, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: @@ -41,11 +42,15 @@ def __init__( percentile of the Hausdorff Distance rather than the maximum result will be achieved. Defaults to ``None``. directed: whether to calculate directed Hausdorff distance. Defaults to ``False``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. - for example: if `ignite.engine.state.output` is `{"pred": xxx, "label": xxx, "other": xxx}`, - output_transform can be `lambda x: (x["pred"], x["label"])`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. save_details: whether to save metric computation details per image, for example: hausdorff distance of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. @@ -55,10 +60,6 @@ def __init__( distance_metric=distance_metric, percentile=percentile, directed=directed, - reduction=MetricReduction.MEAN, - ) - super().__init__( - metric_fn=metric_fn, - output_transform=output_transform, - save_details=save_details, + reduction=reduction, ) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py index cbf84e4626..f28923af68 100644 --- a/monai/handlers/ignite_metric.py +++ b/monai/handlers/ignite_metric.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -41,18 +41,16 @@ class IgniteMetric(Metric): # type: ignore[valid-type, misc] # due to optional_ output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. - for example: if `ignite.engine.state.output` is `{"pred": xxx, "label": xxx, "other": xxx}`, - output_transform can be `lambda x: (x["pred"], x["label"])`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. save_details: whether to save metric computation details per image, for example: mean_dice of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. """ def __init__( - self, - metric_fn: CumulativeIterationMetric, - output_transform: Callable = lambda x: x, - save_details: bool = True, + self, metric_fn: CumulativeIterationMetric, output_transform: Callable = lambda x: x, save_details: bool = True ) -> None: self._is_reduced: bool = False self.metric_fn = metric_fn @@ -101,9 +99,13 @@ def compute(self) -> Any: if self.save_details: if self._engine is None or self._name is None: raise RuntimeError("please call the attach() function to connect expected engine first.") - self._engine.state.metric_details[self._name] = self.metric_fn.get_buffer() + self._engine.state.metric_details[self._name] = self.metric_fn.get_buffer() # type: ignore - return result.item() if isinstance(result, torch.Tensor) else result + if isinstance(result, torch.Tensor): + result = result.squeeze() + if result.ndim == 0: + result = result.item() + return result def attach(self, engine: Engine, name: str) -> None: """ @@ -120,4 +122,4 @@ def attach(self, engine: Engine, name: str) -> None: self._engine = engine self._name = name if self.save_details and not hasattr(engine.state, "metric_details"): - engine.state.metric_details = {} + engine.state.metric_details = {} # type: ignore diff --git a/monai/handlers/lr_schedule_handler.py b/monai/handlers/lr_schedule_handler.py index 3e57ac7bbd..db186bd73d 100644 --- a/monai/handlers/lr_schedule_handler.py +++ b/monai/handlers/lr_schedule_handler.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/handlers/mean_dice.py b/monai/handlers/mean_dice.py index ba5805fc19..c5609c6746 100644 --- a/monai/handlers/mean_dice.py +++ b/monai/handlers/mean_dice.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable +from typing import Callable, Union from monai.handlers.ignite_metric import IgniteMetric from monai.metrics import DiceMetric @@ -24,6 +24,7 @@ class MeanDice(IgniteMetric): def __init__( self, include_background: bool = True, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: @@ -32,20 +33,20 @@ def __init__( Args: include_background: whether to include dice computation on the first channel of the predicted output. Defaults to True. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. - for example: if `ignite.engine.state.output` is `{"pred": xxx, "label": xxx, "other": xxx}`, - output_transform can be `lambda x: (x["pred"], x["label"])`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. save_details: whether to save metric computation details per image, for example: mean dice of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. See also: :py:meth:`monai.metrics.meandice.compute_meandice` """ - metric_fn = DiceMetric(include_background=include_background, reduction=MetricReduction.MEAN) - super().__init__( - metric_fn=metric_fn, - output_transform=output_transform, - save_details=save_details, - ) + metric_fn = DiceMetric(include_background=include_background, reduction=reduction) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) diff --git a/monai/handlers/metric_logger.py b/monai/handlers/metric_logger.py index 64553955b7..350d1978de 100644 --- a/monai/handlers/metric_logger.py +++ b/monai/handlers/metric_logger.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -57,6 +57,9 @@ class MetricLogger: Args: loss_transform: Converts the `output` value from the trainer's state into a loss value + `engine.state` and `loss_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. metric_transform: Converts the metric value coming from the trainer/evaluator's state into a storable value evaluator: Optional evaluator to consume metric results from at the end of its evaluation run """ diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py index 97b080b244..f07aa4f39c 100644 --- a/monai/handlers/metrics_saver.py +++ b/monai/handlers/metrics_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -50,6 +50,9 @@ class MetricsSaver: batch_transform: a callable that is used to extract the `meta_data` dictionary of the input images from `ignite.engine.state.batch` if saving metric details. the purpose is to get the input filenames from the `meta_data` and store with metric details together. + `engine.state` and `batch_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. summary_ops: expected computation operations to generate the summary report. it can be: None, "*" or list of strings, default to None. None - don't generate summary report for every expected metric_details. @@ -67,7 +70,8 @@ class mean median max 5percentile 95percentile notnans mean 6.2500 6.2500 7.0000 5.5750 6.9250 2.0000 save_rank: only the handler on specified rank will save to files in multi-gpus validation, default to 0. - delimiter: the delimiter character in CSV file, default to "\t". + delimiter: the delimiter character in the saved file, default to "," as the default output type is `csv`. + to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter. output_type: expected output file type, supported types: ["csv"], default to "csv". """ @@ -80,7 +84,7 @@ def __init__( batch_transform: Callable = lambda x: x, summary_ops: Optional[Union[str, Sequence[str]]] = None, save_rank: int = 0, - delimiter: str = "\t", + delimiter: str = ",", output_type: str = "csv", ) -> None: self.save_dir = save_dir @@ -102,7 +106,14 @@ def attach(self, engine: Engine) -> None: engine.add_event_handler(Events.ITERATION_COMPLETED, self._get_filenames) engine.add_event_handler(Events.EPOCH_COMPLETED, self) - def _started(self, engine: Engine) -> None: + def _started(self, _engine: Engine) -> None: + """ + Initialize internal buffers. + + Args: + _engine: Ignite Engine, unused argument. + + """ self._filenames = [] def _get_filenames(self, engine: Engine) -> None: @@ -132,10 +143,12 @@ def __call__(self, engine: Engine) -> None: if self.metrics is not None and len(engine.state.metrics) > 0: _metrics = {k: v for k, v in engine.state.metrics.items() if k in self.metrics or "*" in self.metrics} _metric_details = {} - if self.metric_details is not None and len(engine.state.metric_details) > 0: - for k, v in engine.state.metric_details.items(): - if k in self.metric_details or "*" in self.metric_details: - _metric_details[k] = v + if hasattr(engine.state, "metric_details"): + details = engine.state.metric_details # type: ignore + if self.metric_details is not None and len(details) > 0: + for k, v in details.items(): + if k in self.metric_details or "*" in self.metric_details: + _metric_details[k] = v write_metrics_reports( save_dir=self.save_dir, diff --git a/monai/handlers/mlflow_handler.py b/monai/handlers/mlflow_handler.py new file mode 100644 index 0000000000..664a1c8730 --- /dev/null +++ b/monai/handlers/mlflow_handler.py @@ -0,0 +1,199 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence + +import torch + +from monai.config import IgniteInfo +from monai.utils import min_version, optional_import + +Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") +mlflow, _ = optional_import("mlflow") + +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + +DEFAULT_TAG = "Loss" + + +class MLFlowHandler: + """ + MLFlowHandler defines a set of Ignite Event-handlers for the MLFlow tracking logics. + It can be used for any Ignite Engine(trainer, validator and evaluator). + And it can track both epoch level and iteration level logging, then MLFlow can store + the data and visualize. + The expected data source is Ignite ``engine.state.output`` and ``engine.state.metrics``. + + Default behaviors: + - When EPOCH_COMPLETED, track each dictionary item in + ``engine.state.metrics`` in MLFlow. + - When ITERATION_COMPLETED, track expected item in + ``self.output_transform(engine.state.output)`` in MLFlow, default to `Loss`. + + Usage example is available in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb. + + Args: + tracking_uri: connects to a tracking URI. can also set the `MLFLOW_TRACKING_URI` environment + variable to have MLflow find a URI from there. in both cases, the URI can either be + a HTTP/HTTPS URI for a remote server, a database connection string, or a local path + to log data to a directory. The URI defaults to path `mlruns`. + for more details: https://mlflow.org/docs/latest/python_api/mlflow.html#mlflow.set_tracking_uri. + iteration_log: whether to log data to MLFlow when iteration completed, default to `True`. + epoch_log: whether to log data to MLFlow when epoch completed, default to `True`. + epoch_logger: customized callable logger for epoch level logging with MLFlow. + Must accept parameter "engine", use default logger if None. + iteration_logger: customized callable logger for iteration level logging with MLFlow. + Must accept parameter "engine", use default logger if None. + output_transform: a callable that is used to transform the + ``ignite.engine.state.output`` into a scalar to track, or a dictionary of {key: scalar}. + By default this value logging happens when every iteration completed. + The default behavior is to track loss from output[0] as output is a decollated list + and we replicated loss value for every item of the decollated list. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. + global_epoch_transform: a callable that is used to customize global epoch number. + For example, in evaluation, the evaluator engine might want to track synced epoch number + with the trainer engine. + state_attributes: expected attributes from `engine.state`, if provided, will extract them + when epoch completed. + tag_name: when iteration output is a scalar, `tag_name` is used to track, defaults to `'Loss'`. + + For more details of MLFlow usage, please refer to: https://mlflow.org/docs/latest/index.html. + + """ + + def __init__( + self, + tracking_uri: Optional[str] = None, + iteration_log: bool = True, + epoch_log: bool = True, + epoch_logger: Optional[Callable[[Engine], Any]] = None, + iteration_logger: Optional[Callable[[Engine], Any]] = None, + output_transform: Callable = lambda x: x[0], + global_epoch_transform: Callable = lambda x: x, + state_attributes: Optional[Sequence[str]] = None, + tag_name: str = DEFAULT_TAG, + ) -> None: + if tracking_uri is not None: + mlflow.set_tracking_uri(tracking_uri) + + self.iteration_log = iteration_log + self.epoch_log = epoch_log + self.epoch_logger = epoch_logger + self.iteration_logger = iteration_logger + self.output_transform = output_transform + self.global_epoch_transform = global_epoch_transform + self.state_attributes = state_attributes + self.tag_name = tag_name + + def attach(self, engine: Engine) -> None: + """ + Register a set of Ignite Event-Handlers to a specified Ignite engine. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if not engine.has_event_handler(self.start, Events.STARTED): + engine.add_event_handler(Events.STARTED, self.start) + if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): + engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) + if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): + engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed) + + def start(self) -> None: + """ + Check MLFlow status and start if not active. + + """ + if mlflow.active_run() is None: + mlflow.start_run() + + def close(self) -> None: + """ + Stop current running logger of MLFlow. + + """ + mlflow.end_run() + + def epoch_completed(self, engine: Engine) -> None: + """ + Handler for train or validation/evaluation epoch completed Event. + Track epoch level log, default values are from Ignite `engine.state.metrics` dict. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if self.epoch_logger is not None: + self.epoch_logger(engine) + else: + self._default_epoch_log(engine) + + def iteration_completed(self, engine: Engine) -> None: + """ + Handler for train or validation/evaluation iteration completed Event. + Track iteration level log. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if self.iteration_logger is not None: + self.iteration_logger(engine) + else: + self._default_iteration_log(engine) + + def _default_epoch_log(self, engine: Engine) -> None: + """ + Execute epoch level log operation. + Default to track the values from Ignite `engine.state.metrics` dict and + track the values of specified attributes of `engine.state`. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + + """ + log_dict = engine.state.metrics + if not log_dict: + return + + current_epoch = self.global_epoch_transform(engine.state.epoch) + mlflow.log_metrics(log_dict, step=current_epoch) + + if self.state_attributes is not None: + attrs = {attr: getattr(engine.state, attr, None) for attr in self.state_attributes} + mlflow.log_metrics(attrs, step=current_epoch) + + def _default_iteration_log(self, engine: Engine) -> None: + """ + Execute iteration log operation based on Ignite `engine.state.output` data. + Log the values from `self.output_transform(engine.state.output)`. + Since `engine.state.output` is a decollated list and we replicated the loss value for every item + of the decollated list, the default behavior is to track the loss from `output[0]`. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + + """ + loss = self.output_transform(engine.state.output) + if loss is None: + return + + if not isinstance(loss, dict): + loss = {self.tag_name: loss.item() if isinstance(loss, torch.Tensor) else loss} + + mlflow.log_metrics(loss, step=engine.state.iteration) diff --git a/monai/handlers/nvtx_handlers.py b/monai/handlers/nvtx_handlers.py index aba7a7ec0e..327c156f63 100644 --- a/monai/handlers/nvtx_handlers.py +++ b/monai/handlers/nvtx_handlers.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -50,9 +50,7 @@ class RangeHandler: """ def __init__( - self, - events: Union[str, Tuple[Union[str, Events], Union[str, Events]]], - msg: Optional[str] = None, + self, events: Union[str, Tuple[Union[str, Events], Union[str, Events]]], msg: Optional[str] = None ) -> None: self.events = self.resolve_events(events) if msg is None: @@ -73,10 +71,7 @@ def resolve_events(self, events: Union[str, Tuple]) -> Tuple[Events, Events]: if len(events) == 1: return self.create_paired_events(events[0]) if len(events) == 2: - return ( - self.get_event(events[0]), - self.get_event(events[1]), - ) + return self.get_event(events[0]), self.get_event(events[1]) raise ValueError(f"Exactly two Ignite events should be provided [received {len(events)}].") def create_paired_events(self, event: str) -> Tuple[Events, Events]: @@ -84,22 +79,11 @@ def create_paired_events(self, event: str) -> Tuple[Events, Events]: Create pair of Ignite events from a event prefix name """ event = event.upper() - event_prefix = { - "": "", - "ENGINE": "", - "EPOCH": "EPOCH_", - "ITERATION": "ITERATION_", - "BATCH": "GET_BATCH_", - } - return ( - self.get_event(event_prefix[event] + "STARTED"), - self.get_event(event_prefix[event] + "COMPLETED"), - ) + event_prefix = {"": "", "ENGINE": "", "EPOCH": "EPOCH_", "ITERATION": "ITERATION_", "BATCH": "GET_BATCH_"} + return self.get_event(event_prefix[event] + "STARTED"), self.get_event(event_prefix[event] + "COMPLETED") def get_event(self, event: Union[str, Events]) -> Events: - if isinstance(event, str): - event = event.upper() - return Events[event] + return Events[event.upper()] if isinstance(event, str) else event def attach(self, engine: Engine) -> None: """ @@ -126,10 +110,8 @@ class RangePushHandler: msg: ASCII message to associate with range """ - def __init__(self, event: Events, msg: Optional[str] = None) -> None: - if isinstance(event, str): - event = event.upper() - self.event = Events[event] + def __init__(self, event: Union[str, Events], msg: Optional[str] = None) -> None: + self.event = Events[event.upper()] if isinstance(event, str) else event if msg is None: msg = self.event.name self.msg = msg @@ -156,10 +138,8 @@ class RangePopHandler: msg: ASCII message to associate with range """ - def __init__(self, event: Events) -> None: - if isinstance(event, str): - event = event.upper() - self.event = Events[event] + def __init__(self, event: Union[str, Events]) -> None: + self.event = Events[event.upper()] if isinstance(event, str) else event def attach(self, engine: Engine) -> None: """ @@ -181,10 +161,8 @@ class MarkHandler: msg: ASCII message to associate with range """ - def __init__(self, event: Events, msg: Optional[str] = None) -> None: - if isinstance(event, str): - event = event.upper() - self.event = Events[event] + def __init__(self, event: Union[str, Events], msg: Optional[str] = None) -> None: + self.event = Events[event.upper()] if isinstance(event, str) else event if msg is None: msg = self.event.name self.msg = msg diff --git a/monai/handlers/parameter_scheduler.py b/monai/handlers/parameter_scheduler.py index b6eb35562f..67c51fd351 100644 --- a/monai/handlers/parameter_scheduler.py +++ b/monai/handlers/parameter_scheduler.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -134,7 +134,7 @@ def _exponential(initial_value: float, gamma: float, current_step: int) -> float Returns: float: new parameter value """ - return initial_value * gamma ** current_step + return initial_value * gamma**current_step @staticmethod def _step(initial_value: float, gamma: float, step_size: int, current_step: int) -> float: diff --git a/monai/handlers/postprocessing.py b/monai/handlers/postprocessing.py index 05c6bd414d..4a89c86f47 100644 --- a/monai/handlers/postprocessing.py +++ b/monai/handlers/postprocessing.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -63,9 +63,7 @@ def __call__(self, engine: Engine) -> None: """ if not isinstance(engine.state.batch, list) or not isinstance(engine.state.output, list): engine.state.batch, engine.state.output = engine_apply_transform( - batch=engine.state.batch, - output=engine.state.output, - transform=self.transform, + batch=engine.state.batch, output=engine.state.output, transform=self.transform ) else: for i, (b, o) in enumerate(zip(engine.state.batch, engine.state.output)): diff --git a/monai/handlers/regression_metrics.py b/monai/handlers/regression_metrics.py index f203439f40..bf4ac3af1d 100644 --- a/monai/handlers/regression_metrics.py +++ b/monai/handlers/regression_metrics.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,29 +23,30 @@ class MeanSquaredError(IgniteMetric): def __init__( self, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: """ Args: + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. - for example: if `ignite.engine.state.output` is `{"pred": xxx, "label": xxx, "other": xxx}`, - output_transform can be `lambda x: (x["pred"], x["label"])`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. save_details: whether to save metric computation details per image, for example: mean squared error of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. See also: :py:class:`monai.metrics.MSEMetric` """ - metric_fn = MSEMetric(reduction=MetricReduction.MEAN) - super().__init__( - metric_fn=metric_fn, - output_transform=output_transform, - save_details=save_details, - ) + metric_fn = MSEMetric(reduction=reduction) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) class MeanAbsoluteError(IgniteMetric): @@ -55,25 +56,30 @@ class MeanAbsoluteError(IgniteMetric): def __init__( self, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: """ Args: - output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. - save_details: whether to save metric computation details per image, for example: mean absolute error of every image. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. + output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then + construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or + lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. + save_details: whether to save metric computation details per image, for example: mean squared error of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. See also: :py:class:`monai.metrics.MAEMetric` """ - metric_fn = MAEMetric(reduction=MetricReduction.MEAN) - super().__init__( - metric_fn=metric_fn, - output_transform=output_transform, - save_details=save_details, - ) + metric_fn = MAEMetric(reduction=reduction) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) class RootMeanSquaredError(IgniteMetric): @@ -83,25 +89,30 @@ class RootMeanSquaredError(IgniteMetric): def __init__( self, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: """ Args: - output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. - save_details: whether to save metric computation details per image, for example: root mean squared error of every image. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. + output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then + construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or + lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. + save_details: whether to save metric computation details per image, for example: mean squared error of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. See also: :py:class:`monai.metrics.RMSEMetric` """ - metric_fn = RMSEMetric(reduction=MetricReduction.MEAN) - super().__init__( - metric_fn=metric_fn, - output_transform=output_transform, - save_details=save_details, - ) + metric_fn = RMSEMetric(reduction=reduction) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) class PeakSignalToNoiseRatio(IgniteMetric): @@ -112,6 +123,7 @@ class PeakSignalToNoiseRatio(IgniteMetric): def __init__( self, max_val: Union[int, float], + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: @@ -120,17 +132,21 @@ def __init__( Args: max_val: The dynamic range of the images/volumes (i.e., the difference between the maximum and the minimum allowed values e.g. 255 for a uint8 image). - output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. - save_details: whether to save metric computation details per image, for example: PSNR of every image. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. + output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then + construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or + lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. + save_details: whether to save metric computation details per image, for example: mean squared error of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, See also: :py:class:`monai.metrics.PSNRMetric` """ - metric_fn = PSNRMetric(max_val=max_val, reduction=MetricReduction.MEAN) - super().__init__( - metric_fn=metric_fn, - output_transform=output_transform, - save_details=save_details, - ) + metric_fn = PSNRMetric(max_val=max_val, reduction=reduction) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) diff --git a/monai/handlers/roc_auc.py b/monai/handlers/roc_auc.py index 98c8c8f8bc..68cf2e655e 100644 --- a/monai/handlers/roc_auc.py +++ b/monai/handlers/roc_auc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,7 +16,7 @@ from monai.utils import Average -class ROCAUC(IgniteMetric): # type: ignore[valid-type, misc] # due to optional_import +class ROCAUC(IgniteMetric): """ Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). accumulating predictions and the ground-truth during an epoch and applying `compute_roc_auc`. @@ -36,8 +36,9 @@ class ROCAUC(IgniteMetric): # type: ignore[valid-type, misc] # due to optional output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. - for example: if `ignite.engine.state.output` is `{"pred": xxx, "label": xxx, "other": xxx}`, - output_transform can be `lambda x: (x["pred"], x["label"])`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. Note: ROCAUC expects y to be comprised of 0's and 1's. @@ -45,14 +46,6 @@ class ROCAUC(IgniteMetric): # type: ignore[valid-type, misc] # due to optional """ - def __init__( - self, - average: Union[Average, str] = Average.MACRO, - output_transform: Callable = lambda x: x, - ) -> None: + def __init__(self, average: Union[Average, str] = Average.MACRO, output_transform: Callable = lambda x: x) -> None: metric_fn = ROCAUCMetric(average=Average(average)) - super().__init__( - metric_fn=metric_fn, - output_transform=output_transform, - save_details=False, - ) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False) diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index 535f58945b..40bb5f8bed 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -26,7 +26,7 @@ Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") -@deprecated(since="0.6.0", removed="0.7.0", msg_suffix="Please consider using `SaveImage[d]` transform instead.") +@deprecated(since="0.6.0", removed="0.9.0", msg_suffix="Please consider using `SaveImage[d]` transform instead.") class SegmentationSaver: """ Event handler triggered on completing every iteration to save the segmentation predictions into files. @@ -72,16 +72,16 @@ def __init__( - NIfTI files {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - PNG files This option is ignored. @@ -113,9 +113,15 @@ def __init__( batch_transform: a callable that is used to extract the `meta_data` dictionary of the input images from `ignite.engine.state.batch`. the purpose is to extract necessary information from the meta data: filename, affine, original_shape, etc. + `engine.state` and `batch_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. output_transform: a callable that is used to extract the model prediction data from `ignite.engine.state.output`. the first dimension of its output will be treated as the batch dimension. each item in the batch will be saved individually. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. name: identifier of logging.logger to use, defaulting to `engine.logger`. """ diff --git a/monai/handlers/smartcache_handler.py b/monai/handlers/smartcache_handler.py index e3adcbf4a0..56fee78b1d 100644 --- a/monai/handlers/smartcache_handler.py +++ b/monai/handlers/smartcache_handler.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index d5756074fc..e3b5de2d36 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,7 +11,7 @@ import logging import warnings -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence import torch @@ -31,30 +31,52 @@ class StatsHandler: """ StatsHandler defines a set of Ignite Event-handlers for all the log printing logics. - It's can be used for any Ignite Engine(trainer, validator and evaluator). + It can be used for any Ignite Engine(trainer, validator and evaluator). And it can support logging for epoch level and iteration level with pre-defined loggers. + Note that if `name` arg is None, will leverage `engine.logger` as default logger directly, otherwise, + get logger from `logging.getLogger(name)`, we can setup a logger outside first with the same `name`. + As the default log level of `RootLogger` is `WARNING`, may need to call + `logging.basicConfig(stream=sys.stdout, level=logging.INFO)` before running this handler to enable + the stats logging. + Default behaviors: - When EPOCH_COMPLETED, logs ``engine.state.metrics`` using ``self.logger``. - When ITERATION_COMPLETED, logs ``self.output_transform(engine.state.output)`` using ``self.logger``. + Usage example:: + + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + + trainer = SupervisedTrainer(...) + StatsHandler(name="train_stats").attach(trainer) + + trainer.run() + + More details of example is available in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/engines/unet_training_dict.py. + """ def __init__( self, + iteration_log: bool = True, + epoch_log: bool = True, epoch_print_logger: Optional[Callable[[Engine], Any]] = None, iteration_print_logger: Optional[Callable[[Engine], Any]] = None, output_transform: Callable = lambda x: x[0], global_epoch_transform: Callable = lambda x: x, + state_attributes: Optional[Sequence[str]] = None, name: Optional[str] = None, tag_name: str = DEFAULT_TAG, key_var_format: str = DEFAULT_KEY_VAL_FORMAT, - logger_handler: Optional[logging.Handler] = None, ) -> None: """ Args: + iteration_log: whether to log data when iteration completed, default to `True`. + epoch_log: whether to log data when epoch completed, default to `True`. epoch_print_logger: customized callable printer for epoch level logging. Must accept parameter "engine", use default printer if None. iteration_print_logger: customized callable printer for iteration level logging. @@ -65,28 +87,32 @@ def __init__( By default this value logging happens when every iteration completed. The default behavior is to print loss from output[0] as output is a decollated list and we replicated loss value for every item of the decollated list. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. global_epoch_transform: a callable that is used to customize global epoch number. For example, in evaluation, the evaluator engine might want to print synced epoch number with the trainer engine. - name: identifier of logging.logger to use, defaulting to ``engine.logger``. + state_attributes: expected attributes from `engine.state`, if provided, will extract them + when epoch completed. + name: identifier of `logging.logger` to use, if None, defaulting to ``engine.logger``. tag_name: when iteration output is a scalar, tag_name is used to print tag_name: scalar_value to logger. Defaults to ``'Loss'``. key_var_format: a formatting string to control the output string format of key: value. - logger_handler: add additional handler to handle the stats data: save to file, etc. - Add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html + """ + self.iteration_log = iteration_log + self.epoch_log = epoch_log self.epoch_print_logger = epoch_print_logger self.iteration_print_logger = iteration_print_logger self.output_transform = output_transform self.global_epoch_transform = global_epoch_transform - self.logger = logging.getLogger(name) - self._name = name - + self.state_attributes = state_attributes self.tag_name = tag_name self.key_var_format = key_var_format - if logger_handler is not None: - self.logger.addHandler(logger_handler) + self.logger = logging.getLogger(name) # if `name` is None, will default to `engine.logger` when attached + self.name = name def attach(self, engine: Engine) -> None: """ @@ -96,11 +122,16 @@ def attach(self, engine: Engine) -> None: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - if self._name is None: + if self.name is None: self.logger = engine.logger - if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): + if self.logger.getEffectiveLevel() > logging.INFO or logging.root.getEffectiveLevel() > logging.INFO: + warnings.warn( + "the effective log level of engine logger or RootLogger is higher than INFO, may not record log," + " please call `logging.basicConfig(stream=sys.stdout, level=logging.INFO)` to enable it." + ) + if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) - if not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): + if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed) if not engine.has_event_handler(self.exception_raised, Events.EXCEPTION_RAISED): engine.add_event_handler(Events.EXCEPTION_RAISED, self.exception_raised) @@ -108,7 +139,7 @@ def attach(self, engine: Engine) -> None: def epoch_completed(self, engine: Engine) -> None: """ Handler for train or validation/evaluation epoch completed Event. - Print epoch level log, default values are from Ignite state.metrics dict. + Print epoch level log, default values are from Ignite `engine.state.metrics` dict. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. @@ -122,7 +153,7 @@ def epoch_completed(self, engine: Engine) -> None: def iteration_completed(self, engine: Engine) -> None: """ Handler for train or validation/evaluation iteration completed Event. - Print iteration level log, default values are from Ignite state.logs dict. + Print iteration level log, default values are from Ignite `engine.state.output`. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. @@ -133,14 +164,14 @@ def iteration_completed(self, engine: Engine) -> None: else: self._default_iteration_print(engine) - def exception_raised(self, engine: Engine, e: Exception) -> None: + def exception_raised(self, _engine: Engine, e: Exception) -> None: """ Handler for train or validation/evaluation exception raised Event. Print the exception information and traceback. This callback may be skipped because the logic with Ignite can only trigger the first attached handler for `EXCEPTION_RAISED` event. Args: - engine: Ignite Engine, it can be a trainer, validator or evaluator. + _engine: Ignite Engine, unused argument. e: the exception caught in Ignite during engine.run(). """ @@ -149,39 +180,46 @@ def exception_raised(self, engine: Engine, e: Exception) -> None: def _default_epoch_print(self, engine: Engine) -> None: """ - Execute epoch level log operation based on Ignite engine.state data. - print the values from Ignite state.metrics dict. + Execute epoch level log operation. + Default to print the values from Ignite `engine.state.metrics` dict and + print the values of specified attributes of `engine.state`. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - prints_dict = engine.state.metrics - if not prints_dict: - return current_epoch = self.global_epoch_transform(engine.state.epoch) - out_str = f"Epoch[{current_epoch}] Metrics -- " - for name in sorted(prints_dict): - value = prints_dict[name] - out_str += self.key_var_format.format(name, value) - self.logger.info(out_str) + prints_dict = engine.state.metrics + if prints_dict is not None and len(prints_dict) > 0: + out_str = f"Epoch[{current_epoch}] Metrics -- " + for name in sorted(prints_dict): + value = prints_dict[name] + out_str += self.key_var_format.format(name, value) if is_scalar(value) else f"{name}: {str(value)}" + self.logger.info(out_str) if ( hasattr(engine.state, "key_metric_name") and hasattr(engine.state, "best_metric") and hasattr(engine.state, "best_metric_epoch") ): - out_str = f"Key metric: {engine.state.key_metric_name} " - out_str += f"best value: {engine.state.best_metric} at epoch: {engine.state.best_metric_epoch}" - self.logger.info(out_str) + out_str = f"Key metric: {engine.state.key_metric_name} " # type: ignore + out_str += f"best value: {engine.state.best_metric} " # type: ignore + out_str += f"at epoch: {engine.state.best_metric_epoch}" # type: ignore + self.logger.info(out_str) + + if self.state_attributes is not None and len(self.state_attributes) > 0: + out_str = "State values: " + for attr in self.state_attributes: + out_str += f"{attr}: {getattr(engine.state, attr, None)} " + self.logger.info(out_str) def _default_iteration_print(self, engine: Engine) -> None: """ - Execute iteration log operation based on Ignite engine.state data. - Print the values from Ignite state.logs dict. - The default behavior is to print loss from output[0] as output is a decollated list and we replicated loss - value for every item of the decollated list. + Execute iteration log operation based on Ignite `engine.state.output` data. + Print the values from `self.output_transform(engine.state.output)`. + Since `engine.state.output` is a decollated list and we replicated the loss value for every item + of the decollated list, the default behavior is to print the loss from `output[0]`. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. @@ -220,7 +258,9 @@ def _default_iteration_print(self, engine: Engine) -> None: return # no value to print num_iterations = engine.state.epoch_length - current_iteration = (engine.state.iteration - 1) % num_iterations + 1 + current_iteration = engine.state.iteration + if num_iterations is not None: + current_iteration = (current_iteration - 1) % num_iterations + 1 current_epoch = engine.state.epoch num_epochs = engine.state.max_epochs diff --git a/monai/handlers/surface_distance.py b/monai/handlers/surface_distance.py index 4fc5b5a60a..77f0debfe9 100644 --- a/monai/handlers/surface_distance.py +++ b/monai/handlers/surface_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable +from typing import Callable, Union from monai.handlers.ignite_metric import IgniteMetric from monai.metrics import SurfaceDistanceMetric @@ -26,6 +26,7 @@ def __init__( include_background: bool = False, symmetric: bool = False, distance_metric: str = "euclidean", + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: @@ -38,11 +39,15 @@ def __init__( `seg_pred` and `seg_gt`. Defaults to ``False``. distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] the metric used to compute surface distance. Defaults to ``"euclidean"``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. - for example: if `ignite.engine.state.output` is `{"pred": xxx, "label": xxx, "other": xxx}`, - output_transform can be `lambda x: (x["pred"], x["label"])`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. save_details: whether to save metric computation details per image, for example: surface dice of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. @@ -51,10 +56,6 @@ def __init__( include_background=include_background, symmetric=symmetric, distance_metric=distance_metric, - reduction=MetricReduction.MEAN, - ) - super().__init__( - metric_fn=metric_fn, - output_transform=output_transform, - save_details=save_details, + reduction=reduction, ) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index a3a0bf76b8..0105e7c8ca 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,7 +10,7 @@ # limitations under the License. import warnings -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence import numpy as np import torch @@ -20,6 +20,7 @@ from monai.visualize import plot_2d_or_3d_image Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") + if TYPE_CHECKING: from ignite.engine import Engine from torch.utils.tensorboard import SummaryWriter @@ -35,8 +36,8 @@ class TensorBoardHandler: Base class for the handlers to write data into TensorBoard. Args: - summary_writer: user can specify TensorBoard SummaryWriter, - default to create a new writer. + summary_writer: user can specify TensorBoard or TensorBoardX SummaryWriter, + default to create a new TensorBoard writer. log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`. """ @@ -64,7 +65,7 @@ def close(self): class TensorBoardStatsHandler(TensorBoardHandler): """ TensorBoardStatsHandler defines a set of Ignite Event-handlers for all the TensorBoard logics. - It's can be used for any Ignite Engine(trainer, validator and evaluator). + It can be used for any Ignite Engine(trainer, validator and evaluator). And it can support both epoch level and iteration level with pre-defined TensorBoard event writer. The expected data source is Ignite ``engine.state.output`` and ``engine.state.metrics``. @@ -73,25 +74,34 @@ class TensorBoardStatsHandler(TensorBoardHandler): ``engine.state.metrics`` to TensorBoard. - When ITERATION_COMPLETED, write each dictionary item in ``self.output_transform(engine.state.output)`` to TensorBoard. + + Usage example is available in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb. + """ def __init__( self, summary_writer: Optional[SummaryWriter] = None, log_dir: str = "./runs", + iteration_log: bool = True, + epoch_log: bool = True, epoch_event_writer: Optional[Callable[[Engine, SummaryWriter], Any]] = None, epoch_interval: int = 1, iteration_event_writer: Optional[Callable[[Engine, SummaryWriter], Any]] = None, iteration_interval: int = 1, output_transform: Callable = lambda x: x[0], global_epoch_transform: Callable = lambda x: x, + state_attributes: Optional[Sequence[str]] = None, tag_name: str = DEFAULT_TAG, ) -> None: """ Args: - summary_writer: user can specify TensorBoard SummaryWriter, - default to create a new writer. + summary_writer: user can specify TensorBoard or TensorBoardX SummaryWriter, + default to create a new TensorBoard writer. log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`. + iteration_log: whether to write data to TensorBoard when iteration completed, default to `True`. + epoch_log: whether to write data to TensorBoard when epoch completed, default to `True`. epoch_event_writer: customized callable TensorBoard writer for epoch level. Must accept parameter "engine" and "summary_writer", use default event writer if None. epoch_interval: the epoch interval at which the epoch_event_writer is called. Defaults to 1. @@ -104,18 +114,26 @@ def __init__( By default this value plotting happens when every iteration completed. The default behavior is to print loss from output[0] as output is a decollated list and we replicated loss value for every item of the decollated list. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. global_epoch_transform: a callable that is used to customize global epoch number. For example, in evaluation, the evaluator engine might want to use trainer engines epoch number when plotting epoch vs metric curves. + state_attributes: expected attributes from `engine.state`, if provided, will extract them + when epoch completed. tag_name: when iteration output is a scalar, tag_name is used to plot, defaults to ``'Loss'``. """ super().__init__(summary_writer=summary_writer, log_dir=log_dir) + self.iteration_log = iteration_log + self.epoch_log = epoch_log self.epoch_event_writer = epoch_event_writer self.epoch_interval = epoch_interval self.iteration_event_writer = iteration_event_writer self.iteration_interval = iteration_interval self.output_transform = output_transform self.global_epoch_transform = global_epoch_transform + self.state_attributes = state_attributes self.tag_name = tag_name def attach(self, engine: Engine) -> None: @@ -126,17 +144,17 @@ def attach(self, engine: Engine) -> None: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): + if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): engine.add_event_handler( Events.ITERATION_COMPLETED(every=self.iteration_interval), self.iteration_completed ) - if not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): + if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): engine.add_event_handler(Events.EPOCH_COMPLETED(every=self.epoch_interval), self.epoch_completed) def epoch_completed(self, engine: Engine) -> None: """ Handler for train or validation/evaluation epoch completed Event. - Write epoch level events, default values are from Ignite state.metrics dict. + Write epoch level events, default values are from Ignite `engine.state.metrics` dict. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. @@ -150,7 +168,7 @@ def epoch_completed(self, engine: Engine) -> None: def iteration_completed(self, engine: Engine) -> None: """ Handler for train or validation/evaluation iteration completed Event. - Write iteration level events, default values are from Ignite state.logs dict. + Write iteration level events, default values are from Ignite `engine.state.output`. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. @@ -161,31 +179,53 @@ def iteration_completed(self, engine: Engine) -> None: else: self._default_iteration_writer(engine, self._writer) + def _write_scalar(self, _engine: Engine, writer: SummaryWriter, tag: str, value: Any, step: int) -> None: + """ + Write scale value into TensorBoard. + Default to call `SummaryWriter.add_scalar()`. + + Args: + _engine: Ignite Engine, unused argument. + writer: TensorBoard or TensorBoardX writer, passed or created in TensorBoardHandler. + tag: tag name in the TensorBoard. + value: value of the scalar data for current step. + step: index of current step. + + """ + writer.add_scalar(tag, value, step) + def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter) -> None: """ - Execute epoch level event write operation based on Ignite engine.state data. - Default is to write the values from Ignite state.metrics dict. + Execute epoch level event write operation. + Default to write the values from Ignite `engine.state.metrics` dict and + write the values of specified attributes of `engine.state`. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. - writer: TensorBoard writer, created in TensorBoardHandler. + writer: TensorBoard or TensorBoardX writer, passed or created in TensorBoardHandler. """ current_epoch = self.global_epoch_transform(engine.state.epoch) summary_dict = engine.state.metrics for name, value in summary_dict.items(): - writer.add_scalar(name, value, current_epoch) + if is_scalar(value): + self._write_scalar(engine, writer, name, value, current_epoch) + + if self.state_attributes is not None: + for attr in self.state_attributes: + self._write_scalar(engine, writer, attr, getattr(engine.state, attr, None), current_epoch) writer.flush() def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> None: """ - Execute iteration level event write operation based on Ignite engine.state data. - The default behavior is to print loss from output[0] as output is a decollated list and we replicated loss - value for every item of the decollated list. + Execute iteration level event write operation based on Ignite `engine.state.output` data. + Extract the values from `self.output_transform(engine.state.output)`. + Since `engine.state.output` is a decollated list and we replicated the loss value for every item + of the decollated list, the default behavior is to track the loss from `output[0]`. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. - writer: TensorBoard writer, created in TensorBoardHandler. + writer: TensorBoard or TensorBoardX writer, passed or created in TensorBoardHandler. """ loss = self.output_transform(engine.state.output) @@ -202,12 +242,20 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> No " {}:{}".format(name, type(value)) ) continue # not plot multi dimensional output - writer.add_scalar( - name, value.item() if isinstance(value, torch.Tensor) else value, engine.state.iteration + self._write_scalar( + _engine=engine, + writer=writer, + tag=name, + value=value.item() if isinstance(value, torch.Tensor) else value, + step=engine.state.iteration, ) elif is_scalar(loss): # not printing multi dimensional output - writer.add_scalar( - self.tag_name, loss.item() if isinstance(loss, torch.Tensor) else loss, engine.state.iteration + self._write_scalar( + _engine=engine, + writer=writer, + tag=self.tag_name, + value=loss.item() if isinstance(loss, torch.Tensor) else loss, + step=engine.state.iteration, ) else: warnings.warn( @@ -225,6 +273,7 @@ class TensorBoardImageHandler(TensorBoardHandler): 2D output (shape in Batch, channel, H, W) will be shown as simple image using the first element in the batch, for 3D to ND output (shape in Batch, channel, H, W, D) input, each of ``self.max_channels`` number of images' last three dimensions will be shown as animated GIF along the last axis (typically Depth). + And if writer is from TensorBoardX, data has 3 channels and `max_channels=3`, will plot as RGB video. It can be used for any Ignite Engine (trainer, validator and evaluator). User can easily add it to engine for any expected Event, for example: ``EPOCH_COMPLETED``, @@ -239,6 +288,9 @@ class TensorBoardImageHandler(TensorBoardHandler): - Expects ``output_transform(engine.state.output)`` to return a torch tensor in format (y_pred[N, channel, ...], loss). + Usage example is available in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb. + """ def __init__( @@ -252,12 +304,13 @@ def __init__( global_iter_transform: Callable = lambda x: x, index: int = 0, max_channels: int = 1, + frame_dim: int = -3, max_frames: int = 64, ) -> None: """ Args: - summary_writer: user can specify TensorBoard SummaryWriter, - default to create a new writer. + summary_writer: user can specify TensorBoard or TensorBoardX SummaryWriter, + default to create a new TensorBoard writer. log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`. interval: plot content from engine.state every N epochs or every N iterations, default is 1. epoch_level: plot content from engine.state every N epochs or N iterations. `True` is epoch level, @@ -266,13 +319,21 @@ def __init__( then construct `(image, label)` pair. for example: if `ignite.engine.state.batch` is `{"image": xxx, "label": xxx, "other": xxx}`, `batch_transform` can be `lambda x: (x["image"], x["label"])`. will use the result to plot image from `result[0][index]` and plot label from `result[1][index]`. + `engine.state` and `batch_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. output_transform: a callable that is used to extract the `predictions` data from `ignite.engine.state.output`, will use the result to plot output from `result[index]`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. global_iter_transform: a callable that is used to customize global step number for TensorBoard. For example, in evaluation, the evaluator engine needs to know current epoch from trainer. index: plot which element in a data batch, default is the first element. max_channels: number of channels to plot. - max_frames: number of frames for 2D-t plot. + frame_dim: if plotting 3D image as GIF, specify the dimension used as frames, + expect input data shape as `NCHWD`, default to `-3` (the first spatial dim) + max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`. """ super().__init__(summary_writer=summary_writer, log_dir=log_dir) self.interval = interval @@ -281,6 +342,7 @@ def __init__( self.output_transform = output_transform self.global_iter_transform = global_iter_transform self.index = index + self.frame_dim = frame_dim self.max_frames = max_frames self.max_channels = max_channels @@ -320,13 +382,14 @@ def __call__(self, engine: Engine) -> None: ) plot_2d_or_3d_image( # add batch dim and plot the first item - show_images[None], - step, - self._writer, - 0, - self.max_channels, - self.max_frames, - "input_0", + data=show_images[None], + step=step, + writer=self._writer, + index=0, + max_channels=self.max_channels, + frame_dim=self.frame_dim, + max_frames=self.max_frames, + tag="input_0", ) show_labels = self.batch_transform(engine.state.batch)[1][self.index] @@ -338,7 +401,16 @@ def __call__(self, engine: Engine) -> None: "batch_transform(engine.state.batch)[1] must be None or one of " f"(numpy.ndarray, torch.Tensor) but is {type(show_labels).__name__}." ) - plot_2d_or_3d_image(show_labels[None], step, self._writer, 0, self.max_channels, self.max_frames, "input_1") + plot_2d_or_3d_image( + data=show_labels[None], + step=step, + writer=self._writer, + index=0, + max_channels=self.max_channels, + frame_dim=self.frame_dim, + max_frames=self.max_frames, + tag="input_1", + ) show_outputs = self.output_transform(engine.state.output)[self.index] if isinstance(show_outputs, torch.Tensor): @@ -349,6 +421,15 @@ def __call__(self, engine: Engine) -> None: "output_transform(engine.state.output) must be None or one of " f"(numpy.ndarray, torch.Tensor) but is {type(show_outputs).__name__}." ) - plot_2d_or_3d_image(show_outputs[None], step, self._writer, 0, self.max_channels, self.max_frames, "output") + plot_2d_or_3d_image( + data=show_outputs[None], + step=step, + writer=self._writer, + index=0, + max_channels=self.max_channels, + frame_dim=self.frame_dim, + max_frames=self.max_frames, + tag="output", + ) self._writer.flush() diff --git a/monai/handlers/transform_inverter.py b/monai/handlers/transform_inverter.py deleted file mode 100644 index 83b5f56396..0000000000 --- a/monai/handlers/transform_inverter.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union - -import torch - -from monai.config import IgniteInfo, KeysCollection -from monai.engines.utils import CommonKeys, IterationEvents -from monai.transforms import Invertd, InvertibleTransform -from monai.utils import deprecated, ensure_tuple, ensure_tuple_rep, min_version, optional_import - -Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") -if TYPE_CHECKING: - from ignite.engine import Engine -else: - Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") - - -@deprecated(since="0.6.0", removed="0.7.0", msg_suffix="Please consider using `Invertd` transform instead.") -class TransformInverter: - """ - Ignite handler to automatically invert `transforms`. - It takes `engine.state.output` as the input data and uses the transforms information from `engine.state.batch`. - Expect both `engine.state.output` and `engine.state.batch` to be list of dictionaries data. - The inverted data is in-place saved back to `engine.state.output` with key: "{output_key}". - And the inverted meta dict will be stored in `engine.state.batch` - with key: "{meta_keys}" or "{key}_{meta_key_postfix}". - - .. deprecated:: 0.6.0 - Use :class:`monai.transforms.Invertd` instead. - - """ - - def __init__( - self, - transform: InvertibleTransform, - output_keys: KeysCollection = CommonKeys.PRED, - batch_keys: KeysCollection = CommonKeys.IMAGE, - meta_keys: Optional[KeysCollection] = None, - batch_meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", - nearest_interp: Union[bool, Sequence[bool]] = True, - to_tensor: Union[bool, Sequence[bool]] = True, - device: Union[Union[str, torch.device], Sequence[Union[str, torch.device]]] = "cpu", - post_func: Union[Callable, Sequence[Callable]] = lambda x: x, - num_workers: Optional[int] = 0, - ) -> None: - """ - Args: - transform: a callable data transform on input data. - output_keys: the key of expected data in `ignite.engine.output`, invert transforms on it. - it also can be a list of keys, will invert transform for each of them. - Default to "pred". it's in-place operation. - batch_keys: the key of input data in `ignite.engine.batch`. will get the applied transforms - for this input data, then invert them for the expected data with `output_keys`. - It can also be a list of keys, each matches to the `output_keys` data. default to "image". - meta_keys: explicitly indicate the key for the inverted meta data dictionary. - the meta data is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `{key}_{meta_key_postfix}`. - batch_meta_keys: the key of the meta data of input data in `ignite.engine.batch`, - will get the `affine`, `data_shape`, etc. - the meta data is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `{orig_key}_{meta_key_postfix}`. - meta data will also be inverted and stored in `meta_keys`. - meta_key_postfix: if `orig_meta_keys` is None, use `{orig_key}_{meta_key_postfix}` to to fetch the - meta data from dict, if `meta_keys` is None, use `{key}_{meta_key_postfix}`. - default is `meta_dict`, the meta data is a dictionary object. - For example, to handle orig_key `image`, read/write `affine` matrices from the - metadata `image_meta_dict` dictionary's `affine` field. - the inverted meta dict will be stored with key: "{key}_{meta_key_postfix}". - nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms, - default to `True`. If `False`, use the same interpolation mode as the original transform. - it also can be a list of bool, each matches to the `output_keys` data. - to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`. - it also can be a list of bool, each matches to the `output_keys` data. - device: if converted to Tensor, move the inverted results to target device before `post_func`, - default to "cpu", it also can be a list of string or `torch.device`, - each matches to the `output_keys` data. - post_func: post processing for the inverted data, should be a callable function. - it also can be a list of callable, each matches to the `output_keys` data. - - """ - self.inverter = Invertd( - keys=output_keys, - transform=transform, - orig_keys=batch_keys, - meta_keys=meta_keys, - orig_meta_keys=batch_meta_keys, - meta_key_postfix=meta_key_postfix, - nearest_interp=nearest_interp, - to_tensor=to_tensor, - device=device, - post_func=post_func, - ) - self.output_keys = ensure_tuple(output_keys) - self.meta_keys = ensure_tuple_rep(None, len(self.output_keys)) if meta_keys is None else ensure_tuple(meta_keys) - if len(self.output_keys) != len(self.meta_keys): - raise ValueError("meta_keys should have the same length as output_keys.") - self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.output_keys)) - - def attach(self, engine: Engine) -> None: - """ - Args: - engine: Ignite Engine, it can be a trainer, validator or evaluator. - """ - engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self) - - def __call__(self, engine: Engine) -> None: - """ - Args: - engine: Ignite Engine, it can be a trainer, validator or evaluator. - """ - if not isinstance(engine.state.batch, list) or not isinstance(engine.state.output, list): - warnings.warn("inverter requires `engine.state.batch` and `engine.state.output` to be lists.") - else: - for i, (b, o) in enumerate(zip(engine.state.batch, engine.state.output)): - # combine `batch` and `output` to temporarily act as 1 dict for postprocessing - data = dict(b) - data.update(o) - ret = self.inverter(data) - - for output_key, meta_key, meta_key_postfix in zip( - self.output_keys, self.meta_keys, self.meta_key_postfix - ): - # save the inverted data into state.output - engine.state.output[i][output_key] = ret.get(output_key) - # save the inverted meta dict into state.batch - meta_key = meta_key or f"{output_key}_{meta_key_postfix}" - if meta_key in ret: - # FIXME: we save inverted meta dict into `batch` to be compatible with `SegmentationSaver` - # will deprecate both handlers soon - engine.state.batch[i][meta_key] = ret.get(meta_key) diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 13f23c582a..2912e214be 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,13 +11,13 @@ import os from collections import OrderedDict -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union import numpy as np import torch -from monai.config import IgniteInfo, KeysCollection -from monai.utils import deprecated, ensure_tuple, get_torch_version_tuple, look_up_option, min_version, optional_import +from monai.config import IgniteInfo, KeysCollection, PathLike +from monai.utils import ensure_tuple, look_up_option, min_version, optional_import idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") if TYPE_CHECKING: @@ -25,14 +25,7 @@ else: Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") -__all__ = [ - "stopping_fn_from_metric", - "stopping_fn_from_loss", - "evenly_divisible_all_gather", - "string_list_all_gather", - "write_metrics_reports", - "from_engine", -] +__all__ = ["stopping_fn_from_metric", "stopping_fn_from_loss", "write_metrics_reports", "from_engine"] def stopping_fn_from_metric(metric_name: str): @@ -52,90 +45,18 @@ def stopping_fn_from_loss(): """ def stopping_fn(engine: Engine): - return -engine.state.output + return -engine.state.output # type:ignore return stopping_fn -@deprecated(since="0.6.0", removed="0.7.0", msg_suffix="The API had been moved to monai.utils module.") -def evenly_divisible_all_gather(data: torch.Tensor) -> torch.Tensor: - """ - Utility function for distributed data parallel to pad at first dim to make it evenly divisible and all_gather. - - Args: - data: source tensor to pad and execute all_gather in distributed data parallel. - - Note: - The input data on different ranks must have exactly same `dtype`. - - .. versionchanged:: 0.6.0 - The API had been moved to `monai.utils`. - - """ - if not isinstance(data, torch.Tensor): - raise ValueError("input data must be PyTorch Tensor.") - - if idist.get_world_size() <= 1: - return data - - # make sure the data is evenly-divisible on multi-GPUs - length = data.shape[0] - all_lens = idist.all_gather(length) - max_len = max(all_lens) - if length < max_len: - size = [max_len - length] + list(data.shape[1:]) - data = torch.cat([data, data.new_full(size, 0)], dim=0) - # all gather across all processes - data = idist.all_gather(data) - # delete the padding NaN items - return torch.cat([data[i * max_len : i * max_len + l, ...] for i, l in enumerate(all_lens)], dim=0) - - -@deprecated(since="0.6.0", removed="0.7.0", msg_suffix="The API had been moved to monai.utils module.") -def string_list_all_gather(strings: List[str]) -> List[str]: - """ - Utility function for distributed data parallel to all gather a list of strings. - Note that if the item in `strings` is longer than 1024 chars, it will be truncated to 1024: - https://pytorch.org/ignite/v0.4.5/distributed.html#ignite.distributed.utils.all_gather. - - Args: - strings: a list of strings to all gather. - - .. versionchanged:: 0.6.0 - The API had been moved to `monai.utils`. - - """ - world_size = idist.get_world_size() - if world_size <= 1: - return strings - - result: List[List[str]] = [[] for _ in range(world_size)] - # get length of strings - length = len(strings) - all_lens = idist.all_gather(length) - max_len = max(all_lens) - # pad the item to make sure the same length - if length < max_len: - strings += ["" for _ in range(max_len - length)] - - if get_torch_version_tuple() <= (1, 6): - raise RuntimeError("string all_gather can not be supported in PyTorch < 1.7.0.") - - for s in strings: - gathered = idist.all_gather(s) - for i, g in enumerate(gathered): - if len(g) > 0: - result[i].append(g) - return [i for k in result for i in k] - - def write_metrics_reports( - save_dir: str, + save_dir: PathLike, images: Optional[Sequence[str]], metrics: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]], metric_details: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]], summary_ops: Optional[Union[str, Sequence[str]]], - deli: str = "\t", + deli: str = ",", output_type: str = "csv", ): """ @@ -167,7 +88,8 @@ class mean median max 5percentile 95percentile notnans class1 6.0000 6.0000 6.0000 6.0000 6.0000 1.0000 mean 6.2500 6.2500 7.0000 5.5750 6.9250 2.0000 - deli: the delimiter character in the file, default to "\t". + deli: the delimiter character in the saved file, default to "," as the default output type is `csv`. + to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter. output_type: expected output file type, supported types: ["csv"], default to "csv". """ @@ -204,12 +126,12 @@ class mean median max 5percentile 95percentile notnans if summary_ops is not None: supported_ops = OrderedDict( { - "mean": lambda x: np.nanmean(x), - "median": lambda x: np.nanmedian(x), - "max": lambda x: np.nanmax(x), - "min": lambda x: np.nanmin(x), + "mean": np.nanmean, + "median": np.nanmedian, + "max": np.nanmax, + "min": np.nanmin, "90percentile": lambda x: np.nanpercentile(x[0], x[1]), - "std": lambda x: np.nanstd(x), + "std": np.nanstd, "notnans": lambda x: (~np.isnan(x)).sum(), } ) @@ -223,7 +145,7 @@ def _compute_op(op: str, d: np.ndarray): return c_op(d) threshold = int(op.split("percentile")[0]) - return supported_ops["90percentile"]((d, threshold)) + return supported_ops["90percentile"]((d, threshold)) # type: ignore with open(os.path.join(save_dir, f"{k}_summary.csv"), "w") as f: f.write(f"class{deli}{deli.join(ops)}\n") @@ -272,3 +194,12 @@ def _wrapper(data): return tuple(ret) if len(ret) > 1 else ret[0] return _wrapper + + +def ignore_data(x: Any): + """ + Always return `None` for any input data. + A typical usage is to avoid logging the engine output of every iteration during evaluation. + + """ + return None diff --git a/monai/handlers/validation_handler.py b/monai/handlers/validation_handler.py index 6214461a4f..171c901fbb 100644 --- a/monai/handlers/validation_handler.py +++ b/monai/handlers/validation_handler.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/inferers/__init__.py b/monai/inferers/__init__.py index 030344728d..3447782be9 100644 --- a/monai/inferers/__init__.py +++ b/monai/inferers/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,5 +9,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .inferer import Inferer, SaliencyInferer, SimpleInferer, SlidingWindowInferer +from .inferer import Inferer, SaliencyInferer, SimpleInferer, SliceInferer, SlidingWindowInferer from .utils import sliding_window_inference diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index ecb2c2c178..331637ba94 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,10 +16,10 @@ import torch.nn as nn from monai.inferers.utils import sliding_window_inference -from monai.utils import BlendMode, PytorchPadMode +from monai.utils import BlendMode, PytorchPadMode, ensure_tuple from monai.visualize import CAM, GradCAM, GradCAMpp -__all__ = ["Inferer", "SimpleInferer", "SlidingWindowInferer", "SaliencyInferer"] +__all__ = ["Inferer", "SimpleInferer", "SlidingWindowInferer", "SaliencyInferer", "SliceInferer"] class Inferer(ABC): @@ -30,7 +30,8 @@ class Inferer(ABC): Example code:: device = torch.device("cuda:0") - data = ToTensor()(LoadImage()(filename=img_path)).to(device) + transform = Compose([ToTensor(), LoadImage(image_only=True)]) + data = transform(img_path).to(device) model = UNet(...).to(device) inferer = SlidingWindowInferer(...) @@ -42,13 +43,7 @@ class Inferer(ABC): """ @abstractmethod - def __call__( - self, - inputs: torch.Tensor, - network: Callable[..., torch.Tensor], - *args: Any, - **kwargs: Any, - ): + def __call__(self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any): """ Run inference on `inputs` with the `network` model. @@ -75,13 +70,7 @@ class SimpleInferer(Inferer): def __init__(self) -> None: Inferer.__init__(self) - def __call__( - self, - inputs: torch.Tensor, - network: Callable[..., torch.Tensor], - *args: Any, - **kwargs: Any, - ): + def __call__(self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any): """Unified callable function API of Inferers. Args: @@ -121,7 +110,7 @@ class SlidingWindowInferer(Inferer): spatial dimensions. padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``} Padding mode when ``roi_size`` is larger than inputs. Defaults to ``"constant"`` - See also: https://pytorch.org/docs/stable/nn.functional.html#pad + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html cval: fill value for 'constant' padding mode. Default: 0 sw_device: device for the window data. By default the device (and accordingly the memory) of the `inputs` is used. @@ -130,6 +119,7 @@ class SlidingWindowInferer(Inferer): By default the device (and accordingly the memory) of the `inputs` is used. If for example set to device=torch.device('cpu') the gpu memory consumption is less and independent of the `inputs` and `roi_size`. Output is on the `device`. + progress: whether to print a tqdm progress bar. Note: ``sw_batch_size`` denotes the max number of windows per network inference iteration, @@ -148,6 +138,7 @@ def __init__( cval: float = 0.0, sw_device: Union[torch.device, str, None] = None, device: Union[torch.device, str, None] = None, + progress: bool = False, ) -> None: Inferer.__init__(self) self.roi_size = roi_size @@ -159,13 +150,10 @@ def __init__( self.cval = cval self.sw_device = sw_device self.device = device + self.progress = progress def __call__( - self, - inputs: torch.Tensor, - network: Callable[..., torch.Tensor], - *args: Any, - **kwargs: Any, + self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any ) -> torch.Tensor: """ @@ -189,6 +177,7 @@ def __call__( self.cval, self.sw_device, self.device, + self.progress, *args, **kwargs, ) @@ -217,13 +206,7 @@ def __init__(self, cam_name: str, target_layers: str, class_idx: Optional[int] = self.args = args self.kwargs = kwargs - def __call__( # type: ignore - self, - inputs: torch.Tensor, - network: nn.Module, - *args: Any, - **kwargs: Any, - ): + def __call__(self, inputs: torch.Tensor, network: nn.Module, *args: Any, **kwargs: Any): # type: ignore """Unified callable function API of Inferers. Args: @@ -243,3 +226,59 @@ def __call__( # type: ignore cam = GradCAMpp(network, self.target_layers, *self.args, **self.kwargs) return cam(inputs, self.class_idx, *args, **kwargs) + + +class SliceInferer(SlidingWindowInferer): + """ + SliceInferer extends SlidingWindowInferer to provide slice-by-slice (2D) inference + when provided a 3D volume. + + Args: + spatial_dim: Spatial dimension over which the slice-by-slice inference runs on the 3D volume. + For example ``0`` could slide over axial slices. ``1`` over coronal slices and ``2`` over sagittal slices. + args: other optional args to be passed to the `__init__` of base class SlidingWindowInferer. + kwargs: other optional keyword args to be passed to `__init__` of base class SlidingWindowInferer. + + Note: + ``roi_size`` in SliceInferer is expected to be a 2D tuple when a 3D volume is provided. This allows + sliding across slices along the 3D volume using a selected ``spatial_dim``. + + """ + + def __init__(self, spatial_dim: int = 0, *args, **kwargs) -> None: + self.spatial_dim = spatial_dim + super().__init__(*args, **kwargs) + + def __call__( + self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any + ) -> torch.Tensor: + """ + Args: + inputs: 3D input for inference + network: 2D model to execute inference on slices in the 3D input + args: optional args to be passed to ``network``. + kwargs: optional keyword args to be passed to ``network``. + """ + if self.spatial_dim > 2: + raise ValueError("`spatial_dim` can only be `0, 1, 2` with `[H, W, D]` respectively.") + + # Check if ``roi_size`` tuple is 2D and ``inputs`` tensor is 3D + self.roi_size = ensure_tuple(self.roi_size) + if len(self.roi_size) == 2 and len(inputs.shape[2:]) == 3: + self.roi_size = list(self.roi_size) + self.roi_size.insert(self.spatial_dim, 1) + else: + raise RuntimeError("Currently, only 2D `roi_size` with 3D `inputs` tensor is supported.") + + return super().__call__(inputs=inputs, network=lambda x: self.network_wrapper(network, x, *args, **kwargs)) + + def network_wrapper(self, network: Callable[..., torch.Tensor], x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Wrapper handles inference for 2D models over 3D volume inputs. + """ + # Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D. + x = x.squeeze(dim=self.spatial_dim + 2) + out = network(x, *args, **kwargs) + # Unsqueeze the network output so it is [N, C, D, H, W] as expected by + # the default SlidingWindowInferer class + return out.unsqueeze(dim=self.spatial_dim + 2) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 0ca53529c7..36e4377bd6 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,7 +15,9 @@ import torch.nn.functional as F from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size -from monai.utils import BlendMode, PytorchPadMode, fall_back_tuple, look_up_option +from monai.utils import BlendMode, PytorchPadMode, fall_back_tuple, look_up_option, optional_import + +tqdm, _ = optional_import("tqdm", name="tqdm") __all__ = ["sliding_window_inference"] @@ -32,6 +34,7 @@ def sliding_window_inference( cval: float = 0.0, sw_device: Union[torch.device, str, None] = None, device: Union[torch.device, str, None] = None, + progress: bool = False, *args: Any, **kwargs: Any, ) -> torch.Tensor: @@ -65,7 +68,7 @@ def sliding_window_inference( spatial dimensions. padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``} Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"`` - See also: https://pytorch.org/docs/stable/nn.functional.html#pad + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html cval: fill value for 'constant' padding mode. Default: 0 sw_device: device for the window data. By default the device (and accordingly the memory) of the `inputs` is used. @@ -74,6 +77,7 @@ def sliding_window_inference( By default the device (and accordingly the memory) of the `inputs` is used. If for example set to device=torch.device('cpu') the gpu memory consumption is less and independent of the `inputs` and `roi_size`. Output is on the `device`. + progress: whether to print a `tqdm` progress bar. args: optional args to be passed to ``predictor``. kwargs: optional keyword args to be passed to ``predictor``. @@ -120,7 +124,7 @@ def sliding_window_inference( # Perform predictions output_image, count_map = torch.tensor(0.0, device=device), torch.tensor(0.0, device=device) _initialized = False - for slice_g in range(0, total_slices, sw_batch_size): + for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size): slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices)) unravel_slice = [ [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win]) diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 1221cd3041..1922996fb6 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .contrastive import ContrastiveLoss from .deform import BendingEnergyLoss from .dice import ( Dice, @@ -18,7 +19,6 @@ GeneralizedDiceLoss, GeneralizedWassersteinDiceLoss, MaskedDiceLoss, - dice, dice_ce, dice_focal, generalized_dice, diff --git a/monai/losses/contrastive.py b/monai/losses/contrastive.py new file mode 100644 index 0000000000..cd5b261acf --- /dev/null +++ b/monai/losses/contrastive.py @@ -0,0 +1,87 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch.nn import functional as F +from torch.nn.modules.loss import _Loss + +from monai.utils import deprecated_arg + + +class ContrastiveLoss(_Loss): + + """ + Compute the Contrastive loss defined in: + + Chen, Ting, et al. "A simple framework for contrastive learning of visual representations." International + conference on machine learning. PMLR, 2020. (http://proceedings.mlr.press/v119/chen20j.html) + + Adapted from: + https://github.com/Sara-Ahmed/SiT/blob/1aacd6adcd39b71efc903d16b4e9095b97dda76f/losses.py#L5 + + """ + + @deprecated_arg(name="reduction", since="0.8", msg_suffix="`reduction` is no longer supported.") + def __init__(self, temperature: float = 0.5, batch_size: int = 1, reduction="sum") -> None: + """ + Args: + temperature: Can be scaled between 0 and 1 for learning from negative samples, ideally set to 0.5. + batch_size: The number of samples. + + Raises: + ValueError: When an input of dimension length > 2 is passed + ValueError: When input and target are of different shapes + + .. deprecated:: 0.8.0 + + `reduction` is no longer supported. + + """ + super().__init__() + + self.batch_size = batch_size + self.temperature = temperature + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be B[F]. + target: the shape should be B[F]. + """ + if len(target.shape) > 2 or len(input.shape) > 2: + raise ValueError( + f"Either target or input has dimensions greater than 2 where target " + f"shape is ({target.shape}) and input shape is ({input.shape})" + ) + + if target.shape != input.shape: + raise ValueError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") + + temperature_tensor = torch.as_tensor(self.temperature).to(input.device) + + norm_i = F.normalize(input, dim=1) + norm_j = F.normalize(target, dim=1) + + negatives_mask = ~torch.eye(self.batch_size * 2, self.batch_size * 2, dtype=torch.bool) + negatives_mask = torch.clone(negatives_mask.type(torch.float)).to(input.device) + + repr = torch.cat([norm_i, norm_j], dim=0) + sim_matrix = F.cosine_similarity(repr.unsqueeze(1), repr.unsqueeze(0), dim=2) + sim_ij = torch.diag(sim_matrix, self.batch_size) + sim_ji = torch.diag(sim_matrix, -self.batch_size) + + positives = torch.cat([sim_ij, sim_ji], dim=0) + nominator = torch.exp(positives / temperature_tensor) + denominator = negatives_mask * torch.exp(sim_matrix / temperature_tensor) + + loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1)) + + return torch.sum(loss_partial) / (2 * self.batch_size) diff --git a/monai/losses/deform.py b/monai/losses/deform.py index d96fa1440a..0f5e263a53 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -52,12 +52,12 @@ class BendingEnergyLoss(_Loss): DeepReg (https://github.com/DeepRegNet/DeepReg) """ - def __init__( - self, - reduction: Union[LossReduction, str] = LossReduction.MEAN, - ) -> None: + def __init__(self, normalize: bool = False, reduction: Union[LossReduction, str] = LossReduction.MEAN) -> None: """ Args: + normalize: + Whether to divide out spatial sizes in order to make the computation roughly + invariant to image scale (i.e. vector field sampling resolution). Defaults to False. reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. @@ -65,7 +65,8 @@ def __init__( - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. """ - super(BendingEnergyLoss, self).__init__(reduction=LossReduction(reduction).value) + super().__init__(reduction=LossReduction(reduction).value) + self.normalize = normalize def forward(self, pred: torch.Tensor) -> torch.Tensor: """ @@ -77,20 +78,35 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: """ if pred.ndim not in [3, 4, 5]: - raise ValueError(f"expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") + raise ValueError(f"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") for i in range(pred.ndim - 2): if pred.shape[-i - 1] <= 4: - raise ValueError("all spatial dimensions must > 4, got pred of shape {pred.shape}") + raise ValueError(f"All spatial dimensions must be > 4, got spatial dimensions {pred.shape[2:]}") + if pred.shape[1] != pred.ndim - 2: + raise ValueError( + f"Number of vector components, {pred.shape[1]}, does not match number of spatial dimensions, {pred.ndim-2}" + ) # first order gradient first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)] + # spatial dimensions in a shape suited for broadcasting below + if self.normalize: + spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,)) + energy = torch.tensor(0) for dim_1, g in enumerate(first_order_gradient): dim_1 += 2 - energy = spatial_gradient(g, dim_1) ** 2 + energy + if self.normalize: + g *= pred.shape[dim_1] / spatial_dims + energy = energy + (spatial_gradient(g, dim_1) * pred.shape[dim_1]) ** 2 + else: + energy = energy + spatial_gradient(g, dim_1) ** 2 for dim_2 in range(dim_1 + 1, pred.ndim): - energy = 2 * spatial_gradient(g, dim_2) ** 2 + energy + if self.normalize: + energy = energy + 2 * (spatial_gradient(g, dim_2) * pred.shape[dim_2]) ** 2 + else: + energy = energy + 2 * spatial_gradient(g, dim_2) ** 2 if self.reduction == LossReduction.MEAN.value: energy = torch.mean(energy) # the batch and channel average diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 325c5300ea..610327ef63 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,22 +21,23 @@ from monai.losses.focal_loss import FocalLoss from monai.losses.spatial_mask import MaskedLoss from monai.networks import one_hot -from monai.utils import LossReduction, Weight, look_up_option +from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option class DiceLoss(_Loss): """ Compute average Dice loss between two tensors. It can support both multi-classes and multi-labels tasks. - Input logits `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]). - Axis N of `input` is expected to have logit predictions for each class rather than being image channels, - while the same axis of `target` can be 1 or N (one-hot format). The `smooth_nr` and `smooth_dr` parameters are - values added to the intersection and union components of the inter-over-union calculation to smooth results - respectively, these values should be small. The `include_background` class attribute can be set to False for - an instance of DiceLoss to exclude the first category (channel index 0) which is by convention assumed to be - background. If the non-background segmentations are small compared to the total image size they can get - overwhelmed by the signal from the background so excluding it in such cases helps convergence. + The data `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]). - Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks forVolumetric Medical Image Segmentation, 3DV, 2016. + Note that axis N of `input` is expected to be logits or probabilities for each class, if passing logits as input, + must set `sigmoid=True` or `softmax=True`, or specifying `other_act`. And the same axis of `target` + can be 1 or N (one-hot format). + + The `smooth_nr` and `smooth_dr` parameters are values added to the intersection and union components of + the inter-over-union calculation to smooth results respectively, these values should be small. + + The original paper: Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks forVolumetric + Medical Image Segmentation, 3DV, 2016. """ @@ -57,6 +58,8 @@ def __init__( """ Args: include_background: if False, channel index 0 (background category) is excluded from the calculation. + if the non-background segmentations are small compared to the total image size they can get overwhelmed + by the signal from the background so excluding it in such cases helps convergence. to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. sigmoid: if True, apply a sigmoid function to the prediction. softmax: if True, apply a softmax function to the prediction. @@ -111,6 +114,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: have different shapes. ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + Example: + >>> from monai.losses.dice import * # NOQA + >>> import torch + >>> from monai.losses.dice import DiceLoss + >>> B, C, H, W = 7, 5, 3, 2 + >>> input = torch.rand(B, C, H, W) + >>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long() + >>> target = one_hot(target_idx[:, None, ...], num_classes=C) + >>> self = DiceLoss(reduction='none') + >>> loss = self(input, target) + >>> assert np.broadcast_shapes(loss.shape, input.shape) == input.shape """ if self.sigmoid: input = torch.sigmoid(input) @@ -168,7 +182,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: f = torch.mean(f) # the batch and channel average elif self.reduction == LossReduction.SUM.value: f = torch.sum(f) # sum over the batch and channel dims - elif self.reduction != LossReduction.NONE.value: + elif self.reduction == LossReduction.NONE.value: + # If we are not computing voxelwise loss components at least + # make sure a none reduction maintains a broadcastable shape + broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2) + f = f.view(broadcast_shape) + else: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return f @@ -234,7 +253,6 @@ def __init__( other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute other activation layers, Defaults to ``None``. for example: `other_act = torch.tanh`. - squared_pred: use squared versions of targets and predictions in the denominator or not. w_type: {``"square"``, ``"simple"``, ``"uniform"``} Type of function to transform ground truth volume to a weight factor. Defaults to ``"square"``. reduction: {``"none"``, ``"mean"``, ``"sum"``} @@ -330,20 +348,30 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: denominator = ground_o + pred_o w = self.w_func(ground_o.float()) - for b in w: - infs = torch.isinf(b) - b[infs] = 0.0 - b[infs] = torch.max(b) + infs = torch.isinf(w) + if self.batch: + w[infs] = 0.0 + w = w + infs * torch.max(w) + else: + w[infs] = 0.0 + max_values = torch.max(w, dim=1)[0].unsqueeze(dim=1) + w = w + infs * max_values - f: torch.Tensor = 1.0 - (2.0 * (intersection * w).sum(0 if self.batch else 1) + self.smooth_nr) / ( - (denominator * w).sum(0 if self.batch else 1) + self.smooth_dr - ) + final_reduce_dim = 0 if self.batch else 1 + numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr + denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr + f: torch.Tensor = 1.0 - (numer / denom) if self.reduction == LossReduction.MEAN.value: f = torch.mean(f) # the batch and channel average elif self.reduction == LossReduction.SUM.value: f = torch.sum(f) # sum over the batch and channel dims - elif self.reduction != LossReduction.NONE.value: + elif self.reduction == LossReduction.NONE.value: + # If we are not computing voxelwise loss components at least + # make sure a none reduction maintains a broadcastable shape + broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2) + f = f.view(broadcast_shape) + else: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return f @@ -419,7 +447,7 @@ def __init__( wass_loss(pred_score, grnd) # 0 """ - super(GeneralizedWassersteinDiceLoss, self).__init__(reduction=LossReduction(reduction).value) + super().__init__(reduction=LossReduction(reduction).value) if dist_matrix.shape[0] != dist_matrix.shape[1]: raise ValueError(f"dist_matrix must be C x C, got {dist_matrix.shape[0]} x {dist_matrix.shape[1]}.") @@ -478,7 +506,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: wass_dice_loss = torch.mean(wass_dice_loss) # the batch and channel average elif self.reduction == LossReduction.SUM.value: wass_dice_loss = torch.sum(wass_dice_loss) # sum over the batch and channel dims - elif self.reduction != LossReduction.NONE.value: + elif self.reduction == LossReduction.NONE.value: + # If we are not computing voxelwise loss components at least + # make sure a none reduction maintains a broadcastable shape + broadcast_shape = input.shape[0:2] + (1,) * (len(input.shape) - 2) + wass_dice_loss = wass_dice_loss.view(broadcast_shape) + else: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return wass_dice_loss @@ -536,10 +569,7 @@ def _compute_generalized_true_positive( flat_target_extended = torch.unsqueeze(flat_target, dim=1) alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1) - return torch.sum( - alpha_extended * (1.0 - wasserstein_distance_map), - dim=[1, 2], - ) + return torch.sum(alpha_extended * (1.0 - wasserstein_distance_map), dim=[1, 2]) def _compute_denominator( self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor @@ -556,10 +586,7 @@ def _compute_denominator( flat_target_extended = torch.unsqueeze(flat_target, dim=1) alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1) - return torch.sum( - alpha_extended * (2.0 - wasserstein_distance_map), - dim=[1, 2], - ) + return torch.sum(alpha_extended * (2.0 - wasserstein_distance_map), dim=[1, 2]) def _compute_alpha_generalized_true_positives(self, flat_target: torch.Tensor) -> torch.Tensor: """ @@ -644,6 +671,7 @@ def __init__( """ super().__init__() + reduction = look_up_option(reduction, DiceCEReduction).value self.dice = DiceLoss( include_background=include_background, to_onehot_y=to_onehot_y, @@ -657,10 +685,7 @@ def __init__( smooth_dr=smooth_dr, batch=batch, ) - self.cross_entropy = nn.CrossEntropyLoss( - weight=ce_weight, - reduction=reduction, - ) + self.cross_entropy = nn.CrossEntropyLoss(weight=ce_weight, reduction=reduction) if lambda_dice < 0.0: raise ValueError("lambda_dice should be no less than 0.0.") if lambda_ce < 0.0: @@ -815,11 +840,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: dice_loss = self.dice(input, target) focal_loss = self.focal(input, target) total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss - return total_loss -dice = Dice = DiceLoss +Dice = DiceLoss dice_ce = DiceCELoss dice_focal = DiceFocalLoss generalized_dice = GeneralizedDiceLoss diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index b4b3698e5b..bf31682748 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,11 +22,44 @@ class FocalLoss(_Loss): """ + FocalLoss is an extension of BCEWithLogitsLoss that down-weights loss from + high confidence correct predictions. + Reimplementation of the Focal Loss (with a build-in sigmoid activation) described in: - "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017 - "AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy", Zhu et al., Medical Physics 2018 + + Example: + >>> import torch + >>> from monai.losses import FocalLoss + >>> from torch.nn import BCEWithLogitsLoss + >>> shape = B, N, *DIMS = 2, 3, 5, 7, 11 + >>> input = torch.rand(*shape) + >>> target = torch.rand(*shape) + >>> # Demonstrate equivalence to BCE when gamma=0 + >>> fl_g0_criterion = FocalLoss(reduction='none', gamma=0) + >>> fl_g0_loss = fl_g0_criterion(input, target) + >>> bce_criterion = BCEWithLogitsLoss(reduction='none') + >>> bce_loss = bce_criterion(input, target) + >>> assert torch.allclose(fl_g0_loss, bce_loss) + >>> # Demonstrate "focus" by setting gamma > 0. + >>> fl_g2_criterion = FocalLoss(reduction='none', gamma=2) + >>> fl_g2_loss = fl_g2_criterion(input, target) + >>> # Mark easy and hard cases + >>> is_easy = (target > 0.7) & (input > 0.7) + >>> is_hard = (target > 0.7) & (input < 0.3) + >>> easy_loss_g0 = fl_g0_loss[is_easy].mean() + >>> hard_loss_g0 = fl_g0_loss[is_hard].mean() + >>> easy_loss_g2 = fl_g2_loss[is_easy].mean() + >>> hard_loss_g2 = fl_g2_loss[is_hard].mean() + >>> # Gamma > 0 causes the loss function to "focus" on the hard + >>> # cases. IE, easy cases are downweighted, so hard cases + >>> # receive a higher proportion of the loss. + >>> hard_to_easy_ratio_g2 = hard_loss_g2 / easy_loss_g2 + >>> hard_to_easy_ratio_g0 = hard_loss_g0 / easy_loss_g0 + >>> assert hard_to_easy_ratio_g2 > hard_to_easy_ratio_g0 """ def __init__( @@ -56,18 +89,14 @@ def __init__( - ``"sum"``: the output will be summed. Example: - .. code-block:: python - - import torch - from monai.losses import FocalLoss - - pred = torch.tensor([[1, 0], [0, 1], [1, 0]], dtype=torch.float32) - grnd = torch.tensor([[0], [1], [0]], dtype=torch.int64) - fl = FocalLoss(to_onehot_y=True) - fl(pred, grnd) - + >>> import torch + >>> from monai.losses import FocalLoss + >>> pred = torch.tensor([[1, 0], [0, 1], [1, 0]], dtype=torch.float32) + >>> grnd = torch.tensor([[0], [1], [0]], dtype=torch.int64) + >>> fl = FocalLoss(to_onehot_y=True) + >>> fl(pred, grnd) """ - super(FocalLoss, self).__init__(reduction=LossReduction(reduction).value) + super().__init__(reduction=LossReduction(reduction).value) self.include_background = include_background self.to_onehot_y = to_onehot_y self.gamma = gamma @@ -147,12 +176,25 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # Compute the loss mini-batch. # (1-p_t)^gamma * log(p_t) with reduced chance of overflow p = F.logsigmoid(-i * (t * 2.0 - 1.0)) - loss = torch.mean((p * self.gamma).exp() * ce, dim=-1) + flat_loss: torch.Tensor = (p * self.gamma).exp() * ce + + # Previously there was a mean over the last dimension, which did not + # return a compatible BCE loss. To maintain backwards compatible + # behavior we have a flag that performs this extra step, disable or + # parameterize if necessary. (Or justify why the mean should be there) + average_spatial_dims = True if self.reduction == LossReduction.SUM.value: - return loss.sum() - if self.reduction == LossReduction.NONE.value: - return loss - if self.reduction == LossReduction.MEAN.value: - return loss.mean() - raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + if average_spatial_dims: + flat_loss = flat_loss.mean(dim=-1) + loss = flat_loss.sum() + elif self.reduction == LossReduction.MEAN.value: + if average_spatial_dims: + flat_loss = flat_loss.mean(dim=-1) + loss = flat_loss.mean() + elif self.reduction == LossReduction.NONE.value: + spacetime_dims = input.shape[2:] + loss = flat_loss.reshape([b, n] + list(spacetime_dims)) + else: + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + return loss diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index eed5808aa3..a06f6fb5cd 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -8,14 +8,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch from torch.nn import functional as F from torch.nn.modules.loss import _Loss from monai.networks.layers import gaussian_1d, separable_filtering -from monai.utils import LossReduction +from monai.utils import LossReduction, deprecated_arg +from monai.utils.module import look_up_option def make_rectangular_kernel(kernel_size: int) -> torch.Tensor: @@ -59,18 +60,20 @@ class LocalNormalizedCrossCorrelationLoss(_Loss): DeepReg (https://github.com/DeepRegNet/DeepReg) """ + @deprecated_arg(name="ndim", since="0.6", msg_suffix="Please use `spatial_dims` instead.") def __init__( self, - ndim: int = 3, + spatial_dims: int = 3, kernel_size: int = 3, kernel_type: str = "rectangular", reduction: Union[LossReduction, str] = LossReduction.MEAN, smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, + ndim: Optional[int] = None, ) -> None: """ Args: - ndim: number of spatial ndimensions, {``1``, ``2``, ``3``}. Defaults to 3. + spatial_dims: number of spatial dimensions, {``1``, ``2``, ``3``}. Defaults to 3. kernel_size: kernel spatial size, must be odd. kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``. reduction: {``"none"``, ``"mean"``, ``"sum"``} @@ -81,22 +84,24 @@ def __init__( - ``"sum"``: the output will be summed. smooth_nr: a small constant added to the numerator to avoid nan. smooth_dr: a small constant added to the denominator to avoid nan. + + .. deprecated:: 0.6.0 + ``ndim`` is deprecated, use ``spatial_dims``. """ - super(LocalNormalizedCrossCorrelationLoss, self).__init__(reduction=LossReduction(reduction).value) + super().__init__(reduction=LossReduction(reduction).value) - self.ndim = ndim - if self.ndim not in [1, 2, 3]: + if ndim is not None: + spatial_dims = ndim + self.ndim = spatial_dims + if self.ndim not in {1, 2, 3}: raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported") self.kernel_size = kernel_size if self.kernel_size % 2 == 0: raise ValueError(f"kernel_size must be odd, got {self.kernel_size}") - if kernel_type not in kernel_dict.keys(): - raise ValueError( - f'Unsupported kernel_type: {kernel_type}, available options are ["rectangular", "triangular", "gaussian"].' - ) - self.kernel = kernel_dict[kernel_type](self.kernel_size) + _kernel = look_up_option(kernel_type, kernel_dict) + self.kernel = _kernel(self.kernel_size) self.kernel_vol = self.get_kernel_vol() self.smooth_nr = float(smooth_nr) @@ -121,7 +126,7 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != pred.shape: raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})") - t2, p2, tp = target ** 2, pred ** 2, target * pred + t2, p2, tp = target**2, pred**2, target * pred kernel, kernel_vol = self.kernel.to(pred), self.kernel_vol.to(pred) # sum over kernel t_sum = separable_filtering(target, kernels=[kernel.to(pred)] * self.ndim) @@ -170,6 +175,7 @@ class GlobalMutualInformationLoss(_Loss): def __init__( self, + kernel_type: str = "gaussian", num_bins: int = 23, sigma_ratio: float = 0.5, reduction: Union[LossReduction, str] = LossReduction.MEAN, @@ -178,6 +184,19 @@ def __init__( ) -> None: """ Args: + kernel_type: {``"gaussian"``, ``"b-spline"``} + ``"gaussian"``: adapted from DeepReg + Reference: https://dspace.mit.edu/handle/1721.1/123142, Section 3.1, equation 3.1-3.5, Algorithm 1. + ``"b-spline"``: based on the method of Mattes et al [1,2] and adapted from ITK + References: + [1] "Nonrigid multimodality image registration" + D. Mattes, D. R. Haynor, H. Vesselle, T. Lewellen and W. Eubank + Medical Imaging 2001: Image Processing, 2001, pp. 1609-1620. + [2] "PET-CT Image Registration in the Chest Using Free-form Deformations" + D. Mattes, D. R. Haynor, H. Vesselle, T. Lewellen and W. Eubank + IEEE Transactions in Medical Imaging. Vol.22, No.1, + January 2003. pp.120-128. + num_bins: number of bins for intensity sigma_ratio: a hyper param for gaussian function reduction: {``"none"``, ``"mean"``, ``"sum"``} @@ -189,25 +208,99 @@ def __init__( smooth_nr: a small constant added to the numerator to avoid nan. smooth_dr: a small constant added to the denominator to avoid nan. """ - super(GlobalMutualInformationLoss, self).__init__(reduction=LossReduction(reduction).value) + super().__init__(reduction=LossReduction(reduction).value) if num_bins <= 0: raise ValueError("num_bins must > 0, got {num_bins}") bin_centers = torch.linspace(0.0, 1.0, num_bins) # (num_bins,) sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio - self.preterm = 1 / (2 * sigma ** 2) - self.bin_centers = bin_centers[None, None, ...] + self.kernel_type = look_up_option(kernel_type, ["gaussian", "b-spline"]) + self.num_bins = num_bins + self.kernel_type = kernel_type + if self.kernel_type == "gaussian": + self.preterm = 1 / (2 * sigma**2) + self.bin_centers = bin_centers[None, None, ...] self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) - def parzen_windowing(self, pred: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def parzen_windowing( + self, pred: torch.Tensor, target: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if self.kernel_type == "gaussian": + pred_weight, pred_probability = self.parzen_windowing_gaussian(pred) + target_weight, target_probability = self.parzen_windowing_gaussian(target) + elif self.kernel_type == "b-spline": + # a third order BSpline kernel is used for the pred image intensity PDF. + pred_weight, pred_probability = self.parzen_windowing_b_spline(pred, order=3) + # a zero order (box car) BSpline kernel is used for the target image intensity PDF. + target_weight, target_probability = self.parzen_windowing_b_spline(target, order=0) + else: + raise ValueError + return pred_weight, pred_probability, target_weight, target_probability + + def parzen_windowing_b_spline(self, img: torch.Tensor, order: int) -> Tuple[torch.Tensor, torch.Tensor]: """ + Parzen windowing with b-spline kernel (adapted from ITK) + Args: - pred: the shape should be B[NDHW]. + img: the shape should be B[NDHW]. + order: int. + """ + + # Compute binsize for the histograms. + # + # The binsize for the image intensities needs to be adjusted so that + # we can avoid dealing with boundary conditions using the cubic + # spline as the Parzen window. We do this by increasing the size + # of the bins so that the joint histogram becomes "padded" at the + # borders. Because we are changing the binsize, + # we also need to shift the minimum by the padded amount in order to + # avoid minimum values filling in our padded region. + # + # Note that there can still be non-zero bin values in the padded region, + # it's just that these bins will never be a central bin for the Parzen + # window. + _max, _min = torch.max(img), torch.min(img) + padding = 2 + bin_size = (_max - _min) / (self.num_bins - 2 * padding) + norm_min = torch.div(_min, bin_size) - padding + + # assign bin/window index to each voxel + window_term = torch.div(img, bin_size) - norm_min # B[NDHW] + # make sure the extreme values are in valid (non-padded) bins + window_term = torch.clamp(window_term, padding, self.num_bins - padding - 1) # B[NDHW] + window_term = window_term.reshape(window_term.shape[0], -1, 1) # (batch, num_sample, 1) + bins = torch.arange(self.num_bins, device=window_term.device).reshape(1, 1, -1) # (1, 1, num_bins) + sample_bin_matrix = torch.abs(bins - window_term) # (batch, num_sample, num_bins) + + # b-spleen kernel + # (4 - 6 * abs ** 2 + 3 * abs ** 3) / 6 when 0 <= abs < 1 + # (2 - abs) ** 3 / 6 when 1 <= abs < 2 + weight = torch.zeros_like(sample_bin_matrix, dtype=torch.float) # (batch, num_sample, num_bins) + if order == 0: + weight = weight + (sample_bin_matrix < 0.5) + (sample_bin_matrix == 0.5) * 0.5 + elif order == 3: + weight = ( + weight + (4 - 6 * sample_bin_matrix**2 + 3 * sample_bin_matrix**3) * (sample_bin_matrix < 1) / 6 + ) + weight = weight + (2 - sample_bin_matrix) ** 3 * (sample_bin_matrix >= 1) * (sample_bin_matrix < 2) / 6 + else: + raise ValueError(f"Do not support b-spline {order}-order parzen windowing") + + weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bins) + probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bins) + return weight, probability + + def parzen_windowing_gaussian(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Parzen windowing with gaussian kernel (adapted from DeepReg implementation) + Note: the input is expected to range between 0 and 1 + Args: + img: the shape should be B[NDHW]. """ - pred = torch.clamp(pred, 0, 1) - pred = pred.reshape(pred.shape[0], -1, 1) # (batch, num_sample, 1) + img = torch.clamp(img, 0, 1) + img = img.reshape(img.shape[0], -1, 1) # (batch, num_sample, 1) weight = torch.exp( - -self.preterm.to(pred) * (pred - self.bin_centers.to(pred)) ** 2 + -self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2 ) # (batch, num_sample, num_bin) weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bin) probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bin) @@ -223,11 +316,10 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ if target.shape != pred.shape: raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})") - wa, pa = self.parzen_windowing(pred) # (batch, num_sample, num_bin), (batch, 1, num_bin) - wb, pb = self.parzen_windowing(target) # (batch, num_sample, num_bin), (batch, 1, num_bin) - pab = torch.bmm(wa.permute(0, 2, 1), wb).div(wa.shape[1]) # (batch, num_bins, num_bins) + wa, pa, wb, pb = self.parzen_windowing(pred, target) # (batch, num_sample, num_bin), (batch, 1, num_bin) - papb = torch.bmm(pa.permute(0, 2, 1), pb) # (batch, num_bins, num_bins) + pab = torch.bmm(wa.permute(0, 2, 1), wb.to(wa)).div(wa.shape[1]) # (batch, num_bins, num_bins) + papb = torch.bmm(pa.permute(0, 2, 1), pb.to(pa)) # (batch, num_bins, num_bins) mi = torch.sum( pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2) ) # (batch) diff --git a/monai/losses/multi_scale.py b/monai/losses/multi_scale.py index 6f9326420b..5e80af30bc 100644 --- a/monai/losses/multi_scale.py +++ b/monai/losses/multi_scale.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,12 +21,7 @@ def make_gaussian_kernel(sigma: int) -> torch.Tensor: if sigma <= 0: raise ValueError(f"expecting positive sigma, got sigma={sigma}") - return gaussian_1d( - sigma=torch.tensor(sigma), - truncated=3, - approx="sampled", - normalize=False, - ) + return gaussian_1d(sigma=torch.tensor(sigma), truncated=3, approx="sampled", normalize=False) def make_cauchy_kernel(sigma: int) -> torch.Tensor: @@ -39,10 +34,7 @@ def make_cauchy_kernel(sigma: int) -> torch.Tensor: return k -kernel_fn_dict = { - "gaussian": make_gaussian_kernel, - "cauchy": make_cauchy_kernel, -} +kernel_fn_dict = {"gaussian": make_gaussian_kernel, "cauchy": make_cauchy_kernel} class MultiScaleLoss(_Loss): @@ -67,7 +59,7 @@ def __init__( scales: list of scalars or None, if None, do not apply any scaling. kernel: gaussian or cauchy. """ - super(MultiScaleLoss, self).__init__(reduction=LossReduction(reduction).value) + super().__init__(reduction=LossReduction(reduction).value) if kernel not in kernel_fn_dict.keys(): raise ValueError(f"got unsupported kernel type: {kernel}", "only support gaussian and cauchy") self.kernel_fn = kernel_fn_dict[kernel] diff --git a/monai/losses/spatial_mask.py b/monai/losses/spatial_mask.py index 387300e507..aa232f882e 100644 --- a/monai/losses/spatial_mask.py +++ b/monai/losses/spatial_mask.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index 1cc0e1d8d7..ee6d7d933b 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index c2197bdf2a..d18c20f7b2 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,6 +10,7 @@ # limitations under the License. from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix +from .cumulative_average import CumulativeAverage from .froc import compute_fp_tp_probs, compute_froc_curve_data, compute_froc_score from .hausdorff_distance import HausdorffDistanceMetric, compute_hausdorff_distance, compute_percent_hausdorff_distance from .meandice import DiceMetric, compute_meandice diff --git a/monai/metrics/confusion_matrix.py b/monai/metrics/confusion_matrix.py index 9568cf6028..320f657537 100644 --- a/monai/metrics/confusion_matrix.py +++ b/monai/metrics/confusion_matrix.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -47,8 +47,9 @@ class ConfusionMatrixMetric(CumulativeIterationMetric): returned with the same order as input names when calling the class. compute_sample: when reducing, if ``True``, each sample's metric will be computed based on each confusion matrix first. if ``False``, compute reduction on the confusion matrices first, defaults to ``False``. - reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``} + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns [(metric, not_nans), ...]. If False, aggregate() returns [metric, ...]. Here `not_nans` count the number of not nans for True Positive, False Positive, True Negative and False Negative. @@ -99,13 +100,9 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor warnings.warn("As for classification task, compute_sample should be False.") self.compute_sample = False - return get_confusion_matrix( - y_pred=y_pred, - y=y, - include_background=self.include_background, - ) + return get_confusion_matrix(y_pred=y_pred, y=y, include_background=self.include_background) - def aggregate(self): # type: ignore + def aggregate(self): """ Execute reduction for the confusion matrix values. @@ -129,11 +126,7 @@ def aggregate(self): # type: ignore return results -def get_confusion_matrix( - y_pred: torch.Tensor, - y: torch.Tensor, - include_background: bool = True, -): +def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True): """ Compute confusion matrix. A tensor with the shape [BC4] will be returned. Where, the third dimension represents the number of true positive, false positive, true negative and false negative values for @@ -153,16 +146,13 @@ def get_confusion_matrix( """ if not include_background: - y_pred, y = ignore_background( - y_pred=y_pred, - y=y, - ) + y_pred, y = ignore_background(y_pred=y_pred, y=y) y = y.float() y_pred = y_pred.float() if y.shape != y_pred.shape: - raise ValueError("y_pred and y should have same shapes.") + raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") # get confusion matrix related metric batch_size, n_class = y_pred.shape[:2] diff --git a/monai/metrics/cumulative_average.py b/monai/metrics/cumulative_average.py new file mode 100644 index 0000000000..768841f6c7 --- /dev/null +++ b/monai/metrics/cumulative_average.py @@ -0,0 +1,72 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from monai.transforms import isnan +from monai.utils import convert_data_type + +from .metric import Cumulative + + +class CumulativeAverage(Cumulative): + """ + Cumulatively record data value and aggregate for the average value. + It supports single class or multi-class data, for example, + value can be 0.44 (a loss value) or [0.3, 0.4] (metrics of two classes). + It also supports distributed data parallel, sync data when aggregating. + For example, recording loss values and compute the overall average value in every 5 iterations: + + .. code-block:: python + + average = CumulativeAverage() + for i, d in enumerate(dataloader): + loss = ... + average.append(loss) + if i % 5 == 0: + print(f"cumulative average of loss: {average.aggregate()}") + average.reset() + + """ + + def __init__(self) -> None: + super().__init__() + self.sum = None + self.not_nans = None + + def reset(self): + """ + Reset all the running status, including buffers, sum, not nans count, etc. + + """ + super().reset() + self.sum = None + self.not_nans = None + + def aggregate(self): + """ + Sync data from all the ranks and compute the average value with previous sum value. + + """ + data = self.get_buffer() + + # compute SUM across the batch dimension + nans = isnan(data) + not_nans = convert_data_type((~nans), dtype=torch.float32)[0].sum(0) + data[nans] = 0 + f = data.sum(0) + + # clear the buffer for next update + super().reset() + self.sum = f if self.sum is None else (self.sum + f) + self.not_nans = not_nans if self.not_nans is None else (self.not_nans + not_nans) + + return self.sum / self.not_nans diff --git a/monai/metrics/froc.py b/monai/metrics/froc.py index faebbbf7a6..93ad625b90 100644 --- a/monai/metrics/froc.py +++ b/monai/metrics/froc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -46,7 +46,9 @@ def compute_fp_tp_probs( """ if not (probs.shape == y_coord.shape == x_coord.shape): - raise AssertionError("the shapes for coordinates and probabilities should be the same.") + raise ValueError( + f"the shapes between probs {probs.shape}, y_coord {y_coord.shape} and x_coord {x_coord.shape} should be the same." + ) if isinstance(probs, torch.Tensor): probs = probs.detach().cpu().numpy() @@ -96,7 +98,7 @@ def compute_froc_curve_data( num_images: the number of images under evaluation. """ - if type(fp_probs) is not type(tp_probs): + if not isinstance(fp_probs, type(tp_probs)): raise AssertionError("fp and tp probs should have same type.") if isinstance(fp_probs, torch.Tensor): fp_probs = fp_probs.detach().cpu().numpy() @@ -116,9 +118,7 @@ def compute_froc_curve_data( def compute_froc_score( - fps_per_image: np.ndarray, - total_sensitivity: np.ndarray, - eval_thresholds: Tuple = (0.25, 0.5, 1, 2, 4, 8), + fps_per_image: np.ndarray, total_sensitivity: np.ndarray, eval_thresholds: Tuple = (0.25, 0.5, 1, 2, 4, 8) ): """ This function is modified from the official evaluation code of diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 12f3b49d32..5ce739d1f4 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -42,9 +42,9 @@ class HausdorffDistanceMetric(CumulativeIterationMetric): percentile of the Hausdorff Distance rather than the maximum result will be achieved. Defaults to ``None``. directed: whether to calculate directed Hausdorff distance. Defaults to ``False``. - reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``} - Define the mode to reduce computation result. Defaults to ``"mean"``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. @@ -85,7 +85,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor if not torch.all(y_pred.byte() == y_pred): warnings.warn("y_pred should be a binarized tensor.") if not torch.all(y.byte() == y): - raise ValueError("y should be a binarized tensor.") + warnings.warn("y should be a binarized tensor.") dims = y_pred.ndimension() if dims < 3: raise ValueError("y_pred should have at least three dimensions.") @@ -99,7 +99,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor directed=self.directed, ) - def aggregate(self): # type: ignore + def aggregate(self): """ Execute reduction logic for the output of `compute_hausdorff_distance`. @@ -141,17 +141,14 @@ def compute_hausdorff_distance( """ if not include_background: - y_pred, y = ignore_background( - y_pred=y_pred, - y=y, - ) + y_pred, y = ignore_background(y_pred=y_pred, y=y) if isinstance(y, torch.Tensor): y = y.float() if isinstance(y_pred, torch.Tensor): y_pred = y_pred.float() if y.shape != y_pred.shape: - raise ValueError("y_pred and y should have same shapes.") + raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") batch_size, n_class = y_pred.shape[:2] hd = np.empty((batch_size, n_class)) @@ -172,10 +169,7 @@ def compute_hausdorff_distance( def compute_percent_hausdorff_distance( - edges_pred: np.ndarray, - edges_gt: np.ndarray, - distance_metric: str = "euclidean", - percentile: Optional[float] = None, + edges_pred: np.ndarray, edges_gt: np.ndarray, distance_metric: str = "euclidean", percentile: Optional[float] = None ): """ This function is used to compute the directed Hausdorff distance. diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 226c106f7e..4179420804 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -35,9 +35,9 @@ class DiceMetric(CumulativeIterationMetric): Args: include_background: whether to skip Dice computation on the first channel of the predicted output. Defaults to ``True``. - reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``} - Define the mode to reduce computation result. Defaults to ``"mean"``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. @@ -72,18 +72,14 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor if not torch.all(y_pred.byte() == y_pred): warnings.warn("y_pred should be a binarized tensor.") if not torch.all(y.byte() == y): - raise ValueError("y should be a binarized tensor.") + warnings.warn("y should be a binarized tensor.") dims = y_pred.ndimension() if dims < 3: raise ValueError("y_pred should have at least three dimensions.") # compute dice (BxC) for each channel for each batch - return compute_meandice( - y_pred=y_pred, - y=y, - include_background=self.include_background, - ) + return compute_meandice(y_pred=y_pred, y=y, include_background=self.include_background) - def aggregate(self): # type: ignore + def aggregate(self): """ Execute reduction logic for the output of `compute_meandice`. @@ -97,11 +93,7 @@ def aggregate(self): # type: ignore return (f, not_nans) if self.get_not_nans else f -def compute_meandice( - y_pred: torch.Tensor, - y: torch.Tensor, - include_background: bool = True, -) -> torch.Tensor: +def compute_meandice(y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True) -> torch.Tensor: """Computes Dice score metric from full size Tensor and collects average. Args: @@ -122,16 +114,13 @@ def compute_meandice( """ if not include_background: - y_pred, y = ignore_background( - y_pred=y_pred, - y=y, - ) + y_pred, y = ignore_background(y_pred=y_pred, y=y) y = y.float() y_pred = y_pred.float() if y.shape != y_pred.shape: - raise ValueError("y_pred and y should have same shapes.") + raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") # reducing only spatial dimensions (not batch nor channels) n_len = len(y_pred.shape) @@ -142,8 +131,4 @@ def compute_meandice( y_pred_o = torch.sum(y_pred, dim=reduce_axis) denominator = y_o + y_pred_o - return torch.where( - y_o > 0, - (2.0 * intersection) / denominator, - torch.tensor(float("nan"), device=y_o.device), - ) + return torch.where(y_o > 0, (2.0 * intersection) / denominator, torch.tensor(float("nan"), device=y_o.device)) diff --git a/monai/metrics/metric.py b/monai/metrics/metric.py index bb4aa7c343..7782c4c468 100644 --- a/monai/metrics/metric.py +++ b/monai/metrics/metric.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,121 +15,168 @@ import torch from monai.config import TensorOrList -from monai.utils import evenly_divisible_all_gather +from monai.utils import convert_data_type, evenly_divisible_all_gather + +__all__ = ["Metric", "IterationMetric", "Cumulative", "CumulativeIterationMetric"] class Metric(ABC): """ - Base class of all Metrics interface. - `__call__` is designed to execute metric computation. + Base class for metric computation for evaluating the performance of a model. + `__call__` is designed to execute the computation. """ @abstractmethod - def __call__(self, *args: Any, **kwds: Any): + def __call__(self, *args: Any, **kwargs: Any): """ - API to execute the metric computation. - + This method should take raw model outputs as inputs, and return values that measure the models' quality. """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") class IterationMetric(Metric): """ - Base class of Metrics interface for computation on a batch of tensors, usually the data of 1 iteration. - `__call__` is supposed to compute independent logic for several samples of `y_pred` and `y`(optional). - Usually, subclass only needs to implement the `_compute_tensor` function for computation process. - The input data shape should be `list of channel-first tensors` or a `batch-first tensor`. + Base class for metrics computation at the iteration level, that is, on a min-batch of samples + usually using the model outcome of one iteration. + + `__call__` is designed to handle `y_pred` and `y` (optional) in torch tensors or a list/tuple of tensors. + Subclasses typically implement the `_compute_tensor` function for the actual tensor computation logic. """ def __call__(self, y_pred: TensorOrList, y: Optional[TensorOrList] = None): # type: ignore """ - Execute basic computation for model prediction and ground truth. - It can support both `list of channel-first Tensor` and `batch-first Tensor`. - And users can execute on every batch of data, then accumulate the results, or - accumulate the original `y_pred` and `y`, then execute on the accumulated data. + Execute basic computation for model prediction `y_pred` and ground truth `y` (optional). + It supports inputs of a list of "channel-first" Tensor and a "batch-first" Tensor. Args: - y_pred: the model prediction data to compute, must be a list of `channel-first` Tensor + y_pred: the raw model prediction data at one iteration, must be a list of `channel-first` Tensor or a `batch-first` Tensor. y: the ground truth to compute, must be a list of `channel-first` Tensor or a `batch-first` Tensor. + Returns: + The computed metric values at the iteration level. + The output shape could be a `batch-first` tensor or a list of `batch-first` tensors. + When it's a list of tensors, each item in the list can represent a specific type of metric. + """ ret: TensorOrList + # handling a list of channel-first data if isinstance(y_pred, (list, tuple)) or isinstance(y, (list, tuple)): - # if y_pred or y is a list of channel-first data, add batch dim and compute metric - ret = self._compute_list(y_pred, y) - elif isinstance(y_pred, torch.Tensor): - y_ = y.detach() if y is not None and isinstance(y, torch.Tensor) else None - ret = self._compute_tensor(y_pred.detach(), y_) - else: - raise ValueError("y_pred or y must be a list of `channel-first` Tensors or a `batch-first` Tensor.") - - return ret + return self._compute_list(y_pred, y) + # handling a single batch-first data + if isinstance(y_pred, torch.Tensor): + y_ = y.detach() if isinstance(y, torch.Tensor) else None + return self._compute_tensor(y_pred.detach(), y_) + raise ValueError("y_pred or y must be a list/tuple of `channel-first` Tensors or a `batch-first` Tensor.") def _compute_list(self, y_pred: TensorOrList, y: Optional[TensorOrList] = None): """ - Excute the computation for the y_pred and y items of a iteration, the data is in the list shape. - Will concat the results to guarantee the output shape of ret is BCHW[D], otherwise it's list of batch-first, - which is against our principle that data in metrics should be BCHW[D] or list of channel-first. - Note: subclass may enhance the operation with multi-threads to accelerate. + Execute the metric computation for `y_pred` and `y` in a list of "channel-first" tensors. + + The return value is a "batch-first" tensor, or a list of "batch-first" tensors. + When it's a list of tensors, each item in the list can represent a specific type of metric values. + + For example, `self._compute_tensor` may be implemented as returning a list of `batch_size` items, + where each item is a tuple of three values `tp`, `fp`, `fn` for true positives, false positives, + and false negatives respectively. This function will return a list of three items, + (`tp_batched`, `fp_batched`, `fn_batched`), where each item is a `batch_size`-length tensor. + Note: subclass may enhance the operation to have multi-thread support. """ - ret: TensorOrList if y is not None: ret = [self._compute_tensor(p.detach().unsqueeze(0), y_.detach().unsqueeze(0)) for p, y_ in zip(y_pred, y)] else: ret = [self._compute_tensor(p_.detach().unsqueeze(0), None) for p_ in y_pred] - # concat the list of results - if isinstance(ret[0], torch.Tensor): - ret = torch.cat(ret, dim=0) - elif isinstance(ret[0], (list, tuple)) and all(isinstance(i, torch.Tensor) for i in ret[0]): - # if _compute_tensor() returned not only 1 Tensor, concat them separately - ret = [torch.cat([k[i] for k in ret], dim=0) for i in range(len(ret[0]))] + # concat the list of results (e.g. a batch of evaluation scores) + if isinstance(ret[0], torch.Tensor): + return torch.cat(ret, dim=0) + # the result is a list of sequence of tensors (e.g. a batch of multi-class results) + if isinstance(ret[0], (list, tuple)) and all(isinstance(i, torch.Tensor) for i in ret[0]): + return [torch.cat(batch_i, dim=0) for batch_i in zip(*ret)] return ret @abstractmethod def _compute_tensor(self, y_pred: torch.Tensor, y: Optional[torch.Tensor] = None): """ - computation logic for the y_pred and y of a iteration, the data should be `batch-first` Tensors. - Every subclass metric should implement its own computation logic according to its algorithm. - + Computation logic for `y_pred` and `y` of an iteration, the data should be "batch-first" Tensors. + A subclass should implement its own computation logic. + The return value is usually a "batch_first" tensor, or a list of "batch_first" tensors. """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") -class Cumulative(ABC): +class Cumulative: """ Utility class for the typical cumulative computation process based on PyTorch Tensors. - It cumulates tensors in the buffer, then sync across distributed ranks and aggregate. + It provides interfaces to accumulate values in the local buffers, synchronize buffers across distributed nodes, + and aggregate the buffered values. + + In multi-processing, PyTorch programs usually distribute data to multiple nodes. Each node runs with a subset + of the data, adds values to its local buffers. Calling `get_buffer` could gather all the results and + `aggregate` can further handle the results to generate the final outcomes. - To speed up computation with multi-processing, PyTorch programs usually split data to distributed ranks - by `DistributedSampler` before an epoch, every rank then computes only based on its own data part and - `add` to the buffers in its process. Eventually, sync the values of all ranks to compute the final results. + Users can implement their own `aggregate` method to handle the results, + using `get_buffer` to get the buffered contents. Note: the data list should have the same length every time calling `add()` in a round, it will automatically create buffers according to the length of data list. - Typically, this class is expected to execute the steps referring to below examples:: + Typically, this class is expected to execute the following steps: + + .. code-block:: python + + from monai.metrics import Cumulative + + c = Cumulative() + c.append(1) # adds a value + c.extend([2, 3]) # adds a batch of values + c.extend([4, 5, 6]) # adds a batch of values + print(c.get_buffer()) # tensor([1, 2, 3, 4, 5, 6]) + print(len(c)) # 6 + c.reset() + print(len(c)) # 0 + + The following is an example of maintaining two internal buffers: + + .. code-block:: python + + from monai.metrics import Cumulative + + c = Cumulative() + c.append(1, 2) # adds a value to two buffers respectively + c.extend([3, 4], [5, 6]) # adds batches of values + print(c.get_buffer()) # [tensor([1, 3, 4]), tensor([2, 5, 6])] + print(len(c)) - cum = Cumulative() - cum.add(x, y) - cum.add(a, b) - cum.add(c, d) - cum.aggregate() - result = cum.get_buffer() - cum.reset() + The following is an example of extending with variable length data: + + .. code-block:: python + + import torch + from monai.metrics import Cumulative + + c = Cumulative() + c.extend(torch.zeros((8, 2)), torch.zeros((6, 2))) # adds batches + c.append(torch.zeros((2, ))) # adds a value + print(c.get_buffer()) # [torch.zeros((9, 2)), torch.zeros((6, 2))] + print(len(c)) """ def __init__(self): - self.buffer_num: int = 0 + """ + Initialize the internal buffers. + `self._buffers` are local buffers, they are not usually used directly. + `self._sync_buffers` are the buffers with all the results across all the nodes. + """ self._buffers: Optional[List[List[torch.Tensor]]] = None self._synced_tensors: Optional[List[Optional[torch.Tensor]]] = None self._synced: bool = False + self.reset() def reset(self): """ @@ -140,33 +187,55 @@ def reset(self): self._synced_tensors = None self._synced = False - def add(self, *data: torch.Tensor): + def extend(self, *data) -> None: + """ + Extend the local buffers with new ("batch-first") data. + A buffer will be allocated for each `data` item. + Compared with `self.append`, this method adds a "batch" of data to the local buffers. + + Args: + data: each item can be a "batch-first" tensor or a list of "channel-first" tensors. + they will be concatenated at the 0-th dimension when `get_buffer()` is called. + """ + if self._buffers is None: + self._buffers = [[] for _ in data] + for b, d in zip(self._buffers, data): + # converting to pytorch tensors so that we can use the distributed API + d_t, *_ = convert_data_type(d, output_type=torch.Tensor, wrap_sequence=True) + try: + b.extend([x[0] for x in torch.split(d_t, 1, dim=0)]) + except (AttributeError, IndexError, RuntimeError) as e: + raise TypeError( + f"{e}. `data` should be a batch-first tensor or" + f" a list of channel-first tensors, got {type(d_t)}" + ) from e + self._synced = False + + def append(self, *data) -> None: """ - Add samples to the cumulative buffers. + Add samples to the local cumulative buffers. + A buffer will be allocated for each `data` item. + Compared with `self.extend`, this method adds a single sample (instead + of a "batch") to the local buffers. Args: - data: list of input tensor, make sure the input data order is always the same in a round. - every item of data will be added to the corresponding buffer. + data: each item will be converted into a torch tensor. + they will be stacked at the 0-th dim with a new dimension when `get_buffer()` is called. """ - data_len = len(data) if self._buffers is None: - self._buffers = [[] for _ in range(data_len)] - elif len(self._buffers) != data_len: - raise ValueError(f"data length: {data_len} doesn't match buffers length: {len(self._buffers)}.") - if self._synced_tensors is None: - self._synced_tensors = [None for _ in range(data_len)] - - for i, d in enumerate(data): - if not isinstance(d, torch.Tensor): - raise ValueError(f"the data to cumulate in a buffer must be PyTorch Tensor, but got: {type(d)}.") - self._buffers[i].append(d) + self._buffers = [[] for _ in data] + for b, d in zip(self._buffers, data): + # converting to pytorch tensors so that we can use the distributed API + d_t, *_ = convert_data_type(d, output_type=torch.Tensor, wrap_sequence=True) + b.append(d_t) self._synced = False @abstractmethod - def aggregate(self, *args: Any, **kwds: Any): + def aggregate(self, *args: Any, **kwargs: Any): """ - Aggregate final results based on the buffers. + Aggregate final results based on the gathered buffers. + This method is expected to use `get_buffer` to gather the local buffer contents. """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") @@ -174,28 +243,67 @@ def aggregate(self, *args: Any, **kwds: Any): def _sync(self): """ All gather the buffers across distributed ranks for aggregating. - Every buffer will be concatenated as a PyTorch Tensor. + Each buffer will be concatenated as a PyTorch Tensor. """ - self._synced_tensors = [evenly_divisible_all_gather(torch.cat(b, dim=0), concat=True) for b in self._buffers] + if self._synced or self._buffers is None: + return + try: + self._synced_tensors = [ + evenly_divisible_all_gather(torch.stack(b, dim=0), concat=True) for b in self._buffers + ] + except (RuntimeError, TypeError, ValueError) as e: + raise TypeError(f"{e}. unable to sync buffer contents: {self._buffers}.") from e self._synced = True + def __len__(self): + """ + Return the length of the largest buffer. + Note that the method will trigger synchronization of the local buffers. + """ + self._sync() + if not self._synced_tensors: + return 0 + return max(len(x) for x in self._synced_tensors) + def get_buffer(self): """ - Get the synced buffers list. + Get the synchronized list of buffers. A typical usage is to generate the metrics report based on the raw metric details. + Each buffer is a PyTorch Tensor. """ - if not self._synced: - self._sync() + self._sync() return self._synced_tensors[0] if len(self._synced_tensors) == 1 else self._synced_tensors class CumulativeIterationMetric(Cumulative, IterationMetric): """ - Base class of cumulative metric which computes on batch data of every iteration and aggregate. - Typically, it computes some intermediate results for every iteration, cumulates in buffers, - then syncs across all the distributed ranks and aggregates for the final result when epoch completed. + Base class of cumulative metric which collects metrics on each mini-batch data at the iteration level. + + Typically, it computes some intermediate results for each iteration, adds them to the buffers, + then the buffer contents could be gathered and aggregated for the final result when epoch completed. + + For example, `MeanDice` inherits this class and the usage is as follows: + + .. code-block:: python + + dice_metric = DiceMetric(include_background=True, reduction="mean") + + for val_data in val_loader: + val_outputs = model(val_data["img"]) + val_outputs = [postprocessing_transform(i) for i in decollate_batch(val_outputs)] + # compute metric for current iteration + dice_metric(y_pred=val_outputs, y=val_data["seg"]) # callable to add metric to the buffer + + # aggregate the final mean dice result + metric = dice_metric.aggregate().item() + + # reset the status for next computation round + dice_metric.reset() + + And to load `predictions` and `labels` from files, then compute metrics with multi-processing, please refer to: + https://github.com/Project-MONAI/tutorials/blob/master/modules/compute_metric.py. """ @@ -212,11 +320,13 @@ def __call__(self, y_pred: TensorOrList, y: Optional[TensorOrList] = None): # t y: the ground truth to compute, must be a list of `channel-first` Tensor or a `batch-first` Tensor. + Returns: + The computed metric values at the iteration level. """ ret = super().__call__(y_pred=y_pred, y=y) if isinstance(ret, (tuple, list)): - self.add(*ret) + self.extend(*ret) else: - self.add(ret) + self.extend(ret) return ret diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py index a2a2f0853d..d5733eee97 100644 --- a/monai/metrics/regression.py +++ b/monai/metrics/regression.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -30,24 +30,22 @@ class RegressionMetric(CumulativeIterationMetric): `y_preds` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]). Args: - reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``} - Define the mode to reduce computation result. Defaults to ``"mean"``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. """ def __init__( - self, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, - get_not_nans: bool = False, + self, reduction: Union[MetricReduction, str] = MetricReduction.MEAN, get_not_nans: bool = False ) -> None: super().__init__() self.reduction = reduction self.get_not_nans = get_not_nans - def aggregate(self): # type: ignore + def aggregate(self): data = self.get_buffer() if not isinstance(data, torch.Tensor): raise ValueError("the data to aggregate must be PyTorch Tensor.") @@ -57,9 +55,7 @@ def aggregate(self): # type: ignore def _check_shape(self, y_pred: torch.Tensor, y: torch.Tensor) -> None: if y_pred.shape != y.shape: - raise ValueError( - "y_pred and y shapes dont match, received y_pred: [{}] and y: [{}]".format(y_pred.shape, y.shape) - ) + raise ValueError(f"y_pred and y shapes dont match, received y_pred: [{y_pred.shape}] and y: [{y.shape}]") # also check if there is atleast one non-batch dimension i.e. num_dims >= 2 if len(y_pred.shape) < 2: @@ -88,17 +84,15 @@ class MSEMetric(RegressionMetric): Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model. Args: - reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``} - Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). """ def __init__( - self, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, - get_not_nans: bool = False, + self, reduction: Union[MetricReduction, str] = MetricReduction.MEAN, get_not_nans: bool = False ) -> None: super().__init__(reduction=reduction, get_not_nans=get_not_nans) self.sq_func = partial(torch.pow, exponent=2.0) @@ -122,17 +116,15 @@ class MAEMetric(RegressionMetric): Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model. Args: - reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``} - Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). """ def __init__( - self, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, - get_not_nans: bool = False, + self, reduction: Union[MetricReduction, str] = MetricReduction.MEAN, get_not_nans: bool = False ) -> None: super().__init__(reduction=reduction, get_not_nans=get_not_nans) self.abs_func = torch.abs @@ -157,17 +149,15 @@ class RMSEMetric(RegressionMetric): Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model. Args: - reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``} - Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). """ def __init__( - self, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, - get_not_nans: bool = False, + self, reduction: Union[MetricReduction, str] = MetricReduction.MEAN, get_not_nans: bool = False ) -> None: super().__init__(reduction=reduction, get_not_nans=get_not_nans) self.sq_func = partial(torch.pow, exponent=2.0) @@ -198,9 +188,9 @@ class PSNRMetric(RegressionMetric): Args: max_val: The dynamic range of the images/volumes (i.e., the difference between the maximum and the minimum allowed values e.g. 255 for a uint8 image). - reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``} - Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). """ diff --git a/monai/metrics/rocauc.py b/monai/metrics/rocauc.py index c2679cc2ea..221fc50272 100644 --- a/monai/metrics/rocauc.py +++ b/monai/metrics/rocauc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import Union, cast import numpy as np @@ -48,7 +49,7 @@ def __init__(self, average: Union[Average, str] = Average.MACRO) -> None: def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore return y_pred, y - def aggregate(self): # type: ignore + def aggregate(self): """ As AUC metric needs to execute on the overall data, so usually users accumulate `y_pred` and `y` of every iteration, then execute real computation and reduction on the accumulated data. @@ -65,8 +66,14 @@ def aggregate(self): # type: ignore def _calculate(y_pred: torch.Tensor, y: torch.Tensor) -> float: if not (y.ndimension() == y_pred.ndimension() == 1 and len(y) == len(y_pred)): raise AssertionError("y and y_pred must be 1 dimension data with same length.") - if not y.unique().equal(torch.tensor([0, 1], dtype=y.dtype, device=y.device)): - raise AssertionError("y values must be 0 or 1, can not be all 0 or all 1.") + y_unique = y.unique() + if len(y_unique) == 1: + warnings.warn(f"y values can not be all {y_unique.item()}, skip AUC computation and return `Nan`.") + return float("nan") + if not y_unique.equal(torch.tensor([0, 1], dtype=y.dtype, device=y.device)): + warnings.warn(f"y values must be 0 or 1, but in {y_unique.tolist()}, skip AUC computation and return `Nan`.") + return float("nan") + n = len(y) indices = y_pred.argsort() y = y[indices].cpu().numpy() @@ -93,20 +100,18 @@ def _calculate(y_pred: torch.Tensor, y: torch.Tensor) -> float: return auc / (nneg * (n - nneg)) -def compute_roc_auc( - y_pred: torch.Tensor, - y: torch.Tensor, - average: Union[Average, str] = Average.MACRO, -): +def compute_roc_auc(y_pred: torch.Tensor, y: torch.Tensor, average: Union[Average, str] = Average.MACRO): """Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to: `sklearn.metrics.roc_auc_score `_. Args: y_pred: input data to compute, typical classification model output. - it must be One-Hot format and first dim is batch, example shape: [16] or [16, 2]. - y: ground truth to compute ROC AUC metric, the first dim is batch. - example shape: [16, 1] will be converted into [16, 2] (where `2` is inferred from `y_pred`). + the first dim must be batch, if multi-classes, it must be in One-Hot format. + for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data. + y: ground truth to compute ROC AUC metric, the first dim must be batch. + if multi-classes, it must be in One-Hot format. + for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data. average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``} Type of averaging performed if not binary classification. Defaults to ``"macro"``. @@ -131,9 +136,11 @@ def compute_roc_auc( y_pred_ndim = y_pred.ndimension() y_ndim = y.ndimension() if y_pred_ndim not in (1, 2): - raise ValueError("Predictions should be of shape (batch_size, num_classes) or (batch_size, ).") + raise ValueError( + f"Predictions should be of shape (batch_size, num_classes) or (batch_size, ), got {y_pred.shape}." + ) if y_ndim not in (1, 2): - raise ValueError("Targets should be of shape (batch_size, num_classes) or (batch_size, ).") + raise ValueError(f"Targets should be of shape (batch_size, num_classes) or (batch_size, ), got {y.shape}.") if y_pred_ndim == 2 and y_pred.shape[1] == 1: y_pred = y_pred.squeeze(dim=-1) y_pred_ndim = 1 @@ -144,7 +151,7 @@ def compute_roc_auc( return _calculate(y_pred, y) if y.shape != y_pred.shape: - raise AssertionError("data shapes of y_pred and y do not match.") + raise ValueError(f"data shapes of y_pred and y do not match, got {y_pred.shape} and {y.shape}.") average = look_up_option(average, Average) if average == Average.MICRO: diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 6039f1b55e..2c84bb9e7c 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,7 +16,7 @@ import torch from monai.metrics.utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background -from monai.utils import MetricReduction +from monai.utils import MetricReduction, convert_data_type from .metric import CumulativeIterationMetric @@ -37,9 +37,9 @@ class SurfaceDistanceMetric(CumulativeIterationMetric): `seg_pred` and `seg_gt`. Defaults to ``False``. distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] the metric used to compute surface distance. Defaults to ``"euclidean"``. - reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``} - Define the mode to reduce computation result. Defaults to ``"mean"``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. @@ -78,7 +78,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor if not torch.all(y_pred.byte() == y_pred): warnings.warn("y_pred should be a binarized tensor.") if not torch.all(y.byte() == y): - raise ValueError("y should be a binarized tensor.") + warnings.warn("y should be a binarized tensor.") dims = y_pred.ndimension() if dims < 3: raise ValueError("y_pred should have at least three dimensions.") @@ -91,7 +91,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor distance_metric=self.distance_metric, ) - def aggregate(self): # type: ignore + def aggregate(self): """ Execute reduction logic for the output of `compute_average_surface_distance`. @@ -134,10 +134,7 @@ def compute_average_surface_distance( """ if not include_background: - y_pred, y = ignore_background( - y_pred=y_pred, - y=y, - ) + y_pred, y = ignore_background(y_pred=y_pred, y=y) if isinstance(y, torch.Tensor): y = y.float() @@ -145,7 +142,7 @@ def compute_average_surface_distance( y_pred = y_pred.float() if y.shape != y_pred.shape: - raise ValueError("y_pred and y should have same shapes.") + raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") batch_size, n_class = y_pred.shape[:2] asd = np.empty((batch_size, n_class)) @@ -156,20 +153,10 @@ def compute_average_surface_distance( warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.") if not np.any(edges_pred): warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.") - surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric) - if surface_distance.shape == (0,): - avg_surface_distance = np.nan - else: - avg_surface_distance = surface_distance.mean() # type: ignore - if not symmetric: - asd[b, c] = avg_surface_distance - else: + if symmetric: surface_distance_2 = get_surface_distance(edges_gt, edges_pred, distance_metric=distance_metric) - if surface_distance_2.shape == (0,): - avg_surface_distance_2 = np.nan - else: - avg_surface_distance_2 = surface_distance_2.mean() # type: ignore - asd[b, c] = np.mean((avg_surface_distance, avg_surface_distance_2)) + surface_distance = np.concatenate([surface_distance, surface_distance_2]) + asd[b, c] = np.nan if surface_distance.shape == (0,) else surface_distance.mean() - return torch.from_numpy(asd) + return convert_data_type(asd, torch.Tensor)[0] diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 84de834f74..fc42100d6f 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,12 +25,10 @@ __all__ = ["ignore_background", "do_metric_reduction", "get_mask_edges", "get_surface_distance"] -def ignore_background( - y_pred: Union[np.ndarray, torch.Tensor], - y: Union[np.ndarray, torch.Tensor], -): +def ignore_background(y_pred: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor]): """ This function is used to remove background (the first channel) for `y_pred` and `y`. + Args: y_pred: predictions. As for classification tasks, `y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks, @@ -38,24 +36,24 @@ def ignore_background( y: ground truth, the first dim is batch. """ + y = y[:, 1:] if y.shape[1] > 1 else y y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred return y_pred, y -def do_metric_reduction( - f: torch.Tensor, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, -): +def do_metric_reduction(f: torch.Tensor, reduction: Union[MetricReduction, str] = MetricReduction.MEAN): """ - This function is to do the metric reduction for calculated metrics of each example's each class. + This function is to do the metric reduction for calculated `not-nan` metrics of each sample's each class. The function also returns `not_nans`, which counts the number of not nans for the metric. Args: f: a tensor that contains the calculated metric scores per batch and per class. The first two dims should be batch and class. - reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``}, if "none", return the input f tensor and not_nans. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. + if "none", return the input f tensor and not_nans. Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. Raises: @@ -118,7 +116,7 @@ def get_mask_edges( The input images can be binary or labelfield images. If labelfield images are supplied, they are converted to binary images using `label_idx`. - `scipy`'s binary erosion is used to to calculate the edges of the binary + `scipy`'s binary erosion is used to calculate the edges of the binary labelfield. In order to improve the computing efficiency, before getting the edges, @@ -146,7 +144,7 @@ def get_mask_edges( seg_gt = seg_gt.detach().cpu().numpy() if seg_pred.shape != seg_gt.shape: - raise ValueError("seg_pred and seg_gt should have same shapes.") + raise ValueError(f"seg_pred and seg_gt should have same shapes, got {seg_pred.shape} and {seg_gt.shape}.") # If not binary images, convert them if seg_pred.dtype != bool: @@ -156,7 +154,7 @@ def get_mask_edges( if crop: if not np.any(seg_pred | seg_gt): - return (np.zeros_like(seg_pred), np.zeros_like(seg_gt)) + return np.zeros_like(seg_pred), np.zeros_like(seg_gt) seg_pred, seg_gt = np.expand_dims(seg_pred, 0), np.expand_dims(seg_gt, 0) box_start, box_end = generate_spatial_bounding_box(np.asarray(seg_pred | seg_gt)) @@ -167,14 +165,10 @@ def get_mask_edges( edges_pred = binary_erosion(seg_pred) ^ seg_pred edges_gt = binary_erosion(seg_gt) ^ seg_gt - return (edges_pred, edges_gt) + return edges_pred, edges_gt -def get_surface_distance( - seg_pred: np.ndarray, - seg_gt: np.ndarray, - distance_metric: str = "euclidean", -) -> np.ndarray: +def get_surface_distance(seg_pred: np.ndarray, seg_gt: np.ndarray, distance_metric: str = "euclidean") -> np.ndarray: """ This function is used to compute the surface distances from `seg_pred` to `seg_gt`. diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index 3c347dad22..76223dfaef 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,14 +10,17 @@ # limitations under the License. from .utils import ( + convert_to_torchscript, copy_model_state, eval_mode, + get_state_dict, icnr_init, normal_init, normalize_transform, one_hot, pixelshuffle, predict_segmentation, + save_state, slice_channels, to_norm_affine, train_mode, diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index db723f622d..0fdc944760 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ from .aspp import SimpleASPP from .convolutions import Convolution, ResidualUnit from .crf import CRF +from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock from .downsample import MaxAvgPool from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding from .fcn import FCN, GCN, MCFCN, Refine diff --git a/monai/networks/blocks/acti_norm.py b/monai/networks/blocks/acti_norm.py index 593ca6baa7..65b662ac32 100644 --- a/monai/networks/blocks/acti_norm.py +++ b/monai/networks/blocks/acti_norm.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -98,4 +98,4 @@ def __init__( if item not in op_dict: raise ValueError(f"ordering must be a string of {op_dict}, got {item} in it.") if op_dict[item] is not None: - self.add_module(item, op_dict[item]) # type: ignore + self.add_module(item, op_dict[item]) diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index a380f8e757..1526b37056 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,7 +19,6 @@ def monai_mish(x, inplace: bool = False): return torch.nn.functional.mish(x, inplace=inplace) - else: def monai_mish(x, inplace: bool = False): @@ -31,7 +30,6 @@ def monai_mish(x, inplace: bool = False): def monai_swish(x, inplace: bool = False): return torch.nn.functional.silu(x, inplace=inplace) - else: def monai_swish(x, inplace: bool = False): @@ -48,8 +46,7 @@ class Swish(nn.Module): Shape: - - Input: :math:`(N, *)` where `*` means, any number of additional - dimensions + - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input @@ -123,7 +120,7 @@ class MemoryEfficientSwish(nn.Module): """ def __init__(self, inplace: bool = False): - super(MemoryEfficientSwish, self).__init__() + super().__init__() # inplace only works when using torch.nn.functional.silu self.inplace = inplace @@ -143,8 +140,7 @@ class Mish(nn.Module): this class will utilize `torch.nn.functional.mish` to do the calculation if meets the version. Shape: - - Input: :math:`(N, *)` where `*` means, any number of additional - dimensions + - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input @@ -158,7 +154,7 @@ class Mish(nn.Module): """ def __init__(self, inplace: bool = False): - super(Mish, self).__init__() + super().__init__() # inplace only works when using torch.nn.functional.mish self.inplace = inplace diff --git a/monai/networks/blocks/aspp.py b/monai/networks/blocks/aspp.py index f8bf8a5ba6..8d43530fa7 100644 --- a/monai/networks/blocks/aspp.py +++ b/monai/networks/blocks/aspp.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -86,7 +86,7 @@ def __init__( out_channels = conv_out_channels * len(pads) # final conv. output channels self.conv_k1 = Convolution( - dimensions=spatial_dims, + spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=1, diff --git a/monai/networks/blocks/convolutions.py b/monai/networks/blocks/convolutions.py index 39ce60e3f8..37530668a3 100644 --- a/monai/networks/blocks/convolutions.py +++ b/monai/networks/blocks/convolutions.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,6 +18,7 @@ from monai.networks.blocks import ADN from monai.networks.layers.convutils import same_padding, stride_minus_kernel_padding from monai.networks.layers.factories import Conv +from monai.utils.deprecate_utils import deprecated_arg class Convolution(nn.Sequential): @@ -59,7 +60,7 @@ class Convolution(nn.Sequential): ) Args: - dimensions: number of spatial dimensions. + spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. strides: convolution stride. Defaults to 1. @@ -69,13 +70,13 @@ class Convolution(nn.Sequential): act: activation type and arguments. Defaults to PReLU. norm: feature normalization type and arguments. Defaults to instance norm. dropout: dropout ratio. Defaults to no dropout. - dropout_dim: determine the dimensions of dropout. Defaults to 1. + dropout_dim: determine the spatial dimensions of dropout. Defaults to 1. - When dropout_dim = 1, randomly zeroes some of the elements for each channel. - When dropout_dim = 2, Randomly zeroes out entire channels (a channel is a 2D feature map). - When dropout_dim = 3, Randomly zeroes out entire channels (a channel is a 3D feature map). - The value of dropout_dim should be no no larger than the value of `dimensions`. + The value of dropout_dim should be no no larger than the value of `spatial_dims`. dilation: dilation rate. Defaults to 1. groups: controls the connections between inputs and outputs. Defaults to 1. bias: whether to have a bias term. Defaults to True. @@ -86,6 +87,9 @@ class Convolution(nn.Sequential): output_padding: controls the additional size added to one side of the output shape. Defaults to None. + .. deprecated:: 0.6.0 + ``dimensions`` is deprecated, use ``spatial_dims`` instead. + See also: :py:class:`monai.networks.layers.Conv` @@ -93,9 +97,12 @@ class Convolution(nn.Sequential): """ + @deprecated_arg( + name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." + ) def __init__( self, - dimensions: int, + spatial_dims: int, in_channels: int, out_channels: int, strides: Union[Sequence[int], int] = 1, @@ -112,15 +119,16 @@ def __init__( is_transposed: bool = False, padding: Optional[Union[Sequence[int], int]] = None, output_padding: Optional[Union[Sequence[int], int]] = None, + dimensions: Optional[int] = None, ) -> None: super().__init__() - self.dimensions = dimensions + self.dimensions = spatial_dims if dimensions is None else dimensions self.in_channels = in_channels self.out_channels = out_channels self.is_transposed = is_transposed if padding is None: padding = same_padding(kernel_size, dilation) - conv_type = Conv[Conv.CONVTRANS if is_transposed else Conv.CONV, dimensions] + conv_type = Conv[Conv.CONVTRANS if is_transposed else Conv.CONV, self.dimensions] conv: nn.Module if is_transposed: @@ -159,7 +167,7 @@ def __init__( in_channels=out_channels, act=act, norm=norm, - norm_dim=dimensions, + norm_dim=self.dimensions, dropout=dropout, dropout_dim=dropout_dim, ), @@ -177,7 +185,7 @@ class ResidualUnit(nn.Module): from monai.networks.blocks import ResidualUnit convs = ResidualUnit( - dimensions=3, + spatial_dims=3, in_channels=1, out_channels=1, adn_ordering="AN", @@ -209,7 +217,7 @@ class ResidualUnit(nn.Module): ) Args: - dimensions: number of spatial dimensions. + spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. strides: convolution stride. Defaults to 1. @@ -234,15 +242,19 @@ class ResidualUnit(nn.Module): padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each dimension. Defaults to None. + .. deprecated:: 0.6.0 + ``dimensions`` is deprecated, use ``spatial_dims`` instead. + See also: :py:class:`monai.networks.blocks.Convolution` """ + @deprecated_arg(name="dimensions", since="0.6", msg_suffix="Please use `spatial_dims` instead.") def __init__( self, - dimensions: int, + spatial_dims: int, in_channels: int, out_channels: int, strides: Union[Sequence[int], int] = 1, @@ -257,9 +269,10 @@ def __init__( bias: bool = True, last_conv_only: bool = False, padding: Optional[Union[Sequence[int], int]] = None, + dimensions: Optional[int] = None, ) -> None: super().__init__() - self.dimensions = dimensions + self.dimensions = spatial_dims if dimensions is None else dimensions self.in_channels = in_channels self.out_channels = out_channels self.conv = nn.Sequential() @@ -273,7 +286,7 @@ def __init__( for su in range(subunits): conv_only = last_conv_only and su == (subunits - 1) unit = Convolution( - dimensions, + self.dimensions, schannels, out_channels, strides=sstrides, @@ -304,7 +317,7 @@ def __init__( rkernel_size = 1 rpadding = 0 - conv_type = Conv[Conv.CONV, dimensions] + conv_type = Conv[Conv.CONV, self.dimensions] self.residual = conv_type(in_channels, out_channels, rkernel_size, strides, rpadding, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/monai/networks/blocks/crf.py b/monai/networks/blocks/crf.py index 49ff5bcd04..b6382adf5f 100644 --- a/monai/networks/blocks/crf.py +++ b/monai/networks/blocks/crf.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,6 +15,7 @@ from torch.nn.functional import softmax from monai.networks.layers.filtering import PHLFilter +from monai.networks.utils import meshgrid_ij __all__ = ["CRF"] @@ -57,7 +58,7 @@ def __init__( compatibility_matrix: a matrix describing class compatibility, should be NxN where N is the number of classes. """ - super(CRF, self).__init__() + super().__init__() self.iterations = iterations self.bilateral_weight = bilateral_weight self.gaussian_weight = gaussian_weight @@ -114,6 +115,6 @@ def forward(self, input_tensor: torch.Tensor, reference_tensor: torch.Tensor): # helper methods def _create_coordinate_tensor(tensor): axes = [torch.arange(tensor.size(i)) for i in range(2, tensor.dim())] - grids = torch.meshgrid(axes) + grids = meshgrid_ij(axes) coords = torch.stack(grids).to(device=tensor.device, dtype=tensor.dtype) return torch.stack(tensor.size(0) * [coords], dim=0) diff --git a/monai/networks/blocks/dints_block.py b/monai/networks/blocks/dints_block.py new file mode 100644 index 0000000000..f76e125fe0 --- /dev/null +++ b/monai/networks/blocks/dints_block.py @@ -0,0 +1,272 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Tuple, Union + +import torch + +from monai.networks.layers.factories import Conv +from monai.networks.layers.utils import get_act_layer, get_norm_layer + +__all__ = ["FactorizedIncreaseBlock", "FactorizedReduceBlock", "P3DActiConvNormBlock", "ActiConvNormBlock"] + + +class FactorizedIncreaseBlock(torch.nn.Sequential): + """ + Up-sampling the features by two using linear interpolation and convolutions. + """ + + def __init__( + self, + in_channel: int, + out_channel: int, + spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = "INSTANCE", + ): + """ + Args: + in_channel: number of input channels + out_channel: number of output channels + spatial_dims: number of spatial dimensions + act_name: activation layer type and arguments. + norm_name: feature normalization type and arguments. + """ + super().__init__() + self._in_channel = in_channel + self._out_channel = out_channel + self._spatial_dims = spatial_dims + if self._spatial_dims not in (2, 3): + raise ValueError("spatial_dims must be 2 or 3.") + + conv_type = Conv[Conv.CONV, self._spatial_dims] + mode = "trilinear" if self._spatial_dims == 3 else "bilinear" + self.add_module("up", torch.nn.Upsample(scale_factor=2, mode=mode, align_corners=True)) + self.add_module("acti", get_act_layer(name=act_name)) + self.add_module( + "conv", + conv_type( + in_channels=self._in_channel, + out_channels=self._out_channel, + kernel_size=1, + stride=1, + padding=0, + groups=1, + bias=False, + dilation=1, + ), + ) + self.add_module( + "norm", get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel) + ) + + +class FactorizedReduceBlock(torch.nn.Module): + """ + Down-sampling the feature by 2 using stride. + The length along each spatial dimension must be a multiple of 2. + """ + + def __init__( + self, + in_channel: int, + out_channel: int, + spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = "INSTANCE", + ): + """ + Args: + in_channel: number of input channels + out_channel: number of output channels. + spatial_dims: number of spatial dimensions. + act_name: activation layer type and arguments. + norm_name: feature normalization type and arguments. + """ + super().__init__() + self._in_channel = in_channel + self._out_channel = out_channel + self._spatial_dims = spatial_dims + if self._spatial_dims not in (2, 3): + raise ValueError("spatial_dims must be 2 or 3.") + + conv_type = Conv[Conv.CONV, self._spatial_dims] + + self.act = get_act_layer(name=act_name) + self.conv_1 = conv_type( + in_channels=self._in_channel, + out_channels=self._out_channel // 2, + kernel_size=1, + stride=2, + padding=0, + groups=1, + bias=False, + dilation=1, + ) + self.conv_2 = conv_type( + in_channels=self._in_channel, + out_channels=self._out_channel - self._out_channel // 2, + kernel_size=1, + stride=2, + padding=0, + groups=1, + bias=False, + dilation=1, + ) + self.norm = get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + The length along each spatial dimension must be a multiple of 2. + """ + x = self.act(x) + if self._spatial_dims == 3: + out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:, 1:])], dim=1) + else: + out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1) + out = self.norm(out) + return out + + +class P3DActiConvNormBlock(torch.nn.Sequential): + """ + -- (act) -- (conv) -- (norm) -- + """ + + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: int, + padding: int, + mode: int = 0, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = "INSTANCE", + ): + """ + Args: + in_channel: number of input channels. + out_channel: number of output channels. + kernel_size: kernel size to be expanded to 3D. + padding: padding size to be expanded to 3D. + mode: mode for the anisotropic kernels: + + - 0: ``(k, k, 1)``, ``(1, 1, k)``, + - 1: ``(k, 1, k)``, ``(1, k, 1)``, + - 2: ``(1, k, k)``. ``(k, 1, 1)``. + + act_name: activation layer type and arguments. + norm_name: feature normalization type and arguments. + """ + super().__init__() + self._in_channel = in_channel + self._out_channel = out_channel + self._p3dmode = int(mode) + + conv_type = Conv[Conv.CONV, 3] + + if self._p3dmode == 0: # (k, k, 1), (1, 1, k) + kernel_size0 = (kernel_size, kernel_size, 1) + kernel_size1 = (1, 1, kernel_size) + padding0 = (padding, padding, 0) + padding1 = (0, 0, padding) + elif self._p3dmode == 1: # (k, 1, k), (1, k, 1) + kernel_size0 = (kernel_size, 1, kernel_size) + kernel_size1 = (1, kernel_size, 1) + padding0 = (padding, 0, padding) + padding1 = (0, padding, 0) + elif self._p3dmode == 2: # (1, k, k), (k, 1, 1) + kernel_size0 = (1, kernel_size, kernel_size) + kernel_size1 = (kernel_size, 1, 1) + padding0 = (0, padding, padding) + padding1 = (padding, 0, 0) + else: + raise ValueError("`mode` must be 0, 1, or 2.") + + self.add_module("acti", get_act_layer(name=act_name)) + self.add_module( + "conv", + conv_type( + in_channels=self._in_channel, + out_channels=self._in_channel, + kernel_size=kernel_size0, + stride=1, + padding=padding0, + groups=1, + bias=False, + dilation=1, + ), + ) + self.add_module( + "conv_1", + conv_type( + in_channels=self._in_channel, + out_channels=self._out_channel, + kernel_size=kernel_size1, + stride=1, + padding=padding1, + groups=1, + bias=False, + dilation=1, + ), + ) + self.add_module("norm", get_norm_layer(name=norm_name, spatial_dims=3, channels=self._out_channel)) + + +class ActiConvNormBlock(torch.nn.Sequential): + """ + -- (Acti) -- (Conv) -- (Norm) -- + """ + + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: int = 3, + padding: int = 1, + spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = "INSTANCE", + ): + """ + Args: + in_channel: number of input channels. + out_channel: number of output channels. + kernel_size: kernel size of the convolution. + padding: padding size of the convolution. + spatial_dims: number of spatial dimensions. + act_name: activation layer type and arguments. + norm_name: feature normalization type and arguments. + """ + super().__init__() + self._in_channel = in_channel + self._out_channel = out_channel + self._spatial_dims = spatial_dims + + conv_type = Conv[Conv.CONV, self._spatial_dims] + self.add_module("acti", get_act_layer(name=act_name)) + self.add_module( + "conv", + conv_type( + in_channels=self._in_channel, + out_channels=self._out_channel, + kernel_size=kernel_size, + stride=1, + padding=padding, + groups=1, + bias=False, + dilation=1, + ), + ) + self.add_module( + "norm", get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel) + ) diff --git a/monai/networks/blocks/downsample.py b/monai/networks/blocks/downsample.py index 9bee4c596e..9b0d5dd4b9 100644 --- a/monai/networks/blocks/downsample.py +++ b/monai/networks/blocks/downsample.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/blocks/dynunet_block.py b/monai/networks/blocks/dynunet_block.py index bb654d841c..8b22cb16a9 100644 --- a/monai/networks/blocks/dynunet_block.py +++ b/monai/networks/blocks/dynunet_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -33,6 +33,8 @@ class UnetResBlock(nn.Module): kernel_size: convolution kernel size. stride: convolution stride. norm_name: feature normalization type and arguments. + act_name: activation layer type and arguments. + dropout: dropout probability. """ @@ -44,33 +46,26 @@ def __init__( kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], norm_name: Union[Tuple, str], + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), + dropout: Optional[Union[Tuple, str, float]] = None, ): - super(UnetResBlock, self).__init__() + super().__init__() self.conv1 = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=kernel_size, stride=stride, + dropout=dropout, conv_only=True, ) self.conv2 = get_conv_layer( - spatial_dims, - out_channels, - out_channels, - kernel_size=kernel_size, - stride=1, - conv_only=True, + spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True ) self.conv3 = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=1, - stride=stride, - conv_only=True, + spatial_dims, in_channels, out_channels, kernel_size=1, stride=stride, dropout=dropout, conv_only=True ) - self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01})) + self.lrelu = get_act_layer(name=act_name) self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) self.norm3 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) @@ -107,6 +102,8 @@ class UnetBasicBlock(nn.Module): kernel_size: convolution kernel size. stride: convolution stride. norm_name: feature normalization type and arguments. + act_name: activation layer type and arguments. + dropout: dropout probability. """ @@ -118,25 +115,23 @@ def __init__( kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], norm_name: Union[Tuple, str], + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), + dropout: Optional[Union[Tuple, str, float]] = None, ): - super(UnetBasicBlock, self).__init__() + super().__init__() self.conv1 = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=kernel_size, stride=stride, + dropout=dropout, conv_only=True, ) self.conv2 = get_conv_layer( - spatial_dims, - out_channels, - out_channels, - kernel_size=kernel_size, - stride=1, - conv_only=True, + spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True ) - self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01})) + self.lrelu = get_act_layer(name=act_name) self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) @@ -164,6 +159,9 @@ class UnetUpBlock(nn.Module): stride: convolution stride. upsample_kernel_size: convolution kernel size for transposed convolution layers. norm_name: feature normalization type and arguments. + act_name: activation layer type and arguments. + dropout: dropout probability. + trans_bias: transposed convolution bias. """ @@ -176,8 +174,11 @@ def __init__( stride: Union[Sequence[int], int], upsample_kernel_size: Union[Sequence[int], int], norm_name: Union[Tuple, str], + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), + dropout: Optional[Union[Tuple, str, float]] = None, + trans_bias: bool = False, ): - super(UnetUpBlock, self).__init__() + super().__init__() upsample_stride = upsample_kernel_size self.transp_conv = get_conv_layer( spatial_dims, @@ -185,6 +186,8 @@ def __init__( out_channels, kernel_size=upsample_kernel_size, stride=upsample_stride, + dropout=dropout, + bias=trans_bias, conv_only=True, is_transposed=True, ) @@ -194,7 +197,9 @@ def __init__( out_channels, kernel_size=kernel_size, stride=1, + dropout=dropout, norm_name=norm_name, + act_name=act_name, ) def forward(self, inp, skip): @@ -206,10 +211,12 @@ def forward(self, inp, skip): class UnetOutBlock(nn.Module): - def __init__(self, spatial_dims: int, in_channels: int, out_channels: int): - super(UnetOutBlock, self).__init__() + def __init__( + self, spatial_dims: int, in_channels: int, out_channels: int, dropout: Optional[Union[Tuple, str, float]] = None + ): + super().__init__() self.conv = get_conv_layer( - spatial_dims, in_channels, out_channels, kernel_size=1, stride=1, bias=True, conv_only=True + spatial_dims, in_channels, out_channels, kernel_size=1, stride=1, dropout=dropout, bias=True, conv_only=True ) def forward(self, inp): @@ -224,6 +231,7 @@ def get_conv_layer( stride: Union[Sequence[int], int] = 1, act: Optional[Union[Tuple, str]] = Act.PRELU, norm: Union[Tuple, str] = Norm.INSTANCE, + dropout: Optional[Union[Tuple, str, float]] = None, bias: bool = False, conv_only: bool = True, is_transposed: bool = False, @@ -240,6 +248,7 @@ def get_conv_layer( kernel_size=kernel_size, act=act, norm=norm, + dropout=dropout, bias=bias, conv_only=conv_only, is_transposed=is_transposed, @@ -249,8 +258,7 @@ def get_conv_layer( def get_padding( - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], + kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int] ) -> Union[Tuple[int, ...], int]: kernel_size_np = np.atleast_1d(kernel_size) @@ -264,9 +272,7 @@ def get_padding( def get_output_padding( - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - padding: Union[Sequence[int], int], + kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], padding: Union[Sequence[int], int] ) -> Union[Tuple[int, ...], int]: kernel_size_np = np.atleast_1d(kernel_size) stride_np = np.atleast_1d(stride) diff --git a/monai/networks/blocks/dynunet_block_v1.py b/monai/networks/blocks/dynunet_block_v1.py deleted file mode 100644 index d5d9bbf3dc..0000000000 --- a/monai/networks/blocks/dynunet_block_v1.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Sequence, Union - -import numpy as np -import torch.nn as nn - -from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, UnetUpBlock, get_conv_layer -from monai.networks.layers.factories import Norm -from monai.networks.layers.utils import get_act_layer - - -class _UnetResBlockV1(UnetResBlock): - """ - UnetResBlock for backward compatibility purpose. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - norm_name: str, - ): - nn.Module.__init__(self) - self.conv1 = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - conv_only=True, - ) - self.conv2 = get_conv_layer( - spatial_dims, - out_channels, - out_channels, - kernel_size=kernel_size, - stride=1, - conv_only=True, - ) - self.conv3 = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=1, - stride=stride, - conv_only=True, - ) - self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01})) - self.norm1 = _get_norm_layer(spatial_dims, out_channels, norm_name) - self.norm2 = _get_norm_layer(spatial_dims, out_channels, norm_name) - self.norm3 = _get_norm_layer(spatial_dims, out_channels, norm_name) - self.downsample = in_channels != out_channels - stride_np = np.atleast_1d(stride) - if not np.all(stride_np == 1): - self.downsample = True - - -class _UnetBasicBlockV1(UnetBasicBlock): - """ - UnetBasicBlock for backward compatibility purpose. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - norm_name: str, - ): - nn.Module.__init__(self) - self.conv1 = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - conv_only=True, - ) - self.conv2 = get_conv_layer( - spatial_dims, - out_channels, - out_channels, - kernel_size=kernel_size, - stride=1, - conv_only=True, - ) - self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01})) - self.norm1 = _get_norm_layer(spatial_dims, out_channels, norm_name) - self.norm2 = _get_norm_layer(spatial_dims, out_channels, norm_name) - - -class _UnetUpBlockV1(UnetUpBlock): - """ - UnetUpBlock for backward compatibility purpose. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - upsample_kernel_size: Union[Sequence[int], int], - norm_name: str, - ): - nn.Module.__init__(self) - upsample_stride = upsample_kernel_size - self.transp_conv = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=upsample_kernel_size, - stride=upsample_stride, - conv_only=True, - is_transposed=True, - ) - self.conv_block = _UnetBasicBlockV1( - spatial_dims, - out_channels + out_channels, - out_channels, - kernel_size=kernel_size, - stride=1, - norm_name=norm_name, - ) - - -def _get_norm_layer(spatial_dims: int, out_channels: int, norm_name: str, num_groups: int = 16): - if norm_name not in ["batch", "instance", "group"]: - raise ValueError(f"Unsupported normalization mode: {norm_name}") - if norm_name != "group": - return Norm[norm_name, spatial_dims](out_channels, affine=True) - if out_channels % num_groups != 0: - raise AssertionError("out_channels should be divisible by num_groups.") - return Norm[norm_name, spatial_dims](num_groups=num_groups, num_channels=out_channels, affine=True) diff --git a/monai/networks/blocks/fcn.py b/monai/networks/blocks/fcn.py index d84e506774..5833d4a262 100644 --- a/monai/networks/blocks/fcn.py +++ b/monai/networks/blocks/fcn.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -36,7 +36,7 @@ def __init__(self, inplanes: int, planes: int, ks: int = 7): planes: number of output channels. ks: kernel size for one dimension. Defaults to 7. """ - super(GCN, self).__init__() + super().__init__() conv2d_type: Type[nn.Conv2d] = Conv[Conv.CONV, 2] self.conv_l1 = conv2d_type(in_channels=inplanes, out_channels=planes, kernel_size=(ks, 1), padding=(ks // 2, 0)) @@ -67,7 +67,7 @@ def __init__(self, planes: int): Args: planes: number of input channels. """ - super(Refine, self).__init__() + super().__init__() relu_type: Type[nn.ReLU] = Act[Act.RELU] conv2d_type: Type[nn.Conv2d] = Conv[Conv.CONV, 2] @@ -116,7 +116,7 @@ class FCN(nn.Module): def __init__( self, out_channels: int = 1, upsample_mode: str = "bilinear", pretrained: bool = True, progress: bool = True ): - super(FCN, self).__init__() + super().__init__() conv2d_type: Type[nn.Conv2d] = Conv[Conv.CONV, 2] @@ -154,12 +154,7 @@ def __init__( self.transformer = self.conv2d_type(in_channels=256, out_channels=64, kernel_size=1) if self.upsample_mode == "transpose": - self.up_conv = UpSample( - dimensions=2, - in_channels=self.out_channels, - scale_factor=2, - mode="deconv", - ) + self.up_conv = UpSample(spatial_dims=2, in_channels=self.out_channels, scale_factor=2, mode="deconv") def forward(self, x: torch.Tensor): """ @@ -195,14 +190,7 @@ def forward(self, x: torch.Tensor): fs2 = self.refine7(F.interpolate(fs1, fm2.size()[2:], mode=self.upsample_mode, align_corners=True) + gcfm3) fs3 = self.refine8(F.interpolate(fs2, pool_x.size()[2:], mode=self.upsample_mode, align_corners=True) + gcfm4) fs4 = self.refine9(F.interpolate(fs3, conv_x.size()[2:], mode=self.upsample_mode, align_corners=True) + gcfm5) - return self.refine10( - F.interpolate( - fs4, - org_input.size()[2:], - mode=self.upsample_mode, - align_corners=True, - ) - ) + return self.refine10(F.interpolate(fs4, org_input.size()[2:], mode=self.upsample_mode, align_corners=True)) class MCFCN(FCN): @@ -231,12 +219,12 @@ def __init__( pretrained: bool = True, progress: bool = True, ): - super(MCFCN, self).__init__( + super().__init__( out_channels=out_channels, upsample_mode=upsample_mode, pretrained=pretrained, progress=progress ) self.init_proj = Convolution( - dimensions=2, + spatial_dims=2, in_channels=in_channels, out_channels=3, kernel_size=1, @@ -251,4 +239,4 @@ def forward(self, x: torch.Tensor): x: in shape (batch, in_channels, spatial_1, spatial_2). """ x = self.init_proj(x) - return super(MCFCN, self).forward(x) + return super().forward(x) diff --git a/monai/networks/blocks/localnet_block.py b/monai/networks/blocks/localnet_block.py index 3997d42436..41b76c7d4c 100644 --- a/monai/networks/blocks/localnet_block.py +++ b/monai/networks/blocks/localnet_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -29,7 +29,7 @@ def get_conv_block( norm: Optional[Union[Tuple, str]] = "BATCH", ) -> nn.Module: padding = same_padding(kernel_size) - return Convolution( + mod: nn.Module = Convolution( spatial_dims, in_channels, out_channels, @@ -40,33 +40,22 @@ def get_conv_block( conv_only=False, padding=padding, ) + return mod def get_conv_layer( - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int] = 3, + spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int] = 3 ) -> nn.Module: padding = same_padding(kernel_size) - return Convolution( - spatial_dims, - in_channels, - out_channels, - kernel_size=kernel_size, - bias=False, - conv_only=True, - padding=padding, + mod: nn.Module = Convolution( + spatial_dims, in_channels, out_channels, kernel_size=kernel_size, bias=False, conv_only=True, padding=padding ) + return mod -def get_deconv_block( - spatial_dims: int, - in_channels: int, - out_channels: int, -) -> nn.Module: - return Convolution( - dimensions=spatial_dims, +def get_deconv_block(spatial_dims: int, in_channels: int, out_channels: int) -> nn.Module: + mod: nn.Module = Convolution( + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, strides=2, @@ -77,26 +66,20 @@ def get_deconv_block( padding=1, output_padding=1, ) + return mod class ResidualBlock(nn.Module): def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], + self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int] ) -> None: - super(ResidualBlock, self).__init__() + super().__init__() if in_channels != out_channels: raise ValueError( f"expecting in_channels == out_channels, " f"got in_channels={in_channels}, out_channels={out_channels}" ) self.conv_block = get_conv_block( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size ) self.conv = get_conv_layer( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size @@ -110,22 +93,13 @@ def forward(self, x) -> torch.Tensor: class LocalNetResidualBlock(nn.Module): - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - ) -> None: - super(LocalNetResidualBlock, self).__init__() + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> None: + super().__init__() if in_channels != out_channels: raise ValueError( f"expecting in_channels == out_channels, " f"got in_channels={in_channels}, out_channels={out_channels}" ) - self.conv_layer = get_conv_layer( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - ) + self.conv_layer = get_conv_layer(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels) self.norm = Norm[Norm.BATCH, spatial_dims](out_channels) self.relu = nn.ReLU() @@ -147,11 +121,7 @@ class LocalNetDownSampleBlock(nn.Module): """ def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], + self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int] ) -> None: """ Args: @@ -162,16 +132,14 @@ def __init__( Raises: NotImplementedError: when ``kernel_size`` is even """ - super(LocalNetDownSampleBlock, self).__init__() + super().__init__() self.conv_block = get_conv_block( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size ) self.residual_block = ResidualBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size ) - self.max_pool = Pool[Pool.MAX, spatial_dims]( - kernel_size=2, - ) + self.max_pool = Pool[Pool.MAX, spatial_dims](kernel_size=2) def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -208,12 +176,7 @@ class LocalNetUpSampleBlock(nn.Module): DeepReg (https://github.com/DeepRegNet/DeepReg) """ - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - ) -> None: + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> None: """ Args: spatial_dims: number of spatial dimensions. @@ -222,21 +185,13 @@ def __init__( Raises: ValueError: when ``in_channels != 2 * out_channels`` """ - super(LocalNetUpSampleBlock, self).__init__() + super().__init__() self.deconv_block = get_deconv_block( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - ) - self.conv_block = get_conv_block( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels ) + self.conv_block = get_conv_block(spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels) self.residual_block = LocalNetResidualBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, + spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels ) if in_channels / out_channels != 2: raise ValueError( @@ -306,7 +261,7 @@ def __init__( act: activation type and arguments. Defaults to ReLU. kernel_initializer: kernel initializer. Defaults to None. """ - super(LocalNetFeatureExtractorBlock, self).__init__() + super().__init__() self.conv_block = get_conv_block( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, act=act, norm=None ) diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py index 11b5e6fc15..a1728365cf 100644 --- a/monai/networks/blocks/mlp.py +++ b/monai/networks/blocks/mlp.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,12 +18,7 @@ class MLPBlock(nn.Module): An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " """ - def __init__( - self, - hidden_size: int, - mlp_dim: int, - dropout_rate: float = 0.0, - ) -> None: + def __init__(self, hidden_size: int, mlp_dim: int, dropout_rate: float = 0.0) -> None: """ Args: hidden_size: dimension of hidden layer. diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index c1fcfa9af7..4c7263c6d5 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -62,7 +62,7 @@ def __init__( """ - super(PatchEmbeddingBlock, self).__init__() + super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") @@ -80,7 +80,7 @@ def __init__( if self.pos_embed == "perceptron" and m % p != 0: raise ValueError("patch_size should be divisible by img_size for perceptron.") self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_size, patch_size)]) - self.patch_dim = in_channels * np.prod(patch_size) + self.patch_dim = int(in_channels * np.prod(patch_size)) self.patch_embeddings: nn.Module if self.pos_embed == "conv": @@ -94,11 +94,9 @@ def __init__( to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)" axes_len = {f"p{i+1}": p for i, p in enumerate(patch_size)} self.patch_embeddings = nn.Sequential( - Rearrange(f"{from_chars} -> {to_chars}", **axes_len), - nn.Linear(self.patch_dim, hidden_size), + Rearrange(f"{from_chars} -> {to_chars}", **axes_len), nn.Linear(self.patch_dim, hidden_size) ) self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) - self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) self.dropout = nn.Dropout(dropout_rate) self.trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0) self.apply(self._init_weights) diff --git a/monai/networks/blocks/regunet_block.py b/monai/networks/blocks/regunet_block.py index d2cd3518b9..78e2598b4b 100644 --- a/monai/networks/blocks/regunet_block.py +++ b/monai/networks/blocks/regunet_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -32,7 +32,7 @@ def get_conv_block( ) -> nn.Module: if padding is None: padding = same_padding(kernel_size) - conv_block = Convolution( + conv_block: nn.Module = Convolution( spatial_dims, in_channels, out_channels, @@ -59,21 +59,13 @@ def get_conv_block( def get_conv_layer( - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int] = 3, + spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int] = 3 ) -> nn.Module: padding = same_padding(kernel_size) - return Convolution( - spatial_dims, - in_channels, - out_channels, - kernel_size=kernel_size, - bias=False, - conv_only=True, - padding=padding, + mod: nn.Module = Convolution( + spatial_dims, in_channels, out_channels, kernel_size=kernel_size, bias=False, conv_only=True, padding=padding ) + return mod class RegistrationResidualConvBlock(nn.Module): @@ -83,12 +75,7 @@ class RegistrationResidualConvBlock(nn.Module): """ def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - num_layers: int = 2, - kernel_size: int = 3, + self, spatial_dims: int, in_channels: int, out_channels: int, num_layers: int = 2, kernel_size: int = 3 ): """ @@ -99,7 +86,7 @@ def __init__( num_layers: number of layers inside the block kernel_size: kernel_size """ - super(RegistrationResidualConvBlock, self).__init__() + super().__init__() self.num_layers = num_layers self.layers = nn.ModuleList( [ @@ -145,19 +132,14 @@ class RegistrationDownSampleBlock(nn.Module): DeepReg (https://github.com/DeepRegNet/DeepReg) """ - def __init__( - self, - spatial_dims: int, - channels: int, - pooling: bool, - ) -> None: + def __init__(self, spatial_dims: int, channels: int, pooling: bool) -> None: """ Args: spatial_dims: number of spatial dimensions. channels: channels pooling: use MaxPool if True, strided conv if False """ - super(RegistrationDownSampleBlock, self).__init__() + super().__init__() if pooling: self.layer = Pool[Pool.MAX, spatial_dims](kernel_size=2) else: @@ -188,13 +170,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out -def get_deconv_block( - spatial_dims: int, - in_channels: int, - out_channels: int, -) -> nn.Module: - return Convolution( - dimensions=spatial_dims, +def get_deconv_block(spatial_dims: int, in_channels: int, out_channels: int) -> nn.Module: + mod: nn.Module = Convolution( + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, strides=2, @@ -205,6 +183,7 @@ def get_deconv_block( padding=1, output_padding=1, ) + return mod class RegistrationExtractionBlock(nn.Module): @@ -233,7 +212,7 @@ def __init__( kernel_initializer: kernel initializer activation: kernel activation function """ - super(RegistrationExtractionBlock, self).__init__() + super().__init__() self.extract_levels = extract_levels self.max_level = max(extract_levels) self.layers = nn.ModuleList( @@ -261,10 +240,7 @@ def forward(self, x: List[torch.Tensor], image_size: List[int]) -> torch.Tensor: Tensor of shape (batch, `out_channels`, size1, size2, size3), where (size1, size2, size3) = ``image_size`` """ feature_list = [ - F.interpolate( - layer(x[self.max_level - level]), - size=image_size, - ) + F.interpolate(layer(x[self.max_level - level]), size=image_size) for layer, level in zip(self.layers, self.extract_levels) ] out: torch.Tensor = torch.mean(torch.stack(feature_list, dim=0), dim=0) diff --git a/monai/networks/blocks/segresnet_block.py b/monai/networks/blocks/segresnet_block.py index d8f6d7b268..ded270ab52 100644 --- a/monai/networks/blocks/segresnet_block.py +++ b/monai/networks/blocks/segresnet_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,8 +15,7 @@ from monai.networks.blocks.convolutions import Convolution from monai.networks.blocks.upsample import UpSample -from monai.networks.layers.factories import Act -from monai.networks.layers.utils import get_norm_layer +from monai.networks.layers.utils import get_act_layer, get_norm_layer from monai.utils import InterpolateMode, UpsampleMode @@ -25,13 +24,7 @@ def get_conv_layer( ): return Convolution( - spatial_dims, - in_channels, - out_channels, - strides=stride, - kernel_size=kernel_size, - bias=bias, - conv_only=True, + spatial_dims, in_channels, out_channels, strides=stride, kernel_size=kernel_size, bias=bias, conv_only=True ) @@ -39,7 +32,7 @@ def get_upsample_layer( spatial_dims: int, in_channels: int, upsample_mode: Union[UpsampleMode, str] = "nontrainable", scale_factor: int = 2 ): return UpSample( - dimensions=spatial_dims, + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, scale_factor=scale_factor, @@ -62,6 +55,7 @@ def __init__( in_channels: int, norm: Union[Tuple, str], kernel_size: int = 3, + act: Union[Tuple, str] = ("RELU", {"inplace": True}), ) -> None: """ Args: @@ -69,6 +63,7 @@ def __init__( in_channels: number of input channels. norm: feature normalization type and arguments. kernel_size: convolution kernel size, the value should be an odd number. Defaults to 3. + act: activation type and arguments. Defaults to ``RELU``. """ super().__init__() @@ -78,7 +73,7 @@ def __init__( self.norm1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels) self.norm2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels) - self.relu = Act[Act.RELU](inplace=True) + self.act = get_act_layer(act) self.conv1 = get_conv_layer(spatial_dims, in_channels=in_channels, out_channels=in_channels) self.conv2 = get_conv_layer(spatial_dims, in_channels=in_channels, out_channels=in_channels) @@ -87,11 +82,11 @@ def forward(self, x): identity = x x = self.norm1(x) - x = self.relu(x) + x = self.act(x) x = self.conv1(x) x = self.norm2(x) - x = self.relu(x) + x = self.act(x) x = self.conv2(x) x += identity diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 9dc45cccc8..db92111d14 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,7 +14,7 @@ from monai.utils import optional_import -einops, _ = optional_import("einops") +Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") class SABlock(nn.Module): @@ -23,12 +23,7 @@ class SABlock(nn.Module): An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " """ - def __init__( - self, - hidden_size: int, - num_heads: int, - dropout_rate: float = 0.0, - ) -> None: + def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0) -> None: """ Args: hidden_size: dimension of hidden layer. @@ -37,7 +32,7 @@ def __init__( """ - super(SABlock, self).__init__() + super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") @@ -48,17 +43,20 @@ def __init__( self.num_heads = num_heads self.out_proj = nn.Linear(hidden_size, hidden_size) self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=False) + self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) + self.out_rearrange = Rearrange("b h l d -> b l (h d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) self.head_dim = hidden_size // num_heads - self.scale = self.head_dim ** -0.5 + self.scale = self.head_dim**-0.5 def forward(self, x): - q, k, v = einops.rearrange(self.qkv(x), "b h (qkv l d) -> qkv b l h d", qkv=3, l=self.num_heads) + output = self.input_rearrange(self.qkv(x)) + q, k, v = output[0], output[1], output[2] att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) - x = einops.rearrange(x, "b h l d -> b l (h d)") + x = self.out_rearrange(x) x = self.out_proj(x) x = self.drop_output(x) return x diff --git a/monai/networks/blocks/squeeze_and_excitation.py b/monai/networks/blocks/squeeze_and_excitation.py index 4db6dc30f7..a9ac57aa4f 100644 --- a/monai/networks/blocks/squeeze_and_excitation.py +++ b/monai/networks/blocks/squeeze_and_excitation.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -50,7 +50,7 @@ def __init__( :py:class:`monai.networks.layers.Act` """ - super(ChannelSELayer, self).__init__() + super().__init__() self.add_residual = add_residual @@ -181,21 +181,21 @@ def __init__( :py:class:`monai.networks.blocks.ChannelSELayer` """ - super(SEBlock, self).__init__() + super().__init__() if not conv_param_1: conv_param_1 = {"kernel_size": 1, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})} self.conv1 = Convolution( - dimensions=spatial_dims, in_channels=in_channels, out_channels=n_chns_1, **conv_param_1 + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=n_chns_1, **conv_param_1 ) if not conv_param_2: conv_param_2 = {"kernel_size": 3, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})} - self.conv2 = Convolution(dimensions=spatial_dims, in_channels=n_chns_1, out_channels=n_chns_2, **conv_param_2) + self.conv2 = Convolution(spatial_dims=spatial_dims, in_channels=n_chns_1, out_channels=n_chns_2, **conv_param_2) if not conv_param_3: conv_param_3 = {"kernel_size": 1, "norm": Norm.BATCH, "act": None} - self.conv3 = Convolution(dimensions=spatial_dims, in_channels=n_chns_2, out_channels=n_chns_3, **conv_param_3) + self.conv3 = Convolution(spatial_dims=spatial_dims, in_channels=n_chns_2, out_channels=n_chns_3, **conv_param_3) self.se_layer = ChannelSELayer( spatial_dims=spatial_dims, in_channels=n_chns_3, r=r, acti_type_1=acti_type_1, acti_type_2=acti_type_2 @@ -264,7 +264,7 @@ def __init__( } conv_param_3 = {"strides": 1, "kernel_size": 1, "act": None, "norm": Norm.BATCH, "bias": False} - super(SEBottleneck, self).__init__( + super().__init__( spatial_dims=spatial_dims, in_channels=inplanes, n_chns_1=planes * 2, @@ -315,7 +315,7 @@ def __init__( } conv_param_3 = {"strides": 1, "kernel_size": 1, "act": None, "norm": Norm.BATCH, "bias": False} - super(SEResNetBottleneck, self).__init__( + super().__init__( spatial_dims=spatial_dims, in_channels=inplanes, n_chns_1=planes, diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index c7a948ed76..616d84e067 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,13 +21,7 @@ class TransformerBlock(nn.Module): An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " """ - def __init__( - self, - hidden_size: int, - mlp_dim: int, - num_heads: int, - dropout_rate: float = 0.0, - ) -> None: + def __init__(self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0) -> None: """ Args: hidden_size: dimension of hidden layer. diff --git a/monai/networks/blocks/unetr_block.py b/monai/networks/blocks/unetr_block.py index a0852d05e0..a9d871a644 100644 --- a/monai/networks/blocks/unetr_block.py +++ b/monai/networks/blocks/unetr_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -46,7 +46,7 @@ def __init__( """ - super(UnetrUpBlock, self).__init__() + super().__init__() upsample_stride = upsample_kernel_size self.transp_conv = get_conv_layer( spatial_dims, diff --git a/monai/networks/blocks/upsample.py b/monai/networks/blocks/upsample.py index 5320611ce6..fa3929df20 100644 --- a/monai/networks/blocks/upsample.py +++ b/monai/networks/blocks/upsample.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,7 +16,7 @@ from monai.networks.layers.factories import Conv, Pad, Pool from monai.networks.utils import icnr_init, pixelshuffle -from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option +from monai.utils import InterpolateMode, UpsampleMode, deprecated_arg, ensure_tuple_rep, look_up_option __all__ = ["Upsample", "UpSample", "SubpixelUpsample", "Subpixelupsample", "SubpixelUpSample"] @@ -34,9 +34,12 @@ class UpSample(nn.Sequential): (often used to map the number of features from `in_channels` to `out_channels`). """ + @deprecated_arg( + name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." + ) def __init__( self, - dimensions: int, + spatial_dims: int, in_channels: Optional[int] = None, out_channels: Optional[int] = None, scale_factor: Union[Sequence[float], float] = 2, @@ -47,10 +50,11 @@ def __init__( align_corners: Optional[bool] = True, bias: bool = True, apply_pad_pool: bool = True, + dimensions: Optional[int] = None, ) -> None: """ Args: - dimensions: number of spatial dimensions of the input image. + spatial_dims: number of spatial dimensions of the input image. in_channels: number of channels of the input image. out_channels: number of channels of the output image. Defaults to `in_channels`. scale_factor: multiplier for spatial size. Has to match input size if it is a tuple. Defaults to 2. @@ -68,23 +72,28 @@ def __init__( If ends with ``"linear"`` will use ``spatial dims`` to determine the correct interpolation. This corresponds to linear, bilinear, trilinear for 1D, 2D, and 3D respectively. The interpolation mode. Defaults to ``"linear"``. - See also: https://pytorch.org/docs/stable/nn.html#upsample + See also: https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html align_corners: set the align_corners parameter of `torch.nn.Upsample`. Defaults to True. Only used in the "nontrainable" mode. bias: whether to have a bias term in the default preconv and deconv layers. Defaults to True. apply_pad_pool: if True the upsampled tensor is padded then average pooling is applied with a kernel the size of `scale_factor` with a stride of 1. See also: :py:class:`monai.networks.blocks.SubpixelUpsample`. Only used in the "pixelshuffle" mode. + + .. deprecated:: 0.6.0 + ``dimensions`` is deprecated, use ``spatial_dims`` instead. """ super().__init__() - scale_factor_ = ensure_tuple_rep(scale_factor, dimensions) + if dimensions is not None: + spatial_dims = dimensions + scale_factor_ = ensure_tuple_rep(scale_factor, spatial_dims) up_mode = look_up_option(mode, UpsampleMode) if up_mode == UpsampleMode.DECONV: if not in_channels: raise ValueError(f"in_channels needs to be specified in the '{mode}' mode.") self.add_module( "deconv", - Conv[Conv.CONVTRANS, dimensions]( + Conv[Conv.CONVTRANS, spatial_dims]( in_channels=in_channels, out_channels=out_channels or in_channels, kernel_size=scale_factor_, @@ -98,7 +107,7 @@ def __init__( raise ValueError(f"in_channels needs to be specified in the '{mode}' mode.") self.add_module( "preconv", - Conv[Conv.CONV, dimensions]( + Conv[Conv.CONV, spatial_dims]( in_channels=in_channels, out_channels=out_channels or in_channels, kernel_size=1, bias=bias ), ) @@ -112,7 +121,7 @@ def __init__( interp_mode = InterpolateMode(interp_mode) linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR] if interp_mode in linear_mode: # choose mode based on dimensions - interp_mode = linear_mode[dimensions - 1] + interp_mode = linear_mode[spatial_dims - 1] self.add_module( "upsample_non_trainable", nn.Upsample( @@ -126,7 +135,7 @@ def __init__( self.add_module( "pixelshuffle", SubpixelUpsample( - dimensions=dimensions, + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, scale_factor=scale_factor_[0], # isotropic @@ -164,19 +173,23 @@ class SubpixelUpsample(nn.Module): """ + @deprecated_arg( + name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." + ) def __init__( self, - dimensions: int, + spatial_dims: int, in_channels: Optional[int], out_channels: Optional[int] = None, scale_factor: int = 2, conv_block: Optional[Union[nn.Module, str]] = "default", apply_pad_pool: bool = True, bias: bool = True, + dimensions: Optional[int] = None, ) -> None: """ Args: - dimensions: number of spatial dimensions of the input image. + spatial_dims: number of spatial dimensions of the input image. in_channels: number of channels of the input image. out_channels: optional number of channels of the output image. scale_factor: multiplier for spatial size. Defaults to 2. @@ -190,21 +203,24 @@ def __init__( size of `scale_factor` with a stride of 1. This implements the nearest neighbour resize convolution component of subpixel convolutions described in Aitken et al. bias: whether to have a bias term in the default conv_block. Defaults to True. + + .. deprecated:: 0.6.0 + ``dimensions`` is deprecated, use ``spatial_dims`` instead. """ super().__init__() if scale_factor <= 0: raise ValueError(f"The `scale_factor` multiplier must be an integer greater than 0, got {scale_factor}.") - self.dimensions = dimensions + self.dimensions = spatial_dims if dimensions is None else dimensions self.scale_factor = scale_factor if conv_block == "default": out_channels = out_channels or in_channels if not out_channels: raise ValueError("in_channels need to be specified.") - conv_out_channels = out_channels * (scale_factor ** dimensions) - self.conv_block = Conv[Conv.CONV, dimensions]( + conv_out_channels = out_channels * (scale_factor**self.dimensions) + self.conv_block = Conv[Conv.CONV, self.dimensions]( in_channels=in_channels, out_channels=conv_out_channels, kernel_size=3, stride=1, padding=1, bias=bias ) @@ -231,7 +247,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x: Tensor in shape (batch, channel, spatial_1[, spatial_2, ...). """ x = self.conv_block(x) - if x.shape[1] % (self.scale_factor ** self.dimensions) != 0: + if x.shape[1] % (self.scale_factor**self.dimensions) != 0: raise ValueError( f"Number of channels after `conv_block` ({x.shape[1]}) must be evenly " "divisible by scale_factor ** dimensions " diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index d916c026ff..5b925258b6 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,6 +18,7 @@ from monai.config.deviceconfig import USE_COMPILED from monai.networks.layers.spatial_transforms import grid_pull +from monai.networks.utils import meshgrid_ij from monai.utils import GridSampleMode, GridSamplePadMode, optional_import _C, _ = optional_import("monai._C") @@ -30,18 +31,14 @@ class Warp(nn.Module): Warp an image with given dense displacement field (DDF). """ - def __init__( - self, - mode=GridSampleMode.BILINEAR.value, - padding_mode=GridSamplePadMode.BORDER.value, - ): + def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePadMode.BORDER.value): """ For pytorch native APIs, the possible values are: - mode: ``"nearest"``, ``"bilinear"``, ``"bicubic"``. - padding_mode: ``"zeros"``, ``"border"``, ``"reflection"`` - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html For MONAI C++/CUDA extensions, the possible values are: @@ -50,7 +47,7 @@ def __init__( See also: :py:class:`monai.networks.layers.grid_pull` """ - super(Warp, self).__init__() + super().__init__() # resolves _interp_mode for different methods if USE_COMPILED: @@ -88,7 +85,7 @@ def __init__( @staticmethod def get_reference_grid(ddf: torch.Tensor) -> torch.Tensor: mesh_points = [torch.arange(0, dim) for dim in ddf.shape[2:]] - grid = torch.stack(torch.meshgrid(*mesh_points), dim=0) # (spatial_dims, ...) + grid = torch.stack(meshgrid_ij(*mesh_points), dim=0) # (spatial_dims, ...) grid = torch.stack([grid] * ddf.shape[0], dim=0) # (batch, spatial_dims, ...) grid = grid.to(ddf) return grid @@ -123,13 +120,7 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor): ) # using csrc resampling - return grid_pull( - image, - grid, - bound=self._padding_mode, - extrapolate=True, - interpolation=self._interp_mode, - ) + return grid_pull(image, grid, bound=self._padding_mode, extrapolate=True, interpolation=self._interp_mode) class DVF2DDF(nn.Module): @@ -143,12 +134,9 @@ class DVF2DDF(nn.Module): """ def __init__( - self, - num_steps: int = 7, - mode=GridSampleMode.BILINEAR.value, - padding_mode=GridSamplePadMode.ZEROS.value, + self, num_steps: int = 7, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePadMode.ZEROS.value ): - super(DVF2DDF, self).__init__() + super().__init__() if num_steps <= 0: raise ValueError(f"expecting positive num_steps, got {num_steps}") self.num_steps = num_steps @@ -162,7 +150,7 @@ def forward(self, dvf): Returns: a dense displacement field """ - ddf: torch.Tensor = dvf / (2 ** self.num_steps) + ddf: torch.Tensor = dvf / (2**self.num_steps) for _ in range(self.num_steps): ddf = ddf + self.warp_layer(image=ddf, ddf=ddf) return ddf diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index b2defc703d..5115c00af3 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,6 +22,7 @@ Reshape, SavitzkyGolayFilter, SkipConnection, + apply_filter, separable_filtering, ) from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push diff --git a/monai/networks/layers/convutils.py b/monai/networks/layers/convutils.py index 994ca05b85..1e9ce954e8 100644 --- a/monai/networks/layers/convutils.py +++ b/monai/networks/layers/convutils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -44,8 +44,7 @@ def same_padding( def stride_minus_kernel_padding( - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], + kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int] ) -> Union[Tuple[int, ...], int]: kernel_size_np = np.atleast_1d(kernel_size) stride_np = np.atleast_1d(stride) @@ -116,7 +115,7 @@ def gaussian_1d( out = out.clamp(min=0) elif approx.lower() == "sampled": x = torch.arange(-tail, tail + 1, dtype=torch.float, device=sigma.device) - out = torch.exp(-0.5 / (sigma * sigma) * x ** 2) + out = torch.exp(-0.5 / (sigma * sigma) * x**2) if not normalize: # compute the normalizer out = out / (2.5066282 * sigma) elif approx.lower() == "scalespace": diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index d4de08fc50..6379f49449 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index 3b2214d59a..bbf925eba9 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/layers/gmm.py b/monai/networks/layers/gmm.py index 3091f95458..eb9a3f91e4 100644 --- a/monai/networks/layers/gmm.py +++ b/monai/networks/layers/gmm.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index 52f19aab29..7a0a45cb64 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,6 +10,7 @@ # limitations under the License. import math +from copy import deepcopy from typing import List, Sequence, Union import torch @@ -20,29 +21,30 @@ from monai.networks.layers.convutils import gaussian_1d from monai.networks.layers.factories import Conv from monai.utils import ( - PT_BEFORE_1_7, ChannelMatching, InvalidPyTorchVersionError, SkipMode, - ensure_tuple_rep, look_up_option, optional_import, + pytorch_after, ) +from monai.utils.misc import issequenceiterable _C, _ = optional_import("monai._C") -if not PT_BEFORE_1_7: +if pytorch_after(1, 7): fft, _ = optional_import("torch.fft") __all__ = [ - "SkipConnection", + "ChannelPad", "Flatten", "GaussianFilter", + "HilbertTransform", "LLTM", "Reshape", - "separable_filtering", "SavitzkyGolayFilter", - "HilbertTransform", - "ChannelPad", + "SkipConnection", + "apply_filter", + "separable_filtering", ] @@ -210,25 +212,97 @@ def separable_filtering(x: torch.Tensor, kernels: List[torch.Tensor], mode: str Args: x: the input image. must have shape (batch, channels, H[, W, ...]). kernels: kernel along each spatial dimension. - could be a single kernel (duplicated for all dimension), or `spatial_dims` number of kernels. + could be a single kernel (duplicated for all spatial dimensions), or + a list of `spatial_dims` number of kernels. mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` - or ``'circular'``. Default: ``'zeros'``. Modes other than ``'zeros'`` require PyTorch version >= 1.5.1. See - torch.nn.Conv1d() for more information. + or ``'circular'``. Default: ``'zeros'``. See ``torch.nn.Conv1d()`` for more information. Raises: TypeError: When ``x`` is not a ``torch.Tensor``. + + Examples: + + .. code-block:: python + + >>> import torch + >>> from monai.networks.layers import separable_filtering + >>> img = torch.randn(2, 4, 32, 32) # batch_size 2, channels 4, 32x32 2D images + # applying a [-1, 0, 1] filter along each of the spatial dimensions. + # the output shape is the same as the input shape. + >>> out = separable_filtering(img, torch.tensor((-1., 0., 1.))) + # applying `[-1, 0, 1]`, `[1, 0, -1]` filters along two spatial dimensions respectively. + # the output shape is the same as the input shape. + >>> out = separable_filtering(img, [torch.tensor((-1., 0., 1.)), torch.tensor((1., 0., -1.))]) + """ if not isinstance(x, torch.Tensor): raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.") spatial_dims = len(x.shape) - 2 - _kernels = [s.float() for s in kernels] + if isinstance(kernels, torch.Tensor): + kernels = [kernels] * spatial_dims + _kernels = [s.to(x) for s in kernels] _paddings = [(k.shape[0] - 1) // 2 for k in _kernels] n_chs = x.shape[1] pad_mode = "constant" if mode == "zeros" else mode - return _separable_filtering_conv(x, kernels, pad_mode, spatial_dims - 1, spatial_dims, _paddings, n_chs) + return _separable_filtering_conv(x, _kernels, pad_mode, spatial_dims - 1, spatial_dims, _paddings, n_chs) + + +def apply_filter(x: torch.Tensor, kernel: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Filtering `x` with `kernel` independently for each batch and channel respectively. + + Args: + x: the input image, must have shape (batch, channels, H[, W, D]). + kernel: `kernel` must at least have the spatial shape (H_k[, W_k, D_k]). + `kernel` shape must be broadcastable to the `batch` and `channels` dimensions of `x`. + kwargs: keyword arguments passed to `conv*d()` functions. + + Returns: + The filtered `x`. + + Examples: + + .. code-block:: python + + >>> import torch + >>> from monai.networks.layers import apply_filter + >>> img = torch.rand(2, 5, 10, 10) # batch_size 2, channels 5, 10x10 2D images + >>> out = apply_filter(img, torch.rand(3, 3)) # spatial kernel + >>> out = apply_filter(img, torch.rand(5, 3, 3)) # channel-wise kernels + >>> out = apply_filter(img, torch.rand(2, 5, 3, 3)) # batch-, channel-wise kernels + + """ + if not isinstance(x, torch.Tensor): + raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.") + batch, chns, *spatials = x.shape + n_spatial = len(spatials) + if n_spatial > 3: + raise NotImplementedError(f"Only spatial dimensions up to 3 are supported but got {n_spatial}.") + k_size = len(kernel.shape) + if k_size < n_spatial or k_size > n_spatial + 2: + raise ValueError( + f"kernel must have {n_spatial} ~ {n_spatial + 2} dimensions to match the input shape {x.shape}." + ) + kernel = kernel.to(x) + # broadcast kernel size to (batch chns, spatial_kernel_size) + kernel = kernel.expand(batch, chns, *kernel.shape[(k_size - n_spatial) :]) + kernel = kernel.reshape(-1, 1, *kernel.shape[2:]) # group=1 + x = x.view(1, kernel.shape[0], *spatials) + conv = [F.conv1d, F.conv2d, F.conv3d][n_spatial - 1] + if "padding" not in kwargs: + if pytorch_after(1, 10): + kwargs["padding"] = "same" + else: + # even-sized kernels are not supported + kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]] + + if "stride" not in kwargs: + kwargs["stride"] = 1 + output = conv(x, kernel, groups=kernel.shape[0], bias=None, **kwargs) + return output.view(batch, chns, *output.shape[2:]) class SavitzkyGolayFilter(nn.Module): @@ -270,7 +344,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.to(dtype=torch.float) if (self.axis < 0) or (self.axis > len(x.shape) - 1): - raise ValueError("Invalid axis for shape of x.") + raise ValueError(f"Invalid axis for shape of x, got axis {self.axis} and shape {x.shape}.") # Create list of filter kernels (1 per spatial dimension). The kernel for self.axis will be the savgol coeffs, # while the other kernels will be set to [1]. @@ -307,12 +381,12 @@ class HilbertTransform(nn.Module): Args: axis: Axis along which to apply Hilbert transform. Default 2 (first spatial dimension). - N: Number of Fourier components (i.e. FFT size). Default: ``x.shape[axis]``. + n: Number of Fourier components (i.e. FFT size). Default: ``x.shape[axis]``. """ def __init__(self, axis: int = 2, n: Union[int, None] = None) -> None: - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__) super().__init__() @@ -335,7 +409,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.to(dtype=torch.float) if (self.axis < 0) or (self.axis > len(x.shape) - 1): - raise ValueError("Invalid axis for shape of x.") + raise ValueError(f"Invalid axis for shape of x, got axis {self.axis} and shape {x.shape}.") n = x.shape[self.axis] if self.n is None else self.n if n <= 0: @@ -393,13 +467,18 @@ def __init__( (for example `parameters()` iterator could be used to get the parameters); otherwise this module will fix the kernels using `sigma` as the std. """ + if issequenceiterable(sigma): + if len(sigma) != spatial_dims: # type: ignore + raise ValueError + else: + sigma = [deepcopy(sigma) for _ in range(spatial_dims)] # type: ignore super().__init__() self.sigma = [ torch.nn.Parameter( torch.as_tensor(s, dtype=torch.float, device=s.device if isinstance(s, torch.Tensor) else None), requires_grad=requires_grad, ) - for s in ensure_tuple_rep(sigma, int(spatial_dims)) + for s in sigma # type: ignore ] self.truncated = truncated self.approx = approx @@ -449,7 +528,7 @@ class LLTM(nn.Module): """ def __init__(self, input_features: int, state_size: int): - super(LLTM, self).__init__() + super().__init__() self.input_features = input_features self.state_size = state_size self.weights = nn.Parameter(torch.empty(3 * state_size, input_features + state_size)) diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py index 511c24fcb0..c1bb951c4d 100644 --- a/monai/networks/layers/spatial_transforms.py +++ b/monai/networks/layers/spatial_transforms.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -46,7 +46,9 @@ def backward(ctx, grad): return None, grads[0], None, None, None -def grid_pull(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", bound="zero", extrapolate: bool = True): +def grid_pull( + input: torch.Tensor, grid: torch.Tensor, interpolation="linear", bound="zero", extrapolate: bool = True +) -> torch.Tensor: """ Sample an image with respect to a deformation field. @@ -68,13 +70,13 @@ def grid_pull(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", b `bound` can be an int, a string or a BoundType. Possible values are:: - - 0 or 'replicate' or 'nearest' or BoundType.replicate + - 0 or 'replicate' or 'nearest' or BoundType.replicate or 'border' - 1 or 'dct1' or 'mirror' or BoundType.dct1 - 2 or 'dct2' or 'reflect' or BoundType.dct2 - 3 or 'dst1' or 'antimirror' or BoundType.dst1 - 4 or 'dst2' or 'antireflect' or BoundType.dst2 - 5 or 'dft' or 'wrap' or BoundType.dft - - 7 or 'zero' or BoundType.zero + - 7 or 'zero' or 'zeros' or BoundType.zero A list of values can be provided, in the order [W, H, D], to specify dimension-specific boundary conditions. @@ -112,8 +114,9 @@ def grid_pull(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", b _C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i) for i in ensure_tuple(interpolation) ] - - return _GridPull.apply(input, grid, interpolation, bound, extrapolate) + out: torch.Tensor + out = _GridPull.apply(input, grid, interpolation, bound, extrapolate) + return out class _GridPush(torch.autograd.Function): @@ -443,11 +446,11 @@ def __init__( coordinates. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"zeros"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - align_corners: see also https://pytorch.org/docs/stable/nn.functional.html#grid-sample. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + align_corners: see also https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html. reverse_indexing: whether to reverse the spatial indexing of image and coordinates. set to `False` if `theta` follows pytorch's default "D, H, W" convention. set to `True` if `theta` follows `scipy.ndimage` default "i, j, k" convention. diff --git a/monai/networks/layers/utils.py b/monai/networks/layers/utils.py index 380a77552c..42fac58716 100644 --- a/monai/networks/layers/utils.py +++ b/monai/networks/layers/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index ad1ca2418b..16686fa25c 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,7 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .ahnet import AHnet, Ahnet, AHNet, ahnet +from .ahnet import AHnet, Ahnet, AHNet +from .attentionunet import AttentionUnet from .autoencoder import AutoEncoder from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet from .classifier import Classifier, Critic, Discriminator @@ -24,13 +25,13 @@ Densenet201, DenseNet264, Densenet264, - densenet, densenet121, densenet169, densenet201, densenet264, ) -from .dynunet import DynUNet, DynUnet, Dynunet, dynunet +from .dints import DiNTS, TopologyConstruction, TopologyInstance, TopologySearch +from .dynunet import DynUNet, DynUnet, Dynunet from .efficientnet import ( BlockArgs, EfficientNet, @@ -42,6 +43,7 @@ from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet from .generator import Generator from .highresnet import HighResBlock, HighResNet +from .milmodel import MILModel from .netadapter import NetAdapter from .regressor import Regressor from .regunet import GlobalNet, LocalNet, RegUNet @@ -71,7 +73,6 @@ SEResNeXt101, SEresnext101, Seresnext101, - senet, senet154, seresnet50, seresnet101, @@ -80,8 +81,10 @@ seresnext101, ) from .torchvision_fc import TorchVisionFCModel, TorchVisionFullyConvModel -from .unet import UNet, Unet, unet +from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex +from .unet import UNet, Unet from .unetr import UNETR from .varautoencoder import VarAutoEncoder from .vit import ViT +from .vitautoenc import ViTAutoEnc from .vnet import VNet diff --git a/monai/networks/nets/ahnet.py b/monai/networks/nets/ahnet.py index 5ca6813efe..b481374aa1 100644 --- a/monai/networks/nets/ahnet.py +++ b/monai/networks/nets/ahnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,7 +19,7 @@ from monai.networks.blocks.fcn import FCN from monai.networks.layers.factories import Act, Conv, Norm, Pool -__all__ = ["AHnet", "Ahnet", "ahnet", "AHNet"] +__all__ = ["AHnet", "Ahnet", "AHNet"] class Bottleneck3x3x1(nn.Module): @@ -35,7 +35,7 @@ def __init__( downsample: Optional[nn.Sequential] = None, ) -> None: - super(Bottleneck3x3x1, self).__init__() + super().__init__() conv_type = Conv[Conv.CONV, spatial_dims] norm_type: Type[Union[nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] @@ -87,7 +87,7 @@ def forward(self, x): class Projection(nn.Sequential): def __init__(self, spatial_dims: int, num_input_features: int, num_output_features: int): - super(Projection, self).__init__() + super().__init__() conv_type = Conv[Conv.CONV, spatial_dims] norm_type: Type[Union[nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] @@ -108,7 +108,7 @@ def __init__( growth_rate: int, dropout_prob: float, ): - super(DenseBlock, self).__init__() + super().__init__() for i in range(num_layers): layer = Pseudo3DLayer( spatial_dims, num_input_features + i * growth_rate, growth_rate, bn_size, dropout_prob @@ -120,7 +120,7 @@ class UpTransition(nn.Sequential): def __init__( self, spatial_dims: int, num_input_features: int, num_output_features: int, upsample_mode: str = "transpose" ): - super(UpTransition, self).__init__() + super().__init__() conv_type = Conv[Conv.CONV, spatial_dims] norm_type: Type[Union[nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] @@ -145,7 +145,7 @@ class Final(nn.Sequential): def __init__( self, spatial_dims: int, num_input_features: int, num_output_features: int, upsample_mode: str = "transpose" ): - super(Final, self).__init__() + super().__init__() conv_type = Conv[Conv.CONV, spatial_dims] norm_type: Type[Union[nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] @@ -178,7 +178,7 @@ def __init__( class Pseudo3DLayer(nn.Module): def __init__(self, spatial_dims: int, num_input_features: int, growth_rate: int, bn_size: int, dropout_prob: float): - super(Pseudo3DLayer, self).__init__() + super().__init__() # 1x1x1 conv_type = Conv[Conv.CONV, spatial_dims] @@ -244,7 +244,7 @@ def forward(self, x): class PSP(nn.Module): def __init__(self, spatial_dims: int, psp_block_num: int, in_ch: int, upsample_mode: str = "transpose"): - super(PSP, self).__init__() + super().__init__() self.up_modules = nn.ModuleList() conv_type = Conv[Conv.CONV, spatial_dims] pool_type: Type[Union[nn.MaxPool2d, nn.MaxPool3d]] = Pool[Pool.MAX, spatial_dims] @@ -256,13 +256,7 @@ def __init__(self, spatial_dims: int, psp_block_num: int, in_ch: int, upsample_m size = (2 ** (i + 3), 2 ** (i + 3), 1)[-spatial_dims:] self.pool_modules.append(pool_type(kernel_size=size, stride=size)) self.project_modules.append( - conv_type( - in_ch, - 1, - kernel_size=(1, 1, 1)[-spatial_dims:], - stride=1, - padding=(1, 1, 0)[-spatial_dims:], - ) + conv_type(in_ch, 1, kernel_size=(1, 1, 1)[-spatial_dims:], stride=1, padding=(1, 1, 0)[-spatial_dims:]) ) self.spatial_dims = spatial_dims @@ -274,15 +268,7 @@ def __init__(self, spatial_dims: int, psp_block_num: int, in_ch: int, upsample_m for i in range(psp_block_num): size = (2 ** (i + 3), 2 ** (i + 3), 1)[-spatial_dims:] pad_size = (2 ** (i + 3), 2 ** (i + 3), 0)[-spatial_dims:] - self.up_modules.append( - conv_trans_type( - 1, - 1, - kernel_size=size, - stride=size, - padding=pad_size, - ) - ) + self.up_modules.append(conv_trans_type(1, 1, kernel_size=size, stride=size, padding=pad_size)) def forward(self, x): outputs = [] @@ -356,7 +342,7 @@ def __init__( progress: bool = True, ): self.inplanes = 64 - super(AHNet, self).__init__() + super().__init__() conv_type = Conv[Conv.CONV, spatial_dims] conv_trans_type = Conv[Conv.CONVTRANS, spatial_dims] @@ -451,13 +437,7 @@ def __init__( net2d = FCN(pretrained=True, progress=progress) self.copy_from(net2d) - def _make_layer( - self, - block: Type[Bottleneck3x3x1], - planes: int, - blocks: int, - stride: int = 1, - ) -> nn.Sequential: + def _make_layer(self, block: Type[Bottleneck3x3x1], planes: int, blocks: int, stride: int = 1) -> nn.Sequential: downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( @@ -559,4 +539,4 @@ def copy_bn_param(module2d, module3d): p3d.data[:] = p2d.data[:] # Two parameter gamma and beta -AHnet = Ahnet = ahnet = AHNet +AHnet = Ahnet = AHNet diff --git a/monai/networks/nets/attentionunet.py b/monai/networks/nets/attentionunet.py new file mode 100644 index 0000000000..177a54e105 --- /dev/null +++ b/monai/networks/nets/attentionunet.py @@ -0,0 +1,257 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence, Union + +import torch +import torch.nn as nn + +from monai.networks.blocks.convolutions import Convolution +from monai.networks.layers.factories import Norm + +__all__ = ["AttentionUnet"] + + +class ConvBlock(nn.Module): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + strides: int = 1, + dropout=0.0, + ): + super().__init__() + layers = [ + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + strides=strides, + padding=None, + adn_ordering="NDA", + act="relu", + norm=Norm.BATCH, + dropout=dropout, + ), + Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + strides=1, + padding=None, + adn_ordering="NDA", + act="relu", + norm=Norm.BATCH, + dropout=dropout, + ), + ] + self.conv = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_c: torch.Tensor = self.conv(x) + return x_c + + +class UpConv(nn.Module): + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size=3, strides=2, dropout=0.0): + super().__init__() + self.up = Convolution( + spatial_dims, + in_channels, + out_channels, + strides=strides, + kernel_size=kernel_size, + act="relu", + adn_ordering="NDA", + norm=Norm.BATCH, + dropout=dropout, + is_transposed=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_u: torch.Tensor = self.up(x) + return x_u + + +class AttentionBlock(nn.Module): + def __init__(self, spatial_dims: int, f_int: int, f_g: int, f_l: int, dropout=0.0): + super().__init__() + self.W_g = nn.Sequential( + Convolution( + spatial_dims=spatial_dims, + in_channels=f_g, + out_channels=f_int, + kernel_size=1, + strides=1, + padding=0, + dropout=dropout, + conv_only=True, + ), + Norm[Norm.BATCH, spatial_dims](f_int), + ) + + self.W_x = nn.Sequential( + Convolution( + spatial_dims=spatial_dims, + in_channels=f_l, + out_channels=f_int, + kernel_size=1, + strides=1, + padding=0, + dropout=dropout, + conv_only=True, + ), + Norm[Norm.BATCH, spatial_dims](f_int), + ) + + self.psi = nn.Sequential( + Convolution( + spatial_dims=spatial_dims, + in_channels=f_int, + out_channels=1, + kernel_size=1, + strides=1, + padding=0, + dropout=dropout, + conv_only=True, + ), + Norm[Norm.BATCH, spatial_dims](1), + nn.Sigmoid(), + ) + + self.relu = nn.ReLU() + + def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + g1 = self.W_g(g) + x1 = self.W_x(x) + psi: torch.Tensor = self.relu(g1 + x1) + psi = self.psi(psi) + + return x * psi + + +class AttentionLayer(nn.Module): + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, submodule: nn.Module, dropout=0.0): + super().__init__() + self.attention = AttentionBlock( + spatial_dims=spatial_dims, f_g=in_channels, f_l=in_channels, f_int=in_channels // 2 + ) + self.upconv = UpConv(spatial_dims=spatial_dims, in_channels=out_channels, out_channels=in_channels, strides=2) + self.merge = Convolution( + spatial_dims=spatial_dims, in_channels=2 * in_channels, out_channels=in_channels, dropout=dropout + ) + self.submodule = submodule + + def forward(self, x: torch.Tensor) -> torch.Tensor: + fromlower = self.upconv(self.submodule(x)) + att = self.attention(g=fromlower, x=x) + att_m: torch.Tensor = self.merge(torch.cat((att, fromlower), dim=1)) + return att_m + + +class AttentionUnet(nn.Module): + """ + Attention Unet based on + Otkay et al. "Attention U-Net: Learning Where to Look for the Pancreas" + https://arxiv.org/abs/1804.03999 + + Args: + spatial_dims: number of spatial dimensions of the input image. + in_channels: number of the input channel. + out_channels: number of the output classes. + channels (Sequence[int]): sequence of channels. Top block first. The length of `channels` should be no less than 2. + strides (Sequence[int]): stride to use for convolutions. + kernel_size: convolution kernel size. + upsample_kernel_size: convolution kernel size for transposed convolution layers. + dropout: dropout ratio. Defaults to no dropout. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + channels: Sequence[int], + strides: Sequence[int], + kernel_size: Union[Sequence[int], int] = 3, + up_kernel_size: Union[Sequence[int], int] = 3, + dropout: float = 0.0, + ): + super().__init__() + self.dimensions = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.channels = channels + self.strides = strides + self.kernel_size = kernel_size + self.dropout = dropout + + head = ConvBlock(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=channels[0], dropout=dropout) + reduce_channels = Convolution( + spatial_dims=spatial_dims, + in_channels=channels[0], + out_channels=out_channels, + kernel_size=1, + strides=1, + padding=0, + conv_only=True, + ) + self.up_kernel_size = up_kernel_size + + def _create_block(channels: Sequence[int], strides: Sequence[int], level: int = 0) -> nn.Module: + if len(channels) > 2: + subblock = _create_block(channels[1:], strides[1:], level=level + 1) + return AttentionLayer( + spatial_dims=spatial_dims, + in_channels=channels[0], + out_channels=channels[1], + submodule=nn.Sequential( + ConvBlock( + spatial_dims=spatial_dims, + in_channels=channels[0], + out_channels=channels[1], + strides=strides[0], + dropout=self.dropout, + ), + subblock, + ), + dropout=dropout, + ) + else: + # the next layer is the bottom so stop recursion, + # create the bottom layer as the sublock for this layer + return self._get_bottom_layer(channels[0], channels[1], strides[0], level=level + 1) + + encdec = _create_block(self.channels, self.strides) + self.model = nn.Sequential(head, encdec, reduce_channels) + + def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int, level: int) -> nn.Module: + return AttentionLayer( + spatial_dims=self.dimensions, + in_channels=in_channels, + out_channels=out_channels, + submodule=ConvBlock( + spatial_dims=self.dimensions, + in_channels=in_channels, + out_channels=out_channels, + strides=strides, + dropout=self.dropout, + ), + dropout=self.dropout, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_m: torch.Tensor = self.model(x) + return x_m diff --git a/monai/networks/nets/autoencoder.py b/monai/networks/nets/autoencoder.py index d0a54b8148..75edde70eb 100644 --- a/monai/networks/nets/autoencoder.py +++ b/monai/networks/nets/autoencoder.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,14 +16,84 @@ from monai.networks.blocks import Convolution, ResidualUnit from monai.networks.layers.factories import Act, Norm +from monai.utils import deprecated_arg __all__ = ["AutoEncoder"] class AutoEncoder(nn.Module): + """ + Simple definition of an autoencoder and base class for the architecture implementing + :py:class:`monai.networks.nets.VarAutoEncoder`. The network is composed of an encode sequence of blocks, followed + by an intermediary sequence of blocks, and finally a decode sequence of blocks. The encode and decode blocks are + default :py:class:`monai.networks.blocks.Convolution` instances with the encode blocks having the given stride + and the decode blocks having transpose convolutions with the same stride. If `num_res_units` is given residual + blocks are used instead. + + By default the intermediary sequence is empty but if `inter_channels` is given to specify the output channels of + blocks then this will be become a sequence of Convolution blocks or of residual blocks if `num_inter_units` is + given. The optional parameter `inter_dilations` can be used to specify the dilation values of the convolutions in + these blocks, this allows a network to use dilated kernels in this middle section. Since the intermediary section + isn't meant to change the size of the output the strides for all these kernels is 1. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + channels: sequence of channels. Top block first. The length of `channels` should be no less than 2. + strides: sequence of convolution strides. The length of `stride` should equal to `len(channels) - 1`. + kernel_size: convolution kernel size, the value(s) should be odd. If sequence, + its length should equal to dimensions. Defaults to 3. + up_kernel_size: upsampling convolution kernel size, the value(s) should be odd. If sequence, + its length should equal to dimensions. Defaults to 3. + num_res_units: number of residual units. Defaults to 0. + inter_channels: sequence of channels defining the blocks in the intermediate layer between encode and decode. + inter_dilations: defines the dilation value for each block of the intermediate layer. Defaults to 1. + num_inter_units: number of residual units for each block of the intermediate layer. Defaults to 0. + act: activation type and arguments. Defaults to PReLU. + norm: feature normalization type and arguments. Defaults to instance norm. + dropout: dropout ratio. Defaults to no dropout. + bias: whether to have a bias term in convolution blocks. Defaults to True. + According to `Performance Tuning Guide `_, + if a conv layer is directly followed by a batch norm layer, bias should be False. + + .. deprecated:: 0.6.0 + ``dimensions`` is deprecated, use ``spatial_dims`` instead. + + Examples:: + + from monai.networks.nets import AutoEncoder + + # 3 layers each down/up sampling their inputs by a factor 2 with no intermediate layer + net = AutoEncoder( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(2, 4, 8), + strides=(2, 2, 2) + ) + + # 1 layer downsampling by 2, followed by a sequence of residual units with 2 convolutions defined by + # progressively increasing dilations, then final upsample layer + net = AutoEncoder( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(4,), + strides=(2,), + inter_channels=(8, 8, 8), + inter_dilations=(1, 2, 4), + num_inter_units=2 + ) + + """ + + @deprecated_arg( + name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." + ) def __init__( self, - dimensions: int, + spatial_dims: int, in_channels: int, out_channels: int, channels: Sequence[int], @@ -38,10 +108,11 @@ def __init__( norm: Union[Tuple, str] = Norm.INSTANCE, dropout: Optional[Union[Tuple, str, float]] = None, bias: bool = True, + dimensions: Optional[int] = None, ) -> None: super().__init__() - self.dimensions = dimensions + self.dimensions = spatial_dims if dimensions is None else dimensions self.in_channels = in_channels self.out_channels = out_channels self.channels = list(channels) @@ -71,6 +142,9 @@ def __init__( def _get_encode_module( self, in_channels: int, channels: Sequence[int], strides: Sequence[int] ) -> Tuple[nn.Sequential, int]: + """ + Returns the encode part of the network by building up a sequence of layers returned by `_get_encode_layer`. + """ encode = nn.Sequential() layer_channels = in_channels @@ -82,6 +156,10 @@ def _get_encode_module( return encode, layer_channels def _get_intermediate_module(self, in_channels: int, num_inter_units: int) -> Tuple[nn.Module, int]: + """ + Returns the intermediate block of the network which accepts input from the encoder and whose output goes + to the decoder. + """ # Define some types intermediate: nn.Module unit: nn.Module @@ -95,7 +173,7 @@ def _get_intermediate_module(self, in_channels: int, num_inter_units: int) -> Tu for i, (dc, di) in enumerate(zip(self.inter_channels, self.inter_dilations)): if self.num_inter_units > 0: unit = ResidualUnit( - dimensions=self.dimensions, + spatial_dims=self.dimensions, in_channels=layer_channels, out_channels=dc, strides=1, @@ -109,7 +187,7 @@ def _get_intermediate_module(self, in_channels: int, num_inter_units: int) -> Tu ) else: unit = Convolution( - dimensions=self.dimensions, + spatial_dims=self.dimensions, in_channels=layer_channels, out_channels=dc, strides=1, @@ -129,6 +207,9 @@ def _get_intermediate_module(self, in_channels: int, num_inter_units: int) -> Tu def _get_decode_module( self, in_channels: int, channels: Sequence[int], strides: Sequence[int] ) -> Tuple[nn.Sequential, int]: + """ + Returns the decode part of the network by building up a sequence of layers returned by `_get_decode_layer`. + """ decode = nn.Sequential() layer_channels = in_channels @@ -140,10 +221,13 @@ def _get_decode_module( return decode, layer_channels def _get_encode_layer(self, in_channels: int, out_channels: int, strides: int, is_last: bool) -> nn.Module: - + """ + Returns a single layer of the encoder part of the network. + """ + mod: nn.Module if self.num_res_units > 0: - return ResidualUnit( - dimensions=self.dimensions, + mod = ResidualUnit( + spatial_dims=self.dimensions, in_channels=in_channels, out_channels=out_channels, strides=strides, @@ -155,8 +239,8 @@ def _get_encode_layer(self, in_channels: int, out_channels: int, strides: int, i bias=self.bias, last_conv_only=is_last, ) - return Convolution( - dimensions=self.dimensions, + mod = Convolution( + spatial_dims=self.dimensions, in_channels=in_channels, out_channels=out_channels, strides=strides, @@ -167,13 +251,16 @@ def _get_encode_layer(self, in_channels: int, out_channels: int, strides: int, i bias=self.bias, conv_only=is_last, ) + return mod def _get_decode_layer(self, in_channels: int, out_channels: int, strides: int, is_last: bool) -> nn.Sequential: - + """ + Returns a single layer of the decoder part of the network. + """ decode = nn.Sequential() conv = Convolution( - dimensions=self.dimensions, + spatial_dims=self.dimensions, in_channels=in_channels, out_channels=out_channels, strides=strides, @@ -190,7 +277,7 @@ def _get_decode_layer(self, in_channels: int, out_channels: int, strides: int, i if self.num_res_units > 0: ru = ResidualUnit( - dimensions=self.dimensions, + spatial_dims=self.dimensions, in_channels=out_channels, out_channels=out_channels, strides=1, diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py index 63205f45ee..1e46846576 100644 --- a/monai/networks/nets/basic_unet.py +++ b/monai/networks/nets/basic_unet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,27 +16,29 @@ from monai.networks.blocks import Convolution, UpSample from monai.networks.layers.factories import Conv, Pool -from monai.utils import ensure_tuple_rep +from monai.utils import deprecated_arg, ensure_tuple_rep -__all__ = ["BasicUNet", "BasicUnet", "Basicunet"] +__all__ = ["BasicUnet", "Basicunet", "basicunet", "BasicUNet"] class TwoConv(nn.Sequential): """two convolutions.""" + @deprecated_arg(name="dim", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead.") def __init__( self, - dim: int, + spatial_dims: int, in_chns: int, out_chns: int, act: Union[str, tuple], norm: Union[str, tuple], bias: bool, dropout: Union[float, tuple] = 0.0, + dim: Optional[int] = None, ): """ Args: - dim: number of spatial dimensions. + spatial_dims: number of spatial dimensions. in_chns: number of input channels. out_chns: number of output channels. act: activation type and arguments. @@ -44,11 +46,17 @@ def __init__( bias: whether to have a bias term in convolution blocks. dropout: dropout ratio. Defaults to no dropout. + .. deprecated:: 0.6.0 + ``dim`` is deprecated, use ``spatial_dims`` instead. """ super().__init__() - conv_0 = Convolution(dim, in_chns, out_chns, act=act, norm=norm, dropout=dropout, bias=bias, padding=1) - conv_1 = Convolution(dim, out_chns, out_chns, act=act, norm=norm, dropout=dropout, bias=bias, padding=1) + if dim is not None: + spatial_dims = dim + conv_0 = Convolution(spatial_dims, in_chns, out_chns, act=act, norm=norm, dropout=dropout, bias=bias, padding=1) + conv_1 = Convolution( + spatial_dims, out_chns, out_chns, act=act, norm=norm, dropout=dropout, bias=bias, padding=1 + ) self.add_module("conv_0", conv_0) self.add_module("conv_1", conv_1) @@ -56,19 +64,21 @@ def __init__( class Down(nn.Sequential): """maxpooling downsampling and two convolutions.""" + @deprecated_arg(name="dim", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead.") def __init__( self, - dim: int, + spatial_dims: int, in_chns: int, out_chns: int, act: Union[str, tuple], norm: Union[str, tuple], bias: bool, dropout: Union[float, tuple] = 0.0, + dim: Optional[int] = None, ): """ Args: - dim: number of spatial dimensions. + spatial_dims: number of spatial dimensions. in_chns: number of input channels. out_chns: number of output channels. act: activation type and arguments. @@ -76,11 +86,14 @@ def __init__( bias: whether to have a bias term in convolution blocks. dropout: dropout ratio. Defaults to no dropout. + .. deprecated:: 0.6.0 + ``dim`` is deprecated, use ``spatial_dims`` instead. """ super().__init__() - - max_pooling = Pool["MAX", dim](kernel_size=2) - convs = TwoConv(dim, in_chns, out_chns, act, norm, bias, dropout) + if dim is not None: + spatial_dims = dim + max_pooling = Pool["MAX", spatial_dims](kernel_size=2) + convs = TwoConv(spatial_dims, in_chns, out_chns, act, norm, bias, dropout) self.add_module("max_pooling", max_pooling) self.add_module("convs", convs) @@ -88,9 +101,10 @@ def __init__( class UpCat(nn.Module): """upsampling, concatenation with the encoder feature map, two convolutions""" + @deprecated_arg(name="dim", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead.") def __init__( self, - dim: int, + spatial_dims: int, in_chns: int, cat_chns: int, out_chns: int, @@ -103,10 +117,11 @@ def __init__( interp_mode: str = "linear", align_corners: Optional[bool] = True, halves: bool = True, + dim: Optional[int] = None, ): """ Args: - dim: number of spatial dimensions. + spatial_dims: number of spatial dimensions. in_chns: number of input channels to be upsampled. cat_chns: number of channels from the decoder. out_chns: number of output channels. @@ -124,14 +139,19 @@ def __init__( Only used in the "nontrainable" mode. halves: whether to halve the number of channels during upsampling. This parameter does not work on ``nontrainable`` mode if ``pre_conv`` is `None`. + + .. deprecated:: 0.6.0 + ``dim`` is deprecated, use ``spatial_dims`` instead. """ super().__init__() + if dim is not None: + spatial_dims = dim if upsample == "nontrainable" and pre_conv is None: up_chns = in_chns else: up_chns = in_chns // 2 if halves else in_chns self.upsample = UpSample( - dim, + spatial_dims, in_chns, up_chns, 2, @@ -140,7 +160,7 @@ def __init__( interp_mode=interp_mode, align_corners=align_corners, ) - self.convs = TwoConv(dim, cat_chns + up_chns, out_chns, act, norm, bias, dropout) + self.convs = TwoConv(spatial_dims, cat_chns + up_chns, out_chns, act, norm, bias, dropout) def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]): """ @@ -167,9 +187,12 @@ def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]): class BasicUNet(nn.Module): + @deprecated_arg( + name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." + ) def __init__( self, - dimensions: int = 3, + spatial_dims: int = 3, in_channels: int = 1, out_channels: int = 2, features: Sequence[int] = (32, 32, 64, 128, 256, 32), @@ -178,6 +201,7 @@ def __init__( bias: bool = True, dropout: Union[float, tuple] = 0.0, upsample: str = "deconv", + dimensions: Optional[int] = None, ): """ A UNet implementation with 1D/2D/3D supports. @@ -189,7 +213,7 @@ def __init__( http://dx.doi.org/10.1038/s41592-018-0261-2 Args: - dimensions: number of spatial dimensions. Defaults to 3 for spatial 3D inputs. + spatial_dims: number of spatial dimensions. Defaults to 3 for spatial 3D inputs. in_channels: number of input channels. Defaults to 1. out_channels: number of output channels. Defaults to 2. features: six integers as numbers of features. @@ -207,16 +231,19 @@ def __init__( upsample: upsampling mode, available options are ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. + .. deprecated:: 0.6.0 + ``dimensions`` is deprecated, use ``spatial_dims`` instead. + Examples:: # for spatial 2D - >>> net = BasicUNet(dimensions=2, features=(64, 128, 256, 512, 1024, 128)) + >>> net = BasicUNet(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128)) # for spatial 2D, with group norm - >>> net = BasicUNet(dimensions=2, features=(64, 128, 256, 512, 1024, 128), norm=("group", {"num_groups": 4})) + >>> net = BasicUNet(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128), norm=("group", {"num_groups": 4})) # for spatial 3D - >>> net = BasicUNet(dimensions=3, features=(32, 32, 64, 128, 256, 32)) + >>> net = BasicUNet(spatial_dims=3, features=(32, 32, 64, 128, 256, 32)) See Also @@ -225,22 +252,24 @@ def __init__( """ super().__init__() + if dimensions is not None: + spatial_dims = dimensions fea = ensure_tuple_rep(features, 6) print(f"BasicUNet features: {fea}.") - self.conv_0 = TwoConv(dimensions, in_channels, features[0], act, norm, bias, dropout) - self.down_1 = Down(dimensions, fea[0], fea[1], act, norm, bias, dropout) - self.down_2 = Down(dimensions, fea[1], fea[2], act, norm, bias, dropout) - self.down_3 = Down(dimensions, fea[2], fea[3], act, norm, bias, dropout) - self.down_4 = Down(dimensions, fea[3], fea[4], act, norm, bias, dropout) + self.conv_0 = TwoConv(spatial_dims, in_channels, features[0], act, norm, bias, dropout) + self.down_1 = Down(spatial_dims, fea[0], fea[1], act, norm, bias, dropout) + self.down_2 = Down(spatial_dims, fea[1], fea[2], act, norm, bias, dropout) + self.down_3 = Down(spatial_dims, fea[2], fea[3], act, norm, bias, dropout) + self.down_4 = Down(spatial_dims, fea[3], fea[4], act, norm, bias, dropout) - self.upcat_4 = UpCat(dimensions, fea[4], fea[3], fea[3], act, norm, bias, dropout, upsample) - self.upcat_3 = UpCat(dimensions, fea[3], fea[2], fea[2], act, norm, bias, dropout, upsample) - self.upcat_2 = UpCat(dimensions, fea[2], fea[1], fea[1], act, norm, bias, dropout, upsample) - self.upcat_1 = UpCat(dimensions, fea[1], fea[0], fea[5], act, norm, bias, dropout, upsample, halves=False) + self.upcat_4 = UpCat(spatial_dims, fea[4], fea[3], fea[3], act, norm, bias, dropout, upsample) + self.upcat_3 = UpCat(spatial_dims, fea[3], fea[2], fea[2], act, norm, bias, dropout, upsample) + self.upcat_2 = UpCat(spatial_dims, fea[2], fea[1], fea[1], act, norm, bias, dropout, upsample) + self.upcat_1 = UpCat(spatial_dims, fea[1], fea[0], fea[5], act, norm, bias, dropout, upsample, halves=False) - self.final_conv = Conv["conv", dimensions](fea[5], out_channels, kernel_size=1) + self.final_conv = Conv["conv", spatial_dims](fea[5], out_channels, kernel_size=1) def forward(self, x: torch.Tensor): """ diff --git a/monai/networks/nets/classifier.py b/monai/networks/nets/classifier.py index 92fee4f566..7f4e43eedb 100644 --- a/monai/networks/nets/classifier.py +++ b/monai/networks/nets/classifier.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,6 +25,19 @@ class Classifier(Regressor): Defines a classification network from Regressor by specifying the output shape as a single dimensional tensor with size equal to the number of classes to predict. The final activation function can also be specified, eg. softmax or sigmoid. + + Args: + in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension) + classes: integer stating the dimension of the final output tensor + channels: tuple of integers stating the output channels of each convolutional layer + strides: tuple of integers stating the stride (downscale factor) of each convolutional layer + kernel_size: integer or tuple of integers stating size of convolutional kernels + num_res_units: integer stating number of convolutions in residual units, 0 means no residual units + act: name or type defining activation layers + norm: name or type defining normalization layers + dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout + bias: boolean stating if convolution layers should have a bias component + last_act: name defining the last activation layer """ def __init__( @@ -41,20 +54,6 @@ def __init__( bias: bool = True, last_act: Optional[str] = None, ) -> None: - """ - Args: - in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension) - classes: integer stating the dimension of the final output tensor - channels: tuple of integers stating the output channels of each convolutional layer - strides: tuple of integers stating the stride (downscale factor) of each convolutional layer - kernel_size: integer or tuple of integers stating size of convolutional kernels - num_res_units: integer stating number of convolutions in residual units, 0 means no residual units - act: name or type defining activation layers - norm: name or type defining normalization layers - dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout - bias: boolean stating if convolution layers should have a bias component - last_act: name defining the last activation layer - """ super().__init__(in_shape, (classes,), channels, strides, kernel_size, num_res_units, act, norm, dropout, bias) if last_act is not None: @@ -68,6 +67,18 @@ class Discriminator(Classifier): """ Defines a discriminator network from Classifier with a single output value and sigmoid activation by default. This is meant for use with GANs or other applications requiring a generic discriminator network. + + Args: + in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension) + channels: tuple of integers stating the output channels of each convolutional layer + strides: tuple of integers stating the stride (downscale factor) of each convolutional layer + kernel_size: integer or tuple of integers stating size of convolutional kernels + num_res_units: integer stating number of convolutions in residual units, 0 means no residual units + act: name or type defining activation layers + norm: name or type defining normalization layers + dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout + bias: boolean stating if convolution layers should have a bias component + last_act: name defining the last activation layer """ def __init__( @@ -83,19 +94,6 @@ def __init__( bias: bool = True, last_act=Act.SIGMOID, ) -> None: - """ - Args: - in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension) - channels: tuple of integers stating the output channels of each convolutional layer - strides: tuple of integers stating the stride (downscale factor) of each convolutional layer - kernel_size: integer or tuple of integers stating size of convolutional kernels - num_res_units: integer stating number of convolutions in residual units, 0 means no residual units - act: name or type defining activation layers - norm: name or type defining normalization layers - dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout - bias: boolean stating if convolution layers should have a bias component - last_act: name defining the last activation layer - """ super().__init__(in_shape, 1, channels, strides, kernel_size, num_res_units, act, norm, dropout, bias, last_act) @@ -104,6 +102,17 @@ class Critic(Classifier): Defines a critic network from Classifier with a single output value and no final activation. The final layer is `nn.Flatten` instead of `nn.Linear`, the final result is computed as the mean over the first dimension. This is meant to be used with Wasserstein GANs. + + Args: + in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension) + channels: tuple of integers stating the output channels of each convolutional layer + strides: tuple of integers stating the stride (downscale factor) of each convolutional layer + kernel_size: integer or tuple of integers stating size of convolutional kernels + num_res_units: integer stating number of convolutions in residual units, 0 means no residual units + act: name or type defining activation layers + norm: name or type defining normalization layers + dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout + bias: boolean stating if convolution layers should have a bias component """ def __init__( @@ -118,18 +127,6 @@ def __init__( dropout: Optional[float] = 0.25, bias: bool = True, ) -> None: - """ - Args: - in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension) - channels: tuple of integers stating the output channels of each convolutional layer - strides: tuple of integers stating the stride (downscale factor) of each convolutional layer - kernel_size: integer or tuple of integers stating size of convolutional kernels - num_res_units: integer stating number of convolutions in residual units, 0 means no residual units - act: name or type defining activation layers - norm: name or type defining normalization layers - dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout - bias: boolean stating if convolution layers should have a bias component - """ super().__init__(in_shape, 1, channels, strides, kernel_size, num_res_units, act, norm, dropout, bias, None) def _get_final_layer(self, in_shape: Sequence[int]): diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py index e9f3b6d33e..52bd2fa994 100644 --- a/monai/networks/nets/densenet.py +++ b/monai/networks/nets/densenet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,7 +23,6 @@ __all__ = [ "DenseNet", - "densenet", "Densenet", "DenseNet121", "densenet121", @@ -62,7 +61,7 @@ def __init__( act: activation type and arguments. Defaults to relu. norm: feature normalization type and arguments. Defaults to batch norm. """ - super(_DenseLayer, self).__init__() + super().__init__() out_channels = bn_size * growth_rate conv_type: Callable = Conv[Conv.CONV, spatial_dims] @@ -110,7 +109,7 @@ def __init__( act: activation type and arguments. Defaults to relu. norm: feature normalization type and arguments. Defaults to batch norm. """ - super(_DenseBlock, self).__init__() + super().__init__() for i in range(layers): layer = _DenseLayer(spatial_dims, in_channels, growth_rate, bn_size, dropout_prob, act=act, norm=norm) in_channels += growth_rate @@ -134,7 +133,7 @@ def __init__( act: activation type and arguments. Defaults to relu. norm: feature normalization type and arguments. Defaults to batch norm. """ - super(_Transition, self).__init__() + super().__init__() conv_type: Callable = Conv[Conv.CONV, spatial_dims] pool_type: Callable = Pool[Pool.AVG, spatial_dims] @@ -149,6 +148,9 @@ class DenseNet(nn.Module): """ Densenet based on: `Densely Connected Convolutional Networks `_. Adapted from PyTorch Hub 2D version: https://pytorch.org/vision/stable/models.html#id16. + This network is non-determistic When `spatial_dims` is 3 and CUDA is enabled. Please check the link below + for more details: + https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms Args: spatial_dims: number of spatial dimensions of the input image. @@ -178,7 +180,7 @@ def __init__( dropout_prob: float = 0.0, ) -> None: - super(DenseNet, self).__init__() + super().__init__() conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] pool_type: Type[Union[nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d]] = Pool[Pool.MAX, spatial_dims] @@ -299,14 +301,13 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super(DenseNet121, self).__init__( - init_features=init_features, - growth_rate=growth_rate, - block_config=block_config, - **kwargs, - ) + super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs) if pretrained: - # it only worked when `spatial_dims` is 2 + if kwargs["spatial_dims"] > 2: + raise NotImplementedError( + "Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not" + "provide pretrained models for more than two spatial dimensions." + ) _load_state_dict(self, "densenet121", progress) @@ -322,14 +323,13 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super(DenseNet169, self).__init__( - init_features=init_features, - growth_rate=growth_rate, - block_config=block_config, - **kwargs, - ) + super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs) if pretrained: - # it only worked when `spatial_dims` is 2 + if kwargs["spatial_dims"] > 2: + raise NotImplementedError( + "Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not" + "provide pretrained models for more than two spatial dimensions." + ) _load_state_dict(self, "densenet169", progress) @@ -345,14 +345,13 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super(DenseNet201, self).__init__( - init_features=init_features, - growth_rate=growth_rate, - block_config=block_config, - **kwargs, - ) + super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs) if pretrained: - # it only worked when `spatial_dims` is 2 + if kwargs["spatial_dims"] > 2: + raise NotImplementedError( + "Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not" + "provide pretrained models for more than two spatial dimensions." + ) _load_state_dict(self, "densenet201", progress) @@ -363,22 +362,17 @@ def __init__( self, init_features: int = 64, growth_rate: int = 32, - block_config: Sequence[int] = (6, 12, 48, 32), + block_config: Sequence[int] = (6, 12, 64, 48), pretrained: bool = False, progress: bool = True, **kwargs, ) -> None: - super(DenseNet264, self).__init__( - init_features=init_features, - growth_rate=growth_rate, - block_config=block_config, - **kwargs, - ) + super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs) if pretrained: raise NotImplementedError("Currently PyTorch Hub does not provide densenet264 pretrained models.") -Densenet = densenet = DenseNet +Densenet = DenseNet Densenet121 = densenet121 = DenseNet121 Densenet169 = densenet169 = DenseNet169 Densenet201 = densenet201 = DenseNet201 diff --git a/monai/networks/nets/dints.py b/monai/networks/nets/dints.py new file mode 100644 index 0000000000..a4aaf32eed --- /dev/null +++ b/monai/networks/nets/dints.py @@ -0,0 +1,948 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks.dints_block import ( + ActiConvNormBlock, + FactorizedIncreaseBlock, + FactorizedReduceBlock, + P3DActiConvNormBlock, +) +from monai.networks.layers.factories import Conv +from monai.networks.layers.utils import get_act_layer, get_norm_layer +from monai.utils import optional_import + +# solving shortest path problem +csr_matrix, _ = optional_import("scipy.sparse", name="csr_matrix") +dijkstra, _ = optional_import("scipy.sparse.csgraph", name="dijkstra") + +__all__ = ["DiNTS", "TopologyConstruction", "TopologyInstance", "TopologySearch"] + + +@torch.jit.interface +class CellInterface(torch.nn.Module): + """interface for torchscriptable Cell""" + + def forward(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + pass + + +@torch.jit.interface +class StemInterface(torch.nn.Module): + """interface for torchscriptable Stem""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pass + + +class StemTS(StemInterface): + """wrapper for torchscriptable Stem""" + + def __init__(self, *mod): + super().__init__() + self.mod = torch.nn.Sequential(*mod) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mod(x) # type: ignore + + +def _dfs(node, paths): + """use depth first search to find all path activation combination""" + if node == paths: + return [[0], [1]] + child = _dfs(node + 1, paths) + return [[0] + _ for _ in child] + [[1] + _ for _ in child] + + +class _IdentityWithRAMCost(nn.Identity): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.ram_cost = 0 + + +class _CloseWithRAMCost(nn.Module): + def __init__(self): + super().__init__() + self.ram_cost = 0 + + def forward(self, x): + return torch.tensor(0.0, requires_grad=False).to(x) + + +class _ActiConvNormBlockWithRAMCost(ActiConvNormBlock): + """The class wraps monai layers with ram estimation. The ram_cost = total_ram/output_size is estimated. + Here is the estimation: + feature_size = output_size/out_channel + total_ram = ram_cost * output_size + total_ram = in_channel * feature_size (activation map) + + in_channel * feature_size (convolution map) + + out_channel * feature_size (normalization) + = (2*in_channel + out_channel) * output_size/out_channel + ram_cost = total_ram/output_size = 2 * in_channel/out_channel + 1 + """ + + def __init__(self, in_channel: int, out_channel: int, kernel_size: int, padding: int, spatial_dims: int = 3): + super().__init__(in_channel, out_channel, kernel_size, padding, spatial_dims) + self.ram_cost = 1 + in_channel / out_channel * 2 + + +class _P3DActiConvNormBlockWithRAMCost(P3DActiConvNormBlock): + def __init__(self, in_channel: int, out_channel: int, kernel_size: int, padding: int, p3dmode: int = 0): + super().__init__(in_channel, out_channel, kernel_size, padding, p3dmode) + # 1 in_channel (activation) + 1 in_channel (convolution) + + # 1 out_channel (convolution) + 1 out_channel (normalization) + self.ram_cost = 2 + 2 * in_channel / out_channel + + +class _FactorizedIncreaseBlockWithRAMCost(FactorizedIncreaseBlock): + def __init__(self, in_channel: int, out_channel: int, spatial_dims: int = 3): + super().__init__(in_channel, out_channel, spatial_dims) + # s0 is upsampled 2x from s1, representing feature sizes at two resolutions. + # 2 * in_channel * s0 (upsample + activation) + 2 * out_channel * s0 (conv + normalization) + # s0 = output_size/out_channel + self.ram_cost = 2 * in_channel / out_channel + 2 + + +class _FactorizedReduceBlockWithRAMCost(FactorizedReduceBlock): + def __init__(self, in_channel: int, out_channel: int, spatial_dims: int = 3): + super().__init__(in_channel, out_channel, spatial_dims) + # s0 is upsampled 2x from s1, representing feature sizes at two resolutions. + # in_channel * s0 (activation) + 3 * out_channel * s1 (convolution, concatenation, normalization) + # s0 = s1 * 2^(spatial_dims) = output_size / out_channel * 2^(spatial_dims) + self.ram_cost = in_channel / out_channel * 2**self._spatial_dims + 3 + + +class MixedOp(nn.Module): + """ + The weighted averaging of cell operations. + Args: + c: number of output channels. + ops: a dictionary of operations. See also: ``Cell.OPS2D`` or ``Cell.OPS3D``. + arch_code_c: binary cell operation code. It represents the operation results added to the output. + """ + + def __init__(self, c: int, ops: dict, arch_code_c=None): + super().__init__() + if arch_code_c is None: + arch_code_c = np.ones(len(ops)) + self.ops = nn.ModuleList() + for arch_c, op_name in zip(arch_code_c, ops): + self.ops.append(_CloseWithRAMCost() if arch_c == 0 else ops[op_name](c)) + + def forward(self, x: torch.Tensor, weight: torch.Tensor): + """ + Args: + x: input tensor. + weight: learnable architecture weights for cell operations. arch_code_c are derived from it. + Return: + out: weighted average of the operation results. + """ + out = 0.0 + weight = weight.to(x) + for idx, _op in enumerate(self.ops): + out = out + _op(x) * weight[idx] + return out + + +class Cell(CellInterface): + """ + The basic class for cell operation search, which contains a preprocessing operation and a mixed cell operation. + Each cell is defined on a `path` in the topology search space. + Args: + c_prev: number of input channels + c: number of output channels + rate: resolution change rate. It represents the preprocessing operation before the mixed cell operation. + ``-1`` for 2x downsample, ``1`` for 2x upsample, ``0`` for no change of resolution. + arch_code_c: cell operation code + """ + + DIRECTIONS = 3 + # Possible output paths for `Cell`. + # + # - UpSample + # / + # +--+/ + # | |--- Identity or AlignChannels + # +--+\ + # \ + # - Downsample + + # Define connection operation set, parameterized by the number of channels + ConnOPS = { + "up": _FactorizedIncreaseBlockWithRAMCost, + "down": _FactorizedReduceBlockWithRAMCost, + "identity": _IdentityWithRAMCost, + "align_channels": _ActiConvNormBlockWithRAMCost, + } + + # Define 2D operation set, parameterized by the number of channels + OPS2D = { + "skip_connect": lambda _c: _IdentityWithRAMCost(), + "conv_3x3": lambda c: _ActiConvNormBlockWithRAMCost(c, c, 3, padding=1, spatial_dims=2), + } + + # Define 3D operation set, parameterized by the number of channels + OPS3D = { + "skip_connect": lambda _c: _IdentityWithRAMCost(), + "conv_3x3x3": lambda c: _ActiConvNormBlockWithRAMCost(c, c, 3, padding=1, spatial_dims=3), + "conv_3x3x1": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=0), + "conv_3x1x3": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=1), + "conv_1x3x3": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=2), + } + + def __init__(self, c_prev: int, c: int, rate: int, arch_code_c=None, spatial_dims: int = 3): + super().__init__() + self._spatial_dims = spatial_dims + if rate == -1: # downsample + self.preprocess = self.ConnOPS["down"](c_prev, c, spatial_dims=self._spatial_dims) + elif rate == 1: # upsample + self.preprocess = self.ConnOPS["up"](c_prev, c, spatial_dims=self._spatial_dims) + else: + if c_prev == c: + self.preprocess = self.ConnOPS["identity"]() + else: + self.preprocess = self.ConnOPS["align_channels"](c_prev, c, 1, 0, spatial_dims=self._spatial_dims) + + self.OPS = {} + if self._spatial_dims == 2: + self.OPS = self.OPS2D + elif self._spatial_dims == 3: + self.OPS = self.OPS3D + else: + raise NotImplementedError(f"Spatial dimensions {self._spatial_dims} is not supported.") + + self.op = MixedOp(c, self.OPS, arch_code_c) + + def forward(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """ + Args: + x: input tensor + weight: weights for different operations. + """ + x = self.preprocess(x) + x = self.op(x, weight) + return x + + +class DiNTS(nn.Module): + """ + Reimplementation of DiNTS based on + "DiNTS: Differentiable Neural Network Topology Search for 3D Medical Image Segmentation + ". + + The model contains a pre-defined multi-resolution stem block (defined in this class) and a + DiNTS space (defined in :py:class:`monai.networks.nets.TopologyInstance` and + :py:class:`monai.networks.nets.TopologySearch`). + + The stem block is for: 1) input downsample and 2) output upsample to original size. + The model downsamples the input image by 2 (if ``use_downsample=True``). + The downsampled image is downsampled by [1, 2, 4, 8] times (``num_depths=4``) and used as input to the + DiNTS search space (``TopologySearch``) or the DiNTS instance (``TopologyInstance``). + + - ``TopologyInstance`` is the final searched model. The initialization requires the searched architecture codes. + - ``TopologySearch`` is a multi-path topology and cell operation search space. + The architecture codes will be initialized as one. + - ``TopologyConstruction`` is the parent class which constructs the instance and search space. + + To meet the requirements of the structure, the input size for each spatial dimension should be: + divisible by 2 ** (num_depths + 1). + + Args: + dints_space: DiNTS search space. The value should be instance of `TopologyInstance` or `TopologySearch`. + in_channels: number of input image channels. + num_classes: number of output segmentation classes. + act_name: activation name, default to 'RELU'. + norm_name: normalization used in convolution blocks. Default to `InstanceNorm`. + spatial_dims: spatial 2D or 3D inputs. + use_downsample: use downsample in the stem. + If ``False``, the search space will be in resolution [1, 1/2, 1/4, 1/8], + if ``True``, the search space will be in resolution [1/2, 1/4, 1/8, 1/16]. + node_a: node activation numpy matrix. Its shape is `(num_depths, num_blocks + 1)`. + +1 for multi-resolution inputs. + In model searching stage, ``node_a`` can be None. In deployment stage, ``node_a`` cannot be None. + """ + + def __init__( + self, + dints_space, + in_channels: int, + num_classes: int, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = "INSTANCE", + spatial_dims: int = 3, + use_downsample: bool = True, + node_a=None, + ): + super().__init__() + + self.dints_space = dints_space + self.filter_nums = dints_space.filter_nums + self.num_blocks = dints_space.num_blocks + self.num_depths = dints_space.num_depths + if spatial_dims not in (2, 3): + raise NotImplementedError(f"Spatial dimensions {spatial_dims} is not supported.") + self._spatial_dims = spatial_dims + if node_a is None: + self.node_a = torch.ones((self.num_blocks + 1, self.num_depths)) + else: + self.node_a = node_a + + # define stem operations for every block + conv_type = Conv[Conv.CONV, spatial_dims] + self.stem_down = nn.ModuleDict() + self.stem_up = nn.ModuleDict() + self.stem_finals = nn.Sequential( + ActiConvNormBlock( + self.filter_nums[0], + self.filter_nums[0], + act_name=act_name, + norm_name=norm_name, + spatial_dims=spatial_dims, + ), + conv_type( + in_channels=self.filter_nums[0], + out_channels=num_classes, + kernel_size=1, + stride=1, + padding=0, + groups=1, + bias=True, + dilation=1, + ), + ) + mode = "trilinear" if self._spatial_dims == 3 else "bilinear" + for res_idx in range(self.num_depths): + # define downsample stems before DiNTS search + if use_downsample: + self.stem_down[str(res_idx)] = StemTS( + nn.Upsample(scale_factor=1 / (2**res_idx), mode=mode, align_corners=True), + conv_type( + in_channels=in_channels, + out_channels=self.filter_nums[res_idx], + kernel_size=3, + stride=1, + padding=1, + groups=1, + bias=False, + dilation=1, + ), + get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx]), + get_act_layer(name=act_name), + conv_type( + in_channels=self.filter_nums[res_idx], + out_channels=self.filter_nums[res_idx + 1], + kernel_size=3, + stride=2, + padding=1, + groups=1, + bias=False, + dilation=1, + ), + get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx + 1]), + ) + self.stem_up[str(res_idx)] = StemTS( + get_act_layer(name=act_name), + conv_type( + in_channels=self.filter_nums[res_idx + 1], + out_channels=self.filter_nums[res_idx], + kernel_size=3, + stride=1, + padding=1, + groups=1, + bias=False, + dilation=1, + ), + get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx]), + nn.Upsample(scale_factor=2, mode=mode, align_corners=True), + ) + + else: + self.stem_down[str(res_idx)] = StemTS( + nn.Upsample(scale_factor=1 / (2**res_idx), mode=mode, align_corners=True), + conv_type( + in_channels=in_channels, + out_channels=self.filter_nums[res_idx], + kernel_size=3, + stride=1, + padding=1, + groups=1, + bias=False, + dilation=1, + ), + get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx]), + ) + self.stem_up[str(res_idx)] = StemTS( + get_act_layer(name=act_name), + conv_type( + in_channels=self.filter_nums[res_idx], + out_channels=self.filter_nums[max(res_idx - 1, 0)], + kernel_size=3, + stride=1, + padding=1, + groups=1, + bias=False, + dilation=1, + ), + get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx - 1]), + nn.Upsample(scale_factor=2 ** (res_idx != 0), mode=mode, align_corners=True), + ) + + def weight_parameters(self): + return [param for name, param in self.named_parameters()] + + def forward(self, x: torch.Tensor): + """ + Prediction based on dynamic arch_code. + + Args: + x: input tensor. + """ + inputs = [] + for d in range(self.num_depths): + # allow multi-resolution input + _mod_w: StemInterface = self.stem_down[str(d)] + x_out = _mod_w.forward(x) + if self.node_a[0][d]: + inputs.append(x_out) + else: + inputs.append(torch.zeros_like(x_out)) + + outputs = self.dints_space(inputs) + + blk_idx = self.num_blocks - 1 + start = False + _temp: torch.Tensor = torch.empty(0) + for res_idx in range(self.num_depths - 1, -1, -1): + _mod_up: StemInterface = self.stem_up[str(res_idx)] + if start: + _temp = _mod_up.forward(outputs[res_idx] + _temp) + elif self.node_a[blk_idx + 1][res_idx]: + start = True + _temp = _mod_up.forward(outputs[res_idx]) + prediction = self.stem_finals(_temp) + return prediction + + +class TopologyConstruction(nn.Module): + """ + The base class for `TopologyInstance` and `TopologySearch`. + + Args: + arch_code: `[arch_code_a, arch_code_c]`, numpy arrays. The architecture codes defining the model. + For example, for a ``num_depths=4, num_blocks=12`` search space: + + - `arch_code_a` is a 12x10 (10 paths) binary matrix representing if a path is activated. + - `arch_code_c` is a 12x10x5 (5 operations) binary matrix representing if a cell operation is used. + - `arch_code` in ``__init__()`` is used for creating the network and remove unused network blocks. If None, + + all paths and cells operations will be used, and must be in the searching stage (is_search=True). + channel_mul: adjust intermediate channel number, default is 1. + cell: operation of each node. + num_blocks: number of blocks (depth in the horizontal direction) of the DiNTS search space. + num_depths: number of image resolutions of the DiNTS search space: 1, 1/2, 1/4 ... in each dimension. + use_downsample: use downsample in the stem. If False, the search space will be in resolution [1, 1/2, 1/4, 1/8], + if True, the search space will be in resolution [1/2, 1/4, 1/8, 1/16]. + device: `'cpu'`, `'cuda'`, or device ID. + + + Predefined variables: + `filter_nums`: default to 32. Double the number of channels after downsample. + topology related variables: + + - `arch_code2in`: path activation to its incoming node index (resolution). For depth = 4, + arch_code2in = [0, 1, 0, 1, 2, 1, 2, 3, 2, 3]. The first path outputs from node 0 (top resolution), + the second path outputs from node 1 (second resolution in the search space), + the third path outputs from node 0, etc. + - `arch_code2ops`: path activation to operations of upsample 1, keep 0, downsample -1. For depth = 4, + arch_code2ops = [0, 1, -1, 0, 1, -1, 0, 1, -1, 0]. The first path does not change + resolution, the second path perform upsample, the third perform downsample, etc. + - `arch_code2out`: path activation to its output node index. + For depth = 4, arch_code2out = [0, 0, 1, 1, 1, 2, 2, 2, 3, 3], + the first and second paths connects to node 0 (top resolution), the 3,4,5 paths connects to node 1, etc. + """ + + def __init__( + self, + arch_code: Optional[list] = None, + channel_mul: float = 1.0, + cell=Cell, + num_blocks: int = 6, + num_depths: int = 3, + spatial_dims: int = 3, + use_downsample: bool = True, + device: str = "cpu", + ): + + super().__init__() + + self.filter_nums = [int(n_feat * channel_mul) for n_feat in (32, 64, 128, 256, 512)] + self.num_blocks = num_blocks + self.num_depths = num_depths + self._spatial_dims = spatial_dims + self.use_downsample = use_downsample + self.device = device + self.num_cell_ops = 0 + if self._spatial_dims == 2: + self.num_cell_ops = len(cell.OPS2D) + elif self._spatial_dims == 3: + self.num_cell_ops = len(cell.OPS3D) + + # Calculate predefined parameters for topology search and decoding + arch_code2in, arch_code2out = [], [] + for i in range(Cell.DIRECTIONS * self.num_depths - 2): + arch_code2in.append((i + 1) // Cell.DIRECTIONS - 1 + (i + 1) % Cell.DIRECTIONS) + arch_code2ops = ([-1, 0, 1] * self.num_depths)[1:-1] + for m in range(self.num_depths): + arch_code2out.extend([m, m, m]) + arch_code2out = arch_code2out[1:-1] + self.arch_code2in = arch_code2in + self.arch_code2ops = arch_code2ops + self.arch_code2out = arch_code2out + + # define NAS search space + if arch_code is None: + arch_code_a = torch.ones((self.num_blocks, len(self.arch_code2out))).to(self.device) + arch_code_c = torch.ones((self.num_blocks, len(self.arch_code2out), self.num_cell_ops)).to(self.device) + else: + arch_code_a = torch.from_numpy(arch_code[0]).to(self.device) + arch_code_c = F.one_hot(torch.from_numpy(arch_code[1]).to(torch.int64), self.num_cell_ops).to(self.device) + + self.arch_code_a = arch_code_a + self.arch_code_c = arch_code_c + # define cell operation on each path + self.cell_tree = nn.ModuleDict() + for blk_idx in range(self.num_blocks): + for res_idx in range(len(self.arch_code2out)): + if self.arch_code_a[blk_idx, res_idx] == 1: + self.cell_tree[str((blk_idx, res_idx))] = cell( + self.filter_nums[self.arch_code2in[res_idx] + int(use_downsample)], + self.filter_nums[self.arch_code2out[res_idx] + int(use_downsample)], + self.arch_code2ops[res_idx], + self.arch_code_c[blk_idx, res_idx], + self._spatial_dims, + ) + + def forward(self, x): + """This function to be implemented by the architecture instances or search spaces.""" + pass + + +class TopologyInstance(TopologyConstruction): + """ + Instance of the final searched architecture. Only used in re-training/inference stage. + """ + + def __init__( + self, + arch_code=None, + channel_mul: float = 1.0, + cell=Cell, + num_blocks: int = 6, + num_depths: int = 3, + spatial_dims: int = 3, + use_downsample: bool = True, + device: str = "cpu", + ): + """ + Initialize DiNTS topology search space of neural architectures. + """ + if arch_code is None: + warnings.warn("arch_code not provided when not searching.") + + super().__init__( + arch_code=arch_code, + channel_mul=channel_mul, + cell=cell, + num_blocks=num_blocks, + num_depths=num_depths, + spatial_dims=spatial_dims, + use_downsample=use_downsample, + device=device, + ) + + def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: + """ + Args: + x: input tensor. + """ + # generate path activation probability + inputs, outputs = x, [torch.tensor(0.0).to(x[0])] * self.num_depths + for blk_idx in range(self.num_blocks): + outputs = [torch.tensor(0.0).to(x[0])] * self.num_depths + for res_idx, activation in enumerate(self.arch_code_a[blk_idx].data): + if activation: + mod: CellInterface = self.cell_tree[str((blk_idx, res_idx))] + _out = mod.forward( + x=inputs[self.arch_code2in[res_idx]], weight=torch.ones_like(self.arch_code_c[blk_idx, res_idx]) + ) + outputs[self.arch_code2out[res_idx]] = outputs[self.arch_code2out[res_idx]] + _out + inputs = outputs + + return inputs + + +class TopologySearch(TopologyConstruction): + """ + DiNTS topology search space of neural architectures. + + Examples: + + .. code-block:: python + + from monai.networks.nets.dints import TopologySearch + + topology_search_space = TopologySearch( + channel_mul=0.5, num_blocks=8, num_depths=4, use_downsample=True, spatial_dims=3) + topology_search_space.get_ram_cost_usage(in_size=(2, 16, 80, 80, 80), full=True) + multi_res_images = [ + torch.randn(2, 16, 80, 80, 80), + torch.randn(2, 32, 40, 40, 40), + torch.randn(2, 64, 20, 20, 20), + torch.randn(2, 128, 10, 10, 10)] + prediction = topology_search_space(image) + for x in prediction: print(x.shape) + # torch.Size([2, 16, 80, 80, 80]) + # torch.Size([2, 32, 40, 40, 40]) + # torch.Size([2, 64, 20, 20, 20]) + # torch.Size([2, 128, 10, 10, 10]) + + Class method overview: + + - ``get_prob_a()``: convert learnable architecture weights to path activation probabilities. + - ``get_ram_cost_usage()``: get estimated ram cost. + - ``get_topology_entropy()``: get topology entropy loss in searching stage. + - ``decode()``: get final binarized architecture code. + - ``gen_mtx()``: generate variables needed for topology search. + + Predefined variables: + - `tidx`: index used to convert path activation matrix T = (depth,depth) in transfer_mtx to + path activation arch_code (1,3*depth-2), for depth = 4, tidx = [0, 1, 4, 5, 6, 9, 10, 11, 14, 15], + A tidx (10 binary values) represents the path activation. + - `transfer_mtx`: feasible path activation matrix (denoted as T) given a node activation pattern. + It is used to convert path activation pattern (1, paths) to node activation (1, nodes) + - `node_act_list`: all node activation [2^num_depths-1, depth]. For depth = 4, there are 15 node activation + patterns, each of length 4. For example, [1,1,0,0] means nodes 0, 1 are activated (with input paths). + - `all_connect`: All possible path activations. For depth = 4, + all_connection has 1024 vectors of length 10 (10 paths). + The return value will exclude path activation of all 0. + """ + + def __init__( + self, + channel_mul: float = 1.0, + cell=Cell, + arch_code: Optional[list] = None, + num_blocks: int = 6, + num_depths: int = 3, + spatial_dims: int = 3, + use_downsample: bool = True, + device: str = "cpu", + ): + """ + Initialize DiNTS topology search space of neural architectures. + """ + super().__init__( + arch_code=arch_code, + channel_mul=channel_mul, + cell=cell, + num_blocks=num_blocks, + num_depths=num_depths, + spatial_dims=spatial_dims, + use_downsample=use_downsample, + device=device, + ) + + tidx = [] + _d = Cell.DIRECTIONS + for i in range(_d * self.num_depths - 2): + tidx.append((i + 1) // _d * self.num_depths + (i + 1) // _d - 1 + (i + 1) % _d) + self.tidx = tidx + transfer_mtx, node_act_list, child_list = self.gen_mtx(num_depths) + + self.node_act_list = np.asarray(node_act_list) + self.node_act_dict = {str(self.node_act_list[i]): i for i in range(len(self.node_act_list))} + self.transfer_mtx = transfer_mtx + self.child_list = np.asarray(child_list) + + self.ram_cost = np.zeros((self.num_blocks, len(self.arch_code2out), self.num_cell_ops)) + for blk_idx in range(self.num_blocks): + for res_idx in range(len(self.arch_code2out)): + if self.arch_code_a[blk_idx, res_idx] == 1: + self.ram_cost[blk_idx, res_idx] = np.array( + [ + op.ram_cost + self.cell_tree[str((blk_idx, res_idx))].preprocess.ram_cost + for op in self.cell_tree[str((blk_idx, res_idx))].op.ops[: self.num_cell_ops] + ] + ) + + # define cell and macro architecture probabilities + self.log_alpha_c = nn.Parameter( + torch.zeros(self.num_blocks, len(self.arch_code2out), self.num_cell_ops) + .normal_(1, 0.01) + .to(self.device) + .requires_grad_() + ) + self.log_alpha_a = nn.Parameter( + torch.zeros(self.num_blocks, len(self.arch_code2out)).normal_(0, 0.01).to(self.device).requires_grad_() + ) + self._arch_param_names = ["log_alpha_a", "log_alpha_c"] + + def gen_mtx(self, depth: int): + """ + Generate elements needed in decoding and topology. + + - `transfer_mtx`: feasible path activation matrix (denoted as T) given a node activation pattern. + It is used to convert path activation pattern (1, paths) to node activation (1, nodes) + - `node_act_list`: all node activation [2^num_depths-1, depth]. For depth = 4, there are 15 node activation + patterns, each of length 4. For example, [1,1,0,0] means nodes 0, 1 are activated (with input paths). + - `all_connect`: All possible path activations. For depth = 4, + all_connection has 1024 vectors of length 10 (10 paths). + The return value will exclude path activation of all 0. + """ + # total paths in a block, each node has three output paths, + # except the two nodes at the top and the bottom scales + paths = Cell.DIRECTIONS * depth - 2 + + # for 10 paths, all_connect has 1024 possible path activations. [1 0 0 0 0 0 0 0 0 0] means the top + # path is activated. + all_connect = _dfs(0, paths - 1) + + # Save all possible connections in mtx (might be redundant and infeasible) + mtx = [] + for m in all_connect: + # convert path activation [1,paths] to path activation matrix [depth, depth] + ma = np.zeros((depth, depth)) + for i in range(paths): + ma[(i + 1) // Cell.DIRECTIONS, (i + 1) // Cell.DIRECTIONS - 1 + (i + 1) % Cell.DIRECTIONS] = m[i] + mtx.append(ma) + + # define all possible node activation + node_act_list = _dfs(0, depth - 1)[1:] + transfer_mtx = {} + for arch_code in node_act_list: + # make sure each activated node has an active connection, inactivated node has no connection + arch_code_mtx = [_ for _ in mtx if ((np.sum(_, 0) > 0).astype(int) == np.array(arch_code)).all()] + transfer_mtx[str(np.array(arch_code))] = arch_code_mtx + + return transfer_mtx, node_act_list, all_connect[1:] + + def weight_parameters(self): + return [param for name, param in self.named_parameters() if name not in self._arch_param_names] + + def get_prob_a(self, child: bool = False): + """ + Get final path and child model probabilities from architecture weights `log_alpha_a`. + This is used in forward pass, getting training loss, and final decoding. + + Args: + child: return child probability (used in decoding) + Return: + arch_code_prob_a: the path activation probability of size: + `[number of blocks, number of paths in each block]`. + For 12 blocks, 4 depths search space, the size is [12,10] + probs_a: The probability of all child models (size 1023x10). Each child model is a path activation pattern + (1D vector of length 10 for 10 paths). In total 1023 child models (2^10 -1) + """ + _arch_code_prob_a = torch.sigmoid(self.log_alpha_a) + # remove the case where all path are zero, and re-normalize. + norm = 1 - (1 - _arch_code_prob_a).prod(-1) + arch_code_prob_a = _arch_code_prob_a / norm.unsqueeze(1) + if child: + path_activation = torch.from_numpy(self.child_list).to(self.device) + probs_a = [ + ( + path_activation * _arch_code_prob_a[blk_idx] + + (1 - path_activation) * (1 - _arch_code_prob_a[blk_idx]) + ).prod(-1) + / norm[blk_idx] + for blk_idx in range(self.num_blocks) + ] + probs_a = torch.stack(probs_a) # type: ignore + return probs_a, arch_code_prob_a + return None, arch_code_prob_a + + def get_ram_cost_usage(self, in_size, full: bool = False): + """ + Get estimated output tensor size to approximate RAM consumption. + + Args: + in_size: input image shape (4D/5D, ``[BCHW[D]]``) at the highest resolution level. + full: full ram cost usage with all probability of 1. + """ + # convert input image size to feature map size at each level + batch_size = in_size[0] + image_size = np.array(in_size[-self._spatial_dims :]) + sizes = [] + for res_idx in range(self.num_depths): + sizes.append(batch_size * self.filter_nums[res_idx] * (image_size // (2**res_idx)).prod()) + sizes = torch.tensor(sizes).to(torch.float32).to(self.device) / (2 ** (int(self.use_downsample))) + probs_a, arch_code_prob_a = self.get_prob_a(child=False) + cell_prob = F.softmax(self.log_alpha_c, dim=-1) + if full: + arch_code_prob_a = arch_code_prob_a.detach() + arch_code_prob_a.fill_(1) + ram_cost = torch.from_numpy(self.ram_cost).to(torch.float32).to(self.device) + usage = 0.0 + for blk_idx in range(self.num_blocks): + # node activation for input + # cell operation + for path_idx in range(len(self.arch_code2out)): + usage += ( + arch_code_prob_a[blk_idx, path_idx] + * (1 + (ram_cost[blk_idx, path_idx] * cell_prob[blk_idx, path_idx]).sum()) + * sizes[self.arch_code2out[path_idx]] + ) + return usage * 32 / 8 / 1024**2 + + def get_topology_entropy(self, probs): + """ + Get topology entropy loss at searching stage. + + Args: + probs: path activation probabilities + """ + if hasattr(self, "node2in"): + node2in = self.node2in + node2out = self.node2out + else: + # node activation index to feasible input child_idx + node2in = [[] for _ in range(len(self.node_act_list))] + # node activation index to feasible output child_idx + node2out = [[] for _ in range(len(self.node_act_list))] + for child_idx in range(len(self.child_list)): + _node_in, _node_out = np.zeros(self.num_depths), np.zeros(self.num_depths) + for res_idx in range(len(self.arch_code2out)): + _node_out[self.arch_code2out[res_idx]] += self.child_list[child_idx][res_idx] + _node_in[self.arch_code2in[res_idx]] += self.child_list[child_idx][res_idx] + _node_in = (_node_in >= 1).astype(int) + _node_out = (_node_out >= 1).astype(int) + node2in[self.node_act_dict[str(_node_out)]].append(child_idx) + node2out[self.node_act_dict[str(_node_in)]].append(child_idx) + self.node2in = node2in + self.node2out = node2out + # calculate entropy + ent = 0 + for blk_idx in range(self.num_blocks - 1): + blk_ent = 0 + # node activation probability + for node_idx in range(len(self.node_act_list)): + _node_p = probs[blk_idx, node2in[node_idx]].sum() + _out_probs = probs[blk_idx + 1, node2out[node_idx]].sum() + blk_ent += -(_node_p * torch.log(_out_probs + 1e-5) + (1 - _node_p) * torch.log(1 - _out_probs + 1e-5)) + ent += blk_ent + return ent + + def decode(self): + """ + Decode network log_alpha_a/log_alpha_c using dijkstra shortest path algorithm. + + `[node_a, arch_code_a, arch_code_c, arch_code_a_max]` is decoded when using ``self.decode()``. + + For example, for a ``num_depths=4``, ``num_blocks=12`` search space: + + - ``node_a`` is a 4x13 binary matrix representing if a feature node is activated + (13 because of multi-resolution inputs). + - ``arch_code_a`` is a 12x10 (10 paths) binary matrix representing if a path is activated. + - ``arch_code_c`` is a 12x10x5 (5 operations) binary matrix representing if a cell operation is used. + + Return: + arch_code with maximum probability + """ + probs, arch_code_prob_a = self.get_prob_a(child=True) + arch_code_a_max = self.child_list[torch.argmax(probs, -1).data.cpu().numpy()] + arch_code_c = torch.argmax(F.softmax(self.log_alpha_c, -1), -1).data.cpu().numpy() + probs = probs.data.cpu().numpy() + + # define adjacency matrix + amtx = np.zeros( + (1 + len(self.child_list) * self.num_blocks + 1, 1 + len(self.child_list) * self.num_blocks + 1) + ) + + # build a path activation to child index searching dictionary + path2child = {str(self.child_list[i]): i for i in range(len(self.child_list))} + + # build a submodel to submodel index + sub_amtx = np.zeros((len(self.child_list), len(self.child_list))) + for child_idx in range(len(self.child_list)): + _node_act = np.zeros(self.num_depths).astype(int) + for path_idx in range(len(self.child_list[child_idx])): + _node_act[self.arch_code2out[path_idx]] += self.child_list[child_idx][path_idx] + _node_act = (_node_act >= 1).astype(int) + for mtx in self.transfer_mtx[str(_node_act)]: + connect_child_idx = path2child[str(mtx.flatten()[self.tidx].astype(int))] + sub_amtx[child_idx, connect_child_idx] = 1 + + # fill in source to first block, add 1e-5/1e-3 to avoid log0 and negative edge weights + amtx[0, 1 : 1 + len(self.child_list)] = -np.log(probs[0] + 1e-5) + 0.001 + + # fill in the rest blocks + for blk_idx in range(1, self.num_blocks): + amtx[ + 1 + (blk_idx - 1) * len(self.child_list) : 1 + blk_idx * len(self.child_list), + 1 + blk_idx * len(self.child_list) : 1 + (blk_idx + 1) * len(self.child_list), + ] = sub_amtx * np.tile(-np.log(probs[blk_idx] + 1e-5) + 0.001, (len(self.child_list), 1)) + + # fill in the last to the sink + amtx[1 + (self.num_blocks - 1) * len(self.child_list) : 1 + self.num_blocks * len(self.child_list), -1] = 0.001 + + graph = csr_matrix(amtx) + dist_matrix, predecessors, sources = dijkstra( + csgraph=graph, directed=True, indices=0, min_only=True, return_predecessors=True + ) + index, a_idx = -1, -1 + arch_code_a = np.zeros((self.num_blocks, len(self.arch_code2out))) + node_a = np.zeros((self.num_blocks + 1, self.num_depths)) + + # decoding to paths + while True: + index = predecessors[index] + if index == 0: + break + child_idx = (index - 1) % len(self.child_list) + arch_code_a[a_idx, :] = self.child_list[child_idx] + for res_idx in range(len(self.arch_code2out)): + node_a[a_idx, self.arch_code2out[res_idx]] += arch_code_a[a_idx, res_idx] + a_idx -= 1 + for res_idx in range(len(self.arch_code2out)): + node_a[a_idx, self.arch_code2in[res_idx]] += arch_code_a[0, res_idx] + node_a = (node_a >= 1).astype(int) + return node_a, arch_code_a, arch_code_c, arch_code_a_max + + def forward(self, x): + """ + Prediction based on dynamic arch_code. + + Args: + x: a list of `num_depths` input tensors as a multi-resolution input. + tensor is of shape `BCHW[D]` where `C` must match `self.filter_nums`. + """ + # generate path activation probability + probs_a, arch_code_prob_a = self.get_prob_a(child=False) + inputs = x + for blk_idx in range(self.num_blocks): + outputs = [0.0] * self.num_depths + for res_idx, activation in enumerate(self.arch_code_a[blk_idx].data.cpu().numpy()): + if activation: + _w = F.softmax(self.log_alpha_c[blk_idx, res_idx], dim=-1) + outputs[self.arch_code2out[res_idx]] += ( + self.cell_tree[str((blk_idx, res_idx))](inputs[self.arch_code2in[res_idx]], weight=_w) + * arch_code_prob_a[blk_idx, res_idx] + ) + inputs = outputs + + return inputs diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index 4af70b22c7..337a99acd8 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,7 +18,7 @@ from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock -__all__ = ["DynUNet", "DynUnet", "Dynunet", "dynunet"] +__all__ = ["DynUNet", "DynUnet", "Dynunet"] class DynUNetSkipLayer(nn.Module): @@ -31,13 +31,13 @@ class DynUNetSkipLayer(nn.Module): forward passes of the network. """ - heads: List[torch.Tensor] + heads: Optional[List[torch.Tensor]] - def __init__(self, index, heads, downsample, upsample, super_head, next_layer): + def __init__(self, index, downsample, upsample, next_layer, heads=None, super_head=None): super().__init__() self.downsample = downsample - self.upsample = upsample self.next_layer = next_layer + self.upsample = upsample self.super_head = super_head self.heads = heads self.index = index @@ -46,8 +46,8 @@ def forward(self, x): downout = self.downsample(x) nextout = self.next_layer(downout) upout = self.upsample(nextout, downout) - - self.heads[self.index] = self.super_head(upout) + if self.super_head is not None and self.heads is not None and self.index > 0: + self.heads[self.index - 1] = self.super_head(upout) return upout @@ -57,6 +57,7 @@ class DynUNet(nn.Module): This reimplementation of a dynamic UNet (DynUNet) is based on: `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. + `Optimized U-Net for Brain Tumor Segmentation `_. This model is more flexible compared with ``monai.networks.nets.UNet`` in three places: @@ -74,10 +75,15 @@ class DynUNet(nn.Module): To meet the requirements of the structure, the input size for each spatial dimension should be divisible by `2 * the product of all strides in the corresponding dimension`. The output size for each spatial dimension - equals to the input size of the correponding dimension divided by the stride in strides[0]. + equals to the input size of the corresponding dimension divided by the stride in strides[0]. For example, if `strides=((1, 2, 4), 2, 1, 1)`, the minimal spatial size of the input is `(8, 16, 32)`, and the spatial size of the output is `(8, 8, 8)`. + For backwards compatibility with old weights, please set `strict=False` when calling `load_state_dict`. + + Usage example with medical segmentation decathlon dataset is available at: + https://github.com/Project-MONAI/tutorials/tree/master/modules/dynunet_pipeline. + Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. @@ -86,25 +92,32 @@ class DynUNet(nn.Module): strides: convolution strides for each blocks. upsample_kernel_size: convolution kernel size for transposed convolution layers. The values should equal to strides[1:]. + filters: number of output channels for each blocks. Different from nnU-Net, in this implementation we add + this argument to make the network more flexible. As shown in the third reference, one way to determine + this argument is like: + ``[64, 96, 128, 192, 256, 384, 512, 768, 1024][: len(strides)]``. + The above way is used in the network that wins task 1 in the BraTS21 Challenge. + If not specified, the way which nnUNet used will be employed. Defaults to ``None``. + dropout: dropout ratio. Defaults to no dropout. norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``. + act_name: activation layer type and arguments. Defaults to ``leakyrelu``. deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. - If ``True``, in training mode, the forward function will output not only the last feature - map, but also the previous feature maps that come from the intermediate up sample layers. + If ``True``, in training mode, the forward function will output not only the final feature map + (from `output_block`), but also the feature maps that come from the intermediate up sample layers. In order to unify the return type (the restriction of TorchScript), all intermediate - feature maps are interpolated into the same size as the last feature map and stacked together + feature maps are interpolated into the same size as the final feature map and stacked together (with a new dimension in the first axis)into one single tensor. - For instance, if there are three feature maps with shapes: (1, 2, 32, 24), (1, 2, 16, 12) and - (1, 2, 8, 6). The last two will be interpolated into (1, 2, 32, 24), and the stacked tensor - will has the shape (1, 3, 2, 8, 6). + For instance, if there are two intermediate feature maps with shapes: (1, 2, 16, 12) and + (1, 2, 8, 6), and the final feature map has the shape (1, 2, 32, 24), then all intermediate feature maps + will be interpolated into (1, 2, 32, 24), and the stacked tensor will has the shape (1, 3, 2, 32, 24). When calculating the loss, you can use torch.unbind to get all feature maps can compute the loss one by one with the ground truth, then do a weighted average for all losses to achieve the final loss. - (To be added: a corresponding tutorial link) - deep_supr_num: number of feature maps that will output during deep supervision head. The value should be larger than 0 and less than the number of up sample layers. Defaults to 1. res_block: whether to use residual connection based convolution blocks during the network. Defaults to ``False``. + trans_bias: whether to set the bias parameter in transposed convolution layers. Defaults to ``False``. """ def __init__( @@ -115,12 +128,16 @@ def __init__( kernel_size: Sequence[Union[Sequence[int], int]], strides: Sequence[Union[Sequence[int], int]], upsample_kernel_size: Sequence[Union[Sequence[int], int]], + filters: Optional[Sequence[int]] = None, + dropout: Optional[Union[Tuple, str, float]] = None, norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), deep_supervision: bool = False, deep_supr_num: int = 1, res_block: bool = False, + trans_bias: bool = False, ): - super(DynUNet, self).__init__() + super().__init__() self.spatial_dims = spatial_dims self.in_channels = in_channels self.out_channels = out_channels @@ -128,24 +145,32 @@ def __init__( self.strides = strides self.upsample_kernel_size = upsample_kernel_size self.norm_name = norm_name + self.act_name = act_name + self.dropout = dropout self.conv_block = UnetResBlock if res_block else UnetBasicBlock - self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))] + self.trans_bias = trans_bias + if filters is not None: + self.filters = filters + self.check_filters() + else: + self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))] self.input_block = self.get_input_block() self.downsamples = self.get_downsamples() self.bottleneck = self.get_bottleneck() self.upsamples = self.get_upsamples() self.output_block = self.get_output_block(0) self.deep_supervision = deep_supervision - self.deep_supervision_heads = self.get_deep_supervision_heads() self.deep_supr_num = deep_supr_num + # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on + self.heads: List[torch.Tensor] = [torch.rand(1)] * self.deep_supr_num + if self.deep_supervision: + self.deep_supervision_heads = self.get_deep_supervision_heads() + self.check_deep_supr_num() + self.apply(self.initialize_weights) self.check_kernel_stride() - self.check_deep_supr_num() - - # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on - self.heads: List[torch.Tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1) - def create_skips(index, downsamples, upsamples, superheads, bottleneck): + def create_skips(index, downsamples, upsamples, bottleneck, superheads=None): """ Construct the UNet topology as a sequence of skip layers terminating with the bottleneck layer. This is done recursively from the top down since a recursive nn.Module subclass is being used to be compatible @@ -155,64 +180,90 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck): """ if len(downsamples) != len(upsamples): - raise AssertionError(f"{len(downsamples)} != {len(upsamples)}") - if (len(downsamples) - len(superheads)) not in (1, 0): - raise AssertionError(f"{len(downsamples)}-(0,1) != {len(superheads)}") + raise ValueError(f"{len(downsamples)} != {len(upsamples)}") if len(downsamples) == 0: # bottom of the network, pass the bottleneck block return bottleneck + + if superheads is None: + next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], bottleneck) + return DynUNetSkipLayer(index, downsample=downsamples[0], upsample=upsamples[0], next_layer=next_layer) + + super_head_flag = False if index == 0: # don't associate a supervision head with self.input_block - current_head, rest_heads = nn.Identity(), superheads - elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one - current_head, rest_heads = nn.Identity(), superheads[1:] + rest_heads = superheads else: - current_head, rest_heads = superheads[0], superheads[1:] + if len(superheads) > 0: + super_head_flag = True + rest_heads = superheads[1:] + else: + rest_heads = nn.ModuleList() # create the next layer down, this will stop at the bottleneck layer - next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck) - - return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], current_head, next_layer) - - self.skip_layers = create_skips( - 0, - [self.input_block] + list(self.downsamples), - self.upsamples[::-1], - self.deep_supervision_heads, - self.bottleneck, - ) + next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], bottleneck, superheads=rest_heads) + if super_head_flag: + return DynUNetSkipLayer( + index, + downsample=downsamples[0], + upsample=upsamples[0], + next_layer=next_layer, + heads=self.heads, + super_head=superheads[0], + ) + + return DynUNetSkipLayer(index, downsample=downsamples[0], upsample=upsamples[0], next_layer=next_layer) + + if not self.deep_supervision: + self.skip_layers = create_skips( + 0, [self.input_block] + list(self.downsamples), self.upsamples[::-1], self.bottleneck + ) + else: + self.skip_layers = create_skips( + 0, + [self.input_block] + list(self.downsamples), + self.upsamples[::-1], + self.bottleneck, + superheads=self.deep_supervision_heads, + ) def check_kernel_stride(self): kernels, strides = self.kernel_size, self.strides error_msg = "length of kernel_size and strides should be the same, and no less than 3." - if not (len(kernels) == len(strides) and len(kernels) >= 3): - raise AssertionError(error_msg) + if len(kernels) != len(strides) or len(kernels) < 3: + raise ValueError(error_msg) for idx, k_i in enumerate(kernels): kernel, stride = k_i, strides[idx] if not isinstance(kernel, int): - error_msg = "length of kernel_size in block {} should be the same as spatial_dims.".format(idx) + error_msg = f"length of kernel_size in block {idx} should be the same as spatial_dims." if len(kernel) != self.spatial_dims: - raise AssertionError(error_msg) + raise ValueError(error_msg) if not isinstance(stride, int): - error_msg = "length of stride in block {} should be the same as spatial_dims.".format(idx) + error_msg = f"length of stride in block {idx} should be the same as spatial_dims." if len(stride) != self.spatial_dims: - raise AssertionError(error_msg) + raise ValueError(error_msg) def check_deep_supr_num(self): deep_supr_num, strides = self.deep_supr_num, self.strides num_up_layers = len(strides) - 1 if deep_supr_num >= num_up_layers: - raise AssertionError("deep_supr_num should be less than the number of up sample layers.") + raise ValueError("deep_supr_num should be less than the number of up sample layers.") if deep_supr_num < 1: - raise AssertionError("deep_supr_num should be larger than 0.") + raise ValueError("deep_supr_num should be larger than 0.") + + def check_filters(self): + filters = self.filters + if len(filters) < len(self.strides): + raise ValueError("length of filters should be no less than the length of strides.") + else: + self.filters = filters[: len(self.strides)] def forward(self, x): out = self.skip_layers(x) out = self.output_block(out) if self.training and self.deep_supervision: out_all = [out] - feature_maps = self.heads[1 : self.deep_supr_num + 1] - for feature_map in feature_maps: + for feature_map in self.heads: out_all.append(interpolate(feature_map, out.shape[2:])) return torch.stack(out_all, dim=1) return out @@ -225,6 +276,8 @@ def get_input_block(self): self.kernel_size[0], self.strides[0], self.norm_name, + self.act_name, + dropout=self.dropout, ) def get_bottleneck(self): @@ -235,14 +288,12 @@ def get_bottleneck(self): self.kernel_size[-1], self.strides[-1], self.norm_name, + self.act_name, + dropout=self.dropout, ) def get_output_block(self, idx: int): - return UnetOutBlock( - self.spatial_dims, - self.filters[idx], - self.out_channels, - ) + return UnetOutBlock(self.spatial_dims, self.filters[idx], self.out_channels, dropout=self.dropout) def get_downsamples(self): inp, out = self.filters[:-2], self.filters[1:-1] @@ -253,7 +304,9 @@ def get_upsamples(self): inp, out = self.filters[1:][::-1], self.filters[:-1][::-1] strides, kernel_size = self.strides[1:][::-1], self.kernel_size[1:][::-1] upsample_kernel_size = self.upsample_kernel_size[::-1] - return self.get_module_list(inp, out, kernel_size, strides, UnetUpBlock, upsample_kernel_size) + return self.get_module_list( + inp, out, kernel_size, strides, UnetUpBlock, upsample_kernel_size, trans_bias=self.trans_bias + ) def get_module_list( self, @@ -263,6 +316,7 @@ def get_module_list( strides: Sequence[Union[Sequence[int], int]], conv_block: nn.Module, upsample_kernel_size: Optional[Sequence[Union[Sequence[int], int]]] = None, + trans_bias: bool = False, ): layers = [] if upsample_kernel_size is not None: @@ -276,7 +330,10 @@ def get_module_list( "kernel_size": kernel, "stride": stride, "norm_name": self.norm_name, + "act_name": self.act_name, + "dropout": self.dropout, "upsample_kernel_size": up_kernel, + "trans_bias": trans_bias, } layer = conv_block(**params) layers.append(layer) @@ -289,13 +346,15 @@ def get_module_list( "kernel_size": kernel, "stride": stride, "norm_name": self.norm_name, + "act_name": self.act_name, + "dropout": self.dropout, } layer = conv_block(**params) layers.append(layer) return nn.ModuleList(layers) def get_deep_supervision_heads(self): - return nn.ModuleList([self.get_output_block(i + 1) for i in range(len(self.upsamples) - 1)]) + return nn.ModuleList([self.get_output_block(i + 1) for i in range(self.deep_supr_num)]) @staticmethod def initialize_weights(module): @@ -305,4 +364,4 @@ def initialize_weights(module): module.bias = nn.init.constant_(module.bias, 0) -DynUnet = Dynunet = dynunet = DynUNet +DynUnet = Dynunet = DynUNet diff --git a/monai/networks/nets/dynunet_v1.py b/monai/networks/nets/dynunet_v1.py deleted file mode 100644 index feb05d1762..0000000000 --- a/monai/networks/nets/dynunet_v1.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import List, Sequence, Union - -import torch -import torch.nn as nn - -from monai.networks.blocks.dynunet_block_v1 import _UnetBasicBlockV1, _UnetResBlockV1, _UnetUpBlockV1 -from monai.networks.nets.dynunet import DynUNet, DynUNetSkipLayer -from monai.utils import deprecated - -__all__ = ["DynUNetV1", "DynUnetV1", "DynunetV1"] - - -@deprecated( - since="0.6.0", - removed="0.7.0", - msg_suffix="This module is for backward compatibility purpose only. Please use `DynUNet` instead.", -) -class DynUNetV1(DynUNet): - """ - This a deprecated reimplementation of a dynamic UNet (DynUNet), please use `monai.networks.nets.DynUNet` instead. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - kernel_size: convolution kernel size. - strides: convolution strides for each blocks. - upsample_kernel_size: convolution kernel size for transposed convolution layers. - norm_name: [``"batch"``, ``"instance"``, ``"group"``]. Defaults to "instance". - deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. - deep_supr_num: number of feature maps that will output during deep supervision head. Defaults to 1. - res_block: whether to use residual connection based convolution blocks during the network. - Defaults to ``False``. - - .. deprecated:: 0.6.0 - Use :class:`monai.networks.nets.DynUNet` instead. - - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Sequence[Union[Sequence[int], int]], - strides: Sequence[Union[Sequence[int], int]], - upsample_kernel_size: Sequence[Union[Sequence[int], int]], - norm_name: str = "instance", - deep_supervision: bool = False, - deep_supr_num: int = 1, - res_block: bool = False, - ): - nn.Module.__init__(self) - self.spatial_dims = spatial_dims - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.strides = strides - self.upsample_kernel_size = upsample_kernel_size - self.norm_name = norm_name - self.conv_block = _UnetResBlockV1 if res_block else _UnetBasicBlockV1 # type: ignore - self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))] - self.input_block = self.get_input_block() - self.downsamples = self.get_downsamples() - self.bottleneck = self.get_bottleneck() - self.upsamples = self.get_upsamples() - self.output_block = self.get_output_block(0) - self.deep_supervision = deep_supervision - self.deep_supervision_heads = self.get_deep_supervision_heads() - self.deep_supr_num = deep_supr_num - self.apply(self.initialize_weights) - self.check_kernel_stride() - self.check_deep_supr_num() - - # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on - self.heads: List[torch.Tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1) - - def create_skips(index, downsamples, upsamples, superheads, bottleneck): - """ - Construct the UNet topology as a sequence of skip layers terminating with the bottleneck layer. This is - done recursively from the top down since a recursive nn.Module subclass is being used to be compatible - with Torchscript. Initially the length of `downsamples` will be one more than that of `superheads` - since the `input_block` is passed to this function as the first item in `downsamples`, however this - shouldn't be associated with a supervision head. - """ - - if len(downsamples) != len(upsamples): - raise AssertionError(f"{len(downsamples)} != {len(upsamples)}") - if (len(downsamples) - len(superheads)) not in (1, 0): - raise AssertionError(f"{len(downsamples)}-(0,1) != {len(superheads)}") - - if len(downsamples) == 0: # bottom of the network, pass the bottleneck block - return bottleneck - if index == 0: # don't associate a supervision head with self.input_block - current_head, rest_heads = nn.Identity(), superheads - elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one - current_head, rest_heads = nn.Identity(), superheads[1:] - else: - current_head, rest_heads = superheads[0], superheads[1:] - - # create the next layer down, this will stop at the bottleneck layer - next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck) - - return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], current_head, next_layer) - - self.skip_layers = create_skips( - 0, - [self.input_block] + list(self.downsamples), - self.upsamples[::-1], - self.deep_supervision_heads, - self.bottleneck, - ) - - def get_upsamples(self): - inp, out = self.filters[1:][::-1], self.filters[:-1][::-1] - strides, kernel_size = self.strides[1:][::-1], self.kernel_size[1:][::-1] - upsample_kernel_size = self.upsample_kernel_size[::-1] - return self.get_module_list(inp, out, kernel_size, strides, _UnetUpBlockV1, upsample_kernel_size) - - @staticmethod - def initialize_weights(module): - name = module.__class__.__name__.lower() - if "conv3d" in name or "conv2d" in name: - nn.init.kaiming_normal_(module.weight, a=0.01) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - elif "norm" in name: - nn.init.normal_(module.weight, 1.0, 0.02) - nn.init.zeros_(module.bias) - - -DynUnetV1 = DynunetV1 = DynUNetV1 diff --git a/monai/networks/nets/efficientnet.py b/monai/networks/nets/efficientnet.py index 453916758a..fa5efbc4ef 100644 --- a/monai/networks/nets/efficientnet.py +++ b/monai/networks/nets/efficientnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,7 +23,14 @@ from monai.networks.layers.utils import get_norm_layer from monai.utils.module import look_up_option -__all__ = ["EfficientNet", "EfficientNetBN", "get_efficientnet_image_size", "drop_connect"] +__all__ = [ + "EfficientNet", + "EfficientNetBN", + "get_efficientnet_image_size", + "drop_connect", + "EfficientNetBNFeatures", + "BlockArgs", +] efficientnet_params = { # model_name: (width_mult, depth_mult, image_size, dropout_rate, dropconnect_rate) @@ -82,7 +89,7 @@ def __init__( Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. - out_classes: number of output channels. + out_channels: number of output channels. kernel_size: size of the kernel for conv ops. stride: stride to use for conv ops. image_size: input image resolution. @@ -369,10 +376,7 @@ def __init__( ) idx += 1 # increment blocks index counter - self._blocks.add_module( - str(stack_idx), - sub_stack, - ) + self._blocks.add_module(str(stack_idx), sub_stack) # sanity check to see if len(self._blocks) equal expected num_blocks if idx != num_blocks: @@ -534,7 +538,7 @@ def __init__( weight_coeff, depth_coeff, image_size, dropout_rate, dropconnect_rate = efficientnet_params[model_name] # create model and initialize random weights - super(EfficientNetBN, self).__init__( + super().__init__( blocks_args_str=blocks_args_str, spatial_dims=spatial_dims, in_channels=in_channels, @@ -594,7 +598,7 @@ def __init__( weight_coeff, depth_coeff, image_size, dropout_rate, dropconnect_rate = efficientnet_params[model_name] # create model and initialize random weights - super(EfficientNetBNFeatures, self).__init__( + super().__init__( blocks_args_str=blocks_args_str, spatial_dims=spatial_dims, in_channels=in_channels, @@ -669,7 +673,7 @@ def drop_connect(inputs: torch.Tensor, p: float, training: bool) -> torch.Tensor e.g. 1D activations [B, C, H], 2D activations [B, C, H, W] and 3D activations [B, C, H, W, D] Args: - input: input tensor with [B, C, dim_1, dim_2, ..., dim_N] where N=spatial_dims. + inputs: input tensor with [B, C, dim_1, dim_2, ..., dim_N] where N=spatial_dims. p: probability to use for dropping connections. training: whether in training or evaluation mode. @@ -677,7 +681,7 @@ def drop_connect(inputs: torch.Tensor, p: float, training: bool) -> torch.Tensor output: output tensor after applying drop connection. """ if p < 0.0 or p > 1.0: - raise ValueError("p must be in range of [0, 1], found {}".format(p)) + raise ValueError(f"p must be in range of [0, 1], found {p}") # eval mode: drop_connect is switched off - so return input without modifying if not training: @@ -708,7 +712,7 @@ def _load_state_dict(model: nn.Module, arch: str, progress: bool, adv_prop: bool arch = arch.split("efficientnet-")[-1] + "-ap" model_url = look_up_option(arch, url_map, None) if model_url is None: - print("pretrained weights of {} is not provided".format(arch)) + print(f"pretrained weights of {arch} is not provided") else: # load state dict from url model_url = url_map[arch] @@ -852,7 +856,7 @@ def _calculate_output_image_size(input_image_size: List[int], stride: Union[int, if isinstance(stride, tuple): all_strides_equal = all(stride[0] == s for s in stride) if not all_strides_equal: - raise ValueError("unequal strides are not possible, got {}".format(stride)) + raise ValueError(f"unequal strides are not possible, got {stride}") stride = stride[0] diff --git a/monai/networks/nets/fullyconnectednet.py b/monai/networks/nets/fullyconnectednet.py index b906bab015..810c07431b 100644 --- a/monai/networks/nets/fullyconnectednet.py +++ b/monai/networks/nets/fullyconnectednet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -30,9 +30,24 @@ def _get_adn_layer( class FullyConnectedNet(nn.Sequential): """ - Plain full-connected layer neural network + Simple full-connected layer neural network composed of a sequence of linear layers with PReLU activation and + dropout. The network accepts input with `in_channels` channels, has output with `out_channels` channels, and + hidden layer output channels given in `hidden_channels`. If `bias` is True then linear units have a bias term. + + Args: + in_channels: number of input channels. + out_channels: number of output channels. + hidden_channels: number of output channels for each hidden layer. + dropout: dropout ratio. Defaults to no dropout. + act: activation type and arguments. Defaults to PReLU. + bias: whether to have a bias term in linear units. Defaults to True. + adn_ordering: order of operations in :py:class:`monai.networks.blocks.ADN`. + + Examples:: + + # accepts 4 values and infers 3 values as output, has 3 hidden layers with 10, 20, 10 values as output + net = FullyConnectedNet(4, 3, [10, 20, 10], dropout=0.2) - The network uses dropout and, by default, PReLU activation """ def __init__( @@ -53,8 +68,11 @@ def __init__( self.in_channels = in_channels self.out_channels = out_channels self.hidden_channels = list(hidden_channels) + self.act = act + self.dropout = dropout + self.adn_ordering = adn_ordering + self.add_module("flatten", nn.Flatten()) - self.adn_layer = _get_adn_layer(act, dropout, adn_ordering) prev_channels = self.in_channels for i, c in enumerate(hidden_channels): @@ -64,13 +82,34 @@ def __init__( self.add_module("output", nn.Linear(prev_channels, out_channels, bias)) def _get_layer(self, in_channels: int, out_channels: int, bias: bool) -> nn.Sequential: - seq = nn.Sequential(nn.Linear(in_channels, out_channels, bias)) - seq.add_module("ADN", self.adn_layer) + seq = nn.Sequential( + nn.Linear(in_channels, out_channels, bias), _get_adn_layer(self.act, self.dropout, self.adn_ordering) + ) return seq class VarFullyConnectedNet(nn.Module): - """Variational fully-connected network.""" + """ + Variational fully-connected network. This is composed of an encode layer, reparameterization layer, and then a + decode layer. + + Args: + in_channels: number of input channels. + out_channels: number of output channels. + latent_size: number of latent variables to use. + encode_channels: number of output channels for each hidden layer of the encode half. + decode_channels: number of output channels for each hidden layer of the decode half. + dropout: dropout ratio. Defaults to no dropout. + act: activation type and arguments. Defaults to PReLU. + bias: whether to have a bias term in linear units. Defaults to True. + adn_ordering: order of operations in :py:class:`monai.networks.blocks.ADN`. + + Examples:: + + # accepts inputs with 4 values, uses a latent space of 2 variables, and produces outputs of 3 values + net = VarFullyConnectedNet(4, 3, 2, [5, 10], [10, 5]) + + """ def __init__( self, diff --git a/monai/networks/nets/generator.py b/monai/networks/nets/generator.py index 1f24944a63..a69cae4d7b 100644 --- a/monai/networks/nets/generator.py +++ b/monai/networks/nets/generator.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,13 +25,35 @@ class Generator(nn.Module): """ Defines a simple generator network accepting a latent vector and through a sequence of convolution layers constructs an output tensor of greater size and high dimensionality. The method `_get_layer` is used to - create each of these layers, override this method to define layers beyond the default Convolution or - ResidualUnit layers. + create each of these layers, override this method to define layers beyond the default + :py:class:`monai.networks.blocks.Convolution` or :py:class:`monai.networks.blocks.ResidualUnit` layers. + + The layers are constructed using the values in the `channels` and `strides` arguments, the number being defined by + the length of these (which must match). Input is first passed through a :py:class:`torch.nn.Linear` layer to + convert the input vector to an image tensor with dimensions `start_shape`. This passes through the convolution + layers and is progressively upsampled if the `strides` values are greater than 1 using transpose convolutions. The + size of the final output is defined by the `start_shape` dimension and the amount of upsampling done through + strides. In the default definition the size of the output's spatial dimensions will be that of `start_shape` + multiplied by the product of `strides`, thus the example network below upsamples an starting size of (64, 8, 8) + to (1, 64, 64) since its `strides` are (2, 2, 2). + + Args: + latent_shape: tuple of integers stating the dimension of the input latent vector (minus batch dimension) + start_shape: tuple of integers stating the dimension of the tensor to pass to convolution subnetwork + channels: tuple of integers stating the output channels of each convolutional layer + strides: tuple of integers stating the stride (upscale factor) of each convolutional layer + kernel_size: integer or tuple of integers stating size of convolutional kernels + num_res_units: integer stating number of convolutions in residual units, 0 means no residual units + act: name or type defining activation layers + norm: name or type defining normalization layers + dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout + bias: boolean stating if convolution layers should have a bias component + + Examples:: + + # 3 layers, latent input vector of shape (42, 24), output volume of shape (1, 64, 64) + net = Generator((42, 24), (64, 8, 8), (32, 16, 1), (2, 2, 2)) - For example, a generator accepting a latent vector if shape (42,24) and producing an output volume of - shape (1,64,64) can be constructed as: - - gen = Generator((42, 24), (64, 8, 8), (32, 16, 1), (2, 2, 2)) """ def __init__( @@ -47,26 +69,6 @@ def __init__( dropout: Optional[float] = None, bias: bool = True, ) -> None: - """ - Construct the generator network with the number of layers defined by `channels` and `strides`. In the - forward pass a `nn.Linear` layer relates the input latent vector to a tensor of dimensions `start_shape`, - this is then fed forward through the sequence of convolutional layers. The number of layers is defined by - the length of `channels` and `strides` which must match, each layer having the number of output channels - given in `channels` and an upsample factor given in `strides` (ie. a transpose convolution with that stride - size). - - Args: - latent_shape: tuple of integers stating the dimension of the input latent vector (minus batch dimension) - start_shape: tuple of integers stating the dimension of the tensor to pass to convolution subnetwork - channels: tuple of integers stating the output channels of each convolutional layer - strides: tuple of integers stating the stride (upscale factor) of each convolutional layer - kernel_size: integer or tuple of integers stating size of convolutional kernels - num_res_units: integer stating number of convolutions in residual units, 0 means no residual units - act: name or type defining activation layers - norm: name or type defining normalization layers - dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout - bias: boolean stating if convolution layers should have a bias component - """ super().__init__() self.in_channels, *self.start_shape = ensure_tuple(start_shape) @@ -112,7 +114,7 @@ def _get_layer( strides=strides, is_transposed=True, conv_only=is_last or self.num_res_units > 0, - dimensions=self.dimensions, + spatial_dims=self.dimensions, out_channels=out_channels, kernel_size=self.kernel_size, act=self.act, @@ -126,7 +128,7 @@ def _get_layer( in_channels=out_channels, subunits=self.num_res_units, last_conv_only=is_last, - dimensions=self.dimensions, + spatial_dims=self.dimensions, out_channels=out_channels, kernel_size=self.kernel_size, act=self.act, diff --git a/monai/networks/nets/highresnet.py b/monai/networks/nets/highresnet.py index f644a7835a..891a65e67b 100644 --- a/monai/networks/nets/highresnet.py +++ b/monai/networks/nets/highresnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -70,7 +70,7 @@ def __init__( ValueError: When ``channel_matching=pad`` and ``in_channels > out_channels``. Incompatible values. """ - super(HighResBlock, self).__init__() + super().__init__() self.chn_pad = ChannelPad( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, mode=channel_matching ) @@ -84,7 +84,7 @@ def __init__( ) layers.append( Convolution( - dimensions=spatial_dims, + spatial_dims=spatial_dims, in_channels=_in_chns, out_channels=_out_chns, kernel_size=kernel_size, @@ -146,7 +146,7 @@ def __init__( channel_matching: Union[ChannelMatching, str] = ChannelMatching.PAD, ) -> None: - super(HighResNet, self).__init__() + super().__init__() blocks = nn.ModuleList() # initial conv layer @@ -154,7 +154,7 @@ def __init__( _in_chns, _out_chns = in_channels, params["n_features"] blocks.append( Convolution( - dimensions=spatial_dims, + spatial_dims=spatial_dims, in_channels=_in_chns, out_channels=_out_chns, kernel_size=params["kernel_size"], @@ -168,7 +168,7 @@ def __init__( # residual blocks for (idx, params) in enumerate(layer_params[1:-2]): # res blocks except the 1st and last two conv layers. _in_chns, _out_chns = _out_chns, params["n_features"] - _dilation = 2 ** idx + _dilation = 2**idx for _ in range(params["repeat"]): blocks.append( HighResBlock( @@ -190,7 +190,7 @@ def __init__( _in_chns, _out_chns = _out_chns, params["n_features"] blocks.append( Convolution( - dimensions=spatial_dims, + spatial_dims=spatial_dims, in_channels=_in_chns, out_channels=_out_chns, kernel_size=params["kernel_size"], @@ -206,7 +206,7 @@ def __init__( _in_chns = _out_chns blocks.append( Convolution( - dimensions=spatial_dims, + spatial_dims=spatial_dims, in_channels=_in_chns, out_channels=out_channels, kernel_size=params["kernel_size"], diff --git a/monai/networks/nets/milmodel.py b/monai/networks/nets/milmodel.py new file mode 100644 index 0000000000..2f4afaffbe --- /dev/null +++ b/monai/networks/nets/milmodel.py @@ -0,0 +1,244 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Union, cast + +import torch +import torch.nn as nn + +from monai.utils.module import optional_import + +models, _ = optional_import("torchvision.models") + + +class MILModel(nn.Module): + """ + Multiple Instance Learning (MIL) model, with a backbone classification model. + Currently, it only works for 2D images, a typical use case is for classification of the + digital pathology whole slide images. The expected shape of input data is `[B, N, C, H, W]`, + where `B` is the batch_size of PyTorch Dataloader and `N` is the number of instances + extracted from every original image in the batch. A tutorial example is available at: + https://github.com/Project-MONAI/tutorials/tree/master/pathology/multiple_instance_learning. + + Args: + num_classes: number of output classes. + mil_mode: MIL algorithm, available values (Defaults to ``"att"``): + + - ``"mean"`` - average features from all instances, equivalent to pure CNN (non MIL). + - ``"max"`` - retain only the instance with the max probability for loss calculation. + - ``"att"`` - attention based MIL https://arxiv.org/abs/1802.04712. + - ``"att_trans"`` - transformer MIL https://arxiv.org/abs/2111.01556. + - ``"att_trans_pyramid"`` - transformer pyramid MIL https://arxiv.org/abs/2111.01556. + + pretrained: init backbone with pretrained weights, defaults to ``True``. + backbone: Backbone classifier CNN (either ``None``, a ``nn.Module`` that returns features, + or a string name of a torchvision model). + Defaults to ``None``, in which case ResNet50 is used. + backbone_num_features: Number of output features of the backbone CNN + Defaults to ``None`` (necessary only when using a custom backbone) + trans_blocks: number of the blocks in `TransformEncoder` layer. + trans_dropout: dropout rate in `TransformEncoder` layer. + + """ + + def __init__( + self, + num_classes: int, + mil_mode: str = "att", + pretrained: bool = True, + backbone: Optional[Union[str, nn.Module]] = None, + backbone_num_features: Optional[int] = None, + trans_blocks: int = 4, + trans_dropout: float = 0.0, + ) -> None: + + super().__init__() + + if num_classes <= 0: + raise ValueError("Number of classes must be positive: " + str(num_classes)) + + if mil_mode.lower() not in ["mean", "max", "att", "att_trans", "att_trans_pyramid"]: + raise ValueError("Unsupported mil_mode: " + str(mil_mode)) + + self.mil_mode = mil_mode.lower() + self.attention = nn.Sequential() + self.transformer = None # type: Optional[nn.Module] + + if backbone is None: + + net = models.resnet50(pretrained=pretrained) + nfc = net.fc.in_features # save the number of final features + net.fc = torch.nn.Identity() # remove final linear layer + + self.extra_outputs = {} # type: Dict[str, torch.Tensor] + + if mil_mode == "att_trans_pyramid": + # register hooks to capture outputs of intermediate layers + def forward_hook(layer_name): + def hook(module, input, output): + self.extra_outputs[layer_name] = output + + return hook + + net.layer1.register_forward_hook(forward_hook("layer1")) + net.layer2.register_forward_hook(forward_hook("layer2")) + net.layer3.register_forward_hook(forward_hook("layer3")) + net.layer4.register_forward_hook(forward_hook("layer4")) + + elif isinstance(backbone, str): + + # assume torchvision model string is provided + torch_model = getattr(models, backbone, None) + if torch_model is None: + raise ValueError("Unknown torch vision model" + str(backbone)) + net = torch_model(pretrained=pretrained) + + if getattr(net, "fc", None) is not None: + nfc = net.fc.in_features # save the number of final features + net.fc = torch.nn.Identity() # remove final linear layer + else: + raise ValueError( + "Unable to detect FC layer for the torchvision model " + str(backbone), + ". Please initialize the backbone model manually.", + ) + + elif isinstance(backbone, nn.Module): + # use a custom backbone + net = backbone + nfc = backbone_num_features + + if backbone_num_features is None: + raise ValueError("Number of endencoder features must be provided for a custom backbone model") + + else: + raise ValueError("Unsupported backbone") + + if backbone is not None and mil_mode not in ["mean", "max", "att", "att_trans"]: + raise ValueError("Custom backbone is not supported for the mode:" + str(mil_mode)) + + if self.mil_mode in ["mean", "max"]: + pass + elif self.mil_mode == "att": + self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) + + elif self.mil_mode == "att_trans": + transformer = nn.TransformerEncoderLayer(d_model=nfc, nhead=8, dropout=trans_dropout) + self.transformer = nn.TransformerEncoder(transformer, num_layers=trans_blocks) + self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) + + elif self.mil_mode == "att_trans_pyramid": + + transformer_list = nn.ModuleList( + [ + nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), num_layers=trans_blocks + ), + nn.Sequential( + nn.Linear(768, 256), + nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), + num_layers=trans_blocks, + ), + ), + nn.Sequential( + nn.Linear(1280, 256), + nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), + num_layers=trans_blocks, + ), + ), + nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=2304, nhead=8, dropout=trans_dropout), + num_layers=trans_blocks, + ), + ] + ) + self.transformer = transformer_list + nfc = nfc + 256 + self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) + + else: + raise ValueError("Unsupported mil_mode: " + str(mil_mode)) + + self.myfc = nn.Linear(nfc, num_classes) + self.net = net + + def calc_head(self, x: torch.Tensor) -> torch.Tensor: + + sh = x.shape + + if self.mil_mode == "mean": + x = self.myfc(x) + x = torch.mean(x, dim=1) + + elif self.mil_mode == "max": + x = self.myfc(x) + x, _ = torch.max(x, dim=1) + + elif self.mil_mode == "att": + + a = self.attention(x) + a = torch.softmax(a, dim=1) + x = torch.sum(x * a, dim=1) + + x = self.myfc(x) + + elif self.mil_mode == "att_trans" and self.transformer is not None: + + x = x.permute(1, 0, 2) + x = self.transformer(x) + x = x.permute(1, 0, 2) + + a = self.attention(x) + a = torch.softmax(a, dim=1) + x = torch.sum(x * a, dim=1) + + x = self.myfc(x) + + elif self.mil_mode == "att_trans_pyramid" and self.transformer is not None: + + l1 = torch.mean(self.extra_outputs["layer1"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) + l2 = torch.mean(self.extra_outputs["layer2"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) + l3 = torch.mean(self.extra_outputs["layer3"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) + l4 = torch.mean(self.extra_outputs["layer4"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) + + transformer_list = cast(nn.ModuleList, self.transformer) + + x = transformer_list[0](l1) + x = transformer_list[1](torch.cat((x, l2), dim=2)) + x = transformer_list[2](torch.cat((x, l3), dim=2)) + x = transformer_list[3](torch.cat((x, l4), dim=2)) + + x = x.permute(1, 0, 2) + + a = self.attention(x) + a = torch.softmax(a, dim=1) + x = torch.sum(x * a, dim=1) + + x = self.myfc(x) + + else: + raise ValueError("Wrong model mode" + str(self.mil_mode)) + + return x + + def forward(self, x: torch.Tensor, no_head: bool = False) -> torch.Tensor: + + sh = x.shape + x = x.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4]) + + x = self.net(x) + x = x.reshape(sh[0], sh[1], -1) + + if not no_head: + x = self.calc_head(x) + + return x diff --git a/monai/networks/nets/netadapter.py b/monai/networks/nets/netadapter.py index 80288f7945..425c1d5820 100644 --- a/monai/networks/nets/netadapter.py +++ b/monai/networks/nets/netadapter.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,11 +24,12 @@ class NetAdapter(torch.nn.Module): then replace the model's last two layers with an optional `pooling` and a `conv` or `linear` layer. Args: - model: a PyTorch model, support both 2D and 3D models. typically, it can be a pretrained model in Torchvision, - like: ``resnet18``, ``resnet34m``, ``resnet50``, ``resnet101``, ``resnet152``, etc. + model: a PyTorch model, which can be both 2D and 3D models. typically, it can be a pretrained model + in Torchvision, like: ``resnet18``, ``resnet34``, ``resnet50``, ``resnet101``, ``resnet152``, etc. more details: https://pytorch.org/vision/stable/models.html. num_classes: number of classes for the last classification layer. Default to 1. - dim: number of spatial dimensions, default to 2. + dim: number of supported spatial dimensions in the specified model, depends on the model implementation. + default to 2 as most Torchvision models are for 2D image processing. in_channels: number of the input channels of last layer. if None, get it from `in_features` of last layer. use_conv: whether use convolutional layer to replace the last layer, default to False. pool: parameters for the pooling layer, it should be a tuple, the first item is name of the pooling layer, @@ -37,6 +38,9 @@ class NetAdapter(torch.nn.Module): bias: the bias value when replacing the last layer. if False, the layer will not learn an additive bias, default to True. + .. deprecated:: 0.6.0 + ``n_classes`` is deprecated, use ``num_classes`` instead. + """ @deprecated_arg("n_classes", since="0.6") @@ -67,32 +71,23 @@ def __init__( in_channels_ = in_channels if pool is None: - self.pool = None # remove the last layer self.features = torch.nn.Sequential(*layers[:-1]) + self.pool = None else: - self.pool = get_pool_layer(name=pool, spatial_dims=dim) # remove the last 2 layers self.features = torch.nn.Sequential(*layers[:-2]) + self.pool = get_pool_layer(name=pool, spatial_dims=dim) self.fc: Union[torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv3d] if use_conv: # add 1x1 conv (it behaves like a FC layer) - self.fc = Conv[Conv.CONV, dim]( - in_channels=in_channels_, - out_channels=num_classes, - kernel_size=1, - bias=bias, - ) + self.fc = Conv[Conv.CONV, dim](in_channels=in_channels_, out_channels=num_classes, kernel_size=1, bias=bias) else: # remove the last Linear layer (fully connected) self.features = torch.nn.Sequential(*layers[:-1]) # replace the out_features of FC layer - self.fc = torch.nn.Linear( - in_features=in_channels_, - out_features=num_classes, - bias=bias, - ) + self.fc = torch.nn.Linear(in_features=in_channels_, out_features=num_classes, bias=bias) self.use_conv = use_conv def forward(self, x): diff --git a/monai/networks/nets/regressor.py b/monai/networks/nets/regressor.py index 25acb9bfa5..0a1e6258a9 100644 --- a/monai/networks/nets/regressor.py +++ b/monai/networks/nets/regressor.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -29,6 +29,30 @@ class Regressor(nn.Module): This defines a network for relating large-sized input tensors to small output tensors, ie. regressing large values to a prediction. An output of a single dimension can be used as value regression or multi-label classification prediction, an output of a single value can be used as a discriminator or critic prediction. + + The network is constructed as a sequence of layers, either :py:class:`monai.networks.blocks.Convolution` or + :py:class:`monai.networks.blocks.ResidualUnit`, with a final fully-connected layer resizing the output from the + blocks to the final size. Each block is defined with a stride value typically used to downsample the input using + strided convolutions. In this way each block progressively condenses information from the input into a deep + representation the final fully-connected layer relates to a final result. + + Args: + in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension) + out_shape: tuple of integers stating the dimension of the final output tensor (minus batch dimension) + channels: tuple of integers stating the output channels of each convolutional layer + strides: tuple of integers stating the stride (downscale factor) of each convolutional layer + kernel_size: integer or tuple of integers stating size of convolutional kernels + num_res_units: integer stating number of convolutions in residual units, 0 means no residual units + act: name or type defining activation layers + norm: name or type defining normalization layers + dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout + bias: boolean stating if convolution layers should have a bias component + + Examples:: + + # infers a 2-value result (eg. a 2D cartesian coordinate) from a 64x64 image + net = Regressor((1, 64, 64), (2,), (2, 4, 8), (2, 2, 2)) + """ def __init__( @@ -44,23 +68,6 @@ def __init__( dropout: Optional[float] = None, bias: bool = True, ) -> None: - """ - Construct the regressor network with the number of layers defined by `channels` and `strides`. Inputs are - first passed through the convolutional layers in the forward pass, the output from this is then pass - through a fully connected layer to relate them to the final output tensor. - - Args: - in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension) - out_shape: tuple of integers stating the dimension of the final output tensor - channels: tuple of integers stating the output channels of each convolutional layer - strides: tuple of integers stating the stride (downscale factor) of each convolutional layer - kernel_size: integer or tuple of integers stating size of convolutional kernels - num_res_units: integer stating number of convolutions in residual units, 0 means no residual units - act: name or type defining activation layers - norm: name or type defining normalization layers - dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout - bias: boolean stating if convolution layers should have a bias component - """ super().__init__() self.in_channels, *self.in_shape = ensure_tuple(in_shape) @@ -107,7 +114,7 @@ def _get_layer( layer = ResidualUnit( subunits=self.num_res_units, last_conv_only=is_last, - dimensions=self.dimensions, + spatial_dims=self.dimensions, in_channels=in_channels, out_channels=out_channels, strides=strides, @@ -120,7 +127,7 @@ def _get_layer( else: layer = Convolution( conv_only=is_last, - dimensions=self.dimensions, + spatial_dims=self.dimensions, in_channels=in_channels, out_channels=out_channels, strides=strides, diff --git a/monai/networks/nets/regunet.py b/monai/networks/nets/regunet.py index 4cf747f650..6776c7ce9e 100644 --- a/monai/networks/nets/regunet.py +++ b/monai/networks/nets/regunet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,6 +21,7 @@ get_conv_block, get_deconv_block, ) +from monai.networks.utils import meshgrid_ij __all__ = ["RegUNet", "AffineHead", "GlobalNet", "LocalNet"] @@ -67,7 +68,7 @@ def __init__( concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition encode_kernel_sizes: kernel size for down-sampling """ - super(RegUNet, self).__init__() + super().__init__() if not extract_levels: extract_levels = (depth,) if max(extract_levels) != depth: @@ -91,7 +92,7 @@ def __init__( raise AssertionError self.encode_kernel_sizes: List[int] = encode_kernel_sizes - self.num_channels = [self.num_channel_initial * (2 ** d) for d in range(self.depth + 1)] + self.num_channels = [self.num_channel_initial * (2**d) for d in range(self.depth + 1)] self.min_extract_level = min(self.extract_levels) # init layers @@ -106,9 +107,7 @@ def __init__( # build layers self.build_layers() - def build_layers( - self, - ): + def build_layers(self): self.build_encode_layers() self.build_decode_layers() @@ -125,23 +124,13 @@ def build_encode_layers(self): ] ) self.encode_pools = nn.ModuleList( - [ - self.build_down_sampling_block( - channels=self.num_channels[d], - ) - for d in range(self.depth) - ] + [self.build_down_sampling_block(channels=self.num_channels[d]) for d in range(self.depth)] ) self.bottom_block = self.build_bottom_block( in_channels=self.num_channels[-2], out_channels=self.num_channels[-1] ) - def build_conv_block( - self, - in_channels, - out_channels, - kernel_size, - ): + def build_conv_block(self, in_channels, out_channels, kernel_size): return nn.Sequential( get_conv_block( spatial_dims=self.spatial_dims, @@ -157,10 +146,7 @@ def build_conv_block( ), ) - def build_down_sampling_block( - self, - channels: int, - ): + def build_down_sampling_block(self, channels: int): return RegistrationDownSampleBlock(spatial_dims=self.spatial_dims, channels=channels, pooling=self.pooling) def build_bottom_block(self, in_channels: int, out_channels: int): @@ -203,11 +189,7 @@ def build_decode_layers(self): # extraction self.output_block = self.build_output_block() - def build_up_sampling_block( - self, - in_channels: int, - out_channels: int, - ) -> nn.Module: + def build_up_sampling_block(self, in_channels: int, out_channels: int) -> nn.Module: return get_deconv_block(spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels) def build_output_block(self) -> nn.Module: @@ -255,14 +237,8 @@ def forward(self, x): class AffineHead(nn.Module): - def __init__( - self, - spatial_dims: int, - image_size: List[int], - decode_size: List[int], - in_channels: int, - ): - super(AffineHead, self).__init__() + def __init__(self, spatial_dims: int, image_size: List[int], decode_size: List[int], in_channels: int): + super().__init__() self.spatial_dims = spatial_dims if spatial_dims == 2: in_features = in_channels * decode_size[0] * decode_size[1] @@ -285,7 +261,7 @@ def __init__( @staticmethod def get_reference_grid(image_size: Union[Tuple[int], List[int]]) -> torch.Tensor: mesh_points = [torch.arange(0, dim) for dim in image_size] - grid = torch.stack(torch.meshgrid(*mesh_points), dim=0) # (spatial_dims, ...) + grid = torch.stack(meshgrid_ij(*mesh_points), dim=0) # (spatial_dims, ...) return grid.to(dtype=torch.float) def affine_transform(self, theta: torch.Tensor): @@ -334,14 +310,14 @@ def __init__( encode_kernel_sizes: Union[int, List[int]] = 3, ): for size in image_size: - if size % (2 ** depth) != 0: + if size % (2**depth) != 0: raise ValueError( f"given depth {depth}, " f"all input spatial dimension must be divisible by {2 ** depth}, " f"got input of size {image_size}" ) self.image_size = image_size - self.decode_size = [size // (2 ** depth) for size in image_size] + self.decode_size = [size // (2**depth) for size in image_size] super().__init__( spatial_dims=spatial_dims, in_channels=in_channels, @@ -365,13 +341,8 @@ def build_output_block(self): class AdditiveUpSampleBlock(nn.Module): - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - ): - super(AdditiveUpSampleBlock, self).__init__() + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int): + super().__init__() self.deconv = get_deconv_block(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -435,17 +406,10 @@ def __init__( def build_bottom_block(self, in_channels: int, out_channels: int): kernel_size = self.encode_kernel_sizes[self.depth] return get_conv_block( - spatial_dims=self.spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, + spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size ) - def build_up_sampling_block( - self, - in_channels: int, - out_channels: int, - ) -> nn.Module: + def build_up_sampling_block(self, in_channels: int, out_channels: int) -> nn.Module: if self._use_additive_upsampling: return AdditiveUpSampleBlock( spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index a5e6b7ab81..c8be9f0e89 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,13 +10,15 @@ # limitations under the License. from functools import partial -from typing import Any, Callable, List, Optional, Type, Union +from typing import Any, Callable, List, Optional, Tuple, Type, Union import torch import torch.nn as nn -import torch.nn.functional as F from monai.networks.layers.factories import Conv, Norm, Pool +from monai.networks.layers.utils import get_pool_layer +from monai.utils import ensure_tuple_rep +from monai.utils.module import look_up_option __all__ = ["ResNet", "resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"] @@ -28,15 +30,7 @@ def get_inplanes(): def get_avgpool(): - return [(0), (1), (1, 1), (1, 1, 1)] - - -def get_conv1(conv1_t_size: int, conv1_t_stride: int): - return ( - [(0), (conv1_t_size), (conv1_t_size, 7), (conv1_t_size, 7, 7)], - [(0), (conv1_t_stride), (conv1_t_stride, 2), (conv1_t_stride, 2, 2)], - [(0), (conv1_t_size // 2), (conv1_t_size // 2, 3), (conv1_t_size // 2, 3, 3)], - ) + return [0, 1, (1, 1), (1, 1, 1)] class ResNetBlock(nn.Module): @@ -58,7 +52,7 @@ def __init__( stride: stride to use for first conv layer. downsample: which downsample layer to use. """ - super(ResNetBlock, self).__init__() + super().__init__() conv_type: Callable = Conv[Conv.CONV, spatial_dims] norm_type: Callable = Norm[Norm.BATCH, spatial_dims] @@ -110,7 +104,7 @@ def __init__( downsample: which downsample layer to use. """ - super(ResNetBottleneck, self).__init__() + super().__init__() conv_type: Callable = Conv[Conv.CONV, spatial_dims] norm_type: Callable = Norm[Norm.BATCH, spatial_dims] @@ -153,6 +147,7 @@ class ResNet(nn.Module): ResNet based on: `Deep Residual Learning for Image Recognition `_ and `Can Spatiotemporal 3D CNNs Retrace the History of 2D CNNs and ImageNet? `_. Adapted from ``_. + Args: block: which ResNet block to use, either Basic or Bottleneck. layers: how many layers to use. @@ -162,9 +157,16 @@ class ResNet(nn.Module): conv1_t_size: size of first convolution layer, determines kernel and padding. conv1_t_stride: stride of first convolution layer. no_max_pool: bool argument to determine if to use maxpool layer. - shortcut_type: which downsample block to use. + shortcut_type: which downsample block to use. Options are 'A', 'B', default to 'B'. + - 'A': using `self._downsample_basic_block`. + - 'B': kernel_size 1 conv + norm. widen_factor: widen output for each layer. - num_classes: number of output (classifications) + num_classes: number of output (classifications). + feed_forward: whether to add the FC layer for the output, default to `True`. + + .. deprecated:: 0.6.0 + ``n_classes`` is deprecated, use ``num_classes`` instead. + """ @deprecated_arg("n_classes", since="0.6") @@ -175,8 +177,8 @@ def __init__( block_inplanes: List[int], spatial_dims: int = 3, n_input_channels: int = 3, - conv1_t_size: int = 7, - conv1_t_stride: int = 1, + conv1_t_size: Union[Tuple[int], int] = 7, + conv1_t_stride: Union[Tuple[int], int] = 1, no_max_pool: bool = False, shortcut_type: str = "B", widen_factor: float = 1.0, @@ -185,7 +187,7 @@ def __init__( n_classes: Optional[int] = None, ) -> None: - super(ResNet, self).__init__() + super().__init__() # in case the new num_classes is default but you still call deprecated n_classes if n_classes is not None and num_classes == 400: num_classes = n_classes @@ -198,18 +200,20 @@ def __init__( ] block_avgpool = get_avgpool() - conv1_kernel, conv1_stride, con1_padding = get_conv1(conv1_t_size, conv1_t_stride) block_inplanes = [int(x * widen_factor) for x in block_inplanes] self.in_planes = block_inplanes[0] self.no_max_pool = no_max_pool + conv1_kernel_size = ensure_tuple_rep(conv1_t_size, spatial_dims) + conv1_stride = ensure_tuple_rep(conv1_t_stride, spatial_dims) + self.conv1 = conv_type( n_input_channels, self.in_planes, - kernel_size=conv1_kernel[spatial_dims], - stride=conv1_stride[spatial_dims], - padding=con1_padding[spatial_dims], + kernel_size=conv1_kernel_size, # type: ignore + stride=conv1_stride, # type: ignore + padding=tuple(k // 2 for k in conv1_kernel_size), # type: ignore bias=False, ) self.bn1 = norm_type(self.in_planes) @@ -220,9 +224,7 @@ def __init__( self.layer3 = self._make_layer(block, block_inplanes[2], layers[2], spatial_dims, shortcut_type, stride=2) self.layer4 = self._make_layer(block, block_inplanes[3], layers[3], spatial_dims, shortcut_type, stride=2) self.avgpool = avgp_type(block_avgpool[spatial_dims]) - - if feed_forward: - self.fc = nn.Linear(block_inplanes[3] * block.expansion, num_classes) + self.fc = nn.Linear(block_inplanes[3] * block.expansion, num_classes) if feed_forward else None for m in self.modules(): if isinstance(m, conv_type): @@ -234,14 +236,9 @@ def __init__( nn.init.constant_(torch.as_tensor(m.bias), 0) def _downsample_basic_block(self, x: torch.Tensor, planes: int, stride: int, spatial_dims: int = 3) -> torch.Tensor: - assert spatial_dims == 3 - out: torch.Tensor = F.avg_pool3d(x, kernel_size=1, stride=stride) - zero_pads = torch.zeros(out.size(0), planes - out.size(1), out.size(2), out.size(3), out.size(4)) - if isinstance(out.data, torch.FloatTensor): - zero_pads = zero_pads.cuda() - + out: torch.Tensor = get_pool_layer(("avg", {"kernel_size": 1, "stride": stride}), spatial_dims=spatial_dims)(x) + zero_pads = torch.zeros(out.size(0), planes - out.size(1), *out.shape[2:], dtype=out.dtype, device=out.device) out = torch.cat([out.data, zero_pads], dim=1) - return out def _make_layer( @@ -259,9 +256,12 @@ def _make_layer( downsample: Union[nn.Module, partial, None] = None if stride != 1 or self.in_planes != planes * block.expansion: - if shortcut_type == "A": + if look_up_option(shortcut_type, {"A", "B"}) == "A": downsample = partial( - self._downsample_basic_block, planes=planes * block.expansion, kernel_size=1, stride=stride + self._downsample_basic_block, + planes=planes * block.expansion, + stride=stride, + spatial_dims=spatial_dims, ) else: downsample = nn.Sequential( @@ -269,12 +269,12 @@ def _make_layer( norm_type(planes * block.expansion), ) - layers = [] - layers.append( + layers = [ block( in_planes=self.in_planes, planes=planes, spatial_dims=spatial_dims, stride=stride, downsample=downsample ) - ) + ] + self.in_planes = planes * block.expansion for _i in range(1, blocks): layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims)) @@ -296,7 +296,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.avgpool(x) x = x.view(x.size(0), -1) - x = self.fc(x) + if self.fc is not None: + x = self.fc(x) return x diff --git a/monai/networks/nets/segresnet.py b/monai/networks/nets/segresnet.py index 8be562aadd..299f1ca811 100644 --- a/monai/networks/nets/segresnet.py +++ b/monai/networks/nets/segresnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple, Union +from typing import List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -73,7 +73,7 @@ def __init__( super().__init__() if spatial_dims not in (2, 3): - raise AssertionError("spatial_dims can only be 2 or 3.") + raise ValueError("`spatial_dims` can only be 2 or 3.") self.spatial_dims = spatial_dims self.init_filters = init_filters @@ -81,7 +81,8 @@ def __init__( self.blocks_down = blocks_down self.blocks_up = blocks_up self.dropout_prob = dropout_prob - self.act = get_act_layer(act) + self.act = act # input options + self.act_mod = get_act_layer(act) if norm_name: if norm_name.lower() != "group": raise ValueError(f"Deprecating option 'norm_name={norm_name}', please use 'norm' instead.") @@ -99,14 +100,9 @@ def __init__( def _make_down_layers(self): down_layers = nn.ModuleList() - blocks_down, spatial_dims, filters, norm = ( - self.blocks_down, - self.spatial_dims, - self.init_filters, - self.norm, - ) + blocks_down, spatial_dims, filters, norm = (self.blocks_down, self.spatial_dims, self.init_filters, self.norm) for i in range(len(blocks_down)): - layer_in_channels = filters * 2 ** i + layer_in_channels = filters * 2**i pre_conv = ( get_conv_layer(spatial_dims, layer_in_channels // 2, layer_in_channels, stride=2) if i > 0 @@ -114,7 +110,7 @@ def _make_down_layers(self): ) down_layer = nn.Sequential( pre_conv, - *[ResBlock(spatial_dims, layer_in_channels, norm=norm) for _ in range(blocks_down[i])], + *[ResBlock(spatial_dims, layer_in_channels, norm=norm, act=self.act) for _ in range(blocks_down[i])], ) down_layers.append(down_layer) return down_layers @@ -133,7 +129,10 @@ def _make_up_layers(self): sample_in_channels = filters * 2 ** (n_up - i) up_layers.append( nn.Sequential( - *[ResBlock(spatial_dims, sample_in_channels // 2, norm=norm) for _ in range(blocks_up[i])] + *[ + ResBlock(spatial_dims, sample_in_channels // 2, norm=norm, act=self.act) + for _ in range(blocks_up[i]) + ] ) ) up_samples.append( @@ -149,11 +148,11 @@ def _make_up_layers(self): def _make_final_conv(self, out_channels: int): return nn.Sequential( get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=self.init_filters), - self.act, + self.act_mod, get_conv_layer(self.spatial_dims, self.init_filters, out_channels, kernel_size=1, bias=True), ) - def forward(self, x): + def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: x = self.convInit(x) if self.dropout_prob is not None: x = self.dropout(x) @@ -164,14 +163,23 @@ def forward(self, x): x = down(x) down_x.append(x) - down_x.reverse() + return x, down_x + def decode(self, x: torch.Tensor, down_x: List[torch.Tensor]) -> torch.Tensor: for i, (up, upl) in enumerate(zip(self.up_samples, self.up_layers)): x = up(x) + down_x[i + 1] x = upl(x) if self.use_conv_final: x = self.conv_final(x) + + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, down_x = self.encode(x) + down_x.reverse() + + x = self.decode(x, down_x) return x @@ -226,12 +234,13 @@ def __init__( blocks_up: tuple = (1, 1, 1), upsample_mode: Union[UpsampleMode, str] = UpsampleMode.NONTRAINABLE, ): - super(SegResNetVAE, self).__init__( + super().__init__( spatial_dims=spatial_dims, init_filters=init_filters, in_channels=in_channels, out_channels=out_channels, dropout_prob=dropout_prob, + act=act, norm=norm, use_conv_final=use_conv_final, blocks_down=blocks_down, @@ -258,10 +267,10 @@ def _prepare_vae_modules(self): self.vae_down = nn.Sequential( get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=v_filters), - self.act, + self.act_mod, get_conv_layer(self.spatial_dims, v_filters, self.smallest_filters, stride=2, bias=True), get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=self.smallest_filters), - self.act, + self.act_mod, ) self.vae_fc1 = nn.Linear(total_elements, self.vae_nz) self.vae_fc2 = nn.Linear(total_elements, self.vae_nz) @@ -271,7 +280,7 @@ def _prepare_vae_modules(self): get_conv_layer(self.spatial_dims, self.smallest_filters, v_filters, kernel_size=1), get_upsample_layer(self.spatial_dims, v_filters, upsample_mode=self.upsample_mode), get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=v_filters), - self.act, + self.act_mod, ) def _get_vae_loss(self, net_input: torch.Tensor, vae_input: torch.Tensor): @@ -290,17 +299,17 @@ def _get_vae_loss(self, net_input: torch.Tensor, vae_input: torch.Tensor): if self.vae_estimate_std: z_sigma = self.vae_fc2(x_vae) z_sigma = F.softplus(z_sigma) - vae_reg_loss = 0.5 * torch.mean(z_mean ** 2 + z_sigma ** 2 - torch.log(1e-8 + z_sigma ** 2) - 1) + vae_reg_loss = 0.5 * torch.mean(z_mean**2 + z_sigma**2 - torch.log(1e-8 + z_sigma**2) - 1) x_vae = z_mean + z_sigma * z_mean_rand else: z_sigma = self.vae_default_std - vae_reg_loss = torch.mean(z_mean ** 2) + vae_reg_loss = torch.mean(z_mean**2) x_vae = z_mean + z_sigma * z_mean_rand x_vae = self.vae_fc3(x_vae) - x_vae = self.act(x_vae) + x_vae = self.act_mod(x_vae) x_vae = x_vae.view([-1, self.smallest_filters] + self.fc_insize) x_vae = self.vae_fc_up_sample(x_vae) @@ -315,25 +324,11 @@ def _get_vae_loss(self, net_input: torch.Tensor, vae_input: torch.Tensor): def forward(self, x): net_input = x - x = self.convInit(x) - if self.dropout_prob is not None: - x = self.dropout(x) - - down_x = [] - for down in self.down_layers: - x = down(x) - down_x.append(x) - + x, down_x = self.encode(x) down_x.reverse() vae_input = x - - for i, (up, upl) in enumerate(zip(self.up_samples, self.up_layers)): - x = up(x) + down_x[i + 1] - x = upl(x) - - if self.use_conv_final: - x = self.conv_final(x) + x = self.decode(x, down_x) if self.training: vae_loss = self._get_vae_loss(net_input, vae_input) diff --git a/monai/networks/nets/senet.py b/monai/networks/nets/senet.py index 9b7035c259..a85d32ba5a 100644 --- a/monai/networks/nets/senet.py +++ b/monai/networks/nets/senet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,12 +17,32 @@ import torch.nn as nn from torch.hub import load_state_dict_from_url +from monai.apps.utils import download_url from monai.networks.blocks.convolutions import Convolution from monai.networks.blocks.squeeze_and_excitation import SEBottleneck, SEResNetBottleneck, SEResNeXtBottleneck from monai.networks.layers.factories import Act, Conv, Dropout, Norm, Pool from monai.utils.module import look_up_option -__all__ = ["SENet", "SENet154", "SEResNet50", "SEResNet101", "SEResNet152", "SEResNeXt50", "SEResNext101"] +__all__ = [ + "SENet", + "SENet154", + "SEResNet50", + "SEResNet101", + "SEResNet152", + "SEResNeXt50", + "SEResNext101", + "SE_NET_MODELS", +] + + +SE_NET_MODELS = { + "senet154": "http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth", + "se_resnet50": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth", + "se_resnet101": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth", + "se_resnet152": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth", + "se_resnext50_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth", + "se_resnext101_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth", +} class SENet(nn.Module): @@ -87,7 +107,7 @@ def __init__( num_classes: int = 1000, ) -> None: - super(SENet, self).__init__() + super().__init__() relu_type: Type[nn.ReLU] = Act[Act.RELU] conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] @@ -192,7 +212,7 @@ def _make_layer( downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = Convolution( - dimensions=self.spatial_dims, + spatial_dims=self.spatial_dims, in_channels=self.inplanes, out_channels=planes * block.expansion, strides=stride, @@ -254,15 +274,7 @@ def _load_state_dict(model: nn.Module, arch: str, progress: bool): """ This function is used to load pretrained models. """ - model_urls = { - "senet154": "http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth", - "se_resnet50": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth", - "se_resnet101": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth", - "se_resnet152": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth", - "se_resnext50_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth", - "se_resnext101_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth", - } - model_url = look_up_option(arch, model_urls, None) + model_url = look_up_option(arch, SE_NET_MODELS, None) if model_url is None: raise ValueError( "only 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', " @@ -276,7 +288,11 @@ def _load_state_dict(model: nn.Module, arch: str, progress: bool): pattern_down_conv = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.0.)(\w*)$") pattern_down_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.1.)(\w*)$") - state_dict = load_state_dict_from_url(model_url, progress=progress) + if isinstance(model_url, dict): + download_url(model_url["url"], filepath=model_url["filename"]) + state_dict = torch.load(model_url["filename"], map_location=None) + else: + state_dict = load_state_dict_from_url(model_url, progress=progress) for key in list(state_dict.keys()): new_key = None if pattern_conv.match(key): @@ -317,13 +333,7 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super(SENet154, self).__init__( - block=SEBottleneck, - layers=layers, - groups=groups, - reduction=reduction, - **kwargs, - ) + super().__init__(block=SEBottleneck, layers=layers, groups=groups, reduction=reduction, **kwargs) if pretrained: # it only worked when `spatial_dims` is 2 _load_state_dict(self, "senet154", progress) @@ -345,7 +355,7 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super(SEResNet50, self).__init__( + super().__init__( block=SEResNetBottleneck, layers=layers, groups=groups, @@ -378,7 +388,7 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super(SEResNet101, self).__init__( + super().__init__( block=SEResNetBottleneck, layers=layers, groups=groups, @@ -410,7 +420,7 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super(SEResNet152, self).__init__( + super().__init__( block=SEResNetBottleneck, layers=layers, groups=groups, @@ -443,7 +453,7 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super(SEResNext50, self).__init__( + super().__init__( block=SEResNeXtBottleneck, layers=layers, groups=groups, @@ -477,7 +487,7 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super(SEResNext101, self).__init__( + super().__init__( block=SEResNeXtBottleneck, layers=layers, groups=groups, @@ -493,7 +503,7 @@ def __init__( _load_state_dict(self, "se_resnext101_32x4d", progress) -SEnet = Senet = senet = SENet +SEnet = Senet = SENet SEnet154 = Senet154 = senet154 = SENet154 SEresnet50 = Seresnet50 = seresnet50 = SEResNet50 SEresnet101 = Seresnet101 = seresnet101 = SEResNet101 diff --git a/monai/networks/nets/torchvision_fc.py b/monai/networks/nets/torchvision_fc.py index 1619f877e7..e93019d050 100644 --- a/monai/networks/nets/torchvision_fc.py +++ b/monai/networks/nets/torchvision_fc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -26,11 +26,12 @@ class TorchVisionFCModel(NetAdapter): Args: model_name: name of any torchvision model with fully connected layer at the end. - ``resnet18`` (default), ``resnet34m``, ``resnet50``, ``resnet101``, ``resnet152``, + ``resnet18`` (default), ``resnet34``, ``resnet50``, ``resnet101``, ``resnet152``, ``resnext50_32x4d``, ``resnext101_32x8d``, ``wide_resnet50_2``, ``wide_resnet101_2``. model details: https://pytorch.org/vision/stable/models.html. num_classes: number of classes for the last classification layer. Default to 1. - dim: number of spatial dimensions, default to 2. + dim: number of supported spatial dimensions in the specified model, depends on the model implementation. + default to 2 as most Torchvision models are for 2D image processing. in_channels: number of the input channels of last layer. if None, get it from `in_features` of last layer. use_conv: whether use convolutional layer to replace the last layer, default to False. pool: parameters for the pooling layer, it should be a tuple, the first item is name of the pooling layer, @@ -73,14 +74,14 @@ def __init__( ) -@deprecated(since="0.6.0", removed="0.7.0", msg_suffix="Please consider using `TorchVisionFCModel` instead.") +@deprecated(since="0.6.0", removed="0.9.0", msg_suffix="Please consider using `TorchVisionFCModel` instead.") class TorchVisionFullyConvModel(TorchVisionFCModel): """ Customize TorchVision models to replace fully connected layer by convolutional layer. Args: model_name: name of any torchvision with adaptive avg pooling and fully connected layer at the end. - ``resnet18`` (default), ``resnet34m``, ``resnet50``, ``resnet101``, ``resnet152``, + ``resnet18`` (default), ``resnet34``, ``resnet50``, ``resnet101``, ``resnet152``, ``resnext50_32x4d``, ``resnext101_32x8d``, ``wide_resnet50_2``, ``wide_resnet101_2``. num_classes: number of classes for the last classification layer. Default to 1. pool_size: the kernel size for `AvgPool2d` to replace `AdaptiveAvgPool2d`. Default to (7, 7). diff --git a/monai/networks/nets/transchex.py b/monai/networks/nets/transchex.py new file mode 100644 index 0000000000..b03ff5a17d --- /dev/null +++ b/monai/networks/nets/transchex.py @@ -0,0 +1,378 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +import shutil +import tarfile +import tempfile +from typing import Sequence, Tuple, Union + +import torch +from torch import nn + +from monai.utils import optional_import + +transformers = optional_import("transformers") +load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert") +cached_path = optional_import("transformers.file_utils", name="cached_path")[0] +BertEmbeddings = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings")[0] +BertLayer = optional_import("transformers.models.bert.modeling_bert", name="BertLayer")[0] + +__all__ = ["BertPreTrainedModel", "BertAttention", "BertOutput", "BertMixedLayer", "Pooler", "MultiModal", "Transchex"] + + +class BertPreTrainedModel(nn.Module): + """Module to load BERT pre-trained weights. + Based on: + LXMERT + https://github.com/airsplay/lxmert + BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__(self, *inputs, **kwargs) -> None: + super().__init__() + + def init_bert_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, torch.nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + @classmethod + def from_pretrained( + cls, + num_language_layers, + num_vision_layers, + num_mixed_layers, + bert_config, + state_dict=None, + cache_dir=None, + from_tf=False, + *inputs, + **kwargs, + ): + archive_file = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz" + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) + tempdir = None + if os.path.isdir(resolved_archive_file) or from_tf: + serialization_dir = resolved_archive_file + else: + tempdir = tempfile.mkdtemp() + with tarfile.open(resolved_archive_file, "r:gz") as archive: + archive.extractall(tempdir) + serialization_dir = tempdir + model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs) + if state_dict is None and not from_tf: + weights_path = os.path.join(serialization_dir, "pytorch_model.bin") + state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None) + if tempdir: + shutil.rmtree(tempdir) + if from_tf: + weights_path = os.path.join(serialization_dir, "model.ckpt") + return load_tf_weights_in_bert(model, weights_path) + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if "gamma" in key: + new_key = key.replace("gamma", "weight") + if "beta" in key: + new_key = key.replace("beta", "bias") + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs + ) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + start_prefix = "" + if not hasattr(model, "bert") and any(s.startswith("bert.") for s in state_dict.keys()): + start_prefix = "bert." + load(model, prefix=start_prefix) + return model + + +class BertAttention(nn.Module): + """BERT attention layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__(self, config) -> None: + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, context): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(context) + mixed_value_layer = self.value(context) + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.dropout(nn.Softmax(dim=-1)(attention_scores)) + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertOutput(nn.Module): + """BERT output layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__(self, config) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertMixedLayer(nn.Module): + """BERT cross attention layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__(self, config) -> None: + super().__init__() + self.att_x = BertAttention(config) + self.output_x = BertOutput(config) + self.att_y = BertAttention(config) + self.output_y = BertOutput(config) + + def forward(self, x, y): + output_x = self.att_x(x, y) + output_y = self.att_y(y, x) + return self.output_x(output_x, x), self.output_y(output_y, y) + + +class Pooler(nn.Module): + """BERT pooler layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__(self, hidden_size) -> None: + super().__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class MultiModal(BertPreTrainedModel): + """ + Multimodal Transformers From Pretrained BERT Weights" + """ + + def __init__( + self, num_language_layers: int, num_vision_layers: int, num_mixed_layers: int, bert_config: dict + ) -> None: + """ + Args: + num_language_layers: number of language transformer layers. + num_vision_layers: number of vision transformer layers. + bert_config: configuration for bert language transformer encoder. + + """ + super().__init__() + self.config = type("obj", (object,), bert_config) + self.embeddings = BertEmbeddings(self.config) + self.language_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_language_layers)]) + self.vision_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_vision_layers)]) + self.mixed_encoder = nn.ModuleList([BertMixedLayer(self.config) for _ in range(num_mixed_layers)]) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, vision_feats=None, attention_mask=None): + language_features = self.embeddings(input_ids, token_type_ids) + for layer in self.vision_encoder: + vision_feats = layer(vision_feats, None)[0] + for layer in self.language_encoder: + language_features = layer(language_features, attention_mask)[0] + for layer in self.mixed_encoder: + language_features, vision_feats = layer(language_features, vision_feats) + return language_features, vision_feats + + +class Transchex(torch.nn.Module): + """ + TransChex based on: "Hatamizadeh et al.,TransCheX: Self-Supervised Pretraining of Vision-Language + Transformers for Chest X-ray Analysis" + """ + + def __init__( + self, + in_channels: int, + img_size: Union[Sequence[int], int], + patch_size: Union[int, Tuple[int, int]], + num_classes: int, + num_language_layers: int, + num_vision_layers: int, + num_mixed_layers: int, + hidden_size: int = 768, + drop_out: float = 0.0, + attention_probs_dropout_prob: float = 0.1, + gradient_checkpointing: bool = False, + hidden_act: str = "gelu", + hidden_dropout_prob: float = 0.1, + initializer_range: float = 0.02, + intermediate_size: int = 3072, + layer_norm_eps: float = 1e-12, + max_position_embeddings: int = 512, + model_type: str = "bert", + num_attention_heads: int = 12, + num_hidden_layers: int = 12, + pad_token_id: int = 0, + position_embedding_type: str = "absolute", + transformers_version: str = "4.10.2", + type_vocab_size: int = 2, + use_cache: bool = True, + vocab_size: int = 30522, + chunk_size_feed_forward: int = 0, + is_decoder: bool = False, + add_cross_attention: bool = False, + ) -> None: + """ + Args: + in_channels: dimension of input channels. + img_size: dimension of input image. + patch_size: dimension of patch size. + num_classes: number of classes if classification is used. + num_language_layers: number of language transformer layers. + num_vision_layers: number of vision transformer layers. + num_mixed_layers: number of mixed transformer layers. + drop_out: faction of the input units to drop. + + The other parameters are part of the `bert_config` to `MultiModal.from_pretrained`. + + Examples: + + .. code-block:: python + + # for 3-channel with image size of (224,224), patch size of (32,32), 3 classes, 2 language layers, + # 2 vision layers, 2 mixed modality layers and dropout of 0.2 in the classification head + net = Transchex(in_channels=3, + img_size=(224, 224), + num_classes=3, + num_language_layers=2, + num_vision_layers=2, + num_mixed_layers=2, + drop_out=0.2) + + """ + super().__init__() + bert_config = { + "attention_probs_dropout_prob": attention_probs_dropout_prob, + "classifier_dropout": None, + "gradient_checkpointing": gradient_checkpointing, + "hidden_act": hidden_act, + "hidden_dropout_prob": hidden_dropout_prob, + "hidden_size": hidden_size, + "initializer_range": initializer_range, + "intermediate_size": intermediate_size, + "layer_norm_eps": layer_norm_eps, + "max_position_embeddings": max_position_embeddings, + "model_type": model_type, + "num_attention_heads": num_attention_heads, + "num_hidden_layers": num_hidden_layers, + "pad_token_id": pad_token_id, + "position_embedding_type": position_embedding_type, + "transformers_version": transformers_version, + "type_vocab_size": type_vocab_size, + "use_cache": use_cache, + "vocab_size": vocab_size, + "chunk_size_feed_forward": chunk_size_feed_forward, + "is_decoder": is_decoder, + "add_cross_attention": add_cross_attention, + } + if not (0 <= drop_out <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + + if (img_size[0] % patch_size[0] != 0) or (img_size[1] % patch_size[1] != 0): # type: ignore + raise ValueError("img_size should be divisible by patch_size.") + + self.multimodal = MultiModal.from_pretrained( + num_language_layers=num_language_layers, + num_vision_layers=num_vision_layers, + num_mixed_layers=num_mixed_layers, + bert_config=bert_config, + ) + + self.patch_size = patch_size + self.num_patches = (img_size[0] // self.patch_size[0]) * (img_size[1] // self.patch_size[1]) # type: ignore + self.vision_proj = nn.Conv2d( + in_channels=in_channels, out_channels=hidden_size, kernel_size=self.patch_size, stride=self.patch_size + ) + self.norm_vision_pos = nn.LayerNorm(hidden_size) + self.pos_embed_vis = nn.Parameter(torch.zeros(1, self.num_patches, hidden_size)) + self.pooler = Pooler(hidden_size=hidden_size) + self.drop = torch.nn.Dropout(drop_out) + self.cls_head = torch.nn.Linear(hidden_size, num_classes) + + def forward(self, input_ids, token_type_ids=None, vision_feats=None): + attention_mask = torch.ones_like(input_ids).unsqueeze(1).unsqueeze(2) + attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) + attention_mask = (1.0 - attention_mask) * -10000.0 + vision_feats = self.vision_proj(vision_feats).flatten(2).transpose(1, 2) + vision_feats = self.norm_vision_pos(vision_feats) + vision_feats = vision_feats + self.pos_embed_vis + hidden_state_lang, hidden_state_vis = self.multimodal( + input_ids=input_ids, token_type_ids=token_type_ids, vision_feats=vision_feats, attention_mask=attention_mask + ) + pooled_features = self.pooler(hidden_state_lang) + logits = self.cls_head(self.drop(pooled_features)) + return logits diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 70cc816fe9..25ce61ab3a 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,7 +10,7 @@ # limitations under the License. import warnings -from typing import Sequence, Tuple, Union +from typing import Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -18,17 +18,99 @@ from monai.networks.blocks.convolutions import Convolution, ResidualUnit from monai.networks.layers.factories import Act, Norm from monai.networks.layers.simplelayers import SkipConnection -from monai.utils import alias, export +from monai.utils import alias, deprecated_arg, export -__all__ = ["UNet", "Unet", "unet"] +__all__ = ["UNet", "Unet"] @export("monai.networks.nets") @alias("Unet") class UNet(nn.Module): + """ + Enhanced version of UNet which has residual units implemented with the ResidualUnit class. + The residual part uses a convolution to change the input dimensions to match the output dimensions + if this is necessary but will use nn.Identity if not. + Refer to: https://link.springer.com/chapter/10.1007/978-3-030-12029-0_40. + + Each layer of the network has a encode and decode path with a skip connection between them. Data in the encode path + is downsampled using strided convolutions (if `strides` is given values greater than 1) and in the decode path + upsampled using strided transpose convolutions. These down or up sampling operations occur at the beginning of each + block rather than afterwards as is typical in UNet implementations. + + To further explain this consider the first example network given below. This network has 3 layers with strides + of 2 for each of the middle layers (the last layer is the bottom connection which does not down/up sample). Input + data to this network is immediately reduced in the spatial dimensions by a factor of 2 by the first convolution of + the residual unit defining the first layer of the encode part. The last layer of the decode part will upsample its + input (data from the previous layer concatenated with data from the skip connection) in the first convolution. this + ensures the final output of the network has the same shape as the input. + + Padding values for the convolutions are chosen to ensure output sizes are even divisors/multiples of the input + sizes if the `strides` value for a layer is a factor of the input sizes. A typical case is to use `strides` values + of 2 and inputs that are multiples of powers of 2. An input can thus be downsampled evenly however many times its + dimensions can be divided by 2, so for the example network inputs would have to have dimensions that are multiples + of 4. In the second example network given below the input to the bottom layer will have shape (1, 64, 15, 15) for + an input of shape (1, 1, 240, 240) demonstrating the input being reduced in size spatially by 2**4. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + channels: sequence of channels. Top block first. The length of `channels` should be no less than 2. + strides: sequence of convolution strides. The length of `stride` should equal to `len(channels) - 1`. + kernel_size: convolution kernel size, the value(s) should be odd. If sequence, + its length should equal to dimensions. Defaults to 3. + up_kernel_size: upsampling convolution kernel size, the value(s) should be odd. If sequence, + its length should equal to dimensions. Defaults to 3. + num_res_units: number of residual units. Defaults to 0. + act: activation type and arguments. Defaults to PReLU. + norm: feature normalization type and arguments. Defaults to instance norm. + dropout: dropout ratio. Defaults to no dropout. + bias: whether to have a bias term in convolution blocks. Defaults to True. + According to `Performance Tuning Guide `_, + if a conv layer is directly followed by a batch norm layer, bias should be False. + + Examples:: + + from monai.networks.nets import UNet + + # 3 layer network with down/upsampling by a factor of 2 at each layer with 2-convolution residual units + net = UNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(4, 8, 16), + strides=(2, 2), + num_res_units=2 + ) + + # 5 layer network with simple convolution/normalization/dropout/activation blocks defining the layers + net=UNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(4, 8, 16, 32, 64), + strides=(2, 2, 2, 2), + ) + + .. deprecated:: 0.6.0 + ``dimensions`` is deprecated, use ``spatial_dims`` instead. + + Note: The acceptable spatial size of input data depends on the parameters of the network, + to set appropriate spatial size, please check the tutorial for more details: + https://github.com/Project-MONAI/tutorials/blob/master/modules/UNet_input_size_constrains.ipynb. + Typically, when using a stride of 2 in down / up sampling, the output dimensions are either half of the + input when downsampling, or twice when upsampling. In this case with N numbers of layers in the network, + the inputs must have spatial dimensions that are all multiples of 2^N. + Usually, applying `resize`, `pad` or `crop` transforms can help adjust the spatial size of input data. + + """ + + @deprecated_arg( + name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." + ) def __init__( self, - dimensions: int, + spatial_dims: int, in_channels: int, out_channels: int, channels: Sequence[int], @@ -40,40 +122,9 @@ def __init__( norm: Union[Tuple, str] = Norm.INSTANCE, dropout: float = 0.0, bias: bool = True, + dimensions: Optional[int] = None, ) -> None: - """ - Enhanced version of UNet which has residual units implemented with the ResidualUnit class. - The residual part uses a convolution to change the input dimensions to match the output dimensions - if this is necessary but will use nn.Identity if not. - Refer to: https://link.springer.com/chapter/10.1007/978-3-030-12029-0_40. - Args: - dimensions: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - channels: sequence of channels. Top block first. The length of `channels` should be no less than 2. - strides: sequence of convolution strides. The length of `stride` should equal to `len(channels) - 1`. - kernel_size: convolution kernel size, the value(s) should be odd. If sequence, - its length should equal to dimensions. Defaults to 3. - up_kernel_size: upsampling convolution kernel size, the value(s) should be odd. If sequence, - its length should equal to dimensions. Defaults to 3. - num_res_units: number of residual units. Defaults to 0. - act: activation type and arguments. Defaults to PReLU. - norm: feature normalization type and arguments. Defaults to instance norm. - dropout: dropout ratio. Defaults to no dropout. - bias: whether to have a bias term in convolution blocks. Defaults to True. - According to `Performance Tuning Guide `_, - if a conv layer is directly followed by a batch norm layer, bias should be False. - - Note: The acceptable spatial size of input data depends on the parameters of the network, - to set appropriate spatial size, please check the tutorial for more details: - https://github.com/Project-MONAI/tutorials/blob/master/modules/UNet_input_size_constrains.ipynb. - Typically, when using a stride of 2 in down / up sampling, the output dimensions are either half of the - input when downsampling, or twice when upsampling. In this case with N numbers of layers in the network, - the inputs must have spatial dimensions that are all multiples of 2^N. - Usually, applying `resize`, `pad` or `crop` transforms can help adjust the spatial size of input data. - - """ super().__init__() if len(channels) < 2: @@ -83,14 +134,16 @@ def __init__( raise ValueError("the length of `strides` should equal to `len(channels) - 1`.") if delta > 0: warnings.warn(f"`len(strides) > len(channels) - 1`, the last {delta} values of strides will not be used.") + if dimensions is not None: + spatial_dims = dimensions if isinstance(kernel_size, Sequence): - if len(kernel_size) != dimensions: + if len(kernel_size) != spatial_dims: raise ValueError("the length of `kernel_size` should equal to `dimensions`.") if isinstance(up_kernel_size, Sequence): - if len(up_kernel_size) != dimensions: + if len(up_kernel_size) != spatial_dims: raise ValueError("the length of `up_kernel_size` should equal to `dimensions`.") - self.dimensions = dimensions + self.dimensions = spatial_dims self.in_channels = in_channels self.out_channels = out_channels self.channels = channels @@ -105,7 +158,7 @@ def __init__( def _create_block( inc: int, outc: int, channels: Sequence[int], strides: Sequence[int], is_top: bool - ) -> nn.Sequential: + ) -> nn.Module: """ Builds the UNet structure from the bottom up by recursing down to the bottom block, then creating sequential blocks containing the downsample path, a skip connection around the previous block, and the upsample path. @@ -133,20 +186,39 @@ def _create_block( down = self._get_down_layer(inc, c, s, is_top) # create layer in downsampling path up = self._get_up_layer(upc, outc, s, is_top) # create layer in upsampling path - return nn.Sequential(down, SkipConnection(subblock), up) + return self._get_connection_block(down, up, subblock) self.model = _create_block(in_channels, out_channels, self.channels, self.strides, True) + def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: + """ + Returns the block object defining a layer of the UNet structure including the implementation of the skip + between encoding (down) and and decoding (up) sides of the network. + + Args: + down_path: encoding half of the layer + up_path: decoding half of the layer + subblock: block defining the next layer in the network. + Returns: block for this layer: `nn.Sequential(down_path, SkipConnection(subblock), up_path)` + """ + return nn.Sequential(down_path, SkipConnection(subblock), up_path) + def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module: """ + Returns the encoding (down) part of a layer of the network. This typically will downsample data at some point + in its structure. Its output is used as input to the next layer down and is concatenated with output from the + next layer to form the input for the decode (up) part of the layer. + Args: in_channels: number of input channels. out_channels: number of output channels. strides: convolution stride. is_top: True if this is the top block. """ + mod: nn.Module if self.num_res_units > 0: - return ResidualUnit( + + mod = ResidualUnit( self.dimensions, in_channels, out_channels, @@ -158,7 +230,8 @@ def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_ dropout=self.dropout, bias=self.bias, ) - return Convolution( + return mod + mod = Convolution( self.dimensions, in_channels, out_channels, @@ -169,9 +242,12 @@ def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_ dropout=self.dropout, bias=self.bias, ) + return mod def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module: """ + Returns the bottom or bottleneck layer at the bottom of the network linking encode to decode halves. + Args: in_channels: number of input channels. out_channels: number of output channels. @@ -180,6 +256,9 @@ def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module: def _get_up_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module: """ + Returns the decoding (up) part of a layer of the network. This typically will upsample data at some point + in its structure. Its output is used as input to the next layer up. + Args: in_channels: number of input channels. out_channels: number of output channels. @@ -225,4 +304,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -Unet = unet = UNet +Unet = UNet diff --git a/monai/networks/nets/unetr.py b/monai/networks/nets/unetr.py index 9990cb6643..c53936d27f 100644 --- a/monai/networks/nets/unetr.py +++ b/monai/networks/nets/unetr.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -70,7 +70,7 @@ def __init__( """ - super(UNETR, self).__init__() + super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") @@ -179,24 +179,25 @@ def __init__( res_block=res_block, ) self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels) + self.proj_axes = (0, spatial_dims + 1) + tuple(d + 1 for d in range(spatial_dims)) + self.proj_view_shape = list(self.feat_size) + [self.hidden_size] - def proj_feat(self, x, hidden_size, feat_size): - new_view = (x.size(0), *feat_size, hidden_size) + def proj_feat(self, x): + new_view = [x.size(0)] + self.proj_view_shape x = x.view(new_view) - new_axes = (0, len(x.shape) - 1) + tuple(d + 1 for d in range(len(feat_size))) - x = x.permute(new_axes).contiguous() + x = x.permute(self.proj_axes).contiguous() return x def forward(self, x_in): x, hidden_states_out = self.vit(x_in) enc1 = self.encoder1(x_in) x2 = hidden_states_out[3] - enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size)) + enc2 = self.encoder2(self.proj_feat(x2)) x3 = hidden_states_out[6] - enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size)) + enc3 = self.encoder3(self.proj_feat(x3)) x4 = hidden_states_out[9] - enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size)) - dec4 = self.proj_feat(x, self.hidden_size, self.feat_size) + enc4 = self.encoder4(self.proj_feat(x4)) + dec4 = self.proj_feat(x) dec3 = self.decoder5(dec4, enc4) dec2 = self.decoder4(dec3, enc3) dec1 = self.decoder3(dec2, enc2) diff --git a/monai/networks/nets/varautoencoder.py b/monai/networks/nets/varautoencoder.py index 7f54890992..7386883124 100644 --- a/monai/networks/nets/varautoencoder.py +++ b/monai/networks/nets/varautoencoder.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,16 +19,65 @@ from monai.networks.layers.convutils import calculate_out_shape, same_padding from monai.networks.layers.factories import Act, Norm from monai.networks.nets import AutoEncoder +from monai.utils import deprecated_arg __all__ = ["VarAutoEncoder"] class VarAutoEncoder(AutoEncoder): - """Variational Autoencoder based on the paper - https://arxiv.org/abs/1312.6114""" + """ + Variational Autoencoder based on the paper - https://arxiv.org/abs/1312.6114 + + Args: + spatial_dims: number of spatial dimensions. + in_shape: shape of input data starting with channel dimension. + out_channels: number of output channels. + latent_size: size of the latent variable. + channels: sequence of channels. Top block first. The length of `channels` should be no less than 2. + strides: sequence of convolution strides. The length of `stride` should equal to `len(channels) - 1`. + kernel_size: convolution kernel size, the value(s) should be odd. If sequence, + its length should equal to dimensions. Defaults to 3. + up_kernel_size: upsampling convolution kernel size, the value(s) should be odd. If sequence, + its length should equal to dimensions. Defaults to 3. + num_res_units: number of residual units. Defaults to 0. + inter_channels: sequence of channels defining the blocks in the intermediate layer between encode and decode. + inter_dilations: defines the dilation value for each block of the intermediate layer. Defaults to 1. + num_inter_units: number of residual units for each block of the intermediate layer. Defaults to 0. + act: activation type and arguments. Defaults to PReLU. + norm: feature normalization type and arguments. Defaults to instance norm. + dropout: dropout ratio. Defaults to no dropout. + bias: whether to have a bias term in convolution blocks. Defaults to True. + According to `Performance Tuning Guide `_, + if a conv layer is directly followed by a batch norm layer, bias should be False. + + .. deprecated:: 0.6.0 + ``dimensions`` is deprecated, use ``spatial_dims`` instead. + + Examples:: + + from monai.networks.nets import VarAutoEncoder + + # 3 layer network accepting images with dimensions (1, 32, 32) and using a latent vector with 2 values + model = VarAutoEncoder( + dimensions=2, + in_shape=(32, 32), # image spatial shape + out_channels=1, + latent_size=2, + channels=(16, 32, 64), + strides=(1, 2, 2), + ) + + see also: + - Variational autoencoder network with MedNIST Dataset + https://github.com/Project-MONAI/tutorials/blob/master/modules/varautoencoder_mednist.ipynb + """ + @deprecated_arg( + name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." + ) def __init__( self, - dimensions: int, + spatial_dims: int, in_shape: Sequence[int], out_channels: int, latent_size: int, @@ -44,15 +93,18 @@ def __init__( norm: Union[Tuple, str] = Norm.INSTANCE, dropout: Optional[Union[Tuple, str, float]] = None, bias: bool = True, + dimensions: Optional[int] = None, ) -> None: self.in_channels, *self.in_shape = in_shape self.latent_size = latent_size self.final_size = np.asarray(self.in_shape, dtype=int) + if dimensions is not None: + spatial_dims = dimensions super().__init__( - dimensions, + spatial_dims, self.in_channels, out_channels, channels, diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index 3a5d94cc37..a5f7963eca 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,11 +18,15 @@ from monai.networks.blocks.patchembedding import PatchEmbeddingBlock from monai.networks.blocks.transformerblock import TransformerBlock +__all__ = ["ViT"] + class ViT(nn.Module): """ Vision Transformer (ViT), based on: "Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + + ViT supports Torchscript but only works for Pytorch after 1.8. """ def __init__( @@ -68,7 +72,7 @@ def __init__( """ - super(ViT, self).__init__() + super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") @@ -97,7 +101,7 @@ def __init__( def forward(self, x): x = self.patch_embedding(x) - if self.classification: + if hasattr(self, "cls_token"): cls_token = self.cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_token, x), dim=1) hidden_states_out = [] @@ -105,6 +109,6 @@ def forward(self, x): x = blk(x) hidden_states_out.append(x) x = self.norm(x) - if self.classification: + if hasattr(self, "classification_head"): x = self.classification_head(x[:, 0]) return x, hidden_states_out diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py new file mode 100644 index 0000000000..9e5490f9d6 --- /dev/null +++ b/monai/networks/nets/vitautoenc.py @@ -0,0 +1,121 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Sequence, Union + +import torch +import torch.nn as nn + +from monai.networks.blocks.patchembedding import PatchEmbeddingBlock +from monai.networks.blocks.transformerblock import TransformerBlock +from monai.networks.layers import Conv +from monai.utils import ensure_tuple_rep + +__all__ = ["ViTAutoEnc"] + + +class ViTAutoEnc(nn.Module): + """ + Vision Transformer (ViT), based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + + Modified to also give same dimension outputs as the input size of the image + """ + + def __init__( + self, + in_channels: int, + img_size: Union[Sequence[int], int], + patch_size: Union[Sequence[int], int], + out_channels: int = 1, + deconv_chns: int = 16, + hidden_size: int = 768, + mlp_dim: int = 3072, + num_layers: int = 12, + num_heads: int = 12, + pos_embed: str = "conv", + dropout_rate: float = 0.0, + spatial_dims: int = 3, + ) -> None: + """ + Args: + in_channels: dimension of input channels or the number of channels for input + img_size: dimension of input image. + patch_size: dimension of patch size. + hidden_size: dimension of hidden layer. + out_channels: number of output channels. + deconv_chns: number of channels for the deconvolution layers. + mlp_dim: dimension of feedforward layer. + num_layers: number of transformer blocks. + num_heads: number of attention heads. + pos_embed: position embedding layer type. + dropout_rate: faction of the input units to drop. + spatial_dims: number of spatial dimensions. + + Examples:: + + # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone + # It will provide an output of same size as that of the input + >>> net = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(96,96,96), pos_embed='conv') + + # for 3-channel with image size of (128,128,128), output will be same size as of input + >>> net = ViTAutoEnc(in_channels=3, patch_size=(16,16,16), img_size=(128,128,128), pos_embed='conv') + + """ + + super().__init__() + + self.patch_size = ensure_tuple_rep(patch_size, spatial_dims) + self.spatial_dims = spatial_dims + + self.patch_embedding = PatchEmbeddingBlock( + in_channels=in_channels, + img_size=img_size, + patch_size=patch_size, + hidden_size=hidden_size, + num_heads=num_heads, + pos_embed=pos_embed, + dropout_rate=dropout_rate, + spatial_dims=self.spatial_dims, + ) + self.blocks = nn.ModuleList( + [TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)] + ) + self.norm = nn.LayerNorm(hidden_size) + + new_patch_size = [4] * self.spatial_dims + conv_trans = Conv[Conv.CONVTRANS, self.spatial_dims] + # self.conv3d_transpose* is to be compatible with existing 3d model weights. + self.conv3d_transpose = conv_trans(hidden_size, deconv_chns, kernel_size=new_patch_size, stride=new_patch_size) + self.conv3d_transpose_1 = conv_trans( + in_channels=deconv_chns, out_channels=out_channels, kernel_size=new_patch_size, stride=new_patch_size + ) + + def forward(self, x): + """ + Args: + x: input tensor must have isotropic spatial dimensions, + such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``. + """ + spatial_size = x.shape[2:] + x = self.patch_embedding(x) + hidden_states_out = [] + for blk in self.blocks: + x = blk(x) + hidden_states_out.append(x) + x = self.norm(x) + x = x.transpose(1, 2) + d = [s // p for s, p in zip(spatial_size, self.patch_size)] + x = torch.reshape(x, [x.shape[0], x.shape[1], *d]) + x = self.conv3d_transpose(x) + x = self.conv3d_transpose_1(x) + return x, hidden_states_out diff --git a/monai/networks/nets/vnet.py b/monai/networks/nets/vnet.py index 72f3290a89..7669b4678e 100644 --- a/monai/networks/nets/vnet.py +++ b/monai/networks/nets/vnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -30,11 +30,11 @@ def get_acti_layer(act: Union[Tuple[str, Dict], str], nchan: int = 0): class LUConv(nn.Module): def __init__(self, spatial_dims: int, nchan: int, act: Union[Tuple[str, Dict], str], bias: bool = False): - super(LUConv, self).__init__() + super().__init__() self.act_function = get_acti_layer(act, nchan) self.conv_block = Convolution( - dimensions=spatial_dims, + spatial_dims=spatial_dims, in_channels=nchan, out_channels=nchan, kernel_size=5, @@ -65,7 +65,7 @@ def __init__( act: Union[Tuple[str, Dict], str], bias: bool = False, ): - super(InputTransition, self).__init__() + super().__init__() if 16 % in_channels != 0: raise ValueError(f"16 should be divisible by in_channels, got in_channels={in_channels}.") @@ -74,7 +74,7 @@ def __init__( self.in_channels = in_channels self.act_function = get_acti_layer(act, 16) self.conv_block = Convolution( - dimensions=spatial_dims, + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=16, kernel_size=5, @@ -102,7 +102,7 @@ def __init__( dropout_dim: int = 3, bias: bool = False, ): - super(DownTransition, self).__init__() + super().__init__() conv_type: Type[Union[nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] norm_type: Type[Union[nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] @@ -138,7 +138,7 @@ def __init__( dropout_prob: Optional[float] = None, dropout_dim: int = 3, ): - super(UpTransition, self).__init__() + super().__init__() conv_trans_type: Type[Union[nn.ConvTranspose2d, nn.ConvTranspose3d]] = Conv[Conv.CONVTRANS, spatial_dims] norm_type: Type[Union[nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] @@ -174,13 +174,13 @@ def __init__( act: Union[Tuple[str, Dict], str], bias: bool = False, ): - super(OutputTransition, self).__init__() + super().__init__() conv_type: Type[Union[nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] self.act_function1 = get_acti_layer(act, out_channels) self.conv_block = Convolution( - dimensions=spatial_dims, + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=5, @@ -213,7 +213,7 @@ class VNet(nn.Module): The value should meet the condition that ``16 % in_channels == 0``. out_channels: number of output channels for the network. Defaults to 1. act: activation type in the network. Defaults to ``("elu", {"inplace": True})``. - dropout_prob: dropout ratio. Defaults to 0.5. Defaults to 3. + dropout_prob: dropout ratio. Defaults to 0.5. dropout_dim: determine the dimensions of dropout. Defaults to 3. - ``dropout_dim = 1``, randomly zeroes some of the elements for each channel. diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 9d20d2a83b..a6b0699107 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,11 +15,16 @@ import warnings from collections import OrderedDict from contextlib import contextmanager -from typing import Any, Callable, Mapping, Optional, Sequence, Union +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union import torch import torch.nn as nn +from monai.config import PathLike +from monai.utils.deprecate_utils import deprecated, deprecated_arg +from monai.utils.misc import ensure_tuple, save_obj, set_determinism +from monai.utils.module import pytorch_after + __all__ = [ "one_hot", "slice_channels", @@ -31,7 +36,11 @@ "pixelshuffle", "eval_mode", "train_mode", + "get_state_dict", "copy_model_state", + "save_state", + "convert_to_torchscript", + "meshgrid_ij", ] @@ -88,7 +97,13 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f return labels +@deprecated(since="0.8.0", msg_suffix="use `monai.utils.misc.sample_slices` instead.") def slice_channels(tensor: torch.Tensor, *slicevals: Optional[int]) -> torch.Tensor: + """ + .. deprecated:: 0.8.0 + Use `monai.utils.misc.sample_slices` instead. + + """ slices = [slice(None)] * len(tensor.shape) slices[1] = slice(*slicevals) @@ -225,9 +240,14 @@ def icnr_init(conv, upsample_factor, init=nn.init.kaiming_normal_): conv.weight.data.copy_(kernel) -def pixelshuffle(x: torch.Tensor, dimensions: int, scale_factor: int) -> torch.Tensor: +@deprecated_arg( + name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." +) +def pixelshuffle( + x: torch.Tensor, spatial_dims: int, scale_factor: int, dimensions: Optional[int] = None +) -> torch.Tensor: """ - Apply pixel shuffle to the tensor `x` with spatial dimensions `dimensions` and scaling factor `scale_factor`. + Apply pixel shuffle to the tensor `x` with spatial dimensions `spatial_dims` and scaling factor `scale_factor`. See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution Using a nEfficient Sub-Pixel Convolutional Neural Network." @@ -236,20 +256,24 @@ def pixelshuffle(x: torch.Tensor, dimensions: int, scale_factor: int) -> torch.T Args: x: Input tensor - dimensions: number of spatial dimensions, typically 2 or 3 for 2D or 3D + spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D scale_factor: factor to rescale the spatial dimensions by, must be >=1 + .. deprecated:: 0.6.0 + ``dimensions`` is deprecated, use ``spatial_dims`` instead. + Returns: Reshuffled version of `x`. Raises: - ValueError: When input channels of `x` are not divisible by (scale_factor ** dimensions) + ValueError: When input channels of `x` are not divisible by (scale_factor ** spatial_dims) """ - - dim, factor = dimensions, scale_factor + if dimensions is not None: + spatial_dims = dimensions + dim, factor = spatial_dims, scale_factor input_size = list(x.size()) batch_size, channels = input_size[:2] - scale_divisor = factor ** dim + scale_divisor = factor**dim if channels % scale_divisor != 0: raise ValueError( @@ -336,6 +360,20 @@ def train_mode(*nets: nn.Module): n.eval() +def get_state_dict(obj: Union[torch.nn.Module, Mapping]): + """ + Get the state dict of input object if has `state_dict`, otherwise, return object directly. + For data parallel model, automatically convert it to regular model first. + + Args: + obj: input object to check and get the state_dict. + + """ + if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): + obj = obj.module + return obj.state_dict() if hasattr(obj, "state_dict") else obj # type: ignore + + def copy_model_state( dst: Union[torch.nn.Module, Mapping], src: Union[torch.nn.Module, Mapping], @@ -380,15 +418,10 @@ def copy_model_state( # Returns: an OrderedDict of the updated `dst` state, the changed, and unchanged keys. - """ - if isinstance(src, (nn.DataParallel, nn.parallel.DistributedDataParallel)): - src = src.module - if isinstance(dst, (nn.DataParallel, nn.parallel.DistributedDataParallel)): - dst = dst.module - src_dict = src.state_dict() if isinstance(src, torch.nn.Module) else src - dst_dict = dst.state_dict() if isinstance(dst, torch.nn.Module) else dst - dst_dict = OrderedDict(dst_dict) + """ + src_dict = get_state_dict(src) + dst_dict = OrderedDict(get_state_dict(dst)) to_skip = {s_key for s_key in src_dict if exclude_vars and re.compile(exclude_vars).search(s_key)} @@ -413,3 +446,110 @@ def copy_model_state( if inplace and isinstance(dst, torch.nn.Module): dst.load_state_dict(dst_dict) return dst_dict, updated_keys, unchanged_keys + + +def save_state(src: Union[torch.nn.Module, Dict], path: PathLike, **kwargs): + """ + Save the state dict of input source data with PyTorch `save`. + It can save `nn.Module`, `state_dict`, a dictionary of `nn.Module` or `state_dict`. + And automatically convert the data parallel module to regular module. + For example:: + + save_state(net, path) + save_state(net.state_dict(), path) + save_state({"net": net, "opt": opt}, path) + net_dp = torch.nn.DataParallel(net) + save_state(net_dp, path) + + Refer to: https://pytorch.org/ignite/v0.4.8/generated/ignite.handlers.DiskSaver.html. + + Args: + src: input data to save, can be `nn.Module`, `state_dict`, a dictionary of `nn.Module` or `state_dict`. + path: target file path to save the input object. + kwargs: other args for the `save_obj` except for the `obj` and `path`. + default `func` is `torch.save()`, details of the args of it: + https://pytorch.org/docs/stable/generated/torch.save.html. + + """ + + ckpt: Dict = {} + if isinstance(src, dict): + for k, v in src.items(): + ckpt[k] = get_state_dict(v) + else: + ckpt = get_state_dict(src) + + save_obj(obj=ckpt, path=path, **kwargs) + + +def convert_to_torchscript( + model: nn.Module, + filename_or_obj: Optional[Any] = None, + extra_files: Optional[Dict] = None, + verify: bool = False, + inputs: Optional[Sequence[Any]] = None, + device: Optional[torch.device] = None, + rtol: float = 1e-4, + atol: float = 0.0, + **kwargs, +): + """ + Utility to convert a model into TorchScript model and save to file, + with optional input / output data verification. + + Args: + model: source PyTorch model to save. + filename_or_obj: if not None, specify a file-like object (has to implement write and flush) + or a string containing a file path name to save the TorchScript model. + extra_files: map from filename to contents which will be stored as part of the save model file. + works for PyTorch 1.7 or later. + for more details: https://pytorch.org/docs/stable/generated/torch.jit.save.html. + verify: whether to verify the input and output of TorchScript model. + if `filename_or_obj` is not None, load the saved TorchScript model and verify. + inputs: input test data to verify model, should be a sequence of data, every item maps to a argument + of `model()` function. + device: target device to verify the model, if None, use CUDA if available. + rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model. + atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model. + kwargs: other arguments except `obj` for `torch.jit.script()` to convert model, for more details: + https://pytorch.org/docs/master/generated/torch.jit.script.html. + + """ + model.eval() + with torch.no_grad(): + script_module = torch.jit.script(model, **kwargs) + if filename_or_obj is not None: + if not pytorch_after(1, 7): + torch.jit.save(m=script_module, f=filename_or_obj) + else: + torch.jit.save(m=script_module, f=filename_or_obj, _extra_files=extra_files) + + if verify: + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if inputs is None: + raise ValueError("missing input data for verification.") + + inputs = [i.to(device) if isinstance(i, torch.Tensor) else i for i in inputs] + ts_model = torch.jit.load(filename_or_obj) if filename_or_obj is not None else script_module + ts_model.eval().to(device) + model = model.to(device) + + with torch.no_grad(): + set_determinism(seed=0) + torch_out = ensure_tuple(model(*inputs)) + set_determinism(seed=0) + torchscript_out = ensure_tuple(ts_model(*inputs)) + set_determinism(seed=None) + # compare TorchScript and PyTorch results + for r1, r2 in zip(torch_out, torchscript_out): + if isinstance(r1, torch.Tensor) or isinstance(r2, torch.Tensor): + torch.testing.assert_allclose(r1, r2, rtol=rtol, atol=atol) + + return script_module + + +def meshgrid_ij(*tensors): + if pytorch_after(1, 10): + return torch.meshgrid(*tensors, indexing="ij") + return torch.meshgrid(*tensors) diff --git a/monai/optimizers/__init__.py b/monai/optimizers/__init__.py index e53aa8d468..8ce5d3f925 100644 --- a/monai/optimizers/__init__.py +++ b/monai/optimizers/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,5 +10,6 @@ # limitations under the License. from .lr_finder import LearningRateFinder +from .lr_scheduler import ExponentialLR, LinearLR, WarmupCosineSchedule from .novograd import Novograd from .utils import generate_param_groups diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index 49d4427b3d..ce092d33ab 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pickle import warnings from functools import partial from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union @@ -17,6 +18,7 @@ import torch import torch.nn as nn from torch.optim import Optimizer +from torch.serialization import DEFAULT_PROTOCOL from torch.utils.data import DataLoader from monai.networks.utils import eval_mode @@ -120,7 +122,7 @@ def __iter__(self): def __next__(self): self.run_counter += 1 - return super(ValDataLoaderIter, self).__next__() + return super().__next__() def default_image_extractor(x: Any) -> torch.Tensor: @@ -144,30 +146,30 @@ class LearningRateFinder: and what is the optimal learning rate. Example (fastai approach): - >>> lr_finder = LearningRateFinder(net, optimizer, criterion) - >>> lr_finder.range_test(data_loader, end_lr=100, num_iter=100) - >>> lr_finder.get_steepest_gradient() - >>> lr_finder.plot() # to inspect the loss-learning rate graph + >>> lr_finder = LearningRateFinder(net, optimizer, criterion) + >>> lr_finder.range_test(data_loader, end_lr=100, num_iter=100) + >>> lr_finder.get_steepest_gradient() + >>> lr_finder.plot() # to inspect the loss-learning rate graph Example (Leslie Smith's approach): - >>> lr_finder = LearningRateFinder(net, optimizer, criterion) - >>> lr_finder.range_test(train_loader, val_loader=val_loader, end_lr=1, num_iter=100, step_mode="linear") + >>> lr_finder = LearningRateFinder(net, optimizer, criterion) + >>> lr_finder.range_test(train_loader, val_loader=val_loader, end_lr=1, num_iter=100, step_mode="linear") Gradient accumulation is supported; example: - >>> train_data = ... # prepared dataset - >>> desired_bs, real_bs = 32, 4 # batch size - >>> accumulation_steps = desired_bs // real_bs # required steps for accumulation - >>> data_loader = torch.utils.data.DataLoader(train_data, batch_size=real_bs, shuffle=True) - >>> acc_lr_finder = LearningRateFinder(net, optimizer, criterion) - >>> acc_lr_finder.range_test(data_loader, end_lr=10, num_iter=100, accumulation_steps=accumulation_steps) + >>> train_data = ... # prepared dataset + >>> desired_bs, real_bs = 32, 4 # batch size + >>> accumulation_steps = desired_bs // real_bs # required steps for accumulation + >>> data_loader = torch.utils.data.DataLoader(train_data, batch_size=real_bs, shuffle=True) + >>> acc_lr_finder = LearningRateFinder(net, optimizer, criterion) + >>> acc_lr_finder.range_test(data_loader, end_lr=10, num_iter=100, accumulation_steps=accumulation_steps) By default, image will be extracted from data loader with x["image"] and x[0], depending on whether batch data is a dictionary or not (and similar behaviour for extracting the label). If your data loader returns something other than this, pass a callable function to extract it, e.g.: - >>> image_extractor = lambda x: x["input"] - >>> label_extractor = lambda x: x[100] - >>> lr_finder = LearningRateFinder(net, optimizer, criterion) - >>> lr_finder.range_test(train_loader, val_loader, image_extractor, label_extractor) + >>> image_extractor = lambda x: x["input"] + >>> label_extractor = lambda x: x[100] + >>> lr_finder = LearningRateFinder(net, optimizer, criterion) + >>> lr_finder.range_test(train_loader, val_loader, image_extractor, label_extractor) References: Modified from: https://github.com/davidtvs/pytorch-lr-finder. @@ -183,6 +185,8 @@ def __init__( memory_cache: bool = True, cache_dir: Optional[str] = None, amp: bool = False, + pickle_module=pickle, + pickle_protocol: int = DEFAULT_PROTOCOL, verbose: bool = True, ) -> None: """Constructor. @@ -202,6 +206,12 @@ def __init__( specified, system-wide temporary directory is used. Notice that this parameter will be ignored if `memory_cache` is True. amp: use Automatic Mixed Precision + pickle_module: module used for pickling metadata and objects, default to `pickle`. + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. + pickle_protocol: can be specified to override the default protocol, default to `2`. + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. verbose: verbose output Returns: None @@ -221,7 +231,9 @@ def __init__( # Save the original state of the model and optimizer so they can be restored if # needed self.model_device = next(self.model.parameters()).device - self.state_cacher = StateCacher(memory_cache, cache_dir=cache_dir) + self.state_cacher = StateCacher( + in_memory=memory_cache, cache_dir=cache_dir, pickle_module=pickle_module, pickle_protocol=pickle_protocol + ) self.state_cacher.store("model", self.model.state_dict()) self.state_cacher.store("optimizer", self.optimizer.state_dict()) @@ -328,11 +340,7 @@ def range_test( print(f"Computing optimal learning rate, iteration {iteration + 1}/{num_iter}") # Train on batch and retrieve loss - loss = self._train_batch( - train_iter, - accumulation_steps, - non_blocking_transfer=non_blocking_transfer, - ) + loss = self._train_batch(train_iter, accumulation_steps, non_blocking_transfer=non_blocking_transfer) if val_loader: loss = self._validate(val_iter, non_blocking_transfer=non_blocking_transfer) @@ -429,11 +437,7 @@ def _validate(self, val_iter: ValDataLoaderIter, non_blocking_transfer: bool = T return running_loss / len(val_iter.dataset) - def get_lrs_and_losses( - self, - skip_start: int = 0, - skip_end: int = 0, - ) -> Tuple[list, list]: + def get_lrs_and_losses(self, skip_start: int = 0, skip_end: int = 0) -> Tuple[list, list]: """Get learning rates and their corresponding losses Args: @@ -454,9 +458,7 @@ def get_lrs_and_losses( return lrs, losses def get_steepest_gradient( - self, - skip_start: int = 0, - skip_end: int = 0, + self, skip_start: int = 0, skip_end: int = 0 ) -> Union[Tuple[float, float], Tuple[None, None]]: """Get learning rate which has steepest gradient and its corresponding loss @@ -476,14 +478,7 @@ def get_steepest_gradient( print("Failed to compute the gradients, there might not be enough points.") return None, None - def plot( - self, - skip_start: int = 0, - skip_end: int = 0, - log_lr: bool = True, - ax=None, - steepest_lr: bool = True, - ): + def plot(self, skip_start: int = 0, skip_end: int = 0, log_lr: bool = True, ax=None, steepest_lr: bool = True): """Plots the learning rate range test. Args: diff --git a/monai/optimizers/lr_scheduler.py b/monai/optimizers/lr_scheduler.py index 9416b583f7..83412c61ea 100644 --- a/monai/optimizers/lr_scheduler.py +++ b/monai/optimizers/lr_scheduler.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -33,7 +33,7 @@ def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoc """ self.end_lr = end_lr self.num_iter = num_iter - super(_LRSchedulerMONAI, self).__init__(optimizer, last_epoch) + super().__init__(optimizer, last_epoch) class LinearLR(_LRSchedulerMONAI): @@ -77,7 +77,7 @@ def __init__( self.warmup_steps = warmup_steps self.t_total = t_total self.cycles = cycles - super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch) + super().__init__(optimizer, self.lr_lambda, last_epoch) def lr_lambda(self, step): if step < self.warmup_steps: diff --git a/monai/optimizers/novograd.py b/monai/optimizers/novograd.py index 62e42cc9ab..07a6aff90a 100644 --- a/monai/optimizers/novograd.py +++ b/monai/optimizers/novograd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,7 +20,7 @@ class Novograd(Optimizer): Novograd based on `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks `_. The code is adapted from the implementations in `Jasper for PyTorch - `_, + `_, and `OpenSeq2Seq `_. Args: @@ -45,28 +45,23 @@ def __init__( amsgrad: bool = False, ): if 0.0 > lr: - raise ValueError("Invalid learning rate: {}".format(lr)) + raise ValueError(f"Invalid learning rate: {lr}") if 0.0 > eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) + raise ValueError(f"Invalid epsilon value: {eps}") if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") if 0.0 > weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - grad_averaging=grad_averaging, - amsgrad=amsgrad, + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, grad_averaging=grad_averaging, amsgrad=amsgrad ) - super(Novograd, self).__init__(params, defaults) + super().__init__(params, defaults) def __setstate__(self, state): - super(Novograd, self).__setstate__(state) + super().__setstate__(state) for group in self.param_groups: group.setdefault("amsgrad", False) diff --git a/monai/optimizers/utils.py b/monai/optimizers/utils.py index c52ab07a04..1a040927d8 100644 --- a/monai/optimizers/utils.py +++ b/monai/optimizers/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -47,7 +47,7 @@ def generate_param_groups( .. code-block:: python - net = Unet(dimensions=3, in_channels=1, out_channels=3, channels=[2, 2, 2], strides=[1, 1, 1]) + net = Unet(spatial_dims=3, in_channels=1, out_channels=3, channels=[2, 2, 2], strides=[1, 1, 1]) print(net) # print out network components to select expected items print(net.named_parameters()) # print out all the named parameters to filter out expected items params = generate_param_groups( diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 2ea7e3aa63..a3dc439a51 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,6 +18,7 @@ CenterSpatialCrop, CropForeground, DivisiblePad, + Pad, RandCropByLabelClasses, RandCropByPosNegLabel, RandScaleCrop, @@ -48,7 +49,7 @@ DivisiblePadd, DivisiblePadD, DivisiblePadDict, - NumpyPadModeSequence, + PadModeSequence, RandCropByLabelClassesd, RandCropByLabelClassesD, RandCropByLabelClassesDict, @@ -84,18 +85,21 @@ GaussianSmooth, GibbsNoise, HistogramNormalize, + IntensityRemap, KSpaceSpikeNoise, - LocalPatchShuffling, MaskIntensity, NormalizeIntensity, RandAdjustContrast, RandBiasField, RandCoarseDropout, + RandCoarseShuffle, + RandCoarseTransform, RandGaussianNoise, RandGaussianSharpen, RandGaussianSmooth, RandGibbsNoise, RandHistogramShift, + RandIntensityRemap, RandKSpaceSpikeNoise, RandRicianNoise, RandScaleIntensity, @@ -143,6 +147,9 @@ RandCoarseDropoutd, RandCoarseDropoutD, RandCoarseDropoutDict, + RandCoarseShuffled, + RandCoarseShuffleD, + RandCoarseShuffleDict, RandGaussianNoised, RandGaussianNoiseD, RandGaussianNoiseDict, @@ -173,6 +180,9 @@ RandStdShiftIntensityd, RandStdShiftIntensityD, RandStdShiftIntensityDict, + SavitzkyGolaySmoothd, + SavitzkyGolaySmoothD, + SavitzkyGolaySmoothDict, ScaleIntensityd, ScaleIntensityD, ScaleIntensityDict, @@ -192,8 +202,8 @@ ThresholdIntensityD, ThresholdIntensityDict, ) -from .inverse import InvertibleTransform -from .inverse_batch_transform import BatchInverseTransform, Decollated +from .inverse import InvertibleTransform, TraceableTransform +from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict from .io.array import SUPPORTED_READERS, LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .nvtx import ( @@ -241,6 +251,8 @@ AsDiscreted, AsDiscreteDict, Ensembled, + EnsembleD, + EnsembleDict, FillHolesD, FillHolesd, FillHolesDict, @@ -269,11 +281,28 @@ VoteEnsembled, VoteEnsembleDict, ) +from .smooth_field.array import ( + RandSmoothDeform, + RandSmoothFieldAdjustContrast, + RandSmoothFieldAdjustIntensity, + SmoothField, +) +from .smooth_field.dictionary import ( + RandSmoothDeformd, + RandSmoothDeformD, + RandSmoothDeformDict, + RandSmoothFieldAdjustContrastd, + RandSmoothFieldAdjustContrastD, + RandSmoothFieldAdjustContrastDict, + RandSmoothFieldAdjustIntensityd, + RandSmoothFieldAdjustIntensityD, + RandSmoothFieldAdjustIntensityDict, +) from .spatial.array import ( - AddCoordinateChannels, Affine, AffineGrid, Flip, + GridDistortion, Orientation, Rand2DElastic, Rand3DElastic, @@ -282,26 +311,29 @@ RandAxisFlip, RandDeformGrid, RandFlip, + RandGridDistortion, RandRotate, RandRotate90, RandZoom, Resample, + ResampleToMatch, Resize, Rotate, Rotate90, Spacing, + SpatialResample, Zoom, ) from .spatial.dictionary import ( - AddCoordinateChannelsd, - AddCoordinateChannelsD, - AddCoordinateChannelsDict, Affined, AffineD, AffineDict, Flipd, FlipD, FlipDict, + GridDistortiond, + GridDistortionD, + GridDistortionDict, Orientationd, OrientationD, OrientationDict, @@ -320,6 +352,9 @@ RandFlipd, RandFlipD, RandFlipDict, + RandGridDistortiond, + RandGridDistortionD, + RandGridDistortionDict, RandRotate90d, RandRotate90D, RandRotate90Dict, @@ -329,6 +364,9 @@ RandZoomd, RandZoomD, RandZoomDict, + ResampleToMatchd, + ResampleToMatchD, + ResampleToMatchDict, Resized, ResizeD, ResizeDict, @@ -341,6 +379,9 @@ Spacingd, SpacingD, SpacingDict, + SpatialResampled, + SpatialResampleD, + SpatialResampleDict, Zoomd, ZoomD, ZoomDict, @@ -348,12 +389,14 @@ from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform from .utility.array import ( AddChannel, + AddCoordinateChannels, AddExtremePointsChannel, AsChannelFirst, AsChannelLast, CastToType, ClassesToIndices, ConvertToMultiChannelBasedOnBratsClasses, + CuCIM, DataStats, EnsureChannelFirst, EnsureType, @@ -363,6 +406,7 @@ LabelToMask, Lambda, MapLabelValue, + RandCuCIM, RandLambda, RemoveRepeatedChannel, RepeatChannel, @@ -381,6 +425,9 @@ AddChanneld, AddChannelD, AddChannelDict, + AddCoordinateChannelsd, + AddCoordinateChannelsD, + AddCoordinateChannelsDict, AddExtremePointsChanneld, AddExtremePointsChannelD, AddExtremePointsChannelDict, @@ -405,6 +452,9 @@ CopyItemsd, CopyItemsD, CopyItemsDict, + CuCIMd, + CuCIMD, + CuCIMDict, DataStatsd, DataStatsD, DataStatsDict, @@ -435,6 +485,9 @@ MapLabelValued, MapLabelValueD, MapLabelValueDict, + RandCuCIMd, + RandCuCIMD, + RandCuCIMDict, RandLambdad, RandLambdaD, RandLambdaDict, @@ -486,6 +539,8 @@ allow_missing_keys_mode, compute_divisible_spatial_size, convert_inverse_interp_mode, + convert_pad_mode, + convert_to_contiguous, copypaste_arrays, create_control_grid, create_grid, @@ -518,4 +573,25 @@ weighted_patch_samples, zero_margins, ) -from .utils_pytorch_numpy_unification import in1d, moveaxis +from .utils_pytorch_numpy_unification import ( + allclose, + any_np_pt, + ascontiguousarray, + clip, + concatenate, + cumsum, + floor_divide, + in1d, + isfinite, + isnan, + maximum, + mode, + moveaxis, + nonzero, + percentile, + ravel, + repeat, + stack, + unravel_index, + where, +) diff --git a/monai/transforms/adaptors.py b/monai/transforms/adaptors.py index 434d1f1c05..92fd11cf79 100644 --- a/monai/transforms/adaptors.py +++ b/monai/transforms/adaptors.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 4bf175769b..bc55af0b15 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -28,7 +28,7 @@ apply_transform, ) from monai.utils import MAX_SEED, ensure_tuple, get_seed -from monai.utils.enums import InverseKeys +from monai.utils.enums import TraceKeys __all__ = ["Compose", "OneOf"] @@ -100,6 +100,17 @@ class Compose(Randomizable, InvertibleTransform): Alternatively, one can create a class with a `__call__` function that calls your pre-processing functions taking into account that not all of them are called on the labels. + + Args: + transforms: sequence of callables. + map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple. + defaults to `True`. + unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. + defaults to `False`. + log_stats: whether to log the detailed information of data and applied transform when error happened, + for NumPy array and PyTorch Tensor, log the data shape and value range, + for other meta data, log the values directly. default to `False`. + """ def __init__( @@ -107,12 +118,14 @@ def __init__( transforms: Optional[Union[Sequence[Callable], Callable]] = None, map_items: bool = True, unpack_items: bool = False, + log_stats: bool = False, ) -> None: if transforms is None: transforms = [] self.transforms = ensure_tuple(transforms) self.map_items = map_items self.unpack_items = unpack_items + self.log_stats = log_stats self.set_random_state(seed=get_seed()) def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> "Compose": @@ -157,7 +170,7 @@ def __len__(self): def __call__(self, input_): for _transform in self.transforms: - input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items) + input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items, self.log_stats) return input_ def inverse(self, data): @@ -167,22 +180,27 @@ def inverse(self, data): # loop backwards over transforms for t in reversed(invertible_transforms): - data = apply_transform(t.inverse, data, self.map_items, self.unpack_items) + data = apply_transform(t.inverse, data, self.map_items, self.unpack_items, self.log_stats) return data class OneOf(Compose): """ - ``OneOf`` provides the ability to radomly choose one transform out of a - list of callables with predfined probabilities for each. + ``OneOf`` provides the ability to randomly choose one transform out of a + list of callables with pre-defined probabilities for each. Args: transforms: sequence of callables. weights: probabilities corresponding to each callable in transforms. Probabilities are normalized to sum to one. + map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple. + defaults to `True`. + unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. + defaults to `False`. + log_stats: whether to log the detailed information of data and applied transform when error happened, + for NumPy array and PyTorch Tensor, log the data shape and value range, + for other meta data, log the values directly. default to `False`. - OneOf inherits from Compose and uses args map_items and unpack_items in - the same way. """ def __init__( @@ -191,8 +209,9 @@ def __init__( weights: Optional[Union[Sequence[float], float]] = None, map_items: bool = True, unpack_items: bool = False, + log_stats: bool = False, ) -> None: - super().__init__(transforms, map_items, unpack_items) + super().__init__(transforms, map_items, unpack_items, log_stats) if len(self.transforms) == 0: weights = [] elif weights is None or isinstance(weights, float): @@ -204,14 +223,13 @@ def __init__( def _normalize_probabilities(self, weights): if len(weights) == 0: return weights - else: - weights = np.array(weights) - if np.any(weights < 0): - raise AssertionError("Probabilities must be greater than or equal to zero.") - if np.all(weights == 0): - raise AssertionError("At least one probability must be greater than zero.") - weights = weights / weights.sum() - return list(weights) + weights = np.array(weights) + if np.any(weights < 0): + raise AssertionError("Probabilities must be greater than or equal to zero.") + if np.all(weights == 0): + raise AssertionError("At least one probability must be greater than zero.") + weights = weights / weights.sum() + return list(weights) def flatten(self): transforms = [] @@ -232,16 +250,15 @@ def flatten(self): def __call__(self, data): if len(self.transforms) == 0: return data - else: - index = self.R.multinomial(1, self.weights).argmax() - _transform = self.transforms[index] - data = apply_transform(_transform, data, self.map_items, self.unpack_items) - # if the data is a mapping (dictionary), append the OneOf transform to the end - if isinstance(data, Mapping): - for key in data.keys(): - if key + InverseKeys.KEY_SUFFIX in data: - self.push_transform(data, key, extra_info={"index": index}) - return data + index = self.R.multinomial(1, self.weights).argmax() + _transform = self.transforms[index] + data = apply_transform(_transform, data, self.map_items, self.unpack_items, self.log_stats) + # if the data is a mapping (dictionary), append the OneOf transform to the end + if isinstance(data, Mapping): + for key in data.keys(): + if self.trace_key(key) in data: + self.push_transform(data, key, extra_info={"index": index}) + return data def inverse(self, data): if len(self.transforms) == 0: @@ -252,18 +269,15 @@ def inverse(self, data): # loop until we get an index and then break (since they'll all be the same) index = None for key in data.keys(): - if key + InverseKeys.KEY_SUFFIX in data: + if self.trace_key(key) in data: # get the index of the applied OneOf transform - index = self.get_most_recent_transform(data, key)[InverseKeys.EXTRA_INFO]["index"] + index = self.get_most_recent_transform(data, key)[TraceKeys.EXTRA_INFO]["index"] # and then remove the OneOf transform self.pop_transform(data, key) if index is None: - raise RuntimeError("No invertible transforms have been applied") + # no invertible transforms have been applied + return data - # if applied transform is not InvertibleTransform, throw error _transform = self.transforms[index] - if not isinstance(_transform, InvertibleTransform): - raise RuntimeError(f"Applied OneOf transform is not invertible (applied index: {index}).") - # apply the inverse - return _transform.inverse(data) + return _transform.inverse(data) if isinstance(_transform, InvertibleTransform) else data diff --git a/monai/transforms/croppad/__init__.py b/monai/transforms/croppad/__init__.py index 14ae193634..1e97f89407 100644 --- a/monai/transforms/croppad/__init__.py +++ b/monai/transforms/croppad/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 74f556cc1a..b05917a46c 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,11 +22,12 @@ from torch.nn.functional import pad as pad_pt from monai.config import IndexSelection -from monai.config.type_definitions import NdarrayTensor +from monai.config.type_definitions import NdarrayOrTensor from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.transform import Randomizable, Transform from monai.transforms.utils import ( compute_divisible_spatial_size, + convert_pad_mode, generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, @@ -35,11 +36,21 @@ map_classes_to_indices, weighted_patch_samples, ) -from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple, look_up_option +from monai.transforms.utils_pytorch_numpy_unification import floor_divide, maximum +from monai.utils import ( + Method, + NumpyPadMode, + PytorchPadMode, + ensure_tuple, + ensure_tuple_rep, + fall_back_tuple, + look_up_option, +) from monai.utils.enums import TransformBackends -from monai.utils.type_conversion import convert_data_type +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type __all__ = [ + "Pad", "SpatialPad", "BorderPad", "DivisiblePad", @@ -61,16 +72,18 @@ class Pad(Transform): """ Perform padding for a given an amount of padding in each dimension. - If input is `torch.Tensor` and mode is `constant`, `torch.nn.functional.pad` will be used. - Otherwise, `np.pad` will be used (input converted to `np.ndarray` if necessary). - Uses np.pad so in practice, a mode needs to be provided. See numpy.lib.arraypad.pad - for additional details. + If input is `torch.Tensor`, `torch.nn.functional.pad` will be used, otherwise, `np.pad` will be used. + Args: to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...]. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -78,43 +91,44 @@ class Pad(Transform): def __init__( self, to_pad: List[Tuple[int, int]], - mode: Union[NumpyPadMode, str, None] = NumpyPadMode.CONSTANT, - **np_kwargs, + mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, + **kwargs, ) -> None: self.to_pad = to_pad - self.mode = mode or NumpyPadMode.CONSTANT - self.np_kwargs = np_kwargs + self.mode = mode + self.kwargs = kwargs @staticmethod - def _np_pad(img: np.ndarray, all_pad_width, mode, **np_kwargs) -> np.ndarray: - img_np, *_ = convert_data_type(img, np.ndarray) - return np.pad(img_np, all_pad_width, mode=mode, **np_kwargs) # type: ignore + def _np_pad(img: np.ndarray, all_pad_width, mode, **kwargs) -> np.ndarray: + return np.pad(img, all_pad_width, mode=mode, **kwargs) # type: ignore @staticmethod - def _pt_pad(img: torch.Tensor, all_pad_width, mode, **np_kwargs) -> torch.Tensor: - pt_pad_width = [val for sublist in all_pad_width for val in sublist[::-1]][::-1] - return pad_pt(img, pt_pad_width, mode=mode, **np_kwargs) + def _pt_pad(img: torch.Tensor, all_pad_width, mode, **kwargs) -> torch.Tensor: + pt_pad_width = [val for sublist in all_pad_width[1:] for val in sublist[::-1]][::-1] + # torch.pad expects `[B, C, H, W, [D]]` shape + return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **kwargs).squeeze(0) - def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] = None) -> NdarrayTensor: + def __call__( + self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None + ) -> NdarrayOrTensor: """ Args: img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - One of the listed string values or a user supplied function. Defaults to ``self.mode``. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"`` or ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to `self.mode`. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + """ if not np.asarray(self.to_pad).any(): # all zeros, skip padding return img - mode = mode or self.mode - mode = mode.value if isinstance(mode, NumpyPadMode) else mode - if isinstance(img, torch.Tensor) and mode == "constant" and not self.np_kwargs: - pad = self._pt_pad - else: - pad = self._np_pad # type: ignore - return pad(img, self.to_pad, mode, **self.np_kwargs) + mode = convert_pad_mode(dst=img, mode=mode or self.mode).value + pad = self._pt_pad if isinstance(img, torch.Tensor) else self._np_pad + return pad(img, self.to_pad, mode, **self.kwargs) # type: ignore class SpatialPad(Transform): @@ -135,12 +149,14 @@ class SpatialPad(Transform): `spatial_size=[32, 25, -1]`, the spatial size of output data will be [32, 30, 30]. method: {``"symmetric"``, ``"end"``} Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ @@ -150,13 +166,13 @@ def __init__( self, spatial_size: Union[Sequence[int], int], method: Union[Method, str] = Method.SYMMETRIC, - mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, - **np_kwargs, + mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, + **kwargs, ) -> None: self.spatial_size = spatial_size self.method: Method = look_up_option(method, Method) - self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode) - self.np_kwargs = np_kwargs + self.mode = mode + self.kwargs = kwargs def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int, int]]: spatial_size = fall_back_tuple(self.spatial_size, data_shape) @@ -168,15 +184,20 @@ def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int return pad_width return [(0, max(sp_i - data_shape[i], 0)) for i, sp_i in enumerate(spatial_size)] - def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] = None) -> NdarrayTensor: + def __call__( + self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None + ) -> NdarrayOrTensor: """ Args: img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - One of the listed string values or a user supplied function. Defaults to ``self.mode``. + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to `self.mode`. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + """ data_pad_width = self._determine_data_pad_width(img.shape[1:]) all_pad_width = [(0, 0)] + data_pad_width @@ -184,8 +205,7 @@ def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] # all zeros, skip padding return img - mode = look_up_option(mode or self.mode, NumpyPadMode) - padder = Pad(all_pad_width, mode, **self.np_kwargs) + padder = Pad(all_pad_width, mode or self.mode, **self.kwargs) return padder(img) @@ -204,13 +224,14 @@ class BorderPad(Transform): for example, image shape(CHW) is [1, 4, 4], spatial_border is [1, 2, 3, 4], pad top of H dim with 1, pad bottom of H dim with 2, pad left of W dim with 3, pad right of W dim with 4. the result shape is [1, 7, 11]. - - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ @@ -219,22 +240,26 @@ class BorderPad(Transform): def __init__( self, spatial_border: Union[Sequence[int], int], - mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, - **np_kwargs, + mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, + **kwargs, ) -> None: self.spatial_border = spatial_border - self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode) - self.np_kwargs = np_kwargs + self.mode = mode + self.kwargs = kwargs - def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] = None) -> NdarrayTensor: + def __call__( + self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None + ) -> NdarrayOrTensor: """ Args: img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - One of the listed string values or a user supplied function. Defaults to ``self.mode``. + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to `self.mode`. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html Raises: ValueError: When ``self.spatial_border`` does not contain ints. @@ -261,8 +286,7 @@ def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] ) all_pad_width = [(0, 0)] + data_pad_width - mode = look_up_option(mode or self.mode, NumpyPadMode) - padder = Pad(all_pad_width, mode, **self.np_kwargs) + padder = Pad(all_pad_width, mode or self.mode, **self.kwargs) return padder(img) @@ -276,48 +300,50 @@ class DivisiblePad(Transform): def __init__( self, k: Union[Sequence[int], int], - mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, method: Union[Method, str] = Method.SYMMETRIC, - **np_kwargs, + **kwargs, ) -> None: """ Args: k: the target k for each spatial dimension. if `k` is negative or 0, the original size is preserved. if `k` is an int, the same `k` be applied to all the input spatial dimensions. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html method: {``"symmetric"``, ``"end"``} Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. See also :py:class:`monai.transforms.SpatialPad` """ self.k = k self.mode: NumpyPadMode = NumpyPadMode(mode) self.method: Method = Method(method) - self.np_kwargs = np_kwargs + self.kwargs = kwargs - def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] = None) -> NdarrayTensor: + def __call__( + self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None + ) -> NdarrayOrTensor: """ Args: img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - One of the listed string values or a user supplied function. Defaults to ``self.mode``. + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to `self.mode`. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + """ new_size = compute_divisible_spatial_size(spatial_shape=img.shape[1:], k=self.k) - spatial_pad = SpatialPad( - spatial_size=new_size, - method=self.method, - mode=mode or self.mode, - **self.np_kwargs, - ) + spatial_pad = SpatialPad(spatial_size=new_size, method=self.method, mode=mode or self.mode, **self.kwargs) return spatial_pad(img) @@ -336,12 +362,14 @@ class SpatialCrop(Transform): - the start and end coordinates of the ROI """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, - roi_center: Union[Sequence[int], np.ndarray, None] = None, - roi_size: Union[Sequence[int], np.ndarray, None] = None, - roi_start: Union[Sequence[int], np.ndarray, None] = None, - roi_end: Union[Sequence[int], np.ndarray, None] = None, + roi_center: Union[Sequence[int], NdarrayOrTensor, None] = None, + roi_size: Union[Sequence[int], NdarrayOrTensor, None] = None, + roi_start: Union[Sequence[int], NdarrayOrTensor, None] = None, + roi_end: Union[Sequence[int], NdarrayOrTensor, None] = None, roi_slices: Optional[Sequence[slice]] = None, ) -> None: """ @@ -354,28 +382,37 @@ def __init__( use the end coordinate of image. roi_slices: list of slices for each of the spatial dimensions. """ + roi_start_torch: torch.Tensor + if roi_slices: if not all(s.step is None or s.step == 1 for s in roi_slices): raise ValueError("Only slice steps of 1/None are currently supported") self.slices = list(roi_slices) else: if roi_center is not None and roi_size is not None: - roi_center = np.asarray(roi_center, dtype=np.int16) - roi_size = np.asarray(roi_size, dtype=np.int16) - roi_start_np = np.maximum(roi_center - np.floor_divide(roi_size, 2), 0) - roi_end_np = np.maximum(roi_start_np + roi_size, roi_start_np) + roi_center, *_ = convert_data_type( + data=roi_center, output_type=torch.Tensor, dtype=torch.int16, wrap_sequence=True + ) + roi_size, *_ = convert_to_dst_type(src=roi_size, dst=roi_center, wrap_sequence=True) + _zeros = torch.zeros_like(roi_center) + roi_start_torch = maximum(roi_center - floor_divide(roi_size, 2), _zeros) # type: ignore + roi_end_torch = maximum(roi_start_torch + roi_size, roi_start_torch) else: if roi_start is None or roi_end is None: raise ValueError("Please specify either roi_center, roi_size or roi_start, roi_end.") - roi_start_np = np.maximum(np.asarray(roi_start, dtype=np.int16), 0) - roi_end_np = np.maximum(np.asarray(roi_end, dtype=np.int16), roi_start_np) - # Allow for 1D by converting back to np.array (since np.maximum will convert to int) - roi_start_np = roi_start_np if isinstance(roi_start_np, np.ndarray) else np.array([roi_start_np]) - roi_end_np = roi_end_np if isinstance(roi_end_np, np.ndarray) else np.array([roi_end_np]) - # convert to slices - self.slices = [slice(s, e) for s, e in zip(roi_start_np, roi_end_np)] - - def __call__(self, img: Union[np.ndarray, torch.Tensor]): + roi_start_torch, *_ = convert_data_type( + data=roi_start, output_type=torch.Tensor, dtype=torch.int16, wrap_sequence=True + ) + roi_start_torch = maximum(roi_start_torch, torch.zeros_like(roi_start_torch)) # type: ignore + roi_end_torch, *_ = convert_to_dst_type(src=roi_end, dst=roi_start_torch, wrap_sequence=True) + roi_end_torch = maximum(roi_end_torch, roi_start_torch) + # convert to slices (accounting for 1d) + if roi_start_torch.numel() == 1: + self.slices = [slice(int(roi_start_torch.item()), int(roi_end_torch.item()))] + else: + self.slices = [slice(int(s), int(e)) for s, e in zip(roi_start_torch.tolist(), roi_end_torch.tolist())] + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -400,10 +437,12 @@ class CenterSpatialCrop(Transform): the spatial size of output data will be [32, 40, 40]. """ + backend = SpatialCrop.backend + def __init__(self, roi_size: Union[Sequence[int], int]) -> None: self.roi_size = roi_size - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -424,10 +463,12 @@ class CenterScaleCrop(Transform): """ + backend = CenterSpatialCrop.backend + def __init__(self, roi_scale: Union[Sequence[float], float]): self.roi_scale = roi_scale - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: img_size = img.shape[1:] ndim = len(img_size) roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] @@ -459,6 +500,8 @@ class RandSpatialCrop(Randomizable, Transform): if True, the actual size is sampled from `randint(roi_size, max_roi_size + 1)`. """ + backend = CenterSpatialCrop.backend + def __init__( self, roi_size: Union[Sequence[int], int], @@ -479,19 +522,19 @@ def randomize(self, img_size: Sequence[int]) -> None: max_size = img_size if self.max_roi_size is None else fall_back_tuple(self.max_roi_size, img_size) if any(i > j for i, j in zip(self._size, max_size)): raise ValueError(f"min ROI size: {self._size} is bigger than max ROI size: {max_size}.") - self._size = tuple((self.R.randint(low=self._size[i], high=max_size[i] + 1) for i in range(len(img_size)))) + self._size = tuple(self.R.randint(low=self._size[i], high=max_size[i] + 1) for i in range(len(img_size))) if self.random_center: valid_size = get_valid_patch_size(img_size, self._size) self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ self.randomize(img.shape[1:]) if self._size is None: - raise AssertionError + raise RuntimeError("self._size not specified.") if self.random_center: return img[self._slices] cropper = CenterSpatialCrop(self._size) @@ -530,7 +573,7 @@ def __init__( self.roi_scale = roi_scale self.max_roi_scale = max_roi_scale - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -576,6 +619,8 @@ class RandSpatialCropSamples(Randomizable, Transform): """ + backend = RandSpatialCrop.backend + def __init__( self, roi_size: Union[Sequence[int], int], @@ -591,15 +636,15 @@ def __init__( def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "Randomizable": - super().set_random_state(seed=seed, state=state) - self.cropper.set_random_state(state=self.R) + ) -> "RandSpatialCropSamples": + super().set_random_state(seed, state) + self.cropper.set_random_state(seed, state) return self def randomize(self, data: Optional[Any] = None) -> None: pass - def __call__(self, img: np.ndarray) -> List[np.ndarray]: + def __call__(self, img: NdarrayOrTensor) -> List[NdarrayOrTensor]: """ Apply the transform to `img`, assuming `img` is channel-first and cropping doesn't change the channel dim. @@ -639,14 +684,17 @@ def threshold_at_one(x): """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, select_fn: Callable = is_positive, channel_indices: Optional[IndexSelection] = None, margin: Union[Sequence[int], int] = 0, + allow_smaller: bool = True, return_coords: bool = False, k_divisible: Union[Sequence[int], int] = 1, - mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = NumpyPadMode.CONSTANT, **np_kwargs, ) -> None: """ @@ -655,13 +703,18 @@ def __init__( channel_indices: if defined, select foreground only on the specified channels of image. if None, select foreground on the whole image. margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. + allow_smaller: when computing box size with `margin`, whether allow the image size to be smaller + than box size, default to `True`. if the margined size is bigger than image size, will pad with + specified `mode`. return_coords: whether return the coordinates of spatial bounding box for foreground. k_divisible: make each spatial dimension to be divisible by k, default to 1. if `k_divisible` is an int, the same `k` be applied to all the input spatial dimensions. - mode: padding mode {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - one of the listed string values or a user supplied function. Defaults to ``"constant"``. - see also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html @@ -669,23 +722,26 @@ def __init__( self.select_fn = select_fn self.channel_indices = ensure_tuple(channel_indices) if channel_indices is not None else None self.margin = margin + self.allow_smaller = allow_smaller self.return_coords = return_coords self.k_divisible = k_divisible self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode) self.np_kwargs = np_kwargs - def compute_bounding_box(self, img: np.ndarray): + def compute_bounding_box(self, img: NdarrayOrTensor): """ Compute the start points and end points of bounding box to crop. And adjust bounding box coords to be divisible by `k`. """ - box_start, box_end = generate_spatial_bounding_box(img, self.select_fn, self.channel_indices, self.margin) - box_start_ = np.asarray(box_start, dtype=np.int16) - box_end_ = np.asarray(box_end, dtype=np.int16) + box_start, box_end = generate_spatial_bounding_box( + img, self.select_fn, self.channel_indices, self.margin, self.allow_smaller + ) + box_start_, *_ = convert_data_type(box_start, output_type=np.ndarray, dtype=np.int16, wrap_sequence=True) + box_end_, *_ = convert_data_type(box_end, output_type=np.ndarray, dtype=np.int16, wrap_sequence=True) orig_spatial_size = box_end_ - box_start_ # make the spatial size divisible by `k` - spatial_size = np.asarray(compute_divisible_spatial_size(spatial_shape=orig_spatial_size, k=self.k_divisible)) + spatial_size = np.asarray(compute_divisible_spatial_size(orig_spatial_size.tolist(), k=self.k_divisible)) # update box_start and box_end box_start_ = box_start_ - np.floor_divide(np.asarray(spatial_size) - orig_spatial_size, 2) box_end_ = box_start_ + spatial_size @@ -693,10 +749,10 @@ def compute_bounding_box(self, img: np.ndarray): def crop_pad( self, - img: np.ndarray, + img: NdarrayOrTensor, box_start: np.ndarray, box_end: np.ndarray, - mode: Optional[Union[NumpyPadMode, str]] = None, + mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, ): """ Crop and pad based on the bounding box. @@ -708,7 +764,7 @@ def crop_pad( pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) return BorderPad(spatial_border=pad, mode=mode or self.mode, **self.np_kwargs)(cropped) - def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None): + def __call__(self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, str]] = None): """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't change the channel dim. @@ -734,20 +790,25 @@ class RandWeightedCrop(Randomizable, Transform): It should be a single-channel array in shape, for example, `(1, spatial_dim_0, spatial_dim_1, ...)`. """ + backend = SpatialCrop.backend + def __init__( - self, spatial_size: Union[Sequence[int], int], num_samples: int = 1, weight_map: Optional[np.ndarray] = None + self, + spatial_size: Union[Sequence[int], int], + num_samples: int = 1, + weight_map: Optional[NdarrayOrTensor] = None, ): self.spatial_size = ensure_tuple(spatial_size) self.num_samples = int(num_samples) self.weight_map = weight_map self.centers: List[np.ndarray] = [] - def randomize(self, weight_map: np.ndarray) -> None: + def randomize(self, weight_map: NdarrayOrTensor) -> None: self.centers = weighted_patch_samples( spatial_size=self.spatial_size, w=weight_map[0], n_samples=self.num_samples, r_state=self.R ) # using only the first channel as weight map - def __call__(self, img: np.ndarray, weight_map: Optional[np.ndarray] = None) -> List[np.ndarray]: + def __call__(self, img: NdarrayOrTensor, weight_map: Optional[NdarrayOrTensor] = None) -> List[NdarrayOrTensor]: """ Args: img: input image to sample patches from. assuming `img` is a channel-first array. @@ -764,9 +825,10 @@ def __call__(self, img: np.ndarray, weight_map: Optional[np.ndarray] = None) -> raise ValueError("weight map must be provided for weighted patch sampling.") if img.shape[1:] != weight_map.shape[1:]: raise ValueError(f"image and weight map spatial shape mismatch: {img.shape[1:]} vs {weight_map.shape[1:]}.") + self.randomize(weight_map) _spatial_size = fall_back_tuple(self.spatial_size, weight_map.shape[1:]) - results = [] + results: List[NdarrayOrTensor] = [] for center in self.centers: cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size) results.append(cropper(img)) @@ -816,6 +878,9 @@ class RandCropByPosNegLabel(Randomizable, Transform): `image_threshold`, and randomly select crop centers based on them, need to provide `fg_indices` and `bg_indices` together, expect to be 1 dim array of spatial indices after flattening. a typical usage is to call `FgBgToIndices` transform first and cache the results. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will be set to + match the cropped size (i.e., no cropping in that dimension). Raises: ValueError: When ``pos`` or ``neg`` are negative. @@ -823,17 +888,20 @@ class RandCropByPosNegLabel(Randomizable, Transform): """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, spatial_size: Union[Sequence[int], int], - label: Optional[np.ndarray] = None, + label: Optional[NdarrayOrTensor] = None, pos: float = 1.0, neg: float = 1.0, num_samples: int = 1, - image: Optional[np.ndarray] = None, + image: Optional[NdarrayOrTensor] = None, image_threshold: float = 0.0, - fg_indices: Optional[np.ndarray] = None, - bg_indices: Optional[np.ndarray] = None, + fg_indices: Optional[NdarrayOrTensor] = None, + bg_indices: Optional[NdarrayOrTensor] = None, + allow_smaller: bool = False, ) -> None: self.spatial_size = ensure_tuple(spatial_size) self.label = label @@ -845,16 +913,17 @@ def __init__( self.num_samples = num_samples self.image = image self.image_threshold = image_threshold - self.centers: Optional[List[List[np.ndarray]]] = None + self.centers: Optional[List[List[int]]] = None self.fg_indices = fg_indices self.bg_indices = bg_indices + self.allow_smaller = allow_smaller def randomize( self, - label: np.ndarray, - fg_indices: Optional[np.ndarray] = None, - bg_indices: Optional[np.ndarray] = None, - image: Optional[np.ndarray] = None, + label: NdarrayOrTensor, + fg_indices: Optional[NdarrayOrTensor] = None, + bg_indices: Optional[NdarrayOrTensor] = None, + image: Optional[NdarrayOrTensor] = None, ) -> None: self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) if fg_indices is None or bg_indices is None: @@ -867,17 +936,24 @@ def randomize( fg_indices_ = fg_indices bg_indices_ = bg_indices self.centers = generate_pos_neg_label_crop_centers( - self.spatial_size, self.num_samples, self.pos_ratio, label.shape[1:], fg_indices_, bg_indices_, self.R + self.spatial_size, + self.num_samples, + self.pos_ratio, + label.shape[1:], + fg_indices_, + bg_indices_, + self.R, + self.allow_smaller, ) def __call__( self, - img: np.ndarray, - label: Optional[np.ndarray] = None, - image: Optional[np.ndarray] = None, - fg_indices: Optional[np.ndarray] = None, - bg_indices: Optional[np.ndarray] = None, - ) -> List[np.ndarray]: + img: NdarrayOrTensor, + label: Optional[NdarrayOrTensor] = None, + image: Optional[NdarrayOrTensor] = None, + fg_indices: Optional[NdarrayOrTensor] = None, + bg_indices: Optional[NdarrayOrTensor] = None, + ) -> List[NdarrayOrTensor]: """ Args: img: input data to crop samples from based on the pos/neg ratio of `label` and `image`. @@ -900,10 +976,10 @@ def __call__( image = self.image self.randomize(label, fg_indices, bg_indices, image) - results: List[np.ndarray] = [] + results: List[NdarrayOrTensor] = [] if self.centers is not None: for center in self.centers: - cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore + cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size) results.append(cropper(img)) return results @@ -965,19 +1041,25 @@ class RandCropByLabelClasses(Randomizable, Transform): `image_threshold`, and randomly select crop centers based on them, expect to be 1 dim array of spatial indices after flattening. a typical usage is to call `ClassesToIndices` transform first and cache the results for better performance. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will remain + unchanged. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, spatial_size: Union[Sequence[int], int], ratios: Optional[List[Union[float, int]]] = None, - label: Optional[np.ndarray] = None, + label: Optional[NdarrayOrTensor] = None, num_classes: Optional[int] = None, num_samples: int = 1, - image: Optional[np.ndarray] = None, + image: Optional[NdarrayOrTensor] = None, image_threshold: float = 0.0, - indices: Optional[List[np.ndarray]] = None, + indices: Optional[List[NdarrayOrTensor]] = None, + allow_smaller: bool = False, ) -> None: self.spatial_size = ensure_tuple(spatial_size) self.ratios = ratios @@ -986,17 +1068,18 @@ def __init__( self.num_samples = num_samples self.image = image self.image_threshold = image_threshold - self.centers: Optional[List[List[np.ndarray]]] = None + self.centers: Optional[List[List[int]]] = None self.indices = indices + self.allow_smaller = allow_smaller def randomize( self, - label: np.ndarray, - indices: Optional[List[np.ndarray]] = None, - image: Optional[np.ndarray] = None, + label: NdarrayOrTensor, + indices: Optional[List[NdarrayOrTensor]] = None, + image: Optional[NdarrayOrTensor] = None, ) -> None: self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) - indices_: List[np.ndarray] + indices_: Sequence[NdarrayOrTensor] if indices is None: if self.indices is not None: indices_ = self.indices @@ -1005,16 +1088,16 @@ def randomize( else: indices_ = indices self.centers = generate_label_classes_crop_centers( - self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R + self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R, self.allow_smaller ) def __call__( self, - img: np.ndarray, - label: Optional[np.ndarray] = None, - image: Optional[np.ndarray] = None, - indices: Optional[List[np.ndarray]] = None, - ) -> List[np.ndarray]: + img: NdarrayOrTensor, + label: Optional[NdarrayOrTensor] = None, + image: Optional[NdarrayOrTensor] = None, + indices: Optional[List[NdarrayOrTensor]] = None, + ) -> List[NdarrayOrTensor]: """ Args: img: input data to crop samples from based on the ratios of every class, assumes `img` is a @@ -1033,10 +1116,10 @@ def __call__( image = self.image self.randomize(label, indices, image) - results: List[np.ndarray] = [] + results: List[NdarrayOrTensor] = [] if self.centers is not None: for center in self.centers: - cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore + cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) results.append(cropper(img)) return results @@ -1063,6 +1146,8 @@ class ResizeWithPadOrCrop(Transform): """ + backend = list(set(SpatialPad.backend) & set(CenterSpatialCrop.backend)) + def __init__( self, spatial_size: Union[Sequence[int], int], @@ -1073,7 +1158,7 @@ def __init__( self.padder = SpatialPad(spatial_size=spatial_size, method=method, mode=mode, **np_kwargs) self.cropper = CenterSpatialCrop(roi_size=spatial_size) - def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, str]] = None) -> NdarrayOrTensor: """ Args: img: data to pad or crop, assuming `img` is channel-first and @@ -1105,16 +1190,18 @@ class BoundingRect(Transform): Nth_spatial_dim_start, Nth_spatial_dim_end]] The bounding boxes edges are aligned with the input image edges. - This function returns [-1, -1, ...] if there's no positive intensity. + This function returns [0, 0, ...] if there's no positive intensity. Args: select_fn: function to select expected foreground, default is to select values > 0. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, select_fn: Callable = is_positive) -> None: self.select_fn = select_fn - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> np.ndarray: """ See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`. """ diff --git a/monai/transforms/croppad/batch.py b/monai/transforms/croppad/batch.py index 956dff7881..52d0c7be3b 100644 --- a/monai/transforms/croppad/batch.py +++ b/monai/transforms/croppad/batch.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,15 +20,11 @@ import torch from monai.data.utils import list_data_collate -from monai.transforms.compose import Compose from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad from monai.transforms.inverse import InvertibleTransform -from monai.transforms.utility.array import ToTensor -from monai.utils.enums import InverseKeys, Method, NumpyPadMode +from monai.utils.enums import Method, NumpyPadMode, TraceKeys -__all__ = [ - "PadListDataCollate", -] +__all__ = ["PadListDataCollate"] def replace_element(to_replace, batch, idx, key_or_idx): @@ -49,8 +45,9 @@ class PadListDataCollate(InvertibleTransform): tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of different sizes. - This can be used on both list and dictionary data. In the case of the dictionary data, this transform will be added - to the list of invertible transforms. + This can be used on both list and dictionary data. + Note that in the case of the dictionary data, it may add the transform information to the list of invertible transforms + if input batch have different spatial shape, so need to call static method: `inverse` before inverting other transforms. Note that normally, a user won't explicitly use the `__call__` method. Rather this would be passed to the `DataLoader`. This means that `__call__` handles data as it comes out of a `DataLoader`, containing batch dimension. However, the @@ -97,20 +94,12 @@ def __call__(self, batch: Any): # If all same size, skip if np.all(np.array(max_shapes).min(axis=0) == max_shape): continue - # Do we need to convert output to Tensor? - output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor) - - # Use `SpatialPadd` or `SpatialPad` to match sizes - # Default params are central padding, padding with 0's - # If input is dictionary, use the dictionary version so that the transformation is recorded + # Use `SpatialPad` to match sizes, Default params are central padding, padding with 0's padder = SpatialPad(spatial_size=max_shape, method=self.method, mode=self.mode, **self.np_kwargs) - transform = padder if not output_to_tensor else Compose([padder, ToTensor()]) - for idx, batch_i in enumerate(batch): - im = batch_i[key_or_idx] - orig_size = im.shape[1:] - padded = transform(batch_i[key_or_idx]) + orig_size = batch_i[key_or_idx].shape[1:] + padded = padder(batch_i[key_or_idx]) batch = replace_element(padded, batch, idx, key_or_idx) # If we have a dictionary of data, append to list @@ -127,11 +116,13 @@ def inverse(data: dict) -> Dict[Hashable, np.ndarray]: d = deepcopy(data) for key in d: - transform_key = str(key) + InverseKeys.KEY_SUFFIX + transform_key = InvertibleTransform.trace_key(key) if transform_key in d: transform = d[transform_key][-1] - if transform[InverseKeys.CLASS_NAME] == PadListDataCollate.__name__: - d[key] = CenterSpatialCrop(transform["orig_size"])(d[key]) + if not isinstance(transform, Dict): + continue + if transform.get(TraceKeys.CLASS_NAME) == PadListDataCollate.__name__: + d[key] = CenterSpatialCrop(transform.get("orig_size", -1))(d[key]) # fallback to image size # remove transform d[transform_key].pop() return d diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 9e33ab2db1..19ebe40b46 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,7 +25,7 @@ import numpy as np from monai.config import IndexSelection, KeysCollection -from monai.config.type_definitions import NdarrayTensor +from monai.config.type_definitions import NdarrayOrTensor from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.croppad.array import ( BorderPad, @@ -33,6 +33,8 @@ CenterSpatialCrop, CropForeground, DivisiblePad, + RandCropByLabelClasses, + RandCropByPosNegLabel, ResizeWithPadOrCrop, SpatialCrop, SpatialPad, @@ -49,11 +51,11 @@ weighted_patch_samples, ) from monai.utils import ImageMetaKey as Key -from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple -from monai.utils.enums import InverseKeys +from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple +from monai.utils.enums import PostFix, TraceKeys __all__ = [ - "NumpyPadModeSequence", + "PadModeSequence", "SpatialPadd", "BorderPadd", "DivisiblePadd", @@ -96,9 +98,14 @@ "ResizeWithPadOrCropDict", "BoundingRectD", "BoundingRectDict", + "RandCropByLabelClassesd", + "RandCropByLabelClassesD", + "RandCropByLabelClassesDict", ] NumpyPadModeSequence = Union[Sequence[Union[NumpyPadMode, str]], NumpyPadMode, str] +PadModeSequence = Union[Sequence[Union[NumpyPadMode, PytorchPadMode, str]], NumpyPadMode, PytorchPadMode, str] +DEFAULT_POST_FIX = PostFix.meta() class SpatialPadd(MapTransform, InvertibleTransform): @@ -114,9 +121,9 @@ def __init__( keys: KeysCollection, spatial_size: Union[Sequence[int], int], method: Union[Method, str] = Method.SYMMETRIC, - mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, + mode: PadModeSequence = NumpyPadMode.CONSTANT, allow_missing_keys: bool = False, - **np_kwargs, + **kwargs, ) -> None: """ Args: @@ -129,33 +136,35 @@ def __init__( the spatial size of output data will be [32, 30, 30]. method: {``"symmetric"``, ``"end"``} Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padder = SpatialPad(spatial_size, method, **np_kwargs) + self.padder = SpatialPad(spatial_size, method, **kwargs) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, m in self.key_iterator(d, self.mode): self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) d[key] = self.padder(d[key], mode=m) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = transform[InverseKeys.ORIG_SIZE] + orig_size = transform[TraceKeys.ORIG_SIZE] if self.padder.method == Method.SYMMETRIC: current_size = d[key].shape[1:] roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) // 2 for r, i in zip(orig_size, current_size)] @@ -183,9 +192,9 @@ def __init__( self, keys: KeysCollection, spatial_border: Union[Sequence[int], int], - mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, + mode: PadModeSequence = NumpyPadMode.CONSTANT, allow_missing_keys: bool = False, - **np_kwargs, + **kwargs, ) -> None: """ Args: @@ -202,34 +211,36 @@ def __init__( pad bottom of H dim with 2, pad left of W dim with 3, pad right of W dim with 4. the result shape is [1, 7, 11]. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padder = BorderPad(spatial_border=spatial_border, **np_kwargs) + self.padder = BorderPad(spatial_border=spatial_border, **kwargs) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, m in self.key_iterator(d, self.mode): self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) d[key] = self.padder(d[key], mode=m) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) roi_start = np.array(self.padder.spatial_border) # Need to convert single value to [min1,min2,...] if roi_start.size == 1: @@ -237,7 +248,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # need to convert [min1,max1,min2,...] to [min1,min2,...] elif roi_start.size == 2 * orig_size.size: roi_start = roi_start[::2] - roi_end = np.array(transform[InverseKeys.ORIG_SIZE]) + roi_start + roi_end = np.array(transform[TraceKeys.ORIG_SIZE]) + roi_start inverse_transform = SpatialCrop(roi_start=roi_start, roi_end=roi_end) # Apply inverse transform @@ -260,10 +271,10 @@ def __init__( self, keys: KeysCollection, k: Union[Sequence[int], int], - mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, + mode: PadModeSequence = NumpyPadMode.CONSTANT, method: Union[Method, str] = Method.SYMMETRIC, allow_missing_keys: bool = False, - **np_kwargs, + **kwargs, ) -> None: """ Args: @@ -272,38 +283,40 @@ def __init__( k: the target k for each spatial dimension. if `k` is negative or 0, the original size is preserved. if `k` is an int, the same `k` be applied to all the input spatial dimensions. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + See also: https://numpy.org/doc/1.18/reference/ghttps://pytorch.orgenerated/numpy.pad.html + /docs/stable/generated/torch.nn.functional.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. method: {``"symmetric"``, ``"end"``} Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. allow_missing_keys: don't raise exception if key is missing. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. See also :py:class:`monai.transforms.SpatialPad` """ super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padder = DivisiblePad(k=k, method=method, **np_kwargs) + self.padder = DivisiblePad(k=k, method=method, **kwargs) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, m in self.key_iterator(d, self.mode): self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) d[key] = self.padder(d[key], mode=m) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) current_size = np.array(d[key].shape[1:]) roi_start = np.floor((current_size - orig_size) / 2) roi_end = orig_size + roi_start @@ -331,6 +344,8 @@ class SpatialCropd(MapTransform, InvertibleTransform): - the start and end coordinates of the ROI """ + backend = SpatialCrop.backend + def __init__( self, keys: KeysCollection, @@ -357,20 +372,20 @@ def __init__( super().__init__(keys, allow_missing_keys) self.cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): self.push_transform(d, key) d[key] = self.cropper(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) current_size = np.array(d[key].shape[1:]) # get required pad to start and end pad_to_start = np.array([s.indices(o)[0] for s, o in zip(self.cropper.slices, orig_size)]) @@ -404,13 +419,15 @@ class CenterSpatialCropd(MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = CenterSpatialCrop.backend + def __init__( self, keys: KeysCollection, roi_size: Union[Sequence[int], int], allow_missing_keys: bool = False ) -> None: super().__init__(keys, allow_missing_keys) self.cropper = CenterSpatialCrop(roi_size) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): orig_size = d[key].shape[1:] @@ -418,13 +435,13 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.push_transform(d, key, orig_size=orig_size) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) current_size = np.array(d[key].shape[1:]) pad_to_start = np.floor((orig_size - current_size) / 2).astype(int) # in each direction, if original size is even and current size is odd, += 1 @@ -454,16 +471,22 @@ class CenterScaleCropd(MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = CenterSpatialCrop.backend + def __init__( self, keys: KeysCollection, roi_scale: Union[Sequence[float], float], allow_missing_keys: bool = False ) -> None: super().__init__(keys, allow_missing_keys=allow_missing_keys) self.roi_scale = roi_scale - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + # use the spatial size of first image to scale, expect all images have the same spatial size - img_size = data[self.keys[0]].shape[1:] + img_size = d[first_key].shape[1:] # type: ignore ndim = len(img_size) roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] cropper = CenterSpatialCrop(roi_size) @@ -473,13 +496,13 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) current_size = np.array(d[key].shape[1:]) pad_to_start = np.floor((orig_size - current_size) / 2).astype(int) # in each direction, if original size is even and current size is odd, += 1 @@ -525,6 +548,8 @@ class RandSpatialCropd(Randomizable, MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = CenterSpatialCrop.backend + def __init__( self, keys: KeysCollection, @@ -553,11 +578,15 @@ def randomize(self, img_size: Sequence[int]) -> None: valid_size = get_valid_patch_size(img_size, self._size) self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize(d[self.keys[0]].shape[1:]) # image shape from the first data key + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + + self.randomize(d[first_key].shape[1:]) # type: ignore if self._size is None: - raise AssertionError + raise RuntimeError("self._size not specified.") for key in self.key_iterator(d): if self.random_center: self.push_transform(d, key, {"slices": [(i.start, i.stop) for i in self._slices[1:]]}) # type: ignore @@ -568,18 +597,18 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = cropper(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = transform[InverseKeys.ORIG_SIZE] + orig_size = transform[TraceKeys.ORIG_SIZE] random_center = self.random_center pad_to_start = np.empty((len(orig_size)), dtype=np.int32) pad_to_end = np.empty((len(orig_size)), dtype=np.int32) if random_center: - for i, _slice in enumerate(transform[InverseKeys.EXTRA_INFO]["slices"]): + for i, _slice in enumerate(transform[TraceKeys.EXTRA_INFO]["slices"]): pad_to_start[i] = _slice[0] pad_to_end[i] = orig_size[i] - _slice[1] else: @@ -615,7 +644,7 @@ class RandScaleCropd(RandSpatialCropd): roi_scale: if `random_size` is True, it specifies the minimum crop size: `roi_scale * image spatial size`. if `random_size` is False, it specifies the expected scale of image size to crop. e.g. [0.3, 0.4, 0.5]. If its components have non-positive values, will use `1.0` instead, which means the input image size. - max_roi_size: if `random_size` is True and `roi_scale` specifies the min crop region size, `max_roi_scale` + max_roi_scale: if `random_size` is True and `roi_scale` specifies the min crop region size, `max_roi_scale` can specify the max crop region size: `max_roi_scale * image spatial size`. if None, defaults to the input image size. if its components have non-positive values, will use `1.0` instead, which means the input image size. @@ -626,6 +655,8 @@ class RandScaleCropd(RandSpatialCropd): allow_missing_keys: don't raise exception if key is missing. """ + backend = RandSpatialCropd.backend + def __init__( self, keys: KeysCollection, @@ -643,12 +674,15 @@ def __init__( random_size=random_size, allow_missing_keys=allow_missing_keys, ) - MapTransform.__init__(self, keys, allow_missing_keys) self.roi_scale = roi_scale self.max_roi_scale = max_roi_scale - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - img_size = data[self.keys[0]].shape[1:] + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + first_key: Union[Hashable, List] = self.first_key(data) # type: ignore + if first_key == []: + return data # type: ignore + + img_size = data[first_key].shape[1:] # type: ignore ndim = len(img_size) self.roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] if self.max_roi_scale is not None: @@ -701,7 +735,7 @@ class RandSpatialCropSamplesd(Randomizable, MapTransform, InvertibleTransform): the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. used to add `patch_index` to the meta dict. allow_missing_keys: don't raise exception if key is missing. @@ -711,6 +745,8 @@ class RandSpatialCropSamplesd(Randomizable, MapTransform, InvertibleTransform): """ + backend = RandSpatialCropd.backend + def __init__( self, keys: KeysCollection, @@ -720,7 +756,7 @@ def __init__( random_center: bool = True, random_size: bool = True, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", + meta_key_postfix: str = DEFAULT_POST_FIX, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) @@ -735,15 +771,15 @@ def __init__( def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "Randomizable": - super().set_random_state(seed=seed, state=state) - self.cropper.set_random_state(state=self.R) + ) -> "RandSpatialCropSamplesd": + super().set_random_state(seed, state) + self.cropper.set_random_state(seed, state) return self def randomize(self, data: Optional[Any] = None) -> None: pass - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: ret = [] for i in range(self.num_samples): d = dict(data) @@ -753,24 +789,24 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n cropped = self.cropper(d) # self.cropper will have added RandSpatialCropd to the list. Change to RandSpatialCropSamplesd for key in self.key_iterator(cropped): - cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.CLASS_NAME] = self.__class__.__name__ - cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.ID] = id(self) + cropped[self.trace_key(key)][-1][TraceKeys.CLASS_NAME] = self.__class__.__name__ # type: ignore + cropped[self.trace_key(key)][-1][TraceKeys.ID] = id(self) # type: ignore # add `patch_index` to the meta data for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): meta_key = meta_key or f"{key}_{meta_key_postfix}" if meta_key not in cropped: cropped[meta_key] = {} # type: ignore - cropped[meta_key][Key.PATCH_INDEX] = i + cropped[meta_key][Key.PATCH_INDEX] = i # type: ignore ret.append(cropped) return ret - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: d = deepcopy(dict(data)) # We changed the transform name from RandSpatialCropd to RandSpatialCropSamplesd # Need to revert that since we're calling RandSpatialCropd's inverse for key in self.key_iterator(d): - d[key + InverseKeys.KEY_SUFFIX][-1][InverseKeys.CLASS_NAME] = self.cropper.__class__.__name__ - d[key + InverseKeys.KEY_SUFFIX][-1][InverseKeys.ID] = id(self.cropper) + d[self.trace_key(key)][-1][TraceKeys.CLASS_NAME] = self.cropper.__class__.__name__ + d[self.trace_key(key)][-1][TraceKeys.ID] = id(self.cropper) context_manager = allow_missing_keys_mode if self.allow_missing_keys else _nullcontext with context_manager(self.cropper): return self.cropper.inverse(d) @@ -789,6 +825,8 @@ class CropForegroundd(MapTransform, InvertibleTransform): channels. And it can also add margin to every dim of the bounding box of foreground object. """ + backend = CropForeground.backend + def __init__( self, keys: KeysCollection, @@ -796,8 +834,9 @@ def __init__( select_fn: Callable = is_positive, channel_indices: Optional[IndexSelection] = None, margin: Union[Sequence[int], int] = 0, + allow_smaller: bool = True, k_divisible: Union[Sequence[int], int] = 1, - mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, + mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = NumpyPadMode.CONSTANT, start_coord_key: str = "foreground_start_coord", end_coord_key: str = "foreground_end_coord", allow_missing_keys: bool = False, @@ -812,12 +851,17 @@ def __init__( channel_indices: if defined, select foreground only on the specified channels of image. if None, select foreground on the whole image. margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. + allow_smaller: when computing box size with `margin`, whether allow the image size to be smaller + than box size, default to `True`. if the margined size is bigger than image size, will pad with + specified `mode`. k_divisible: make each spatial dimension to be divisible by k, default to 1. if `k_divisible` is an int, the same `k` be applied to all the input spatial dimensions. - mode: padding mode {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - one of the listed string values or a user supplied function. Defaults to ``"constant"``. - see also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html it also can be a sequence of string, each element corresponds to a key in ``keys``. start_coord_key: key to record the start coordinate of spatial bounding box for foreground. end_coord_key: key to record the end coordinate of spatial bounding box for foreground. @@ -834,12 +878,13 @@ def __init__( select_fn=select_fn, channel_indices=channel_indices, margin=margin, + allow_smaller=allow_smaller, k_divisible=k_divisible, **np_kwargs, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) box_start, box_end = self.cropper.compute_bounding_box(img=d[self.source_key]) d[self.start_coord_key] = box_start @@ -849,14 +894,14 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.cropper.crop_pad(img=d[key], box_start=box_start, box_end=box_end, mode=m) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE]) cur_size = np.asarray(d[key].shape[1:]) - extra_info = transform[InverseKeys.EXTRA_INFO] + extra_info = transform[TraceKeys.EXTRA_INFO] box_start = np.asarray(extra_info["box_start"]) box_end = np.asarray(extra_info["box_end"]) # first crop the padding part @@ -897,7 +942,7 @@ class RandWeightedCropd(Randomizable, MapTransform, InvertibleTransform): the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. used to add `patch_index` to the meta dict. allow_missing_keys: don't raise exception if key is missing. @@ -906,6 +951,8 @@ class RandWeightedCropd(Randomizable, MapTransform, InvertibleTransform): :py:class:`monai.transforms.RandWeightedCrop` """ + backend = SpatialCrop.backend + def __init__( self, keys: KeysCollection, @@ -914,7 +961,7 @@ def __init__( num_samples: int = 1, center_coord_key: Optional[str] = None, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", + meta_key_postfix: str = DEFAULT_POST_FIX, allow_missing_keys: bool = False, ): MapTransform.__init__(self, keys, allow_missing_keys) @@ -928,18 +975,18 @@ def __init__( self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) self.centers: List[np.ndarray] = [] - def randomize(self, weight_map: np.ndarray) -> None: + def randomize(self, weight_map: NdarrayOrTensor) -> None: self.centers = weighted_patch_samples( spatial_size=self.spatial_size, w=weight_map[0], n_samples=self.num_samples, r_state=self.R ) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: d = dict(data) self.randomize(d[self.w_key]) _spatial_size = fall_back_tuple(self.spatial_size, d[self.w_key].shape[1:]) # initialize returned list with shallow copy to preserve key ordering - results: List[Dict[Hashable, np.ndarray]] = [dict(data) for _ in range(self.num_samples)] + results: List[Dict[Hashable, NdarrayOrTensor]] = [dict(data) for _ in range(self.num_samples)] # fill in the extra keys with unmodified data for i in range(self.num_samples): for key in set(data.keys()).difference(set(self.keys)): @@ -965,19 +1012,19 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n meta_key = meta_key or f"{key}_{meta_key_postfix}" if meta_key not in results[i]: results[i][meta_key] = {} # type: ignore - results[i][meta_key][Key.PATCH_INDEX] = i + results[i][meta_key][Key.PATCH_INDEX] = i # type: ignore return results - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE]) current_size = np.asarray(d[key].shape[1:]) - center = transform[InverseKeys.EXTRA_INFO]["center"] - cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) + center = transform[TraceKeys.EXTRA_INFO]["center"] + cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size) # get required pad to start and end pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)]) pad_to_end = orig_size - current_size - pad_to_start @@ -1037,9 +1084,12 @@ class RandCropByPosNegLabeld(Randomizable, MapTransform, InvertibleTransform): the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. used to add `patch_index` to the meta dict. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will be set to + match the cropped size (i.e., no cropping in that dimension). allow_missing_keys: don't raise exception if key is missing. Raises: @@ -1048,6 +1098,8 @@ class RandCropByPosNegLabeld(Randomizable, MapTransform, InvertibleTransform): """ + backend = RandCropByPosNegLabel.backend + def __init__( self, keys: KeysCollection, @@ -1061,7 +1113,8 @@ def __init__( fg_indices_key: Optional[str] = None, bg_indices_key: Optional[str] = None, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", + meta_key_postfix: str = DEFAULT_POST_FIX, + allow_smaller: bool = False, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) @@ -1081,14 +1134,15 @@ def __init__( if len(self.keys) != len(self.meta_keys): raise ValueError("meta_keys should have the same length as keys.") self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self.centers: Optional[List[List[np.ndarray]]] = None + self.centers: Optional[List[List[int]]] = None + self.allow_smaller = allow_smaller def randomize( self, - label: np.ndarray, - fg_indices: Optional[np.ndarray] = None, - bg_indices: Optional[np.ndarray] = None, - image: Optional[np.ndarray] = None, + label: NdarrayOrTensor, + fg_indices: Optional[NdarrayOrTensor] = None, + bg_indices: Optional[NdarrayOrTensor] = None, + image: Optional[NdarrayOrTensor] = None, ) -> None: self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) if fg_indices is None or bg_indices is None: @@ -1097,10 +1151,17 @@ def randomize( fg_indices_ = fg_indices bg_indices_ = bg_indices self.centers = generate_pos_neg_label_crop_centers( - self.spatial_size, self.num_samples, self.pos_ratio, label.shape[1:], fg_indices_, bg_indices_, self.R + self.spatial_size, + self.num_samples, + self.pos_ratio, + label.shape[1:], + fg_indices_, + bg_indices_, + self.R, + self.allow_smaller, ) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: d = dict(data) label = d[self.label_key] image = d[self.image_key] if self.image_key else None @@ -1114,7 +1175,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n raise ValueError("no available ROI centers to crop.") # initialize returned list with shallow copy to preserve key ordering - results: List[Dict[Hashable, np.ndarray]] = [dict(d) for _ in range(self.num_samples)] + results: List[Dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(self.num_samples)] for i, center in enumerate(self.centers): # fill in the extra keys with unmodified data @@ -1122,7 +1183,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n results[i][key] = deepcopy(d[key]) for key in self.key_iterator(d): img = d[key] - cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore + cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) orig_size = img.shape[1:] results[i][key] = cropper(img) self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) @@ -1131,18 +1192,18 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n meta_key = meta_key or f"{key}_{meta_key_postfix}" if meta_key not in results[i]: results[i][meta_key] = {} # type: ignore - results[i][meta_key][Key.PATCH_INDEX] = i + results[i][meta_key][Key.PATCH_INDEX] = i # type: ignore return results - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE]) current_size = np.asarray(d[key].shape[1:]) - center = transform[InverseKeys.EXTRA_INFO]["center"] + center = transform[TraceKeys.EXTRA_INFO]["center"] cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore # get required pad to start and end pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)]) @@ -1227,13 +1288,18 @@ class RandCropByLabelClassesd(Randomizable, MapTransform, InvertibleTransform): the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. used to add `patch_index` to the meta dict. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will remain + unchanged. allow_missing_keys: don't raise exception if key is missing. """ + backend = RandCropByLabelClasses.backend + def __init__( self, keys: KeysCollection, @@ -1246,7 +1312,8 @@ def __init__( image_threshold: float = 0.0, indices_key: Optional[str] = None, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", + meta_key_postfix: str = DEFAULT_POST_FIX, + allow_smaller: bool = False, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) @@ -1262,25 +1329,25 @@ def __init__( if len(self.keys) != len(self.meta_keys): raise ValueError("meta_keys should have the same length as keys.") self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self.centers: Optional[List[List[np.ndarray]]] = None + self.centers: Optional[List[List[int]]] = None + self.allow_smaller = allow_smaller def randomize( self, - label: np.ndarray, - indices: Optional[List[np.ndarray]] = None, - image: Optional[np.ndarray] = None, + label: NdarrayOrTensor, + indices: Optional[List[NdarrayOrTensor]] = None, + image: Optional[NdarrayOrTensor] = None, ) -> None: self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) - indices_: List[np.ndarray] if indices is None: indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) else: indices_ = indices self.centers = generate_label_classes_crop_centers( - self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R + self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R, self.allow_smaller ) - def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, NdarrayOrTensor]]: d = dict(data) label = d[self.label_key] image = d[self.image_key] if self.image_key else None @@ -1293,7 +1360,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarr raise ValueError("no available ROI centers to crop.") # initialize returned list with shallow copy to preserve key ordering - results: List[Dict[Hashable, np.ndarray]] = [dict(d) for _ in range(self.num_samples)] + results: List[Dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(self.num_samples)] for i, center in enumerate(self.centers): # fill in the extra keys with unmodified data @@ -1301,7 +1368,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarr results[i][key] = deepcopy(d[key]) for key in self.key_iterator(d): img = d[key] - cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore + cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) orig_size = img.shape[1:] results[i][key] = cropper(img) self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) @@ -1310,18 +1377,18 @@ def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarr meta_key = meta_key or f"{key}_{meta_key_postfix}" if meta_key not in results[i]: results[i][meta_key] = {} # type: ignore - results[i][meta_key][Key.PATCH_INDEX] = i + results[i][meta_key][Key.PATCH_INDEX] = i # type: ignore return results - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE]) current_size = np.asarray(d[key].shape[1:]) - center = transform[InverseKeys.EXTRA_INFO]["center"] + center = transform[TraceKeys.EXTRA_INFO]["center"] cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore # get required pad to start and end pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)]) @@ -1359,6 +1426,8 @@ class ResizeWithPadOrCropd(MapTransform, InvertibleTransform): """ + backend = ResizeWithPadOrCrop.backend + def __init__( self, keys: KeysCollection, @@ -1372,27 +1441,20 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, method=method, **np_kwargs) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, m in self.key_iterator(d, self.mode): orig_size = d[key].shape[1:] d[key] = self.padcropper(d[key], mode=m) - self.push_transform( - d, - key, - orig_size=orig_size, - extra_info={ - "mode": m.value if isinstance(m, Enum) else m, - }, - ) + self.push_transform(d, key, orig_size=orig_size, extra_info={"mode": m.value if isinstance(m, Enum) else m}) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) current_size = np.array(d[key].shape[1:]) # Unfortunately, we can't just use ResizeWithPadOrCrop with original size because of odd/even rounding. # Instead, we first pad any smaller dimensions, and then we crop any larger dimensions. @@ -1436,6 +1498,8 @@ class BoundingRectd(MapTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = BoundingRect.backend + def __init__( self, keys: KeysCollection, @@ -1447,7 +1511,7 @@ def __init__( self.bbox = BoundingRect(select_fn=select_fn) self.bbox_key_postfix = bbox_key_postfix - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: """ See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`. """ diff --git a/monai/transforms/intensity/__init__.py b/monai/transforms/intensity/__init__.py index 14ae193634..1e97f89407 100644 --- a/monai/transforms/intensity/__init__.py +++ b/monai/transforms/intensity/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 20d306be04..da46b105e1 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,7 +13,7 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ -import copy +from abc import abstractmethod from collections.abc import Iterable from functools import partial from typing import Any, Callable, List, Optional, Sequence, Tuple, Union @@ -28,8 +28,8 @@ from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter from monai.transforms.transform import RandomizableTransform, Transform from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array +from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where from monai.utils import ( - PT_BEFORE_1_7, InvalidPyTorchVersionError, convert_data_type, convert_to_dst_type, @@ -37,7 +37,9 @@ ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, + pytorch_after, ) +from monai.utils.deprecate_utils import deprecated_arg from monai.utils.enums import TransformBackends from monai.utils.type_conversion import convert_to_tensor, get_equivalent_dtype @@ -69,9 +71,12 @@ "RandGibbsNoise", "KSpaceSpikeNoise", "RandKSpaceSpikeNoise", + "RandCoarseTransform", "RandCoarseDropout", + "RandCoarseShuffle", "HistogramNormalize", - "LocalPatchShuffling", + "IntensityRemap", + "RandIntensityRemap", ] @@ -83,30 +88,42 @@ class RandGaussianNoise(RandomizableTransform): prob: Probability to add Gaussian noise. mean: Mean or “centre” of the distribution. std: Standard deviation (spread) of distribution. + dtype: output data type, if None, same as input image. defaults to float32. + """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, prob: float = 0.1, mean: Union[Sequence[float], float] = 0.0, std: float = 0.1) -> None: + def __init__(self, prob: float = 0.1, mean: float = 0.0, std: float = 0.1, dtype: DtypeLike = np.float32) -> None: RandomizableTransform.__init__(self, prob) self.mean = mean self.std = std - self._noise: np.ndarray + self.dtype = dtype + self.noise: Optional[np.ndarray] = None - def randomize(self, im_shape: Sequence[int]) -> None: + def randomize(self, img: NdarrayOrTensor, mean: Optional[float] = None) -> None: super().randomize(None) - self._noise = self.R.normal(self.mean, self.R.uniform(0, self.std), size=im_shape) + if not self._do_transform: + return None + rand_std = self.R.uniform(0, self.std) + noise = self.R.normal(self.mean if mean is None else mean, rand_std, size=img.shape) + # noise is float64 array, convert to the output dtype to save memory + self.noise, *_ = convert_data_type(noise, dtype=self.dtype) - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, mean: Optional[float] = None, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - self.randomize(img.shape) - if self._noise is None: - raise RuntimeError("randomized factor should not be None.") + if randomize: + self.randomize(img=img, mean=self.mean if mean is None else mean) + if not self._do_transform: return img - noise, *_ = convert_to_dst_type(self._noise, img) + + if self.noise is None: + raise RuntimeError("please call the `randomize()` function first.") + img, *_ = convert_data_type(img, dtype=self.dtype) + noise, *_ = convert_to_dst_type(self.noise, img) return img + noise @@ -114,10 +131,10 @@ class RandRicianNoise(RandomizableTransform): """ Add Rician noise to image. Rician noise in MRI is the result of performing a magnitude operation on complex - data with Gaussian noise of the same variance in both channels, as described in `Noise in Magnitude Magnetic Resonance Images - `_. This transform is adapted from - `DIPY`_. See also: `The rician distribution of noisy mri data - `_. + data with Gaussian noise of the same variance in both channels, as described in + `Noise in Magnitude Magnetic Resonance Images `_. + This transform is adapted from `DIPY `_. + See also: `The rician distribution of noisy mri data `_. Args: prob: Probability to add Rician noise. @@ -131,6 +148,8 @@ class RandRicianNoise(RandomizableTransform): histogram. sample_std: If True, sample the spread of the Gaussian distributions uniformly from 0 to std. + dtype: output data type, if None, same as input image. defaults to float32. + """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -143,6 +162,7 @@ def __init__( channel_wise: bool = False, relative: bool = False, sample_std: bool = True, + dtype: DtypeLike = np.float32, ) -> None: RandomizableTransform.__init__(self, prob) self.prob = prob @@ -151,29 +171,34 @@ def __init__( self.channel_wise = channel_wise self.relative = relative self.sample_std = sample_std + self.dtype = dtype self._noise1: NdarrayOrTensor self._noise2: NdarrayOrTensor - def _add_noise(self, img: NdarrayTensor, mean: float, std: float): + def _add_noise(self, img: NdarrayOrTensor, mean: float, std: float): dtype_np = get_equivalent_dtype(img.dtype, np.ndarray) im_shape = img.shape _std = self.R.uniform(0, std) if self.sample_std else std - self._noise1 = self.R.normal(mean, _std, size=im_shape).astype(dtype_np) - self._noise2 = self.R.normal(mean, _std, size=im_shape).astype(dtype_np) + self._noise1 = self.R.normal(mean, _std, size=im_shape).astype(dtype_np, copy=False) + self._noise2 = self.R.normal(mean, _std, size=im_shape).astype(dtype_np, copy=False) if isinstance(img, torch.Tensor): n1 = torch.tensor(self._noise1, device=img.device) n2 = torch.tensor(self._noise2, device=img.device) - return torch.sqrt((img + n1) ** 2 + n2 ** 2) + return torch.sqrt((img + n1) ** 2 + n2**2) - return np.sqrt((img + self._noise1) ** 2 + self._noise2 ** 2) + return np.sqrt((img + self._noise1) ** 2 + self._noise2**2) - def __call__(self, img: NdarrayTensor) -> NdarrayTensor: + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - super().randomize(None) + if randomize: + super().randomize(None) + if not self._do_transform: return img + + img, *_ = convert_data_type(img, dtype=self.dtype) if self.channel_wise: _mean = ensure_tuple_rep(self.mean, len(img)) _std = ensure_tuple_rep(self.std, len(img)) @@ -211,9 +236,9 @@ def __call__(self, img: NdarrayOrTensor, offset: Optional[float] = None) -> Ndar offset = self.offset if offset is None else offset out = img + offset - if isinstance(out, torch.Tensor): - return out.type(img.dtype) - return out.astype(img.dtype) # type: ignore + out, *_ = convert_data_type(data=out, dtype=img.dtype) + + return out class RandShiftIntensity(RandomizableTransform): @@ -241,10 +266,12 @@ def __init__(self, offsets: Union[Tuple[float, float], float], prob: float = 0.1 self._shfiter = ShiftIntensity(self._offset) def randomize(self, data: Optional[Any] = None) -> None: - self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) super().randomize(None) + if not self._do_transform: + return None + self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) - def __call__(self, img: NdarrayOrTensor, factor: Optional[float] = None) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, factor: Optional[float] = None, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`. @@ -254,9 +281,12 @@ def __call__(self, img: NdarrayOrTensor, factor: Optional[float] = None) -> Ndar can be some image specific value at runtime, like: max(img), etc. """ - self.randomize() + if randomize: + self.randomize() + if not self._do_transform: return img + return self._shfiter(img, self._offset if factor is None else self._offset * factor) @@ -272,7 +302,7 @@ class StdShiftIntensity(Transform): nonzero: whether only count non-zero values. channel_wise: if True, calculate on each channel separately. Please ensure that the first dimension represents the channel of the image if True. - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -305,7 +335,8 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - img, *_ = convert_data_type(img, dtype=self.dtype) + if self.dtype is not None: + img, *_ = convert_data_type(img, dtype=self.dtype) if self.channel_wise: for i, d in enumerate(img): img[i] = self._stdshift(d) # type: ignore @@ -337,7 +368,7 @@ def __init__( prob: probability of std shift. nonzero: whether only count non-zero values. channel_wise: if True, calculate on each channel separately. - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. """ RandomizableTransform.__init__(self, prob) @@ -353,20 +384,25 @@ def __init__( self.dtype = dtype def randomize(self, data: Optional[Any] = None) -> None: - self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) super().randomize(None) + if not self._do_transform: + return None + self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - self.randomize() + if randomize: + self.randomize() + if not self._do_transform: return img + shifter = StdShiftIntensity( factor=self.factor, nonzero=self.nonzero, channel_wise=self.channel_wise, dtype=self.dtype ) - return shifter(img) + return shifter(img=img) class ScaleIntensity(Transform): @@ -378,18 +414,28 @@ class ScaleIntensity(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( - self, minv: Optional[float] = 0.0, maxv: Optional[float] = 1.0, factor: Optional[float] = None + self, + minv: Optional[float] = 0.0, + maxv: Optional[float] = 1.0, + factor: Optional[float] = None, + channel_wise: bool = False, + dtype: DtypeLike = np.float32, ) -> None: """ Args: minv: minimum value of output data. maxv: maximum value of output data. factor: factor scale by ``v = v * (1 + factor)``. In order to use - this parameter, please set `minv` and `maxv` into None. + this parameter, please set both `minv` and `maxv` into None. + channel_wise: if True, scale on each channel separately. Please ensure + that the first dimension represents the channel of the image if True. + dtype: output data type, if None, same as input image. defaults to float32. """ self.minv = minv self.maxv = maxv self.factor = factor + self.channel_wise = channel_wise + self.dtype = dtype def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ @@ -399,13 +445,18 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: ValueError: When ``self.minv=None`` or ``self.maxv=None`` and ``self.factor=None``. Incompatible values. """ - if self.minv is not None and self.maxv is not None: - return rescale_array(img, self.minv, self.maxv, img.dtype) - if self.factor is not None: - out = img * (1 + self.factor) - out, *_ = convert_data_type(out, dtype=img.dtype) - return out - raise ValueError("Incompatible values: minv=None or maxv=None and factor=None.") + ret: NdarrayOrTensor + if self.minv is not None or self.maxv is not None: + if self.channel_wise: + out = [rescale_array(d, self.minv, self.maxv, dtype=self.dtype) for d in img] + ret = torch.stack(out) if isinstance(img, torch.Tensor) else np.stack(out) # type: ignore + else: + ret = rescale_array(img, self.minv, self.maxv, dtype=self.dtype) + else: + ret = (img * (1 + self.factor)) if self.factor is not None else img + + ret, *_ = convert_data_type(ret, dtype=self.dtype or img.dtype) + return ret class RandScaleIntensity(RandomizableTransform): @@ -416,12 +467,15 @@ class RandScaleIntensity(RandomizableTransform): backend = ScaleIntensity.backend - def __init__(self, factors: Union[Tuple[float, float], float], prob: float = 0.1) -> None: + def __init__( + self, factors: Union[Tuple[float, float], float], prob: float = 0.1, dtype: DtypeLike = np.float32 + ) -> None: """ Args: factors: factor range to randomly scale by ``v = v * (1 + factor)``. if single number, factor value is picked from (-factors, factors). prob: probability of scale. + dtype: output data type, if None, same as input image. defaults to float32. """ RandomizableTransform.__init__(self, prob) @@ -432,20 +486,25 @@ def __init__(self, factors: Union[Tuple[float, float], float], prob: float = 0.1 else: self.factors = (min(factors), max(factors)) self.factor = self.factors[0] + self.dtype = dtype def randomize(self, data: Optional[Any] = None) -> None: - self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) super().randomize(None) + if not self._do_transform: + return None + self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - self.randomize() + if randomize: + self.randomize() + if not self._do_transform: return img - scaler = ScaleIntensity(minv=None, maxv=None, factor=self.factor) - return scaler(img) + + return ScaleIntensity(minv=None, maxv=None, factor=self.factor, dtype=self.dtype)(img) class RandBiasField(RandomizableTransform): @@ -463,17 +522,19 @@ class RandBiasField(RandomizableTransform): degree: degree of freedom of the polynomials. The value should be no less than 1. Defaults to 3. coeff_range: range of the random coefficients. Defaults to (0.0, 0.1). - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. prob: probability to do random bias field. """ + backend = [TransformBackends.NUMPY] + def __init__( self, degree: int = 3, coeff_range: Tuple[float, float] = (0.0, 0.1), dtype: DtypeLike = np.float32, - prob: float = 1.0, + prob: float = 0.1, ) -> None: RandomizableTransform.__init__(self, prob) if degree < 1: @@ -507,18 +568,23 @@ def _generate_random_field(self, spatial_shape: Sequence[int], degree: int, coef return np.polynomial.legendre.leggrid3d(coords[0], coords[1], coords[2], coeff_mat) raise NotImplementedError("only supports 2D or 3D fields") - def randomize(self, data: np.ndarray) -> None: + def randomize(self, img_size: Sequence[int]) -> None: super().randomize(None) - n_coeff = int(np.prod([(self.degree + k) / k for k in range(1, len(data.shape[1:]) + 1)])) + if not self._do_transform: + return None + n_coeff = int(np.prod([(self.degree + k) / k for k in range(1, len(img_size) + 1)])) self._coeff = self.R.uniform(*self.coeff_range, n_coeff).tolist() - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - self.randomize(data=img) + if randomize: + self.randomize(img_size=img.shape[1:]) + if not self._do_transform: return img + num_channels, *spatial_shape = img.shape _bias_fields = np.stack( [ @@ -527,7 +593,10 @@ def __call__(self, img: np.ndarray): ], axis=0, ) - return (img * np.exp(_bias_fields)).astype(self.dtype) + img_np, *_ = convert_data_type(img, np.ndarray) + out: NdarrayOrTensor = img_np * np.exp(_bias_fields) + out, *_ = convert_to_dst_type(src=out, dst=img, dtype=self.dtype or img.dtype) + return out class NormalizeIntensity(Transform): @@ -542,9 +611,9 @@ class NormalizeIntensity(Transform): subtrahend: the amount to subtract by (usually the mean). divisor: the amount to divide by (usually the standard deviation). nonzero: whether only normalize non-zero values. - channel_wise: if using calculated mean and std, calculate on each channel separately - or calculate on the entire image directly. - dtype: output data type, defaults to float32. + channel_wise: if True, calculate on each channel separately, otherwise, calculate on + the entire image directly. default to False. + dtype: output data type, if None, same as input image. defaults to float32. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -611,6 +680,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is a channel-first array if `self.channel_wise` is True, """ + dtype = self.dtype or img.dtype if self.channel_wise: if self.subtrahend is not None and len(self.subtrahend) != len(img): raise ValueError(f"img has {len(img)} channels, but subtrahend has {len(self.subtrahend)} components.") @@ -626,7 +696,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: else: img = self._normalize(img, self.subtrahend, self.divisor) - out, *_ = convert_data_type(img, dtype=self.dtype) + out = convert_to_dst_type(img, img, dtype=dtype)[0] return out @@ -641,6 +711,8 @@ class ThresholdIntensity(Transform): cval: value to fill the remaining parts of the image, default is 0. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, threshold: float, above: bool = True, cval: float = 0.0) -> None: if not isinstance(threshold, (int, float)): raise ValueError("threshold must be a float or int number.") @@ -648,13 +720,14 @@ def __init__(self, threshold: float, above: bool = True, cval: float = 0.0) -> N self.above = above self.cval = cval - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - return np.asarray( - np.where(img > self.threshold if self.above else img < self.threshold, img, self.cval), dtype=img.dtype - ) + mask = img > self.threshold if self.above else img < self.threshold + res = where(mask, img, self.cval) + res, *_ = convert_data_type(res, dtype=img.dtype) + return res class ScaleIntensityRange(Transform): @@ -662,34 +735,55 @@ class ScaleIntensityRange(Transform): Apply specific intensity scaling to the whole numpy array. Scaling from [a_min, a_max] to [b_min, b_max] with clip option. + When `b_min` or `b_max` are `None`, `scacled_array * (b_max - b_min) + b_min` will be skipped. + If `clip=True`, when `b_min`/`b_max` is None, the clipping is not performed on the corresponding edge. + Args: a_min: intensity original range min. a_max: intensity original range max. b_min: intensity target range min. b_max: intensity target range max. clip: whether to perform clip after scaling. + dtype: output data type, if None, same as input image. defaults to float32. """ - def __init__(self, a_min: float, a_max: float, b_min: float, b_max: float, clip: bool = False) -> None: + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__( + self, + a_min: float, + a_max: float, + b_min: Optional[float] = None, + b_max: Optional[float] = None, + clip: bool = False, + dtype: DtypeLike = np.float32, + ) -> None: self.a_min = a_min self.a_max = a_max self.b_min = b_min self.b_max = b_max self.clip = clip + self.dtype = dtype - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ + dtype = self.dtype or img.dtype if self.a_max - self.a_min == 0.0: warn("Divide by zero (a_min == a_max)", Warning) + if self.b_min is None: + return img - self.a_min return img - self.a_min + self.b_min img = (img - self.a_min) / (self.a_max - self.a_min) - img = img * (self.b_max - self.b_min) + self.b_min + if (self.b_min is not None) and (self.b_max is not None): + img = img * (self.b_max - self.b_min) + self.b_min if self.clip: - img = np.asarray(np.clip(img, self.b_min, self.b_max)) - return img + img = clip(img, self.b_min, self.b_max) + ret: NdarrayOrTensor = convert_data_type(img, dtype=dtype)[0] + + return ret class AdjustContrast(Transform): @@ -702,19 +796,22 @@ class AdjustContrast(Transform): gamma: gamma value to adjust the contrast as function. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, gamma: float) -> None: if not isinstance(gamma, (int, float)): raise ValueError("gamma must be a float or int number.") self.gamma = gamma - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ epsilon = 1e-7 img_min = img.min() img_range = img.max() - img_min - return np.power(((img - img_min) / float(img_range + epsilon)), self.gamma) * img_range + img_min + ret: NdarrayOrTensor = ((img - img_min) / float(img_range + epsilon)) ** self.gamma * img_range + img_min + return ret class RandAdjustContrast(RandomizableTransform): @@ -729,6 +826,8 @@ class RandAdjustContrast(RandomizableTransform): If single number, value is picked from (0.5, gamma), default is (0.5, 4.5). """ + backend = AdjustContrast.backend + def __init__(self, prob: float = 0.1, gamma: Union[Sequence[float], float] = (0.5, 4.5)) -> None: RandomizableTransform.__init__(self, prob) @@ -743,34 +842,39 @@ def __init__(self, prob: float = 0.1, gamma: Union[Sequence[float], float] = (0. else: self.gamma = (min(gamma), max(gamma)) - self.gamma_value: float + self.gamma_value: Optional[float] = None def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) + if not self._do_transform: + return None self.gamma_value = self.R.uniform(low=self.gamma[0], high=self.gamma[1]) - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - self.randomize() - if self.gamma_value is None: - raise ValueError("gamma_value is not set.") + if randomize: + self.randomize() + if not self._do_transform: return img - adjuster = AdjustContrast(self.gamma_value) - return adjuster(img) + + if self.gamma_value is None: + raise RuntimeError("gamma_value is not set, please call `randomize` function first.") + return AdjustContrast(self.gamma_value)(img) class ScaleIntensityRangePercentiles(Transform): """ Apply range scaling to a numpy array based on the intensity distribution of the input. - By default this transform will scale from [lower_intensity_percentile, upper_intensity_percentile] to [b_min, b_max], where - {lower,upper}_intensity_percentile are the intensity values at the corresponding percentiles of ``img``. + By default this transform will scale from [lower_intensity_percentile, upper_intensity_percentile] to + `[b_min, b_max]`, where {lower,upper}_intensity_percentile are the intensity values at the corresponding + percentiles of ``img``. - The ``relative`` parameter can also be set to scale from [lower_intensity_percentile, upper_intensity_percentile] to the - lower and upper percentiles of the output range [b_min, b_max] + The ``relative`` parameter can also be set to scale from [lower_intensity_percentile, upper_intensity_percentile] + to the lower and upper percentiles of the output range [b_min, b_max]. For example: @@ -807,6 +911,9 @@ class ScaleIntensityRangePercentiles(Transform): [20., 60., 100., 140., 180.], [20., 60., 100., 140., 180.]]] + See Also: + + - :py:class:`monai.transforms.ScaleIntensityRange` Args: lower: lower intensity percentile. @@ -815,10 +922,23 @@ class ScaleIntensityRangePercentiles(Transform): b_max: intensity target range max. clip: whether to perform clip after scaling. relative: whether to scale to the corresponding percentiles of [b_min, b_max]. + channel_wise: if True, compute intensity percentile and normalize every channel separately. + default to False. + dtype: output data type, if None, same as input image. defaults to float32. """ + backend = ScaleIntensityRange.backend + def __init__( - self, lower: float, upper: float, b_min: float, b_max: float, clip: bool = False, relative: bool = False + self, + lower: float, + upper: float, + b_min: Optional[float], + b_max: Optional[float], + clip: bool = False, + relative: bool = False, + channel_wise: bool = False, + dtype: DtypeLike = np.float32, ) -> None: if lower < 0.0 or lower > 100.0: raise ValueError("Percentiles must be in the range [0, 100]") @@ -830,25 +950,36 @@ def __init__( self.b_max = b_max self.clip = clip self.relative = relative + self.channel_wise = channel_wise + self.dtype = dtype - def __call__(self, img: np.ndarray): - """ - Apply the transform to `img`. - """ - a_min = np.percentile(img, self.lower) - a_max = np.percentile(img, self.upper) + def _normalize(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + a_min: float = percentile(img, self.lower) # type: ignore + a_max: float = percentile(img, self.upper) # type: ignore b_min = self.b_min b_max = self.b_max if self.relative: + if (self.b_min is None) or (self.b_max is None): + raise ValueError("If it is relative, b_min and b_max should not be None.") b_min = ((self.b_max - self.b_min) * (self.lower / 100.0)) + self.b_min b_max = ((self.b_max - self.b_min) * (self.upper / 100.0)) + self.b_min - scalar = ScaleIntensityRange(a_min=a_min, a_max=a_max, b_min=b_min, b_max=b_max, clip=False) + scalar = ScaleIntensityRange( + a_min=a_min, a_max=a_max, b_min=b_min, b_max=b_max, clip=self.clip, dtype=self.dtype + ) img = scalar(img) + return img - if self.clip: - img = np.asarray(np.clip(img, self.b_min, self.b_max)) + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + """ + Apply the transform to `img`. + """ + if self.channel_wise: + out = [self._normalize(img=d) for d in img] + img = torch.stack(out) if isinstance(img, torch.Tensor) else np.stack(out) # type: ignore + else: + img = self._normalize(img=img) return img @@ -871,11 +1002,13 @@ class MaskIntensity(Transform): """ - def __init__(self, mask_data: Optional[np.ndarray] = None, select_fn: Callable = is_positive) -> None: + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self, mask_data: Optional[NdarrayOrTensor] = None, select_fn: Callable = is_positive) -> None: self.mask_data = mask_data self.select_fn = select_fn - def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor, mask_data: Optional[NdarrayOrTensor] = None) -> NdarrayOrTensor: """ Args: mask_data: if mask data is single channel, apply to every channel @@ -892,14 +1025,16 @@ def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> n if mask_data is None: raise ValueError("must provide the mask_data when initializing the transform or at runtime.") - mask_data = np.asarray(self.select_fn(mask_data)) - if mask_data.shape[0] != 1 and mask_data.shape[0] != img.shape[0]: + mask_data_, *_ = convert_to_dst_type(src=mask_data, dst=img) + + mask_data_ = self.select_fn(mask_data_) + if mask_data_.shape[0] != 1 and mask_data_.shape[0] != img.shape[0]: raise ValueError( "When mask_data is not single channel, mask_data channels must match img, " - f"got img channels={img.shape[0]} mask_data channels={mask_data.shape[0]}." + f"got img channels={img.shape[0]} mask_data channels={mask_data_.shape[0]}." ) - return np.asarray(img * mask_data) + return img * mask_data_ class SavitzkyGolaySmooth(Transform): @@ -914,7 +1049,7 @@ class SavitzkyGolaySmooth(Transform): or ``'circular'``. Default: ``'zeros'``. See ``torch.nn.Conv1d()`` for more information. """ - backend = [TransformBackends.NUMPY] + backend = [TransformBackends.TORCH] def __init__(self, window_length: int, order: int, axis: int = 1, mode: str = "zeros"): @@ -927,7 +1062,7 @@ def __init__(self, window_length: int, order: int, axis: int = 1, mode: str = "z self.mode = mode self.img_t: torch.Tensor = torch.tensor(0.0) - def __call__(self, img: NdarrayOrTensor) -> torch.Tensor: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: array containing input data. Must be real and in shape [channels, spatial1, spatial2, ...]. @@ -941,7 +1076,9 @@ def __call__(self, img: NdarrayOrTensor) -> torch.Tensor: # add one to transform axis because a batch axis will be added at dimension 0 savgol_filter = SavitzkyGolayFilter(self.window_length, self.order, self.axis + 1, self.mode) # convert to Tensor and add Batch axis expected by HilbertTransform - out: torch.Tensor = savgol_filter(self.img_t.unsqueeze(0)).squeeze(0) + smoothed = savgol_filter(self.img_t.unsqueeze(0)).squeeze(0) + out, *_ = convert_to_dst_type(smoothed, dst=img) + return out @@ -952,14 +1089,16 @@ class DetectEnvelope(Transform): Args: axis: Axis along which to detect the envelope. Default 1, i.e. the first spatial dimension. - N: FFT size. Default img.shape[axis]. Input will be zero-padded or truncated to this size along dimension + n: FFT size. Default img.shape[axis]. Input will be zero-padded or truncated to this size along dimension ``axis``. """ + backend = [TransformBackends.TORCH] + def __init__(self, axis: int = 1, n: Union[int, None] = None) -> None: - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__) if axis < 0: @@ -968,7 +1107,7 @@ def __init__(self, axis: int = 1, n: Union[int, None] = None) -> None: self.axis = axis self.n = n - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor): """ Args: @@ -978,11 +1117,14 @@ def __call__(self, img: np.ndarray): np.ndarray containing envelope of data in img along the specified axis. """ + img_t, *_ = convert_data_type(img, torch.Tensor) # add one to transform axis because a batch axis will be added at dimension 0 hilbert_transform = HilbertTransform(self.axis + 1, self.n) # convert to Tensor and add Batch axis expected by HilbertTransform - input_data = torch.as_tensor(np.ascontiguousarray(img)).unsqueeze(0) - return np.abs(hilbert_transform(input_data).squeeze(0).numpy()) + out = hilbert_transform(img_t.unsqueeze(0)).squeeze(0).abs() + out, *_ = convert_to_dst_type(src=out, dst=img) + + return out class GaussianSmooth(Transform): @@ -999,14 +1141,24 @@ class GaussianSmooth(Transform): """ + backend = [TransformBackends.TORCH] + def __init__(self, sigma: Union[Sequence[float], float] = 1.0, approx: str = "erf") -> None: self.sigma = sigma self.approx = approx - def __call__(self, img: np.ndarray): - gaussian_filter = GaussianFilter(img.ndim - 1, self.sigma, approx=self.approx) - input_data = torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0) - return gaussian_filter(input_data).squeeze(0).detach().numpy() + def __call__(self, img: NdarrayTensor) -> NdarrayTensor: + img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) + sigma: Union[Sequence[torch.Tensor], torch.Tensor] + if isinstance(self.sigma, Sequence): + sigma = [torch.as_tensor(s, device=img_t.device) for s in self.sigma] + else: + sigma = torch.as_tensor(self.sigma, device=img_t.device) + gaussian_filter = GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx) + out_t: torch.Tensor = gaussian_filter(img_t.unsqueeze(0)).squeeze(0) + out, *_ = convert_data_type(out_t, type(img), device=img.device if isinstance(img, torch.Tensor) else None) + + return out class RandGaussianSmooth(RandomizableTransform): @@ -1023,6 +1175,8 @@ class RandGaussianSmooth(RandomizableTransform): """ + backend = GaussianSmooth.backend + def __init__( self, sigma_x: Tuple[float, float] = (0.25, 1.5), @@ -1043,14 +1197,19 @@ def __init__( def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) + if not self._do_transform: + return None self.x = self.R.uniform(low=self.sigma_x[0], high=self.sigma_x[1]) self.y = self.R.uniform(low=self.sigma_y[0], high=self.sigma_y[1]) self.z = self.R.uniform(low=self.sigma_z[0], high=self.sigma_z[1]) - def __call__(self, img: np.ndarray): - self.randomize() + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: + if randomize: + self.randomize() + if not self._do_transform: return img + sigma = ensure_tuple_size(tup=(self.x, self.y, self.z), dim=img.ndim - 1) return GaussianSmooth(sigma=sigma, approx=self.approx)(img) @@ -1082,6 +1241,8 @@ class GaussianSharpen(Transform): """ + backend = [TransformBackends.TORCH] + def __init__( self, sigma1: Union[Sequence[float], float] = 3.0, @@ -1094,13 +1255,18 @@ def __init__( self.alpha = alpha self.approx = approx - def __call__(self, img: np.ndarray): - gaussian_filter1 = GaussianFilter(img.ndim - 1, self.sigma1, approx=self.approx) - gaussian_filter2 = GaussianFilter(img.ndim - 1, self.sigma2, approx=self.approx) - input_data = torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0) - blurred_f = gaussian_filter1(input_data) - filter_blurred_f = gaussian_filter2(blurred_f) - return (blurred_f + self.alpha * (blurred_f - filter_blurred_f)).squeeze(0).detach().numpy() + def __call__(self, img: NdarrayTensor) -> NdarrayTensor: + img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float32) + + gf1, gf2 = ( + GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx).to(img_t.device) + for sigma in (self.sigma1, self.sigma2) + ) + blurred_f = gf1(img_t.unsqueeze(0)) + filter_blurred_f = gf2(blurred_f) + out_t: torch.Tensor = (blurred_f + self.alpha * (blurred_f - filter_blurred_f)).squeeze(0) + out, *_ = convert_data_type(out_t, type(img), device=img.device if isinstance(img, torch.Tensor) else None) + return out class RandGaussianSharpen(RandomizableTransform): @@ -1125,6 +1291,8 @@ class RandGaussianSharpen(RandomizableTransform): """ + backend = GaussianSharpen.backend + def __init__( self, sigma1_x: Tuple[float, float] = (0.5, 1.0), @@ -1146,9 +1314,18 @@ def __init__( self.sigma2_z = sigma2_z self.alpha = alpha self.approx = approx + self.x1: Optional[float] = None + self.y1: Optional[float] = None + self.z1: Optional[float] = None + self.x2: Optional[float] = None + self.y2: Optional[float] = None + self.z2: Optional[float] = None + self.a: Optional[float] = None def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) + if not self._do_transform: + return None self.x1 = self.R.uniform(low=self.sigma1_x[0], high=self.sigma1_x[1]) self.y1 = self.R.uniform(low=self.sigma1_y[0], high=self.sigma1_y[1]) self.z1 = self.R.uniform(low=self.sigma1_z[0], high=self.sigma1_z[1]) @@ -1160,10 +1337,15 @@ def randomize(self, data: Optional[Any] = None) -> None: self.z2 = self.R.uniform(low=sigma2_z[0], high=sigma2_z[1]) self.a = self.R.uniform(low=self.alpha[0], high=self.alpha[1]) - def __call__(self, img: np.ndarray): - self.randomize() + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: + if randomize: + self.randomize() + if not self._do_transform: return img + + if self.x2 is None or self.y2 is None or self.z2 is None or self.a is None: + raise RuntimeError("please call the `randomize()` function first.") sigma1 = ensure_tuple_size(tup=(self.x1, self.y1, self.z1), dim=img.ndim - 1) sigma2 = ensure_tuple_size(tup=(self.x2, self.y2, self.z2), dim=img.ndim - 1) return GaussianSharpen(sigma1=sigma1, sigma2=sigma2, alpha=self.a, approx=self.approx)(img) @@ -1180,6 +1362,8 @@ class RandHistogramShift(RandomizableTransform): prob: probability of histogram shift. """ + backend = [TransformBackends.NUMPY] + def __init__(self, num_control_points: Union[Tuple[int, int], int] = 10, prob: float = 0.1) -> None: RandomizableTransform.__init__(self, prob) @@ -1193,9 +1377,13 @@ def __init__(self, num_control_points: Union[Tuple[int, int], int] = 10, prob: f if min(num_control_points) <= 2: raise ValueError("num_control_points should be greater than or equal to 3") self.num_control_points = (min(num_control_points), max(num_control_points)) + self.reference_control_points: np.ndarray + self.floating_control_points: np.ndarray def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) + if not self._do_transform: + return None num_control_point = self.R.randint(self.num_control_points[0], self.num_control_points[1] + 1) self.reference_control_points = np.linspace(0, 1, num_control_point) self.floating_control_points = np.copy(self.reference_control_points) @@ -1204,79 +1392,25 @@ def randomize(self, data: Optional[Any] = None) -> None: self.floating_control_points[i - 1], self.floating_control_points[i + 1] ) - def __call__(self, img: np.ndarray) -> np.ndarray: - self.randomize() + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: + if randomize: + self.randomize() + if not self._do_transform: return img - img_min, img_max = img.min(), img.max() + + if self.reference_control_points is None or self.floating_control_points is None: + raise RuntimeError("please call the `randomize()` function first.") + img_np, *_ = convert_data_type(img, np.ndarray) + img_min, img_max = img_np.min(), img_np.max() reference_control_points_scaled = self.reference_control_points * (img_max - img_min) + img_min floating_control_points_scaled = self.floating_control_points * (img_max - img_min) + img_min - return np.asarray( - np.interp(img, reference_control_points_scaled, floating_control_points_scaled), dtype=img.dtype + img_np = np.asarray( # type: ignore + np.interp(img_np, reference_control_points_scaled, floating_control_points_scaled), dtype=img_np.dtype ) - - -class RandGibbsNoise(RandomizableTransform): - """ - Naturalistic image augmentation via Gibbs artifacts. The transform - randomly applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts - are one of the common type of type artifacts appearing in MRI scans. - - The transform is applied to all the channels in the data. - - For general information on Gibbs artifacts, please refer to: - https://pubs.rsna.org/doi/full/10.1148/rg.313105115 - https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949 - - - Args: - prob (float): probability of applying the transform. - alpha (float, Sequence(float)): Parametrizes the intensity of the Gibbs noise filter applied. Takes - values in the interval [0,1] with alpha = 0 acting as the identity mapping. - If a length-2 list is given as [a,b] then the value of alpha will be - sampled uniformly from the interval [a,b]. 0 <= a <= b <= 1. - as_tensor_output: if true return torch.Tensor, else return np.array. default: True. - """ - - def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0), as_tensor_output: bool = True) -> None: - - if len(alpha) != 2: - raise ValueError("alpha length must be 2.") - if alpha[1] > 1 or alpha[0] < 0: - raise ValueError("alpha must take values in the interval [0,1]") - if alpha[0] > alpha[1]: - raise ValueError("When alpha = [a,b] we need a < b.") - - self.alpha = alpha - self.sampled_alpha = -1.0 # stores last alpha sampled by randomize() - self.as_tensor_output = as_tensor_output - - RandomizableTransform.__init__(self, prob=prob) - - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: - - # randomize application and possibly alpha - self._randomize(None) - - if self._do_transform: - # apply transform - transform = GibbsNoise(self.sampled_alpha, self.as_tensor_output) - img = transform(img) - else: - if isinstance(img, np.ndarray) and self.as_tensor_output: - img = torch.Tensor(img) - elif isinstance(img, torch.Tensor) and not self.as_tensor_output: - img = img.detach().cpu().numpy() + img, *_ = convert_to_dst_type(img_np, dst=img) return img - def _randomize(self, _: Any) -> None: - """ - (1) Set random variable to apply the transform. - (2) Get alpha from uniform distribution. - """ - super().randomize(None) - self.sampled_alpha = self.R.uniform(self.alpha[0], self.alpha[1]) - class GibbsNoise(Transform, Fourier): """ @@ -1296,21 +1430,20 @@ class GibbsNoise(Transform, Fourier): Args: alpha: Parametrizes the intensity of the Gibbs noise filter applied. Takes values in the interval [0,1] with alpha = 0 acting as the identity mapping. - as_tensor_output: if true return torch.Tensor, else return np.array. Default: True. """ - def __init__(self, alpha: float = 0.5, as_tensor_output: bool = True) -> None: + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + @deprecated_arg(name="as_tensor_output", since="0.6") + def __init__(self, alpha: float = 0.1, as_tensor_output: bool = True) -> None: if alpha > 1 or alpha < 0: - raise ValueError("alpha must take values in the interval [0,1].") + raise ValueError("alpha must take values in the interval [0, 1].") self.alpha = alpha - self.as_tensor_output = as_tensor_output - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: n_dims = len(img.shape[1:]) - if isinstance(img, np.ndarray): - img = torch.Tensor(img) # FT k = self.shift_fourier(img, n_dims) # build and apply mask @@ -1318,13 +1451,13 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, # map back img = self.inv_shift_fourier(k, n_dims) - return img if self.as_tensor_output else img.cpu().detach().numpy() + return img - def _apply_mask(self, k: torch.Tensor) -> torch.Tensor: + def _apply_mask(self, k: NdarrayOrTensor) -> NdarrayOrTensor: """Builds and applies a mask on the spatial dimensions. Args: - k (np.ndarray): k-space version of the image. + k: k-space version of the image. Returns: masked version of the k-space image. """ @@ -1345,11 +1478,73 @@ def _apply_mask(self, k: torch.Tensor) -> torch.Tensor: # add channel dimension into mask mask = np.repeat(mask[None], k.shape[0], axis=0) + if isinstance(k, torch.Tensor): + mask, *_ = convert_data_type(mask, torch.Tensor, device=k.device) + # apply binary mask - k_masked = k * torch.tensor(mask, device=k.device) + k_masked: NdarrayOrTensor + k_masked = k * mask return k_masked +class RandGibbsNoise(RandomizableTransform): + """ + Naturalistic image augmentation via Gibbs artifacts. The transform + randomly applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts + are one of the common type of type artifacts appearing in MRI scans. + + The transform is applied to all the channels in the data. + + For general information on Gibbs artifacts, please refer to: + https://pubs.rsna.org/doi/full/10.1148/rg.313105115 + https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949 + + + Args: + prob (float): probability of applying the transform. + alpha (Sequence(float)): Parametrizes the intensity of the Gibbs noise filter applied. Takes + values in the interval [0,1] with alpha = 0 acting as the identity mapping. + If a length-2 list is given as [a,b] then the value of alpha will be + sampled uniformly from the interval [a,b]. 0 <= a <= b <= 1. + """ + + backend = GibbsNoise.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") + def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0), as_tensor_output: bool = True) -> None: + if len(alpha) != 2: + raise ValueError("alpha length must be 2.") + if alpha[1] > 1 or alpha[0] < 0: + raise ValueError("alpha must take values in the interval [0, 1]") + if alpha[0] > alpha[1]: + raise ValueError("When alpha = [a,b] we need a < b.") + + self.alpha = alpha + self.sampled_alpha = -1.0 # stores last alpha sampled by randomize() + + RandomizableTransform.__init__(self, prob=prob) + + def randomize(self, data: Any) -> None: + """ + (1) Set random variable to apply the transform. + (2) Get alpha from uniform distribution. + """ + super().randomize(None) + if not self._do_transform: + return None + self.sampled_alpha = self.R.uniform(self.alpha[0], self.alpha[1]) + + def __call__(self, img: NdarrayOrTensor, randomize: bool = True): + if randomize: + # randomize application and possibly alpha + self.randomize(None) + + if not self._do_transform: + return img + + return GibbsNoise(self.sampled_alpha)(img) + + class KSpaceSpikeNoise(Transform, Fourier): """ Apply localized spikes in `k`-space at the given locations and intensities. @@ -1377,8 +1572,6 @@ class KSpaceSpikeNoise(Transform, Fourier): receive a sequence of intensities. This value should be tested as it is data-dependent. The default values are the 2.5 the mean of the log-intensity for each channel. - as_tensor_output: if ``True`` return torch.Tensor, else return np.array. - Default: ``True``. Example: When working with 4D data, ``KSpaceSpikeNoise(loc = ((3,60,64,32), (64,60,32)), k_intensity = (13,14))`` @@ -1387,6 +1580,9 @@ class KSpaceSpikeNoise(Transform, Fourier): with `log-intensity = 14`. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, loc: Union[Tuple, Sequence[Tuple]], @@ -1395,7 +1591,6 @@ def __init__( ): self.loc = ensure_tuple(loc) - self.as_tensor_output = as_tensor_output self.k_intensity = k_intensity # assert one-to-one relationship between factors and locations @@ -1409,7 +1604,7 @@ def __init__( if isinstance(self.loc[0], Sequence) and k_intensity is not None and not isinstance(self.k_intensity, Sequence): raise ValueError("There must be one intensity_factor value for each tuple of indices in loc.") - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: image with dimensions (C, H, W) or (C, H, W, D) @@ -1421,22 +1616,21 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, raise RuntimeError("Image needs a channel direction.") if isinstance(self.loc[0], int) and len(img.shape) == 4 and len(self.loc) == 2: raise RuntimeError("Input images of dimension 4 need location tuple to be length 3 or 4") - if isinstance(self.loc[0], Sequence) and len(img.shape) == 4 and min(map(lambda x: len(x), self.loc)) == 2: + if isinstance(self.loc[0], Sequence) and len(img.shape) == 4 and min(map(len, self.loc)) == 2: raise RuntimeError("Input images of dimension 4 need location tuple to be length 3 or 4") n_dims = len(img.shape[1:]) - if isinstance(img, np.ndarray): - img = torch.Tensor(img) # FT k = self.shift_fourier(img, n_dims) - log_abs = torch.log(torch.absolute(k) + 1e-10) - phase = torch.angle(k) + lib = np if isinstance(k, np.ndarray) else torch + log_abs = lib.log(lib.abs(k) + 1e-10) + phase = lib.angle(k) k_intensity = self.k_intensity # default log intensity if k_intensity is None: - k_intensity = tuple(torch.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5) + k_intensity = tuple(lib.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5) # highlight if isinstance(self.loc[0], Sequence): @@ -1445,10 +1639,10 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, else: self._set_spike(log_abs, self.loc, k_intensity) # map back - k = torch.exp(log_abs) * torch.exp(1j * phase) - img = self.inv_shift_fourier(k, n_dims) + k = lib.exp(log_abs) * lib.exp(1j * phase) + img, *_ = convert_to_dst_type(self.inv_shift_fourier(k, n_dims), dst=img) - return img if self.as_tensor_output else img.cpu().detach().numpy() + return img def _check_indices(self, img) -> None: """Helper method to check consistency of self.loc and input image. @@ -1468,7 +1662,7 @@ def _check_indices(self, img) -> None: f"The index value at position {i} of one of the tuples in loc = {self.loc} is out of bounds for current image." ) - def _set_spike(self, k: torch.Tensor, idx: Tuple, val: Union[Sequence[float], float]): + def _set_spike(self, k: NdarrayOrTensor, idx: Tuple, val: Union[Sequence[float], float]): """ Helper function to introduce a given intensity at given location. @@ -1504,18 +1698,14 @@ class RandKSpaceSpikeNoise(RandomizableTransform, Fourier): Args: prob: probability of applying the transform, either on all channels at once, or channel-wise if ``channel_wise = True``. - intensity_range: pass a tuple - (a, b) to sample the log-intensity from the interval (a, b) - uniformly for all channels. Or pass sequence of intevals + intensity_range: pass a tuple (a, b) to sample the log-intensity from the interval (a, b) + uniformly for all channels. Or pass sequence of intervals ((a0, b0), (a1, b1), ...) to sample for each respective channel. - In the second case, the number of 2-tuples must match the number of - channels. + In the second case, the number of 2-tuples must match the number of channels. Default ranges is `(0.95x, 1.10x)` where `x` is the mean log-intensity for each channel. channel_wise: treat each channel independently. True by default. - as_tensor_output: if True return torch.Tensor, else - return np.array. default: True. Example: To apply `k`-space spikes randomly with probability 0.5, and @@ -1524,17 +1714,19 @@ class RandKSpaceSpikeNoise(RandomizableTransform, Fourier): ``RandKSpaceSpikeNoise(prob=0.5, intensity_range=(11, 12), channel_wise=True)`` """ + backend = KSpaceSpikeNoise.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, prob: float = 0.1, intensity_range: Optional[Sequence[Union[Sequence[float], float]]] = None, - channel_wise=True, + channel_wise: bool = True, as_tensor_output: bool = True, ): self.intensity_range = intensity_range self.channel_wise = channel_wise - self.as_tensor_output = as_tensor_output self.sampled_k_intensity: List = [] self.sampled_locs: List[Tuple] = [] @@ -1543,13 +1735,14 @@ def __init__( super().__init__(prob) - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: NdarrayOrTensor, randomize: bool = True): """ Apply transform to `img`. Assumes data is in channel-first form. Args: img: image with dimensions (C, H, W) or (C, H, W, D) """ + if ( self.intensity_range is not None and isinstance(self.intensity_range[0], Sequence) @@ -1562,20 +1755,16 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, self.sampled_k_intensity = [] self.sampled_locs = [] - if not isinstance(img, torch.Tensor): - img = torch.Tensor(img) - - intensity_range = self._make_sequence(img) - self._randomize(img, intensity_range) + if randomize: + intensity_range = self._make_sequence(img) + self.randomize(img, intensity_range) - # build/appy transform only if there are spike locations - if self.sampled_locs: - transform = KSpaceSpikeNoise(self.sampled_locs, self.sampled_k_intensity, self.as_tensor_output) - return transform(img) + if not self._do_transform: + return img - return img if self.as_tensor_output else img.detach().numpy() + return KSpaceSpikeNoise(self.sampled_locs, self.sampled_k_intensity)(img) - def _randomize(self, img: torch.Tensor, intensity_range: Sequence[Sequence[float]]) -> None: + def randomize(self, img: NdarrayOrTensor, intensity_range: Sequence[Sequence[float]]) -> None: # type: ignore """ Helper method to sample both the location and intensity of the spikes. When not working channel wise (channel_wise=False) it use the random @@ -1585,25 +1774,24 @@ def _randomize(self, img: torch.Tensor, intensity_range: Sequence[Sequence[float When working channel wise, the method randomly samples a location and intensity for each channel depending on ``self._do_transform``. """ - # randomizing per channel + super().randomize(None) + if not self._do_transform: + return None if self.channel_wise: + # randomizing per channel for i, chan in enumerate(img): - super().randomize(None) - if self._do_transform: - self.sampled_locs.append((i,) + tuple(self.R.randint(0, k) for k in chan.shape)) - self.sampled_k_intensity.append(self.R.uniform(intensity_range[i][0], intensity_range[i][1])) - # working with all channels together + self.sampled_locs.append((i,) + tuple(self.R.randint(0, k) for k in chan.shape)) + self.sampled_k_intensity.append(self.R.uniform(intensity_range[i][0], intensity_range[i][1])) else: - super().randomize(None) - if self._do_transform: - spatial = tuple(self.R.randint(0, k) for k in img.shape[1:]) - self.sampled_locs = [(i,) + spatial for i in range(img.shape[0])] - if isinstance(intensity_range[0], Sequence): - self.sampled_k_intensity = [self.R.uniform(p[0], p[1]) for p in intensity_range] - else: - self.sampled_k_intensity = [self.R.uniform(intensity_range[0], intensity_range[1])] * len(img) + # working with all channels together + spatial = tuple(self.R.randint(0, k) for k in img.shape[1:]) + self.sampled_locs = [(i,) + spatial for i in range(img.shape[0])] + if isinstance(intensity_range[0], Sequence): + self.sampled_k_intensity = [self.R.uniform(p[0], p[1]) for p in intensity_range] + else: + self.sampled_k_intensity = [self.R.uniform(intensity_range[0], intensity_range[1])] * len(img) - def _make_sequence(self, x: torch.Tensor) -> Sequence[Sequence[float]]: + def _make_sequence(self, x: NdarrayOrTensor) -> Sequence[Sequence[float]]: """ Formats the sequence of intensities ranges to Sequence[Sequence[float]]. """ @@ -1615,7 +1803,7 @@ def _make_sequence(self, x: torch.Tensor) -> Sequence[Sequence[float]]: return (ensure_tuple(self.intensity_range),) * x.shape[0] return ensure_tuple(self.intensity_range) - def _set_default_range(self, img: torch.Tensor) -> Sequence[Sequence[float]]: + def _set_default_range(self, img: NdarrayOrTensor) -> Sequence[Sequence[float]]: """ Sets default intensity ranges to be sampled. @@ -1625,18 +1813,19 @@ def _set_default_range(self, img: torch.Tensor) -> Sequence[Sequence[float]]: n_dims = len(img.shape[1:]) k = self.shift_fourier(img, n_dims) - log_abs = torch.log(torch.absolute(k) + 1e-10) - shifted_means = torch.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5 + mod = torch if isinstance(k, torch.Tensor) else np + log_abs = mod.log(mod.absolute(k) + 1e-10) + shifted_means = mod.mean(log_abs, tuple(range(-n_dims, 0))) * 2.5 + if isinstance(shifted_means, torch.Tensor): + shifted_means = shifted_means.to("cpu") return tuple((i * 0.95, i * 1.1) for i in shifted_means) -class RandCoarseDropout(RandomizableTransform): +class RandCoarseTransform(RandomizableTransform): """ - Randomly coarse dropout regions in the image, then fill in the rectangular regions with specified value. - Or keep the rectangular regions and fill in the other areas with specified value. - Refer to papers: https://arxiv.org/abs/1708.04552, https://arxiv.org/pdf/1604.07379 - And other implementation: https://albumentations.ai/docs/api_reference/augmentations/transforms/ - #albumentations.augmentations.transforms.CoarseDropout. + Randomly select coarse regions in the image, then execute transform operations for the regions. + It's the base class of all kinds of region transforms. + Refer to papers: https://arxiv.org/abs/1708.04552 Args: holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to @@ -1646,12 +1835,6 @@ class RandCoarseDropout(RandomizableTransform): if some components of the `spatial_size` are non-positive values, the transform will use the corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. - dropout_holes: if `True`, dropout the regions of holes and fill value, if `False`, keep the holes and - dropout the outside and fill value. default to `True`. - fill_value: target value to fill the dropout regions, if providing a number, will use it as constant - value to fill all the regions. if providing a tuple for the `min` and `max`, will randomly select - value for every pixel / voxel from the range `[min, max)`. if None, will compute the `min` and `max` - value of input image then randomly select value to fill, default to None. max_holes: if not None, define the maximum number to randomly select the expected number of regions. max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region. if some components of the `max_spatial_size` are non-positive values, the transform will use the @@ -1661,12 +1844,12 @@ class RandCoarseDropout(RandomizableTransform): """ + backend = [TransformBackends.NUMPY] + def __init__( self, holes: int, spatial_size: Union[Sequence[int], int], - dropout_holes: bool = True, - fill_value: Optional[Union[Tuple[float, float], float]] = None, max_holes: Optional[int] = None, max_spatial_size: Optional[Union[Sequence[int], int]] = None, prob: float = 0.1, @@ -1676,17 +1859,14 @@ def __init__( raise ValueError("number of holes must be greater than 0.") self.holes = holes self.spatial_size = spatial_size - self.dropout_holes = dropout_holes - if isinstance(fill_value, (tuple, list)): - if len(fill_value) != 2: - raise ValueError("fill value should contain 2 numbers if providing the `min` and `max`.") - self.fill_value = fill_value self.max_holes = max_holes self.max_spatial_size = max_spatial_size self.hole_coords: List = [] def randomize(self, img_size: Sequence[int]) -> None: super().randomize(None) + if not self._do_transform: + return None size = fall_back_tuple(self.spatial_size, img_size) self.hole_coords = [] # clear previously computed coords num_holes = self.holes if self.max_holes is None else self.R.randint(self.holes, self.max_holes + 1) @@ -1697,28 +1877,142 @@ def randomize(self, img_size: Sequence[int]) -> None: valid_size = get_valid_patch_size(img_size, size) self.hole_coords.append((slice(None),) + get_random_patch(img_size, valid_size, self.R)) - def __call__(self, img: np.ndarray): - self.randomize(img.shape[1:]) - ret = img - if self._do_transform: - fill_value = (img.min(), img.max()) if self.fill_value is None else self.fill_value - - if self.dropout_holes: - for h in self.hole_coords: - if isinstance(fill_value, (tuple, list)): - ret[h] = self.R.uniform(fill_value[0], fill_value[1], size=img[h].shape) - else: - ret[h] = fill_value - else: + @abstractmethod + def _transform_holes(self, img: np.ndarray) -> np.ndarray: + """ + Transform the randomly selected `self.hole_coords` in input images. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: + if randomize: + self.randomize(img.shape[1:]) + + if not self._do_transform: + return img + + img_np, *_ = convert_data_type(img, np.ndarray) + out = self._transform_holes(img=img_np) + ret, *_ = convert_to_dst_type(src=out, dst=img) + return ret + + +class RandCoarseDropout(RandCoarseTransform): + """ + Randomly coarse dropout regions in the image, then fill in the rectangular regions with specified value. + Or keep the rectangular regions and fill in the other areas with specified value. + Refer to papers: https://arxiv.org/abs/1708.04552, https://arxiv.org/pdf/1604.07379 + And other implementation: https://albumentations.ai/docs/api_reference/augmentations/transforms/ + #albumentations.augmentations.transforms.CoarseDropout. + + Args: + holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to + randomly select the expected number of regions. + spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg + as the minimum spatial size to randomly select size for every region. + if some components of the `spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + dropout_holes: if `True`, dropout the regions of holes and fill value, if `False`, keep the holes and + dropout the outside and fill value. default to `True`. + fill_value: target value to fill the dropout regions, if providing a number, will use it as constant + value to fill all the regions. if providing a tuple for the `min` and `max`, will randomly select + value for every pixel / voxel from the range `[min, max)`. if None, will compute the `min` and `max` + value of input image then randomly select value to fill, default to None. + max_holes: if not None, define the maximum number to randomly select the expected number of regions. + max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region. + if some components of the `max_spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + prob: probability of applying the transform. + + """ + + def __init__( + self, + holes: int, + spatial_size: Union[Sequence[int], int], + dropout_holes: bool = True, + fill_value: Optional[Union[Tuple[float, float], float]] = None, + max_holes: Optional[int] = None, + max_spatial_size: Optional[Union[Sequence[int], int]] = None, + prob: float = 0.1, + ) -> None: + super().__init__( + holes=holes, spatial_size=spatial_size, max_holes=max_holes, max_spatial_size=max_spatial_size, prob=prob + ) + self.dropout_holes = dropout_holes + if isinstance(fill_value, (tuple, list)): + if len(fill_value) != 2: + raise ValueError("fill value should contain 2 numbers if providing the `min` and `max`.") + self.fill_value = fill_value + + def _transform_holes(self, img: np.ndarray): + """ + Fill the randomly selected `self.hole_coords` in input images. + Please note that we usually only use `self.R` in `randomize()` method, here is a special case. + + """ + fill_value = (img.min(), img.max()) if self.fill_value is None else self.fill_value + + if self.dropout_holes: + for h in self.hole_coords: if isinstance(fill_value, (tuple, list)): - ret = self.R.uniform(fill_value[0], fill_value[1], size=img.shape).astype(img.dtype) + img[h] = self.R.uniform(fill_value[0], fill_value[1], size=img[h].shape) else: - ret = np.full_like(img, fill_value) - for h in self.hole_coords: - ret[h] = img[h] + img[h] = fill_value + ret = img + else: + if isinstance(fill_value, (tuple, list)): + ret = self.R.uniform(fill_value[0], fill_value[1], size=img.shape).astype(img.dtype, copy=False) + else: + ret = np.full_like(img, fill_value) + for h in self.hole_coords: + ret[h] = img[h] return ret +class RandCoarseShuffle(RandCoarseTransform): + """ + Randomly select regions in the image, then shuffle the pixels within every region. + It shuffles every channel separately. + Refer to paper: + Kang, Guoliang, et al. "Patchshuffle regularization." arXiv preprint arXiv:1707.07103 (2017). + https://arxiv.org/abs/1707.07103 + + Args: + holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to + randomly select the expected number of regions. + spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg + as the minimum spatial size to randomly select size for every region. + if some components of the `spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + max_holes: if not None, define the maximum number to randomly select the expected number of regions. + max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region. + if some components of the `max_spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + prob: probability of applying the transform. + + """ + + def _transform_holes(self, img: np.ndarray): + """ + Shuffle the content of randomly selected `self.hole_coords` in input images. + Please note that we usually only use `self.R` in `randomize()` method, here is a special case. + + """ + for h in self.hole_coords: + # shuffle every channel separately + for i, c in enumerate(img[h]): + patch_channel = c.flatten() + self.R.shuffle(patch_channel) + img[h][i] = patch_channel.reshape(c.shape) + return img + + class HistogramNormalize(Transform): """ Apply the histogram normalization to input image. @@ -1732,16 +2026,18 @@ class HistogramNormalize(Transform): mask: if provided, must be ndarray of bools or 0s and 1s, and same shape as `image`. only points at which `mask==True` are used for the equalization. can also provide the mask along with img at runtime. - dtype: data type of the output, default to `float32`. + dtype: data type of the output, if None, same as input image. default to `float32`. """ + backend = [TransformBackends.NUMPY] + def __init__( self, num_bins: int = 256, min: int = 0, max: int = 255, - mask: Optional[np.ndarray] = None, + mask: Optional[NdarrayOrTensor] = None, dtype: DtypeLike = np.float32, ) -> None: self.num_bins = num_bins @@ -1750,104 +2046,124 @@ def __init__( self.mask = mask self.dtype = dtype - def __call__(self, img: np.ndarray, mask: Optional[np.ndarray] = None) -> np.ndarray: - return equalize_hist( - img=img, - mask=mask if mask is not None else self.mask, - num_bins=self.num_bins, - min=self.min, - max=self.max, - dtype=self.dtype, - ) + def __call__(self, img: NdarrayOrTensor, mask: Optional[NdarrayOrTensor] = None) -> NdarrayOrTensor: + img_np, *_ = convert_data_type(img, np.ndarray) + mask = mask if mask is not None else self.mask + mask_np: Optional[np.ndarray] = None + if mask is not None: + mask_np, *_ = convert_data_type(mask, np.ndarray) + + ret = equalize_hist(img=img_np, mask=mask_np, num_bins=self.num_bins, min=self.min, max=self.max) + out, *_ = convert_to_dst_type(src=ret, dst=img, dtype=self.dtype or img.dtype) + + return out -class LocalPatchShuffling(RandomizableTransform): +class IntensityRemap(RandomizableTransform): """ - Takes a 3D image and based on input of the local patch size, shuffles the pixels of the local patch within it. - This process is repeated a for N number of times where every time a different random block is selected for local - pixel shuffling. + Transform for intensity remapping of images. The intensity at each + pixel is replaced by a new values coming from an intensity remappping + curve. - Kang, Guoliang, et al. "Patchshuffle regularization." arXiv preprint arXiv:1707.07103 (2017). + The remapping curve is created by uniformly sampling values from the + possible intensities for the input image and then adding a linear + component. The curve is the rescaled to the input image intensity range. + + Intended to be used as a means to data augmentation via: + :py:class:`monai.transforms.RandIntensityRemap`. + + Implementation is described in the work: + `Intensity augmentation for domain transfer of whole breast segmentation + in MRI `_. + + Args: + kernel_size: window size for averaging operation for the remapping + curve. + slope: slope of the linear component. Easiest to leave default value + and tune the kernel_size parameter instead. + return_map: set to True for the transform to return a dictionary version + of the lookup table used in the intensity remapping. The keys + correspond to the old intensities, and the values are the new + values. """ - def __init__( - self, - prob: float = 1.0, - number_blocks: int = 1000, - blocksize_ratio: int = 10, - channel_wise: bool = True, - device: Optional[torch.device] = None, - image_only: bool = False, - ) -> None: + def __init__(self, kernel_size: int = 30, slope: float = 0.7): + + super().__init__() + + self.kernel_size = kernel_size + self.slope = slope + + def __call__(self, img: torch.Tensor) -> torch.Tensor: """ Args: - prob: The chance of this transform occuring on the given volume. - number_blocks: Total number of time a random 3D block will be selected for local shuffling of pixels/voxels - contained in the block. - blocksize_ratio: This ratio can be used to estimate the local 3D block sizes that will be selected. - channel_wise: If True, treats each channel of the image separately. - device: device on which the tensor will be allocated. - image_only: if True return only the image volume, otherwise return (image, affine). + img: image to remap. """ - RandomizableTransform.__init__(self, prob) - self.prob = prob - self.number_blocks = number_blocks - self.blocksize_ratio = blocksize_ratio - self.channel_wise = channel_wise - def _local_patch_shuffle(self, img: Union[torch.Tensor, np.ndarray], number_blocks: int, blocksize_ratio: int): - im_shape = img.shape - img_copy = copy.deepcopy(img) - for _each_block in range(number_blocks): + img = img.clone() + # sample noise + vals_to_sample = torch.unique(img).tolist() + noise = torch.from_numpy(self.R.choice(vals_to_sample, len(vals_to_sample) - 1 + self.kernel_size)) + # smooth + noise = torch.nn.AvgPool1d(self.kernel_size, stride=1)(noise.unsqueeze(0)).squeeze() + # add linear component + grid = torch.arange(len(noise)) / len(noise) + noise += self.slope * grid + # rescale + noise = (noise - noise.min()) / (noise.max() - noise.min()) * img.max() + img.min() - block_size_x = self.R.randint(1, im_shape[0] // blocksize_ratio) - block_size_y = self.R.randint(1, im_shape[1] // blocksize_ratio) - block_size_z = self.R.randint(1, im_shape[2] // blocksize_ratio) + # intensity remapping function + index_img = torch.bucketize(img, torch.tensor(vals_to_sample)) + img = noise[index_img] - noise_x = self.R.randint(0, im_shape[0] - block_size_x) - noise_y = self.R.randint(0, im_shape[1] - block_size_y) - noise_z = self.R.randint(0, im_shape[2] - block_size_z) + return img - local_patch = img[ - noise_x : noise_x + block_size_x, - noise_y : noise_y + block_size_y, - noise_z : noise_z + block_size_z, - ] - local_patch = local_patch.flatten() - self.R.shuffle(local_patch) - local_patch = local_patch.reshape((block_size_x, block_size_y, block_size_z)) +class RandIntensityRemap(RandomizableTransform): + """ + Transform for intensity remapping of images. The intensity at each + pixel is replaced by a new values coming from an intensity remappping + curve. - img_copy[ - noise_x : noise_x + block_size_x, noise_y : noise_y + block_size_y, noise_z : noise_z + block_size_z - ] = local_patch + The remapping curve is created by uniformly sampling values from the + possible intensities for the input image and then adding a linear + component. The curve is the rescaled to the input image intensity range. - shuffled_image = img_copy - return shuffled_image + Implementation is described in the work: + `Intensity augmentation for domain transfer of whole breast segmentation + in MRI `_. - def __call__( - self, - img: Union[np.ndarray, torch.Tensor], - # spatial_size: Optional[Union[Sequence[int], int]] = None, - # mode: Optional[Union[GridSampleMode, str]] = None, - # padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ): + Args: + prob: probability of applying the transform. + kernel_size: window size for averaging operation for the remapping + curve. + slope: slope of the linear component. Easiest to leave default value + and tune the kernel_size parameter instead. + channel_wise: set to True to treat each channel independently. + """ + + def __init__(self, prob: float = 0.1, kernel_size: int = 30, slope: float = 0.7, channel_wise: bool = True): + + RandomizableTransform.__init__(self, prob=prob) + self.kernel_size = kernel_size + self.slope = slope + self.channel_wise = True + + def __call__(self, img: torch.Tensor) -> torch.Tensor: """ Args: - img: shape must be (num_channels, H, W[, D]), - + img: image to remap. """ - super().randomize(None) - if not self._do_transform: - return img - - if self.channel_wise: - # img = self._local_patch_shuffle(img=img) - for i, _d in enumerate(img): - img[i] = self._local_patch_shuffle( - img=img[i], blocksize_ratio=self.blocksize_ratio, number_blocks=self.number_blocks + if self._do_transform: + if self.channel_wise: + img = torch.stack( + [ + IntensityRemap(self.kernel_size, self.R.choice([-self.slope, self.slope]))(img[i]) + for i in range(len(img)) + ] ) - else: - raise AssertionError("If channel_wise is False, the image needs to be set to channel first") + else: + img = IntensityRemap(self.kernel_size, self.R.choice([-self.slope, self.slope]))(img) + return img diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index bc53fb6b7b..b0f5149456 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,13 +15,11 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ -from collections.abc import Iterable -from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np -import torch -from monai.config import DtypeLike, KeysCollection, NdarrayTensor +from monai.config import DtypeLike, KeysCollection from monai.config.type_definitions import NdarrayOrTensor from monai.transforms.intensity.array import ( AdjustContrast, @@ -32,11 +30,21 @@ KSpaceSpikeNoise, MaskIntensity, NormalizeIntensity, + RandAdjustContrast, RandBiasField, RandCoarseDropout, + RandCoarseShuffle, RandGaussianNoise, + RandGaussianSharpen, + RandGaussianSmooth, + RandGibbsNoise, + RandHistogramShift, RandKSpaceSpikeNoise, RandRicianNoise, + RandScaleIntensity, + RandShiftIntensity, + RandStdShiftIntensity, + SavitzkyGolaySmooth, ScaleIntensity, ScaleIntensityRange, ScaleIntensityRangePercentiles, @@ -44,9 +52,11 @@ StdShiftIntensity, ThresholdIntensity, ) -from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform +from monai.transforms.transform import MapTransform, RandomizableTransform from monai.transforms.utils import is_positive -from monai.utils import convert_to_dst_type, ensure_tuple, ensure_tuple_rep, ensure_tuple_size +from monai.utils import ensure_tuple, ensure_tuple_rep +from monai.utils.deprecate_utils import deprecated_arg +from monai.utils.enums import PostFix __all__ = [ "RandGaussianNoised", @@ -65,6 +75,7 @@ "RandAdjustContrastd", "ScaleIntensityRangePercentilesd", "MaskIntensityd", + "SavitzkyGolaySmoothd", "GaussianSmoothd", "RandGaussianSmoothd", "GaussianSharpend", @@ -75,6 +86,7 @@ "RandKSpaceSpikeNoised", "RandHistogramShiftd", "RandCoarseDropoutd", + "RandCoarseShuffled", "HistogramNormalized", "RandGaussianNoiseD", "RandGaussianNoiseDict", @@ -106,6 +118,8 @@ "ScaleIntensityRangePercentilesDict", "MaskIntensityD", "MaskIntensityDict", + "SavitzkyGolaySmoothD", + "SavitzkyGolaySmoothDict", "GaussianSmoothD", "GaussianSmoothDict", "RandGaussianSmoothD", @@ -126,10 +140,16 @@ "RandRicianNoiseDict", "RandCoarseDropoutD", "RandCoarseDropoutDict", + "RandCoarseShuffleD", + "RandCoarseShuffleDict", "HistogramNormalizeD", "HistogramNormalizeDict", + "RandKSpaceSpikeNoiseD", + "RandKSpaceSpikeNoiseDict", ] +DEFAULT_POST_FIX = PostFix.meta() + class RandGaussianNoised(RandomizableTransform, MapTransform): """ @@ -143,6 +163,7 @@ class RandGaussianNoised(RandomizableTransform, MapTransform): prob: Probability to add Gaussian noise. mean: Mean or “centre” of the distribution. std: Standard deviation (spread) of distribution. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ @@ -152,34 +173,36 @@ def __init__( self, keys: KeysCollection, prob: float = 0.1, - mean: Union[Sequence[float], float] = 0.0, + mean: float = 0.0, std: float = 0.1, + dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.mean = ensure_tuple_rep(mean, len(self.keys)) - self.std = std - self._noise: List[np.ndarray] = [] + self.rand_gaussian_noise = RandGaussianNoise(mean=mean, std=std, prob=1.0, dtype=dtype) - def randomize(self, im_shape: Sequence[int]) -> None: - super().randomize(None) - self._noise.clear() - for m in self.mean: - self._noise.append(self.R.normal(m, self.R.uniform(0, self.std), size=im_shape)) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandGaussianNoised": + super().set_random_state(seed, state) + self.rand_gaussian_noise.set_random_state(seed, state) + return self - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - - image_shape = d[self.keys[0]].shape # image shape from the first data key - self.randomize(image_shape) - if len(self._noise) != len(self.keys): - raise RuntimeError("inconsistent noise items and keys.") + self.randomize(None) if not self._do_transform: return d - for key, noise in self.key_iterator(d, self._noise): - noise, *_ = convert_to_dst_type(noise, d[key]) - d[key] = d[key] + noise + + # all the keys share the same random noise + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + + self.rand_gaussian_noise.randomize(d[first_key]) # type: ignore + for key in self.key_iterator(d): + d[key] = self.rand_gaussian_noise(img=d[key], randomize=False) return d @@ -192,9 +215,7 @@ class RandRicianNoised(RandomizableTransform, MapTransform): Args: keys: Keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` - global_prob: Probability to add Rician noise to the dictionary. - prob: Probability to add Rician noise to each item in the dictionary, - once asserted that noise will be added to the dictionary at all. + prob: Probability to add Rician noise to the dictionary. mean: Mean or "centre" of the Gaussian distributions sampled to make up the Rician noise. std: Standard deviation (spread) of the Gaussian distributions sampled @@ -205,38 +226,52 @@ class RandRicianNoised(RandomizableTransform, MapTransform): histogram. sample_std: If True, sample the spread of the Gaussian distributions uniformly from 0 to std. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: Don't raise exception if key is missing. """ backend = RandRicianNoise.backend + @deprecated_arg("global_prob", since="0.7") def __init__( self, keys: KeysCollection, - global_prob: float = 0.1, - prob: float = 1.0, + prob: float = 0.1, mean: Union[Sequence[float], float] = 0.0, std: Union[Sequence[float], float] = 1.0, channel_wise: bool = False, relative: bool = False, sample_std: bool = True, + dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) - RandomizableTransform.__init__(self, global_prob) - self.rand_rician_noise = RandRicianNoise(prob, mean, std, channel_wise, relative, sample_std) + RandomizableTransform.__init__(self, prob) + self.rand_rician_noise = RandRicianNoise( + prob=1.0, + mean=mean, + std=std, + channel_wise=channel_wise, + relative=relative, + sample_std=sample_std, + dtype=dtype, + ) - def set_random_state(self, seed=None, state=None): + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandRicianNoised": super().set_random_state(seed, state) self.rand_rician_noise.set_random_state(seed, state) + return self - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - super().randomize(None) + self.randomize(None) if not self._do_transform: return d + for key in self.key_iterator(d): - d[key] = self.rand_rician_noise(d[key]) + d[key] = self.rand_rician_noise(d[key], randomize=True) return d @@ -253,7 +288,7 @@ def __init__( offset: float, factor_key: Optional[str] = None, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", + meta_key_postfix: str = DEFAULT_POST_FIX, allow_missing_keys: bool = False, ) -> None: """ @@ -272,7 +307,7 @@ def __init__( the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. used to extract the factor value is `factor_key` is not None. allow_missing_keys: don't raise exception if key is missing. @@ -302,7 +337,7 @@ class RandShiftIntensityd(RandomizableTransform, MapTransform): Dictionary-based version :py:class:`monai.transforms.RandShiftIntensity`. """ - backend = ShiftIntensity.backend + backend = RandShiftIntensity.backend def __init__( self, @@ -310,7 +345,7 @@ def __init__( offsets: Union[Tuple[float, float], float], factor_key: Optional[str] = None, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", + meta_key_postfix: str = DEFAULT_POST_FIX, prob: float = 0.1, allow_missing_keys: bool = False, ) -> None: @@ -331,7 +366,7 @@ def __init__( the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. used to extract the factor value is `factor_key` is not None. prob: probability of rotating. @@ -341,36 +376,34 @@ def __init__( MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - if isinstance(offsets, (int, float)): - self.offsets = (min(-offsets, offsets), max(-offsets, offsets)) - else: - if len(offsets) != 2: - raise ValueError("offsets should be a number or pair of numbers.") - self.offsets = (min(offsets), max(offsets)) - self._offset = self.offsets[0] self.factor_key = ensure_tuple_rep(factor_key, len(self.keys)) self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) if len(self.keys) != len(self.meta_keys): raise ValueError("meta_keys should have the same length as keys.") self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self.shifter = ShiftIntensity(self._offset) + self.shifter = RandShiftIntensity(offsets=offsets, prob=1.0) - def randomize(self, data: Optional[Any] = None) -> None: - self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) - super().randomize(None) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandShiftIntensityd": + super().set_random_state(seed, state) + self.shifter.set_random_state(seed, state) + return self def __call__(self, data) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize() + self.randomize(None) if not self._do_transform: return d + + # all the keys share the same random shift factor + self.shifter.randomize(None) for key, factor_key, meta_key, meta_key_postfix in self.key_iterator( d, self.factor_key, self.meta_keys, self.meta_key_postfix ): meta_key = meta_key or f"{key}_{meta_key_postfix}" factor: Optional[float] = d[meta_key].get(factor_key) if meta_key in d else None - offset = self._offset if factor is None else self._offset * factor - d[key] = self.shifter(d[key], offset=offset) + d[key] = self.shifter(d[key], factor=factor, randomize=False) return d @@ -398,7 +431,7 @@ def __init__( nonzero: whether only count non-zero values. channel_wise: if True, calculate on each channel separately. Please ensure that the first dimension represents the channel of the image if True. - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) @@ -416,7 +449,7 @@ class RandStdShiftIntensityd(RandomizableTransform, MapTransform): Dictionary-based version :py:class:`monai.transforms.RandStdShiftIntensity`. """ - backend = StdShiftIntensity.backend + backend = RandStdShiftIntensity.backend def __init__( self, @@ -437,35 +470,32 @@ def __init__( prob: probability of std shift. nonzero: whether only count non-zero values. channel_wise: if True, calculate on each channel separately. - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.shifter = RandStdShiftIntensity( + factors=factors, nonzero=nonzero, channel_wise=channel_wise, dtype=dtype, prob=1.0 + ) - if isinstance(factors, (int, float)): - self.factors = (min(-factors, factors), max(-factors, factors)) - elif len(factors) != 2: - raise ValueError("factors should be a number or pair of numbers.") - else: - self.factors = (min(factors), max(factors)) - self.factor = self.factors[0] - self.nonzero = nonzero - self.channel_wise = channel_wise - self.dtype = dtype - - def randomize(self, data: Optional[Any] = None) -> None: - self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) - super().randomize(None) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandStdShiftIntensityd": + super().set_random_state(seed, state) + self.shifter.set_random_state(seed, state) + return self def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize() + self.randomize(None) if not self._do_transform: return d - shifter = StdShiftIntensity(self.factor, self.nonzero, self.channel_wise, self.dtype) + + # all the keys share the same random shift factor + self.shifter.randomize(None) for key in self.key_iterator(d): - d[key] = shifter(d[key]) + d[key] = self.shifter(d[key], randomize=False) return d @@ -484,6 +514,8 @@ def __init__( minv: Optional[float] = 0.0, maxv: Optional[float] = 1.0, factor: Optional[float] = None, + channel_wise: bool = False, + dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: """ @@ -493,12 +525,15 @@ def __init__( minv: minimum value of output data. maxv: maximum value of output data. factor: factor scale by ``v = v * (1 + factor)``. In order to use - this parameter, please set `minv` and `maxv` into None. + this parameter, please set both `minv` and `maxv` into None. + channel_wise: if True, scale on each channel separately. Please ensure + that the first dimension represents the channel of the image if True. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) - self.scaler = ScaleIntensity(minv, maxv, factor) + self.scaler = ScaleIntensity(minv, maxv, factor, channel_wise, dtype) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -512,13 +547,14 @@ class RandScaleIntensityd(RandomizableTransform, MapTransform): Dictionary-based version :py:class:`monai.transforms.RandScaleIntensity`. """ - backend = ScaleIntensity.backend + backend = RandScaleIntensity.backend def __init__( self, keys: KeysCollection, factors: Union[Tuple[float, float], float], prob: float = 0.1, + dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: """ @@ -529,32 +565,31 @@ def __init__( if single number, factor value is picked from (-factors, factors). prob: probability of rotating. (Default 0.1, with 10% probability it returns a rotated array.) + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.scaler = RandScaleIntensity(factors=factors, dtype=dtype, prob=1.0) - if isinstance(factors, (int, float)): - self.factors = (min(-factors, factors), max(-factors, factors)) - elif len(factors) != 2: - raise ValueError("factors should be a number or pair of numbers.") - else: - self.factors = (min(factors), max(factors)) - self.factor = self.factors[0] - - def randomize(self, data: Optional[Any] = None) -> None: - self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) - super().randomize(None) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandScaleIntensityd": + super().set_random_state(seed, state) + self.scaler.set_random_state(seed, state) + return self def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize() + self.randomize(None) if not self._do_transform: return d - scaler = ScaleIntensity(minv=None, maxv=None, factor=self.factor) + + # all the keys share the same random scale factor + self.scaler.randomize(None) for key in self.key_iterator(d): - d[key] = scaler(d[key]) + d[key] = self.scaler(d[key], randomize=False) return d @@ -563,13 +598,15 @@ class RandBiasFieldd(RandomizableTransform, MapTransform): Dictionary-based version :py:class:`monai.transforms.RandBiasField`. """ + backend = RandBiasField.backend + def __init__( self, keys: KeysCollection, degree: int = 3, coeff_range: Tuple[float, float] = (0.0, 0.1), dtype: DtypeLike = np.float32, - prob: float = 1.0, + prob: float = 0.1, allow_missing_keys: bool = False, ) -> None: """ @@ -579,7 +616,7 @@ def __init__( degree: degree of freedom of the polynomials. The value should be no less than 1. Defaults to 3. coeff_range: range of the random coefficients. Defaults to (0.0, 0.1). - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. prob: probability to do random bias field. allow_missing_keys: don't raise exception if key is missing. @@ -587,18 +624,29 @@ def __init__( MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.rand_bias_field = RandBiasField(degree, coeff_range, dtype, prob) + self.rand_bias_field = RandBiasField(degree=degree, coeff_range=coeff_range, dtype=dtype, prob=1.0) - def randomize(self, data: Optional[Any] = None) -> None: - super().randomize(None) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandBiasFieldd": + super().set_random_state(seed, state) + self.rand_bias_field.set_random_state(seed, state) + return self - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize() + self.randomize(None) if not self._do_transform: return d + + # all the keys share the same random bias factor + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + + self.rand_bias_field.randomize(img_size=d[first_key].shape[1:]) # type: ignore for key in self.key_iterator(d): - d[key] = self.rand_bias_field(d[key]) + d[key] = self.rand_bias_field(d[key], randomize=False) return d @@ -614,9 +662,9 @@ class NormalizeIntensityd(MapTransform): subtrahend: the amount to subtract by (usually the mean) divisor: the amount to divide by (usually the standard deviation) nonzero: whether only normalize non-zero values. - channel_wise: if using calculated mean and std, calculate on each channel separately - or calculate on the entire image directly. - dtype: output data type, defaults to float32. + channel_wise: if True, calculate on each channel separately, otherwise, calculate on + the entire image directly. default to False. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ @@ -655,6 +703,8 @@ class ThresholdIntensityd(MapTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = ThresholdIntensity.backend + def __init__( self, keys: KeysCollection, @@ -666,7 +716,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.filter = ThresholdIntensity(threshold, above, cval) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.filter(d[key]) @@ -685,23 +735,27 @@ class ScaleIntensityRanged(MapTransform): b_min: intensity target range min. b_max: intensity target range max. clip: whether to perform clip after scaling. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ + backend = ScaleIntensityRange.backend + def __init__( self, keys: KeysCollection, a_min: float, a_max: float, - b_min: float, - b_max: float, + b_min: Optional[float] = None, + b_max: Optional[float] = None, clip: bool = False, + dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.scaler = ScaleIntensityRange(a_min, a_max, b_min, b_max, clip) + self.scaler = ScaleIntensityRange(a_min, a_max, b_min, b_max, clip, dtype) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.scaler(d[key]) @@ -722,11 +776,13 @@ class AdjustContrastd(MapTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = AdjustContrast.backend + def __init__(self, keys: KeysCollection, gamma: float, allow_missing_keys: bool = False) -> None: super().__init__(keys, allow_missing_keys) self.adjuster = AdjustContrast(gamma) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.adjuster(d[key]) @@ -749,6 +805,8 @@ class RandAdjustContrastd(RandomizableTransform, MapTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = RandAdjustContrast.backend + def __init__( self, keys: KeysCollection, @@ -758,34 +816,25 @@ def __init__( ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.adjuster = RandAdjustContrast(gamma=gamma, prob=1.0) - if isinstance(gamma, (int, float)): - if gamma <= 0.5: - raise ValueError( - "if gamma is single number, must greater than 0.5 and value is picked from (0.5, gamma)" - ) - self.gamma = (0.5, gamma) - elif len(gamma) != 2: - raise ValueError("gamma should be a number or pair of numbers.") - else: - self.gamma = (min(gamma), max(gamma)) - - self.gamma_value: Optional[float] = None - - def randomize(self, data: Optional[Any] = None) -> None: - super().randomize(None) - self.gamma_value = self.R.uniform(low=self.gamma[0], high=self.gamma[1]) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandAdjustContrastd": + super().set_random_state(seed, state) + self.adjuster.set_random_state(seed, state) + return self - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize() - if self.gamma_value is None: - raise RuntimeError("gamma_value is not set.") + self.randomize(None) if not self._do_transform: return d - adjuster = AdjustContrast(self.gamma_value) + + # all the keys share the same random gamma value + self.adjuster.randomize(None) for key in self.key_iterator(d): - d[key] = adjuster(d[key]) + d[key] = self.adjuster(d[key], randomize=False) return d @@ -802,24 +851,31 @@ class ScaleIntensityRangePercentilesd(MapTransform): b_max: intensity target range max. clip: whether to perform clip after scaling. relative: whether to scale to the corresponding percentiles of [b_min, b_max] + channel_wise: if True, compute intensity percentile and normalize every channel separately. + default to False. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ + backend = ScaleIntensityRangePercentiles.backend + def __init__( self, keys: KeysCollection, lower: float, upper: float, - b_min: float, - b_max: float, + b_min: Optional[float], + b_max: Optional[float], clip: bool = False, relative: bool = False, + channel_wise: bool = False, + dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative) + self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative, channel_wise, dtype) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.scaler(d[key]) @@ -847,10 +903,12 @@ class MaskIntensityd(MapTransform): """ + backend = MaskIntensity.backend + def __init__( self, keys: KeysCollection, - mask_data: Optional[np.ndarray] = None, + mask_data: Optional[NdarrayOrTensor] = None, mask_key: Optional[str] = None, select_fn: Callable = is_positive, allow_missing_keys: bool = False, @@ -859,13 +917,50 @@ def __init__( self.converter = MaskIntensity(mask_data=mask_data, select_fn=select_fn) self.mask_key = mask_key if mask_data is None else None - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key], d[self.mask_key]) if self.mask_key is not None else self.converter(d[key]) return d +class SavitzkyGolaySmoothd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.SavitzkyGolaySmooth`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + window_length: length of the filter window, must be a positive odd integer. + order: order of the polynomial to fit to each window, must be less than ``window_length``. + axis: optional axis along which to apply the filter kernel. Default 1 (first spatial dimension). + mode: optional padding mode, passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` + or ``'circular'``. default: ``'zeros'``. See ``torch.nn.Conv1d()`` for more information. + allow_missing_keys: don't raise exception if key is missing. + + """ + + backend = SavitzkyGolaySmooth.backend + + def __init__( + self, + keys: KeysCollection, + window_length: int, + order: int, + axis: int = 1, + mode: str = "zeros", + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.converter = SavitzkyGolaySmooth(window_length=window_length, order=order, axis=axis, mode=mode) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.converter(d[key]) + return d + + class GaussianSmoothd(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.GaussianSmooth`. @@ -882,6 +977,8 @@ class GaussianSmoothd(MapTransform): """ + backend = GaussianSmooth.backend + def __init__( self, keys: KeysCollection, @@ -892,7 +989,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = GaussianSmooth(sigma, approx=approx) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -916,6 +1013,8 @@ class RandGaussianSmoothd(RandomizableTransform, MapTransform): """ + backend = RandGaussianSmooth.backend + def __init__( self, keys: KeysCollection, @@ -928,25 +1027,27 @@ def __init__( ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.sigma_x, self.sigma_y, self.sigma_z = sigma_x, sigma_y, sigma_z - self.approx = approx - - self.x, self.y, self.z = self.sigma_x[0], self.sigma_y[0], self.sigma_z[0] + self.rand_smooth = RandGaussianSmooth( + sigma_x=sigma_x, sigma_y=sigma_y, sigma_z=sigma_z, approx=approx, prob=1.0 + ) - def randomize(self, data: Optional[Any] = None) -> None: - super().randomize(None) - self.x = self.R.uniform(low=self.sigma_x[0], high=self.sigma_x[1]) - self.y = self.R.uniform(low=self.sigma_y[0], high=self.sigma_y[1]) - self.z = self.R.uniform(low=self.sigma_z[0], high=self.sigma_z[1]) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandGaussianSmoothd": + super().set_random_state(seed, state) + self.rand_smooth.set_random_state(seed, state) + return self - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize() + self.randomize(None) if not self._do_transform: return d + + # all the keys share the same random sigma + self.rand_smooth.randomize(None) for key in self.key_iterator(d): - sigma = ensure_tuple_size(tup=(self.x, self.y, self.z), dim=d[key].ndim - 1) - d[key] = GaussianSmooth(sigma=sigma, approx=self.approx)(d[key]) + d[key] = self.rand_smooth(d[key], randomize=False) return d @@ -970,6 +1071,8 @@ class GaussianSharpend(MapTransform): """ + backend = GaussianSharpen.backend + def __init__( self, keys: KeysCollection, @@ -982,7 +1085,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = GaussianSharpen(sigma1, sigma2, alpha, approx=approx) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -1013,6 +1116,8 @@ class RandGaussianSharpend(RandomizableTransform, MapTransform): """ + backend = RandGaussianSharpen.backend + def __init__( self, keys: KeysCollection, @@ -1029,37 +1134,35 @@ def __init__( ): MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.sigma1_x = sigma1_x - self.sigma1_y = sigma1_y - self.sigma1_z = sigma1_z - self.sigma2_x = sigma2_x - self.sigma2_y = sigma2_y - self.sigma2_z = sigma2_z - self.alpha = alpha - self.approx = approx - - def randomize(self, data: Optional[Any] = None) -> None: - super().randomize(None) - self.x1 = self.R.uniform(low=self.sigma1_x[0], high=self.sigma1_x[1]) - self.y1 = self.R.uniform(low=self.sigma1_y[0], high=self.sigma1_y[1]) - self.z1 = self.R.uniform(low=self.sigma1_z[0], high=self.sigma1_z[1]) - sigma2_x = (self.sigma2_x, self.x1) if not isinstance(self.sigma2_x, Iterable) else self.sigma2_x - sigma2_y = (self.sigma2_y, self.y1) if not isinstance(self.sigma2_y, Iterable) else self.sigma2_y - sigma2_z = (self.sigma2_z, self.z1) if not isinstance(self.sigma2_z, Iterable) else self.sigma2_z - self.x2 = self.R.uniform(low=sigma2_x[0], high=sigma2_x[1]) - self.y2 = self.R.uniform(low=sigma2_y[0], high=sigma2_y[1]) - self.z2 = self.R.uniform(low=sigma2_z[0], high=sigma2_z[1]) - self.a = self.R.uniform(low=self.alpha[0], high=self.alpha[1]) + self.rand_sharpen = RandGaussianSharpen( + sigma1_x=sigma1_x, + sigma1_y=sigma1_y, + sigma1_z=sigma1_z, + sigma2_x=sigma2_x, + sigma2_y=sigma2_y, + sigma2_z=sigma2_z, + alpha=alpha, + approx=approx, + prob=1.0, + ) - def __call__(self, data): + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandGaussianSharpend": + super().set_random_state(seed, state) + self.rand_sharpen.set_random_state(seed, state) + return self + + def __call__(self, data: Dict[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize() + self.randomize(None) if not self._do_transform: return d + + # all the keys share the same random sigma1, sigma2, etc. + self.rand_sharpen.randomize(None) for key in self.key_iterator(d): - sigma1 = ensure_tuple_size(tup=(self.x1, self.y1, self.z1), dim=d[key].ndim - 1) - sigma2 = ensure_tuple_size(tup=(self.x2, self.y2, self.z2), dim=d[key].ndim - 1) - d[key] = GaussianSharpen(sigma1=sigma1, sigma2=sigma2, alpha=self.a, approx=self.approx)(d[key]) + d[key] = self.rand_sharpen(d[key], randomize=False) return d @@ -1078,6 +1181,8 @@ class RandHistogramShiftd(RandomizableTransform, MapTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = RandHistogramShift.backend + def __init__( self, keys: KeysCollection, @@ -1087,38 +1192,25 @@ def __init__( ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - if isinstance(num_control_points, int): - if num_control_points <= 2: - raise ValueError("num_control_points should be greater than or equal to 3") - self.num_control_points = (num_control_points, num_control_points) - else: - if len(num_control_points) != 2: - raise ValueError("num_control points should be a number or a pair of numbers") - if min(num_control_points) <= 2: - raise ValueError("num_control_points should be greater than or equal to 3") - self.num_control_points = (min(num_control_points), max(num_control_points)) - - def randomize(self, data: Optional[Any] = None) -> None: - super().randomize(None) - num_control_point = self.R.randint(self.num_control_points[0], self.num_control_points[1] + 1) - self.reference_control_points = np.linspace(0, 1, num_control_point) - self.floating_control_points = np.copy(self.reference_control_points) - for i in range(1, num_control_point - 1): - self.floating_control_points[i] = self.R.uniform( - self.floating_control_points[i - 1], self.floating_control_points[i + 1] - ) - - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + self.shifter = RandHistogramShift(num_control_points=num_control_points, prob=1.0) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandHistogramShiftd": + super().set_random_state(seed, state) + self.shifter.set_random_state(seed, state) + return self + + def __call__(self, data: Dict[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize() + self.randomize(None) if not self._do_transform: return d + + # all the keys share the same random shift params + self.shifter.randomize(None) for key in self.key_iterator(d): - img_min, img_max = d[key].min(), d[key].max() - reference_control_points_scaled = self.reference_control_points * (img_max - img_min) + img_min - floating_control_points_scaled = self.floating_control_points * (img_max - img_min) + img_min - dtype = d[key].dtype - d[key] = np.interp(d[key], reference_control_points_scaled, floating_control_points_scaled).astype(dtype) + d[key] = self.shifter(d[key], randomize=False) return d @@ -1144,56 +1236,43 @@ class RandGibbsNoised(RandomizableTransform, MapTransform): values in the interval [0,1] with alpha = 0 acting as the identity mapping. If a length-2 list is given as [a,b] then the value of alpha will be sampled uniformly from the interval [a,b]. - as_tensor_output: if true return torch.Tensor, else return np.array. default: True. allow_missing_keys: do not raise exception if key is missing. """ + backend = RandGibbsNoise.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0), - as_tensor_output: bool = True, allow_missing_keys: bool = False, + as_tensor_output: bool = True, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob=prob) - self.alpha = alpha - self.sampled_alpha = -1.0 # stores last alpha sampled by randomize() - self.as_tensor_output = as_tensor_output + self.rand_gibbs_noise = RandGibbsNoise(alpha=alpha, prob=1.0) - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandGibbsNoised": + super().set_random_state(seed, state) + self.rand_gibbs_noise.set_random_state(seed, state) + return self + def __call__(self, data: Dict[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self._randomize(None) - - for i, key in enumerate(self.key_iterator(d)): - if self._do_transform: - if i == 0: - transform = GibbsNoise(self.sampled_alpha, self.as_tensor_output) - d[key] = transform(d[key]) - else: - if isinstance(d[key], np.ndarray) and self.as_tensor_output: - d[key] = torch.Tensor(d[key]) - elif isinstance(d[key], torch.Tensor) and not self.as_tensor_output: - d[key] = self._to_numpy(d[key]) - return d - - def _randomize(self, _: Any) -> None: - """ - (1) Set random variable to apply the transform. - (2) Get alpha from uniform distribution. - """ - super().randomize(None) - self.sampled_alpha = self.R.uniform(self.alpha[0], self.alpha[1]) + self.randomize(None) + if not self._do_transform: + return d - def _to_numpy(self, d: Union[torch.Tensor, np.ndarray]) -> np.ndarray: - if isinstance(d, torch.Tensor): - d_numpy: np.ndarray = d.cpu().detach().numpy() - return d_numpy + # all the keys share the same random noise params + self.rand_gibbs_noise.randomize(None) + for key in self.key_iterator(d): + d[key] = self.rand_gibbs_noise(d[key], randomize=False) + return d class GibbsNoised(MapTransform): @@ -1212,20 +1291,20 @@ class GibbsNoised(MapTransform): you need to transform. alpha (float): Parametrizes the intensity of the Gibbs noise filter applied. Takes values in the interval [0,1] with alpha = 0 acting as the identity mapping. - as_tensor_output: if true return torch.Tensor, else return np.array. default: True. allow_missing_keys: do not raise exception if key is missing. """ + backend = GibbsNoise.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( - self, keys: KeysCollection, alpha: float = 0.5, as_tensor_output: bool = True, allow_missing_keys: bool = False + self, keys: KeysCollection, alpha: float = 0.5, allow_missing_keys: bool = False, as_tensor_output: bool = True ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) - self.transform = GibbsNoise(alpha, as_tensor_output) + self.transform = GibbsNoise(alpha) - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): @@ -1264,8 +1343,6 @@ class KSpaceSpikeNoised(MapTransform): receive a sequence of intensities. This value should be tested as it is data-dependent. The default values are the 2.5 the mean of the log-intensity for each channel. - as_tensor_output: if ``True`` return torch.Tensor, else return np.array. - Default: ``True``. allow_missing_keys: do not raise exception if key is missing. Example: @@ -1276,21 +1353,22 @@ class KSpaceSpikeNoised(MapTransform): with `log-intensity = 14`. """ + backend = KSpaceSpikeNoise.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, loc: Union[Tuple, Sequence[Tuple]], k_intensity: Optional[Union[Sequence[float], float]] = None, - as_tensor_output: bool = True, allow_missing_keys: bool = False, + as_tensor_output: bool = True, ) -> None: super().__init__(keys, allow_missing_keys) - self.transform = KSpaceSpikeNoise(loc, k_intensity, as_tensor_output) + self.transform = KSpaceSpikeNoise(loc, k_intensity) - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: """ Args: data: Expects image/label to have dimensions (C, H, W) or @@ -1320,110 +1398,66 @@ class RandKSpaceSpikeNoised(RandomizableTransform, MapTransform): Args: keys: "image", "label", or ["image", "label"] depending on which data you need to transform. - global_prob: probability of applying transform to the dictionary. prob: probability to add spike artifact to each item in the dictionary provided it is realized that the noise will be applied to the dictionary. - intensity_ranges: Dictionary with intensity - ranges to sample for each key. Given a dictionary value of `(a, b)` the - transform will sample the log-intensity from the interval `(a, b)` uniformly for all - channels of the respective key. If a sequence of intevals `((a0, b0), (a1, b1), ...)` - is given, then the transform will sample from each interval for each - respective channel. In the second case, the number of 2-tuples must - match the number of channels. Default ranges is `(0.95x, 1.10x)` - where `x` is the mean log-intensity for each channel. - channel_wise: treat each channel independently. True by - default. - common_sampling: If ``True`` same values for location and log-intensity - will be sampled for the image and label. - common_seed: Seed to be used in case ``common_sampling = True``. - as_tensor_output: if ``True`` return torch.Tensor, else return - np.array. Default: ``True``. + intensity_range: pass a tuple (a, b) to sample the log-intensity from the interval (a, b) + uniformly for all channels. Or pass sequence of intervals + ((a0, b0), (a1, b1), ...) to sample for each respective channel. + In the second case, the number of 2-tuples must match the number of channels. + Default ranges is `(0.95x, 1.10x)` where `x` is the mean + log-intensity for each channel. + channel_wise: treat each channel independently. True by default. allow_missing_keys: do not raise exception if key is missing. Example: To apply `k`-space spikes randomly on the image only, with probability 0.5, and log-intensity sampled from the interval [13, 15] for each channel independently, one uses - ``RandKSpaceSpikeNoised("image", prob=0.5, intensity_ranges={"image":(13,15)}, channel_wise=True)``. + ``RandKSpaceSpikeNoised("image", prob=0.5, intensity_ranges=(13, 15), channel_wise=True)``. """ + backend = RandKSpaceSpikeNoise.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") + @deprecated_arg(name="common_sampling", since="0.6") + @deprecated_arg(name="common_seed", since="0.6") + @deprecated_arg(name="global_prob", since="0.6") def __init__( self, keys: KeysCollection, global_prob: float = 1.0, prob: float = 0.1, - intensity_ranges: Optional[Mapping[Hashable, Sequence[Union[Sequence[float], float]]]] = None, + intensity_range: Optional[Sequence[Union[Sequence[float], float]]] = None, channel_wise: bool = True, common_sampling: bool = False, common_seed: int = 42, - as_tensor_output: bool = True, allow_missing_keys: bool = False, + as_tensor_output: bool = True, ): - MapTransform.__init__(self, keys, allow_missing_keys) - RandomizableTransform.__init__(self, global_prob) - - self.common_sampling = common_sampling - self.common_seed = common_seed - self.as_tensor_output = as_tensor_output - # the spikes artifact is amplitude dependent so we instantiate one per key - self.transforms = {} - if isinstance(intensity_ranges, Mapping): - for k in self.keys: - self.transforms[k] = RandKSpaceSpikeNoise( - prob, intensity_ranges[k], channel_wise, self.as_tensor_output - ) - else: - for k in self.keys: - self.transforms[k] = RandKSpaceSpikeNoise(prob, None, channel_wise, self.as_tensor_output) - - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: - """ - Args: - data: Expects image/label to have dimensions (C, H, W) or - (C, H, W, D), where C is the channel. - """ - d = dict(data) - super().randomize(None) - - # In case the same spikes are desired for both image and label. - if self.common_sampling: - for k in self.keys: - self.transforms[k].set_random_state(self.common_seed) - - for key, t in self.key_iterator(d, self.transforms): - if self._do_transform: - d[key] = self.transforms[t](d[key]) - else: - if isinstance(d[key], np.ndarray) and self.as_tensor_output: - d[key] = torch.Tensor(d[key]) - elif isinstance(d[key], torch.Tensor) and not self.as_tensor_output: - d[key] = self._to_numpy(d[key]) - return d - - def set_rand_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> None: - """ - Set the random state locally to control the randomness. - User should use this method instead of ``set_random_state``. + RandomizableTransform.__init__(self, prob=prob) + self.rand_noise = RandKSpaceSpikeNoise(prob=1.0, intensity_range=intensity_range, channel_wise=channel_wise) - Args: - seed: set the random state with an integer seed. - state: set the random state with a `np.random.RandomState` object.""" + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandKSpaceSpikeNoised": + super().set_random_state(seed, state) + self.rand_noise.set_random_state(seed, state) + return self - self.set_random_state(seed, state) - for key in self.keys: - self.transforms[key].set_random_state(seed, state) + def __call__(self, data: Dict[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + self.randomize(None) + if not self._do_transform: + return d - def _to_numpy(self, d: Union[torch.Tensor, np.ndarray]) -> np.ndarray: - if isinstance(d, torch.Tensor): - d_numpy: np.ndarray = d.cpu().detach().numpy() - return d_numpy + for key in self.key_iterator(d): + d[key] = self.rand_noise(d[key], randomize=True) + return d -class RandCoarseDropoutd(Randomizable, MapTransform): +class RandCoarseDropoutd(RandomizableTransform, MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.RandCoarseDropout`. Expect all the data specified by `keys` have same spatial shape and will randomly dropout the same regions @@ -1455,6 +1489,8 @@ class RandCoarseDropoutd(Randomizable, MapTransform): """ + backend = RandCoarseDropout.backend + def __init__( self, keys: KeysCollection, @@ -1468,6 +1504,7 @@ def __init__( allow_missing_keys: bool = False, ): MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob=prob) self.dropper = RandCoarseDropout( holes=holes, spatial_size=spatial_size, @@ -1475,19 +1512,99 @@ def __init__( fill_value=fill_value, max_holes=max_holes, max_spatial_size=max_spatial_size, - prob=prob, + prob=1.0, ) - def randomize(self, img_size: Sequence[int]) -> None: - self.dropper.randomize(img_size=img_size) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandCoarseDropoutd": + super().set_random_state(seed, state) + self.dropper.set_random_state(seed, state) + return self def __call__(self, data): d = dict(data) - # expect all the specified keys have same spatial shape - self.randomize(d[self.keys[0]].shape[1:]) - if self.dropper._do_transform: - for key in self.key_iterator(d): - d[key] = self.dropper(img=d[key]) + self.randomize(None) + if not self._do_transform: + return d + + # expect all the specified keys have same spatial shape and share same random holes + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + + self.dropper.randomize(d[first_key].shape[1:]) + for key in self.key_iterator(d): + d[key] = self.dropper(img=d[key], randomize=False) + + return d + + +class RandCoarseShuffled(RandomizableTransform, MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.RandCoarseShuffle`. + Expect all the data specified by `keys` have same spatial shape and will randomly dropout the same regions + for every key, if want to shuffle different regions for every key, please use this transform separately. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to + randomly select the expected number of regions. + spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg + as the minimum spatial size to randomly select size for every region. + if some components of the `spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + max_holes: if not None, define the maximum number to randomly select the expected number of regions. + max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region. + if some components of the `max_spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + prob: probability of applying the transform. + allow_missing_keys: don't raise exception if key is missing. + + """ + + backend = RandCoarseShuffle.backend + + def __init__( + self, + keys: KeysCollection, + holes: int, + spatial_size: Union[Sequence[int], int], + max_holes: Optional[int] = None, + max_spatial_size: Optional[Union[Sequence[int], int]] = None, + prob: float = 0.1, + allow_missing_keys: bool = False, + ): + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob=prob) + self.shuffle = RandCoarseShuffle( + holes=holes, spatial_size=spatial_size, max_holes=max_holes, max_spatial_size=max_spatial_size, prob=1.0 + ) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandCoarseShuffled": + super().set_random_state(seed, state) + self.shuffle.set_random_state(seed, state) + return self + + def __call__(self, data): + d = dict(data) + self.randomize(None) + if not self._do_transform: + return d + + # expect all the specified keys have same spatial shape and share same random holes + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + + self.shuffle.randomize(d[first_key].shape[1:]) + for key in self.key_iterator(d): + d[key] = self.shuffle(img=d[key], randomize=False) return d @@ -1507,18 +1624,20 @@ class HistogramNormalized(MapTransform): only points at which `mask==True` are used for the equalization. can also provide the mask by `mask_key` at runtime. mask_key: if mask is None, will try to get the mask with `mask_key`. - dtype: data type of the output, default to `float32`. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: do not raise exception if key is missing. """ + backend = HistogramNormalize.backend + def __init__( self, keys: KeysCollection, num_bins: int = 256, min: int = 0, max: int = 255, - mask: Optional[np.ndarray] = None, + mask: Optional[NdarrayOrTensor] = None, mask_key: Optional[str] = None, dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, @@ -1527,7 +1646,7 @@ def __init__( self.transform = HistogramNormalize(num_bins=num_bins, min=min, max=max, mask=mask, dtype=dtype) self.mask_key = mask_key if mask is None else None - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.transform(d[key], d[self.mask_key]) if self.mask_key is not None else self.transform(d[key]) @@ -1551,6 +1670,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda RandAdjustContrastD = RandAdjustContrastDict = RandAdjustContrastd ScaleIntensityRangePercentilesD = ScaleIntensityRangePercentilesDict = ScaleIntensityRangePercentilesd MaskIntensityD = MaskIntensityDict = MaskIntensityd +SavitzkyGolaySmoothD = SavitzkyGolaySmoothDict = SavitzkyGolaySmoothd GaussianSmoothD = GaussianSmoothDict = GaussianSmoothd RandGaussianSmoothD = RandGaussianSmoothDict = RandGaussianSmoothd GaussianSharpenD = GaussianSharpenDict = GaussianSharpend @@ -1562,3 +1682,4 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda RandKSpaceSpikeNoiseD = RandKSpaceSpikeNoiseDict = RandKSpaceSpikeNoised RandCoarseDropoutD = RandCoarseDropoutDict = RandCoarseDropoutd HistogramNormalizeD = HistogramNormalizeDict = HistogramNormalized +RandCoarseShuffleD = RandCoarseShuffleDict = RandCoarseShuffled diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 58f3526086..c8bfeeca05 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -8,18 +8,80 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Hashable, Optional, Tuple +import os +from typing import Hashable, Mapping, Optional, Tuple import torch -from monai.transforms.transform import RandomizableTransform, Transform -from monai.utils.enums import InverseKeys +from monai.transforms.transform import Transform +from monai.utils.enums import TraceKeys + +__all__ = ["TraceableTransform", "InvertibleTransform"] + + +class TraceableTransform(Transform): + """ + Maintains a stack of applied transforms. The stack is inserted as pairs of + `trace_key: list of transforms` to each data dictionary. + + The ``__call__`` method of this transform class must be implemented so + that the transformation information for each key is stored when + ``__call__`` is called. If the transforms were applied to keys "image" and + "label", there will be two extra keys in the dictionary: "image_transforms" + and "label_transforms" (based on `TraceKeys.KEY_SUFFIX`). Each list + contains a list of the transforms applied to that key. + + The information in ``data[key_transform]`` will be compatible with the + default collate since it only stores strings, numbers and arrays. -__all__ = ["InvertibleTransform"] + `tracing` could be enabled by `self.set_tracing` or setting + `MONAI_TRACE_TRANSFORM` when initializing the class. + """ + + tracing = False if os.environ.get("MONAI_TRACE_TRANSFORM", "1") == "0" else True + + def set_tracing(self, tracing: bool) -> None: + """Set whether to trace transforms.""" + self.tracing = tracing + + @staticmethod + def trace_key(key: Hashable = None): + """The key to store the stack of applied transforms.""" + if key is None: + return TraceKeys.KEY_SUFFIX + return str(key) + TraceKeys.KEY_SUFFIX + + def push_transform( + self, data: Mapping, key: Hashable = None, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None + ) -> None: + """PUsh to a stack of applied transforms for that key.""" + if not self.tracing: + return + info = {TraceKeys.CLASS_NAME: self.__class__.__name__, TraceKeys.ID: id(self)} + if orig_size is not None: + info[TraceKeys.ORIG_SIZE] = orig_size + elif key in data and hasattr(data[key], "shape"): + info[TraceKeys.ORIG_SIZE] = data[key].shape[1:] + if extra_info is not None: + info[TraceKeys.EXTRA_INFO] = extra_info + # If class is randomizable transform, store whether the transform was actually performed (based on `prob`) + if hasattr(self, "_do_transform"): # RandomizableTransform + info[TraceKeys.DO_TRANSFORM] = self._do_transform # type: ignore + # If this is the first, create list + if self.trace_key(key) not in data: + if not isinstance(data, dict): + data = dict(data) + data[self.trace_key(key)] = [] + data[self.trace_key(key)].append(info) + + def pop_transform(self, data: Mapping, key: Hashable = None): + """Remove the most recent applied transform.""" + if not self.tracing: + return + return data.get(self.trace_key(key), []).pop() -class InvertibleTransform(Transform): +class InvertibleTransform(TraceableTransform): """Classes for invertible transforms. This class exists so that an ``invert`` method can be implemented. This allows, for @@ -27,28 +89,21 @@ class InvertibleTransform(Transform): and after be returned to their original size before saving to file for comparison in an external viewer. - When the ``__call__`` method is called, the transformation information for each key is - stored. If the transforms were applied to keys "image" and "label", there will be two - extra keys in the dictionary: "image_transforms" and "label_transforms". Each list - contains a list of the transforms applied to that key. When the ``inverse`` method is - called, the inverse is called on each key individually, which allows for different - parameters being passed to each label (e.g., different interpolation for image and - label). + When the ``inverse`` method is called: - When the ``inverse`` method is called, the inverse transforms are applied in a last- - in-first-out order. As the inverse is applied, its entry is removed from the list - detailing the applied transformations. That is to say that during the forward pass, - the list of applied transforms grows, and then during the inverse it shrinks back - down to an empty list. + - the inverse is called on each key individually, which allows for + different parameters being passed to each label (e.g., different + interpolation for image and label). - The information in ``data[key_transform]`` will be compatible with the default collate - since it only stores strings, numbers and arrays. + - the inverse transforms are applied in a last- in-first-out order. As + the inverse is applied, its entry is removed from the list detailing + the applied transformations. That is to say that during the forward + pass, the list of applied transforms grows, and then during the + inverse it shrinks back down to an empty list. We currently check that the ``id()`` of the transform is the same in the forward and inverse directions. This is a useful check to ensure that the inverses are being - processed in the correct order. However, this may cause issues if the ``id()`` of the - object changes (such as multiprocessing on Windows). If you feel this issue affects - you, please raise a GitHub issue. + processed in the correct order. Note to developers: When converting a transform to an invertible transform, you need to: @@ -63,55 +118,25 @@ class InvertibleTransform(Transform): """ - def push_transform( - self, - data: dict, - key: Hashable, - extra_info: Optional[dict] = None, - orig_size: Optional[Tuple] = None, - ) -> None: - """Append to list of applied transforms for that key.""" - key_transform = str(key) + InverseKeys.KEY_SUFFIX - info = { - InverseKeys.CLASS_NAME: self.__class__.__name__, - InverseKeys.ID: id(self), - } - if orig_size is not None: - info[InverseKeys.ORIG_SIZE] = orig_size - elif hasattr(data[key], "shape"): - info[InverseKeys.ORIG_SIZE] = data[key].shape[1:] - if extra_info is not None: - info[InverseKeys.EXTRA_INFO] = extra_info - # If class is randomizable transform, store whether the transform was actually performed (based on `prob`) - if isinstance(self, RandomizableTransform): - info[InverseKeys.DO_TRANSFORM] = self._do_transform - # If this is the first, create list - if key_transform not in data: - data[key_transform] = [] - data[key_transform].append(info) - - def check_transforms_match(self, transform: dict) -> None: + def check_transforms_match(self, transform: Mapping) -> None: """Check transforms are of same instance.""" - if transform[InverseKeys.ID] == id(self): + xform_name = transform.get(TraceKeys.CLASS_NAME, "") + xform_id = transform.get(TraceKeys.ID, "") + if xform_id == id(self): return # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID) - if ( - torch.multiprocessing.get_start_method() in ("spawn", None) - and transform[InverseKeys.CLASS_NAME] == self.__class__.__name__ - ): + if torch.multiprocessing.get_start_method() in ("spawn", None) and xform_name == self.__class__.__name__: return - raise RuntimeError("Should inverse most recently applied invertible transform first") + raise RuntimeError(f"Error inverting the most recently applied invertible transform {xform_name} {xform_id}.") - def get_most_recent_transform(self, data: dict, key: Hashable) -> dict: + def get_most_recent_transform(self, data: Mapping, key: Hashable = None): """Get most recent transform.""" - transform = dict(data[str(key) + InverseKeys.KEY_SUFFIX][-1]) + if not self.tracing: + raise RuntimeError("Transform Tracing must be enabled to get the most recent transform.") + transform = data[self.trace_key(key)][-1] self.check_transforms_match(transform) return transform - def pop_transform(self, data: dict, key: Hashable) -> None: - """Remove most recent transform.""" - data[str(key) + InverseKeys.KEY_SUFFIX].pop() - def inverse(self, data: dict) -> dict: """ Inverse of ``__call__``. diff --git a/monai/transforms/inverse_batch_transform.py b/monai/transforms/inverse_batch_transform.py index d9c6790840..cc77a199dd 100644 --- a/monai/transforms/inverse_batch_transform.py +++ b/monai/transforms/inverse_batch_transform.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,22 +17,17 @@ from monai.config import KeysCollection from monai.data.dataloader import DataLoader -from monai.data.utils import decollate_batch, no_collation, pad_list_data_collate, rep_scalar_to_batch +from monai.data.utils import decollate_batch, no_collation, pad_list_data_collate from monai.transforms.croppad.batch import PadListDataCollate from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Transform from monai.utils import first -__all__ = ["BatchInverseTransform", "Decollated"] +__all__ = ["BatchInverseTransform", "Decollated", "DecollateD", "DecollateDict"] class _BatchInverseDataset(Dataset): - def __init__( - self, - data: Sequence[Any], - transform: InvertibleTransform, - pad_collation_used: bool, - ) -> None: + def __init__(self, data: Sequence[Any], transform: InvertibleTransform, pad_collation_used: bool) -> None: self.data = data self.invertible_transform = transform self.pad_collation_used = pad_collation_used @@ -65,6 +60,8 @@ def __init__( collate_fn: Optional[Callable] = no_collation, num_workers: Optional[int] = 0, detach: bool = True, + pad_batch: bool = True, + fill_value=None, ) -> None: """ Args: @@ -78,6 +75,10 @@ def __init__( if set to `None`, use the `num_workers` of the transform data loader. detach: whether to detach the tensors. Scalars tensors will be detached into number types instead of torch tensors. + pad_batch: when the items in a batch indicate different batch size, + whether to pad all the sequences to the longest. + If False, the batch size will be the length of the shortest sequence. + fill_value: the value to fill the padded sequences when `pad_batch=True`. """ self.transform = transform @@ -85,10 +86,12 @@ def __init__( self.num_workers = loader.num_workers if num_workers is None else num_workers self.collate_fn = collate_fn self.detach = detach + self.pad_batch = pad_batch + self.fill_value = fill_value self.pad_collation_used = loader.collate_fn.__doc__ == pad_list_data_collate.__doc__ def __call__(self, data: Dict[str, Any]) -> Any: - decollated_data = decollate_batch(data, detach=self.detach) + decollated_data = decollate_batch(data, detach=self.detach, pad=self.pad_batch, fill_value=self.fill_value) inv_ds = _BatchInverseDataset(decollated_data, self.transform, self.pad_collation_used) inv_loader = DataLoader( inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn @@ -99,21 +102,25 @@ def __call__(self, data: Dict[str, Any]) -> Any: re_str = str(re) if "equal size" in re_str: re_str += "\nMONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`." - raise RuntimeError(re_str) + raise RuntimeError(re_str) from re class Decollated(MapTransform): """ - Decollate a batch of data, if input a dictionary, it can also support to only decollate specified keys. - Note that unlike most MapTransforms, it will delete other keys not specified and if keys=None, will decollate - all the data in the input. - And it replicates the scalar values to every item of the decollated list. + Decollate a batch of data. If input is a dictionary, it also supports to only decollate specified keys. + Note that unlike most MapTransforms, it will delete the other keys that are not specified. + if `keys=None`, it will decollate all the data in the input. + It replicates the scalar values to every item of the decollated list. Args: keys: keys of the corresponding items to decollate, note that it will delete other keys not specified. if None, will decollate all the keys. see also: :py:class:`monai.transforms.compose.MapTransform`. detach: whether to detach the tensors. Scalars tensors will be detached into number types instead of torch tensors. + pad_batch: when the items in a batch indicate different batch size, + whether to pad all the sequences to the longest. + If False, the batch size will be the length of the shortest sequence. + fill_value: the value to fill the padded sequences when `pad_batch=True`. allow_missing_keys: don't raise exception if key is missing. """ @@ -122,10 +129,14 @@ def __init__( self, keys: Optional[KeysCollection] = None, detach: bool = True, + pad_batch: bool = True, + fill_value=None, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) self.detach = detach + self.pad_batch = pad_batch + self.fill_value = fill_value def __call__(self, data: Union[Dict, List]): d: Union[Dict, List] @@ -139,4 +150,7 @@ def __call__(self, data: Union[Dict, List]): for key in self.key_iterator(data): d[key] = data[key] - return decollate_batch(rep_scalar_to_batch(d), detach=self.detach) + return decollate_batch(d, detach=self.detach, pad=self.pad_batch, fill_value=self.fill_value) + + +DecollateD = DecollateDict = Decollated diff --git a/monai/transforms/io/__init__.py b/monai/transforms/io/__init__.py index 14ae193634..1e97f89407 100644 --- a/monai/transforms/io/__init__.py +++ b/monai/transforms/io/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 38a2861c8b..5bafd84eaf 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,22 +16,24 @@ import inspect import logging import sys +import traceback import warnings from pathlib import Path +from pydoc import locate from typing import Dict, List, Optional, Sequence, Union import numpy as np import torch -from monai.config import DtypeLike +from monai.config import DtypeLike, NdarrayOrTensor, PathLike +from monai.data import image_writer +from monai.data.folder_layout import FolderLayout from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader -from monai.data.nifti_saver import NiftiSaver -from monai.data.png_saver import PNGSaver from monai.transforms.transform import Transform +from monai.transforms.utility.array import EnsureChannelFirst from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key -from monai.utils import InterpolateMode, ensure_tuple, optional_import -from monai.utils.module import look_up_option +from monai.utils import InterpolateMode, OptionalImportError, ensure_tuple, look_up_option, optional_import nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") @@ -83,7 +85,7 @@ class LoadImage(Transform): - User-specified reader in the constructor of `LoadImage`. - Readers from the last to the first in the registered list. - Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader), - (npz, npy -> NumpyReader), (others -> ITKReader). + (npz, npy -> NumpyReader), (DICOM file -> ITKReader). See also: @@ -91,19 +93,28 @@ class LoadImage(Transform): """ - def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np.float32, *args, **kwargs) -> None: + def __init__( + self, + reader=None, + image_only: bool = False, + dtype: DtypeLike = np.float32, + ensure_channel_first: bool = False, + *args, + **kwargs, + ) -> None: """ Args: reader: reader to load image file and meta data - - if `reader` is None, a default set of `SUPPORTED_READERS` will be used. - - if `reader` is a string, the corresponding item in `SUPPORTED_READERS` will be used, - and a reader instance will be constructed with the `*args` and `**kwargs` parameters. - the supported reader names are: "nibabelreader", "pilreader", "itkreader", "numpyreader". + - if `reader` is a string, it's treated as a class name or dotted path + (such as ``"monai.data.ITKReader"``), the supported built-in reader classes are + ``"ITKReader"``, ``"NibabelReader"``, ``"NumpyReader"``. + a reader instance will be constructed with the `*args` and `**kwargs` parameters. - if `reader` is a reader class/instance, it will be registered to this loader accordingly. - image_only: if True return only the image volume, otherwise return image data array and header dict. dtype: if not None convert the loaded image to this data type. + ensure_channel_first: if `True` and loaded both image array and meta data, automatically convert + the image array shape to `channel first`. default to `False`. args: additional parameters for reader if providing a reader name. kwargs: additional parameters for reader if providing a reader name. @@ -113,7 +124,7 @@ def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np. or a tuple of two elements containing the data array, and the meta data in a dictionary format otherwise. - If `reader` is specified, the loader will attempt to use the specified readers and the default supported readers. This might introduce overheads when handling the exceptions of trying the incompatible loaders. - In this case, it is therefore recommended to set the most appropriate reader as + In this case, it is therefore recommended setting the most appropriate reader as the last item of the `reader` parameter. """ @@ -121,11 +132,16 @@ def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np. self.auto_select = reader is None self.image_only = image_only self.dtype = dtype + self.ensure_channel_first = ensure_channel_first self.readers: List[ImageReader] = [] for r in SUPPORTED_READERS: # set predefined readers as default try: self.register(SUPPORTED_READERS[r](*args, **kwargs)) + except OptionalImportError: + logging.getLogger(self.__class__.__name__).debug( + f"required package for reader {r} is not installed, or the version doesn't match requirement." + ) except TypeError: # the reader doesn't have the corresponding args/kwargs logging.getLogger(self.__class__.__name__).debug( f"{r} is not supported with the given parameters {args} {kwargs}." @@ -136,11 +152,19 @@ def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np. for _r in ensure_tuple(reader): if isinstance(_r, str): - the_reader = look_up_option(_r.lower(), SUPPORTED_READERS) + the_reader, has_built_in = optional_import("monai.data", name=f"{_r}") # search built-in + if not has_built_in: + the_reader = locate(f"{_r}") # search dotted path + if the_reader is None: + the_reader = look_up_option(_r.lower(), SUPPORTED_READERS) try: self.register(the_reader(*args, **kwargs)) + except OptionalImportError: + warnings.warn( + f"required package for reader {_r} is not installed, or the version doesn't match requirement." + ) except TypeError: # the reader doesn't have the corresponding args/kwargs - warnings.warn(f"{r} is not supported with the given parameters {args} {kwargs}.") + warnings.warn(f"{_r} is not supported with the given parameters {args} {kwargs}.") self.register(the_reader()) elif inspect.isclass(_r): self.register(_r(*args, **kwargs)) @@ -160,7 +184,7 @@ def register(self, reader: ImageReader): warnings.warn(f"Preferably the reader should inherit ImageReader, but got {type(reader)}.") self.readers.append(reader) - def __call__(self, filename: Union[Sequence[str], str, Path, Sequence[Path]], reader: Optional[ImageReader] = None): + def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Optional[ImageReader] = None): """ Load image file and meta data from the given filename(s). If `reader` is not specified, this class automatically chooses readers based on the @@ -176,8 +200,8 @@ def __call__(self, filename: Union[Sequence[str], str, Path, Sequence[Path]], re reader: runtime reader to load image file and meta data. """ - filename = tuple(str(s) for s in ensure_tuple(filename)) # allow Path objects - img = None + filename = tuple(f"{Path(s).expanduser()}" for s in ensure_tuple(filename)) # allow Path objects + img, err = None, [] if reader is not None: img = reader.read(filename) # runtime specified reader else: @@ -190,153 +214,208 @@ def __call__(self, filename: Union[Sequence[str], str, Path, Sequence[Path]], re try: img = reader.read(filename) except Exception as e: - logging.getLogger(self.__class__.__name__).debug( - f"{reader.__class__.__name__}: unable to load {filename}.\n" f"Error: {e}" + err.append(traceback.format_exc()) + logging.getLogger(self.__class__.__name__).debug(e, exc_info=True) + logging.getLogger(self.__class__.__name__).info( + f"{reader.__class__.__name__}: unable to load {filename}.\n" ) else: + err = [] break if img is None or reader is None: + if isinstance(filename, tuple) and len(filename) == 1: + filename = filename[0] + msg = "\n".join([f"{e}" for e in err]) raise RuntimeError( - f"can not find a suitable reader for file: {filename}.\n" + f"{self.__class__.__name__} cannot find a suitable reader for file: {filename}.\n" " Please install the reader libraries, see also the installation instructions:\n" " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" - f" The current registered: {self.readers}.\n" + f" The current registered: {self.readers}.\n{msg}" ) + img_array: NdarrayOrTensor img_array, meta_data = reader.get_data(img) - img_array = img_array.astype(self.dtype) + img_array = img_array.astype(self.dtype, copy=False) + if not isinstance(meta_data, dict): + raise ValueError("`meta_data` must be a dict.") + # make sure all elements in metadata are little endian + meta_data = switch_endianness(meta_data, "<") + if self.ensure_channel_first: + img_array = EnsureChannelFirst()(img_array, meta_data) if self.image_only: return img_array - meta_data[Key.FILENAME_OR_OBJ] = ensure_tuple(filename)[0] - # make sure all elements in metadata are little endian - meta_data = switch_endianness(meta_data, "<") + meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader return img_array, meta_data class SaveImage(Transform): """ - Save transformed data into files, support NIfTI and PNG formats. - It can work for both numpy array and PyTorch Tensor in both preprocessing transform - chain and postprocessing transform chain. - The name of saved file will be `{input_image_name}_{output_postfix}{output_ext}`, - where the input image name is extracted from the provided meta data dictionary. - If no meta data provided, use index from 0 as the filename prefix. - It can also save a list of PyTorch Tensor or numpy array without `batch dim`. + Save the image (in the form of torch tensor or numpy ndarray) and metadata dictionary into files. - Note: image should be channel-first shape: [C,H,W,[D]]. + The name of saved file will be `{input_image_name}_{output_postfix}{output_ext}`, + where the `input_image_name` is extracted from the provided metadata dictionary. + If no metadata provided, a running index starting from 0 will be used as the filename prefix. Args: output_dir: output image directory. output_postfix: a string appended to all output file names, default to `trans`. - output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`. - resample: whether to resample before saving the data array. - if saving PNG format image, based on the `spatial_shape` from metadata. - if saving NIfTI format image, based on the `original_affine` from metadata. - mode: This option is used when ``resample = True``. Defaults to ``"nearest"``. - - - NIfTI files {``"bilinear"``, ``"nearest"``} - Interpolation mode to calculate output values. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} - The interpolation mode. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - - padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. + output_ext: output file extension name. + output_dtype: data type for saving data. Defaults to ``np.float32``. + resample: whether to resample image (if needed) before saving the data array, + based on the `spatial_shape` (and `original_affine`) from metadata. + mode: This option is used when ``resample=True``. Defaults to ``"nearest"``. + Depending on the writers, the possible options are - - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``} - Padding mode for outside grid values. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - - PNG files - This option is ignored. + - {``"bilinear"``, ``"nearest"``, ``"bicubic"``}. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + - {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}. + See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. + Possible options are {``"zeros"``, ``"border"``, ``"reflection"``} + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling - [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. - it's used for PNG format only. + [0, 255] (uint8) or [0, 65535] (uint16). Default is `None` (no scaling). dtype: data type during resampling computation. Defaults to ``np.float64`` for best precision. if None, use the data type of input data. To be compatible with other modules, - the output data type is always ``np.float32``. - it's used for NIfTI format only. - output_dtype: data type for saving data. Defaults to ``np.float32``. - it's used for NIfTI format only. squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and - then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false, + then if C==1, it will be saved as (H,W,D). If D is also 1, it will be saved as (H,W). If `false`, image will always be saved as (H,W,D,C). - it's used for NIfTI format only. data_root_dir: if not empty, it specifies the beginning parts of the input file's - absolute path. it's used to compute `input_file_rel_path`, the relative path to the file from + absolute path. It's used to compute `input_file_rel_path`, the relative path to the file from `data_root_dir` to preserve folder structure when saving in case there are files in different - folders with the same file names. for example: - input_file_name: /foo/bar/test1/image.nii, - output_postfix: seg - output_ext: nii.gz - output_dir: /output, - data_root_dir: /foo/bar, - output will be: /output/test1/image/image_seg.nii.gz - separate_folder: whether to save every file in a separate folder, for example: if input filename is - `image.nii`, postfix is `seg` and folder_path is `output`, if `True`, save as: - `output/image/image_seg.nii`, if `False`, save as `output/image_seg.nii`. default to `True`. - print_log: whether to print log about the saved file path, etc. default to `True`. - + folders with the same file names. For example, with the following inputs: + + - input_file_name: `/foo/bar/test1/image.nii` + - output_postfix: `seg` + - output_ext: `.nii.gz` + - output_dir: `/output` + - data_root_dir: `/foo/bar` + + The output will be: /output/test1/image/image_seg.nii.gz + + separate_folder: whether to save every file in a separate folder. For example: for the input filename + `image.nii`, postfix `seg` and folder_path `output`, if `separate_folder=True`, it will be saved as: + `output/image/image_seg.nii`, if `False`, saving as `output/image_seg.nii`. Default to `True`. + print_log: whether to print logs when saving. Default to `True`. + output_format: an optional string of filename extension to specify the output image writer. + see also: `monai.data.image_writer.SUPPORTED_WRITERS`. + writer: a customised `monai.data.ImageWriter` subclass to save data arrays. + if `None`, use the default writer from `monai.data.image_writer` according to `output_ext`. + if it's a string, it's treated as a class name or dotted path (such as ``"monai.data.ITKWriter"``); + the supported built-in writer classes are ``"NibabelWriter"``, ``"ITKWriter"``, ``"PILWriter"``. + channel_dim: the index of the channel dimension. Default to `0`. + `None` to indicate no channel dimension. """ def __init__( self, - output_dir: Union[Path, str] = "./", + output_dir: PathLike = "./", output_postfix: str = "trans", output_ext: str = ".nii.gz", + output_dtype: DtypeLike = np.float32, resample: bool = True, mode: Union[GridSampleMode, InterpolateMode, str] = "nearest", padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, scale: Optional[int] = None, dtype: DtypeLike = np.float64, - output_dtype: DtypeLike = np.float32, squeeze_end_dims: bool = True, - data_root_dir: str = "", + data_root_dir: PathLike = "", separate_folder: bool = True, print_log: bool = True, + output_format: str = "", + writer: Union[image_writer.ImageWriter, str, None] = None, + channel_dim: Optional[int] = 0, ) -> None: - self.saver: Union[NiftiSaver, PNGSaver] - if output_ext in {".nii.gz", ".nii"}: - self.saver = NiftiSaver( - output_dir=output_dir, - output_postfix=output_postfix, - output_ext=output_ext, - resample=resample, - mode=GridSampleMode(mode), - padding_mode=padding_mode, - dtype=dtype, - output_dtype=output_dtype, - squeeze_end_dims=squeeze_end_dims, - data_root_dir=data_root_dir, - separate_folder=separate_folder, - print_log=print_log, - ) - elif output_ext == ".png": - self.saver = PNGSaver( - output_dir=output_dir, - output_postfix=output_postfix, - output_ext=output_ext, - resample=resample, - mode=InterpolateMode(mode), - scale=scale, - data_root_dir=data_root_dir, - separate_folder=separate_folder, - print_log=print_log, - ) - else: - raise ValueError(f"unsupported output extension: {output_ext}.") + self.folder_layout = FolderLayout( + output_dir=output_dir, + postfix=output_postfix, + extension=output_ext, + parent=separate_folder, + makedirs=True, + data_root_dir=data_root_dir, + ) + + self.output_ext = output_ext.lower() or output_format.lower() + if isinstance(writer, str): + writer_, has_built_in = optional_import("monai.data", name=f"{writer}") # search built-in + if not has_built_in: + writer_ = locate(f"{writer}") # search dotted path + if writer_ is None: + raise ValueError(f"writer {writer} not found") + writer = writer_ # type: ignore + self.writers = image_writer.resolve_writer(self.output_ext) if writer is None else (writer,) + self.writer_obj = None + + _output_dtype = output_dtype + if self.output_ext == ".png" and _output_dtype not in (np.uint8, np.uint16): + _output_dtype = np.uint8 + if self.output_ext == ".dcm" and _output_dtype not in (np.uint8, np.uint16): + _output_dtype = np.uint8 + self.init_kwargs = {"output_dtype": _output_dtype, "scale": scale} + self.data_kwargs = {"squeeze_end_dims": squeeze_end_dims, "channel_dim": channel_dim} + self.meta_kwargs = {"resample": resample, "mode": mode, "padding_mode": padding_mode, "dtype": dtype} + self.write_kwargs = {"verbose": print_log} + self._data_index = 0 + + def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): + """ + Set the options for the underlying writer by updating the `self.*_kwargs` dictionaries. + + The arguments correspond to the following usage: + + - `writer = ImageWriter(**init_kwargs)` + - `writer.set_data_array(array, **data_kwargs)` + - `writer.set_metadata(meta_data, **meta_kwargs)` + - `writer.write(filename, **write_kwargs)` + + """ + if init_kwargs is not None: + self.init_kwargs.update(init_kwargs) + if data_kwargs is not None: + self.data_kwargs.update(data_kwargs) + if meta_kwargs is not None: + self.meta_kwargs.update(meta_kwargs) + if write_kwargs is not None: + self.write_kwargs.update(write_kwargs) def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None): """ Args: - img: target data content that save into file. - meta_data: key-value pairs of meta_data corresponding to the data. - + img: target data content that save into file. The image should be channel-first, shape: `[C,H,W,[D]]`. + meta_data: key-value pairs of metadata corresponding to the data. """ - self.saver.save(img, meta_data) - - return img + subject = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) + patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None + filename = self.folder_layout.filename(subject=f"{subject}", idx=patch_index) + if meta_data and len(ensure_tuple(meta_data.get("spatial_shape", ()))) == len(img.shape): + self.data_kwargs["channel_dim"] = None + + err = [] + for writer_cls in self.writers: + try: + writer_obj = writer_cls(**self.init_kwargs) + writer_obj.set_data_array(data_array=img, **self.data_kwargs) + writer_obj.set_metadata(meta_dict=meta_data, **self.meta_kwargs) + writer_obj.write(filename, **self.write_kwargs) + self.writer_obj = writer_obj + except Exception as e: + err.append(traceback.format_exc()) + logging.getLogger(self.__class__.__name__).debug(e, exc_info=True) + logging.getLogger(self.__class__.__name__).info( + f"{writer_cls.__class__.__name__}: unable to write {filename}.\n" + ) + else: + self._data_index += 1 + return img + msg = "\n".join([f"{e}" for e in err]) + raise RuntimeError( + f"{self.__class__.__name__} cannot find a suitable writer for {filename}.\n" + " Please install the writer libraries, see also the installation instructions:\n" + " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" + f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}" + ) diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 764e20f838..30dedc7810 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,19 +21,16 @@ import numpy as np from monai.config import DtypeLike, KeysCollection +from monai.data import image_writer from monai.data.image_reader import ImageReader from monai.transforms.io.array import LoadImage, SaveImage from monai.transforms.transform import MapTransform from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, ensure_tuple, ensure_tuple_rep +from monai.utils.enums import PostFix -__all__ = [ - "LoadImaged", - "LoadImageD", - "LoadImageDict", - "SaveImaged", - "SaveImageD", - "SaveImageDict", -] +__all__ = ["LoadImaged", "LoadImageD", "LoadImageDict", "SaveImaged", "SaveImageD", "SaveImageDict"] + +DEFAULT_POST_FIX = PostFix.meta() class LoadImaged(MapTransform): @@ -52,13 +49,13 @@ class LoadImaged(MapTransform): - User-specified reader in the constructor of `LoadImage`. - Readers from the last to the first in the registered list. - Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader), - (npz, npy -> NumpyReader), (others -> ITKReader). + (npz, npy -> NumpyReader), (dcm, DICOM series and others -> ITKReader). Note: - If `reader` is specified, the loader will attempt to use the specified readers and the default supported readers. This might introduce overheads when handling the exceptions of trying the incompatible loaders. - In this case, it is therefore recommended to set the most appropriate reader as + In this case, it is therefore recommended setting the most appropriate reader as the last item of the `reader` parameter. See also: @@ -73,9 +70,10 @@ def __init__( reader: Optional[Union[ImageReader, str]] = None, dtype: DtypeLike = np.float32, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", + meta_key_postfix: str = DEFAULT_POST_FIX, overwriting: bool = False, image_only: bool = False, + ensure_channel_first: bool = False, allow_missing_keys: bool = False, *args, **kwargs, @@ -84,11 +82,14 @@ def __init__( Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` - reader: register reader to load image file and meta data, if None, still can register readers - at runtime or use the default readers. If a string of reader name provided, will construct - a reader object with the `*args` and `**kwargs` parameters, supported reader name: "NibabelReader", - "PILReader", "ITKReader", "NumpyReader". - dtype: if not None convert the loaded image data to this data type. + reader: reader to load image file and meta data + - if `reader` is None, a default set of `SUPPORTED_READERS` will be used. + - if `reader` is a string, it's treated as a class name or dotted path + (such as ``"monai.data.ITKReader"``), the supported built-in reader classes are + ``"ITKReader"``, ``"NibabelReader"``, ``"NumpyReader"``. + a reader instance will be constructed with the `*args` and `**kwargs` parameters. + - if `reader` is a reader class/instance, it will be registered to this loader accordingly. + dtype: if not None, convert the loaded image data to this data type. meta_keys: explicitly indicate the key to store the corresponding meta data dictionary. the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. @@ -96,16 +97,18 @@ def __init__( meta_key_postfix: if meta_keys is None, use `key_{postfix}` to store the metadata of the nifti image, default is `meta_dict`. The meta data is a dictionary object. For example, load nifti file for `image`, store the metadata into `image_meta_dict`. - overwriting: whether allow to overwrite existing meta data of same key. + overwriting: whether allow overwriting existing meta data of same key. default is False, which will raise exception if encountering existing key. image_only: if True return dictionary containing just only the image volumes, otherwise return dictionary containing image data array and header dict per input key. + ensure_channel_first: if `True` and loaded both image array and meta data, automatically convert + the image array shape to `channel first`. default to `False`. allow_missing_keys: don't raise exception if key is missing. args: additional parameters for reader if providing a reader name. kwargs: additional parameters for reader if providing a reader name. """ super().__init__(keys, allow_missing_keys) - self._loader = LoadImage(reader, image_only, dtype, *args, **kwargs) + self._loader = LoadImage(reader, image_only, dtype, ensure_channel_first, *args, **kwargs) if not isinstance(meta_key_postfix, str): raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.") self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) @@ -154,68 +157,63 @@ class SaveImaged(MapTransform): Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` - meta_keys: explicitly indicate the key of the corresponding meta data dictionary. - for example, for data with key `image`, the metadata by default is in `image_meta_dict`. - the meta data is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None and `key_{postfix}` was used to store the metadata in `LoadImaged`. - need the key to extract metadata to save images, default is `meta_dict`. - for example, for data with key `image`, the metadata by default is in `image_meta_dict`. - the meta data is a dictionary object which contains: filename, affine, original_shape, etc. - if no corresponding metadata, set to `None`. + meta_keys: explicitly indicate the key of the corresponding metadata dictionary. + For example, for data with key `image`, the metadata by default is in `image_meta_dict`. + The metadata is a dictionary contains values such as filename, original_shape. + This argument can be a sequence of string, map to the `keys`. + If `None`, will try to construct meta_keys by `key_{meta_key_postfix}`. + meta_key_postfix: if `meta_keys` is `None`, use `key_{meta_key_postfix}` to retrieve the metadict. output_dir: output image directory. output_postfix: a string appended to all output file names, default to `trans`. output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`. - resample: whether to resample before saving the data array. - if saving PNG format image, based on the `spatial_shape` from metadata. - if saving NIfTI format image, based on the `original_affine` from metadata. - mode: This option is used when ``resample = True``. Defaults to ``"nearest"``. - - - NIfTI files {``"bilinear"``, ``"nearest"``} - Interpolation mode to calculate output values. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} - The interpolation mode. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - - padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. + output_dtype: data type for saving data. Defaults to ``np.float32``. + resample: whether to resample image (if needed) before saving the data array, + based on the `spatial_shape` (and `original_affine`) from metadata. + mode: This option is used when ``resample=True``. Defaults to ``"nearest"``. + Depending on the writers, the possible options are: - - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``} - Padding mode for outside grid values. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - - PNG files - This option is ignored. + - {``"bilinear"``, ``"nearest"``, ``"bicubic"``}. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + - {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}. + See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. + Possible options are {``"zeros"``, ``"border"``, ``"reflection"``} + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling - [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. - it's used for PNG format only. + [0, 255] (uint8) or [0, 65535] (uint16). Default is `None` (no scaling). dtype: data type during resampling computation. Defaults to ``np.float64`` for best precision. if None, use the data type of input data. To be compatible with other modules, - the output data type is always ``np.float32``. - it's used for NIfTI format only. output_dtype: data type for saving data. Defaults to ``np.float32``. it's used for NIfTI format only. allow_missing_keys: don't raise exception if key is missing. squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and - then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false, + then if C==1, it will be saved as (H,W,D). If D is also 1, it will be saved as (H,W). If `false`, image will always be saved as (H,W,D,C). - it's used for NIfTI format only. data_root_dir: if not empty, it specifies the beginning parts of the input file's - absolute path. it's used to compute `input_file_rel_path`, the relative path to the file from + absolute path. It's used to compute `input_file_rel_path`, the relative path to the file from `data_root_dir` to preserve folder structure when saving in case there are files in different - folders with the same file names. for example: - input_file_name: /foo/bar/test1/image.nii, - output_postfix: seg - output_ext: nii.gz - output_dir: /output, - data_root_dir: /foo/bar, - output will be: /output/test1/image/image_seg.nii.gz - separate_folder: whether to save every file in a separate folder, for example: if input filename is - `image.nii`, postfix is `seg` and folder_path is `output`, if `True`, save as: - `output/image/image_seg.nii`, if `False`, save as `output/image_seg.nii`. default to `True`. - print_log: whether to print log about the saved file path, etc. default to `True`. + folders with the same file names. For example, with the following inputs: + + - input_file_name: `/foo/bar/test1/image.nii` + - output_postfix: `seg` + - output_ext: `.nii.gz` + - output_dir: `/output` + - data_root_dir: `/foo/bar` + + The output will be: /output/test1/image/image_seg.nii.gz + + separate_folder: whether to save every file in a separate folder. For example: for the input filename + `image.nii`, postfix `seg` and folder_path `output`, if `separate_folder=True`, it will be saved as: + `output/image/image_seg.nii`, if `False`, saving as `output/image_seg.nii`. Default to `True`. + print_log: whether to print logs when saving. Default to `True`. + output_format: an optional string to specify the output image writer. + see also: `monai.data.image_writer.SUPPORTED_WRITERS`. + writer: a customised `monai.data.ImageWriter` subclass to save data arrays. + if `None`, use the default writer from `monai.data.image_writer` according to `output_ext`. + if it's a string, it's treated as a class name or dotted path; + the supported built-in writer classes are ``"NibabelWriter"``, ``"ITKWriter"``, ``"PILWriter"``. """ @@ -223,7 +221,7 @@ def __init__( self, keys: KeysCollection, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", + meta_key_postfix: str = DEFAULT_POST_FIX, output_dir: Union[Path, str] = "./", output_postfix: str = "trans", output_ext: str = ".nii.gz", @@ -238,11 +236,13 @@ def __init__( data_root_dir: str = "", separate_folder: bool = True, print_log: bool = True, + output_format: str = "", + writer: Union[image_writer.ImageWriter, str, None] = None, ) -> None: super().__init__(keys, allow_missing_keys) self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self._saver = SaveImage( + self.saver = SaveImage( output_dir=output_dir, output_postfix=output_postfix, output_ext=output_ext, @@ -256,15 +256,20 @@ def __init__( data_root_dir=data_root_dir, separate_folder=separate_folder, print_log=print_log, + output_format=output_format, + writer=writer, ) + def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): + self.saver.set_options(init_kwargs, data_kwargs, meta_kwargs, write_kwargs) + def __call__(self, data): d = dict(data) for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): if meta_key is None and meta_key_postfix is not None: meta_key = f"{key}_{meta_key_postfix}" meta_data = d[meta_key] if meta_key is not None else None - self._saver(img=d[key], meta_data=meta_data) + self.saver(img=d[key], meta_data=meta_data) return d diff --git a/monai/transforms/nvtx.py b/monai/transforms/nvtx.py index 6dd5c3b0a3..f00145efbc 100644 --- a/monai/transforms/nvtx.py +++ b/monai/transforms/nvtx.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/post/__init__.py b/monai/transforms/post/__init__.py index 14ae193634..1e97f89407 100644 --- a/monai/transforms/post/__init__.py +++ b/monai/transforms/post/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 631947025c..6396435aa7 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,14 +18,15 @@ import numpy as np import torch -import torch.nn.functional as F -from monai.config import NdarrayTensor +from monai.config.type_definitions import NdarrayOrTensor from monai.networks import one_hot -from monai.networks.layers import GaussianFilter +from monai.networks.layers import GaussianFilter, apply_filter from monai.transforms.transform import Transform -from monai.transforms.utils import fill_holes, get_largest_connected_component_mask -from monai.utils import deprecated_arg, ensure_tuple, look_up_option +from monai.transforms.utils import fill_holes, get_largest_connected_component_mask, get_unique_labels +from monai.transforms.utils_pytorch_numpy_unification import unravel_index +from monai.utils import TransformBackends, convert_data_type, deprecated_arg, ensure_tuple, look_up_option +from monai.utils.type_conversion import convert_to_dst_type __all__ = [ "Activations", @@ -57,6 +58,8 @@ class Activations(Transform): """ + backend = [TransformBackends.TORCH] + def __init__(self, sigmoid: bool = False, softmax: bool = False, other: Optional[Callable] = None) -> None: self.sigmoid = sigmoid self.softmax = softmax @@ -66,11 +69,11 @@ def __init__(self, sigmoid: bool = False, softmax: bool = False, other: Optional def __call__( self, - img: torch.Tensor, + img: NdarrayOrTensor, sigmoid: Optional[bool] = None, softmax: Optional[bool] = None, other: Optional[Callable] = None, - ) -> torch.Tensor: + ) -> NdarrayOrTensor: """ Args: sigmoid: whether to execute sigmoid function on model output before transform. @@ -92,17 +95,17 @@ def __call__( raise TypeError(f"other must be None or callable but is {type(other).__name__}.") # convert to float as activation must operate on float tensor - img = img.float() + img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) if sigmoid or self.sigmoid: - img = torch.sigmoid(img) + img_t = torch.sigmoid(img_t) if softmax or self.softmax: - img = torch.softmax(img, dim=0) + img_t = torch.softmax(img_t, dim=0) act_func = self.other if other is None else other if act_func is not None: - img = act_func(img) - - return img + img_t = act_func(img_t) + out, *_ = convert_to_dst_type(img_t, img) + return out class AsDiscrete(Transform): @@ -118,91 +121,137 @@ class AsDiscrete(Transform): Args: argmax: whether to execute argmax function on input data before transform. Defaults to ``False``. - to_onehot: whether to convert input data into the one-hot format. - Defaults to ``False``. - num_classes: the number of classes to convert to One-Hot format. + to_onehot: if not None, convert input data into the one-hot format with specified number of classes. + Defaults to ``None``. + threshold: if not None, threshold the float values to int number 0 or 1 with specified theashold. Defaults to ``None``. - threshold_values: whether threshold the float value to int number 0 or 1. - Defaults to ``False``. - logit_thresh: the threshold value for thresholding operation.. - Defaults to ``0.5``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. + Example: + + >>> transform = AsDiscrete(argmax=True) + >>> print(transform(np.array([[[0.0, 1.0]], [[2.0, 3.0]]]))) + # [[[1.0, 1.0]]] + + >>> transform = AsDiscrete(threshold=0.6) + >>> print(transform(np.array([[[0.0, 0.5], [0.8, 3.0]]]))) + # [[[0.0, 0.0], [1.0, 1.0]]] + + >>> transform = AsDiscrete(argmax=True, to_onehot=2, threshold=0.5) + >>> print(transform(np.array([[[0.0, 1.0]], [[2.0, 3.0]]]))) + # [[[0.0, 0.0]], [[1.0, 1.0]]] + + .. deprecated:: 0.6.0 + ``n_classes`` is deprecated, use ``to_onehot`` instead. + + .. deprecated:: 0.7.0 + ``num_classes`` is deprecated, use ``to_onehot`` instead. + ``logit_thresh`` is deprecated, use ``threshold`` instead. + ``threshold_values`` is deprecated, use ``threshold`` instead. + """ - @deprecated_arg("n_classes", since="0.6") + backend = [TransformBackends.TORCH] + + @deprecated_arg(name="n_classes", new_name="num_classes", since="0.6", msg_suffix="please use `to_onehot` instead.") + @deprecated_arg("num_classes", since="0.7", msg_suffix="please use `to_onehot` instead.") + @deprecated_arg("logit_thresh", since="0.7", msg_suffix="please use `threshold` instead.") + @deprecated_arg( + name="threshold_values", new_name="threshold", since="0.7", msg_suffix="please use `threshold` instead." + ) def __init__( self, argmax: bool = False, - to_onehot: bool = False, - num_classes: Optional[int] = None, - threshold_values: bool = False, - logit_thresh: float = 0.5, + to_onehot: Optional[int] = None, + threshold: Optional[float] = None, rounding: Optional[str] = None, - n_classes: Optional[int] = None, + n_classes: Optional[int] = None, # deprecated + num_classes: Optional[int] = None, # deprecated + logit_thresh: float = 0.5, # deprecated + threshold_values: Optional[bool] = False, # deprecated ) -> None: - # in case the new num_classes is default but you still call deprecated n_classes - if n_classes is not None and num_classes is None: - num_classes = n_classes self.argmax = argmax + if isinstance(to_onehot, bool): # for backward compatibility + warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") + to_onehot = num_classes if to_onehot else None self.to_onehot = to_onehot - self.num_classes = num_classes - self.threshold_values = threshold_values - self.logit_thresh = logit_thresh + + if isinstance(threshold, bool): # for backward compatibility + warnings.warn("`threshold_values=True/False` is deprecated, please use `threshold=value` instead.") + threshold = logit_thresh if threshold else None + self.threshold = threshold + self.rounding = rounding - @deprecated_arg("n_classes", since="0.6") + @deprecated_arg(name="n_classes", new_name="num_classes", since="0.6", msg_suffix="please use `to_onehot` instead.") + @deprecated_arg("num_classes", since="0.7", msg_suffix="please use `to_onehot` instead.") + @deprecated_arg("logit_thresh", since="0.7", msg_suffix="please use `threshold` instead.") + @deprecated_arg( + name="threshold_values", new_name="threshold", since="0.7", msg_suffix="please use `threshold` instead." + ) def __call__( self, - img: torch.Tensor, + img: NdarrayOrTensor, argmax: Optional[bool] = None, - to_onehot: Optional[bool] = None, - num_classes: Optional[int] = None, - threshold_values: Optional[bool] = None, - logit_thresh: Optional[float] = None, + to_onehot: Optional[int] = None, + threshold: Optional[float] = None, rounding: Optional[str] = None, - n_classes: Optional[int] = None, - ) -> torch.Tensor: + n_classes: Optional[int] = None, # deprecated + num_classes: Optional[int] = None, # deprecated + logit_thresh: Optional[float] = None, # deprecated + threshold_values: Optional[bool] = None, # deprecated + ) -> NdarrayOrTensor: """ Args: img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`, will automatically add it. argmax: whether to execute argmax function on input data before transform. Defaults to ``self.argmax``. - to_onehot: whether to convert input data into the one-hot format. + to_onehot: if not None, convert input data into the one-hot format with specified number of classes. Defaults to ``self.to_onehot``. - num_classes: the number of classes to convert to One-Hot format. - Defaults to ``self.num_classes``. - threshold_values: whether threshold the float value to int number 0 or 1. - Defaults to ``self.threshold_values``. - logit_thresh: the threshold value for thresholding operation.. - Defaults to ``self.logit_thresh``. + threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value. + Defaults to ``self.threshold``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. + .. deprecated:: 0.6.0 + ``n_classes`` is deprecated, use ``to_onehot`` instead. + + .. deprecated:: 0.7.0 + ``num_classes`` is deprecated, use ``to_onehot`` instead. + ``logit_thresh`` is deprecated, use ``threshold`` instead. + ``threshold_values`` is deprecated, use ``threshold`` instead. + """ - # in case the new num_classes is default but you still call deprecated n_classes - if n_classes is not None and num_classes is None: - num_classes = n_classes + if isinstance(to_onehot, bool): + warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") + to_onehot = num_classes if to_onehot else None + if isinstance(threshold, bool): + warnings.warn("`threshold_values=True/False` is deprecated, please use `threshold=value` instead.") + threshold = logit_thresh if threshold else None + + img_t, *_ = convert_data_type(img, torch.Tensor) if argmax or self.argmax: - img = torch.argmax(img, dim=0, keepdim=True) + img_t = torch.argmax(img_t, dim=0, keepdim=True) - if to_onehot or self.to_onehot: - _nclasses = self.num_classes if num_classes is None else num_classes - if not isinstance(_nclasses, int): - raise AssertionError("One of self.num_classes or num_classes must be an integer") - img = one_hot(img, num_classes=_nclasses, dim=0) + to_onehot = self.to_onehot if to_onehot is None else to_onehot + if to_onehot is not None: + if not isinstance(to_onehot, int): + raise AssertionError("the number of classes for One-Hot must be an integer.") + img_t = one_hot(img_t, num_classes=to_onehot, dim=0) - if threshold_values or self.threshold_values: - img = img >= (self.logit_thresh if logit_thresh is None else logit_thresh) + threshold = self.threshold if threshold is None else threshold + if threshold is not None: + img_t = img_t >= threshold rounding = self.rounding if rounding is None else rounding if rounding is not None: - rounding = look_up_option(rounding, ["torchrounding"]) - img = torch.round(img) + look_up_option(rounding, ["torchrounding"]) + img_t = torch.round(img_t) - return img.float() + img, *_ = convert_to_dst_type(img_t, img, dtype=torch.float) + return img class KeepLargestConnectedComponent(Transform): @@ -211,21 +260,19 @@ class KeepLargestConnectedComponent(Transform): This transform can be used as a post-processing step to clean up over-segment areas in model output. The input is assumed to be a channel-first PyTorch Tensor: - 1) With shape (1, spatial_dim1[, spatial_dim2, ...]) and the values correspond to expected labels. - 2) With shape (C, spatial_dim1[, spatial_dim2, ...]) and the values should be 0, 1 on each labels. - - Note: - For single channel data, 0 will be treated as background and the over-segment pixels will be set to 0. - For one-hot data, the over-segment pixels will be set to 0 in its channel. + 1) For not OneHot format data, the values correspond to expected labels, + 0 will be treated as background and the over-segment pixels will be set to 0. + 2) For OneHot format data, the values should be 0, 1 on each labels, + the over-segment pixels will be set to 0 in its channel. For example: - Use KeepLargestConnectedComponent with applied_labels=[1], connectivity=1:: + Use with applied_labels=[1], is_onehot=False, connectivity=1:: [1, 0, 0] [0, 0, 0] [0, 1, 1] => [0, 1 ,1] [0, 1, 1] [0, 1, 1] - Use KeepLargestConnectedComponent with applied_labels[1, 2], independent=False, connectivity=1:: + Use with applied_labels=[1, 2], is_onehot=False, independent=False, connectivity=1:: [0, 0, 1, 0 ,0] [0, 0, 1, 0 ,0] [0, 2, 1, 1 ,1] [0, 2, 1, 1 ,1] @@ -233,7 +280,7 @@ class KeepLargestConnectedComponent(Transform): [1, 2, 0, 1 ,0] [1, 2, 0, 0 ,0] [2, 2, 0, 0 ,2] [2, 2, 0, 0 ,0] - Use KeepLargestConnectedComponent with applied_labels[1, 2], independent=True, connectivity=1:: + Use with applied_labels=[1, 2], is_onehot=False, independent=True, connectivity=1:: [0, 0, 1, 0 ,0] [0, 0, 1, 0 ,0] [0, 2, 1, 1 ,1] [0, 2, 1, 1 ,1] @@ -241,7 +288,7 @@ class KeepLargestConnectedComponent(Transform): [1, 2, 0, 1 ,0] [0, 2, 0, 0 ,0] [2, 2, 0, 0 ,2] [2, 2, 0, 0 ,0] - Use KeepLargestConnectedComponent with applied_labels[1, 2], independent=False, connectivity=2:: + Use with applied_labels=[1, 2], is_onehot=False, independent=False, connectivity=2:: [0, 0, 1, 0 ,0] [0, 0, 1, 0 ,0] [0, 2, 1, 1 ,1] [0, 2, 1, 1 ,1] @@ -251,68 +298,74 @@ class KeepLargestConnectedComponent(Transform): """ + backend = [TransformBackends.NUMPY] + def __init__( - self, applied_labels: Union[Sequence[int], int], independent: bool = True, connectivity: Optional[int] = None + self, + applied_labels: Optional[Union[Sequence[int], int]] = None, + is_onehot: Optional[bool] = None, + independent: bool = True, + connectivity: Optional[int] = None, ) -> None: """ Args: - applied_labels: Labels for applying the connected component on. - If only one channel. The pixel whose value is not in this list will remain unchanged. - If the data is in one-hot format, this is used to determine what channels to apply. - independent: consider several labels as a whole or independent, default is `True`. - Example use case would be segment label 1 is liver and label 2 is liver tumor, in that case - you want this "independent" to be specified as False. + applied_labels: Labels for applying the connected component analysis on. + If given, voxels whose value is in this list will be analyzed. + If `None`, all non-zero values will be analyzed. + is_onehot: if `True`, treat the input data as OneHot format data, otherwise, not OneHot format data. + default to None, which treats multi-channel data as OneHot and single channel data as not OneHot. + independent: whether to treat ``applied_labels`` as a union of foreground labels. + If ``True``, the connected component analysis will be performed on each foreground label independently + and return the intersection of the largest components. + If ``False``, the analysis will be performed on the union of foreground labels. + default is `True`. connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. Accepted values are ranging from 1 to input.ndim. If ``None``, a full - connectivity of ``input.ndim`` is used. + connectivity of ``input.ndim`` is used. for more details: + https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.label. + """ super().__init__() - self.applied_labels = ensure_tuple(applied_labels) + self.applied_labels = ensure_tuple(applied_labels) if applied_labels is not None else None + self.is_onehot = is_onehot self.independent = independent self.connectivity = connectivity - def __call__(self, img: torch.Tensor) -> torch.Tensor: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: shape must be (C, spatial_dim1[, spatial_dim2, ...]). Returns: - A PyTorch Tensor with shape (C, spatial_dim1[, spatial_dim2, ...]). + An array with shape (C, spatial_dim1[, spatial_dim2, ...]). """ - if img.shape[0] == 1: - img = torch.squeeze(img, dim=0) - - if self.independent: - for i in self.applied_labels: - foreground = (img == i).type(torch.uint8) - mask = get_largest_connected_component_mask(foreground, self.connectivity) - img[foreground != mask] = 0 - else: - foreground = torch.zeros_like(img) - for i in self.applied_labels: - foreground += (img == i).type(torch.uint8) - mask = get_largest_connected_component_mask(foreground, self.connectivity) - img[foreground != mask] = 0 - - output = torch.unsqueeze(img, dim=0) + is_onehot = img.shape[0] > 1 if self.is_onehot is None else self.is_onehot + if self.applied_labels is not None: + applied_labels = self.applied_labels else: - # one-hot data is assumed to have binary value in each channel - if self.independent: - for i in self.applied_labels: - foreground = img[i, ...].type(torch.uint8) - mask = get_largest_connected_component_mask(foreground, self.connectivity) - img[i, ...][foreground != mask] = 0 - else: - applied_img = img[self.applied_labels, ...].type(torch.uint8) - foreground = torch.any(applied_img, dim=0) - mask = get_largest_connected_component_mask(foreground, self.connectivity) - background_mask = torch.unsqueeze(foreground != mask, dim=0) - background_mask = torch.repeat_interleave(background_mask, len(self.applied_labels), dim=0) - applied_img[background_mask] = 0 - img[self.applied_labels, ...] = applied_img.type(img.type()) - output = img + applied_labels = tuple(get_unique_labels(img, is_onehot, discard=0)) - return output + if self.independent: + for i in applied_labels: + foreground = img[i] > 0 if is_onehot else img[0] == i + mask = get_largest_connected_component_mask(foreground, self.connectivity) + if is_onehot: + img[i][foreground != mask] = 0 + else: + img[0][foreground != mask] = 0 + return img + if not is_onehot: # not one-hot, union of labels + labels, *_ = convert_to_dst_type(applied_labels, dst=img, wrap_sequence=True) + foreground = (img[..., None] == labels).any(-1)[0] + mask = get_largest_connected_component_mask(foreground, self.connectivity) + img[0][foreground != mask] = 0 + return img + # one-hot, union of labels + foreground = (img[applied_labels, ...] == 1).any(0) + mask = get_largest_connected_component_mask(foreground, self.connectivity) + for i in applied_labels: + img[i][foreground != mask] = 0 + return img class LabelFilter: @@ -333,6 +386,8 @@ class LabelFilter: [7, 8, 9] [0, 0, 9] """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, applied_labels: Union[Iterable[int], int]) -> None: """ Initialize the LabelFilter class with the labels to filter on. @@ -342,7 +397,7 @@ def __init__(self, applied_labels: Union[Iterable[int], int]) -> None: """ self.applied_labels = ensure_tuple(applied_labels) - def __call__(self, img: NdarrayTensor) -> NdarrayTensor: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Filter the image on the `applied_labels`. @@ -355,13 +410,18 @@ def __call__(self, img: NdarrayTensor) -> NdarrayTensor: Returns: Pytorch tensor or numpy array of the same shape as the input. """ - if isinstance(img, np.ndarray): - return np.asarray(np.where(np.isin(img, self.applied_labels), img, 0)) + if not isinstance(img, (np.ndarray, torch.Tensor)): + raise NotImplementedError(f"{self.__class__} can not handle data of type {type(img)}.") + if isinstance(img, torch.Tensor): - img_arr = img.detach().cpu().numpy() - img_arr = self(img_arr) - return torch.as_tensor(img_arr, device=img.device) - raise NotImplementedError(f"{self.__class__} can not handle data of type {type(img)}.") + if hasattr(torch, "isin"): # `isin` is new in torch 1.10.0 + appl_lbls = torch.as_tensor(self.applied_labels, device=img.device) + return torch.where(torch.isin(img, appl_lbls), img, torch.tensor(0.0).to(img)) + else: + out = self(img.detach().cpu().numpy()) + out, *_ = convert_to_dst_type(out, img) + return out + return np.asarray(np.where(np.isin(img, self.applied_labels), img, 0)) class FillHoles(Transform): @@ -404,6 +464,8 @@ class FillHoles(Transform): The background label near label 2 and 3 is not fully enclosed and therefore not filled. """ + backend = [TransformBackends.NUMPY] + def __init__( self, applied_labels: Optional[Union[Iterable[int], int]] = None, connectivity: Optional[int] = None ) -> None: @@ -419,7 +481,7 @@ def __init__( self.applied_labels = ensure_tuple(applied_labels) if applied_labels else None self.connectivity = connectivity - def __call__(self, img: NdarrayTensor) -> NdarrayTensor: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Fill the holes in the provided image. @@ -435,18 +497,17 @@ def __call__(self, img: NdarrayTensor) -> NdarrayTensor: Returns: Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]]. """ - if isinstance(img, np.ndarray): - return fill_holes(img, self.applied_labels, self.connectivity) - if isinstance(img, torch.Tensor): - img_arr = img.detach().cpu().numpy() - img_arr = self(img_arr) - return torch.as_tensor(img_arr, device=img.device) - raise NotImplementedError(f"{self.__class__} can not handle data of type {type(img)}.") + if not isinstance(img, (np.ndarray, torch.Tensor)): + raise NotImplementedError(f"{self.__class__} can not handle data of type {type(img)}.") + img_np, *_ = convert_data_type(img, np.ndarray) + out_np: np.ndarray = fill_holes(img_np, self.applied_labels, self.connectivity) + out, *_ = convert_to_dst_type(out_np, img) + return out class LabelToContour(Transform): """ - Return the contour of binary input images that only compose of 0 and 1, with Laplace kernel + Return the contour of binary input images that only compose of 0 and 1, with Laplacian kernel set as default for edge detection. Typical usage is to plot the edge of label or segmentation output. Args: @@ -457,12 +518,14 @@ class LabelToContour(Transform): """ + backend = [TransformBackends.TORCH] + def __init__(self, kernel_type: str = "Laplace") -> None: if kernel_type != "Laplace": raise NotImplementedError('Currently only kernel_type="Laplace" is supported.') self.kernel_type = kernel_type - def __call__(self, img: torch.Tensor) -> torch.Tensor: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: torch tensor data to extract the contour, with shape: [channels, height, width[, depth]] @@ -478,25 +541,41 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: ideally the edge should be thin enough, but now it has a thickness. """ - channels = img.shape[0] - img_ = img.unsqueeze(0) - if img.ndimension() == 3: - kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32, device=img.device) - kernel = kernel.repeat(channels, 1, 1, 1) - contour_img = F.conv2d(img_, kernel, bias=None, stride=1, padding=1, dilation=1, groups=channels) - elif img.ndimension() == 4: - kernel = -1 * torch.ones(3, 3, 3, dtype=torch.float32, device=img.device) - kernel[1, 1, 1] = 26 - kernel = kernel.repeat(channels, 1, 1, 1, 1) - contour_img = F.conv3d(img_, kernel, bias=None, stride=1, padding=1, dilation=1, groups=channels) + img_: torch.Tensor = convert_data_type(img, torch.Tensor)[0] + spatial_dims = len(img_.shape) - 1 + img_ = img_.unsqueeze(0) # adds a batch dim + if spatial_dims == 2: + kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32) + elif spatial_dims == 3: + kernel = -1.0 * torch.ones(3, 3, 3, dtype=torch.float32) + kernel[1, 1, 1] = 26.0 else: - raise ValueError(f"Unsupported img dimension: {img.ndimension()}, available options are [4, 5].") - + raise ValueError(f"{self.__class__} can only handle 2D or 3D images.") + contour_img = apply_filter(img_, kernel) contour_img.clamp_(min=0.0, max=1.0) - return contour_img.squeeze(0) + output, *_ = convert_to_dst_type(contour_img.squeeze(0), img) + return output -class MeanEnsemble(Transform): +class Ensemble: + @staticmethod + def get_stacked_torch(img: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor]) -> torch.Tensor: + """Get either a sequence or single instance of np.ndarray/torch.Tensor. Return single torch.Tensor.""" + if isinstance(img, Sequence) and isinstance(img[0], np.ndarray): + img = [torch.as_tensor(i) for i in img] + elif isinstance(img, np.ndarray): + img = torch.as_tensor(img) + out: torch.Tensor = torch.stack(img) if isinstance(img, Sequence) else img # type: ignore + return out + + @staticmethod + def post_convert(img: torch.Tensor, orig_img: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor]) -> NdarrayOrTensor: + orig_img_ = orig_img[0] if isinstance(orig_img, Sequence) else orig_img + out, *_ = convert_to_dst_type(img, orig_img_) + return out + + +class MeanEnsemble(Ensemble, Transform): """ Execute mean ensemble on the input data. The input data can be a list or tuple of PyTorch Tensor with shape: [C[, H, W, D]], @@ -519,11 +598,13 @@ class MeanEnsemble(Transform): """ - def __init__(self, weights: Optional[Union[Sequence[float], torch.Tensor, np.ndarray]] = None) -> None: + backend = [TransformBackends.TORCH] + + def __init__(self, weights: Optional[Union[Sequence[float], NdarrayOrTensor]] = None) -> None: self.weights = torch.as_tensor(weights, dtype=torch.float) if weights is not None else None - def __call__(self, img: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Tensor: - img_ = torch.stack(img) if isinstance(img, (tuple, list)) else torch.as_tensor(img) + def __call__(self, img: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor]) -> NdarrayOrTensor: + img_ = self.get_stacked_torch(img) if self.weights is not None: self.weights = self.weights.to(img_.device) shape = tuple(self.weights.shape) @@ -533,10 +614,11 @@ def __call__(self, img: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Te img_ = img_ * weights / weights.mean(dim=0, keepdim=True) - return torch.mean(img_, dim=0) + out_pt = torch.mean(img_, dim=0) + return self.post_convert(out_pt, img) -class VoteEnsemble(Transform): +class VoteEnsemble(Ensemble, Transform): """ Execute vote ensemble on the input data. The input data can be a list or tuple of PyTorch Tensor with shape: [C[, H, W, D]], @@ -556,11 +638,14 @@ class VoteEnsemble(Transform): """ + backend = [TransformBackends.TORCH] + def __init__(self, num_classes: Optional[int] = None) -> None: self.num_classes = num_classes - def __call__(self, img: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Tensor: - img_ = torch.stack(img) if isinstance(img, (tuple, list)) else torch.as_tensor(img) + def __call__(self, img: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor]) -> NdarrayOrTensor: + img_ = self.get_stacked_torch(img) + if self.num_classes is not None: has_ch_dim = True if img_.ndimension() > 1 and img_.shape[1] > 1: @@ -575,9 +660,11 @@ def __call__(self, img: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Te if self.num_classes is not None: # if not One-Hot, use "argmax" to vote the most common class - return torch.argmax(img_, dim=0, keepdim=has_ch_dim) - # for One-Hot data, round the float number to 0 or 1 - return torch.round(img_) + out_pt = torch.argmax(img_, dim=0, keepdim=has_ch_dim) + else: + # for One-Hot data, round the float number to 0 or 1 + out_pt = torch.round(img_) + return self.post_convert(out_pt, img) class ProbNMS(Transform): @@ -611,6 +698,8 @@ class ProbNMS(Transform): """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, spatial_dims: int = 2, @@ -627,9 +716,9 @@ def __init__( self.prob_threshold = prob_threshold if isinstance(box_size, int): self.box_size = np.asarray([box_size] * spatial_dims) + elif len(box_size) != spatial_dims: + raise ValueError("the sequence length of box_size should be the same as spatial_dims.") else: - if len(box_size) != spatial_dims: - raise ValueError("the sequence length of box_size should be the same as spatial_dims.") self.box_size = np.asarray(box_size) if self.box_size.min() <= 0: raise ValueError("box_size should be larger than 0.") @@ -637,10 +726,7 @@ def __init__( self.box_lower_bd = self.box_size // 2 self.box_upper_bd = self.box_size - self.box_lower_bd - def __call__( - self, - prob_map: Union[np.ndarray, torch.Tensor], - ): + def __call__(self, prob_map: NdarrayOrTensor): """ prob_map: the input probabilities map, it must have shape (H[, W, ...]). """ @@ -649,24 +735,19 @@ def __call__( prob_map = torch.as_tensor(prob_map, dtype=torch.float) self.filter.to(prob_map) prob_map = self.filter(prob_map) - else: - if not isinstance(prob_map, torch.Tensor): - prob_map = prob_map.copy() - - if isinstance(prob_map, torch.Tensor): - prob_map = prob_map.detach().cpu().numpy() prob_map_shape = prob_map.shape outputs = [] - while np.max(prob_map) > self.prob_threshold: - max_idx = np.unravel_index(prob_map.argmax(), prob_map_shape) - prob_max = prob_map[max_idx] - max_idx_arr = np.asarray(max_idx) - outputs.append([prob_max] + list(max_idx_arr)) - - idx_min_range = (max_idx_arr - self.box_lower_bd).clip(0, None) - idx_max_range = (max_idx_arr + self.box_upper_bd).clip(None, prob_map_shape) + while prob_map.max() > self.prob_threshold: + max_idx = unravel_index(prob_map.argmax(), prob_map_shape) + prob_max = prob_map[tuple(max_idx)] + max_idx = max_idx.cpu().numpy() if isinstance(max_idx, torch.Tensor) else max_idx + prob_max = prob_max.item() if isinstance(prob_max, torch.Tensor) else prob_max + outputs.append([prob_max] + list(max_idx)) + + idx_min_range = (max_idx - self.box_lower_bd).clip(0, None) + idx_max_range = (max_idx + self.box_upper_bd).clip(None, prob_map_shape) # for each dimension, set values during index ranges to 0 slices = tuple(slice(idx_min_range[i], idx_max_range[i]) for i in range(self.spatial_dims)) prob_map[slices] = 0 diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 2fc3993e3e..00ffe7edf7 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,10 +19,9 @@ from copy import deepcopy from typing import Any, Callable, Dict, Hashable, Iterable, List, Mapping, Optional, Sequence, Union -import numpy as np import torch -from monai.config import KeysCollection, NdarrayTensor +from monai.config.type_definitions import KeysCollection, NdarrayOrTensor, PathLike from monai.data.csv_saver import CSVSaver from monai.transforms.inverse import InvertibleTransform from monai.transforms.post.array import ( @@ -40,7 +39,7 @@ from monai.transforms.utility.array import ToTensor from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode from monai.utils import deprecated_arg, ensure_tuple, ensure_tuple_rep -from monai.utils.enums import InverseKeys +from monai.utils.enums import PostFix __all__ = [ "ActivationsD", @@ -50,6 +49,8 @@ "AsDiscreteDict", "AsDiscreted", "Ensembled", + "EnsembleD", + "EnsembleDict", "FillHolesD", "FillHolesDict", "FillHolesd", @@ -79,6 +80,8 @@ "VoteEnsembled", ] +DEFAULT_POST_FIX = PostFix.meta() + class Activationsd(MapTransform): """ @@ -86,6 +89,8 @@ class Activationsd(MapTransform): Add activation layers to the input data specified by `keys`. """ + backend = Activations.backend + def __init__( self, keys: KeysCollection, @@ -114,7 +119,7 @@ def __init__( self.other = ensure_tuple_rep(other, len(self.keys)) self.converter = Activations() - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, sigmoid, softmax, other in self.key_iterator(d, self.sigmoid, self.softmax, self.other): d[key] = self.converter(d[key], sigmoid, softmax, other) @@ -126,18 +131,26 @@ class AsDiscreted(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.AsDiscrete`. """ - @deprecated_arg("n_classes", since="0.6") + backend = AsDiscrete.backend + + @deprecated_arg(name="n_classes", new_name="num_classes", since="0.6", msg_suffix="please use `to_onehot` instead.") + @deprecated_arg("num_classes", since="0.7", msg_suffix="please use `to_onehot` instead.") + @deprecated_arg("logit_thresh", since="0.7", msg_suffix="please use `threshold` instead.") + @deprecated_arg( + name="threshold_values", new_name="threshold", since="0.7", msg_suffix="please use `threshold` instead." + ) def __init__( self, keys: KeysCollection, argmax: Union[Sequence[bool], bool] = False, - to_onehot: Union[Sequence[bool], bool] = False, - num_classes: Optional[Union[Sequence[int], int]] = None, - threshold_values: Union[Sequence[bool], bool] = False, - logit_thresh: Union[Sequence[float], float] = 0.5, + to_onehot: Union[Sequence[Optional[int]], Optional[int]] = None, + threshold: Union[Sequence[Optional[float]], Optional[float]] = None, rounding: Union[Sequence[Optional[str]], Optional[str]] = None, allow_missing_keys: bool = False, - n_classes: Optional[int] = None, + n_classes: Optional[Union[Sequence[int], int]] = None, # deprecated + num_classes: Optional[Union[Sequence[int], int]] = None, # deprecated + logit_thresh: Union[Sequence[float], float] = 0.5, # deprecated + threshold_values: Union[Sequence[bool], bool] = False, # deprecated ) -> None: """ Args: @@ -145,46 +158,55 @@ def __init__( See also: :py:class:`monai.transforms.compose.MapTransform` argmax: whether to execute argmax function on input data before transform. it also can be a sequence of bool, each element corresponds to a key in ``keys``. - to_onehot: whether to convert input data into the one-hot format. Defaults to False. - it also can be a sequence of bool, each element corresponds to a key in ``keys``. - num_classes: the number of classes to convert to One-Hot format. it also can be a - sequence of int, each element corresponds to a key in ``keys``. - threshold_values: whether threshold the float value to int number 0 or 1, default is False. - it also can be a sequence of bool, each element corresponds to a key in ``keys``. - logit_thresh: the threshold value for thresholding operation, default is 0.5. - it also can be a sequence of float, each element corresponds to a key in ``keys``. + to_onehot: if not None, convert input data into the one-hot format with specified number of classes. + defaults to ``None``. it also can be a sequence, each element corresponds to a key in ``keys``. + threshold: if not None, threshold the float values to int number 0 or 1 with specified theashold value. + defaults to ``None``. it also can be a sequence, each element corresponds to a key in ``keys``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. it also can be a sequence of str or None, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. + .. deprecated:: 0.6.0 + ``n_classes`` is deprecated, use ``to_onehot`` instead. + + .. deprecated:: 0.7.0 + ``num_classes`` is deprecated, use ``to_onehot`` instead. + ``logit_thresh`` is deprecated, use ``threshold`` instead. + ``threshold_values`` is deprecated, use ``threshold`` instead. + """ - # in case the new num_classes is default but you still call deprecated n_classes - if n_classes is not None and num_classes is None: - num_classes = n_classes super().__init__(keys, allow_missing_keys) self.argmax = ensure_tuple_rep(argmax, len(self.keys)) - self.to_onehot = ensure_tuple_rep(to_onehot, len(self.keys)) - self.num_classes = ensure_tuple_rep(num_classes, len(self.keys)) - self.threshold_values = ensure_tuple_rep(threshold_values, len(self.keys)) - self.logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys)) + to_onehot_ = ensure_tuple_rep(to_onehot, len(self.keys)) + num_classes = ensure_tuple_rep(num_classes, len(self.keys)) + self.to_onehot = [] + for flag, val in zip(to_onehot_, num_classes): + if isinstance(flag, bool): + warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") + self.to_onehot.append(val if flag else None) + else: + self.to_onehot.append(flag) + + threshold_ = ensure_tuple_rep(threshold, len(self.keys)) + logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys)) + self.threshold = [] + for flag, val in zip(threshold_, logit_thresh): + if isinstance(flag, bool): + warnings.warn("`threshold_values=True/False` is deprecated, please use `threshold=value` instead.") + self.threshold.append(val if flag else None) + else: + self.threshold.append(flag) + self.rounding = ensure_tuple_rep(rounding, len(self.keys)) self.converter = AsDiscrete() - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - for key, argmax, to_onehot, num_classes, threshold_values, logit_thresh, rounding in self.key_iterator( - d, self.argmax, self.to_onehot, self.num_classes, self.threshold_values, self.logit_thresh, self.rounding + for key, argmax, to_onehot, threshold, rounding in self.key_iterator( + d, self.argmax, self.to_onehot, self.threshold, self.rounding ): - d[key] = self.converter( - d[key], - argmax, - to_onehot, - num_classes, - threshold_values, - logit_thresh, - rounding, - ) + d[key] = self.converter(d[key], argmax, to_onehot, threshold, rounding) return d @@ -193,10 +215,13 @@ class KeepLargestConnectedComponentd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.KeepLargestConnectedComponent`. """ + backend = KeepLargestConnectedComponent.backend + def __init__( self, keys: KeysCollection, - applied_labels: Union[Sequence[int], int], + applied_labels: Optional[Union[Sequence[int], int]] = None, + is_onehot: Optional[bool] = None, independent: bool = True, connectivity: Optional[int] = None, allow_missing_keys: bool = False, @@ -205,22 +230,29 @@ def __init__( Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` - applied_labels: Labels for applying the connected component on. - If only one channel. The pixel whose value is not in this list will remain unchanged. - If the data is in one-hot format, this is the channel indices to apply transform. - independent: consider several labels as a whole or independent, default is `True`. - Example use case would be segment label 1 is liver and label 2 is liver tumor, in that case - you want this "independent" to be specified as False. + applied_labels: Labels for applying the connected component analysis on. + If given, voxels whose value is in this list will be analyzed. + If `None`, all non-zero values will be analyzed. + is_onehot: if `True`, treat the input data as OneHot format data, otherwise, not OneHot format data. + default to None, which treats multi-channel data as OneHot and single channel data as not OneHot. + independent: whether to treat ``applied_labels`` as a union of foreground labels. + If ``True``, the connected component analysis will be performed on each foreground label independently + and return the intersection of the largest components. + If ``False``, the analysis will be performed on the union of foreground labels. + default is `True`. connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. Accepted values are ranging from 1 to input.ndim. If ``None``, a full - connectivity of ``input.ndim`` is used. + connectivity of ``input.ndim`` is used. for more details: + https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.label. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) - self.converter = KeepLargestConnectedComponent(applied_labels, independent, connectivity) + self.converter = KeepLargestConnectedComponent( + applied_labels=applied_labels, is_onehot=is_onehot, independent=independent, connectivity=connectivity + ) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -232,11 +264,10 @@ class LabelFilterd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.LabelFilter`. """ + backend = LabelFilter.backend + def __init__( - self, - keys: KeysCollection, - applied_labels: Union[Sequence[int], int], - allow_missing_keys: bool = False, + self, keys: KeysCollection, applied_labels: Union[Sequence[int], int], allow_missing_keys: bool = False ) -> None: """ Args: @@ -249,7 +280,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = LabelFilter(applied_labels) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -261,6 +292,8 @@ class FillHolesd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.FillHoles`. """ + backend = FillHoles.backend + def __init__( self, keys: KeysCollection, @@ -284,7 +317,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = FillHoles(applied_labels=applied_labels, connectivity=connectivity) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -296,6 +329,8 @@ class LabelToContourd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.LabelToContour`. """ + backend = LabelToContour.backend + def __init__(self, keys: KeysCollection, kernel_type: str = "Laplace", allow_missing_keys: bool = False) -> None: """ Args: @@ -308,7 +343,7 @@ def __init__(self, keys: KeysCollection, kernel_type: str = "Laplace", allow_mis super().__init__(keys, allow_missing_keys) self.converter = LabelToContour(kernel_type=kernel_type) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -321,10 +356,12 @@ class Ensembled(MapTransform): """ + backend = list(set(VoteEnsemble.backend) & set(MeanEnsemble.backend)) + def __init__( self, keys: KeysCollection, - ensemble: Callable[[Union[Sequence[torch.Tensor], torch.Tensor]], torch.Tensor], + ensemble: Callable[[Union[Sequence[NdarrayOrTensor], NdarrayOrTensor]], NdarrayOrTensor], output_key: Optional[str] = None, allow_missing_keys: bool = False, ) -> None: @@ -350,14 +387,16 @@ def __init__( raise ValueError("Incompatible values: len(self.keys) > 1 and output_key=None.") self.output_key = output_key if output_key is not None else self.keys[0] - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - items: Union[List[torch.Tensor], torch.Tensor] - if len(self.keys) == 1: + items: Union[List[NdarrayOrTensor], NdarrayOrTensor] + if len(self.keys) == 1 and self.keys[0] in d: items = d[self.keys[0]] else: items = [d[key] for key in self.key_iterator(d)] - d[self.output_key] = self.ensemble(items) + + if len(items) > 0: + d[self.output_key] = self.ensemble(items) return d @@ -367,11 +406,13 @@ class MeanEnsembled(Ensembled): Dictionary-based wrapper of :py:class:`monai.transforms.MeanEnsemble`. """ + backend = MeanEnsemble.backend + def __init__( self, keys: KeysCollection, output_key: Optional[str] = None, - weights: Optional[Union[Sequence[float], torch.Tensor, np.ndarray]] = None, + weights: Optional[Union[Sequence[float], NdarrayOrTensor]] = None, ) -> None: """ Args: @@ -400,6 +441,8 @@ class VoteEnsembled(Ensembled): Dictionary-based wrapper of :py:class:`monai.transforms.VoteEnsemble`. """ + backend = VoteEnsemble.backend + def __init__( self, keys: KeysCollection, output_key: Optional[str] = None, num_classes: Optional[int] = None ) -> None: @@ -448,6 +491,8 @@ class ProbNMSd(MapTransform): """ + backend = ProbNMS.backend + def __init__( self, keys: KeysCollection, @@ -459,13 +504,10 @@ def __init__( ) -> None: super().__init__(keys, allow_missing_keys) self.prob_nms = ProbNMS( - spatial_dims=spatial_dims, - sigma=sigma, - prob_threshold=prob_threshold, - box_size=box_size, + spatial_dims=spatial_dims, sigma=sigma, prob_threshold=prob_threshold, box_size=box_size ) - def __call__(self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]): + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]): d = dict(data) for key in self.key_iterator(d): d[key] = self.prob_nms(d[key]) @@ -475,28 +517,25 @@ def __call__(self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]): class Invertd(MapTransform): """ Utility transform to automatically invert the previously applied transforms. - When applying preprocessing transforms on a orig_key(like: `image`, `label`, etc.), we record the context - information of applied transforms in a dictionary in the input data dictionary with the key - "{orig_key}_transforms". This transform will extract the transform context information of `orig_keys` - then invert the transforms(got from this context information) on the `keys` data. - Typical usage is to invert the preprocessing transforms(applied on input `image`) on the model `pred` data. - The output of the inverted data and metadata will be stored at `keys` and `meta_keys` respectively. - To correctly invert the transforms, the information of the previously applied transforms should be - available at `orig_keys`, and the original metadata at `orig_meta_keys`. - (`meta_key_postfix` is an optional string to conveniently construct "meta_keys" and/or "orig_meta_keys".) + Taking the ``transform`` previously applied on ``orig_keys``, this ``Invertd`` will apply the inverse of it + to the data stored at ``keys``. ``Invertd``'s output will also include a copy of the metadata + dictionary (originally from ``orig_meta_keys``), with the relevant fields inverted and stored at ``meta_keys``. + + A typical usage is to apply the inverse of the preprocessing on input ``image`` to the model ``pred``. A detailed usage example is available in the tutorial: https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_inference_dict.py Note: - According to the `collate_fn`, this transform may return a list of Tensor without batch dim, - thus some following transforms may not support a list of Tensor, and users can leverage the - `post_func` arg for basic processing logic. - This transform needs to extract the context information of applied transforms and the meta data - dictionary from the input data dictionary, then use some numpy arrays in them to computes the inverse - logic, so please don't move `data["{orig_key}_transforms"]` and `data["{orig_meta_key}"]` to GPU device. + - The output of the inverted data and metadata will be stored at ``keys`` and ``meta_keys`` respectively. + - To correctly invert the transforms, the information of the previously applied transforms should be + available at ``{orig_keys}_transforms``, and the original metadata at ``orig_meta_keys``. + (``meta_key_postfix`` is an optional string to conveniently construct "meta_keys" and/or "orig_meta_keys".) + see also: :py:class:`monai.transforms.TraceableTransform`. + - The transform will not change the content in ``orig_keys`` and ``orig_meta_key``. + These keys are only used to represent the data status of ``key`` before inverting. """ @@ -507,7 +546,7 @@ def __init__( orig_keys: KeysCollection, meta_keys: Optional[KeysCollection] = None, orig_meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", + meta_key_postfix: str = DEFAULT_POST_FIX, nearest_interp: Union[bool, Sequence[bool]] = True, to_tensor: Union[bool, Sequence[bool]] = True, device: Union[Union[str, torch.device], Sequence[Union[str, torch.device]]] = "cpu", @@ -516,37 +555,32 @@ def __init__( ) -> None: """ Args: - keys: the key of expected data in the dict, invert transforms on it, in-place operation. - it also can be a list of keys, will invert transform for each of them, like: ["pred", "pred_class2"]. - transform: the previous callable transform that applied on input data. - orig_keys: the key of the original input data in the dict. will get the applied transform information - for this input data, then invert them for the expected data with `keys`. - It can also be a list of keys, each matches to the `keys` data. - meta_keys: explicitly indicate the key for the inverted meta data dictionary. - the meta data is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `{key}_{meta_key_postfix}`. - orig_meta_keys: the key of the meta data of original input data, will get the `affine`, `data_shape`, etc. - the meta data is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `{orig_key}_{meta_key_postfix}`. - meta data will also be inverted and stored in `meta_keys`. - meta_key_postfix: if `orig_meta_keys` is None, use `{orig_key}_{meta_key_postfix}` to to fetch the - meta data from dict, if `meta_keys` is None, use `{key}_{meta_key_postfix}`. - default is `meta_dict`, the meta data is a dictionary object. - For example, to handle orig_key `image`, read/write `affine` matrices from the - metadata `image_meta_dict` dictionary's `affine` field. - the inverted meta dict will be stored with key: "{key}_{meta_key_postfix}". + keys: the key of expected data in the dict, the inverse of ``transforms`` will be applied on it in-place. + It also can be a list of keys, will apply the inverse transform respectively. + transform: the transform applied to ``orig_key``, its inverse will be applied on ``key``. + orig_keys: the key of the original input data in the dict. + the transform trace information of ``transforms`` should be stored at ``{orig_keys}_transforms``. + It can also be a list of keys, each matches the ``keys``. + meta_keys: The key to output the inverted meta data dictionary. + The meta data is a dictionary optionally containing: filename, original_shape. + It can be a sequence of strings, maps to ``keys``. + If None, will try to create a meta data dict with the default key: `{key}_{meta_key_postfix}`. + orig_meta_keys: the key of the meta data of original input data. + The meta data is a dictionary optionally containing: filename, original_shape. + It can be a sequence of strings, maps to the `keys`. + If None, will try to create a meta data dict with the default key: `{orig_key}_{meta_key_postfix}`. + This meta data dict will also be included in the inverted dict, stored in `meta_keys`. + meta_key_postfix: if `orig_meta_keys` is None, use `{orig_key}_{meta_key_postfix}` to fetch the + meta data from dict, if `meta_keys` is None, use `{key}_{meta_key_postfix}`. Default: ``"meta_dict"``. nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms, default to `True`. If `False`, use the same interpolation mode as the original transform. - it also can be a list of bool, each matches to the `keys` data. + It also can be a list of bool, each matches to the `keys` data. to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`. - it also can be a list of bool, each matches to the `keys` data. + It also can be a list of bool, each matches to the `keys` data. device: if converted to Tensor, move the inverted results to target device before `post_func`, - default to "cpu", it also can be a list of string or `torch.device`, - each matches to the `keys` data. + default to "cpu", it also can be a list of string or `torch.device`, each matches to the `keys` data. post_func: post processing for the inverted data, should be a callable function. - it also can be a list of callable, each matches to the `keys` data. + It also can be a list of callable, each matches to the `keys` data. allow_missing_keys: don't raise exception if key is missing. """ @@ -589,7 +623,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: self.device, self.post_func, ): - transform_key = f"{orig_key}{InverseKeys.KEY_SUFFIX}" + transform_key = InvertibleTransform.trace_key(orig_key) if transform_key not in d: warnings.warn(f"transform info of `{orig_key}` is not available or no InvertibleTransform applied.") continue @@ -597,21 +631,16 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: transform_info = d[transform_key] if nearest_interp: transform_info = convert_inverse_interp_mode( - trans_info=deepcopy(transform_info), - mode="nearest", - align_corners=None, + trans_info=deepcopy(transform_info), mode="nearest", align_corners=None ) input = d[key] if isinstance(input, torch.Tensor): input = input.detach() - # construct the input dict data for BatchInverseTransform - input_dict = { - orig_key: input, - transform_key: transform_info, - } + + # construct the input dict data + input_dict = {orig_key: input, transform_key: transform_info} orig_meta_key = orig_meta_key or f"{orig_key}_{meta_key_postfix}" - meta_key = meta_key or f"{key}_{meta_key_postfix}" if orig_meta_key in d: input_dict[orig_meta_key] = d[orig_meta_key] @@ -620,8 +649,10 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: # save the inverted data d[key] = post_func(self._totensor(inverted[orig_key]).to(device) if to_tensor else inverted[orig_key]) + # save the inverted meta dict if orig_meta_key in d: + meta_key = meta_key or f"{key}_{meta_key_postfix}" d[meta_key] = inverted.get(orig_meta_key) return d @@ -637,10 +668,11 @@ def __init__( self, keys: KeysCollection, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", + meta_key_postfix: str = DEFAULT_POST_FIX, saver: Optional[CSVSaver] = None, - output_dir: str = "./", + output_dir: PathLike = "./", filename: str = "predictions.csv", + delimiter: str = ",", overwrite: bool = True, flush: bool = True, allow_missing_keys: bool = False, @@ -664,6 +696,8 @@ def __init__( the saver must provide `save(data, meta_data)` and `finalize()` APIs. output_dir: if `saver=None`, specify the directory to save the CSV file. filename: if `saver=None`, specify the name of the saved CSV file. + delimiter: the delimiter character in the saved file, default to "," as the default output type is `csv`. + to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter. overwrite: if `saver=None`, indicate whether to overwriting existing CSV file content, if True, will clear the file before saving. otherwise, will append new content to the CSV file. flush: if `saver=None`, indicate whether to write the cache data to CSV file immediately @@ -675,7 +709,9 @@ def __init__( super().__init__(keys, allow_missing_keys) if len(self.keys) != 1: raise ValueError("only 1 key is allowed when saving the classification result.") - self.saver = saver or CSVSaver(output_dir, filename, overwrite, flush) + self.saver = saver or CSVSaver( + output_dir=output_dir, filename=filename, overwrite=overwrite, flush=flush, delimiter=delimiter + ) self.flush = flush self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) @@ -712,3 +748,4 @@ def get_saver(self): ProbNMSD = ProbNMSDict = ProbNMSd SaveClassificationD = SaveClassificationDict = SaveClassificationd VoteEnsembleD = VoteEnsembleDict = VoteEnsembled +EnsembleD = EnsembleDict = Ensembled diff --git a/monai/transforms/smooth_field/__init__.py b/monai/transforms/smooth_field/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/transforms/smooth_field/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/transforms/smooth_field/array.py b/monai/transforms/smooth_field/array.py new file mode 100644 index 0000000000..f581687ea5 --- /dev/null +++ b/monai/transforms/smooth_field/array.py @@ -0,0 +1,459 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transforms using a smooth spatial field generated by interpolating from smaller randomized fields.""" + +from typing import Any, Optional, Sequence, Union + +import numpy as np +import torch +from torch.nn.functional import grid_sample, interpolate + +import monai +from monai.config.type_definitions import NdarrayOrTensor +from monai.networks.utils import meshgrid_ij +from monai.transforms.transform import Randomizable, RandomizableTransform +from monai.transforms.utils_pytorch_numpy_unification import moveaxis +from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode +from monai.utils.enums import TransformBackends +from monai.utils.module import look_up_option +from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor + +__all__ = ["SmoothField", "RandSmoothFieldAdjustContrast", "RandSmoothFieldAdjustIntensity", "RandSmoothDeform"] + + +class SmoothField(Randomizable): + """ + Generate a smooth field array by defining a smaller randomized field and then reinterpolating to the desired size. + + This exploits interpolation to create a smoothly varying field used for other applications. An initial randomized + field is defined with `rand_size` dimensions with `pad` number of values padding it along each dimension using + `pad_val` as the value. If `spatial_size` is given this is interpolated to that size, otherwise if None the random + array is produced uninterpolated. The output is always a Pytorch tensor allocated on the specified device. + + Args: + rand_size: size of the randomized field to start from + pad: number of pixels/voxels along the edges of the field to pad with `pad_val` + pad_val: value with which to pad field edges + low: low value for randomized field + high: high value for randomized field + channels: number of channels of final output + spatial_size: final output size of the array, None to produce original uninterpolated field + mode: interpolation mode for resizing the field + align_corners: if True align the corners when upsampling field + device: Pytorch device to define field on + """ + + def __init__( + self, + rand_size: Sequence[int], + pad: int = 0, + pad_val: float = 0, + low: float = -1.0, + high: float = 1.0, + channels: int = 1, + spatial_size: Optional[Sequence[int]] = None, + mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + align_corners: Optional[bool] = None, + device: Optional[torch.device] = None, + ): + self.rand_size = tuple(rand_size) + self.pad = pad + self.low = low + self.high = high + self.channels = channels + self.mode = mode + self.align_corners = align_corners + self.device = device + + self.spatial_size: Optional[Sequence[int]] = None + self.spatial_zoom: Optional[Sequence[float]] = None + + if low >= high: + raise ValueError("Value for `low` must be less than `high` otherwise field will be zeros") + + self.total_rand_size = tuple(rs + self.pad * 2 for rs in self.rand_size) + + self.field = torch.ones((1, self.channels) + self.total_rand_size, device=self.device) * pad_val + + self.crand_size = (self.channels,) + self.rand_size + + pad_slice = slice(None) if self.pad == 0 else slice(self.pad, -self.pad) + self.rand_slices = (0, slice(None)) + (pad_slice,) * len(self.rand_size) + + self.set_spatial_size(spatial_size) + + def randomize(self, data: Optional[Any] = None) -> None: + self.field[self.rand_slices] = torch.from_numpy(self.R.uniform(self.low, self.high, self.crand_size)) + + def set_spatial_size(self, spatial_size: Optional[Sequence[int]]) -> None: + """ + Set the `spatial_size` and `spatial_zoom` attributes used for interpolating the field to the given + dimension, or not interpolate at all if None. + + Args: + spatial_size: new size to interpolate to, or None to not interpolate + """ + if spatial_size is None: + self.spatial_size = None + self.spatial_zoom = None + else: + self.spatial_size = tuple(spatial_size) + self.spatial_zoom = tuple(s / f for s, f in zip(self.spatial_size, self.total_rand_size)) + + def set_mode(self, mode: Union[monai.utils.InterpolateMode, str]) -> None: + self.mode = mode + + def __call__(self, randomize=False) -> torch.Tensor: + if randomize: + self.randomize() + + field = self.field.clone() + + if self.spatial_zoom is not None: + resized_field = interpolate( # type: ignore + input=field, + scale_factor=self.spatial_zoom, + mode=look_up_option(self.mode, InterpolateMode).value, + align_corners=self.align_corners, + recompute_scale_factor=False, + ) + + mina = resized_field.min() + maxa = resized_field.max() + minv = self.field.min() + maxv = self.field.max() + + # faster than rescale_array, this uses in-place operations and doesn't perform unneeded range checks + norm_field = (resized_field.squeeze(0) - mina).div_(maxa - mina) + field = norm_field.mul_(maxv - minv).add_(minv) + + return field + + +class RandSmoothFieldAdjustContrast(RandomizableTransform): + """ + Randomly adjust the contrast of input images by calculating a randomized smooth field for each invocation. + + This uses SmoothField internally to define the adjustment over the image. If `pad` is greater than 0 the + edges of the input volume of that width will be mostly unchanged. Contrast is changed by raising input + values by the power of the smooth field so the range of values given by `gamma` should be chosen with this + in mind. For example, a minimum value of 0 in `gamma` will produce white areas so this should be avoided. + Afte the contrast is adjusted the values of the result are rescaled to the range of the original input. + + Args: + spatial_size: size of input array's spatial dimensions + rand_size: size of the randomized field to start from + pad: number of pixels/voxels along the edges of the field to pad with 1 + mode: interpolation mode to use when upsampling + align_corners: if True align the corners when upsampling field + prob: probability transform is applied + gamma: (min, max) range for exponential field + device: Pytorch device to define field on + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__( + self, + spatial_size: Sequence[int], + rand_size: Sequence[int], + pad: int = 0, + mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + align_corners: Optional[bool] = None, + prob: float = 0.1, + gamma: Union[Sequence[float], float] = (0.5, 4.5), + device: Optional[torch.device] = None, + ): + super().__init__(prob) + + if isinstance(gamma, (int, float)): + self.gamma = (0.5, gamma) + else: + if len(gamma) != 2: + raise ValueError("Argument `gamma` should be a number or pair of numbers.") + + self.gamma = (min(gamma), max(gamma)) + + self.sfield = SmoothField( + rand_size=rand_size, + pad=pad, + pad_val=1, + low=self.gamma[0], + high=self.gamma[1], + channels=1, + spatial_size=spatial_size, + mode=mode, + align_corners=align_corners, + device=device, + ) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandSmoothFieldAdjustContrast": + super().set_random_state(seed, state) + self.sfield.set_random_state(seed, state) + return self + + def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) + + if self._do_transform: + self.sfield.randomize() + + def set_mode(self, mode: Union[monai.utils.InterpolateMode, str]) -> None: + self.sfield.set_mode(mode) + + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: + """ + Apply the transform to `img`, if `randomize` randomizing the smooth field otherwise reusing the previous. + """ + if randomize: + self.randomize() + + if not self._do_transform: + return img + + img_min = img.min() + img_max = img.max() + img_rng = img_max - img_min + + field = self.sfield() + rfield, *_ = convert_to_dst_type(field, img) + + # everything below here is to be computed using the destination type (numpy, tensor, etc.) + + img = (img - img_min) / (img_rng + 1e-10) # rescale to unit values + img = img**rfield # contrast is changed by raising image data to a power, in this case the field + + out = (img * img_rng) + img_min # rescale back to the original image value range + + return out + + +class RandSmoothFieldAdjustIntensity(RandomizableTransform): + """ + Randomly adjust the intensity of input images by calculating a randomized smooth field for each invocation. + + This uses SmoothField internally to define the adjustment over the image. If `pad` is greater than 0 the + edges of the input volume of that width will be mostly unchanged. Intensity is changed by multiplying the + inputs by the smooth field, so the values of `gamma` should be chosen with this in mind. The default values + of `(0.1, 1.0)` are sensible in that values will not be zeroed out by the field nor multiplied greater than + the original value range. + + Args: + spatial_size: size of input array + rand_size: size of the randomized field to start from + pad: number of pixels/voxels along the edges of the field to pad with 1 + mode: interpolation mode to use when upsampling + align_corners: if True align the corners when upsampling field + prob: probability transform is applied + gamma: (min, max) range of intensity multipliers + device: Pytorch device to define field on + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__( + self, + spatial_size: Sequence[int], + rand_size: Sequence[int], + pad: int = 0, + mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + align_corners: Optional[bool] = None, + prob: float = 0.1, + gamma: Union[Sequence[float], float] = (0.1, 1.0), + device: Optional[torch.device] = None, + ): + super().__init__(prob) + + if isinstance(gamma, (int, float)): + self.gamma = (0.5, gamma) + else: + if len(gamma) != 2: + raise ValueError("Argument `gamma` should be a number or pair of numbers.") + + self.gamma = (min(gamma), max(gamma)) + + self.sfield = SmoothField( + rand_size=rand_size, + pad=pad, + pad_val=1, + low=self.gamma[0], + high=self.gamma[1], + channels=1, + spatial_size=spatial_size, + mode=mode, + align_corners=align_corners, + device=device, + ) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandSmoothFieldAdjustIntensity": + super().set_random_state(seed, state) + self.sfield.set_random_state(seed, state) + return self + + def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) + + if self._do_transform: + self.sfield.randomize() + + def set_mode(self, mode: Union[InterpolateMode, str]) -> None: + self.sfield.set_mode(mode) + + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: + """ + Apply the transform to `img`, if `randomize` randomizing the smooth field otherwise reusing the previous. + """ + + if randomize: + self.randomize() + + if not self._do_transform: + return img + + field = self.sfield() + rfield, *_ = convert_to_dst_type(field, img) + + # everything below here is to be computed using the destination type (numpy, tensor, etc.) + + out = img * rfield + + return out + + +class RandSmoothDeform(RandomizableTransform): + """ + Deform an image using a random smooth field and Pytorch's grid_sample. + + The amount of deformation is given by `def_range` in fractions of the size of the image. The size of each dimension + of the input image is always defined as 2 regardless of actual image voxel dimensions, that is the coordinates in + every dimension range from -1 to 1. A value of 0.1 means pixels/voxels can be moved by up to 5% of the image's size. + + Args: + spatial_size: input array size to which deformation grid is interpolated + rand_size: size of the randomized field to start from + pad: number of pixels/voxels along the edges of the field to pad with 0 + field_mode: interpolation mode to use when upsampling the deformation field + align_corners: if True align the corners when upsampling field + prob: probability transform is applied + def_range: value of the deformation range in image size fractions, single min/max value or min/max pair + grid_dtype: type for the deformation grid calculated from the field + grid_mode: interpolation mode used for sampling input using deformation grid + grid_padding_mode: padding mode used for sampling input using deformation grid + grid_align_corners: if True align the corners when sampling the deformation grid + device: Pytorch device to define field on + """ + + backend = [TransformBackends.TORCH] + + def __init__( + self, + spatial_size: Sequence[int], + rand_size: Sequence[int], + pad: int = 0, + field_mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + align_corners: Optional[bool] = None, + prob: float = 0.1, + def_range: Union[Sequence[float], float] = 1.0, + grid_dtype=torch.float32, + grid_mode: Union[GridSampleMode, str] = GridSampleMode.NEAREST, + grid_padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + grid_align_corners: Optional[bool] = False, + device: Optional[torch.device] = None, + ): + super().__init__(prob) + + self.grid_dtype = grid_dtype + self.grid_mode = grid_mode + self.def_range = def_range + self.device = device + self.grid_align_corners = grid_align_corners + self.grid_padding_mode = grid_padding_mode + + if isinstance(def_range, (int, float)): + self.def_range = (-def_range, def_range) + else: + if len(def_range) != 2: + raise ValueError("Argument `def_range` should be a number or pair of numbers.") + + self.def_range = (min(def_range), max(def_range)) + + self.sfield = SmoothField( + spatial_size=spatial_size, + rand_size=rand_size, + pad=pad, + low=self.def_range[0], + high=self.def_range[1], + channels=len(rand_size), + mode=field_mode, + align_corners=align_corners, + device=device, + ) + + grid_space = spatial_size if spatial_size is not None else self.sfield.field.shape[2:] + grid_ranges = [torch.linspace(-1, 1, d) for d in grid_space] + + grid = meshgrid_ij(*grid_ranges) + + self.grid = torch.stack(grid).unsqueeze(0).to(self.device, self.grid_dtype) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "Randomizable": + super().set_random_state(seed, state) + self.sfield.set_random_state(seed, state) + return self + + def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) + + if self._do_transform: + self.sfield.randomize() + + def set_field_mode(self, mode: Union[monai.utils.InterpolateMode, str]) -> None: + self.sfield.set_mode(mode) + + def set_grid_mode(self, mode: Union[monai.utils.GridSampleMode, str]) -> None: + self.grid_mode = mode + + def __call__( + self, img: NdarrayOrTensor, randomize: bool = True, device: Optional[torch.device] = None + ) -> NdarrayOrTensor: + if randomize: + self.randomize() + + if not self._do_transform: + return img + + device = device if device is not None else self.device + + field = self.sfield() + + dgrid = self.grid + field.to(self.grid_dtype) + dgrid = moveaxis(dgrid, 1, -1) # type: ignore + + img_t = convert_to_tensor(img[None], torch.float32, device) + + out = grid_sample( + input=img_t, + grid=dgrid, + mode=look_up_option(self.grid_mode, GridSampleMode).value, + align_corners=self.grid_align_corners, + padding_mode=look_up_option(self.grid_padding_mode, GridSamplePadMode).value, + ) + + out_t, *_ = convert_to_dst_type(out.squeeze(0), img) + + return out_t diff --git a/monai/transforms/smooth_field/dictionary.py b/monai/transforms/smooth_field/dictionary.py new file mode 100644 index 0000000000..24890140cc --- /dev/null +++ b/monai/transforms/smooth_field/dictionary.py @@ -0,0 +1,293 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Hashable, Mapping, Optional, Sequence, Union + +import numpy as np +import torch + +from monai.config import KeysCollection +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.smooth_field.array import ( + RandSmoothDeform, + RandSmoothFieldAdjustContrast, + RandSmoothFieldAdjustIntensity, +) +from monai.transforms.transform import MapTransform, RandomizableTransform +from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, ensure_tuple_rep +from monai.utils.enums import TransformBackends + +__all__ = [ + "RandSmoothFieldAdjustContrastd", + "RandSmoothFieldAdjustIntensityd", + "RandSmoothDeformd", + "RandSmoothFieldAdjustContrastD", + "RandSmoothFieldAdjustIntensityD", + "RandSmoothDeformD", + "RandSmoothFieldAdjustContrastDict", + "RandSmoothFieldAdjustIntensityDict", + "RandSmoothDeformDict", +] + + +InterpolateModeType = Union[InterpolateMode, str] +GridSampleModeType = Union[GridSampleMode, str] + + +class RandSmoothFieldAdjustContrastd(RandomizableTransform, MapTransform): + """ + Dictionary version of RandSmoothFieldAdjustContrast. + + The field is randomized once per invocation by default so the same field is applied to every selected key. The + `mode` parameter specifying interpolation mode for the field can be a single value or a sequence of values with + one for each key in `keys`. + + Args: + keys: key names to apply the augment to + spatial_size: size of input arrays, all arrays stated in `keys` must have same dimensions + rand_size: size of the randomized field to start from + pad: number of pixels/voxels along the edges of the field to pad with 0 + mode: interpolation mode to use when upsampling + align_corners: if True align the corners when upsampling field + prob: probability transform is applied + gamma: (min, max) range for exponential field + device: Pytorch device to define field on + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__( + self, + keys: KeysCollection, + spatial_size: Sequence[int], + rand_size: Sequence[int], + pad: int = 0, + mode: Union[InterpolateModeType, Sequence[InterpolateModeType]] = InterpolateMode.AREA, + align_corners: Optional[bool] = None, + prob: float = 0.1, + gamma: Union[Sequence[float], float] = (0.5, 4.5), + device: Optional[torch.device] = None, + ): + RandomizableTransform.__init__(self, prob) + MapTransform.__init__(self, keys) + + self.mode = ensure_tuple_rep(mode, len(self.keys)) + + self.trans = RandSmoothFieldAdjustContrast( + spatial_size=spatial_size, + rand_size=rand_size, + pad=pad, + mode=self.mode[0], + align_corners=align_corners, + prob=1.0, + gamma=gamma, + device=device, + ) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandSmoothFieldAdjustContrastd": + super().set_random_state(seed, state) + self.trans.set_random_state(seed, state) + return self + + def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) + + if self._do_transform: + self.trans.randomize() + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + self.randomize() + + if not self._do_transform: + return data + + d = dict(data) + + for idx, key in enumerate(self.key_iterator(d)): + self.trans.set_mode(self.mode[idx % len(self.mode)]) + d[key] = self.trans(d[key], False) + + return d + + +class RandSmoothFieldAdjustIntensityd(RandomizableTransform, MapTransform): + """ + Dictionary version of RandSmoothFieldAdjustIntensity. + + The field is randomized once per invocation by default so the same field is applied to every selected key. The + `mode` parameter specifying interpolation mode for the field can be a single value or a sequence of values with + one for each key in `keys`. + + Args: + keys: key names to apply the augment to + spatial_size: size of input arrays, all arrays stated in `keys` must have same dimensions + rand_size: size of the randomized field to start from + pad: number of pixels/voxels along the edges of the field to pad with 0 + mode: interpolation mode to use when upsampling + align_corners: if True align the corners when upsampling field + prob: probability transform is applied + gamma: (min, max) range of intensity multipliers + device: Pytorch device to define field on + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__( + self, + keys: KeysCollection, + spatial_size: Sequence[int], + rand_size: Sequence[int], + pad: int = 0, + mode: Union[InterpolateModeType, Sequence[InterpolateModeType]] = InterpolateMode.AREA, + align_corners: Optional[bool] = None, + prob: float = 0.1, + gamma: Union[Sequence[float], float] = (0.1, 1.0), + device: Optional[torch.device] = None, + ): + RandomizableTransform.__init__(self, prob) + MapTransform.__init__(self, keys) + + self.mode = ensure_tuple_rep(mode, len(self.keys)) + + self.trans = RandSmoothFieldAdjustIntensity( + spatial_size=spatial_size, + rand_size=rand_size, + pad=pad, + mode=self.mode[0], + align_corners=align_corners, + prob=1.0, + gamma=gamma, + device=device, + ) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandSmoothFieldAdjustIntensityd": + super().set_random_state(seed, state) + self.trans.set_random_state(seed, state) + return self + + def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) + self.trans.randomize() + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + self.randomize() + + if not self._do_transform: + return data + + d = dict(data) + + for idx, key in enumerate(self.key_iterator(d)): + self.trans.set_mode(self.mode[idx % len(self.mode)]) + d[key] = self.trans(d[key], False) + + return d + + +class RandSmoothDeformd(RandomizableTransform, MapTransform): + """ + Dictionary version of RandSmoothDeform. + + The field is randomized once per invocation by default so the same field is applied to every selected key. The + `field_mode` parameter specifying interpolation mode for the field can be a single value or a sequence of values + with one for each key in `keys`. Similarly the `grid_mode` parameter can be one value or one per key. + + Args: + keys: key names to apply the augment to + spatial_size: input array size to which deformation grid is interpolated + rand_size: size of the randomized field to start from + pad: number of pixels/voxels along the edges of the field to pad with 0 + field_mode: interpolation mode to use when upsampling the deformation field + align_corners: if True align the corners when upsampling field + prob: probability transform is applied + def_range: value of the deformation range in image size fractions + grid_dtype: type for the deformation grid calculated from the field + grid_mode: interpolation mode used for sampling input using deformation grid + grid_padding_mode: padding mode used for sampling input using deformation grid + grid_align_corners: if True align the corners when sampling the deformation grid + device: Pytorch device to define field on + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__( + self, + keys: KeysCollection, + spatial_size: Sequence[int], + rand_size: Sequence[int], + pad: int = 0, + field_mode: Union[InterpolateModeType, Sequence[InterpolateModeType]] = InterpolateMode.AREA, + align_corners: Optional[bool] = None, + prob: float = 0.1, + def_range: Union[Sequence[float], float] = 1.0, + grid_dtype=torch.float32, + grid_mode: Union[GridSampleModeType, Sequence[GridSampleModeType]] = GridSampleMode.NEAREST, + grid_padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + grid_align_corners: Optional[bool] = False, + device: Optional[torch.device] = None, + ): + RandomizableTransform.__init__(self, prob) + MapTransform.__init__(self, keys) + + self.field_mode = ensure_tuple_rep(field_mode, len(self.keys)) + self.grid_mode = ensure_tuple_rep(grid_mode, len(self.keys)) + + self.trans = RandSmoothDeform( + rand_size=rand_size, + spatial_size=spatial_size, + pad=pad, + field_mode=self.field_mode[0], + align_corners=align_corners, + prob=1.0, + def_range=def_range, + grid_dtype=grid_dtype, + grid_mode=self.grid_mode[0], + grid_padding_mode=grid_padding_mode, + grid_align_corners=grid_align_corners, + device=device, + ) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandSmoothDeformd": + super().set_random_state(seed, state) + self.trans.set_random_state(seed, state) + return self + + def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) + self.trans.randomize() + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + self.randomize() + + if not self._do_transform: + return data + + d = dict(data) + + for idx, key in enumerate(self.key_iterator(d)): + self.trans.set_field_mode(self.field_mode[idx % len(self.field_mode)]) + self.trans.set_grid_mode(self.grid_mode[idx % len(self.grid_mode)]) + + d[key] = self.trans(d[key], False, self.trans.device) + + return d + + +RandSmoothDeformD = RandSmoothDeformDict = RandSmoothDeformd +RandSmoothFieldAdjustIntensityD = RandSmoothFieldAdjustIntensityDict = RandSmoothFieldAdjustIntensityd +RandSmoothFieldAdjustContrastD = RandSmoothFieldAdjustContrastDict = RandSmoothFieldAdjustContrastd diff --git a/monai/transforms/spatial/__init__.py b/monai/transforms/spatial/__init__.py index 14ae193634..1e97f89407 100644 --- a/monai/transforms/spatial/__init__.py +++ b/monai/transforms/spatial/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index c3bd4a3433..f7327aa07b 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,16 +13,18 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ import warnings -from typing import Any, List, Optional, Sequence, Tuple, Union +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch from monai.config import USE_COMPILED, DtypeLike from monai.config.type_definitions import NdarrayOrTensor -from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine +from monai.data.utils import AFFINE_TOL, compute_shape_offset, reorient_spatial_axes, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull -from monai.transforms.croppad.array import CenterSpatialCrop +from monai.networks.utils import meshgrid_ij, normalize_transform +from monai.transforms.croppad.array import CenterSpatialCrop, Pad from monai.transforms.transform import Randomizable, RandomizableTransform, ThreadUnsafe, Transform from monai.transforms.utils import ( create_control_grid, @@ -33,27 +35,35 @@ create_translate, map_spatial_axes, ) +from monai.transforms.utils_pytorch_numpy_unification import allclose, moveaxis from monai.utils import ( GridSampleMode, GridSamplePadMode, InterpolateMode, NumpyPadMode, + PytorchPadMode, ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, issequenceiterable, optional_import, + pytorch_after, ) +from monai.utils.deprecate_utils import deprecated_arg from monai.utils.enums import TransformBackends from monai.utils.module import look_up_option +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type -nib, _ = optional_import("nibabel") +nib, has_nib = optional_import("nibabel") __all__ = [ + "SpatialResample", + "ResampleToMatch", "Spacing", "Orientation", "Flip", + "GridDistortion", "Resize", "Rotate", "Zoom", @@ -61,6 +71,7 @@ "RandRotate90", "RandRotate", "RandFlip", + "RandGridDistortion", "RandAxisFlip", "RandZoom", "AffineGrid", @@ -71,25 +82,248 @@ "RandAffine", "Rand2DElastic", "Rand3DElastic", - "AddCoordinateChannels", ] RandRange = Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] +class SpatialResample(Transform): + """ + Resample input image from the orientation/spacing defined by ``src_affine`` affine matrix into + the ones specified by ``dst_affine`` affine matrix. + + Internally this transform computes the affine transform matrix from ``src_affine`` to ``dst_affine``, + by ``xform = linalg.solve(src_affine, dst_affine)``, and call ``monai.transforms.Affine`` with ``xform``. + """ + + backend = [TransformBackends.TORCH] + + def __init__( + self, + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + align_corners: bool = False, + dtype: DtypeLike = np.float64, + ): + """ + Args: + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + align_corners: Geometrically, we consider the pixels of the input as squares rather than points. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. + If ``None``, use the data type of input data. To be compatible with other modules, + the output data type is always ``np.float32``. + """ + self.mode = mode + self.padding_mode = padding_mode + self.align_corners = align_corners + self.dtype = dtype + + def __call__( + self, + img: NdarrayOrTensor, + src_affine: Optional[NdarrayOrTensor] = None, + dst_affine: Optional[NdarrayOrTensor] = None, + spatial_size: Optional[Union[Sequence[int], np.ndarray, int]] = None, + mode: Union[GridSampleMode, str, None] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str, None] = GridSamplePadMode.BORDER, + align_corners: Optional[bool] = False, + dtype: DtypeLike = None, + ) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: + """ + Args: + img: input image to be resampled. It currently supports channel-first arrays with + at most three spatial dimensions. + src_affine: source affine matrix. Defaults to ``None``, which means the identity matrix. + the shape should be `(r+1, r+1)` where `r` is the spatial rank of ``img``. + dst_affine: destination affine matrix. Defaults to ``None``, which means the same as `src_affine`. + the shape should be `(r+1, r+1)` where `r` is the spatial rank of ``img``. + when `dst_affine` and `spatial_size` are None, the input will be returned without resampling, + but the data type will be `float32`. + spatial_size: output image spatial size. + if `spatial_size` and `self.spatial_size` are not defined, + the transform will compute a spatial size automatically containing the previous field of view. + if `spatial_size` is ``-1`` are the transform will use the corresponding input img size. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + align_corners: Geometrically, we consider the pixels of the input as squares rather than points. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + dtype: data type for resampling computation. Defaults to ``self.dtype`` or + ``np.float64`` (for best precision). If ``None``, use the data type of input data. + To be compatible with other modules, the output data type is always `float32`. + + The spatial rank is determined by the smallest among ``img.ndim -1``, ``len(src_affine) - 1``, and ``3``. + + When both ``monai.config.USE_COMPILED`` and ``align_corners`` are set to ``True``, + MONAI's resampling implementation will be used. + Set `dst_affine` and `spatial_size` to `None` to turn off the resampling step. + """ + if src_affine is None: + src_affine = np.eye(4, dtype=np.float64) + spatial_rank = min(len(img.shape) - 1, src_affine.shape[0] - 1, 3) + if (not isinstance(spatial_size, int) or spatial_size != -1) and spatial_size is not None: + spatial_rank = min(len(ensure_tuple(spatial_size)), 3) # infer spatial rank based on spatial_size + src_affine = to_affine_nd(spatial_rank, src_affine) + dst_affine = to_affine_nd(spatial_rank, dst_affine) if dst_affine is not None else src_affine + dst_affine, *_ = convert_to_dst_type(dst_affine, dst_affine, dtype=torch.float32) + + in_spatial_size = np.asarray(img.shape[1 : spatial_rank + 1]) + if isinstance(spatial_size, int) and (spatial_size == -1): # using the input spatial size + spatial_size = in_spatial_size + elif spatial_size is None and spatial_rank > 1: # auto spatial size + spatial_size, _ = compute_shape_offset(in_spatial_size, src_affine, dst_affine) # type: ignore + spatial_size = np.asarray(fall_back_tuple(ensure_tuple(spatial_size)[:spatial_rank], in_spatial_size)) + + if ( + allclose(src_affine, dst_affine, atol=AFFINE_TOL) + and allclose(spatial_size, in_spatial_size) + or spatial_rank == 1 + ): + # no significant change, return original image + output_data, *_ = convert_to_dst_type(img, img, dtype=torch.float32) + return output_data, dst_affine + + if has_nib and isinstance(img, np.ndarray): + spatial_ornt, dst_r = reorient_spatial_axes(img.shape[1 : spatial_rank + 1], src_affine, dst_affine) + if allclose(dst_r, dst_affine, atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size): + # simple reorientation achieves the desired affine + spatial_ornt[:, 0] += 1 + spatial_ornt = np.concatenate([np.array([[0, 1]]), spatial_ornt]) + img_ = nib.orientations.apply_orientation(img, spatial_ornt) + output_data, *_ = convert_to_dst_type(img_, img, dtype=torch.float32) + return output_data, dst_affine + + try: + src_affine, *_ = convert_to_dst_type(src_affine, dst_affine) + if isinstance(src_affine, np.ndarray): + xform = np.linalg.solve(src_affine, dst_affine) + else: + xform = ( + torch.linalg.solve(src_affine, dst_affine) + if pytorch_after(1, 8, 0) + else torch.solve(dst_affine, src_affine).solution # type: ignore + ) + except (np.linalg.LinAlgError, RuntimeError) as e: + raise ValueError(f"src affine is not invertible: {src_affine}") from e + xform = to_affine_nd(spatial_rank, xform) + # no resampling if it's identity transform + if allclose(xform, np.diag(np.ones(len(xform))), atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size): + output_data, *_ = convert_to_dst_type(img, img, dtype=torch.float32) + return output_data, dst_affine + + _dtype = dtype or self.dtype or img.dtype + in_spatial_size = in_spatial_size.tolist() + chns, additional_dims = img.shape[0], img.shape[spatial_rank + 1 :] # beyond three spatial dims + # resample + img_ = convert_data_type(img, torch.Tensor, dtype=_dtype)[0] + xform = convert_to_dst_type(xform, img_)[0] + align_corners = self.align_corners if align_corners is None else align_corners + mode = mode or self.mode + padding_mode = padding_mode or self.padding_mode + if additional_dims: + xform_shape = [-1] + in_spatial_size + img_ = img_.reshape(xform_shape) + if align_corners: + _t_r = torch.diag(torch.ones(len(xform), dtype=xform.dtype, device=xform.device)) # type: ignore + for idx, d_dst in enumerate(spatial_size[:spatial_rank]): + _t_r[idx, -1] = (max(d_dst, 2) - 1.0) / 2.0 + xform = xform @ _t_r + if not USE_COMPILED: + _t_l = normalize_transform( + in_spatial_size, xform.device, xform.dtype, align_corners=True # type: ignore + ) + xform = _t_l @ xform # type: ignore + affine_xform = Affine( + affine=xform, spatial_size=spatial_size, norm_coords=False, image_only=True, dtype=_dtype + ) + output_data = affine_xform(img_, mode=mode, padding_mode=padding_mode) + else: + affine_xform = AffineTransform( + normalized=False, + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + reverse_indexing=True, + ) + output_data = affine_xform(img_.unsqueeze(0), theta=xform, spatial_size=spatial_size).squeeze(0) + if additional_dims: + full_shape = (chns, *spatial_size, *additional_dims) + output_data = output_data.reshape(full_shape) + # output dtype float + output_data, *_ = convert_to_dst_type(output_data, img, dtype=torch.float32) + return output_data, dst_affine + + +class ResampleToMatch(SpatialResample): + """Resample an image to match given meta data. The affine matrix will be aligned, + and the size of the output image will match.""" + + def __call__( # type: ignore + self, + img: NdarrayOrTensor, + src_meta: Optional[Dict] = None, + dst_meta: Optional[Dict] = None, + mode: Union[GridSampleMode, str, None] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str, None] = GridSamplePadMode.BORDER, + align_corners: Optional[bool] = False, + dtype: DtypeLike = None, + ): + if src_meta is None: + raise RuntimeError("`in_meta` is missing") + if dst_meta is None: + raise RuntimeError("`out_meta` is missing") + mode = mode or self.mode + padding_mode = padding_mode or self.padding_mode + align_corners = self.align_corners if align_corners is None else align_corners + dtype = dtype or self.dtype + src_affine = src_meta.get("affine") + dst_affine = dst_meta.get("affine") + img, updated_affine = super().__call__( + img=img, + src_affine=src_affine, + dst_affine=dst_affine, + spatial_size=dst_meta.get("spatial_shape"), + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + dtype=dtype, + ) + dst_meta = deepcopy(dst_meta) + dst_meta["affine"] = updated_affine + return img, dst_meta + + class Spacing(Transform): """ Resample input image into the specified `pixdim`. """ + backend = SpatialResample.backend + def __init__( self, - pixdim: Union[Sequence[float], float], + pixdim: Union[Sequence[float], float, np.ndarray], diagonal: bool = False, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, align_corners: bool = False, dtype: DtypeLike = np.float64, + image_only: bool = False, ) -> None: """ Args: @@ -113,46 +347,57 @@ def __init__( of the original data. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html align_corners: Geometrically, we consider the pixels of the input as squares rather than points. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. + image_only: return just the image or the image, the old affine and new affine. Default is `False`. """ self.pixdim = np.array(ensure_tuple(pixdim), dtype=np.float64) self.diagonal = diagonal - self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) - self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) - self.align_corners = align_corners - self.dtype = dtype + self.image_only = image_only + + self.sp_resample = SpatialResample( + mode=look_up_option(mode, GridSampleMode), + padding_mode=look_up_option(padding_mode, GridSamplePadMode), + align_corners=align_corners, + dtype=dtype, + ) def __call__( self, - data_array: np.ndarray, - affine: Optional[np.ndarray] = None, + data_array: NdarrayOrTensor, + affine: Optional[NdarrayOrTensor] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, dtype: DtypeLike = None, - output_spatial_shape: Optional[np.ndarray] = None, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + output_spatial_shape: Optional[Union[Sequence[int], np.ndarray, int]] = None, + ) -> Union[NdarrayOrTensor, Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor]]: """ Args: data_array: in shape (num_channels, H[, W, ...]). affine (matrix): (N+1)x(N+1) original affine matrix for spatially ND `data_array`. Defaults to identity. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html align_corners: Geometrically, we consider the pixels of the input as squares rather than points. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. @@ -168,16 +413,16 @@ def __call__( data_array (resampled into `self.pixdim`), original affine, current affine. """ - _dtype = dtype or self.dtype or data_array.dtype - sr = data_array.ndim - 1 + sr = int(data_array.ndim - 1) if sr <= 0: raise ValueError("data_array must have at least one spatial dimension.") if affine is None: # default to identity - affine = np.eye(sr + 1, dtype=np.float64) + affine_np = affine = np.eye(sr + 1, dtype=np.float64) affine_ = np.eye(sr + 1, dtype=np.float64) else: - affine_ = to_affine_nd(sr, affine) + affine_np, *_ = convert_data_type(affine, np.ndarray) + affine_ = to_affine_nd(sr, affine_np) out_d = self.pixdim[:sr] if out_d.size < sr: @@ -187,33 +432,21 @@ def __call__( new_affine = zoom_affine(affine_, out_d, diagonal=self.diagonal) output_shape, offset = compute_shape_offset(data_array.shape[1:], affine_, new_affine) new_affine[:sr, -1] = offset[:sr] - transform = np.linalg.inv(affine_) @ new_affine - # adapt to the actual rank - transform = to_affine_nd(sr, transform) - - # no resampling if it's identity transform - if np.allclose(transform, np.diag(np.ones(len(transform))), atol=1e-3): - output_data = data_array.copy().astype(np.float32) - new_affine = to_affine_nd(affine, new_affine) - return output_data, affine, new_affine - - # resample - affine_xform = AffineTransform( - normalized=False, - mode=look_up_option(mode or self.mode, GridSampleMode), - padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode), - align_corners=self.align_corners if align_corners is None else align_corners, - reverse_indexing=True, + output_data, new_affine = self.sp_resample( + data_array, + src_affine=affine, + dst_affine=new_affine, + spatial_size=list(output_shape) if output_spatial_shape is None else output_spatial_shape, + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + dtype=dtype, ) - output_data = affine_xform( - # AffineTransform requires a batch dim - torch.as_tensor(np.ascontiguousarray(data_array).astype(_dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(transform).astype(_dtype)), - spatial_size=output_shape if output_spatial_shape is None else output_spatial_shape, - ) - output_data = np.asarray(output_data.squeeze(0).detach().cpu().numpy(), dtype=np.float32) # type: ignore - new_affine = to_affine_nd(affine, new_affine) + new_affine = to_affine_nd(affine_np, new_affine) + new_affine, *_ = convert_to_dst_type(src=new_affine, dst=affine, dtype=torch.float32) + if self.image_only: + return output_data return output_data, affine, new_affine @@ -222,11 +455,14 @@ class Orientation(Transform): Change the input image's orientation into the specified based on `axcodes`. """ + backend = [TransformBackends.NUMPY, TransformBackends.TORCH] + def __init__( self, axcodes: Optional[str] = None, as_closest_canonical: bool = False, labels: Optional[Sequence[Tuple[str, str]]] = tuple(zip("LPI", "RAS")), + image_only: bool = False, ) -> None: """ Args: @@ -239,6 +475,7 @@ def __init__( labels: optional, None or sequence of (2,) sequences (2,) sequences are labels for (beginning, end) of output axis. Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``. + image_only: if True return only the image volume, otherwise return (image, affine, new_affine). Raises: ValueError: When ``axcodes=None`` and ``as_closest_canonical=True``. Incompatible values. @@ -253,10 +490,11 @@ def __init__( self.axcodes = axcodes self.as_closest_canonical = as_closest_canonical self.labels = labels + self.image_only = image_only def __call__( - self, data_array: np.ndarray, affine: Optional[np.ndarray] = None - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + self, data_array: NdarrayOrTensor, affine: Optional[NdarrayOrTensor] = None + ) -> Union[NdarrayOrTensor, Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor]]: """ original orientation of `data_array` is defined by `affine`. @@ -269,38 +507,64 @@ def __call__( ValueError: When ``axcodes`` spatiality differs from ``data_array``. Returns: - data_array (reoriented in `self.axcodes`), original axcodes, current axcodes. + data_array [reoriented in `self.axcodes`] if `self.image_only`, else + (data_array [reoriented in `self.axcodes`], original axcodes, current axcodes). """ - sr = data_array.ndim - 1 + spatial_shape = data_array.shape[1:] + sr = len(spatial_shape) if sr <= 0: raise ValueError("data_array must have at least one spatial dimension.") + affine_: np.ndarray if affine is None: - affine = np.eye(sr + 1, dtype=np.float64) + # default to identity + affine_np = affine = np.eye(sr + 1, dtype=np.float64) affine_ = np.eye(sr + 1, dtype=np.float64) else: - affine_ = to_affine_nd(sr, affine) + affine_np, *_ = convert_data_type(affine, np.ndarray) + affine_ = to_affine_nd(sr, affine_np) + src = nib.io_orientation(affine_) if self.as_closest_canonical: spatial_ornt = src else: if self.axcodes is None: - raise AssertionError + raise ValueError("Incompatible values: axcodes=None and as_closest_canonical=True.") + if sr < len(self.axcodes): + warnings.warn( + f"axcodes ('{self.axcodes}') length is smaller than the number of input spatial dimensions D={sr}.\n" + f"{self.__class__.__name__}: input spatial shape is {spatial_shape}, num. channels is {data_array.shape[0]}," + "please make sure the input is in the channel-first format." + ) dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=self.labels) if len(dst) < sr: raise ValueError( f"axcodes must match data_array spatially, got axcodes={len(self.axcodes)}D data_array={sr}D" ) spatial_ornt = nib.orientations.ornt_transform(src, dst) - ornt = spatial_ornt.copy() - ornt[:, 0] += 1 # skip channel dim - ornt = np.concatenate([np.array([[0, 1]]), ornt]) - shape = data_array.shape[1:] - data_array = np.ascontiguousarray(nib.orientations.apply_orientation(data_array, ornt)) - new_affine = affine_ @ nib.orientations.inv_ornt_aff(spatial_ornt, shape) - new_affine = to_affine_nd(affine, new_affine) + new_affine = affine_ @ nib.orientations.inv_ornt_aff(spatial_ornt, spatial_shape) + _is_tensor = isinstance(data_array, torch.Tensor) + spatial_ornt[:, 0] += 1 # skip channel dim + spatial_ornt = np.concatenate([np.array([[0, 1]]), spatial_ornt]) + axes = [ax for ax, flip in enumerate(spatial_ornt[:, 1]) if flip == -1] + if axes: + data_array = ( + torch.flip(data_array, dims=axes) if _is_tensor else np.flip(data_array, axis=axes) # type: ignore + ) + full_transpose = np.arange(len(data_array.shape)) + full_transpose[: len(spatial_ornt)] = np.argsort(spatial_ornt[:, 0]) + if not np.all(full_transpose == np.arange(len(data_array.shape))): + if _is_tensor: + data_array = data_array.permute(full_transpose.tolist()) # type: ignore + else: + data_array = data_array.transpose(full_transpose) # type: ignore + out, *_ = convert_to_dst_type(src=data_array, dst=data_array) + new_affine = to_affine_nd(affine_np, new_affine) + new_affine, *_ = convert_to_dst_type(src=new_affine, dst=affine, dtype=torch.float32) - return data_array, affine, new_affine + if self.image_only: + return out + return out, affine, new_affine class Flip(Transform): @@ -330,8 +594,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ if isinstance(img, np.ndarray): return np.ascontiguousarray(np.flip(img, map_spatial_axes(img.ndim, self.spatial_axis))) - else: - return torch.flip(img, map_spatial_axes(img.ndim, self.spatial_axis)) + return torch.flip(img, map_spatial_axes(img.ndim, self.spatial_axis)) class Resize(Transform): @@ -351,12 +614,14 @@ class Resize(Transform): #albumentations.augmentations.geometric.resize.LongestMaxSize. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"area"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html """ + backend = [TransformBackends.TORCH] + def __init__( self, spatial_size: Union[Sequence[int], int], @@ -371,50 +636,51 @@ def __init__( def __call__( self, - img: np.ndarray, + img: NdarrayOrTensor, mode: Optional[Union[InterpolateMode, str]] = None, align_corners: Optional[bool] = None, - ) -> np.ndarray: + ) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]). mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html Raises: ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. """ + img_, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) if self.size_mode == "all": - input_ndim = img.ndim - 1 # spatial ndim + input_ndim = img_.ndim - 1 # spatial ndim output_ndim = len(ensure_tuple(self.spatial_size)) if output_ndim > input_ndim: - input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) - img = img.reshape(input_shape) + input_shape = ensure_tuple_size(img_.shape, output_ndim + 1, 1) + img_ = img_.reshape(input_shape) elif output_ndim < input_ndim: raise ValueError( "len(spatial_size) must be greater or equal to img spatial dimensions, " f"got spatial_size={output_ndim} img={input_ndim}." ) - spatial_size_ = fall_back_tuple(self.spatial_size, img.shape[1:]) + spatial_size_ = fall_back_tuple(self.spatial_size, img_.shape[1:]) else: # for the "longest" mode - img_size = img.shape[1:] + img_size = img_.shape[1:] if not isinstance(self.spatial_size, int): raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") scale = self.spatial_size / max(img_size) spatial_size_ = tuple(int(round(s * scale)) for s in img_size) - resized = torch.nn.functional.interpolate( # type: ignore - input=torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0), + resized = torch.nn.functional.interpolate( + input=img_.unsqueeze(0), size=spatial_size_, mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value, align_corners=self.align_corners if align_corners is None else align_corners, ) - resized = resized.squeeze(0).detach().cpu().numpy() - return np.asarray(resized) + out, *_ = convert_to_dst_type(resized.squeeze(0), img) + return out class Rotate(Transform, ThreadUnsafe): @@ -428,17 +694,19 @@ class Rotate(Transform, ThreadUnsafe): input array is contained completely in the output. Default is True. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html align_corners: Defaults to False. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + dtype: data type for resampling computation. Defaults to ``np.float32``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. """ + backend = [TransformBackends.TORCH] + def __init__( self, angle: Union[Sequence[float], float], @@ -446,7 +714,7 @@ def __init__( mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, align_corners: bool = False, - dtype: DtypeLike = np.float64, + dtype: Union[DtypeLike, torch.dtype] = np.float32, ) -> None: self.angle = angle self.keep_size = keep_size @@ -454,29 +722,29 @@ def __init__( self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) self.align_corners = align_corners self.dtype = dtype - self._rotation_matrix: Optional[np.ndarray] = None + self._rotation_matrix: Optional[NdarrayOrTensor] = None def __call__( self, - img: np.ndarray, + img: NdarrayOrTensor, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, - dtype: DtypeLike = None, - ) -> np.ndarray: + dtype: Union[DtypeLike, torch.dtype] = None, + ) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D]. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html align_corners: Defaults to ``self.align_corners``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html align_corners: Defaults to ``self.align_corners``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. @@ -486,7 +754,10 @@ def __call__( """ _dtype = dtype or self.dtype or img.dtype - im_shape = np.asarray(img.shape[1:]) # spatial dimensions + + img_t, *_ = convert_data_type(img, torch.Tensor, dtype=_dtype) + + im_shape = np.asarray(img_t.shape[1:]) # spatial dimensions input_ndim = len(im_shape) if input_ndim not in (2, 3): raise ValueError(f"Unsupported img dimension: {input_ndim}, available options are [2, 3].") @@ -499,11 +770,13 @@ def __call__( corners = np.asarray(np.meshgrid(*[(0, dim) for dim in im_shape], indexing="ij")).reshape( (len(im_shape), -1) ) - corners = transform[:-1, :-1] @ corners + corners = transform[:-1, :-1] @ corners # type: ignore output_shape = np.asarray(corners.ptp(axis=1) + 0.5, dtype=int) shift_1 = create_translate(input_ndim, (-(output_shape - 1) / 2).tolist()) transform = shift @ transform @ shift_1 + transform_t, *_ = convert_to_dst_type(transform, img_t) + xform = AffineTransform( normalized=False, mode=look_up_option(mode or self.mode, GridSampleMode), @@ -511,15 +784,13 @@ def __call__( align_corners=self.align_corners if align_corners is None else align_corners, reverse_indexing=True, ) - output = xform( - torch.as_tensor(np.ascontiguousarray(img).astype(_dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(transform).astype(_dtype)), - spatial_size=output_shape, - ) + output: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=output_shape).float().squeeze(0) self._rotation_matrix = transform - return np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + out: NdarrayOrTensor + out, *_ = convert_to_dst_type(output, dst=img, dtype=output.dtype) + return out - def get_rotation_matrix(self) -> Optional[np.ndarray]: + def get_rotation_matrix(self) -> Optional[NdarrayOrTensor]: """ Get the most recently applied rotation matrix This is not thread-safe. @@ -530,7 +801,7 @@ def get_rotation_matrix(self) -> Optional[np.ndarray]: class Zoom(Transform): """ Zooms an ND image using :py:class:`torch.nn.functional.interpolate`. - For details, please see https://pytorch.org/docs/stable/nn.functional.html#interpolate. + For details, please see https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html. Different from :py:class:`monai.transforms.resize`, this transform takes scaling factors as input, and provides an option of preserving the input spatial size. @@ -541,83 +812,96 @@ class Zoom(Transform): If a sequence, zoom should contain one value for each spatial axis. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"area"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - padding_mode: {``"constant"``, ``"edge``", ``"linear_ramp``", ``"maximum``", ``"mean``", `"median``", - ``"minimum``", `"reflect``", ``"symmetric``", ``"wrap``", ``"empty``", ``"``"} + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. The mode to pad data after zooming. - See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html keep_size: Should keep original size (padding/slicing if needed), default is True. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ + backend = [TransformBackends.TORCH] + def __init__( self, zoom: Union[Sequence[float], float], mode: Union[InterpolateMode, str] = InterpolateMode.AREA, - padding_mode: Union[NumpyPadMode, str] = NumpyPadMode.EDGE, + padding_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.EDGE, align_corners: Optional[bool] = None, keep_size: bool = True, - **np_kwargs, + **kwargs, ) -> None: self.zoom = zoom self.mode: InterpolateMode = InterpolateMode(mode) - self.padding_mode: NumpyPadMode = NumpyPadMode(padding_mode) + self.padding_mode = padding_mode self.align_corners = align_corners self.keep_size = keep_size - self.np_kwargs = np_kwargs + self.kwargs = kwargs def __call__( self, - img: np.ndarray, + img: NdarrayOrTensor, mode: Optional[Union[InterpolateMode, str]] = None, - padding_mode: Optional[Union[NumpyPadMode, str]] = None, + padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, align_corners: Optional[bool] = None, - ): + ) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]). mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - padding_mode: {``"constant"``, ``"edge``", ``"linear_ramp``", ``"maximum``", ``"mean``", `"median``", - ``"minimum``", `"reflect``", ``"symmetric``", ``"wrap``", ``"empty``", ``"``"} - The mode to pad data after zooming, default to ``self.padding_mode``. - See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + The mode to pad data after zooming. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html """ + img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float32) + _zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim - zoomed = torch.nn.functional.interpolate( # type: ignore + zoomed: NdarrayOrTensor = torch.nn.functional.interpolate( # type: ignore recompute_scale_factor=True, - input=torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0), + input=img_t.unsqueeze(0), scale_factor=list(_zoom), mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value, align_corners=self.align_corners if align_corners is None else align_corners, ) - zoomed = zoomed.squeeze(0).detach().cpu().numpy() - if not self.keep_size or np.allclose(img.shape, zoomed.shape): - return zoomed + zoomed = zoomed.squeeze(0) + + if self.keep_size and not np.allclose(img_t.shape, zoomed.shape): + + pad_vec = [(0, 0)] * len(img_t.shape) + slice_vec = [slice(None)] * len(img_t.shape) + for idx, (od, zd) in enumerate(zip(img_t.shape, zoomed.shape)): + diff = od - zd + half = abs(diff) // 2 + if diff > 0: # need padding + pad_vec[idx] = (half, diff - half) + elif diff < 0: # need slicing + slice_vec[idx] = slice(half, half + od) - pad_vec = [[0, 0]] * len(img.shape) - slice_vec = [slice(None)] * len(img.shape) - for idx, (od, zd) in enumerate(zip(img.shape, zoomed.shape)): - diff = od - zd - half = abs(diff) // 2 - if diff > 0: # need padding - pad_vec[idx] = [half, diff - half] - elif diff < 0: # need slicing - slice_vec[idx] = slice(half, half + od) + padder = Pad(pad_vec, padding_mode or self.padding_mode) + zoomed = padder(zoomed) + zoomed = zoomed[tuple(slice_vec)] - padding_mode = look_up_option(self.padding_mode if padding_mode is None else padding_mode, NumpyPadMode) - zoomed = np.pad(zoomed, pad_vec, mode=padding_mode.value, **self.np_kwargs) # type: ignore - return zoomed[tuple(slice_vec)] + out, *_ = convert_to_dst_type(zoomed, dst=img) + return out class Rotate90(Transform): @@ -628,6 +912,8 @@ class Rotate90(Transform): """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: """ Args: @@ -642,14 +928,15 @@ def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: raise ValueError("spatial_axes must be 2 int numbers to indicate the axes to rotate 90 degrees.") self.spatial_axes = spatial_axes_ - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - - result: np.ndarray = np.rot90(img, self.k, map_spatial_axes(img.ndim, self.spatial_axes)) - return result.astype(img.dtype) + rot90: Callable = torch.rot90 if isinstance(img, torch.Tensor) else np.rot90 # type: ignore + out: NdarrayOrTensor = rot90(img, self.k, map_spatial_axes(img.ndim, self.spatial_axes)) + out, *_ = convert_data_type(out, dtype=img.dtype) + return out class RandRotate90(RandomizableTransform): @@ -658,6 +945,8 @@ class RandRotate90(RandomizableTransform): in the plane specified by `spatial_axes`. """ + backend = Rotate90.backend + def __init__(self, prob: float = 0.1, max_k: int = 3, spatial_axes: Tuple[int, int] = (0, 1)) -> None: """ Args: @@ -674,19 +963,24 @@ def __init__(self, prob: float = 0.1, max_k: int = 3, spatial_axes: Tuple[int, i self._rand_k = 0 def randomize(self, data: Optional[Any] = None) -> None: - self._rand_k = self.R.randint(self.max_k) + 1 super().randomize(None) + if not self._do_transform: + return None + self._rand_k = self.R.randint(self.max_k) + 1 - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), + randomize: whether to execute `randomize()` function first, default to True. """ - self.randomize() + if randomize: + self.randomize() + if not self._do_transform: return img - rotator = Rotate90(self._rand_k, self.spatial_axes) - return rotator(img) + + return Rotate90(self._rand_k, self.spatial_axes)(img) class RandRotate(RandomizableTransform): @@ -706,17 +1000,19 @@ class RandRotate(RandomizableTransform): If it is True, the output shape is the same as the input. Default is True. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html align_corners: Defaults to False. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + dtype: data type for resampling computation. Defaults to ``np.float32``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. """ + backend = Rotate.backend + def __init__( self, range_x: Union[Tuple[float, float], float] = 0.0, @@ -727,7 +1023,7 @@ def __init__( mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, align_corners: bool = False, - dtype: DtypeLike = np.float64, + dtype: Union[DtypeLike, torch.dtype] = np.float32, ) -> None: RandomizableTransform.__init__(self, prob) self.range_x = ensure_tuple(range_x) @@ -752,36 +1048,45 @@ def __init__( def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) + if not self._do_transform: + return None self.x = self.R.uniform(low=self.range_x[0], high=self.range_x[1]) self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) def __call__( self, - img: np.ndarray, + img: NdarrayOrTensor, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, - dtype: DtypeLike = None, - ) -> np.ndarray: + dtype: Union[DtypeLike, torch.dtype] = None, + randomize: bool = True, + get_matrix: bool = False, + ): """ Args: img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D). mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html align_corners: Defaults to ``self.align_corners``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. + randomize: whether to execute `randomize()` function first, default to True. + get_matrix: whether to return the rotated image and rotate matrix together, default to False. """ - self.randomize() + if randomize: + self.randomize() + if not self._do_transform: return img + rotator = Rotate( angle=self.x if img.ndim == 3 else (self.x, self.y, self.z), keep_size=self.keep_size, @@ -790,7 +1095,8 @@ def __call__( align_corners=self.align_corners if align_corners is None else align_corners, dtype=dtype or self.dtype or img.dtype, ) - return np.array(rotator(img)) + img = rotator(img) + return (img, rotator.get_rotation_matrix()) if get_matrix else img class RandFlip(RandomizableTransform): @@ -810,14 +1116,18 @@ def __init__(self, prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int] RandomizableTransform.__init__(self, prob) self.flipper = Flip(spatial_axis=spatial_axis) - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), + randomize: whether to execute `randomize()` function first, default to True. """ - self.randomize(None) + if randomize: + self.randomize(None) + if not self._do_transform: return img + return self.flipper(img) @@ -840,18 +1150,23 @@ def __init__(self, prob: float = 0.1) -> None: def randomize(self, data: NdarrayOrTensor) -> None: super().randomize(None) + if not self._do_transform: + return None self._axis = self.R.randint(data.ndim - 1) - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), + randomize: whether to execute `randomize()` function first, default to True. """ - self.randomize(data=img) + if randomize: + self.randomize(data=img) + if not self._do_transform: return img - flipper = Flip(spatial_axis=self._axis) - return flipper(img) + + return Flip(spatial_axis=self._axis)(img) class RandZoom(RandomizableTransform): @@ -872,30 +1187,35 @@ class RandZoom(RandomizableTransform): If 2 values provided for 3D data, use the first value for both H & W dims to keep the same zoom ratio. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"area"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - padding_mode: {``"constant"``, ``"edge``", ``"linear_ramp``", ``"maximum``", ``"mean``", `"median``", - ``"minimum``", `"reflect``", ``"symmetric``", ``"wrap``", ``"empty``", ``"``"} + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. The mode to pad data after zooming. - See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html keep_size: Should keep original size (pad if needed), default is True. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ + backend = Zoom.backend + def __init__( self, prob: float = 0.1, min_zoom: Union[Sequence[float], float] = 0.9, max_zoom: Union[Sequence[float], float] = 1.1, mode: Union[InterpolateMode, str] = InterpolateMode.AREA, - padding_mode: Union[NumpyPadMode, str] = NumpyPadMode.EDGE, + padding_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.EDGE, align_corners: Optional[bool] = None, keep_size: bool = True, - **np_kwargs, + **kwargs, ) -> None: RandomizableTransform.__init__(self, prob) self.min_zoom = ensure_tuple(min_zoom) @@ -903,59 +1223,67 @@ def __init__( if len(self.min_zoom) != len(self.max_zoom): raise AssertionError("min_zoom and max_zoom must have same length.") self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) - self.padding_mode: NumpyPadMode = look_up_option(padding_mode, NumpyPadMode) + self.padding_mode = padding_mode self.align_corners = align_corners self.keep_size = keep_size - self.np_kwargs = np_kwargs + self.kwargs = kwargs self._zoom: Sequence[float] = [1.0] - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, img: NdarrayOrTensor) -> None: super().randomize(None) + if not self._do_transform: + return None self._zoom = [self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)] + if len(self._zoom) == 1: + # to keep the spatial shape ratio, use same random zoom factor for all dims + self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 1) + elif len(self._zoom) == 2 and img.ndim > 3: + # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim + self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 2) + ensure_tuple(self._zoom[-1]) def __call__( self, - img: np.ndarray, + img: NdarrayOrTensor, mode: Optional[Union[InterpolateMode, str]] = None, - padding_mode: Optional[Union[NumpyPadMode, str]] = None, + padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, align_corners: Optional[bool] = None, - ) -> np.ndarray: + randomize: bool = True, + ) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D). mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - padding_mode: {``"constant"``, ``"edge``", ``"linear_ramp``", ``"maximum``", ``"mean``", `"median``", - ``"minimum``", `"reflect``", ``"symmetric``", ``"wrap``", ``"empty``", ``"``"} - The mode to pad data after zooming, default to ``self.padding_mode``. - See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + The mode to pad data after zooming. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + randomize: whether to execute `randomize()` function first, default to True. + """ # match the spatial image dim - self.randomize() - _dtype = np.float32 + if randomize: + self.randomize(img=img) + if not self._do_transform: - return img.astype(_dtype) - if len(self._zoom) == 1: - # to keep the spatial shape ratio, use same random zoom factor for all dims - self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 1) - elif len(self._zoom) == 2 and img.ndim > 3: - # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim - self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 2) + ensure_tuple(self._zoom[-1]) - zoomer = Zoom(self._zoom, keep_size=self.keep_size, **self.np_kwargs) - return np.asarray( - zoomer( - img, - mode=look_up_option(mode or self.mode, InterpolateMode), - padding_mode=look_up_option(padding_mode or self.padding_mode, NumpyPadMode), - align_corners=self.align_corners if align_corners is None else align_corners, - ), - dtype=_dtype, - ) + return img + + return Zoom( + self._zoom, + keep_size=self.keep_size, + mode=look_up_option(mode or self.mode, InterpolateMode), + padding_mode=padding_mode or self.padding_mode, + align_corners=align_corners or self.align_corners, + **self.kwargs, + )(img) class AffineGrid(Transform): @@ -979,14 +1307,21 @@ class AffineGrid(Transform): pixel/voxel relative to the center of the input image. Defaults to no translation. scale_params: scale factor for every spatial dims. a tuple of 2 floats for 2D, a tuple of 3 floats for 3D. Defaults to `1.0`. - as_tensor_output: whether to output tensor instead of numpy array, defaults to True. - device: device to store the output grid data. + dtype: data type for the grid computation. Defaults to ``np.float32``. + If ``None``, use the data type of input data (if `grid` is provided). + device: device on which the tensor will be allocated, if a new grid is generated. affine: If applied, ignore the params (`rotate_params`, etc.) and use the supplied matrix. Should be square with each side = num of image spatial dimensions + 1. + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, rotate_params: Optional[Union[Sequence[float], float]] = None, @@ -995,24 +1330,25 @@ def __init__( scale_params: Optional[Union[Sequence[float], float]] = None, as_tensor_output: bool = True, device: Optional[torch.device] = None, - affine: Optional[Union[np.ndarray, torch.Tensor]] = None, + dtype: DtypeLike = np.float32, + affine: Optional[NdarrayOrTensor] = None, ) -> None: self.rotate_params = rotate_params self.shear_params = shear_params self.translate_params = translate_params self.scale_params = scale_params - - self.as_tensor_output = as_tensor_output self.device = device - + self.dtype = dtype self.affine = affine def __call__( - self, - spatial_size: Optional[Sequence[int]] = None, - grid: Optional[Union[np.ndarray, torch.Tensor]] = None, - ) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]: + self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[NdarrayOrTensor] = None + ) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: """ + The grid can be initialized with a `spatial_size` parameter, or provided directly as `grid`. + Therefore, either `spatial_size` or `grid` must be provided. + When initialising from `spatial_size`, the backend "torch" will be used. + Args: spatial_size: output grid size. grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. @@ -1021,38 +1357,36 @@ def __call__( ValueError: When ``grid=None`` and ``spatial_size=None``. Incompatible values. """ - if grid is None: - if spatial_size is not None: - grid = create_grid(spatial_size) - else: + if grid is None: # create grid from spatial_size + if spatial_size is None: raise ValueError("Incompatible values: grid=None and spatial_size=None.") - - affine: Union[torch.Tensor, np.ndarray] + grid = create_grid(spatial_size, device=self.device, backend="torch", dtype=self.dtype) + _b = TransformBackends.TORCH if isinstance(grid, torch.Tensor) else TransformBackends.NUMPY + _device = grid.device if isinstance(grid, torch.Tensor) else self.device + affine: NdarrayOrTensor if self.affine is None: spatial_dims = len(grid.shape) - 1 - affine = np.eye(spatial_dims + 1) + affine = ( + torch.eye(spatial_dims + 1, device=_device) + if _b == TransformBackends.TORCH + else np.eye(spatial_dims + 1) + ) if self.rotate_params: - affine = affine @ create_rotate(spatial_dims, self.rotate_params) + affine = affine @ create_rotate(spatial_dims, self.rotate_params, device=_device, backend=_b) if self.shear_params: - affine = affine @ create_shear(spatial_dims, self.shear_params) + affine = affine @ create_shear(spatial_dims, self.shear_params, device=_device, backend=_b) if self.translate_params: - affine = affine @ create_translate(spatial_dims, self.translate_params) + affine = affine @ create_translate(spatial_dims, self.translate_params, device=_device, backend=_b) if self.scale_params: - affine = affine @ create_scale(spatial_dims, self.scale_params) + affine = affine @ create_scale(spatial_dims, self.scale_params, device=_device, backend=_b) else: affine = self.affine - if isinstance(affine, np.ndarray): - affine = torch.as_tensor(np.ascontiguousarray(affine)) + grid, *_ = convert_data_type(grid, torch.Tensor, device=_device, dtype=self.dtype or grid.dtype) + affine, *_ = convert_to_dst_type(affine, grid) - grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone() - if self.device: - affine = affine.to(self.device) - grid = grid.to(self.device) - grid = (affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:])) - if grid is None or not isinstance(grid, torch.Tensor): - raise ValueError("Unknown grid.") - return grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()), affine + grid = (affine @ grid.reshape((grid.shape[0], -1))).reshape([-1] + list(grid.shape[1:])) + return grid, affine class RandAffineGrid(Randomizable, Transform): @@ -1061,6 +1395,9 @@ class RandAffineGrid(Randomizable, Transform): """ + backend = AffineGrid.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, rotate_range: RandRange = None, @@ -1094,8 +1431,6 @@ def __init__( scale_range: scaling range with format matching `rotate_range`. it defines the range to randomly select the scale factor to translate for every spatial dims. A value of 1.0 is added to the result. This allows 0 to correspond to no change (i.e., a scaling of 1.0). - as_tensor_output: whether to output tensor instead of numpy array. - defaults to True. device: device to store the output grid data. See also: @@ -1103,6 +1438,10 @@ def __init__( - :py:meth:`monai.transforms.utils.create_shear` - :py:meth:`monai.transforms.utils.create_translate` - :py:meth:`monai.transforms.utils.create_scale` + + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ self.rotate_range = ensure_tuple(rotate_range) self.shear_range = ensure_tuple(shear_range) @@ -1114,9 +1453,8 @@ def __init__( self.translate_params: Optional[List[float]] = None self.scale_params: Optional[List[float]] = None - self.as_tensor_output = as_tensor_output self.device = device - self.affine: Optional[Union[np.ndarray, torch.Tensor]] = None + self.affine: Optional[NdarrayOrTensor] = None def _get_rand_param(self, param_range, add_scalar: float = 0.0): out_param = [] @@ -1136,10 +1474,8 @@ def randomize(self, data: Optional[Any] = None) -> None: self.scale_params = self._get_rand_param(self.scale_range, 1.0) def __call__( - self, - spatial_size: Optional[Sequence[int]] = None, - grid: Optional[Union[np.ndarray, torch.Tensor]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[NdarrayOrTensor] = None + ) -> NdarrayOrTensor: """ Args: spatial_size: output grid size. @@ -1154,13 +1490,13 @@ def __call__( shear_params=self.shear_params, translate_params=self.translate_params, scale_params=self.scale_params, - as_tensor_output=self.as_tensor_output, device=self.device, ) - grid, self.affine = affine_grid(spatial_size, grid) - return grid + _grid: NdarrayOrTensor + _grid, self.affine = affine_grid(spatial_size, grid) + return _grid - def get_transformation_matrix(self) -> Optional[Union[np.ndarray, torch.Tensor]]: + def get_transformation_matrix(self) -> Optional[NdarrayOrTensor]: """Get the most recently applied transformation matrix""" return self.affine @@ -1170,6 +1506,8 @@ class RandDeformGrid(Randomizable, Transform): Generate random deformation grid. """ + backend = [TransformBackends.TORCH] + def __init__( self, spacing: Union[Sequence[float], float], @@ -1198,7 +1536,7 @@ def __init__( self.device = device def randomize(self, grid_size: Sequence[int]) -> None: - self.random_offset = self.R.normal(size=([len(grid_size)] + list(grid_size))).astype(np.float32) + self.random_offset = self.R.normal(size=([len(grid_size)] + list(grid_size))).astype(np.float32, copy=False) self.rand_mag = self.R.uniform(self.magnitude[0], self.magnitude[1]) def __call__(self, spatial_size: Sequence[int]): @@ -1207,21 +1545,28 @@ def __call__(self, spatial_size: Sequence[int]): spatial_size: spatial size of the grid. """ self.spacing = fall_back_tuple(self.spacing, (1.0,) * len(spatial_size)) - control_grid = create_control_grid(spatial_size, self.spacing) + control_grid = create_control_grid(spatial_size, self.spacing, device=self.device, backend="torch") self.randomize(control_grid.shape[1:]) - control_grid[: len(spatial_size)] += self.rand_mag * self.random_offset - if self.as_tensor_output: - control_grid = torch.as_tensor(np.ascontiguousarray(control_grid), device=self.device) + _offset, *_ = convert_to_dst_type(self.rand_mag * self.random_offset, control_grid) + control_grid[: len(spatial_size)] += _offset + if not self.as_tensor_output: + control_grid, *_ = convert_data_type(control_grid, output_type=np.ndarray, dtype=np.float32) return control_grid class Resample(Transform): + + backend = [TransformBackends.TORCH] + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, - as_tensor_output: bool = False, + as_tensor_output: bool = True, + norm_coords: bool = True, device: Optional[torch.device] = None, + dtype: DtypeLike = np.float64, ) -> None: """ computes output image using values from `img`, locations from `grid` using pytorch. @@ -1230,85 +1575,106 @@ def __init__( Args: mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: whether to return a torch tensor. Defaults to False. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull + norm_coords: whether to normalize the coordinates from `[-(size-1)/2, (size-1)/2]` to + `[0, size - 1]` (for ``monai/csrc`` implementation) or + `[-1, 1]` (for torch ``grid_sample`` implementation) to be compatible with the underlying + resampling API. device: device on which the tensor will be allocated. + dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. + If ``None``, use the data type of input data. To be compatible with other modules, + the output data type is always `float32`. + + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) - self.as_tensor_output = as_tensor_output + self.norm_coords = norm_coords self.device = device + self.dtype = dtype def __call__( self, - img: Union[np.ndarray, torch.Tensor], - grid: Optional[Union[np.ndarray, torch.Tensor]] = None, + img: NdarrayOrTensor, + grid: Optional[NdarrayOrTensor] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + dtype: DtypeLike = None, + ) -> NdarrayOrTensor: """ Args: img: shape must be (num_channels, H, W[, D]). grid: shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. + if ``norm_coords`` is True, the grid values must be in `[-(size-1)/2, (size-1)/2]`. + if ``USE_COMPILED=True`` and ``norm_coords=False``, grid values must be in `[0, size-1]`. + if ``USE_COMPILED=False`` and ``norm_coords=False``, grid values must be in `[-1, 1]`. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - """ + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + dtype: data type for resampling computation. Defaults to ``self.dtype``. + To be compatible with other modules, the output data type is always `float32`. - if not isinstance(img, torch.Tensor): - img = torch.as_tensor(np.ascontiguousarray(img)) + See also: + :py:const:`monai.config.USE_COMPILED` + """ if grid is None: - raise AssertionError("Error, grid argument must be supplied as an ndarray or tensor ") - grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone() - if self.device: - img = img.to(self.device) - grid = grid.to(self.device) + raise ValueError("Unknown grid.") + _device = img.device if isinstance(img, torch.Tensor) else self.device + _dtype = dtype or self.dtype or img.dtype + img_t, *_ = convert_data_type(img, torch.Tensor, device=_device, dtype=_dtype) + grid_t = convert_to_dst_type(grid, img_t)[0] + if grid_t is grid: # copy if needed (convert_data_type converts to contiguous) + grid_t = grid_t.clone(memory_format=torch.contiguous_format) + sr = min(len(img_t.shape[1:]), 3) if USE_COMPILED: - for i, dim in enumerate(img.shape[1:]): - grid[i] += (dim - 1.0) / 2.0 - grid = grid[:-1] / grid[-1:] - grid = grid.permute(list(range(grid.ndimension()))[1:] + [0]) - _padding_mode = look_up_option( - self.padding_mode if padding_mode is None else padding_mode, GridSamplePadMode - ).value - if _padding_mode == "zeros": - bound = 7 - elif _padding_mode == "border": - bound = 0 + if self.norm_coords: + for i, dim in enumerate(img_t.shape[1 : 1 + sr]): + grid_t[i] = (max(dim, 2) / 2.0 - 0.5 + grid_t[i]) / grid_t[-1:] + grid_t = moveaxis(grid_t[:sr], 0, -1) # type: ignore + _padding_mode = self.padding_mode if padding_mode is None else padding_mode + _padding_mode = _padding_mode.value if isinstance(_padding_mode, GridSamplePadMode) else _padding_mode + bound = 1 if _padding_mode == "reflection" else _padding_mode + _interp_mode = self.mode if mode is None else mode + _interp_mode = _interp_mode.value if isinstance(_interp_mode, GridSampleMode) else _interp_mode + if _interp_mode == "bicubic": + interp = 3 + elif _interp_mode == "bilinear": + interp = 1 else: - bound = 1 - _interp_mode = look_up_option(self.mode if mode is None else mode, GridSampleMode).value + interp = _interp_mode # type: ignore out = grid_pull( - img.unsqueeze(0).float(), - grid.unsqueeze(0).float(), - bound=bound, - extrapolate=True, - interpolation=1 if _interp_mode == "bilinear" else _interp_mode, + img_t.unsqueeze(0), grid_t.unsqueeze(0), bound=bound, extrapolate=True, interpolation=interp )[0] else: - for i, dim in enumerate(img.shape[1:]): - grid[i] = 2.0 * grid[i] / (dim - 1.0) - grid = grid[:-1] / grid[-1:] - index_ordering: List[int] = list(range(img.ndimension() - 2, -1, -1)) - grid = grid[index_ordering] - grid = grid.permute(list(range(grid.ndimension()))[1:] + [0]) + if self.norm_coords: + for i, dim in enumerate(img_t.shape[1 : 1 + sr]): + grid_t[i] = 2.0 / (max(2, dim) - 1.0) * grid_t[i] / grid_t[-1:] + index_ordering: List[int] = list(range(sr - 1, -1, -1)) + grid_t = moveaxis(grid_t[index_ordering], 0, -1) # type: ignore out = torch.nn.functional.grid_sample( - img.unsqueeze(0).float(), - grid.unsqueeze(0).float(), + img_t.unsqueeze(0), + grid_t.unsqueeze(0), mode=self.mode.value if mode is None else GridSampleMode(mode).value, padding_mode=self.padding_mode.value if padding_mode is None else GridSamplePadMode(padding_mode).value, align_corners=True, )[0] - if self.as_tensor_output: - return torch.as_tensor(out) - return np.asarray(out.cpu().numpy()) + out_val, *_ = convert_to_dst_type(out, dst=img, dtype=np.float32) + return out_val class Affine(Transform): @@ -1318,17 +1684,23 @@ class Affine(Transform): """ + backend = list(set(AffineGrid.backend) & set(Resample.backend)) + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, rotate_params: Optional[Union[Sequence[float], float]] = None, shear_params: Optional[Union[Sequence[float], float]] = None, translate_params: Optional[Union[Sequence[float], float]] = None, scale_params: Optional[Union[Sequence[float], float]] = None, + affine: Optional[NdarrayOrTensor] = None, spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION, - as_tensor_output: bool = False, + norm_coords: bool = True, + as_tensor_output: bool = True, device: Optional[torch.device] = None, + dtype: DtypeLike = np.float32, image_only: bool = False, ) -> None: """ @@ -1351,6 +1723,9 @@ def __init__( pixel/voxel relative to the center of the input image. Defaults to no translation. scale_params: scale factor for every spatial dims. a tuple of 2 floats for 2D, a tuple of 3 floats for 3D. Defaults to `1.0`. + affine: If applied, ignore the params (`rotate_params`, etc.) and use the + supplied matrix. Should be square with each side = num of image spatial + dimensions + 1. spatial_size: output image spatial size. if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1, the transform will use the spatial size of `img`. @@ -1359,36 +1734,50 @@ def __init__( to `(32, 64)` if the second spatial dimension size of img is `64`. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + norm_coords: whether to normalize the coordinates from `[-(size-1)/2, (size-1)/2]` to + `[0, size - 1]` or `[-1, 1]` to be compatible with the underlying resampling API. + If the coordinates are generated by ``monai.transforms.utils.create_grid`` + and the ``affine`` doesn't include the normalization, this argument should be set to ``True``. + If the output `self.affine_grid` is already normalized, this argument should be set to ``False``. device: device on which the tensor will be allocated. + dtype: data type for resampling computation. Defaults to ``np.float32``. + If ``None``, use the data type of input data. To be compatible with other modules, + the output data type is always `float32`. image_only: if True return only the image volume, otherwise return (image, affine). + + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ self.affine_grid = AffineGrid( rotate_params=rotate_params, shear_params=shear_params, translate_params=translate_params, scale_params=scale_params, - as_tensor_output=True, + affine=affine, + dtype=dtype, device=device, ) self.image_only = image_only - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.resampler = Resample(norm_coords=norm_coords, device=device, dtype=dtype) self.spatial_size = spatial_size self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: NdarrayOrTensor, spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ): + ) -> Union[NdarrayOrTensor, Tuple[NdarrayOrTensor, NdarrayOrTensor]]: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1399,10 +1788,13 @@ def __call__( if `img` has three spatial dimensions, `spatial_size` should have 3 elements [h, w, d]. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) grid, affine = self.affine_grid(spatial_size=sp_size) @@ -1418,6 +1810,9 @@ class RandAffine(RandomizableTransform): """ + backend = Affine.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, prob: float = 0.1, @@ -1466,20 +1861,22 @@ def __init__( to `(32, 64)` if the second spatial dimension size of img is `64`. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html cache_grid: whether to cache the identity sampling grid. If the spatial size is not dynamically defined by input image, enabling this option could accelerate the transform. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. See also: - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. + + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ RandomizableTransform.__init__(self, prob) @@ -1488,10 +1885,9 @@ def __init__( shear_range=shear_range, translate_range=translate_range, scale_range=scale_range, - as_tensor_output=True, device=device, ) - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.resampler = Resample(device=device) self.spatial_size = spatial_size self.cache_grid = cache_grid @@ -1519,7 +1915,7 @@ def _init_identity_cache(self): f"'spatial_size={self.spatial_size}', please specify 'spatial_size'." ) return None - return torch.tensor(create_grid(spatial_size=_sp_size)).to(self.rand_affine_grid.device) + return create_grid(spatial_size=_sp_size, device=self.rand_affine_grid.device, backend="torch") def get_identity_grid(self, spatial_size: Sequence[int]): """ @@ -1533,7 +1929,11 @@ def get_identity_grid(self, spatial_size: Sequence[int]): spatial_size, [2] * ndim ): raise RuntimeError(f"spatial_size should not be dynamic, got {spatial_size}.") - return create_grid(spatial_size=spatial_size) if self._cached_grid is None else self._cached_grid + return ( + create_grid(spatial_size=spatial_size, device=self.rand_affine_grid.device, backend="torch") + if self._cached_grid is None + else self._cached_grid + ) def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None @@ -1544,15 +1944,18 @@ def set_random_state( def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) + if not self._do_transform: + return None self.rand_affine_grid.randomize() def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: NdarrayOrTensor, spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + randomize: bool = True, + ) -> NdarrayOrTensor: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1563,25 +1966,29 @@ def __call__( if `img` has three spatial dimensions, `spatial_size` should have 3 elements [h, w, d]. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + randomize: whether to execute `randomize()` function first, default to True. + """ - self.randomize() + if randomize: + self.randomize() + # if not doing transform and spatial size doesn't change, nothing to do - # except convert to float and convert numpy/torch + # except convert to float and device sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) do_resampling = self._do_transform or (sp_size != ensure_tuple(img.shape[1:])) if not do_resampling: - img = img.float() if isinstance(img, torch.Tensor) else img.astype("float32") - return torch.Tensor(img) if self.resampler.as_tensor_output else np.array(img) + img, *_ = convert_data_type(img, dtype=torch.float32, device=self.resampler.device) grid = self.get_identity_grid(sp_size) if self._do_transform: grid = self.rand_affine_grid(grid=grid) - return self.resampler( + out: NdarrayOrTensor = self.resampler( img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode ) + return out class Rand2DElastic(RandomizableTransform): @@ -1591,6 +1998,9 @@ class Rand2DElastic(RandomizableTransform): """ + backend = Resample.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, spacing: Union[Tuple[float, float], float], @@ -1641,17 +2051,19 @@ def __init__( to `(32, 64)` if the second spatial dimension size of img is `64`. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html device: device on which the tensor will be allocated. See also: - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. + + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ RandomizableTransform.__init__(self, prob) self.deform_grid = RandDeformGrid( @@ -1662,11 +2074,11 @@ def __init__( shear_range=shear_range, translate_range=translate_range, scale_range=scale_range, - as_tensor_output=True, device=device, ) - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.resampler = Resample(device=device) + self.device = device self.spatial_size = spatial_size self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) @@ -1681,16 +2093,19 @@ def set_random_state( def randomize(self, spatial_size: Sequence[int]) -> None: super().randomize(None) + if not self._do_transform: + return None self.deform_grid.randomize(spatial_size) self.rand_affine_grid.randomize() def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: NdarrayOrTensor, spatial_size: Optional[Union[Tuple[int, int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + randomize: bool = True, + ) -> NdarrayOrTensor: """ Args: img: shape must be (num_channels, H, W), @@ -1699,27 +2114,34 @@ def __call__( the transform will use the spatial size of `img`. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + randomize: whether to execute `randomize()` function first, default to True. """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) - self.randomize(spatial_size=sp_size) + if randomize: + self.randomize(spatial_size=sp_size) + if self._do_transform: grid = self.deform_grid(spatial_size=sp_size) grid = self.rand_affine_grid(grid=grid) grid = torch.nn.functional.interpolate( # type: ignore recompute_scale_factor=True, - input=torch.as_tensor(grid).unsqueeze(0), + input=grid.unsqueeze(0), scale_factor=list(ensure_tuple(self.deform_grid.spacing)), mode=InterpolateMode.BICUBIC.value, align_corners=False, ) grid = CenterSpatialCrop(roi_size=sp_size)(grid[0]) else: - grid = create_grid(spatial_size=sp_size) - return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) + _device = img.device if isinstance(img, torch.Tensor) else self.device + grid = create_grid(spatial_size=sp_size, device=_device, backend="torch") + out: NdarrayOrTensor = self.resampler( + img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode + ) + return out class Rand3DElastic(RandomizableTransform): @@ -1729,6 +2151,9 @@ class Rand3DElastic(RandomizableTransform): """ + backend = Resample.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, sigma_range: Tuple[float, float], @@ -1782,21 +2207,29 @@ def __init__( to `(32, 32, 64)` if the third spatial dimension size of img is `64`. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html device: device on which the tensor will be allocated. See also: - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. + + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ RandomizableTransform.__init__(self, prob) - self.rand_affine_grid = RandAffineGrid(rotate_range, shear_range, translate_range, scale_range, True, device) - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.rand_affine_grid = RandAffineGrid( + rotate_range=rotate_range, + shear_range=shear_range, + translate_range=translate_range, + scale_range=scale_range, + device=device, + ) + self.resampler = Resample(device=device) self.sigma_range = sigma_range self.magnitude_range = magnitude_range @@ -1818,19 +2251,21 @@ def set_random_state( def randomize(self, grid_size: Sequence[int]) -> None: super().randomize(None) - if self._do_transform: - self.rand_offset = self.R.uniform(-1.0, 1.0, [3] + list(grid_size)).astype(np.float32) + if not self._do_transform: + return None + self.rand_offset = self.R.uniform(-1.0, 1.0, [3] + list(grid_size)).astype(np.float32, copy=False) self.magnitude = self.R.uniform(self.magnitude_range[0], self.magnitude_range[1]) self.sigma = self.R.uniform(self.sigma_range[0], self.sigma_range[1]) self.rand_affine_grid.randomize() def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: NdarrayOrTensor, spatial_size: Optional[Union[Tuple[int, int, int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + randomize: bool = True, + ) -> NdarrayOrTensor: """ Args: img: shape must be (num_channels, H, W, D), @@ -1839,63 +2274,189 @@ def __call__( the transform will use the spatial size of `img`. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + randomize: whether to execute `randomize()` function first, default to True. """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) - self.randomize(grid_size=sp_size) - grid = create_grid(spatial_size=sp_size) + if randomize: + self.randomize(grid_size=sp_size) + + _device = img.device if isinstance(img, torch.Tensor) else self.device + grid = create_grid(spatial_size=sp_size, device=_device, backend="torch") if self._do_transform: if self.rand_offset is None: - raise AssertionError - grid = torch.as_tensor(np.ascontiguousarray(grid), device=self.device) - gaussian = GaussianFilter(3, self.sigma, 3.0).to(device=self.device) - offset = torch.as_tensor(self.rand_offset, device=self.device).unsqueeze(0) + raise RuntimeError("rand_offset is not initialized.") + gaussian = GaussianFilter(3, self.sigma, 3.0).to(device=_device) + offset = torch.as_tensor(self.rand_offset, device=_device).unsqueeze(0) grid[:3] += gaussian(offset)[0] * self.magnitude grid = self.rand_affine_grid(grid=grid) - return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) + out: NdarrayOrTensor = self.resampler( + img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode + ) + return out -class AddCoordinateChannels(Transform): - """ - Appends additional channels encoding coordinates of the input. Useful when e.g. training using patch-based sampling, - to allow feeding of the patch's location into the network. +class GridDistortion(Transform): - This can be seen as a input-only version of CoordConv: + backend = [TransformBackends.TORCH] + + def __init__( + self, + num_cells: Union[Tuple[int], int], + distort_steps: Sequence[Sequence[float]], + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + device: Optional[torch.device] = None, + ) -> None: + """ + Grid distortion transform. Refer to: + https://github.com/albumentations-team/albumentations/blob/master/albumentations/augmentations/transforms.py + + Args: + num_cells: number of grid cells on each dimension. + distort_steps: This argument is a list of tuples, where each tuple contains the distort steps of the + corresponding dimensions (in the order of H, W[, D]). The length of each tuple equals to `num_cells + 1`. + Each value in the tuple represents the distort step of the related cell. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + device: device on which the tensor will be allocated. + + """ + self.resampler = Resample(mode=mode, padding_mode=padding_mode, device=device) + self.num_cells = num_cells + self.distort_steps = distort_steps + self.device = device + + def __call__( + self, + img: NdarrayOrTensor, + distort_steps: Optional[Sequence[Sequence]] = None, + mode: Optional[Union[GridSampleMode, str]] = None, + padding_mode: Optional[Union[GridSamplePadMode, str]] = None, + ) -> NdarrayOrTensor: + """ + Args: + img: shape must be (num_channels, H, W[, D]). + distort_steps: This argument is a list of tuples, where each tuple contains the distort steps of the + corresponding dimensions (in the order of H, W[, D]). The length of each tuple equals to `num_cells + 1`. + Each value in the tuple represents the distort step of the related cell. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + + """ + distort_steps = self.distort_steps if distort_steps is None else distort_steps + if len(img.shape) != len(distort_steps) + 1: + raise ValueError("the spatial size of `img` does not match with the length of `distort_steps`") + + all_ranges = [] + num_cells = ensure_tuple_rep(self.num_cells, len(img.shape) - 1) + for dim_idx, dim_size in enumerate(img.shape[1:]): + dim_distort_steps = distort_steps[dim_idx] + ranges = torch.zeros(dim_size, dtype=torch.float32) + cell_size = dim_size // num_cells[dim_idx] + prev = 0 + for idx in range(num_cells[dim_idx] + 1): + start = int(idx * cell_size) + end = start + cell_size + if end > dim_size: + end = dim_size + cur = dim_size + else: + cur = prev + cell_size * dim_distort_steps[idx] + ranges[start:end] = torch.linspace(prev, cur, end - start) + prev = cur + ranges = ranges - (dim_size - 1.0) / 2.0 + all_ranges.append(ranges) + + coords = meshgrid_ij(*all_ranges) + grid = torch.stack([*coords, torch.ones_like(coords[0])]) + + return self.resampler(img, grid=grid, mode=mode, padding_mode=padding_mode) # type: ignore - Liu, R. et al. An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution, NeurIPS 2018. - """ + +class RandGridDistortion(RandomizableTransform): + + backend = [TransformBackends.TORCH] def __init__( self, - spatial_channels: Sequence[int], + num_cells: Union[Tuple[int], int] = 5, + prob: float = 0.1, + distort_limit: Union[Tuple[float, float], float] = (-0.03, 0.03), + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + device: Optional[torch.device] = None, ) -> None: """ + Random grid distortion transform. Refer to: + https://github.com/albumentations-team/albumentations/blob/master/albumentations/augmentations/transforms.py + Args: - spatial_channels: the spatial dimensions that are to have their coordinates encoded in a channel and - appended to the input. E.g., `(1,2,3)` will append three channels to the input, encoding the - coordinates of the input's three spatial dimensions (0 is reserved for the channel dimension). + num_cells: number of grid cells on each dimension. + prob: probability of returning a randomized grid distortion transform. Defaults to 0.1. + distort_limit: range to randomly distort. + If single number, distort_limit is picked from (-distort_limit, distort_limit). + Defaults to (-0.03, 0.03). + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + device: device on which the tensor will be allocated. + """ - self.spatial_channels = spatial_channels + RandomizableTransform.__init__(self, prob) + self.num_cells = num_cells + if isinstance(distort_limit, (int, float)): + self.distort_limit = (min(-distort_limit, distort_limit), max(-distort_limit, distort_limit)) + else: + self.distort_limit = (min(distort_limit), max(distort_limit)) + self.distort_steps: Sequence[Sequence[float]] = ((1.0,),) + self.grid_distortion = GridDistortion( + num_cells=num_cells, distort_steps=self.distort_steps, mode=mode, padding_mode=padding_mode, device=device + ) - def __call__(self, img: Union[np.ndarray, torch.Tensor]): + def randomize(self, spatial_shape: Sequence[int]) -> None: + super().randomize(None) + if not self._do_transform: + return + self.distort_steps = tuple( + tuple(1.0 + self.R.uniform(low=self.distort_limit[0], high=self.distort_limit[1], size=n_cells + 1)) + for n_cells in ensure_tuple_rep(self.num_cells, len(spatial_shape)) + ) + + def __call__( + self, + img: NdarrayOrTensor, + mode: Optional[Union[GridSampleMode, str]] = None, + padding_mode: Optional[Union[GridSamplePadMode, str]] = None, + randomize: bool = True, + ) -> NdarrayOrTensor: """ Args: - img: data to be transformed, assuming `img` is channel first. + img: shape must be (num_channels, H, W[, D]). + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + randomize: whether to shuffle the random factors using `randomize()`, default to True. """ - if max(self.spatial_channels) > img.ndim - 1: - raise ValueError( - f"input has {img.ndim-1} spatial dimensions, cannot add AddCoordinateChannels channel for " - f"dim {max(self.spatial_channels)}." - ) - if 0 in self.spatial_channels: - raise ValueError("cannot add AddCoordinateChannels channel for dimension 0, as 0 is channel dim.") - - spatial_dims = img.shape[1:] - coord_channels = np.array(np.meshgrid(*tuple(np.linspace(-0.5, 0.5, s) for s in spatial_dims), indexing="ij")) - # only keep required dimensions. need to subtract 1 since im will be 0-based - # but user input is 1-based (because channel dim is 0) - coord_channels = coord_channels[[s - 1 for s in self.spatial_channels]] - return np.concatenate((img, coord_channels), axis=0) + if randomize: + self.randomize(img.shape[1:]) + if not self._do_transform: + return img + return self.grid_distortion(img, distort_steps=self.distort_steps, mode=mode, padding_mode=padding_mode) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index b0558a6556..a6d6eba27f 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,30 +17,38 @@ from copy import deepcopy from enum import Enum -from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch from monai.config import DtypeLike, KeysCollection from monai.config.type_definitions import NdarrayOrTensor +from monai.data.utils import affine_to_spacing from monai.networks.layers import AffineTransform from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad from monai.transforms.inverse import InvertibleTransform from monai.transforms.spatial.array import ( - AddCoordinateChannels, Affine, AffineGrid, Flip, + GridDistortion, Orientation, Rand2DElastic, Rand3DElastic, RandAffine, + RandAxisFlip, + RandFlip, + RandGridDistortion, + RandRotate, + RandZoom, + ResampleToMatch, Resize, Rotate, Rotate90, Spacing, + SpatialResample, Zoom, ) from monai.transforms.transform import MapTransform, RandomizableTransform @@ -50,16 +58,21 @@ GridSamplePadMode, InterpolateMode, NumpyPadMode, + PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple, ) -from monai.utils.enums import InverseKeys +from monai.utils.deprecate_utils import deprecated_arg +from monai.utils.enums import PostFix, TraceKeys from monai.utils.module import optional_import +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type nib, _ = optional_import("nibabel") __all__ = [ + "SpatialResampled", + "ResampleToMatchd", "Spacingd", "Orientationd", "Rotate90d", @@ -71,11 +84,15 @@ "Rand3DElasticd", "Flipd", "RandFlipd", + "GridDistortiond", + "RandGridDistortiond", "RandAxisFlipd", "Rotated", "RandRotated", "Zoomd", "RandZoomd", + "SpatialResampleD", + "SpatialResampleDict", "SpacingD", "SpacingDict", "OrientationD", @@ -98,6 +115,10 @@ "FlipDict", "RandFlipD", "RandFlipDict", + "GridDistortionD", + "GridDistortionDict", + "RandGridDistortionD", + "RandGridDistortionDict", "RandAxisFlipD", "RandAxisFlipDict", "RotateD", @@ -108,14 +129,280 @@ "ZoomDict", "RandZoomD", "RandZoomDict", - "AddCoordinateChannelsD", - "AddCoordinateChannelsDict", ] GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str] InterpolateModeSequence = Union[Sequence[Union[InterpolateMode, str]], InterpolateMode, str] -NumpyPadModeSequence = Union[Sequence[Union[NumpyPadMode, str]], NumpyPadMode, str] +PadModeSequence = Union[Sequence[Union[NumpyPadMode, PytorchPadMode, str]], NumpyPadMode, PytorchPadMode, str] +DEFAULT_POST_FIX = PostFix.meta() + + +class SpatialResampled(MapTransform, InvertibleTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.SpatialResample`. + + This transform assumes the ``data`` dictionary has a key for the input + data's metadata and contains ``src_affine`` and ``dst_affine`` required by + `SpatialResample`. The key is formed by ``key_{meta_key_postfix}``. The + transform will swap ``src_affine`` and ``dst_affine`` affine (with potential data type + changes) in the dictionary so that ``src_affine`` always refers to the current + status of affine. + + See also: + :py:class:`monai.transforms.SpatialResample` + """ + + backend = SpatialResample.backend + + def __init__( + self, + keys: KeysCollection, + mode: GridSampleModeSequence = GridSampleMode.BILINEAR, + padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, + align_corners: Union[Sequence[bool], bool] = False, + dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, + meta_keys: Optional[KeysCollection] = None, + meta_key_postfix: str = DEFAULT_POST_FIX, + meta_src_keys: Optional[KeysCollection] = "src_affine", + meta_dst_keys: Optional[KeysCollection] = "dst_affine", + allow_missing_keys: bool = False, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of string, each element corresponds to a key in ``keys``. + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of string, each element corresponds to a key in ``keys``. + align_corners: Geometrically, we consider the pixels of the input as squares rather than points. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of bool, each element corresponds to a key in ``keys``. + dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. + If None, use the data type of input data. To be compatible with other modules, + the output data type is always ``np.float32``. + It also can be a sequence of dtypes, each element corresponds to a key in ``keys``. + meta_keys: explicitly indicate the key of the corresponding meta data dictionary. + for example, for data with key `image`, the metadata by default is in `image_meta_dict`. + the meta data is a dictionary object which contains: filename, affine, original_shape, etc. + it can be a sequence of string, map to the `keys`. + if None, will try to construct meta_keys by `key_{meta_key_postfix}`. + meta_key_postfix: if meta_keys=None, use `key_{postfix}` to fetch the meta data according + to the key data, default is `meta_dict`, the meta data is a dictionary object. + For example, to handle key `image`, read/write affine matrices from the + metadata `image_meta_dict` dictionary's `affine` field. + meta_src_keys: the key of the corresponding ``src_affine`` in the metadata dictionary. + meta_dst_keys: the key of the corresponding ``dst_affine`` in the metadata dictionary. + allow_missing_keys: don't raise exception if key is missing. + """ + super().__init__(keys, allow_missing_keys) + self.sp_transform = SpatialResample() + self.mode = ensure_tuple_rep(mode, len(self.keys)) + self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) + self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) + self.dtype = ensure_tuple_rep(dtype, len(self.keys)) + self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) + if len(self.keys) != len(self.meta_keys): + raise ValueError("meta_keys should have the same length as keys.") + self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) + self.meta_src_keys = ensure_tuple_rep(meta_src_keys, len(self.keys)) + self.meta_dst_keys = ensure_tuple_rep(meta_dst_keys, len(self.keys)) + + def __call__( + self, data: Mapping[Union[Hashable, str], Dict[str, NdarrayOrTensor]] + ) -> Dict[Hashable, NdarrayOrTensor]: + d: Dict = dict(data) + for (key, mode, padding_mode, align_corners, dtype, *metakeyinfo) in self.key_iterator( + d, + self.mode, + self.padding_mode, + self.align_corners, + self.dtype, + self.meta_keys, + self.meta_key_postfix, + self.meta_src_keys, + self.meta_dst_keys, + ): + meta_key, meta_key_postfix, meta_src_key, meta_dst_key = metakeyinfo + meta_key = meta_key or f"{key}_{meta_key_postfix}" + # create metadata if necessary + if meta_key not in d: + d[meta_key] = {meta_src_key: None, meta_dst_key: None} + meta_data = d[meta_key] + original_spatial_shape = d[key].shape[1:] + d[key], meta_data[meta_dst_key] = self.sp_transform( # write dst affine because the dtype might change + img=d[key], + src_affine=meta_data[meta_src_key], + dst_affine=meta_data[meta_dst_key], + spatial_size=None, # None means shape auto inferred + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + dtype=dtype, + ) + meta_data[meta_dst_key], meta_data[meta_src_key] = meta_data[meta_src_key], meta_data[meta_dst_key] + self.push_transform( + d, + key, + extra_info={ + "meta_key": meta_key, + "meta_src_key": meta_src_key, + "meta_dst_key": meta_dst_key, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, + }, + orig_size=original_spatial_shape, + ) + return d + + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = deepcopy(dict(data)) + for key, dtype in self.key_iterator(d, self.dtype): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + meta_data = d[transform[TraceKeys.EXTRA_INFO]["meta_key"]] + src_key = transform[TraceKeys.EXTRA_INFO]["meta_src_key"] + dst_key = transform[TraceKeys.EXTRA_INFO]["meta_dst_key"] + src_affine = meta_data[src_key] + dst_affine = meta_data[dst_key] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] + orig_size = transform[TraceKeys.ORIG_SIZE] + inverse_transform = SpatialResample() + # Apply inverse + d[key], dst_affine = inverse_transform( + img=d[key], + src_affine=src_affine, + dst_affine=dst_affine, + mode=mode, + padding_mode=padding_mode, + align_corners=False if align_corners == TraceKeys.NONE else align_corners, + dtype=dtype, + spatial_size=orig_size, + ) + meta_data[src_key], meta_data[dst_key] = dst_affine, meta_data[src_key] # type: ignore + # Remove the applied transform + self.pop_transform(d, key) + return d + + +class ResampleToMatchd(MapTransform, InvertibleTransform): + """Dictionary-based wrapper of :py:class:`monai.transforms.ResampleToMatch`.""" + + backend = ResampleToMatch.backend + + def __init__( + self, + keys: KeysCollection, + template_key: str, + mode: GridSampleModeSequence = GridSampleMode.BILINEAR, + padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, + align_corners: Union[Sequence[bool], bool] = False, + dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, + allow_missing_keys: bool = False, + ): + """ + Args: + keys: keys of the corresponding items to be transformed. + template_key: key to meta data that output should be resampled to match. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of string, each element corresponds to a key in ``keys``. + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of string, each element corresponds to a key in ``keys``. + align_corners: Geometrically, we consider the pixels of the input as squares rather than points. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of bool, each element corresponds to a key in ``keys``. + dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. + If None, use the data type of input data. To be compatible with other modules, + the output data type is always ``np.float32``. + It also can be a sequence of dtypes, each element corresponds to a key in ``keys``. + allow_missing_keys: don't raise exception if key is missing. + """ + super().__init__(keys, allow_missing_keys) + self.template_key = template_key + self.mode = ensure_tuple_rep(mode, len(self.keys)) + self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) + self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) + self.dtype = ensure_tuple_rep(dtype, len(self.keys)) + self.resampler = ResampleToMatch() + + def __call__(self, data): + d = deepcopy(dict(data)) + dst_meta = d[self.template_key] + for (key, mode, padding_mode, align_corners, dtype) in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners, self.dtype + ): + src_meta_key = PostFix.meta(key) + src_meta = d[src_meta_key] + + orig_spatial_shape = d[key].shape[1:] + orig_meta = deepcopy(src_meta) + + img, new_meta = self.resampler( + img=d[key], + src_meta=src_meta, + dst_meta=dst_meta, + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + dtype=dtype, + ) + d[key] = img + d[src_meta_key] = new_meta + + # track the transform for the inverse + self.push_transform( + d, + key, + extra_info={ + "orig_meta": orig_meta, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, + }, + orig_size=orig_spatial_shape, + ) + + return d + + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = deepcopy(dict(data)) + for key, dtype in self.key_iterator(d, self.dtype): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_meta = transform[TraceKeys.EXTRA_INFO]["orig_meta"] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] + + src_meta_key = PostFix.meta(key) + src_meta = d[src_meta_key] + + img, new_meta = self.resampler( + img=d[key], + src_meta=src_meta, # type: ignore + dst_meta=orig_meta, + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + dtype=dtype, + ) + d[key] = img + d[src_meta_key] = new_meta + + # Remove the applied transform + self.pop_transform(d, key) + return d class Spacingd(MapTransform, InvertibleTransform): @@ -132,6 +419,8 @@ class Spacingd(MapTransform, InvertibleTransform): :py:class:`monai.transforms.Spacing` """ + backend = Spacing.backend + def __init__( self, keys: KeysCollection, @@ -140,9 +429,9 @@ def __init__( mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, align_corners: Union[Sequence[bool], bool] = False, - dtype: Optional[Union[Sequence[DtypeLike], DtypeLike]] = np.float64, + dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", + meta_key_postfix: str = DEFAULT_POST_FIX, allow_missing_keys: bool = False, ) -> None: """ @@ -168,14 +457,14 @@ def __init__( axes against the original ones. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. align_corners: Geometrically, we consider the pixels of the input as squares rather than points. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of bool, each element corresponds to a key in ``keys``. dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. To be compatible with other modules, @@ -186,7 +475,7 @@ def __init__( the meta data is a dictionary object which contains: filename, affine, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys=None, use `key_{postfix}` to to fetch the meta data according + meta_key_postfix: if meta_keys=None, use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. @@ -208,8 +497,8 @@ def __init__( self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) def __call__( - self, data: Mapping[Union[Hashable, str], Dict[str, np.ndarray]] - ) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]: + self, data: Mapping[Union[Hashable, str], Dict[str, NdarrayOrTensor]] + ) -> Dict[Hashable, NdarrayOrTensor]: d: Dict = dict(data) for key, mode, padding_mode, align_corners, dtype, meta_key, meta_key_postfix in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype, self.meta_keys, self.meta_key_postfix @@ -223,7 +512,7 @@ def __call__( # using affine fetched from d[affine_key] original_spatial_shape = d[key].shape[1:] d[key], old_affine, new_affine = self.spacing_transform( - data_array=np.asarray(d[key]), + data_array=d[key], affine=meta_data["affine"], mode=mode, padding_mode=padding_mode, @@ -238,7 +527,7 @@ def __call__( "old_affine": old_affine, "mode": mode.value if isinstance(mode, Enum) else mode, "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else "none", + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, }, orig_size=original_spatial_shape, ) @@ -246,7 +535,7 @@ def __call__( meta_data["affine"] = new_affine return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) @@ -256,25 +545,25 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar + "Please raise a github issue if you need this feature" ) # Create inverse transform - meta_data = d[transform[InverseKeys.EXTRA_INFO]["meta_key"]] - old_affine = np.array(transform[InverseKeys.EXTRA_INFO]["old_affine"]) - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] - orig_size = transform[InverseKeys.ORIG_SIZE] - orig_pixdim = np.sqrt(np.sum(np.square(old_affine), 0))[:-1] + meta_data = d[transform[TraceKeys.EXTRA_INFO]["meta_key"]] + old_affine = np.array(transform[TraceKeys.EXTRA_INFO]["old_affine"]) + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] + orig_size = transform[TraceKeys.ORIG_SIZE] + orig_pixdim = affine_to_spacing(old_affine, -1) inverse_transform = Spacing(orig_pixdim, diagonal=self.spacing_transform.diagonal) # Apply inverse d[key], _, new_affine = inverse_transform( - data_array=np.asarray(d[key]), - affine=meta_data["affine"], + data_array=d[key], + affine=meta_data["affine"], # type: ignore mode=mode, padding_mode=padding_mode, - align_corners=False if align_corners == "none" else align_corners, + align_corners=False if align_corners == TraceKeys.NONE else align_corners, dtype=dtype, output_spatial_shape=orig_size, ) - meta_data["affine"] = new_affine + meta_data["affine"] = new_affine # type: ignore # Remove the applied transform self.pop_transform(d, key) @@ -290,8 +579,14 @@ class Orientationd(MapTransform, InvertibleTransform): After reorienting the input array, this transform will write the new affine to the `affine` field of metadata which is formed by ``key_{meta_key_postfix}``. + + This transform assumes the channel-first input format. + In the case of using this transform for normalizing the orientations of images, + it should be used before any anisotropic spatial transforms. """ + backend = Orientation.backend + def __init__( self, keys: KeysCollection, @@ -299,7 +594,7 @@ def __init__( as_closest_canonical: bool = False, labels: Optional[Sequence[Tuple[str, str]]] = tuple(zip("LPI", "RAS")), meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", + meta_key_postfix: str = DEFAULT_POST_FIX, allow_missing_keys: bool = False, ) -> None: """ @@ -318,7 +613,7 @@ def __init__( the meta data is a dictionary object which contains: filename, affine, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. @@ -341,8 +636,8 @@ def __init__( self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) def __call__( - self, data: Mapping[Union[Hashable, str], Dict[str, np.ndarray]] - ) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]: + self, data: Mapping[Union[Hashable, str], Dict[str, NdarrayOrTensor]] + ) -> Dict[Hashable, NdarrayOrTensor]: d: Dict = dict(data) for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): meta_key = meta_key or f"{key}_{meta_key_postfix}" @@ -355,18 +650,16 @@ def __call__( d[meta_key]["affine"] = new_affine return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - meta_data = d[transform[InverseKeys.EXTRA_INFO]["meta_key"]] - orig_affine = transform[InverseKeys.EXTRA_INFO]["old_affine"] + meta_data: Dict = d[transform[TraceKeys.EXTRA_INFO]["meta_key"]] # type: ignore + orig_affine = transform[TraceKeys.EXTRA_INFO]["old_affine"] orig_axcodes = nib.orientations.aff2axcodes(orig_affine) inverse_transform = Orientation( - axcodes=orig_axcodes, - as_closest_canonical=False, - labels=self.ornt_transform.labels, + axcodes=orig_axcodes, as_closest_canonical=False, labels=self.ornt_transform.labels ) # Apply inverse d[key], _, new_affine = inverse_transform(d[key], affine=meta_data["affine"]) @@ -382,6 +675,8 @@ class Rotate90d(MapTransform, InvertibleTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Rotate90`. """ + backend = Rotate90.backend + def __init__( self, keys: KeysCollection, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1), allow_missing_keys: bool = False ) -> None: @@ -395,14 +690,14 @@ def __init__( super().__init__(keys, allow_missing_keys) self.rotator = Rotate90(k, spatial_axes) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): self.push_transform(d, key) d[key] = self.rotator(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): _ = self.get_most_recent_transform(d, key) @@ -411,9 +706,6 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar num_times_rotated = self.rotator.k num_times_to_rotate = 4 - num_times_rotated inverse_transform = Rotate90(num_times_to_rotate, spatial_axes) - # Might need to convert to numpy - if isinstance(d[key], torch.Tensor): - d[key] = torch.Tensor(d[key]).cpu().numpy() # Apply inverse d[key] = inverse_transform(d[key]) # Remove the applied transform @@ -429,6 +721,8 @@ class RandRotate90d(RandomizableTransform, MapTransform, InvertibleTransform): in the plane specified by `spatial_axes`. """ + backend = Rotate90.backend + def __init__( self, keys: KeysCollection, @@ -461,10 +755,12 @@ def randomize(self, data: Optional[Any] = None) -> None: self._rand_k = self.R.randint(self.max_k) + 1 super().randomize(None) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: self.randomize() d = dict(data) + # FIXME: here we didn't use array version `RandRotate90` transform as others, because we need + # to be compatible with the random status of some previous integration tests rotator = Rotate90(self._rand_k, self.spatial_axes) for key in self.key_iterator(d): if self._do_transform: @@ -472,19 +768,16 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. self.push_transform(d, key, extra_info={"rand_k": self._rand_k}) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InverseKeys.DO_TRANSFORM]: + if transform[TraceKeys.DO_TRANSFORM]: # Create inverse transform - num_times_rotated = transform[InverseKeys.EXTRA_INFO]["rand_k"] + num_times_rotated = transform[TraceKeys.EXTRA_INFO]["rand_k"] num_times_to_rotate = 4 - num_times_rotated inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes) - # Might need to convert to numpy - if isinstance(d[key], torch.Tensor): - d[key] = torch.Tensor(d[key]).cpu().numpy() # Apply inverse d[key] = inverse_transform(d[key]) # Remove the applied transform @@ -511,15 +804,17 @@ class Resized(MapTransform, InvertibleTransform): #albumentations.augmentations.geometric.resize.LongestMaxSize. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"area"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html It also can be a sequence of string, each element corresponds to a key in ``keys``. align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. """ + backend = Resize.backend + def __init__( self, keys: KeysCollection, @@ -534,7 +829,7 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.resizer = Resize(spatial_size=spatial_size, size_mode=size_mode) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): self.push_transform( @@ -542,24 +837,24 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda key, extra_info={ "mode": mode.value if isinstance(mode, Enum) else mode, - "align_corners": align_corners if align_corners is not None else "none", + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, }, ) d[key] = self.resizer(d[key], mode=mode, align_corners=align_corners) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) - orig_size = transform[InverseKeys.ORIG_SIZE] - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] + orig_size = transform[TraceKeys.ORIG_SIZE] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] # Create inverse transform inverse_transform = Resize( spatial_size=orig_size, mode=mode, - align_corners=None if align_corners == "none" else align_corners, + align_corners=None if align_corners == TraceKeys.NONE else align_corners, ) # Apply inverse transform d[key] = inverse_transform(d[key]) @@ -574,6 +869,9 @@ class Affined(MapTransform, InvertibleTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Affine`. """ + backend = Affine.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, @@ -581,11 +879,13 @@ def __init__( shear_params: Optional[Union[Sequence[float], float]] = None, translate_params: Optional[Union[Sequence[float], float]] = None, scale_params: Optional[Union[Sequence[float], float]] = None, + affine: Optional[NdarrayOrTensor] = None, spatial_size: Optional[Union[Sequence[int], int]] = None, mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, - as_tensor_output: bool = False, + as_tensor_output: bool = True, device: Optional[torch.device] = None, + dtype: Union[DtypeLike, torch.dtype] = np.float32, allow_missing_keys: bool = False, ) -> None: """ @@ -607,6 +907,9 @@ def __init__( pixel/voxel relative to the center of the input image. Defaults to no translation. scale_params: scale factor for every spatial dims. a tuple of 2 floats for 2D, a tuple of 3 floats for 3D. Defaults to `1.0`. + affine: if applied, ignore the params (`rotate_params`, etc.) and use the + supplied matrix. Should be square with each side = num of image spatial + dimensions + 1. spatial_size: output image spatial size. if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1, the transform will use the spatial size of `img`. @@ -615,20 +918,25 @@ def __init__( to `(32, 64)` if the second spatial dimension size of img is `64`. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. + dtype: data type for resampling computation. Defaults to ``np.float32``. + If ``None``, use the data type of input data. To be compatible with other modules, + the output data type is always `float32`. allow_missing_keys: don't raise exception if key is missing. See also: - :py:class:`monai.transforms.compose.MapTransform` - :py:class:`RandAffineGrid` for the random affine parameters configurations. + + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ MapTransform.__init__(self, keys, allow_missing_keys) self.affine = Affine( @@ -636,16 +944,15 @@ def __init__( shear_params=shear_params, translate_params=translate_params, scale_params=scale_params, + affine=affine, spatial_size=spatial_size, - as_tensor_output=as_tensor_output, device=device, + dtype=dtype, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): orig_size = d[key].shape[1:] @@ -662,26 +969,23 @@ def __call__( ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) - orig_size = transform[InverseKeys.ORIG_SIZE] + orig_size = transform[TraceKeys.ORIG_SIZE] # Create inverse transform - fwd_affine = transform[InverseKeys.EXTRA_INFO]["affine"] - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] + fwd_affine = transform[TraceKeys.EXTRA_INFO]["affine"] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) - grid, _ = affine_grid(orig_size) # type: ignore + grid, _ = affine_grid(orig_size) # Apply inverse transform - out = self.affine.resampler(d[key], grid, mode, padding_mode) - - # Convert to numpy - d[key] = out if isinstance(out, np.ndarray) else out.cpu().numpy() + d[key] = self.affine.resampler(d[key], grid, mode, padding_mode) # Remove the applied transform self.pop_transform(d, key) @@ -694,6 +998,9 @@ class RandAffined(RandomizableTransform, MapTransform, InvertibleTransform): Dictionary-based wrapper of :py:class:`monai.transforms.RandAffine`. """ + backend = RandAffine.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, @@ -745,23 +1052,25 @@ def __init__( This allows 0 to correspond to no change (i.e., a scaling of 1.0). mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. cache_grid: whether to cache the identity sampling grid. If the spatial size is not dynamically defined by input image, enabling this option could accelerate the transform. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. See also: - :py:class:`monai.transforms.compose.MapTransform` - :py:class:`RandAffineGrid` for the random affine parameters configurations. + + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) @@ -773,7 +1082,6 @@ def __init__( scale_range=scale_range, spatial_size=spatial_size, cache_grid=cache_grid, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -786,28 +1094,29 @@ def set_random_state( super().set_random_state(seed, state) return self - def randomize(self, data: Optional[Any] = None) -> None: - super().randomize(None) - self.rand_affine.randomize() - - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize() + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d - sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:]) - # change image size or do random transform - do_resampling = self._do_transform or (sp_size != ensure_tuple(data[self.keys[0]].shape[1:])) + self.randomize(None) + # all the keys share the same random Affine factor + self.rand_affine.randomize() - # to be consistent with the self._do_transform case (dtype and device) - affine = torch.as_tensor(np.eye(len(sp_size) + 1), device=self.rand_affine.rand_affine_grid.device) + device = self.rand_affine.resampler.device + spatial_size = d[first_key].shape[1:] # type: ignore + sp_size = fall_back_tuple(self.rand_affine.spatial_size, spatial_size) + # change image size or do random transform + do_resampling = self._do_transform or (sp_size != ensure_tuple(spatial_size)) + affine: torch.Tensor = torch.eye(len(sp_size) + 1, dtype=torch.float64, device=device) + # converting affine to tensor because the resampler currently only support torch backend grid = None if do_resampling: # need to prepare grid grid = self.rand_affine.get_identity_grid(sp_size) if self._do_transform: # add some random factors grid = self.rand_affine.rand_affine_grid(grid=grid) - affine = self.rand_affine.rand_affine_grid.get_transformation_matrix() # type: ignore[assignment] + affine = self.rand_affine.rand_affine_grid.get_transformation_matrix() for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): self.push_transform( @@ -822,39 +1131,28 @@ def __call__( # do the transform if do_resampling: d[key] = self.rand_affine.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) - # if not doing transform and and spatial size is unchanged, only need to do numpy/torch conversion - else: - if self.rand_affine.resampler.as_tensor_output and not isinstance(d[key], torch.Tensor): - d[key] = torch.Tensor(d[key]) - elif not self.rand_affine.resampler.as_tensor_output and isinstance(d[key], torch.Tensor): - d[key] = d[key].detach().cpu().numpy() # type: ignore[union-attr] return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # if transform was not performed and spatial size is None, nothing to do. - if not transform[InverseKeys.DO_TRANSFORM] and self.rand_affine.spatial_size is None: - out: Union[np.ndarray, torch.Tensor] = d[key] - else: - orig_size = transform[InverseKeys.ORIG_SIZE] + if transform[TraceKeys.DO_TRANSFORM] or self.rand_affine.spatial_size is not None: + orig_size = transform[TraceKeys.ORIG_SIZE] # Create inverse transform - fwd_affine = transform[InverseKeys.EXTRA_INFO]["affine"] - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] + fwd_affine = transform[TraceKeys.EXTRA_INFO]["affine"] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) - grid, _ = affine_grid(orig_size) # type: ignore + grid, _ = affine_grid(orig_size) # Apply inverse transform - out = self.rand_affine.resampler(d[key], grid, mode, padding_mode) - - # Convert to numpy - d[key] = out if isinstance(out, np.ndarray) else out.cpu().numpy() + d[key] = self.rand_affine.resampler(d[key], grid, mode, padding_mode) # Remove the applied transform self.pop_transform(d, key) @@ -867,6 +1165,9 @@ class Rand2DElasticd(RandomizableTransform, MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Rand2DElastic`. """ + backend = Rand2DElastic.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, @@ -921,20 +1222,22 @@ def __init__( This allows 0 to correspond to no change (i.e., a scaling of 1.0). mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. See also: - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. + + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) @@ -947,7 +1250,6 @@ def __init__( translate_range=translate_range, scale_range=scale_range, spatial_size=spatial_size, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -960,17 +1262,17 @@ def set_random_state( super().set_random_state(seed, state) return self - def randomize(self, spatial_size: Sequence[int]) -> None: - super().randomize(None) - self.rand_2d_elastic.randomize(spatial_size) - - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d - sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, data[self.keys[0]].shape[1:]) - self.randomize(spatial_size=sp_size) + self.randomize(None) + + sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, d[first_key].shape[1:]) # type: ignore + # all the keys share the same random elastic factor + self.rand_2d_elastic.randomize(sp_size) if self._do_transform: grid = self.rand_2d_elastic.deform_grid(spatial_size=sp_size) @@ -984,7 +1286,8 @@ def __call__( ) grid = CenterSpatialCrop(roi_size=sp_size)(grid[0]) else: - grid = create_grid(spatial_size=sp_size) + _device = self.rand_2d_elastic.deform_grid.device + grid = create_grid(spatial_size=sp_size, device=_device, backend="torch") for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): d[key] = self.rand_2d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) @@ -996,6 +1299,9 @@ class Rand3DElasticd(RandomizableTransform, MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Rand3DElastic`. """ + backend = Rand3DElastic.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, @@ -1052,20 +1358,22 @@ def __init__( This allows 0 to correspond to no change (i.e., a scaling of 1.0). mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. See also: - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. + + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) @@ -1078,7 +1386,6 @@ def __init__( translate_range=translate_range, scale_range=scale_range, spatial_size=spatial_size, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -1091,23 +1398,24 @@ def set_random_state( super().set_random_state(seed, state) return self - def randomize(self, grid_size: Sequence[int]) -> None: - super().randomize(None) - self.rand_3d_elastic.randomize(grid_size) - - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, data[self.keys[0]].shape[1:]) + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + + self.randomize(None) + + sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, d[first_key].shape[1:]) # type: ignore + # all the keys share the same random elastic factor + self.rand_3d_elastic.randomize(sp_size) - self.randomize(grid_size=sp_size) - grid = create_grid(spatial_size=sp_size) + _device = self.rand_3d_elastic.device + grid = create_grid(spatial_size=sp_size, device=_device, backend="torch") if self._do_transform: device = self.rand_3d_elastic.device - grid = torch.tensor(grid).to(device) gaussian = GaussianFilter(spatial_dims=3, sigma=self.rand_3d_elastic.sigma, truncated=3.0).to(device) - offset = torch.tensor(self.rand_3d_elastic.rand_offset, device=device).unsqueeze(0) + offset = torch.as_tensor(self.rand_3d_elastic.rand_offset, device=device).unsqueeze(0) grid[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude grid = self.rand_3d_elastic.rand_affine_grid(grid=grid) @@ -1173,7 +1481,7 @@ class RandFlipd(RandomizableTransform, MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ - backend = Flip.backend + backend = RandFlip.backend def __init__( self, @@ -1184,16 +1492,22 @@ def __init__( ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.spatial_axis = spatial_axis + self.flipper = RandFlip(prob=1.0, spatial_axis=spatial_axis) - self.flipper = Flip(spatial_axis=spatial_axis) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandFlipd": + super().set_random_state(seed, state) + self.flipper.set_random_state(seed, state) + return self def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - self.randomize(None) d = dict(data) + self.randomize(None) + for key in self.key_iterator(d): if self._do_transform: - d[key] = self.flipper(d[key]) + d[key] = self.flipper(d[key], randomize=False) self.push_transform(d, key) return d @@ -1202,9 +1516,9 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InverseKeys.DO_TRANSFORM]: + if transform[TraceKeys.DO_TRANSFORM]: # Inverse is same as forward - d[key] = self.flipper(d[key]) + d[key] = self.flipper(d[key], randomize=False) # Remove the applied transform self.pop_transform(d, key) return d @@ -1224,26 +1538,34 @@ class RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform): """ - backend = Flip.backend + backend = RandAxisFlip.backend def __init__(self, keys: KeysCollection, prob: float = 0.1, allow_missing_keys: bool = False) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self._axis: Optional[int] = None + self.flipper = RandAxisFlip(prob=1.0) - def randomize(self, data: NdarrayOrTensor) -> None: - super().randomize(None) - self._axis = self.R.randint(data.ndim - 1) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandAxisFlipd": + super().set_random_state(seed, state) + self.flipper.set_random_state(seed, state) + return self def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - self.randomize(data=data[self.keys[0]]) - flipper = Flip(spatial_axis=self._axis) - d = dict(data) + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + + self.randomize(None) + + # all the keys share the same random selected axis + self.flipper.randomize(d[first_key]) # type: ignore for key in self.key_iterator(d): if self._do_transform: - d[key] = flipper(d[key]) - self.push_transform(d, key, extra_info={"axis": self._axis}) + d[key] = self.flipper(d[key], randomize=False) + self.push_transform(d, key, extra_info={"axis": self.flipper._axis}) return d def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: @@ -1251,8 +1573,8 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InverseKeys.DO_TRANSFORM]: - flipper = Flip(spatial_axis=transform[InverseKeys.EXTRA_INFO]["axis"]) + if transform[TraceKeys.DO_TRANSFORM]: + flipper = Flip(spatial_axis=transform[TraceKeys.EXTRA_INFO]["axis"]) # Inverse is same as forward d[key] = flipper(d[key]) # Remove the applied transform @@ -1272,22 +1594,24 @@ class Rotated(MapTransform, InvertibleTransform): If it is True, the output shape is the same as the input. Default is True. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. align_corners: Defaults to False. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of bool, each element corresponds to a key in ``keys``. - dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. + dtype: data type for resampling computation. Defaults to ``np.float32``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. It also can be a sequence of dtype or None, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. """ + backend = Rotate.backend + def __init__( self, keys: KeysCollection, @@ -1296,7 +1620,7 @@ def __init__( mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, align_corners: Union[Sequence[bool], bool] = False, - dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, + dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], DtypeLike, torch.dtype] = np.float32, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) @@ -1307,18 +1631,14 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype ): orig_size = d[key].shape[1:] d[key] = self.rotator( - d[key], - mode=mode, - padding_mode=padding_mode, - align_corners=align_corners, - dtype=dtype, + d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype ) rot_mat = self.rotator.get_rotation_matrix() self.push_transform( @@ -1329,35 +1649,35 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda "rot_mat": rot_mat, "mode": mode.value if isinstance(mode, Enum) else mode, "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else "none", + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, }, ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) # Create inverse transform - fwd_rot_mat = transform[InverseKeys.EXTRA_INFO]["rot_mat"] - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] + fwd_rot_mat = transform[TraceKeys.EXTRA_INFO]["rot_mat"] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] inv_rot_mat = np.linalg.inv(fwd_rot_mat) xform = AffineTransform( normalized=False, mode=mode, padding_mode=padding_mode, - align_corners=False if align_corners == "none" else align_corners, + align_corners=False if align_corners == TraceKeys.NONE else align_corners, reverse_indexing=True, ) - output = xform( - torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), - spatial_size=transform[InverseKeys.ORIG_SIZE], - ) - d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + img_t, *_ = convert_data_type(d[key], torch.Tensor, dtype=dtype) + transform_t, *_ = convert_to_dst_type(inv_rot_mat, img_t) + + out = xform(img_t.unsqueeze(0), transform_t, spatial_size=transform[TraceKeys.ORIG_SIZE]).squeeze(0) + out, *_ = convert_to_dst_type(out, dst=d[key], dtype=out.dtype) + d[key] = out # Remove the applied transform self.pop_transform(d, key) @@ -1383,14 +1703,14 @@ class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform): If it is True, the output shape is the same as the input. Default is True. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. align_corners: Defaults to False. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html It also can be a sequence of bool, each element corresponds to a key in ``keys``. dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. To be compatible with other modules, @@ -1399,6 +1719,8 @@ class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = RandRotate.backend + def __init__( self, keys: KeysCollection, @@ -1410,99 +1732,84 @@ def __init__( mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, align_corners: Union[Sequence[bool], bool] = False, - dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, + dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], DtypeLike, torch.dtype] = np.float32, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.range_x = ensure_tuple(range_x) - if len(self.range_x) == 1: - self.range_x = tuple(sorted([-self.range_x[0], self.range_x[0]])) - self.range_y = ensure_tuple(range_y) - if len(self.range_y) == 1: - self.range_y = tuple(sorted([-self.range_y[0], self.range_y[0]])) - self.range_z = ensure_tuple(range_z) - if len(self.range_z) == 1: - self.range_z = tuple(sorted([-self.range_z[0], self.range_z[0]])) - - self.keep_size = keep_size + self.rand_rotate = RandRotate(range_x=range_x, range_y=range_y, range_z=range_z, prob=1.0, keep_size=keep_size) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - self.x = 0.0 - self.y = 0.0 - self.z = 0.0 - - def randomize(self, data: Optional[Any] = None) -> None: - super().randomize(None) - self.x = self.R.uniform(low=self.range_x[0], high=self.range_x[1]) - self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) - self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandRotated": + super().set_random_state(seed, state) + self.rand_rotate.set_random_state(seed, state) + return self - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - self.randomize() + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - angle: Union[Sequence[float], float] = self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z) - rotator = Rotate( - angle=angle, - keep_size=self.keep_size, - ) + self.randomize(None) + + # all the keys share the same random rotate angle + self.rand_rotate.randomize() for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype ): - orig_size = d[key].shape[1:] if self._do_transform: - d[key] = rotator( + d[key], rot_mat = self.rand_rotate( d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, + randomize=False, + get_matrix=True, ) - rot_mat = rotator.get_rotation_matrix() else: rot_mat = np.eye(d[key].ndim) self.push_transform( d, key, - orig_size=orig_size, + orig_size=d[key].shape[1:], extra_info={ "rot_mat": rot_mat, "mode": mode.value if isinstance(mode, Enum) else mode, "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else "none", + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, }, ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InverseKeys.DO_TRANSFORM]: + if transform[TraceKeys.DO_TRANSFORM]: # Create inverse transform - fwd_rot_mat = transform[InverseKeys.EXTRA_INFO]["rot_mat"] - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] + fwd_rot_mat = transform[TraceKeys.EXTRA_INFO]["rot_mat"] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] inv_rot_mat = np.linalg.inv(fwd_rot_mat) xform = AffineTransform( normalized=False, mode=mode, padding_mode=padding_mode, - align_corners=False if align_corners == "none" else align_corners, + align_corners=False if align_corners == TraceKeys.NONE else align_corners, reverse_indexing=True, ) - output = xform( - torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), - spatial_size=transform[InverseKeys.ORIG_SIZE], - ) - d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + img_t, *_ = convert_data_type(d[key], torch.Tensor, dtype=dtype) + transform_t, *_ = convert_to_dst_type(inv_rot_mat, img_t) + output: torch.Tensor + out = xform(img_t.unsqueeze(0), transform_t, spatial_size=transform[TraceKeys.ORIG_SIZE]).squeeze(0) + out, *_ = convert_to_dst_type(out, dst=d[key], dtype=out.dtype) + d[key] = out # Remove the applied transform self.pop_transform(d, key) @@ -1520,41 +1827,46 @@ class Zoomd(MapTransform, InvertibleTransform): If a sequence, zoom should contain one value for each spatial axis. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"area"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html It also can be a sequence of string, each element corresponds to a key in ``keys``. - padding_mode: {``"constant"``, ``"edge``", ``"linear_ramp``", ``"maximum``", ``"mean``", `"median``", - ``"minimum``", `"reflect``", ``"symmetric``", ``"wrap``", ``"empty``", ``"``"} + padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. The mode to pad data after zooming. - See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. keep_size: Should keep original size (pad if needed), default is True. allow_missing_keys: don't raise exception if key is missing. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ + backend = Zoom.backend + def __init__( self, keys: KeysCollection, zoom: Union[Sequence[float], float], mode: InterpolateModeSequence = InterpolateMode.AREA, - padding_mode: NumpyPadModeSequence = NumpyPadMode.EDGE, + padding_mode: PadModeSequence = NumpyPadMode.EDGE, align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, keep_size: bool = True, allow_missing_keys: bool = False, - **np_kwargs, + **kwargs, ) -> None: super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) - self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **np_kwargs) + self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **kwargs) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners @@ -1565,36 +1877,31 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda extra_info={ "mode": mode.value if isinstance(mode, Enum) else mode, "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else "none", + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, }, ) - d[key] = self.zoomer( - d[key], - mode=mode, - padding_mode=padding_mode, - align_corners=align_corners, - ) + d[key] = self.zoomer(d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform zoom = np.array(self.zoomer.zoom) inverse_transform = Zoom(zoom=(1 / zoom).tolist(), keep_size=self.zoomer.keep_size) - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] # Apply inverse d[key] = inverse_transform( d[key], mode=mode, padding_mode=padding_mode, - align_corners=None if align_corners == "none" else align_corners, + align_corners=None if align_corners == TraceKeys.NONE else align_corners, ) # Size might be out by 1 voxel so pad - d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE], mode="edge")(d[key]) + d[key] = SpatialPad(transform[TraceKeys.ORIG_SIZE], mode="edge")(d[key]) # Remove the applied transform self.pop_transform(d, key) @@ -1620,23 +1927,28 @@ class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform): If 2 values provided for 3D data, use the first value for both H & W dims to keep the same zoom ratio. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"area"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html It also can be a sequence of string, each element corresponds to a key in ``keys``. - padding_mode: {``"constant"``, ``"edge``", ``"linear_ramp``", ``"maximum``", ``"mean``", `"median``", - ``"minimum``", `"reflect``", ``"symmetric``", ``"wrap``", ``"empty``", ``"``"} + padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. The mode to pad data after zooming. - See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. keep_size: Should keep original size (pad if needed), default is True. allow_missing_keys: don't raise exception if key is missing. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. + kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ + backend = RandZoom.backend + def __init__( self, keys: KeysCollection, @@ -1644,121 +1956,201 @@ def __init__( min_zoom: Union[Sequence[float], float] = 0.9, max_zoom: Union[Sequence[float], float] = 1.1, mode: InterpolateModeSequence = InterpolateMode.AREA, - padding_mode: NumpyPadModeSequence = NumpyPadMode.EDGE, + padding_mode: PadModeSequence = NumpyPadMode.EDGE, align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, keep_size: bool = True, allow_missing_keys: bool = False, - **np_kwargs, + **kwargs, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.min_zoom = ensure_tuple(min_zoom) - self.max_zoom = ensure_tuple(max_zoom) - if len(self.min_zoom) != len(self.max_zoom): - raise AssertionError("min_zoom and max_zoom must have same length.") - + self.rand_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, keep_size=keep_size, **kwargs) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) - self.keep_size = keep_size - self.np_kwargs = np_kwargs - - self._zoom: Sequence[float] = [1.0] - def randomize(self, data: Optional[Any] = None) -> None: - super().randomize(None) - self._zoom = [self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)] + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandZoomd": + super().set_random_state(seed, state) + self.rand_zoom.set_random_state(seed, state) + return self - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - # match the spatial dim of first item - self.randomize() + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + + self.randomize(None) - img_dims = data[self.keys[0]].ndim - if len(self._zoom) == 1: - # to keep the spatial shape ratio, use same random zoom factor for all dims - self._zoom = ensure_tuple_rep(self._zoom[0], img_dims - 1) - elif len(self._zoom) == 2 and img_dims > 3: - # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim - self._zoom = ensure_tuple_rep(self._zoom[0], img_dims - 2) + ensure_tuple(self._zoom[-1]) - zoomer = Zoom(self._zoom, keep_size=self.keep_size, **self.np_kwargs) + # all the keys share the same random zoom factor + self.rand_zoom.randomize(d[first_key]) # type: ignore for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners ): + if self._do_transform: + d[key] = self.rand_zoom( + d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, randomize=False + ) self.push_transform( d, key, extra_info={ - "zoom": self._zoom, + "zoom": self.rand_zoom._zoom, "mode": mode.value if isinstance(mode, Enum) else mode, "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else "none", + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, }, ) - if self._do_transform: - d[key] = zoomer( - d[key], - mode=mode, - padding_mode=padding_mode, - align_corners=align_corners, - ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InverseKeys.DO_TRANSFORM]: + if transform[TraceKeys.DO_TRANSFORM]: # Create inverse transform - zoom = np.array(transform[InverseKeys.EXTRA_INFO]["zoom"]) - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] - inverse_transform = Zoom(zoom=(1 / zoom).tolist(), keep_size=self.keep_size) + zoom = np.array(transform[TraceKeys.EXTRA_INFO]["zoom"]) + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] + inverse_transform = Zoom(zoom=(1 / zoom).tolist(), keep_size=self.rand_zoom.keep_size) # Apply inverse d[key] = inverse_transform( d[key], mode=mode, padding_mode=padding_mode, - align_corners=None if align_corners == "none" else align_corners, + align_corners=None if align_corners == TraceKeys.NONE else align_corners, ) # Size might be out by 1 voxel so pad - d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE], mode="edge")(d[key]) + d[key] = SpatialPad(transform[TraceKeys.ORIG_SIZE], mode="edge")(d[key]) # Remove the applied transform self.pop_transform(d, key) return d -class AddCoordinateChannelsd(MapTransform): +class GridDistortiond(MapTransform): """ - Dictionary-based wrapper of :py:class:`monai.transforms.AddCoordinateChannels`. + Dictionary-based wrapper of :py:class:`monai.transforms.GridDistortion`. """ - def __init__(self, keys: KeysCollection, spatial_channels: Sequence[int], allow_missing_keys: bool = False) -> None: + backend = GridDistortion.backend + + def __init__( + self, + keys: KeysCollection, + num_cells: Union[Tuple[int], int], + distort_steps: List[Tuple], + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + device: Optional[torch.device] = None, + allow_missing_keys: bool = False, + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. - See also: :py:class:`monai.transforms.compose.MapTransform` + num_cells: number of grid cells on each dimension. + distort_steps: This argument is a list of tuples, where each tuple contains the distort steps of the + corresponding dimensions (in the order of H, W[, D]). The length of each tuple equals to `num_cells + 1`. + Each value in the tuple represents the distort step of the related cell. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + It also can be a sequence of string, each element corresponds to a key in ``keys``. + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"reflection"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + It also can be a sequence of string, each element corresponds to a key in ``keys``. + device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. - spatial_channels: the spatial dimensions that are to have their coordinates encoded in a channel and - appended to the input. E.g., `(1,2,3)` will append three channels to the input, encoding the - coordinates of the input's three spatial dimensions. It is assumed dimension 0 is the channel. """ super().__init__(keys, allow_missing_keys) - self.add_coordinate_channels = AddCoordinateChannels(spatial_channels) + self.grid_distortion = GridDistortion(num_cells=num_cells, distort_steps=distort_steps, device=device) + self.mode = ensure_tuple_rep(mode, len(self.keys)) + self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - for key in self.key_iterator(d): - d[key] = self.add_coordinate_channels(d[key]) + for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + d[key] = self.grid_distortion(d[key], mode=mode, padding_mode=padding_mode) + return d + + +class RandGridDistortiond(RandomizableTransform, MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.RandGridDistortion`. + """ + + backend = RandGridDistortion.backend + + def __init__( + self, + keys: KeysCollection, + num_cells: Union[Tuple[int], int] = 5, + prob: float = 0.1, + distort_limit: Union[Tuple[float, float], float] = (-0.03, 0.03), + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + device: Optional[torch.device] = None, + allow_missing_keys: bool = False, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + num_cells: number of grid cells on each dimension. + prob: probability of returning a randomized grid distortion transform. Defaults to 0.1. + distort_limit: range to randomly distort. + If single number, distort_limit is picked from (-distort_limit, distort_limit). + Defaults to (-0.03, 0.03). + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + It also can be a sequence of string, each element corresponds to a key in ``keys``. + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"reflection"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + It also can be a sequence of string, each element corresponds to a key in ``keys``. + device: device on which the tensor will be allocated. + allow_missing_keys: don't raise exception if key is missing. + + """ + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob) + self.rand_grid_distortion = RandGridDistortion( + num_cells=num_cells, prob=1.0, distort_limit=distort_limit, device=device + ) + self.mode = ensure_tuple_rep(mode, len(self.keys)) + self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandGridDistortiond": + super().set_random_state(seed, state) + self.rand_grid_distortion.set_random_state(seed, state) + return self + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + self.randomize(None) + if not self._do_transform: + return d + + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + + self.rand_grid_distortion.randomize(d[first_key].shape[1:]) # type: ignore + for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + d[key] = self.rand_grid_distortion(d[key], mode=mode, padding_mode=padding_mode, randomize=False) return d +SpatialResampleD = SpatialResampleDict = SpatialResampled +ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd SpacingD = SpacingDict = Spacingd OrientationD = OrientationDict = Orientationd Rotate90D = Rotate90Dict = Rotate90d @@ -1770,9 +2162,10 @@ def __call__( Rand3DElasticD = Rand3DElasticDict = Rand3DElasticd FlipD = FlipDict = Flipd RandFlipD = RandFlipDict = RandFlipd +GridDistortionD = GridDistortionDict = GridDistortiond +RandGridDistortionD = RandGridDistortionDict = RandGridDistortiond RandAxisFlipD = RandAxisFlipDict = RandAxisFlipd RotateD = RotateDict = Rotated RandRotateD = RandRotateDict = RandRotated ZoomD = ZoomDict = Zoomd RandZoomD = RandZoomDict = RandZoomd -AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index ef49bc706c..8537f7eb89 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,17 +21,10 @@ from monai import transforms from monai.config import KeysCollection -from monai.utils import MAX_SEED, ensure_tuple +from monai.utils import MAX_SEED, ensure_tuple, first from monai.utils.enums import TransformBackends -__all__ = [ - "ThreadUnsafe", - "apply_transform", - "Randomizable", - "RandomizableTransform", - "Transform", - "MapTransform", -] +__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"] ReturnType = TypeVar("ReturnType") @@ -47,9 +40,9 @@ def _apply_transform( Otherwise `parameters` is considered as single argument to `transform`. Args: - transform (Callable[..., ReturnType]): a callable to be used to transform `data`. - parameters (Any): parameters for the `transform`. - unpack_parameters (bool, optional): whether to unpack parameters for `transform`. Defaults to False. + transform: a callable to be used to transform `data`. + parameters: parameters for the `transform`. + unpack_parameters: whether to unpack parameters for `transform`. Defaults to False. Returns: ReturnType: The return type of `transform`. @@ -65,6 +58,7 @@ def apply_transform( data: Any, map_items: bool = True, unpack_items: bool = False, + log_stats: bool = False, ) -> Union[List[ReturnType], ReturnType]: """ Transform `data` with `transform`. @@ -74,11 +68,14 @@ def apply_transform( otherwise transform will be applied once with `data` as the argument. Args: - transform (Callable[..., ReturnType]): a callable to be used to transform `data`. - data (Any): an object to be transformed. - map_items (bool, optional): whether to apply transform to each item in `data`, + transform: a callable to be used to transform `data`. + data: an object to be transformed. + map_items: whether to apply transform to each item in `data`, if `data` is a list or tuple. Defaults to True. - unpack_items (bool, optional): [description]. Defaults to False. + unpack_items: whether to unpack parameters using `*`. Defaults to False. + log_stats: whether to log the detailed information of data and applied transform when error happened, + for NumPy array and PyTorch Tensor, log the data shape and value range, + for other meta data, log the values directly. default to `False`. Raises: Exception: When ``transform`` raises an exception. @@ -92,7 +89,7 @@ def apply_transform( return _apply_transform(transform, data, unpack_items) except Exception as e: - if not isinstance(transform, transforms.compose.Compose): + if log_stats and not isinstance(transform, transforms.compose.Compose): # log the input data information of exact transform in the transform chain datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False) logger = logging.getLogger(datastats._logger_name) @@ -103,7 +100,7 @@ def apply_transform( def _log_stats(data, prefix: Optional[str] = "Data"): if isinstance(data, (np.ndarray, torch.Tensor)): # log data type, shape, range for array - datastats(img=data, data_shape=True, value_range=True, prefix=prefix) # type: ignore + datastats(img=data, data_shape=True, value_range=True, prefix=prefix) else: # log data type and value for other meta data datastats(img=data, data_value=True, prefix=prefix) @@ -213,10 +210,10 @@ class Transform(ABC): :py:class:`monai.transforms.Compose` """ + # Transforms should add data types to this list if they are capable of performing a transform without + # modifying the input type. For example, ["torch.Tensor", "np.ndarray"] means that no copies of the data + # are required if the input is either `torch.Tensor` or `np.ndarray`. backend: List[TransformBackends] = [] - """Transforms should add data types to this list if they are capable of performing a transform without - modifying the input type. For example, [\"torch.Tensor\", \"np.ndarray\"] means that no copies of the data - are required if the input is either \"torch.Tensor\" or \"np.ndarray\".""" @abstractmethod def __call__(self, data: Any): @@ -226,17 +223,15 @@ def __call__(self, data: Any): return an updated version of ``data``. To simplify the input validations, most of the transforms assume that - - ``data`` is a Numpy ndarray, PyTorch Tensor or string + - ``data`` is a Numpy ndarray, PyTorch Tensor or string, - the data shape can be: - #. string data without shape, `LoadImage` transform expects file paths - #. most of the pre-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``, - except that `AddChannel` expects (spatial_dim_1[, spatial_dim_2, ...]) and - `AsChannelFirst` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels) - #. most of the post-processing transforms expect - ``(batch_size, num_channels, spatial_dim_1[, spatial_dim_2, ...])`` + #. string data without shape, `LoadImage` transform expects file paths, + #. most of the pre-/post-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``, + except for example: `AddChannel` expects (spatial_dim_1[, spatial_dim_2, ...]) and + `AsChannelFirst` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels), - - the channel dimension is not omitted even if number of channels is one + - the channel dimension is often not omitted even if number of channels is one. This method can optionally take additional arguments to help execute transformation operation. @@ -333,18 +328,16 @@ def __call__(self, data): To simplify the input validations, this method assumes: - - ``data`` is a Python dictionary + - ``data`` is a Python dictionary, - ``data[key]`` is a Numpy ndarray, PyTorch Tensor or string, where ``key`` is an element of ``self.keys``, the data shape can be: - #. string data without shape, `LoadImaged` transform expects file paths - #. most of the pre-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``, - except that `AddChanneld` expects (spatial_dim_1[, spatial_dim_2, ...]) and + #. string data without shape, `LoadImaged` transform expects file paths, + #. most of the pre-/post-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``, + except for example: `AddChanneld` expects (spatial_dim_1[, spatial_dim_2, ...]) and `AsChannelFirstd` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels) - #. most of the post-processing transforms expect - ``(batch_size, num_channels, spatial_dim_1[, spatial_dim_2, ...])`` - - the channel dimension is not omitted even if number of channels is one + - the channel dimension is often not omitted even if number of channels is one. Raises: NotImplementedError: When the subclass does not override this method. @@ -355,11 +348,7 @@ def __call__(self, data): """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - def key_iterator( - self, - data: Dict[Hashable, Any], - *extra_iterables: Optional[Iterable], - ) -> Generator: + def key_iterator(self, data: Dict[Hashable, Any], *extra_iterables: Optional[Iterable]) -> Generator: """ Iterate across keys and optionally extra iterables. If key is missing, exception is raised if `allow_missing_keys==False` (default). If `allow_missing_keys==True`, key is skipped. @@ -378,4 +367,18 @@ def key_iterator( if key in data: yield (key,) + tuple(_ex_iters) if extra_iterables else key elif not self.allow_missing_keys: - raise KeyError(f"Key was missing ({key}) and allow_missing_keys==False") + raise KeyError( + f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data" + " and allow_missing_keys==False." + ) + + def first_key(self, data: Dict[Hashable, Any]): + """ + Get the first available key of `self.keys` in the input `data` dictionary. + If no available key, return an empty list `[]`. + + Args: + data: data that the transform will be applied to. + + """ + return first(self.key_iterator(data), []) diff --git a/monai/transforms/utility/__init__.py b/monai/transforms/utility/__init__.py index 14ae193634..1e97f89407 100644 --- a/monai/transforms/utility/__init__.py +++ b/monai/transforms/utility/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 2eb6c447c6..72f50f6894 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -31,22 +31,33 @@ map_binary_to_indices, map_classes_to_indices, ) -from monai.transforms.utils_pytorch_numpy_unification import in1d, moveaxis -from monai.utils import convert_to_numpy, convert_to_tensor, ensure_tuple, look_up_option, min_version, optional_import +from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, moveaxis, unravel_indices +from monai.utils import ( + convert_data_type, + convert_to_cupy, + convert_to_numpy, + convert_to_tensor, + deprecated_arg, + ensure_tuple, + look_up_option, + min_version, + optional_import, +) from monai.utils.enums import TransformBackends from monai.utils.misc import is_module_ver_at_least -from monai.utils.type_conversion import convert_data_type +from monai.utils.type_conversion import convert_to_dst_type, get_equivalent_dtype PILImageImage, has_pil = optional_import("PIL.Image", name="Image") pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray") cp, has_cp = optional_import("cupy") -cp_ndarray, _ = optional_import("cupy", name="ndarray") + __all__ = [ "Identity", "AsChannelFirst", "AsChannelLast", "AddChannel", + "AddCoordinateChannels", "EnsureChannelFirst", "EnsureType", "RepeatChannel", @@ -71,6 +82,9 @@ "MapLabelValue", "IntensityStats", "ToDevice", + "CuCIM", + "RandCuCIM", + "ToCupy", ] @@ -188,6 +202,7 @@ def __init__(self, strict_check: bool = True): strict_check: whether to raise an error when the meta information is insufficient. """ self.strict_check = strict_check + self.add_channel = AddChannel() def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> NdarrayOrTensor: """ @@ -209,7 +224,7 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> warnings.warn(msg) return img if channel_dim == "no_channel": - return AddChannel()(img) + return self.add_channel(img) return AsChannelFirst(channel_dim=channel_dim)(img) @@ -234,8 +249,8 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is a "channel-first" array. """ - repeeat_fn = torch.repeat_interleave if isinstance(img, torch.Tensor) else np.repeat - return repeeat_fn(img, self.repeats, 0) # type: ignore + repeat_fn = torch.repeat_interleave if isinstance(img, torch.Tensor) else np.repeat + return repeat_fn(img, self.repeats, 0) # type: ignore class RemoveRepeatedChannel(Transform): @@ -310,7 +325,7 @@ def __init__(self, dtype=np.float32) -> None: """ self.dtype = dtype - def __call__(self, img: NdarrayOrTensor, dtype: Optional[Union[DtypeLike, torch.dtype]] = None) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, dtype: Union[DtypeLike, torch.dtype] = None) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is a numpy array or PyTorch Tensor. @@ -321,87 +336,138 @@ def __call__(self, img: NdarrayOrTensor, dtype: Optional[Union[DtypeLike, torch. TypeError: When ``img`` type is not in ``Union[numpy.ndarray, torch.Tensor]``. """ - if not isinstance(img, (torch.Tensor, np.ndarray)): - raise TypeError(f"img must be one of (numpy.ndarray, torch.Tensor) but is {type(img).__name__}.") - img_out, *_ = convert_data_type(img, output_type=type(img), dtype=dtype or self.dtype) - return img_out + return convert_data_type(img, output_type=type(img), dtype=dtype or self.dtype)[0] # type: ignore class ToTensor(Transform): """ Converts the input image to a tensor without applying any other transformations. + Input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. + Will convert Tensor, Numpy array, float, int, bool to Tensor, strings and objects keep the original. + For dictionary, list or tuple, convert every item to a Tensor if applicable and `wrap_sequence=False`. + + Args: + dtype: target data type to when converting to Tensor. + device: target device to put the converted Tensor data. + wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. + E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`. + """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __call__(self, img: NdarrayOrTensor) -> torch.Tensor: + def __init__( + self, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, wrap_sequence: bool = True + ) -> None: + super().__init__() + self.dtype = dtype + self.device = device + self.wrap_sequence = wrap_sequence + + def __call__(self, img: NdarrayOrTensor): """ Apply the transform to `img` and make it contiguous. """ - return convert_to_tensor(img, wrap_sequence=True) # type: ignore + return convert_to_tensor(img, dtype=self.dtype, device=self.device, wrap_sequence=self.wrap_sequence) class EnsureType(Transform): """ Ensure the input data to be a PyTorch Tensor or numpy array, support: `numpy array`, `PyTorch Tensor`, `float`, `int`, `bool`, `string` and `object` keep the original. - If passing a dictionary, list or tuple, still return dictionary, list or tuple and recursively convert - every item to the expected data type. + If passing a dictionary, list or tuple, still return dictionary, list or tuple will recursively convert + every item to the expected data type if `wrap_sequence=False`. Args: data_type: target data type to convert, should be "tensor" or "numpy". + dtype: target data content type to convert, for example: np.float32, torch.float, etc. + device: for Tensor data type, specify the target device. + wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. + E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, data_type: str = "tensor") -> None: - data_type = data_type.lower() - if data_type not in ("tensor", "numpy"): - raise ValueError("`data type` must be 'tensor' or 'numpy'.") - - self.data_type = data_type + def __init__( + self, + data_type: str = "tensor", + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + device: Optional[torch.device] = None, + wrap_sequence: bool = True, + ) -> None: + self.data_type = look_up_option(data_type.lower(), {"tensor", "numpy"}) + self.dtype = dtype + self.device = device + self.wrap_sequence = wrap_sequence - def __call__(self, data: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, data: NdarrayOrTensor): """ Args: data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will ensure Tensor, Numpy array, float, int, bool as Tensors or numpy arrays, strings and objects keep the original. for dictionary, list or tuple, ensure every item as expected type - if applicable. + if applicable and `wrap_sequence=False`. """ - return convert_to_tensor(data) if self.data_type == "tensor" else convert_to_numpy(data) # type: ignore + output_type = torch.Tensor if self.data_type == "tensor" else np.ndarray + out: NdarrayOrTensor + out, *_ = convert_data_type( + data=data, output_type=output_type, dtype=self.dtype, device=self.device, wrap_sequence=self.wrap_sequence + ) + return out class ToNumpy(Transform): """ Converts the input data to numpy array, can support list or tuple of numbers and PyTorch Tensor. + + Args: + dtype: target data type when converting to numpy array. + wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. + E.g., if `False`, `[1, 2]` -> `[array(1), array(2)]`, if `True`, then `[1, 2]` -> `array([1, 2])`. + """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __call__(self, img: NdarrayOrTensor) -> np.ndarray: + def __init__(self, dtype: DtypeLike = None, wrap_sequence: bool = True) -> None: + super().__init__() + self.dtype = dtype + self.wrap_sequence = wrap_sequence + + def __call__(self, img: NdarrayOrTensor): """ Apply the transform to `img` and make it contiguous. """ - return convert_to_numpy(img) # type: ignore + return convert_to_numpy(img, dtype=self.dtype, wrap_sequence=self.wrap_sequence) class ToCupy(Transform): """ Converts the input data to CuPy array, can support list or tuple of numbers, NumPy and PyTorch Tensor. + + Args: + dtype: data type specifier. It is inferred from the input by default. + if not None, must be an argument of `numpy.dtype`, for more details: + https://docs.cupy.dev/en/stable/reference/generated/cupy.array.html. + wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. + E.g., if `False`, `[1, 2]` -> `[array(1), array(2)]`, if `True`, then `[1, 2]` -> `array([1, 2])`. + """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __init__(self, dtype: Optional[np.dtype] = None, wrap_sequence: bool = True) -> None: + super().__init__() + self.dtype = dtype + self.wrap_sequence = wrap_sequence + + def __call__(self, data: NdarrayOrTensor): """ - Apply the transform to `img` and make it contiguous. + Create a CuPy array from `data` and make it contiguous """ - if isinstance(img, torch.Tensor): - img = img.detach().cpu().numpy() - return cp.ascontiguousarray(cp.asarray(img)) # type: ignore + return convert_to_cupy(data, dtype=self.dtype, wrap_sequence=self.wrap_sequence) class ToPIL(Transform): @@ -481,6 +547,11 @@ class DataStats(Transform): It can be inserted into any place of a transform chain and check results of previous transforms. It support both `numpy.ndarray` and `torch.tensor` as input data, so it can be used in pre-processing and post-processing. + + It gets logger from `logging.getLogger(name)`, we can setup a logger outside first with the same `name`. + If the log level of `logging.RootLogger` is higher than `INFO`, will add a separate `StreamHandler` + log handler with `INFO` level and record to `stdout`. + """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -493,7 +564,7 @@ def __init__( value_range: bool = True, data_value: bool = False, additional_info: Optional[Callable] = None, - logger_handler: Optional[logging.Handler] = None, + name: str = "DataStats", ) -> None: """ Args: @@ -504,9 +575,7 @@ def __init__( data_value: whether to show the raw value of input data. a typical example is to print some properties of Nifti image: affine, pixdim, etc. additional_info: user can define callable function to extract additional info from input data. - logger_handler: add additional handler to output data: save to file, etc. - add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html - the handler should have a logging level of at least `INFO`. + name: identifier of `logging.logger` to use, defaulting to "DataStats". Raises: TypeError: When ``additional_info`` is not an ``Optional[Callable]``. @@ -522,14 +591,14 @@ def __init__( if additional_info is not None and not callable(additional_info): raise TypeError(f"additional_info must be None or callable but is {type(additional_info).__name__}.") self.additional_info = additional_info - self._logger_name = "DataStats" + self._logger_name = name _logger = logging.getLogger(self._logger_name) _logger.setLevel(logging.INFO) - console = logging.StreamHandler(sys.stdout) # always stdout - console.setLevel(logging.INFO) - _logger.addHandler(console) - if logger_handler is not None: - _logger.addHandler(logger_handler) + if logging.root.getEffectiveLevel() > logging.INFO: + # if the root log level is higher than INFO, set a separate stream handler to record + console = logging.StreamHandler(sys.stdout) + console.setLevel(logging.INFO) + _logger.addHandler(console) def __call__( self, @@ -547,7 +616,7 @@ def __call__( lines = [f"{prefix or self.prefix} statistics:"] if self.data_type if data_type is None else data_type: - lines.append(f"Type: {type(img)}") + lines.append(f"Type: {type(img)} {img.dtype if hasattr(img, 'dtype') else None}") if self.data_shape if data_shape is None else data_shape: lines.append(f"Shape: {img.shape}") if self.value_range if value_range is None else value_range: @@ -697,9 +766,7 @@ class LabelToMask(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( # pytype: disable=annotation-type-mismatch - self, - select_labels: Union[Sequence[int], int], - merge_channels: bool = False, + self, select_labels: Union[Sequence[int], int], merge_channels: bool = False ) -> None: # pytype: disable=annotation-type-mismatch self.select_labels = ensure_tuple(select_labels) self.merge_channels = merge_channels @@ -726,7 +793,7 @@ def __call__( if img.shape[0] > 1: data = img[[*select_labels]] else: - where = np.where if isinstance(img, np.ndarray) else torch.where + where: Callable = np.where if isinstance(img, np.ndarray) else torch.where # type: ignore if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): data = where(in1d(img, select_labels), True, False).reshape(img.shape) # pre pytorch 1.8.0, need to use 1/0 instead of True/False @@ -759,16 +826,18 @@ class FgBgToIndices(Transform): """ + backend = [TransformBackends.NUMPY, TransformBackends.TORCH] + def __init__(self, image_threshold: float = 0.0, output_shape: Optional[Sequence[int]] = None) -> None: self.image_threshold = image_threshold self.output_shape = output_shape def __call__( self, - label: np.ndarray, - image: Optional[np.ndarray] = None, + label: NdarrayOrTensor, + image: Optional[NdarrayOrTensor] = None, output_shape: Optional[Sequence[int]] = None, - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: """ Args: label: input data to compute foreground and background indices. @@ -781,13 +850,15 @@ def __call__( output_shape = self.output_shape fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold) if output_shape is not None: - fg_indices = np.stack([np.unravel_index(i, output_shape) for i in fg_indices]) - bg_indices = np.stack([np.unravel_index(i, output_shape) for i in bg_indices]) - + fg_indices = unravel_indices(fg_indices, output_shape) + bg_indices = unravel_indices(bg_indices, output_shape) return fg_indices, bg_indices class ClassesToIndices(Transform): + + backend = [TransformBackends.NUMPY, TransformBackends.TORCH] + def __init__( self, num_classes: Optional[int] = None, @@ -814,10 +885,10 @@ def __init__( def __call__( self, - label: np.ndarray, - image: Optional[np.ndarray] = None, + label: NdarrayOrTensor, + image: Optional[NdarrayOrTensor] = None, output_shape: Optional[Sequence[int]] = None, - ) -> List[np.ndarray]: + ) -> List[NdarrayOrTensor]: """ Args: label: input data to compute the indices of every class. @@ -826,11 +897,13 @@ def __call__( output_shape: expected shape of output indices. if None, use `self.output_shape` instead. """ + if output_shape is None: output_shape = self.output_shape + indices: List[NdarrayOrTensor] indices = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) if output_shape is not None: - indices = [np.stack([np.unravel_index(i, output_shape) for i in array]) for array in indices] + indices = [unravel_indices(cls_indices, output_shape) for cls_indices in indices] return indices @@ -845,19 +918,17 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform): and ET (Enhancing tumor). """ - def __call__(self, img: np.ndarray) -> np.ndarray: + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: # if img has channel dim, squeeze it if img.ndim == 4 and img.shape[0] == 1: - img = np.squeeze(img, axis=0) + img = img.squeeze(0) - result = [] - # merge labels 1 (tumor non-enh) and 4 (tumor enh) to TC - result.append(np.logical_or(img == 1, img == 4)) + result = [(img == 1) | (img == 4), (img == 1) | (img == 4) | (img == 2), img == 4] # merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT - result.append(np.logical_or(np.logical_or(img == 1, img == 4), img == 2)) # label 4 is ET - result.append(img == 4) - return np.stack(result, axis=0) + return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0) class AddExtremePointsChannel(Randomizable, Transform): @@ -880,22 +951,24 @@ class AddExtremePointsChannel(Randomizable, Transform): ValueError: When label image is not single channel. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, background: int = 0, pert: float = 0.0) -> None: self._background = background self._pert = pert self._points: List[Tuple[int, ...]] = [] - def randomize(self, label: np.ndarray) -> None: + def randomize(self, label: NdarrayOrTensor) -> None: self._points = get_extreme_points(label, rand_state=self.R, background=self._background, pert=self._pert) def __call__( self, - img: np.ndarray, - label: Optional[np.ndarray] = None, + img: NdarrayOrTensor, + label: Optional[NdarrayOrTensor] = None, sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 3.0, rescale_min: float = -1.0, rescale_max: float = 1.0, - ): + ) -> NdarrayOrTensor: """ Args: img: the image that we want to add new channel to. @@ -918,8 +991,8 @@ def __call__( points_image = extreme_points_to_image( points=self._points, label=label, sigma=sigma, rescale_min=rescale_min, rescale_max=rescale_max ) - - return np.concatenate([img, points_image], axis=0) + points_image, *_ = convert_to_dst_type(points_image, img) # type: ignore + return concatenate((img, points_image), axis=0) class TorchVision: @@ -930,6 +1003,8 @@ class TorchVision: """ + backend = [TransformBackends.TORCH] + def __init__(self, name: str, *args, **kwargs) -> None: """ Args: @@ -939,16 +1014,20 @@ def __init__(self, name: str, *args, **kwargs) -> None: """ super().__init__() + self.name = name transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name) self.trans = transform(*args, **kwargs) - def __call__(self, img: torch.Tensor): + def __call__(self, img: NdarrayOrTensor): """ Args: img: PyTorch Tensor data for the TorchVision transform. """ - return self.trans(img) + img_t, *_ = convert_data_type(img, torch.Tensor) + out = self.trans(img_t) + out, *_ = convert_to_dst_type(src=out, dst=img) + return out class MapLabelValue: @@ -960,6 +1039,8 @@ class MapLabelValue: """ + backend = [TransformBackends.NUMPY] + def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeLike = np.float32) -> None: """ Args: @@ -975,13 +1056,13 @@ def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeL self.orig_labels = orig_labels self.target_labels = target_labels - self.dtype = dtype + self.dtype = get_equivalent_dtype(dtype, data_type=np.ndarray) - def __call__(self, img: np.ndarray): - img = np.asarray(img) - img_flat = img.flatten() + def __call__(self, img: NdarrayOrTensor): + img_np, *_ = convert_data_type(img, np.ndarray) + img_flat = img_np.flatten() try: - out_flat = np.copy(img_flat).astype(self.dtype) + out_flat = np.array(img_flat, dtype=self.dtype) except ValueError: # can't copy unchanged labels as the expected dtype is not supported, must map all the label values out_flat = np.zeros(shape=img_flat.shape, dtype=self.dtype) @@ -991,7 +1072,9 @@ def __call__(self, img: np.ndarray): continue np.place(out_flat, img_flat == o, t) - return out_flat.reshape(img.shape) + reshaped = out_flat.reshape(img_np.shape) + out, *_ = convert_to_dst_type(src=reshaped, dst=img, dtype=self.dtype) + return out class IntensityStats(Transform): @@ -1013,17 +1096,16 @@ class IntensityStats(Transform): """ + backend = [TransformBackends.NUMPY] + def __init__(self, ops: Sequence[Union[str, Callable]], key_prefix: str, channel_wise: bool = False) -> None: self.ops = ensure_tuple(ops) self.key_prefix = key_prefix self.channel_wise = channel_wise def __call__( - self, - img: np.ndarray, - meta_data: Optional[Dict] = None, - mask: Optional[np.ndarray] = None, - ) -> Tuple[np.ndarray, Dict]: + self, img: NdarrayOrTensor, meta_data: Optional[Dict] = None, mask: Optional[np.ndarray] = None + ) -> Tuple[NdarrayOrTensor, Dict]: """ Compute statistics for the intensity of input image. @@ -1034,21 +1116,23 @@ def __call__( mask must have the same shape as input `img`. """ + img_np, *_ = convert_data_type(img, np.ndarray) if meta_data is None: meta_data = {} - img_: np.ndarray = img if mask is not None: - if mask.shape != img.shape or mask.dtype != bool: - raise TypeError("mask must be bool array with the same shape as input `img`.") - img_ = img[mask] + if mask.shape != img_np.shape: + raise ValueError(f"mask must have the same shape as input `img`, got {mask.shape} and {img_np.shape}.") + if mask.dtype != bool: + raise TypeError(f"mask must be bool array, got type {mask.dtype}.") + img_np = img_np[mask] supported_ops = { - "mean": lambda x: np.nanmean(x), - "median": lambda x: np.nanmedian(x), - "max": lambda x: np.nanmax(x), - "min": lambda x: np.nanmin(x), - "std": lambda x: np.nanstd(x), + "mean": np.nanmean, + "median": np.nanmedian, + "max": np.nanmax, + "min": np.nanmin, + "std": np.nanstd, } def _compute(op: Callable, data: np.ndarray): @@ -1060,9 +1144,9 @@ def _compute(op: Callable, data: np.ndarray): for o in self.ops: if isinstance(o, str): o = look_up_option(o, supported_ops.keys()) - meta_data[self.key_prefix + "_" + o] = _compute(supported_ops[o], img_) + meta_data[self.key_prefix + "_" + o] = _compute(supported_ops[o], img_np) # type: ignore elif callable(o): - meta_data[self.key_prefix + "_custom_" + str(custom_index)] = _compute(o, img_) + meta_data[self.key_prefix + "_custom_" + str(custom_index)] = _compute(o, img_np) custom_index += 1 else: raise ValueError("ops must be key string for predefined operations or callable function.") @@ -1083,6 +1167,8 @@ class ToDevice(Transform): """ + backend = [TransformBackends.TORCH] + def __init__(self, device: Union[torch.device, str], **kwargs) -> None: """ Args: @@ -1099,3 +1185,121 @@ def __call__(self, img: torch.Tensor): raise ValueError("img must be PyTorch Tensor, consider converting img by `EnsureType` transform first.") return img.to(self.device, **self.kwargs) + + +class CuCIM(Transform): + """ + Wrap a non-randomized cuCIM transform, defined based on the transform name and args. + For randomized transforms (or randomly applying a transform) use :py:class:`monai.transforms.RandCuCIM`. + + Args: + name: the transform name in CuCIM package + args: parameters for the CuCIM transform + kwargs: parameters for the CuCIM transform + + Note: + CuCIM transform only work with CuPy arrays, so this transform expects input data to be `cupy.ndarray`. + Users can call `ToCuPy` transform to convert a numpy array or torch tensor to cupy array. + """ + + def __init__(self, name: str, *args, **kwargs) -> None: + super().__init__() + self.name = name + self.transform, _ = optional_import("cucim.core.operations.expose.transform", name=name) + self.args = args + self.kwargs = kwargs + + def __call__(self, data): + """ + Args: + data: a CuPy array (`cupy.ndarray`) for the cuCIM transform + + Returns: + `cupy.ndarray` + + """ + return self.transform(data, *self.args, **self.kwargs) + + +class RandCuCIM(CuCIM, RandomizableTransform): + """ + Wrap a randomized cuCIM transform, defined based on the transform name and args, + or randomly apply a non-randomized transform. + For deterministic non-randomized transforms use :py:class:`monai.transforms.CuCIM`. + + Args: + name: the transform name in CuCIM package. + apply_prob: the probability to apply the transform (default=1.0) + args: parameters for the CuCIM transform. + kwargs: parameters for the CuCIM transform. + + Note: + - CuCIM transform only work with CuPy arrays, so this transform expects input data to be `cupy.ndarray`. + Users can call `ToCuPy` transform to convert a numpy array or torch tensor to cupy array. + - If the cuCIM transform is already randomized the `apply_prob` argument has nothing to do with + the randomness of the underlying cuCIM transform. `apply_prob` defines if the transform (either randomized + or non-randomized) being applied randomly, so it can apply non-randomized transforms randomly but be careful + with setting `apply_prob` to anything than 1.0 when using along with cuCIM's randomized transforms. + - If the random factor of the underlying cuCIM transform is not derived from `self.R`, + the results may not be deterministic. See Also: :py:class:`monai.transforms.Randomizable`. + """ + + def __init__(self, name: str, apply_prob: float = 1.0, *args, **kwargs) -> None: + CuCIM.__init__(self, name, *args, **kwargs) + RandomizableTransform.__init__(self, prob=apply_prob) + + def __call__(self, data): + """ + Args: + data: a CuPy array (`cupy.ndarray`) for the cuCIM transform + + Returns: + `cupy.ndarray` + + """ + self.randomize(data) + if not self._do_transform: + return data + return super().__call__(data) + + +class AddCoordinateChannels(Transform): + """ + Appends additional channels encoding coordinates of the input. Useful when e.g. training using patch-based sampling, + to allow feeding of the patch's location into the network. + + This can be seen as a input-only version of CoordConv: + + Liu, R. et al. An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution, NeurIPS 2018. + + Args: + spatial_dims: the spatial dimensions that are to have their coordinates encoded in a channel and + appended to the input image. E.g., `(0, 1, 2)` represents `H, W, D` dims and append three channels + to the input image, encoding the coordinates of the input's three spatial dimensions. + + .. deprecated:: 0.8.0 + ``spatial_channels`` is deprecated, use ``spatial_dims`` instead. + + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + @deprecated_arg( + name="spatial_channels", new_name="spatial_dims", since="0.8", msg_suffix="please use `spatial_dims` instead." + ) + def __init__(self, spatial_dims: Sequence[int]) -> None: + self.spatial_dims = spatial_dims + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + """ + Args: + img: data to be transformed, assuming `img` is channel first. + """ + if max(self.spatial_dims) > img.ndim - 2 or min(self.spatial_dims) < 0: + raise ValueError(f"`spatial_dims` values must be within [0, {img.ndim - 2}]") + + spatial_size = img.shape[1:] + coord_channels = np.array(np.meshgrid(*tuple(np.linspace(-0.5, 0.5, s) for s in spatial_size), indexing="ij")) + coord_channels, *_ = convert_to_dst_type(coord_channels, img) # type: ignore + coord_channels = coord_channels[list(self.spatial_dims)] + return concatenate((img, coord_channels), axis=0) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index e9bcce93b0..ecf9aaffa4 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,8 +15,7 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ -import copy -import logging +import re from copy import deepcopy from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union @@ -30,11 +29,14 @@ from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform from monai.transforms.utility.array import ( AddChannel, + AddCoordinateChannels, + AddExtremePointsChannel, AsChannelFirst, AsChannelLast, CastToType, ClassesToIndices, ConvertToMultiChannelBasedOnBratsClasses, + CuCIM, DataStats, EnsureChannelFirst, EnsureType, @@ -58,13 +60,18 @@ Transpose, ) from monai.transforms.utils import extreme_points_to_image, get_extreme_points -from monai.utils import convert_to_numpy, ensure_tuple, ensure_tuple_rep -from monai.utils.enums import InverseKeys, TransformBackends +from monai.transforms.utils_pytorch_numpy_unification import concatenate +from monai.utils import convert_to_numpy, deprecated_arg, ensure_tuple, ensure_tuple_rep +from monai.utils.enums import PostFix, TraceKeys, TransformBackends +from monai.utils.type_conversion import convert_to_dst_type __all__ = [ "AddChannelD", "AddChannelDict", "AddChanneld", + "AddCoordinateChannelsD", + "AddCoordinateChannelsDict", + "AddCoordinateChannelsd", "AddExtremePointsChannelD", "AddExtremePointsChannelDict", "AddExtremePointsChanneld", @@ -86,6 +93,9 @@ "CopyItemsD", "CopyItemsDict", "CopyItemsd", + "CuCIMd", + "CuCIMD", + "CuCIMDict", "DataStatsD", "DataStatsDict", "DataStatsd", @@ -116,6 +126,9 @@ "MapLabelValueD", "MapLabelValueDict", "MapLabelValued", + "RandCuCIMd", + "RandCuCIMD", + "RandCuCIMDict", "RandLambdaD", "RandLambdaDict", "RandLambdad", @@ -161,8 +174,13 @@ "Transposed", "TransposeDict", "TransposeD", + "ClassesToIndicesd", + "ClassesToIndicesD", + "ClassesToIndicesDict", ] +DEFAULT_POST_FIX = PostFix.meta() + class Identityd(MapTransform): """ @@ -274,7 +292,7 @@ def __init__( self, keys: KeysCollection, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", + meta_key_postfix: str = DEFAULT_POST_FIX, strict_check: bool = True, ) -> None: """ @@ -442,15 +460,26 @@ class ToTensord(MapTransform, InvertibleTransform): backend = ToTensor.backend - def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: + def __init__( + self, + keys: KeysCollection, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + wrap_sequence: bool = True, + allow_missing_keys: bool = False, + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` + dtype: target data content type to convert, for example: torch.float, etc. + device: specify the target device to put the Tensor data. + wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. + E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) - self.converter = ToTensor() + self.converter = ToTensor(dtype=dtype, device=device, wrap_sequence=wrap_sequence) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -478,7 +507,7 @@ class EnsureTyped(MapTransform, InvertibleTransform): Ensure the input data to be a PyTorch Tensor or numpy array, support: `numpy array`, `PyTorch Tensor`, `float`, `int`, `bool`, `string` and `object` keep the original. If passing a dictionary, list or tuple, still return dictionary, list or tuple and recursively convert - every item to the expected data type. + every item to the expected data type if `wrap_sequence=False`. Note: Currently, we only convert tensor data to numpy array or scalar number in the inverse operation. @@ -486,16 +515,28 @@ class EnsureTyped(MapTransform, InvertibleTransform): backend = EnsureType.backend - def __init__(self, keys: KeysCollection, data_type: str = "tensor", allow_missing_keys: bool = False) -> None: + def __init__( + self, + keys: KeysCollection, + data_type: str = "tensor", + dtype: Union[DtypeLike, torch.dtype] = None, + device: Optional[torch.device] = None, + wrap_sequence: bool = True, + allow_missing_keys: bool = False, + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` data_type: target data type to convert, should be "tensor" or "numpy". + dtype: target data content type to convert, for example: np.float32, torch.float, etc. + device: for Tensor data type, specify the target device. + wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. + E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) - self.converter = EnsureType(data_type=data_type) + self.converter = EnsureType(data_type=data_type, dtype=dtype, device=device, wrap_sequence=wrap_sequence) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -522,15 +563,24 @@ class ToNumpyd(MapTransform): backend = ToNumpy.backend - def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: + def __init__( + self, + keys: KeysCollection, + dtype: DtypeLike = None, + wrap_sequence: bool = True, + allow_missing_keys: bool = False, + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` + dtype: target data type when converting to numpy array. + wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. + E.g., if `False`, `[1, 2]` -> `[array(1), array(2)]`, if `True`, then `[1, 2]` -> `array([1, 2])`. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) - self.converter = ToNumpy() + self.converter = ToNumpy(dtype=dtype, wrap_sequence=wrap_sequence) def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: d = dict(data) @@ -542,19 +592,29 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: class ToCupyd(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.ToCupy`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + dtype: data type specifier. It is inferred from the input by default. + if not None, must be an argument of `numpy.dtype`, for more details: + https://docs.cupy.dev/en/stable/reference/generated/cupy.array.html. + wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. + E.g., if `False`, `[1, 2]` -> `[array(1), array(2)]`, if `True`, then `[1, 2]` -> `array([1, 2])`. + allow_missing_keys: don't raise exception if key is missing. """ backend = ToCupy.backend - def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: - """ - Args: - keys: keys of the corresponding items to be transformed. - See also: :py:class:`monai.transforms.compose.MapTransform` - allow_missing_keys: don't raise exception if key is missing. - """ + def __init__( + self, + keys: KeysCollection, + dtype: Optional[np.dtype] = None, + wrap_sequence: bool = True, + allow_missing_keys: bool = False, + ) -> None: super().__init__(keys, allow_missing_keys) - self.converter = ToCupy() + self.converter = ToCupy(dtype=dtype, wrap_sequence=wrap_sequence) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -614,7 +674,7 @@ def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - fwd_indices = np.array(transform[InverseKeys.EXTRA_INFO]["indices"]) + fwd_indices = np.array(transform[TraceKeys.EXTRA_INFO]["indices"]) inv_indices = np.argsort(fwd_indices) inverse_transform = Transpose(inv_indices.tolist()) # Apply inverse @@ -630,8 +690,35 @@ class DeleteItemsd(MapTransform): It will remove the key-values and copy the others to construct a new dictionary. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self, keys: KeysCollection, sep: str = ".", use_re: Union[Sequence[bool], bool] = False) -> None: + """ + Args: + keys: keys of the corresponding items to delete, can be "A{sep}B{sep}C" + to delete key `C` in nested dictionary, `C` can be regular expression. + See also: :py:class:`monai.transforms.compose.MapTransform` + sep: the separator tag to define nested dictionary keys, default to ".". + use_re: whether the specified key is a regular expression, it also can be + a list of bool values, map the to keys. + """ + super().__init__(keys) + self.sep = sep + self.use_re = ensure_tuple_rep(use_re, len(self.keys)) + def __call__(self, data): - return {key: val for key, val in data.items() if key not in self.key_iterator(data)} + def _delete_item(keys, d, use_re: bool = False): + key = keys[0] + if len(keys) > 1: + d[key] = _delete_item(keys[1:], d[key], use_re) + return d + return {k: v for k, v in d.items() if (use_re and not re.search(key, k)) or (not use_re and k != key)} + + d = dict(data) + for key, use_re in zip(self.keys, self.use_re): + d = _delete_item(key.split(self.sep), d, use_re) + + return d class SelectItemsd(MapTransform): @@ -640,9 +727,10 @@ class SelectItemsd(MapTransform): It will copy the selected key-values and construct and new dictionary. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __call__(self, data): - result = {key: data[key] for key in self.key_iterator(data)} - return result + return {key: data[key] for key in self.key_iterator(data)} class SqueezeDimd(MapTransform): @@ -686,7 +774,7 @@ def __init__( value_range: Union[Sequence[bool], bool] = True, data_value: Union[Sequence[bool], bool] = False, additional_info: Optional[Union[Sequence[Callable], Callable]] = None, - logger_handler: Optional[logging.Handler] = None, + name: str = "DataStats", allow_missing_keys: bool = False, ) -> None: """ @@ -707,9 +795,7 @@ def __init__( additional_info: user can define callable function to extract additional info from input data. it also can be a sequence of string, each element corresponds to a key in ``keys``. - logger_handler: add additional handler to output data: save to file, etc. - add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html - the handler should have a logging level of at least `INFO`. + name: identifier of `logging.logger` to use, defaulting to "DataStats". allow_missing_keys: don't raise exception if key is missing. """ @@ -720,23 +806,14 @@ def __init__( self.value_range = ensure_tuple_rep(value_range, len(self.keys)) self.data_value = ensure_tuple_rep(data_value, len(self.keys)) self.additional_info = ensure_tuple_rep(additional_info, len(self.keys)) - self.logger_handler = logger_handler - self.printer = DataStats(logger_handler=logger_handler) + self.printer = DataStats(name=name) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, prefix, data_type, data_shape, value_range, data_value, additional_info in self.key_iterator( d, self.prefix, self.data_type, self.data_shape, self.value_range, self.data_value, self.additional_info ): - d[key] = self.printer( - d[key], - prefix, - data_type, - data_shape, - value_range, - data_value, - additional_info, - ) + d[key] = self.printer(d[key], prefix, data_type, data_shape, value_range, data_value, additional_info) return d @@ -779,17 +856,22 @@ class CopyItemsd(MapTransform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( - self, keys: KeysCollection, times: int, names: KeysCollection, allow_missing_keys: bool = False + self, + keys: KeysCollection, + times: int = 1, + names: Optional[KeysCollection] = None, + allow_missing_keys: bool = False, ) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` times: expected copy times, for example, if keys is "img", times is 3, - it will add 3 copies of "img" data to the dictionary. + it will add 3 copies of "img" data to the dictionary, default to 1. names: the names corresponding to the newly copied data, the length should match `len(keys) x times`. for example, if keys is ["img", "seg"] and times is 2, names can be: ["img_1", "seg_1", "img_2", "seg_2"]. + if None, use "{key}_{index}" as key for copy times `N`, index from `0` to `N-1`. allow_missing_keys: don't raise exception if key is missing. Raises: @@ -801,7 +883,7 @@ def __init__( if times < 1: raise ValueError(f"times must be positive, got {times}.") self.times = times - names = ensure_tuple(names) + names = [f"{k}_{i}" for k in self.keys for i in range(self.times)] if names is None else ensure_tuple(names) if len(names) != (len(self.keys) * times): raise ValueError( "len(names) must match len(keys) * times, " @@ -825,7 +907,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N if isinstance(val, torch.Tensor): d[new_key] = val.detach().clone() else: - d[new_key] = copy.deepcopy(val) + d[new_key] = deepcopy(val) return d @@ -866,6 +948,10 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N elif not isinstance(d[key], data_type): raise TypeError("All items in data must have the same type.") output.append(d[key]) + + if len(output) == 0: + return d + if data_type is np.ndarray: d[self.name] = np.concatenate(output, axis=self.dim) elif data_type is torch.Tensor: @@ -1000,7 +1086,7 @@ def __call__(self, data): return super().__call__(data) def _inverse_transform(self, transform_info: Dict, data: Any, func: Callable): - return self._lambd(data, func=func) if transform_info[InverseKeys.DO_TRANSFORM] else data + return self._lambd(data, func=func) if transform_info[TraceKeys.DO_TRANSFORM] else data class LabelToMaskd(MapTransform): @@ -1059,6 +1145,8 @@ class FgBgToIndicesd(MapTransform): """ + backend = FgBgToIndices.backend + def __init__( self, keys: KeysCollection, @@ -1075,7 +1163,7 @@ def __init__( self.image_key = image_key self.converter = FgBgToIndices(image_threshold, output_shape) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) image = d[self.image_key] if self.image_key else None for key in self.key_iterator(d): @@ -1103,6 +1191,8 @@ class ClassesToIndicesd(MapTransform): """ + backend = ClassesToIndices.backend + def __init__( self, keys: KeysCollection, @@ -1118,7 +1208,7 @@ def __init__( self.image_key = image_key self.converter = ClassesToIndices(num_classes, image_threshold, output_shape) - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, Any]): d = dict(data) image = d[self.image_key] if self.image_key else None for key in self.key_iterator(d): @@ -1138,11 +1228,13 @@ class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform): and ET (Enhancing tumor). """ + backend = ConvertToMultiChannelBasedOnBratsClasses.backend + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False): super().__init__(keys, allow_missing_keys) self.converter = ConvertToMultiChannelBasedOnBratsClasses() - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -1168,6 +1260,8 @@ class AddExtremePointsChanneld(Randomizable, MapTransform): """ + backend = AddExtremePointsChannel.backend + def __init__( self, keys: KeysCollection, @@ -1188,10 +1282,10 @@ def __init__( self.rescale_min = rescale_min self.rescale_max = rescale_max - def randomize(self, label: np.ndarray) -> None: + def randomize(self, label: NdarrayOrTensor) -> None: self.points = get_extreme_points(label, rand_state=self.R, background=self.background, pert=self.pert) - def __call__(self, data): + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) label = d[self.label_key] if label.shape[0] != 1: @@ -1209,7 +1303,8 @@ def __call__(self, data): rescale_min=self.rescale_min, rescale_max=self.rescale_max, ) - d[key] = np.concatenate([img, points_image], axis=0) + points_image, *_ = convert_to_dst_type(points_image, img) # type: ignore + d[key] = concatenate([img, points_image], axis=0) return d @@ -1223,14 +1318,9 @@ class TorchVisiond(MapTransform): data to be dict of PyTorch Tensors, users can easily call `ToTensord` transform to convert Numpy to Tensor. """ - def __init__( - self, - keys: KeysCollection, - name: str, - allow_missing_keys: bool = False, - *args, - **kwargs, - ) -> None: + backend = TorchVision.backend + + def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None: """ Args: keys: keys of the corresponding items to be transformed. @@ -1242,9 +1332,10 @@ def __init__( """ super().__init__(keys, allow_missing_keys) + self.name = name self.trans = TorchVision(name, *args, **kwargs) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.trans(d[key]) @@ -1267,14 +1358,9 @@ class RandTorchVisiond(Randomizable, MapTransform): """ - def __init__( - self, - keys: KeysCollection, - name: str, - allow_missing_keys: bool = False, - *args, - **kwargs, - ) -> None: + backend = TorchVision.backend + + def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None: """ Args: keys: keys of the corresponding items to be transformed. @@ -1286,9 +1372,10 @@ def __init__( """ MapTransform.__init__(self, keys, allow_missing_keys) + self.name = name self.trans = TorchVision(name, *args, **kwargs) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.trans(d[key]) @@ -1300,6 +1387,8 @@ class MapLabelValued(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.MapLabelValue`. """ + backend = MapLabelValue.backend + def __init__( self, keys: KeysCollection, @@ -1321,7 +1410,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.mapper = MapLabelValue(orig_labels=orig_labels, target_labels=target_labels, dtype=dtype) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.mapper(d[key]) @@ -1356,13 +1445,15 @@ class IntensityStatsd(MapTransform): the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. used to store the computed statistics to the meta dict. allow_missing_keys: don't raise exception if key is missing. """ + backend = IntensityStats.backend + def __init__( self, keys: KeysCollection, @@ -1371,7 +1462,7 @@ def __init__( mask_keys: Optional[KeysCollection] = None, channel_wise: bool = False, meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", + meta_key_postfix: str = DEFAULT_POST_FIX, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) @@ -1382,16 +1473,14 @@ def __init__( raise ValueError("meta_keys should have the same length as keys.") self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - def __call__(self, data) -> Dict[Hashable, np.ndarray]: + def __call__(self, data) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, mask_key, meta_key, meta_key_postfix in self.key_iterator( d, self.mask_keys, self.meta_keys, self.meta_key_postfix ): meta_key = meta_key or f"{key}_{meta_key_postfix}" d[key], d[meta_key] = self.stats( - img=d[key], - meta_data=d.get(meta_key), - mask=d.get(mask_key) if mask_key is not None else None, + img=d[key], meta_data=d.get(meta_key), mask=d.get(mask_key) if mask_key is not None else None ) return d @@ -1401,12 +1490,10 @@ class ToDeviced(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ToDevice`. """ + backend = ToDevice.backend + def __init__( - self, - keys: KeysCollection, - device: Union[torch.device, str], - allow_missing_keys: bool = False, - **kwargs, + self, keys: KeysCollection, device: Union[torch.device, str], allow_missing_keys: bool = False, **kwargs ) -> None: """ Args: @@ -1427,6 +1514,121 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc return d +class CuCIMd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.CuCIM` for non-randomized transforms. + For randomized transforms of CuCIM use :py:class:`monai.transforms.RandCuCIMd`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + name: The transform name in CuCIM package. + allow_missing_keys: don't raise exception if key is missing. + args: parameters for the CuCIM transform. + kwargs: parameters for the CuCIM transform. + + Note: + CuCIM transforms only work with CuPy arrays, this transform expects input data to be `cupy.ndarray`. + Users can call `ToCuPy` transform to convert a numpy array or torch tensor to cupy array. + """ + + def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None: + super().__init__(keys=keys, allow_missing_keys=allow_missing_keys) + self.name = name + self.trans = CuCIM(name, *args, **kwargs) + + def __call__(self, data): + """ + Args: + data: Dict[Hashable, `cupy.ndarray`] + + Returns: + Dict[Hashable, `cupy.ndarray`] + + """ + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.trans(d[key]) + return d + + +class RandCuCIMd(CuCIMd, RandomizableTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.CuCIM` for randomized transforms. + For deterministic non-randomized transforms of CuCIM use :py:class:`monai.transforms.CuCIMd`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + name: The transform name in CuCIM package. + apply_prob: the probability to apply the transform (default=1.0) + allow_missing_keys: don't raise exception if key is missing. + args: parameters for the CuCIM transform. + kwargs: parameters for the CuCIM transform. + + Note: + - CuCIM transform only work with CuPy arrays, so this transform expects input data to be `cupy.ndarray`. + Users can call `ToCuPy` transform to convert a numpy array or torch tensor to cupy array. + - If the cuCIM transform is already randomized the `apply_prob` argument has nothing to do with + the randomness of the underlying cuCIM transform. `apply_prob` defines if the transform (either randomized + or non-randomized) being applied randomly, so it can apply non-randomized transforms randomly but be careful + with setting `apply_prob` to anything than 1.0 when using along with cuCIM's randomized transforms. + - If the random factor of the underlying cuCIM transform is not derived from `self.R`, + the results may not be deterministic. See Also: :py:class:`monai.transforms.Randomizable`. + """ + + def __init__(self, apply_prob: float = 1.0, *args, **kwargs) -> None: + CuCIMd.__init__(self, *args, **kwargs) + RandomizableTransform.__init__(self, prob=apply_prob) + + def __call__(self, data): + """ + Args: + data: Dict[Hashable, `cupy.ndarray`] + + Returns: + Dict[Hashable, `cupy.ndarray`] + + """ + self.randomize(data) + if not self._do_transform: + return dict(data) + return super().__call__(data) + + +class AddCoordinateChannelsd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.AddCoordinateChannels`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + spatial_dims: the spatial dimensions that are to have their coordinates encoded in a channel and + appended to the input image. E.g., `(0, 1, 2)` represents `H, W, D` dims and append three channels + to the input image, encoding the coordinates of the input's three spatial dimensions. + allow_missing_keys: don't raise exception if key is missing. + + .. deprecated:: 0.8.0 + ``spatial_channels`` is deprecated, use ``spatial_dims`` instead. + + """ + + backend = AddCoordinateChannels.backend + + @deprecated_arg( + name="spatial_channels", new_name="spatial_dims", since="0.8", msg_suffix="please use `spatial_dims` instead." + ) + def __init__(self, keys: KeysCollection, spatial_dims: Sequence[int], allow_missing_keys: bool = False) -> None: + super().__init__(keys, allow_missing_keys) + self.add_coordinate_channels = AddCoordinateChannels(spatial_dims=spatial_dims) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.add_coordinate_channels(d[key]) + return d + + IdentityD = IdentityDict = Identityd AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd @@ -1463,3 +1665,6 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc MapLabelValueD = MapLabelValueDict = MapLabelValued IntensityStatsD = IntensityStatsDict = IntensityStatsd ToDeviceD = ToDeviceDict = ToDeviced +CuCIMD = CuCIMDict = CuCIMd +RandCuCIMD = RandCuCIMDict = RandCuCIMd +AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 30aa5e7b99..847614adfe 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,37 +14,55 @@ import warnings from contextlib import contextmanager from inspect import getmembers, isclass -from typing import Any, Callable, Hashable, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Hashable, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Union import numpy as np import torch import monai -import monai.transforms.transform from monai.config import DtypeLike, IndexSelection -from monai.config.type_definitions import NdarrayOrTensor +from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor from monai.networks.layers import GaussianFilter +from monai.networks.utils import meshgrid_ij from monai.transforms.compose import Compose, OneOf -from monai.transforms.transform import MapTransform, Transform +from monai.transforms.transform import MapTransform, Transform, apply_transform +from monai.transforms.utils_pytorch_numpy_unification import ( + any_np_pt, + ascontiguousarray, + cumsum, + isfinite, + nonzero, + ravel, + searchsorted, + unique, + unravel_index, + where, +) from monai.utils import ( GridSampleMode, InterpolateMode, - InverseKeys, + NumpyPadMode, + PytorchPadMode, + TraceKeys, + deprecated_arg, ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, + get_equivalent_dtype, issequenceiterable, + look_up_option, min_version, optional_import, ) from monai.utils.enums import TransformBackends -from monai.utils.type_conversion import convert_data_type +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type measure, _ = optional_import("skimage.measure", "0.14.2", min_version) ndimage, _ = optional_import("scipy.ndimage") cp, has_cp = optional_import("cupy") cp_ndarray, _ = optional_import("cupy", name="ndarray") +cucim, has_cucim = optional_import("cucim") exposure, has_skimage = optional_import("skimage.exposure") __all__ = [ @@ -84,6 +102,9 @@ "get_number_image_type_conversions", "get_transform_backends", "print_transform_backends", + "convert_pad_mode", + "convert_to_contiguous", + "get_unique_labels", ] @@ -135,31 +156,43 @@ def zero_margins(img: np.ndarray, margin: int) -> bool: def rescale_array( - arr: NdarrayOrTensor, minv: float = 0.0, maxv: float = 1.0, dtype: Union[DtypeLike, torch.dtype] = np.float32 + arr: NdarrayOrTensor, + minv: Optional[float] = 0.0, + maxv: Optional[float] = 1.0, + dtype: Union[DtypeLike, torch.dtype] = np.float32, ) -> NdarrayOrTensor: """ Rescale the values of numpy array `arr` to be from `minv` to `maxv`. + If either `minv` or `maxv` is None, it returns `(a - min_a) / (max_a - min_a)`. + + Args: + arr: input array to rescale. + minv: minimum value of target rescaled array. + maxv: maxmum value of target rescaled array. + dtype: if not None, convert input array to dtype before computation. + """ if dtype is not None: arr, *_ = convert_data_type(arr, dtype=dtype) - mina = arr.min() maxa = arr.max() if mina == maxa: - return arr * minv + return arr * minv if minv is not None else arr norm = (arr - mina) / (maxa - mina) # normalize the array first + if (minv is None) or (maxv is None): + return norm return (norm * (maxv - minv)) + minv # rescale by minv and maxv, which is the normalized array by default def rescale_instance_array( - arr: np.ndarray, minv: float = 0.0, maxv: float = 1.0, dtype: DtypeLike = np.float32 + arr: np.ndarray, minv: Optional[float] = 0.0, maxv: Optional[float] = 1.0, dtype: DtypeLike = np.float32 ) -> np.ndarray: """ Rescale each array slice along the first dimension of `arr` independently. """ - out: np.ndarray = np.zeros(arr.shape, dtype) + out: np.ndarray = np.zeros(arr.shape, dtype or arr.dtype) for i in range(arr.shape[0]): out[i] = rescale_array(arr[i], minv, maxv, dtype) @@ -170,16 +203,12 @@ def rescale_array_int_max(arr: np.ndarray, dtype: DtypeLike = np.uint16) -> np.n """ Rescale the array `arr` to be between the minimum and maximum values of the type `dtype`. """ - info: np.iinfo = np.iinfo(dtype) - return np.asarray(rescale_array(arr, info.min, info.max), dtype=dtype) + info: np.iinfo = np.iinfo(dtype or arr.dtype) + return np.asarray(rescale_array(arr, info.min, info.max), dtype=dtype or arr.dtype) def copypaste_arrays( - src_shape, - dest_shape, - srccenter: Sequence[int], - destcenter: Sequence[int], - dims: Sequence[Optional[int]], + src_shape, dest_shape, srccenter: Sequence[int], destcenter: Sequence[int], dims: Sequence[Optional[int]] ) -> Tuple[Tuple[slice, ...], Tuple[slice, ...]]: """ Calculate the slices to copy a sliced area of array in `src_shape` into array in `dest_shape`. @@ -256,10 +285,8 @@ def resize_center(img: np.ndarray, *resize_dims: Optional[int], fill_value: floa def map_binary_to_indices( - label: np.ndarray, - image: Optional[np.ndarray] = None, - image_threshold: float = 0.0, -) -> Tuple[np.ndarray, np.ndarray]: + label: NdarrayOrTensor, image: Optional[NdarrayOrTensor] = None, image_threshold: float = 0.0 +) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: """ Compute the foreground and background of input label data, return the indices after fattening. For example: @@ -272,28 +299,32 @@ def map_binary_to_indices( to define background. so the output items will not map to all the voxels in the label. image_threshold: if enabled `image`, use ``image > image_threshold`` to determine the valid image content area and select background only in this area. - """ + # Prepare fg/bg indices if label.shape[0] > 1: label = label[1:] # for One-Hot format data, remove the background channel - label_flat = np.any(label, axis=0).ravel() # in case label has multiple dimensions - fg_indices = np.nonzero(label_flat)[0] + label_flat = ravel(any_np_pt(label, 0)) # in case label has multiple dimensions + fg_indices = nonzero(label_flat) if image is not None: - img_flat = np.any(image > image_threshold, axis=0).ravel() - bg_indices = np.nonzero(np.logical_and(img_flat, ~label_flat))[0] + img_flat = ravel(any_np_pt(image > image_threshold, 0)) + img_flat, *_ = convert_to_dst_type(img_flat, label, dtype=img_flat.dtype) + bg_indices = nonzero(img_flat & ~label_flat) else: - bg_indices = np.nonzero(~label_flat)[0] + bg_indices = nonzero(~label_flat) + # no need to save the indices in GPU, otherwise, still need to move to CPU at runtime when crop by indices + fg_indices, *_ = convert_data_type(fg_indices, device=torch.device("cpu")) + bg_indices, *_ = convert_data_type(bg_indices, device=torch.device("cpu")) return fg_indices, bg_indices def map_classes_to_indices( - label: np.ndarray, + label: NdarrayOrTensor, num_classes: Optional[int] = None, - image: Optional[np.ndarray] = None, + image: Optional[NdarrayOrTensor] = None, image_threshold: float = 0.0, -) -> List[np.ndarray]: +) -> List[NdarrayOrTensor]: """ Filter out indices of every class of the input label data, return the indices after fattening. It can handle both One-Hot format label and Argmax format label, must provide `num_classes` for @@ -313,11 +344,11 @@ def map_classes_to_indices( determine the valid image content area and select class indices only in this area. """ - img_flat: Optional[np.ndarray] = None + img_flat: Optional[NdarrayOrTensor] = None if image is not None: - img_flat = np.any(image > image_threshold, axis=0).ravel() + img_flat = ravel((image > image_threshold).any(0)) - indices: List[np.ndarray] = [] + indices: List[NdarrayOrTensor] = [] # assuming the first dimension is channel channels = len(label) @@ -328,16 +359,18 @@ def map_classes_to_indices( num_classes_ = num_classes for c in range(num_classes_): - label_flat = np.any(label[c : c + 1] if channels > 1 else label == c, axis=0).ravel() - label_flat = np.logical_and(img_flat, label_flat) if img_flat is not None else label_flat - indices.append(np.nonzero(label_flat)[0]) + label_flat = ravel(any_np_pt(label[c : c + 1] if channels > 1 else label == c, 0)) + label_flat = img_flat & label_flat if img_flat is not None else label_flat + # no need to save the indices in GPU, otherwise, still need to move to CPU at runtime when crop by indices + cls_indices: NdarrayOrTensor = convert_data_type(nonzero(label_flat), device=torch.device("cpu"))[0] + indices.append(cls_indices) return indices def weighted_patch_samples( spatial_size: Union[int, Sequence[int]], - w: np.ndarray, + w: NdarrayOrTensor, n_samples: int = 1, r_state: Optional[np.random.RandomState] = None, ) -> List: @@ -366,34 +399,45 @@ def weighted_patch_samples( s = tuple(slice(w // 2, m - w + w // 2) if m > w else slice(m // 2, m // 2 + 1) for w, m in zip(win_size, img_size)) v = w[s] # weight map in the 'valid' mode v_size = v.shape - v = v.ravel() - if np.any(v < 0): - v -= np.min(v) # shifting to non-negative - v = v.cumsum() - if not v[-1] or not np.isfinite(v[-1]) or v[-1] < 0: # uniform sampling + v = ravel(v) + if (v < 0).any(): + v -= v.min() # shifting to non-negative + v = cumsum(v) + if not v[-1] or not isfinite(v[-1]) or v[-1] < 0: # uniform sampling idx = r_state.randint(0, len(v), size=n_samples) else: - idx = v.searchsorted(r_state.random(n_samples) * v[-1], side="right") + r, *_ = convert_to_dst_type(r_state.random(n_samples), v) + idx = searchsorted(v, r * v[-1], right=True) # type: ignore + idx, *_ = convert_to_dst_type(idx, v, dtype=torch.int) # type: ignore # compensate 'valid' mode diff = np.minimum(win_size, img_size) // 2 - return [np.unravel_index(i, v_size) + diff for i in np.asarray(idx, dtype=int)] + diff, *_ = convert_to_dst_type(diff, v) # type: ignore + return [unravel_index(i, v_size) + diff for i in idx] def correct_crop_centers( - centers: List[np.ndarray], spatial_size: Union[Sequence[int], int], label_spatial_shape: Sequence[int] -) -> List[np.ndarray]: + centers: List[int], + spatial_size: Union[Sequence[int], int], + label_spatial_shape: Sequence[int], + allow_smaller: bool = False, +): """ - Utility to correct the crop center if the crop size is bigger than the image size. + Utility to correct the crop center if the crop size and centers are not compatible with the image size. Args: - ceters: pre-computed crop centers, will correct based on the valid region. + centers: pre-computed crop centers of every dim, will correct based on the valid region. spatial_size: spatial size of the ROIs to be sampled. label_spatial_shape: spatial shape of the original label data to compare with ROI. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will be set to + match the cropped size (i.e., no cropping in that dimension). """ spatial_size = fall_back_tuple(spatial_size, default=label_spatial_shape) - if not (np.subtract(label_spatial_shape, spatial_size) >= 0).all(): - raise ValueError("The size of the proposed random crop ROI is larger than the image size.") + if any(np.subtract(label_spatial_shape, spatial_size) < 0): + if not allow_smaller: + raise ValueError("The size of the proposed random crop ROI is larger than the image size.") + spatial_size = tuple(min(l, s) for l, s in zip(label_spatial_shape, spatial_size)) # Select subregion to assure valid roi valid_start = np.floor_divide(spatial_size, 2) @@ -405,16 +449,11 @@ def correct_crop_centers( # need this because np.random.randint does not work with same start and end if valid_s == valid_end[i]: valid_end[i] += 1 - - for i, c in enumerate(centers): - center_i = c - if c < valid_start[i]: - center_i = valid_start[i] - if c >= valid_end[i]: - center_i = valid_end[i] - 1 - centers[i] = center_i - - return centers + valid_centers = [] + for c, v_s, v_e in zip(centers, valid_start, valid_end): + center_i = min(max(c, v_s), v_e - 1) + valid_centers.append(int(center_i)) + return valid_centers def generate_pos_neg_label_crop_centers( @@ -422,10 +461,11 @@ def generate_pos_neg_label_crop_centers( num_samples: int, pos_ratio: float, label_spatial_shape: Sequence[int], - fg_indices: np.ndarray, - bg_indices: np.ndarray, + fg_indices: NdarrayOrTensor, + bg_indices: NdarrayOrTensor, rand_state: Optional[np.random.RandomState] = None, -) -> List[List[np.ndarray]]: + allow_smaller: bool = False, +) -> List[List[int]]: """ Generate valid sample locations based on the label with option for specifying foreground ratio Valid: samples sitting entirely within image, expected input shape: [C, H, W, D] or [C, H, W] @@ -438,6 +478,9 @@ def generate_pos_neg_label_crop_centers( fg_indices: pre-computed foreground indices in 1 dimension. bg_indices: pre-computed background indices in 1 dimension. rand_state: numpy randomState object to align with other modules. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will be set to + match the cropped size (i.e., no cropping in that dimension). Raises: ValueError: When the proposed roi is larger than the image. @@ -448,11 +491,12 @@ def generate_pos_neg_label_crop_centers( rand_state = np.random.random.__self__ # type: ignore centers = [] - fg_indices, bg_indices = np.asarray(fg_indices), np.asarray(bg_indices) - if fg_indices.size == 0 and bg_indices.size == 0: + fg_indices = np.asarray(fg_indices) if isinstance(fg_indices, Sequence) else fg_indices + bg_indices = np.asarray(bg_indices) if isinstance(bg_indices, Sequence) else bg_indices + if len(fg_indices) == 0 and len(bg_indices) == 0: raise ValueError("No sampling location available.") - if fg_indices.size == 0 or bg_indices.size == 0: + if len(fg_indices) == 0 or len(bg_indices) == 0: warnings.warn( f"N foreground {len(fg_indices)}, N background {len(bg_indices)}," "unable to generate class balanced samples." @@ -462,10 +506,10 @@ def generate_pos_neg_label_crop_centers( for _ in range(num_samples): indices_to_use = fg_indices if rand_state.rand() < pos_ratio else bg_indices random_int = rand_state.randint(len(indices_to_use)) - center = np.unravel_index(indices_to_use[random_int], label_spatial_shape) + idx = indices_to_use[random_int] + center = unravel_index(idx, label_spatial_shape).tolist() # shift center to range of valid centers - center_ori = list(center) - centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape)) + centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape, allow_smaller)) return centers @@ -474,10 +518,11 @@ def generate_label_classes_crop_centers( spatial_size: Union[Sequence[int], int], num_samples: int, label_spatial_shape: Sequence[int], - indices: List[np.ndarray], + indices: Sequence[NdarrayOrTensor], ratios: Optional[List[Union[float, int]]] = None, rand_state: Optional[np.random.RandomState] = None, -) -> List[List[np.ndarray]]: + allow_smaller: bool = False, +) -> List[List[int]]: """ Generate valid sample locations based on the specified ratios of label classes. Valid: samples sitting entirely within image, expected input shape: [C, H, W, D] or [C, H, W] @@ -490,6 +535,9 @@ def generate_label_classes_crop_centers( ratios: ratios of every class in the label to generate crop centers, including background class. if None, every class will have the same ratio to generate crop centers. rand_state: numpy randomState object to align with other modules. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will be set to + match the cropped size (i.e., no cropping in that dimension). """ if rand_state is None: @@ -499,12 +547,10 @@ def generate_label_classes_crop_centers( raise ValueError("num_samples must be an int number and greater than 0.") ratios_: List[Union[float, int]] = ([1] * len(indices)) if ratios is None else ratios if len(ratios_) != len(indices): - raise ValueError("random crop radios must match the number of indices of classes.") + raise ValueError("random crop ratios must match the number of indices of classes.") if any(i < 0 for i in ratios_): raise ValueError("ratios should not contain negative number.") - # ensure indices are numpy array - indices = [np.asarray(i) for i in indices] for i, array in enumerate(indices): if len(array) == 0: warnings.warn(f"no available indices of class {i} to crop, set the crop ratio of this class to zero.") @@ -516,10 +562,9 @@ def generate_label_classes_crop_centers( # randomly select the indices of a class based on the ratios indices_to_use = indices[i] random_int = rand_state.randint(len(indices_to_use)) - center = np.unravel_index(indices_to_use[random_int], label_spatial_shape) + center = unravel_index(indices_to_use[random_int], label_spatial_shape).tolist() # shift center to range of valid centers - center_ori = list(center) - centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape)) + centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape, allow_smaller)) return centers @@ -528,7 +573,9 @@ def create_grid( spatial_size: Sequence[int], spacing: Optional[Sequence[float]] = None, homogeneous: bool = True, - dtype: DtypeLike = float, + dtype: Union[DtypeLike, torch.dtype] = float, + device: Optional[torch.device] = None, + backend=TransformBackends.NUMPY, ): """ compute a `spatial_size` mesh. @@ -537,33 +584,95 @@ def create_grid( spatial_size: spatial size of the grid. spacing: same len as ``spatial_size``, defaults to 1.0 (dense grid). homogeneous: whether to make homogeneous coordinates. - dtype: output grid data type. + dtype: output grid data type, defaults to `float`. + device: device to compute and store the output (when the backend is "torch"). + backend: APIs to use, ``numpy`` or ``torch``. + + """ + _backend = look_up_option(backend, TransformBackends) + _dtype = dtype or float + if _backend == TransformBackends.NUMPY: + return _create_grid_numpy(spatial_size, spacing, homogeneous, _dtype) + if _backend == TransformBackends.TORCH: + return _create_grid_torch(spatial_size, spacing, homogeneous, _dtype, device) + raise ValueError(f"backend {backend} is not supported") + + +def _create_grid_numpy( + spatial_size: Sequence[int], + spacing: Optional[Sequence[float]] = None, + homogeneous: bool = True, + dtype: Union[DtypeLike, torch.dtype] = float, +): + """ + compute a `spatial_size` mesh with the numpy API. """ spacing = spacing or tuple(1.0 for _ in spatial_size) ranges = [np.linspace(-(d - 1.0) / 2.0 * s, (d - 1.0) / 2.0 * s, int(d)) for d, s in zip(spatial_size, spacing)] - coords = np.asarray(np.meshgrid(*ranges, indexing="ij"), dtype=dtype) + coords = np.asarray(np.meshgrid(*ranges, indexing="ij"), dtype=get_equivalent_dtype(dtype, np.ndarray)) if not homogeneous: return coords return np.concatenate([coords, np.ones_like(coords[:1])]) +def _create_grid_torch( + spatial_size: Sequence[int], + spacing: Optional[Sequence[float]] = None, + homogeneous: bool = True, + dtype=torch.float32, + device: Optional[torch.device] = None, +): + """ + compute a `spatial_size` mesh with the torch API. + """ + spacing = spacing or tuple(1.0 for _ in spatial_size) + ranges = [ + torch.linspace( + -(d - 1.0) / 2.0 * s, + (d - 1.0) / 2.0 * s, + int(d), + device=device, + dtype=get_equivalent_dtype(dtype, torch.Tensor), + ) + for d, s in zip(spatial_size, spacing) + ] + coords = meshgrid_ij(*ranges) + if not homogeneous: + return torch.stack(coords) + return torch.stack([*coords, torch.ones_like(coords[0])]) + + def create_control_grid( - spatial_shape: Sequence[int], spacing: Sequence[float], homogeneous: bool = True, dtype: DtypeLike = float + spatial_shape: Sequence[int], + spacing: Sequence[float], + homogeneous: bool = True, + dtype: DtypeLike = float, + device: Optional[torch.device] = None, + backend=TransformBackends.NUMPY, ): """ control grid with two additional point in each direction """ + torch_backend = look_up_option(backend, TransformBackends) == TransformBackends.TORCH + ceil_func: Callable = torch.ceil if torch_backend else np.ceil # type: ignore grid_shape = [] for d, s in zip(spatial_shape, spacing): - d = int(d) + d = torch.as_tensor(d, device=device) if torch_backend else int(d) # type: ignore if d % 2 == 0: - grid_shape.append(np.ceil((d - 1.0) / (2.0 * s) + 0.5) * 2.0 + 2.0) + grid_shape.append(ceil_func((d - 1.0) / (2.0 * s) + 0.5) * 2.0 + 2.0) else: - grid_shape.append(np.ceil((d - 1.0) / (2.0 * s)) * 2.0 + 3.0) - return create_grid(grid_shape, spacing, homogeneous, dtype) + grid_shape.append(ceil_func((d - 1.0) / (2.0 * s)) * 2.0 + 3.0) + return create_grid( + spatial_size=grid_shape, spacing=spacing, homogeneous=homogeneous, dtype=dtype, device=device, backend=backend + ) -def create_rotate(spatial_dims: int, radians: Union[Sequence[float], float]) -> np.ndarray: +def create_rotate( + spatial_dims: int, + radians: Union[Sequence[float], float], + device: Optional[torch.device] = None, + backend=TransformBackends.NUMPY, +) -> NdarrayOrTensor: """ create a 2D or 3D rotation matrix @@ -572,48 +681,83 @@ def create_rotate(spatial_dims: int, radians: Union[Sequence[float], float]) -> radians: rotation radians when spatial_dims == 3, the `radians` sequence corresponds to rotation in the 1st, 2nd, and 3rd dim respectively. + device: device to compute and store the output (when the backend is "torch"). + backend: APIs to use, ``numpy`` or ``torch``. Raises: ValueError: When ``radians`` is empty. ValueError: When ``spatial_dims`` is not one of [2, 3]. """ + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + return _create_rotate( + spatial_dims=spatial_dims, radians=radians, sin_func=np.sin, cos_func=np.cos, eye_func=np.eye + ) + if _backend == TransformBackends.TORCH: + return _create_rotate( + spatial_dims=spatial_dims, + radians=radians, + sin_func=lambda th: torch.sin(torch.as_tensor(th, dtype=torch.float32, device=device)), + cos_func=lambda th: torch.cos(torch.as_tensor(th, dtype=torch.float32, device=device)), + eye_func=lambda rank: torch.eye(rank, device=device), + ) + raise ValueError(f"backend {backend} is not supported") + + +def _create_rotate( + spatial_dims: int, + radians: Union[Sequence[float], float], + sin_func: Callable = np.sin, + cos_func: Callable = np.cos, + eye_func: Callable = np.eye, +) -> NdarrayOrTensor: radians = ensure_tuple(radians) if spatial_dims == 2: if len(radians) >= 1: - sin_, cos_ = np.sin(radians[0]), np.cos(radians[0]) - return np.array([[cos_, -sin_, 0.0], [sin_, cos_, 0.0], [0.0, 0.0, 1.0]]) + sin_, cos_ = sin_func(radians[0]), cos_func(radians[0]) + out = eye_func(3) + out[0, 0], out[0, 1] = cos_, -sin_ + out[1, 0], out[1, 1] = sin_, cos_ + return out # type: ignore raise ValueError("radians must be non empty.") if spatial_dims == 3: affine = None if len(radians) >= 1: - sin_, cos_ = np.sin(radians[0]), np.cos(radians[0]) - affine = np.array( - [[1.0, 0.0, 0.0, 0.0], [0.0, cos_, -sin_, 0.0], [0.0, sin_, cos_, 0.0], [0.0, 0.0, 0.0, 1.0]] - ) + sin_, cos_ = sin_func(radians[0]), cos_func(radians[0]) + affine = eye_func(4) + affine[1, 1], affine[1, 2] = cos_, -sin_ + affine[2, 1], affine[2, 2] = sin_, cos_ if len(radians) >= 2: - sin_, cos_ = np.sin(radians[1]), np.cos(radians[1]) + sin_, cos_ = sin_func(radians[1]), cos_func(radians[1]) if affine is None: raise ValueError("Affine should be a matrix.") - affine = affine @ np.array( - [[cos_, 0.0, sin_, 0.0], [0.0, 1.0, 0.0, 0.0], [-sin_, 0.0, cos_, 0.0], [0.0, 0.0, 0.0, 1.0]] - ) + _affine = eye_func(4) + _affine[0, 0], _affine[0, 2] = cos_, sin_ + _affine[2, 0], _affine[2, 2] = -sin_, cos_ + affine = affine @ _affine if len(radians) >= 3: - sin_, cos_ = np.sin(radians[2]), np.cos(radians[2]) + sin_, cos_ = sin_func(radians[2]), cos_func(radians[2]) if affine is None: raise ValueError("Affine should be a matrix.") - affine = affine @ np.array( - [[cos_, -sin_, 0.0, 0.0], [sin_, cos_, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] - ) + _affine = eye_func(4) + _affine[0, 0], _affine[0, 1] = cos_, -sin_ + _affine[1, 0], _affine[1, 1] = sin_, cos_ + affine = affine @ _affine if affine is None: raise ValueError("radians must be non empty.") - return affine + return affine # type: ignore raise ValueError(f"Unsupported spatial_dims: {spatial_dims}, available options are [2, 3].") -def create_shear(spatial_dims: int, coefs: Union[Sequence[float], float]) -> np.ndarray: +def create_shear( + spatial_dims: int, + coefs: Union[Sequence[float], float], + device: Optional[torch.device] = None, + backend=TransformBackends.NUMPY, +) -> NdarrayOrTensor: """ create a shearing matrix @@ -629,61 +773,120 @@ def create_shear(spatial_dims: int, coefs: Union[Sequence[float], float]) -> np. [0.0, 0.0, 0.0, 1.0], ] + device: device to compute and store the output (when the backend is "torch"). + backend: APIs to use, ``numpy`` or ``torch``. + Raises: NotImplementedError: When ``spatial_dims`` is not one of [2, 3]. """ + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + return _create_shear(spatial_dims=spatial_dims, coefs=coefs, eye_func=np.eye) + if _backend == TransformBackends.TORCH: + return _create_shear( + spatial_dims=spatial_dims, coefs=coefs, eye_func=lambda rank: torch.eye(rank, device=device) + ) + raise ValueError(f"backend {backend} is not supported") + + +def _create_shear(spatial_dims: int, coefs: Union[Sequence[float], float], eye_func=np.eye) -> NdarrayOrTensor: if spatial_dims == 2: coefs = ensure_tuple_size(coefs, dim=2, pad_val=0.0) - return np.array([[1, coefs[0], 0.0], [coefs[1], 1.0, 0.0], [0.0, 0.0, 1.0]]) + out = eye_func(3) + out[0, 1], out[1, 0] = coefs[0], coefs[1] + return out # type: ignore if spatial_dims == 3: coefs = ensure_tuple_size(coefs, dim=6, pad_val=0.0) - return np.array( - [ - [1.0, coefs[0], coefs[1], 0.0], - [coefs[2], 1.0, coefs[3], 0.0], - [coefs[4], coefs[5], 1.0, 0.0], - [0.0, 0.0, 0.0, 1.0], - ] - ) + out = eye_func(4) + out[0, 1], out[0, 2] = coefs[0], coefs[1] + out[1, 0], out[1, 2] = coefs[2], coefs[3] + out[2, 0], out[2, 1] = coefs[4], coefs[5] + return out # type: ignore raise NotImplementedError("Currently only spatial_dims in [2, 3] are supported.") -def create_scale(spatial_dims: int, scaling_factor: Union[Sequence[float], float]): +def create_scale( + spatial_dims: int, + scaling_factor: Union[Sequence[float], float], + device: Optional[torch.device] = None, + backend=TransformBackends.NUMPY, +) -> NdarrayOrTensor: """ create a scaling matrix Args: spatial_dims: spatial rank scaling_factor: scaling factors for every spatial dim, defaults to 1. - """ + device: device to compute and store the output (when the backend is "torch"). + backend: APIs to use, ``numpy`` or ``torch``. + """ + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + return _create_scale(spatial_dims=spatial_dims, scaling_factor=scaling_factor, array_func=np.diag) + if _backend == TransformBackends.TORCH: + return _create_scale( + spatial_dims=spatial_dims, + scaling_factor=scaling_factor, + array_func=lambda x: torch.diag(torch.as_tensor(x, device=device)), + ) + raise ValueError(f"backend {backend} is not supported") + + +def _create_scale( + spatial_dims: int, scaling_factor: Union[Sequence[float], float], array_func=np.diag +) -> NdarrayOrTensor: scaling_factor = ensure_tuple_size(scaling_factor, dim=spatial_dims, pad_val=1.0) - return np.diag(scaling_factor[:spatial_dims] + (1.0,)) + return array_func(scaling_factor[:spatial_dims] + (1.0,)) # type: ignore -def create_translate(spatial_dims: int, shift: Union[Sequence[float], float]) -> np.ndarray: +def create_translate( + spatial_dims: int, + shift: Union[Sequence[float], float], + device: Optional[torch.device] = None, + backend=TransformBackends.NUMPY, +) -> NdarrayOrTensor: """ create a translation matrix Args: spatial_dims: spatial rank shift: translate pixel/voxel for every spatial dim, defaults to 0. - """ + device: device to compute and store the output (when the backend is "torch"). + backend: APIs to use, ``numpy`` or ``torch``. + """ + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + return _create_translate(spatial_dims=spatial_dims, shift=shift, eye_func=np.eye, array_func=np.asarray) + if _backend == TransformBackends.TORCH: + return _create_translate( + spatial_dims=spatial_dims, + shift=shift, + eye_func=lambda x: torch.eye(torch.as_tensor(x), device=device), # type: ignore + array_func=lambda x: torch.as_tensor(x, device=device), + ) + raise ValueError(f"backend {backend} is not supported") + + +def _create_translate( + spatial_dims: int, shift: Union[Sequence[float], float], eye_func=np.eye, array_func=np.asarray +) -> NdarrayOrTensor: shift = ensure_tuple(shift) - affine = np.eye(spatial_dims + 1) + affine = eye_func(spatial_dims + 1) for i, a in enumerate(shift[:spatial_dims]): affine[i, spatial_dims] = a - return np.asarray(affine) + return array_func(affine) # type: ignore def generate_spatial_bounding_box( - img: np.ndarray, + img: NdarrayOrTensor, select_fn: Callable = is_positive, channel_indices: Optional[IndexSelection] = None, margin: Union[Sequence[int], int] = 0, + allow_smaller: bool = True, ) -> Tuple[List[int], List[int]]: """ - generate the spatial bounding box of foreground in the image with start-end positions. + Generate the spatial bounding box of foreground in the image with start-end positions (inclusive). Users can define arbitrary function to select expected foreground from the whole image or specified channels. And it can also add margin to every dim of the bounding box. The output format of the coordinates is: @@ -691,18 +894,21 @@ def generate_spatial_bounding_box( [1st_spatial_dim_start, 2nd_spatial_dim_start, ..., Nth_spatial_dim_start], [1st_spatial_dim_end, 2nd_spatial_dim_end, ..., Nth_spatial_dim_end] - The bounding boxes edges are aligned with the input image edges. - This function returns [-1, -1, ...], [-1, -1, ...] if there's no positive intensity. + If `allow_smaller`, the bounding boxes edges are aligned with the input image edges. + This function returns [0, 0, ...], [0, 0, ...] if there's no positive intensity. Args: - img: source image to generate bounding box from. + img: a "channel-first" image of shape (C, spatial_dim1[, spatial_dim2, ...]) to generate bounding box from. select_fn: function to select expected foreground, default is to select values > 0. channel_indices: if defined, select foreground only on the specified channels of image. if None, select foreground on the whole image. margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. + allow_smaller: when computing box size with `margin`, whether allow the image size to be smaller + than box size, default to `True`. """ + spatial_size = img.shape[1:] data = img[list(ensure_tuple(channel_indices))] if channel_indices is not None else img - data = np.any(select_fn(data), axis=0) + data = select_fn(data).any(0) ndim = len(data.shape) margin = ensure_tuple_rep(margin, ndim) for m in margin: @@ -713,19 +919,28 @@ def generate_spatial_bounding_box( box_end = [0] * ndim for di, ax in enumerate(itertools.combinations(reversed(range(ndim)), ndim - 1)): - dt = data.any(axis=ax) - if not np.any(dt): + dt = data + if len(ax) != 0: + dt = any_np_pt(dt, ax) + + if not dt.any(): # if no foreground, return all zero bounding box coords return [0] * ndim, [0] * ndim - min_d = max(np.argmax(dt) - margin[di], 0) - max_d = max(data.shape[di] - max(np.argmax(dt[::-1]) - margin[di], 0), min_d + 1) - box_start[di], box_end[di] = min_d, max_d + arg_max = where(dt == dt.max())[0] + min_d = arg_max[0] - margin[di] + max_d = arg_max[-1] + margin[di] + 1 + if allow_smaller: + min_d = max(min_d, 0) + max_d = min(max_d, spatial_size[di]) + + box_start[di] = min_d.detach().cpu().item() if isinstance(min_d, torch.Tensor) else min_d # type: ignore + box_end[di] = max_d.detach().cpu().item() if isinstance(max_d, torch.Tensor) else max_d # type: ignore return box_start, box_end -def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Optional[int] = None) -> torch.Tensor: +def get_largest_connected_component_mask(img: NdarrayTensor, connectivity: Optional[int] = None) -> NdarrayTensor: """ Gets the largest connected component mask of an image. @@ -733,15 +948,54 @@ def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Option img: Image to get largest connected component from. Shape is (spatial_dim1 [, spatial_dim2, ...]) connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. Accepted values are ranging from 1 to input.ndim. If ``None``, a full - connectivity of ``input.ndim`` is used. + connectivity of ``input.ndim`` is used. for more details: + https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.label. """ - img_arr = img.detach().cpu().numpy() - largest_cc = np.zeros(shape=img_arr.shape, dtype=img_arr.dtype) + if isinstance(img, torch.Tensor) and has_cp and has_cucim: + x_cupy = monai.transforms.ToCupy()(img.short()) + x_label = cucim.skimage.measure.label(x_cupy, connectivity=connectivity) + vals, counts = cp.unique(x_label[cp.nonzero(x_label)], return_counts=True) + comp = x_label == vals[cp.ndarray.argmax(counts)] + out_tensor = monai.transforms.ToTensor(device=img.device)(comp) + out_tensor = out_tensor.bool() + + return out_tensor # type: ignore + + img_arr = convert_data_type(img, np.ndarray)[0] + largest_cc: np.ndarray = np.zeros(shape=img_arr.shape, dtype=img_arr.dtype) img_arr = measure.label(img_arr, connectivity=connectivity) if img_arr.max() != 0: largest_cc[...] = img_arr == (np.argmax(np.bincount(img_arr.flat)[1:]) + 1) - return torch.as_tensor(largest_cc, device=img.device) + return convert_to_dst_type(largest_cc, dst=img, dtype=largest_cc.dtype)[0] + + +def get_unique_labels( + img: NdarrayOrTensor, is_onehot: bool, discard: Optional[Union[int, Iterable[int]]] = None +) -> Set[int]: + """Get list of non-background labels in an image. + + Args: + img: Image to be processed. Shape should be [C, W, H, [D]] with C=1 if not onehot else `num_classes`. + is_onehot: Boolean as to whether input image is one-hotted. If one-hotted, only return channels with + discard: Can be used to remove labels (e.g., background). Can be any value, sequence of values, or + `None` (nothing is discarded). + + Returns: + Set of labels + """ + applied_labels: Set[int] + n_channels = img.shape[0] + if is_onehot: + applied_labels = {i for i, s in enumerate(img) if s.sum() > 0} + else: + if n_channels != 1: + raise ValueError("If input not one-hotted, should only be 1 channel.") + applied_labels = set(unique(img).tolist()) + if discard is not None: + for i in ensure_tuple(discard): + applied_labels.discard(i) + return applied_labels def fill_holes( @@ -780,7 +1034,7 @@ def fill_holes( structure = ndimage.generate_binary_structure(spatial_dims, connectivity or spatial_dims) # Get labels if not provided. Exclude background label. - applied_labels = set(applied_labels or (range(num_channels) if is_one_hot else np.unique(img_arr))) + applied_labels = set(applied_labels) if applied_labels is not None else get_unique_labels(img_arr, is_one_hot) background_label = 0 applied_labels.discard(background_label) @@ -804,7 +1058,7 @@ def fill_holes( def get_extreme_points( - img: np.ndarray, rand_state: Optional[np.random.RandomState] = None, background: int = 0, pert: float = 0.0 + img: NdarrayOrTensor, rand_state: Optional[np.random.RandomState] = None, background: int = 0, pert: float = 0.0 ) -> List[Tuple[int, ...]]: """ Generate extreme points from an image. These are used to generate initial segmentation @@ -828,7 +1082,7 @@ def get_extreme_points( """ if rand_state is None: rand_state = np.random.random.__self__ # type: ignore - indices = np.where(img != background) + indices = where(img != background) if np.size(indices[0]) == 0: raise ValueError("get_extreme_points: no foreground object in mask!") @@ -840,7 +1094,9 @@ def _get_point(val, dim): val : value for comparison dim : dimension in which to look for value """ - idx = rand_state.choice(np.where(indices[dim] == val)[0]) + idx = where(indices[dim] == val)[0] + idx = idx.cpu() if isinstance(idx, torch.Tensor) else idx + idx = rand_state.choice(idx) pt = [] for j in range(img.ndim): # add +- pert to each dimension @@ -852,19 +1108,19 @@ def _get_point(val, dim): points = [] for i in range(img.ndim): - points.append(tuple(_get_point(np.min(indices[i][...]), i))) - points.append(tuple(_get_point(np.max(indices[i][...]), i))) + points.append(tuple(_get_point(indices[i].min(), i))) + points.append(tuple(_get_point(indices[i].max(), i))) return points def extreme_points_to_image( points: List[Tuple[int, ...]], - label: np.ndarray, + label: NdarrayOrTensor, sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 0.0, rescale_min: float = -1.0, rescale_max: float = 1.0, -): +) -> torch.Tensor: """ Please refer to :py:class:`monai.transforms.AddExtremePointsChannel` for the usage. @@ -882,27 +1138,30 @@ def extreme_points_to_image( rescale_max: maximum value of output data. """ # points to image - points_image = torch.zeros(label.shape[1:], dtype=torch.float) + # points_image = torch.zeros(label.shape[1:], dtype=torch.float) + points_image = torch.zeros_like(torch.as_tensor(label[0]), dtype=torch.float) for p in points: points_image[p] = 1.0 + if isinstance(sigma, Sequence): + sigma = [torch.as_tensor(s, device=points_image.device) for s in sigma] + else: + sigma = torch.as_tensor(sigma, device=points_image.device) + # add channel and add batch points_image = points_image.unsqueeze(0).unsqueeze(0) gaussian_filter = GaussianFilter(label.ndim - 1, sigma=sigma) - points_image = gaussian_filter(points_image).squeeze(0).detach().numpy() + points_image = gaussian_filter(points_image).squeeze(0).detach() # rescale the points image to [rescale_min, rescale_max] - min_intensity = np.min(points_image) - max_intensity = np.max(points_image) + min_intensity = points_image.min() + max_intensity = points_image.max() points_image = (points_image - min_intensity) / (max_intensity - min_intensity) - points_image = points_image * (rescale_max - rescale_min) + rescale_min - return points_image + return points_image * (rescale_max - rescale_min) + rescale_min def map_spatial_axes( - img_ndim: int, - spatial_axes: Optional[Union[Sequence[int], int]] = None, - channel_first: bool = True, + img_ndim: int, spatial_axes: Optional[Union[Sequence[int], int]] = None, channel_first: bool = True ) -> List[int]: """ Utility to map the spatial axes to real axes in channel first/last shape. @@ -989,7 +1248,7 @@ def allow_missing_keys_mode(transform: Union[MapTransform, Compose, Tuple[MapTra def convert_inverse_interp_mode(trans_info: List, mode: str = "nearest", align_corners: Optional[bool] = None): """ Change the interpolation mode when inverting spatial transforms, default to "nearest". - This function modifies trans_info's `InverseKeys.EXTRA_INFO`. + This function modifies trans_info's `TraceKeys.EXTRA_INFO`. See also: :py:class:`monai.transform.inverse.InvertibleTransform` @@ -1002,21 +1261,21 @@ def convert_inverse_interp_mode(trans_info: List, mode: str = "nearest", align_c interp_modes = [i.value for i in InterpolateMode] + [i.value for i in GridSampleMode] # set to string for DataLoader collation - align_corners_ = "none" if align_corners is None else align_corners + align_corners_ = TraceKeys.NONE if align_corners is None else align_corners for item in ensure_tuple(trans_info): - if InverseKeys.EXTRA_INFO in item: - orig_mode = item[InverseKeys.EXTRA_INFO].get("mode", None) + if TraceKeys.EXTRA_INFO in item: + orig_mode = item[TraceKeys.EXTRA_INFO].get("mode", None) if orig_mode is not None: if orig_mode[0] in interp_modes: - item[InverseKeys.EXTRA_INFO]["mode"] = [mode for _ in range(len(mode))] + item[TraceKeys.EXTRA_INFO]["mode"] = [mode for _ in range(len(mode))] elif orig_mode in interp_modes: - item[InverseKeys.EXTRA_INFO]["mode"] = mode - if "align_corners" in item[InverseKeys.EXTRA_INFO]: - if issequenceiterable(item[InverseKeys.EXTRA_INFO]["align_corners"]): - item[InverseKeys.EXTRA_INFO]["align_corners"] = [align_corners_ for _ in range(len(mode))] + item[TraceKeys.EXTRA_INFO]["mode"] = mode + if "align_corners" in item[TraceKeys.EXTRA_INFO]: + if issequenceiterable(item[TraceKeys.EXTRA_INFO]["align_corners"]): + item[TraceKeys.EXTRA_INFO]["align_corners"] = [align_corners_ for _ in range(len(mode))] else: - item[InverseKeys.EXTRA_INFO]["align_corners"] = align_corners_ + item[TraceKeys.EXTRA_INFO]["align_corners"] = align_corners_ return trans_info @@ -1041,12 +1300,7 @@ def compute_divisible_spatial_size(spatial_shape: Sequence[int], k: Union[Sequen def equalize_hist( - img: np.ndarray, - mask: Optional[np.ndarray] = None, - num_bins: int = 256, - min: int = 0, - max: int = 255, - dtype: DtypeLike = np.float32, + img: np.ndarray, mask: Optional[np.ndarray] = None, num_bins: int = 256, min: int = 0, max: int = 255 ) -> np.ndarray: """ Utility to equalize input image based on the histogram. @@ -1061,9 +1315,9 @@ def equalize_hist( https://numpy.org/doc/stable/reference/generated/numpy.histogram.html. min: the min value to normalize input image, default to `0`. max: the max value to normalize input image, default to `255`. - dtype: data type of the output, default to `float32`. """ + orig_shape = img.shape hist_img = img[np.array(mask, dtype=bool)] if mask is not None else img if has_skimage: @@ -1078,8 +1332,7 @@ def equalize_hist( # apply linear interpolation img = np.interp(img.flatten(), bins, cum) - - return img.reshape(orig_shape).astype(dtype) + return img.reshape(orig_shape) class Fourier: @@ -1088,38 +1341,70 @@ class Fourier: """ @staticmethod - def shift_fourier(x: torch.Tensor, n_dims: int) -> torch.Tensor: + @deprecated_arg( + name="n_dims", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." + ) + def shift_fourier(x: NdarrayOrTensor, spatial_dims: int, n_dims: Optional[int] = None) -> NdarrayOrTensor: """ Applies fourier transform and shifts the zero-frequency component to the center of the spectrum. Only the spatial dimensions get transformed. Args: x: Image to transform. - n_dims: Number of spatial dimensions. + spatial_dims: Number of spatial dimensions. + + .. deprecated:: 0.6.0 + ``n_dims`` is deprecated, use ``spatial_dims`` instead. + Returns k: K-space data. """ - k: torch.Tensor = torch.fft.fftshift( - torch.fft.fftn(x, dim=tuple(range(-n_dims, 0))), dim=tuple(range(-n_dims, 0)) - ) + if n_dims is not None: + spatial_dims = n_dims + dims = tuple(range(-spatial_dims, 0)) + k: NdarrayOrTensor + if isinstance(x, torch.Tensor): + if hasattr(torch.fft, "fftshift"): # `fftshift` is new in torch 1.8.0 + k = torch.fft.fftshift(torch.fft.fftn(x, dim=dims), dim=dims) + else: + # if using old PyTorch, will convert to numpy array and return + k = np.fft.fftshift(np.fft.fftn(x.cpu().numpy(), axes=dims), axes=dims) + else: + k = np.fft.fftshift(np.fft.fftn(x, axes=dims), axes=dims) return k @staticmethod - def inv_shift_fourier(k: torch.Tensor, n_dims: int) -> torch.Tensor: + @deprecated_arg( + name="n_dims", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." + ) + def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: Optional[int] = None) -> NdarrayOrTensor: """ Applies inverse shift and fourier transform. Only the spatial dimensions are transformed. Args: k: K-space data. - n_dims: Number of spatial dimensions. + spatial_dims: Number of spatial dimensions. + + .. deprecated:: 0.6.0 + ``n_dims`` is deprecated, use ``spatial_dims`` instead. + Returns: x: Tensor in image space. """ - x: torch.Tensor = torch.fft.ifftn( - torch.fft.ifftshift(k, dim=tuple(range(-n_dims, 0))), dim=tuple(range(-n_dims, 0)) - ).real - return x + if n_dims is not None: + spatial_dims = n_dims + dims = tuple(range(-spatial_dims, 0)) + out: NdarrayOrTensor + if isinstance(k, torch.Tensor): + if hasattr(torch.fft, "ifftshift"): # `ifftshift` is new in torch 1.8.0 + out = torch.fft.ifftn(torch.fft.ifftshift(k, dim=dims), dim=dims, norm="backward").real + else: + # if using old PyTorch, will convert to numpy array and return + out = np.fft.ifftn(np.fft.ifftshift(k.cpu().numpy(), axes=dims), axes=dims).real + else: + out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims).real + return out def get_number_image_type_conversions(transform: Compose, test_data: Any, key: Optional[Hashable] = None) -> int: @@ -1149,9 +1434,7 @@ def _get_data(obj, key): prev_data = _get_data(test_data, key) prev_type = type(prev_data) prev_device = prev_data.device if isinstance(prev_data, torch.Tensor) else None - test_data = monai.transforms.transform.apply_transform( - _transform, test_data, transform.map_items, transform.unpack_items - ) + test_data = apply_transform(_transform, test_data, transform.map_items, transform.unpack_items) # every time the type or device changes, increment the counter curr_data = _get_data(test_data, key) curr_device = curr_data.device if isinstance(curr_data, torch.Tensor) else None @@ -1178,24 +1461,30 @@ def get_transform_backends(): continue unique_transforms.append(obj) - if isclass(obj) and issubclass(obj, Transform): - if n in [ - "Transform", + if ( + isclass(obj) + and issubclass(obj, Transform) + and n + not in [ + "BatchInverseTransform", + "Compose", + "Decollated", + "InvertD", "InvertibleTransform", "Lambda", "LambdaD", - "Compose", - "RandomizableTransform", + "MapTransform", "OneOf", - "BatchInverseTransform", - "InverteD", - ]: - continue - - backends[n] = [ - TransformBackends.TORCH in obj.backend, - TransformBackends.NUMPY in obj.backend, + "PadListDataCollate", + "RandLambda", + "RandLambdaD", + "RandTorchVisionD", + "RandomizableTransform", + "TorchVisionD", + "Transform", ] + ): + backends[n] = [TransformBackends.TORCH in obj.backend, TransformBackends.NUMPY in obj.backend] return backends @@ -1212,7 +1501,7 @@ def print_color(t, color): print(f"\033[{color}m{t}\033[00m") def print_table_column(name, torch, numpy, color=Colors.none): - print_color("{:<50} {:<8} {:<8}".format(name, torch, numpy), color) + print_color(f"{name:<50} {torch:<8} {numpy:<8}", color) backends = get_transform_backends() n_total = len(backends) @@ -1240,5 +1529,49 @@ def print_table_column(name, torch, numpy, color=Colors.none): print_color(f"Number of uncategorised: {n_uncategorized}", Colors.red) +def convert_pad_mode(dst: NdarrayOrTensor, mode: Union[NumpyPadMode, PytorchPadMode, str]): + """ + Utility to convert padding mode between numpy array and PyTorch Tensor. + + Args: + dst: target data to convert padding mode for, should be numpy array or PyTorch Tensor. + mode: current padding mode. + + """ + mode = mode.value if isinstance(mode, (NumpyPadMode, PytorchPadMode)) else mode + if isinstance(dst, torch.Tensor): + if mode == "wrap": + mode = "circular" + if mode == "edge": + mode = "replicate" + return look_up_option(mode, PytorchPadMode) + if isinstance(dst, np.ndarray): + if mode == "circular": + mode = "wrap" + if mode == "replicate": + mode = "edge" + return look_up_option(mode, NumpyPadMode) + raise ValueError(f"unsupported data type: {type(dst)}.") + + +def convert_to_contiguous(data, **kwargs): + """ + Check and ensure the numpy array or PyTorch Tensor in data to be contuguous in memory. + + Args: + data: input data to convert, will recursively convert the numpy array or PyTorch Tensor in dict and sequence. + kwargs: if `x` is PyTorch Tensor, additional args for `torch.contiguous`, more details: + https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html#torch.Tensor.contiguous. + + """ + if isinstance(data, (np.ndarray, torch.Tensor, str, bytes)): + return ascontiguousarray(data, **kwargs) + if isinstance(data, Mapping): + return {k: convert_to_contiguous(v, **kwargs) for k, v in data.items()} + if isinstance(data, Sequence): + return [convert_to_contiguous(i, **kwargs) for i in data] + return data + + if __name__ == "__main__": print_transform_backends() diff --git a/monai/transforms/utils_create_transform_ims.py b/monai/transforms/utils_create_transform_ims.py new file mode 100644 index 0000000000..b096e1b93d --- /dev/null +++ b/monai/transforms/utils_create_transform_ims.py @@ -0,0 +1,735 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pathlib +import tempfile +import textwrap +from copy import deepcopy +from glob import glob +from typing import TYPE_CHECKING, Callable + +import numpy as np +import torch + +from monai.apps import download_and_extract +from monai.transforms import ( + AddChanneld, + Affine, + Affined, + AsDiscrete, + Compose, + Flip, + Flipd, + LoadImaged, + MapTransform, + Orientation, + Orientationd, + Rand3DElastic, + Rand3DElasticd, + RandFlip, + RandFlipd, + Randomizable, + RandRotate, + RandRotated, + RandZoom, + RandZoomd, + Rotate, + Rotate90, + Rotate90d, + Rotated, + ScaleIntensity, + ScaleIntensityd, + SpatialPadd, + Zoom, + Zoomd, +) +from monai.transforms.croppad.array import ( + BorderPad, + CenterScaleCrop, + CenterSpatialCrop, + CropForeground, + DivisiblePad, + RandCropByLabelClasses, + RandCropByPosNegLabel, + RandScaleCrop, + RandSpatialCrop, + RandSpatialCropSamples, + RandWeightedCrop, + ResizeWithPadOrCrop, + SpatialCrop, + SpatialPad, +) +from monai.transforms.croppad.dictionary import ( + BorderPadd, + CenterScaleCropd, + CenterSpatialCropd, + CropForegroundd, + DivisiblePadd, + RandCropByLabelClassesd, + RandCropByPosNegLabeld, + RandScaleCropd, + RandSpatialCropd, + RandSpatialCropSamplesd, + RandWeightedCropd, + ResizeWithPadOrCropd, + SpatialCropd, +) +from monai.transforms.intensity.array import ( + AdjustContrast, + GaussianSharpen, + GaussianSmooth, + GibbsNoise, + HistogramNormalize, + KSpaceSpikeNoise, + MaskIntensity, + NormalizeIntensity, + RandAdjustContrast, + RandBiasField, + RandCoarseDropout, + RandCoarseShuffle, + RandGaussianNoise, + RandGaussianSharpen, + RandGaussianSmooth, + RandGibbsNoise, + RandHistogramShift, + RandKSpaceSpikeNoise, + RandRicianNoise, + RandScaleIntensity, + RandShiftIntensity, + RandStdShiftIntensity, + SavitzkyGolaySmooth, + ScaleIntensityRange, + ScaleIntensityRangePercentiles, + ShiftIntensity, + StdShiftIntensity, + ThresholdIntensity, +) +from monai.transforms.intensity.dictionary import ( + AdjustContrastd, + GaussianSharpend, + GaussianSmoothd, + GibbsNoised, + HistogramNormalized, + KSpaceSpikeNoised, + MaskIntensityd, + NormalizeIntensityd, + RandAdjustContrastd, + RandBiasFieldd, + RandCoarseDropoutd, + RandCoarseShuffled, + RandGaussianNoised, + RandGaussianSharpend, + RandGaussianSmoothd, + RandGibbsNoised, + RandHistogramShiftd, + RandKSpaceSpikeNoised, + RandRicianNoised, + RandScaleIntensityd, + RandShiftIntensityd, + RandStdShiftIntensityd, + SavitzkyGolaySmoothd, + ScaleIntensityRanged, + ScaleIntensityRangePercentilesd, + ShiftIntensityd, + StdShiftIntensityd, + ThresholdIntensityd, +) +from monai.transforms.post.array import KeepLargestConnectedComponent, LabelFilter, LabelToContour +from monai.transforms.post.dictionary import AsDiscreted, KeepLargestConnectedComponentd, LabelFilterd, LabelToContourd +from monai.transforms.smooth_field.array import ( + RandSmoothDeform, + RandSmoothFieldAdjustContrast, + RandSmoothFieldAdjustIntensity, +) +from monai.transforms.smooth_field.dictionary import ( + RandSmoothDeformd, + RandSmoothFieldAdjustContrastd, + RandSmoothFieldAdjustIntensityd, +) +from monai.transforms.spatial.array import ( + GridDistortion, + Rand2DElastic, + RandAffine, + RandAxisFlip, + RandGridDistortion, + RandRotate90, + Resize, + Spacing, +) +from monai.transforms.spatial.dictionary import ( + GridDistortiond, + Rand2DElasticd, + RandAffined, + RandAxisFlipd, + RandGridDistortiond, + RandRotate90d, + Resized, + Spacingd, +) +from monai.utils.enums import CommonKeys +from monai.utils.module import optional_import + +if TYPE_CHECKING: + import matplotlib.pyplot as plt + + has_matplotlib = True + +else: + plt, has_matplotlib = optional_import("matplotlib.pyplot") + + +def get_data(keys): + """Get the example data to be used. + + Use MarsAtlas as it only contains 1 image for quick download and + that image is parcellated. + """ + cache_dir = os.environ.get("MONAI_DATA_DIRECTORY") or tempfile.mkdtemp() + fname = "MarsAtlas-MNI-Colin27.zip" + url = "https://www.dropbox.com/s/ndz8qtqblkciole/" + fname + "?dl=1" + out_path = os.path.join(cache_dir, "MarsAtlas-MNI-Colin27") + zip_path = os.path.join(cache_dir, fname) + + download_and_extract(url, zip_path, out_path) + + image, label = sorted(glob(os.path.join(out_path, "*.nii"))) + + data = {CommonKeys.IMAGE: image, CommonKeys.LABEL: label} + + transforms = Compose( + [LoadImaged(keys), AddChanneld(keys), ScaleIntensityd(CommonKeys.IMAGE), Rotate90d(keys, spatial_axes=[0, 2])] + ) + data = transforms(data) + max_size = max(data[keys[0]].shape) + padder = SpatialPadd(keys, (max_size, max_size, max_size)) + return padder(data) + + +def update_docstring(code_path, transform_name): + """ + Find the documentation for a given transform and if it's missing, + add a pointer to the transform's example image. + """ + with open(code_path) as f: + contents = f.readlines() + doc_start = None + for i, line in enumerate(contents): + # find the line containing start of the transform documentation + if "`" + transform_name + "`" in line: + doc_start = i + break + if doc_start is None: + raise RuntimeError("Couldn't find transform documentation") + + # if image is already in docs, nothing to do + image_line = doc_start + 2 + if ".. image" in contents[image_line]: + return + + # add the line for the image and the alt text + contents_orig = deepcopy(contents) + contents.insert( + image_line, + ".. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/" + transform_name + ".png\n", + ) + contents.insert(image_line + 1, " :alt: example of " + transform_name + "\n") + + # check that we've only added two lines + if len(contents) != len(contents_orig) + 2: + raise AssertionError + + # write the updated doc to overwrite the original + with open(code_path, "w") as f: + f.writelines(contents) + + +def pre_process_data(data, ndim, is_map, is_post): + """If transform requires 2D data, then convert to 2D""" + if ndim == 2: + for k in keys: + data[k] = data[k][..., data[k].shape[-1] // 2] + + if is_map: + return data + return data[CommonKeys.LABEL] if is_post else data[CommonKeys.IMAGE] + + +def get_2d_slice(image, view, is_label): + """If image is 3d, get the central slice. If is already 2d, return as-is. + If image is label, set 0 to np.nan. + """ + if image.ndim == 2: + out = image + else: + shape = image.shape + slices = [slice(0, s) for s in shape] + _slice = shape[view] // 2 + slices[view] = slice(_slice, _slice + 1) + slices = tuple(slices) + out = np.squeeze(image[slices], view) + if is_label: + out[out == 0] = np.nan + return out + + +def get_stacked_2d_ims(im, is_label): + """Get the 3 orthogonal views and stack them into 1 image. + Requires that all images be same size, but this is taken care + of by the `SpatialPadd` earlier. + """ + return [get_2d_slice(im, i, is_label) for i in range(3)] + + +def get_stacked_before_after(before, after, is_label=False): + """Stack before and after images into 1 image if 3d. + Requires that before and after images be the same size. + """ + return [get_stacked_2d_ims(d, is_label) for d in (before, after)] + + +def save_image(images, labels, filename, transform_name, transform_args, shapes, colorbar=False): + """Save image to file, ensuring there's no whitespace around the edge.""" + plt.rcParams.update({"font.family": "monospace"}) + plt.style.use("dark_background") + nrow = len(images) # before and after (should always be 2) + ncol = len(images[0]) # num orthogonal views (either 1 or 3) + # roughly estimate the height_ratios of the first:second row + hs = [float(r[0].shape[0]) for r in images] + fig = plt.figure(tight_layout=True) + spec = fig.add_gridspec(nrow, ncol, hspace=0, wspace=0, height_ratios=hs) + for row in range(nrow): + vmin = min(i.min() for i in images[row]) + vmax = max(i.max() for i in images[row]) + for col in range(ncol): + ax = fig.add_subplot(spec[row, col]) + imshow = ax.imshow(images[row][col], cmap="gray", vmin=vmin, vmax=vmax) + ax.set_aspect("equal") + if colorbar and col == ncol - 1: + plt.colorbar(imshow, ax=ax) + if col == 0: + y_label = "After" if row else "Before" + y_label += ("\n" + shapes[row]) if shapes[0] != shapes[1] else "" + ax.set_ylabel(y_label) + # print yticks for the right most column + if col != ncol - 1 or colorbar: + ax.set_yticks([]) + else: + ax.yaxis.tick_right() + for n, label in enumerate(ax.yaxis.get_ticklabels()): + if n > 2: + label.set_visible(False) + ax.set_xticks([]) + ax.set_frame_on(False) + if labels is not None: + ax.imshow(labels[row][col], cmap="hsv", alpha=0.9, interpolation="nearest") + # title is e.g., Flipd(keys=keys, spatial_axis=0) + title = transform_name + "(" + for k, v in transform_args.items(): + title += k + "=" + if isinstance(v, str): + title += "'" + v + "'" + elif isinstance(v, (np.ndarray, torch.Tensor)): + title += "[array]" + elif isinstance(v, Callable): + title += "[callable]" + else: + title += str(v) + title += ", " + if len(transform_args) > 0: + title = title[:-2] + title += ")" + # shorten the lines + title = textwrap.fill(title, 50, break_long_words=False, subsequent_indent=" " * (len(transform_name) + 1)) + fig.suptitle(title, x=0.1, horizontalalignment="left") + fig.savefig(filename) + plt.close(fig) + + +def get_images(data, is_label=False): + """Get image. If is dictionary, extract key. If is list, stack. If both dictionary and list, do both. + Also return the image size as string to be used im the imshow. If it's a list, return `N x (H,W,D)`. + """ + # If not a list, convert + if not isinstance(data, list): + data = [data] + key = CommonKeys.LABEL if is_label else CommonKeys.IMAGE + is_map = isinstance(data[0], dict) + # length of the list will be equal to number of samples produced. This will be 1 except for transforms that + # produce `num_samples`. + data = [d[key] if is_map else d for d in data] + data = [d[0] for d in data] # remove channel component + + # for each sample, create a list of the orthogonal views. If image is 2d, length will be 1. If 3d, there + # will be three orthogonal views + num_samples = len(data) + num_orthog_views = 3 if data[0].ndim == 3 else 1 + shape_str = (f"{num_samples} x " if num_samples > 1 else "") + str(data[0].shape) + for i in range(num_samples): + data[i] = [get_2d_slice(data[i], view, is_label) for view in range(num_orthog_views)] + + out = [] + if num_samples == 1: + out = data[0] + else: + # we might need to panel the images. this happens if a transform produces e.g. 4 output images. + # In this case, we create a 2-by-2 grid from them. Output will be a list containing n_orthog_views, + # each element being either the image (if num_samples is 1) or the panelled image. + nrows = int(np.floor(num_samples**0.5)) + for view in range(num_orthog_views): + result = np.asarray([d[view] for d in data]) + nindex, height, width = result.shape + ncols = nindex // nrows + # only implemented for square number of images (e.g. 4 images goes to a 2-by-2 panel) + if nindex != nrows * ncols: + raise NotImplementedError + # want result.shape = (height*nrows, width*ncols), have to be careful about striding + result = result.reshape(nrows, ncols, height, width).swapaxes(1, 2).reshape(height * nrows, width * ncols) + out.append(result) + return out, shape_str + + +def create_transform_im( + transform, transform_args, data, ndim=3, colorbar=False, update_doc=True, seed=0, is_post=False +): + """Create an image with the before and after of the transform. + Also update the transform's documentation to point to this image.""" + + transform = transform(**transform_args) + + if not has_matplotlib: + raise RuntimeError + + if isinstance(transform, Randomizable): + # increment the seed for map transforms so they're different to the array versions. + seed = seed + 1 if isinstance(transform, MapTransform) else seed + transform.set_random_state(seed) + + out_dir = os.environ.get("MONAI_DOC_IMAGES") + if out_dir is None: + raise RuntimeError( + "Please git clone https://github.com/Project-MONAI/DocImages" + + " and then set the environment variable `MONAI_DOC_IMAGES`" + ) + out_dir = os.path.join(out_dir, "transforms") + + # Path is transform name + transform_name = transform.__class__.__name__ + out_fname = transform_name + ".png" + out_file = os.path.join(out_dir, out_fname) + + is_map = isinstance(transform, MapTransform) + data_in = pre_process_data(deepcopy(data), ndim, is_map, is_post) + + data_tr = transform(deepcopy(data_in)) + + images_before, before_shape = get_images(data_in) + images_after, after_shape = get_images(data_tr) + images = (images_before, images_after) + shapes = (before_shape, after_shape) + + labels = None + if is_map: + labels_before, *_ = get_images(data_in, is_label=True) + labels_after, *_ = get_images(data_tr, is_label=True) + labels = (labels_before, labels_after) + + save_image(images, labels, out_file, transform_name, transform_args, shapes, colorbar) + + if update_doc: + base_dir = pathlib.Path(__file__).parent.parent.parent + rst_path = os.path.join(base_dir, "docs", "source", "transforms.rst") + update_docstring(rst_path, transform_name) + + +if __name__ == "__main__": + + keys = [CommonKeys.IMAGE, CommonKeys.LABEL] + data = get_data(keys) + create_transform_im(RandFlip, dict(prob=1, spatial_axis=1), data) + create_transform_im(RandFlipd, dict(keys=keys, prob=1, spatial_axis=2), data) + create_transform_im(Flip, dict(spatial_axis=1), data) + create_transform_im(Flipd, dict(keys=keys, spatial_axis=2), data) + create_transform_im(Flipd, dict(keys=keys, spatial_axis=2), data) + create_transform_im(Orientation, dict(axcodes="RPI", image_only=True), data) + create_transform_im(Orientationd, dict(keys=keys, axcodes="RPI"), data) + create_transform_im( + Rand3DElastic, dict(prob=1.0, sigma_range=(1, 2), magnitude_range=(0.5, 0.5), shear_range=(1, 1, 1)), data + ) + create_transform_im(Affine, dict(shear_params=(0, 0.5, 0), image_only=True, padding_mode="zeros"), data) + create_transform_im( + Affined, dict(keys=keys, shear_params=(0, 0.5, 0), mode=["bilinear", "nearest"], padding_mode="zeros"), data + ) + create_transform_im(RandAffine, dict(prob=1, shear_range=(0.5, 0.5), padding_mode="zeros"), data) + create_transform_im( + RandAffined, + dict(keys=keys, prob=1, shear_range=(0.5, 0.5), mode=["bilinear", "nearest"], padding_mode="zeros"), + data, + ) + create_transform_im( + Rand3DElastic, dict(sigma_range=(5, 7), magnitude_range=(50, 150), prob=1, padding_mode="zeros"), data + ) + create_transform_im( + Rand2DElastic, dict(prob=1, spacing=(20, 20), magnitude_range=(1, 2), padding_mode="zeros"), data, 2 + ) + create_transform_im( + Rand2DElasticd, + dict( + keys=keys, + prob=1, + spacing=(20, 20), + magnitude_range=(1, 2), + padding_mode="zeros", + mode=["bilinear", "nearest"], + ), + data, + 2, + ) + create_transform_im( + Rand3DElasticd, + dict( + keys=keys, + sigma_range=(5, 7), + magnitude_range=(50, 150), + prob=1, + padding_mode="zeros", + mode=["bilinear", "nearest"], + ), + data, + ) + create_transform_im(Rotate90, dict(spatial_axes=(1, 2)), data) + create_transform_im(Rotate90d, dict(keys=keys, spatial_axes=(1, 2)), data) + create_transform_im(RandRotate90, dict(prob=1), data) + create_transform_im(RandRotate90d, dict(keys=keys, prob=1), data) + create_transform_im(Rotate, dict(angle=0.1), data) + create_transform_im(Rotated, dict(keys=keys, angle=0.1, mode=["bilinear", "nearest"]), data) + create_transform_im(RandRotate, dict(prob=1, range_x=[0.4, 0.4]), data) + create_transform_im(RandRotated, dict(keys=keys, prob=1, range_x=[0.4, 0.4], mode=["bilinear", "nearest"]), data) + create_transform_im(Zoom, dict(zoom=0.6), data) + create_transform_im(Zoomd, dict(keys=keys, zoom=1.3, mode=["area", "nearest"]), data) + create_transform_im(RandZoom, dict(prob=1, min_zoom=0.6, max_zoom=0.8), data) + create_transform_im(RandZoomd, dict(keys=keys, prob=1, min_zoom=1.3, max_zoom=1.5, mode=["area", "nearest"]), data) + create_transform_im(ScaleIntensity, dict(minv=0, maxv=10), data, colorbar=True) + create_transform_im(ScaleIntensityd, dict(keys=CommonKeys.IMAGE, minv=0, maxv=10), data, colorbar=True) + create_transform_im(RandScaleIntensity, dict(prob=1.0, factors=(5, 10)), data, colorbar=True) + create_transform_im( + RandScaleIntensityd, dict(keys=CommonKeys.IMAGE, prob=1.0, factors=(5, 10)), data, colorbar=True + ) + create_transform_im(DivisiblePad, dict(k=64), data) + create_transform_im(DivisiblePadd, dict(keys=keys, k=64), data) + create_transform_im(CropForeground, dict(), data) + create_transform_im(CropForegroundd, dict(keys=keys, source_key=CommonKeys.IMAGE), data) + create_transform_im(RandGaussianNoise, dict(prob=1, mean=0, std=0.1), data) + create_transform_im(RandGaussianNoised, dict(keys=CommonKeys.IMAGE, prob=1, mean=0, std=0.1), data) + create_transform_im(KSpaceSpikeNoise, dict(loc=(100, 100, 100), k_intensity=13), data) + create_transform_im(KSpaceSpikeNoised, dict(keys=CommonKeys.IMAGE, loc=(100, 100, 100), k_intensity=13), data) + create_transform_im(RandKSpaceSpikeNoise, dict(prob=1, intensity_range=(10, 13)), data) + create_transform_im( + RandKSpaceSpikeNoised, + dict(keys=CommonKeys.IMAGE, global_prob=1, prob=1, common_sampling=True, intensity_range=(13, 15)), + data, + ) + create_transform_im(RandRicianNoise, dict(prob=1.0, mean=1, std=0.5), data) + create_transform_im(RandRicianNoised, dict(keys=CommonKeys.IMAGE, prob=1.0, mean=1, std=0.5), data) + create_transform_im(SavitzkyGolaySmooth, dict(window_length=5, order=1), data) + create_transform_im(SavitzkyGolaySmoothd, dict(keys=CommonKeys.IMAGE, window_length=5, order=1), data) + create_transform_im(GibbsNoise, dict(alpha=0.8), data) + create_transform_im(GibbsNoised, dict(keys=CommonKeys.IMAGE, alpha=0.8), data) + create_transform_im(RandGibbsNoise, dict(prob=1.0, alpha=(0.6, 0.8)), data) + create_transform_im(RandGibbsNoised, dict(keys=CommonKeys.IMAGE, prob=1.0, alpha=(0.6, 0.8)), data) + create_transform_im(ShiftIntensity, dict(offset=1), data, colorbar=True) + create_transform_im(ShiftIntensityd, dict(keys=CommonKeys.IMAGE, offset=1), data, colorbar=True) + create_transform_im(RandShiftIntensity, dict(prob=1.0, offsets=(10, 20)), data, colorbar=True) + create_transform_im( + RandShiftIntensityd, dict(keys=CommonKeys.IMAGE, prob=1.0, offsets=(10, 20)), data, colorbar=True + ) + create_transform_im(StdShiftIntensity, dict(factor=10), data, colorbar=True) + create_transform_im(StdShiftIntensityd, dict(keys=CommonKeys.IMAGE, factor=10), data, colorbar=True) + create_transform_im(RandStdShiftIntensity, dict(prob=1.0, factors=(5, 10)), data, colorbar=True) + create_transform_im( + RandStdShiftIntensityd, dict(keys=CommonKeys.IMAGE, prob=1.0, factors=(5, 10)), data, colorbar=True + ) + create_transform_im(RandBiasField, dict(prob=1, coeff_range=(0.2, 0.3)), data) + create_transform_im(RandBiasFieldd, dict(keys=CommonKeys.IMAGE, prob=1, coeff_range=(0.2, 0.3)), data) + create_transform_im(NormalizeIntensity, dict(subtrahend=0, divisor=10), data, colorbar=True) + create_transform_im(NormalizeIntensityd, dict(keys=CommonKeys.IMAGE, subtrahend=0, divisor=10), data, colorbar=True) + create_transform_im(ThresholdIntensity, dict(threshold=0.4, above=False, cval=0.9), data, colorbar=True) + create_transform_im( + ThresholdIntensityd, dict(keys=CommonKeys.IMAGE, threshold=0.4, above=False, cval=0.9), data, colorbar=True + ) + create_transform_im(ScaleIntensityRange, dict(a_min=0, a_max=1, b_min=1, b_max=10), data, colorbar=True) + create_transform_im( + ScaleIntensityRanged, dict(keys=CommonKeys.IMAGE, a_min=0, a_max=1, b_min=1, b_max=10), data, colorbar=True + ) + create_transform_im(ScaleIntensityRangePercentiles, dict(lower=5, upper=95, b_min=1, b_max=10), data, colorbar=True) + create_transform_im( + ScaleIntensityRangePercentilesd, + dict(keys=CommonKeys.IMAGE, lower=5, upper=95, b_min=1, b_max=10), + data, + colorbar=True, + ) + create_transform_im(AdjustContrast, dict(gamma=2), data, colorbar=True) + create_transform_im(AdjustContrastd, dict(keys=CommonKeys.IMAGE, gamma=2), data, colorbar=True) + create_transform_im(RandAdjustContrast, dict(prob=1, gamma=(1.5, 2)), data, colorbar=True) + create_transform_im(RandAdjustContrastd, dict(keys=CommonKeys.IMAGE, prob=1, gamma=(1.5, 2)), data, colorbar=True) + create_transform_im(MaskIntensity, dict(mask_data=data[CommonKeys.IMAGE], select_fn=lambda x: x > 0.3), data) + create_transform_im( + MaskIntensityd, dict(keys=CommonKeys.IMAGE, mask_key=CommonKeys.IMAGE, select_fn=lambda x: x > 0.3), data + ) + create_transform_im(GaussianSmooth, dict(sigma=2), data) + create_transform_im(GaussianSmoothd, dict(keys=CommonKeys.IMAGE, sigma=2), data) + create_transform_im(RandGaussianSmooth, dict(prob=1.0, sigma_x=(1, 2)), data) + create_transform_im(RandGaussianSmoothd, dict(keys=CommonKeys.IMAGE, prob=1.0, sigma_x=(1, 2)), data) + create_transform_im(GaussianSharpen, dict(), GaussianSmoothd(CommonKeys.IMAGE, 2)(data)) + create_transform_im(GaussianSharpend, dict(keys=CommonKeys.IMAGE), GaussianSmoothd(CommonKeys.IMAGE, 2)(data)) + create_transform_im(RandGaussianSharpen, dict(prob=1), GaussianSmoothd(CommonKeys.IMAGE, 2)(data)) + create_transform_im( + RandGaussianSharpend, dict(keys=CommonKeys.IMAGE, prob=1), GaussianSmoothd(CommonKeys.IMAGE, 2)(data) + ) + create_transform_im(RandHistogramShift, dict(prob=1, num_control_points=3), data, colorbar=True) + create_transform_im( + RandHistogramShiftd, dict(keys=CommonKeys.IMAGE, prob=1, num_control_points=3), data, colorbar=True + ) + create_transform_im(RandCoarseDropout, dict(prob=1, holes=200, spatial_size=20, fill_value=0), data) + create_transform_im( + RandCoarseDropoutd, dict(keys=CommonKeys.IMAGE, prob=1, holes=200, spatial_size=20, fill_value=0), data + ) + create_transform_im(RandCoarseShuffle, dict(prob=1, holes=200, spatial_size=20), data) + create_transform_im(RandCoarseShuffled, dict(keys=CommonKeys.IMAGE, prob=1, holes=200, spatial_size=20), data) + create_transform_im(HistogramNormalize, dict(num_bins=10), data) + create_transform_im(HistogramNormalized, dict(keys=CommonKeys.IMAGE, num_bins=10), data) + create_transform_im(SpatialPad, dict(spatial_size=(300, 300, 300)), data) + create_transform_im(SpatialPadd, dict(keys=keys, spatial_size=(300, 300, 300)), data) + create_transform_im(BorderPad, dict(spatial_border=10), data) + create_transform_im(BorderPadd, dict(keys=keys, spatial_border=10), data) + create_transform_im(SpatialCrop, dict(roi_center=(75, 75, 75), roi_size=(100, 100, 100)), data) + create_transform_im(SpatialCropd, dict(keys=keys, roi_center=(75, 75, 75), roi_size=(100, 100, 100)), data) + create_transform_im(CenterSpatialCrop, dict(roi_size=(100, 100, 100)), data) + create_transform_im(CenterSpatialCropd, dict(keys=keys, roi_size=(100, 100, 100)), data) + create_transform_im(RandSpatialCrop, dict(roi_size=(100, 100, 100), random_size=False), data) + create_transform_im(RandSpatialCropd, dict(keys=keys, roi_size=(100, 100, 100), random_size=False), data) + create_transform_im(RandSpatialCropSamples, dict(num_samples=4, roi_size=(100, 100, 100), random_size=False), data) + create_transform_im( + RandSpatialCropSamplesd, dict(keys=keys, num_samples=4, roi_size=(100, 100, 100), random_size=False), data + ) + create_transform_im( + RandWeightedCrop, dict(spatial_size=(100, 100, 100), num_samples=4, weight_map=data[CommonKeys.IMAGE] > 0), data + ) + create_transform_im( + RandWeightedCropd, dict(keys=keys, spatial_size=(100, 100, 100), num_samples=4, w_key=CommonKeys.IMAGE), data + ) + create_transform_im( + RandCropByPosNegLabel, + dict(spatial_size=(100, 100, 100), label=data[CommonKeys.LABEL], neg=0, num_samples=4), + data, + ) + create_transform_im( + RandCropByPosNegLabeld, + dict(keys=keys, spatial_size=(100, 100, 100), label_key=CommonKeys.LABEL, neg=0, num_samples=4), + data, + ) + create_transform_im( + RandCropByLabelClasses, + dict( + spatial_size=(100, 100, 100), label=data[CommonKeys.LABEL] > 0, num_classes=2, ratios=[0, 1], num_samples=4 + ), + data, + ) + create_transform_im( + RandCropByLabelClassesd, + dict( + keys=keys, + spatial_size=(100, 100, 100), + label_key=CommonKeys.LABEL, + num_classes=2, + ratios=[0, 1], + num_samples=4, + ), + data, + ) + create_transform_im(ResizeWithPadOrCrop, dict(spatial_size=(100, 100, 100)), data) + create_transform_im(ResizeWithPadOrCropd, dict(keys=keys, spatial_size=(100, 100, 100)), data) + create_transform_im(RandScaleCrop, dict(roi_scale=0.4), data) + create_transform_im(RandScaleCropd, dict(keys=keys, roi_scale=0.4), data) + create_transform_im(CenterScaleCrop, dict(roi_scale=0.4), data) + create_transform_im(CenterScaleCropd, dict(keys=keys, roi_scale=0.4), data) + create_transform_im(AsDiscrete, dict(to_onehot=None, threshold=10), data, is_post=True, colorbar=True) + create_transform_im(AsDiscreted, dict(keys=CommonKeys.LABEL, to_onehot=None, threshold=10), data, is_post=True) + create_transform_im(LabelFilter, dict(applied_labels=(1, 2, 3, 4, 5, 6)), data, is_post=True) + create_transform_im( + LabelFilterd, dict(keys=CommonKeys.LABEL, applied_labels=(1, 2, 3, 4, 5, 6)), data, is_post=True + ) + create_transform_im(LabelToContour, dict(), data, is_post=True) + create_transform_im(LabelToContourd, dict(keys=CommonKeys.LABEL), data, is_post=True) + create_transform_im(Spacing, dict(pixdim=(5, 5, 5), image_only=True), data) + create_transform_im(Spacingd, dict(keys=keys, pixdim=(5, 5, 5), mode=["bilinear", "nearest"]), data) + create_transform_im(RandAxisFlip, dict(prob=1), data) + create_transform_im(RandAxisFlipd, dict(keys=keys, prob=1), data) + create_transform_im(Resize, dict(spatial_size=(100, 100, 100)), data) + create_transform_im(Resized, dict(keys=keys, spatial_size=(100, 100, 100), mode=["area", "nearest"]), data) + data_binary = deepcopy(data) + data_binary[CommonKeys.LABEL] = (data_binary[CommonKeys.LABEL] > 0).astype(np.float32) + create_transform_im(KeepLargestConnectedComponent, dict(applied_labels=1), data_binary, is_post=True, ndim=2) + create_transform_im( + KeepLargestConnectedComponentd, dict(keys=CommonKeys.LABEL, applied_labels=1), data_binary, is_post=True, ndim=2 + ) + create_transform_im( + GridDistortion, dict(num_cells=3, distort_steps=[(1.5,) * 4] * 3, mode="nearest", padding_mode="zeros"), data + ) + create_transform_im( + GridDistortiond, + dict( + keys=keys, num_cells=3, distort_steps=[(1.5,) * 4] * 3, mode=["bilinear", "nearest"], padding_mode="zeros" + ), + data, + ) + create_transform_im(RandGridDistortion, dict(num_cells=3, prob=1.0, distort_limit=(-0.1, 0.1)), data) + create_transform_im( + RandGridDistortiond, + dict(keys=keys, num_cells=4, prob=1.0, distort_limit=(-0.2, 0.2), mode=["bilinear", "nearest"]), + data, + ) + create_transform_im( + RandSmoothFieldAdjustContrast, dict(spatial_size=(217, 217, 217), rand_size=(10, 10, 10), prob=1.0), data + ) + create_transform_im( + RandSmoothFieldAdjustContrastd, + dict(keys=keys, spatial_size=(217, 217, 217), rand_size=(10, 10, 10), prob=1.0), + data, + ) + create_transform_im( + RandSmoothFieldAdjustIntensity, + dict(spatial_size=(217, 217, 217), rand_size=(10, 10, 10), prob=1.0, gamma=(0.5, 4.5)), + data, + ) + create_transform_im( + RandSmoothFieldAdjustIntensityd, + dict(keys=keys, spatial_size=(217, 217, 217), rand_size=(10, 10, 10), prob=1.0, gamma=(0.5, 4.5)), + data, + ) + + create_transform_im( + RandSmoothDeform, + dict(spatial_size=(217, 217, 217), rand_size=(10, 10, 10), prob=1.0, def_range=0.05, grid_mode="blinear"), + data, + ) + create_transform_im( + RandSmoothDeformd, + dict( + keys=keys, + spatial_size=(217, 217, 217), + rand_size=(10, 10, 10), + prob=1.0, + def_range=0.05, + grid_mode="blinear", + ), + data, + ) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 2eebe3eda3..2103ccff58 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,39 +9,73 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Sequence, Union + import numpy as np import torch -from monai.config.type_definitions import NdarrayOrTensor +from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor +from monai.utils.misc import ensure_tuple, is_module_ver_at_least +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type __all__ = [ + "allclose", "moveaxis", "in1d", + "clip", + "percentile", + "where", + "nonzero", + "floor_divide", + "unravel_index", + "unravel_indices", + "ravel", + "any_np_pt", + "maximum", + "concatenate", + "cumsum", + "isfinite", + "searchsorted", + "repeat", + "isnan", + "ascontiguousarray", + "stack", + "mode", + "unique", ] -def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor: - """`moveaxis` for pytorch and numpy, using `permute` for pytorch ver < 1.8""" +def allclose(a: NdarrayTensor, b: NdarrayOrTensor, rtol=1e-5, atol=1e-8, equal_nan=False) -> bool: + """`np.allclose` with equivalent implementation for torch.""" + b, *_ = convert_to_dst_type(b, a) + if isinstance(a, np.ndarray): + return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) # type: ignore + + +def moveaxis(x: NdarrayOrTensor, src: Union[int, Sequence[int]], dst: Union[int, Sequence[int]]) -> NdarrayOrTensor: + """`moveaxis` for pytorch and numpy, using `permute` for pytorch version < 1.7""" if isinstance(x, torch.Tensor): - if hasattr(torch, "moveaxis"): - return torch.moveaxis(x, src, dst) - return _moveaxis_with_permute(x, src, dst) # type: ignore - if isinstance(x, np.ndarray): - return np.moveaxis(x, src, dst) - raise RuntimeError() + if hasattr(torch, "movedim"): # `movedim` is new in torch 1.7.0 + # torch.moveaxis is a recent alias since torch 1.8.0 + return torch.movedim(x, src, dst) # type: ignore + return _moveaxis_with_permute(x, src, dst) + return np.moveaxis(x, src, dst) -def _moveaxis_with_permute(x, src, dst): +def _moveaxis_with_permute( + x: torch.Tensor, src: Union[int, Sequence[int]], dst: Union[int, Sequence[int]] +) -> torch.Tensor: # get original indices indices = list(range(x.ndim)) - # make src and dst positive - if src < 0: - src = len(indices) + src - if dst < 0: - dst = len(indices) + dst - # remove desired index and insert it in new position - indices.pop(src) - indices.insert(dst, src) + len_indices = len(indices) + for s, d in zip(ensure_tuple(src), ensure_tuple(dst)): + # make src and dst positive + # remove desired index and insert it in new position + pos_s = len_indices + s if s < 0 else s + pos_d = len_indices + d if d < 0 else d + indices.pop(pos_s) + indices.insert(pos_d, pos_s) return x.permute(indices) @@ -50,3 +84,346 @@ def in1d(x, y): if isinstance(x, np.ndarray): return np.in1d(x, y) return (x[..., None] == torch.tensor(y, device=x.device)).any(-1).view(-1) + + +def clip(a: NdarrayOrTensor, a_min, a_max) -> NdarrayOrTensor: + """`np.clip` with equivalent implementation for torch.""" + result: NdarrayOrTensor + if isinstance(a, np.ndarray): + result = np.clip(a, a_min, a_max) + else: + result = torch.clamp(a, a_min, a_max) + return result + + +def percentile( + x: NdarrayOrTensor, q, dim: Optional[int] = None, keepdim: bool = False, **kwargs +) -> Union[NdarrayOrTensor, float, int]: + """`np.percentile` with equivalent implementation for torch. + + Pytorch uses `quantile`, but this functionality is only available from v1.7. + For earlier methods, we calculate it ourselves. This doesn't do interpolation, + so is the equivalent of ``numpy.percentile(..., interpolation="nearest")``. + For more details, please refer to: + https://pytorch.org/docs/stable/generated/torch.quantile.html. + https://numpy.org/doc/stable/reference/generated/numpy.percentile.html. + + Args: + x: input data + q: percentile to compute (should in range 0 <= q <= 100) + dim: the dim along which the percentiles are computed. default is to compute the percentile + along a flattened version of the array. only work for numpy array or Tensor with PyTorch >= 1.7.0. + keepdim: whether the output data has dim retained or not. + kwargs: if `x` is numpy array, additional args for `np.percentile`, more details: + https://numpy.org/doc/stable/reference/generated/numpy.percentile.html. + + Returns: + Resulting value (scalar) + """ + if np.isscalar(q): + if not 0 <= q <= 100: # type: ignore + raise ValueError + elif any(q < 0) or any(q > 100): + raise ValueError + result: Union[NdarrayOrTensor, float, int] + if isinstance(x, np.ndarray): + result = np.percentile(x, q, axis=dim, keepdims=keepdim, **kwargs) + else: + q = torch.tensor(q, device=x.device) + if hasattr(torch, "quantile"): # `quantile` is new in torch 1.7.0 + result = torch.quantile(x, q / 100.0, dim=dim, keepdim=keepdim) + else: + # Note that ``kthvalue()`` works one-based, i.e., the first sorted value + # corresponds to k=1, not k=0. Thus, we need the `1 +`. + k = 1 + (0.01 * q * (x.numel() - 1)).round().int() + if k.numel() > 1: + r = [x.view(-1).kthvalue(int(_k)).values.item() for _k in k] + result = torch.tensor(r, device=x.device) + else: + result = x.view(-1).kthvalue(int(k)).values.item() + + return result + + +def where(condition: NdarrayOrTensor, x=None, y=None) -> NdarrayOrTensor: + """ + Note that `torch.where` may convert y.dtype to x.dtype. + """ + result: NdarrayOrTensor + if isinstance(condition, np.ndarray): + if x is not None: + result = np.where(condition, x, y) + else: + result = np.where(condition) # type: ignore + else: + if x is not None: + x = torch.as_tensor(x, device=condition.device) + y = torch.as_tensor(y, device=condition.device, dtype=x.dtype) + result = torch.where(condition, x, y) + else: + result = torch.where(condition) # type: ignore + return result + + +def nonzero(x: NdarrayOrTensor) -> NdarrayOrTensor: + """`np.nonzero` with equivalent implementation for torch. + + Args: + x: array/tensor + + Returns: + Index unravelled for given shape + """ + if isinstance(x, np.ndarray): + return np.nonzero(x)[0] + return torch.nonzero(x).flatten() + + +def floor_divide(a: NdarrayOrTensor, b) -> NdarrayOrTensor: + """`np.floor_divide` with equivalent implementation for torch. + + As of pt1.8, use `torch.div(..., rounding_mode="floor")`, and + before that, use `torch.floor_divide`. + + Args: + a: first array/tensor + b: scalar to divide by + + Returns: + Element-wise floor division between two arrays/tensors. + """ + if isinstance(a, torch.Tensor): + if is_module_ver_at_least(torch, (1, 8, 0)): + return torch.div(a, b, rounding_mode="floor") + return torch.floor_divide(a, b) + return np.floor_divide(a, b) + + +def unravel_index(idx, shape) -> NdarrayOrTensor: + """`np.unravel_index` with equivalent implementation for torch. + + Args: + idx: index to unravel + shape: shape of array/tensor + + Returns: + Index unravelled for given shape + """ + if isinstance(idx, torch.Tensor): + coord = [] + for dim in reversed(shape): + coord.append(idx % dim) + idx = floor_divide(idx, dim) + return torch.stack(coord[::-1]) + return np.asarray(np.unravel_index(idx, shape)) + + +def unravel_indices(idx, shape) -> NdarrayOrTensor: + """Computing unravel coordinates from indices. + + Args: + idx: a sequence of indices to unravel + shape: shape of array/tensor + + Returns: + Stacked indices unravelled for given shape + """ + lib_stack = torch.stack if isinstance(idx[0], torch.Tensor) else np.stack + return lib_stack([unravel_index(i, shape) for i in idx]) # type: ignore + + +def ravel(x: NdarrayOrTensor) -> NdarrayOrTensor: + """`np.ravel` with equivalent implementation for torch. + + Args: + x: array/tensor to ravel + + Returns: + Return a contiguous flattened array/tensor. + """ + if isinstance(x, torch.Tensor): + if hasattr(torch, "ravel"): # `ravel` is new in torch 1.8.0 + return x.ravel() + return x.flatten().contiguous() + return np.ravel(x) + + +def any_np_pt(x: NdarrayOrTensor, axis: Union[int, Sequence[int]]) -> NdarrayOrTensor: + """`np.any` with equivalent implementation for torch. + + For pytorch, convert to boolean for compatibility with older versions. + + Args: + x: input array/tensor + axis: axis to perform `any` over + + Returns: + Return a contiguous flattened array/tensor. + """ + if isinstance(x, np.ndarray): + return np.any(x, axis) # type: ignore + + # pytorch can't handle multiple dimensions to `any` so loop across them + axis = [axis] if not isinstance(axis, Sequence) else axis + for ax in axis: + try: + x = torch.any(x, ax) + except RuntimeError: + # older versions of pytorch require the input to be cast to boolean + x = torch.any(x.bool(), ax) + return x + + +def maximum(a: NdarrayOrTensor, b: NdarrayOrTensor) -> NdarrayOrTensor: + """`np.maximum` with equivalent implementation for torch. + + `torch.maximum` only available from pt>1.6, else use `torch.stack` and `torch.max`. + + Args: + a: first array/tensor + b: second array/tensor + + Returns: + Element-wise maximum between two arrays/tensors. + """ + if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): + # is torch and has torch.maximum (pt>1.6) + if hasattr(torch, "maximum"): # `maximum` is new in torch 1.7.0 + return torch.maximum(a, b) + return torch.stack((a, b)).max(dim=0)[0] + return np.maximum(a, b) + + +def concatenate(to_cat: Sequence[NdarrayOrTensor], axis: int = 0, out=None) -> NdarrayOrTensor: + """`np.concatenate` with equivalent implementation for torch (`torch.cat`).""" + if isinstance(to_cat[0], np.ndarray): + return np.concatenate(to_cat, axis, out) # type: ignore + return torch.cat(to_cat, dim=axis, out=out) # type: ignore + + +def cumsum(a: NdarrayOrTensor, axis=None, **kwargs) -> NdarrayOrTensor: + """ + `np.cumsum` with equivalent implementation for torch. + + Args: + a: input data to compute cumsum. + axis: expected axis to compute cumsum. + kwargs: if `a` is PyTorch Tensor, additional args for `torch.cumsum`, more details: + https://pytorch.org/docs/stable/generated/torch.cumsum.html. + + """ + + if isinstance(a, np.ndarray): + return np.cumsum(a, axis) + if axis is None: + return torch.cumsum(a[:], 0, **kwargs) + return torch.cumsum(a, dim=axis, **kwargs) + + +def isfinite(x: NdarrayOrTensor) -> NdarrayOrTensor: + """`np.isfinite` with equivalent implementation for torch.""" + if not isinstance(x, torch.Tensor): + return np.isfinite(x) + return torch.isfinite(x) + + +def searchsorted(a: NdarrayTensor, v: NdarrayOrTensor, right=False, sorter=None, **kwargs) -> NdarrayTensor: + """ + `np.searchsorted` with equivalent implementation for torch. + + Args: + a: numpy array or tensor, containing monotonically increasing sequence on the innermost dimension. + v: containing the search values. + right: if False, return the first suitable location that is found, if True, return the last such index. + sorter: if `a` is numpy array, optional array of integer indices that sort array `a` into ascending order. + kwargs: if `a` is PyTorch Tensor, additional args for `torch.searchsorted`, more details: + https://pytorch.org/docs/stable/generated/torch.searchsorted.html. + + """ + side = "right" if right else "left" + if isinstance(a, np.ndarray): + return np.searchsorted(a, v, side, sorter) # type: ignore + return torch.searchsorted(a, v, right=right, **kwargs) # type: ignore + + +def repeat(a: NdarrayOrTensor, repeats: int, axis: Optional[int] = None, **kwargs) -> NdarrayOrTensor: + """ + `np.repeat` with equivalent implementation for torch (`repeat_interleave`). + + Args: + a: input data to repeat. + repeats: number of repetitions for each element, repeats is broadcasted to fit the shape of the given axis. + axis: axis along which to repeat values. + kwargs: if `a` is PyTorch Tensor, additional args for `torch.repeat_interleave`, more details: + https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html. + + """ + if isinstance(a, np.ndarray): + return np.repeat(a, repeats, axis) + return torch.repeat_interleave(a, repeats, dim=axis, **kwargs) + + +def isnan(x: NdarrayOrTensor) -> NdarrayOrTensor: + """`np.isnan` with equivalent implementation for torch. + + Args: + x: array/tensor + + """ + if isinstance(x, np.ndarray): + return np.isnan(x) + return torch.isnan(x) + + +def ascontiguousarray(x: NdarrayTensor, **kwargs) -> NdarrayOrTensor: + """`np.ascontiguousarray` with equivalent implementation for torch (`contiguous`). + + Args: + x: array/tensor + kwargs: if `x` is PyTorch Tensor, additional args for `torch.contiguous`, more details: + https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html. + + """ + if isinstance(x, np.ndarray): + if x.ndim == 0: + return x + return np.ascontiguousarray(x) + if isinstance(x, torch.Tensor): + return x.contiguous(**kwargs) + return x + + +def stack(x: Sequence[NdarrayTensor], dim: int) -> NdarrayTensor: + """`np.stack` with equivalent implementation for torch. + + Args: + x: array/tensor + dim: dimension along which to perform the stack (referred to as `axis` by numpy) + """ + if isinstance(x[0], np.ndarray): + return np.stack(x, dim) # type: ignore + return torch.stack(x, dim) # type: ignore + + +def mode(x: NdarrayTensor, dim: int = -1, to_long: bool = True) -> NdarrayTensor: + """`torch.mode` with equivalent implementation for numpy. + + Args: + x: array/tensor + dim: dimension along which to perform `mode` (referred to as `axis` by numpy) + to_long: convert input to long before performing mode. + """ + dtype = torch.int64 if to_long else None + x_t, *_ = convert_data_type(x, torch.Tensor, dtype=dtype) + o_t = torch.mode(x_t, dim).values + o, *_ = convert_to_dst_type(o_t, x) + return o + + +def unique(x: NdarrayTensor) -> NdarrayTensor: + """`torch.unique` with equivalent implementation for numpy. + + Args: + x: array/tensor + """ + return torch.unique(x) if isinstance(x, torch.Tensor) else np.unique(x) # type: ignore diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index aa8f02f815..429183b1a0 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,24 +12,28 @@ # have to explicitly bring these in here to resolve circular import issues from .aliases import alias, resolve_name from .decorators import MethodReplacer, RestartGenerator -from .deprecated import DeprecatedError, deprecated, deprecated_arg +from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg from .dist import evenly_divisible_all_gather, get_dist_device, string_list_all_gather from .enums import ( Average, BlendMode, ChannelMatching, CommonKeys, + DiceCEReduction, ForwardMode, GridSampleMode, GridSamplePadMode, InterpolateMode, InverseKeys, + JITMetadataKeys, LossReduction, Method, MetricReduction, NumpyPadMode, + PostFix, PytorchPadMode, SkipMode, + TraceKeys, TransformBackends, UpsampleMode, Weight, @@ -38,6 +42,7 @@ from .misc import ( MAX_SEED, ImageMetaKey, + check_parent_dir, copy_to_device, ensure_tuple, ensure_tuple_rep, @@ -52,12 +57,13 @@ issequenceiterable, list_to_dict, progress_bar, + sample_slices, + save_obj, set_determinism, star_zip_with, zip_with, ) from .module import ( - PT_BEFORE_1_7, InvalidPyTorchVersionError, OptionalImportError, damerau_levenshtein_distance, @@ -66,10 +72,13 @@ get_full_type_name, get_package_version, get_torch_version_tuple, + instantiate, load_submodules, look_up_option, min_version, optional_import, + pytorch_after, + require_pkg, version_leq, ) from .nvtx import Range @@ -77,6 +86,7 @@ from .state_cacher import StateCacher from .type_conversion import ( convert_data_type, + convert_to_cupy, convert_to_dst_type, convert_to_numpy, convert_to_tensor, diff --git a/monai/utils/aliases.py b/monai/utils/aliases.py index 2b7b29eeb5..0ae79e26ff 100644 --- a/monai/utils/aliases.py +++ b/monai/utils/aliases.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -70,8 +70,8 @@ def resolve_name(name): try: mod = importlib.import_module(modname) obj = getattr(mod, declname, None) - except ModuleNotFoundError: - raise ValueError(f"Module {modname!r} not found.") + except ModuleNotFoundError as not_found_err: + raise ValueError(f"Module {modname!r} not found.") from not_found_err if obj is None: raise ValueError(f"Module {modname!r} does not have member {declname!r}.") diff --git a/monai/utils/decorators.py b/monai/utils/decorators.py index 1931d703c9..0856c0fc1a 100644 --- a/monai/utils/decorators.py +++ b/monai/utils/decorators.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/utils/deprecated.py b/monai/utils/deprecate_utils.py similarity index 82% rename from monai/utils/deprecated.py rename to monai/utils/deprecate_utils.py index 3a4568b06c..a6092c1b63 100644 --- a/monai/utils/deprecated.py +++ b/monai/utils/deprecate_utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,6 +10,7 @@ # limitations under the License. import inspect +import sys import warnings from functools import wraps from types import FunctionType @@ -60,6 +61,9 @@ def deprecated( Decorated definition which warns or raises exception when used """ + # if version_val.startswith("0+"): + # # version unknown, set version_val to a large value (assuming the latest version) + # version_val = f"{sys.maxsize}" if since is not None and removed is not None and not version_leq(since, removed): raise ValueError(f"since must be less or equal to removed, got since={since}, removed={removed}.") is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since) @@ -116,6 +120,7 @@ def deprecated_arg( removed: Optional[str] = None, msg_suffix: str = "", version_val: str = __version__, + new_name: Optional[str] = None, ): """ Marks a particular named argument of a callable as deprecated. The same conditions for `since` and `removed` as @@ -130,6 +135,8 @@ def deprecated_arg( using the Sphinx directives such as `.. versionchanged:: version` and `.. deprecated:: version`. https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded + In the current implementation type annotations are not preserved. + Args: name: name of position or keyword argument to mark as deprecated. @@ -137,17 +144,23 @@ def deprecated_arg( removed: version at which the argument was removed and no longer usable. msg_suffix: message appended to warning/exception detailing reasons for deprecation and what to use instead. version_val: (used for testing) version to compare since and removed against, default is MONAI version. + new_name: name of position or keyword argument to replace the deprecated argument. + if it is specified and the signature of the decorated function has a `kwargs`, the value to the + deprecated argument `name` will be removed. Returns: - Decorated callable which warns or raises exception when deprecated argument used + Decorated callable which warns or raises exception when deprecated argument used. """ + + if version_val.startswith("0+") or not f"{version_val}".strip()[0].isdigit(): + # version unknown, set version_val to a large value (assuming the latest version) + version_val = f"{sys.maxsize}" if since is not None and removed is not None and not version_leq(since, removed): raise ValueError(f"since must be less or equal to removed, got since={since}, removed={removed}.") is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since) if is_not_yet_deprecated: # smaller than `since`, do nothing return lambda obj: obj - if since is None and removed is None: # raise a DeprecatedError directly is_removed = True @@ -157,9 +170,6 @@ def deprecated_arg( is_deprecated = since is not None and version_leq(since, version_val) is_removed = removed is not None and version_leq(removed, version_val) - if is_not_yet_deprecated: - return lambda obj: obj - def _decorator(func): argname = f"{func.__name__}_{name}" @@ -180,10 +190,23 @@ def _decorator(func): @wraps(func) def _wrapper(*args, **kwargs): + if new_name is not None and name in kwargs and new_name not in kwargs: + # replace the deprecated arg "name" with "new_name" + # if name is specified and new_name is not specified + kwargs[new_name] = kwargs[name] + try: + sig.bind(*args, **kwargs).arguments + except TypeError: + # multiple values for new_name using both args and kwargs + kwargs.pop(new_name, None) binding = sig.bind(*args, **kwargs).arguments - positional_found = name in binding - kw_found = "kwargs" in binding and name in binding["kwargs"] + kw_found = False + for k, param in sig.parameters.items(): + if param.kind == inspect.Parameter.VAR_KEYWORD and k in binding and name in binding[k]: + kw_found = True + # if the deprecated arg is found in the **kwargs, it should be removed + kwargs.pop(name, None) if positional_found or kw_found: if is_removed: diff --git a/monai/utils/dist.py b/monai/utils/dist.py index beb958a5c8..37536bfe83 100644 --- a/monai/utils/dist.py +++ b/monai/utils/dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 847df9e2d3..4bc3d6ee84 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,6 +10,9 @@ # limitations under the License. from enum import Enum +from typing import Optional + +from monai.utils.deprecate_utils import deprecated __all__ = [ "NumpyPadMode", @@ -22,12 +25,15 @@ "Average", "MetricReduction", "LossReduction", + "DiceCEReduction", "Weight", "ChannelMatching", "SkipMode", "Method", + "TraceKeys", "InverseKeys", "CommonKeys", + "PostFix", "ForwardMode", "TransformBackends", ] @@ -53,7 +59,7 @@ class NumpyPadMode(Enum): class GridSampleMode(Enum): """ - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html interpolation mode of `torch.nn.functional.grid_sample` @@ -71,7 +77,7 @@ class GridSampleMode(Enum): class InterpolateMode(Enum): """ - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html """ NEAREST = "nearest" @@ -103,7 +109,7 @@ class BlendMode(Enum): class PytorchPadMode(Enum): """ - See also: https://pytorch.org/docs/stable/nn.functional.html#pad + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html """ CONSTANT = "constant" @@ -114,7 +120,7 @@ class PytorchPadMode(Enum): class GridSamplePadMode(Enum): """ - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html """ ZEROS = "zeros" @@ -161,6 +167,16 @@ class LossReduction(Enum): SUM = "sum" +class DiceCEReduction(Enum): + """ + See also: + - :py:class:`monai.losses.dice.DiceCELoss` + """ + + MEAN = "mean" + SUM = "sum" + + class Weight(Enum): """ See also: :py:class:`monai.losses.dice.GeneralizedDiceLoss` @@ -208,8 +224,27 @@ class ForwardMode(Enum): EVAL = "eval" +class TraceKeys: + """Extra meta data keys used for traceable transforms.""" + + CLASS_NAME = "class" + ID = "id" + ORIG_SIZE = "orig_size" + EXTRA_INFO = "extra_info" + DO_TRANSFORM = "do_transforms" + KEY_SUFFIX = "_transforms" + NONE = "none" + + +@deprecated(since="0.8.0", msg_suffix="use monai.utils.enums.TraceKeys instead.") class InverseKeys: - """Extra meta data keys used for inverse transforms.""" + """ + Extra meta data keys used for inverse transforms. + + .. deprecated:: 0.8.0 + Use :class:`monai.utils.enums.TraceKeys` instead. + + """ CLASS_NAME = "class" ID = "id" @@ -217,6 +252,7 @@ class InverseKeys: EXTRA_INFO = "extra_info" DO_TRANSFORM = "do_transforms" KEY_SUFFIX = "_transforms" + NONE = "none" class CommonKeys: @@ -236,6 +272,22 @@ class CommonKeys: LOSS = "loss" +class PostFix: + """Post-fixes.""" + + @staticmethod + def _get_str(prefix, suffix): + return suffix if prefix is None else f"{prefix}_{suffix}" + + @staticmethod + def meta(key: Optional[str] = None): + return PostFix._get_str(key, "meta_dict") + + @staticmethod + def orig_meta(key: Optional[str] = None): + return PostFix._get_str(key, "orig_meta_dict") + + class TransformBackends(Enum): """ Transform backends. @@ -243,3 +295,15 @@ class TransformBackends(Enum): TORCH = "torch" NUMPY = "numpy" + + +class JITMetadataKeys(Enum): + """ + Keys stored in the metadata file for saved Torchscript models. Some of these are generated by the routines + and others are optionally provided by users. + """ + + NAME = "name" + TIMESTAMP = "timestamp" + VERSION = "version" + DESCRIPTION = "description" diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py index 26487083b1..366d11ebd8 100644 --- a/monai/utils/jupyter_utils.py +++ b/monai/utils/jupyter_utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,11 +16,14 @@ from enum import Enum from threading import RLock, Thread -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch +from monai.config import IgniteInfo +from monai.utils.module import min_version, optional_import + try: import matplotlib.pyplot as plt @@ -28,14 +31,11 @@ except ImportError: has_matplotlib = False -try: +if TYPE_CHECKING: from ignite.engine import Engine, Events - - has_ignite = True -except ImportError: - Engine = object - Events = object - has_ignite = False +else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") LOSS_NAME = "loss" @@ -128,7 +128,7 @@ def plot_metric_images( else: im.imshow(np.squeeze(imagemap[n]), cmap="gray") - im.set_title("%s\n%.3g -> %.3g" % (n, imagemap[n].min(), imagemap[n].max())) + im.set_title(f"{n}\n{imagemap[n].min():.3g} -> {imagemap[n].max():.3g}") im.axis("off") axes.append(im) @@ -161,6 +161,7 @@ def plot_engine_status( window_fraction: int = 20, image_fn: Optional[Callable] = tensor_to_images, fig=None, + selected_inst: int = 0, ) -> Tuple: """ Plot the status of the given Engine with its logger. The plot will consist of a graph of loss values and metrics @@ -177,6 +178,7 @@ def plot_engine_status( window_fraction: for metric plot, what fraction of the graph value length to use as the running average window image_fn: callable converting tensors keyed to a name in the Engine to a tuple of images to plot fig: Figure object to plot into, reuse from previous plotting for flicker-free refreshing + selected_inst: index of the instance to show in the image plot Returns: Figure object (or `fig` if given), list of Axes objects for graph and images @@ -189,22 +191,36 @@ def plot_engine_status( graphmap = {LOSS_NAME: logger.loss} graphmap.update(logger.metrics) - imagemap = {} + imagemap: Dict = {} if image_fn is not None and engine.state is not None and engine.state.batch is not None: for src in (engine.state.batch, engine.state.output): + label = "Batch" if src is engine.state.batch else "Output" + batch_selected_inst = selected_inst # selected batch index, set to 0 when src is decollated + + # if the src object is a list of elements, ie. a decollated batch, select an element and keep it as + # a dictionary of tensors with a batch dimension added if isinstance(src, list): - for i, s in enumerate(src): - if isinstance(s, dict): - for k, v in s.items(): - if isinstance(v, torch.Tensor): - image = image_fn(k, v) - if image is not None: - imagemap[f"{k}_{i}"] = image - elif isinstance(s, torch.Tensor): - label = "Batch" if src is engine.state.batch else "Output" - image = image_fn(label, s) + selected_dict = src[selected_inst] # select this element + batch_selected_inst = 0 # set the selection to be the single index in the batch dimension + # store each tensor that is interpretable as an image with an added batch dimension + src = {k: v[None] for k, v in selected_dict.items() if isinstance(v, torch.Tensor) and v.ndim >= 3} + + # images will be generated from the batch item selected above only, or from the single item given as `src` + + if isinstance(src, dict): + for k, v in src.items(): + if isinstance(v, torch.Tensor) and v.ndim >= 4: + image = image_fn(k, v[batch_selected_inst]) + + # if we have images add each one separately to the map if image is not None: - imagemap[f"{label}_{i}"] = image + for i, im in enumerate(image): + imagemap[f"{k}_{i}"] = im + + elif isinstance(src, torch.Tensor): + image = image_fn(label, src) + if image is not None: + imagemap[f"{label}_{i}"] = image axes = plot_metric_images(fig, title, graphmap, imagemap, yscale, avg_keys, window_fraction) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index a31452f6ae..36ba7722b8 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,17 +12,22 @@ import collections.abc import inspect import itertools +import os import random +import shutil +import tempfile import types import warnings from ast import literal_eval from distutils.util import strtobool +from pathlib import Path from typing import Any, Callable, Optional, Sequence, Tuple, Union, cast import numpy as np import torch -from monai.utils.module import get_torch_version_tuple, version_leq +from monai.config.type_definitions import NdarrayOrTensor, PathLike +from monai.utils.module import version_leq __all__ = [ "zip_with", @@ -43,6 +48,10 @@ "copy_to_device", "ImageMetaKey", "is_module_ver_at_least", + "has_option", + "sample_slices", + "check_parent_dir", + "save_obj", ] _seed = None @@ -88,7 +97,7 @@ def ensure_tuple(vals: Any) -> Tuple[Any, ...]: Returns a tuple of `vals`. """ if not issequenceiterable(vals): - vals = (vals,) + return (vals,) return tuple(vals) @@ -97,8 +106,8 @@ def ensure_tuple_size(tup: Any, dim: int, pad_val: Any = 0) -> Tuple[Any, ...]: """ Returns a copy of `tup` with `dim` values by either shortened or padded with `pad_val` as necessary. """ - tup = ensure_tuple(tup) + (pad_val,) * dim - return tuple(tup[:dim]) + new_tup = ensure_tuple(tup) + (pad_val,) * dim + return new_tup[:dim] def ensure_tuple_rep(tup: Any, dim: int) -> Tuple[Any, ...]: @@ -231,6 +240,13 @@ def set_determinism( use_deterministic_algorithms: Set whether PyTorch operations must use "deterministic" algorithms. additional_settings: additional settings that need to set random seed. + Note: + + This function will not affect the randomizable objects in :py:class:`monai.transforms.Randomizable`, which + have independent random states. For those objects, the ``set_random_state()`` method should be used to + ensure the deterministic behavior (alternatively, :py:class:`monai.data.DataLoader` by default sets the seeds + according to the global random state, please see also: :py:class:`monai.data.utils.worker_init_fn` and + :py:class:`monai.data.utils.set_rnd`). """ if seed is None: # cast to 32 bit seed for CUDA @@ -250,19 +266,21 @@ def set_determinism( for func in additional_settings: func(seed) + if torch.backends.flags_frozen(): + warnings.warn("PyTorch global flag support of backends is disabled, enable it to set global `cudnn` flags.") + torch.backends.__allow_nonbracketed_mutation_flag = True + if seed is not None: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False else: # restore the original flags torch.backends.cudnn.deterministic = _flag_deterministic torch.backends.cudnn.benchmark = _flag_cudnn_benchmark - if use_deterministic_algorithms is not None: - torch_ver = get_torch_version_tuple() - if torch_ver >= (1, 9): + if hasattr(torch, "use_deterministic_algorithms"): # `use_deterministic_algorithms` is new in torch 1.8.0 torch.use_deterministic_algorithms(use_deterministic_algorithms) - elif torch_ver >= (1, 7): - torch.set_deterministic(use_deterministic_algorithms) # beta feature + elif hasattr(torch, "set_deterministic"): # `set_deterministic` is new in torch 1.7.0 + torch.set_deterministic(use_deterministic_algorithms) # type: ignore else: warnings.warn("use_deterministic_algorithms=True, but PyTorch version is too old to set the mode.") @@ -279,9 +297,7 @@ def list_to_dict(items): def _parse_var(s): items = s.split("=", maxsplit=1) key = items[0].strip(" \n\r\t'") - value = None - if len(items) > 1: - value = items[1].strip(" \n\r\t'") + value = items[1].strip(" \n\r\t'") if len(items) > 1 else None return key, value d = {} @@ -302,10 +318,7 @@ def _parse_var(s): def copy_to_device( - obj: Any, - device: Optional[Union[str, torch.device]], - non_blocking: bool = True, - verbose: bool = False, + obj: Any, device: Optional[Union[str, torch.device]], non_blocking: bool = True, verbose: bool = False ) -> Any: """ Copy object or tuple/list/dictionary of objects to ``device``. @@ -314,7 +327,7 @@ def copy_to_device( obj: object or tuple/list/dictionary of objects to move to ``device``. device: move ``obj`` to this device. Can be a string (e.g., ``cpu``, ``cuda``, ``cuda:0``, etc.) or of type ``torch.device``. - non_blocking_transfer: when `True`, moves data to device asynchronously if + non_blocking: when `True`, moves data to device asynchronously if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. verbose: when `True`, will print a warning for any elements of incompatible type not copied to ``device``. @@ -368,3 +381,83 @@ def is_module_ver_at_least(module, version): """ test_ver = ".".join(map(str, version)) return module.__version__ != test_ver and version_leq(test_ver, module.__version__) + + +def sample_slices(data: NdarrayOrTensor, dim: int = 1, as_indices: bool = True, *slicevals: int) -> NdarrayOrTensor: + """sample several slices of input numpy array or Tensor on specified `dim`. + + Args: + data: input data to sample slices, can be numpy array or PyTorch Tensor. + dim: expected dimension index to sample slices, default to `1`. + as_indices: if `True`, `slicevals` arg will be treated as the expected indices of slice, like: `1, 3, 5` + means `data[..., [1, 3, 5], ...]`, if `False`, `slicevals` arg will be treated as args for `slice` func, + like: `1, None` means `data[..., [1:], ...]`, `1, 5` means `data[..., [1: 5], ...]`. + slicevals: indices of slices or start and end indices of expected slices, depends on `as_indices` flag. + + """ + slices = [slice(None)] * len(data.shape) + slices[dim] = slicevals if as_indices else slice(*slicevals) # type: ignore + + return data[tuple(slices)] + + +def check_parent_dir(path: PathLike, create_dir: bool = True): + """ + Utility to check whether the parent directory of the `path` exists. + + Args: + path: input path to check the parent directory. + create_dir: if True, when the parent directory doesn't exist, create the directory, + otherwise, raise exception. + + """ + path = Path(path) + path_dir = path.parent + if not path_dir.exists(): + if create_dir: + path_dir.mkdir(parents=True) + else: + raise ValueError(f"the directory of specified path does not exist: `{path_dir}`.") + + +def save_obj( + obj, path: PathLike, create_dir: bool = True, atomic: bool = True, func: Optional[Callable] = None, **kwargs +): + """ + Save an object to file with specified path. + Support to serialize to a temporary file first, then move to final destination, + so that files are guaranteed to not be damaged if exception occurs. + + Args: + obj: input object data to save. + path: target file path to save the input object. + create_dir: whether to create dictionary of the path if not existng, default to `True`. + atomic: if `True`, state is serialized to a temporary file first, then move to final destination. + so that files are guaranteed to not be damaged if exception occurs. default to `True`. + func: the function to save file, if None, default to `torch.save`. + kwargs: other args for the save `func` except for the checkpoint and filename. + default `func` is `torch.save()`, details of other args: + https://pytorch.org/docs/stable/generated/torch.save.html. + + """ + path = Path(path) + check_parent_dir(path=path, create_dir=create_dir) + if path.exists(): + # remove the existing file + os.remove(path) + + if func is None: + func = torch.save + + if not atomic: + func(obj=obj, f=path, **kwargs) + return + try: + # writing to a temporary directory and then using a nearly atomic rename operation + with tempfile.TemporaryDirectory() as tempdir: + temp_path: Path = Path(tempdir) / path.name + func(obj=obj, f=temp_path, **kwargs) + if temp_path.is_file(): + shutil.move(str(temp_path), path) + except PermissionError: # project-monai/monai issue #3613 + pass diff --git a/monai/utils/module.py b/monai/utils/module.py index 33314fb0e3..065cc8f7c8 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -8,13 +8,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import enum +import os +import re import sys import warnings +from functools import partial, wraps from importlib import import_module +from inspect import isclass, isfunction, ismethod from pkgutil import walk_packages +from pydoc import locate from re import match -from typing import Any, Callable, Collection, Hashable, Iterable, List, Mapping, Tuple, cast +from types import FunctionType +from typing import Any, Callable, Collection, Hashable, Iterable, List, Mapping, Tuple, Union, cast import torch @@ -29,16 +36,18 @@ "look_up_option", "min_version", "optional_import", + "require_pkg", "load_submodules", + "instantiate", "get_full_type_name", "get_package_version", "get_torch_version_tuple", - "PT_BEFORE_1_7", "version_leq", + "pytorch_after", ] -def look_up_option(opt_str, supported: Collection, default="no_default"): +def look_up_option(opt_str, supported: Union[Collection, enum.EnumMeta], default="no_default", print_all_options=True): """ Look up the option in the supported collection and return the matched item. Raise a value error possibly with a guess of the closest match. @@ -49,6 +58,7 @@ def look_up_option(opt_str, supported: Collection, default="no_default"): default: If it is given, this method will return `default` when `opt_str` is not found, instead of raising a `ValueError`. Otherwise, it defaults to `"no_default"`, so that the method may raise a `ValueError`. + print_all_options: whether to print all available options when `opt_str` is not found. Defaults to True Examples: @@ -104,12 +114,12 @@ class Color(Enum): if edit_dist <= 3: edit_dists[key] = edit_dist - supported_msg = f"Available options are {set_to_check}.\n" + supported_msg = f"Available options are {set_to_check}.\n" if print_all_options else "" if edit_dists: guess_at_spelling = min(edit_dists, key=edit_dists.get) # type: ignore raise ValueError( f"By '{opt_str}', did you mean '{guess_at_spelling}'?\n" - + f"'{opt_str}' is not a valid option.\n" + + f"'{opt_str}' is not a valid value.\n" + supported_msg ) raise ValueError(f"Unsupported option '{opt_str}', " + supported_msg) @@ -136,9 +146,7 @@ def damerau_levenshtein_distance(s1: str, s2: str): for j, s2j in enumerate(s2): cost = 0 if s1i == s2j else 1 d[(i, j)] = min( - d[(i - 1, j)] + 1, # deletion - d[(i, j - 1)] + 1, # insertion - d[(i - 1, j - 1)] + cost, # substitution + d[(i - 1, j)] + 1, d[(i, j - 1)] + 1, d[(i - 1, j - 1)] + cost # deletion # insertion # substitution ) if i and j and s1i == s2[j - 1] and s1[i - 1] == s2j: d[(i, j)] = min(d[(i, j)], d[i - 2, j - 2] + cost) # transposition @@ -189,7 +197,37 @@ def load_submodules(basemod, load_all: bool = True, exclude_pattern: str = "(.*[ return submodules, err_mod +def instantiate(path: str, **kwargs): + """ + Create an object instance or partial function from a class or function represented by string. + `kwargs` will be part of the input arguments to the class constructor or function. + The target component must be a class or a function, if not, return the component directly. + + Args: + path: full path of the target class or function component. + kwargs: arguments to initialize the class instance or set default args + for `partial` function. + + """ + + component = locate(path) + if component is None: + raise ModuleNotFoundError(f"Cannot locate class or function path: '{path}'.") + if isclass(component): + return component(**kwargs) + # support regular function, static method and class method + if isfunction(component) or (ismethod(component) and isclass(getattr(component, "__self__", None))): + return partial(component, **kwargs) + + warnings.warn(f"Component to instantiate must represent a valid class or function, but got {path}.") + return component + + def get_full_type_name(typeobj): + """ + Utility to get the full path name of a class or object type. + + """ module = typeobj.__module__ if module is None or module == str.__class__.__module__: return typeobj.__name__ # Avoid reporting __builtin__ @@ -349,6 +387,45 @@ def __call__(self, *_args, **_kwargs): return _LazyRaise(), False +def require_pkg( + pkg_name: str, version: str = "", version_checker: Callable[..., bool] = min_version, raise_error: bool = True +): + """ + Decorator function to check the required package installation. + + Args: + pkg_name: required package name, like: "itk", "nibabel", etc. + version: required version string used by the version_checker. + version_checker: a callable to check the module version, defaults to `monai.utils.min_version`. + raise_error: if True, raise `OptionalImportError` error if the required package is not installed + or the version doesn't match requirement, if False, print the error in a warning. + + """ + + def _decorator(obj): + is_func = isinstance(obj, FunctionType) + call_obj = obj if is_func else obj.__init__ + _, has = optional_import(module=pkg_name, version=version, version_checker=version_checker) + + @wraps(call_obj) + def _wrapper(*args, **kwargs): + if not has: + err_msg = f"required package `{pkg_name}` is not installed or the version doesn't match requirement." + if raise_error: + raise OptionalImportError(err_msg) + else: + warnings.warn(err_msg) + + return call_obj(*args, **kwargs) + + if is_func: + return _wrapper + obj.__init__ = _wrapper + return obj + + return _decorator + + def get_package_version(dep_name, default="NOT INSTALLED or UNKNOWN VERSION."): """ Try to load package and get version. If not found, return `default`. @@ -364,17 +441,28 @@ def get_torch_version_tuple(): Returns: tuple of ints represents the pytorch major/minor version. """ - return tuple((int(x) for x in torch.__version__.split(".")[:2])) + return tuple(int(x) for x in torch.__version__.split(".")[:2]) + + +def version_leq(lhs: str, rhs: str): + """ + Returns True if version `lhs` is earlier or equal to `rhs`. + Args: + lhs: version name to compare with `rhs`, return True if earlier or equal to `rhs`. + rhs: version name to compare with `lhs`, return True if later or equal to `lhs`. -def version_leq(lhs, rhs): - """Returns True if version `lhs` is earlier or equal to `rhs`.""" + """ - ver, has_ver = optional_import("pkg_resources", name="parse_version") + lhs, rhs = str(lhs), str(rhs) + pkging, has_ver = optional_import("pkg_resources", name="packaging") if has_ver: - return ver(lhs) <= ver(rhs) + try: + return pkging.version.Version(lhs) <= pkging.version.Version(rhs) + except pkging.version.InvalidVersion: + return True - def _try_cast(val): + def _try_cast(val: str): val = val.strip() try: m = match("(\\d+)(.*)", val) @@ -390,10 +478,10 @@ def _try_cast(val): rhs = rhs.split("+", 1)[0] # parse the version strings in this basic way without `packaging` package - lhs = map(_try_cast, lhs.split(".")) - rhs = map(_try_cast, rhs.split(".")) + lhs_ = map(_try_cast, lhs.split(".")) + rhs_ = map(_try_cast, rhs.split(".")) - for l, r in zip(lhs, rhs): + for l, r in zip(lhs_, rhs_): if l != r: if isinstance(l, int) and isinstance(r, int): return l < r @@ -402,7 +490,51 @@ def _try_cast(val): return True -try: - PT_BEFORE_1_7 = torch.__version__ != "1.7.0" and version_leq(torch.__version__, "1.7.0") -except (AttributeError, TypeError): - PT_BEFORE_1_7 = True +def pytorch_after(major, minor, patch=0, current_ver_string=None) -> bool: + """ + Compute whether the current pytorch version is after or equal to the specified version. + The current system pytorch version is determined by `torch.__version__` or + via system environment variable `PYTORCH_VER`. + + Args: + major: major version number to be compared with + minor: minor version number to be compared with + patch: patch version number to be compared with + current_ver_string: if None, `torch.__version__` will be used. + + Returns: + True if the current pytorch version is greater than or equal to the specified version. + """ + + try: + if current_ver_string is None: + _env_var = os.environ.get("PYTORCH_VER", "") + current_ver_string = _env_var if _env_var else torch.__version__ + ver, has_ver = optional_import("pkg_resources", name="parse_version") + if has_ver: + return ver(".".join((f"{major}", f"{minor}", f"{patch}"))) <= ver(f"{current_ver_string}") # type: ignore + parts = f"{current_ver_string}".split("+", 1)[0].split(".", 3) + while len(parts) < 3: + parts += ["0"] + c_major, c_minor, c_patch = parts[:3] + except (AttributeError, ValueError, TypeError): + c_major, c_minor = get_torch_version_tuple() + c_patch = "0" + c_mn = int(c_major), int(c_minor) + mn = int(major), int(minor) + if c_mn != mn: + return c_mn > mn + is_prerelease = ("a" in f"{c_patch}".lower()) or ("rc" in f"{c_patch}".lower()) + c_p = 0 + try: + p_reg = re.search(r"\d+", f"{c_patch}") + if p_reg: + c_p = int(p_reg.group()) + except (AttributeError, TypeError, ValueError): + is_prerelease = True + patch = int(patch) + if c_p != patch: + return c_p > patch # type: ignore + if is_prerelease: + return False + return True diff --git a/monai/utils/nvtx.py b/monai/utils/nvtx.py index 1980ceef71..691f900c7d 100644 --- a/monai/utils/nvtx.py +++ b/monai/utils/nvtx.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -62,8 +62,15 @@ def __call__(self, obj: Any): # Define the name to be associated to the range if not provided if self.name is None: name = type(obj).__name__ + # If CuCIM or TorchVision transform wrappers are being used, + # append the underlying transform to the name for more clarity + if "CuCIM" in name or "TorchVision" in name: + name = f"{name}_{obj.name}" self.name_counter[name] += 1 - self.name = f"{name}_{self.name_counter[name]}" + if self.name_counter[name] > 1: + self.name = f"{name}_{self.name_counter[name]}" + else: + self.name = name # Define the methods to be wrapped if not provided if self.methods is None: @@ -138,7 +145,7 @@ def _get_method(self, obj: Any) -> tuple: if len(method_list) < 1: raise ValueError( f"The method to be wrapped for this object [{type(obj)}] is not recognized." - "The name of the method should be provied or the object should have one of these methods:" + "The name of the method should be provided or the object should have one of these methods:" f"{default_methods}" ) return ensure_tuple(method_list) diff --git a/monai/utils/profiling.py b/monai/utils/profiling.py index 695653e897..8e0742268f 100644 --- a/monai/utils/profiling.py +++ b/monai/utils/profiling.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -56,7 +56,7 @@ def wrapper(*args, **kwargs): cpu_time = torch.autograd.profiler.format_time(cpu_time) gpu_time = torch.autograd.profiler.format_time(gpu_time) - print("cpu time: {}, gpu time: {}".format(cpu_time, gpu_time), flush=True) + print(f"cpu time: {cpu_time}, gpu time: {gpu_time}", flush=True) return result @@ -83,7 +83,7 @@ def wrapper(*args, **kwargs): total_time = (end - start) * 1e6 total_time_str = torch.autograd.profiler.format_time(total_time) - print("end to end time: {}".format(total_time_str), flush=True) + print(f"end to end time: {total_time_str}", flush=True) return result diff --git a/monai/utils/state_cacher.py b/monai/utils/state_cacher.py index 94943a8c37..3e392ab979 100644 --- a/monai/utils/state_cacher.py +++ b/monai/utils/state_cacher.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,10 +11,14 @@ import copy import os +import pickle import tempfile from typing import Dict, Optional import torch +from torch.serialization import DEFAULT_PROTOCOL + +from monai.config.type_definitions import PathLike __all__ = ["StateCacher"] @@ -37,8 +41,10 @@ class StateCacher: def __init__( self, in_memory: bool, - cache_dir: Optional[str] = None, + cache_dir: Optional[PathLike] = None, allow_overwrite: bool = True, + pickle_module=pickle, + pickle_protocol: int = DEFAULT_PROTOCOL, ) -> None: """Constructor. @@ -51,20 +57,39 @@ def __init__( allow_overwrite: allow the cache to be overwritten. If set to `False`, an error will be thrown if a matching already exists in the list of cached objects. + pickle_module: module used for pickling metadata and objects, default to `pickle`. + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. + pickle_protocol: can be specified to override the default protocol, default to `2`. + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. + """ self.in_memory = in_memory - self.cache_dir = cache_dir + self.cache_dir = tempfile.gettempdir() if cache_dir is None else cache_dir + if not os.path.isdir(self.cache_dir): + raise ValueError("Given `cache_dir` is not a valid directory.") + self.allow_overwrite = allow_overwrite + self.pickle_module = pickle_module + self.pickle_protocol = pickle_protocol + self.cached: Dict = {} - if self.cache_dir is None: - self.cache_dir = tempfile.gettempdir() - elif not os.path.isdir(self.cache_dir): - raise ValueError("Given `cache_dir` is not a valid directory.") + def store(self, key, data_obj, pickle_module=None, pickle_protocol: Optional[int] = None): + """ + Store a given object with the given key name. - self.cached: Dict[str, str] = {} + Args: + key: key of the data object to store. + data_obj: data object to store. + pickle_module: module used for pickling metadata and objects, default to `self.pickle_module`. + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. + pickle_protocol: can be specified to override the default protocol, default to `self.pickle_protocol`. + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. - def store(self, key, data_obj): - """Store a given object with the given key name.""" + """ if key in self.cached and not self.allow_overwrite: raise RuntimeError("Cached key already exists and overwriting is disabled.") if self.in_memory: @@ -72,7 +97,12 @@ def store(self, key, data_obj): else: fn = os.path.join(self.cache_dir, f"state_{key}_{id(self)}.pt") self.cached.update({key: {"obj": fn}}) - torch.save(data_obj, fn) + torch.save( + obj=data_obj, + f=fn, + pickle_module=self.pickle_module if pickle_module is None else pickle_module, + pickle_protocol=self.pickle_protocol if pickle_protocol is None else pickle_protocol, + ) # store object's device if relevant if hasattr(data_obj, "device"): self.cached[key]["device"] = data_obj.device diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index b0ce187e38..d5944e265b 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -1,11 +1,23 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import re -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Sequence, Tuple, Type, Union import numpy as np import torch -from monai.config.type_definitions import DtypeLike, NdarrayOrTensor +from monai.config.type_definitions import DtypeLike, NdarrayTensor from monai.utils import optional_import +from monai.utils.module import look_up_option cp, has_cp = optional_import("cupy") cp_ndarray, _ = optional_import("cupy", name="ndarray") @@ -16,6 +28,7 @@ "get_equivalent_dtype", "convert_data_type", "get_dtype", + "convert_to_cupy", "convert_to_numpy", "convert_to_tensor", "convert_to_dst_type", @@ -40,31 +53,34 @@ def dtype_torch_to_numpy(dtype): """Convert a torch dtype to its numpy equivalent.""" - if dtype not in _torch_to_np_dtype: - raise ValueError(f"Unsupported torch to numpy dtype '{dtype}'.") - return _torch_to_np_dtype[dtype] + return look_up_option(dtype, _torch_to_np_dtype) def dtype_numpy_to_torch(dtype): """Convert a numpy dtype to its torch equivalent.""" # np dtypes can be given as np.float32 and np.dtype(np.float32) so unify them - dtype = np.dtype(dtype) if type(dtype) is type else dtype - if dtype not in _np_to_torch_dtype: - raise ValueError(f"Unsupported numpy to torch dtype '{dtype}'.") - return _np_to_torch_dtype[dtype] + dtype = np.dtype(dtype) if isinstance(dtype, (type, str)) else dtype + return look_up_option(dtype, _np_to_torch_dtype) def get_equivalent_dtype(dtype, data_type): """Convert to the `dtype` that corresponds to `data_type`. - Example: + + Example:: + im = torch.tensor(1) dtype = get_equivalent_dtype(np.float32, type(im)) + """ + if dtype is None: + return None if data_type is torch.Tensor: - if type(dtype) is torch.dtype: + if isinstance(dtype, torch.dtype): + # already a torch dtype and target `data_type` is torch.Tensor return dtype return dtype_numpy_to_torch(dtype) - if type(dtype) is not torch.dtype: + if not isinstance(dtype, torch.dtype): + # assuming the dtype is ok if it is not a torch dtype and target `data_type` is not torch.Tensor return dtype return dtype_torch_to_numpy(dtype) @@ -83,43 +99,49 @@ def get_dtype(data: Any): return type(data) -def convert_to_tensor(data, wrap_sequence: bool = False): +def convert_to_tensor( + data, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, wrap_sequence: bool = False +): """ Utility to convert the input data to a PyTorch Tensor. If passing a dictionary, list or tuple, recursively check every item and convert it to PyTorch Tensor. Args: data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. - will convert Tensor, Numpy array, float, int, bool to Tensors, strings and objects keep the original. + will convert Tensor, Numpy array, float, int, bool to Tensor, strings and objects keep the original. for dictionary, list or tuple, convert every item to a Tensor if applicable. - wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. - If `True`, then `[1, 2]` -> `tensor([1, 2])`. + dtype: target data type to when converting to Tensor. + device: target device to put the converted Tensor data. + wrap_sequence: if `False`, then lists will recursively call this function. + E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. If `True`, then `[1, 2]` -> `tensor([1, 2])`. """ if isinstance(data, torch.Tensor): - return data.contiguous() + return data.to(dtype=dtype, device=device, memory_format=torch.contiguous_format) # type: ignore if isinstance(data, np.ndarray): # skip array of string classes and object, refer to: # https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/_utils/collate.py#L13 if re.search(r"[SaUO]", data.dtype.str) is None: # numpy array with 0 dims is also sequence iterable, # `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims - return torch.as_tensor(data if data.ndim == 0 else np.ascontiguousarray(data)) - elif isinstance(data, (float, int, bool)): - return torch.as_tensor(data) - elif isinstance(data, Sequence) and wrap_sequence: - return torch.as_tensor(data) + if data.ndim > 0: + data = np.ascontiguousarray(data) + return torch.as_tensor(data, dtype=dtype, device=device) # type: ignore + elif (has_cp and isinstance(data, cp_ndarray)) or isinstance(data, (float, int, bool)): + return torch.as_tensor(data, dtype=dtype, device=device) # type: ignore elif isinstance(data, list): - return [convert_to_tensor(i) for i in data] + list_ret = [convert_to_tensor(i, dtype=dtype, device=device) for i in data] + return torch.as_tensor(list_ret, dtype=dtype, device=device) if wrap_sequence else list_ret # type: ignore elif isinstance(data, tuple): - return tuple(convert_to_tensor(i) for i in data) + tuple_ret = tuple(convert_to_tensor(i, dtype=dtype, device=device) for i in data) + return torch.as_tensor(tuple_ret, dtype=dtype, device=device) if wrap_sequence else tuple_ret # type: ignore elif isinstance(data, dict): - return {k: convert_to_tensor(v) for k, v in data.items()} + return {k: convert_to_tensor(v, dtype=dtype, device=device) for k, v in data.items()} return data -def convert_to_numpy(data, wrap_sequence: bool = False): +def convert_to_numpy(data, dtype: DtypeLike = None, wrap_sequence: bool = False): """ Utility to convert the input data to a numpy array. If passing a dictionary, list or tuple, recursively check every item and convert it to numpy array. @@ -128,23 +150,24 @@ def convert_to_numpy(data, wrap_sequence: bool = False): data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will convert Tensor, Numpy array, float, int, bool to numpy arrays, strings and objects keep the original. for dictionary, list or tuple, convert every item to a numpy array if applicable. - wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`. - If `True`, then `[1, 2]` -> `array([1, 2])`. + dtype: target data type when converting to numpy array. + wrap_sequence: if `False`, then lists will recursively call this function. + E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`. """ if isinstance(data, torch.Tensor): - data = data.detach().cpu().numpy() + data = data.detach().to(dtype=get_equivalent_dtype(dtype, torch.Tensor), device="cpu").numpy() elif has_cp and isinstance(data, cp_ndarray): - data = cp.asnumpy(data) - elif isinstance(data, (float, int, bool)): - data = np.asarray(data) - elif isinstance(data, Sequence) and wrap_sequence: - return np.asarray(data) + data = cp.asnumpy(data).astype(dtype, copy=False) + elif isinstance(data, (np.ndarray, float, int, bool)): + data = np.asarray(data, dtype=dtype) elif isinstance(data, list): - return [convert_to_numpy(i) for i in data] + list_ret = [convert_to_numpy(i, dtype=dtype) for i in data] + return np.asarray(list_ret) if wrap_sequence else list_ret elif isinstance(data, tuple): - return tuple(convert_to_numpy(i) for i in data) + tuple_ret = tuple(convert_to_numpy(i, dtype=dtype) for i in data) + return np.asarray(tuple_ret) if wrap_sequence else tuple_ret elif isinstance(data, dict): - return {k: convert_to_numpy(v) for k, v in data.items()} + return {k: convert_to_numpy(v, dtype=dtype) for k, v in data.items()} if isinstance(data, np.ndarray) and data.ndim > 0: data = np.ascontiguousarray(data) @@ -152,30 +175,79 @@ def convert_to_numpy(data, wrap_sequence: bool = False): return data +def convert_to_cupy(data, dtype: Optional[np.dtype] = None, wrap_sequence: bool = False): + """ + Utility to convert the input data to a cupy array. If passing a dictionary, list or tuple, + recursively check every item and convert it to cupy array. + + Args: + data: input data can be PyTorch Tensor, numpy array, cupy array, list, dictionary, int, float, bool, str, etc. + Tensor, numpy array, cupy array, float, int, bool are converted to cupy arrays, + for dictionary, list or tuple, convert every item to a numpy array if applicable. + dtype: target data type when converting to Cupy array, tt must be an argument of `numpy.dtype`, + for more details: https://docs.cupy.dev/en/stable/reference/generated/cupy.array.html. + wrap_sequence: if `False`, then lists will recursively call this function. + E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`. + """ + + # direct calls + if isinstance(data, (cp_ndarray, np.ndarray, torch.Tensor, float, int, bool)): + data = cp.asarray(data, dtype) + elif isinstance(data, list): + list_ret = [convert_to_cupy(i, dtype) for i in data] + return cp.asarray(list_ret) if wrap_sequence else list_ret + elif isinstance(data, tuple): + tuple_ret = tuple(convert_to_cupy(i, dtype) for i in data) + return cp.asarray(tuple_ret) if wrap_sequence else tuple_ret + elif isinstance(data, dict): + return {k: convert_to_cupy(v, dtype) for k, v in data.items()} + # make it contiguous + if not isinstance(data, cp.ndarray): + raise ValueError(f"The input data type [{type(data)}] cannot be converted into cupy arrays!") + + if data.ndim > 0: + data = cp.ascontiguousarray(data) + return data + + def convert_data_type( data: Any, - output_type: Optional[type] = None, + output_type: Optional[Type[NdarrayTensor]] = None, device: Optional[torch.device] = None, - dtype: Optional[Union[DtypeLike, torch.dtype]] = None, -) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]: + dtype: Union[DtypeLike, torch.dtype] = None, + wrap_sequence: bool = False, +) -> Tuple[NdarrayTensor, type, Optional[torch.device]]: """ Convert to `torch.Tensor`/`np.ndarray` from `torch.Tensor`/`np.ndarray`/`float`/`int` etc. Args: data: data to be converted - output_type: `torch.Tensor` or `np.ndarray` (if blank, unchanged) - device: if output is `torch.Tensor`, select device (if blank, unchanged) + output_type: `torch.Tensor` or `np.ndarray` (if `None`, unchanged) + device: if output is `torch.Tensor`, select device (if `None`, unchanged) dtype: dtype of output data. Converted to correct library type (e.g., `np.float32` is converted to `torch.float32` if output type is `torch.Tensor`). If left blank, it remains unchanged. + wrap_sequence: if `False`, then lists will recursively call this function. + E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`. Returns: modified data, orig_type, orig_device + + Note: + When both `output_type` and `dtype` are specified with different backend + (e.g., `torch.Tensor` and `np.float32`), the `output_type` will be used as the primary type, + for example:: + + >>> convert_data_type(1, torch.Tensor, dtype=np.float32) + (1.0, , None) + """ - orig_type: Any + orig_type: type if isinstance(data, torch.Tensor): orig_type = torch.Tensor elif isinstance(data, np.ndarray): orig_type = np.ndarray + elif has_cp and isinstance(data, cp.ndarray): + orig_type = cp.ndarray else: orig_type = type(data) @@ -183,33 +255,49 @@ def convert_data_type( output_type = output_type or orig_type - dtype = get_equivalent_dtype(dtype or get_dtype(data), output_type) - - if output_type is torch.Tensor: - if orig_type is not torch.Tensor: - data = convert_to_tensor(data) - if dtype != data.dtype: - data = data.to(dtype) - if device is not None: - data = data.to(device) - elif output_type is np.ndarray: - if orig_type is not np.ndarray: - data = convert_to_numpy(data) - if data is not None and dtype != data.dtype: - data = data.astype(dtype) - else: - raise ValueError(f"Unsupported output type: {output_type}") - return data, orig_type, orig_device + dtype_ = get_equivalent_dtype(dtype, output_type) + data_: NdarrayTensor + if issubclass(output_type, torch.Tensor): + data_ = convert_to_tensor(data, dtype=dtype_, device=device, wrap_sequence=wrap_sequence) + return data_, orig_type, orig_device + if issubclass(output_type, np.ndarray): + data_ = convert_to_numpy(data, dtype=dtype_, wrap_sequence=wrap_sequence) + return data_, orig_type, orig_device + elif has_cp and issubclass(output_type, cp.ndarray): + data_ = convert_to_cupy(data, dtype=dtype_, wrap_sequence=wrap_sequence) + return data_, orig_type, orig_device + raise ValueError(f"Unsupported output type: {output_type}") -def convert_to_dst_type(src: Any, dst: NdarrayOrTensor) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]: + +def convert_to_dst_type( + src: Any, dst: NdarrayTensor, dtype: Union[DtypeLike, torch.dtype, None] = None, wrap_sequence: bool = False +) -> Tuple[NdarrayTensor, type, Optional[torch.device]]: """ - Convert `src` to the same `torch.Tensor`/`np.ndarray` and data type as `dst`. + Convert source data to the same data type and device as the destination data. + If `dst` is an instance of `torch.Tensor` or its subclass, convert `src` to `torch.Tensor` with the same data type as `dst`, + if `dst` is an instance of `numpy.ndarray` or its subclass, convert to `numpy.ndarray` with the same data type as `dst`, + otherwise, convert to the type of `dst` directly. + + Args: + src: source data to convert type. + dst: destination data that convert to the same data type as it. + dtype: an optional argument if the target `dtype` is different from the original `dst`'s data type. + wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`. + If `True`, then `[1, 2]` -> `array([1, 2])`. See Also: :func:`convert_data_type` """ - device = None + device = dst.device if isinstance(dst, torch.Tensor) else None + if dtype is None: + dtype = dst.dtype + + output_type: Any if isinstance(dst, torch.Tensor): - device = dst.device - return convert_data_type(data=src, output_type=type(dst), device=device, dtype=dst.dtype) + output_type = torch.Tensor + elif isinstance(dst, np.ndarray): + output_type = np.ndarray + else: + output_type = type(dst) + return convert_data_type(data=src, output_type=output_type, device=device, dtype=dtype, wrap_sequence=wrap_sequence) diff --git a/monai/visualize/__init__.py b/monai/visualize/__init__.py index 9ad61fa3f2..cd980846b3 100644 --- a/monai/visualize/__init__.py +++ b/monai/visualize/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,11 +10,7 @@ # limitations under the License. from .class_activation_maps import CAM, GradCAM, GradCAMpp, ModelWithHooks, default_normalizer -from .img2tensorboard import ( - add_animated_gif, - add_animated_gif_no_channels, - make_animated_gif_summary, - plot_2d_or_3d_image, -) +from .img2tensorboard import add_animated_gif, make_animated_gif_summary, plot_2d_or_3d_image from .occlusion_sensitivity import OcclusionSensitivity +from .utils import blend_images, matshow3d from .visualizer import default_upsampler diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index 992eaecdac..16fb64cb46 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,7 +19,7 @@ from monai.config import NdarrayTensor from monai.transforms import ScaleIntensity -from monai.utils import ensure_tuple, get_torch_version_tuple +from monai.utils import ensure_tuple, pytorch_after from monai.visualize.visualizer import default_upsampler __all__ = ["CAM", "GradCAM", "GradCAMpp", "ModelWithHooks", "default_normalizer"] @@ -39,9 +39,9 @@ def _compute(data: np.ndarray) -> np.ndarray: return np.stack([scaler(i) for i in data], axis=0) if isinstance(x, torch.Tensor): - return torch.as_tensor(_compute(x.detach().cpu().numpy()), device=x.device) + return torch.as_tensor(_compute(x.detach().cpu().numpy()), device=x.device) # type: ignore - return _compute(x) + return _compute(x) # type: ignore class ModelWithHooks: @@ -80,13 +80,13 @@ def __init__( continue _registered.append(name) if self.register_backward: - if get_torch_version_tuple() < (1, 8): - mod.register_backward_hook(self.backward_hook(name)) - else: + if pytorch_after(1, 8): if "inplace" in mod.__dict__ and mod.__dict__["inplace"]: # inplace=True causes errors for register_full_backward_hook mod.__dict__["inplace"] = False mod.register_full_backward_hook(self.backward_hook(name)) + else: + mod.register_backward_hook(self.backward_hook(name)) if self.register_forward: mod.register_forward_hook(self.forward_hook(name)) if len(_registered) != len(self.target_layers): @@ -137,6 +137,11 @@ def __call__(self, x, class_idx=None, retain_graph=False): self.score = self.class_score(logits, self.class_idx) self.model.zero_grad() self.score.sum().backward(retain_graph=retain_graph) + for layer in self.target_layers: + if layer not in self.gradients: + raise RuntimeError( + f"Backward hook for {layer} is not triggered; `requires_grad` of {layer} should be `True`." + ) grad = tuple(self.gradients[layer] for layer in self.target_layers) if train: self.model.train() @@ -221,6 +226,8 @@ class CAM(CAMBase): .. code-block:: python + import torch + # densenet 2d from monai.networks.nets import DenseNet121 from monai.visualize import CAM @@ -319,6 +326,8 @@ class GradCAM(CAMBase): .. code-block:: python + import torch + # densenet 2d from monai.networks.nets import DenseNet121 from monai.visualize import GradCAM diff --git a/monai/visualize/img2tensorboard.py b/monai/visualize/img2tensorboard.py index ccdbdc2396..0af05adf32 100644 --- a/monai/visualize/img2tensorboard.py +++ b/monai/visualize/img2tensorboard.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,17 +9,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, List, Optional, Union import numpy as np import torch from monai.config import NdarrayTensor from monai.transforms import rescale_array -from monai.utils import optional_import +from monai.utils import convert_data_type, optional_import PIL, _ = optional_import("PIL") GifImage, _ = optional_import("PIL.GifImagePlugin", name="Image") +SummaryX, _ = optional_import("tensorboardX.proto.summary_pb2", name="Summary") +SummaryWriterX, has_tensorboardx = optional_import("tensorboardX", name="SummaryWriter") if TYPE_CHECKING: from tensorboard.compat.proto.summary_pb2 import Summary @@ -28,23 +30,27 @@ Summary, _ = optional_import("tensorboard.compat.proto.summary_pb2", name="Summary") SummaryWriter, _ = optional_import("torch.utils.tensorboard", name="SummaryWriter") +__all__ = ["make_animated_gif_summary", "add_animated_gif", "plot_2d_or_3d_image"] -__all__ = ["make_animated_gif_summary", "add_animated_gif", "add_animated_gif_no_channels", "plot_2d_or_3d_image"] - -def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], scale_factor: float = 1.0) -> Summary: +def _image3_animated_gif( + tag: str, image: Union[np.ndarray, torch.Tensor], writer, frame_dim: int = 0, scale_factor: float = 1.0 +): """Function to actually create the animated gif. Args: tag: Data identifier image: 3D image tensors expected to be in `HWD` format + writer: the tensorboard writer to plot image + frame_dim: the dimension used as frames for GIF image, expect data shape as `HWD`, default to `0`. scale_factor: amount to multiply values by. if the image data is between 0 and 1, using 255 for this value will scale it to displayable range """ if len(image.shape) != 3: raise AssertionError("3D image tensors expected to be in `HWD` format, len(image.shape) != 3") - ims = [(np.asarray((image[:, :, i])) * scale_factor).astype(np.uint8) for i in range(image.shape[2])] + image_np, *_ = convert_data_type(image, output_type=np.ndarray) + ims = [(i * scale_factor).astype(np.uint8, copy=False) for i in np.moveaxis(image_np, frame_dim, 0)] ims = [GifImage.fromarray(im) for im in ims] img_str = b"" for b_data in PIL.GifImagePlugin.getheader(ims[0])[0]: @@ -54,52 +60,46 @@ def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], scale for b_data in PIL.GifImagePlugin.getdata(i): img_str += b_data img_str += b"\x3B" - summary_image_str = Summary.Image(height=10, width=10, colorspace=1, encoded_image_string=img_str) - image_summary = Summary.Value(tag=tag, image=summary_image_str) - return Summary(value=[image_summary]) + + summary = SummaryX if has_tensorboardx and isinstance(writer, SummaryWriterX) else Summary + summary_image_str = summary.Image(height=10, width=10, colorspace=1, encoded_image_string=img_str) + image_summary = summary.Value(tag=tag, image=summary_image_str) + return summary(value=[image_summary]) def make_animated_gif_summary( tag: str, image: Union[np.ndarray, torch.Tensor], + writer=None, max_out: int = 3, - animation_axes: Sequence[int] = (3,), - image_axes: Sequence[int] = (1, 2), - other_indices: Optional[Dict] = None, + frame_dim: int = -3, scale_factor: float = 1.0, ) -> Summary: """Creates an animated gif out of an image tensor in 'CHWD' format and returns Summary. Args: tag: Data identifier - image: The image, expected to be in CHWD format - max_out: maximum number of slices to animate through - animation_axes: axis to animate on (not currently used) - image_axes: axes of image (not currently used) - other_indices: (not currently used) + image: The image, expected to be in `CHWD` format + writer: the tensorboard writer to plot image + max_out: maximum number of image channels to animate through + frame_dim: the dimension used as frames for GIF image, expect input data shape as `CHWD`, + default to `-3` (the first spatial dim) scale_factor: amount to multiply values by. if the image data is between 0 and 1, using 255 for this value will scale it to displayable range """ suffix = "/image" if max_out == 1 else "/image/{}" - if other_indices is None: - other_indices = {} - axis_order = [0] + list(animation_axes) + list(image_axes) - - slicing = [] - for i in range(len(image.shape)): - if i in axis_order: - slicing.append(slice(None)) - else: - other_ind = other_indices.get(i, 0) - slicing.append(slice(other_ind, other_ind + 1)) - image = image[tuple(slicing)] + # GIF image has no channel dim, reduce the spatial dim index if positive + frame_dim = frame_dim - 1 if frame_dim > 0 else frame_dim + summary_op = [] for it_i in range(min(max_out, list(image.shape)[0])): one_channel_img: Union[torch.Tensor, np.ndarray] = ( image[it_i, :, :, :].squeeze(dim=0) if isinstance(image, torch.Tensor) else image[it_i, :, :, :] ) - summary_op = _image3_animated_gif(tag + suffix.format(it_i), one_channel_img, scale_factor) + summary_op.append( + _image3_animated_gif(tag + suffix.format(it_i), one_channel_img, writer, frame_dim, scale_factor) + ) return summary_op @@ -107,8 +107,9 @@ def add_animated_gif( writer: SummaryWriter, tag: str, image_tensor: Union[np.ndarray, torch.Tensor], - max_out: int, - scale_factor: float, + max_out: int = 3, + frame_dim: int = -3, + scale_factor: float = 1.0, global_step: Optional[int] = None, ) -> None: """Creates an animated gif out of an image tensor in 'CHWD' format and writes it with SummaryWriter. @@ -116,47 +117,20 @@ def add_animated_gif( Args: writer: Tensorboard SummaryWriter to write to tag: Data identifier - image_tensor: tensor for the image to add, expected to be in CHWD format - max_out: maximum number of slices to animate through + image_tensor: tensor for the image to add, expected to be in `CHWD` format + max_out: maximum number of image channels to animate through + frame_dim: the dimension used as frames for GIF image, expect input data shape as `CHWD`, + default to `-3` (the first spatial dim) scale_factor: amount to multiply values by. If the image data is between 0 and 1, using 255 for this value will scale it to displayable range global_step: Global step value to record """ - writer._get_file_writer().add_summary( - make_animated_gif_summary( - tag, image_tensor, max_out=max_out, animation_axes=[1], image_axes=[2, 3], scale_factor=scale_factor - ), - global_step, - ) - - -def add_animated_gif_no_channels( - writer: SummaryWriter, - tag: str, - image_tensor: Union[np.ndarray, torch.Tensor], - max_out: int, - scale_factor: float, - global_step: Optional[int] = None, -) -> None: - """Creates an animated gif out of an image tensor in 'HWD' format that does not have - a channel dimension and writes it with SummaryWriter. This is similar to the "add_animated_gif" - after inserting a channel dimension of 1. - - Args: - writer: Tensorboard SummaryWriter to write to - tag: Data identifier - image_tensor: tensor for the image to add, expected to be in HWD format - max_out: maximum number of slices to animate through - scale_factor: amount to multiply values by. If the image data is between 0 and 1, - using 255 for this value will scale it to displayable range - global_step: Global step value to record - """ - writer._get_file_writer().add_summary( - make_animated_gif_summary( - tag, image_tensor, max_out=max_out, animation_axes=[1], image_axes=[1, 2], scale_factor=scale_factor - ), - global_step, + summary = make_animated_gif_summary( + tag=tag, image=image_tensor, writer=writer, max_out=max_out, frame_dim=frame_dim, scale_factor=scale_factor ) + for s in summary: + # add GIF for every channel separately + writer._get_file_writer().add_summary(s, global_step) def plot_2d_or_3d_image( @@ -165,26 +139,33 @@ def plot_2d_or_3d_image( writer: SummaryWriter, index: int = 0, max_channels: int = 1, - max_frames: int = 64, + frame_dim: int = -3, + max_frames: int = 24, tag: str = "output", ) -> None: """Plot 2D or 3D image on the TensorBoard, 3D image will be converted to GIF image. Note: Plot 3D or 2D image(with more than 3 channels) as separate images. + And if writer is from TensorBoardX, data has 3 channels and `max_channels=3`, will plot as RGB video. Args: data: target data to be plotted as image on the TensorBoard. The data is expected to have 'NCHW[D]' dimensions or a list of data with `CHW[D]` dimensions, and only plot the first in the batch. step: current step to plot in a chart. - writer: specify TensorBoard SummaryWriter to plot the image. + writer: specify TensorBoard or TensorBoardX SummaryWriter to plot the image. index: plot which element in the input data batch, default is the first element. max_channels: number of channels to plot. - max_frames: number of frames for 2D-t plot. + frame_dim: if plotting 3D image as GIF, specify the dimension used as frames, + expect input data shape as `NCHWD`, default to `-3` (the first spatial dim) + max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`. tag: tag of the plotted image on TensorBoard. """ data_index = data[index] + # as the `d` data has no batch dim, reduce the spatial dim index if positive + frame_dim = frame_dim - 1 if frame_dim > 0 else frame_dim + d: np.ndarray = data_index.detach().cpu().numpy() if isinstance(data_index, torch.Tensor) else data_index if d.ndim == 2: @@ -206,7 +187,15 @@ def plot_2d_or_3d_image( if d.ndim >= 4: spatial = d.shape[-3:] - for j, d3 in enumerate(d.reshape([-1] + list(spatial))[:max_channels]): - d3 = rescale_array(d3, 0, 255) - add_animated_gif(writer, f"{tag}_HWD_{j}", d3[None], max_frames, 1.0, step) + d = d.reshape([-1] + list(spatial)) + if d.shape[0] == 3 and max_channels == 3 and has_tensorboardx and isinstance(writer, SummaryWriterX): # RGB + # move the expected frame dim to the end as `T` dim for video + d = np.moveaxis(d, frame_dim, -1) + writer.add_video(tag, d[None], step, fps=max_frames, dataformats="NCHWT") + return + # scale data to 0 - 255 for visualization + max_channels = min(max_channels, d.shape[0]) + d = np.stack([rescale_array(i, 0, 255) for i in d[:max_channels]], axis=0) + # will plot every channel as a separate GIF image + add_animated_gif(writer, f"{tag}_HWD", d, max_out=max_channels, frame_dim=frame_dim, global_step=step) return diff --git a/monai/visualize/occlusion_sensitivity.py b/monai/visualize/occlusion_sensitivity.py index 61b84bb406..d87b93396a 100644 --- a/monai/visualize/occlusion_sensitivity.py +++ b/monai/visualize/occlusion_sensitivity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -149,6 +149,7 @@ def __init__( mask_size: Union[int, Sequence] = 15, n_batch: int = 128, stride: Union[int, Sequence] = 1, + per_channel: bool = True, upsampler: Optional[Callable] = default_upsampler, verbose: bool = True, ) -> None: @@ -163,11 +164,14 @@ def __init__( n_batch: Number of images in a batch for inference. stride: Stride in spatial directions for performing occlusions. Can be single value or sequence (for varying stride in the different directions). - Should be >= 1. Striding in the channel direction will always be 1. + Should be >= 1. Striding in the channel direction depends on the `per_channel` argument. + per_channel: If `True`, `mask_size` and `stride` both equal 1 in the channel dimension. If `False`, + then both `mask_size` equals the number of channels in the image. If `True`, the output image will be: + `[B, C, H, W, D, num_seg_classes]`. Else, will be `[B, 1, H, W, D, num_seg_classes]` upsampler: An upsampling method to upsample the output image. Default is N-dimensional linear (bilinear, trilinear, etc.) depending on num spatial dimensions of input. - verbose: Use ``tdqm.trange`` output (if available). + verbose: Use ``tqdm.trange`` output (if available). """ self.nn_module = nn_module @@ -176,6 +180,7 @@ def __init__( self.mask_size = mask_size self.n_batch = n_batch self.stride = stride + self.per_channel = per_channel self.verbose = verbose def _compute_occlusion_sensitivity(self, x, b_box): @@ -201,32 +206,39 @@ def _compute_occlusion_sensitivity(self, x, b_box): output_im_shape = im_shape if b_box is None else b_box_max - b_box_min + 1 # Get the stride and mask_size as numpy arrays - self.stride = _get_as_np_array(self.stride, len(im_shape)) - self.mask_size = _get_as_np_array(self.mask_size, len(im_shape)) + stride = _get_as_np_array(self.stride, len(im_shape)) + mask_size = _get_as_np_array(self.mask_size, len(im_shape)) + + # If not doing it on a per-channel basis, then the output image will have 1 output channel + # (since all will be occluded together) + if not self.per_channel: + output_im_shape[0] = 1 + stride[0] = x.shape[1] + mask_size[0] = x.shape[1] # For each dimension, ... - for o, s in zip(output_im_shape, self.stride): + for o, s in zip(output_im_shape, stride): # if the size is > 1, then check that the stride is a factor of the output image shape if o > 1 and o % s != 0: raise ValueError( "Stride should be a factor of the image shape. Im shape " - + f"(taking bounding box into account): {output_im_shape}, stride: {self.stride}" + + f"(taking bounding box into account): {output_im_shape}, stride: {stride}" ) # to ensure the occluded area is nicely centred if stride is even, ensure that so is the mask_size - if np.any(self.mask_size % 2 != self.stride % 2): + if np.any(mask_size % 2 != stride % 2): raise ValueError( "Stride and mask size should both be odd or even (element-wise). " - + f"``stride={self.stride}``, ``mask_size={self.mask_size}``" + + f"``stride={stride}``, ``mask_size={mask_size}``" ) - downsampled_im_shape = (output_im_shape / self.stride).astype(np.int32) + downsampled_im_shape = (output_im_shape / stride).astype(np.int32) downsampled_im_shape[downsampled_im_shape == 0] = 1 # make sure dimension sizes are >= 1 num_required_predictions = np.prod(downsampled_im_shape) # Get bottom left and top right corners of occluded region - lower_corner = (self.stride - self.mask_size) // 2 - upper_corner = (self.stride + self.mask_size) // 2 + lower_corner = (stride - mask_size) // 2 + upper_corner = (stride + mask_size) // 2 # Loop 1D over image verbose_range = trange if self.verbose else range @@ -234,7 +246,7 @@ def _compute_occlusion_sensitivity(self, x, b_box): # Get corresponding ND index idx = np.unravel_index(i, downsampled_im_shape) # Multiply by stride - idx *= self.stride + idx *= stride # If a bounding box is being used, we need to add on # the min to shift to start of region of interest if b_box_min is not None: @@ -264,11 +276,7 @@ def _compute_occlusion_sensitivity(self, x, b_box): return sensitivity_ims, output_im_shape - def __call__( # type: ignore - self, - x: torch.Tensor, - b_box: Optional[Sequence] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + def __call__(self, x: torch.Tensor, b_box: Optional[Sequence] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: x: Image to use for inference. Should be a tensor consisting of 1 batch. @@ -285,6 +293,7 @@ def __call__( # type: ignore Hence, more -ve values imply that region was important in the decision process. * The map will have shape ``BCHW(D)N``, where N is the number of classes to be inferred by the network. Hence, the occlusion for class ``i`` can be seen with ``map[...,i]``. + * If `per_channel==False`, output ``C`` will equal 1: ``B1HW(D)N`` * Most probable class: * The most probable class when the corresponding part of the image is occluded (``argmax(dim=-1)``). Both images will be cropped if a bounding box used, but voxel sizes will always match the input. diff --git a/monai/visualize/utils.py b/monai/visualize/utils.py new file mode 100644 index 0000000000..c1111abd82 --- /dev/null +++ b/monai/visualize/utils.py @@ -0,0 +1,198 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import numpy as np + +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.croppad.array import SpatialPad +from monai.transforms.utils import rescale_array +from monai.transforms.utils_pytorch_numpy_unification import repeat, where +from monai.utils.module import optional_import +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type + +plt, _ = optional_import("matplotlib", name="pyplot") +cm, _ = optional_import("matplotlib", name="cm") + +__all__ = ["matshow3d", "blend_images"] + + +def matshow3d( + volume, + fig=None, + title: Optional[str] = None, + figsize=(10, 10), + frames_per_row: Optional[int] = None, + frame_dim: int = -3, + channel_dim: Optional[int] = None, + vmin=None, + vmax=None, + every_n: int = 1, + interpolation: str = "none", + show=False, + fill_value=np.nan, + margin: int = 1, + dtype=np.float32, + **kwargs, +): + """ + Create a 3D volume figure as a grid of images. + + Args: + volume: 3D volume to display. data shape can be `BCHWD`, `CHWD` or `HWD`. + Higher dimensional arrays will be reshaped into (-1, H, W, [C]), `C` depends on `channel_dim` arg. + A list of channel-first (C, H[, W, D]) arrays can also be passed in, + in which case they will be displayed as a padded and stacked volume. + fig: matplotlib figure to use. If None, a new figure will be created. + title: title of the figure. + figsize: size of the figure. + frames_per_row: number of frames to display in each row. If None, sqrt(firstdim) will be used. + frame_dim: for higher dimensional arrays, which dimension from (`-1`, `-2`, `-3`) is moved to + the `-3` dimension. dim and reshape to (-1, H, W) shape to construct frames, default to `-3`. + channel_dim: if not None, explicitly specify the channel dimension to be transposed to the + last dimensionas shape (-1, H, W, C). this can be used to plot RGB color image. + if None, the channel dimension will be flattened with `frame_dim` and `batch_dim` as shape (-1, H, W). + note that it can only support 3D input image. default is None. + vmin: `vmin` for the matplotlib `imshow`. + vmax: `vmax` for the matplotlib `imshow`. + every_n: factor to subsample the frames so that only every n-th frame is displayed. + interpolation: interpolation to use for the matplotlib `matshow`. + show: if True, show the figure. + fill_value: value to use for the empty part of the grid. + margin: margin to use for the grid. + dtype: data type of the output stacked frames. + kwargs: additional keyword arguments to matplotlib `matshow` and `imshow`. + + See Also: + - https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.imshow.html + - https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.matshow.html + + Example: + + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from monai.visualize import matshow3d + # create a figure of a 3D volume + >>> volume = np.random.rand(10, 10, 10) + >>> fig = plt.figure() + >>> matshow3d(volume, fig=fig, title="3D Volume") + >>> plt.show() + # create a figure of a list of channel-first 3D volumes + >>> volumes = [np.random.rand(1, 10, 10, 10), np.random.rand(1, 10, 10, 10)] + >>> fig = plt.figure() + >>> matshow3d(volumes, fig=fig, title="List of Volumes") + >>> plt.show() + + """ + vol = convert_data_type(data=volume, output_type=np.ndarray)[0] + if channel_dim is not None: + if channel_dim not in [0, 1] or vol.shape[channel_dim] not in [1, 3, 4]: + raise ValueError("channel_dim must be: None, 0 or 1, and channels of image must be 1, 3 or 4.") + + if isinstance(vol, (list, tuple)): + # a sequence of channel-first volumes + if not isinstance(vol[0], np.ndarray): + raise ValueError("volume must be a list of arrays.") + pad_size = np.max(np.asarray([v.shape for v in vol]), axis=0) + pad = SpatialPad(pad_size[1:]) # assuming channel-first for item in vol + vol = np.concatenate([pad(v) for v in vol], axis=0) + else: # ndarray + while len(vol.shape) < 3: + vol = np.expand_dims(vol, 0) # type: ignore # so that we display 2d as well + + if channel_dim is not None: # move the expected dim to construct frames with `B` dim + vol = np.moveaxis(vol, frame_dim, -4) # type: ignore + vol = vol.reshape((-1, vol.shape[-3], vol.shape[-2], vol.shape[-1])) + else: + vol = np.moveaxis(vol, frame_dim, -3) # type: ignore + vol = vol.reshape((-1, vol.shape[-2], vol.shape[-1])) + vmin = np.nanmin(vol) if vmin is None else vmin + vmax = np.nanmax(vol) if vmax is None else vmax + + # subsample every_n-th frame of the 3D volume + vol = vol[:: max(every_n, 1)] + if not frames_per_row: + frames_per_row = int(np.ceil(np.sqrt(len(vol)))) + # create the grid of frames + cols = max(min(len(vol), frames_per_row), 1) + rows = int(np.ceil(len(vol) / cols)) + width = [[0, cols * rows - len(vol)]] + if channel_dim is not None: + width += [[0, 0]] # add pad width for the channel dim + width += [[margin, margin]] * 2 + vol = np.pad(vol.astype(dtype, copy=False), width, mode="constant", constant_values=fill_value) # type: ignore + im = np.block([[vol[i * cols + j] for j in range(cols)] for i in range(rows)]) + if channel_dim is not None: + # move channel dim to the end + im = np.moveaxis(im, 0, -1) + + # figure related configurations + if fig is None: + fig = plt.figure(tight_layout=True) + if not fig.axes: + fig.add_subplot(111) + ax = fig.axes[0] + ax.matshow(im, vmin=vmin, vmax=vmax, interpolation=interpolation, **kwargs) + ax.axis("off") + + if title is not None: + ax.set_title(title) + if figsize is not None: + fig.set_size_inches(figsize) + if show: + plt.show() + return fig, im + + +def blend_images( + image: NdarrayOrTensor, label: NdarrayOrTensor, alpha: float = 0.5, cmap: str = "hsv", rescale_arrays: bool = True +): + """ + Blend an image and a label. Both should have the shape CHW[D]. + The image may have C==1 or 3 channels (greyscale or RGB). + The label is expected to have C==1. + + Args: + image: the input image to blend with label data. + label: the input label to blend with image data. + alpha: when blending image and label, `alpha` is the weight for the image region mapping to `label != 0`, + and `1 - alpha` is the weight for the label region that `label != 0`, default to `0.5`. + cmap: specify colormap in the matplotlib, default to `hsv`, for more details, please refer to: + https://matplotlib.org/2.0.2/users/colormaps.html. + rescale_arrays: whether to rescale the array to [0, 1] first, default to `True`. + + """ + + if label.shape[0] != 1: + raise ValueError("Label should have 1 channel") + if image.shape[0] not in (1, 3): + raise ValueError("Image should have 1 or 3 channels") + # rescale arrays to [0, 1] if desired + if rescale_arrays: + image = rescale_array(image) + label = rescale_array(label) + # convert image to rgb (if necessary) and then rgba + if image.shape[0] == 1: + image = repeat(image, 3, axis=0) + + def get_label_rgb(cmap: str, label: NdarrayOrTensor): + _cmap = cm.get_cmap(cmap) + label_np, *_ = convert_data_type(label, np.ndarray) + label_rgb_np = _cmap(label_np[0]) + label_rgb_np = np.moveaxis(label_rgb_np, -1, 0)[:3] + label_rgb, *_ = convert_to_dst_type(label_rgb_np, label) + return label_rgb + + label_rgb = get_label_rgb(cmap, label) + w_image = where(label == 0, 1.0, alpha) + w_label = where(label == 0, 0.0, 1 - alpha) + return w_image * image + w_label * label_rgb diff --git a/monai/visualize/visualizer.py b/monai/visualize/visualizer.py index bbb01f5c5e..5f19e4f63f 100644 --- a/monai/visualize/visualizer.py +++ b/monai/visualize/visualizer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/pyproject.toml b/pyproject.toml index 008af45f97..03e9f49ab5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,13 +2,13 @@ requires = [ "wheel", "setuptools", - "torch>=1.5", + "torch>=1.6", "ninja", ] [tool.black] line-length = 120 -target-version = ['py36', 'py37', 'py38'] +target-version = ['py37', 'py38', 'py39'] include = '\.pyi?$' exclude = ''' ( diff --git a/requirements-dev.txt b/requirements-dev.txt index 785454ad5d..4d2829f930 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,6 @@ # Full requirements for developments -r requirements-min.txt -pytorch-ignite==0.4.5 +pytorch-ignite==0.4.8 gdown>=3.6.4 scipy itk>=5.2 @@ -15,6 +15,7 @@ flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi +pylint mccabe pep8-naming pycodestyle @@ -31,8 +32,18 @@ Sphinx==3.5.3 recommonmark==0.6.0 sphinx-autodoc-typehints==1.11.1 sphinx-rtd-theme==0.5.2 -cucim~=0.19.0; platform_system == "Linux" +cucim>=21.8.2; platform_system == "Linux" openslide-python==1.1.2 +imagecodecs; platform_system == "Linux" +tifffile; platform_system == "Linux" pandas requests einops +transformers +mlflow +matplotlib!=3.5.0 +tensorboardX +types-PyYAML +pyyaml +fire +jsonschema diff --git a/requirements-min.txt b/requirements-min.txt index 5db219c840..63906b4a94 100644 --- a/requirements-min.txt +++ b/requirements-min.txt @@ -1,5 +1,5 @@ # Requirements for minimal tests -r requirements.txt -setuptools>=50.3.0 +setuptools>=50.3.0,!=60.0.0,!=60.6.0 coverage>=5.5 parameterized diff --git a/requirements.txt b/requirements.txt index 5d96284307..e4ea34b5d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -torch>=1.5 +torch>=1.6 numpy>=1.17 diff --git a/runtests.sh b/runtests.sh index f10e888543..5464f3d020 100755 --- a/runtests.sh +++ b/runtests.sh @@ -1,6 +1,6 @@ #! /bin/bash -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -38,12 +38,15 @@ doNetTests=false doDryRun=false doZooTests=false doUnitTests=false +doBuild=false doBlackFormat=false doBlackFix=false doIsortFormat=false doIsortFix=false doFlake8Format=false +doPylintFormat=false doClangFormat=false +doCopyRight=false doPytypeFormat=false doMypyFormat=false doCleanup=false @@ -54,8 +57,9 @@ NUM_PARALLEL=1 PY_EXE=${MONAI_PY_EXE:-$(which python)} function print_usage { - echo "runtests.sh [--codeformat] [--autofix] [--black] [--isort] [--flake8] [--clangformat] [--pytype] [--mypy]" - echo " [--unittests] [--disttests] [--coverage] [--quick] [--min] [--net] [--dryrun] [-j number] [--clean] [--help] [--version]" + echo "runtests.sh [--codeformat] [--autofix] [--black] [--isort] [--flake8] [--pylint] [--clangformat] [--pytype] [--mypy]" + echo " [--unittests] [--disttests] [--coverage] [--quick] [--min] [--net] [--dryrun] [-j number] [--list_tests]" + echo " [--copyright] [--build] [--clean] [--help] [--version]" echo "" echo "MONAI unit testing utilities." echo "" @@ -72,6 +76,7 @@ function print_usage { echo " --autofix : format code using \"isort\" and \"black\"" echo " --isort : perform \"isort\" import sort checks" echo " --flake8 : perform \"flake8\" code format checks" + echo " --pylint : perform \"pylint\" code format checks" echo " --clangformat : format csrc code using \"clang-format\"" echo "" echo "Python type check options:" @@ -86,10 +91,12 @@ function print_usage { echo " -q, --quick : skip long running unit tests and integration tests" echo " -m, --min : only run minimal unit tests which do not require optional packages" echo " --net : perform integration testing" + echo " -b, --build : compile and install the source code folder an editable release." echo " --list_tests : list unit tests and exit" echo "" echo "Misc. options:" echo " --dryrun : display the commands to the screen without running" + echo " --copyright : check whether every source code has a copyright header" echo " -f, --codeformat : shorthand to run all code style and static analysis tests" echo " -c, --clean : clean temporary files from tests and exit" echo " -h, --help : show this help message and exit" @@ -103,8 +110,8 @@ function print_usage { } function check_import { - echo "python: ${PY_EXE}" - ${cmdPrefix}${PY_EXE} -c "import monai" + echo "Python: ${PY_EXE}" + ${cmdPrefix}${PY_EXE} -W error -W ignore::DeprecationWarning -c "import monai" } function print_version { @@ -236,8 +243,10 @@ do doBlackFormat=true doIsortFormat=true doFlake8Format=true + doPylintFormat=true doPytypeFormat=true doMypyFormat=true + doCopyRight=true ;; --disttests) doDistTests=true @@ -250,6 +259,7 @@ do doBlackFix=true doIsortFormat=true doBlackFormat=true + doCopyRight=true ;; --clangformat) doClangFormat=true @@ -260,6 +270,9 @@ do --flake8) doFlake8Format=true ;; + --pylint) + doPylintFormat=true + ;; --pytype) doPytypeFormat=true ;; @@ -270,6 +283,12 @@ do NUM_PARALLEL=$2 shift ;; + --copyright) + doCopyRight=true + ;; + -b|--build) + doBuild=true + ;; -c|--clean) doCleanup=true ;; @@ -314,6 +333,14 @@ else check_import fi +if [ $doBuild = true ] +then + echo "${separator}${blue}compile and install${noColor}" + # try to compile MONAI cpp + compile_cpp + + echo "${green}done! (to uninstall and clean up, please use \"./runtests.sh --clean\")${noColor}" +fi if [ $doCleanup = true ] then @@ -335,12 +362,33 @@ then exit fi -# try to compile MONAI cpp -compile_cpp - # unconditionally report on the state of monai print_version +if [ $doCopyRight = true ] +then + # check copyright headers + copyright_bad=0 + copyright_all=0 + while read -r fname; do + copyright_all=$((copyright_all + 1)) + if ! grep "http://www.apache.org/licenses/LICENSE-2.0" "$fname" > /dev/null; then + print_error_msg "Missing the license header in file: $fname" + copyright_bad=$((copyright_bad + 1)) + fi + done <<< "$(find "$(pwd)/monai" "$(pwd)/tests" -type f \ + ! -wholename "*_version.py" -and -name "*.py" -or -name "*.cpp" -or -name "*.cu" -or -name "*.h")" + if [[ ${copyright_bad} -eq 0 ]]; + then + echo "${green}Source code copyright headers checked ($copyright_all).${noColor}" + else + echo "Please add the licensing header to the file ($copyright_bad of $copyright_all files)." + echo " See also: https://github.com/Project-MONAI/MONAI/blob/dev/CONTRIBUTING.md#checking-the-coding-style" + echo "" + exit 1 + fi +fi + if [ $doIsortFormat = true ] then @@ -397,9 +445,9 @@ then if [ $doBlackFix = true ] then - ${cmdPrefix}${PY_EXE} -m black "$(pwd)" + ${cmdPrefix}${PY_EXE} -m black --skip-magic-trailing-comma "$(pwd)" else - ${cmdPrefix}${PY_EXE} -m black --check "$(pwd)" + ${cmdPrefix}${PY_EXE} -m black --skip-magic-trailing-comma --check "$(pwd)" fi black_status=$? @@ -421,7 +469,7 @@ then # ensure that the necessary packages for code format testing are installed if ! is_pip_installed flake8 - then + then install_deps fi ${cmdPrefix}${PY_EXE} -m flake8 --version @@ -439,19 +487,47 @@ then set -e # enable exit on failure fi +if [ $doPylintFormat = true ] +then + set +e # disable exit on failure so that diagnostics can be given on failure + echo "${separator}${blue}pylint${noColor}" + + # ensure that the necessary packages for code format testing are installed + if ! is_pip_installed flake8 + then + install_deps + fi + ${cmdPrefix}${PY_EXE} -m pylint --version + + ignore_codes="E1101,E1102,E0601,E1130,E1123,E0102,E1120,E1137,E1136" + ${cmdPrefix}${PY_EXE} -m pylint monai tests -E --disable=$ignore_codes -j $NUM_PARALLEL + pylint_status=$? + + if [ ${pylint_status} -ne 0 ] + then + print_style_fail_msg + exit ${pylint_status} + else + echo "${green}passed!${noColor}" + fi + set -e # enable exit on failure +fi + if [ $doPytypeFormat = true ] then set +e # disable exit on failure so that diagnostics can be given on failure echo "${separator}${blue}pytype${noColor}" - if [[ "$OSTYPE" == "darwin"* ]]; then - echo "${red}pytype not working on macOS (https://github.com/Project-MONAI/MONAI/issues/2391), skipping the tests.${noColor}" + # ensure that the necessary packages for code format testing are installed + if ! is_pip_installed pytype + then + install_deps + fi + pytype_ver=$(${cmdPrefix}${PY_EXE} -m pytype --version) + if [[ "$OSTYPE" == "darwin"* && "$pytype_ver" == "2021."* ]]; then + echo "${red}pytype not working on macOS 2021 (https://github.com/Project-MONAI/MONAI/issues/2391). Please upgrade to 2022*.${noColor}" + exit 1 else - # ensure that the necessary packages for code format testing are installed - if ! is_pip_installed pytype - then - install_deps - fi ${cmdPrefix}${PY_EXE} -m pytype --version ${cmdPrefix}${PY_EXE} -m pytype -j ${NUM_PARALLEL} --python-version="$(${PY_EXE} -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")" @@ -535,7 +611,7 @@ if [ $doUnitTests = true ] then echo "${separator}${blue}unittests${noColor}" torch_validate - ${cmdPrefix}${cmd} ./tests/runner.py -p "test_((?!integration).)" + ${cmdPrefix}${cmd} ./tests/runner.py -p "^(?!test_integration).*(?= 3.6 +python_requires = >= 3.7 # for compiling and develop setup only # no need to specify the versions so that we could # compile for multiple targeted versions. @@ -24,7 +24,7 @@ setup_requires = torch ninja install_requires = - torch>=1.5 + torch>=1.6 numpy>=1.17 [options.extras_require] @@ -34,16 +34,25 @@ all = pillow tensorboard gdown>=3.6.4 - pytorch-ignite==0.4.5 + pytorch-ignite==0.4.8 torchvision itk>=5.2 tqdm>=4.47.0 lmdb psutil - cucim~=0.19.0 + cucim>=21.8.2 openslide-python==1.1.2 + tifffile + imagecodecs pandas einops + transformers + mlflow + matplotlib + tensorboardX + pyyaml + fire + jsonschema nibabel = nibabel skimage = @@ -55,7 +64,7 @@ tensorboard = gdown = gdown>=3.6.4 ignite = - pytorch-ignite==0.4.5 + pytorch-ignite==0.4.8 torchvision = torchvision itk = @@ -67,23 +76,46 @@ lmdb = psutil = psutil cucim = - cucim~=0.19.0 + cucim>=21.8.2 openslide = openslide-python==1.1.2 +tifffile = + tifffile +imagecodecs = + imagecodecs pandas = pandas einops = einops +transformers = + transformers +mlflow = + mlflow +matplotlib = + matplotlib +tensorboardX = + tensorboardX +pyyaml = + pyyaml +fire = + fire +jsonschema = + jsonschema + [flake8] select = B,C,E,F,N,P,T4,W,B9 max_line_length = 120 # C408 ignored because we like the dict keyword argument syntax # E501 is not flexible enough, we're using B950 instead ignore = - E203,E305,E402,E501,E721,E741,F821,F841,F999,W503,W504,C408,E302,W291,E303, - # N812 lowercase 'torch.nn.functional' imported as non lowercase 'F' - N812 -per_file_ignores = __init__.py: F401 + E203 + E501 + E741 + W503 + W504 + C408 + N812 # lowercase 'torch.nn.functional' imported as non lowercase 'F' +per_file_ignores = __init__.py: F401, __main__.py: F401 exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv,.venv,_version.py [isort] diff --git a/setup.py b/setup.py index eeaffb7823..219f0eb957 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -38,7 +38,7 @@ BUILD_CUDA = (CUDA_HOME is not None) if torch.cuda.is_available() else FORCE_CUDA - _pt_version = pkg_resources.parse_version(torch.__version__).release # type: ignore[attr-defined] + _pt_version = pkg_resources.parse_version(torch.__version__).release if _pt_version is None or len(_pt_version) < 3: raise AssertionError("unknown torch version") TORCH_VERSION = int(_pt_version[0]) * 10000 + int(_pt_version[1]) * 100 + int(_pt_version[2]) @@ -53,11 +53,7 @@ def torch_parallel_backend(): try: - match = re.search( - "^ATen parallel backend: (?P.*)$", - torch._C._parallel_info(), - re.MULTILINE, - ) + match = re.search("^ATen parallel backend: (?P.*)$", torch._C._parallel_info(), re.MULTILINE) if match is None: return None backend = match.group("backend") diff --git a/tests/__init__.py b/tests/__init__.py index 5093d1f72d..4639a58496 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/clang_format_utils.py b/tests/clang_format_utils.py index 41902eb272..1b13ce0ac3 100644 --- a/tests/clang_format_utils.py +++ b/tests/clang_format_utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,8 +34,8 @@ # This dictionary maps each platform to a relative path to a file containing its reference hash. # github/pytorch/pytorch/tree/63d62d3e44a0a4ec09d94f30381d49b78cc5b095/tools/clang_format_hash PLATFORM_TO_HASH = { - "Darwin": "b24cc8972344c4e01afbbae78d6a414f7638ff6f", - "Linux": "9073602de1c4e1748f2feea5a0782417b20e3043", + "Darwin": "1485a242a96c737ba7cdd9f259114f2201accdb46d87ac7a8650b1a814cd4d4d", + "Linux": "e1c8b97b919541a99e0a355df5c3f9e8abebc64259dbee6f8c68e1ef90582856", } # Directory and file paths for the clang-format binary. @@ -50,15 +50,15 @@ def get_and_check_clang_format(): """ # If the host platform is not in PLATFORM_TO_HASH, it is unsupported. if HOST_PLATFORM not in PLATFORM_TO_HASH: - print("Unsupported platform: {}".format(HOST_PLATFORM)) + print(f"Unsupported platform: {HOST_PLATFORM}") return False if HOST_PLATFORM not in PLATFORM_TO_CF_URL: - print("Unsupported platform: {}".format(HOST_PLATFORM)) + print(f"Unsupported platform: {HOST_PLATFORM}") return False try: download_url( - PLATFORM_TO_CF_URL[HOST_PLATFORM], CLANG_FORMAT_PATH, PLATFORM_TO_HASH[HOST_PLATFORM], hash_type="sha1" + PLATFORM_TO_CF_URL[HOST_PLATFORM], CLANG_FORMAT_PATH, PLATFORM_TO_HASH[HOST_PLATFORM], hash_type="sha256" ) except Exception as e: print(f"Download {CLANG_FORMAT_PATH} failed: {e}") @@ -69,7 +69,7 @@ def get_and_check_clang_format(): mode = os.stat(CLANG_FORMAT_PATH).st_mode mode |= stat.S_IXUSR os.chmod(CLANG_FORMAT_PATH, mode) - print("Using clang-format located at {}".format(CLANG_FORMAT_PATH)) + print(f"Using clang-format located at {CLANG_FORMAT_PATH}") return True diff --git a/tests/hvd_evenly_divisible_all_gather.py b/tests/hvd_evenly_divisible_all_gather.py index 42b2e9530d..cf8254b614 100644 --- a/tests/hvd_evenly_divisible_all_gather.py +++ b/tests/hvd_evenly_divisible_all_gather.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,10 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import horovod.torch as hvd import torch from monai.utils import evenly_divisible_all_gather +from monai.utils.module import optional_import + +hvd, has_hvd = optional_import("horovod", name="torch") class HvdEvenlyDivisibleAllGather: diff --git a/tests/min_tests.py b/tests/min_tests.py index 5b376d7b57..c0d4f36430 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -33,8 +33,11 @@ def run_testsuit(): "test_cachedataset_parallel", "test_cachedataset_persistent_workers", "test_cachentransdataset", + "test_contrastive_loss", + "test_check_missing_files", "test_csv_dataset", "test_csv_iterable_dataset", + "test_cumulative_average_dist", "test_dataset", "test_dataset_summary", "test_deepedit_transforms", @@ -42,11 +45,14 @@ def run_testsuit(): "test_deepgrow_interaction", "test_deepgrow_transforms", "test_detect_envelope", + "test_dints_network", "test_efficientnet", "test_ensemble_evaluator", "test_ensure_channel_first", "test_ensure_channel_firstd", "test_fill_holes", + "test_fill_holesd", + "test_global_mutual_information_loss", "test_handler_checkpoint_loader", "test_handler_checkpoint_saver", "test_handler_classification_saver", @@ -61,6 +67,7 @@ def run_testsuit(): "test_handler_mean_dice", "test_handler_metrics_saver", "test_handler_metrics_saver_dist", + "test_handler_mlflow", "test_handler_nvtx", "test_handler_parameter_scheduler", "test_handler_post_processing", @@ -75,18 +82,20 @@ def run_testsuit(): "test_handler_surface_distance", "test_handler_tb_image", "test_handler_tb_stats", - "test_handler_transform_inverter", "test_handler_validation", "test_hausdorff_distance", "test_header_correct", "test_hilbert_transform", "test_image_dataset", + "test_image_rw", "test_img2tensorboard", + "test_integration_fast_train", "test_integration_segmentation_3d", "test_integration_sliding_window", "test_integration_unet_2d", "test_integration_workflows", "test_integration_workflows_gan", + "test_integration_bundle_run", "test_invertd", "test_iterable_dataset", "test_keep_largest_connected_component", @@ -94,10 +103,12 @@ def run_testsuit(): "test_label_filter", "test_lltm", "test_lmdbdataset", + "test_lmdbdataset_dist", "test_load_image", "test_load_imaged", "test_load_spacing_orientation", "test_mednistdataset", + "test_milmodel", "test_mlp", "test_nifti_header_revise", "test_nifti_rw", @@ -112,6 +123,8 @@ def run_testsuit(): "test_plot_2d_or_3d_image", "test_png_rw", "test_png_saver", + "test_prepare_batch_default", + "test_prepare_batch_extra_input", "test_rand_rotate", "test_rand_rotated", "test_rand_zoom", @@ -119,6 +132,8 @@ def run_testsuit(): "test_randtorchvisiond", "test_resize", "test_resized", + "test_resample_to_match", + "test_resample_to_matchd", "test_rotate", "test_rotated", "test_save_image", @@ -132,14 +147,21 @@ def run_testsuit(): "test_testtimeaugmentation", "test_torchvision", "test_torchvisiond", + "test_transchex", "test_transformerblock", "test_unetr", "test_unetr_block", "test_vit", + "test_vitautoenc", "test_write_metrics_reports", + "test_wsireader", "test_zoom", "test_zoom_affine", "test_zoomd", + "test_prepare_batch_default_dist", + "test_parallel_execution_dist", + "test_bundle_verify_metadata", + "test_bundle_verify_net", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/ngc_mmar_loading.py b/tests/ngc_mmar_loading.py new file mode 100644 index 0000000000..df48ebc564 --- /dev/null +++ b/tests/ngc_mmar_loading.py @@ -0,0 +1,43 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import torch +from parameterized import parameterized + +from monai.apps.mmars import MODEL_DESC, load_from_mmar +from monai.config import print_debug_info + + +class TestAllDownloadingMMAR(unittest.TestCase): + def setUp(self): + print_debug_info() + self.test_dir = "./" + + @parameterized.expand((item,) for item in MODEL_DESC) + def test_loading_mmar(self, item): + if item["name"] == "clara_pt_fed_learning_brain_tumor_mri_segmentation": + default_model_file = os.path.join("models", "server", "best_FL_global_model.pt") + else: + default_model_file = None + pretrained_model = load_from_mmar( + item=item["name"], mmar_dir="./", map_location="cpu", api=True, model_file=default_model_file + ) + self.assertTrue(isinstance(pretrained_model, torch.nn.Module)) + + def tearDown(self): + print(os.listdir(self.test_dir)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/runner.py b/tests/runner.py index b340d60719..7356581365 100644 --- a/tests/runner.py +++ b/tests/runner.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_acn_block.py b/tests/test_acn_block.py new file mode 100644 index 0000000000..4c12155fd8 --- /dev/null +++ b/tests/test_acn_block.py @@ -0,0 +1,38 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.blocks.dints_block import ActiConvNormBlock + +TEST_CASES = [ + [{"in_channel": 32, "out_channel": 16, "kernel_size": 3, "padding": 1}, (7, 32, 16, 31, 7), (7, 16, 16, 31, 7)], + [ + {"in_channel": 32, "out_channel": 16, "kernel_size": 3, "padding": 1, "spatial_dims": 2}, + (7, 32, 13, 32), + (7, 16, 13, 32), + ], +] + + +class TestACNBlock(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_acn_block(self, input_param, input_shape, expected_shape): + net = ActiConvNormBlock(**input_param) + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_activations.py b/tests/test_activations.py index 7d8b3e4c38..a67e6f8cb6 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,27 +16,36 @@ from monai.networks.layers.factories import Act from monai.transforms import Activations - -TEST_CASE_1 = [ - {"sigmoid": True, "softmax": False, "other": None}, - torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), - torch.tensor([[[0.5000, 0.7311], [0.8808, 0.9526]]]), - (1, 2, 2), -] - -TEST_CASE_2 = [ - {"sigmoid": False, "softmax": True, "other": None}, - torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), - torch.tensor([[[0.1192, 0.1192]], [[0.8808, 0.8808]]]), - (2, 1, 2), -] - -TEST_CASE_3 = [ - {"sigmoid": False, "softmax": False, "other": torch.tanh}, - torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), - torch.tensor([[[0.0000, 0.7616], [0.9640, 0.9951]]]), - (1, 2, 2), -] +from tests.utils import TEST_NDARRAYS, assert_allclose + +TEST_CASES = [] +for p in TEST_NDARRAYS: + TEST_CASES.append( + [ + {"sigmoid": True, "softmax": False, "other": None}, + p([[[0.0, 1.0], [2.0, 3.0]]]), + p([[[0.5000, 0.7311], [0.8808, 0.9526]]]), + (1, 2, 2), + ] + ) + + TEST_CASES.append( + [ + {"sigmoid": False, "softmax": True, "other": None}, + p([[[0.0, 1.0]], [[2.0, 3.0]]]), + p([[[0.1192, 0.1192]], [[0.8808, 0.8808]]]), + (2, 1, 2), + ] + ) + + TEST_CASES.append( + [ + {"sigmoid": False, "softmax": False, "other": torch.tanh}, + p([[[0.0, 1.0], [2.0, 3.0]]]), + p([[[0.0000, 0.7616], [0.9640, 0.9951]]]), + (1, 2, 2), + ] + ) TEST_CASE_4 = [ "swish", @@ -67,12 +76,12 @@ class TestActivations(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TEST_CASES[:3]) def test_value_shape(self, input_param, img, out, expected_shape): result = Activations(**input_param)(img) def _compare(ret, out, shape): - torch.testing.assert_allclose(ret, out) + assert_allclose(ret, out, rtol=1e-3) self.assertTupleEqual(ret.shape, shape) if isinstance(result, (list, tuple)): diff --git a/tests/test_activationsd.py b/tests/test_activationsd.py index 355c50f389..557d68de90 100644 --- a/tests/test_activationsd.py +++ b/tests/test_activationsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,43 +15,46 @@ from parameterized import parameterized from monai.transforms import Activationsd - -TEST_CASE_1 = [ - {"keys": ["pred", "label"], "sigmoid": False, "softmax": [True, False], "other": None}, - {"pred": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), "label": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]])}, - { - "pred": torch.tensor([[[0.1192, 0.1192]], [[0.8808, 0.8808]]]), - "label": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), - }, - (2, 1, 2), -] - -TEST_CASE_2 = [ - {"keys": ["pred", "label"], "sigmoid": False, "softmax": False, "other": [torch.tanh, None]}, - {"pred": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), "label": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])}, - { - "pred": torch.tensor([[[0.0000, 0.7616], [0.9640, 0.9951]]]), - "label": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), - }, - (1, 2, 2), -] - -TEST_CASE_3 = [ - {"keys": "pred", "sigmoid": False, "softmax": False, "other": torch.tanh}, - {"pred": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])}, - {"pred": torch.tensor([[[0.0000, 0.7616], [0.9640, 0.9951]]])}, - (1, 2, 2), -] +from tests.utils import TEST_NDARRAYS, assert_allclose + +TEST_CASES = [] +for p in TEST_NDARRAYS: + TEST_CASES.append( + [ + {"keys": ["pred", "label"], "sigmoid": False, "softmax": [True, False], "other": None}, + {"pred": p([[[0.0, 1.0]], [[2.0, 3.0]]]), "label": p([[[0.0, 1.0]], [[2.0, 3.0]]])}, + {"pred": p([[[0.1192, 0.1192]], [[0.8808, 0.8808]]]), "label": p([[[0.0, 1.0]], [[2.0, 3.0]]])}, + (2, 1, 2), + ] + ) + + TEST_CASES.append( + [ + {"keys": ["pred", "label"], "sigmoid": False, "softmax": False, "other": [torch.tanh, None]}, + {"pred": p([[[0.0, 1.0], [2.0, 3.0]]]), "label": p([[[0.0, 1.0], [2.0, 3.0]]])}, + {"pred": p([[[0.0000, 0.7616], [0.9640, 0.9951]]]), "label": p([[[0.0, 1.0], [2.0, 3.0]]])}, + (1, 2, 2), + ] + ) + + TEST_CASES.append( + [ + {"keys": "pred", "sigmoid": False, "softmax": False, "other": torch.tanh}, + {"pred": p([[[0.0, 1.0], [2.0, 3.0]]])}, + {"pred": p([[[0.0000, 0.7616], [0.9640, 0.9951]]])}, + (1, 2, 2), + ] + ) class TestActivationsd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TEST_CASES) def test_value_shape(self, input_param, test_input, output, expected_shape): result = Activationsd(**input_param)(test_input) - torch.testing.assert_allclose(result["pred"], output["pred"]) + assert_allclose(result["pred"], output["pred"], rtol=1e-3) self.assertTupleEqual(result["pred"].shape, expected_shape) if "label" in result: - torch.testing.assert_allclose(result["label"], output["label"]) + assert_allclose(result["label"], output["label"], rtol=1e-3) self.assertTupleEqual(result["label"].shape, expected_shape) diff --git a/tests/test_adaptors.py b/tests/test_adaptors.py index 9bcd01feb7..f59bdaa15e 100644 --- a/tests/test_adaptors.py +++ b/tests/test_adaptors.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,7 +20,7 @@ def test_function_signature(self): def foo(image, label=None, *a, **kw): pass - f = FunctionSignature(foo) + _ = FunctionSignature(foo) def test_single_in_single_out(self): def foo(image): diff --git a/tests/test_add_channeld.py b/tests/test_add_channeld.py index 8bdd89a4ae..9dc984aff3 100644 --- a/tests/test_add_channeld.py +++ b/tests/test_add_channeld.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_add_coordinate_channels.py b/tests/test_add_coordinate_channels.py index 3399008e02..5a483e25b9 100644 --- a/tests/test_add_coordinate_channels.py +++ b/tests/test_add_coordinate_channels.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,32 +12,36 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import AddCoordinateChannels +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [{"spatial_channels": (1, 2, 3)}, np.random.randint(0, 2, size=(1, 3, 3, 3)), (4, 3, 3, 3)] - -TEST_CASE_2 = [{"spatial_channels": (1,)}, np.random.randint(0, 2, size=(1, 3, 3, 3)), (2, 3, 3, 3)] - -TEST_CASE_ERROR_3 = [{"spatial_channels": (3,)}, np.random.randint(0, 2, size=(1, 3, 3))] - -TEST_CASE_ERROR_4 = [{"spatial_channels": (0, 1, 2)}, np.random.randint(0, 2, size=(1, 3, 3))] +TESTS, TEST_CASES_ERROR_1, TEST_CASES_ERROR_2 = [], [], [] +for p in TEST_NDARRAYS: + TESTS.append([{"spatial_dims": (0, 1, 2)}, p(np.random.randint(0, 2, size=(1, 3, 3, 3))), (4, 3, 3, 3)]) + TESTS.append([{"spatial_dims": (0,)}, p(np.random.randint(0, 2, size=(1, 3, 3, 3))), (2, 3, 3, 3)]) + TEST_CASES_ERROR_1.append([{"spatial_dims": (2,)}, p(np.random.randint(0, 2, size=(1, 3, 3)))]) + TEST_CASES_ERROR_2.append([{"spatial_dims": (-1, 0, 1)}, p(np.random.randint(0, 2, size=(1, 3, 3)))]) class TestAddCoordinateChannels(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input, expected_shape): result = AddCoordinateChannels(**input_param)(input) + self.assertEqual(type(result), type(input)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, input.device) self.assertEqual(list(result.shape), list(expected_shape)) - np.testing.assert_array_equal(input[0, ...], result[0, ...]) + assert_allclose(input[0, ...], result[0, ...]) - @parameterized.expand([TEST_CASE_ERROR_3]) + @parameterized.expand(TEST_CASES_ERROR_1) def test_max_channel(self, input_param, input): with self.assertRaises(ValueError): AddCoordinateChannels(**input_param)(input) - @parameterized.expand([TEST_CASE_ERROR_4]) + @parameterized.expand(TEST_CASES_ERROR_2) def test_channel_dim(self, input_param, input): with self.assertRaises(ValueError): AddCoordinateChannels(**input_param)(input) diff --git a/tests/test_add_coordinate_channelsd.py b/tests/test_add_coordinate_channelsd.py index 0fa6aae1c9..c14ff0ba64 100644 --- a/tests/test_add_coordinate_channelsd.py +++ b/tests/test_add_coordinate_channelsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,40 +12,50 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import AddCoordinateChannelsd +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"spatial_channels": (1, 2, 3), "keys": ["img"]}, - {"img": np.random.randint(0, 2, size=(1, 3, 3, 3))}, - (4, 3, 3, 3), -] +TESTS, TEST_CASES_ERROR_1, TEST_CASES_ERROR_2 = [], [], [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"spatial_dims": (0, 1, 2), "keys": ["img"]}, + {"img": p(np.random.randint(0, 2, size=(1, 3, 3, 3)))}, + (4, 3, 3, 3), + ] + ) + TESTS.append( + [{"spatial_dims": (0,), "keys": ["img"]}, {"img": p(np.random.randint(0, 2, size=(1, 3, 3, 3)))}, (2, 3, 3, 3)] + ) -TEST_CASE_2 = [ - {"spatial_channels": (1,), "keys": ["img"]}, - {"img": np.random.randint(0, 2, size=(1, 3, 3, 3))}, - (2, 3, 3, 3), -] - -TEST_CASE_ERROR_3 = [{"spatial_channels": (3,), "keys": ["img"]}, {"img": np.random.randint(0, 2, size=(1, 3, 3))}] - -TEST_CASE_ERROR_4 = [{"spatial_channels": (0, 1, 2), "keys": ["img"]}, {"img": np.random.randint(0, 2, size=(1, 3, 3))}] + TEST_CASES_ERROR_1.append( + [{"spatial_dims": (2,), "keys": ["img"]}, {"img": p(np.random.randint(0, 2, size=(1, 3, 3)))}] + ) + TEST_CASES_ERROR_2.append( + [{"spatial_dims": (-1, 0, 1), "keys": ["img"]}, {"img": p(np.random.randint(0, 2, size=(1, 3, 3)))}] + ) class TestAddCoordinateChannels(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input, expected_shape): - result = AddCoordinateChannelsd(**input_param)(input) - self.assertEqual(list(result["img"].shape), list(expected_shape)) - np.testing.assert_array_equal(input["img"][0, ...], result["img"][0, ...]) + result = AddCoordinateChannelsd(**input_param)(input)["img"] + input = input["img"] + self.assertEqual(type(result), type(input)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, input.device) + self.assertEqual(result.shape, expected_shape) + assert_allclose(input[0, ...], result[0, ...]) - @parameterized.expand([TEST_CASE_ERROR_3]) + @parameterized.expand(TEST_CASES_ERROR_1) def test_max_channel(self, input_param, input): with self.assertRaises(ValueError): AddCoordinateChannelsd(**input_param)(input) - @parameterized.expand([TEST_CASE_ERROR_4]) + @parameterized.expand(TEST_CASES_ERROR_2) def test_channel_dim(self, input_param, input): with self.assertRaises(ValueError): AddCoordinateChannelsd(**input_param)(input) diff --git a/tests/test_add_extreme_points_channel.py b/tests/test_add_extreme_points_channel.py index ecf2c83d3c..d2c8a627b6 100644 --- a/tests/test_add_extreme_points_channel.py +++ b/tests/test_add_extreme_points_channel.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,52 +15,63 @@ from parameterized import parameterized from monai.transforms import AddExtremePointsChannel +from tests.utils import TEST_NDARRAYS, assert_allclose IMG_CHANNEL = 3 +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + TESTS.append( + [ + { + "img": p(np.zeros((IMG_CHANNEL, 4, 3))), + "label": q(np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]])), + "sigma": 1.0, + "rescale_min": 0.0, + "rescale_max": 1.0, + }, + p( + np.array( + [ + [0.38318458, 0.98615628, 0.85551184], + [0.35422316, 0.94430935, 1.0], + [0.46000731, 0.57319659, 0.46000722], + [0.64577687, 0.38318464, 0.0], + ] + ) + ), + ] + ) -TEST_CASE_1 = [ - { - "img": np.zeros((IMG_CHANNEL, 4, 3)), - "label": np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]]), - "sigma": 1.0, - "rescale_min": 0.0, - "rescale_max": 1.0, - }, - np.array( - [ - [0.38318458, 0.98615628, 0.85551184], - [0.35422316, 0.94430935, 1.0], - [0.46000731, 0.57319659, 0.46000722], - [0.64577687, 0.38318464, 0.0], - ] - ), -] - -TEST_CASE_2 = [ - { - "img": np.zeros((IMG_CHANNEL, 4, 3)), - "label": np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]]), - "sigma": 1.0, - "rescale_min": 0.0, - "rescale_max": 1.0, - }, - np.array( - [ - [0.44628328, 0.80495411, 0.44628328], - [0.6779086, 1.0, 0.67790854], - [0.33002687, 0.62079221, 0.33002687], - [0.0, 0.31848389, 0.0], - ] - ), -] + TESTS.append( + [ + { + "img": p(np.zeros((IMG_CHANNEL, 4, 3))), + "label": q(np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]])), + "sigma": 1.0, + "rescale_min": 0.0, + "rescale_max": 1.0, + }, + p( + np.array( + [ + [0.44628328, 0.80495411, 0.44628328], + [0.6779086, 1.0, 0.67790854], + [0.33002687, 0.62079221, 0.33002687], + [0.0, 0.31848389, 0.0], + ] + ) + ), + ] + ) class TestAddExtremePointsChannel(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_correct_results(self, input_data, expected): add_extreme_points_channel = AddExtremePointsChannel() result = add_extreme_points_channel(**input_data) - np.testing.assert_allclose(result[IMG_CHANNEL], expected, rtol=1e-4) + assert_allclose(result[IMG_CHANNEL], expected, rtol=1e-4) if __name__ == "__main__": diff --git a/tests/test_add_extreme_points_channeld.py b/tests/test_add_extreme_points_channeld.py index e33bb0838c..39d221596f 100644 --- a/tests/test_add_extreme_points_channeld.py +++ b/tests/test_add_extreme_points_channeld.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,42 +15,60 @@ from parameterized import parameterized from monai.transforms import AddExtremePointsChanneld +from tests.utils import TEST_NDARRAYS, assert_allclose IMG_CHANNEL = 3 -TEST_CASE_1 = [ - {"img": np.zeros((IMG_CHANNEL, 4, 3)), "label": np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]])}, - np.array( - [ - [0.38318458, 0.98615628, 0.85551184], - [0.35422316, 0.94430935, 1.0], - [0.46000731, 0.57319659, 0.46000722], - [0.64577687, 0.38318464, 0.0], - ] - ), -] - -TEST_CASE_2 = [ - {"img": np.zeros((IMG_CHANNEL, 4, 3)), "label": np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]])}, - np.array( - [ - [0.44628328, 0.80495411, 0.44628328], - [0.6779086, 1.0, 0.67790854], - [0.33002687, 0.62079221, 0.33002687], - [0.0, 0.31848389, 0.0], - ] - ), -] +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + TESTS.append( + [ + { + "img": p(np.zeros((IMG_CHANNEL, 4, 3))), + "label": q(np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]])), + }, + p( + np.array( + [ + [0.38318458, 0.98615628, 0.85551184], + [0.35422316, 0.94430935, 1.0], + [0.46000731, 0.57319659, 0.46000722], + [0.64577687, 0.38318464, 0.0], + ] + ) + ), + ] + ) + + TESTS.append( + [ + { + "img": p(np.zeros((IMG_CHANNEL, 4, 3))), + "label": q(np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]])), + }, + p( + np.array( + [ + [0.44628328, 0.80495411, 0.44628328], + [0.6779086, 1.0, 0.67790854], + [0.33002687, 0.62079221, 0.33002687], + [0.0, 0.31848389, 0.0], + ] + ) + ), + ] + ) class TestAddExtremePointsChanneld(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_correct_results(self, input_data, expected): add_extreme_points_channel = AddExtremePointsChanneld( keys="img", label_key="label", sigma=1.0, rescale_min=0.0, rescale_max=1.0 ) result = add_extreme_points_channel(input_data) - np.testing.assert_allclose(result["img"][IMG_CHANNEL], expected, rtol=1e-4) + assert_allclose(result["img"][IMG_CHANNEL], expected, rtol=1e-4) if __name__ == "__main__": diff --git a/tests/test_adjust_contrast.py b/tests/test_adjust_contrast.py index 8e78698360..2f6c4e2259 100644 --- a/tests/test_adjust_contrast.py +++ b/tests/test_adjust_contrast.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import AdjustContrast -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose TEST_CASE_1 = [1.0] @@ -28,15 +28,16 @@ class TestAdjustContrast(NumpyImageTestCase2D): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_correct_results(self, gamma): adjuster = AdjustContrast(gamma=gamma) - result = adjuster(self.imt) - if gamma == 1.0: - expected = self.imt - else: - epsilon = 1e-7 - img_min = self.imt.min() - img_range = self.imt.max() - img_min - expected = np.power(((self.imt - img_min) / float(img_range + epsilon)), gamma) * img_range + img_min - np.testing.assert_allclose(expected, result, rtol=1e-05) + for p in TEST_NDARRAYS: + result = adjuster(p(self.imt)) + if gamma == 1.0: + expected = self.imt + else: + epsilon = 1e-7 + img_min = self.imt.min() + img_range = self.imt.max() - img_min + expected = np.power(((self.imt - img_min) / float(img_range + epsilon)), gamma) * img_range + img_min + assert_allclose(expected, result, rtol=1e-05, type_test=False) if __name__ == "__main__": diff --git a/tests/test_adjust_contrastd.py b/tests/test_adjust_contrastd.py index 65647607e4..a7224b643b 100644 --- a/tests/test_adjust_contrastd.py +++ b/tests/test_adjust_contrastd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import AdjustContrastd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose TEST_CASE_1 = [1.0] @@ -28,15 +28,16 @@ class TestAdjustContrastd(NumpyImageTestCase2D): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_correct_results(self, gamma): adjuster = AdjustContrastd("img", gamma=gamma) - result = adjuster({"img": self.imt}) - if gamma == 1.0: - expected = self.imt - else: - epsilon = 1e-7 - img_min = self.imt.min() - img_range = self.imt.max() - img_min - expected = np.power(((self.imt - img_min) / float(img_range + epsilon)), gamma) * img_range + img_min - np.testing.assert_allclose(expected, result["img"], rtol=1e-05) + for p in TEST_NDARRAYS: + result = adjuster({"img": p(self.imt)}) + if gamma == 1.0: + expected = self.imt + else: + epsilon = 1e-7 + img_min = self.imt.min() + img_range = self.imt.max() - img_min + expected = np.power(((self.imt - img_min) / float(img_range + epsilon)), gamma) * img_range + img_min + assert_allclose(expected, result["img"], rtol=1e-05, type_test=False) if __name__ == "__main__": diff --git a/tests/test_adn.py b/tests/test_adn.py index 2130ebc005..2352f5c1e2 100644 --- a/tests/test_adn.py +++ b/tests/test_adn.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_affine.py b/tests/test_affine.py index dd82d72e23..d681d2941b 100644 --- a/tests/test_affine.py +++ b/tests/test_affine.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,78 +16,150 @@ from parameterized import parameterized from monai.transforms import Affine +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(9).reshape((1, 3, 3)), "spatial_size": (-1, 0)}, - np.arange(9).reshape(1, 3, 3), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None, image_only=True), - {"img": np.arange(9).reshape((1, 3, 3)), "spatial_size": (-1, 0)}, - np.arange(9).reshape(1, 3, 3), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.arange(4).reshape(1, 2, 2), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2)), "spatial_size": (4, 4)}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(rotate_params=[np.pi / 2], padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2)), "spatial_size": (4, 4)}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(27).reshape((1, 3, 3, 3)), "spatial_size": (-1, 0, 0)}, - np.arange(27).reshape(1, 3, 3, 3), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(8).reshape((1, 2, 2, 2)), "spatial_size": (4, 4, 4)}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 4.0, 5.0, 0.0], [0.0, 6.0, 7.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(9).reshape((1, 3, 3))), "spatial_size": (-1, 0)}, + p(np.arange(9).reshape(1, 3, 3)), ] - ), - ], - [ - dict(rotate_params=[np.pi / 2], padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(8).reshape((1, 2, 2, 2)), "spatial_size": (4, 4, 4)}, - np.array( + ) + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 6.0, 4.0, 0.0], [0.0, 7.0, 5.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(padding_mode="zeros", device=device, image_only=True), + {"img": p(np.arange(9).reshape((1, 3, 3))), "spatial_size": (-1, 0)}, + p(np.arange(9).reshape(1, 3, 3)), ] - ), - ], -] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.arange(4).reshape(1, 2, 2)), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2))), "spatial_size": (4, 4)}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(rotate_params=[np.pi / 2], padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2))), "spatial_size": (4, 4)}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict( + affine=p(torch.tensor([[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]])), + padding_mode="zeros", + device=device, + ), + {"img": p(np.arange(4).reshape((1, 2, 2))), "spatial_size": (4, 4)}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(27).reshape((1, 3, 3, 3))), "spatial_size": (-1, 0, 0)}, + p(np.arange(27).reshape(1, 3, 3, 3)), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(8).reshape((1, 2, 2, 2))), "spatial_size": (4, 4, 4)}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 2.0, 3.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 4.0, 5.0, 0.0], + [0.0, 6.0, 7.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + dict(rotate_params=[np.pi / 2], padding_mode="zeros", device=device), + {"img": p(np.arange(8).reshape((1, 2, 2, 2))), "spatial_size": (4, 4, 4)}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 0.0, 0.0], + [0.0, 3.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 6.0, 4.0, 0.0], + [0.0, 7.0, 5.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) class TestAffine(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_affine(self, input_param, input_data, expected_val): g = Affine(**input_param) result = g(**input_data) if isinstance(result, tuple): result = result[0] - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_affine_grid.py b/tests/test_affine_grid.py index 24772b9a21..6f6364feda 100644 --- a/tests/test_affine_grid.py +++ b/tests/test_affine_grid.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,88 +16,130 @@ from parameterized import parameterized from monai.transforms import AffineGrid +from tests.utils import TEST_NDARRAYS, assert_allclose, is_tf32_env -TEST_CASES = [ - [ - {"as_tensor_output": False, "device": torch.device("cpu:0")}, - {"spatial_size": (2, 2)}, - np.array([[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]], [[1.0, 1.0], [1.0, 1.0]]]), - ], - [ - {"as_tensor_output": True, "device": None}, - {"spatial_size": (2, 2)}, - torch.tensor([[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]], [[1.0, 1.0], [1.0, 1.0]]]), - ], - [{"as_tensor_output": False, "device": None}, {"grid": np.ones((3, 3, 3))}, np.ones((3, 3, 3))], - [{"as_tensor_output": True, "device": torch.device("cpu:0")}, {"grid": np.ones((3, 3, 3))}, torch.ones((3, 3, 3))], - [{"as_tensor_output": False, "device": None}, {"grid": torch.ones((3, 3, 3))}, np.ones((3, 3, 3))], - [ - {"as_tensor_output": True, "device": torch.device("cpu:0")}, - {"grid": torch.ones((3, 3, 3))}, - torch.ones((3, 3, 3)), - ], - [ - { - "rotate_params": (1.0, 1.0), - "scale_params": (-20, 10), - "as_tensor_output": True, - "device": torch.device("cpu:0"), - }, - {"grid": torch.ones((3, 3, 3))}, - torch.tensor( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [[-19.2208, -19.2208, -19.2208], [-19.2208, -19.2208, -19.2208], [-19.2208, -19.2208, -19.2208]], - [[-11.4264, -11.4264, -11.4264], [-11.4264, -11.4264, -11.4264], [-11.4264, -11.4264, -11.4264]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + {"device": device}, + {"spatial_size": (2, 2)}, + np.array([[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]], [[1.0, 1.0], [1.0, 1.0]]]), ] - ), - ], - [ - { - "rotate_params": (1.0, 1.0, 1.0), - "scale_params": (-20, 10), - "as_tensor_output": True, - "device": torch.device("cpu:0"), - }, - {"grid": torch.ones((4, 3, 3, 3))}, - torch.tensor( + ) + + TESTS.append([{"device": device}, {"grid": p(np.ones((3, 3, 3)))}, p(np.ones((3, 3, 3)))]) + TESTS.append([{"device": device}, {"grid": p(torch.ones((3, 3, 3)))}, p(np.ones((3, 3, 3)))]) + TESTS.append( + [ + {"rotate_params": (1.0, 1.0), "scale_params": (-20, 10), "device": device}, + {"grid": p(torch.ones((3, 3, 3)))}, + p( + torch.tensor( + [ + [ + [-19.2208, -19.2208, -19.2208], + [-19.2208, -19.2208, -19.2208], + [-19.2208, -19.2208, -19.2208], + ], + [ + [-11.4264, -11.4264, -11.4264], + [-11.4264, -11.4264, -11.4264], + [-11.4264, -11.4264, -11.4264], + ], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ) + ), + ] + ) + TESTS.append( [ - [ - [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], - [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], - [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], - ], - [ - [[-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381]], - [[-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381]], - [[-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381]], - ], - [ - [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], - [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], - [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], - ], - [ - [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], - [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], - [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], - ], + { + "affine": p( + torch.tensor( + [[-10.8060, -8.4147, 0.0000], [-16.8294, 5.4030, 0.0000], [0.0000, 0.0000, 1.0000]] + ) + ) + }, + {"grid": p(torch.ones((3, 3, 3)))}, + p( + torch.tensor( + [ + [ + [-19.2208, -19.2208, -19.2208], + [-19.2208, -19.2208, -19.2208], + [-19.2208, -19.2208, -19.2208], + ], + [ + [-11.4264, -11.4264, -11.4264], + [-11.4264, -11.4264, -11.4264], + [-11.4264, -11.4264, -11.4264], + ], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ) + ), ] - ), - ], -] + ) + TESTS.append( + [ + {"rotate_params": (1.0, 1.0, 1.0), "scale_params": (-20, 10), "device": device}, + {"grid": p(torch.ones((4, 3, 3, 3)))}, + p( + torch.tensor( + [ + [ + [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], + [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], + [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], + ], + [ + [ + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + ], + [ + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + ], + [ + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + ], + ], + [ + [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], + [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], + [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], + ], + [ + [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], + [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], + [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], + ], + ] + ) + ), + ] + ) + + +_rtol = 5e-2 if is_tf32_env() else 1e-4 class TestAffineGrid(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_affine_grid(self, input_param, input_data, expected_val): g = AffineGrid(**input_param) result, _ = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + if "device" in input_data: + self.assertEqual(result.device, input_data[device]) + assert_allclose(result, expected_val, type_test=False, rtol=_rtol) if __name__ == "__main__": diff --git a/tests/test_affine_transform.py b/tests/test_affine_transform.py index 42af58be73..5170ab4260 100644 --- a/tests/test_affine_transform.py +++ b/tests/test_affine_transform.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,6 +17,9 @@ from monai.networks import normalize_transform, to_norm_affine from monai.networks.layers import AffineTransform +from tests.utils import is_tf32_env + +_rtol = 1e-4 if not is_tf32_env() else 5e-3 TEST_NORM_CASES = [ [(4, 5), True, [[[0.666667, 0, -1], [0, 0.5, -1], [0, 0, 1]]]], @@ -95,7 +98,7 @@ def test_to_norm_affine(self, affine, src_size, dst_size, align_corners, expecte affine = torch.as_tensor(affine, device=torch.device("cuda:0"), dtype=torch.float32) new_affine = to_norm_affine(affine, src_size, dst_size, align_corners) new_affine = new_affine.detach().cpu().numpy() - np.testing.assert_allclose(new_affine, expected, atol=1e-4) + np.testing.assert_allclose(new_affine, expected, atol=1e-5, rtol=_rtol) @parameterized.expand(TEST_ILL_TO_NORM_AFFINE_CASES) def test_to_norm_affine_ill(self, affine, src_size, dst_size, align_corners): @@ -113,7 +116,7 @@ def test_affine_shift(self): out = AffineTransform()(image, affine) out = out.detach().cpu().numpy() expected = [[[[0, 4, 1, 3], [0, 7, 6, 8], [0, 3, 5, 3]]]] - np.testing.assert_allclose(out, expected, atol=1e-5) + np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol) def test_affine_shift_1(self): affine = torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, -1.0]]) @@ -121,7 +124,7 @@ def test_affine_shift_1(self): out = AffineTransform()(image, affine) out = out.detach().cpu().numpy() expected = [[[[0, 0, 0, 0], [0, 4, 1, 3], [0, 7, 6, 8]]]] - np.testing.assert_allclose(out, expected, atol=1e-5) + np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol) def test_affine_shift_2(self): affine = torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, 0.0]]) @@ -129,28 +132,28 @@ def test_affine_shift_2(self): out = AffineTransform()(image, affine) out = out.detach().cpu().numpy() expected = [[[[0, 0, 0, 0], [4, 1, 3, 2], [7, 6, 8, 5]]]] - np.testing.assert_allclose(out, expected, atol=1e-5) + np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol) def test_zoom(self): affine = torch.as_tensor([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]]) image = torch.arange(1.0, 13.0).view(1, 1, 3, 4).to(device=torch.device("cpu:0")) out = AffineTransform((3, 2))(image, affine) expected = [[[[1, 3], [5, 7], [9, 11]]]] - np.testing.assert_allclose(out, expected, atol=1e-5) + np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol) def test_zoom_1(self): affine = torch.as_tensor([[2.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) image = torch.arange(1.0, 13.0).view(1, 1, 3, 4).to(device=torch.device("cpu:0")) out = AffineTransform()(image, affine, (1, 4)) expected = [[[[1, 2, 3, 4]]]] - np.testing.assert_allclose(out, expected, atol=1e-5) + np.testing.assert_allclose(out, expected, atol=_rtol) def test_zoom_2(self): affine = torch.as_tensor([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0]], dtype=torch.float32) image = torch.arange(1.0, 13.0).view(1, 1, 3, 4).to(device=torch.device("cpu:0")) out = AffineTransform((1, 2))(image, affine) expected = [[[[1, 3]]]] - np.testing.assert_allclose(out, expected, atol=1e-5) + np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol) def test_affine_transform_minimum(self): t = np.pi / 3 @@ -169,7 +172,7 @@ def test_affine_transform_minimum(self): ] ] ] - np.testing.assert_allclose(out, expected, atol=1e-5) + np.testing.assert_allclose(out, expected, atol=1e-3, rtol=_rtol) def test_affine_transform_2d(self): t = np.pi / 3 @@ -188,7 +191,7 @@ def test_affine_transform_2d(self): ] ] ] - np.testing.assert_allclose(out, expected, atol=1e-5) + np.testing.assert_allclose(out, expected, atol=1e-3, rtol=_rtol) if torch.cuda.is_available(): affine = torch.as_tensor(affine, device=torch.device("cuda:0"), dtype=torch.float32) @@ -205,7 +208,7 @@ def test_affine_transform_2d(self): ] ] ] - np.testing.assert_allclose(out, expected, atol=1e-4) + np.testing.assert_allclose(out, expected, atol=5e-3) def test_affine_transform_3d(self): t = np.pi / 3 @@ -231,7 +234,7 @@ def test_affine_transform_3d(self): ] ], ] - np.testing.assert_allclose(out, expected, atol=1e-4) + np.testing.assert_allclose(out, expected, atol=1e-4, rtol=_rtol) if torch.cuda.is_available(): affine = torch.as_tensor(affine, device=torch.device("cuda:0"), dtype=torch.float32) @@ -255,7 +258,7 @@ def test_affine_transform_3d(self): ] ], ] - np.testing.assert_allclose(out, expected, atol=1e-4) + np.testing.assert_allclose(out, expected, atol=5e-3) def test_ill_affine_transform(self): with self.assertRaises(ValueError): # image too small diff --git a/tests/test_affined.py b/tests/test_affined.py index 850f12905d..665c93d23f 100644 --- a/tests/test_affined.py +++ b/tests/test_affined.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,85 +16,152 @@ from parameterized import parameterized from monai.transforms import Affined +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - dict(keys="img", padding_mode="zeros", as_tensor_output=False, spatial_size=(-1, 0), device=None), - {"img": np.arange(9).reshape((1, 3, 3))}, - np.arange(9).reshape(1, 3, 3), - ], - [ - dict(keys="img", padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.arange(4).reshape(1, 2, 2), - ], - [ - dict(keys="img", padding_mode="zeros", spatial_size=(4, 4), as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict( - keys="img", - rotate_params=[np.pi / 2], - padding_mode="zeros", - spatial_size=(4, 4), - as_tensor_output=False, - device=None, - ), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0, 0), as_tensor_output=False, device=None), - {"img": np.arange(27).reshape((1, 3, 3, 3))}, - np.arange(27).reshape(1, 3, 3, 3), - ], - [ - dict(keys="img", padding_mode="zeros", spatial_size=(4, 4, 4), as_tensor_output=False, device=None), - {"img": np.arange(8).reshape((1, 2, 2, 2))}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 4.0, 5.0, 0.0], [0.0, 6.0, 7.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0), device=device), + {"img": p(np.arange(9).reshape((1, 3, 3)))}, + p(np.arange(9).reshape(1, 3, 3)), ] - ), - ], - [ - dict( - keys="img", - rotate_params=[np.pi / 2], - padding_mode="zeros", - spatial_size=(4, 4, 4), - as_tensor_output=False, - device=None, - ), - {"img": np.arange(8).reshape((1, 2, 2, 2))}, - np.array( + ) + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 6.0, 4.0, 0.0], [0.0, 7.0, 5.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0), device=device, dtype=None), + {"img": p(np.arange(9, dtype=float).reshape((1, 3, 3)))}, + p(np.arange(9).reshape(1, 3, 3)), ] - ), - ], -] + ) + TESTS.append( + [ + dict(keys="img", padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.arange(4).reshape(1, 2, 2)), + ] + ) + TESTS.append( + [ + dict(keys="img", padding_mode="zeros", spatial_size=(4, 4), device=device), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(keys="img", rotate_params=[np.pi / 2], padding_mode="zeros", spatial_size=(4, 4), device=device), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict( + keys="img", + affine=p(torch.tensor([[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]])), + padding_mode="zeros", + spatial_size=(4, 4), + device=device, + ), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0, 0), device=device), + {"img": p(np.arange(27).reshape((1, 3, 3, 3)))}, + p(np.arange(27).reshape(1, 3, 3, 3)), + ] + ) + TESTS.append( + [ + dict(keys="img", padding_mode="zeros", spatial_size=(4, 4, 4), device=device), + {"img": p(np.arange(8).reshape((1, 2, 2, 2)))}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 2.0, 3.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 4.0, 5.0, 0.0], + [0.0, 6.0, 7.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + dict( + keys="img", rotate_params=[np.pi / 2], padding_mode="zeros", spatial_size=(4, 4, 4), device=device + ), + {"img": p(np.arange(8).reshape((1, 2, 2, 2)))}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 0.0, 0.0], + [0.0, 3.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 6.0, 4.0, 0.0], + [0.0, 7.0, 5.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) class TestAffined(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_affine(self, input_param, input_data, expected_val): g = Affined(**input_param) result = g(input_data)["img"] - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_ahnet.py b/tests/test_ahnet.py index 777e2637a7..dba8eaf72b 100644 --- a/tests/test_ahnet.py +++ b/tests/test_ahnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -162,26 +162,14 @@ def test_mcfcn_shape(self, input_param, input_shape, expected_shape): class TestAHNET(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_AHNET_2D_1, - TEST_CASE_AHNET_2D_2, - TEST_CASE_AHNET_2D_3, - ] - ) + @parameterized.expand([TEST_CASE_AHNET_2D_1, TEST_CASE_AHNET_2D_2, TEST_CASE_AHNET_2D_3]) def test_ahnet_shape_2d(self, input_param, input_shape, expected_shape): net = AHNet(**input_param).to(device) with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) - @parameterized.expand( - [ - TEST_CASE_AHNET_3D_1, - TEST_CASE_AHNET_3D_2, - TEST_CASE_AHNET_3D_3, - ] - ) + @parameterized.expand([TEST_CASE_AHNET_3D_1, TEST_CASE_AHNET_3D_2, TEST_CASE_AHNET_3D_3]) @skip_if_quick def test_ahnet_shape_3d(self, input_param, input_shape, expected_shape): net = AHNet(**input_param).to(device) @@ -203,11 +191,7 @@ def test_script(self): class TestAHNETWithPretrain(unittest.TestCase): @parameterized.expand( - [ - TEST_CASE_AHNET_3D_WITH_PRETRAIN_1, - TEST_CASE_AHNET_3D_WITH_PRETRAIN_2, - TEST_CASE_AHNET_3D_WITH_PRETRAIN_3, - ] + [TEST_CASE_AHNET_3D_WITH_PRETRAIN_1, TEST_CASE_AHNET_3D_WITH_PRETRAIN_2, TEST_CASE_AHNET_3D_WITH_PRETRAIN_3] ) def test_ahnet_shape(self, input_param, input_shape, expected_shape, fcn_input_param): net = AHNet(**input_param).to(device) diff --git a/tests/test_alias.py b/tests/test_alias.py new file mode 100644 index 0000000000..49f9fa56fe --- /dev/null +++ b/tests/test_alias.py @@ -0,0 +1,39 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import inspect +import os +import unittest + +from monai.utils import optional_import + + +class TestModuleAlias(unittest.TestCase): + """check that 'import monai.xx.file_name' returns a module""" + + def test_files(self): + src_dir = os.path.dirname(os.path.dirname(__file__)) + monai_dir = os.path.join(src_dir, "monai") + py_files = glob.glob(os.path.join(monai_dir, "**", "*.py"), recursive=True) + for x in py_files: + if os.path.basename(x).startswith("_"): + continue + mod_name = x[len(src_dir) : -3] # create relative path + mod_name = mod_name[1:].replace(mod_name[0], ".") + mod, cls = mod_name.rsplit(".", 1) + obj, exist = optional_import(mod, name=cls) + if exist: + self.assertTrue(inspect.ismodule(obj), msg=mod_name) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_apply_filter.py b/tests/test_apply_filter.py new file mode 100644 index 0000000000..3174211f34 --- /dev/null +++ b/tests/test_apply_filter.py @@ -0,0 +1,87 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from monai.networks.layers import apply_filter + + +class ApplyFilterTestCase(unittest.TestCase): + def test_1d(self): + a = torch.tensor([[list(range(10))]], dtype=torch.float) + out = apply_filter(a, torch.tensor([-1, 0, 1]), stride=1) + expected = np.array([[[1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, -8.0]]]) + np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4) + if torch.cuda.is_available(): + out = apply_filter(a.cuda(), torch.tensor([-1, 0, 1]).cuda()) + np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4) + + def test_2d(self): + a = torch.tensor([[[list(range(7)), list(range(7, 0, -1)), list(range(7))]]], dtype=torch.float) + expected = np.array( + [ + [14.0, 21.0, 21.0, 21.0, 21.0, 21.0, 14.0], + [15.0, 24.0, 27.0, 30.0, 33.0, 36.0, 25.0], + [14.0, 21.0, 21.0, 21.0, 21.0, 21.0, 14.0], + ] + ) + expected = expected[None][None] + out = apply_filter(a, torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]])) + np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4) + if torch.cuda.is_available(): + out = apply_filter(a.cuda(), torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]).cuda()) + np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4) + + def test_3d(self): + a = torch.tensor( + [[list(range(7)), list(range(7)), list(range(7))], [list(range(7)), list(range(7)), list(range(7))]], + dtype=torch.float, + ) + a = a[None][None] + a = a.expand(2, 3, -1, -1, -1) + expected = np.array( + [ + [ + [2.0, 6.0, 12.0, 18.0, 24.0, 30.0, 22.0], + [3.0, 9.0, 18.0, 27.0, 36.0, 45.0, 33.0], + [2.0, 6.0, 12.0, 18.0, 24.0, 30.0, 22.0], + ], + [ + [2.0, 6.0, 12.0, 18.0, 24.0, 30.0, 22.0], + [3.0, 9.0, 18.0, 27.0, 36.0, 45.0, 33.0], + [2.0, 6.0, 12.0, 18.0, 24.0, 30.0, 22.0], + ], + ] + ) + expected = expected + # testing shapes + k = torch.tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]]]) + for kernel in (k, k[None], k[None][None]): + out = apply_filter(a, kernel) + np.testing.assert_allclose(out.cpu().numpy()[1][2], expected, rtol=1e-4) + if torch.cuda.is_available(): + out = apply_filter(a.cuda(), kernel.cuda()) + np.testing.assert_allclose(out.cpu().numpy()[0][1], expected, rtol=1e-4) + + def test_wrong_args(self): + with self.assertRaisesRegex(ValueError, ""): + apply_filter(torch.ones((1, 2, 3, 2)), torch.ones((2,))) + with self.assertRaisesRegex(NotImplementedError, ""): + apply_filter(torch.ones((1, 1, 1, 2, 3, 2)), torch.ones((2,))) + with self.assertRaisesRegex(TypeError, ""): + apply_filter(((1, 1, 1, 2, 3, 2)), torch.ones((2,))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_arraydataset.py b/tests/test_arraydataset.py index f6459cc88c..ee1a92cf97 100644 --- a/tests/test_arraydataset.py +++ b/tests/test_arraydataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_as_channel_first.py b/tests/test_as_channel_first.py index 0d1b1c7d3a..a2d56295b8 100644 --- a/tests/test_as_channel_first.py +++ b/tests/test_as_channel_first.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,7 +34,7 @@ def test_value(self, in_type, input_param, expected_shape): if isinstance(test_data, torch.Tensor): test_data = test_data.cpu().numpy() expected = np.moveaxis(test_data, input_param["channel_dim"], 0) - assert_allclose(expected, result) + assert_allclose(result, expected, type_test=False) if __name__ == "__main__": diff --git a/tests/test_as_channel_firstd.py b/tests/test_as_channel_firstd.py index 68d33434c1..91086f9299 100644 --- a/tests/test_as_channel_firstd.py +++ b/tests/test_as_channel_firstd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_as_channel_last.py b/tests/test_as_channel_last.py index 55a7a08676..e6446ab7a6 100644 --- a/tests/test_as_channel_last.py +++ b/tests/test_as_channel_last.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_as_channel_lastd.py b/tests/test_as_channel_lastd.py index 350f639f3f..a6d94d216a 100644 --- a/tests/test_as_channel_lastd.py +++ b/tests/test_as_channel_lastd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index bb9457a357..a68e6431ec 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,52 +11,62 @@ import unittest -import torch from parameterized import parameterized from monai.transforms import AsDiscrete +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"argmax": True, "to_onehot": False, "num_classes": None, "threshold_values": False, "logit_thresh": 0.5}, - torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), - torch.tensor([[[1.0, 1.0]]]), - (1, 1, 2), -] +TEST_CASES = [] +for p in TEST_NDARRAYS: + TEST_CASES.append( + [ + {"argmax": True, "to_onehot": None, "threshold": 0.5}, + p([[[0.0, 1.0]], [[2.0, 3.0]]]), + p([[[1.0, 1.0]]]), + (1, 1, 2), + ] + ) -TEST_CASE_2 = [ - {"argmax": True, "to_onehot": True, "num_classes": 2, "threshold_values": False, "logit_thresh": 0.5}, - torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), - torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]]), - (2, 1, 2), -] + TEST_CASES.append( + [ + {"argmax": True, "to_onehot": 2, "threshold": 0.5}, + p([[[0.0, 1.0]], [[2.0, 3.0]]]), + p([[[0.0, 0.0]], [[1.0, 1.0]]]), + (2, 1, 2), + ] + ) -TEST_CASE_3 = [ - {"argmax": False, "to_onehot": False, "num_classes": None, "threshold_values": True, "logit_thresh": 0.6}, - torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), - torch.tensor([[[0.0, 1.0], [1.0, 1.0]]]), - (1, 2, 2), -] + TEST_CASES.append( + [ + {"argmax": False, "to_onehot": None, "threshold": 0.6}, + p([[[0.0, 1.0], [2.0, 3.0]]]), + p([[[0.0, 1.0], [1.0, 1.0]]]), + (1, 2, 2), + ] + ) -TEST_CASE_4 = [ - {"argmax": False, "to_onehot": True, "num_classes": 3}, - torch.tensor(1), - torch.tensor([0.0, 1.0, 0.0]), - (3,), -] + # test threshold = 0.0 + TEST_CASES.append( + [ + {"argmax": False, "to_onehot": None, "threshold": 0.0}, + p([[[0.0, -1.0], [-2.0, 3.0]]]), + p([[[1.0, 0.0], [0.0, 1.0]]]), + (1, 2, 2), + ] + ) -TEST_CASE_5 = [ - {"rounding": "torchrounding"}, - torch.tensor([[[0.123, 1.345], [2.567, 3.789]]]), - torch.tensor([[[0.0, 1.0], [3.0, 4.0]]]), - (1, 2, 2), -] + TEST_CASES.append([{"argmax": False, "to_onehot": 3}, p(1), p([0.0, 1.0, 0.0]), (3,)]) + + TEST_CASES.append( + [{"rounding": "torchrounding"}, p([[[0.123, 1.345], [2.567, 3.789]]]), p([[[0.0, 1.0], [3.0, 4.0]]]), (1, 2, 2)] + ) class TestAsDiscrete(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TEST_CASES) def test_value_shape(self, input_param, img, out, expected_shape): result = AsDiscrete(**input_param)(img) - torch.testing.assert_allclose(result, out) + assert_allclose(result, out, rtol=1e-3) self.assertTupleEqual(result.shape, expected_shape) diff --git a/tests/test_as_discreted.py b/tests/test_as_discreted.py index 90e98b297b..21825c2d6c 100644 --- a/tests/test_as_discreted.py +++ b/tests/test_as_discreted.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,69 +11,84 @@ import unittest -import torch from parameterized import parameterized from monai.transforms import AsDiscreted +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - { - "keys": ["pred", "label"], - "argmax": [True, False], - "to_onehot": True, - "num_classes": 2, - "threshold_values": False, - "logit_thresh": 0.5, - }, - {"pred": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), "label": torch.tensor([[[0, 1]]])}, - {"pred": torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]]), "label": torch.tensor([[[1.0, 0.0]], [[0.0, 1.0]]])}, - (2, 1, 2), -] +TEST_CASES = [] +for p in TEST_NDARRAYS: + TEST_CASES.append( + [ + {"keys": ["pred", "label"], "argmax": [True, False], "to_onehot": 2, "threshold": 0.5}, + {"pred": p([[[0.0, 1.0]], [[2.0, 3.0]]]), "label": p([[[0, 1]]])}, + {"pred": p([[[0.0, 0.0]], [[1.0, 1.0]]]), "label": p([[[1.0, 0.0]], [[0.0, 1.0]]])}, + (2, 1, 2), + ] + ) -TEST_CASE_2 = [ - { - "keys": ["pred", "label"], - "argmax": False, - "to_onehot": False, - "num_classes": None, - "threshold_values": [True, False], - "logit_thresh": 0.6, - }, - {"pred": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), "label": torch.tensor([[[0, 1], [1, 1]]])}, - {"pred": torch.tensor([[[0.0, 1.0], [1.0, 1.0]]]), "label": torch.tensor([[[0.0, 1.0], [1.0, 1.0]]])}, - (1, 2, 2), -] + TEST_CASES.append( + [ + {"keys": ["pred", "label"], "argmax": False, "to_onehot": None, "threshold": [0.6, None]}, + {"pred": p([[[0.0, 1.0], [2.0, 3.0]]]), "label": p([[[0, 1], [1, 1]]])}, + {"pred": p([[[0.0, 1.0], [1.0, 1.0]]]), "label": p([[[0.0, 1.0], [1.0, 1.0]]])}, + (1, 2, 2), + ] + ) -TEST_CASE_3 = [ - { - "keys": ["pred"], - "argmax": True, - "to_onehot": True, - "num_classes": 2, - "threshold_values": False, - "logit_thresh": 0.5, - }, - {"pred": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]])}, - {"pred": torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]])}, - (2, 1, 2), -] + TEST_CASES.append( + [ + {"keys": ["pred"], "argmax": True, "to_onehot": 2, "threshold": 0.5}, + {"pred": p([[[0.0, 1.0]], [[2.0, 3.0]]])}, + {"pred": p([[[0.0, 0.0]], [[1.0, 1.0]]])}, + (2, 1, 2), + ] + ) -TEST_CASE_4 = [ - {"keys": "pred", "rounding": "torchrounding"}, - {"pred": torch.tensor([[[0.123, 1.345], [2.567, 3.789]]])}, - {"pred": torch.tensor([[[0.0, 1.0], [3.0, 4.0]]])}, - (1, 2, 2), -] + TEST_CASES.append( + [ + {"keys": "pred", "rounding": "torchrounding"}, + {"pred": p([[[0.123, 1.345], [2.567, 3.789]]])}, + {"pred": p([[[0.0, 1.0], [3.0, 4.0]]])}, + (1, 2, 2), + ] + ) + + # test compatible with previous versions + TEST_CASES.append( + [ + { + "keys": ["pred", "label"], + "argmax": False, + "to_onehot": None, + "threshold": [True, None], + "logit_thresh": 0.6, + }, + {"pred": p([[[0.0, 1.0], [2.0, 3.0]]]), "label": p([[[0, 1], [1, 1]]])}, + {"pred": p([[[0.0, 1.0], [1.0, 1.0]]]), "label": p([[[0.0, 1.0], [1.0, 1.0]]])}, + (1, 2, 2), + ] + ) + + # test threshold = 0.0 + TEST_CASES.append( + [ + {"keys": ["pred", "label"], "argmax": False, "to_onehot": None, "threshold": [0.0, None]}, + {"pred": p([[[0.0, -1.0], [-2.0, 3.0]]]), "label": p([[[0, 1], [1, 1]]])}, + {"pred": p([[[1.0, 0.0], [0.0, 1.0]]]), "label": p([[[0.0, 1.0], [1.0, 1.0]]])}, + (1, 2, 2), + ] + ) class TestAsDiscreted(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand(TEST_CASES) def test_value_shape(self, input_param, test_input, output, expected_shape): result = AsDiscreted(**input_param)(test_input) - torch.testing.assert_allclose(result["pred"], output["pred"]) + assert_allclose(result["pred"], output["pred"], rtol=1e-3) self.assertTupleEqual(result["pred"].shape, expected_shape) if "label" in result: - torch.testing.assert_allclose(result["label"], output["label"]) + assert_allclose(result["label"], output["label"], rtol=1e-3) self.assertTupleEqual(result["label"].shape, expected_shape) diff --git a/tests/test_attentionunet.py b/tests/test_attentionunet.py new file mode 100644 index 0000000000..b2f53f9c16 --- /dev/null +++ b/tests/test_attentionunet.py @@ -0,0 +1,65 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +import monai.networks.nets.attentionunet as att +from tests.utils import skip_if_no_cuda + + +class TestAttentionUnet(unittest.TestCase): + def test_attention_block(self): + for dims in [2, 3]: + block = att.AttentionBlock(dims, f_int=2, f_g=6, f_l=6) + shape = (4, 6) + (30,) * dims + x = torch.rand(*shape, dtype=torch.float32) + output = block(x, x) + self.assertEqual(output.shape, x.shape) + + block = att.AttentionBlock(dims, f_int=2, f_g=3, f_l=6) + xshape = (4, 6) + (30,) * dims + x = torch.rand(*xshape, dtype=torch.float32) + gshape = (4, 3) + (30,) * dims + g = torch.rand(*gshape, dtype=torch.float32) + output = block(g, x) + self.assertEqual(output.shape, x.shape) + + def test_attentionunet(self): + for dims in [2, 3]: + shape = (3, 1) + (92,) * dims + input = torch.rand(*shape) + model = att.AttentionUnet( + spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), strides=(2, 2) + ) + output = model(input) + self.assertEqual(output.shape[2:], input.shape[2:]) + self.assertEqual(output.shape[0], input.shape[0]) + self.assertEqual(output.shape[1], 2) + + @skip_if_no_cuda + def test_attentionunet_gpu(self): + for dims in [2, 3]: + shape = (3, 1) + (92,) * dims + input = torch.rand(*shape).to("cuda:0") + model = att.AttentionUnet( + spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), strides=(2, 2) + ).to("cuda:0") + with torch.no_grad(): + output = model(input) + self.assertEqual(output.shape[2:], input.shape[2:]) + self.assertEqual(output.shape[0], input.shape[0]) + self.assertEqual(output.shape[1], 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_autoencoder.py b/tests/test_autoencoder.py index 54d6832c8d..bed5a198ff 100644 --- a/tests/test_autoencoder.py +++ b/tests/test_autoencoder.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,7 +23,7 @@ TEST_CASE_0 = [ # single channel 2D, batch 4, no residual { - "dimensions": 2, + "spatial_dims": 2, "in_channels": 1, "out_channels": 1, "channels": (4, 8, 16), @@ -35,20 +35,14 @@ ] TEST_CASE_1 = [ # single channel 2D, batch 4 - { - "dimensions": 2, - "in_channels": 1, - "out_channels": 1, - "channels": (4, 8, 16), - "strides": (2, 2, 2), - }, + {"spatial_dims": 2, "in_channels": 1, "out_channels": 1, "channels": (4, 8, 16), "strides": (2, 2, 2)}, (1, 1, 128, 128), (1, 1, 128, 128), ] TEST_CASE_2 = [ # 3-channel 2D, batch 4, LeakyReLU activation { - "dimensions": 2, + "spatial_dims": 2, "in_channels": 3, "out_channels": 3, "channels": (4, 8, 16), @@ -60,13 +54,7 @@ ] TEST_CASE_3 = [ # 4-channel 3D, batch 4 - { - "dimensions": 3, - "in_channels": 4, - "out_channels": 3, - "channels": (4, 8, 16), - "strides": (2, 2, 2), - }, + {"spatial_dims": 3, "in_channels": 4, "out_channels": 3, "channels": (4, 8, 16), "strides": (2, 2, 2)}, (1, 4, 128, 128, 128), (1, 3, 128, 128, 128), ] @@ -75,7 +63,7 @@ TEST_CASE_FAIL = { # 2-channel 2D, should fail because of stride/channel mismatch. - "dimensions": 2, + "spatial_dims": 2, "in_channels": 2, "out_channels": 2, "channels": (4, 8, 16), @@ -92,7 +80,7 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) def test_script(self): - net = AutoEncoder(dimensions=2, in_channels=1, out_channels=1, channels=(4, 8), strides=(2, 2)) + net = AutoEncoder(spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8), strides=(2, 2)) test_data = torch.randn(2, 1, 32, 32) test_script_save(net, test_data) diff --git a/tests/test_basic_unet.py b/tests/test_basic_unet.py index 09d7f72d0e..a4f88367dd 100644 --- a/tests/test_basic_unet.py +++ b/tests/test_basic_unet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,20 +20,10 @@ CASES_1D = [] for mode in ["pixelshuffle", "nontrainable", "deconv", None]: - kwargs = { - "dimensions": 1, - "in_channels": 5, - "out_channels": 8, - } + kwargs = {"spatial_dims": 1, "in_channels": 5, "out_channels": 8} if mode is not None: kwargs["upsample"] = mode # type: ignore - CASES_1D.append( - [ - kwargs, - (10, 5, 33), - (10, 8, 33), - ] - ) + CASES_1D.append([kwargs, (10, 5, 33), (10, 8, 33)]) CASES_2D = [] for mode in ["pixelshuffle", "nontrainable", "deconv"]: @@ -43,7 +33,7 @@ CASES_2D.append( [ { - "dimensions": 2, + "spatial_dims": 2, "in_channels": in_channels, "out_channels": out_channels, "features": (12, 12, 13, 14, 15, 16), @@ -56,7 +46,7 @@ CASES_3D = [ [ # single channel 3D, batch 2 { - "dimensions": 3, + "spatial_dims": 3, "in_channels": 1, "out_channels": 2, "features": (16, 20, 21, 22, 23, 11), @@ -67,7 +57,7 @@ ], [ # 2-channel 3D, batch 3 { - "dimensions": 3, + "spatial_dims": 3, "in_channels": 2, "out_channels": 7, "features": (14, 15, 16, 17, 18, 11), @@ -78,7 +68,7 @@ ], [ # 4-channel 3D, batch 5 { - "dimensions": 3, + "spatial_dims": 3, "in_channels": 4, "out_channels": 2, "features": (14, 15, 16, 17, 18, 10), @@ -101,7 +91,7 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) def test_script(self): - net = BasicUNet(dimensions=2, in_channels=1, out_channels=3) + net = BasicUNet(spatial_dims=2, in_channels=1, out_channels=3) test_data = torch.randn(16, 1, 32, 32) test_script_save(net, test_data) diff --git a/tests/test_bending_energy.py b/tests/test_bending_energy.py index 8f1fb43535..318b1905df 100644 --- a/tests/test_bending_energy.py +++ b/tests/test_bending_energy.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,31 +20,30 @@ device = "cuda" if torch.cuda.is_available() else "cpu" TEST_CASES = [ + [{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0], + [{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 0.0], [ - {}, - {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, - 0.0, - ], - [ - {}, - {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, - 0.0, - ], - [ - {}, + {"normalize": False}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, 4.0, ], [ - {}, - {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 3, 5, 5) ** 2}, + {"normalize": False}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, 4.0, ], + [{"normalize": False}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 4.0], [ - {}, - {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 3, 5) ** 2}, - 4.0, + {"normalize": True}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + 100.0, + ], + [ + {"normalize": True}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, + 100.0, ], + [{"normalize": True}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 100.0], ] @@ -57,18 +56,24 @@ def test_shape(self, input_param, input_data, expected_val): def test_ill_shape(self): loss = BendingEnergyLoss() # not in 3-d, 4-d, 5-d - with self.assertRaisesRegex(ValueError, ""): + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): loss.forward(torch.ones((1, 3), device=device)) - with self.assertRaisesRegex(ValueError, ""): - loss.forward(torch.ones((1, 3, 5, 5, 5, 5), device=device)) + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): + loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device)) # spatial_dim < 5 - with self.assertRaisesRegex(ValueError, ""): + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): loss.forward(torch.ones((1, 3, 4, 5, 5), device=device)) - with self.assertRaisesRegex(ValueError, ""): + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): loss.forward(torch.ones((1, 3, 5, 4, 5))) - with self.assertRaisesRegex(ValueError, ""): + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): loss.forward(torch.ones((1, 3, 5, 5, 4))) + # number of vector components unequal to number of spatial dims + with self.assertRaisesRegex(ValueError, "Number of vector components"): + loss.forward(torch.ones((1, 2, 5, 5, 5))) + with self.assertRaisesRegex(ValueError, "Number of vector components"): + loss.forward(torch.ones((1, 2, 5, 5, 5))) + def test_ill_opts(self): pred = torch.rand(1, 3, 5, 5, 5).to(device=device) with self.assertRaisesRegex(ValueError, ""): diff --git a/tests/test_bilateral_approx_cpu.py b/tests/test_bilateral_approx_cpu.py index 7960f76591..b3f4f5c3be 100644 --- a/tests/test_bilateral_approx_cpu.py +++ b/tests/test_bilateral_approx_cpu.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -306,7 +306,7 @@ # Frame 4 [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], ] - ], + ] ], # Expected [ diff --git a/tests/test_bilateral_approx_cuda.py b/tests/test_bilateral_approx_cuda.py index 345a920f3c..db34b0ff71 100644 --- a/tests/test_bilateral_approx_cuda.py +++ b/tests/test_bilateral_approx_cuda.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -306,7 +306,7 @@ # Frame 4 [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], ] - ], + ] ], # Expected [ diff --git a/tests/test_bilateral_precise.py b/tests/test_bilateral_precise.py index dfa3ca107d..b19369d758 100644 --- a/tests/test_bilateral_precise.py +++ b/tests/test_bilateral_precise.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -306,7 +306,7 @@ # Frame 4 [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], ] - ], + ] ], # Expected [ diff --git a/tests/test_blend_images.py b/tests/test_blend_images.py new file mode 100644 index 0000000000..6fea53ac30 --- /dev/null +++ b/tests/test_blend_images.py @@ -0,0 +1,53 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.case import skipUnless + +import torch +from parameterized import parameterized + +from monai.data.synthetic import create_test_image_2d, create_test_image_3d +from monai.transforms.utils_pytorch_numpy_unification import moveaxis +from monai.utils.module import optional_import +from monai.visualize.utils import blend_images +from tests.utils import TEST_NDARRAYS + +plt, has_matplotlib = optional_import("matplotlib.pyplot") + +TESTS = [] +for p in TEST_NDARRAYS: + image, label = create_test_image_2d(100, 101) + TESTS.append((p(image), p(label))) + + image, label = create_test_image_3d(100, 101, 102) + TESTS.append((p(image), p(label))) + + +@skipUnless(has_matplotlib, "Matplotlib required") +class TestBlendImages(unittest.TestCase): + @parameterized.expand(TESTS) + def test_blend(self, image, label): + blended = blend_images(image[None], label[None]) + self.assertEqual(type(image), type(blended)) + if isinstance(blended, torch.Tensor): + self.assertEqual(blended.device, image.device) + blended = blended.cpu().numpy() + self.assertEqual((3,) + image.shape, blended.shape) + + blended = moveaxis(blended, 0, -1) # move RGB component to end + if blended.ndim > 3: + blended = blended[blended.shape[0] // 2] + plt.imshow(blended) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_border_pad.py b/tests/test_border_pad.py index 9e6a8a6a08..b632ff831f 100644 --- a/tests/test_border_pad.py +++ b/tests/test_border_pad.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,17 +18,9 @@ from monai.utils import NumpyPadMode from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"spatial_border": 2, "mode": "constant"}, - np.zeros((3, 8, 8, 4)), - np.zeros((3, 12, 12, 8)), -] +TEST_CASE_1 = [{"spatial_border": 2, "mode": "constant"}, np.zeros((3, 8, 8, 4)), np.zeros((3, 12, 12, 8))] -TEST_CASE_2 = [ - {"spatial_border": [1, 2, 3], "mode": "constant"}, - np.zeros((3, 8, 8, 4)), - np.zeros((3, 10, 12, 10)), -] +TEST_CASE_2 = [{"spatial_border": [1, 2, 3], "mode": "constant"}, np.zeros((3, 8, 8, 4)), np.zeros((3, 10, 12, 10))] TEST_CASE_3 = [ {"spatial_border": [1, 2, 3, 4, 5, 6], "mode": "constant"}, diff --git a/tests/test_border_padd.py b/tests/test_border_padd.py index b48629fc98..e4b8dd20ea 100644 --- a/tests/test_border_padd.py +++ b/tests/test_border_padd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_bounding_rect.py b/tests/test_bounding_rect.py index 38585cba18..a7c2648f1e 100644 --- a/tests/test_bounding_rect.py +++ b/tests/test_bounding_rect.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,6 +16,7 @@ import monai from monai.transforms import BoundingRect +from tests.utils import TEST_NDARRAYS TEST_CASE_1 = [(2, 3), [[0, 0], [1, 2]]] @@ -35,14 +36,16 @@ def tearDown(self): def test_shape(self, input_shape, expected): test_data = np.random.randint(0, 8, size=input_shape) test_data = test_data == 7 - result = BoundingRect()(test_data) - np.testing.assert_allclose(result, expected) + for p in TEST_NDARRAYS: + result = BoundingRect()(p(test_data)) + np.testing.assert_allclose(result, expected) def test_select_fn(self): test_data = np.random.randint(0, 8, size=(2, 3)) test_data = test_data == 7 - bbox = BoundingRect(select_fn=lambda x: x < 1)(test_data) - np.testing.assert_allclose(bbox, [[0, 3], [0, 3]]) + for p in TEST_NDARRAYS: + bbox = BoundingRect(select_fn=lambda x: x < 1)(p(test_data)) + np.testing.assert_allclose(bbox, [[0, 3], [0, 3]]) if __name__ == "__main__": diff --git a/tests/test_bounding_rectd.py b/tests/test_bounding_rectd.py index 6e725ff583..47ed854263 100644 --- a/tests/test_bounding_rectd.py +++ b/tests/test_bounding_rectd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,6 +16,7 @@ import monai from monai.transforms import BoundingRectD +from tests.utils import TEST_NDARRAYS TEST_CASE_1 = [(2, 3), [[0, 0], [1, 2]]] @@ -35,14 +36,15 @@ def tearDown(self): def test_shape(self, input_shape, expected): test_data = np.random.randint(0, 8, size=input_shape) test_data = test_data == 7 - result = BoundingRectD("image")({"image": test_data}) - np.testing.assert_allclose(result["image_bbox"], expected) + for p in TEST_NDARRAYS: + result = BoundingRectD("image")({"image": p(test_data)}) + np.testing.assert_allclose(result["image_bbox"], expected) - result = BoundingRectD("image", "cc")({"image": test_data}) - np.testing.assert_allclose(result["image_cc"], expected) + result = BoundingRectD("image", "cc")({"image": p(test_data)}) + np.testing.assert_allclose(result["image_cc"], expected) - with self.assertRaises(KeyError): - BoundingRectD("image", "cc")({"image": test_data, "image_cc": None}) + with self.assertRaises(KeyError): + BoundingRectD("image", "cc")({"image": p(test_data), "image_cc": None}) if __name__ == "__main__": diff --git a/tests/test_bundle_verify_metadata.py b/tests/test_bundle_verify_metadata.py new file mode 100644 index 0000000000..b018c9a568 --- /dev/null +++ b/tests/test_bundle_verify_metadata.py @@ -0,0 +1,69 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import subprocess +import sys +import tempfile +import unittest + +from parameterized import parameterized + +from monai.bundle import ConfigParser +from tests.utils import download_url_or_skip_test, skip_if_windows, testing_data_config + +SCHEMA_FILE = os.path.join(os.path.dirname(__file__), "testing_data", "schema.json") + +TEST_CASE_1 = [os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json"), SCHEMA_FILE] + + +@skip_if_windows +class TestVerifyMetaData(unittest.TestCase): + def setUp(self): + self.config = testing_data_config("configs", "test_meta_file") + download_url_or_skip_test( + url=self.config["url"], + filepath=SCHEMA_FILE, + hash_val=self.config.get("hash_val"), + hash_type=self.config.get("hash_type", "sha256"), + ) + + @parameterized.expand([TEST_CASE_1]) + def test_verify(self, meta_file, schema_file): + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + with tempfile.TemporaryDirectory() as tempdir: + def_args = {"meta_file": "will be replaced by `meta_file` arg"} + def_args_file = os.path.join(tempdir, "def_args.json") + ConfigParser.export_config_file(config=def_args, filepath=def_args_file) + + cmd = [sys.executable, "-m", "monai.bundle", "verify_metadata", "--meta_file", meta_file] + cmd += ["--filepath", schema_file, "--hash_val", self.config["hash_val"], "--args_file", def_args_file] + ret = subprocess.check_call(cmd) + self.assertEqual(ret, 0) + + def test_verify_error(self): + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + with tempfile.TemporaryDirectory() as tempdir: + filepath = os.path.join(tempdir, "schema.json") + metafile = os.path.join(tempdir, "metadata.json") + meta_dict = {"schema": self.config["url"], "wrong_meta": "wrong content"} + with open(metafile, "w") as f: + json.dump(meta_dict, f) + + cmd = [sys.executable, "-m", "monai.bundle", "verify_metadata", metafile, "--filepath", filepath] + ret = subprocess.check_call(cmd) + self.assertEqual(ret, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_bundle_verify_net.py b/tests/test_bundle_verify_net.py new file mode 100644 index 0000000000..62f99aab99 --- /dev/null +++ b/tests/test_bundle_verify_net.py @@ -0,0 +1,49 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import sys +import tempfile +import unittest + +from parameterized import parameterized + +from monai.bundle import ConfigParser +from tests.utils import skip_if_windows + +TEST_CASE_1 = [ + os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json"), + os.path.join(os.path.dirname(__file__), "testing_data", "inference.json"), +] + + +@skip_if_windows +class TestVerifyNetwork(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) + def test_verify(self, meta_file, config_file): + with tempfile.TemporaryDirectory() as tempdir: + def_args = {"meta_file": "will be replaced by `meta_file` arg", "p": 2} + def_args_file = os.path.join(tempdir, "def_args.json") + ConfigParser.export_config_file(config=def_args, filepath=def_args_file) + + cmd = [sys.executable, "-m", "monai.bundle", "verify_net_in_out", "network_def", "--meta_file", meta_file] + cmd += ["--config_file", config_file, "-n", "2", "--any", "32", "--args_file", def_args_file] + cmd += ["--_meta_#network_data_format#inputs#image#spatial_shape", "[32,'*','4**p*n']"] + + test_env = os.environ.copy() + print(f"CUDA_VISIBLE_DEVICES in {__file__}", test_env.get("CUDA_VISIBLE_DEVICES")) + ret = subprocess.check_call(cmd, env=test_env) + self.assertEqual(ret, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_cachedataset.py b/tests/test_cachedataset.py index bbb8143631..7227f53e04 100644 --- a/tests/test_cachedataset.py +++ b/tests/test_cachedataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,8 +19,8 @@ from parameterized import parameterized from monai.data import CacheDataset, DataLoader, PersistentDataset, SmartCacheDataset -from monai.transforms import Compose, Lambda, LoadImaged, ThreadUnsafe, Transform -from monai.utils import get_torch_version_tuple +from monai.transforms import Compose, Lambda, LoadImaged, RandLambda, ThreadUnsafe, Transform +from monai.utils.module import pytorch_after TEST_CASE_1 = [Compose([LoadImaged(keys=["image", "label", "extra"])]), (128, 128, 128)] @@ -42,25 +42,13 @@ class TestCacheDataset(unittest.TestCase): def test_shape(self, transform, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) with tempfile.TemporaryDirectory() as tempdir: - nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) - test_data = [ - { - "image": os.path.join(tempdir, "test_image1.nii.gz"), - "label": os.path.join(tempdir, "test_label1.nii.gz"), - "extra": os.path.join(tempdir, "test_extra1.nii.gz"), - }, - { - "image": os.path.join(tempdir, "test_image2.nii.gz"), - "label": os.path.join(tempdir, "test_label2.nii.gz"), - "extra": os.path.join(tempdir, "test_extra2.nii.gz"), - }, - ] - dataset = CacheDataset(data=test_data, transform=transform, cache_rate=0.5) + test_data = [] + for i in ["1", "2"]: + for k in ["image", "label", "extra"]: + nib.save(test_image, os.path.join(tempdir, f"{k}{i}.nii.gz")) + test_data.append({k: os.path.join(tempdir, f"{k}{i}.nii.gz") for k in ["image", "label", "extra"]}) + + dataset = CacheDataset(data=test_data, transform=transform, cache_rate=0.5, as_contiguous=True) data1 = dataset[0] data2 = dataset[1] data3 = dataset[0:-1] @@ -68,9 +56,9 @@ def test_shape(self, transform, expected_shape): self.assertEqual(len(data3), 1) if transform is None: - self.assertEqual(data1["image"], os.path.join(tempdir, "test_image1.nii.gz")) - self.assertEqual(data2["label"], os.path.join(tempdir, "test_label2.nii.gz")) - self.assertEqual(data4["image"], os.path.join(tempdir, "test_image2.nii.gz")) + self.assertEqual(data1["image"], os.path.join(tempdir, "image1.nii.gz")) + self.assertEqual(data2["label"], os.path.join(tempdir, "label2.nii.gz")) + self.assertEqual(data4["image"], os.path.join(tempdir, "image2.nii.gz")) else: self.assertTupleEqual(data1["image"].shape, expected_shape) self.assertTupleEqual(data1["label"].shape, expected_shape) @@ -84,7 +72,7 @@ def test_shape(self, transform, expected_shape): def test_set_data(self): data_list1 = list(range(10)) - transform = Lambda(func=lambda x: np.array([x * 10])) + transform = Compose([Lambda(func=lambda x: np.array([x * 10])), RandLambda(func=lambda x: x + 1)]) dataset = CacheDataset( data=data_list1, @@ -92,19 +80,23 @@ def test_set_data(self): cache_rate=1.0, num_workers=4, progress=True, + copy_cache=False if sys.platform == "linux" else True, ) num_workers = 2 if sys.platform == "linux" else 0 dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=1) for i, d in enumerate(dataloader): - np.testing.assert_allclose([[data_list1[i] * 10]], d) + np.testing.assert_allclose([[data_list1[i] * 10 + 1]], d) + # simulate another epoch, the cache content should not be modified + for i, d in enumerate(dataloader): + np.testing.assert_allclose([[data_list1[i] * 10 + 1]], d) # update the datalist and fill the cache content data_list2 = list(range(-10, 0)) dataset.set_data(data=data_list2) # rerun with updated cache content for i, d in enumerate(dataloader): - np.testing.assert_allclose([[data_list2[i] * 10]], d) + np.testing.assert_allclose([[data_list2[i] * 10 + 1]], d) class _StatefulTransform(Transform, ThreadUnsafe): @@ -130,14 +122,10 @@ class TestCacheThread(unittest.TestCase): @parameterized.expand(TEST_DS) def test_thread_safe(self, persistent_workers, cache_workers, loader_workers): expected = [102, 202, 302, 402, 502, 602, 702, 802, 902, 1002] - _kwg = {"persistent_workers": persistent_workers} if get_torch_version_tuple() > (1, 7) else {} + _kwg = {"persistent_workers": persistent_workers} if pytorch_after(1, 8) else {} data_list = list(range(1, 11)) dataset = CacheDataset( - data=data_list, - transform=_StatefulTransform(), - cache_rate=1.0, - num_workers=cache_workers, - progress=False, + data=data_list, transform=_StatefulTransform(), cache_rate=1.0, num_workers=cache_workers, progress=False ) self.assertListEqual(expected, list(dataset)) loader = DataLoader( @@ -195,6 +183,46 @@ def test_thread_safe(self, persistent_workers, cache_workers, loader_workers): self.assertListEqual(expected, [y.item() for y in loader]) self.assertListEqual(expected, [y.item() for y in loader]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_hash_as_key(self, transform, expected_shape): + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) + with tempfile.TemporaryDirectory() as tempdir: + test_data = [] + for i in ["1", "2", "2", "3", "3"]: + for k in ["image", "label", "extra"]: + nib.save(test_image, os.path.join(tempdir, f"{k}{i}.nii.gz")) + test_data.append({k: os.path.join(tempdir, f"{k}{i}.nii.gz") for k in ["image", "label", "extra"]}) + + dataset = CacheDataset(data=test_data, transform=transform, cache_num=4, num_workers=2, hash_as_key=True) + self.assertEqual(len(dataset), 5) + # ensure no duplicated cache content + self.assertEqual(len(dataset._cache), 3) + self.assertEqual(dataset.cache_num, 3) + data1 = dataset[0] + data2 = dataset[1] + data3 = dataset[-1] + # test slice indices + data4 = dataset[0:-1] + self.assertEqual(len(data4), 4) + + if transform is None: + self.assertEqual(data1["image"], os.path.join(tempdir, "image1.nii.gz")) + self.assertEqual(data2["label"], os.path.join(tempdir, "label2.nii.gz")) + self.assertEqual(data3["image"], os.path.join(tempdir, "image3.nii.gz")) + else: + self.assertTupleEqual(data1["image"].shape, expected_shape) + self.assertTupleEqual(data2["label"].shape, expected_shape) + self.assertTupleEqual(data3["image"].shape, expected_shape) + for d in data4: + self.assertTupleEqual(d["image"].shape, expected_shape) + + test_data2 = test_data[:3] + dataset.set_data(data=test_data2) + self.assertEqual(len(dataset), 3) + # ensure no duplicated cache content + self.assertEqual(len(dataset._cache), 2) + self.assertEqual(dataset.cache_num, 2) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_cachedataset_parallel.py b/tests/test_cachedataset_parallel.py index 0be3ba085b..f409e17787 100644 --- a/tests/test_cachedataset_parallel.py +++ b/tests/test_cachedataset_parallel.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -42,12 +42,7 @@ def test_shape(self, num_workers, dataset_size, transform): "extra": os.path.join(tempdir, "test_extra1.nii.gz"), } ] * dataset_size - dataset = CacheDataset( - data=test_data, - transform=transform, - cache_rate=1, - num_workers=num_workers, - ) + dataset = CacheDataset(data=test_data, transform=transform, cache_rate=1, num_workers=num_workers) self.assertEqual(len(dataset._cache), dataset.cache_num) for i in range(dataset.cache_num): diff --git a/tests/test_cachedataset_persistent_workers.py b/tests/test_cachedataset_persistent_workers.py index 584a053614..4bea0486bc 100644 --- a/tests/test_cachedataset_persistent_workers.py +++ b/tests/test_cachedataset_persistent_workers.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,20 +19,16 @@ @SkipIfBeforePyTorchVersion((1, 7)) class TestTransformsWCacheDatasetAndPersistentWorkers(unittest.TestCase): def test_duplicate_transforms(self): - im, _ = create_test_image_2d(128, 128, num_seg_classes=1, channel_dim=0) - data = [{"img": im} for _ in range(2)] + data = [{"img": create_test_image_2d(128, 128, num_seg_classes=1, channel_dim=0)[0]} for _ in range(2)] # at least 1 deterministic followed by at least 1 random - transform = Compose( - [ - Spacingd("img", pixdim=(1, 1)), - RandAffined("img", prob=1.0), - ] - ) + transform = Compose([Spacingd("img", pixdim=(1, 1)), RandAffined("img", prob=1.0)]) # cachedataset and data loader w persistent_workers train_ds = CacheDataset(data, transform, cache_num=1) - train_loader = DataLoader(train_ds, num_workers=2, persistent_workers=True) + # num_workers > 1 may fail randomly with 21.09 on A100 test node + # https://github.com/Project-MONAI/MONAI/issues/3283 + train_loader = DataLoader(train_ds, num_workers=1, persistent_workers=True) b1 = next(iter(train_loader)) b2 = next(iter(train_loader)) diff --git a/tests/test_cachentransdataset.py b/tests/test_cachentransdataset.py index 492db8b16f..99ca0e0c3d 100644 --- a/tests/test_cachentransdataset.py +++ b/tests/test_cachentransdataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -42,19 +42,13 @@ def test_n_trans(self, transform, expected_shape): cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data") dataset_precached = CacheNTransDataset( - data=test_data, - transform=transform, - cache_dir=cache_dir, - cache_n_trans=2, + data=test_data, transform=transform, cache_dir=cache_dir, cache_n_trans=2 ) data_precached = dataset_precached[0] self.assertTupleEqual(data_precached["image"].shape, expected_shape) dataset_postcached = CacheNTransDataset( - data=test_data, - transform=transform, - cache_dir=cache_dir, - cache_n_trans=2, + data=test_data, transform=transform, cache_dir=cache_dir, cache_n_trans=2 ) data_postcached = dataset_postcached[0] self.assertTupleEqual(data_postcached["image"].shape, expected_shape) diff --git a/tests/test_distcall.py b/tests/test_call_dist.py similarity index 96% rename from tests/test_distcall.py rename to tests/test_call_dist.py index 1830a85654..bed8289506 100644 --- a/tests/test_distcall.py +++ b/tests/test_call_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_cast_to_type.py b/tests/test_cast_to_type.py index 0ef25cbafa..82daabc4e7 100644 --- a/tests/test_cast_to_type.py +++ b/tests/test_cast_to_type.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,14 +16,23 @@ from parameterized import parameterized from monai.transforms import CastToType +from monai.utils import optional_import from monai.utils.type_conversion import get_equivalent_dtype -from tests.utils import TEST_NDARRAYS +from tests.utils import HAS_CUPY, TEST_NDARRAYS + +cp, _ = optional_import("cupy") TESTS = [] for p in TEST_NDARRAYS: for out_dtype in (np.float64, torch.float64): TESTS.append([out_dtype, p(np.array([[0, 1], [1, 2]], dtype=np.float32)), out_dtype]) +TESTS_CUPY = [ + [np.float32, np.array([[0, 1], [1, 2]], dtype=np.float32), np.float32], + [np.float32, np.array([[0, 1], [1, 2]], dtype=np.uint8), np.float32], + [np.uint8, np.array([[0, 1], [1, 2]], dtype=np.float32), np.uint8], +] + class TestCastToType(unittest.TestCase): @parameterized.expand(TESTS) @@ -35,6 +44,19 @@ def test_type(self, out_dtype, input_data, expected_type): result = CastToType()(input_data, out_dtype) self.assertEqual(result.dtype, get_equivalent_dtype(expected_type, type(result))) + @parameterized.expand(TESTS_CUPY) + @unittest.skipUnless(HAS_CUPY, "Requires CuPy") + def test_type_cupy(self, out_dtype, input_data, expected_type): + input_data = cp.asarray(input_data) + + result = CastToType(dtype=out_dtype)(input_data) + self.assertTrue(isinstance(result, cp.ndarray)) + self.assertEqual(result.dtype, get_equivalent_dtype(expected_type, type(result))) + + result = CastToType()(input_data, out_dtype) + self.assertTrue(isinstance(result, cp.ndarray)) + self.assertEqual(result.dtype, get_equivalent_dtype(expected_type, type(result))) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_cast_to_typed.py b/tests/test_cast_to_typed.py index be495564fb..4c7623a9e0 100644 --- a/tests/test_cast_to_typed.py +++ b/tests/test_cast_to_typed.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,6 +16,10 @@ from parameterized import parameterized from monai.transforms import CastToTyped +from monai.utils import optional_import +from tests.utils import HAS_CUPY + +cp, _ = optional_import("cupy") TEST_CASE_1 = [ {"keys": ["img"], "dtype": np.float64}, @@ -33,6 +37,20 @@ ] +TESTS_CUPY = [ + [ + {"keys": "image", "dtype": np.uint8}, + {"image": np.array([[0, 1], [1, 2]], dtype=np.float32), "label": np.array([[0, 1], [1, 1]], dtype=np.float32)}, + {"image": np.uint8, "label": np.float32}, + ], + [ + {"keys": ["image", "label"], "dtype": np.float32}, + {"image": np.array([[0, 1], [1, 2]], dtype=np.uint8), "label": np.array([[0, 1], [1, 1]], dtype=np.uint8)}, + {"image": np.float32, "label": np.float32}, + ], +] + + class TestCastToTyped(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_type(self, input_param, input_data, expected_type): @@ -40,6 +58,16 @@ def test_type(self, input_param, input_data, expected_type): for k, v in result.items(): self.assertEqual(v.dtype, expected_type[k]) + @parameterized.expand(TESTS_CUPY) + @unittest.skipUnless(HAS_CUPY, "Requires CuPy") + def test_type_cupy(self, input_param, input_data, expected_type): + input_data = {k: cp.asarray(v) for k, v in input_data.items()} + + result = CastToTyped(**input_param)(input_data) + for k, v in result.items(): + self.assertTrue(isinstance(v, cp.ndarray)) + self.assertEqual(v.dtype, expected_type[k]) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_center_scale_crop.py b/tests/test_center_scale_crop.py index e28849ce90..f22651e3e0 100644 --- a/tests/test_center_scale_crop.py +++ b/tests/test_center_scale_crop.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -38,11 +38,13 @@ class TestCenterScaleCrop(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) def test_shape(self, input_param, input_data, expected_shape): result = CenterScaleCrop(**input_param)(input_data) + self.assertEqual(isinstance(result, torch.Tensor), isinstance(input_data, torch.Tensor)) np.testing.assert_allclose(result.shape, expected_shape) @parameterized.expand([TEST_CASE_2]) def test_value(self, input_param, input_data, expected_value): result = CenterScaleCrop(**input_param)(input_data) + self.assertEqual(isinstance(result, torch.Tensor), isinstance(input_data, torch.Tensor)) np.testing.assert_allclose(result, expected_value) diff --git a/tests/test_center_scale_cropd.py b/tests/test_center_scale_cropd.py index 313e8e7f7e..8aef2dbe5b 100644 --- a/tests/test_center_scale_cropd.py +++ b/tests/test_center_scale_cropd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -33,9 +33,15 @@ (3, 2, 2, 2), ] +TEST_CASE_4 = [ + {"keys": "test", "roi_scale": 0.6, "allow_missing_keys": True}, + np.random.randint(0, 2, size=[3, 3, 3, 3]), + (3, 3, 3, 3), +] + class TestCenterScaleCropd(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3, TEST_CASE_4]) def test_shape(self, input_param, input_data, expected_shape): result = CenterScaleCropd(**input_param)({"img": input_data}) np.testing.assert_allclose(result["img"].shape, expected_shape) diff --git a/tests/test_center_spatial_crop.py b/tests/test_center_spatial_crop.py index 3e828176a5..09f61be2f1 100644 --- a/tests/test_center_spatial_crop.py +++ b/tests/test_center_spatial_crop.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -38,11 +38,13 @@ class TestCenterSpatialCrop(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) def test_shape(self, input_param, input_data, expected_shape): result = CenterSpatialCrop(**input_param)(input_data) + self.assertEqual(isinstance(result, torch.Tensor), isinstance(input_data, torch.Tensor)) np.testing.assert_allclose(result.shape, expected_shape) @parameterized.expand([TEST_CASE_2]) def test_value(self, input_param, input_data, expected_value): result = CenterSpatialCrop(**input_param)(input_data) + self.assertEqual(isinstance(result, torch.Tensor), isinstance(input_data, torch.Tensor)) np.testing.assert_allclose(result, expected_value) diff --git a/tests/test_center_spatial_cropd.py b/tests/test_center_spatial_cropd.py index 349253ab56..bdbc1a5031 100644 --- a/tests/test_center_spatial_cropd.py +++ b/tests/test_center_spatial_cropd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,36 +15,43 @@ from parameterized import parameterized from monai.transforms import CenterSpatialCropd - -TEST_CASE_0 = [ - {"keys": "img", "roi_size": [2, -1, -1]}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 2, 3, 3), -] - -TEST_CASE_1 = [ - {"keys": "img", "roi_size": [2, 2, 2]}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 2, 2, 2), -] - -TEST_CASE_2 = [ - {"keys": "img", "roi_size": [2, 2]}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[1, 2], [2, 3]]]), -] +from tests.utils import TEST_NDARRAYS, assert_allclose + +TEST_SHAPES = [] +for p in TEST_NDARRAYS: + TEST_SHAPES.append( + [{"keys": "img", "roi_size": [2, -1, -1]}, {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, (3, 2, 3, 3)] + ) + + TEST_SHAPES.append( + [{"keys": "img", "roi_size": [2, 2, 2]}, {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, (3, 2, 2, 2)] + ) + +TEST_CASES = [] +for p in TEST_NDARRAYS: + TEST_CASES.append( + [ + {"keys": "img", "roi_size": [2, 2]}, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ) + }, + p(np.array([[[1, 2], [2, 3]]])), + ] + ) class TestCenterSpatialCropd(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) + @parameterized.expand(TEST_SHAPES) def test_shape(self, input_param, input_data, expected_shape): result = CenterSpatialCropd(**input_param)(input_data) self.assertTupleEqual(result["img"].shape, expected_shape) - @parameterized.expand([TEST_CASE_2]) + @parameterized.expand(TEST_CASES) def test_value(self, input_param, input_data, expected_value): result = CenterSpatialCropd(**input_param)(input_data) - np.testing.assert_allclose(result["img"], expected_value) + assert_allclose(result["img"], expected_value, type_test=False) if __name__ == "__main__": diff --git a/tests/test_channel_pad.py b/tests/test_channel_pad.py index ebc731c321..bde0f18d83 100644 --- a/tests/test_channel_pad.py +++ b/tests/test_channel_pad.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_check_hash.py b/tests/test_check_hash.py index 0126b3c1a3..5297021540 100644 --- a/tests/test_check_hash.py +++ b/tests/test_check_hash.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -41,7 +41,7 @@ def test_result(self, md5_value, t, expected_result): self.assertTrue(result == expected_result) def test_hash_type_error(self): - with self.assertRaises(NotImplementedError): + with self.assertRaises(ValueError): with tempfile.TemporaryDirectory() as tempdir: check_hash(tempdir, "test_hash", "test_type") diff --git a/tests/test_check_missing_files.py b/tests/test_check_missing_files.py new file mode 100644 index 0000000000..ecd0b52a63 --- /dev/null +++ b/tests/test_check_missing_files.py @@ -0,0 +1,56 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest +from pathlib import Path + +import nibabel as nib +import numpy as np + +from monai.data import check_missing_files + + +class TestCheckMissingFiles(unittest.TestCase): + def test_content(self): + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) + with tempfile.TemporaryDirectory() as tempdir: + nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) + + datalist = [ + { + "image": os.path.join(tempdir, "test_image1.nii.gz"), + "label": [os.path.join(tempdir, "test_label1.nii.gz"), os.path.join(tempdir, "test_extra1.nii.gz")], + }, + { + "image": Path(os.path.join(tempdir, "test_image2.nii.gz")), + "label": Path(os.path.join(tempdir, "test_label_missing.nii.gz")), + }, + ] + + missings = check_missing_files(datalist=datalist, keys=["image", "label"]) + self.assertEqual(len(missings), 1) + self.assertEqual(str(missings[0]), os.path.join(tempdir, "test_label_missing.nii.gz")) + + # test with missing key and relative path + datalist = [{"image": "test_image1.nii.gz", "label": "test_label_missing.nii.gz"}] + missings = check_missing_files( + datalist=datalist, keys=["image", "label", "test"], root_dir=tempdir, allow_missing_keys=True + ) + self.assertEqual(f"{missings[0]}", os.path.join(tempdir, "test_label_missing.nii.gz")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_classes_to_indices.py b/tests/test_classes_to_indices.py index 0ba3dd094a..1f39e0f480 100644 --- a/tests/test_classes_to_indices.py +++ b/tests/test_classes_to_indices.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,68 +11,80 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import ClassesToIndices +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - # test Argmax data - {"num_classes": 3, "image_threshold": 0.0}, - np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), - None, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] +TESTS_CASES = [] +for p in TEST_NDARRAYS: + TESTS_CASES.append( + [ + # test Argmax data + {"num_classes": 3, "image_threshold": 0.0}, + p([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), + None, + [p([0, 4, 8]), p([1, 5, 6]), p([2, 3, 7])], + ] + ) -TEST_CASE_2 = [ - {"num_classes": 3, "image_threshold": 60}, - np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), - np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + TESTS_CASES.append( + [ + {"num_classes": 3, "image_threshold": 60}, + p([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), + p([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), + [p([0, 8]), p([1, 5, 6]), p([3])], + ] + ) -TEST_CASE_3 = [ - # test One-Hot data - {"image_threshold": 0.0}, - np.array( + TESTS_CASES.append( [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + # test One-Hot data + {"image_threshold": 0.0}, + p( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ), + None, + [p([0, 4, 8]), p([1, 5, 6]), p([2, 3, 7])], ] - ), - None, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] + ) -TEST_CASE_4 = [ - {"num_classes": None, "image_threshold": 60}, - np.array( + TESTS_CASES.append( [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + {"num_classes": None, "image_threshold": 60}, + p( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ), + p([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), + [p([0, 8]), p([1, 5, 6]), p([3])], ] - ), - np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + ) -TEST_CASE_5 = [ - # test output_shape - {"num_classes": 3, "image_threshold": 0.0, "output_shape": [3, 3]}, - np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), - None, - [np.array([[0, 0], [1, 1], [2, 2]]), np.array([[0, 1], [1, 2], [2, 0]]), np.array([[0, 2], [1, 0], [2, 1]])], -] + TESTS_CASES.append( + [ + # test output_shape + {"num_classes": 3, "image_threshold": 0.0, "output_shape": [3, 3]}, + p([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), + None, + [p([[0, 0], [1, 1], [2, 2]]), p([[0, 1], [1, 2], [2, 0]]), p([[0, 2], [1, 0], [2, 1]])], + ] + ) class TestClassesToIndices(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS_CASES) def test_value(self, input_args, label, image, expected_indices): indices = ClassesToIndices(**input_args)(label, image) for i, e in zip(indices, expected_indices): - np.testing.assert_allclose(i, e) + assert_allclose(i, e) if __name__ == "__main__": diff --git a/tests/test_classes_to_indicesd.py b/tests/test_classes_to_indicesd.py index 67fac95c8c..398620d304 100644 --- a/tests/test_classes_to_indicesd.py +++ b/tests/test_classes_to_indicesd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,73 +11,91 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import ClassesToIndicesd +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - # test Argmax data - {"keys": "label", "num_classes": 3, "image_threshold": 0.0}, - {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])}, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] +TESTS_CASES = [] +for p in TEST_NDARRAYS: + TESTS_CASES.append( + [ + # test Argmax data + {"keys": "label", "num_classes": 3, "image_threshold": 0.0}, + {"label": p([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])}, + [p([0, 4, 8]), p([1, 5, 6]), p([2, 3, 7])], + ] + ) -TEST_CASE_2 = [ - {"keys": "label", "image_key": "image", "num_classes": 3, "image_threshold": 60}, - { - "label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), - "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - }, - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + TESTS_CASES.append( + [ + {"keys": "label", "image_key": "image", "num_classes": 3, "image_threshold": 60}, + { + "label": p([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), + "image": p([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), + }, + [p([0, 8]), p([1, 5, 6]), p([3])], + ] + ) -TEST_CASE_3 = [ - # test One-Hot data - {"keys": "label", "image_threshold": 0.0}, - { - "label": np.array( - [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], - ] - ) - }, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] + TESTS_CASES.append( + [ + # test One-Hot data + {"keys": "label", "image_threshold": 0.0}, + { + "label": p( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ) + }, + [p([0, 4, 8]), p([1, 5, 6]), p([2, 3, 7])], + ] + ) -TEST_CASE_4 = [ - {"keys": "label", "image_key": "image", "num_classes": None, "image_threshold": 60}, - { - "label": np.array( - [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], - ] - ), - "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - }, - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + TESTS_CASES.append( + [ + {"keys": "label", "image_key": "image", "num_classes": None, "image_threshold": 60}, + { + "label": p( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ), + "image": p([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), + }, + [p([0, 8]), p([1, 5, 6]), p([3])], + ] + ) -TEST_CASE_5 = [ - # test output_shape - {"keys": "label", "indices_postfix": "cls", "num_classes": 3, "image_threshold": 0.0, "output_shape": [3, 3]}, - {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])}, - [np.array([[0, 0], [1, 1], [2, 2]]), np.array([[0, 1], [1, 2], [2, 0]]), np.array([[0, 2], [1, 0], [2, 1]])], -] + TESTS_CASES.append( + [ + # test output_shape + { + "keys": "label", + "indices_postfix": "cls", + "num_classes": 3, + "image_threshold": 0.0, + "output_shape": [3, 3], + }, + {"label": p([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])}, + [p([[0, 0], [1, 1], [2, 2]]), p([[0, 1], [1, 2], [2, 0]]), p([[0, 2], [1, 0], [2, 1]])], + ] + ) class TestClassesToIndicesd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS_CASES) def test_value(self, input_args, input_data, expected_indices): result = ClassesToIndicesd(**input_args)(input_data) key_postfix = input_args.get("indices_postfix") key_postfix = "_cls_indices" if key_postfix is None else key_postfix for i, e in zip(result["label" + key_postfix], expected_indices): - np.testing.assert_allclose(i, e) + assert_allclose(i, e) if __name__ == "__main__": diff --git a/tests/test_component_locator.py b/tests/test_component_locator.py new file mode 100644 index 0000000000..ebb2cca7b3 --- /dev/null +++ b/tests/test_component_locator.py @@ -0,0 +1,35 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from pydoc import locate + +from monai.bundle import ComponentLocator +from monai.utils import optional_import + +_, has_ignite = optional_import("ignite") + + +class TestComponentLocator(unittest.TestCase): + def test_locate(self): + locator = ComponentLocator(excludes=None if has_ignite else ["monai.handlers"]) + # test init mapping table and get the module path of component + self.assertEqual(locator.get_component_module_name("LoadImage"), "monai.transforms.io.array") + self.assertGreater(len(locator._components_table), 0) + for _, mods in locator._components_table.items(): + for i in mods: + self.assertGreater(len(mods), 0) + # ensure we can locate all the items by `name` + self.assertIsNotNone(locate(i), msg=f"can not locate target: {i}.") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_compose.py b/tests/test_compose.py index 28783cad23..4d1bcfe01c 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -154,7 +154,7 @@ def __call__(self, data): c.randomize() def test_err_msg(self): - transforms = Compose([abs, AddChannel(), round]) + transforms = Compose([abs, AddChannel(), round], log_stats=False) with self.assertRaisesRegex(Exception, "AddChannel"): transforms(42.1) diff --git a/tests/test_compose_get_number_conversions.py b/tests/test_compose_get_number_conversions.py index eb10c7d5ef..fca5bc727d 100644 --- a/tests/test_compose_get_number_conversions.py +++ b/tests/test_compose_get_number_conversions.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_compute_confusion_matrix.py b/tests/test_compute_confusion_matrix.py index 69a95e0c8b..1212715548 100644 --- a/tests/test_compute_confusion_matrix.py +++ b/tests/test_compute_confusion_matrix.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -174,22 +174,10 @@ # 3. test metric with compute_sample, denominator may have zeros TEST_CASES_COMPUTE_SAMPLE_NAN = [] metric_names = ["tpr", "tnr"] -result_sum = [ - torch.tensor([0.5000]), - torch.tensor([4.8333]), -] -not_nans_sum = [ - torch.tensor([6]), - torch.tensor([8]), -] -result_sum_batch = [ - torch.tensor([0.0000, 0.5000, 0.0000]), - torch.tensor([1.6667, 2.5000, 0.6667]), -] -not_nans_sum_batch = [ - torch.tensor([3.0, 2.0, 1.0]), - torch.tensor([2.0, 3.0, 3.0]), -] +result_sum = [torch.tensor([0.5000]), torch.tensor([4.8333])] +not_nans_sum = [torch.tensor([6]), torch.tensor([8])] +result_sum_batch = [torch.tensor([0.0000, 0.5000, 0.0000]), torch.tensor([1.6667, 2.5000, 0.6667])] +not_nans_sum_batch = [torch.tensor([3.0, 2.0, 1.0]), torch.tensor([2.0, 3.0, 3.0])] for idx in range(2): for reduction in ["sum", "sum_batch"]: TEST_CASE = [data_nan.copy()] diff --git a/tests/test_compute_froc.py b/tests/test_compute_froc.py index 70de836dd9..d68f3f7fb4 100644 --- a/tests/test_compute_froc.py +++ b/tests/test_compute_froc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_compute_meandice.py b/tests/test_compute_meandice.py index f9e494efc7..ad66ed672a 100644 --- a/tests/test_compute_meandice.py +++ b/tests/test_compute_meandice.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -168,10 +168,7 @@ ] TEST_CASE_10 = [ - { - "y": [torch.ones((2, 3, 3)), torch.ones((2, 3, 3))], - "y_pred": [torch.ones((2, 3, 3)), torch.ones((2, 3, 3))], - }, + {"y": [torch.ones((2, 3, 3)), torch.ones((2, 3, 3))], "y_pred": [torch.ones((2, 3, 3)), torch.ones((2, 3, 3))]}, [[1.0000, 1.0000], [1.0000, 1.0000]], ] diff --git a/tests/test_compute_regression_metrics.py b/tests/test_compute_regression_metrics.py index 126eab3f07..65ca73a4ec 100644 --- a/tests/test_compute_regression_metrics.py +++ b/tests/test_compute_regression_metrics.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_compute_roc_auc.py b/tests/test_compute_roc_auc.py index 1cec357b93..887db08c7c 100644 --- a/tests/test_compute_roc_auc.py +++ b/tests/test_compute_roc_auc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,7 +23,7 @@ torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), torch.tensor([[0], [1], [0], [1]]), True, - True, + 2, "macro", 0.75, ] @@ -32,34 +32,20 @@ torch.tensor([[0.5], [0.5], [0.2], [8.3]]), torch.tensor([[0], [1], [0], [1]]), False, - False, + None, "macro", 0.875, ] -TEST_CASE_3 = [ - torch.tensor([[0.5], [0.5], [0.2], [8.3]]), - torch.tensor([0, 1, 0, 1]), - False, - False, - "macro", - 0.875, -] +TEST_CASE_3 = [torch.tensor([[0.5], [0.5], [0.2], [8.3]]), torch.tensor([0, 1, 0, 1]), False, None, "macro", 0.875] -TEST_CASE_4 = [ - torch.tensor([0.5, 0.5, 0.2, 8.3]), - torch.tensor([0, 1, 0, 1]), - False, - False, - "macro", - 0.875, -] +TEST_CASE_4 = [torch.tensor([0.5, 0.5, 0.2, 8.3]), torch.tensor([0, 1, 0, 1]), False, None, "macro", 0.875] TEST_CASE_5 = [ torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), torch.tensor([[0], [1], [0], [1]]), True, - True, + 2, "none", [0.75, 0.75], ] @@ -68,7 +54,7 @@ torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), True, - False, + None, "weighted", 0.56667, ] @@ -77,26 +63,79 @@ torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), True, - False, + None, "micro", 0.62, ] +TEST_CASE_8 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), + torch.tensor([[0], [0], [0], [0]]), + True, + 2, + "macro", + float("nan"), +] + +TEST_CASE_9 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), + torch.tensor([[1], [1], [1], [1]]), + True, + 2, + "macro", + float("nan"), +] + +TEST_CASE_10 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), + torch.tensor([[0, 0], [1, 1], [2, 2], [3, 3]]), + True, + None, + "macro", + float("nan"), +] + class TestComputeROCAUC(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + @parameterized.expand( + [ + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + TEST_CASE_10, + ] + ) def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value): y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)]) - y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot, num_classes=2)]) + y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot)]) y_pred = torch.stack([y_pred_trans(i) for i in decollate_batch(y_pred)], dim=0) y = torch.stack([y_trans(i) for i in decollate_batch(y)], dim=0) result = compute_roc_auc(y_pred=y_pred, y=y, average=average) np.testing.assert_allclose(expected_value, result, rtol=1e-5) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + @parameterized.expand( + [ + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + TEST_CASE_10, + ] + ) def test_class_value(self, y_pred, y, softmax, to_onehot, average, expected_value): y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)]) - y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot, num_classes=2)]) + y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot)]) y_pred = [y_pred_trans(i) for i in decollate_batch(y_pred)] y = [y_trans(i) for i in decollate_batch(y)] metric = ROCAUCMetric(average=average) diff --git a/tests/test_concat_itemsd.py b/tests/test_concat_itemsd.py index 9c51e1efea..2f98738233 100644 --- a/tests/test_concat_itemsd.py +++ b/tests/test_concat_itemsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_config_item.py b/tests/test_config_item.py new file mode 100644 index 0000000000..fbd76e7be7 --- /dev/null +++ b/tests/test_config_item.py @@ -0,0 +1,124 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import Callable + +import torch +from parameterized import parameterized + +import monai +from monai.bundle import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem +from monai.data import DataLoader, Dataset +from monai.transforms import LoadImaged, RandTorchVisiond +from monai.utils import min_version, optional_import + +_, has_tv = optional_import("torchvision", "0.8.0", min_version) + +TEST_CASE_1 = [{"lr": 0.001}, 0.0001] + +TEST_CASE_2 = [{"_target_": "LoadImaged", "keys": ["image"]}, LoadImaged] +# test full module path +TEST_CASE_3 = [{"_target_": "monai.transforms.LoadImaged", "keys": ["image"]}, LoadImaged] +# test `_disabled_` +TEST_CASE_4 = [{"_target_": "LoadImaged", "_disabled_": True, "keys": ["image"]}, dict] +# test `_disabled_` with string +TEST_CASE_5 = [{"_target_": "LoadImaged", "_disabled_": "true", "keys": ["image"]}, dict] +# test non-monai modules and excludes +TEST_CASE_6 = [{"_target_": "torch.optim.Adam", "params": torch.nn.PReLU().parameters(), "lr": 1e-4}, torch.optim.Adam] +TEST_CASE_7 = [{"_target_": "decollate_batch", "detach": True, "pad": True}, partial] +# test args contains "name" field +TEST_CASE_8 = [ + {"_target_": "RandTorchVisiond", "keys": "image", "name": "ColorJitter", "brightness": 0.25}, + RandTorchVisiond, +] +# test execute some function in args, test pre-imported global packages `monai` +TEST_CASE_9 = ["collate_fn", "$monai.data.list_data_collate"] +# test lambda function, should not execute the lambda function, just change the string +TEST_CASE_10 = ["collate_fn", "$lambda x: monai.data.list_data_collate(x) + torch.tensor(var)"] + + +class TestConfigItem(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) + def test_item(self, test_input, expected): + item = ConfigItem(config=test_input) + conf = item.get_config() + conf["lr"] = 0.0001 + item.update_config(config=conf) + self.assertEqual(item.get_config()["lr"], expected) + + @parameterized.expand( + [TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7] + + ([TEST_CASE_8] if has_tv else []) + ) + def test_component(self, test_input, output_type): + locator = ComponentLocator(excludes=["metrics"]) + configer = ConfigComponent(id="test", config=test_input, locator=locator) + ret = configer.instantiate() + if test_input.get("_disabled_", False): + # test `_disabled_` works fine + self.assertEqual(ret, None) + return + self.assertTrue(isinstance(ret, output_type)) + if isinstance(ret, LoadImaged): + self.assertEqual(ret.keys[0], "image") + + @parameterized.expand([TEST_CASE_9, TEST_CASE_10]) + def test_expression(self, id, test_input): + configer = ConfigExpression(id=id, config=test_input, globals={"monai": monai, "torch": torch}) + var = 100 + ret = configer.evaluate(locals={"var": var}) + self.assertTrue(isinstance(ret, Callable)) + + def test_lazy_instantiation(self): + config = {"_target_": "DataLoader", "dataset": Dataset(data=[1, 2]), "batch_size": 2} + configer = ConfigComponent(config=config, locator=None) + init_config = configer.get_config() + # modify config content at runtime + init_config["batch_size"] = 4 + configer.update_config(config=init_config) + + ret = configer.instantiate() + self.assertTrue(isinstance(ret, DataLoader)) + self.assertEqual(ret.batch_size, 4) + + @parameterized.expand([("$import json", "json"), ("$import json as j", "j")]) + def test_import(self, stmt, mod_name): + test_globals = {} + ConfigExpression(id="", config=stmt, globals=test_globals).evaluate() + self.assertTrue(callable(test_globals[mod_name].dump)) + + @parameterized.expand( + [ + ("$from json import dump", "dump"), + ("$from json import dump, dumps", "dump"), + ("$from json import dump as jd", "jd"), + ("$from json import dump as jd, dumps as ds", "jd"), + ] + ) + def test_import_from(self, stmt, mod_name): + test_globals = {} + ConfigExpression(id="", config=stmt, globals=test_globals).evaluate() + self.assertTrue(callable(test_globals[mod_name])) + self.assertTrue(ConfigExpression.is_import_statement(ConfigExpression(id="", config=stmt).config)) + + @parameterized.expand( + [("$from json import dump", True), ("$print()", False), ("$import json", True), ("import json", False)] + ) + def test_is_import_stmt(self, stmt, expected): + expr = ConfigExpression(id="", config=stmt) + flag = expr.is_import_statement(expr.config) + self.assertEqual(flag, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py new file mode 100644 index 0000000000..8b1076b1f7 --- /dev/null +++ b/tests/test_config_parser.py @@ -0,0 +1,147 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest import skipUnless + +from parameterized import parameterized + +from monai.bundle.config_parser import ConfigParser +from monai.data import DataLoader, Dataset +from monai.transforms import Compose, LoadImaged, RandTorchVisiond +from monai.utils import min_version, optional_import + +_, has_tv = optional_import("torchvision", "0.8.0", min_version) + +# test the resolved and parsed instances +TEST_CASE_1 = [ + { + "transform": { + "_target_": "Compose", + "transforms": [ + {"_target_": "LoadImaged", "keys": "image"}, + # test relative id in `keys` + {"_target_": "RandTorchVisiond", "keys": "@##0#keys", "name": "ColorJitter", "brightness": 0.25}, + ], + }, + "dataset": {"_target_": "Dataset", "data": [1, 2], "transform": "@transform"}, + "dataloader": { + "_target_": "DataLoader", + # test relative id in `dataset` + "dataset": "@##dataset", + "batch_size": 2, + "collate_fn": "$monai.data.list_data_collate", + }, + }, + ["transform", "transform#transforms#0", "transform#transforms#1", "dataset", "dataloader"], + [Compose, LoadImaged, RandTorchVisiond, Dataset, DataLoader], +] + + +class TestClass: + @staticmethod + def compute(a, b, func=lambda x, y: x + y): + return func(a, b) + + @classmethod + def cls_compute(cls, a, b, func=lambda x, y: x + y): + return cls.compute(a, b, func) + + def __call__(self, a, b): + return self.compute(a, b) + + +TEST_CASE_2 = [ + { + "basic_func": "$lambda x, y: x + y", + "static_func": "$TestClass.compute", + "cls_func": "$TestClass.cls_compute", + "lambda_static_func": "$lambda x, y: TestClass.compute(x, y)", + "lambda_cls_func": "$lambda x, y: TestClass.cls_compute(x, y)", + "compute": {"_target_": "tests.test_config_parser.TestClass.compute", "func": "@basic_func"}, + "cls_compute": {"_target_": "tests.test_config_parser.TestClass.cls_compute", "func": "@basic_func"}, + "call_compute": {"_target_": "tests.test_config_parser.TestClass"}, + "error_func": "$TestClass.__call__", + "": "$lambda x, y: x + y", + } +] + + +TEST_CASE_3 = [ + { + "A": 1, + "B": "@A", + "C": "@#A", + "D": {"key": "@##A", "value1": 2, "value2": "%#value1", "value3": [3, 4, "@#1", "$100 + @#0 + @##value1"]}, + } +] + + +class TestConfigParser(unittest.TestCase): + def test_config_content(self): + test_config = {"preprocessing": [{"_target_": "LoadImage"}], "dataset": {"_target_": "Dataset"}} + parser = ConfigParser(config=test_config) + # test `get`, `set`, `__getitem__`, `__setitem__` + self.assertEqual(str(parser.get()), str(test_config)) + parser.set(config=test_config) + self.assertListEqual(parser["preprocessing"], test_config["preprocessing"]) + parser["dataset"] = {"_target_": "CacheDataset"} + self.assertEqual(parser["dataset"]["_target_"], "CacheDataset") + # test nested ids + parser["dataset#_target_"] = "Dataset" + self.assertEqual(parser["dataset#_target_"], "Dataset") + # test int id + parser.set(["test1", "test2", "test3"]) + parser[1] = "test4" + self.assertEqual(parser[1], "test4") + + @parameterized.expand([TEST_CASE_1]) + @skipUnless(has_tv, "Requires torchvision >= 0.8.0.") + def test_parse(self, config, expected_ids, output_types): + parser = ConfigParser(config=config, globals={"monai": "monai"}) + # test lazy instantiation with original config content + parser["transform"]["transforms"][0]["keys"] = "label1" + self.assertEqual(parser.get_parsed_content(id="transform#transforms#0").keys[0], "label1") + # test nested id + parser["transform#transforms#0#keys"] = "label2" + self.assertEqual(parser.get_parsed_content(id="transform#transforms#0").keys[0], "label2") + for id, cls in zip(expected_ids, output_types): + self.assertTrue(isinstance(parser.get_parsed_content(id), cls)) + # test root content + root = parser.get_parsed_content(id="") + for v, cls in zip(root.values(), [Compose, Dataset, DataLoader]): + self.assertTrue(isinstance(v, cls)) + + @parameterized.expand([TEST_CASE_2]) + def test_function(self, config): + parser = ConfigParser(config=config, globals={"TestClass": TestClass}) + for id in config: + func = parser.get_parsed_content(id=id) + self.assertTrue(id in parser.ref_resolver.resolved_content) + if id == "error_func": + with self.assertRaises(TypeError): + func(1, 2) + continue + self.assertEqual(func(1, 2), 3) + + @parameterized.expand([TEST_CASE_3]) + def test_relative_id(self, config): + parser = ConfigParser(config=config) + for id in config: + item = parser.get_parsed_content(id=id) + if isinstance(item, int): + self.assertEqual(item, 1) + if isinstance(item, dict): + self.assertEqual(str(item), str({"key": 1, "value1": 2, "value2": 2, "value3": [3, 4, 4, 105]})) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_contrastive_loss.py b/tests/test_contrastive_loss.py new file mode 100644 index 0000000000..5dce860486 --- /dev/null +++ b/tests/test_contrastive_loss.py @@ -0,0 +1,79 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses import ContrastiveLoss + +TEST_CASES = [ + [ # shape: (1, 4), (1, 4) + {"temperature": 0.5, "batch_size": 1}, + {"input": torch.tensor([[1.0, 1.0, 0.0, 0.0]]), "target": torch.tensor([[1.0, 1.0, 0.0, 0.0]])}, + 0.0, + ], + [ # shape: (2, 4), (2, 4) + {"temperature": 0.5, "batch_size": 2}, + { + "input": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), + "target": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), + }, + 1.0986, + ], + [ # shape: (1, 4), (1, 4) + {"temperature": 0.5, "batch_size": 2}, + { + "input": torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 1.0, 0.0, 0.0]]), + "target": torch.tensor([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), + }, + 0.8719, + ], + [ # shape: (1, 4), (1, 4) + {"temperature": 0.5, "batch_size": 1}, + {"input": torch.tensor([[0.0, 0.0, 1.0, 1.0]]), "target": torch.tensor([[1.0, 1.0, 0.0, 0.0]])}, + 0.0, + ], + [ # shape: (1, 4), (1, 4) + {"temperature": 0.05, "batch_size": 1}, + {"input": torch.tensor([[0.0, 0.0, 1.0, 1.0]]), "target": torch.tensor([[1.0, 1.0, 0.0, 0.0]])}, + 0.0, + ], +] + + +class TestContrastiveLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_result(self, input_param, input_data, expected_val): + contrastiveloss = ContrastiveLoss(**input_param) + result = contrastiveloss(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + + def test_ill_shape(self): + loss = ContrastiveLoss(temperature=0.5, batch_size=1) + with self.assertRaisesRegex(ValueError, ""): + loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + + def test_with_cuda(self): + loss = ContrastiveLoss(temperature=0.5, batch_size=1) + i = torch.ones((1, 10)) + j = torch.ones((1, 10)) + if torch.cuda.is_available(): + i = i.cuda() + j = j.cuda() + output = loss(i, j) + np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_convert_data_type.py b/tests/test_convert_data_type.py index a7fc64f950..1818f500f9 100644 --- a/tests/test_convert_data_type.py +++ b/tests/test_convert_data_type.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,6 +24,24 @@ for out_type in TEST_NDARRAYS: TESTS.append((in_type(np.array(1.0)), out_type(np.array(1.0)))) # type: ignore +TESTS_LIST: List[Tuple] = [] +for in_type in TEST_NDARRAYS + (int, float): + for out_type in TEST_NDARRAYS: + TESTS_LIST.append( + ([in_type(np.array(1.0)), in_type(np.array(1.0))], out_type(np.array([1.0, 1.0])), True) # type: ignore + ) + TESTS_LIST.append( + ( + [in_type(np.array(1.0)), in_type(np.array(1.0))], # type: ignore + [out_type(np.array(1.0)), out_type(np.array(1.0))], + False, + ) + ) + + +class TestTensor(torch.Tensor): + pass + class TestConvertDataType(unittest.TestCase): @parameterized.expand(TESTS) @@ -42,14 +60,23 @@ def test_convert_data_type(self, in_image, im_out): def test_neg_stride(self): _ = convert_data_type(np.array((1, 2))[::-1], torch.Tensor) - def test_ill_arg(self): - with self.assertRaises(ValueError): - convert_data_type(None, torch.Tensor) - convert_data_type(None, np.ndarray) + @parameterized.expand(TESTS_LIST) + def test_convert_list(self, in_image, im_out, wrap): + output_type = type(im_out) if wrap else type(im_out[0]) + converted_im, *_ = convert_data_type(in_image, output_type, wrap_sequence=wrap) + # check output is desired type + if not wrap: + converted_im = converted_im[0] + im_out = im_out[0] + self.assertEqual(type(converted_im), type(im_out)) + # check dtype is unchanged + if isinstance(in_type, (np.ndarray, torch.Tensor)): + self.assertEqual(converted_im.dtype, im_out.dtype) class TestConvertDataSame(unittest.TestCase): - @parameterized.expand(TESTS) + # add test for subclass of Tensor + @parameterized.expand(TESTS + [(np.array(1.0), TestTensor(np.array(1.0)))]) def test_convert_data_type(self, in_image, im_out): converted_im, orig_type, orig_device = convert_to_dst_type(in_image, im_out) # check input is unchanged @@ -57,7 +84,11 @@ def test_convert_data_type(self, in_image, im_out): if isinstance(in_image, torch.Tensor): self.assertEqual(in_image.device, orig_device) # check output is desired type - self.assertEqual(type(converted_im), type(im_out)) + if isinstance(im_out, torch.Tensor): + output_type = torch.Tensor + else: + output_type = np.ndarray + self.assertEqual(type(converted_im), output_type) # check dtype is unchanged if isinstance(in_type, (np.ndarray, torch.Tensor)): self.assertEqual(converted_im.dtype, im_out.dtype) diff --git a/tests/test_convert_to_multi_channel.py b/tests/test_convert_to_multi_channel.py index 2f7a38e6e4..b606fee04f 100644 --- a/tests/test_convert_to_multi_channel.py +++ b/tests/test_convert_to_multi_channel.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,34 +11,46 @@ import unittest -import numpy as np +import torch from parameterized import parameterized from monai.transforms import ConvertToMultiChannelBasedOnBratsClasses +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - np.array([[0, 1, 2], [1, 2, 4], [0, 1, 4]]), - np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]), -] - -TEST_CASE_2 = [ - np.array([[[[0, 1], [1, 2]], [[2, 4], [4, 4]]]]), - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.extend( [ - [[[0, 1], [1, 0]], [[0, 1], [1, 1]]], - [[[0, 1], [1, 1]], [[1, 1], [1, 1]]], - [[[0, 0], [0, 0]], [[0, 1], [1, 1]]], + [ + p([[0, 1, 2], [1, 2, 4], [0, 1, 4]]), + p( + [ + [[0, 1, 0], [1, 0, 1], [0, 1, 1]], + [[0, 1, 1], [1, 1, 1], [0, 1, 1]], + [[0, 0, 0], [0, 0, 1], [0, 0, 1]], + ] + ), + ], + [ + p([[[[0, 1], [1, 2]], [[2, 4], [4, 4]]]]), + p( + [ + [[[0, 1], [1, 0]], [[0, 1], [1, 1]]], + [[[0, 1], [1, 1]], [[1, 1], [1, 1]]], + [[[0, 0], [0, 0]], [[0, 1], [1, 1]]], + ] + ), + ], ] - ), -] + ) class TestConvertToMultiChannel(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_type_shape(self, data, expected_result): result = ConvertToMultiChannelBasedOnBratsClasses()(data) - np.testing.assert_equal(result, expected_result) - self.assertEqual(f"{result.dtype}", "bool") + assert_allclose(result, expected_result) + self.assertTrue(result.dtype in (bool, torch.bool)) if __name__ == "__main__": diff --git a/tests/test_convert_to_multi_channeld.py b/tests/test_convert_to_multi_channeld.py index 945e07e1cd..7525f8d7e2 100644 --- a/tests/test_convert_to_multi_channeld.py +++ b/tests/test_convert_to_multi_channeld.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_convert_to_torchscript.py b/tests/test_convert_to_torchscript.py new file mode 100644 index 0000000000..a1c1471463 --- /dev/null +++ b/tests/test_convert_to_torchscript.py @@ -0,0 +1,43 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import torch + +from monai.networks import convert_to_torchscript +from monai.networks.nets import UNet + + +class TestConvertToTorchScript(unittest.TestCase): + def test_value(self): + model = UNet( + spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0 + ) + with tempfile.TemporaryDirectory() as tempdir: + torchscript_model = convert_to_torchscript( + model=model, + filename_or_obj=os.path.join(tempdir, "model.ts"), + extra_files={"foo.txt": b"bar"}, + verify=True, + inputs=[torch.randn((16, 1, 32, 32), requires_grad=False)], + device="cuda" if torch.cuda.is_available() else "cpu", + rtol=1e-3, + atol=1e-4, + optimize=None, + ) + self.assertTrue(isinstance(torchscript_model, torch.nn.Module)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_convolutions.py b/tests/test_convolutions.py index 97c01dd659..5c1681ad71 100644 --- a/tests/test_convolutions.py +++ b/tests/test_convolutions.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_copy_itemsd.py b/tests/test_copy_itemsd.py index a0a1ad412b..11f920cf6b 100644 --- a/tests/test_copy_itemsd.py +++ b/tests/test_copy_itemsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -39,13 +39,20 @@ def test_numpy_values(self, keys, times, names): np.testing.assert_allclose(result["img_1"], np.array([[1, 2], [2, 3]])) np.testing.assert_allclose(result["img"], np.array([[0, 1], [1, 2]])) + def test_default_names(self): + input_data = {"img": np.array([[0, 1], [1, 2]]), "seg": np.array([[3, 4], [4, 5]])} + result = CopyItemsd(keys=["img", "seg"], times=2, names=None)(input_data) + for name in ["img_0", "seg_0", "img_1", "seg_1"]: + self.assertTrue(name in result) + def test_tensor_values(self): device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0") input_data = { "img": torch.tensor([[0, 1], [1, 2]], device=device), "seg": torch.tensor([[0, 1], [1, 2]], device=device), } - result = CopyItemsd(keys="img", times=1, names="img_1")(input_data) + # test default `times=1` + result = CopyItemsd(keys="img", names="img_1")(input_data) self.assertTrue("img_1" in result) result["img_1"] += 1 torch.testing.assert_allclose(result["img"], torch.tensor([[0, 1], [1, 2]], device=device)) diff --git a/tests/test_copy_model_state.py b/tests/test_copy_model_state.py index 6330a1918a..bc7b116e1f 100644 --- a/tests/test_copy_model_state.py +++ b/tests/test_copy_model_state.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,7 +21,7 @@ class _TestModelOne(torch.nn.Module): def __init__(self, n_n, n_m, n_class): - super(_TestModelOne, self).__init__() + super().__init__() self.layer = torch.nn.Linear(n_n, n_m) self.class_layer = torch.nn.Linear(n_m, n_class) @@ -33,7 +33,7 @@ def forward(self, x): class _TestModelTwo(torch.nn.Module): def __init__(self, n_n, n_m, n_d, n_class): - super(_TestModelTwo, self).__init__() + super().__init__() self.layer = torch.nn.Linear(n_n, n_m) self.layer_1 = torch.nn.Linear(n_m, n_d) self.class_layer = torch.nn.Linear(n_d, n_class) diff --git a/tests/test_correct_crop_centers.py b/tests/test_correct_crop_centers.py new file mode 100644 index 0000000000..50478c7d5d --- /dev/null +++ b/tests/test_correct_crop_centers.py @@ -0,0 +1,34 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms.utils import correct_crop_centers +from tests.utils import assert_allclose + +TESTS = [[[1, 5, 0], [2, 2, 2], [10, 10, 10]], [[4, 4, 4], [2, 2, 1], [10, 10, 10]]] + + +class TestCorrectCropCenters(unittest.TestCase): + @parameterized.expand(TESTS) + def test_torch(self, spatial_size, centers, label_spatial_shape): + result1 = correct_crop_centers(centers, spatial_size, label_spatial_shape) + centers = [torch.tensor(i) for i in centers] + result2 = correct_crop_centers(centers, spatial_size, label_spatial_shape) + assert_allclose(result1, result2) + self.assertEqual(type(result1[0]), type(result2[0])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_create_cross_validation_datalist.py b/tests/test_create_cross_validation_datalist.py new file mode 100644 index 0000000000..3a3e8481ea --- /dev/null +++ b/tests/test_create_cross_validation_datalist.py @@ -0,0 +1,56 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest +from pathlib import Path + +from monai.data import create_cross_validation_datalist, load_decathlon_datalist + + +class TestCreateCrossValidationDatalist(unittest.TestCase): + def test_content(self): + with tempfile.TemporaryDirectory() as tempdir: + datalist = [] + for i in range(5): + image = os.path.join(tempdir, f"test_image{i}.nii.gz") + label = os.path.join(tempdir, f"test_label{i}.nii.gz") + Path(image).touch() + Path(label).touch() + datalist.append({"image": image, "label": label}) + + filename = os.path.join(tempdir, "test_datalist.json") + result = create_cross_validation_datalist( + datalist=datalist, + nfolds=5, + train_folds=[0, 1, 2, 3], + val_folds=4, + train_key="test_train", + val_key="test_val", + filename=Path(filename), + shuffle=True, + seed=123, + check_missing=True, + keys=["image", "label"], + root_dir=None, + allow_missing_keys=False, + raise_error=True, + ) + + loaded = load_decathlon_datalist(filename, data_list_key="test_train") + for r, l in zip(result["test_train"], loaded): + self.assertEqual(r["image"], l["image"]) + self.assertEqual(r["label"], l["label"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_create_grid_and_affine.py b/tests/test_create_grid_and_affine.py index 0c0e52e04a..d70db45468 100644 --- a/tests/test_create_grid_and_affine.py +++ b/tests/test_create_grid_and_affine.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from monai.transforms import ( create_control_grid, @@ -21,6 +22,7 @@ create_shear, create_translate, ) +from tests.utils import assert_allclose, is_tf32_env class TestCreateGrid(unittest.TestCase): @@ -32,50 +34,47 @@ def test_create_grid(self): with self.assertRaisesRegex(TypeError, ""): create_grid((1, 1), spacing=2.0) - g = create_grid((1, 1)) - expected = np.array([[[0.0]], [[0.0]], [[1.0]]]) - np.testing.assert_allclose(g, expected) + test_assert(create_grid, ((1, 1),), np.array([[[0.0]], [[0.0]], [[1.0]]])) - g = create_grid((1, 1), homogeneous=False) - expected = np.array([[[0.0]], [[0.0]]]) - np.testing.assert_allclose(g, expected) + test_assert(create_grid, ((1, 1), None, False), np.array([[[0.0]], [[0.0]]])) - g = create_grid((1, 1), spacing=(1.2, 1.3)) - expected = np.array([[[0.0]], [[0.0]], [[1.0]]]) - np.testing.assert_allclose(g, expected) + test_assert(create_grid, ((1, 1), (1.2, 1.3)), np.array([[[0.0]], [[0.0]], [[1.0]]])) - g = create_grid((1, 1, 1), spacing=(1.2, 1.3, 1.0)) - expected = np.array([[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[1.0]]]]) - np.testing.assert_allclose(g, expected) + test_assert(create_grid, ((1, 1, 1), (1.2, 1.3, 1.0)), np.array([[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[1.0]]]])) - g = create_grid((1, 1, 1), spacing=(1.2, 1.3, 1.0), homogeneous=False) - expected = np.array([[[[0.0]]], [[[0.0]]], [[[0.0]]]]) - np.testing.assert_allclose(g, expected) + test_assert(create_grid, ((1, 1, 1), (1.2, 1.3, 1.0), False), np.array([[[[0.0]]], [[[0.0]]], [[[0.0]]]])) g = create_grid((1, 1, 1), spacing=(1.2, 1.3, 1.0), dtype=np.int32) np.testing.assert_equal(g.dtype, np.int32) - g = create_grid((2, 2, 2)) - expected = np.array( - [ - [[[-0.5, -0.5], [-0.5, -0.5]], [[0.5, 0.5], [0.5, 0.5]]], - [[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, -0.5], [0.5, 0.5]]], - [[[-0.5, 0.5], [-0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]]], - [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], - ] + g = create_grid((1, 1, 1), spacing=(1.2, 1.3, 1.0), dtype=torch.float64, backend="torch") + np.testing.assert_equal(g.dtype, torch.float64) + + test_assert( + create_grid, + ((2, 2, 2),), + np.array( + [ + [[[-0.5, -0.5], [-0.5, -0.5]], [[0.5, 0.5], [0.5, 0.5]]], + [[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, -0.5], [0.5, 0.5]]], + [[[-0.5, 0.5], [-0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]]], + [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], + ] + ), ) - np.testing.assert_allclose(g, expected) - g = create_grid((2, 2, 2), spacing=(1.2, 1.3, 1.0)) - expected = np.array( - [ - [[[-0.6, -0.6], [-0.6, -0.6]], [[0.6, 0.6], [0.6, 0.6]]], - [[[-0.65, -0.65], [0.65, 0.65]], [[-0.65, -0.65], [0.65, 0.65]]], - [[[-0.5, 0.5], [-0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]]], - [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], - ] + test_assert( + create_grid, + ((2, 2, 2), (1.2, 1.3, 1.0)), + np.array( + [ + [[[-0.6, -0.6], [-0.6, -0.6]], [[0.6, 0.6], [0.6, 0.6]]], + [[[-0.65, -0.65], [0.65, 0.65]], [[-0.65, -0.65], [0.65, 0.65]]], + [[[-0.5, 0.5], [-0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]]], + [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], + ] + ), ) - np.testing.assert_allclose(g, expected) def test_create_control_grid(self): with self.assertRaisesRegex(TypeError, ""): @@ -83,72 +82,87 @@ def test_create_control_grid(self): with self.assertRaisesRegex(TypeError, ""): create_control_grid((1, 1), 2.0) - g = create_control_grid((1.0, 1.0), (1.0, 1.0)) - expected = np.array( - [ - [[-1.0, -1.0, -1.0], [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]], - [[-1.0, 0.0, 1.0], [-1.0, 0.0, 1.0], [-1.0, 0.0, 1.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - ] + test_assert( + create_control_grid, + ((1.0, 1.0), (1.0, 1.0)), + np.array( + [ + [[-1.0, -1.0, -1.0], [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]], + [[-1.0, 0.0, 1.0], [-1.0, 0.0, 1.0], [-1.0, 0.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ), ) - np.testing.assert_allclose(g, expected) - g = create_control_grid((1.0, 1.0), (2.0, 2.0)) - expected = np.array( - [ - [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], - [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - ] + test_assert( + create_control_grid, + ((1.0, 1.0), (2.0, 2.0)), + np.array( + [ + [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], + [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ), ) - np.testing.assert_allclose(g, expected) - g = create_control_grid((2.0, 2.0), (1.0, 1.0)) - expected = np.array( - [ - [[-1.5, -1.5, -1.5, -1.5], [-0.5, -0.5, -0.5, -0.5], [0.5, 0.5, 0.5, 0.5], [1.5, 1.5, 1.5, 1.5]], - [[-1.5, -0.5, 0.5, 1.5], [-1.5, -0.5, 0.5, 1.5], [-1.5, -0.5, 0.5, 1.5], [-1.5, -0.5, 0.5, 1.5]], - [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]], - ] + test_assert( + create_control_grid, + ((2.0, 2.0), (1.0, 1.0)), + np.array( + [ + [[-1.5, -1.5, -1.5, -1.5], [-0.5, -0.5, -0.5, -0.5], [0.5, 0.5, 0.5, 0.5], [1.5, 1.5, 1.5, 1.5]], + [[-1.5, -0.5, 0.5, 1.5], [-1.5, -0.5, 0.5, 1.5], [-1.5, -0.5, 0.5, 1.5], [-1.5, -0.5, 0.5, 1.5]], + [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]], + ] + ), ) - np.testing.assert_allclose(g, expected) - g = create_control_grid((2.0, 2.0), (2.0, 2.0)) - expected = np.array( - [ - [[-3.0, -3.0, -3.0, -3.0], [-1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0], [3.0, 3.0, 3.0, 3.0]], - [[-3.0, -1.0, 1.0, 3.0], [-3.0, -1.0, 1.0, 3.0], [-3.0, -1.0, 1.0, 3.0], [-3.0, -1.0, 1.0, 3.0]], - [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]], - ] + test_assert( + create_control_grid, + ((2.0, 2.0), (2.0, 2.0)), + np.array( + [ + [[-3.0, -3.0, -3.0, -3.0], [-1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0], [3.0, 3.0, 3.0, 3.0]], + [[-3.0, -1.0, 1.0, 3.0], [-3.0, -1.0, 1.0, 3.0], [-3.0, -1.0, 1.0, 3.0], [-3.0, -1.0, 1.0, 3.0]], + [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]], + ] + ), ) - np.testing.assert_allclose(g, expected) - g = create_control_grid((1.0, 1.0, 1.0), (2.0, 2.0, 2.0), homogeneous=False) - expected = np.array( - [ - [ - [[-2.0, -2.0, -2.0], [-2.0, -2.0, -2.0], [-2.0, -2.0, -2.0]], - [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], - [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], - ], - [ - [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], - [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], - [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], - ], + test_assert( + create_control_grid, + ((1.0, 1.0, 1.0), (2.0, 2.0, 2.0), False), + np.array( [ - [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], - [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], - [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], - ], - ] + [ + [[-2.0, -2.0, -2.0], [-2.0, -2.0, -2.0], [-2.0, -2.0, -2.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], + ], + [ + [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], + [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], + [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], + ], + [ + [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], + [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], + [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], + ], + ] + ), ) - np.testing.assert_allclose(g, expected) def test_assert(func, params, expected): - m = func(*params) - np.testing.assert_allclose(m, expected, atol=1e-7) + gpu_test = ("torch_gpu",) if torch.cuda.is_available() else () + for b in ("torch", "numpy") + gpu_test: + if b == "torch_gpu": + m = func(*params, device="cuda:0", backend="torch") + else: + m = func(*params, backend=b) + assert_allclose(m, expected, type_test=False, rtol=1e-2 if is_tf32_env() else 1e-5, atol=1e-5) class TestCreateAffine(unittest.TestCase): diff --git a/tests/test_crf_cpu.py b/tests/test_crf_cpu.py index ed1860943f..46da3298bc 100644 --- a/tests/test_crf_cpu.py +++ b/tests/test_crf_cpu.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -55,12 +55,12 @@ # Batch 0 [ # Channel 0 - [1, 1, 1, 0.5, 0], + [1, 1, 1, 0.5, 0] ], # Batch 1 [ # Channel 0 - [1, 1, 0.5, 0, 0], + [1, 1, 0.5, 0, 0] ], ], # Expected @@ -117,12 +117,12 @@ # Batch 0 [ # Channel 0 - [1, 1, 1, 0.5, 0], + [1, 1, 1, 0.5, 0] ], # Batch 1 [ # Channel 0 - [1, 1, 0.5, 0, 0], + [1, 1, 0.5, 0, 0] ], ], # Expected @@ -185,7 +185,7 @@ [1.0, 1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0, 0.0], ], - ], + ] ], # Features [ @@ -207,7 +207,7 @@ [0.0, 0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 1.0, 1.0], ], - ], + ] ], # Expected [ @@ -237,7 +237,7 @@ [0.688815, 0.687855, 0.687076, 0.228579, 0.227552], [0.687434, 0.686453, 0.445019, 0.229047, 0.227588], ], - ], + ] ], ], [ @@ -344,7 +344,7 @@ [0.0, 0.0, 0.0, 1.0, 1.0], ], ], - ], + ] ], # Features [ @@ -392,8 +392,8 @@ [0.0, 0.0, 1.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0, 1.0], ], - ], - ], + ] + ] ], # Expected [ @@ -485,7 +485,7 @@ [0.500533, 0.500745, 0.553344, 0.771576, 0.772222], ], ], - ], + ] ], ], ] diff --git a/tests/test_crf_cuda.py b/tests/test_crf_cuda.py index adf8c440c0..ca25fe2de9 100644 --- a/tests/test_crf_cuda.py +++ b/tests/test_crf_cuda.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -55,12 +55,12 @@ # Batch 0 [ # Channel 0 - [1, 1, 1, 0.5, 0], + [1, 1, 1, 0.5, 0] ], # Batch 1 [ # Channel 0 - [1, 1, 0.5, 0, 0], + [1, 1, 0.5, 0, 0] ], ], # Expected @@ -117,12 +117,12 @@ # Batch 0 [ # Channel 0 - [1, 1, 1, 0.5, 0], + [1, 1, 1, 0.5, 0] ], # Batch 1 [ # Channel 0 - [1, 1, 0.5, 0, 0], + [1, 1, 0.5, 0, 0] ], ], # Expected @@ -185,7 +185,7 @@ [0.5, 1.0, 0.5, 0.0, 0.0], [1.0, 0.5, 0.0, 0.0, 0.0], ], - ], + ] ], # Features [ @@ -207,7 +207,7 @@ [0.0, 0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 1.0, 1.0], ], - ], + ] ], # Expected [ @@ -237,7 +237,7 @@ [0.492602, 0.609557, 0.480947, 0.161909, 0.161476], [0.610678, 0.480516, 0.352479, 0.159380, 0.158274], ], - ], + ] ], ], [ @@ -344,7 +344,7 @@ [0.0, 0.0, 0.0, 1.0, 1.0], ], ], - ], + ] ], # Features [ @@ -392,8 +392,8 @@ [0.0, 0.0, 1.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0, 1.0], ], - ], - ], + ] + ] ], # Expected [ @@ -485,7 +485,7 @@ [0.500663, 0.500887, 0.556332, 0.773597, 0.775210], ], ], - ], + ] ], ], ] diff --git a/tests/test_crop_foreground.py b/tests/test_crop_foreground.py index 71e488cac8..af945673fe 100644 --- a/tests/test_crop_foreground.py +++ b/tests/test_crop_foreground.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,60 +12,87 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import CropForeground +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), -] - -TEST_CASE_2 = [ - {"select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[3]]]), -] - -TEST_CASE_3 = [ - {"select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), -] - -TEST_CASE_4 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), -] - -TEST_CASE_5 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1]}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), -] - -TEST_CASE_6 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 4}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2, 1, 0], [2, 3, 2, 0], [1, 2, 1, 0], [0, 0, 0, 0]]]), -] - -TEST_CASE_7 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 10, "constant_values": 2}, - np.array([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - np.zeros((1, 0, 0)), -] +TEST_COORDS, TESTS = [], [] + +for p in TEST_NDARRAYS: + TEST_COORDS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + p([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, + p([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]), + p([[[3]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + p([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1], "allow_smaller": True}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1], "allow_smaller": False}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + p([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 4}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + p([[[1, 2, 1, 0], [2, 3, 2, 0], [1, 2, 1, 0], [0, 0, 0, 0]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 10}, + p([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + p(np.zeros((1, 0, 0), dtype=np.int64)), + ] + ) class TestCropForeground(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + @parameterized.expand(TEST_COORDS + TESTS) def test_value(self, argments, image, expected_data): result = CropForeground(**argments)(image) - np.testing.assert_allclose(result, expected_data) + torch.testing.assert_allclose(result, expected_data, rtol=1e-7, atol=0) - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TEST_COORDS) def test_return_coords(self, argments, image, _): argments["return_coords"] = True _, start_coord, end_coord = CropForeground(**argments)(image) diff --git a/tests/test_crop_foregroundd.py b/tests/test_crop_foregroundd.py index efe6b65b4b..fa69143827 100644 --- a/tests/test_crop_foregroundd.py +++ b/tests/test_crop_foregroundd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,85 +12,161 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import CropForegroundd -from monai.utils import NumpyPadMode +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - { - "keys": ["img", "label"], - "source_key": "label", - "select_fn": lambda x: x > 0, - "channel_indices": None, - "margin": 0, - "mode": "constant", - "constant_values": 2, - }, - { - "img": np.array([[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]), - "label": np.array([[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]), - }, - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), -] +TEST_POSITION, TESTS = [], [] +for p in TEST_NDARRAYS: -TEST_CASE_2 = [ - {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[3]]]), -] - -TEST_CASE_3 = [ - {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), -] - -TEST_CASE_4 = [ - {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), -] - -TEST_CASE_5 = [ - {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1]}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), -] - -TEST_CASE_6 = [ - { - "keys": ["img", "seg"], - "source_key": "img", - "select_fn": lambda x: x > 0, - "channel_indices": 0, - "margin": 0, - "k_divisible": [4, 6], - "mode": ["edge", NumpyPadMode.CONSTANT], - }, - { - "img": np.array([[[0, 2, 1, 2, 0], [1, 1, 2, 1, 1], [2, 2, 3, 2, 2], [1, 1, 2, 1, 1], [0, 0, 0, 0, 0]]]), - "seg": np.array([[[0, 2, 1, 2, 0], [1, 1, 2, 1, 1], [2, 2, 3, 2, 2], [1, 1, 2, 1, 1], [0, 0, 0, 0, 0]]]), - }, - np.array([[[0, 2, 1, 2, 0, 0], [1, 1, 2, 1, 1, 1], [2, 2, 3, 2, 2, 2], [1, 1, 2, 1, 1, 1]]]), -] + TEST_POSITION.append( + [ + { + "keys": ["img", "label"], + "source_key": "label", + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": 0, + }, + { + "img": p( + np.array([[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]) + ), + "label": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]) + ), + }, + p(np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]])), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]) + ) + }, + p(np.array([[[3]]])), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ) + }, + p(np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]])), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]) + ) + }, + p(np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]])), + ] + ) + TESTS.append( + [ + { + "keys": ["img"], + "source_key": "img", + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": [2, 1], + "allow_smaller": True, + }, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ) + }, + p(np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])), + ] + ) + TESTS.append( + [ + { + "keys": ["img"], + "source_key": "img", + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": [2, 1], + "allow_smaller": False, + }, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ) + }, + p( + np.array( + [ + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 1, 2, 1, 0], + [0, 2, 3, 2, 0], + [0, 1, 2, 1, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "keys": ["img"], + "source_key": "img", + "select_fn": lambda x: x > 0, + "channel_indices": 0, + "margin": 0, + "k_divisible": [4, 6], + "mode": "edge", + }, + { + "img": p( + np.array( + [[[0, 2, 1, 2, 0], [1, 1, 2, 1, 1], [2, 2, 3, 2, 2], [1, 1, 2, 1, 1], [0, 0, 0, 0, 0]]], + dtype=np.float32, + ) + ) + }, + p(np.array([[[0, 2, 1, 2, 0, 0], [1, 1, 2, 1, 1, 1], [2, 2, 3, 2, 2, 2], [1, 1, 2, 1, 1, 1]]])), + ] + ) class TestCropForegroundd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) - def test_value(self, argments, image, expected_data): - result = CropForegroundd(**argments)(image) - np.testing.assert_allclose(result["img"], expected_data) + @parameterized.expand(TEST_POSITION + TESTS) + def test_value(self, argments, input_data, expected_data): + result = CropForegroundd(**argments)(input_data) + r, i = result["img"], input_data["img"] + self.assertEqual(type(r), type(i)) + if isinstance(r, torch.Tensor): + self.assertEqual(r.device, i.device) + assert_allclose(r, expected_data) - @parameterized.expand([TEST_CASE_1]) - def test_foreground_position(self, argments, image, _): - result = CropForegroundd(**argments)(image) + @parameterized.expand(TEST_POSITION) + def test_foreground_position(self, argments, input_data, _): + result = CropForegroundd(**argments)(input_data) np.testing.assert_allclose(result["foreground_start_coord"], np.array([1, 1])) np.testing.assert_allclose(result["foreground_end_coord"], np.array([4, 4])) argments["start_coord_key"] = "test_start_coord" argments["end_coord_key"] = "test_end_coord" - result = CropForegroundd(**argments)(image) + result = CropForegroundd(**argments)(input_data) np.testing.assert_allclose(result["test_start_coord"], np.array([1, 1])) np.testing.assert_allclose(result["test_end_coord"], np.array([4, 4])) diff --git a/tests/test_cross_validation.py b/tests/test_cross_validation.py index 33e10a6a40..c378a52f78 100644 --- a/tests/test_cross_validation.py +++ b/tests/test_cross_validation.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,18 +11,18 @@ import os import unittest -from urllib.error import ContentTooShortError, HTTPError from monai.apps import CrossValidation, DecathlonDataset from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord -from tests.utils import skip_if_quick +from monai.utils.enums import PostFix +from tests.utils import skip_if_downloading_fails, skip_if_quick class TestCrossValidation(unittest.TestCase): @skip_if_quick def test_values(self): testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") - transform = Compose( + train_transform = Compose( [ LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), @@ -30,12 +30,13 @@ def test_values(self): ToTensord(keys=["image", "label"]), ] ) + val_transform = LoadImaged(keys=["image", "label"]) def _test_dataset(dataset): self.assertEqual(len(dataset), 52) self.assertTrue("image" in dataset[0]) self.assertTrue("label" in dataset[0]) - self.assertTrue("image_meta_dict" in dataset[0]) + self.assertTrue(PostFix.meta("image") in dataset[0]) self.assertTupleEqual(dataset[0]["image"].shape, (1, 34, 49, 41)) cvdataset = CrossValidation( @@ -45,28 +46,23 @@ def _test_dataset(dataset): root_dir=testing_dir, task="Task04_Hippocampus", section="validation", - transform=transform, + transform=train_transform, download=True, ) - try: # will start downloading if testing_dir doesn't have the Decathlon files + with skip_if_downloading_fails(): data = cvdataset.get_dataset(folds=0) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors _test_dataset(data) - # test training data for fold 0 of 5 splits + # test training data for fold [1, 2, 3, 4] of 5 splits data = cvdataset.get_dataset(folds=[1, 2, 3, 4]) self.assertTupleEqual(data[0]["image"].shape, (1, 35, 52, 33)) self.assertEqual(len(data), 208) # test train / validation for fold 4 of 5 splits - data = cvdataset.get_dataset(folds=[4]) - self.assertTupleEqual(data[0]["image"].shape, (1, 38, 53, 30)) + data = cvdataset.get_dataset(folds=[4], transform=val_transform, download=False) + # val_transform doesn't add the channel dim to shape + self.assertTupleEqual(data[0]["image"].shape, (38, 53, 30)) self.assertEqual(len(data), 52) data = cvdataset.get_dataset(folds=[0, 1, 2, 3]) self.assertTupleEqual(data[0]["image"].shape, (1, 34, 49, 41)) diff --git a/tests/test_csv_dataset.py b/tests/test_csv_dataset.py index d187f4e64d..f288ac4b95 100644 --- a/tests/test_csv_dataset.py +++ b/tests/test_csv_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import unittest import numpy as np +import pandas as pd from monai.data import CSVDataset from monai.transforms import ToNumpyd @@ -57,6 +58,7 @@ def prepare_csv_file(data, filepath): filepath1 = os.path.join(tempdir, "test_data1.csv") filepath2 = os.path.join(tempdir, "test_data2.csv") filepath3 = os.path.join(tempdir, "test_data3.csv") + filepaths = [filepath1, filepath2, filepath3] prepare_csv_file(test_data1, filepath1) prepare_csv_file(test_data2, filepath2) prepare_csv_file(test_data3, filepath3) @@ -76,7 +78,7 @@ def prepare_csv_file(data, filepath): ) # test multiple CSV files, join tables with kwargs - dataset = CSVDataset([filepath1, filepath2, filepath3], on="subject_id") + dataset = CSVDataset(filepaths, on="subject_id") self.assertDictEqual( {k: round(v, 4) if not isinstance(v, (str, np.bool_)) else v for k, v in dataset[3].items()}, { @@ -102,7 +104,7 @@ def prepare_csv_file(data, filepath): # test selected rows and columns dataset = CSVDataset( - filename=[filepath1, filepath2, filepath3], + src=filepaths, row_indices=[[0, 2], 3], # load row: 0, 1, 3 col_names=["subject_id", "image", "ehr_1", "ehr_7", "meta_1"], ) @@ -120,7 +122,7 @@ def prepare_csv_file(data, filepath): # test group columns dataset = CSVDataset( - filename=[filepath1, filepath2, filepath3], + src=filepaths, row_indices=[1, 3], # load row: 1, 3 col_names=["subject_id", "image", *[f"ehr_{i}" for i in range(11)], "meta_0", "meta_1", "meta_2"], col_groups={"ehr": [f"ehr_{i}" for i in range(11)], "meta12": ["meta_1", "meta_2"]}, @@ -133,9 +135,7 @@ def prepare_csv_file(data, filepath): # test transform dataset = CSVDataset( - filename=[filepath1, filepath2, filepath3], - col_groups={"ehr": [f"ehr_{i}" for i in range(5)]}, - transform=ToNumpyd(keys="ehr"), + src=filepaths, col_groups={"ehr": [f"ehr_{i}" for i in range(5)]}, transform=ToNumpyd(keys="ehr") ) self.assertEqual(len(dataset), 5) expected = [ @@ -151,7 +151,7 @@ def prepare_csv_file(data, filepath): # test default values and dtype dataset = CSVDataset( - filename=[filepath1, filepath2, filepath3], + src=filepaths, col_names=["subject_id", "image", "ehr_1", "ehr_9", "meta_1"], col_types={"image": {"type": str, "default": "No image"}, "ehr_1": {"type": int, "default": 0}}, how="outer", # generate NaN values in this merge mode @@ -161,6 +161,29 @@ def prepare_csv_file(data, filepath): self.assertEqual(type(dataset[-1]["ehr_1"]), int) np.testing.assert_allclose(dataset[-1]["ehr_9"], 3.3537, rtol=1e-2) + # test pre-loaded DataFrame + df = pd.read_csv(filepath1) + dataset = CSVDataset(src=df) + self.assertDictEqual( + {k: round(v, 4) if not isinstance(v, str) else v for k, v in dataset[2].items()}, + { + "subject_id": "s000002", + "label": 4, + "image": "./imgs/s000002.png", + "ehr_0": 3.7725, + "ehr_1": 4.2118, + "ehr_2": 4.6353, + }, + ) + + # test pre-loaded multiple DataFrames, join tables with kwargs + dfs = [pd.read_csv(i) for i in filepaths] + dataset = CSVDataset(src=dfs, on="subject_id") + self.assertEqual(dataset[3]["subject_id"], "s000003") + self.assertEqual(dataset[3]["label"], 1) + self.assertEqual(round(dataset[3]["ehr_0"], 4), 3.3333) + self.assertEqual(dataset[3]["meta_0"], False) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_csv_iterable_dataset.py b/tests/test_csv_iterable_dataset.py index c7a3f31dc6..d6b84074ba 100644 --- a/tests/test_csv_iterable_dataset.py +++ b/tests/test_csv_iterable_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,6 +15,7 @@ import unittest import numpy as np +import pandas as pd from monai.data import CSVIterableDataset, DataLoader from monai.transforms import ToNumpyd @@ -58,14 +59,17 @@ def prepare_csv_file(data, filepath): filepath1 = os.path.join(tempdir, "test_data1.csv") filepath2 = os.path.join(tempdir, "test_data2.csv") filepath3 = os.path.join(tempdir, "test_data3.csv") + filepaths = [filepath1, filepath2, filepath3] prepare_csv_file(test_data1, filepath1) prepare_csv_file(test_data2, filepath2) prepare_csv_file(test_data3, filepath3) # test single CSV file - dataset = CSVIterableDataset(filepath1) - for i, item in enumerate(dataset): - if i == 2: + dataset = CSVIterableDataset(filepath1, shuffle=False) + count = 0 + for item in dataset: + count += 1 + if count == 3: self.assertDictEqual( {k: round(v, 4) if not isinstance(v, str) else v for k, v in item.items()}, { @@ -78,16 +82,25 @@ def prepare_csv_file(data, filepath): }, ) break + self.assertEqual(count, 3) + dataset.close() + # test reset iterables - dataset.reset(filename=filepath3) + dataset.reset(src=filepath3) + count = 0 for i, item in enumerate(dataset): - if i == 3: + count += 1 + if i == 4: self.assertEqual(item["meta_0"], False) + self.assertEqual(count, 5) + dataset.close() # test multiple CSV files, join tables with kwargs - dataset = CSVIterableDataset([filepath1, filepath2, filepath3], on="subject_id") - for i, item in enumerate(dataset): - if i == 3: + dataset = CSVIterableDataset(filepaths, on="subject_id", shuffle=False) + count = 0 + for item in dataset: + count += 1 + if count == 4: self.assertDictEqual( {k: round(v, 4) if not isinstance(v, (str, np.bool_)) else v for k, v in item.items()}, { @@ -110,15 +123,17 @@ def prepare_csv_file(data, filepath): "meta_2": True, }, ) + self.assertEqual(count, 5) + dataset.close() # test selected columns and chunk size dataset = CSVIterableDataset( - filename=[filepath1, filepath2, filepath3], - chunksize=2, - col_names=["subject_id", "image", "ehr_1", "ehr_7", "meta_1"], + src=filepaths, chunksize=2, col_names=["subject_id", "image", "ehr_1", "ehr_7", "meta_1"], shuffle=False ) - for i, item in enumerate(dataset): - if i == 3: + count = 0 + for item in dataset: + count += 1 + if count == 4: self.assertDictEqual( {k: round(v, 4) if not isinstance(v, (str, np.bool_)) else v for k, v in item.items()}, { @@ -129,49 +144,106 @@ def prepare_csv_file(data, filepath): "meta_1": False, }, ) + self.assertEqual(count, 5) + dataset.close() # test group columns dataset = CSVIterableDataset( - filename=[filepath1, filepath2, filepath3], + src=filepaths, col_names=["subject_id", "image", *[f"ehr_{i}" for i in range(11)], "meta_0", "meta_1", "meta_2"], col_groups={"ehr": [f"ehr_{i}" for i in range(11)], "meta12": ["meta_1", "meta_2"]}, + shuffle=False, ) - for i, item in enumerate(dataset): - if i == 3: + count = 0 + for item in dataset: + count += 1 + if count == 4: np.testing.assert_allclose( [round(i, 4) for i in item["ehr"]], [3.3333, 3.2353, 3.4000, 3.1647, 3.0863, 3.7255, 3.6980, 3.6980, 3.7020, 3.3098, 3.7294], ) np.testing.assert_allclose(item["meta12"], [False, True]) + self.assertEqual(count, 5) + dataset.close() # test transform dataset = CSVIterableDataset( - filename=[filepath1, filepath2, filepath3], + chunksize=2, + buffer_size=4, + src=filepaths, col_groups={"ehr": [f"ehr_{i}" for i in range(5)]}, transform=ToNumpyd(keys="ehr"), + shuffle=True, + seed=123, ) expected = [ - [2.0078, 2.2902, 2.0549, 3.0196, 3.8078], [6.8392, 6.4745, 5.8627, 5.1922, 5.2745], - [3.7725, 4.2118, 4.6353, 5.2980, 9.5451], [3.3333, 3.2353, 3.4000, 3.1647, 3.0863], + [3.7725, 4.2118, 4.6353, 5.298, 9.5451], [6.4275, 6.2549, 5.9765, 6.2627, 7.7176], + [2.0078, 2.2902, 2.0549, 3.0196, 3.8078], ] + count = 0 for item, exp in zip(dataset, expected): + count += 1 self.assertTrue(isinstance(item["ehr"], np.ndarray)) np.testing.assert_allclose(np.around(item["ehr"], 4), exp) + self.assertEqual(count, 5) + dataset.close() # test multiple processes loading - dataset = CSVIterableDataset(filepath1, transform=ToNumpyd(keys="label")) + dataset = CSVIterableDataset(filepath1, transform=ToNumpyd(keys="label"), shuffle=False) # set num workers = 0 for mac / win num_workers = 2 if sys.platform == "linux" else 0 dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=2) + count = 0 for item in dataloader: + count += 1 # test the last item which only has 1 data if len(item) == 1: self.assertListEqual(item["subject_id"], ["s000002"]) np.testing.assert_allclose(item["label"], [4]) self.assertListEqual(item["image"], ["./imgs/s000002.png"]) + self.assertEqual(count, 3) + dataset.close() + + # test iterable stream + iters = pd.read_csv(filepath1, chunksize=1000) + dataset = CSVIterableDataset(src=iters, shuffle=False) + count = 0 + for item in dataset: + count += 1 + if count == 3: + self.assertDictEqual( + {k: round(v, 4) if not isinstance(v, str) else v for k, v in item.items()}, + { + "subject_id": "s000002", + "label": 4, + "image": "./imgs/s000002.png", + "ehr_0": 3.7725, + "ehr_1": 4.2118, + "ehr_2": 4.6353, + }, + ) + break + self.assertEqual(count, 3) + dataset.close() + + # test multiple iterable streams, join tables with kwargs + iters = [pd.read_csv(i, chunksize=1000) for i in filepaths] + dataset = CSVIterableDataset(src=iters, on="subject_id", shuffle=False) + count = 0 + for item in dataset: + count += 1 + if count == 4: + self.assertEqual(item["subject_id"], "s000003") + self.assertEqual(item["label"], 1) + self.assertEqual(round(item["ehr_0"], 4), 3.3333) + self.assertEqual(item["meta_0"], False) + self.assertEqual(count, 5) + # manually close the pre-loaded iterables instead of `dataset.close()` + for i in iters: + i.close() if __name__ == "__main__": diff --git a/tests/test_csv_saver.py b/tests/test_csv_saver.py index 6dd0159322..01796da00c 100644 --- a/tests/test_csv_saver.py +++ b/tests/test_csv_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,14 +23,14 @@ class TestCSVSaver(unittest.TestCase): def test_saved_content(self): with tempfile.TemporaryDirectory() as tempdir: - saver = CSVSaver(output_dir=tempdir, filename="predictions.csv") + saver = CSVSaver(output_dir=tempdir, filename="predictions.csv", delimiter="\t") meta_data = {"filename_or_obj": ["testfile" + str(i) for i in range(8)]} saver.save_batch(torch.zeros(8), meta_data) saver.finalize() filepath = os.path.join(tempdir, "predictions.csv") self.assertTrue(os.path.exists(filepath)) - with open(filepath, "r") as f: - reader = csv.reader(f) + with open(filepath) as f: + reader = csv.reader(f, delimiter="\t") i = 0 for row in reader: self.assertEqual(row[0], "testfile" + str(i)) diff --git a/tests/test_cucim_dict_transform.py b/tests/test_cucim_dict_transform.py new file mode 100644 index 0000000000..f8b54c3147 --- /dev/null +++ b/tests/test_cucim_dict_transform.py @@ -0,0 +1,141 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import CuCIMd +from monai.utils import optional_import, set_determinism +from tests.utils import HAS_CUPY, skip_if_no_cuda + +_, has_cut = optional_import("cucim.core.operations.expose.transform") +cp, _ = optional_import("cupy") + +set_determinism(seed=0) + +TEST_CASE_COLOR_JITTER_1 = [ + {"name": "color_jitter", "brightness": 0.0, "contrast": 0.0, "saturation": 0.0, "hue": 0.0}, + np.array([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]], dtype=np.float32), + np.array([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]], dtype=np.float32), +] + +TEST_CASE_COLOR_JITTER_2 = [ + {"name": "color_jitter", "brightness": 0.0, "contrast": 0.0, "saturation": 0.0, "hue": 0.0}, + np.array([[[0, 1], [2, 3]], [[0, 1], [2, 3]], [[0, 1], [2, 3]]], dtype=np.uint8), + np.array([[[0, 1], [2, 3]], [[0, 1], [2, 3]], [[0, 1], [2, 3]]], dtype=np.uint8), +] + +TEST_CASE_FLIP_1 = [ + {"name": "image_flip", "spatial_axis": -1}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]]], dtype=np.float32), +] + + +TEST_CASE_ROTATE_1 = [ + {"name": "image_rotate_90", "k": 1, "spatial_axis": (-2, -1)}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]]], dtype=np.float32), +] + +TEST_CASE_SCALE_INTENSITY_1 = [ + {"name": "scale_intensity_range", "a_min": 0.0, "a_max": 4.0, "b_min": 0.0, "b_max": 1.0, "clip": False}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]]], dtype=np.float32), +] + +TEST_CASE_ZOOM_1 = [ + {"name": "zoom", "zoom_factor": (0.5, 0.5)}, + np.mgrid[:3, 1:4].astype(dtype=np.float32), + np.concatenate([np.ones((1, 3, 3), dtype=np.float32) * 1.0, np.ones((1, 3, 3), dtype=np.float32) * 2.0]), +] + + +@skip_if_no_cuda +@unittest.skipUnless(HAS_CUPY, "CuPy is required.") +@unittest.skipUnless(has_cut, "cuCIM transforms are required.") +class TestCuCIMDict(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_COLOR_JITTER_2, + TEST_CASE_FLIP_1, + TEST_CASE_ROTATE_1, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + ] + ) + def test_tramsforms_numpy_single(self, params, input, expected): + input = {"image": input} + output = CuCIMd(keys="image", **params)(input)["image"] + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, expected) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_COLOR_JITTER_2, + TEST_CASE_FLIP_1, + TEST_CASE_ROTATE_1, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + ] + ) + def test_tramsforms_numpy_batch(self, params, input, expected): + input = {"image": input[cp.newaxis, ...]} + expected = expected[cp.newaxis, ...] + output = CuCIMd(keys="image", **params)(input)["image"] + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, expected) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_COLOR_JITTER_2, + TEST_CASE_FLIP_1, + TEST_CASE_ROTATE_1, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + ] + ) + def test_tramsforms_cupy_single(self, params, input, expected): + input = {"image": cp.asarray(input)} + expected = cp.asarray(expected) + output = CuCIMd(keys="image", **params)(input)["image"] + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, expected) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_COLOR_JITTER_2, + TEST_CASE_FLIP_1, + TEST_CASE_ROTATE_1, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + ] + ) + def test_tramsforms_cupy_batch(self, params, input, expected): + input = {"image": cp.asarray(input)[cp.newaxis, ...]} + expected = cp.asarray(expected)[cp.newaxis, ...] + output = CuCIMd(keys="image", **params)(input)["image"] + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_cucim_transform.py b/tests/test_cucim_transform.py new file mode 100644 index 0000000000..2bf9791bce --- /dev/null +++ b/tests/test_cucim_transform.py @@ -0,0 +1,140 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import CuCIM +from monai.utils import optional_import, set_determinism +from tests.utils import HAS_CUPY, skip_if_no_cuda + +_, has_cut = optional_import("cucim.core.operations.expose.transform") +cp, _ = optional_import("cupy") + +set_determinism(seed=0) + +TEST_CASE_COLOR_JITTER_1 = [ + {"name": "color_jitter", "brightness": 0.0, "contrast": 0.0, "saturation": 0.0, "hue": 0.0}, + np.array([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]], dtype=np.float32), + np.array([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]], dtype=np.float32), +] + +TEST_CASE_COLOR_JITTER_2 = [ + {"name": "color_jitter", "brightness": 0.0, "contrast": 0.0, "saturation": 0.0, "hue": 0.0}, + np.array([[[0, 1], [2, 3]], [[0, 1], [2, 3]], [[0, 1], [2, 3]]], dtype=np.uint8), + np.array([[[0, 1], [2, 3]], [[0, 1], [2, 3]], [[0, 1], [2, 3]]], dtype=np.uint8), +] + +TEST_CASE_FLIP_1 = [ + {"name": "image_flip", "spatial_axis": -1}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]]], dtype=np.float32), +] + + +TEST_CASE_ROTATE_1 = [ + {"name": "image_rotate_90", "k": 1, "spatial_axis": (-2, -1)}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]]], dtype=np.float32), +] + +TEST_CASE_SCALE_INTENSITY_1 = [ + {"name": "scale_intensity_range", "a_min": 0.0, "a_max": 4.0, "b_min": 0.0, "b_max": 1.0, "clip": False}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]]], dtype=np.float32), +] + +TEST_CASE_ZOOM_1 = [ + {"name": "zoom", "zoom_factor": (0.5, 0.5)}, + np.mgrid[:3, 1:4].astype(dtype=np.float32), + np.concatenate([np.ones((1, 3, 3), dtype=np.float32) * 1.0, np.ones((1, 3, 3), dtype=np.float32) * 2.0]), +] + + +@skip_if_no_cuda +@unittest.skipUnless(HAS_CUPY, "CuPy is required.") +@unittest.skipUnless(has_cut, "cuCIM transforms are required.") +class TestCuCIM(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_COLOR_JITTER_2, + TEST_CASE_FLIP_1, + TEST_CASE_ROTATE_1, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + ] + ) + def test_tramsforms_numpy_single(self, params, input, expected): + output = CuCIM(**params)(input) + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, expected) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_COLOR_JITTER_2, + TEST_CASE_FLIP_1, + TEST_CASE_ROTATE_1, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + ] + ) + def test_tramsforms_numpy_batch(self, params, input, expected): + input = input[cp.newaxis, ...] + expected = expected[cp.newaxis, ...] + output = CuCIM(**params)(input) + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, expected) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_COLOR_JITTER_2, + TEST_CASE_FLIP_1, + TEST_CASE_ROTATE_1, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + ] + ) + def test_tramsforms_cupy_single(self, params, input, expected): + input = cp.asarray(input) + expected = cp.asarray(expected) + output = CuCIM(**params)(input) + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, expected) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_COLOR_JITTER_2, + TEST_CASE_FLIP_1, + TEST_CASE_ROTATE_1, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + ] + ) + def test_tramsforms_cupy_batch(self, params, input, expected): + input = cp.asarray(input)[cp.newaxis, ...] + expected = cp.asarray(expected)[cp.newaxis, ...] + output = CuCIM(**params)(input) + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_cuimage_reader.py b/tests/test_cuimage_reader.py deleted file mode 100644 index 2cbfaec113..0000000000 --- a/tests/test_cuimage_reader.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import unittest -from unittest import skipUnless - -import numpy as np -from numpy.testing import assert_array_equal -from parameterized import parameterized - -from monai.apps.utils import download_url -from monai.data.image_reader import WSIReader -from monai.utils import optional_import - -_, has_cim = optional_import("cucim") -PILImage, has_pil = optional_import("PIL.Image") - -FILE_URL = "http://openslide.cs.cmu.edu/download/openslide-testdata/Generic-TIFF/CMU-1.tiff" -FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + os.path.basename(FILE_URL)) - -HEIGHT = 32914 -WIDTH = 46000 - -TEST_CASE_0 = [FILE_PATH, (3, HEIGHT, WIDTH)] - -TEST_CASE_1 = [ - FILE_PATH, - {"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0}, - np.array([[[246], [246]], [[246], [246]], [[246], [246]]]), -] - -TEST_CASE_2 = [ - FILE_PATH, - {"location": (0, 0), "size": (2, 1), "level": 2}, - np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), -] - -TEST_CASE_3 = [ - FILE_PATH, - { - "location": (0, 0), - "size": (8, 8), - "level": 2, - "grid_shape": (2, 1), - "patch_size": 2, - }, - np.array( - [ - [[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[239, 239], [239, 239]]], - [[[242, 242], [242, 243]], [[242, 242], [242, 243]], [[242, 242], [242, 243]]], - ] - ), -] - -TEST_CASE_4 = [ - FILE_PATH, - { - "location": (0, 0), - "size": (8, 8), - "level": 2, - "grid_shape": (2, 1), - "patch_size": 1, - }, - np.array([[[[239]], [[239]], [[239]]], [[[243]], [[243]], [[243]]]]), -] - -TEST_CASE_RGB_0 = [ - np.ones((3, 2, 2), dtype=np.uint8), # CHW -] - -TEST_CASE_RGB_1 = [ - np.ones((3, 100, 100), dtype=np.uint8), # CHW -] - - -class TestCuCIMReader(unittest.TestCase): - @skipUnless(has_cim, "Requires CuCIM") - def setUp(self): - download_url(FILE_URL, FILE_PATH, "5a3cfd4fd725c50578ddb80b517b759f") - - @parameterized.expand([TEST_CASE_0]) - def test_read_whole_image(self, file_path, expected_shape): - reader = WSIReader("cuCIM") - img_obj = reader.read(file_path) - img = reader.get_data(img_obj)[0] - self.assertTupleEqual(img.shape, expected_shape) - - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_read_region(self, file_path, patch_info, expected_img): - reader = WSIReader("cuCIM") - img_obj = reader.read(file_path) - img = reader.get_data(img_obj, **patch_info)[0] - self.assertTupleEqual(img.shape, expected_img.shape) - self.assertIsNone(assert_array_equal(img, expected_img)) - - @parameterized.expand([TEST_CASE_3, TEST_CASE_4]) - def test_read_patches(self, file_path, patch_info, expected_img): - reader = WSIReader("cuCIM") - img_obj = reader.read(file_path) - img = reader.get_data(img_obj, **patch_info)[0] - self.assertTupleEqual(img.shape, expected_img.shape) - self.assertIsNone(assert_array_equal(img, expected_img)) - - @parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1]) - @skipUnless(has_pil, "Requires PIL") - def test_read_rgba(self, img_expected): - image = {} - reader = WSIReader("cuCIM") - for mode in ["RGB", "RGBA"]: - file_path = self.create_rgba_image(img_expected, "temp_cu_tiff_image", mode=mode) - img_obj = reader.read(file_path) - image[mode], _ = reader.get_data(img_obj) - - self.assertIsNone(assert_array_equal(image["RGB"], img_expected)) - self.assertIsNone(assert_array_equal(image["RGBA"], img_expected)) - - def create_rgba_image(self, array: np.ndarray, filename_prefix: str, mode: str): - file_path = os.path.join(os.path.dirname(__file__), "testing_data", f"{filename_prefix}_{mode}.tiff") - - if mode == "RGBA": - array = np.concatenate([array, 255 * np.ones_like(array[0])[np.newaxis]]).astype(np.uint8) - - img_rgb = array.transpose(1, 2, 0) - - image = PILImage.fromarray(img_rgb, mode=mode) - image.save(file_path) - - return file_path - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_cumulative.py b/tests/test_cumulative.py new file mode 100644 index 0000000000..12a6a5e5e7 --- /dev/null +++ b/tests/test_cumulative.py @@ -0,0 +1,59 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from monai.metrics import Cumulative +from tests.utils import assert_allclose + + +class TestCumulative(unittest.TestCase): + def test_single(self): + c = Cumulative() + c.extend([2, 3]) + c.append(1) + assert_allclose(c.get_buffer(), torch.tensor([2, 3, 1])) + + def test_multi(self): + c = Cumulative() + c.extend([2, 3], [4, 6]) + c.append(1) + assert_allclose(c.get_buffer()[0], torch.tensor([2, 3, 1])) + assert_allclose(c.get_buffer()[1], torch.tensor([4, 6])) + + c.reset() + c.append() + c.extend() + self.assertEqual(c.get_buffer(), []) + + c.reset() + + def test_ill(self): + c = Cumulative() + with self.assertRaises(TypeError): + c.extend(None) + with self.assertRaises(TypeError): + c.extend([]) + with self.assertRaises(TypeError): + c.extend(1) + with self.assertRaises(TypeError): + c.append([]) + c.append([1, 2]) + c.get_buffer() + with self.assertRaises(TypeError): + c.append(None) + c.get_buffer() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_cumulative_average.py b/tests/test_cumulative_average.py new file mode 100644 index 0000000000..4e7e4ff5d9 --- /dev/null +++ b/tests/test_cumulative_average.py @@ -0,0 +1,63 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.metrics import CumulativeAverage + +# single class value +TEST_CASE_1 = [[torch.as_tensor([[0.1]]), torch.as_tensor([[0.2]]), torch.as_tensor([[0.3]])], torch.as_tensor([0.2])] + +# multi-class value +TEST_CASE_2 = [ + [torch.as_tensor([[0.1, 0.2]]), torch.as_tensor([[0.2, 0.3]]), torch.as_tensor([[0.3, 0.4]])], + torch.as_tensor([0.2, 0.3]), +] + +# Nan value +TEST_CASE_3 = [ + [torch.as_tensor([[0.1]]), torch.as_tensor([[0.2]]), torch.as_tensor([[float("nan")]])], + torch.as_tensor([0.15]), +] + +# different input shape +TEST_CASE_4 = [[torch.as_tensor(0.1), torch.as_tensor(0.2), torch.as_tensor(0.3)], torch.as_tensor(0.2)] + + +class TestCumulativeAverage(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + def test_value(self, input_data, expected_value): + average = CumulativeAverage() + func = average.append if input_data[0].ndim < 2 else average.extend + func(input_data[0]) + func(input_data[1]) + result = average.aggregate() + # continue to update new data + func(input_data[2]) + result = average.aggregate() + torch.testing.assert_allclose(result, expected_value) + + def test_numpy_array(self): + class TestCumulativeAverage(CumulativeAverage): + def get_buffer(self): + return np.array([[1, 2], [3, np.nan]]) + + average = TestCumulativeAverage() + result = average.aggregate() + np.testing.assert_allclose(result, np.array([2.0, 2.0])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_cumulative_average_dist.py b/tests/test_cumulative_average_dist.py new file mode 100644 index 0000000000..5de139e9ac --- /dev/null +++ b/tests/test_cumulative_average_dist.py @@ -0,0 +1,47 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +import torch.distributed as dist + +from monai.metrics import CumulativeAverage +from tests.utils import DistCall, DistTestCase + + +class DistributedCumulativeAverage(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) + def test_value(self): + rank = dist.get_rank() + input_data = [ + [torch.as_tensor([[0.1]]), torch.as_tensor([[0.2]]), torch.as_tensor([[0.3]])], + [torch.as_tensor([[0.1]]), torch.as_tensor([[0.2]]), torch.as_tensor([[float("nan")]])], + [torch.as_tensor([[0.1, 0.2]]), torch.as_tensor([[0.2, 0.3]]), torch.as_tensor([[0.3, 0.4]])], + [torch.as_tensor(0.1), torch.as_tensor(0.2), torch.as_tensor(0.3)], + ] + expected = [torch.as_tensor([0.2]), torch.as_tensor([0.15]), torch.as_tensor([0.2, 0.3]), torch.as_tensor(0.2)] + average = CumulativeAverage() + + for i, e in zip(input_data, expected): + func = average.append if i[0].ndim < 2 else average.extend + if rank == 0: + func(i[0]) + func(i[1]) + else: + func(i[2]) + result = average.aggregate() + torch.testing.assert_allclose(result, e) + average.reset() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_data_stats.py b/tests/test_data_stats.py index 50536f2a5c..c18abfcedc 100644 --- a/tests/test_data_stats.py +++ b/tests/test_data_stats.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,6 +11,7 @@ import logging import os +import sys import tempfile import unittest @@ -28,7 +29,7 @@ "value_range": False, "data_value": False, "additional_info": None, - "logger_handler": None, + "name": "DataStats", }, np.array([[0, 1], [1, 2]]), "test data statistics:", @@ -42,7 +43,7 @@ "value_range": False, "data_value": False, "additional_info": None, - "logger_handler": None, + "name": "DataStats", }, np.array([[0, 1], [1, 2]]), "test data statistics:\nType: ", @@ -56,7 +57,7 @@ "value_range": False, "data_value": False, "additional_info": None, - "logger_handler": None, + "name": "DataStats", }, np.array([[0, 1], [1, 2]]), "test data statistics:\nType: \nShape: (2, 2)", @@ -70,7 +71,7 @@ "value_range": True, "data_value": False, "additional_info": None, - "logger_handler": None, + "name": "DataStats", }, np.array([[0, 1], [1, 2]]), "test data statistics:\nType: \nShape: (2, 2)\nValue range: (0, 2)", @@ -84,7 +85,7 @@ "value_range": True, "data_value": True, "additional_info": None, - "logger_handler": None, + "name": "DataStats", }, np.array([[0, 1], [1, 2]]), "test data statistics:\nType: \nShape: (2, 2)\nValue range: (0, 2)\nValue: [[0 1]\n [1 2]]", @@ -98,7 +99,7 @@ "value_range": True, "data_value": True, "additional_info": np.mean, - "logger_handler": None, + "name": "DataStats", }, np.array([[0, 1], [1, 2]]), ( @@ -115,7 +116,7 @@ "value_range": True, "data_value": True, "additional_info": lambda x: torch.mean(x.float()), - "logger_handler": None, + "name": "DataStats", }, torch.tensor([[0, 1], [1, 2]]).to("cuda" if torch.cuda.is_available() else "cpu"), ( @@ -126,7 +127,7 @@ TEST_CASE_8 = [ np.array([[0, 1], [1, 2]]), - "test data statistics:\nType: \nShape: (2, 2)\nValue range: (0, 2)\n" + "test data statistics:\nType: int64\nShape: (2, 2)\nValue range: (0, 2)\n" "Value: [[0 1]\n [1 2]]\nAdditional info: 1.0\n", ] @@ -144,6 +145,9 @@ def test_file(self, input_data, expected_print): filename = os.path.join(tempdir, "test_data_stats.log") handler = logging.FileHandler(filename, mode="w") handler.setLevel(logging.INFO) + name = "DataStats" + logger = logging.getLogger(name) + logger.addHandler(handler) input_param = { "prefix": "test data", "data_type": True, @@ -151,17 +155,17 @@ def test_file(self, input_data, expected_print): "value_range": True, "data_value": True, "additional_info": np.mean, - "logger_handler": handler, + "name": name, } transform = DataStats(**input_param) _ = transform(input_data) - _logger = logging.getLogger(transform._logger_name) - for h in _logger.handlers[:]: + for h in logger.handlers[:]: h.close() - _logger.removeHandler(h) - with open(filename, "r") as f: + logger.removeHandler(h) + with open(filename) as f: content = f.read() - self.assertEqual(content, expected_print) + if sys.platform != "win32": + self.assertEqual(content, expected_print) if __name__ == "__main__": diff --git a/tests/test_data_statsd.py b/tests/test_data_statsd.py index aea0f1e721..28da936cd0 100644 --- a/tests/test_data_statsd.py +++ b/tests/test_data_statsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,6 +11,7 @@ import logging import os +import sys import tempfile import unittest @@ -29,7 +30,7 @@ "value_range": False, "data_value": False, "additional_info": None, - "logger_handler": None, + "name": "DataStats", }, {"img": np.array([[0, 1], [1, 2]])}, "test data statistics:", @@ -44,7 +45,7 @@ "value_range": False, "data_value": False, "additional_info": None, - "logger_handler": None, + "name": "DataStats", }, {"img": np.array([[0, 1], [1, 2]])}, "test data statistics:\nType: ", @@ -59,7 +60,7 @@ "value_range": False, "data_value": False, "additional_info": None, - "logger_handler": None, + "name": "DataStats", }, {"img": np.array([[0, 1], [1, 2]])}, "test data statistics:\nType: \nShape: (2, 2)", @@ -74,7 +75,7 @@ "value_range": True, "data_value": False, "additional_info": None, - "logger_handler": None, + "name": "DataStats", }, {"img": np.array([[0, 1], [1, 2]])}, "test data statistics:\nType: \nShape: (2, 2)\nValue range: (0, 2)", @@ -89,7 +90,7 @@ "value_range": True, "data_value": True, "additional_info": None, - "logger_handler": None, + "name": "DataStats", }, {"img": np.array([[0, 1], [1, 2]])}, "test data statistics:\nType: \nShape: (2, 2)\nValue range: (0, 2)\nValue: [[0 1]\n [1 2]]", @@ -104,7 +105,7 @@ "value_range": True, "data_value": True, "additional_info": np.mean, - "logger_handler": None, + "name": "DataStats", }, {"img": np.array([[0, 1], [1, 2]])}, ( @@ -122,7 +123,7 @@ "value_range": True, "data_value": True, "additional_info": lambda x: torch.mean(x.float()), - "logger_handler": None, + "name": "DataStats", }, {"img": torch.tensor([[0, 1], [1, 2]]).to("cuda" if torch.cuda.is_available() else "cpu")}, ( @@ -140,6 +141,7 @@ "value_range": (True, False), "data_value": (False, True), "additional_info": (np.mean, None), + "name": "DataStats", }, {"img": np.array([[0, 1], [1, 2]]), "affine": np.eye(2, 2)}, "affine statistics:\nType: \nShape: (2, 2)\nValue: [[1. 0.]\n [0. 1.]]", @@ -147,23 +149,14 @@ TEST_CASE_9 = [ {"img": np.array([[0, 1], [1, 2]])}, - "test data statistics:\nType: \nShape: (2, 2)\nValue range: (0, 2)\n" + "test data statistics:\nType: int64\nShape: (2, 2)\nValue range: (0, 2)\n" "Value: [[0 1]\n [1 2]]\nAdditional info: 1.0\n", ] class TestDataStatsd(unittest.TestCase): @parameterized.expand( - [ - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - TEST_CASE_7, - TEST_CASE_8, - ] + [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8] ) def test_value(self, input_param, input_data, expected_print): transform = DataStatsd(**input_param) @@ -176,6 +169,9 @@ def test_file(self, input_data, expected_print): filename = os.path.join(tempdir, "test_stats.log") handler = logging.FileHandler(filename, mode="w") handler.setLevel(logging.INFO) + name = "DataStats" + logger = logging.getLogger(name) + logger.addHandler(handler) input_param = { "keys": "img", "prefix": "test data", @@ -183,18 +179,18 @@ def test_file(self, input_data, expected_print): "value_range": True, "data_value": True, "additional_info": np.mean, - "logger_handler": handler, + "name": name, } transform = DataStatsd(**input_param) _ = transform(input_data) - _logger = logging.getLogger(transform.printer._logger_name) - for h in _logger.handlers[:]: + for h in logger.handlers[:]: h.close() - _logger.removeHandler(h) + logger.removeHandler(h) del handler - with open(filename, "r") as f: + with open(filename) as f: content = f.read() - self.assertEqual(content, expected_print) + if sys.platform != "win32": + self.assertEqual(content, expected_print) if __name__ == "__main__": diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 3b159fb5b8..79126b2dbb 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,19 +20,9 @@ from monai.transforms import Compose, DataStatsd, Randomizable, SimulateDelayd from monai.utils import set_determinism -TEST_CASE_1 = [ - [ - {"image": np.asarray([1, 2, 3])}, - {"image": np.asarray([4, 5])}, - ] -] - -TEST_CASE_2 = [ - [ - {"label": torch.as_tensor([[3], [2]])}, - {"label": np.asarray([[1], [2]])}, - ] -] +TEST_CASE_1 = [[{"image": np.asarray([1, 2, 3])}, {"image": np.asarray([4, 5])}]] + +TEST_CASE_2 = [[{"label": torch.as_tensor([[3], [2]])}, {"label": np.asarray([[1], [2]])}]] class TestDataLoader(unittest.TestCase): diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 491b777550..f8d4ed2104 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_dataset_func.py b/tests/test_dataset_func.py new file mode 100644 index 0000000000..b5871d7de1 --- /dev/null +++ b/tests/test_dataset_func.py @@ -0,0 +1,52 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import tempfile +import unittest + +from monai.data import Dataset, DatasetFunc, load_decathlon_datalist, partition_dataset + + +class TestDatasetFunc(unittest.TestCase): + def test_seg_values(self): + with tempfile.TemporaryDirectory() as tempdir: + # prepare test datalist file + test_data = { + "name": "Spleen", + "description": "Spleen Segmentation", + "labels": {"0": "background", "1": "spleen"}, + "training": [ + {"image": "spleen_19.nii.gz", "label": "spleen_19.nii.gz"}, + {"image": "spleen_31.nii.gz", "label": "spleen_31.nii.gz"}, + ], + "test": ["spleen_15.nii.gz", "spleen_23.nii.gz"], + } + json_str = json.dumps(test_data) + file_path = os.path.join(tempdir, "test_data.json") + with open(file_path, "w") as json_file: + json_file.write(json_str) + + data_list = DatasetFunc( + data=file_path, func=load_decathlon_datalist, data_list_key="training", base_dir=tempdir + ) + # partition dataset for train / validation + data_partition = DatasetFunc( + data=data_list, func=lambda x, **kwargs: partition_dataset(x, **kwargs)[0], num_partitions=2 + ) + dataset = Dataset(data=data_partition, transform=None) + self.assertEqual(dataset[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz")) + self.assertEqual(dataset[0]["label"], os.path.join(tempdir, "spleen_19.nii.gz")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dataset_summary.py b/tests/test_dataset_summary.py index 5307bc7e66..51840f77ea 100644 --- a/tests/test_dataset_summary.py +++ b/tests/test_dataset_summary.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,6 +20,16 @@ from monai.data import Dataset, DatasetSummary, create_test_image_3d from monai.transforms import LoadImaged from monai.utils import set_determinism +from monai.utils.enums import PostFix + + +def test_collate(batch): + elem = batch[0] + elem_type = type(elem) + if isinstance(elem, np.ndarray): + return np.stack(batch, 0) + elif isinstance(elem, dict): + return elem_type({key: test_collate([d[key] for d in batch]) for key in elem}) class TestDatasetSummary(unittest.TestCase): @@ -40,9 +50,12 @@ def test_spacing_intensity(self): {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) ] - dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"])) + dataset = Dataset( + data=data_dicts, transform=LoadImaged(keys=["image", "label"], meta_keys=["test1", "test2"]) + ) - calculator = DatasetSummary(dataset, num_workers=4) + # test **kwargs of `DatasetSummary` for `DataLoader` + calculator = DatasetSummary(dataset, num_workers=4, meta_key="test1", collate_fn=test_collate) target_spacing = calculator.get_target_spacing() self.assertEqual(target_spacing, (1.0, 1.0, 1.0)) @@ -56,13 +69,7 @@ def test_spacing_intensity(self): def test_anisotropic_spacing(self): with tempfile.TemporaryDirectory() as tempdir: - pixdims = [ - [1.0, 1.0, 5.0], - [1.0, 1.0, 4.0], - [1.0, 1.0, 4.5], - [1.0, 1.0, 2.0], - [1.0, 1.0, 1.0], - ] + pixdims = [[1.0, 1.0, 5.0], [1.0, 1.0, 4.0], [1.0, 1.0, 4.5], [1.0, 1.0, 2.0], [1.0, 1.0, 1.0]] for i in range(5): im, seg = create_test_image_3d(32, 32, 32, num_seg_classes=1, num_objs=3, rad_max=6, channel_dim=0) n = nib.Nifti1Image(im, np.eye(4)) @@ -80,7 +87,7 @@ def test_anisotropic_spacing(self): dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"])) - calculator = DatasetSummary(dataset, num_workers=4) + calculator = DatasetSummary(dataset, num_workers=4, meta_key_postfix=PostFix.meta()) target_spacing = calculator.get_target_spacing(anisotropic_threshold=4.0, percentile=20.0) np.testing.assert_allclose(target_spacing, (1.0, 1.0, 1.8)) diff --git a/tests/test_decathlondataset.py b/tests/test_decathlondataset.py index 15dbceb8ad..744dccefaa 100644 --- a/tests/test_decathlondataset.py +++ b/tests/test_decathlondataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,11 +12,12 @@ import os import shutil import unittest -from urllib.error import ContentTooShortError, HTTPError +from pathlib import Path from monai.apps import DecathlonDataset from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord -from tests.utils import skip_if_quick +from monai.utils.enums import PostFix +from tests.utils import skip_if_downloading_fails, skip_if_quick class TestDecathlonDataset(unittest.TestCase): @@ -36,29 +37,26 @@ def _test_dataset(dataset): self.assertEqual(len(dataset), 52) self.assertTrue("image" in dataset[0]) self.assertTrue("label" in dataset[0]) - self.assertTrue("image_meta_dict" in dataset[0]) + self.assertTrue(PostFix.meta("image") in dataset[0]) self.assertTupleEqual(dataset[0]["image"].shape, (1, 36, 47, 44)) - try: # will start downloading if testing_dir doesn't have the Decathlon files + with skip_if_downloading_fails(): data = DecathlonDataset( root_dir=testing_dir, task="Task04_Hippocampus", transform=transform, section="validation", download=True, + copy_cache=False, ) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors _test_dataset(data) data = DecathlonDataset( root_dir=testing_dir, task="Task04_Hippocampus", transform=transform, section="validation", download=False ) _test_dataset(data) + self.assertTrue(data[0][PostFix.meta("image")]["filename_or_obj"].endswith("hippocampus_163.nii.gz")) + self.assertTrue(data[0][PostFix.meta("label")]["filename_or_obj"].endswith("hippocampus_163.nii.gz")) # test validation without transforms data = DecathlonDataset(root_dir=testing_dir, task="Task04_Hippocampus", section="validation", download=False) self.assertTupleEqual(data[0]["image"].shape, (36, 47, 44)) @@ -69,17 +67,14 @@ def _test_dataset(dataset): # test dataset properties data = DecathlonDataset( - root_dir=testing_dir, - task="Task04_Hippocampus", - section="validation", - download=False, + root_dir=Path(testing_dir), task="Task04_Hippocampus", section="validation", download=False ) properties = data.get_properties(keys="labels") self.assertDictEqual(properties["labels"], {"0": "background", "1": "Anterior", "2": "Posterior"}) shutil.rmtree(os.path.join(testing_dir, "Task04_Hippocampus")) try: - data = DecathlonDataset( + DecathlonDataset( root_dir=testing_dir, task="Task04_Hippocampus", transform=transform, diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 521d263663..adeaa73337 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -38,7 +38,7 @@ from monai.transforms.inverse_batch_transform import Decollated from monai.transforms.spatial.dictionary import RandAffined, RandRotate90d from monai.utils import optional_import, set_determinism -from monai.utils.enums import InverseKeys +from monai.utils.enums import PostFix, TraceKeys from tests.utils import make_nifti_image _, has_nib = optional_import("nibabel") @@ -75,6 +75,7 @@ [[None, None], [None, None]], [["test"], ["test"]], [[], []], + [[("ch1", "ch2"), ("ch3",)], [["ch1", "ch3"], ["ch2", None]]], # default pad None ] @@ -105,7 +106,7 @@ def check_match(self, in1, in2): k1, k2 = k1.value, k2.value self.check_match(k1, k2) # Transform ids won't match for windows with multiprocessing, so don't check values - if k1 == InverseKeys.ID and sys.platform in ["darwin", "win32"]: + if k1 == TraceKeys.ID and sys.platform in ["darwin", "win32"]: continue self.check_match(v1, v2) elif isinstance(in1, (list, tuple)): @@ -120,7 +121,7 @@ def check_match(self, in1, in2): def check_decollate(self, dataset): batch_size = 2 - num_workers = 2 + num_workers = 2 if sys.platform == "linux" else 0 loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) @@ -170,13 +171,10 @@ def test_decollation_examples(self, input_val, expected_out): self.assertListEqual(expected_out, out) def test_dict_examples(self): - test_case = { - "meta": {"out": ["test", "test"]}, - "image_meta_dict": {"scl_slope": torch.Tensor((0.0, 0.0))}, - } + test_case = {"meta": {"out": ["test", "test"]}, PostFix.meta("image"): {"scl_slope": torch.Tensor((0.0, 0.0))}} out = decollate_batch(test_case) self.assertEqual(out[0]["meta"]["out"], "test") - self.assertEqual(out[0]["image_meta_dict"]["scl_slope"], 0.0) + self.assertEqual(out[0][PostFix.meta("image")]["scl_slope"], 0.0) test_case = [torch.ones((2, 1, 10, 10)), torch.ones((2, 3, 5, 5))] out = decollate_batch(test_case) @@ -210,38 +208,52 @@ def test_dict_examples(self): out = decollate_batch(test_case, detach=False) self.assertEqual(out[0]["out"], "test") + test_case = { + "image": torch.tensor([[[1, 2, 3]], [[3, 4, 5]]]), + "label": torch.tensor([[[5]], [[7]]]), + "out": ["test"], + } + out = decollate_batch(test_case, detach=False, pad=False) + self.assertEqual(len(out), 1) # no padding + out = decollate_batch(test_case, detach=False, pad=True, fill_value=0) + self.assertEqual(out[1]["out"], 0) # verify padding fill_value + def test_decollated(self): test_case = { "image": torch.tensor([[[1, 2]], [[3, 4]]]), "meta": {"out": ["test", "test"]}, - "image_meta_dict": {"scl_slope": torch.Tensor((0.0, 0.0))}, + PostFix.meta("image"): {"scl_slope": torch.Tensor((0.0, 0.0))}, "loss": 0.85, } - transform = Decollated(keys=["meta", "image_meta_dict"], detach=False) + transform = Decollated(keys=["meta", PostFix.meta("image")], detach=False) out = transform(test_case) self.assertFalse("loss" in out) self.assertEqual(out[0]["meta"]["out"], "test") - self.assertEqual(out[0]["image_meta_dict"]["scl_slope"], 0.0) - self.assertTrue(isinstance(out[0]["image_meta_dict"]["scl_slope"], torch.Tensor)) + self.assertEqual(out[0][PostFix.meta("image")]["scl_slope"], 0.0) + self.assertTrue(isinstance(out[0][PostFix.meta("image")]["scl_slope"], torch.Tensor)) # decollate all data with keys=None transform = Decollated(keys=None, detach=True) out = transform(test_case) self.assertEqual(out[1]["loss"], 0.85) self.assertEqual(out[0]["meta"]["out"], "test") - self.assertEqual(out[0]["image_meta_dict"]["scl_slope"], 0.0) - self.assertTrue(isinstance(out[0]["image_meta_dict"]["scl_slope"], float)) + self.assertEqual(out[0][PostFix.meta("image")]["scl_slope"], 0.0) + self.assertTrue(isinstance(out[0][PostFix.meta("image")]["scl_slope"], float)) # test list input test_case = [ torch.tensor([[[1, 2]], [[3, 4]]]), {"out": ["test", "test"]}, {"scl_slope": torch.Tensor((0.0, 0.0))}, + {"out2": ["test1"]}, 0.85, + [], ] - transform = Decollated(keys=None, detach=False) + transform = Decollated(keys=None, detach=False, fill_value=-1) out = transform(test_case) - # the 4th item in the list is scalar loss value - self.assertEqual(out[1][3], 0.85) + + self.assertEqual(out[0][-2], 0.85) # scalar replicates + self.assertEqual(out[1][-2], 0.85) # scalar replicates + self.assertEqual(out[1][-3], -1) # fill value for the dictionary item self.assertEqual(out[0][1]["out"], "test") self.assertEqual(out[0][2]["scl_slope"], 0.0) self.assertTrue(isinstance(out[0][2]["scl_slope"], torch.Tensor)) diff --git a/tests/test_deepedit_transforms.py b/tests/test_deepedit_transforms.py index c2b11e8ee7..a5c5f0fe2f 100644 --- a/tests/test_deepedit_transforms.py +++ b/tests/test_deepedit_transforms.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.apps.deepedit.transforms import ClickRatioAddRandomGuidanced, DiscardAddGuidanced, ResizeGuidanceCustomd +from monai.utils.enums import PostFix IMAGE = np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]) LABEL = np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]) @@ -22,17 +23,13 @@ DATA_1 = { "image": IMAGE, "label": LABEL, - "image_meta_dict": {"dim": IMAGE.shape}, - "label_meta_dict": {}, + PostFix.meta("image"): {"dim": IMAGE.shape}, + PostFix.meta("label"): {}, "foreground": [0, 0, 0], "background": [0, 0, 0], } -DISCARD_ADD_GUIDANCE_TEST_CASE = [ - {"image": IMAGE, "label": LABEL}, - DATA_1, - (3, 1, 5, 5), -] +DISCARD_ADD_GUIDANCE_TEST_CASE = [{"image": IMAGE, "label": LABEL}, DATA_1, (3, 1, 5, 5)] DATA_2 = { "image": IMAGE, @@ -55,7 +52,7 @@ DATA_3 = { "image": np.arange(1000).reshape((1, 5, 10, 20)), - "image_meta_dict": {"foreground_cropped_shape": (1, 10, 20, 40), "dim": [3, 512, 512, 128]}, + PostFix.meta("image"): {"foreground_cropped_shape": (1, 10, 20, 40), "dim": [3, 512, 512, 128]}, "guidance": [[[6, 10, 14], [8, 10, 14]], [[8, 10, 16]]], "foreground": [[10, 14, 6], [10, 14, 8]], "background": [[10, 16, 8]], diff --git a/tests/test_deepgrow_dataset.py b/tests/test_deepgrow_dataset.py index 147d8e7099..ff8de87b81 100644 --- a/tests/test_deepgrow_dataset.py +++ b/tests/test_deepgrow_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_deepgrow_interaction.py b/tests/test_deepgrow_interaction.py index 016ba17251..b040348b62 100644 --- a/tests/test_deepgrow_interaction.py +++ b/tests/test_deepgrow_interaction.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_deepgrow_transforms.py b/tests/test_deepgrow_transforms.py index f50e92d146..436bef0c5b 100644 --- a/tests/test_deepgrow_transforms.py +++ b/tests/test_deepgrow_transforms.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -27,16 +27,12 @@ SpatialCropForegroundd, SpatialCropGuidanced, ) +from monai.utils.enums import PostFix IMAGE = np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]) LABEL = np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]) -DATA_1 = { - "image": IMAGE, - "label": LABEL, - "image_meta_dict": {}, - "label_meta_dict": {}, -} +DATA_1 = {"image": IMAGE, "label": LABEL, PostFix.meta("image"): {}, PostFix.meta("label"): {}} DATA_2 = { "image": np.array( @@ -79,21 +75,21 @@ DATA_5 = { "image": np.arange(25).reshape((1, 5, 5)), - "image_meta_dict": {"spatial_shape": [5, 5, 1]}, + PostFix.meta("image"): {"spatial_shape": [5, 5, 1]}, "foreground": [[2, 2, 0]], "background": [], } DATA_6 = { "image": np.arange(25).reshape((1, 5, 5)), - "image_meta_dict": {"spatial_shape": [5, 2, 1]}, + PostFix.meta("image"): {"spatial_shape": [5, 2, 1]}, "foreground": [[2, 1, 0]], "background": [[1, 0, 0]], } DATA_7 = { "image": np.arange(500).reshape((5, 10, 10)), - "image_meta_dict": {"spatial_shape": [20, 20, 10]}, + PostFix.meta("image"): {"spatial_shape": [20, 20, 10]}, "foreground": [[10, 14, 6], [10, 14, 8]], "background": [[10, 16, 8]], "slice": 6, @@ -101,19 +97,19 @@ DATA_8 = { "image": np.arange(500).reshape((1, 5, 10, 10)), - "image_meta_dict": {"spatial_shape": [20, 20, 10]}, + PostFix.meta("image"): {"spatial_shape": [20, 20, 10]}, "guidance": [[[3, 5, 7], [4, 5, 7]], [[4, 5, 8]]], } DATA_9 = { "image": np.arange(1000).reshape((1, 5, 10, 20)), - "image_meta_dict": {"foreground_cropped_shape": (1, 10, 20, 40)}, + PostFix.meta("image"): {"foreground_cropped_shape": (1, 10, 20, 40)}, "guidance": [[[6, 10, 14], [8, 10, 14]], [[8, 10, 16]]], } DATA_10 = { "image": np.arange(9).reshape((1, 1, 3, 3)), - "image_meta_dict": { + PostFix.meta("image"): { "spatial_shape": [3, 3, 1], "foreground_start_coord": np.array([0, 0, 0]), "foreground_end_coord": np.array([1, 3, 3]), @@ -128,7 +124,7 @@ DATA_11 = { "image": np.arange(500).reshape((1, 5, 10, 10)), - "image_meta_dict": { + PostFix.meta("image"): { "spatial_shape": [20, 20, 10], "foreground_start_coord": np.array([2, 2, 2]), "foreground_end_coord": np.array([4, 4, 4]), @@ -141,23 +137,11 @@ "pred": np.array([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]), } -DATA_12 = { - "image": np.arange(27).reshape(3, 3, 3), - "image_meta_dict": {}, - "guidance": [[0, 0, 0], [0, 1, 1], 1], -} +DATA_12 = {"image": np.arange(27).reshape(3, 3, 3), PostFix.meta("image"): {}, "guidance": [[0, 0, 0], [0, 1, 1], 1]} -FIND_SLICE_TEST_CASE_1 = [ - {"label": "label", "sids": "sids"}, - DATA_1, - [0], -] +FIND_SLICE_TEST_CASE_1 = [{"label": "label", "sids": "sids"}, DATA_1, [0]] -FIND_SLICE_TEST_CASE_2 = [ - {"label": "label", "sids": "sids"}, - DATA_2, - [0, 1], -] +FIND_SLICE_TEST_CASE_2 = [{"label": "label", "sids": "sids"}, DATA_2, [0, 1]] CROP_TEST_CASE_1 = [ { @@ -338,14 +322,10 @@ [[1.0, 1.0, 2.0, 2.0], [1.0, 1.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0], [3.0, 3.0, 4.0, 4.0]], [[5.0, 5.0, 6.0, 6.0], [5.0, 5.0, 6.0, 6.0], [7.0, 7.0, 8.0, 8.0], [7.0, 7.0, 8.0, 8.0]], [[5.0, 5.0, 6.0, 6.0], [5.0, 5.0, 6.0, 6.0], [7.0, 7.0, 8.0, 8.0], [7.0, 7.0, 8.0, 8.0]], - ], + ] ) -RESTORE_LABEL_TEST_CASE_2 = [ - {"keys": ["pred"], "ref_image": "image", "mode": "nearest"}, - DATA_11, - RESULT, -] +RESTORE_LABEL_TEST_CASE_2 = [{"keys": ["pred"], "ref_image": "image", "mode": "nearest"}, DATA_11, RESULT] FETCH_2D_SLICE_TEST_CASE_1 = [ {"keys": ["image"], "guidance": "guidance"}, @@ -375,14 +355,14 @@ def test_correct_shape(self, arguments, input_data, expected_shape): @parameterized.expand([CROP_TEST_CASE_1]) def test_foreground_position(self, arguments, input_data, _): result = SpatialCropForegroundd(**arguments)(input_data) - np.testing.assert_allclose(result["image_meta_dict"]["foreground_start_coord"], np.array([0, 1, 1])) - np.testing.assert_allclose(result["image_meta_dict"]["foreground_end_coord"], np.array([1, 4, 4])) + np.testing.assert_allclose(result[PostFix.meta("image")]["foreground_start_coord"], np.array([0, 1, 1])) + np.testing.assert_allclose(result[PostFix.meta("image")]["foreground_end_coord"], np.array([1, 4, 4])) arguments["start_coord_key"] = "test_start_coord" arguments["end_coord_key"] = "test_end_coord" result = SpatialCropForegroundd(**arguments)(input_data) - np.testing.assert_allclose(result["image_meta_dict"]["test_start_coord"], np.array([0, 1, 1])) - np.testing.assert_allclose(result["image_meta_dict"]["test_end_coord"], np.array([1, 4, 4])) + np.testing.assert_allclose(result[PostFix.meta("image")]["test_start_coord"], np.array([0, 1, 1])) + np.testing.assert_allclose(result[PostFix.meta("image")]["test_end_coord"], np.array([1, 4, 4])) class TestAddInitialSeedPointd(unittest.TestCase): diff --git a/tests/test_delete_itemsd.py b/tests/test_delete_itemsd.py index 7426e39ff0..450abb40db 100644 --- a/tests/test_delete_itemsd.py +++ b/tests/test_delete_itemsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,22 +16,40 @@ from parameterized import parameterized from monai.transforms import DeleteItemsd +from monai.utils.enums import PostFix TEST_CASE_1 = [{"keys": [str(i) for i in range(30)]}, 20] +TEST_CASE_2 = [{"keys": ["image/" + str(i) for i in range(30)], "sep": "/"}, 20] + +TEST_CASE_3 = [{"keys": "meta_dict%0008\\|[0-9]", "sep": "%", "use_re": True}] + class TestDeleteItemsd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_memory(self, input_param, expected_key_size): - input_data = {} + input_data = {"image": {}} if "sep" in input_param else {} for i in range(50): - input_data[str(i)] = [time.time()] * 100000 + if "sep" in input_param: + input_data["image"][str(i)] = [time.time()] * 100000 + else: + input_data[str(i)] = [time.time()] * 100000 result = DeleteItemsd(**input_param)(input_data) - self.assertEqual(len(result.keys()), expected_key_size) + if "sep" in input_param: + self.assertEqual(len(result["image"].keys()), expected_key_size) + else: + self.assertEqual(len(result.keys()), expected_key_size) self.assertGreaterEqual( sys.getsizeof(input_data) * float(expected_key_size) / len(input_data), sys.getsizeof(result) ) + @parameterized.expand([TEST_CASE_3]) + def test_re(self, input_param): + input_data = {"image": [1, 2, 3], PostFix.meta(): {"0008|0005": 1, "0008|1050": 2, "0008test": 3}} + result = DeleteItemsd(**input_param)(input_data) + self.assertEqual(result[PostFix.meta()]["0008test"], 3) + self.assertTrue(len(result[PostFix.meta()]), 1) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_densenet.py b/tests/test_densenet.py index ba4b7afcb4..47f584297e 100644 --- a/tests/test_densenet.py +++ b/tests/test_densenet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 429d5ee767..545031321b 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -29,7 +29,7 @@ def test_warning(self): def foo2(): pass - print(foo2()) + foo2() # should not raise any warnings def test_warning_milestone(self): """Test deprecated decorator with `since` and `removed` set for a milestone version""" @@ -172,7 +172,7 @@ def test_arg_warn2(self): """Test deprecated_arg decorator with just `since` set.""" @deprecated_arg("b", since=self.prev_version, version_val=self.test_version) - def afoo2(a, **kwargs): + def afoo2(a, **kw): pass afoo2(1) # ok when no b provided @@ -222,3 +222,72 @@ def future1(): warnings.warn("fake warning", DeprecationWarning) self.assertEqual(aw.warning.args[0], "fake warning") + + def test_arg_except2_unknown(self): + """ + Test deprecated_arg decorator raises exception with `removed` set in the past. + with unknown version + """ + + @deprecated_arg("b", removed=self.prev_version, version_val="0+untagged.1.g3131155") + def afoo4(a, b=None): + pass + + self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2)) + + def test_arg_except3_unknown(self): + """ + Test deprecated_arg decorator raises exception with `removed` set in the past. + with unknown version and kwargs + """ + + @deprecated_arg("b", removed=self.prev_version, version_val="0+untagged.1.g3131155") + def afoo4(a, b=None, **kwargs): + pass + + self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2)) + self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2, c=3)) + + def test_replacement_arg(self): + """ + Test deprecated arg being replaced. + """ + + @deprecated_arg("b", new_name="a", since=self.prev_version, version_val=self.test_version) + def afoo4(a, b=None): + return a + + self.assertEqual(afoo4(b=2), 2) + self.assertEqual(afoo4(1, b=2), 1) # new name is in use + self.assertEqual(afoo4(a=1, b=2), 1) # prefers the new arg + + def test_replacement_arg1(self): + """ + Test deprecated arg being replaced with kwargs. + """ + + @deprecated_arg("b", new_name="a", since=self.prev_version, version_val=self.test_version) + def afoo4(a, *args, **kwargs): + return a + + self.assertEqual(afoo4(b=2), 2) + self.assertEqual(afoo4(1, b=2, c=3), 1) # new name is in use + self.assertEqual(afoo4(a=1, b=2, c=3), 1) # prefers the new arg + + def test_replacement_arg2(self): + """ + Test deprecated arg (with a default value) being replaced. + """ + + @deprecated_arg("b", new_name="a", since=self.prev_version, version_val=self.test_version) + def afoo4(a, b=None, **kwargs): + return a, kwargs + + self.assertEqual(afoo4(b=2, c=3), (2, {"c": 3})) + self.assertEqual(afoo4(1, b=2, c=3), (1, {"c": 3})) # new name is in use + self.assertEqual(afoo4(a=1, b=2, c=3), (1, {"c": 3})) # prefers the new arg + self.assertEqual(afoo4(1, 2, c=3), (1, {"c": 3})) # prefers the new positional arg + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_detect_envelope.py b/tests/test_detect_envelope.py index ded0290de2..f1b9c7ad1a 100644 --- a/tests/test_detect_envelope.py +++ b/tests/test_detect_envelope.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import DetectEnvelope @@ -70,9 +71,9 @@ TEST_CASE_2_CHAN_3D_SINE = [ {}, # args (empty, so use default (i.e. process along first spatial dimension, axis=1) # Create 100 identical windowed sine waves as a (n_samples x 10 x 10) 3D numpy array, twice (2 channels) - np.stack([np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2)] * 2, axis=0), + torch.as_tensor(np.stack([np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2)] * 2, axis=0)), # Expected output: Set of 100 identical Hann windows in (n_samples x 10 x 10) 3D numpy array, twice (2 channels) - np.stack([np.stack([np.stack([np.hanning(n_samples)] * 10, axis=1)] * 10, axis=2)] * 2, axis=0), + torch.as_tensor(np.stack([np.stack([np.stack([np.hanning(n_samples)] * 10, axis=1)] * 10, axis=2)] * 2, axis=0)), 1e-4, # absolute tolerance ] diff --git a/tests/test_dev_collate.py b/tests/test_dev_collate.py new file mode 100644 index 0000000000..83dbd71d28 --- /dev/null +++ b/tests/test_dev_collate.py @@ -0,0 +1,45 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data.utils import dev_collate + +TEST_CASES = [ + [ + [ + {"img": 2, "meta": {"shape": [torch.tensor(1.0)]}}, + {"img": 3, "meta": {"shape": [np.asarray(1.0)]}}, + {"img": 4, "meta": {"shape": [torch.tensor(1.0)]}}, + ], + "got numpy.ndarray", + ], + [[["img", np.array([2])], ["img", np.array([3, 4])], ["img", np.array([4])]], "size"], + [[["img", [2]], ["img", [3, 4]], ["img", 4]], "type"], + [[["img", [2, 2]], ["img", [3, 4]], ["img", 4]], "type"], +] + + +class DevCollateTest(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_dev_collate(self, inputs, msg): + with self.assertLogs(level=logging.CRITICAL) as log: + dev_collate(inputs) + self.assertRegex(" ".join(log.output), f"{msg}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py index 66cfb36e99..ff2cd00b02 100644 --- a/tests/test_dice_ce_loss.py +++ b/tests/test_dice_ce_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -80,6 +80,11 @@ def test_ill_shape(self): with self.assertRaisesRegex(ValueError, ""): loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + def test_ill_reduction(self): + with self.assertRaisesRegex(ValueError, ""): + loss = DiceCELoss(reduction="none") + loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + @SkipIfBeforePyTorchVersion((1, 7, 0)) def test_script(self): loss = DiceCELoss() diff --git a/tests/test_dice_focal_loss.py b/tests/test_dice_focal_loss.py index 920994f8de..c611fe4160 100644 --- a/tests/test_dice_focal_loss.py +++ b/tests/test_dice_focal_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,11 +24,7 @@ def test_result_onehot_target_include_bg(self): label = torch.randint(low=0, high=2, size=size) pred = torch.randn(size) for reduction in ["sum", "mean", "none"]: - common_params = { - "include_background": True, - "to_onehot_y": False, - "reduction": reduction, - } + common_params = {"include_background": True, "to_onehot_y": False, "reduction": reduction} for focal_weight in [None, torch.tensor([1.0, 1.0, 2.0]), (3, 2.0, 1)]: for lambda_focal in [0.5, 1.0, 1.5]: dice_focal = DiceFocalLoss( @@ -46,11 +42,7 @@ def test_result_no_onehot_no_bg(self): label = torch.argmax(label, dim=1, keepdim=True) pred = torch.randn(size) for reduction in ["sum", "mean", "none"]: - common_params = { - "include_background": False, - "to_onehot_y": True, - "reduction": reduction, - } + common_params = {"include_background": False, "to_onehot_y": True, "reduction": reduction} for focal_weight in [2.0, torch.tensor([1.0, 2.0]), (2.0, 1)]: for lambda_focal in [0.5, 1.0, 1.5]: dice_focal = DiceFocalLoss(focal_weight=focal_weight, lambda_focal=lambda_focal, **common_params) diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index ef0a51eb15..4e45393de6 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,10 +21,7 @@ TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.307576, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) @@ -64,7 +61,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - [[0.296529, 0.415136], [0.599976, 0.428559]], + [[[0.296529], [0.415136]], [[0.599976], [0.428559]]], ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": True, "to_onehot_y": True, "softmax": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, @@ -91,26 +88,17 @@ ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.307576, ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "squared_pred": True}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.178337, ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "jaccard": True}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.470451, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) diff --git a/tests/test_dints_cell.py b/tests/test_dints_cell.py new file mode 100644 index 0000000000..d480235b70 --- /dev/null +++ b/tests/test_dints_cell.py @@ -0,0 +1,77 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.nets.dints import Cell + +TEST_CASES_3D = [ + [ + {"c_prev": 8, "c": 8, "rate": 1, "arch_code_c": None}, + torch.tensor([1, 1, 1, 1, 1]), + torch.tensor([1, 1, 1, 1, 1]), + (2, 8, 32, 16, 8), + (2, 8, 64, 32, 16), + ], + [ + {"c_prev": 8, "c": 4, "rate": 1, "arch_code_c": [1, 1, 0, 0, 1]}, + torch.tensor([1, 1, 0, 0, 1]), + torch.tensor([1, 0.2, 1.3, 0, 1]), + (2, 8, 32, 16, 8), + (2, 4, 64, 32, 16), + ], + [ + {"c_prev": 8, "c": 8, "rate": 0, "arch_code_c": None}, + torch.tensor([1, 1, 1, 1, 1]), + torch.tensor([0, 0, 0, 1, 0]), + (2, 8, 32, 16, 8), + (2, 8, 32, 16, 8), + ], + [ + {"c_prev": 8, "c": 8, "rate": -1, "arch_code_c": None}, + torch.tensor([1, 1, 1, 1, 1]), + torch.tensor([1, 1, 1, 1, 1]), + (2, 8, 32, 16, 8), + (2, 8, 16, 8, 4), + ], + [ + {"c_prev": 8, "c": 8, "rate": -1, "arch_code_c": [1, 0, 0, 0, 1]}, + torch.tensor([1, 0, 0, 0, 1]), + torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2]), + (2, 8, 32, 16, 8), + (2, 8, 16, 8, 4), + ], +] + +TEST_CASES_2D = [ + [ + {"c_prev": 8, "c": 7, "rate": -1, "arch_code_c": [1, 0, 0, 0, 1], "spatial_dims": 2}, + torch.tensor([1, 0]), + torch.tensor([0.2, 0.2]), + (2, 8, 16, 8), + (2, 7, 8, 4), + ] +] + + +class TestCell(unittest.TestCase): + @parameterized.expand(TEST_CASES_2D + TEST_CASES_3D) + def test_cell_3d(self, input_param, ops, weight, input_shape, expected_shape): + net = Cell(**input_param) + result = net(torch.randn(input_shape), weight=weight) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dints_mixop.py b/tests/test_dints_mixop.py new file mode 100644 index 0000000000..b686069173 --- /dev/null +++ b/tests/test_dints_mixop.py @@ -0,0 +1,82 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.nets.dints import Cell, MixedOp +from tests.utils import test_script_save + +TEST_CASES_3D = [ + [ + {"c": 8, "arch_code_c": None}, + torch.tensor([1, 1, 1, 1, 1]), + torch.tensor([1, 1, 1, 1, 1]), + (2, 8, 32, 16, 8), + (2, 8, 32, 16, 8), + ], + [ + {"c": 8, "arch_code_c": [1, 1, 0, 0, 1]}, + torch.tensor([1, 1, 0, 0, 1]), + torch.tensor([1, 0.2, 1.3, 0, 1]), + (2, 8, 64, 32, 16), + (2, 8, 64, 32, 16), + ], + [ + {"c": 8, "arch_code_c": None}, + torch.tensor([1, 1, 1, 1, 1]), + torch.tensor([0, 0, 0, 1, 0]), + (2, 8, 32, 16, 8), + (2, 8, 32, 16, 8), + ], + [ + {"c": 8, "arch_code_c": [1, 1, 1, 0, 1]}, + torch.tensor([1, 1, 1, 1, 1]), + torch.tensor([0, 0, 0, 1, 0]), + (2, 8, 32, 16, 8), + (2, 8, 32, 16, 8), + ], +] +TEST_CASES_2D = [ + [ + {"c": 32, "arch_code_c": [1, 1, 1, 0, 1]}, + torch.tensor([1, 1]), + torch.tensor([0, 0]), + (2, 32, 16, 8), + (2, 32, 16, 8), + ] +] + + +class TestMixOP(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D) + def test_mixop_3d(self, input_param, ops, weight, input_shape, expected_shape): + net = MixedOp(ops=Cell.OPS3D, **input_param) + result = net(torch.randn(input_shape), weight=weight) + self.assertEqual(result.shape, expected_shape) + self.assertEqual(result.shape, expected_shape) + + @parameterized.expand(TEST_CASES_2D) + def test_mixop_2d(self, input_param, ops, weight, input_shape, expected_shape): + net = MixedOp(ops=Cell.OPS2D, **input_param) + result = net(torch.randn(input_shape), weight=weight) + self.assertEqual(result.shape, expected_shape) + + @parameterized.expand(TEST_CASES_3D) + def test_script(self, input_param, ops, weight, input_shape, expected_shape): + net = MixedOp(ops=Cell.OPS3D, **input_param) + test_script_save(net, torch.randn(input_shape), weight) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dints_network.py b/tests/test_dints_network.py new file mode 100644 index 0000000000..8be5eb7ccd --- /dev/null +++ b/tests/test_dints_network.py @@ -0,0 +1,165 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.nets import DiNTS, TopologyInstance, TopologySearch +from monai.networks.nets.dints import Cell +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save + +TEST_CASES_3D = [ + [ + { + "channel_mul": 0.2, + "num_blocks": 6, + "num_depths": 3, + "device": "cpu", + "use_downsample": False, + "spatial_dims": 3, + }, + { + "in_channels": 1, + "num_classes": 3, + "act_name": "RELU", + "norm_name": "INSTANCE", + "use_downsample": False, + "spatial_dims": 3, + }, + (3, 1, 32, 32, 16), + (3, 3, 32, 32, 16), + ] +] +if torch.cuda.is_available(): + TEST_CASES_3D += [ + [ + { + "channel_mul": 0.5, + "num_blocks": 7, + "num_depths": 4, + "device": "cuda", + "use_downsample": True, + "spatial_dims": 3, + }, + { + "in_channels": 2, + "num_classes": 2, + "act_name": "PRELU", + "norm_name": "BATCH", + "use_downsample": True, + "spatial_dims": 3, + }, + (3, 2, 32, 32, 16), + (3, 2, 32, 32, 16), + ] + ] +TEST_CASES_2D = [ + [ + { + "channel_mul": 1, + "num_blocks": 7, + "num_depths": 4, + "device": "cpu", + "use_downsample": True, + "spatial_dims": 2, + }, + { + "in_channels": 2, + "num_classes": 2, + "act_name": "PRELU", + "norm_name": "BATCH", + "use_downsample": True, + "spatial_dims": 2, + }, + (2, 2, 32, 16), + (2, 2, 32, 16), + ] +] +if torch.cuda.is_available(): + TEST_CASES_2D += [ + [ + { + "channel_mul": 0.5, + "num_blocks": 8, + "num_depths": 4, + "device": "cuda", + "use_downsample": False, + "spatial_dims": 2, + }, + { + "in_channels": 1, + "num_classes": 4, + "act_name": "RELU", + "norm_name": "INSTANCE", + "use_downsample": False, + "spatial_dims": 2, + }, + (2, 1, 32, 16), + (2, 4, 32, 16), + ] + ] + + +class TestDints(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D + TEST_CASES_2D) + def test_dints_inference(self, dints_grid_params, dints_params, input_shape, expected_shape): + grid = TopologySearch(**dints_grid_params) + dints_params["dints_space"] = grid + net = DiNTS(**dints_params).to(dints_grid_params["device"]) + result = net(torch.randn(input_shape).to(dints_grid_params["device"])) + self.assertEqual(result.shape, expected_shape) + # test functions + grid.get_ram_cost_usage(in_size=input_shape, full=True) + grid.get_ram_cost_usage(in_size=input_shape, full=False) + probs_a, _ = grid.get_prob_a(child=True) + grid.get_topology_entropy(probs_a) + grid.decode() + grid.gen_mtx(depth=4) + + @parameterized.expand(TEST_CASES_3D + TEST_CASES_2D) + def test_dints_search(self, dints_grid_params, dints_params, input_shape, expected_shape): + num_blocks = dints_grid_params["num_blocks"] + num_depths = dints_grid_params["num_depths"] + # init a Cell to obtain cell operation number + _cell = Cell(1, 1, 0, spatial_dims=dints_grid_params["spatial_dims"]) + num_cell_ops = len(_cell.OPS) + # define archtecture codes + node_a = torch.ones((num_blocks + 1, num_depths)) + arch_code_a = np.ones((num_blocks, 3 * num_depths - 2)) + arch_code_c = np.random.randint(num_cell_ops, size=(num_blocks, 3 * num_depths - 2)) + # initialize with codes + dints_grid_params["arch_code"] = [arch_code_a, arch_code_c] + grid = TopologyInstance(**dints_grid_params) + # set as deploy stage + dints_params["dints_space"] = grid + dints_params["node_a"] = node_a + net = DiNTS(**dints_params).to(dints_grid_params["device"]) + result = net(torch.randn(input_shape).to(dints_grid_params["device"])) + self.assertEqual(result.shape, expected_shape) + self.assertTrue(isinstance(net.weight_parameters(), list)) + + +@SkipIfBeforePyTorchVersion((1, 9)) +class TestDintsTS(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D + TEST_CASES_2D) + def test_script(self, dints_grid_params, dints_params, input_shape, _): + grid = TopologyInstance(**dints_grid_params) + dints_grid_params["device"] = "cpu" + dints_params["dints_space"] = grid + net = DiNTS(**dints_params).to(dints_grid_params["device"]) + test_script_save(net, torch.randn(input_shape).to(dints_grid_params["device"])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_discriminator.py b/tests/test_discriminator.py index 52b9a10dd5..aa9b9720c4 100644 --- a/tests/test_discriminator.py +++ b/tests/test_discriminator.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_divisible_pad.py b/tests/test_divisible_pad.py index ca15b4b347..f940636fa8 100644 --- a/tests/test_divisible_pad.py +++ b/tests/test_divisible_pad.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,21 +22,11 @@ for p in TEST_NDARRAYS: # pad first dim to be divisible by 7, the second unchanged. - TESTS.append( - [ - {"k": (7, -1), "mode": "constant"}, - p(np.zeros((3, 8, 7))), - p(np.zeros((3, 14, 7))), - ] - ) + TESTS.append([{"k": (7, -1), "mode": "constant"}, p(np.zeros((3, 8, 7))), p(np.zeros((3, 14, 7)))]) # pad all dimensions to be divisible by 5 TESTS.append( - [ - {"k": 5, "mode": "constant", "method": "end"}, - p(np.zeros((3, 10, 5, 17))), - p(np.zeros((3, 10, 5, 20))), - ] + [{"k": 5, "mode": "constant", "method": "end"}, p(np.zeros((3, 10, 5, 17))), p(np.zeros((3, 10, 5, 20)))] ) @@ -50,11 +40,13 @@ def test_pad_shape(self, input_param, input_data, expected_val): self.assertAlmostEqual(result.shape, expected_val.shape) def test_pad_kwargs(self): - padder = DivisiblePad(k=5, mode="constant", constant_values=((0, 0), (1, 1), (2, 2))) for p in TEST_NDARRAYS: - result = padder(p(np.zeros((3, 8, 4)))) - result = result.cpu() if isinstance(result, torch.Tensor) else result - torch.testing.assert_allclose(result[:, :1, :4], np.ones((3, 1, 4)), rtol=1e-7, atol=0) + input_data = p(np.zeros((3, 8, 4))) + if isinstance(input_data, np.ndarray): + result = DivisiblePad(k=5, mode="constant", constant_values=((0, 0), (1, 1), (2, 2)))(input_data) + np.testing.assert_allclose(result[:, :1, :4], np.ones((3, 1, 4)), rtol=1e-7, atol=0) + else: + result = DivisiblePad(k=5, mode="constant", value=2)(input_data).cpu() torch.testing.assert_allclose(result[:, :, 4:5], np.ones((3, 10, 1)) + 1, rtol=1e-7, atol=0) diff --git a/tests/test_divisible_padd.py b/tests/test_divisible_padd.py index c834adac6d..61fe917421 100644 --- a/tests/test_divisible_padd.py +++ b/tests/test_divisible_padd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -28,11 +28,7 @@ np.zeros((3, 14, 7)), ] -TEST_CASE_3 = [ - {"keys": ["img"], "k": 0, "mode": {"constant"}}, - {"img": np.zeros((3, 8))}, - np.zeros((3, 8)), -] +TEST_CASE_3 = [{"keys": ["img"], "k": 0, "mode": {"constant"}}, {"img": np.zeros((3, 8))}, np.zeros((3, 8))] class TestDivisiblePadd(unittest.TestCase): diff --git a/tests/test_download_and_extract.py b/tests/test_download_and_extract.py index b02e4ff86f..e6045cada9 100644 --- a/tests/test_download_and_extract.py +++ b/tests/test_download_and_extract.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,39 +12,37 @@ import os import tempfile import unittest +from pathlib import Path from urllib.error import ContentTooShortError, HTTPError +from parameterized import parameterized + from monai.apps import download_and_extract, download_url, extractall -from tests.utils import skip_if_quick +from tests.utils import skip_if_downloading_fails, skip_if_quick, testing_data_config class TestDownloadAndExtract(unittest.TestCase): @skip_if_quick def test_actions(self): testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") - url = "https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE" - filepath = os.path.join(testing_dir, "MedNIST.tar.gz") - output_dir = testing_dir - md5_value = "0bc7306e7427e00ad1c5526a6677552d" - try: - download_and_extract(url, filepath, output_dir, md5_value) - download_and_extract(url, filepath, output_dir, md5_value) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors + config_dict = testing_data_config("images", "mednist") + url = config_dict["url"] + filepath = Path(testing_dir) / "MedNIST.tar.gz" + output_dir = Path(testing_dir) + hash_val, hash_type = config_dict["hash_val"], config_dict["hash_type"] + with skip_if_downloading_fails(): + download_and_extract(url, filepath, output_dir, hash_val=hash_val, hash_type=hash_type) + download_and_extract(url, filepath, output_dir, hash_val=hash_val, hash_type=hash_type) wrong_md5 = "0" - try: - download_url(url, filepath, wrong_md5) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors + with self.assertLogs(logger="monai.apps", level="ERROR"): + try: + download_url(url, filepath, wrong_md5) + except (ContentTooShortError, HTTPError, RuntimeError) as e: + if isinstance(e, RuntimeError): + # FIXME: skip MD5 check as current downloading method may fail + self.assertTrue(str(e).startswith("md5 check")) + return # skipping this test due the network connection errors try: extractall(filepath, output_dir, wrong_md5) @@ -52,29 +50,18 @@ def test_actions(self): self.assertTrue(str(e).startswith("md5 check")) @skip_if_quick - def test_default(self): + @parameterized.expand((("icon", "tar"), ("favicon", "zip"))) + def test_default(self, key, file_type): with tempfile.TemporaryDirectory() as tmp_dir: - try: - # icon.tar.gz https://drive.google.com/file/d/1HrQd-AKPbts9jkTNN4pT8vLZyhM5irVn/view?usp=sharing + with skip_if_downloading_fails(): + img_spec = testing_data_config("images", key) download_and_extract( - "https://drive.google.com/uc?id=1HrQd-AKPbts9jkTNN4pT8vLZyhM5irVn", + img_spec["url"], output_dir=tmp_dir, - hash_val="a55d11ad26ed9eb7277905d796205531", - file_type="tar", + hash_val=img_spec["hash_val"], + hash_type=img_spec["hash_type"], + file_type=file_type, ) - # favicon.ico.zip https://drive.google.com/file/d/1TqBTJap621NO9arzXRrYi04lr9NTVF8H/view?usp=sharing - download_and_extract( - "https://drive.google.com/uc?id=1TqBTJap621NO9arzXRrYi04lr9NTVF8H", - output_dir=tmp_dir, - hash_val="ac6e167ee40803577d98237f2b0241e5", - file_type="zip", - ) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors if __name__ == "__main__": diff --git a/tests/test_downsample_block.py b/tests/test_downsample_block.py index f4ae30198f..ac2acb0845 100644 --- a/tests/test_downsample_block.py +++ b/tests/test_downsample_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,11 +20,7 @@ TEST_CASES = [ [{"spatial_dims": 2, "kernel_size": 2}, (7, 4, 64, 48), (7, 8, 32, 24)], # 4-channel 2D, batch 7 [{"spatial_dims": 1, "kernel_size": 4}, (16, 4, 63), (16, 8, 15)], # 4-channel 1D, batch 16 - [ # 4-channel 1D, batch 16 - {"spatial_dims": 1, "kernel_size": 4, "padding": 1}, - (16, 4, 63), - (16, 8, 16), - ], + [{"spatial_dims": 1, "kernel_size": 4, "padding": 1}, (16, 4, 63), (16, 8, 16)], # 4-channel 1D, batch 16 [ # 4-channel 3D, batch 16 {"spatial_dims": 3, "kernel_size": 3, "ceil_mode": True}, (16, 4, 32, 24, 48), diff --git a/tests/test_dvf2ddf.py b/tests/test_dvf2ddf.py index cc3323cf13..d061cca7ff 100644 --- a/tests/test_dvf2ddf.py +++ b/tests/test_dvf2ddf.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index 81ed239461..36ac9d0309 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -26,14 +26,14 @@ expected_shape: Sequence[Any] TEST_CASE_DYNUNET_2D = [] +out_channels = 2 +in_size = 64 +spatial_dims = 2 for kernel_size in [(3, 3, 3, 1), ((3, 1), 1, (3, 3), (1, 1))]: for strides in [(1, 1, 1, 1), (2, 2, 2, 1)]: + expected_shape = (1, out_channels, *[in_size // strides[0]] * spatial_dims) for in_channels in [2, 3]: for res_block in [True, False]: - out_channels = 2 - in_size = 64 - spatial_dims = 2 - expected_shape = (1, out_channels, *[in_size // strides[0]] * spatial_dims) test_case = [ { "spatial_dims": spatial_dims, @@ -43,8 +43,10 @@ "strides": strides, "upsample_kernel_size": strides[1:], "norm_name": "batch", + "act_name": ("leakyrelu", {"inplace": True, "negative_slope": 0.2}), "deep_supervision": False, "res_block": res_block, + "dropout": None, }, (1, in_channels, in_size, in_size), expected_shape, @@ -52,11 +54,11 @@ TEST_CASE_DYNUNET_2D.append(test_case) TEST_CASE_DYNUNET_3D = [] # in 3d cases, also test anisotropic kernel/strides +in_channels = 1 +in_size = 64 for out_channels in [2, 3]: + expected_shape = (1, out_channels, 64, 32, 64) for res_block in [True, False]: - in_channels = 1 - in_size = 64 - expected_shape = (1, out_channels, 64, 32, 64) test_case = [ { "spatial_dims": 3, @@ -65,9 +67,11 @@ "kernel_size": (3, (1, 1, 3), 3, 3), "strides": ((1, 2, 1), 2, 2, 1), "upsample_kernel_size": (2, 2, 1), + "filters": (64, 96, 128, 192), "norm_name": ("INSTANCE", {"affine": True}), - "deep_supervision": False, + "deep_supervision": True, "res_block": res_block, + "dropout": ("alphadropout", {"p": 0.25}), }, (1, in_channels, in_size, in_size, in_size), expected_shape, diff --git a/tests/test_dynunet_block.py b/tests/test_dynunet_block.py index 7e832f6d81..1c83552766 100644 --- a/tests/test_dynunet_block.py +++ b/tests/test_dynunet_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -35,6 +35,7 @@ "out_channels": 16, "kernel_size": kernel_size, "norm_name": norm_name, + "act_name": ("leakyrelu", {"inplace": True, "negative_slope": 0.1}), "stride": stride, }, (1, 16, *([in_size] * spatial_dims)), @@ -49,22 +50,24 @@ for stride in [1, 2]: for norm_name in ["batch", "instance"]: for in_size in [15, 16]: - out_size = in_size * stride - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": in_channels, - "out_channels": out_channels, - "kernel_size": kernel_size, - "norm_name": norm_name, - "stride": stride, - "upsample_kernel_size": stride, - }, - (1, in_channels, *([in_size] * spatial_dims)), - (1, out_channels, *([out_size] * spatial_dims)), - (1, out_channels, *([in_size * stride] * spatial_dims)), - ] - TEST_UP_BLOCK.append(test_case) + for trans_bias in [True, False]: + out_size = in_size * stride + test_case = [ + { + "spatial_dims": spatial_dims, + "in_channels": in_channels, + "out_channels": out_channels, + "kernel_size": kernel_size, + "norm_name": norm_name, + "stride": stride, + "upsample_kernel_size": stride, + "trans_bias": trans_bias, + }, + (1, in_channels, *([in_size] * spatial_dims)), + (1, out_channels, *([out_size] * spatial_dims)), + (1, out_channels, *([in_size * stride] * spatial_dims)), + ] + TEST_UP_BLOCK.append(test_case) class TestResBasicBlock(unittest.TestCase): diff --git a/tests/test_dynunet_v1.py b/tests/test_dynunet_v1.py deleted file mode 100644 index fc216c145b..0000000000 --- a/tests/test_dynunet_v1.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from typing import Any, Sequence, Union - -import torch -from parameterized import parameterized - -from monai.networks import eval_mode -from monai.networks.nets.dynunet_v1 import DynUNetV1 -from tests.utils import skip_if_quick, test_script_save - -device = "cuda" if torch.cuda.is_available() else "cpu" - -strides: Sequence[Union[Sequence[int], int]] -kernel_size: Sequence[Any] -expected_shape: Sequence[Any] - -TEST_CASE_DYNUNET_2D = [] -for kernel_size in [(3, 3, 3, 1), ((3, 1), 1, (3, 3), (1, 1))]: - for strides in [(1, 1, 1, 1), (2, 2, 2, 1)]: - for in_channels in [2, 3]: - for res_block in [True, False]: - out_channels = 2 - in_size = 64 - spatial_dims = 2 - expected_shape = (1, out_channels, *[in_size // strides[0]] * spatial_dims) - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": in_channels, - "out_channels": out_channels, - "kernel_size": kernel_size, - "strides": strides, - "upsample_kernel_size": strides[1:], - "norm_name": "batch", - "deep_supervision": False, - "res_block": res_block, - }, - (1, in_channels, in_size, in_size), - expected_shape, - ] - TEST_CASE_DYNUNET_2D.append(test_case) - -TEST_CASE_DYNUNET_3D = [] # in 3d cases, also test anisotropic kernel/strides -for out_channels in [2, 3]: - for res_block in [True, False]: - in_channels = 1 - in_size = 64 - expected_shape = (1, out_channels, 64, 32, 64) - test_case = [ - { - "spatial_dims": 3, - "in_channels": in_channels, - "out_channels": out_channels, - "kernel_size": (3, (1, 1, 3), 3, 3), - "strides": ((1, 2, 1), 2, 2, 1), - "upsample_kernel_size": (2, 2, 1), - "norm_name": "instance", - "deep_supervision": False, - "res_block": res_block, - }, - (1, in_channels, in_size, in_size, in_size), - expected_shape, - ] - TEST_CASE_DYNUNET_3D.append(test_case) - -TEST_CASE_DEEP_SUPERVISION = [] -for spatial_dims in [2, 3]: - for res_block in [True, False]: - for deep_supr_num in [1, 2]: - for strides in [(1, 2, 1, 2, 1), (2, 2, 2, 1), (2, 1, 1, 2, 2)]: - scale = strides[0] - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": 1, - "out_channels": 2, - "kernel_size": [3] * len(strides), - "strides": strides, - "upsample_kernel_size": strides[1:], - "norm_name": "group", - "deep_supervision": True, - "deep_supr_num": deep_supr_num, - "res_block": res_block, - }, - (1, 1, *[in_size] * spatial_dims), - (1, 1 + deep_supr_num, 2, *[in_size // scale] * spatial_dims), - ] - TEST_CASE_DEEP_SUPERVISION.append(test_case) - - -@skip_if_quick -class TestDynUNet(unittest.TestCase): - @parameterized.expand(TEST_CASE_DYNUNET_2D + TEST_CASE_DYNUNET_3D) - def test_shape(self, input_param, input_shape, expected_shape): - net = DynUNetV1(**input_param).to(device) - with eval_mode(net): - result = net(torch.randn(input_shape).to(device)) - self.assertEqual(result.shape, expected_shape) - - def test_script(self): - input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] - net = DynUNetV1(**input_param) - test_data = torch.randn(input_shape) - test_script_save(net, test_data) - - -class TestDynUNetDeepSupervision(unittest.TestCase): - @parameterized.expand(TEST_CASE_DEEP_SUPERVISION) - def test_shape(self, input_param, input_shape, expected_shape): - net = DynUNetV1(**input_param).to(device) - with torch.no_grad(): - results = net(torch.randn(input_shape).to(device)) - self.assertEqual(results.shape, expected_shape) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_efficientnet.py b/tests/test_efficientnet.py index 6befba108a..a2a5e30750 100644 --- a/tests/test_efficientnet.py +++ b/tests/test_efficientnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -26,7 +26,7 @@ get_efficientnet_image_size, ) from monai.utils import optional_import -from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save +from tests.utils import skip_if_downloading_fails, skip_if_quick, test_pretrained_networks, test_script_save if TYPE_CHECKING: import torchvision @@ -44,7 +44,7 @@ def get_model_names(): - return ["efficientnet-b{}".format(d) for d in range(8)] + return [f"efficientnet-b{d}" for d in range(8)] def get_expected_model_shape(model_name): @@ -107,11 +107,7 @@ def make_shape_cases( ret_tests.append( [ kwargs, - ( - batch, - in_channels, - ) - + (get_expected_model_shape(model),) * spatial_dim, + (batch, in_channels) + (get_expected_model_shape(model),) * spatial_dim, (batch, num_classes), ] ) @@ -245,7 +241,7 @@ def make_shape_cases( }, [1, 2, 224, 224], ([1, 32, 112, 112], [1, 56, 56, 56], [1, 88, 28, 28], [1, 248, 14, 14], [1, 704, 7, 7]), - ), + ) ] @@ -254,8 +250,8 @@ class TestEFFICIENTNET(unittest.TestCase): def test_shape(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" - # initialize model - net = EfficientNetBN(**input_param).to(device) + with skip_if_downloading_fails(): + net = EfficientNetBN(**input_param).to(device) # run inference with random tensor with eval_mode(net): @@ -268,8 +264,8 @@ def test_shape(self, input_param, input_shape, expected_shape): def test_non_default_shapes(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" - # initialize model - net = EfficientNetBN(**input_param).to(device) + with skip_if_downloading_fails(): + net = EfficientNetBN(**input_param).to(device) # override input shape with different variations num_dims = len(input_shape) - 2 @@ -382,8 +378,8 @@ class TestExtractFeatures(unittest.TestCase): def test_shape(self, input_param, input_shape, expected_shapes): device = "cuda" if torch.cuda.is_available() else "cpu" - # initialize model - net = EfficientNetBNFeatures(**input_param).to(device) + with skip_if_downloading_fails(): + net = EfficientNetBNFeatures(**input_param).to(device) # run inference with random tensor with eval_mode(net): diff --git a/tests/test_ensemble_evaluator.py b/tests/test_ensemble_evaluator.py index 7f63cb6401..dab46f366f 100644 --- a/tests/test_ensemble_evaluator.py +++ b/tests/test_ensemble_evaluator.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,12 +13,18 @@ import torch from ignite.engine import EventEnum, Events +from parameterized import parameterized from monai.engines import EnsembleEvaluator +TEST_CASE_1 = [["pred_0", "pred_1", "pred_2", "pred_3", "pred_4"]] + +TEST_CASE_2 = [None] + class TestEnsembleEvaluator(unittest.TestCase): - def test_content(self): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_content(self, pred_keys): device = torch.device("cpu:0") class TestDataset(torch.utils.data.Dataset): @@ -52,7 +58,7 @@ class CustomEvents(EventEnum): device=device, val_data_loader=val_loader, networks=[net0, net1, net2, net3, net4], - pred_keys=["pred0", "pred1", "pred2", "pred3", "pred4"], + pred_keys=pred_keys, event_names=["bwd_event", "opt_event", CustomEvents], event_to_attr={CustomEvents.FOO_EVENT: "foo", "opt_event": "opt"}, ) @@ -61,7 +67,7 @@ class CustomEvents(EventEnum): def run_transform(engine): for i in range(5): expected_value = engine.state.iteration + i - torch.testing.assert_allclose(engine.state.output[0][f"pred{i}"].item(), expected_value) + torch.testing.assert_allclose(engine.state.output[0][f"pred_{i}"].item(), expected_value) @val_engine.on(Events.EPOCH_COMPLETED) def trigger_custom_event(): diff --git a/tests/test_ensure_channel_first.py b/tests/test_ensure_channel_first.py index 23126d326f..dd6168ec75 100644 --- a/tests/test_ensure_channel_first.py +++ b/tests/test_ensure_channel_first.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -27,11 +27,7 @@ TEST_CASE_2 = [{"image_only": False}, ["test_image.nii.gz"], -1] -TEST_CASE_3 = [ - {"image_only": False}, - ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], - None, -] +TEST_CASE_3 = [{"image_only": False}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], None] TEST_CASE_4 = [{"reader": ITKReader(), "image_only": False}, ["test_image.nii.gz"], None] @@ -43,11 +39,7 @@ None, ] -TEST_CASE_7 = [ - {"image_only": False, "reader": ITKReader(pixel_type=itk.UC)}, - "tests/testing_data/CT_DICOM", - None, -] +TEST_CASE_7 = [{"image_only": False, "reader": ITKReader(pixel_type=itk.UC)}, "tests/testing_data/CT_DICOM", None] class TestEnsureChannelFirst(unittest.TestCase): diff --git a/tests/test_ensure_channel_firstd.py b/tests/test_ensure_channel_firstd.py index b4cde02a8f..7f1a57a207 100644 --- a/tests/test_ensure_channel_firstd.py +++ b/tests/test_ensure_channel_firstd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,17 +19,14 @@ from PIL import Image from monai.transforms import EnsureChannelFirstd, LoadImaged +from monai.utils.enums import PostFix from tests.utils import TEST_NDARRAYS TEST_CASE_1 = [{"keys": "img"}, ["test_image.nii.gz"], None] TEST_CASE_2 = [{"keys": "img"}, ["test_image.nii.gz"], -1] -TEST_CASE_3 = [ - {"keys": "img"}, - ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], - None, -] +TEST_CASE_3 = [{"keys": "img"}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], None] class TestEnsureChannelFirstd(unittest.TestCase): @@ -62,12 +59,14 @@ def test_load_png(self): def test_exceptions(self): with self.assertRaises(ValueError): # no meta - EnsureChannelFirstd("img")({"img": np.zeros((1, 2, 3)), "img_meta_dict": None}) + EnsureChannelFirstd("img")({"img": np.zeros((1, 2, 3)), PostFix.meta("img"): None}) with self.assertRaises(ValueError): # no meta channel - EnsureChannelFirstd("img")({"img": np.zeros((1, 2, 3)), "img_meta_dict": {"original_channel_dim": None}}) - EnsureChannelFirstd("img", strict_check=False)({"img": np.zeros((1, 2, 3)), "img_meta_dict": None}) + EnsureChannelFirstd("img")( + {"img": np.zeros((1, 2, 3)), PostFix.meta("img"): {"original_channel_dim": None}} + ) + EnsureChannelFirstd("img", strict_check=False)({"img": np.zeros((1, 2, 3)), PostFix.meta("img"): None}) EnsureChannelFirstd("img", strict_check=False)( - {"img": np.zeros((1, 2, 3)), "img_meta_dict": {"original_channel_dim": None}} + {"img": np.zeros((1, 2, 3)), PostFix.meta("img"): {"original_channel_dim": None}} ) diff --git a/tests/test_ensure_type.py b/tests/test_ensure_type.py index 8feb96ed37..f8a6ee30ff 100644 --- a/tests/test_ensure_type.py +++ b/tests/test_ensure_type.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,9 +25,11 @@ def test_array_input(self): test_datas.append(test_datas[-1].cuda()) for test_data in test_datas: for dtype in ("tensor", "NUMPY"): - result = EnsureType(data_type=dtype)(test_data) + result = EnsureType(dtype, dtype=np.float32 if dtype == "NUMPY" else None, device="cpu")(test_data) + if dtype == "NUMPY": + self.assertTrue(result.dtype == np.float32) self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) self.assertTupleEqual(result.shape, (2, 2)) def test_single_input(self): @@ -36,12 +38,12 @@ def test_single_input(self): test_datas.append(test_datas[-1].cuda()) for test_data in test_datas: for dtype in ("tensor", "numpy"): - result = EnsureType(data_type=dtype)(test_data) + result = EnsureType(data_type=dtype, device="cpu")(test_data) self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) if isinstance(test_data, bool): self.assertFalse(result) else: - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) self.assertEqual(result.ndim, 0) def test_string(self): @@ -57,12 +59,12 @@ def test_string(self): def test_list_tuple(self): for dtype in ("tensor", "numpy"): - result = EnsureType(data_type=dtype)([[1, 2], [3, 4]]) + result = EnsureType(data_type=dtype, wrap_sequence=False)([[1, 2], [3, 4]]) self.assertTrue(isinstance(result, list)) self.assertTrue(isinstance(result[0][1], torch.Tensor if dtype == "tensor" else np.ndarray)) torch.testing.assert_allclose(result[1][0], torch.as_tensor(3)) # tuple of numpy arrays - result = EnsureType(data_type=dtype)((np.array([1, 2]), np.array([3, 4]))) + result = EnsureType(data_type=dtype, wrap_sequence=False)((np.array([1, 2]), np.array([3, 4]))) self.assertTrue(isinstance(result, tuple)) self.assertTrue(isinstance(result[0], torch.Tensor if dtype == "tensor" else np.ndarray)) torch.testing.assert_allclose(result[1], torch.as_tensor([3, 4])) diff --git a/tests/test_ensure_typed.py b/tests/test_ensure_typed.py index 96f482afc2..cadab9bd56 100644 --- a/tests/test_ensure_typed.py +++ b/tests/test_ensure_typed.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,9 +25,13 @@ def test_array_input(self): test_datas.append(test_datas[-1].cuda()) for test_data in test_datas: for dtype in ("tensor", "NUMPY"): - result = EnsureTyped(keys="data", data_type=dtype)({"data": test_data})["data"] + result = EnsureTyped( + keys="data", data_type=dtype, dtype=np.float32 if dtype == "NUMPY" else None, device="cpu" + )({"data": test_data})["data"] + if dtype == "NUMPY": + self.assertTrue(result.dtype == np.float32) self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) self.assertTupleEqual(result.shape, (2, 2)) def test_single_input(self): @@ -41,7 +45,7 @@ def test_single_input(self): if isinstance(test_data, bool): self.assertFalse(result) else: - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) self.assertEqual(result.ndim, 0) def test_string(self): @@ -57,12 +61,14 @@ def test_string(self): def test_list_tuple(self): for dtype in ("tensor", "numpy"): - result = EnsureTyped(keys="data", data_type=dtype)({"data": [[1, 2], [3, 4]]})["data"] + result = EnsureTyped(keys="data", data_type=dtype, wrap_sequence=False)({"data": [[1, 2], [3, 4]]})["data"] self.assertTrue(isinstance(result, list)) self.assertTrue(isinstance(result[0][1], torch.Tensor if dtype == "tensor" else np.ndarray)) torch.testing.assert_allclose(result[1][0], torch.as_tensor(3)) # tuple of numpy arrays - result = EnsureTyped(keys="data", data_type=dtype)({"data": (np.array([1, 2]), np.array([3, 4]))})["data"] + result = EnsureTyped(keys="data", data_type=dtype, wrap_sequence=False)( + {"data": (np.array([1, 2]), np.array([3, 4]))} + )["data"] self.assertTrue(isinstance(result, tuple)) self.assertTrue(isinstance(result[0], torch.Tensor if dtype == "tensor" else np.ndarray)) torch.testing.assert_allclose(result[1], torch.as_tensor([3, 4])) @@ -75,7 +81,7 @@ def test_dict(self): "extra": None, } for dtype in ("tensor", "numpy"): - result = EnsureTyped(keys="data", data_type=dtype)({"data": test_data})["data"] + result = EnsureTyped(keys="data", data_type=dtype, device="cpu")({"data": test_data})["data"] self.assertTrue(isinstance(result, dict)) self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray)) torch.testing.assert_allclose(result["img"], torch.as_tensor([1.0, 2.0])) diff --git a/tests/test_enum_bound_interp.py b/tests/test_enum_bound_interp.py index f788f8ba17..7607619e7a 100644 --- a/tests/test_enum_bound_interp.py +++ b/tests/test_enum_bound_interp.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_eval_mode.py b/tests/test_eval_mode.py index 45c551c209..bc9c97d238 100644 --- a/tests/test_eval_mode.py +++ b/tests/test_eval_mode.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_evenly_divisible_all_gather_dist.py b/tests/test_evenly_divisible_all_gather_dist.py index bf3bd1bacc..1bb3d887a0 100644 --- a/tests/test_evenly_divisible_all_gather_dist.py +++ b/tests/test_evenly_divisible_all_gather_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_factorized_increase.py b/tests/test_factorized_increase.py new file mode 100644 index 0000000000..a86f5a2db9 --- /dev/null +++ b/tests/test_factorized_increase.py @@ -0,0 +1,34 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.blocks.dints_block import FactorizedIncreaseBlock + +TEST_CASES_3D = [ + [{"in_channel": 32, "out_channel": 16}, (7, 32, 24, 16, 8), (7, 16, 48, 32, 16)], + [{"in_channel": 1, "out_channel": 2}, (1, 1, 1, 1, 1), (1, 2, 2, 2, 2)], +] + + +class TestFactInc(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D) + def test_factorized_increase_3d(self, input_param, input_shape, expected_shape): + net = FactorizedIncreaseBlock(**input_param) + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_factorized_reduce.py b/tests/test_factorized_reduce.py new file mode 100644 index 0000000000..d14418233e --- /dev/null +++ b/tests/test_factorized_reduce.py @@ -0,0 +1,34 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.blocks.dints_block import FactorizedReduceBlock + +TEST_CASES_3D = [ + [{"in_channel": 32, "out_channel": 16}, (7, 32, 24, 16, 8), (7, 16, 12, 8, 4)], + [{"in_channel": 16, "out_channel": 32}, (7, 16, 22, 14, 6), (7, 32, 11, 7, 3)], +] + + +class TestFactRed(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D) + def test_factorized_reduce_3d(self, input_param, input_shape, expected_shape): + net = FactorizedReduceBlock(**input_param) + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_fg_bg_to_indices.py b/tests/test_fg_bg_to_indices.py index 98626c7028..03eb770d6d 100644 --- a/tests/test_fg_bg_to_indices.py +++ b/tests/test_fg_bg_to_indices.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,58 +11,70 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import FgBgToIndices +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"image_threshold": 0.0, "output_shape": None}, - np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), - None, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 4, 8]), -] +TESTS_CASES = [] +for p in TEST_NDARRAYS: + TESTS_CASES.append( + [ + {"image_threshold": 0.0, "output_shape": None}, + p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), + None, + p([1, 2, 3, 5, 6, 7]), + p([0, 4, 8]), + ] + ) -TEST_CASE_2 = [ - {"image_threshold": 0.0, "output_shape": None}, - np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), - np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]]), - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TESTS_CASES.append( + [ + {"image_threshold": 0.0, "output_shape": None}, + p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), + p([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]]), + p([1, 2, 3, 5, 6, 7]), + p([0, 8]), + ] + ) -TEST_CASE_3 = [ - {"image_threshold": 1.0, "output_shape": None}, - np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), - np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TESTS_CASES.append( + [ + {"image_threshold": 1.0, "output_shape": None}, + p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), + p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), + p([1, 2, 3, 5, 6, 7]), + p([0, 8]), + ] + ) -TEST_CASE_4 = [ - {"image_threshold": 1.0, "output_shape": None}, - np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), - np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TESTS_CASES.append( + [ + {"image_threshold": 1.0, "output_shape": None}, + p([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), + p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), + p([1, 2, 3, 5, 6, 7]), + p([0, 8]), + ] + ) -TEST_CASE_5 = [ - {"image_threshold": 1.0, "output_shape": [3, 3]}, - np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), - np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), - np.array([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]), - np.array([[0, 0], [2, 2]]), -] + TESTS_CASES.append( + [ + {"image_threshold": 1.0, "output_shape": [3, 3]}, + p([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), + p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), + p([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]), + p([[0, 0], [2, 2]]), + ] + ) class TestFgBgToIndices(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS_CASES) def test_type_shape(self, input_data, label, image, expected_fg, expected_bg): fg_indices, bg_indices = FgBgToIndices(**input_data)(label, image) - np.testing.assert_allclose(fg_indices, expected_fg) - np.testing.assert_allclose(bg_indices, expected_bg) + assert_allclose(fg_indices, expected_fg) + assert_allclose(bg_indices, expected_bg) if __name__ == "__main__": diff --git a/tests/test_fg_bg_to_indicesd.py b/tests/test_fg_bg_to_indicesd.py index ce6ca30f1b..3be795919f 100644 --- a/tests/test_fg_bg_to_indicesd.py +++ b/tests/test_fg_bg_to_indicesd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,53 +11,66 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import FgBgToIndicesd +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"keys": "label", "image_key": None, "image_threshold": 0.0, "output_shape": None}, - {"label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 4, 8]), -] +TEST_CASES = [] +for p in TEST_NDARRAYS: -TEST_CASE_2 = [ - {"keys": "label", "image_key": "image", "image_threshold": 0.0, "output_shape": None}, - {"label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]])}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TEST_CASES.append( + [ + {"keys": "label", "image_key": None, "image_threshold": 0.0, "output_shape": None}, + {"label": p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])}, + p([1, 2, 3, 5, 6, 7]), + p([0, 4, 8]), + ] + ) -TEST_CASE_3 = [ - {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None}, - {"label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TEST_CASES.append( + [ + {"keys": "label", "image_key": "image", "image_threshold": 0.0, "output_shape": None}, + {"label": p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": p([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]])}, + p([1, 2, 3, 5, 6, 7]), + p([0, 8]), + ] + ) -TEST_CASE_4 = [ - {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None}, - {"label": np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TEST_CASES.append( + [ + {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None}, + {"label": p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, + p([1, 2, 3, 5, 6, 7]), + p([0, 8]), + ] + ) -TEST_CASE_5 = [ - {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": [3, 3]}, - {"label": np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, - np.array([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]), - np.array([[0, 0], [2, 2]]), -] + TEST_CASES.append( + [ + {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None}, + {"label": p([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), "image": p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, + p([1, 2, 3, 5, 6, 7]), + p([0, 8]), + ] + ) + + TEST_CASES.append( + [ + {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": [3, 3]}, + {"label": p([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), "image": p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, + p([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]), + p([[0, 0], [2, 2]]), + ] + ) class TestFgBgToIndicesd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TEST_CASES) def test_type_shape(self, input_data, data, expected_fg, expected_bg): result = FgBgToIndicesd(**input_data)(data) - np.testing.assert_allclose(result["label_fg_indices"], expected_fg) - np.testing.assert_allclose(result["label_bg_indices"], expected_bg) + assert_allclose(result["label_fg_indices"], expected_fg) + assert_allclose(result["label_bg_indices"], expected_bg) if __name__ == "__main__": diff --git a/tests/test_file_basename.py b/tests/test_file_basename.py index 77e77fabc5..dc8b1316a2 100644 --- a/tests/test_file_basename.py +++ b/tests/test_file_basename.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,6 +12,7 @@ import os import tempfile import unittest +from pathlib import Path from monai.data.utils import create_file_basename @@ -69,6 +70,15 @@ def test_value(self): expected = os.path.join(output_tmp, "test", "test_post_8") self.assertEqual(result, expected) + result = create_file_basename("post", Path("test.tar.gz"), Path(output_tmp), Path("foo"), True, 8) + expected = os.path.join(output_tmp, "test", "test_post_8") + self.assertEqual(result, expected) + + def test_relative_path(self): + output = create_file_basename("", "test.txt", "output", "", makedirs=False) + expected = os.path.join("output", "test", "test") + self.assertEqual(output, expected) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_fill_holes.py b/tests/test_fill_holes.py index 6ea83c239b..9f9dc1fc2e 100644 --- a/tests/test_fill_holes.py +++ b/tests/test_fill_holes.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,31 +16,15 @@ from parameterized import parameterized from monai.transforms import FillHoles -from tests.utils import assert_allclose, clone +from tests.utils import TEST_NDARRAYS, assert_allclose, clone -grid_1_raw = [ - [1, 1, 1], - [1, 0, 1], - [1, 1, 1], -] +grid_1_raw = [[1, 1, 1], [1, 0, 1], [1, 1, 1]] -grid_2_raw = [ - [0, 1, 0], - [1, 0, 1], - [0, 1, 0], -] +grid_2_raw = [[0, 1, 0], [1, 0, 1], [0, 1, 0]] -grid_3_raw = [ - [1, 1, 1], - [1, 1, 1], - [1, 1, 1], -] +grid_3_raw = [[1, 1, 1], [1, 1, 1], [1, 1, 1]] -grid_4_raw = [ - [0, 1, 0], - [1, 1, 1], - [0, 1, 0], -] +grid_4_raw = [[0, 1, 0], [1, 1, 1], [0, 1, 0]] grid_1 = torch.tensor([grid_1_raw]) @@ -50,49 +34,15 @@ grid_4 = torch.tensor([grid_4_raw]) -grid_5 = torch.tensor( - [ - [ - [1, 1, 1], - [1, 0, 0], - [1, 1, 1], - ] - ] -) - -grid_6 = torch.tensor( - [ - [ - [1, 1, 2, 2, 2], - [1, 0, 2, 0, 2], - [1, 1, 2, 2, 2], - ] - ] -) - -grid_7 = torch.tensor( - [ - [ - [1, 1, 2, 2, 2], - [1, 0, 2, 2, 2], - [1, 1, 2, 2, 2], - ] - ] -) - -TEST_CASE_0 = [ - "enclosed_default_full_connectivity_default_applied_labels", - {}, - grid_1, - grid_3, -] +grid_5 = torch.tensor([[[1, 1, 1], [1, 0, 0], [1, 1, 1]]]) -TEST_CASE_1 = [ - "enclosed_full_connectivity_default_applied_labels", - {"connectivity": 2}, - grid_1, - grid_3, -] +grid_6 = torch.tensor([[[1, 1, 2, 2, 2], [1, 0, 2, 0, 2], [1, 1, 2, 2, 2]]]) + +grid_7 = torch.tensor([[[1, 1, 2, 2, 2], [1, 0, 2, 2, 2], [1, 1, 2, 2, 2]]]) + +TEST_CASE_0 = ["enclosed_default_full_connectivity_default_applied_labels", {}, grid_1, grid_3] + +TEST_CASE_1 = ["enclosed_full_connectivity_default_applied_labels", {"connectivity": 2}, grid_1, grid_3] TEST_CASE_2 = [ "enclosed_full_connectivity_applied_labels_same_single", @@ -129,40 +79,15 @@ grid_3, ] -TEST_CASE_7 = [ - "enclosed_connectivity_1_default_applied_labels", - {"connectivity": 1}, - grid_1, - grid_3, -] +TEST_CASE_7 = ["enclosed_connectivity_1_default_applied_labels", {"connectivity": 1}, grid_1, grid_3] -TEST_CASE_8 = [ - "enclosed_connectivity_1_default_applied_labels", - {"connectivity": 1}, - grid_2, - grid_4, -] +TEST_CASE_8 = ["enclosed_connectivity_1_default_applied_labels", {"connectivity": 1}, grid_2, grid_4] -TEST_CASE_9 = [ - "open_full_connectivity_default_applied_labels", - {"connectivity": 2}, - grid_2, - grid_2, -] +TEST_CASE_9 = ["open_full_connectivity_default_applied_labels", {"connectivity": 2}, grid_2, grid_2] -TEST_CASE_10 = [ - "open_to_edge_connectivity_1_default_applied_labels", - {"connectivity": 1}, - grid_5, - grid_5, -] +TEST_CASE_10 = ["open_to_edge_connectivity_1_default_applied_labels", {"connectivity": 1}, grid_5, grid_5] -TEST_CASE_11 = [ - "open_to_other_label_connectivity_1_default_applied_labels", - {"connectivity": 1}, - grid_6, - grid_7, -] +TEST_CASE_11 = ["open_to_other_label_connectivity_1_default_applied_labels", {"connectivity": 1}, grid_6, grid_7] TEST_CASE_12 = [ "open_to_other_label_connectivity_1_applied_labels_other", @@ -276,11 +201,9 @@ class TestFillHoles(unittest.TestCase): @parameterized.expand(VALID_CASES) def test_correct_results(self, _, args, input_image, expected): converter = FillHoles(**args) - if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): - result = converter(clone(input_image).cuda()) - else: - result = converter(clone(input_image)) - assert_allclose(result, expected) + for p in TEST_NDARRAYS: + result = converter(p(clone(input_image))) + assert_allclose(result, p(expected)) @parameterized.expand(INVALID_CASES) def test_raise_exception(self, _, args, input_image, expected_error): diff --git a/tests/test_fill_holesd.py b/tests/test_fill_holesd.py new file mode 100644 index 0000000000..f7aa9f6108 --- /dev/null +++ b/tests/test_fill_holesd.py @@ -0,0 +1,222 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import FillHolesd +from monai.utils.enums import CommonKeys +from tests.utils import TEST_NDARRAYS, assert_allclose, clone + +grid_1_raw = [[1, 1, 1], [1, 0, 1], [1, 1, 1]] + +grid_2_raw = [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + +grid_3_raw = [[1, 1, 1], [1, 1, 1], [1, 1, 1]] + +grid_4_raw = [[0, 1, 0], [1, 1, 1], [0, 1, 0]] + +grid_1 = torch.tensor([grid_1_raw]) + +grid_2 = torch.tensor([grid_2_raw]) + +grid_3 = torch.tensor([grid_3_raw]) + +grid_4 = torch.tensor([grid_4_raw]) + +grid_5 = torch.tensor([[[1, 1, 1], [1, 0, 0], [1, 1, 1]]]) + +grid_6 = torch.tensor([[[1, 1, 2, 2, 2], [1, 0, 2, 0, 2], [1, 1, 2, 2, 2]]]) + +grid_7 = torch.tensor([[[1, 1, 2, 2, 2], [1, 0, 2, 2, 2], [1, 1, 2, 2, 2]]]) + +TEST_CASE_0 = ["enclosed_default_full_connectivity_default_applied_labels", {}, grid_1, grid_3] + +TEST_CASE_1 = ["enclosed_full_connectivity_default_applied_labels", {"connectivity": 2}, grid_1, grid_3] + +TEST_CASE_2 = [ + "enclosed_full_connectivity_applied_labels_same_single", + {"connectivity": 2, "applied_labels": 1}, + grid_1, + grid_3, +] + +TEST_CASE_3 = [ + "enclosed_full_connectivity_applied_labels_same_list", + {"connectivity": 2, "applied_labels": [1]}, + grid_1, + grid_3, +] + +TEST_CASE_4 = [ + "enclosed_full_connectivity_applied_labels_other_single", + {"connectivity": 2, "applied_labels": 2}, + grid_1, + grid_1, +] + +TEST_CASE_5 = [ + "enclosed_full_connectivity_applied_labels_other_list", + {"connectivity": 2, "applied_labels": [2]}, + grid_1, + grid_1, +] + +TEST_CASE_6 = [ + "enclosed_full_connectivity_applied_labels_same_and_other", + {"connectivity": 2, "applied_labels": [1, 2]}, + grid_1, + grid_3, +] + +TEST_CASE_7 = ["enclosed_connectivity_1_default_applied_labels", {"connectivity": 1}, grid_1, grid_3] + +TEST_CASE_8 = ["enclosed_connectivity_1_default_applied_labels", {"connectivity": 1}, grid_2, grid_4] + +TEST_CASE_9 = ["open_full_connectivity_default_applied_labels", {"connectivity": 2}, grid_2, grid_2] + +TEST_CASE_10 = ["open_to_edge_connectivity_1_default_applied_labels", {"connectivity": 1}, grid_5, grid_5] + +TEST_CASE_11 = ["open_to_other_label_connectivity_1_default_applied_labels", {"connectivity": 1}, grid_6, grid_7] + +TEST_CASE_12 = [ + "open_to_other_label_connectivity_1_applied_labels_other", + {"connectivity": 1, "applied_labels": 1}, + grid_6, + grid_6, +] + +TEST_CASE_13 = [ + "numpy_enclosed_default_full_connectivity_default_applied_labels", + {}, + grid_1.cpu().numpy(), + grid_3.cpu().numpy(), +] + +TEST_CASE_14 = [ + "3D_enclosed_full_connectivity_default_applied_labels", + {"connectivity": 3}, + torch.tensor([[grid_3_raw, grid_1_raw, grid_3_raw]]), + torch.tensor([[grid_3_raw, grid_3_raw, grid_3_raw]]), +] + +TEST_CASE_15 = [ + "3D_enclosed_connectivity_1_default_applied_labels", + {"connectivity": 1}, + torch.tensor([[grid_4_raw, grid_2_raw, grid_4_raw]]), + torch.tensor([[grid_4_raw, grid_4_raw, grid_4_raw]]), +] + +TEST_CASE_16 = [ + "3D_open_full_connectivity_default_applied_labels", + {"connectivity": 3}, + torch.tensor([[grid_4_raw, grid_2_raw, grid_4_raw]]), + torch.tensor([[grid_4_raw, grid_2_raw, grid_4_raw]]), +] + +TEST_CASE_17 = [ + "3D_open_to_edge_connectivity_1_default_applied_labels", + {"connectivity": 1}, + torch.tensor([[grid_1_raw, grid_1_raw, grid_3_raw]]), + torch.tensor([[grid_1_raw, grid_1_raw, grid_3_raw]]), +] + +TEST_CASE_18 = [ + "enclosed_full_connectivity_applied_labels_with_background", + {"connectivity": 2, "applied_labels": [0, 1]}, + grid_1, + grid_3, +] + +TEST_CASE_19 = [ + "enclosed_full_connectivity_applied_labels_only_background", + {"connectivity": 2, "applied_labels": [0]}, + grid_1, + grid_1, +] + +TEST_CASE_20 = [ + "one-hot_enclosed_connectivity_1_default_applied_labels", + {"connectivity": 1}, + torch.tensor([grid_1_raw, grid_1_raw, grid_2_raw]), + torch.tensor([grid_1_raw, grid_3_raw, grid_4_raw]), +] + +TEST_CASE_21 = [ + "one-hot_enclosed_connectivity_1_applied_labels_2", + {"connectivity": 1, "applied_labels": [2]}, + torch.tensor([grid_1_raw, grid_1_raw, grid_2_raw]), + torch.tensor([grid_1_raw, grid_1_raw, grid_4_raw]), +] + +TEST_CASE_22 = [ + "one-hot_full_connectivity_applied_labels_2", + {"connectivity": 2}, + torch.tensor([grid_1_raw, grid_1_raw, grid_2_raw]), + torch.tensor([grid_1_raw, grid_3_raw, grid_2_raw]), +] + +VALID_CASES = [ + TEST_CASE_0, + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + TEST_CASE_10, + TEST_CASE_11, + TEST_CASE_12, + TEST_CASE_13, + TEST_CASE_14, + TEST_CASE_15, + TEST_CASE_16, + TEST_CASE_17, + TEST_CASE_18, + TEST_CASE_19, + TEST_CASE_20, + TEST_CASE_21, + TEST_CASE_22, +] + +ITEST_CASE_1 = ["invalid_image_data_type", {}, [[[[1, 1, 1]]]], NotImplementedError] + +INVALID_CASES = [ITEST_CASE_1] + + +class TestFillHoles(unittest.TestCase): + @parameterized.expand(VALID_CASES) + def test_correct_results(self, _, args, input_image, expected): + key = CommonKeys.IMAGE + converter = FillHolesd(keys=key, **args) + for p in TEST_NDARRAYS: + result = converter({key: p(clone(input_image))})[key] + assert_allclose(result, p(expected)) + + @parameterized.expand(INVALID_CASES) + def test_raise_exception(self, _, args, input_image, expected_error): + key = CommonKeys.IMAGE + with self.assertRaises(expected_error): + converter = FillHolesd(keys=key, **args) + if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): + _ = converter({key: clone(input_image).cuda()})[key] + else: + _ = converter({key: clone(input_image)})[key] + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_flip.py b/tests/test_flip.py index 404a3def7d..17cf0d2c39 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,12 +34,10 @@ def test_correct_results(self, _, spatial_axis): for p in TEST_NDARRAYS: im = p(self.imt[0]) flip = Flip(spatial_axis=spatial_axis) - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, spatial_axis)) + expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) result = flip(im) - assert_allclose(expected, result) + assert_allclose(result, p(expected)) if __name__ == "__main__": diff --git a/tests/test_flipd.py b/tests/test_flipd.py index 1676723800..900779f4e0 100644 --- a/tests/test_flipd.py +++ b/tests/test_flipd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -33,12 +33,10 @@ def test_invalid_cases(self, _, spatial_axis, raises): def test_correct_results(self, _, spatial_axis): for p in TEST_NDARRAYS: flip = Flipd(keys="img", spatial_axis=spatial_axis) - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, spatial_axis)) + expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) result = flip({"img": p(self.imt[0])})["img"] - assert_allclose(expected, result) + assert_allclose(result, p(expected)) if __name__ == "__main__": diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index 1314fe3841..d8a9c8ab5b 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -44,6 +44,35 @@ def test_consistency_with_cross_entropy_2d(self): max_error = abs(a - b) self.assertAlmostEqual(max_error, 0.0, places=3) + def test_consistency_with_cross_entropy_2d_no_reduction(self): + """For gamma=0 the focal loss reduces to the cross entropy loss""" + import numpy as np + + focal_loss = FocalLoss(to_onehot_y=False, gamma=0.0, reduction="none", weight=1.0) + ce = nn.BCEWithLogitsLoss(reduction="none") + max_error = 0 + class_num = 10 + batch_size = 128 + for _ in range(100): + # Create a random tensor of shape (batch_size, class_num, 8, 4) + x = torch.rand(batch_size, class_num, 8, 4, requires_grad=True) + # Create a random batch of classes + l = torch.randint(low=0, high=2, size=(batch_size, class_num, 8, 4)).float() + if torch.cuda.is_available(): + x = x.cuda() + l = l.cuda() + output0 = focal_loss(x, l) + output1 = ce(x, l) + a = output0.cpu().detach().numpy() + b = output1.cpu().detach().numpy() + error = np.abs(a - b) + max_error = np.maximum(error, max_error) + # if np.all(np.abs(a - b) > max_error): + # max_error = np.abs(a - b) + + assert np.allclose(max_error, 0) + # self.assertAlmostEqual(max_error, 0.0, places=3) + def test_consistency_with_cross_entropy_2d_onehot_label(self): """For gamma=0 the focal loss reduces to the cross entropy loss""" focal_loss = FocalLoss(to_onehot_y=True, gamma=0.0, reduction="mean") diff --git a/tests/test_folder_layout.py b/tests/test_folder_layout.py new file mode 100644 index 0000000000..f7291933a3 --- /dev/null +++ b/tests/test_folder_layout.py @@ -0,0 +1,75 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest +from pathlib import Path + +from parameterized import parameterized + +from monai.data.folder_layout import FolderLayout + +TEST_CASES = [ + ({"output_dir": ""}, {}, "subject"), + ({"output_dir": Path(".")}, {}, "subject"), + ({"output_dir": Path(".")}, {"idx": 1}, "subject_1"), + (dict(output_dir=Path("/test_run_1"), extension=".seg", makedirs=False), {}, "/test_run_1/subject.seg"), + (dict(output_dir=Path("/test_run_1"), extension=None, makedirs=False), {}, "/test_run_1/subject"), + ( + dict(output_dir=Path("/test_run_1"), postfix="seg", extension=".test", makedirs=False), + {}, # using the default subject name + "/test_run_1/subject_seg.test", + ), + ( + dict(output_dir=Path("/test_run_1"), postfix="seg", extension=".test", makedirs=False), + {"subject": "test.abc"}, + "/test_run_1/test_seg.test", # subject's extension is ignored + ), + ( + dict(output_dir=Path("/test_run_1/dest/test1/"), data_root_dir="/test_run", makedirs=False), + {"subject": "/test_run/source/test.abc"}, + "/test_run_1/dest/test1/source/test", # preserves the structure from `subject` + ), + ( + dict(output_dir=Path("/test_run_1/dest/test1/"), makedirs=False), + {"subject": "/test_run/source/test.abc"}, + "/test_run_1/dest/test1/test", # data_root_dir used + ), + ( + dict(output_dir=Path("/test_run_1/dest/test1/"), makedirs=False), + {"subject": "/test_run/source/test.abc", "key": "value"}, + "/test_run_1/dest/test1/test_key-value", # data_root_dir used + ), + ( + dict(output_dir=Path("/test_run_1/"), postfix="seg", extension=".nii", makedirs=False), + dict(subject=Path("Sub-A"), idx="00", modality="T1"), + "/test_run_1/Sub-A_seg_00_modality-T1.nii", # test the code example + ), +] + + +class TestFolderLayout(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_value(self, con_params, f_params, expected): + fname = FolderLayout(**con_params).filename(**f_params) + self.assertEqual(Path(fname), Path(expected)) + + def test_mkdir(self): + """mkdir=True should create the directory if it does not exist.""" + with tempfile.TemporaryDirectory() as tempdir: + output_tmp = os.path.join(tempdir, "output") + FolderLayout(output_tmp, makedirs=True).filename("subject_test", "001") + self.assertTrue(os.path.exists(os.path.join(output_tmp))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_fourier.py b/tests/test_fourier.py index 488bf0cbf9..e1b5f3089d 100644 --- a/tests/test_fourier.py +++ b/tests/test_fourier.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_fullyconnectednet.py b/tests/test_fullyconnectednet.py index ec91a99c3e..6378ec9718 100644 --- a/tests/test_fullyconnectednet.py +++ b/tests/test_fullyconnectednet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_gaussian.py b/tests/test_gaussian.py index e2659abb0c..461b11d076 100644 --- a/tests/test_gaussian.py +++ b/tests/test_gaussian.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -241,16 +241,8 @@ def test_gaussian(self): rtol=1e-4, ) - np.testing.assert_allclose( - gaussian_1d(1, 1), - torch.tensor([0.24173, 0.382925, 0.24173]), - rtol=1e-4, - ) - np.testing.assert_allclose( - gaussian_1d(1, 1, normalize=True), - torch.tensor([0.2790, 0.4420, 0.2790]), - rtol=1e-4, - ) + np.testing.assert_allclose(gaussian_1d(1, 1), torch.tensor([0.24173, 0.382925, 0.24173]), rtol=1e-4) + np.testing.assert_allclose(gaussian_1d(1, 1, normalize=True), torch.tensor([0.2790, 0.4420, 0.2790]), rtol=1e-4) def test_scalespace_gaussian(self): np.testing.assert_allclose( @@ -272,15 +264,11 @@ def test_scalespace_gaussian(self): ) np.testing.assert_allclose( - gaussian_1d(1, 1, "scalespace"), - torch.tensor([0.20791, 0.46576, 0.20791]), - rtol=1e-3, + gaussian_1d(1, 1, "scalespace"), torch.tensor([0.20791, 0.46576, 0.20791]), rtol=1e-3 ) np.testing.assert_allclose( - gaussian_1d(1, 1, "scalespace", normalize=True), - torch.tensor([0.2358, 0.5283, 0.2358]), - rtol=1e-3, + gaussian_1d(1, 1, "scalespace", normalize=True), torch.tensor([0.2358, 0.5283, 0.2358]), rtol=1e-3 ) np.testing.assert_allclose( diff --git a/tests/test_gaussian_filter.py b/tests/test_gaussian_filter.py index 7636aa5459..9d76e44cec 100644 --- a/tests/test_gaussian_filter.py +++ b/tests/test_gaussian_filter.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,10 +19,7 @@ from tests.utils import SkipIfBeforePyTorchVersion, skip_if_quick TEST_CASES = [[{"type": "erf", "gt": 2.0}], [{"type": "scalespace", "gt": 3.0}], [{"type": "sampled", "gt": 5.0}]] -TEST_CASES_GPU = [ - [{"type": "erf", "gt": 0.8, "device": "cuda"}], - [{"type": "sampled", "gt": 5.0, "device": "cuda"}], -] +TEST_CASES_GPU = [[{"type": "erf", "gt": 0.8, "device": "cuda"}], [{"type": "sampled", "gt": 5.0, "device": "cuda"}]] TEST_CASES_3d = [ [{"type": "scalespace", "gt": 0.5, "dims": (2, 3, 8, 9, 10), "lr": 0.01, "device": "cuda"}], [{"type": "erf", "gt": 3.8, "dims": (2, 3, 8, 9, 10), "lr": 0.1, "device": "cuda"}], diff --git a/tests/test_gaussian_sharpen.py b/tests/test_gaussian_sharpen.py index 9d078e65e5..547febdfaf 100644 --- a/tests/test_gaussian_sharpen.py +++ b/tests/test_gaussian_sharpen.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,50 +11,79 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import GaussianSharpen +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( +TESTS = [] + +for p in TEST_NDARRAYS: + TESTS.append( [ - [[4.1081963, 3.4950666, 4.1081963], [3.7239995, 2.8491793, 3.7239995], [4.569839, 3.9529324, 4.569839]], - [[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]], + {}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [4.1081963, 3.4950666, 4.1081963], + [3.7239995, 2.8491793, 3.7239995], + [4.569839, 3.9529324, 4.569839], + ], + [[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - {"sigma1": 1.0, "sigma2": 0.75, "alpha": 20}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[4.513644, 4.869134, 4.513644], [8.467242, 9.4004135, 8.467242], [10.416813, 12.0653515, 10.416813]], - [[15.711488, 17.569994, 15.711488], [21.16811, 23.501041, 21.16811], [21.614658, 24.766209, 21.614658]], + {"sigma1": 1.0, "sigma2": 0.75, "alpha": 20}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [4.513644, 4.869134, 4.513644], + [8.467242, 9.4004135, 8.467242], + [10.416813, 12.0653515, 10.416813], + ], + [ + [15.711488, 17.569994, 15.711488], + [21.16811, 23.501041, 21.16811], + [21.614658, 24.766209, 21.614658], + ], + ] + ), ] - ), -] + ) -TEST_CASE_3 = [ - {"sigma1": (0.5, 1.0), "sigma2": (0.5, 0.75), "alpha": 20}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[3.3324685, 3.335536, 3.3324673], [7.7666636, 8.16056, 7.7666636], [12.662973, 14.317837, 12.6629715]], - [[15.329051, 16.57557, 15.329051], [19.41665, 20.40139, 19.416655], [24.659554, 27.557873, 24.659554]], + {"sigma1": (0.5, 1.0), "sigma2": (0.5, 0.75), "alpha": 20}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [3.3324685, 3.335536, 3.3324673], + [7.7666636, 8.16056, 7.7666636], + [12.662973, 14.317837, 12.6629715], + ], + [ + [15.329051, 16.57557, 15.329051], + [19.41665, 20.40139, 19.416655], + [24.659554, 27.557873, 24.659554], + ], + ] + ), ] - ), -] + ) class TestGaussianSharpen(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = GaussianSharpen(**argments)(image) - np.testing.assert_allclose(result, expected_data, rtol=1e-4) + assert_allclose(result, expected_data, atol=0, rtol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_gaussian_sharpend.py b/tests/test_gaussian_sharpend.py index c795b11762..d9ef503532 100644 --- a/tests/test_gaussian_sharpend.py +++ b/tests/test_gaussian_sharpend.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,46 +15,75 @@ from parameterized import parameterized from monai.transforms import GaussianSharpend +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"keys": "img"}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( [ - [[4.1081963, 3.4950666, 4.1081963], [3.7239995, 2.8491793, 3.7239995], [4.569839, 3.9529324, 4.569839]], - [[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]], + {"keys": "img"}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [4.1081963, 3.4950666, 4.1081963], + [3.7239995, 2.8491793, 3.7239995], + [4.569839, 3.9529324, 4.569839], + ], + [[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - {"keys": "img", "sigma1": 1.0, "sigma2": 0.75, "alpha": 20}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( + TESTS.append( [ - [[4.513644, 4.869134, 4.513644], [8.467242, 9.4004135, 8.467242], [10.416813, 12.0653515, 10.416813]], - [[15.711488, 17.569994, 15.711488], [21.16811, 23.501041, 21.16811], [21.614658, 24.766209, 21.614658]], + {"keys": "img", "sigma1": 1.0, "sigma2": 0.75, "alpha": 20}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [4.513644, 4.869134, 4.513644], + [8.467242, 9.4004135, 8.467242], + [10.416813, 12.0653515, 10.416813], + ], + [ + [15.711488, 17.569994, 15.711488], + [21.16811, 23.501041, 21.16811], + [21.614658, 24.766209, 21.614658], + ], + ] + ), ] - ), -] + ) -TEST_CASE_3 = [ - {"keys": "img", "sigma1": (0.5, 1.0), "sigma2": (0.5, 0.75), "alpha": 20}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( + TESTS.append( [ - [[3.3324685, 3.335536, 3.3324673], [7.7666636, 8.16056, 7.7666636], [12.662973, 14.317837, 12.6629715]], - [[15.329051, 16.57557, 15.329051], [19.41665, 20.40139, 19.416655], [24.659554, 27.557873, 24.659554]], + {"keys": "img", "sigma1": (0.5, 1.0), "sigma2": (0.5, 0.75), "alpha": 20}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [3.3324685, 3.335536, 3.3324673], + [7.7666636, 8.16056, 7.7666636], + [12.662973, 14.317837, 12.6629715], + ], + [ + [15.329051, 16.57557, 15.329051], + [19.41665, 20.40139, 19.416655], + [24.659554, 27.557873, 24.659554], + ], + ] + ), ] - ), -] + ) class TestGaussianSharpend(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = GaussianSharpend(**argments)(image) - np.testing.assert_allclose(result["img"], expected_data, rtol=1e-4) + assert_allclose(result["img"], expected_data, rtol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_gaussian_smooth.py b/tests/test_gaussian_smooth.py index e51977fbee..53f2fc396b 100644 --- a/tests/test_gaussian_smooth.py +++ b/tests/test_gaussian_smooth.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,54 +11,83 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import GaussianSmooth +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"sigma": 1.5}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( +TESTS = [] + +for p in TEST_NDARRAYS: + TESTS.append( [ - [ - [0.59167546, 0.69312394, 0.59167546], - [0.7956997, 0.93213004, 0.7956997], - [0.7668002, 0.8982755, 0.7668002], - ], - [[1.6105323, 1.8866735, 1.6105323], [1.9892492, 2.3303251, 1.9892492], [1.7856569, 2.091825, 1.7856569]], + {"sigma": 1.5}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [0.59167546, 0.69312394, 0.59167546], + [0.7956997, 0.93213004, 0.7956997], + [0.7668002, 0.8982755, 0.7668002], + ], + [ + [1.6105323, 1.8866735, 1.6105323], + [1.9892492, 2.3303251, 1.9892492], + [1.7856569, 2.091825, 1.7856569], + ], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - {"sigma": 0.5}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[0.8424794, 0.99864554, 0.8424794], [1.678146, 1.9892154, 1.678146], [1.9889624, 2.3576462, 1.9889624]], - [[2.966061, 3.5158648, 2.966061], [4.1953645, 4.973038, 4.1953645], [4.112544, 4.8748655, 4.1125436]], + {"sigma": 0.5}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [0.8424794, 0.99864554, 0.8424794], + [1.678146, 1.9892154, 1.678146], + [1.9889624, 2.3576462, 1.9889624], + ], + [ + [2.966061, 3.5158648, 2.966061], + [4.1953645, 4.973038, 4.1953645], + [4.112544, 4.8748655, 4.1125436], + ], + ] + ), ] - ), -] + ) -TEST_CASE_3 = [ - {"sigma": [1.5, 0.5]}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[0.8542037, 1.0125432, 0.8542037], [1.1487541, 1.3616928, 1.1487541], [1.1070318, 1.3122368, 1.1070318]], - [[2.3251305, 2.756128, 2.3251305], [2.8718853, 3.4042323, 2.8718853], [2.5779586, 3.0558217, 2.5779586]], + {"sigma": [1.5, 0.5]}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [0.8542037, 1.0125432, 0.8542037], + [1.1487541, 1.3616928, 1.1487541], + [1.1070318, 1.3122368, 1.1070318], + ], + [ + [2.3251305, 2.756128, 2.3251305], + [2.8718853, 3.4042323, 2.8718853], + [2.5779586, 3.0558217, 2.5779586], + ], + ] + ), ] - ), -] + ) class TestGaussianSmooth(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = GaussianSmooth(**argments)(image) - np.testing.assert_allclose(result, expected_data, rtol=1e-4) + assert_allclose(result, expected_data, atol=0, rtol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_gaussian_smoothd.py b/tests/test_gaussian_smoothd.py index 3d7eb6195e..839bac81fe 100644 --- a/tests/test_gaussian_smoothd.py +++ b/tests/test_gaussian_smoothd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,50 +15,79 @@ from parameterized import parameterized from monai.transforms import GaussianSmoothd +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"keys": "img", "sigma": 1.5}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( [ - [ - [0.59167546, 0.69312394, 0.59167546], - [0.7956997, 0.93213004, 0.7956997], - [0.7668002, 0.8982755, 0.7668002], - ], - [[1.6105323, 1.8866735, 1.6105323], [1.9892492, 2.3303251, 1.9892492], [1.7856569, 2.091825, 1.7856569]], + {"keys": "img", "sigma": 1.5}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [0.59167546, 0.69312394, 0.59167546], + [0.7956997, 0.93213004, 0.7956997], + [0.7668002, 0.8982755, 0.7668002], + ], + [ + [1.6105323, 1.8866735, 1.6105323], + [1.9892492, 2.3303251, 1.9892492], + [1.7856569, 2.091825, 1.7856569], + ], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - {"keys": "img", "sigma": 0.5}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( + TESTS.append( [ - [[0.8424794, 0.99864554, 0.8424794], [1.678146, 1.9892154, 1.678146], [1.9889624, 2.3576462, 1.9889624]], - [[2.966061, 3.5158648, 2.966061], [4.1953645, 4.973038, 4.1953645], [4.112544, 4.8748655, 4.1125436]], + {"keys": "img", "sigma": 0.5}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [0.8424794, 0.99864554, 0.8424794], + [1.678146, 1.9892154, 1.678146], + [1.9889624, 2.3576462, 1.9889624], + ], + [ + [2.966061, 3.5158648, 2.966061], + [4.1953645, 4.973038, 4.1953645], + [4.112544, 4.8748655, 4.1125436], + ], + ] + ), ] - ), -] + ) -TEST_CASE_3 = [ - {"keys": "img", "sigma": [1.5, 0.5]}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( + TESTS.append( [ - [[0.8542037, 1.0125432, 0.8542037], [1.1487541, 1.3616928, 1.1487541], [1.1070318, 1.3122368, 1.1070318]], - [[2.3251305, 2.756128, 2.3251305], [2.8718853, 3.4042323, 2.8718853], [2.5779586, 3.0558217, 2.5779586]], + {"keys": "img", "sigma": [1.5, 0.5]}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [0.8542037, 1.0125432, 0.8542037], + [1.1487541, 1.3616928, 1.1487541], + [1.1070318, 1.3122368, 1.1070318], + ], + [ + [2.3251305, 2.756128, 2.3251305], + [2.8718853, 3.4042323, 2.8718853], + [2.5779586, 3.0558217, 2.5779586], + ], + ] + ), ] - ), -] + ) class TestGaussianSmoothd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = GaussianSmoothd(**argments)(image) - np.testing.assert_allclose(result["img"], expected_data, rtol=1e-4) + assert_allclose(result["img"], expected_data, rtol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index 06446204fb..fa301201e4 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,10 +21,7 @@ TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.307576, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) @@ -87,7 +84,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - [0.273476, 0.555539], + [[[0.273476]], [[0.555539]]], ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 1e-8, "smooth_dr": 1e-8}, @@ -99,10 +96,7 @@ ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.307576, ], [ # shape: (1, 2, 4), (1, 1, 4) @@ -179,6 +173,26 @@ def test_input_warnings(self): loss = GeneralizedDiceLoss(to_onehot_y=True) loss.forward(chn_input, chn_target) + def test_differentiability(self): + prediction = torch.ones((1, 1, 1, 3)) + target = torch.ones((1, 1, 1, 3)) + prediction.requires_grad = True + target.requires_grad = True + + generalized_dice_loss = GeneralizedDiceLoss() + loss = generalized_dice_loss(prediction, target) + self.assertNotEqual(loss.grad_fn, None) + + def test_batch(self): + prediction = torch.zeros(2, 3, 3, 3) + target = torch.zeros(2, 3, 3, 3) + prediction.requires_grad = True + target.requires_grad = True + + generalized_dice_loss = GeneralizedDiceLoss(batch=True) + loss = generalized_dice_loss(prediction, target) + self.assertNotEqual(loss.grad_fn, None) + @SkipIfBeforePyTorchVersion((1, 7, 0)) def test_script(self): loss = GeneralizedDiceLoss() diff --git a/tests/test_generalized_wasserstein_dice_loss.py b/tests/test_generalized_wasserstein_dice_loss.py index 295a4a6d70..2c33d365f4 100644 --- a/tests/test_generalized_wasserstein_dice_loss.py +++ b/tests/test_generalized_wasserstein_dice_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -159,7 +159,7 @@ def test_convergence(self): # define a model with one layer class OnelayerNet(nn.Module): def __init__(self): - super(OnelayerNet, self).__init__() + super().__init__() self.layer = nn.Linear(num_voxels, num_voxels * num_classes) def forward(self, x): diff --git a/tests/test_generate_label_classes_crop_centers.py b/tests/test_generate_label_classes_crop_centers.py index 38f2a3e0d1..4f64aadc26 100644 --- a/tests/test_generate_label_classes_crop_centers.py +++ b/tests/test_generate_label_classes_crop_centers.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,11 +10,13 @@ # limitations under the License. import unittest +from copy import deepcopy -import numpy as np from parameterized import parameterized from monai.transforms import generate_label_classes_crop_centers +from monai.utils.misc import set_determinism +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASE_1 = [ { @@ -23,7 +25,6 @@ "ratios": [1, 2], "label_spatial_shape": [3, 3, 3], "indices": [[3, 12, 21], [1, 9, 18]], - "rand_state": np.random.RandomState(), }, list, 2, @@ -37,7 +38,6 @@ "ratios": None, "label_spatial_shape": [3, 3, 3], "indices": [[3, 12, 21], [1, 9, 18]], - "rand_state": np.random.RandomState(), }, list, 1, @@ -48,10 +48,21 @@ class TestGenerateLabelClassesCropCenters(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_type_shape(self, input_data, expected_type, expected_count, expected_shape): - result = generate_label_classes_crop_centers(**input_data) - self.assertIsInstance(result, expected_type) - self.assertEqual(len(result), expected_count) - self.assertEqual(len(result[0]), expected_shape) + results = [] + for p in TEST_NDARRAYS + (None,): + input_data = deepcopy(input_data) + if p is not None: + input_data["indices"] = p(input_data["indices"]) + set_determinism(0) + result = generate_label_classes_crop_centers(**input_data) + self.assertIsInstance(result, expected_type) + self.assertEqual(len(result), expected_count) + self.assertEqual(len(result[0]), expected_shape) + # check for consistency between numpy, torch and torch.cuda + results.append(result) + if len(results) > 1: + for x, y in zip(result[0], result[-1]): + assert_allclose(x, y, type_test=False) if __name__ == "__main__": diff --git a/tests/test_generate_param_groups.py b/tests/test_generate_param_groups.py index ea1fad44f9..0b259442ea 100644 --- a/tests/test_generate_param_groups.py +++ b/tests/test_generate_param_groups.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,15 +18,7 @@ from monai.optimizers import generate_param_groups from monai.utils import ensure_tuple -TEST_CASE_1 = [ - { - "layer_matches": [lambda x: x.model[-1]], - "match_types": "select", - "lr_values": [1], - }, - (1, 100), - [5, 21], -] +TEST_CASE_1 = [{"layer_matches": [lambda x: x.model[-1]], "match_types": "select", "lr_values": [1]}, (1, 100), [5, 21]] TEST_CASE_2 = [ { @@ -39,11 +31,7 @@ ] TEST_CASE_3 = [ - { - "layer_matches": [lambda x: x.model[2][1].conv[0].conv], - "match_types": ["select"], - "lr_values": [1], - }, + {"layer_matches": [lambda x: x.model[2][1].conv[0].conv], "match_types": ["select"], "lr_values": [1]}, (1, 100), [2, 24], ] @@ -59,12 +47,7 @@ ] TEST_CASE_5 = [ - { - "layer_matches": [lambda x: x.model[-1]], - "match_types": ["select"], - "lr_values": [1], - "include_others": False, - }, + {"layer_matches": [lambda x: x.model[-1]], "match_types": ["select"], "lr_values": [1], "include_others": False}, (1), [5], ] @@ -86,12 +69,7 @@ class TestGenerateParamGroups(unittest.TestCase): def test_lr_values(self, input_param, expected_values, expected_groups): device = "cuda" if torch.cuda.is_available() else "cpu" net = Unet( - dimensions=3, - in_channels=1, - out_channels=3, - channels=(16, 32, 64), - strides=(2, 2), - num_res_units=1, + spatial_dims=3, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=1 ).to(device) params = generate_param_groups(network=net, **input_param) @@ -107,12 +85,7 @@ def test_wrong(self): """overlapped""" device = "cuda" if torch.cuda.is_available() else "cpu" net = Unet( - dimensions=3, - in_channels=1, - out_channels=3, - channels=(16, 32, 64), - strides=(2, 2), - num_res_units=1, + spatial_dims=3, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=1 ).to(device) params = generate_param_groups( diff --git a/tests/test_generate_pos_neg_label_crop_centers.py b/tests/test_generate_pos_neg_label_crop_centers.py index 40181aa9ea..e1d9398fe3 100644 --- a/tests/test_generate_pos_neg_label_crop_centers.py +++ b/tests/test_generate_pos_neg_label_crop_centers.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,35 +10,52 @@ # limitations under the License. import unittest +from copy import deepcopy -import numpy as np from parameterized import parameterized from monai.transforms import generate_pos_neg_label_crop_centers - -TEST_CASE_1 = [ - { - "spatial_size": [2, 2, 2], - "num_samples": 2, - "pos_ratio": 1.0, - "label_spatial_shape": [3, 3, 3], - "fg_indices": [1, 9, 18], - "bg_indices": [3, 12, 21], - "rand_state": np.random.RandomState(), - }, - list, - 2, - 3, -] +from monai.utils.misc import set_determinism +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +TESTS.append( + [ + { + "spatial_size": [2, 2, 2], + "num_samples": 2, + "pos_ratio": 1.0, + "label_spatial_shape": [3, 3, 3], + "fg_indices": [1, 9, 18], + "bg_indices": [3, 12, 21], + }, + list, + 2, + 3, + ] +) class TestGeneratePosNegLabelCropCenters(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TESTS) def test_type_shape(self, input_data, expected_type, expected_count, expected_shape): - result = generate_pos_neg_label_crop_centers(**input_data) - self.assertIsInstance(result, expected_type) - self.assertEqual(len(result), expected_count) - self.assertEqual(len(result[0]), expected_shape) + results = [] + for p in TEST_NDARRAYS + (None,): + input_data = deepcopy(input_data) + if p is not None: + for k in ["fg_indices", "bg_indices"]: + input_data[k] = p(input_data[k]) + set_determinism(0) + result = generate_pos_neg_label_crop_centers(**input_data) + self.assertIsInstance(result, expected_type) + self.assertEqual(len(result), expected_count) + self.assertEqual(len(result[0]), expected_shape) + # check for consistency between numpy, torch and torch.cuda + results.append(result) + if len(results) > 1: + # compare every crop center + for x, y in zip(results[0], results[-1]): + assert_allclose(x, y, type_test=False) if __name__ == "__main__": diff --git a/tests/test_generate_spatial_bounding_box.py b/tests/test_generate_spatial_bounding_box.py index 32a45d8d1c..d27d5a570f 100644 --- a/tests/test_generate_spatial_bounding_box.py +++ b/tests/test_generate_spatial_bounding_box.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,60 +15,94 @@ from parameterized import parameterized from monai.transforms import generate_spatial_bounding_box +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 0, - "channel_indices": None, - "margin": 0, - }, - ([1, 1], [4, 4]), -] - -TEST_CASE_2 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 1, - "channel_indices": None, - "margin": 0, - }, - ([2, 2], [3, 3]), -] - -TEST_CASE_3 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 0, - "channel_indices": 0, - "margin": 0, - }, - ([1, 1], [4, 4]), -] - -TEST_CASE_4 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 0, - "channel_indices": None, - "margin": 1, - }, - ([0, 0], [4, 5]), -] - -TEST_CASE_5 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 0, - "channel_indices": None, - "margin": [2, 1], - }, - ([0, 0], [5, 5]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": 0, + }, + ([1, 1], [4, 4]), + ] + ) + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 1, + "channel_indices": None, + "margin": 0, + }, + ([2, 2], [3, 3]), + ] + ) + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 0, + "channel_indices": 0, + "margin": 0, + }, + ([1, 1], [4, 4]), + ] + ) + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": 1, + }, + ([0, 0], [4, 5]), + ] + ) + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": [2, 1], + "allow_smaller": False, + }, + ([-1, 0], [6, 5]), + ] + ) + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": [2, 1], + "allow_smaller": True, + }, + ([0, 0], [5, 5]), + ] + ) class TestGenerateSpatialBoundingBox(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_value(self, input_data, expected_box): result = generate_spatial_bounding_box(**input_data) self.assertTupleEqual(result, expected_box) diff --git a/tests/test_generator.py b/tests/test_generator.py index b5d846febc..617655f86e 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_get_equivalent_dtype.py b/tests/test_get_equivalent_dtype.py index 04ba5ae5fb..fc0867523d 100644 --- a/tests/test_get_equivalent_dtype.py +++ b/tests/test_get_equivalent_dtype.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -32,6 +32,14 @@ def test_get_equivalent_dtype(self, im, input_dtype): out_dtype = get_equivalent_dtype(input_dtype, type(im)) self.assertEqual(out_dtype, im.dtype) + def test_native_type(self): + """the get_equivalent_dtype currently doesn't change the build-in type""" + n_type = [float, int, bool] + for n in n_type: + for im_dtype in DTYPES: + out_dtype = get_equivalent_dtype(n, type(im_dtype)) + self.assertEqual(out_dtype, n) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_get_extreme_points.py b/tests/test_get_extreme_points.py index a334c12415..457351b98c 100644 --- a/tests/test_get_extreme_points.py +++ b/tests/test_get_extreme_points.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,30 +15,37 @@ from parameterized import parameterized from monai.transforms import get_extreme_points - -TEST_CASE_1 = [ - { - "img": np.array([[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]), - "rand_state": np.random, - "background": 0, - "pert": 0.0, - }, - [(0, 1), (3, 0), (3, 0), (1, 2)], -] - -TEST_CASE_2 = [ - { - "img": np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]), - "rand_state": np.random, - "background": 0, - "pert": 0.0, - }, - [(0, 1), (3, 1), (1, 0), (1, 2)], -] +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + { + "img": p(np.array([[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]])), + "rand_state": np.random, + "background": 0, + "pert": 0.0, + }, + [(0, 1), (3, 0), (3, 0), (1, 2)], + ] + ) + + TESTS.append( + [ + { + "img": p(np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]])), + "rand_state": np.random, + "background": 0, + "pert": 0.0, + }, + [(0, 1), (3, 1), (1, 0), (1, 2)], + ] + ) class TestGetExtremePoints(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_type_shape(self, input_data, expected): result = get_extreme_points(**input_data) self.assertEqual(result, expected) diff --git a/tests/test_get_layers.py b/tests/test_get_layers.py index e6ea810a6b..6109052d1f 100644 --- a/tests/test_get_layers.py +++ b/tests/test_get_layers.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_get_package_version.py b/tests/test_get_package_version.py index beddb340ab..c4e15c9d09 100644 --- a/tests/test_get_package_version.py +++ b/tests/test_get_package_version.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_get_unique_labels.py b/tests/test_get_unique_labels.py new file mode 100644 index 0000000000..9bc6f9b152 --- /dev/null +++ b/tests/test_get_unique_labels.py @@ -0,0 +1,44 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import torch +import torch.nn.functional as F +from parameterized import parameterized + +from monai.transforms.utils import get_unique_labels +from monai.transforms.utils_pytorch_numpy_unification import moveaxis +from tests.utils import TEST_NDARRAYS + +grid_raw = [[0, 0, 0], [0, 0, 1], [2, 2, 3], [5, 5, 6], [3, 6, 2], [5, 6, 6]] +grid = torch.Tensor(grid_raw).unsqueeze(0).to(torch.int64) +grid_onehot = moveaxis(F.one_hot(grid)[0], -1, 0) + +TESTS = [] +for p in TEST_NDARRAYS: + for o_h in (False, True): + im = grid_onehot if o_h else grid + TESTS.append([dict(img=p(im), is_onehot=o_h), {0, 1, 2, 3, 5, 6}]) + TESTS.append([dict(img=p(im), is_onehot=o_h, discard=0), {1, 2, 3, 5, 6}]) + TESTS.append([dict(img=p(im), is_onehot=o_h, discard=[1, 2]), {0, 3, 5, 6}]) + + +class TestGetUniqueLabels(unittest.TestCase): + @parameterized.expand(TESTS) + def test_correct_results(self, args, expected): + result = get_unique_labels(**args) + self.assertEqual(result, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_gibbs_noise.py b/tests/test_gibbs_noise.py index 264e2e630a..3fbe047944 100644 --- a/tests/test_gibbs_noise.py +++ b/tests/test_gibbs_noise.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,17 +19,17 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import GibbsNoise from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS + +_, has_torch_fft = optional_import("torch.fft", name="fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, input_type)) -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestGibbsNoise(unittest.TestCase): def setUp(self): set_determinism(0) @@ -39,36 +39,39 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d im = create_test_image(*im_shape, num_objs=4, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] - return torch.Tensor(im) if as_tensor_input else im + return input_type(im) @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_same_result(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = 0.8 - t = GibbsNoise(alpha, as_tensor_output) + t = GibbsNoise(alpha) out1 = t(deepcopy(im)) out2 = t(deepcopy(im)) - np.testing.assert_allclose(out1, out2) - self.assertIsInstance(out1, torch.Tensor if as_tensor_output else np.ndarray) + self.assertEqual(type(out1), type(im)) + if isinstance(out1, torch.Tensor): + self.assertEqual(out1.device, im.device) + torch.testing.assert_allclose(out1, out2, rtol=1e-7, atol=0) + self.assertIsInstance(out1, type(im)) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_identity(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = 0.0 t = GibbsNoise(alpha) out = t(deepcopy(im)) - np.testing.assert_allclose(im, out, atol=1e-2) + torch.testing.assert_allclose(im, out, atol=1e-2, rtol=1e-7) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_alpha_1(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = 1.0 t = GibbsNoise(alpha) out = t(deepcopy(im)) - np.testing.assert_allclose(0 * im, out) + torch.testing.assert_allclose(0 * im, out, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_gibbs_noised.py b/tests/test_gibbs_noised.py index 558556489a..4905300703 100644 --- a/tests/test_gibbs_noised.py +++ b/tests/test_gibbs_noised.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,19 +19,18 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import GibbsNoised from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS + +_, has_torch_fft = optional_import("torch.fft", name="fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) - + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, input_type)) KEYS = ["im", "label"] -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestGibbsNoised(unittest.TestCase): def setUp(self): set_determinism(0) @@ -41,49 +40,56 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims - return dict(zip(KEYS, ims)) + return {k: input_type(deepcopy(v)) for k, v in zip(KEYS, ims)} @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_same_result(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = 0.8 - t = GibbsNoised(KEYS, alpha, as_tensor_output) + t = GibbsNoised(KEYS, alpha) out1 = t(deepcopy(data)) out2 = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(out1[k], out2[k]) - self.assertIsInstance(out1[k], torch.Tensor if as_tensor_output else np.ndarray) + torch.testing.assert_allclose(out1[k], out2[k], rtol=1e-7, atol=0) + self.assertIsInstance(out1[k], type(data[k])) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_identity(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = 0.0 t = GibbsNoised(KEYS, alpha) out = t(deepcopy(data)) for k in KEYS: + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k], data[k] = out[k].cpu(), data[k].cpu() np.testing.assert_allclose(data[k], out[k], atol=1e-2) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_alpha_1(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = 1.0 t = GibbsNoised(KEYS, alpha) out = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(0 * data[k], out[k]) + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k], data[k] = out[k].cpu(), data[k].cpu() + np.testing.assert_allclose(0.0 * data[k], out[k], atol=1e-2) @parameterized.expand(TEST_CASES) - def test_dict_matches(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_dict_matches(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} alpha = 1.0 t = GibbsNoised(KEYS, alpha) out = t(deepcopy(data)) - np.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]]) + torch.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]], rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_global_mutual_information_loss.py b/tests/test_global_mutual_information_loss.py index 3373b59621..b2a52f139b 100644 --- a/tests/test_global_mutual_information_loss.py +++ b/tests/test_global_mutual_information_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -8,87 +8,106 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os import unittest import numpy as np import torch -from parameterized import parameterized +from monai import transforms from monai.losses.image_dissimilarity import GlobalMutualInformationLoss +from tests.utils import SkipIfBeforePyTorchVersion, download_url_or_skip_test, skip_if_quick, testing_data_config device = "cuda" if torch.cuda.is_available() else "cpu" -TEST_CASES = [ - [ - {}, - { - "pred": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None, None] - .expand(1, 3, 3, 3, 3) - .div(3), - "target": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None, None] - .expand(1, 3, 3, 3, 3) - .div(3), - }, - -1.0986018, - ], - [ - {}, - { - "pred": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None, None] - .expand(1, 3, 3, 3, 3) - .div(3), - "target": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None, None] - .expand(1, 3, 3, 3, 3) - .div(3) - ** 2, - }, - -1.083999, - ], - [ - {}, - { - "pred": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None].expand(1, 3, 3, 3).div(3), - "target": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None] - .expand(1, 3, 3, 3) - .div(3) - ** 2, - }, - -1.083999, - ], - [ - {}, - { - "pred": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None].expand(1, 3, 3).div(3), - "target": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None].expand(1, 3, 3).div(3) ** 2, - }, - -1.083999, - ], - [ - {}, - { - "pred": torch.arange(0, 3, dtype=torch.float, device=device)[None, :].div(3), - "target": torch.arange(0, 3, dtype=torch.float, device=device)[None, :].div(3) ** 2, - }, - -1.083999, +FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + "mri.nii") + +EXPECTED_VALUE = { + "xyz_translation": [ + -1.5860259532928467, + -0.5957175493240356, + -0.3855515122413635, + -0.28728482127189636, + -0.23416118323802948, + -0.19534644484519958, + -0.17001715302467346, + -0.15043553709983826, + -0.1366637945175171, + -0.12534910440444946, ], - [ - {}, - { - "pred": torch.arange(0, 3, dtype=torch.float, device=device).div(3), - "target": torch.arange(0, 3, dtype=torch.float, device=device).div(3) ** 2, - }, - -1.1920927e-07, + "xyz_rotation": [ + -1.5860259532928467, + -0.29977330565452576, + -0.18411292135715485, + -0.1582011878490448, + -0.16107326745986938, + -0.165723517537117, + -0.1970357596874237, + -0.1755618453025818, + -0.17100191116333008, + -0.17264796793460846, ], -] +} +@skip_if_quick class TestGlobalMutualInformationLoss(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_shape(self, input_param, input_data, expected_val): - result = GlobalMutualInformationLoss(**input_param).forward(**input_data) - np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4) + def setUp(self): + config = testing_data_config("images", "Prostate_T2W_AX_1") + download_url_or_skip_test( + url=config["url"], + filepath=FILE_PATH, + hash_val=config.get("hash_val"), + hash_type=config.get("hash_type", "sha256"), + ) + + @SkipIfBeforePyTorchVersion((1, 9)) + def test_bspline(self): + loss_fn = GlobalMutualInformationLoss(kernel_type="b-spline", num_bins=32, sigma_ratio=0.015) + + transform_params_dict = { + "xyz_translation": [(i, i, i) for i in range(10)], + "xyz_rotation": [(np.pi / 100 * i, np.pi / 100 * i, np.pi / 100 * i) for i in range(10)], + } + + def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0.0)): + """ + Read and transform Prostate_T2W_AX_1.nii + Args: + translate_params: a tuple of 3 floats, translation is in pixel/voxel relative to the center of the input + image. Defaults to no translation. + rotate_params: a rotation angle in radians, a tuple of 3 floats for 3D. + Defaults to no rotation. + Returns: + numpy array of shape HWD + """ + transform_list = [ + transforms.LoadImaged(keys="img"), + transforms.Affined( + keys="img", translate_params=translate_params, rotate_params=rotate_params, device=None + ), + transforms.NormalizeIntensityd(keys=["img"]), + ] + transformation = transforms.Compose(transform_list) + return transformation({"img": FILE_PATH})["img"] + + a1 = transformation() + a1 = torch.tensor(a1).unsqueeze(0).unsqueeze(0).to(device) + + for mode in transform_params_dict.keys(): + transform_params_list = transform_params_dict[mode] + expected_value_list = EXPECTED_VALUE[mode] + for transform_params, expected_value in zip(transform_params_list, expected_value_list): + a2 = transformation( + translate_params=transform_params if "translation" in mode else (0.0, 0.0, 0.0), + rotate_params=transform_params if "rotation" in mode else (0.0, 0.0, 0.0), + ) + a2 = torch.tensor(a2).unsqueeze(0).unsqueeze(0).to(device) + result = loss_fn(a2, a1).detach().cpu().numpy() + np.testing.assert_allclose(result, expected_value, rtol=1e-3, atol=5e-3) + +class TestGlobalMutualInformationLossIll(unittest.TestCase): def test_ill_shape(self): loss = GlobalMutualInformationLoss() with self.assertRaisesRegex(ValueError, ""): diff --git a/tests/test_globalnet.py b/tests/test_globalnet.py index 32bc58f610..ef0209e397 100644 --- a/tests/test_globalnet.py +++ b/tests/test_globalnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_gmm.py b/tests/test_gmm.py index 0e2401b452..f085dd916c 100644 --- a/tests/test_gmm.py +++ b/tests/test_gmm.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -47,12 +47,12 @@ # Batch 0 [ # Channel 0 - [1, -1, 0, -1, 1], + [1, -1, 0, -1, 1] ], # Batch 1 [ # Channel 0 - [1, 1, 0, 0, -1], + [1, 1, 0, 0, -1] ], ], # Expected @@ -94,15 +94,15 @@ [0.7, 0.9, 0.0, 0.0, 0.0], # Channel 4 [0.2, 0.1, 0.2, 0.2, 0.1], - ], + ] ], # Labels [ # Batch 0 [ # Channel 0 - [0, 0, -1, 1, 1], - ], + [0, 0, -1, 1, 1] + ] ], # Expected [ @@ -112,7 +112,7 @@ [1, 1, 0, 0, 0], # Channel 1 [0, 0, 1, 1, 1], - ], + ] ], ], [ @@ -142,21 +142,15 @@ [0.4, 0.5, 0.0, 0.0, 0.0], [0.7, 0.6, 0.0, 0.0, 0.0], ], - ], + ] ], # Labels [ # Batch 0 [ # Channel 0 - [ - [-1, 1, -1, 0, -1], - [1, -1, -1, -1, -1], - [-1, -1, 0, -1, -1], - [2, 2, -1, 3, -1], - [-1, -1, -1, -1, 3], - ], - ], + [[-1, 1, -1, 0, -1], [1, -1, -1, -1, -1], [-1, -1, 0, -1, -1], [2, 2, -1, 3, -1], [-1, -1, -1, -1, 3]] + ] ], # Expected [ @@ -194,7 +188,7 @@ [0.0, 0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 1.0, 1.0], ], - ], + ] ], ], [ @@ -211,25 +205,13 @@ # Channel 0 [ # Slice 0 - [ - [0.7, 0.6, 0.0], - [0.5, 0.4, 0.0], - [0.0, 0.0, 0.0], - ], + [[0.7, 0.6, 0.0], [0.5, 0.4, 0.0], [0.0, 0.0, 0.0]], # Slice 1 - [ - [0.5, 0.6, 0.0], - [0.4, 0.3, 0.0], - [0.0, 0.0, 0.0], - ], + [[0.5, 0.6, 0.0], [0.4, 0.3, 0.0], [0.0, 0.0, 0.0]], # Slice 2 - [ - [0.3, 0.3, 0.0], - [0.2, 0.1, 0.0], - [0.0, 0.0, 0.0], - ], - ], - ], + [[0.3, 0.3, 0.0], [0.2, 0.1, 0.0], [0.0, 0.0, 0.0]], + ] + ] ], # Labels [ @@ -238,25 +220,13 @@ # Channel 0 [ # Slice 0 - [ - [0, -1, -1], - [0, -1, -1], - [-1, -1, 1], - ], + [[0, -1, -1], [0, -1, -1], [-1, -1, 1]], # Slice 1 - [ - [0, 0, -1], - [-1, -1, 1], - [-1, 1, 1], - ], + [[0, 0, -1], [-1, -1, 1], [-1, 1, 1]], # Slice 2 - [ - [0, -1, -1], - [-1, -1, -1], - [-1, -1, -1], - ], - ], - ], + [[0, -1, -1], [-1, -1, -1], [-1, -1, -1]], + ] + ] ], # Expected [ @@ -265,46 +235,22 @@ # Channel 0 [ # Slice 0 - [ - [1.0, 1.0, 0.0], - [1.0, 1.0, 0.0], - [0.0, 0.0, 0.0], - ], + [[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 0.0]], # Slice 1 - [ - [1.0, 1.0, 0.0], - [1.0, 1.0, 0.0], - [0.0, 0.0, 0.0], - ], + [[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 0.0]], # Slice 2 - [ - [1.0, 1.0, 0.0], - [1.0, 1.0, 0.0], - [0.0, 0.0, 0.0], - ], + [[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 0.0]], ], # Channel 1 [ # Slice 0 - [ - [0.0, 0.0, 1.0], - [0.0, 0.0, 1.0], - [1.0, 1.0, 1.0], - ], + [[0.0, 0.0, 1.0], [0.0, 0.0, 1.0], [1.0, 1.0, 1.0]], # Slice 1 - [ - [0.0, 0.0, 1.0], - [0.0, 0.0, 1.0], - [1.0, 1.0, 1.0], - ], + [[0.0, 0.0, 1.0], [0.0, 0.0, 1.0], [1.0, 1.0, 1.0]], # Slice 2 - [ - [0.0, 0.0, 1.0], - [0.0, 0.0, 1.0], - [1.0, 1.0, 1.0], - ], + [[0.0, 0.0, 1.0], [0.0, 0.0, 1.0], [1.0, 1.0, 1.0]], ], - ], + ] ], ], ] diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py index 6e0aa4023e..9361d82cdf 100644 --- a/tests/test_grid_dataset.py +++ b/tests/test_grid_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -33,49 +33,47 @@ def tearDown(self): set_determinism(None) def test_shape(self): - test_dataset = ["vwxyz", "helloworld", "worldfoobar"] - result = GridPatchDataset(dataset=test_dataset, patch_iter=identity_generator, with_coordinates=False) + # test Iterable input data + test_dataset = iter(["vwxyz", "helloworld", "worldfoobar"]) + result = GridPatchDataset(data=test_dataset, patch_iter=identity_generator, with_coordinates=False) output = [] n_workers = 0 if sys.platform == "win32" else 2 for item in DataLoader(result, batch_size=3, num_workers=n_workers): output.append("".join(item)) - expected = ["vwx", "wor", "yzh", "ldf", "ell", "oob", "owo", "ar", "rld"] + if sys.platform == "win32": + expected = ["ar", "ell", "ldf", "oob", "owo", "rld", "vwx", "wor", "yzh"] + else: + expected = ["d", "dfo", "hel", "low", "oba", "orl", "orl", "r", "vwx", "yzw"] + self.assertEqual(len("".join(expected)), len("".join(list(test_dataset)))) self.assertEqual(sorted(output), sorted(expected)) - self.assertEqual(len("".join(expected)), len("".join(test_dataset))) def test_loading_array(self): set_determinism(seed=1234) - # image dataset + # test sequence input data with images images = [np.arange(16, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)] # image level patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0) patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0)) - ds = GridPatchDataset(dataset=images, patch_iter=patch_iter, transform=patch_intensity) + ds = GridPatchDataset(data=images, patch_iter=patch_iter, transform=patch_intensity) # use the grid patch dataset for item in DataLoader(ds, batch_size=2, shuffle=False, num_workers=0): np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) np.testing.assert_allclose( item[0], - np.array([[[[1.7413, 2.7413], [5.7413, 6.7413]]], [[[9.1419, 10.1419], [13.1419, 14.1419]]]]), - rtol=1e-5, - ) - np.testing.assert_allclose( - item[1], - np.array([[[0, 1], [0, 2], [2, 4]], [[0, 1], [2, 4], [2, 4]]]), - rtol=1e-5, + np.array([[[[1.4965, 2.4965], [5.4965, 6.4965]]], [[[11.3584, 12.3584], [15.3584, 16.3584]]]]), + rtol=1e-4, ) + np.testing.assert_allclose(item[1], np.array([[[0, 1], [0, 2], [2, 4]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5) if sys.platform != "win32": for item in DataLoader(ds, batch_size=2, shuffle=False, num_workers=2): np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) np.testing.assert_allclose( item[0], - np.array([[[[2.3944, 3.3944], [6.3944, 7.3944]]], [[[10.6551, 11.6551], [14.6551, 15.6551]]]]), + np.array([[[[1.2548, 2.2548], [5.2548, 6.2548]]], [[[9.1106, 10.1106], [13.1106, 14.1106]]]]), rtol=1e-3, ) np.testing.assert_allclose( - item[1], - np.array([[[0, 1], [0, 2], [2, 4]], [[0, 1], [2, 4], [2, 4]]]), - rtol=1e-5, + item[1], np.array([[[0, 1], [0, 2], [2, 4]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5 ) diff --git a/tests/test_grid_distortion.py b/tests/test_grid_distortion.py new file mode 100644 index 0000000000..5e7ccd7c32 --- /dev/null +++ b/tests/test_grid_distortion.py @@ -0,0 +1,108 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import GridDistortion +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + dict(num_cells=3, distort_steps=[(1.5,) * 4] * 2, mode="nearest", padding_mode="zeros"), + p(np.indices([6, 6]).astype(np.float32)), + p( + np.array( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [3.0, 3.0, 3.0, 0.0, 0.0, 0.0], + [3.0, 3.0, 3.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 3.0, 3.0, 0.0, 0.0, 0.0], + [0.0, 3.0, 3.0, 0.0, 0.0, 0.0], + [0.0, 3.0, 3.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + ] + ).astype(np.float32) + ), + ] + ) + num_cells = (2, 2) + distort_steps = [(1.5,) * (1 + num_cells[0]), (1.0,) * (1 + num_cells[1])] + TESTS.append( + [ + dict(num_cells=num_cells, distort_steps=distort_steps, mode="bilinear", padding_mode="reflection"), + p(np.indices([6, 6]).astype(np.float32)), + p( + np.array( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [2.25, 2.25, 2.25, 2.25, 2.25, 2.25], + [4.5, 4.5, 4.5, 4.5, 4.5, 4.5], + [4.5, 4.5, 4.5, 4.5, 4.5, 4.5], + [3.25, 3.25, 3.25, 3.25, 3.25, 3.25], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ], + [ + [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], + [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], + [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], + [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], + [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], + [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], + ], + ] + ).astype(np.float32) + ), + ] + ) + TESTS.append( + [ + dict(num_cells=2, distort_steps=[(1.25,) * 3] * 3, mode="nearest", padding_mode="zeros"), + p(np.indices([3, 3, 3])[:1].astype(np.float32)), + p( + np.array( + [ + [ + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + ] + ] + ).astype(np.float32) + ), + ] + ) + + +class TestGridDistortion(unittest.TestCase): + @parameterized.expand(TESTS) + def test_grid_distortion(self, input_param, input_data, expected_val): + g = GridDistortion(**input_param) + result = g(input_data) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_grid_distortiond.py b/tests/test_grid_distortiond.py new file mode 100644 index 0000000000..662596f935 --- /dev/null +++ b/tests/test_grid_distortiond.py @@ -0,0 +1,85 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import GridDistortiond +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +num_cells = (2, 2) +distort_steps = [(1.5,) * (1 + n_c) for n_c in num_cells] +for p in TEST_NDARRAYS: + img = np.indices([6, 6]).astype(np.float32) + TESTS.append( + [ + dict( + keys=["img", "mask"], + num_cells=num_cells, + distort_steps=distort_steps, + mode=["bilinear", "nearest"], + padding_mode=["reflection", "zeros"], + ), + {"img": p(img), "mask": p(np.ones_like(img[:1]))}, + p( + np.array( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [2.25, 2.25, 2.25, 2.25, 2.25, 2.25], + [4.5, 4.5, 4.5, 4.5, 4.5, 4.5], + [4.5, 4.5, 4.5, 4.5, 4.5, 4.5], + [3.25, 3.25, 3.25, 3.25, 3.25, 3.25], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ], + [ + [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], + [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], + [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], + [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], + [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], + [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], + ], + ] + ).astype(np.float32) + ), + p( + np.array( + [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] + ] + ).astype(np.float32) + ), + ] + ) + + +class TestGridDistortiond(unittest.TestCase): + @parameterized.expand(TESTS) + def test_grid_distortiond(self, input_param, input_data, expected_val_img, expected_val_mask): + g = GridDistortiond(**input_param) + result = g(input_data) + assert_allclose(result["img"], expected_val_img, rtol=1e-4, atol=1e-4) + assert_allclose(result["mask"], expected_val_mask, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_grid_pull.py b/tests/test_grid_pull.py index 9e4d2e8237..561b231498 100644 --- a/tests/test_grid_pull.py +++ b/tests/test_grid_pull.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,6 +16,7 @@ from parameterized import parameterized from monai.networks.layers import grid_pull +from monai.networks.utils import meshgrid_ij from monai.utils import optional_import from tests.testing_data.cpp_resample_answers import Expected_1D_GP_bwd, Expected_1D_GP_fwd from tests.utils import skip_if_no_cpp_extension @@ -26,7 +27,7 @@ def make_grid(shape, dtype=None, device=None, requires_grad=True): ranges = [torch.arange(float(s), dtype=dtype, device=device, requires_grad=requires_grad) for s in shape] - grid = torch.stack(torch.meshgrid(*ranges), dim=-1) + grid = torch.stack(meshgrid_ij(*ranges), dim=-1) return grid[None] @@ -53,11 +54,7 @@ def make_grid(shape, dtype=None, device=None, requires_grad=True): "interpolation": interp, "bound": bound, }, - { - "val": torch.tensor([[expected_val]]), - "device": device, - "grad": torch.tensor(expected_grad), - }, + {"val": torch.tensor([[expected_val]]), "device": device, "grad": torch.tensor(expected_grad)}, ] TEST_1D_GP.append(test_case) @@ -85,7 +82,7 @@ def test_grid_pull(self, input_param, expected): grads = grads[0] else: grads = torch.cat(grads, dim=0) - self.assertTrue("{}".format(result.device).startswith(expected["device"])) + self.assertTrue(f"{result.device}".startswith(expected["device"])) np.testing.assert_allclose(result.detach().cpu().numpy(), expected["val"].cpu().numpy(), rtol=1e-4, atol=1e-4) np.testing.assert_allclose(grads.detach().cpu().numpy(), expected["grad"].cpu().numpy(), rtol=1e-4, atol=1e-4) diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py index 81a3cdc96d..2331602234 100644 --- a/tests/test_handler_checkpoint_loader.py +++ b/tests/test_handler_checkpoint_loader.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,8 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys import tempfile import unittest @@ -23,7 +21,6 @@ class TestHandlerCheckpointLoader(unittest.TestCase): def test_one_save_one_load(self): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) net1 = torch.nn.PReLU() data1 = net1.state_dict() data1["weight"] = torch.tensor([0.1]) @@ -58,7 +55,6 @@ def check_epoch(engine: Engine): self.assertEqual(engine3.state.max_epochs, 5) def test_two_save_one_load(self): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) net1 = torch.nn.PReLU() optimizer = optim.SGD(net1.parameters(), lr=0.02) data1 = net1.state_dict() @@ -80,7 +76,6 @@ def test_two_save_one_load(self): torch.testing.assert_allclose(net2.state_dict()["weight"], torch.tensor([0.1])) def test_save_single_device_load_multi_devices(self): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) net1 = torch.nn.PReLU() data1 = net1.state_dict() data1["weight"] = torch.tensor([0.1]) @@ -101,7 +96,6 @@ def test_save_single_device_load_multi_devices(self): torch.testing.assert_allclose(net2.state_dict()["module.weight"].cpu(), torch.tensor([0.1])) def test_partial_under_load(self): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) net1 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()]) data1 = net1.state_dict() data1["0.weight"] = torch.tensor([0.1]) @@ -124,7 +118,6 @@ def test_partial_under_load(self): torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.1])) def test_partial_over_load(self): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) net1 = torch.nn.Sequential(*[torch.nn.PReLU()]) data1 = net1.state_dict() data1["0.weight"] = torch.tensor([0.1]) @@ -147,7 +140,6 @@ def test_partial_over_load(self): torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.1])) def test_strict_shape(self): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) net1 = torch.nn.Sequential(*[torch.nn.PReLU(num_parameters=5)]) data1 = net1.state_dict() data1["0.weight"] = torch.tensor([1, 2, 3, 4, 5]) diff --git a/tests/test_handler_checkpoint_saver.py b/tests/test_handler_checkpoint_saver.py index bcab49f12b..c87866490c 100644 --- a/tests/test_handler_checkpoint_saver.py +++ b/tests/test_handler_checkpoint_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,9 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import os -import sys import tempfile import unittest @@ -112,16 +110,7 @@ class TestHandlerCheckpointSaver(unittest.TestCase): @parameterized.expand( - [ - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - TEST_CASE_7, - TEST_CASE_8, - ] + [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8] ) def test_file( self, @@ -140,7 +129,6 @@ def test_file( filenames, multi_devices=False, ): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) data = [0] * 8 # set up engine @@ -180,7 +168,6 @@ def _train_func(engine, batch): self.assertTrue(os.path.exists(os.path.join(tempdir, filename))) def test_exception(self): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) net = torch.nn.PReLU() # set up engine @@ -199,7 +186,6 @@ def _train_func(engine, batch): self.assertTrue(os.path.exists(os.path.join(tempdir, "net_final_iteration=1.pt"))) def test_load_state_dict(self): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) net = torch.nn.PReLU() # set up engine diff --git a/tests/test_handler_classification_saver.py b/tests/test_handler_classification_saver.py index 87ce5ca3f8..313842b443 100644 --- a/tests/test_handler_classification_saver.py +++ b/tests/test_handler_classification_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -35,8 +35,8 @@ def _train_func(engine, batch): engine = Engine(_train_func) # set up testing handler - saver = CSVSaver(output_dir=tempdir, filename="predictions2.csv") - ClassificationSaver(output_dir=tempdir, filename="predictions1.csv").attach(engine) + saver = CSVSaver(output_dir=tempdir, filename="predictions2.csv", delimiter="\t") + ClassificationSaver(output_dir=tempdir, filename="predictions1.csv", delimiter="\t").attach(engine) ClassificationSaver(saver=saver).attach(engine) data = [{"filename_or_obj": ["testfile" + str(i) for i in range(8)]}] @@ -45,8 +45,8 @@ def _train_func(engine, batch): def _test_file(filename): filepath = os.path.join(tempdir, filename) self.assertTrue(os.path.exists(filepath)) - with open(filepath, "r") as f: - reader = csv.reader(f) + with open(filepath) as f: + reader = csv.reader(f, delimiter="\t") i = 0 for row in reader: self.assertEqual(row[0], "testfile" + str(i)) diff --git a/tests/test_handler_classification_saver_dist.py b/tests/test_handler_classification_saver_dist.py index 70cc0ca42f..e92009d37f 100644 --- a/tests/test_handler_classification_saver_dist.py +++ b/tests/test_handler_classification_saver_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -61,7 +61,7 @@ def _train_func(engine, batch): filepath = os.path.join(tempdir, "predictions.csv") if rank == 1: self.assertTrue(os.path.exists(filepath)) - with open(filepath, "r") as f: + with open(filepath) as f: reader = csv.reader(f) i = 0 for row in reader: diff --git a/tests/test_handler_confusion_matrix.py b/tests/test_handler_confusion_matrix.py index 0c6e36066b..5bddef26af 100644 --- a/tests/test_handler_confusion_matrix.py +++ b/tests/test_handler_confusion_matrix.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,7 +20,7 @@ TEST_CASE_1 = [{"include_background": True, "save_details": False, "metric_name": "f1"}, 0.75] TEST_CASE_2 = [{"include_background": False, "save_details": False, "metric_name": "ppv"}, 1.0] - +TEST_CASE_3 = [{"save_details": False, "metric_name": "f1", "reduction": "mean_batch"}, torch.tensor([0.6667, 0.8000])] TEST_CASE_SEG_1 = [{"include_background": True, "metric_name": "tpr"}, 0.7] data_1: Dict[Any, Any] = { @@ -39,23 +39,15 @@ } data_2: Dict[Any, Any] = { - "y_pred": torch.tensor( - [ - [[[0.0, 1.0], [1.0, 0.0]], [[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [0.0, 0.0]]], - ] - ), - "y": torch.tensor( - [ - [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], - ] - ), + "y_pred": torch.tensor([[[[0.0, 1.0], [1.0, 0.0]], [[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [0.0, 0.0]]]]), + "y": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]), } class TestHandlerConfusionMatrix(unittest.TestCase): # TODO test multi node averaged confusion matrix - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_compute(self, input_params, expected_avg): metric = ConfusionMatrix(**input_params) # test input a list of channel-first tensor @@ -68,7 +60,7 @@ def test_compute(self, input_params, expected_avg): metric.update([y_pred, y]) avg_metric = metric.compute() - self.assertAlmostEqual(avg_metric, expected_avg, places=4) + torch.testing.assert_allclose(avg_metric, expected_avg) @parameterized.expand([TEST_CASE_SEG_1]) def test_compute_seg(self, input_params, expected_avg): diff --git a/tests/test_handler_confusion_matrix_dist.py b/tests/test_handler_confusion_matrix_dist.py index 40245bce2e..325a799098 100644 --- a/tests/test_handler_confusion_matrix_dist.py +++ b/tests/test_handler_confusion_matrix_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -54,12 +54,10 @@ def _val_func(engine, batch): if dist.get_rank() == 1: y_pred = torch.tensor( - [[[[0.0, 1.0], [1.0, 0.0]], [[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [0.0, 0.0]]]], - device=device, + [[[[0.0, 1.0], [1.0, 0.0]], [[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [0.0, 0.0]]]], device=device ) y = torch.tensor( - [[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]], - device=device, + [[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]], device=device ) metric.update([y_pred, y]) diff --git a/tests/test_handler_decollate_batch.py b/tests/test_handler_decollate_batch.py index 8f0ffb2b5c..1a43ae295b 100644 --- a/tests/test_handler_decollate_batch.py +++ b/tests/test_handler_decollate_batch.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -32,7 +32,7 @@ def test_compute(self): [ Activationsd(keys="pred", sigmoid=True), CopyItemsd(keys="filename", times=1, names="filename_bak"), - AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, num_classes=2), + AsDiscreted(keys="pred", threshold=0.5, to_onehot=2), ] ) ), diff --git a/tests/test_handler_early_stop.py b/tests/test_handler_early_stop.py index efe8e89825..36604e5735 100644 --- a/tests/test_handler_early_stop.py +++ b/tests/test_handler_early_stop.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,10 +23,7 @@ def _train_func(engine, batch): trainer = Engine(_train_func) EarlyStopHandler( - patience=5, - score_function=lambda x: x.state.output["loss"], - trainer=trainer, - epoch_level=False, + patience=5, score_function=lambda x: x.state.output["loss"], trainer=trainer, epoch_level=False ).attach(trainer) trainer.run(range(4), max_epochs=2) diff --git a/tests/test_handler_garbage_collector.py b/tests/test_handler_garbage_collector.py index 75ab9ceb99..0350ba62fb 100644 --- a/tests/test_handler_garbage_collector.py +++ b/tests/test_handler_garbage_collector.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,13 +34,7 @@ class TestHandlerGarbageCollector(unittest.TestCase): @skipUnless(has_ignite, "Requires ignite") - @parameterized.expand( - [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - ] - ) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) def test_content(self, data, trigger_event): # set up engine gb_count_dict = {} diff --git a/tests/test_handler_hausdorff_distance.py b/tests/test_handler_hausdorff_distance.py index bbc36cc2b5..7e38f0ad56 100644 --- a/tests/test_handler_hausdorff_distance.py +++ b/tests/test_handler_hausdorff_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,9 +20,7 @@ def create_spherical_seg_3d( - radius: float = 20.0, - centre: Tuple[int, int, int] = (49, 49, 49), - im_shape: Tuple[int, int, int] = (99, 99, 99), + radius: float = 20.0, centre: Tuple[int, int, int] = (49, 49, 49), im_shape: Tuple[int, int, int] = (99, 99, 99) ) -> np.ndarray: """ Return a 3D image with a sphere inside. Voxel values will be @@ -91,6 +89,21 @@ def test_shape_mismatch(self): y = torch.ones((1, 1, 10, 10, 10)) hd_metric.update([y_pred, y]) + def test_reduction(self): + hd_metric = HausdorffDistance(include_background=True, reduction="mean_channel") + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + hd_metric.attach(engine, "hausdorff_distance") + + y_pred, y = TEST_SAMPLE_1 + hd_metric.update([y_pred, y]) + y_pred, y = TEST_SAMPLE_2 + hd_metric.update([y_pred, y]) + torch.testing.assert_allclose(hd_metric.compute().float(), torch.tensor([10.0, 0.0])) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_handler_lr_scheduler.py b/tests/test_handler_lr_scheduler.py index 82a62dce21..15401fe1b2 100644 --- a/tests/test_handler_lr_scheduler.py +++ b/tests/test_handler_lr_scheduler.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,7 +10,9 @@ # limitations under the License. import logging -import sys +import os +import re +import tempfile import unittest import numpy as np @@ -22,8 +24,9 @@ class TestHandlerLrSchedule(unittest.TestCase): def test_content(self): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) data = [0] * 8 + test_lr = 0.1 + gamma = 0.1 # set up engine def _train_func(engine, batch): @@ -41,24 +44,47 @@ def run_validation(engine): net = torch.nn.PReLU() def _reduce_lr_on_plateau(): - optimizer = torch.optim.SGD(net.parameters(), 0.1) + optimizer = torch.optim.SGD(net.parameters(), test_lr) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=1) handler = LrScheduleHandler(lr_scheduler, step_transform=lambda x: val_engine.state.metrics["val_loss"]) handler.attach(train_engine) - return lr_scheduler + return handler - def _reduce_on_step(): - optimizer = torch.optim.SGD(net.parameters(), 0.1) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1) - handler = LrScheduleHandler(lr_scheduler) - handler.attach(train_engine) - return lr_scheduler + with tempfile.TemporaryDirectory() as tempdir: + key_to_handler = "test_log_lr" + key_to_print = "Current learning rate" + filename = os.path.join(tempdir, "test_lr.log") + # test with additional logging handler + file_saver = logging.FileHandler(filename, mode="w") + file_saver.setLevel(logging.INFO) + logger = logging.getLogger(key_to_handler) + logger.setLevel(logging.INFO) + logger.addHandler(file_saver) + + def _reduce_on_step(): + optimizer = torch.optim.SGD(net.parameters(), test_lr) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=gamma) + handler = LrScheduleHandler(lr_scheduler, name=key_to_handler) + handler.attach(train_engine) + return handler + + schedulers = _reduce_lr_on_plateau(), _reduce_on_step() + + train_engine.run(data, max_epochs=5) + file_saver.close() + logger.removeHandler(file_saver) - schedulers = _reduce_lr_on_plateau(), _reduce_on_step() + with open(filename) as f: + output_str = f.read() + has_key_word = re.compile(f".*{key_to_print}.*") + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + self.assertTrue(content_count > 0) - train_engine.run(data, max_epochs=5) for scheduler in schedulers: - np.testing.assert_allclose(scheduler._last_lr[0], 0.001) + np.testing.assert_allclose(scheduler.lr_scheduler._last_lr[0], 0.001) if __name__ == "__main__": diff --git a/tests/test_handler_mean_dice.py b/tests/test_handler_mean_dice.py index ba4fb9d413..f309c7e693 100644 --- a/tests/test_handler_mean_dice.py +++ b/tests/test_handler_mean_dice.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,12 +19,17 @@ TEST_CASE_1 = [{"include_background": True, "output_transform": from_engine(["pred", "label"])}, 0.75, (4, 2)] TEST_CASE_2 = [{"include_background": False, "output_transform": from_engine(["pred", "label"])}, 0.66666, (4, 1)] +TEST_CASE_3 = [ + {"reduction": "mean_channel", "output_transform": from_engine(["pred", "label"])}, + torch.Tensor([1.0, 0.0, 1.0, 1.0]), + (4, 2), +] class TestHandlerMeanDice(unittest.TestCase): # TODO test multi node averaged dice - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_compute(self, input_params, expected_avg, details_shape): dice_metric = MeanDice(**input_params) # set up engine @@ -46,8 +51,7 @@ def _val_func(engine, batch): engine.fire_event(Events.ITERATION_COMPLETED) engine.fire_event(Events.EPOCH_COMPLETED) - - self.assertAlmostEqual(engine.state.metrics["mean_dice"], expected_avg, places=4) + torch.testing.assert_allclose(engine.state.metrics["mean_dice"], expected_avg) self.assertTupleEqual(tuple(engine.state.metric_details["mean_dice"].shape), details_shape) @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) diff --git a/tests/test_handler_metric_logger.py b/tests/test_handler_metric_logger.py index 5812605cd7..c3de866c5c 100644 --- a/tests/test_handler_metric_logger.py +++ b/tests/test_handler_metric_logger.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_metrics_saver.py b/tests/test_handler_metrics_saver.py index 17c23be274..56cfeb033b 100644 --- a/tests/test_handler_metrics_saver.py +++ b/tests/test_handler_metrics_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,6 +18,7 @@ from ignite.engine import Engine, Events from monai.handlers import MetricsSaver +from monai.utils.enums import PostFix class TestHandlerMetricsSaver(unittest.TestCase): @@ -27,13 +28,14 @@ def test_content(self): save_dir=tempdir, metrics=["metric1", "metric2"], metric_details=["metric3", "metric4"], - batch_transform=lambda x: x["image_meta_dict"], + batch_transform=lambda x: x[PostFix.meta("image")], summary_ops=["mean", "median", "max", "5percentile", "95percentile", "notnans"], + delimiter="\t", ) # set up engine data = [ - {"image_meta_dict": {"filename_or_obj": ["filepath1"]}}, - {"image_meta_dict": {"filename_or_obj": ["filepath2"]}}, + {PostFix.meta("image"): {"filename_or_obj": ["filepath1"]}}, + {PostFix.meta("image"): {"filename_or_obj": ["filepath2"]}}, ] def _val_func(engine, batch): diff --git a/tests/test_handler_metrics_saver_dist.py b/tests/test_handler_metrics_saver_dist.py index 0a36a19c66..245ff492a5 100644 --- a/tests/test_handler_metrics_saver_dist.py +++ b/tests/test_handler_metrics_saver_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,6 +21,7 @@ from monai.handlers import MetricsSaver from monai.utils import evenly_divisible_all_gather +from monai.utils.enums import PostFix from tests.utils import DistCall, DistTestCase @@ -31,14 +32,16 @@ def test_content(self): self._run(tempdir) def _run(self, tempdir): + my_rank = dist.get_rank() fnames = ["aaa" * 300, "bbb" * 301, "ccc" * 302] metrics_saver = MetricsSaver( save_dir=tempdir, metrics=["metric1", "metric2"], metric_details=["metric3", "metric4"], - batch_transform=lambda x: x["image_meta_dict"], + batch_transform=lambda x: x[PostFix.meta("image")], summary_ops="*", + delimiter="\t", ) def _val_func(engine, batch): @@ -46,22 +49,19 @@ def _val_func(engine, batch): engine = Engine(_val_func) - if dist.get_rank() == 0: - data = [{"image_meta_dict": {"filename_or_obj": [fnames[0]]}}] + if my_rank == 0: + data = [{PostFix.meta("image"): {"filename_or_obj": [fnames[0]]}}] @engine.on(Events.EPOCH_COMPLETED) def _save_metrics0(engine): engine.state.metrics = {"metric1": 1, "metric2": 2} - engine.state.metric_details = { - "metric3": torch.tensor([[1, 2]]), - "metric4": torch.tensor([[5, 6]]), - } + engine.state.metric_details = {"metric3": torch.tensor([[1, 2]]), "metric4": torch.tensor([[5, 6]])} - if dist.get_rank() == 1: + if my_rank == 1: # different ranks have different data length data = [ - {"image_meta_dict": {"filename_or_obj": [fnames[1]]}}, - {"image_meta_dict": {"filename_or_obj": [fnames[2]]}}, + {PostFix.meta("image"): {"filename_or_obj": [fnames[1]]}}, + {PostFix.meta("image"): {"filename_or_obj": [fnames[2]]}}, ] @engine.on(Events.EPOCH_COMPLETED) @@ -82,7 +82,7 @@ def _all_gather(engine): metrics_saver.attach(engine) engine.run(data, max_epochs=1) - if dist.get_rank() == 0: + if my_rank == 0: # check the metrics.csv and content self.assertTrue(os.path.exists(os.path.join(tempdir, "metrics.csv"))) with open(os.path.join(tempdir, "metrics.csv")) as f: @@ -110,6 +110,7 @@ def _all_gather(engine): self.assertEqual(row, ["mean\t2.5000\t2.5000\t3.5000\t1.5000\t3.3000\t0.8165\t3.0000"]) self.assertTrue(os.path.exists(os.path.join(tempdir, "metric4_raw.csv"))) self.assertTrue(os.path.exists(os.path.join(tempdir, "metric4_summary.csv"))) + dist.barrier() if __name__ == "__main__": diff --git a/tests/test_handler_mlflow.py b/tests/test_handler_mlflow.py new file mode 100644 index 0000000000..9f8b829481 --- /dev/null +++ b/tests/test_handler_mlflow.py @@ -0,0 +1,53 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import os +import tempfile +import unittest +from pathlib import Path + +from ignite.engine import Engine, Events + +from monai.handlers import MLFlowHandler + + +class TestHandlerMLFlow(unittest.TestCase): + def test_metrics_track(self): + with tempfile.TemporaryDirectory() as tempdir: + + # set up engine + def _train_func(engine, batch): + return [batch + 1.0] + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get("acc", 0.1) + engine.state.metrics["acc"] = current_metric + 0.1 + engine.state.test = current_metric + + # set up testing handler + test_path = os.path.join(tempdir, "mlflow_test") + handler = MLFlowHandler( + iteration_log=False, epoch_log=True, tracking_uri=Path(test_path).as_uri(), state_attributes=["test"] + ) + handler.attach(engine) + engine.run(range(3), max_epochs=2) + handler.close() + # check logging output + self.assertTrue(len(glob.glob(test_path)) > 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_nvtx.py b/tests/test_handler_nvtx.py index fee29af344..eeca15ea6f 100644 --- a/tests/test_handler_nvtx.py +++ b/tests/test_handler_nvtx.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,42 +22,18 @@ _, has_nvtx = optional_import("torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?") -TENSOR_0 = torch.tensor( - [ - [ - [[1.0], [2.0]], - [[3.0], [4.0]], - ] - ] -) +TENSOR_0 = torch.tensor([[[[1.0], [2.0]], [[3.0], [4.0]]]]) -TENSOR_1 = torch.tensor( - [ - [ - [[0.0], [-2.0]], - [[-3.0], [4.0]], - ] - ] -) +TENSOR_1 = torch.tensor([[[[0.0], [-2.0]], [[-3.0], [4.0]]]]) -TENSOR_1_EXPECTED = torch.tensor( - [ - [[1.0], [0.5]], - [[0.25], [5.0]], - ] -) +TENSOR_1_EXPECTED = torch.tensor([[[1.0], [0.5]], [[0.25], [5.0]]]) TEST_CASE_0 = [[{"image": TENSOR_0}], TENSOR_0[0] + 1.0] TEST_CASE_1 = [[{"image": TENSOR_1}], TENSOR_1_EXPECTED] class TestHandlerDecollateBatch(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_0, - TEST_CASE_1, - ] - ) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) @unittest.skipUnless(has_nvtx, "CUDA is required for NVTX!") def test_compute(self, data, expected): # Set up handlers diff --git a/tests/test_handler_parameter_scheduler.py b/tests/test_handler_parameter_scheduler.py index 5b3e845ace..72742f1956 100644 --- a/tests/test_handler_parameter_scheduler.py +++ b/tests/test_handler_parameter_scheduler.py @@ -1,3 +1,14 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest import torch @@ -9,7 +20,7 @@ class ToyNet(Module): def __init__(self, value): - super(ToyNet, self).__init__() + super().__init__() self.value = value def forward(self, input): diff --git a/tests/test_handler_post_processing.py b/tests/test_handler_post_processing.py index e9d57128cb..89087e1765 100644 --- a/tests/test_handler_post_processing.py +++ b/tests/test_handler_post_processing.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -26,7 +26,7 @@ "transform": Compose( [ CopyItemsd(keys="filename", times=1, names="filename_bak"), - AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, num_classes=2), + AsDiscreted(keys="pred", threshold=0.5, to_onehot=2), ] ), "event": "iteration_completed", diff --git a/tests/test_handler_prob_map_producer.py b/tests/test_handler_prob_map_producer.py index b21cf03171..b3f79cf587 100644 --- a/tests/test_handler_prob_map_producer.py +++ b/tests/test_handler_prob_map_producer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -32,12 +32,7 @@ class TestDataset(Dataset): def __init__(self, name, size): super().__init__( data=[ - { - "name": name, - "mask_shape": (size, size), - "mask_locations": [[i, i] for i in range(size)], - "level": 0, - } + {"name": name, "mask_shape": (size, size), "mask_locations": [[i, i] for i in range(size)], "level": 0} ] ) self.len = size @@ -46,11 +41,7 @@ def __len__(self): return self.len def __getitem__(self, index): - return { - "name": self.data[0]["name"], - "mask_location": self.data[0]["mask_locations"][index], - "pred": index + 1, - } + return {"name": self.data[0]["name"], "mask_location": self.data[0]["mask_locations"][index], "pred": index + 1} class TestEvaluator(Evaluator): @@ -59,13 +50,7 @@ def _iteration(self, engine, batchdata): class TestHandlerProbMapGenerator(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - ] - ) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) def test_prob_map_generator(self, name, size): # set up dataset dataset = TestDataset(name, size) diff --git a/tests/test_handler_regression_metrics.py b/tests/test_handler_regression_metrics.py index 7bb72dd5d5..c6af76a4db 100644 --- a/tests/test_handler_regression_metrics.py +++ b/tests/test_handler_regression_metrics.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_regression_metrics_dist.py b/tests/test_handler_regression_metrics_dist.py index c336ccf28c..a8b644d550 100644 --- a/tests/test_handler_regression_metrics_dist.py +++ b/tests/test_handler_regression_metrics_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_rocauc.py b/tests/test_handler_rocauc.py index 5b80bc43eb..6e2d6be27e 100644 --- a/tests/test_handler_rocauc.py +++ b/tests/test_handler_rocauc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,7 +22,7 @@ class TestHandlerROCAUC(unittest.TestCase): def test_compute(self): auc_metric = ROCAUC() act = Activations(softmax=True) - to_onehot = AsDiscrete(to_onehot=True, num_classes=2) + to_onehot = AsDiscrete(to_onehot=2) y_pred = [torch.Tensor([0.1, 0.9]), torch.Tensor([0.3, 1.4])] y = [torch.Tensor([0]), torch.Tensor([1])] diff --git a/tests/test_handler_rocauc_dist.py b/tests/test_handler_rocauc_dist.py index 8316d4c4b6..5113911d7c 100644 --- a/tests/test_handler_rocauc_dist.py +++ b/tests/test_handler_rocauc_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -26,7 +26,7 @@ class DistributedROCAUC(DistTestCase): def test_compute(self): auc_metric = ROCAUC() act = Activations(softmax=True) - to_onehot = AsDiscrete(to_onehot=True, num_classes=2) + to_onehot = AsDiscrete(to_onehot=2) device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu" if dist.get_rank() == 0: diff --git a/tests/test_handler_segmentation_saver.py b/tests/test_handler_segmentation_saver.py index 78dea0a68b..ee6566f6cb 100644 --- a/tests/test_handler_segmentation_saver.py +++ b/tests/test_handler_segmentation_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -39,7 +39,9 @@ def _train_func(engine, batch): engine = Engine(_train_func) # set up testing handler - saver = SegmentationSaver(output_dir=tempdir, output_postfix="seg", output_ext=output_ext, scale=255) + saver = SegmentationSaver( + output_dir=tempdir, output_postfix="seg", output_ext=output_ext, scale=255, output_dtype=np.uint8 + ) saver.attach(engine) data = [ @@ -65,7 +67,9 @@ def _train_func(engine, batch): engine = Engine(_train_func) # set up testing handler - saver = SegmentationSaver(output_dir=tempdir, output_postfix="seg", output_ext=output_ext, scale=255) + saver = SegmentationSaver( + output_dir=tempdir, output_postfix="seg", output_ext=output_ext, scale=255, output_dtype=np.uint8 + ) saver.attach(engine) data = [ diff --git a/tests/test_handler_smartcache.py b/tests/test_handler_smartcache.py index b67f1226cd..ec96d47e3d 100644 --- a/tests/test_handler_smartcache.py +++ b/tests/test_handler_smartcache.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,13 +24,7 @@ class TestHandlerSmartCache(unittest.TestCase): def test_content(self): data = [0, 1, 2, 3, 4, 5, 6, 7, 8] - expected = [ - [0, 1, 2, 3, 4], - [1, 2, 3, 4, 5], - [2, 3, 4, 5, 6], - [3, 4, 5, 6, 7], - [4, 5, 6, 7, 8], - ] + expected = [[0, 1, 2, 3, 4], [1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7], [4, 5, 6, 7, 8]] # set up engine def _train_func(engine, batch): diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index 84cdef59a8..7fe07d974b 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -43,7 +43,10 @@ def _update_metric(engine): engine.state.metrics[key_to_print] = current_metric + 0.1 # set up testing handler - stats_handler = StatsHandler(name=key_to_handler, logger_handler=log_handler) + logger = logging.getLogger(key_to_handler) + logger.setLevel(logging.INFO) + logger.addHandler(log_handler) + stats_handler = StatsHandler(iteration_log=False, epoch_log=True, name=key_to_handler) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) @@ -51,12 +54,12 @@ def _update_metric(engine): # check logging output output_str = log_stream.getvalue() log_handler.close() - grep = re.compile(f".*{key_to_handler}.*") has_key_word = re.compile(f".*{key_to_print}.*") - for idx, line in enumerate(output_str.split("\n")): - if grep.match(line): - if idx in [5, 10]: - self.assertTrue(has_key_word.match(line)) + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + self.assertTrue(content_count > 0) def test_loss_print(self): log_stream = StringIO() @@ -72,7 +75,10 @@ def _train_func(engine, batch): engine = Engine(_train_func) # set up testing handler - stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print, logger_handler=log_handler) + logger = logging.getLogger(key_to_handler) + logger.setLevel(logging.INFO) + logger.addHandler(log_handler) + stats_handler = StatsHandler(iteration_log=True, epoch_log=False, name=key_to_handler, tag_name=key_to_print) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) @@ -80,12 +86,12 @@ def _train_func(engine, batch): # check logging output output_str = log_stream.getvalue() log_handler.close() - grep = re.compile(f".*{key_to_handler}.*") has_key_word = re.compile(f".*{key_to_print}.*") - for idx, line in enumerate(output_str.split("\n")): - if grep.match(line): - if idx in [1, 2, 3, 6, 7, 8]: - self.assertTrue(has_key_word.match(line)) + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + self.assertTrue(content_count > 0) def test_loss_dict(self): log_stream = StringIO() @@ -101,9 +107,10 @@ def _train_func(engine, batch): engine = Engine(_train_func) # set up testing handler - stats_handler = StatsHandler( - name=key_to_handler, output_transform=lambda x: {key_to_print: x}, logger_handler=log_handler - ) + logger = logging.getLogger(key_to_handler) + logger.setLevel(logging.INFO) + logger.addHandler(log_handler) + stats_handler = StatsHandler(name=key_to_handler, output_transform=lambda x: {key_to_print: x[0]}) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) @@ -111,12 +118,12 @@ def _train_func(engine, batch): # check logging output output_str = log_stream.getvalue() log_handler.close() - grep = re.compile(f".*{key_to_handler}.*") has_key_word = re.compile(f".*{key_to_print}.*") - for idx, line in enumerate(output_str.split("\n")): - if grep.match(line): - if idx in [1, 2, 3, 6, 7, 8]: - self.assertTrue(has_key_word.match(line)) + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + self.assertTrue(content_count > 0) def test_loss_file(self): key_to_handler = "test_logging" @@ -134,20 +141,23 @@ def _train_func(engine, batch): engine = Engine(_train_func) # set up testing handler - stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print, logger_handler=handler) + logger = logging.getLogger(key_to_handler) + logger.setLevel(logging.INFO) + logger.addHandler(handler) + stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) handler.close() stats_handler.logger.removeHandler(handler) - with open(filename, "r") as f: + with open(filename) as f: output_str = f.read() - grep = re.compile(f".*{key_to_handler}.*") has_key_word = re.compile(f".*{key_to_print}.*") - for idx, line in enumerate(output_str.split("\n")): - if grep.match(line): - if idx in [1, 2, 3, 6, 7, 8]: - self.assertTrue(has_key_word.match(line)) + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + self.assertTrue(content_count > 0) def test_exception(self): # set up engine @@ -163,6 +173,80 @@ def _train_func(engine, batch): with self.assertRaises(RuntimeError): engine.run(range(3), max_epochs=2) + def test_attributes_print(self): + log_stream = StringIO() + log_handler = logging.StreamHandler(log_stream) + log_handler.setLevel(logging.INFO) + key_to_handler = "test_logging" + + # set up engine + def _train_func(engine, batch): + return [torch.tensor(0.0)] + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + if not hasattr(engine.state, "test1"): + engine.state.test1 = 0.1 + engine.state.test2 = 0.2 + else: + engine.state.test1 += 0.1 + engine.state.test2 += 0.2 + + # set up testing handler + logger = logging.getLogger(key_to_handler) + logger.setLevel(logging.INFO) + logger.addHandler(log_handler) + stats_handler = StatsHandler(name=key_to_handler, state_attributes=["test1", "test2", "test3"]) + stats_handler.attach(engine) + + engine.run(range(3), max_epochs=2) + + # check logging output + output_str = log_stream.getvalue() + log_handler.close() + has_key_word = re.compile(".*State values.*") + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + self.assertTrue(content_count > 0) + + def test_default_logger(self): + log_stream = StringIO() + log_handler = logging.StreamHandler(log_stream) + log_handler.setLevel(logging.INFO) + key_to_print = "myLoss" + + # set up engine + def _train_func(engine, batch): + return [torch.tensor(0.0)] + + engine = Engine(_train_func) + engine.logger.addHandler(log_handler) + + # set up testing handler + stats_handler = StatsHandler(name=None, tag_name=key_to_print) + stats_handler.attach(engine) + # leverage `engine.logger` to print info + engine.logger.setLevel(logging.INFO) + level = logging.root.getEffectiveLevel() + logging.basicConfig(level=logging.INFO) + engine.run(range(3), max_epochs=2) + logging.basicConfig(level=level) + + # check logging output + output_str = log_stream.getvalue() + log_handler.close() + has_key_word = re.compile(f".*{key_to_print}.*") + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + self.assertTrue(content_count > 0) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_handler_surface_distance.py b/tests/test_handler_surface_distance.py index 82cdb50d90..c990181998 100644 --- a/tests/test_handler_surface_distance.py +++ b/tests/test_handler_surface_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,9 +20,7 @@ def create_spherical_seg_3d( - radius: float = 20.0, - centre: Tuple[int, int, int] = (49, 49, 49), - im_shape: Tuple[int, int, int] = (99, 99, 99), + radius: float = 20.0, centre: Tuple[int, int, int] = (49, 49, 49), im_shape: Tuple[int, int, int] = (99, 99, 99) ) -> np.ndarray: """ Return a 3D image with a sphere inside. Voxel values will be @@ -91,6 +89,21 @@ def test_shape_mismatch(self): y = torch.ones((1, 1, 10, 10, 10)) sur_metric.update([y_pred, y]) + def test_reduction(self): + sur_metric = SurfaceDistance(include_background=True, reduction="mean_channel") + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + sur_metric.attach(engine, "surface_distance") + + y_pred, y = TEST_SAMPLE_1 + sur_metric.update([y_pred, y]) + y_pred, y = TEST_SAMPLE_2 + sur_metric.update([y_pred, y]) + torch.testing.assert_allclose(sur_metric.compute().float(), torch.tensor([4.1713, 0.0000])) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_handler_tb_image.py b/tests/test_handler_tb_image.py index b5d963eedf..d11bbfec59 100644 --- a/tests/test_handler_tb_image.py +++ b/tests/test_handler_tb_image.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_tb_stats.py b/tests/test_handler_tb_stats.py index 1d722e7f66..4d582d151b 100644 --- a/tests/test_handler_tb_stats.py +++ b/tests/test_handler_tb_stats.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -36,7 +36,7 @@ def _update_metric(engine): engine.state.metrics["acc"] = current_metric + 0.1 # set up testing handler - stats_handler = TensorBoardStatsHandler(log_dir=tempdir) + stats_handler = TensorBoardStatsHandler(log_dir=tempdir, iteration_log=False, epoch_log=True) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) stats_handler.close() @@ -57,11 +57,17 @@ def _train_func(engine, batch): def _update_metric(engine): current_metric = engine.state.metrics.get("acc", 0.1) engine.state.metrics["acc"] = current_metric + 0.1 + engine.state.test = current_metric # set up testing handler writer = SummaryWriter(log_dir=tempdir) stats_handler = TensorBoardStatsHandler( - writer, output_transform=lambda x: {"loss": x[0] * 2.0}, global_epoch_transform=lambda x: x * 3.0 + summary_writer=writer, + iteration_log=True, + epoch_log=False, + output_transform=lambda x: {"loss": x[0] * 2.0}, + global_epoch_transform=lambda x: x * 3.0, + state_attributes=["test"], ) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) diff --git a/tests/test_handler_transform_inverter.py b/tests/test_handler_transform_inverter.py deleted file mode 100644 index 385311eba7..0000000000 --- a/tests/test_handler_transform_inverter.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -import unittest - -import numpy as np -import torch -from ignite.engine import Engine - -from monai.data import CacheDataset, DataLoader, create_test_image_3d, decollate_batch -from monai.engines.utils import IterationEvents -from monai.handlers import TransformInverter -from monai.transforms import ( - AddChanneld, - CastToTyped, - Compose, - CopyItemsd, - LoadImaged, - Orientationd, - RandAffined, - RandAxisFlipd, - RandFlipd, - RandRotate90d, - RandRotated, - RandZoomd, - ResizeWithPadOrCropd, - ScaleIntensityd, - Spacingd, - ToTensord, -) -from monai.utils.misc import set_determinism -from tests.utils import make_nifti_image - -KEYS = ["image", "label"] - - -class TestTransformInverter(unittest.TestCase): - def test_invert(self): - set_determinism(seed=0) - im_fname, seg_fname = [make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)] - transform = Compose( - [ - LoadImaged(KEYS), - AddChanneld(KEYS), - Orientationd(KEYS, "RPS"), - Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), - ScaleIntensityd("image", minv=1, maxv=10), - RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), - RandAxisFlipd(KEYS, prob=0.5), - RandRotate90d(KEYS, spatial_axes=(1, 2)), - RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), - RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), - RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), - ResizeWithPadOrCropd(KEYS, 100), - ToTensord("image"), # test to support both Tensor and Numpy array when inverting - CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), - CopyItemsd("label", times=2, names=["label_inverted1", "label_inverted2"]), - CopyItemsd("image", times=2, names=["image_inverted1", "image_inverted2"]), - ] - ) - data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] - - # num workers = 0 for mac or gpu transforms - num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available() else 2 - - dataset = CacheDataset(data, transform=transform, progress=False) - loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) - - # set up engine - def _train_func(engine, batch): - self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100)) - engine.state.output = engine.state.batch = decollate_batch(batch) - engine.fire_event(IterationEvents.MODEL_COMPLETED) - return engine.state.output - - engine = Engine(_train_func) - engine.register_events(*IterationEvents) - - # set up testing handler - TransformInverter( - transform=transform, - output_keys=["image_inverted1", "label_inverted1"], - batch_keys="label", - meta_keys=["image_inverted1_meta_dict", "label_inverted1_meta_dict"], - batch_meta_keys="label_meta_dict", - nearest_interp=True, - to_tensor=[True, False], - device="cpu", - ).attach(engine) - - # test different nearest interpolation values - TransformInverter( - transform=transform, - output_keys=["image_inverted2", "label_inverted2"], - batch_keys="image", - meta_keys=None, - batch_meta_keys="image_meta_dict", - meta_key_postfix="meta_dict", - nearest_interp=[True, False], - post_func=[lambda x: x + 10, lambda x: x], - ).attach(engine) - - engine.run(loader, max_epochs=1) - set_determinism(seed=None) - - for output in engine.state.output: - self.assertTupleEqual(output["image"].shape, (1, 100, 100, 100)) - self.assertTupleEqual(output["label"].shape, (1, 100, 100, 100)) - # check the nearest inerpolation mode - i = output["image_inverted1"] - torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float)) - self.assertTupleEqual(i.shape, (1, 100, 101, 107)) - i = output["label_inverted1"] - np.testing.assert_allclose(i.astype(np.uint8).astype(np.float32), i.astype(np.float32)) - self.assertTupleEqual(i.shape, (1, 100, 101, 107)) - - # check the case that different items use different interpolation mode to invert transforms - d = output["image_inverted2"] - # if the interpolation mode is nearest, accumulated diff should be smaller than 1 - self.assertLess(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) - self.assertTupleEqual(d.shape, (1, 100, 101, 107)) - - d = output["label_inverted2"] - # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 - self.assertGreater(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) - self.assertTupleEqual(d.shape, (1, 100, 101, 107)) - - # check labels match - reverted = engine.state.output[-1]["label_inverted1"].astype(np.int32) - original = LoadImaged(KEYS)(data[-1])["label"] - n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) - reverted_name = engine.state.batch[-1]["label_inverted1_meta_dict"]["filename_or_obj"] - original_name = data[-1]["label"] - self.assertEqual(reverted_name, original_name) - print("invert diff", reverted.size - n_good) - # 25300: 2 workers (cpu, non-macos) - # 1812: 0 workers (gpu or macos) - # 1824: torch 1.5.1 - self.assertTrue((reverted.size - n_good) in (25300, 1812, 1824), "diff. in 3 possible values") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_handler_validation.py b/tests/test_handler_validation.py index 06f400109d..42ffc8b9eb 100644 --- a/tests/test_handler_validation.py +++ b/tests/test_handler_validation.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_hashing.py b/tests/test_hashing.py index ca317a72e8..5a1265bd48 100644 --- a/tests/test_hashing.py +++ b/tests/test_hashing.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_hausdorff_distance.py b/tests/test_hausdorff_distance.py index 0b313f722f..79a2c84b37 100644 --- a/tests/test_hausdorff_distance.py +++ b/tests/test_hausdorff_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,9 +20,7 @@ def create_spherical_seg_3d( - radius: float = 20.0, - centre: Tuple[int, int, int] = (49, 49, 49), - im_shape: Tuple[int, int, int] = (99, 99, 99), + radius: float = 20.0, centre: Tuple[int, int, int] = (49, 49, 49), im_shape: Tuple[int, int, int] = (99, 99, 99) ) -> np.ndarray: """ Return a 3D image with a sphere inside. Voxel values will be @@ -49,10 +47,7 @@ def create_spherical_seg_3d( TEST_CASES = [ - [ - [create_spherical_seg_3d(), create_spherical_seg_3d(), 1], - [0, 0, 0, 0, 0, 0], - ], + [[create_spherical_seg_3d(), create_spherical_seg_3d(), 1], [0, 0, 0, 0, 0, 0]], [ [ create_spherical_seg_3d(radius=20, centre=(20, 20, 20)), @@ -106,8 +101,8 @@ def create_spherical_seg_3d( # both pred and gt do not have foreground, metric and not_nans should be 0 np.zeros([99, 99, 99]), np.zeros([99, 99, 99]), - ], - ], + ] + ] ] diff --git a/tests/test_header_correct.py b/tests/test_header_correct.py index 4a8927fa80..aa0a4dde08 100644 --- a/tests/test_header_correct.py +++ b/tests/test_header_correct.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_highresnet.py b/tests/test_highresnet.py index 61af529b63..76c2203431 100644 --- a/tests/test_highresnet.py +++ b/tests/test_highresnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_hilbert_transform.py b/tests/test_hilbert_transform.py index 82454c34d0..10aa83293f 100644 --- a/tests/test_hilbert_transform.py +++ b/tests/test_hilbert_transform.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -28,10 +28,7 @@ def create_expected_numpy_output(input_datum, **kwargs): - x = np.fft.fft( - input_datum.cpu().numpy() if input_datum.device.type == "cuda" else input_datum.numpy(), - **kwargs, - ) + x = np.fft.fft(input_datum.cpu().numpy() if input_datum.device.type == "cuda" else input_datum.numpy(), **kwargs) f = np.fft.fftfreq(x.shape[kwargs["axis"]]) u = np.heaviside(f, 0.5) new_dims_before = kwargs["axis"] diff --git a/tests/test_histogram_normalize.py b/tests/test_histogram_normalize.py index b69fb1d927..95aa37f26e 100644 --- a/tests/test_histogram_normalize.py +++ b/tests/test_histogram_normalize.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,32 +15,42 @@ from parameterized import parameterized from monai.transforms import HistogramNormalize - -TEST_CASE_1 = [ - {"num_bins": 4, "min": 1, "max": 5, "mask": np.array([1, 1, 1, 1, 1, 0])}, - np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), - np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0]), -] - -TEST_CASE_2 = [ - {"num_bins": 4, "max": 4, "dtype": np.uint8}, - np.array([0.0, 1.0, 2.0, 3.0, 4.0]), - np.array([0, 0, 1, 3, 4]), -] - -TEST_CASE_3 = [ - {"num_bins": 256, "max": 255, "dtype": np.uint8}, - np.array([[[100.0, 200.0], [150.0, 250.0]]]), - np.array([[[0, 170], [70, 255]]]), -] +from monai.utils import get_equivalent_dtype +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"num_bins": 4, "min": 1, "max": 5, "mask": np.array([1, 1, 1, 1, 1, 0])}, + p(np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])), + p(np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0])), + ] + ) + + TESTS.append( + [ + {"num_bins": 4, "max": 4, "dtype": np.uint8}, + p(np.array([0.0, 1.0, 2.0, 3.0, 4.0])), + p(np.array([0, 0, 1, 3, 4])), + ] + ) + + TESTS.append( + [ + {"num_bins": 256, "max": 255, "dtype": np.uint8}, + p(np.array([[[100.0, 200.0], [150.0, 250.0]]])), + p(np.array([[[0, 170], [70, 255]]])), + ] + ) class TestHistogramNormalize(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = HistogramNormalize(**argments)(image) - np.testing.assert_allclose(result, expected_data) - self.assertEqual(result.dtype, argments.get("dtype", np.float32)) + assert_allclose(result, expected_data) + self.assertEqual(get_equivalent_dtype(result.dtype, data_type=np.ndarray), argments.get("dtype", np.float32)) if __name__ == "__main__": diff --git a/tests/test_histogram_normalized.py b/tests/test_histogram_normalized.py index 68647e82fb..7b86a9685f 100644 --- a/tests/test_histogram_normalized.py +++ b/tests/test_histogram_normalized.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,32 +15,42 @@ from parameterized import parameterized from monai.transforms import HistogramNormalized - -TEST_CASE_1 = [ - {"keys": "img", "num_bins": 4, "min": 1, "max": 5, "mask_key": "mask"}, - {"img": np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), "mask": np.array([1, 1, 1, 1, 1, 0])}, - np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0]), -] - -TEST_CASE_2 = [ - {"keys": "img", "num_bins": 4, "max": 4, "dtype": np.uint8}, - {"img": np.array([0.0, 1.0, 2.0, 3.0, 4.0])}, - np.array([0, 0, 1, 3, 4]), -] - -TEST_CASE_3 = [ - {"keys": "img", "num_bins": 256, "max": 255, "dtype": np.uint8}, - {"img": np.array([[[100.0, 200.0], [150.0, 250.0]]])}, - np.array([[[0, 170], [70, 255]]]), -] +from monai.utils import get_equivalent_dtype +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": "img", "num_bins": 4, "min": 1, "max": 5, "mask_key": "mask"}, + {"img": p(np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])), "mask": p(np.array([1, 1, 1, 1, 1, 0]))}, + p(np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0])), + ] + ) + + TESTS.append( + [ + {"keys": "img", "num_bins": 4, "max": 4, "dtype": np.uint8}, + {"img": p(np.array([0.0, 1.0, 2.0, 3.0, 4.0]))}, + p(np.array([0, 0, 1, 3, 4])), + ] + ) + + TESTS.append( + [ + {"keys": "img", "num_bins": 256, "max": 255, "dtype": np.uint8}, + {"img": p(np.array([[[100.0, 200.0], [150.0, 250.0]]]))}, + p(np.array([[[0, 170], [70, 255]]])), + ] + ) class TestHistogramNormalized(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = HistogramNormalized(**argments)(image)["img"] - np.testing.assert_allclose(result, expected_data) - self.assertEqual(result.dtype, argments.get("dtype", np.float32)) + assert_allclose(result, expected_data) + self.assertEqual(get_equivalent_dtype(result.dtype, data_type=np.ndarray), argments.get("dtype", np.float32)) if __name__ == "__main__": diff --git a/tests/test_identity.py b/tests/test_identity.py index 172860668c..60134c24a4 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_identityd.py b/tests/test_identityd.py index 665b7d5d1c..2df74ba2c6 100644 --- a/tests/test_identityd.py +++ b/tests/test_identityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_image_dataset.py b/tests/test_image_dataset.py index 3b3c06c87c..41eda803dc 100644 --- a/tests/test_image_dataset.py +++ b/tests/test_image_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,7 +17,15 @@ import numpy as np from monai.data import ImageDataset -from monai.transforms import Compose, EnsureChannelFirst, RandAdjustContrast, RandomizableTransform, Spacing +from monai.transforms import ( + Compose, + EnsureChannelFirst, + MapLabelValue, + RandAdjustContrast, + RandomizableTransform, + Spacing, +) +from monai.transforms.utility.array import ToNumpy FILENAMES = ["test1.nii.gz", "test2.nii", "test3.nii.gz"] @@ -106,16 +114,6 @@ def test_dataset(self): for d, ref in zip(dataset, ref_data): np.testing.assert_allclose(d, ref + 1, atol=1e-3) - # set seg transform, but no seg_files - with self.assertRaises(RuntimeError): - dataset = ImageDataset(full_names, seg_transform=lambda x: x + 1, image_only=True) - _ = dataset[0] - - # set seg transform, but no seg_files - with self.assertRaises(RuntimeError): - dataset = ImageDataset(full_names, seg_transform=lambda x: x + 1, image_only=True) - _ = dataset[0] - # loading image/label, with meta dataset = ImageDataset( full_names, @@ -133,13 +131,27 @@ def test_dataset(self): # loading image/label, with meta dataset = ImageDataset( - full_names, transform=lambda x: x + 1, seg_files=full_names, labels=[1, 2, 3], image_only=False + image_files=full_names, + seg_files=full_names, + labels=[1, 2, 3], + transform=lambda x: x + 1, + label_transform=Compose( + [ + ToNumpy(), + MapLabelValue(orig_labels=[1, 2, 3], target_labels=[30.0, 20.0, 10.0], dtype=np.float32), + ] + ), + image_only=False, ) for idx, (d_tuple, ref) in enumerate(zip(dataset, ref_data)): img, seg, label, meta, seg_meta = d_tuple np.testing.assert_allclose(img, ref + 1, atol=1e-3) np.testing.assert_allclose(seg, ref, atol=1e-3) - np.testing.assert_allclose(idx + 1, label) + # test label_transform + + np.testing.assert_allclose((3 - idx) * 10.0, label) + self.assertTrue(isinstance(label, np.ndarray)) + self.assertEqual(label.dtype, np.float32) np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3) np.testing.assert_allclose(seg_meta["original_affine"], np.eye(4), atol=1e-3) diff --git a/tests/test_image_rw.py b/tests/test_image_rw.py new file mode 100644 index 0000000000..62b1147aa5 --- /dev/null +++ b/tests/test_image_rw.py @@ -0,0 +1,136 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import os +import shutil +import tempfile +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.data.image_reader import ITKReader, NibabelReader, PILReader +from monai.data.image_writer import ITKWriter, NibabelWriter, PILWriter, register_writer, resolve_writer +from monai.transforms import LoadImage, SaveImage, moveaxis +from monai.utils import OptionalImportError +from tests.utils import TEST_NDARRAYS, assert_allclose + + +class TestLoadSaveNifti(unittest.TestCase): + def setUp(self): + self.test_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.test_dir, ignore_errors=True) + + def nifti_rw(self, test_data, reader, writer, dtype, resample=True): + test_data = test_data.astype(dtype) + ndim = len(test_data.shape) - 1 + for p in TEST_NDARRAYS: + output_ext = ".nii.gz" + filepath = f"testfile_{ndim}d" + saver = SaveImage( + output_dir=self.test_dir, output_ext=output_ext, resample=resample, separate_folder=False, writer=writer + ) + saver( + p(test_data), + { + "filename_or_obj": f"{filepath}.png", + "affine": np.eye(4), + "original_affine": np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]), + }, + ) + saved_path = os.path.join(self.test_dir, filepath + "_trans" + output_ext) + self.assertTrue(os.path.exists(saved_path)) + loader = LoadImage(reader=reader, squeeze_non_spatial_dims=True) + data, meta = loader(saved_path) + if meta["original_channel_dim"] == -1: + _test_data = moveaxis(test_data, 0, -1) + else: + _test_data = test_data[0] + if resample: + _test_data = moveaxis(_test_data, 0, 1) + assert_allclose(data, _test_data) + + @parameterized.expand(itertools.product([NibabelReader, ITKReader], [NibabelWriter, "ITKWriter"])) + def test_2d(self, reader, writer): + test_data = np.arange(48, dtype=np.uint8).reshape(1, 6, 8) + self.nifti_rw(test_data, reader, writer, np.uint8) + self.nifti_rw(test_data, reader, writer, np.float32) + + @parameterized.expand(itertools.product([NibabelReader, ITKReader], [NibabelWriter, ITKWriter])) + def test_3d(self, reader, writer): + test_data = np.arange(48, dtype=np.uint8).reshape(1, 2, 3, 8) + self.nifti_rw(test_data, reader, writer, int) + self.nifti_rw(test_data, reader, writer, int, False) + + @parameterized.expand(itertools.product([NibabelReader, ITKReader], ["NibabelWriter", ITKWriter])) + def test_4d(self, reader, writer): + test_data = np.arange(48, dtype=np.uint8).reshape(2, 1, 3, 8) + self.nifti_rw(test_data, reader, writer, np.float16) + + +class TestLoadSavePNG(unittest.TestCase): + def setUp(self): + self.test_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.test_dir, ignore_errors=True) + + def png_rw(self, test_data, reader, writer, dtype, resample=True): + test_data = test_data.astype(dtype) + ndim = len(test_data.shape) - 1 + for p in TEST_NDARRAYS: + output_ext = ".png" + filepath = f"testfile_{ndim}d" + saver = SaveImage( + output_dir=self.test_dir, output_ext=output_ext, resample=resample, separate_folder=False, writer=writer + ) + saver(p(test_data), {"filename_or_obj": f"{filepath}.png", "spatial_shape": (6, 8)}) + saved_path = os.path.join(self.test_dir, filepath + "_trans" + output_ext) + self.assertTrue(os.path.exists(saved_path)) + loader = LoadImage(reader=reader) + data, meta = loader(saved_path) + if meta["original_channel_dim"] == -1: + _test_data = moveaxis(test_data, 0, -1) + else: + _test_data = test_data[0] + assert_allclose(data, _test_data) + + @parameterized.expand(itertools.product([PILReader, ITKReader], [PILWriter, ITKWriter])) + def test_2d(self, reader, writer): + test_data = np.arange(48, dtype=np.uint8).reshape(1, 6, 8) + self.png_rw(test_data, reader, writer, np.uint8) + + @parameterized.expand(itertools.product([PILReader, ITKReader], ["monai.data.PILWriter", ITKWriter])) + def test_rgb(self, reader, writer): + test_data = np.arange(48, dtype=np.uint8).reshape(3, 2, 8) + self.png_rw(test_data, reader, writer, np.uint8, False) + + +class TestRegRes(unittest.TestCase): + def test_0_default(self): + self.assertTrue(len(resolve_writer(".png")) > 0, "has png writer") + self.assertTrue(len(resolve_writer(".nrrd")) > 0, "has nrrd writer") + self.assertTrue(len(resolve_writer("unknown")) > 0, "has writer") + register_writer("unknown1", lambda: (_ for _ in ()).throw(OptionalImportError)) + with self.assertRaises(OptionalImportError): + resolve_writer("unknown1") + + def test_1_new(self): + register_writer("new", lambda x: x + 1) + register_writer("new2", lambda x: x + 1) + self.assertEqual(resolve_writer("new")[0](0), 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_img2tensorboard.py b/tests/test_img2tensorboard.py index bd0369868e..58c4d3cfab 100644 --- a/tests/test_img2tensorboard.py +++ b/tests/test_img2tensorboard.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,29 +22,21 @@ class TestImg2Tensorboard(unittest.TestCase): def test_write_gray(self): nparr = np.ones(shape=(1, 32, 32, 32), dtype=np.float32) summary_object_np = make_animated_gif_summary( - tag="test_summary_nparr.png", - image=nparr, - max_out=1, - animation_axes=(3,), - image_axes=(1, 2), - scale_factor=253.0, + tag="test_summary_nparr.png", image=nparr, max_out=1, scale_factor=253.0 ) - assert isinstance( - summary_object_np, tensorboard.compat.proto.summary_pb2.Summary - ), "make_animated_gif_summary must return a tensorboard.summary object from numpy array" + for s in summary_object_np: + assert isinstance( + s, tensorboard.compat.proto.summary_pb2.Summary + ), "make_animated_gif_summary must return a tensorboard.summary object from numpy array" tensorarr = torch.tensor(nparr) summary_object_tensor = make_animated_gif_summary( - tag="test_summary_tensorarr.png", - image=tensorarr, - max_out=1, - animation_axes=(3,), - image_axes=(1, 2), - scale_factor=253.0, + tag="test_summary_tensorarr.png", image=tensorarr, max_out=1, frame_dim=-1, scale_factor=253.0 ) - assert isinstance( - summary_object_tensor, tensorboard.compat.proto.summary_pb2.Summary - ), "make_animated_gif_summary must return a tensorboard.summary object from tensor input" + for s in summary_object_tensor: + assert isinstance( + s, tensorboard.compat.proto.summary_pb2.Summary + ), "make_animated_gif_summary must return a tensorboard.summary object from tensor input" if __name__ == "__main__": diff --git a/tests/test_init_reader.py b/tests/test_init_reader.py index d6737c26ca..03a63cc375 100644 --- a/tests/test_init_reader.py +++ b/tests/test_init_reader.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,6 +13,7 @@ from monai.data import ITKReader, NibabelReader, NumpyReader, PILReader from monai.transforms import LoadImage, LoadImaged +from tests.utils import SkipIfNoModule class TestInitLoadImage(unittest.TestCase): @@ -26,6 +27,9 @@ def test_load_image(self): inst = LoadImaged("image", reader=r) self.assertIsInstance(inst, LoadImaged) + @SkipIfNoModule("itk") + @SkipIfNoModule("nibabel") + @SkipIfNoModule("PIL") def test_readers(self): inst = ITKReader() self.assertIsInstance(inst, ITKReader) diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py new file mode 100644 index 0000000000..e6d4dfd89f --- /dev/null +++ b/tests/test_integration_bundle_run.py @@ -0,0 +1,109 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import shutil +import subprocess +import sys +import tempfile +import unittest + +import nibabel as nib +import numpy as np +from parameterized import parameterized + +from monai.bundle import ConfigParser +from monai.transforms import LoadImage + +TEST_CASE_1 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.json"), (128, 128, 128)] + +TEST_CASE_2 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.yaml"), (128, 128, 128)] + + +class _Runnable42: + def __init__(self, val=1): + self.val = val + + def run(self): + assert self.val == 42 # defined in `TestBundleRun.test_tiny`` + return self.val + + +class TestBundleRun(unittest.TestCase): + def setUp(self): + self.data_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.data_dir) + + def test_tiny(self): + config_file = os.path.join(self.data_dir, "tiny_config.json") + with open(config_file, "w") as f: + json.dump({"": {"_target_": "tests.test_integration_bundle_run._Runnable42", "val": 42}}, f) + cmd = [sys.executable, "-m", "monai.bundle", "run", "--config_file", config_file] + ret = subprocess.check_call(cmd) + self.assertEqual(ret, 0) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_shape(self, config_file, expected_shape): + test_image = np.random.rand(*expected_shape) + tempdir = self.data_dir + filename = os.path.join(tempdir, "image.nii") + nib.save(nib.Nifti1Image(test_image, np.eye(4)), filename) + + # generate default args in a JSON file + def_args = {"config_file": "will be replaced by `config_file` arg"} + def_args_file = os.path.join(tempdir, "def_args.json") + ConfigParser.export_config_file(config=def_args, filepath=def_args_file) + + meta = {"datalist": [{"image": filename}], "output_dir": tempdir, "window": (96, 96, 96)} + # test YAML file + meta_file = os.path.join(tempdir, "meta.yaml") + ConfigParser.export_config_file(config=meta, filepath=meta_file, fmt="yaml") + + # test override with file, up case postfix + overridefile1 = os.path.join(tempdir, "override1.JSON") + with open(overridefile1, "w") as f: + # test override with part of the overriding file + json.dump({"move_net": "$@network_def.to(@device)"}, f) + os.makedirs(os.path.join(tempdir, "jsons"), exist_ok=True) + overridefile2 = os.path.join(tempdir, "jsons/override2.JSON") + with open(overridefile2, "w") as f: + # test override with the whole overriding file + json.dump("Dataset", f) + + saver = LoadImage(image_only=True) + + if sys.platform == "win32": + override = "--network $@network_def.to(@device) --dataset#_target_ Dataset" + else: + override = f"--network %{overridefile1}#move_net --dataset#_target_ %{overridefile2}" + # test with `monai.bundle` as CLI entry directly + cmd = f"-m monai.bundle run evaluator --postprocessing#transforms#2#output_postfix seg {override}" + la = [f"{sys.executable}"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] + test_env = os.environ.copy() + print(f"CUDA_VISIBLE_DEVICES in {__file__}", test_env.get("CUDA_VISIBLE_DEVICES")) + ret = subprocess.check_call(la + ["--args_file", def_args_file], env=test_env) + self.assertEqual(ret, 0) + self.assertTupleEqual(saver(os.path.join(tempdir, "image", "image_seg.nii.gz")).shape, expected_shape) + + # here test the script with `google fire` tool as CLI + cmd = "-m fire monai.bundle.scripts run --runner_id evaluator" + cmd += f" --evaluator#amp False {override}" + la = [f"{sys.executable}"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] + ret = subprocess.check_call(la, env=test_env) + self.assertEqual(ret, 0) + self.assertTupleEqual(saver(os.path.join(tempdir, "image", "image_trans.nii.gz")).shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_integration_classification_2d.py b/tests/test_integration_classification_2d.py index 03b5571973..5a742ce4f9 100644 --- a/tests/test_integration_classification_2d.py +++ b/tests/test_integration_classification_2d.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,7 +12,6 @@ import os import unittest import warnings -from urllib.error import ContentTooShortError, HTTPError import numpy as np import torch @@ -39,10 +38,8 @@ ) from monai.utils import set_determinism from tests.testing_data.integration_answers import test_integration_value -from tests.utils import DistTestCase, TimedCall, skip_if_quick +from tests.utils import DistTestCase, TimedCall, skip_if_downloading_fails, skip_if_quick, testing_data_config -TEST_DATA_URL = "https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE" -MD5_VALUE = "0bc7306e7427e00ad1c5526a6677552d" TASK = "integration_classification_2d" @@ -69,7 +66,7 @@ def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0", AddChannel(), Transpose(indices=[0, 2, 1]), ScaleIntensity(), - RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True), + RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True, dtype=np.float64), RandFlip(spatial_axis=0, prob=0.5), RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5), ToTensor(), @@ -80,7 +77,7 @@ def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0", [LoadImage(image_only=True), AddChannel(), Transpose(indices=[0, 2, 1]), ScaleIntensity(), ToTensor()] ) y_pred_trans = Compose([ToTensor(), Activations(softmax=True)]) - y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=True, num_classes=len(np.unique(train_y)))]) + y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=len(np.unique(train_y)))]) auc_metric = ROCAUCMetric() # create train, val data loaders @@ -186,18 +183,19 @@ def setUp(self): dataset_file = os.path.join(self.data_dir, "MedNIST.tar.gz") if not os.path.exists(data_dir): - try: - download_and_extract(TEST_DATA_URL, dataset_file, self.data_dir, MD5_VALUE) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors + with skip_if_downloading_fails(): + data_spec = testing_data_config("images", "mednist") + download_and_extract( + data_spec["url"], + dataset_file, + self.data_dir, + hash_val=data_spec["hash_val"], + hash_type=data_spec["hash_type"], + ) assert os.path.exists(data_dir) - class_names = sorted((x for x in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, x)))) + class_names = sorted(x for x in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, x))) image_files = [ [os.path.join(data_dir, class_name, x) for x in sorted(os.listdir(os.path.join(data_dir, class_name)))] for class_name in class_names diff --git a/tests/test_integration_determinism.py b/tests/test_integration_determinism.py index e077420420..64c018b4f5 100644 --- a/tests/test_integration_determinism.py +++ b/tests/test_integration_determinism.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -41,7 +41,7 @@ def __len__(self): return train_steps net = UNet( - dimensions=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2 + spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2 ).to(device) loss = DiceLoss(sigmoid=True) @@ -75,13 +75,13 @@ def setUp(self): def tearDown(self): set_determinism(seed=None) - @TimedCall(seconds=150) + @TimedCall(seconds=150, skip_timing=not torch.cuda.is_available()) def test_training(self): set_determinism(seed=0) loss, step = run_test(device=self.device) print(f"Deterministic loss {loss} at training step {step}") np.testing.assert_allclose(step, 4) - np.testing.assert_allclose(loss, 0.535927, rtol=1e-4) + np.testing.assert_allclose(loss, 0.536134, rtol=1e-4) if __name__ == "__main__": diff --git a/tests/test_integration_fast_train.py b/tests/test_integration_fast_train.py new file mode 100644 index 0000000000..51b2ac1d3f --- /dev/null +++ b/tests/test_integration_fast_train.py @@ -0,0 +1,235 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +import shutil +import tempfile +import time +import unittest +from glob import glob + +import nibabel as nib +import numpy as np +import torch + +import monai +from monai.data import CacheDataset, ThreadDataLoader, create_test_image_3d, decollate_batch +from monai.inferers import sliding_window_inference +from monai.losses import DiceCELoss +from monai.metrics import DiceMetric +from monai.networks.layers import Norm +from monai.networks.nets import UNet +from monai.optimizers import Novograd +from monai.transforms import ( + AsDiscrete, + Compose, + CropForegroundd, + EnsureChannelFirstd, + EnsureType, + EnsureTyped, + FgBgToIndicesd, + LoadImaged, + RandAffined, + RandAxisFlipd, + RandCropByPosNegLabeld, + RandFlipd, + RandGaussianNoised, + RandRotate90d, + RandRotated, + RandStdShiftIntensityd, + RandZoomd, + ScaleIntensityd, + Spacingd, + ToDeviced, +) +from monai.utils import set_determinism +from tests.utils import DistTestCase, SkipIfBeforePyTorchVersion, TimedCall, skip_if_no_cuda, skip_if_quick + + +@skip_if_no_cuda +@skip_if_quick +@SkipIfBeforePyTorchVersion((1, 7)) +class IntegrationFastTrain(DistTestCase): + def setUp(self): + set_determinism(seed=0) + monai.config.print_config() + + self.data_dir = tempfile.mkdtemp() + for i in range(41): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(self.data_dir, f"img{i:d}.nii.gz")) + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(self.data_dir, f"seg{i:d}.nii.gz")) + + def tearDown(self): + set_determinism(seed=None) + shutil.rmtree(self.data_dir) + + # test the fast training speed is as expected + @TimedCall(seconds=100, daemon=False, force_quit=False) + def test_train_timing(self): + images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz"))) + segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz"))) + train_files = [{"image": img, "label": seg} for img, seg in zip(images[:32], segs[:32])] + val_files = [{"image": img, "label": seg} for img, seg in zip(images[-9:], segs[-9:])] + + device = torch.device("cuda:0") + # define transforms for train and validation + train_transforms = Compose( + [ + LoadImaged(keys=["image", "label"]), + EnsureChannelFirstd(keys=["image", "label"]), + Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), + ScaleIntensityd(keys="image"), + CropForegroundd(keys=["image", "label"], source_key="image"), + # pre-compute foreground and background indexes + # and cache them to accelerate training + FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"), + # change to execute transforms with Tensor data + EnsureTyped(keys=["image", "label"]), + # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch + ToDeviced(keys=["image", "label"], device=device), + # randomly crop out patch samples from big + # image based on pos / neg ratio + # the image centers of negative samples + # must be in valid image area + RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=(64, 64, 64), + pos=1, + neg=1, + num_samples=4, + fg_indices_key="label_fg", + bg_indices_key="label_bg", + ), + RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]), + RandAxisFlipd(keys=["image", "label"], prob=0.5), + RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(1, 2)), + RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.2, keep_size=True), + RandRotated( + keys=["image", "label"], + prob=0.5, + range_x=np.pi / 4, + mode=("bilinear", "nearest"), + align_corners=True, + dtype=np.float64, + ), + RandAffined(keys=["image", "label"], prob=0.5, rotate_range=np.pi / 2, mode=("bilinear", "nearest")), + RandGaussianNoised(keys="image", prob=0.5), + RandStdShiftIntensityd(keys="image", prob=0.5, factors=0.05, nonzero=True), + ] + ) + + val_transforms = Compose( + [ + LoadImaged(keys=["image", "label"]), + EnsureChannelFirstd(keys=["image", "label"]), + Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), + ScaleIntensityd(keys="image"), + CropForegroundd(keys=["image", "label"], source_key="image"), + EnsureTyped(keys=["image", "label"]), + # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch + ToDeviced(keys=["image", "label"], device=device), + ] + ) + + max_epochs = 5 + learning_rate = 2e-4 + val_interval = 1 # do validation for every epoch + + # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training + train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=8) + val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=5) + # disable multi-workers because `ThreadDataLoader` works with multi-threads + train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=4, shuffle=True) + val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1) + + loss_function = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True) + model = UNet( + spatial_dims=3, + in_channels=1, + out_channels=2, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, + norm=Norm.BATCH, + ).to(device) + + # Novograd paper suggests to use a bigger LR than Adam, + # because Adam does normalization by element-wise second moments + optimizer = Novograd(model.parameters(), learning_rate * 10) + scaler = torch.cuda.amp.GradScaler() + + post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]) + post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) + + dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) + + best_metric = -1 + total_start = time.time() + for epoch in range(max_epochs): + epoch_start = time.time() + print("-" * 10) + print(f"epoch {epoch + 1}/{max_epochs}") + model.train() + epoch_loss = 0 + step = 0 + for batch_data in train_loader: + step_start = time.time() + step += 1 + optimizer.zero_grad() + # set AMP for training + with torch.cuda.amp.autocast(): + outputs = model(batch_data["image"]) + loss = loss_function(outputs, batch_data["label"]) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + epoch_loss += loss.item() + epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) + print( + f"{step}/{epoch_len}, train_loss: {loss.item():.4f}" f" step time: {(time.time() - step_start):.4f}" + ) + epoch_loss /= step + print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") + + if (epoch + 1) % val_interval == 0: + model.eval() + with torch.no_grad(): + for val_data in val_loader: + roi_size = (96, 96, 96) + sw_batch_size = 4 + # set AMP for validation + with torch.cuda.amp.autocast(): + val_outputs = sliding_window_inference(val_data["image"], roi_size, sw_batch_size, model) + + val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)] + val_labels = [post_label(i) for i in decollate_batch(val_data["label"])] + dice_metric(y_pred=val_outputs, y=val_labels) + + metric = dice_metric.aggregate().item() + dice_metric.reset() + if metric > best_metric: + best_metric = metric + print(f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}") + print(f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}") + + total_time = time.time() - total_start + print(f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}") + # test expected metrics + self.assertGreater(best_metric, 0.95) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_integration_segmentation_3d.py b/tests/test_integration_segmentation_3d.py index d5eb69f7af..5c273d0a46 100644 --- a/tests/test_integration_segmentation_3d.py +++ b/tests/test_integration_segmentation_3d.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,7 +21,7 @@ from torch.utils.tensorboard import SummaryWriter import monai -from monai.data import NiftiSaver, create_test_image_3d, decollate_batch +from monai.data import create_test_image_3d, decollate_batch from monai.inferers import sliding_window_inference from monai.metrics import DiceMetric from monai.networks import eval_mode @@ -34,12 +34,14 @@ LoadImaged, RandCropByPosNegLabeld, RandRotate90d, + SaveImage, ScaleIntensityd, Spacingd, ToTensor, ToTensord, ) from monai.utils import set_determinism +from monai.utils.enums import PostFix from monai.visualize import plot_2d_or_3d_image from tests.testing_data.integration_answers import test_integration_value from tests.utils import DistTestCase, TimedCall, skip_if_quick @@ -95,12 +97,12 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) - val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) + val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) # create UNet, DiceLoss and Adam optimizer model = monai.networks.nets.UNet( - dimensions=3, + spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), @@ -195,11 +197,11 @@ def run_inference_test(root_dir, device="cuda:0"): val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) - val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) + val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) model = UNet( - dimensions=3, + spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), @@ -212,17 +214,25 @@ def run_inference_test(root_dir, device="cuda:0"): with eval_mode(model): # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 - saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"), dtype=np.float32) + saver = SaveImage( + output_dir=os.path.join(root_dir, "output"), + dtype=np.float32, + output_ext=".nii.gz", + output_postfix="seg", + mode="bilinear", + ) for val_data in val_loader: val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) - # decollate prediction into a list and execute post processing for every item + # decollate prediction into a list val_outputs = [val_post_tran(i) for i in decollate_batch(val_outputs)] + val_meta = decollate_batch(val_data[PostFix.meta("img")]) # compute metrics dice_metric(y_pred=val_outputs, y=val_labels) - saver.save_batch(val_outputs, val_data["img_meta_dict"]) + for img, meta in zip(val_outputs, val_meta): # save a decollated batch of files + saver(img, meta) return dice_metric.aggregate().item() diff --git a/tests/test_integration_sliding_window.py b/tests/test_integration_sliding_window.py index b63f331ba6..af49e3db77 100644 --- a/tests/test_integration_sliding_window.py +++ b/tests/test_integration_sliding_window.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,7 +34,7 @@ def run_test(batch_size, img_name, seg_name, output_dir, device="cuda:0"): loader = DataLoader(ds, batch_size=1, pin_memory=torch.cuda.is_available()) net = UNet( - dimensions=3, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2 + spatial_dims=3, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2 ).to(device) roi_size = (16, 32, 48) sw_batch_size = batch_size diff --git a/tests/test_integration_stn.py b/tests/test_integration_stn.py index c1fcfe7a89..e655ff6755 100644 --- a/tests/test_integration_stn.py +++ b/tests/test_integration_stn.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function import unittest @@ -104,7 +103,7 @@ def setUp(self): def tearDown(self): set_determinism(seed=None) - @TimedCall(seconds=60) + @TimedCall(seconds=100, skip_timing=not torch.cuda.is_available()) def test_training(self): """ check that the quality AffineTransform backpropagation diff --git a/tests/test_integration_unet_2d.py b/tests/test_integration_unet_2d.py index a46a174dc9..e60c91968a 100644 --- a/tests/test_integration_unet_2d.py +++ b/tests/test_integration_unet_2d.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -32,10 +32,10 @@ def __len__(self): return train_steps if net_name == "basicunet": - net = BasicUNet(dimensions=2, in_channels=1, out_channels=1, features=(4, 8, 8, 16, 16, 32)) + net = BasicUNet(spatial_dims=2, in_channels=1, out_channels=1, features=(4, 8, 8, 16, 16, 32)) elif net_name == "unet": net = UNet( - dimensions=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2 + spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2 ) net.to(device) diff --git a/tests/test_integration_workers.py b/tests/test_integration_workers.py new file mode 100644 index 0000000000..21515d1f82 --- /dev/null +++ b/tests/test_integration_workers.py @@ -0,0 +1,57 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from monai.data import DataLoader +from monai.utils import set_determinism +from tests.utils import DistTestCase, SkipIfBeforePyTorchVersion, TimedCall, skip_if_no_cuda, skip_if_quick + + +def run_loading_test(num_workers=50, device="cuda:0" if torch.cuda.is_available() else "cpu", pw=False): + """multi workers stress tests""" + set_determinism(seed=0) + train_ds = list(range(10000)) + train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=num_workers, persistent_workers=pw) + answer = [] + for _ in range(2): + np.testing.assert_equal(torch.cuda.memory_allocated(), 0) + for batch_data in train_loader: + x = batch_data.to(device) + mem = torch.cuda.memory_allocated() + np.testing.assert_equal(mem > 0 and mem < 5000, True) + answer.append(x[-1].item()) + del x + return answer + + +@skip_if_quick +@skip_if_no_cuda +@SkipIfBeforePyTorchVersion((1, 9)) +class IntegrationLoading(DistTestCase): + def tearDown(self): + set_determinism(seed=None) + + @TimedCall(seconds=5000, skip_timing=not torch.cuda.is_available(), daemon=False) + def test_timing(self): + expected = None + for pw in (False, True): + result = run_loading_test(pw=pw) + if expected is None: + expected = result[0] + np.testing.assert_allclose(result[0], expected) # test for deterministic first epoch in two settings + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 7fcc0b4064..432e5e90a0 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,10 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import os import shutil -import sys import tempfile import unittest import warnings @@ -54,6 +52,7 @@ ToTensord, ) from monai.utils import set_determinism +from monai.utils.enums import PostFix from tests.testing_data.integration_answers import test_integration_value from tests.utils import DistTestCase, TimedCall, skip_if_quick @@ -98,7 +97,7 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): # create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( - dimensions=3, + spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), @@ -114,7 +113,7 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): [ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), - AsDiscreted(keys="pred", threshold_values=True), + AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) @@ -127,8 +126,8 @@ def _forward_completed(self, engine): pass val_handlers = [ - StatsHandler(output_transform=lambda x: None), - TensorBoardStatsHandler(summary_writer=summary_writer, output_transform=lambda x: None), + StatsHandler(iteration_log=False), + TensorBoardStatsHandler(summary_writer=summary_writer, iteration_log=False), TensorBoardImageHandler( log_dir=root_dir, batch_transform=from_engine(["image", "label"]), output_transform=from_engine("pred") ), @@ -155,7 +154,7 @@ def _forward_completed(self, engine): [ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), - AsDiscreted(keys="pred", threshold_values=True), + AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) @@ -230,7 +229,7 @@ def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_wor # create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( - dimensions=3, + spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), @@ -242,24 +241,21 @@ def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_wor [ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), - AsDiscreted(keys="pred", threshold_values=True), + AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), # test the case that `pred` in `engine.state.output`, while `image_meta_dict` in `engine.state.batch` SaveImaged( - keys="pred", - meta_keys="image_meta_dict", - output_dir=root_dir, - output_postfix="seg_transform", + keys="pred", meta_keys=PostFix.meta("image"), output_dir=root_dir, output_postfix="seg_transform" ), ] ) val_handlers = [ - StatsHandler(output_transform=lambda x: None), + StatsHandler(iteration_log=False), CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}), SegmentationSaver( output_dir=root_dir, output_postfix="seg_handler", - batch_transform=from_engine("image_meta_dict"), + batch_transform=from_engine(PostFix.meta("image")), output_transform=from_engine("pred"), ), ] @@ -297,7 +293,6 @@ def setUp(self): self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0") monai.config.print_config() - logging.basicConfig(stream=sys.stdout, level=logging.INFO) def tearDown(self): set_determinism(seed=None) @@ -351,20 +346,15 @@ def _test_saved_files(postfix): def test_training(self): repeated = [] - test_rounds = 3 if monai.utils.module.get_torch_version_tuple() >= (1, 6) else 2 + test_rounds = 3 for i in range(test_rounds): results = self.train_and_infer(idx=i) repeated.append(results) np.testing.assert_allclose(repeated[0], repeated[1]) - @TimedCall( - seconds=300, - skip_timing=not torch.cuda.is_available(), - daemon=False, - ) + @TimedCall(seconds=300, skip_timing=not torch.cuda.is_available(), daemon=False) def test_timing(self): - if monai.utils.module.get_torch_version_tuple() >= (1, 6): - self.train_and_infer(idx=2) + self.train_and_infer(idx=2) if __name__ == "__main__": diff --git a/tests/test_integration_workflows_gan.py b/tests/test_integration_workflows_gan.py index c54e8b01f2..c9306b349f 100644 --- a/tests/test_integration_workflows_gan.py +++ b/tests/test_integration_workflows_gan.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,10 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import os import shutil -import sys import tempfile import unittest from glob import glob @@ -90,8 +88,7 @@ def generator_loss(gen_images): train_handlers = [ StatsHandler( - name="training_loss", - output_transform=lambda x: {Keys.GLOSS: x[Keys.GLOSS], Keys.DLOSS: x[Keys.DLOSS]}, + name="training_loss", output_transform=lambda x: {Keys.GLOSS: x[Keys.GLOSS], Keys.DLOSS: x[Keys.DLOSS]} ), TensorBoardStatsHandler( log_dir=root_dir, @@ -139,7 +136,6 @@ def setUp(self): self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0") monai.config.print_config() - logging.basicConfig(stream=sys.stdout, level=logging.INFO) def tearDown(self): set_determinism(seed=None) diff --git a/tests/test_intensity_stats.py b/tests/test_intensity_stats.py index 059271e442..3479306180 100644 --- a/tests/test_intensity_stats.py +++ b/tests/test_intensity_stats.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,45 +15,43 @@ from parameterized import parameterized from monai.transforms import IntensityStats +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"ops": ["max", "mean"], "key_prefix": "orig"}, - np.array([[[0.0, 1.0], [2.0, 3.0]]]), - {"affine": None}, - {"orig_max": 3.0, "orig_mean": 1.5}, -] - -TEST_CASE_2 = [ - {"ops": "std", "key_prefix": "orig"}, - np.array([[[0.0, 1.0], [2.0, 3.0]]]), - None, - {"orig_std": 1.118034}, -] - -TEST_CASE_3 = [ - {"ops": [lambda x: np.mean(x), "max", lambda x: np.min(x)], "key_prefix": "orig"}, - np.array([[[0.0, 1.0], [2.0, 3.0]]]), - None, - {"orig_custom_0": 1.5, "orig_max": 3.0, "orig_custom_1": 0.0}, -] - -TEST_CASE_4 = [ - {"ops": ["max", "mean"], "key_prefix": "orig", "channel_wise": True}, - np.array([[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]]]), - {"affine": None}, - {"orig_max": [3.0, 7.0], "orig_mean": [1.5, 5.5]}, -] - -TEST_CASE_5 = [ - {"ops": ["max", "mean"], "key_prefix": "orig"}, - np.array([[[0.0, 1.0], [2.0, 3.0]]]), - {"affine": None}, - {"orig_max": 3.0, "orig_mean": 1.5}, -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.extend( + [ + [ + {"ops": ["max", "mean"], "key_prefix": "orig"}, + p([[[0.0, 1.0], [2.0, 3.0]]]), + {"affine": None}, + {"orig_max": 3.0, "orig_mean": 1.5}, + ], + [{"ops": "std", "key_prefix": "orig"}, p([[[0.0, 1.0], [2.0, 3.0]]]), None, {"orig_std": 1.118034}], + [ + {"ops": [np.mean, "max", np.min], "key_prefix": "orig"}, + p([[[0.0, 1.0], [2.0, 3.0]]]), + None, + {"orig_custom_0": 1.5, "orig_max": 3.0, "orig_custom_1": 0.0}, + ], + [ + {"ops": ["max", "mean"], "key_prefix": "orig", "channel_wise": True}, + p([[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]]]), + {"affine": None}, + {"orig_max": [3.0, 7.0], "orig_mean": [1.5, 5.5]}, + ], + [ + {"ops": ["max", "mean"], "key_prefix": "orig"}, + p([[[0.0, 1.0], [2.0, 3.0]]]), + {"affine": None}, + {"orig_max": 3.0, "orig_mean": 1.5}, + ], + ] + ) class TestIntensityStats(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_value(self, input_param, img, meta_dict, expected): _, meta_dict = IntensityStats(**input_param)(img, meta_dict) for k, v in expected.items(): @@ -61,11 +59,12 @@ def test_value(self, input_param, img, meta_dict, expected): np.testing.assert_allclose(v, meta_dict[k], atol=1e-3) def test_mask(self): - img = np.array([[[0.0, 1.0], [2.0, 3.0]]]) - mask = np.array([[[1, 0], [1, 0]]], dtype=bool) - img, meta_dict = IntensityStats(ops=["max", "mean"], key_prefix="orig")(img, mask=mask) - np.testing.assert_allclose(meta_dict["orig_max"], 2.0, atol=1e-3) - np.testing.assert_allclose(meta_dict["orig_mean"], 1.0, atol=1e-3) + for p in TEST_NDARRAYS: + img = p([[[0.0, 1.0], [2.0, 3.0]]]) + mask = np.array([[[1, 0], [1, 0]]], dtype=bool) + img, meta_dict = IntensityStats(ops=["max", "mean"], key_prefix="orig")(img, mask=mask) + np.testing.assert_allclose(meta_dict["orig_max"], 2.0, atol=1e-3) + np.testing.assert_allclose(meta_dict["orig_mean"], 1.0, atol=1e-3) if __name__ == "__main__": diff --git a/tests/test_intensity_statsd.py b/tests/test_intensity_statsd.py index 8c8bc8795a..d6aeac61b0 100644 --- a/tests/test_intensity_statsd.py +++ b/tests/test_intensity_statsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,6 +18,7 @@ from monai.data import DataLoader, Dataset from monai.transforms import IntensityStatsd +from monai.utils.enums import PostFix TEST_CASE_1 = [ {"keys": "img", "ops": ["max", "mean"], "key_prefix": "orig", "meta_keys": "test_meta"}, @@ -29,14 +30,14 @@ TEST_CASE_2 = [ {"keys": "img", "ops": "std", "key_prefix": "orig"}, {"img": np.array([[[0.0, 1.0], [2.0, 3.0]]])}, - "img_meta_dict", + PostFix.meta("img"), {"orig_std": 1.118034}, ] TEST_CASE_3 = [ - {"keys": "img", "ops": [lambda x: np.mean(x), "max", lambda x: np.min(x)], "key_prefix": "orig"}, + {"keys": "img", "ops": [np.mean, "max", np.min], "key_prefix": "orig"}, {"img": np.array([[[0.0, 1.0], [2.0, 3.0]]])}, - "img_meta_dict", + PostFix.meta("img"), {"orig_custom_0": 1.5, "orig_max": 3.0, "orig_custom_1": 0.0}, ] @@ -68,7 +69,7 @@ def test_dataloader(self): mp.set_start_method("spawn", force=True) for d in dataloader: - meta = d["img_meta_dict"] + meta = d[PostFix.meta("img")] np.testing.assert_allclose(meta["orig_max"], [3.0, 3.0], atol=1e-3) np.testing.assert_allclose(meta["orig_mean"], [1.5, 1.5], atol=1e-3) # restore the mp method @@ -77,7 +78,7 @@ def test_dataloader(self): def test_mask(self): data = {"img": np.array([[[0.0, 1.0], [2.0, 3.0]]]), "img_mask": np.array([[[1, 0], [1, 0]]], dtype=bool)} stats = IntensityStatsd(keys="img", ops=["max", "mean"], mask_keys="img_mask", key_prefix="orig") - meta = stats(data)["img_meta_dict"] + meta = stats(data)[PostFix.meta("img")] np.testing.assert_allclose(meta["orig_max"], 2.0, atol=1e-3) np.testing.assert_allclose(meta["orig_mean"], 1.0, atol=1e-3) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index f2470d47fd..c04e9b0cd7 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -59,13 +59,13 @@ Spacingd, SpatialCropd, SpatialPadd, + TraceableTransform, Transposed, Zoomd, allow_missing_keys_mode, convert_inverse_interp_mode, ) from monai.utils import first, get_seed, optional_import, set_determinism -from monai.utils.enums import InverseKeys from tests.utils import make_nifti_image, make_rand_affine if TYPE_CHECKING: @@ -122,23 +122,9 @@ ) ) -TESTS.append( - ( - "SpatialPadd 3d", - "3D", - 0, - SpatialPadd(KEYS, spatial_size=[112, 113, 116]), - ) -) +TESTS.append(("SpatialPadd 3d", "3D", 0, SpatialPadd(KEYS, spatial_size=[112, 113, 116]))) -TESTS.append( - ( - "SpatialCropd 2d", - "2D", - 0, - SpatialCropd(KEYS, [49, 51], [90, 89]), - ) -) +TESTS.append(("SpatialCropd 2d", "2D", 0, SpatialCropd(KEYS, [49, 51], [90, 89]))) TESTS.append( ( @@ -149,91 +135,28 @@ ) ) -TESTS.append( - ( - "SpatialCropd 2d", - "2D", - 0, - SpatialCropd(KEYS, [49, 51], [390, 89]), - ) -) +TESTS.append(("SpatialCropd 2d", "2D", 0, SpatialCropd(KEYS, [49, 51], [390, 89]))) -TESTS.append( - ( - "SpatialCropd 3d", - "3D", - 0, - SpatialCropd(KEYS, [49, 51, 44], [90, 89, 93]), - ) -) +TESTS.append(("SpatialCropd 3d", "3D", 0, SpatialCropd(KEYS, [49, 51, 44], [90, 89, 93]))) TESTS.append(("RandSpatialCropd 2d", "2D", 0, RandSpatialCropd(KEYS, [96, 93], None, True, False))) TESTS.append(("RandSpatialCropd 3d", "3D", 0, RandSpatialCropd(KEYS, [96, 93, 92], None, False, False))) -TESTS.append( - ( - "BorderPadd 2d", - "2D", - 0, - BorderPadd(KEYS, [3, 7, 2, 5]), - ) -) +TESTS.append(("BorderPadd 2d", "2D", 0, BorderPadd(KEYS, [3, 7, 2, 5]))) -TESTS.append( - ( - "BorderPadd 2d", - "2D", - 0, - BorderPadd(KEYS, [3, 7]), - ) -) +TESTS.append(("BorderPadd 2d", "2D", 0, BorderPadd(KEYS, [3, 7]))) -TESTS.append( - ( - "BorderPadd 3d", - "3D", - 0, - BorderPadd(KEYS, [4]), - ) -) +TESTS.append(("BorderPadd 3d", "3D", 0, BorderPadd(KEYS, [4]))) -TESTS.append( - ( - "DivisiblePadd 2d", - "2D", - 0, - DivisiblePadd(KEYS, k=4), - ) -) +TESTS.append(("DivisiblePadd 2d", "2D", 0, DivisiblePadd(KEYS, k=4))) -TESTS.append( - ( - "DivisiblePadd 3d", - "3D", - 0, - DivisiblePadd(KEYS, k=[4, 8, 11]), - ) -) +TESTS.append(("DivisiblePadd 3d", "3D", 0, DivisiblePadd(KEYS, k=[4, 8, 11]))) -TESTS.append( - ( - "CenterSpatialCropd 2d", - "2D", - 0, - CenterSpatialCropd(KEYS, roi_size=95), - ) -) +TESTS.append(("CenterSpatialCropd 2d", "2D", 0, CenterSpatialCropd(KEYS, roi_size=95))) -TESTS.append( - ( - "CenterSpatialCropd 3d", - "3D", - 0, - CenterSpatialCropd(KEYS, roi_size=[95, 97, 98]), - ) -) +TESTS.append(("CenterSpatialCropd 3d", "3D", 0, CenterSpatialCropd(KEYS, roi_size=[95, 97, 98]))) TESTS.append(("CropForegroundd 2d", "2D", 0, CropForegroundd(KEYS, source_key="label", margin=2))) @@ -242,69 +165,20 @@ TESTS.append(("ResizeWithPadOrCropd 3d", "3D", 0, ResizeWithPadOrCropd(KEYS, [201, 150, 105]))) -TESTS.append( - ( - "Flipd 3d", - "3D", - 0, - Flipd(KEYS, [1, 2]), - ) -) +TESTS.append(("Flipd 3d", "3D", 0, Flipd(KEYS, [1, 2]))) -TESTS.append( - ( - "RandFlipd 3d", - "3D", - 0, - RandFlipd(KEYS, 1, [1, 2]), - ) -) +TESTS.append(("RandFlipd 3d", "3D", 0, RandFlipd(KEYS, 1, [1, 2]))) -TESTS.append( - ( - "RandAxisFlipd 3d", - "3D", - 0, - RandAxisFlipd(KEYS, 1), - ) -) +TESTS.append(("RandAxisFlipd 3d", "3D", 0, RandAxisFlipd(KEYS, 1))) for acc in [True, False]: - TESTS.append( - ( - "Orientationd 3d", - "3D", - 0, - Orientationd(KEYS, "RAS", as_closest_canonical=acc), - ) - ) + TESTS.append(("Orientationd 3d", "3D", 0, Orientationd(KEYS, "RAS", as_closest_canonical=acc))) -TESTS.append( - ( - "Rotate90d 2d", - "2D", - 0, - Rotate90d(KEYS), - ) -) +TESTS.append(("Rotate90d 2d", "2D", 0, Rotate90d(KEYS))) -TESTS.append( - ( - "Rotate90d 3d", - "3D", - 0, - Rotate90d(KEYS, k=2, spatial_axes=(1, 2)), - ) -) +TESTS.append(("Rotate90d 3d", "3D", 0, Rotate90d(KEYS, k=2, spatial_axes=(1, 2)))) -TESTS.append( - ( - "RandRotate90d 3d", - "3D", - 0, - RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2)), - ) -) +TESTS.append(("RandRotate90d 3d", "3D", 0, RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2)))) TESTS.append(("Spacingd 3d", "3D", 3e-2, Spacingd(KEYS, [0.5, 0.7, 0.9], diagonal=False))) @@ -327,50 +201,22 @@ ) ) -TESTS.append( - ( - "Zoomd 1d", - "1D odd", - 0, - Zoomd(KEYS, zoom=2, keep_size=False), - ) -) +TESTS.append(("Zoomd 1d", "1D odd", 0, Zoomd(KEYS, zoom=2, keep_size=False))) -TESTS.append( - ( - "Zoomd 2d", - "2D", - 2e-1, - Zoomd(KEYS, zoom=0.9), - ) -) +TESTS.append(("Zoomd 2d", "2D", 2e-1, Zoomd(KEYS, zoom=0.9))) -TESTS.append( - ( - "Zoomd 3d", - "3D", - 3e-2, - Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False), - ) -) +TESTS.append(("Zoomd 3d", "3D", 3e-2, Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False))) TESTS.append(("RandZoom 3d", "3D", 9e-2, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) -TESTS.append( - ( - "RandRotated, prob 0", - "2D", - 0, - RandRotated(KEYS, prob=0), - ) -) +TESTS.append(("RandRotated, prob 0", "2D", 0, RandRotated(KEYS, prob=0, dtype=np.float64))) TESTS.append( ( "Rotated 2d", "2D", 8e-2, - Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False), + Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False, dtype=np.float64), ) ) @@ -379,7 +225,7 @@ "Rotated 3d", "3D", 1e-1, - Rotated(KEYS, [random.uniform(np.pi / 6, np.pi) for _ in range(3)], True), # type: ignore + Rotated(KEYS, [random.uniform(np.pi / 6, np.pi) for _ in range(3)], True, dtype=np.float64), ) ) @@ -388,27 +234,13 @@ "RandRotated 3d", "3D", 1e-1, - RandRotated(KEYS, *[random.uniform(np.pi / 6, np.pi) for _ in range(3)], 1), # type: ignore + RandRotated(KEYS, *[random.uniform(np.pi / 6, np.pi) for _ in range(3)], 1, dtype=np.float64), # type: ignore ) ) -TESTS.append( - ( - "Transposed 2d", - "2D", - 0, - Transposed(KEYS, [0, 2, 1]), # channel=0 - ) -) +TESTS.append(("Transposed 2d", "2D", 0, Transposed(KEYS, [0, 2, 1]))) # channel=0 -TESTS.append( - ( - "Transposed 3d", - "3D", - 0, - Transposed(KEYS, [0, 3, 1, 2]), # channel=0 - ) -) +TESTS.append(("Transposed 3d", "3D", 0, Transposed(KEYS, [0, 3, 1, 2]))) # channel=0 TESTS.append( ( @@ -444,14 +276,7 @@ ) ) -TESTS.append( - ( - "RandAffine 3d", - "3D", - 0, - RandAffined(KEYS, spatial_size=None, prob=0), - ) -) +TESTS.append(("RandAffine 3d", "3D", 0, RandAffined(KEYS, spatial_size=None, prob=0))) TESTS.append( ( @@ -462,32 +287,11 @@ ) ) -TESTS.append( - ( - "RandCropByPosNegLabeld 2d", - "2D", - 1e-7, - RandCropByPosNegLabeld(KEYS, "label", (99, 96), num_samples=10), - ) -) +TESTS.append(("RandCropByPosNegLabeld 2d", "2D", 1e-7, RandCropByPosNegLabeld(KEYS, "label", (99, 96), num_samples=10))) -TESTS.append( - ( - "RandSpatialCropSamplesd 2d", - "2D", - 1e-7, - RandSpatialCropSamplesd(KEYS, (90, 91), num_samples=10), - ) -) +TESTS.append(("RandSpatialCropSamplesd 2d", "2D", 1e-7, RandSpatialCropSamplesd(KEYS, (90, 91), num_samples=10))) -TESTS.append( - ( - "RandWeightedCropd 2d", - "2D", - 1e-7, - RandWeightedCropd(KEYS, "label", (90, 91), num_samples=10), - ) -) +TESTS.append(("RandWeightedCropd 2d", "2D", 1e-7, RandWeightedCropd(KEYS, "label", (90, 91), num_samples=10))) TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] @@ -566,8 +370,8 @@ def setUp(self): "other": np.array(im_1d, copy=True), } - im_2d_fname, seg_2d_fname = [make_nifti_image(i) for i in create_test_image_2d(101, 100)] - im_3d_fname, seg_3d_fname = [make_nifti_image(i, affine) for i in create_test_image_3d(100, 101, 107)] + im_2d_fname, seg_2d_fname = (make_nifti_image(i) for i in create_test_image_2d(101, 100)) + im_3d_fname, seg_3d_fname = (make_nifti_image(i, affine) for i in create_test_image_3d(100, 101, 107)) load_ims = Compose([LoadImaged(KEYS), AddChanneld(KEYS)]) self.all_data["2D"] = load_ims({"image": im_2d_fname, "label": seg_2d_fname}) @@ -651,27 +455,15 @@ def test_inverse_inferred_seg(self, extra_transform): batch_size = 10 # num workers = 0 for mac - num_workers = 2 if sys.platform != "darwin" else 0 - transforms = Compose( - [ - AddChanneld(KEYS), - SpatialPadd(KEYS, (150, 153)), - extra_transform, - ] - ) + num_workers = 2 if sys.platform == "linux" else 0 + transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), extra_transform]) num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform)) dataset = CacheDataset(test_data, transform=transforms, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) device = "cuda" if torch.cuda.is_available() else "cpu" - model = UNet( - dimensions=2, - in_channels=1, - out_channels=1, - channels=(2, 4), - strides=(2,), - ).to(device) + model = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(2,)).to(device) data = first(loader) self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) @@ -679,7 +471,7 @@ def test_inverse_inferred_seg(self, extra_transform): labels = data["label"].to(device) segs = model(labels).detach().cpu() - label_transform_key = "label" + InverseKeys.KEY_SUFFIX + label_transform_key = TraceableTransform.trace_key("label") segs_dict = {"label": segs, label_transform_key: data[label_transform_key]} segs_dict_decollated = decollate_batch(segs_dict) diff --git a/tests/test_inverse_collation.py b/tests/test_inverse_collation.py index c302e04017..4e8c6b58cc 100644 --- a/tests/test_inverse_collation.py +++ b/tests/test_inverse_collation.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -48,15 +48,11 @@ for t in [ RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=KEYS, prob=0.5), - RandRotate90d(keys=KEYS, spatial_axes=(1, 2)), + Compose([RandRotate90d(keys=KEYS, spatial_axes=(1, 2)), ToTensord(keys=KEYS)]), RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), - RandRotated(keys=KEYS, prob=0.5, range_x=np.pi), + RandRotated(keys=KEYS, prob=0.5, range_x=np.pi, dtype=np.float64), RandAffined( - keys=KEYS, - prob=0.5, - rotate_range=np.pi, - device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), - as_tensor_output=False, + keys=KEYS, prob=0.5, rotate_range=np.pi, device=torch.device("cuda" if torch.cuda.is_available() else "cpu") ), ] ] @@ -67,15 +63,11 @@ for t in [ RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1]), RandAxisFlipd(keys=KEYS, prob=0.5), - RandRotate90d(keys=KEYS, prob=0.5, spatial_axes=(0, 1)), + Compose([RandRotate90d(keys=KEYS, prob=0.5, spatial_axes=(0, 1)), ToTensord(keys=KEYS)]), RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), - RandRotated(keys=KEYS, prob=0.5, range_x=np.pi), + RandRotated(keys=KEYS, prob=0.5, range_x=np.pi, dtype=np.float64), RandAffined( - keys=KEYS, - prob=0.5, - rotate_range=np.pi, - device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), - as_tensor_output=False, + keys=KEYS, prob=0.5, rotate_range=np.pi, device=torch.device("cuda" if torch.cuda.is_available() else "cpu") ), ] ] @@ -91,12 +83,12 @@ def setUp(self): set_determinism(seed=0) b_size = 11 - im_fname, seg_fname = [make_nifti_image(i) for i in create_test_image_3d(101, 100, 107)] + im_fname, seg_fname = (make_nifti_image(i) for i in create_test_image_3d(101, 100, 107)) load_ims = Compose([LoadImaged(KEYS), AddChanneld(KEYS)]) self.data_3d = [load_ims({"image": im_fname, "label": seg_fname}) for _ in range(b_size)] b_size = 8 - im_fname, seg_fname = [make_nifti_image(i) for i in create_test_image_2d(62, 37, rad_max=10)] + im_fname, seg_fname = (make_nifti_image(i) for i in create_test_image_2d(62, 37, rad_max=10)) load_ims = Compose([LoadImaged(KEYS), AddChanneld(KEYS)]) self.data_2d = [load_ims({"image": im_fname, "label": seg_fname}) for _ in range(b_size)] @@ -107,17 +99,14 @@ def tearDown(self): @parameterized.expand(TESTS_2D + TESTS_3D) def test_collation(self, _, transform, collate_fn, ndim): - if ndim == 3: - data = self.data_3d - else: - data = self.data_2d + data = self.data_3d if ndim == 3 else self.data_2d if collate_fn: modified_transform = transform else: modified_transform = Compose([transform, ResizeWithPadOrCropd(KEYS, 100), ToTensord(KEYS)]) # num workers = 0 for mac or gpu transforms - num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available() else 2 + num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2 dataset = CacheDataset(data, transform=modified_transform, progress=False) loader = DataLoader(dataset, num_workers, batch_size=self.batch_size, collate_fn=collate_fn) diff --git a/tests/test_invertd.py b/tests/test_invertd.py index 5b98653f0a..64c26c4012 100644 --- a/tests/test_invertd.py +++ b/tests/test_invertd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,8 +34,10 @@ ResizeWithPadOrCropd, ScaleIntensityd, Spacingd, + ToTensord, ) -from monai.utils.misc import set_determinism +from monai.utils import set_determinism +from monai.utils.enums import PostFix from tests.utils import make_nifti_image KEYS = ["image", "label"] @@ -44,7 +46,7 @@ class TestInvertd(unittest.TestCase): def test_invert(self): set_determinism(seed=0) - im_fname, seg_fname = [make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)] + im_fname, seg_fname = (make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)) transform = Compose( [ LoadImaged(KEYS), @@ -56,54 +58,113 @@ def test_invert(self): RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), - RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), + RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True, dtype=np.float64), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), # test EnsureTensor for complicated dict data and invert it - CopyItemsd("image_meta_dict", times=1, names="test_dict"), + CopyItemsd(PostFix.meta("image"), times=1, names="test_dict"), # test to support Tensor, Numpy array and dictionary when inverting EnsureTyped(keys=["image", "test_dict"]), + ToTensord("image"), CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), - CopyItemsd("label", times=1, names="label_inverted"), + CopyItemsd("label", times=2, names=["label_inverted", "label_inverted1"]), + CopyItemsd("image", times=2, names=["image_inverted", "image_inverted1"]), ] ) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms - num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available() else 2 + num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) inverter = Invertd( # `image` was not copied, invert the original value directly - keys=["image", "label_inverted", "test_dict"], + keys=["image_inverted", "label_inverted", "test_dict"], transform=transform, orig_keys=["label", "label", "test_dict"], - meta_keys=["image_meta_dict", "label_inverted_meta_dict", None], - orig_meta_keys=["label_meta_dict", "label_meta_dict", None], + meta_keys=[PostFix.meta("image_inverted"), PostFix.meta("label_inverted"), None], + orig_meta_keys=[PostFix.meta("label"), PostFix.meta("label"), None], nearest_interp=True, to_tensor=[True, False, False], device="cpu", ) + inverter_1 = Invertd( + # `image` was not copied, invert the original value directly + keys=["image_inverted1", "label_inverted1"], + transform=transform, + orig_keys=["image", "image"], + meta_keys=[PostFix.meta("image_inverted1"), PostFix.meta("label_inverted1")], + orig_meta_keys=[PostFix.meta("image"), PostFix.meta("image")], + nearest_interp=[True, False], + to_tensor=[True, True], + device="cpu", + ) + + expected_keys = [ + "image", + "image_inverted", + "image_inverted1", + PostFix.meta("image_inverted1"), + PostFix.meta("image_inverted"), + PostFix.meta("image"), + "image_transforms", + "label", + "label_inverted", + "label_inverted1", + PostFix.meta("label_inverted1"), + PostFix.meta("label_inverted"), + PostFix.meta("label"), + "label_transforms", + "test_dict", + "test_dict_transforms", + ] # execute 1 epoch for d in loader: d = decollate_batch(d) for item in d: item = inverter(item) - # this unit test only covers basic function, test_handler_transform_inverter covers more + item = inverter_1(item) + + self.assertListEqual(sorted(item), expected_keys) + self.assertTupleEqual(item["image"].shape[1:], (100, 100, 100)) self.assertTupleEqual(item["label"].shape[1:], (100, 100, 100)) - # check the nearest inerpolation mode - i = item["image"] + # check the nearest interpolation mode + i = item["image_inverted"] torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) i = item["label_inverted"] - np.testing.assert_allclose(i.astype(np.uint8).astype(np.float32), i.astype(np.float32)) + torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) # test inverted test_dict self.assertTrue(isinstance(item["test_dict"]["affine"], np.ndarray)) self.assertTrue(isinstance(item["test_dict"]["filename_or_obj"], str)) + # check the case that different items use different interpolation mode to invert transforms + d = item["image_inverted1"] + # if the interpolation mode is nearest, accumulated diff should be smaller than 1 + self.assertLess(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) + self.assertTupleEqual(d.shape, (1, 100, 101, 107)) + + d = item["label_inverted1"] + # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 + self.assertGreater(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) + self.assertTupleEqual(d.shape, (1, 100, 101, 107)) + + # check labels match + reverted = item["label_inverted"].detach().cpu().numpy().astype(np.int32) + original = LoadImaged(KEYS)(data[-1])["label"] + n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) + reverted_name = item[PostFix.meta("label_inverted")]["filename_or_obj"] + original_name = data[-1]["label"] + self.assertEqual(reverted_name, original_name) + print("invert diff", reverted.size - n_good) + # 25300: 2 workers (cpu, non-macos) + # 1812: 0 workers (gpu or macos) + # 1821: windows torch 1.10.0 + self.assertTrue((reverted.size - n_good) in (34007, 1812, 1821), f"diff. {reverted.size - n_good}") + set_determinism(seed=None) diff --git a/tests/test_is_supported_format.py b/tests/test_is_supported_format.py index c0af8f4395..71a44bd190 100644 --- a/tests/test_is_supported_format.py +++ b/tests/test_is_supported_format.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,35 +15,17 @@ from monai.data import is_supported_format -TEST_CASE_1 = [ - {"filename": "testfile.nii.gz", "suffixes": ["nii", "nii.gz"]}, - True, -] - -TEST_CASE_2 = [ - {"filename": "./testfile.nii.gz", "suffixes": ["nii", "nii.gz"]}, - True, -] - -TEST_CASE_3 = [ - {"filename": "./test.data/file.nii.gz", "suffixes": ["nii", "nii.gz"]}, - True, -] - -TEST_CASE_4 = [ - {"filename": "./test.data/file.nii", "suffixes": ["nii", "nii.gz"]}, - True, -] - -TEST_CASE_5 = [ - {"filename": "C:\\documents\\testfile.nii.gz", "suffixes": ["nii", "nii.gz"]}, - True, -] - -TEST_CASE_6 = [ - {"filename": "1.3.12.2.1107.5.4.4.145.nii.gz", "suffixes": ["nii.gz"]}, - True, -] +TEST_CASE_1 = [{"filename": "testfile.nii.gz", "suffixes": ["nii", "nii.gz"]}, True] + +TEST_CASE_2 = [{"filename": "./testfile.nii.gz", "suffixes": ["nii", "nii.gz"]}, True] + +TEST_CASE_3 = [{"filename": "./test.data/file.nii.gz", "suffixes": ["nii", "nii.gz"]}, True] + +TEST_CASE_4 = [{"filename": "./test.data/file.nii", "suffixes": ["nii", "nii.gz"]}, True] + +TEST_CASE_5 = [{"filename": "C:\\documents\\testfile.nii.gz", "suffixes": ["nii", "nii.gz"]}, True] + +TEST_CASE_6 = [{"filename": "1.3.12.2.1107.5.4.4.145.nii.gz", "suffixes": ["nii.gz"]}, True] class TestIsSupportedFormat(unittest.TestCase): diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 7b16eaf594..2c47a2181e 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,6 +10,7 @@ # limitations under the License. import os +import sys import tempfile import unittest @@ -36,14 +37,9 @@ def test_shape(self): with tempfile.TemporaryDirectory() as tempdir: for i in range(6): nib.save(test_image, os.path.join(tempdir, f"test_image{str(i)}.nii.gz")) - test_data.append({"image": os.path.join(tempdir, f"test_image{str(i)}.nii.gz")}) + test_data.append({"image": os.path.join(tempdir, f"test_image{i}.nii.gz")}) - test_transform = Compose( - [ - LoadImaged(keys="image"), - SimulateDelayd(keys="image", delay_time=1e-7), - ] - ) + test_transform = Compose([LoadImaged(keys="image"), SimulateDelayd(keys="image", delay_time=1e-7)]) data_iterator = _Stream(test_data) with self.assertRaises(TypeError): # Dataset doesn't work @@ -54,7 +50,8 @@ def test_shape(self): for d in dataset: self.assertTupleEqual(d["image"].shape, expected_shape) - dataloader = DataLoader(dataset=dataset, batch_size=3, num_workers=2) + num_workers = 2 if sys.platform == "linux" else 0 + dataloader = DataLoader(dataset=dataset, batch_size=3, num_workers=num_workers) for d in dataloader: self.assertTupleEqual(d["image"].shape[1:], expected_shape) diff --git a/tests/test_itk_writer.py b/tests/test_itk_writer.py new file mode 100644 index 0000000000..163fead76e --- /dev/null +++ b/tests/test_itk_writer.py @@ -0,0 +1,55 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import numpy as np +import torch + +from monai.data import ITKWriter +from monai.utils import optional_import + +itk, has_itk = optional_import("itk") +nib, has_nibabel = optional_import("nibabel") + + +@unittest.skipUnless(has_itk, "Requires `itk` package.") +class TestITKWriter(unittest.TestCase): + def test_channel_shape(self): + with tempfile.TemporaryDirectory() as tempdir: + for c in (0, 1, 2, 3): + fname = os.path.join(tempdir, f"testing{c}.nii") + itk_writer = ITKWriter() + itk_writer.set_data_array(torch.zeros(1, 2, 3, 4), channel_dim=c, squeeze_end_dims=False) + itk_writer.set_metadata({}) + itk_writer.write(fname) + itk_obj = itk.imread(fname) + s = [1, 2, 3, 4] + s.pop(c) + np.testing.assert_allclose(itk.size(itk_obj), s) + + def test_rgb(self): + with tempfile.TemporaryDirectory() as tempdir: + fname = os.path.join(tempdir, "testing.png") + writer = ITKWriter(output_dtype=np.uint8) + writer.set_data_array(np.arange(48).reshape(3, 4, 4), channel_dim=0) + writer.set_metadata({"spatial_shape": (5, 5)}) + writer.write(fname) + + output = np.asarray(itk.imread(fname)) + np.testing.assert_allclose(output.shape, (5, 5, 3)) + np.testing.assert_allclose(output[1, 1], (5, 5, 4)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_k_space_spike_noise.py b/tests/test_k_space_spike_noise.py index bb6d05e676..85e8fec3c3 100644 --- a/tests/test_k_space_spike_noise.py +++ b/tests/test_k_space_spike_noise.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,17 +20,15 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import KSpaceSpikeNoise from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from tests.utils import TEST_NDARRAYS -TEST_CASES = [] +TESTS = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for p in TEST_NDARRAYS: + for intensity in [10, None]: + TESTS.append((shape, p, intensity)) -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestKSpaceSpikeNoise(unittest.TestCase): def setUp(self): set_determinism(0) @@ -40,38 +38,47 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, im_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d - im = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] - return torch.Tensor(im) if as_tensor_input else im + im, _ = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) + return im_type(im[None]) - @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_same_result(self, im_shape, im_type, k_intensity): - im = self.get_data(im_shape, as_tensor_input) + im = self.get_data(im_shape, im_type) loc = [0, int(im.shape[1] / 2), 0] if len(im_shape) == 2 else [0, int(im.shape[1] / 2), 0, 0] - k_intensity = 10 - t = KSpaceSpikeNoise(loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoise(loc, k_intensity) out1 = t(deepcopy(im)) out2 = t(deepcopy(im)) + self.assertEqual(type(im), type(out1)) + if isinstance(out1, torch.Tensor): + self.assertEqual(im.device, out1.device) + out1 = out1.cpu() + out2 = out2.cpu() + np.testing.assert_allclose(out1, out2) - self.assertIsInstance(out1, torch.Tensor if as_tensor_output else np.ndarray) - @parameterized.expand(TEST_CASES) - def test_highlighted_kspace_pixel(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_highlighted_kspace_pixel(self, im_shape, as_tensor_input, k_intensity): im = self.get_data(im_shape, as_tensor_input) loc = [0, int(im.shape[1] / 2), 0] if len(im_shape) == 2 else [0, int(im.shape[1] / 2), 0, 0] - k_intensity = 10 - t = KSpaceSpikeNoise(loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoise(loc, k_intensity) out = t(im) - n_dims = len(im_shape) - out_k = fftshift(fftn(out, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) - log_mag = np.log(np.absolute(out_k)) - np.testing.assert_allclose(k_intensity, log_mag[tuple(loc)], 1e-4) + self.assertEqual(type(im), type(out)) + if isinstance(out, torch.Tensor): + self.assertEqual(im.device, out.device) + out = out.cpu() + + if k_intensity is not None: + n_dims = len(im_shape) + out_k = fftshift(fftn(out, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) + log_mag = np.log(np.absolute(out_k)) + np.testing.assert_allclose(k_intensity, log_mag[tuple(loc)], 1e-4) if __name__ == "__main__": diff --git a/tests/test_k_space_spike_noised.py b/tests/test_k_space_spike_noised.py index 616662b3cd..0230f40b15 100644 --- a/tests/test_k_space_spike_noised.py +++ b/tests/test_k_space_spike_noised.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,19 +20,16 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import KSpaceSpikeNoised from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from tests.utils import TEST_NDARRAYS -TEST_CASES = [] +TESTS = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for p in TEST_NDARRAYS: + TESTS.append((shape, p)) KEYS = ["image", "label"] -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestKSpaceSpikeNoised(unittest.TestCase): def setUp(self): set_determinism(0) @@ -42,55 +39,69 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, im_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) - ims = [im[None] for im in ims] - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims - return dict(zip(KEYS, ims)) + ims = [im_type(im[None]) for im in ims] + return {k: v for k, v in zip(KEYS, ims)} - @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_same_result(self, im_shape, im_type): - data = self.get_data(im_shape, as_tensor_input) + data = self.get_data(im_shape, im_type) loc = [0] + [int(im_shape[i] / 2) for i in range(len(im_shape))] k_intensity = 10 - t = KSpaceSpikeNoised(KEYS, loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoised(KEYS, loc, k_intensity) out1 = t(deepcopy(data)) out2 = t(deepcopy(data)) for k in KEYS: + self.assertEqual(type(out1[k]), type(data[k])) + if isinstance(out1[k], torch.Tensor): + self.assertEqual(out1[k].device, data[k].device) + out1[k] = out1[k].cpu() + out2[k] = out2[k].cpu() np.testing.assert_allclose(out1[k], out2[k]) - self.assertIsInstance(out1[k], torch.Tensor if as_tensor_output else np.ndarray) - @parameterized.expand(TEST_CASES) - def test_highlighted_kspace_pixel(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_highlighted_kspace_pixel(self, im_shape, im_type): - data = self.get_data(im_shape, as_tensor_input) + data = self.get_data(im_shape, im_type) loc = [0] + [int(im_shape[i] / 2) for i in range(len(im_shape))] k_intensity = 10 - t = KSpaceSpikeNoised(KEYS, loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoised(KEYS, loc, k_intensity) out = t(data) for k in KEYS: + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k] = out[k].cpu() + n_dims = len(im_shape) out_k = fftshift(fftn(out[k], axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) log_mag = np.log(np.absolute(out_k)) np.testing.assert_allclose(k_intensity, log_mag[tuple(loc)], 1e-1) - @parameterized.expand(TEST_CASES) - def test_dict_matches(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_dict_matches(self, im_shape, im_type): + data = self.get_data(im_shape, im_type) # use same image for both dictionary entries to check same trans is applied to them data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} loc = [0] + [int(im_shape[i] / 2) for i in range(len(im_shape))] k_intensity = 10 - t = KSpaceSpikeNoised(KEYS, loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoised(KEYS, loc, k_intensity) out = t(deepcopy(data)) + for k in KEYS: + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k] = out[k].cpu() + np.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]]) diff --git a/tests/test_keep_largest_connected_component.py b/tests/test_keep_largest_connected_component.py index 527d986614..5c96b62131 100644 --- a/tests/test_keep_largest_connected_component.py +++ b/tests/test_keep_largest_connected_component.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,336 +10,375 @@ # limitations under the License. import unittest +from copy import deepcopy import torch +import torch.nn.functional as F from parameterized import parameterized from monai.transforms import KeepLargestConnectedComponent -from tests.utils import assert_allclose, clone +from monai.transforms.utils_pytorch_numpy_unification import moveaxis +from monai.utils.type_conversion import convert_to_dst_type +from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose -grid_1 = torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]) -grid_2 = torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]) -grid_3 = torch.tensor( + +def to_onehot(x): + out = moveaxis(F.one_hot(torch.as_tensor(x).long())[0], -1, 0) + out, *_ = convert_to_dst_type(out, x) + return out + + +grid_1 = [[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]] +grid_2 = [[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]] +grid_3 = [ [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [1.0, 0.0, 1.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], ], -) -grid_4 = torch.tensor( [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [1.0, 0.0, 1.0, 1.0, 0.0], - [1.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [1.0, 0.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], ], -) - - -TEST_CASE_1 = [ - "value_1", - {"independent": False, "applied_labels": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), -] - -TEST_CASE_2 = [ - "value_2", - {"independent": False, "applied_labels": [2]}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), ] - -TEST_CASE_3 = [ - "independent_value_1_2", - {"independent": True, "applied_labels": [1, 2]}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), +grid_4 = [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [1.0, 0.0, 1.0, 1.0, 0.0], + [1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], ] +grid_5 = [[[0, 0, 1, 0, 0], [0, 1, 1, 1, 1], [1, 1, 1, 0, 0], [1, 1, 0, 1, 0], [1, 1, 0, 0, 1]]] -TEST_CASE_4 = [ - "dependent_value_1_2", - {"independent": False, "applied_labels": [1, 2]}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + "value_1", + {"independent": False, "applied_labels": 1, "is_onehot": False}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), + ] + ) -TEST_CASE_5 = [ - "value_1", - {"independent": True, "applied_labels": [1]}, - grid_2, - torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), -] + TESTS.append( + [ + "value_2", + {"independent": False, "applied_labels": [2], "is_onehot": False}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), + ] + ) -TEST_CASE_6 = [ - "independent_value_1_2", - {"independent": True, "applied_labels": [1, 2]}, - grid_2, - torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), -] + TESTS.append( + [ + "independent_value_1_2", + {"independent": True, "applied_labels": [1, 2], "is_onehot": False}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), + ] + ) -TEST_CASE_7 = [ - "dependent_value_1_2", - {"independent": False, "applied_labels": [1, 2]}, - grid_2, - torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]), -] + TESTS.append( + [ + "dependent_value_1_2", + {"independent": False, "applied_labels": [1, 2], "is_onehot": False}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), + ] + ) -TEST_CASE_8 = [ - "value_1_connect_1", - {"independent": False, "applied_labels": [1], "connectivity": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 2]]]), -] + TESTS.append( + [ + "value_1", + {"independent": True, "applied_labels": [1], "is_onehot": False}, + p(grid_2), + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), + ] + ) -TEST_CASE_9 = [ - "independent_value_1_2_connect_1", - {"independent": True, "applied_labels": [1, 2], "connectivity": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), -] + TESTS.append( + [ + "independent_value_1_2", + {"independent": True, "applied_labels": [1, 2], "is_onehot": False}, + p(grid_2), + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), + ] + ) -TEST_CASE_10 = [ - "dependent_value_1_2_connect_1", - {"independent": False, "applied_labels": [1, 2], "connectivity": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), -] + TESTS.append( + [ + "dependent_value_1_2", + {"independent": False, "applied_labels": [1, 2], "is_onehot": False}, + p(grid_2), + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]), + ] + ) -TEST_CASE_11 = [ - "onehot_independent_batch_2_apply_label_1_connect_1", - {"independent": True, "applied_labels": [1], "connectivity": 1}, - grid_3, - torch.tensor( + TESTS.append( [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], - ], - ), -] + "value_1_connect_1", + {"independent": False, "applied_labels": [1], "connectivity": 1, "is_onehot": False}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 2]]]), + ] + ) -TEST_CASE_12 = [ - "onehot_independent_batch_2_apply_label_1_connect_2", - {"independent": True, "applied_labels": [1], "connectivity": 2}, - grid_3, - torch.tensor( + TESTS.append( [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], - ], - ), -] + "independent_value_1_2_connect_1", + {"independent": True, "applied_labels": [1, 2], "connectivity": 1, "is_onehot": False}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), + ] + ) -TEST_CASE_13 = [ - "onehot_independent_batch_2_apply_label_1_2_connect_2", - {"independent": True, "applied_labels": [1, 2], "connectivity": 2}, - grid_3, - torch.tensor( + TESTS.append( [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - ], - ], - ), -] + "onehot_none_dependent_value_1_2_connect_1", + {"independent": False, "applied_labels": [1, 2], "connectivity": 1}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), + ] + ) -TEST_CASE_14 = [ - "onehot_dependent_batch_2_apply_label_1_2_connect_2", - {"independent": False, "applied_labels": [1, 2], "connectivity": 2}, - grid_4, - torch.tensor( + TESTS.append( [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - ], - ), -] + "onehot_independent_batch_2_apply_label_1_connect_1", + {"independent": True, "applied_labels": [1], "connectivity": 1, "is_onehot": True}, + p(grid_3), + torch.tensor( + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], + ], + ] + ), + ] + ) -TEST_CASE_15 = [ - "onehot_dependent_batch_2_apply_label_1_2_connect_1", - {"independent": False, "applied_labels": [1, 2], "connectivity": 1}, - grid_4, - torch.tensor( + TESTS.append( [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - ], - ), -] + "onehot_independent_batch_2_apply_label_1_connect_2", + {"independent": True, "applied_labels": [1], "connectivity": 2, "is_onehot": True}, + p(grid_3), + torch.tensor( + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], + ], + ] + ), + ] + ) -VALID_CASES = [ - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - TEST_CASE_7, - TEST_CASE_8, - TEST_CASE_9, - TEST_CASE_10, - TEST_CASE_11, - TEST_CASE_12, - TEST_CASE_13, - TEST_CASE_14, - TEST_CASE_15, -] + TESTS.append( + [ + "onehot_independent_batch_2_apply_label_1_2_connect_2", + {"independent": True, "applied_labels": [1, 2], "connectivity": 2, "is_onehot": True}, + p(grid_3), + torch.tensor( + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + ], + ] + ), + ] + ) -ITEST_CASE_1 = ["no_applied_labels_for_single_channel", {"independent": False}, grid_1, TypeError] + TESTS.append( + [ + "onehot_dependent_batch_2_apply_label_1_2_connect_2", + {"independent": False, "applied_labels": [1, 2], "connectivity": 2, "is_onehot": True}, + p(grid_4), + torch.tensor( + [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + ] + ), + ] + ) -ITEST_CASE_2 = ["no_applied_labels_for_multi_channel", {"independent": False}, grid_3, TypeError] + TESTS.append( + [ + "onehot_none_dependent_batch_2_apply_label_1_2_connect_1", + {"independent": False, "applied_labels": [1, 2], "connectivity": 1}, + p(grid_4), + torch.tensor( + [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + ] + ), + ] + ) -INVALID_CASES = [ITEST_CASE_1, ITEST_CASE_2] + TESTS.append( + [ + "all_non_zero_labels", + {"independent": True}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), + ] + ) class TestKeepLargestConnectedComponent(unittest.TestCase): - @parameterized.expand(VALID_CASES) + @parameterized.expand(TESTS) def test_correct_results(self, _, args, input_image, expected): converter = KeepLargestConnectedComponent(**args) - if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): - result = converter(clone(input_image).cuda()) + result = converter(input_image) + assert_allclose(result, expected, type_test=False) - else: - result = converter(clone(input_image)) - assert_allclose(result, expected) + @parameterized.expand(TESTS) + @SkipIfBeforePyTorchVersion((1, 7)) + def test_correct_results_before_after_onehot(self, _, args, input_image, expected): + """ + From torch==1.7, torch.argmax changes its mechanism that if there are multiple maximal values then the + indices of the first maximal value are returned (before this version, the indices of the last maximal value + are returned). + Therefore, we can may use of this changes to convert the onehotted labels into un-onehot format directly + and then check if the result stays the same. - @parameterized.expand(INVALID_CASES) - def test_raise_exception(self, _, args, input_image, expected_error): - with self.assertRaises(expected_error): - converter = KeepLargestConnectedComponent(**args) - if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): - _ = converter(clone(input_image).cuda()) - else: - _ = converter(clone(input_image).clone()) + """ + converter = KeepLargestConnectedComponent(**args) + result = converter(deepcopy(input_image)) + + if "is_onehot" in args: + args["is_onehot"] = not args["is_onehot"] + # if not onehotted, onehot it and make sure result stays the same + if input_image.shape[0] == 1: + img = to_onehot(input_image) + result2 = KeepLargestConnectedComponent(**args)(img) + result2 = result2.argmax(0)[None] + assert_allclose(result, result2) + # if onehotted, un-onehot and check result stays the same + else: + img = input_image.argmax(0)[None] + result2 = KeepLargestConnectedComponent(**args)(img) + assert_allclose(result.argmax(0)[None], result2) if __name__ == "__main__": diff --git a/tests/test_keep_largest_connected_componentd.py b/tests/test_keep_largest_connected_componentd.py index 9478cfb965..a06fb51a97 100644 --- a/tests/test_keep_largest_connected_componentd.py +++ b/tests/test_keep_largest_connected_componentd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,333 +15,331 @@ from parameterized import parameterized from monai.transforms import KeepLargestConnectedComponentd +from tests.utils import TEST_NDARRAYS, assert_allclose -grid_1 = {"img": torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]])} -grid_2 = {"img": torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]])} -grid_3 = { - "img": torch.tensor( - [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [1.0, 0.0, 1.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], - ], - ) -} -grid_4 = { - "img": torch.tensor( - [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [1.0, 0.0, 1.0, 1.0, 0.0], - [1.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - ], - ) -} - -TEST_CASE_1 = [ - "value_1", - {"keys": ["img"], "independent": False, "applied_labels": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), +grid_1 = [[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]] +grid_2 = [[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]] +grid_3 = [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [1.0, 0.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], + ], ] - -TEST_CASE_2 = [ - "value_2", - {"keys": ["img"], "independent": False, "applied_labels": [2]}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), +grid_4 = [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [1.0, 0.0, 1.0, 1.0, 0.0], + [1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], ] +grid_5 = [[[0, 0, 1, 0, 0], [0, 1, 1, 1, 1], [1, 1, 1, 0, 0], [1, 1, 0, 1, 0], [1, 1, 0, 0, 1]]] -TEST_CASE_3 = [ - "independent_value_1_2", - {"keys": ["img"], "independent": True, "applied_labels": [1, 2]}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), -] - -TEST_CASE_4 = [ - "dependent_value_1_2", - {"keys": ["img"], "independent": False, "applied_labels": [1, 2]}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), -] +VALID_CASES = [] +for p in TEST_NDARRAYS: + VALID_CASES.append( + [ + "value_1", + {"keys": ["img"], "independent": False, "applied_labels": 1, "is_onehot": False}, + {"img": p(grid_1)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), + ] + ) -TEST_CASE_5 = [ - "value_1", - {"keys": ["img"], "independent": True, "applied_labels": [1]}, - grid_2, - torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), -] + VALID_CASES.append( + [ + "value_2", + {"keys": ["img"], "independent": False, "applied_labels": [2], "is_onehot": False}, + {"img": p(grid_1)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), + ] + ) -TEST_CASE_6 = [ - "independent_value_1_2", - {"keys": ["img"], "independent": True, "applied_labels": [1, 2]}, - grid_2, - torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), -] + VALID_CASES.append( + [ + "independent_value_1_2", + {"keys": ["img"], "independent": True, "applied_labels": [1, 2], "is_onehot": False}, + {"img": p(grid_1)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), + ] + ) -TEST_CASE_7 = [ - "dependent_value_1_2", - {"keys": ["img"], "independent": False, "applied_labels": [1, 2]}, - grid_2, - torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]), -] + VALID_CASES.append( + [ + "dependent_value_1_2", + {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "is_onehot": False}, + {"img": p(grid_1)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), + ] + ) -TEST_CASE_8 = [ - "value_1_connect_1", - {"keys": ["img"], "independent": False, "applied_labels": [1], "connectivity": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 2]]]), -] + VALID_CASES.append( + [ + "value_1", + {"keys": ["img"], "independent": True, "applied_labels": [1], "is_onehot": False}, + {"img": p(grid_2)}, + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), + ] + ) -TEST_CASE_9 = [ - "independent_value_1_2_connect_1", - {"keys": ["img"], "independent": True, "applied_labels": [1, 2], "connectivity": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), -] + VALID_CASES.append( + [ + "independent_value_1_2", + {"keys": ["img"], "independent": True, "applied_labels": [1, 2], "is_onehot": False}, + {"img": p(grid_2)}, + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), + ] + ) -TEST_CASE_10 = [ - "dependent_value_1_2_connect_1", - {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "connectivity": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), -] + VALID_CASES.append( + [ + "dependent_value_1_2", + {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "is_onehot": False}, + {"img": p(grid_2)}, + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]), + ] + ) -TEST_CASE_11 = [ - "onehot_independent_batch_2_apply_label_1_connect_1", - {"keys": ["img"], "independent": True, "applied_labels": [1], "connectivity": 1}, - grid_3, - torch.tensor( + VALID_CASES.append( [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], - ], - ), -] + "value_1_connect_1", + {"keys": ["img"], "independent": False, "applied_labels": [1], "connectivity": 1, "is_onehot": False}, + {"img": p(grid_1)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 2]]]), + ] + ) -TEST_CASE_12 = [ - "onehot_independent_batch_2_apply_label_1_connect_2", - {"keys": ["img"], "independent": True, "applied_labels": [1], "connectivity": 2}, - grid_3, - torch.tensor( + VALID_CASES.append( [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], - ], - ), -] + "independent_value_1_2_connect_1", + {"keys": ["img"], "independent": True, "applied_labels": [1, 2], "connectivity": 1, "is_onehot": False}, + {"img": p(grid_1)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), + ] + ) -TEST_CASE_13 = [ - "onehot_independent_batch_2_apply_label_1_2_connect_2", - {"keys": ["img"], "independent": True, "applied_labels": [1, 2], "connectivity": 2}, - grid_3, - torch.tensor( + VALID_CASES.append( [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - ], - ], - ), -] + "onehot_none_dependent_value_1_2_connect_1", + {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "connectivity": 1}, + {"img": p(grid_1)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), + ] + ) -TEST_CASE_14 = [ - "onehot_dependent_batch_2_apply_label_1_2_connect_2", - {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "connectivity": 2}, - grid_4, - torch.tensor( + VALID_CASES.append( [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - ], - ), -] + "onehot_independent_batch_2_apply_label_1_connect_1", + {"keys": ["img"], "independent": True, "applied_labels": [1], "connectivity": 1, "is_onehot": True}, + {"img": p(grid_3)}, + torch.tensor( + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], + ], + ] + ), + ] + ) -TEST_CASE_15 = [ - "onehot_dependent_batch_2_apply_label_1_2_connect_1", - {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "connectivity": 1}, - grid_4, - torch.tensor( + VALID_CASES.append( [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - ], - ), -] + "onehot_independent_batch_2_apply_label_1_connect_2", + {"keys": ["img"], "independent": True, "applied_labels": [1], "connectivity": 2, "is_onehot": True}, + {"img": p(grid_3)}, + torch.tensor( + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], + ], + ] + ), + ] + ) -VALID_CASES = [ - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - TEST_CASE_7, - TEST_CASE_8, - TEST_CASE_9, - TEST_CASE_10, - TEST_CASE_11, - TEST_CASE_12, - TEST_CASE_13, - TEST_CASE_14, - TEST_CASE_15, -] + VALID_CASES.append( + [ + "onehot_independent_batch_2_apply_label_1_2_connect_2", + {"keys": ["img"], "independent": True, "applied_labels": [1, 2], "connectivity": 2, "is_onehot": True}, + {"img": p(grid_3)}, + torch.tensor( + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + ], + ] + ), + ] + ) -ITEST_CASE_1 = ["no_applied_labels_for_single_channel", {"keys": ["img"], "independent": False}, grid_1, TypeError] + VALID_CASES.append( + [ + "onehot_dependent_batch_2_apply_label_1_2_connect_2", + {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "connectivity": 2, "is_onehot": True}, + {"img": p(grid_4)}, + torch.tensor( + [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + ] + ), + ] + ) -ITEST_CASE_2 = ["no_applied_labels_for_multi_channel", {"keys": ["img"], "independent": False}, grid_3, TypeError] + VALID_CASES.append( + [ + "onehot_none_dependent_batch_2_apply_label_1_2_connect_1", + {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "connectivity": 1}, + {"img": p(grid_4)}, + torch.tensor( + [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + ] + ), + ] + ) -INVALID_CASES = [ITEST_CASE_1, ITEST_CASE_2] + VALID_CASES.append( + [ + "single_channel_onehot", + {"keys": ["img"], "independent": False, "applied_labels": 0, "connectivity": 1, "is_onehot": True}, + {"img": p(grid_5)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 1, 1, 1, 1], [1, 1, 1, 0, 0], [1, 1, 0, 0, 0], [1, 1, 0, 0, 0]]]), + ] + ) class TestKeepLargestConnectedComponentd(unittest.TestCase): @parameterized.expand(VALID_CASES) def test_correct_results(self, _, args, input_dict, expected): converter = KeepLargestConnectedComponentd(**args) - if torch.cuda.is_available(): - input_dict["img"] = input_dict["img"].cuda() - result = converter(input_dict) - torch.allclose(result["img"], expected.cuda()) - else: - result = converter(input_dict) - torch.allclose(result["img"], expected) - - @parameterized.expand(INVALID_CASES) - def test_raise_exception(self, _, args, input_dict, expected_error): - with self.assertRaises(expected_error): - converter = KeepLargestConnectedComponentd(**args) - if torch.cuda.is_available(): - input_dict["img"] = input_dict["img"].cuda() - _ = converter(input_dict) + result = converter(input_dict) + assert_allclose(result["img"], expected, type_test=False) if __name__ == "__main__": diff --git a/tests/test_label_filter.py b/tests/test_label_filter.py index c699fb31fd..b782f90441 100644 --- a/tests/test_label_filter.py +++ b/tests/test_label_filter.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,86 +16,41 @@ from parameterized import parameterized from monai.transforms import LabelFilter -from tests.utils import assert_allclose, clone +from tests.utils import TEST_NDARRAYS, assert_allclose -grid_1 = torch.tensor( - [ - [ - [ - [1, 2, 3], - [4, 5, 6], - [7, 8, 9], - ] - ] - ] -) +grid_1 = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]]) - -TEST_CASE_0 = [ - "filter_single_label", - {"applied_labels": 3}, - grid_1, - torch.tensor( +VALID_TESTS = [] +for p in TEST_NDARRAYS: + VALID_TESTS.append( [ - [ - [ - [0, 0, 3], - [0, 0, 0], - [0, 0, 0], - ] - ] + "filter_single_label", + {"applied_labels": 3}, + p(grid_1), + p(torch.tensor([[[[0, 0, 3], [0, 0, 0], [0, 0, 0]]]])), ] - ), -] - + ) -TEST_CASE_1 = [ - "filter_single_label_list", - {"applied_labels": [3]}, - grid_1, - torch.tensor( + VALID_TESTS.append( [ - [ - [ - [0, 0, 3], - [0, 0, 0], - [0, 0, 0], - ] - ] + "filter_single_label_list", + {"applied_labels": [3]}, + p(grid_1), + p(torch.tensor([[[[0, 0, 3], [0, 0, 0], [0, 0, 0]]]])), ] - ), -] - -TEST_CASE_2 = [ - "filter_multi_label", - {"applied_labels": [3, 5, 8]}, - grid_1, - torch.tensor( + ) + + VALID_TESTS.append( [ - [ - [ - [0, 0, 3], - [0, 5, 0], - [0, 8, 0], - ] - ] + "filter_multi_label", + {"applied_labels": [3, 5, 8]}, + p(grid_1), + p(torch.tensor([[[[0, 0, 3], [0, 5, 0], [0, 8, 0]]]])), ] - ), -] - -TEST_CASE_3 = [ - "filter_all", - {"applied_labels": [1, 2, 3, 4, 5, 6, 7, 8, 9]}, - grid_1, - grid_1, -] - -VALID_CASES = [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, -] + ) + + VALID_TESTS.append(["filter_all", {"applied_labels": [1, 2, 3, 4, 5, 6, 7, 8, 9]}, p(grid_1), p(grid_1)]) + ITEST_CASE_1 = ["invalid_image_data_type", {"applied_labels": 1}, [[[[1, 1, 1]]]], NotImplementedError] @@ -103,13 +58,10 @@ class TestLabelFilter(unittest.TestCase): - @parameterized.expand(VALID_CASES) + @parameterized.expand(VALID_TESTS) def test_correct_results(self, _, args, input_image, expected): converter = LabelFilter(**args) - if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): - result = converter(clone(input_image).cuda()) - else: - result = converter(clone(input_image)) + result = converter(input_image) assert_allclose(result, expected) @parameterized.expand(INVALID_CASES) @@ -117,9 +69,9 @@ def test_raise_exception(self, _, args, input_image, expected_error): with self.assertRaises(expected_error): converter = LabelFilter(**args) if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): - _ = converter(clone(input_image).cuda()) + _ = converter(input_image.cuda()) else: - _ = converter(clone(input_image)) + _ = converter(input_image) if __name__ == "__main__": diff --git a/tests/test_label_filterd.py b/tests/test_label_filterd.py new file mode 100644 index 0000000000..d53dc21faf --- /dev/null +++ b/tests/test_label_filterd.py @@ -0,0 +1,78 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms.post.dictionary import LabelFilterd +from tests.utils import TEST_NDARRAYS, assert_allclose + +grid_1 = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]]) + +VALID_TESTS = [] +for p in TEST_NDARRAYS: + VALID_TESTS.append( + [ + "filter_single_label", + {"applied_labels": 3}, + p(grid_1), + p(torch.tensor([[[[0, 0, 3], [0, 0, 0], [0, 0, 0]]]])), + ] + ) + + VALID_TESTS.append( + [ + "filter_single_label_list", + {"applied_labels": [3]}, + p(grid_1), + p(torch.tensor([[[[0, 0, 3], [0, 0, 0], [0, 0, 0]]]])), + ] + ) + + VALID_TESTS.append( + [ + "filter_multi_label", + {"applied_labels": [3, 5, 8]}, + p(grid_1), + p(torch.tensor([[[[0, 0, 3], [0, 5, 0], [0, 8, 0]]]])), + ] + ) + + VALID_TESTS.append(["filter_all", {"applied_labels": [1, 2, 3, 4, 5, 6, 7, 8, 9]}, p(grid_1), p(grid_1)]) + + +ITEST_CASE_1 = ["invalid_image_data_type", {"applied_labels": 1}, [[[[1, 1, 1]]]], NotImplementedError] + +INVALID_CASES = [ITEST_CASE_1] + + +class TestLabelFilter(unittest.TestCase): + @parameterized.expand(VALID_TESTS) + def test_correct_results(self, _, args, input_image, expected): + converter = LabelFilterd(keys="image", **args) + result = converter({"image": input_image})["image"] + assert_allclose(result, expected) + + @parameterized.expand(INVALID_CASES) + def test_raise_exception(self, _, args, input_image, expected_error): + with self.assertRaises(expected_error): + converter = LabelFilterd(keys="image", **args) + if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): + _ = converter({"image": input_image.cuda()}) + else: + _ = converter({"image": input_image}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_label_to_contour.py b/tests/test_label_to_contour.py index 8f8f3cc054..fef40af08d 100644 --- a/tests/test_label_to_contour.py +++ b/tests/test_label_to_contour.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,108 +15,107 @@ import torch from monai.transforms import LabelToContour +from tests.utils import TEST_NDARRAYS, assert_allclose -expected_output_for_cube = np.array( +expected_output_for_cube = [ [ - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - ] -) - - -def gen_fixed_cube(): + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], +] + + +def gen_fixed_cube(array_type): scale, core_start, core_end = 8, 1, 7 - cube = torch.zeros(scale, scale, scale) + cube = np.zeros((scale, scale, scale)) cube[core_start:core_end, core_start:core_end, core_start:core_end] = torch.ones( core_end - core_start, core_end - core_start, core_end - core_start ) - cube = torch.unsqueeze(cube, 0) + cube = cube[None] batch_size, channels = 10, 6 - cube = cube.repeat(batch_size, channels, 1, 1, 1) - return cube, expected_output_for_cube + cube = np.tile(cube, (batch_size, channels, 1, 1, 1)) + return array_type(cube), array_type(expected_output_for_cube) -def gen_fixed_img(): - img = torch.tensor( +def gen_fixed_img(array_type): + img = np.array( [ [0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1], @@ -124,19 +123,18 @@ def gen_fixed_img(): [0, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], ], - dtype=torch.float32, + dtype=np.float32, ) batch_size, channels = 10, 6 - img = img.repeat(batch_size, channels, 1, 1) - expected_output_for_img = torch.tensor( + img = array_type(np.tile(img, (batch_size, channels, 1, 1))) + expected_output_for_img = array_type( [ [0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 0, 0, 1], [0, 0, 1, 1, 0, 0, 1], [0, 1, 1, 0, 0, 0, 1], [1, 1, 1, 1, 1, 1, 1], - ], - dtype=torch.float32, + ] ) return img, expected_output_for_img @@ -145,33 +143,34 @@ class TestContour(unittest.TestCase): def test_contour(self): input_param = {"kernel_type": "Laplace"} - # check 5-dim input data - test_cube, expected_output = gen_fixed_cube() - for cube in test_cube: - test_result_cube = LabelToContour(**input_param)(cube) - self.assertEqual(test_result_cube.shape, cube.shape) + for p in TEST_NDARRAYS: + # check 5-dim input data + test_cube, expected_output = gen_fixed_cube(p) + for cube in test_cube: + test_result_cube = LabelToContour(**input_param)(cube) + self.assertEqual(test_result_cube.shape, cube.shape) - test_result_np = test_result_cube.cpu().numpy() - channels = cube.shape[0] - for channel in range(channels): - np.testing.assert_allclose(test_result_np[channel, ...], expected_output) + channels = cube.shape[0] + for channel in range(channels): + assert_allclose(test_result_cube[channel, ...], expected_output) - # check 4-dim input data - test_img, expected_output = gen_fixed_img() - for img in test_img: - channels = img.shape[0] - test_result_img = LabelToContour(**input_param)(img) - self.assertEqual(test_result_img.shape, img.shape) + # check 4-dim input data + test_img, expected_output = gen_fixed_img(p) + for img in test_img: + channels = img.shape[0] + test_result_img = LabelToContour(**input_param)(img) + self.assertEqual(test_result_img.shape, img.shape) - test_result_np = test_result_img.cpu().numpy() - for channel in range(channels): - np.testing.assert_allclose(test_result_np[channel, ...], expected_output) + for channel in range(channels): + assert_allclose(test_result_img[channel, ...], expected_output) # check invalid input data error_input = torch.rand(1, 2) self.assertRaises(ValueError, LabelToContour(**input_param), error_input) error_input = torch.rand(1, 2, 3, 4, 5) self.assertRaises(ValueError, LabelToContour(**input_param), error_input) + error_input = np.random.rand(1, 2, 3, 4, 5) + self.assertRaises(ValueError, LabelToContour(**input_param), error_input) if __name__ == "__main__": diff --git a/tests/test_label_to_contourd.py b/tests/test_label_to_contourd.py index d3795755c7..6481e803ba 100644 --- a/tests/test_label_to_contourd.py +++ b/tests/test_label_to_contourd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,108 +15,107 @@ import torch from monai.transforms import LabelToContourd +from tests.utils import TEST_NDARRAYS, assert_allclose -expected_output_for_cube = np.array( +expected_output_for_cube = [ [ - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - ] -) - - -def gen_fixed_cube(): + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], +] + + +def gen_fixed_cube(array_type): scale, core_start, core_end = 8, 1, 7 - cube = torch.zeros(scale, scale, scale) + cube = np.zeros((scale, scale, scale)) cube[core_start:core_end, core_start:core_end, core_start:core_end] = torch.ones( core_end - core_start, core_end - core_start, core_end - core_start ) - cube = torch.unsqueeze(cube, 0) + cube = cube[None] batch_size, channels = 10, 6 - cube = cube.repeat(batch_size, channels, 1, 1, 1) - return cube, expected_output_for_cube + cube = np.tile(cube, (batch_size, channels, 1, 1, 1)) + return array_type(cube), array_type(expected_output_for_cube) -def gen_fixed_img(): - img = torch.tensor( +def gen_fixed_img(array_type): + img = np.array( [ [0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1], @@ -124,19 +123,19 @@ def gen_fixed_img(): [0, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], ], - dtype=torch.float32, + dtype=np.float32, ) batch_size, channels = 10, 6 - img = img.repeat(batch_size, channels, 1, 1) - expected_output_for_img = torch.tensor( + img = np.tile(img, (batch_size, channels, 1, 1)) + img = array_type(img) + expected_output_for_img = array_type( [ [0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 0, 0, 1], [0, 0, 1, 1, 0, 0, 1], [0, 1, 1, 0, 0, 0, 1], [1, 1, 1, 1, 1, 1, 1], - ], - dtype=torch.float32, + ] ) return img, expected_output_for_img @@ -145,31 +144,34 @@ class TestContourd(unittest.TestCase): def test_contour(self): input_param = {"keys": "img", "kernel_type": "Laplace"} - # check 5-dim input data - test_cube, expected_output = gen_fixed_cube() - for cube in test_cube: - test_result_cube = LabelToContourd(**input_param)({"img": cube}) - self.assertEqual(test_result_cube["img"].shape, cube.shape) - - test_result_np = test_result_cube["img"].cpu().numpy() - channels = cube.shape[0] - for channel in range(channels): - np.testing.assert_allclose(test_result_np[channel, ...], expected_output) - - # check 4-dim input data - test_img, expected_output = gen_fixed_img() - for img in test_img: - channels = img.shape[0] - test_result_img = LabelToContourd(**input_param)({"img": img}) - self.assertEqual(test_result_img["img"].shape, img.shape) - - test_result_np = test_result_img["img"].cpu().numpy() - for channel in range(channels): - np.testing.assert_allclose(test_result_np[channel, ...], expected_output) + for p in TEST_NDARRAYS: + # check 5-dim input data + test_cube, expected_output = gen_fixed_cube(p) + for cube in test_cube: + test_result_cube = LabelToContourd(**input_param)({"img": cube}) + self.assertEqual(test_result_cube["img"].shape, cube.shape) + + test_result_np = test_result_cube["img"] + channels = cube.shape[0] + for channel in range(channels): + assert_allclose(test_result_np[channel, ...], expected_output) + + # check 4-dim input data + test_img, expected_output = gen_fixed_img(p) + for img in test_img: + channels = img.shape[0] + test_result_img = LabelToContourd(**input_param)({"img": img}) + self.assertEqual(test_result_img["img"].shape, img.shape) + + test_result_np = test_result_img["img"] + for channel in range(channels): + assert_allclose(test_result_np[channel, ...], expected_output) # check invalid input data error_input = {"img": torch.rand(1, 2)} self.assertRaises(ValueError, LabelToContourd(**input_param), error_input) + error_input = {"img": np.random.rand(1, 2)} + self.assertRaises(ValueError, LabelToContourd(**input_param), error_input) error_input = {"img": torch.rand(1, 2, 3, 4, 5)} self.assertRaises(ValueError, LabelToContourd(**input_param), error_input) diff --git a/tests/test_label_to_mask.py b/tests/test_label_to_mask.py index 9caa7252f3..8f81a8da1a 100644 --- a/tests/test_label_to_mask.py +++ b/tests/test_label_to_mask.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -64,7 +64,7 @@ def test_value(self, argments, image, expected_data): self.assertEqual(type(result), type(image)) if isinstance(result, torch.Tensor): self.assertEqual(result.device, image.device) - assert_allclose(result, expected_data) + assert_allclose(result, expected_data, type_test=False) if __name__ == "__main__": diff --git a/tests/test_label_to_maskd.py b/tests/test_label_to_maskd.py index b8f0d3c171..e67b857502 100644 --- a/tests/test_label_to_maskd.py +++ b/tests/test_label_to_maskd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -65,7 +65,7 @@ def test_value(self, argments, input_data, expected_data): self.assertEqual(type(r), type(i)) if isinstance(r, torch.Tensor): self.assertEqual(r.device, i.device) - assert_allclose(r, expected_data) + assert_allclose(r, expected_data, type_test=False) if __name__ == "__main__": diff --git a/tests/test_lambda.py b/tests/test_lambda.py index 738c81130d..c187cc979b 100644 --- a/tests/test_lambda.py +++ b/tests/test_lambda.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_lambdad.py b/tests/test_lambdad.py index 05ba0ff6bc..30d70f40fb 100644 --- a/tests/test_lambdad.py +++ b/tests/test_lambdad.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_lesion_froc.py b/tests/test_lesion_froc.py index 2454de88fa..6b4989a9d5 100644 --- a/tests/test_lesion_froc.py +++ b/tests/test_lesion_froc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,18 +19,19 @@ from monai.apps.pathology.metrics import LesionFROC from monai.utils import optional_import -_, has_cucim = optional_import("cucim") +_cucim, has_cucim = optional_import("cucim") +has_cucim = has_cucim and hasattr(_cucim, "CuImage") _, has_skimage = optional_import("skimage.measure") _, has_sp = optional_import("scipy.ndimage") -PILImage, has_pil = optional_import("PIL.Image") +imwrite, has_tif = optional_import("tifffile", name="imwrite") def save_as_tif(filename, array): array = array[::-1, ...] # Upside-down - img = PILImage.fromarray(array) if not filename.endswith(".tif"): filename += ".tif" - img.save(os.path.join("tests", "testing_data", filename)) + file_path = os.path.join("tests", "testing_data", filename) + imwrite(file_path, array, compress="jpeg", tile=(16, 16)) def around(val, interval=3): @@ -301,7 +302,7 @@ class TestEvaluateTumorFROC(unittest.TestCase): @skipUnless(has_cucim, "Requires cucim") @skipUnless(has_skimage, "Requires skimage") @skipUnless(has_sp, "Requires scipy") - @skipUnless(has_pil, "Requires PIL") + @skipUnless(has_tif, "Requires tifffile") def setUp(self): prepare_test_data() diff --git a/tests/test_list_data_collate.py b/tests/test_list_data_collate.py index eebac69fcf..93b06cc187 100644 --- a/tests/test_list_data_collate.py +++ b/tests/test_list_data_collate.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_list_to_dict.py b/tests/test_list_to_dict.py index 2f026f3e29..ec81310c9f 100644 --- a/tests/test_list_to_dict.py +++ b/tests/test_list_to_dict.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,25 +15,13 @@ from monai.utils import list_to_dict -TEST_CASE_1 = [ - ["a=1", "b=2", "c=3", "d=4"], - {"a": 1, "b": 2, "c": 3, "d": 4}, -] +TEST_CASE_1 = [["a=1", "b=2", "c=3", "d=4"], {"a": 1, "b": 2, "c": 3, "d": 4}] -TEST_CASE_2 = [ - ["a=a", "b=b", "c=c", "d=d"], - {"a": "a", "b": "b", "c": "c", "d": "d"}, -] +TEST_CASE_2 = [["a=a", "b=b", "c=c", "d=d"], {"a": "a", "b": "b", "c": "c", "d": "d"}] -TEST_CASE_3 = [ - ["a=0.1", "b=0.2", "c=0.3", "d=0.4"], - {"a": 0.1, "b": 0.2, "c": 0.3, "d": 0.4}, -] +TEST_CASE_3 = [["a=0.1", "b=0.2", "c=0.3", "d=0.4"], {"a": 0.1, "b": 0.2, "c": 0.3, "d": 0.4}] -TEST_CASE_4 = [ - ["a=True", "b=TRUE", "c=false", "d=FALSE"], - {"a": True, "b": True, "c": False, "d": False}, -] +TEST_CASE_4 = [["a=True", "b=TRUE", "c=false", "d=FALSE"], {"a": True, "b": True, "c": False, "d": False}] TEST_CASE_5 = [ ["a='1'", "b=2 ", " c = 3", "d='test'", "'e'=0", "f", "g=None"], diff --git a/tests/test_lltm.py b/tests/test_lltm.py index f1311379bc..7633c2fe34 100644 --- a/tests/test_lltm.py +++ b/tests/test_lltm.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,7 +15,9 @@ from parameterized import parameterized from monai.networks.layers import LLTM -from tests.utils import SkipIfNoModule +from tests.utils import SkipIfNoModule, is_tf32_env + +_rtol = 0.001 if is_tf32_env() else 0.0001 TEST_CASE_1 = [ {"input_features": 32, "state_size": 2}, @@ -50,8 +52,8 @@ def test_value_cuda(self, input_param, expected_h, expected_c): new_h, new_c = lltm(x, (h, c)) (new_h.sum() + new_c.sum()).backward() - torch.testing.assert_allclose(new_h, expected_h.to(device), rtol=0.0001, atol=1e-04) - torch.testing.assert_allclose(new_c, expected_c.to(device), rtol=0.0001, atol=1e-04) + torch.testing.assert_allclose(new_h, expected_h.to(device), rtol=_rtol, atol=0.001) + torch.testing.assert_allclose(new_c, expected_c.to(device), rtol=_rtol, atol=0.001) if __name__ == "__main__": diff --git a/tests/test_lmdbdataset.py b/tests/test_lmdbdataset.py index fbdb651297..33f27ee4bc 100644 --- a/tests/test_lmdbdataset.py +++ b/tests/test_lmdbdataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -57,7 +57,7 @@ SimulateDelayd(keys=["image", "label", "extra"], delay_time=[1e-7, 1e-6, 1e-5]), ], (128, 128, 128), - {"pickle_protocol": 2, "lmdb_kwargs": {"map_size": 100 * 1024 ** 2}}, + {"pickle_protocol": 2, "lmdb_kwargs": {"map_size": 100 * 1024**2}}, ] TEST_CASE_6 = [ @@ -66,7 +66,7 @@ SimulateDelayd(keys=["image", "label", "extra"], delay_time=[1e-7, 1e-6, 1e-5]), ], (128, 128, 128), - {"db_name": "testdb", "lmdb_kwargs": {"map_size": 100 * 1024 ** 2}}, + {"db_name": "testdb", "lmdb_kwargs": {"map_size": 100 * 1024**2}}, ] TEST_CASE_7 = [ @@ -75,7 +75,7 @@ SimulateDelayd(keys=["image", "label", "extra"], delay_time=[1e-7, 1e-6, 1e-5]), ], (128, 128, 128), - {"db_name": "testdb", "lmdb_kwargs": {"map_size": 2 * 1024 ** 2}}, + {"db_name": "testdb", "lmdb_kwargs": {"map_size": 2 * 1024**2}}, ] diff --git a/tests/test_lmdbdataset_dist.py b/tests/test_lmdbdataset_dist.py new file mode 100644 index 0000000000..cad2949dde --- /dev/null +++ b/tests/test_lmdbdataset_dist.py @@ -0,0 +1,72 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import tempfile +import unittest + +import numpy as np + +from monai.data import LMDBDataset, json_hashing +from monai.transforms import Transform +from tests.utils import DistCall, DistTestCase, skip_if_windows + + +class _InplaceXform(Transform): + def __call__(self, data): + if data: + data[0] = data[0] + np.pi + else: + data.append(1) + return data + + +@skip_if_windows +class TestMPLMDBDataset(DistTestCase): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tempdir) + + @DistCall(nnodes=1, nproc_per_node=1) + def test_mp_cache(self): + items = [[list(range(i))] for i in range(5)] + + ds = LMDBDataset(items, transform=_InplaceXform(), cache_dir=self.tempdir, lmdb_kwargs={"map_size": 10 * 1024}) + self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) + ds1 = LMDBDataset(items, transform=_InplaceXform(), cache_dir=self.tempdir, lmdb_kwargs={"map_size": 10 * 1024}) + self.assertEqual(list(ds1), list(ds)) + self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) + + ds = LMDBDataset( + items, + transform=_InplaceXform(), + cache_dir=self.tempdir, + lmdb_kwargs={"map_size": 10 * 1024}, + hash_func=json_hashing, + ) + self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) + ds1 = LMDBDataset( + items, + transform=_InplaceXform(), + cache_dir=self.tempdir, + lmdb_kwargs={"map_size": 10 * 1024}, + hash_func=json_hashing, + ) + self.assertEqual(list(ds1), list(ds)) + self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) + + self.assertTrue(isinstance(ds1.info(), dict)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_load_decathlon_datalist.py b/tests/test_load_decathlon_datalist.py index fe7ff6f8a2..91d144d84f 100644 --- a/tests/test_load_decathlon_datalist.py +++ b/tests/test_load_decathlon_datalist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,6 +13,7 @@ import os import tempfile import unittest +from pathlib import Path from monai.data import load_decathlon_datalist @@ -115,7 +116,7 @@ def test_additional_items(self): file_path = os.path.join(tempdir, "test_data.json") with open(file_path, "w") as json_file: json_file.write(json_str) - result = load_decathlon_datalist(file_path, True, "training", tempdir) + result = load_decathlon_datalist(file_path, True, "training", Path(tempdir)) self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz")) self.assertEqual(result[0]["label"], os.path.join(tempdir, "spleen_19.nii.gz")) self.assertEqual(result[1]["mask"], os.path.join(tempdir, "mask31.txt")) diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 2aa6eced65..201fe2fd5b 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -68,11 +68,7 @@ def get_data(self, _obj): (3, 128, 128, 128), ] -TEST_CASE_5 = [ - {"reader": NibabelReader(mmap=False), "image_only": False}, - ["test_image.nii.gz"], - (128, 128, 128), -] +TEST_CASE_5 = [{"reader": NibabelReader(mmap=False), "image_only": False}, ["test_image.nii.gz"], (128, 128, 128)] TEST_CASE_6 = [{"reader": ITKReader(), "image_only": True}, ["test_image.nii.gz"], (128, 128, 128)] @@ -84,6 +80,13 @@ def get_data(self, _obj): (3, 128, 128, 128), ] +TEST_CASE_8_1 = [ + {"reader": ITKReader(channel_dim=0), "image_only": True}, + ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], + (384, 128, 128), +] + + TEST_CASE_9 = [ {"reader": ITKReader(), "image_only": False}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], @@ -94,12 +97,41 @@ def get_data(self, _obj): {"image_only": False, "reader": ITKReader(pixel_type=itk.UC)}, "tests/testing_data/CT_DICOM", (16, 16, 4), + (16, 16, 4), ] TEST_CASE_11 = [ {"image_only": False, "reader": "ITKReader", "pixel_type": itk.UC}, "tests/testing_data/CT_DICOM", (16, 16, 4), + (16, 16, 4), +] + +TEST_CASE_12 = [ + {"image_only": False, "reader": "ITKReader", "pixel_type": itk.UC, "reverse_indexing": True}, + "tests/testing_data/CT_DICOM", + (16, 16, 4), + (4, 16, 16), +] + +TEST_CASE_13 = [{"reader": "nibabelreader", "channel_dim": 0}, "test_image.nii.gz", (3, 128, 128, 128)] + +TEST_CASE_14 = [ + {"reader": "nibabelreader", "channel_dim": -1, "ensure_channel_first": True}, + "test_image.nii.gz", + (128, 128, 128, 3), +] + +TEST_CASE_15 = [{"reader": "nibabelreader", "channel_dim": 2}, "test_image.nii.gz", (128, 128, 3, 128)] + +TEST_CASE_16 = [{"reader": "itkreader", "channel_dim": 0}, "test_image.nii.gz", (3, 128, 128, 128)] + +TEST_CASE_17 = [{"reader": "monai.data.ITKReader", "channel_dim": -1}, "test_image.nii.gz", (128, 128, 128, 3)] + +TEST_CASE_18 = [ + {"reader": "ITKReader", "channel_dim": 2, "ensure_channel_first": True}, + "test_image.nii.gz", + (128, 128, 3, 128), ] @@ -123,7 +155,7 @@ def test_nibabel_reader(self, input_param, filenames, expected_shape): np.testing.assert_allclose(header["original_affine"], np.eye(4)) self.assertTupleEqual(result.shape, expected_shape) - @parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9]) + @parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_8_1, TEST_CASE_9]) def test_itk_reader(self, input_param, filenames, expected_shape): test_image = np.random.rand(128, 128, 128) with tempfile.TemporaryDirectory() as tempdir: @@ -142,11 +174,11 @@ def test_itk_reader(self, input_param, filenames, expected_shape): np.testing.assert_allclose(header["original_affine"], np_diag) self.assertTupleEqual(result.shape, expected_shape) - @parameterized.expand([TEST_CASE_10, TEST_CASE_11]) - def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape): + @parameterized.expand([TEST_CASE_10, TEST_CASE_11, TEST_CASE_12]) + def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape, expected_np_shape): result, header = LoadImage(**input_param)(filenames) self.assertTrue("affine" in header) - self.assertEqual(header["filename_or_obj"], filenames) + self.assertEqual(header["filename_or_obj"], f"{Path(filenames)}") np.testing.assert_allclose( header["affine"], np.array( @@ -158,8 +190,8 @@ def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape): ] ), ) - self.assertTupleEqual(result.shape, expected_shape) self.assertTupleEqual(tuple(header["spatial_shape"]), expected_shape) + self.assertTupleEqual(result.shape, expected_np_shape) def test_itk_reader_multichannel(self): test_image = np.random.randint(0, 256, size=(256, 224, 3)).astype("uint8") @@ -167,12 +199,31 @@ def test_itk_reader_multichannel(self): filename = os.path.join(tempdir, "test_image.png") itk_np_view = itk.image_view_from_array(test_image, is_vector=True) itk.imwrite(itk_np_view, filename) - result, header = LoadImage(reader=ITKReader())(Path(filename)) + for flag in (False, True): + result, header = LoadImage(reader=ITKReader(reverse_indexing=flag))(Path(filename)) + + self.assertTupleEqual(tuple(header["spatial_shape"]), (224, 256)) + test_image = test_image.transpose(1, 0, 2) + np.testing.assert_allclose(result[:, :, 0], test_image[:, :, 0]) + np.testing.assert_allclose(result[:, :, 1], test_image[:, :, 1]) + np.testing.assert_allclose(result[:, :, 2], test_image[:, :, 2]) - self.assertTupleEqual(tuple(header["spatial_shape"]), (224, 256)) - np.testing.assert_allclose(result[:, :, 0], test_image[:, :, 0].T) - np.testing.assert_allclose(result[:, :, 1], test_image[:, :, 1].T) - np.testing.assert_allclose(result[:, :, 2], test_image[:, :, 2].T) + def test_load_nifti_multichannel(self): + test_image = np.random.randint(0, 256, size=(31, 64, 16, 2)).astype(np.float32) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, "test_image.nii.gz") + itk_np_view = itk.image_view_from_array(test_image, is_vector=True) + itk.imwrite(itk_np_view, filename) + + itk_img, itk_header = LoadImage(reader=ITKReader())(Path(filename)) + self.assertTupleEqual(tuple(itk_header["spatial_shape"]), (16, 64, 31)) + self.assertTupleEqual(tuple(itk_img.shape), (16, 64, 31, 2)) + + nib_image, nib_header = LoadImage(reader=NibabelReader(squeeze_non_spatial_dims=True))(Path(filename)) + self.assertTupleEqual(tuple(nib_header["spatial_shape"]), (16, 64, 31)) + self.assertTupleEqual(tuple(nib_image.shape), (16, 64, 31, 2)) + + np.testing.assert_allclose(itk_img, nib_image, atol=1e-3, rtol=1e-3) def test_load_png(self): spatial_size = (256, 224) @@ -230,6 +281,29 @@ def test_my_reader(self): out = LoadImage()("test", reader=_MiniReader(is_compatible=False)) self.assertEqual(out[1]["name"], "my test") + def test_itk_meta(self): + """test metadata from a directory""" + out, meta = LoadImage(reader="ITKReader", pixel_type=itk.UC, series_meta=True)("tests/testing_data/CT_DICOM") + idx = "0008|103e" + label = itk.GDCMImageIO.GetLabelFromTag(idx, "")[1] + val = meta[idx] + expected = "Series Description=Routine Brain " + self.assertEqual(f"{label}={val}", expected) + + @parameterized.expand([TEST_CASE_13, TEST_CASE_14, TEST_CASE_15, TEST_CASE_16, TEST_CASE_17, TEST_CASE_18]) + def test_channel_dim(self, input_param, filename, expected_shape): + test_image = np.random.rand(*expected_shape) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, filename) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), filename) + result = LoadImage(**input_param)(filename) + + self.assertTupleEqual( + result[0].shape, (3, 128, 128, 128) if input_param.get("ensure_channel_first", False) else expected_shape + ) + self.assertTupleEqual(tuple(result[1]["spatial_shape"]), (128, 128, 128)) + self.assertEqual(result[1]["original_channel_dim"], input_param["channel_dim"]) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_load_imaged.py b/tests/test_load_imaged.py index ca5b56a7d9..bc001cf2fd 100644 --- a/tests/test_load_imaged.py +++ b/tests/test_load_imaged.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,6 +21,7 @@ from monai.data import ITKReader from monai.transforms import Compose, EnsureChannelFirstD, LoadImaged, SaveImageD +from monai.utils.enums import PostFix KEYS = ["image", "label", "extra"] @@ -54,7 +55,7 @@ def test_register(self): loader = LoadImaged(keys="img") loader.register(ITKReader()) result = loader({"img": Path(filename)}) - self.assertTupleEqual(tuple(result["img_meta_dict"]["spatial_shape"]), spatial_size[::-1]) + self.assertTupleEqual(tuple(result[PostFix.meta("img")]["spatial_shape"]), spatial_size[::-1]) self.assertTupleEqual(result["img"].shape, spatial_size[::-1]) def test_channel_dim(self): @@ -67,7 +68,7 @@ def test_channel_dim(self): loader = LoadImaged(keys="img") loader.register(ITKReader(channel_dim=2)) result = EnsureChannelFirstD("img")(loader({"img": filename})) - self.assertTupleEqual(tuple(result["img_meta_dict"]["spatial_shape"]), (32, 64, 128)) + self.assertTupleEqual(tuple(result[PostFix.meta("img")]["spatial_shape"]), (32, 64, 128)) self.assertTupleEqual(result["img"].shape, (3, 32, 64, 128)) def test_no_file(self): @@ -81,34 +82,24 @@ class TestConsistency(unittest.TestCase): def _cmp(self, filename, shape, ch_shape, reader_1, reader_2, outname, ext): data_dict = {"img": filename} keys = data_dict.keys() - xforms = Compose( - [ - LoadImaged(keys, reader=reader_1), - EnsureChannelFirstD(keys), - ] - ) + xforms = Compose([LoadImaged(keys, reader=reader_1, ensure_channel_first=True)]) img_dict = xforms(data_dict) # load dicom with itk self.assertTupleEqual(img_dict["img"].shape, ch_shape) - self.assertTupleEqual(tuple(img_dict["img_meta_dict"]["spatial_shape"]), shape) + self.assertTupleEqual(tuple(img_dict[PostFix.meta("img")]["spatial_shape"]), shape) with tempfile.TemporaryDirectory() as tempdir: save_xform = SaveImageD( - keys, meta_keys="img_meta_dict", output_dir=tempdir, squeeze_end_dims=False, output_ext=ext + keys, meta_keys=PostFix.meta("img"), output_dir=tempdir, squeeze_end_dims=False, output_ext=ext ) save_xform(img_dict) # save to nifti - new_xforms = Compose( - [ - LoadImaged(keys, reader=reader_2), - EnsureChannelFirstD(keys), - ] - ) + new_xforms = Compose([LoadImaged(keys, reader=reader_2), EnsureChannelFirstD(keys)]) out = new_xforms({"img": os.path.join(tempdir, outname)}) # load nifti with itk self.assertTupleEqual(out["img"].shape, ch_shape) - self.assertTupleEqual(tuple(out["img_meta_dict"]["spatial_shape"]), shape) - if "affine" in img_dict["img_meta_dict"] and "affine" in out["img_meta_dict"]: + self.assertTupleEqual(tuple(out[PostFix.meta("img")]["spatial_shape"]), shape) + if "affine" in img_dict[PostFix.meta("img")] and "affine" in out[PostFix.meta("img")]: np.testing.assert_allclose( - img_dict["img_meta_dict"]["affine"], out["img_meta_dict"]["affine"], rtol=1e-3 + img_dict[PostFix.meta("img")]["affine"], out[PostFix.meta("img")]["affine"], rtol=1e-3 ) np.testing.assert_allclose(out["img"], img_dict["img"], rtol=1e-3) diff --git a/tests/test_load_spacing_orientation.py b/tests/test_load_spacing_orientation.py index 48aac7ec56..2792822c3d 100644 --- a/tests/test_load_spacing_orientation.py +++ b/tests/test_load_spacing_orientation.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,6 +19,7 @@ from parameterized import parameterized from monai.transforms import AddChanneld, LoadImaged, Orientationd, Spacingd +from monai.utils.enums import PostFix FILES = tuple( os.path.join(os.path.dirname(__file__), "testing_data", filename) @@ -36,12 +37,12 @@ def test_load_spacingd(self, filename): res_dict = Spacingd(keys="image", pixdim=(1, 0.2, 1), diagonal=True, padding_mode="zeros")(data_dict) t1 = time.time() print(f"time monai: {t1 - t}") - anat = nibabel.Nifti1Image(data_dict["image"][0], data_dict["image_meta_dict"]["original_affine"]) + anat = nibabel.Nifti1Image(data_dict["image"][0], data_dict[PostFix.meta("image")]["original_affine"]) ref = resample_to_output(anat, (1, 0.2, 1), order=1) t2 = time.time() print(f"time scipy: {t2 - t1}") self.assertTrue(t2 >= t1) - np.testing.assert_allclose(res_dict["image_meta_dict"]["affine"], ref.affine) + np.testing.assert_allclose(res_dict[PostFix.meta("image")]["affine"], ref.affine) np.testing.assert_allclose(res_dict["image"].shape[1:], ref.shape) np.testing.assert_allclose(ref.get_fdata(), res_dict["image"][0], atol=0.05) @@ -50,20 +51,20 @@ def test_load_spacingd_rotate(self, filename): data = {"image": filename} data_dict = LoadImaged(keys="image")(data) data_dict = AddChanneld(keys="image")(data_dict) - affine = data_dict["image_meta_dict"]["affine"] - data_dict["image_meta_dict"]["original_affine"] = data_dict["image_meta_dict"]["affine"] = ( + affine = data_dict[PostFix.meta("image")]["affine"] + data_dict[PostFix.meta("image")]["original_affine"] = data_dict[PostFix.meta("image")]["affine"] = ( np.array([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]) @ affine ) t = time.time() res_dict = Spacingd(keys="image", pixdim=(1, 2, 3), diagonal=True, padding_mode="zeros")(data_dict) t1 = time.time() print(f"time monai: {t1 - t}") - anat = nibabel.Nifti1Image(data_dict["image"][0], data_dict["image_meta_dict"]["original_affine"]) + anat = nibabel.Nifti1Image(data_dict["image"][0], data_dict[PostFix.meta("image")]["original_affine"]) ref = resample_to_output(anat, (1, 2, 3), order=1) t2 = time.time() print(f"time scipy: {t2 - t1}") self.assertTrue(t2 >= t1) - np.testing.assert_allclose(res_dict["image_meta_dict"]["affine"], ref.affine) + np.testing.assert_allclose(res_dict[PostFix.meta("image")]["affine"], ref.affine) if "anatomical" not in filename: np.testing.assert_allclose(res_dict["image"].shape[1:], ref.shape) np.testing.assert_allclose(ref.get_fdata(), res_dict["image"][0], atol=0.05) @@ -76,13 +77,13 @@ def test_load_spacingd_non_diag(self): data = {"image": FILES[1]} data_dict = LoadImaged(keys="image")(data) data_dict = AddChanneld(keys="image")(data_dict) - affine = data_dict["image_meta_dict"]["affine"] - data_dict["image_meta_dict"]["original_affine"] = data_dict["image_meta_dict"]["affine"] = ( + affine = data_dict[PostFix.meta("image")]["affine"] + data_dict[PostFix.meta("image")]["original_affine"] = data_dict[PostFix.meta("image")]["affine"] = ( np.array([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]) @ affine ) res_dict = Spacingd(keys="image", pixdim=(1, 2, 3), diagonal=False, padding_mode="zeros")(data_dict) np.testing.assert_allclose( - res_dict["image_meta_dict"]["affine"], + res_dict[PostFix.meta("image")]["affine"], np.array( [ [0.0, 0.0, 3.0, -27.599409], @@ -99,7 +100,7 @@ def test_load_spacingd_rotate_non_diag(self): data_dict = AddChanneld(keys="image")(data_dict) res_dict = Spacingd(keys="image", pixdim=(1, 2, 3), diagonal=False, padding_mode="border")(data_dict) np.testing.assert_allclose( - res_dict["image_meta_dict"]["affine"], + res_dict[PostFix.meta("image")]["affine"], np.array([[-1.0, 0.0, 0.0, 32.0], [0.0, 2.0, 0.0, -40.0], [0.0, 0.0, 3.0, -16.0], [0.0, 0.0, 0.0, 1.0]]), ) @@ -110,7 +111,7 @@ def test_load_spacingd_rotate_non_diag_ornt(self): res_dict = Spacingd(keys="image", pixdim=(1, 2, 3), diagonal=False, padding_mode="border")(data_dict) res_dict = Orientationd(keys="image", axcodes="LPI")(res_dict) np.testing.assert_allclose( - res_dict["image_meta_dict"]["affine"], + res_dict[PostFix.meta("image")]["affine"], np.array([[-1.0, 0.0, 0.0, 32.0], [0.0, -2.0, 0.0, 40.0], [0.0, 0.0, -3.0, 32.0], [0.0, 0.0, 0.0, 1.0]]), ) @@ -118,14 +119,14 @@ def test_load_spacingd_non_diag_ornt(self): data = {"image": FILES[1]} data_dict = LoadImaged(keys="image")(data) data_dict = AddChanneld(keys="image")(data_dict) - affine = data_dict["image_meta_dict"]["affine"] - data_dict["image_meta_dict"]["original_affine"] = data_dict["image_meta_dict"]["affine"] = ( + affine = data_dict[PostFix.meta("image")]["affine"] + data_dict[PostFix.meta("image")]["original_affine"] = data_dict[PostFix.meta("image")]["affine"] = ( np.array([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]) @ affine ) res_dict = Spacingd(keys="image", pixdim=(1, 2, 3), diagonal=False, padding_mode="border")(data_dict) res_dict = Orientationd(keys="image", axcodes="LPI")(res_dict) np.testing.assert_allclose( - res_dict["image_meta_dict"]["affine"], + res_dict[PostFix.meta("image")]["affine"], np.array( [ [-3.0, 0.0, 0.0, 56.4005909], diff --git a/tests/test_loader_semaphore.py b/tests/test_loader_semaphore.py index 85c6d54f35..bbb2d4eef6 100644 --- a/tests/test_loader_semaphore.py +++ b/tests/test_loader_semaphore.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py index 31954e727b..8070c27f90 100644 --- a/tests/test_local_normalized_cross_correlation_loss.py +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_localnet.py b/tests/test_localnet.py index dc680f15f9..1a288fb447 100644 --- a/tests/test_localnet.py +++ b/tests/test_localnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_localnet_block.py b/tests/test_localnet_block.py index f4e857a0fa..d85509344e 100644 --- a/tests/test_localnet_block.py +++ b/tests/test_localnet_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_look_up_option.py b/tests/test_look_up_option.py index 60786f2fc5..89fec1b575 100644 --- a/tests/test_look_up_option.py +++ b/tests/test_look_up_option.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py index 5b730c2a77..a76808be20 100644 --- a/tests/test_lr_finder.py +++ b/tests/test_lr_finder.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,6 +10,7 @@ # limitations under the License. import os +import pickle import random import sys import unittest @@ -23,6 +24,7 @@ from monai.optimizers import LearningRateFinder from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord from monai.utils import optional_import, set_determinism +from tests.utils import skip_if_downloading_fails if TYPE_CHECKING: import matplotlib.pyplot as plt @@ -60,14 +62,15 @@ def setUp(self): def test_lr_finder(self): # 0.001 gives 54 examples - train_ds = MedNISTDataset( - root_dir=self.root_dir, - transform=self.transforms, - section="validation", - val_frac=0.001, - download=True, - num_workers=10, - ) + with skip_if_downloading_fails(): + train_ds = MedNISTDataset( + root_dir=self.root_dir, + transform=self.transforms, + section="validation", + val_frac=0.001, + download=True, + num_workers=10, + ) train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10) num_classes = train_ds.get_num_classes() @@ -78,7 +81,14 @@ def test_lr_finder(self): learning_rate = 1e-5 optimizer = torch.optim.Adam(model.parameters(), learning_rate) - lr_finder = LearningRateFinder(model, optimizer, loss_function, device=device) + lr_finder = LearningRateFinder( + model=model, + optimizer=optimizer, + criterion=loss_function, + device=device, + pickle_module=pickle, + pickle_protocol=4, + ) lr_finder.range_test(train_loader, val_loader=train_loader, end_lr=10, num_iter=5) print(lr_finder.get_steepest_gradient(0, 0)[0]) diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py index aa126f7848..a3e1ea9dd6 100644 --- a/tests/test_lr_scheduler.py +++ b/tests/test_lr_scheduler.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,7 +19,7 @@ class SchedulerTestNet(torch.nn.Module): def __init__(self): - super(SchedulerTestNet, self).__init__() + super().__init__() self.conv1 = torch.nn.Conv2d(1, 1, 1) self.conv2 = torch.nn.Conv2d(1, 1, 1) @@ -28,13 +28,7 @@ def forward(self, x): TEST_CASE_LRSCHEDULER = [ - [ - { - "warmup_steps": 2, - "t_total": 10, - }, - [0.000, 0.500, 1.00, 0.962, 0.854, 0.691, 0.500, 0.309, 0.146, 0.038], - ] + [{"warmup_steps": 2, "t_total": 10}, [0.000, 0.500, 1.00, 0.962, 0.854, 0.691, 0.500, 0.309, 0.146, 0.038]] ] @@ -47,11 +41,11 @@ def test_shape(self, input_param, expected_lr): self.assertEqual(len([scheduler.get_last_lr()[0]]), 1) lrs_1 = [] for _ in range(input_param["t_total"]): - lrs_1.append(float("{:.3f}".format(scheduler.get_last_lr()[0]))) + lrs_1.append(float(f"{scheduler.get_last_lr()[0]:.3f}")) optimizer.step() scheduler.step() for a, b in zip(lrs_1, expected_lr): - self.assertEqual(a, b, msg="LR is wrong ! expected {}, got {}".format(b, a)) + self.assertEqual(a, b, msg=f"LR is wrong ! expected {b}, got {a}") if __name__ == "__main__": diff --git a/tests/test_make_nifti.py b/tests/test_make_nifti.py new file mode 100644 index 0000000000..951f079764 --- /dev/null +++ b/tests/test_make_nifti.py @@ -0,0 +1,43 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data.synthetic import create_test_image_2d +from monai.utils import optional_import +from tests.utils import make_nifti_image + +_, has_nib = optional_import("nibabel") + +TESTS = [] +for affine in (None, np.eye(4), torch.eye(4)): + for dir in (None, tempfile.mkdtemp()): + for fname in (None, "fname"): + TESTS.append([{"affine": affine, "dir": dir, "fname": fname}]) + + +@unittest.skipUnless(has_nib, "Requires nibabel") +class TestMakeNifti(unittest.TestCase): + @parameterized.expand(TESTS) + def test_make_nifti(self, params): + im, _ = create_test_image_2d(100, 88) + created_file = make_nifti_image(im, verbose=True, **params) + self.assertTrue(os.path.isfile(created_file)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_map_binary_to_indices.py b/tests/test_map_binary_to_indices.py index 1fafa6f446..bc96231160 100644 --- a/tests/test_map_binary_to_indices.py +++ b/tests/test_map_binary_to_indices.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,50 +15,58 @@ from parameterized import parameterized from monai.transforms import map_binary_to_indices +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": None, "image_threshold": 0.0}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 4, 8]), -] - -TEST_CASE_2 = [ - { - "label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), - "image": np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]]), - "image_threshold": 0.0, - }, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] - -TEST_CASE_3 = [ - { - "label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), - "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), - "image_threshold": 1.0, - }, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] - -TEST_CASE_4 = [ - { - "label": np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), - "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), - "image_threshold": 1.0, - }, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"label": p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])), "image": None, "image_threshold": 0.0}, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 4, 8]), + ] + ) + TESTS.append( + [ + { + "label": p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])), + "image": p(np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]])), + "image_threshold": 0.0, + }, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 8]), + ] + ) + TESTS.append( + [ + { + "label": p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])), + "image": p(np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])), + "image_threshold": 1.0, + }, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 8]), + ] + ) + TESTS.append( + [ + { + "label": p(np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]])), + "image": p(np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])), + "image_threshold": 1.0, + }, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 8]), + ] + ) class TestMapBinaryToIndices(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand(TESTS) def test_type_shape(self, input_data, expected_fg, expected_bg): fg_indices, bg_indices = map_binary_to_indices(**input_data) - np.testing.assert_allclose(fg_indices, expected_fg) - np.testing.assert_allclose(bg_indices, expected_bg) + assert_allclose(fg_indices, expected_fg, type_test=False) + assert_allclose(bg_indices, expected_bg, type_test=False) if __name__ == "__main__": diff --git a/tests/test_map_classes_to_indices.py b/tests/test_map_classes_to_indices.py index 2320954520..2f32382f6b 100644 --- a/tests/test_map_classes_to_indices.py +++ b/tests/test_map_classes_to_indices.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,86 +15,117 @@ from parameterized import parameterized from monai.transforms import map_classes_to_indices +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - # test Argmax data - {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), "num_classes": 3, "image": None, "image_threshold": 0.0}, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + # test Argmax data + { + "label": p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + "num_classes": 3, + "image": None, + "image_threshold": 0.0, + }, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], + ] + ) -TEST_CASE_2 = [ - { - "label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), - "num_classes": 3, - "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - "image_threshold": 60, - }, - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + TESTS.append( + [ + { + "label": p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + "num_classes": 3, + "image": p(np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]])), + "image_threshold": 60, + }, + [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], + ] + ) -TEST_CASE_3 = [ - # test One-Hot data - { - "label": np.array( - [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], - ] - ), - "image": None, - "image_threshold": 0.0, - }, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] + TESTS.append( + [ + # test One-Hot data + { + "label": p( + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ) + ), + "image": None, + "image_threshold": 0.0, + }, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], + ] + ) -TEST_CASE_4 = [ - { - "label": np.array( - [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], - ] - ), - "num_classes": None, - "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - "image_threshold": 60, - }, - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + TESTS.append( + [ + { + "label": p( + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ) + ), + "num_classes": None, + "image": p(np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]])), + "image_threshold": 60, + }, + [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], + ] + ) -TEST_CASE_5 = [ - # test empty class - {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), "num_classes": 5, "image": None, "image_threshold": 0.0}, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7]), np.array([]), np.array([])], -] + TESTS.append( + [ + # test empty class + { + "label": p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + "num_classes": 5, + "image": None, + "image_threshold": 0.0, + }, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7]), np.array([]), np.array([])], + ] + ) -TEST_CASE_6 = [ - # test empty class - { - "label": np.array( - [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], - [[0, 0, 0], [0, 0, 0], [0, 0, 0]], - [[0, 0, 0], [0, 0, 0], [0, 0, 0]], - ] - ), - "image": None, - "image_threshold": 0.0, - }, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7]), np.array([]), np.array([])], -] + TESTS.append( + [ + # test empty class + { + "label": p( + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + ] + ) + ), + "image": None, + "image_threshold": 0.0, + }, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7]), np.array([]), np.array([])], + ] + ) class TestMapClassesToIndices(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand(TESTS) def test_value(self, input_data, expected_indices): indices = map_classes_to_indices(**input_data) for i, e in zip(indices, expected_indices): - np.testing.assert_allclose(i, e) + assert_allclose(i, e, type_test=False) if __name__ == "__main__": diff --git a/tests/test_map_label_value.py b/tests/test_map_label_value.py index ff1d7d1eef..0416858a74 100644 --- a/tests/test_map_label_value.py +++ b/tests/test_map_label_value.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,75 +12,67 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import MapLabelValue +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"orig_labels": [3, 2, 1], "target_labels": [0, 1, 2]}, - np.array([[3, 1], [1, 2]]), - np.array([[0, 2], [2, 1]]), -] - -TEST_CASE_2 = [ - {"orig_labels": [3, 5, 8], "target_labels": [0, 1, 2]}, - np.array([[[3], [5], [5], [8]]]), - np.array([[[0], [1], [1], [2]]]), -] - -TEST_CASE_3 = [ - {"orig_labels": [1, 2, 3], "target_labels": [0, 1, 2]}, - np.array([3, 1, 1, 2]), - np.array([2, 0, 0, 1]), -] - -TEST_CASE_4 = [ - {"orig_labels": [1, 2, 3], "target_labels": [0.5, 1.5, 2.5]}, - np.array([3, 1, 1, 2]), - np.array([2.5, 0.5, 0.5, 1.5]), -] - -TEST_CASE_5 = [ - {"orig_labels": [1.5, 2.5, 3.5], "target_labels": [0, 1, 2], "dtype": np.int8}, - np.array([3.5, 1.5, 1.5, 2.5]), - np.array([2, 0, 0, 1]), -] - -TEST_CASE_6 = [ - {"orig_labels": ["label3", "label2", "label1"], "target_labels": [0, 1, 2]}, - np.array([["label3", "label1"], ["label1", "label2"]]), - np.array([[0, 2], [2, 1]]), -] - -TEST_CASE_7 = [ - {"orig_labels": [3.5, 2.5, 1.5], "target_labels": ["label0", "label1", "label2"], "dtype": "str"}, - np.array([[3.5, 1.5], [1.5, 2.5]]), - np.array([["label0", "label2"], ["label2", "label1"]]), -] - -TEST_CASE_8 = [ - {"orig_labels": ["label3", "label2", "label1"], "target_labels": ["label1", "label2", "label3"], "dtype": "str"}, - np.array([["label3", "label1"], ["label1", "label2"]]), - np.array([["label1", "label3"], ["label3", "label2"]]), -] - - -class TestMapLabelValue(unittest.TestCase): - @parameterized.expand( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.extend( [ - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - TEST_CASE_7, - TEST_CASE_8, + [{"orig_labels": [3, 2, 1], "target_labels": [0, 1, 2]}, p([[3, 1], [1, 2]]), p([[0.0, 2.0], [2.0, 1.0]])], + [ + {"orig_labels": [3, 5, 8], "target_labels": [0, 1, 2]}, + p([[[3], [5], [5], [8]]]), + p([[[0.0], [1.0], [1.0], [2.0]]]), + ], + [{"orig_labels": [1, 2, 3], "target_labels": [0, 1, 2]}, p([3, 1, 1, 2]), p([2.0, 0.0, 0.0, 1.0])], + [{"orig_labels": [1, 2, 3], "target_labels": [0.5, 1.5, 2.5]}, p([3, 1, 1, 2]), p([2.5, 0.5, 0.5, 1.5])], ] ) + # note: PyTorch 1.5.1 doesn't support rich dtypes + TESTS.append( + [ + {"orig_labels": [1.5, 2.5, 3.5], "target_labels": [0, 1, 2], "dtype": np.int8}, + p([3.5, 1.5, 1.5, 2.5]), + p([2, 0, 0, 1]), + ] + ) +TESTS.extend( + [ + [ + {"orig_labels": ["label3", "label2", "label1"], "target_labels": [0, 1, 2]}, + np.array([["label3", "label1"], ["label1", "label2"]]), + np.array([[0, 2], [2, 1]]), + ], + [ + {"orig_labels": [3.5, 2.5, 1.5], "target_labels": ["label0", "label1", "label2"], "dtype": "str"}, + np.array([[3.5, 1.5], [1.5, 2.5]]), + np.array([["label0", "label2"], ["label2", "label1"]]), + ], + [ + { + "orig_labels": ["label3", "label2", "label1"], + "target_labels": ["label1", "label2", "label3"], + "dtype": "str", + }, + np.array([["label3", "label1"], ["label1", "label2"]]), + np.array([["label1", "label3"], ["label3", "label2"]]), + ], + ] +) + + +class TestMapLabelValue(unittest.TestCase): + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_value): result = MapLabelValue(**input_param)(input_data) - np.testing.assert_equal(result, expected_value) + if isinstance(expected_value, torch.Tensor): + torch.testing.assert_allclose(result, expected_value) + else: + np.testing.assert_equal(result, expected_value) self.assertTupleEqual(result.shape, expected_value.shape) diff --git a/tests/test_map_label_valued.py b/tests/test_map_label_valued.py index 426ac28836..cf8ca6c8e2 100644 --- a/tests/test_map_label_valued.py +++ b/tests/test_map_label_valued.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_map_transform.py b/tests/test_map_transform.py index 803e699a7d..dd77ccb099 100644 --- a/tests/test_map_transform.py +++ b/tests/test_map_transform.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_mask_intensity.py b/tests/test_mask_intensity.py index a3662eec49..b6cfe0e10c 100644 --- a/tests/test_mask_intensity.py +++ b/tests/test_mask_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import MaskIntensity @@ -43,9 +44,15 @@ np.array([[[0, 0, 0], [2, 2, 2], [0, 0, 0]], [[0, 0, 0], [5, 5, 5], [0, 0, 0]]]), ] +TEST_CASE_5 = [ + {"mask_data": np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [0, 1, 0], [0, 1, 0]]])}, + torch.as_tensor([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + torch.as_tensor([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]), +] + class TestMaskIntensity(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_value(self, argments, image, expected_data): result = MaskIntensity(**argments)(image) np.testing.assert_allclose(result, expected_data) diff --git a/tests/test_mask_intensityd.py b/tests/test_mask_intensityd.py index c21e26eba6..fe61e7be04 100644 --- a/tests/test_mask_intensityd.py +++ b/tests/test_mask_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_masked_dice_loss.py b/tests/test_masked_dice_loss.py index b8d69bc8f9..317da3a316 100644 --- a/tests/test_masked_dice_loss.py +++ b/tests/test_masked_dice_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -67,7 +67,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - [[0.296529, 0.415136], [0.599976, 0.428559]], + [[[0.296529], [0.415136]], [[0.599976], [0.428559]]], ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": True, "to_onehot_y": True, "softmax": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, @@ -94,26 +94,17 @@ ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.307576, ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "squared_pred": True, "smooth_nr": 1e-5, "smooth_dr": 1e-5}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.178337, ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "jaccard": True, "smooth_nr": 1e-5, "smooth_dr": 1e-5}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.470451, ], ] diff --git a/tests/test_masked_inference_wsi_dataset.py b/tests/test_masked_inference_wsi_dataset.py index 361c17e106..c29b95a2d8 100644 --- a/tests/test_masked_inference_wsi_dataset.py +++ b/tests/test_masked_inference_wsi_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,16 +18,16 @@ from parameterized import parameterized from monai.apps.pathology.data import MaskedInferenceWSIDataset -from monai.apps.utils import download_url from monai.utils import optional_import -from tests.utils import skip_if_quick +from tests.utils import download_url_or_skip_test, skip_if_quick, testing_data_config -_, has_cim = optional_import("cucim") +_, has_cim = optional_import("cucim", name="CuImage") _, has_osl = optional_import("openslide") -FILE_URL = "http://openslide.cs.cmu.edu/download/openslide-testdata/Generic-TIFF/CMU-1.tiff" -base_name, extension = os.path.splitext(os.path.basename(FILE_URL)) -FILE_NAME = "temp_" + base_name +FILE_KEY = "wsi_img" +FILE_URL = testing_data_config("images", FILE_KEY, "url") +base_name, extension = os.path.basename(f"{FILE_URL}"), ".tiff" +FILE_NAME = f"temp_{base_name}" FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", FILE_NAME + extension) MASK1 = os.path.join(os.path.dirname(__file__), "testing_data", "temp_tissue_mask1.npy") @@ -50,28 +50,12 @@ def prepare_data(): TEST_CASE_0 = [ - { - "data": [ - {"image": FILE_PATH, "mask": MASK1}, - ], - "patch_size": 1, - "image_reader_name": "cuCIM", - }, - [ - { - "image": np.array([[[243]], [[243]], [[243]]], dtype=np.uint8), - "name": FILE_NAME, - "mask_location": [100, 100], - }, - ], + {"data": [{"image": FILE_PATH, "mask": MASK1}], "patch_size": 1, "image_reader_name": "cuCIM"}, + [{"image": np.array([[[243]], [[243]], [[243]]], dtype=np.uint8), "name": FILE_NAME, "mask_location": [100, 100]}], ] TEST_CASE_1 = [ - { - "data": [{"image": FILE_PATH, "mask": MASK2}], - "patch_size": 1, - "image_reader_name": "cuCIM", - }, + {"data": [{"image": FILE_PATH, "mask": MASK2}], "patch_size": 1, "image_reader_name": "cuCIM"}, [ { "image": np.array([[[243]], [[243]], [[243]]], dtype=np.uint8), @@ -87,11 +71,7 @@ def prepare_data(): ] TEST_CASE_2 = [ - { - "data": [{"image": FILE_PATH, "mask": MASK4}], - "patch_size": 1, - "image_reader_name": "cuCIM", - }, + {"data": [{"image": FILE_PATH, "mask": MASK4}], "patch_size": 1, "image_reader_name": "cuCIM"}, [ { "image": np.array([[[243]], [[243]], [[243]]], dtype=np.uint8), @@ -117,35 +97,21 @@ def prepare_data(): ] TEST_CASE_3 = [ - { - "data": [ - {"image": FILE_PATH, "mask": MASK1}, - ], - "patch_size": 2, - "image_reader_name": "cuCIM", - }, + {"data": [{"image": FILE_PATH, "mask": MASK1}], "patch_size": 2, "image_reader_name": "cuCIM"}, [ { "image": np.array( - [ - [[243, 243], [243, 243]], - [[243, 243], [243, 243]], - [[243, 243], [243, 243]], - ], - dtype=np.uint8, + [[[243, 243], [243, 243]], [[243, 243], [243, 243]], [[243, 243], [243, 243]]], dtype=np.uint8 ), "name": FILE_NAME, "mask_location": [100, 100], - }, + } ], ] TEST_CASE_4 = [ { - "data": [ - {"image": FILE_PATH, "mask": MASK1}, - {"image": FILE_PATH, "mask": MASK2}, - ], + "data": [{"image": FILE_PATH, "mask": MASK1}, {"image": FILE_PATH, "mask": MASK2}], "patch_size": 1, "image_reader_name": "cuCIM", }, @@ -170,28 +136,12 @@ def prepare_data(): TEST_CASE_OPENSLIDE_0 = [ - { - "data": [ - {"image": FILE_PATH, "mask": MASK1}, - ], - "patch_size": 1, - "image_reader_name": "OpenSlide", - }, - [ - { - "image": np.array([[[243]], [[243]], [[243]]], dtype=np.uint8), - "name": FILE_NAME, - "mask_location": [100, 100], - }, - ], + {"data": [{"image": FILE_PATH, "mask": MASK1}], "patch_size": 1, "image_reader_name": "OpenSlide"}, + [{"image": np.array([[[243]], [[243]], [[243]]], dtype=np.uint8), "name": FILE_NAME, "mask_location": [100, 100]}], ] TEST_CASE_OPENSLIDE_1 = [ - { - "data": [{"image": FILE_PATH, "mask": MASK2}], - "patch_size": 1, - "image_reader_name": "OpenSlide", - }, + {"data": [{"image": FILE_PATH, "mask": MASK2}], "patch_size": 1, "image_reader_name": "OpenSlide"}, [ { "image": np.array([[[243]], [[243]], [[243]]], dtype=np.uint8), @@ -210,29 +160,18 @@ def prepare_data(): class TestMaskedInferenceWSIDataset(unittest.TestCase): def setUp(self): prepare_data() - download_url(FILE_URL, FILE_PATH, "5a3cfd4fd725c50578ddb80b517b759f") - - @parameterized.expand( - [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - ] - ) + hash_type = testing_data_config("images", FILE_KEY, "hash_type") + hash_val = testing_data_config("images", FILE_KEY, "hash_val") + download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val) + + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) @skipUnless(has_cim, "Requires CuCIM") @skip_if_quick def test_read_patches_cucim(self, input_parameters, expected): dataset = MaskedInferenceWSIDataset(**input_parameters) self.compare_samples_expected(dataset, expected) - @parameterized.expand( - [ - TEST_CASE_OPENSLIDE_0, - TEST_CASE_OPENSLIDE_1, - ] - ) + @parameterized.expand([TEST_CASE_OPENSLIDE_0, TEST_CASE_OPENSLIDE_1]) @skipUnless(has_osl, "Requires OpenSlide") @skip_if_quick def test_read_patches_openslide(self, input_parameters, expected): diff --git a/tests/test_masked_loss.py b/tests/test_masked_loss.py index 225e3d9668..9f28d51aa4 100644 --- a/tests/test_masked_loss.py +++ b/tests/test_masked_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -33,7 +33,7 @@ "reduction": "sum", }, [(14.538666, 20.191753), (13.17672, 8.251623)], - ], + ] ] diff --git a/tests/test_matshow3d.py b/tests/test_matshow3d.py new file mode 100644 index 0000000000..83984d1556 --- /dev/null +++ b/tests/test_matshow3d.py @@ -0,0 +1,109 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import numpy as np + +from monai.transforms import AddChanneld, Compose, LoadImaged, RandSpatialCropSamplesd, RepeatChanneld, ScaleIntensityd +from monai.utils import optional_import +from monai.visualize.utils import matshow3d +from tests.utils import SkipIfNoModule + +compare_images, _ = optional_import("matplotlib.testing.compare", name="compare_images") +pyplot, has_pyplot = optional_import("matplotlib", name="pyplot") + + +@SkipIfNoModule("matplotlib") +class TestMatshow3d(unittest.TestCase): + def test_3d(self): + testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") + keys = "image" + xforms = Compose([LoadImaged(keys=keys), AddChanneld(keys=keys), ScaleIntensityd(keys=keys)]) + image_path = os.path.join(testing_dir, "anatomical.nii") + ims = xforms({keys: image_path}) + + fig = pyplot.figure() # external figure + fig, _ = matshow3d(ims[keys], fig=fig, figsize=(2, 2), frames_per_row=5, every_n=2, frame_dim=-1, show=False) + + with tempfile.TemporaryDirectory() as tempdir: + tempimg = f"{tempdir}/matshow3d_test.png" + fig.savefig(tempimg) + comp = compare_images(f"{testing_dir}/matshow3d_test.png", tempimg, 5e-2) + self.assertIsNone(comp, f"value of comp={comp}") # None indicates test passed + + def test_samples(self): + testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") + keys = "image" + xforms = Compose( + [ + LoadImaged(keys=keys), + AddChanneld(keys=keys), + ScaleIntensityd(keys=keys), + RandSpatialCropSamplesd(keys=keys, roi_size=(8, 8, 5), random_size=True, num_samples=10), + ] + ) + image_path = os.path.join(testing_dir, "anatomical.nii") + xforms.set_random_state(0) + ims = xforms({keys: image_path}) + fig, mat = matshow3d( + [im[keys] for im in ims], title=f"testing {keys}", figsize=(2, 2), frames_per_row=5, every_n=2, show=False + ) + self.assertTrue(mat.dtype == np.float32) + + with tempfile.TemporaryDirectory() as tempdir: + tempimg = f"{tempdir}/matshow3d_patch_test.png" + fig.savefig(tempimg) + comp = compare_images(f"{testing_dir}/matshow3d_patch_test.png", tempimg, 5e-2, in_decorator=True) + if comp: + print("not none comp: ", comp) # matplotlib 3.2.2 + np.testing.assert_allclose(comp["rms"], 30.786983, atol=1e-3, rtol=1e-3) + else: + self.assertIsNone(comp, f"value of comp={comp}") # None indicates test passed + + def test_3d_rgb(self): + testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") + keys = "image" + xforms = Compose( + [ + LoadImaged(keys=keys), + AddChanneld(keys=keys), + ScaleIntensityd(keys=keys), + # change to RGB color image + RepeatChanneld(keys=keys, repeats=3), + ] + ) + image_path = os.path.join(testing_dir, "anatomical.nii") + ims = xforms({keys: image_path}) + + fig = pyplot.figure() # external figure + fig, _ = matshow3d( + volume=ims[keys], + fig=fig, + figsize=(2, 2), + frames_per_row=5, + every_n=2, + frame_dim=-1, + channel_dim=0, + show=False, + ) + + with tempfile.TemporaryDirectory() as tempdir: + tempimg = f"{tempdir}/matshow3d_rgb_test.png" + fig.savefig(tempimg) + comp = compare_images(f"{testing_dir}/matshow3d_rgb_test.png", tempimg, 5e-2) + self.assertIsNone(comp, f"value of comp={comp}") # None indicates test passed + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_mean_ensemble.py b/tests/test_mean_ensemble.py index 7e08846beb..b14f6f01d3 100644 --- a/tests/test_mean_ensemble.py +++ b/tests/test_mean_ensemble.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,49 +16,50 @@ from parameterized import parameterized from monai.transforms import MeanEnsemble +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"weights": None}, - [torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2], - torch.ones(2, 2, 2) + 1, -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([{"weights": None}, [p(torch.ones(2, 2, 2)), p(torch.ones(2, 2, 2)) + 2], p(torch.ones(2, 2, 2)) + 1]) -TEST_CASE_2 = [ - {"weights": None}, - torch.stack([torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2]), - torch.ones(2, 2, 2) + 1, -] + TESTS.append( + [{"weights": None}, p(torch.stack([torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2])), p(torch.ones(2, 2, 2)) + 1] + ) -TEST_CASE_3 = [ - {"weights": [1, 3]}, - [torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2], - torch.ones(2, 2, 2) * 2.5, -] + TESTS.append( + [{"weights": [1, 3]}, [p(torch.ones(2, 2, 2)), p(torch.ones(2, 2, 2)) + 2], p(torch.ones(2, 2, 2)) * 2.5] + ) -TEST_CASE_4 = [ - {"weights": [[1, 3], [3, 1]]}, - [torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2], - torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1), -] + TESTS.append( + [ + {"weights": [[1, 3], [3, 1]]}, + [p(torch.ones(2, 2, 2)), p(torch.ones(2, 2, 2)) + 2], + p(torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1)), + ] + ) -TEST_CASE_5 = [ - {"weights": np.array([[1, 3], [3, 1]])}, - [torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2], - torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1), -] + TESTS.append( + [ + {"weights": np.array([[1, 3], [3, 1]])}, + [p(torch.ones(2, 2, 2)), p(torch.ones(2, 2, 2)) + 2], + p(torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1)), + ] + ) -TEST_CASE_6 = [ - {"weights": torch.tensor([[[1, 3]], [[3, 1]]])}, - [torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2) + 2], - torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1), -] + TESTS.append( + [ + {"weights": torch.tensor([[[1, 3]], [[3, 1]]])}, + [p(torch.ones(2, 2, 2, 2)), p(torch.ones(2, 2, 2, 2)) + 2], + p(torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1)), + ] + ) class TestMeanEnsemble(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand(TESTS) def test_value(self, input_param, img, expected_value): result = MeanEnsemble(**input_param)(img) - torch.testing.assert_allclose(result, expected_value) + assert_allclose(result, expected_value) def test_cuda_value(self): img = torch.stack([torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2) + 2]) diff --git a/tests/test_mean_ensembled.py b/tests/test_mean_ensembled.py index ea77ef18a0..b5e1569d65 100644 --- a/tests/test_mean_ensembled.py +++ b/tests/test_mean_ensembled.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,46 +16,61 @@ from parameterized import parameterized from monai.transforms import MeanEnsembled +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"keys": ["pred0", "pred1"], "output_key": "output", "weights": None}, - {"pred0": torch.ones(2, 2, 2), "pred1": torch.ones(2, 2, 2) + 2}, - torch.ones(2, 2, 2) + 1, -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["pred0", "pred1"], "output_key": "output", "weights": None}, + {"pred0": p(torch.ones(2, 2, 2)), "pred1": p(torch.ones(2, 2, 2)) + 2}, + p(torch.ones(2, 2, 2)) + 1, + ] + ) -TEST_CASE_2 = [ - {"keys": "output", "weights": None}, - {"output": torch.stack([torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2])}, - torch.ones(2, 2, 2) + 1, -] + TESTS.append( + [ + {"keys": "output", "weights": None}, + {"output": p(torch.stack([torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2]))}, + p(torch.ones(2, 2, 2)) + 1, + ] + ) -TEST_CASE_3 = [ - {"keys": ["pred0", "pred1"], "output_key": "output", "weights": [1, 3]}, - {"pred0": torch.ones(2, 2, 2, 2), "pred1": torch.ones(2, 2, 2, 2) + 2}, - torch.ones(2, 2, 2, 2) * 2.5, -] + TESTS.append( + [ + {"keys": ["pred0", "pred1"], "output_key": "output", "weights": [1, 3]}, + {"pred0": p(torch.ones(2, 2, 2, 2)), "pred1": p(torch.ones(2, 2, 2, 2)) + 2}, + p(torch.ones(2, 2, 2, 2)) * 2.5, + ] + ) -TEST_CASE_4 = [ - {"keys": ["pred0", "pred1"], "output_key": "output", "weights": [[1, 3], [3, 1]]}, - {"pred0": torch.ones(2, 2, 2), "pred1": torch.ones(2, 2, 2) + 2}, - torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1), -] + TESTS.append( + [ + {"keys": ["pred0", "pred1"], "output_key": "output", "weights": [[1, 3], [3, 1]]}, + {"pred0": p(torch.ones(2, 2, 2)), "pred1": p(torch.ones(2, 2, 2)) + 2}, + p(torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1)), + ] + ) -TEST_CASE_5 = [ - {"keys": ["pred0", "pred1"], "output_key": "output", "weights": np.array([[[1, 3]], [[3, 1]]])}, - {"pred0": torch.ones(2, 2, 2, 2), "pred1": torch.ones(2, 2, 2, 2) + 2}, - torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1), -] + TESTS.append( + [ + {"keys": ["pred0", "pred1"], "output_key": "output", "weights": np.array([[[1, 3]], [[3, 1]]])}, + {"pred0": p(torch.ones(2, 2, 2, 2)), "pred1": p(torch.ones(2, 2, 2, 2)) + 2}, + p(torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1)), + ] + ) -TEST_CASE_6 = [ - {"keys": ["pred0", "pred1"], "output_key": "output", "weights": torch.tensor([[[1, 3]], [[3, 1]]])}, - {"pred0": torch.ones(2, 2, 2, 2), "pred1": torch.ones(2, 2, 2, 2) + 2}, - torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1), -] + TESTS.append( + [ + {"keys": ["pred0", "pred1"], "output_key": "output", "weights": torch.tensor([[[1, 3]], [[3, 1]]])}, + {"pred0": p(torch.ones(2, 2, 2, 2)), "pred1": p(torch.ones(2, 2, 2, 2)) + 2}, + p(torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1)), + ] + ) class TestMeanEnsembled(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand(TESTS) def test_value(self, input_param, data, expected_value): result = MeanEnsembled(**input_param)(data) torch.testing.assert_allclose(result["output"], expected_value) @@ -67,7 +82,7 @@ def test_cuda_value(self): img = img.to(torch.device("cuda:0")) expected_value = expected_value.to(torch.device("cuda:0")) result = MeanEnsembled(keys="output", weights=torch.tensor([[[1, 3]], [[3, 1]]]))({"output": img}) - torch.testing.assert_allclose(result["output"], expected_value) + assert_allclose(result["output"], expected_value) if __name__ == "__main__": diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py index 2e27f4ba95..e7cc1a60ff 100644 --- a/tests/test_mednistdataset.py +++ b/tests/test_mednistdataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,11 +12,12 @@ import os import shutil import unittest -from urllib.error import ContentTooShortError, HTTPError +from pathlib import Path from monai.apps import MedNISTDataset from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord -from tests.utils import skip_if_quick +from monai.utils.enums import PostFix +from tests.utils import skip_if_downloading_fails, skip_if_quick MEDNIST_FULL_DATASET_LENGTH = 58954 @@ -38,32 +39,30 @@ def _test_dataset(dataset): self.assertEqual(len(dataset), int(MEDNIST_FULL_DATASET_LENGTH * dataset.test_frac)) self.assertTrue("image" in dataset[0]) self.assertTrue("label" in dataset[0]) - self.assertTrue("image_meta_dict" in dataset[0]) + self.assertTrue(PostFix.meta("image") in dataset[0]) self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64)) - try: # will start downloading if testing_dir doesn't have the MedNIST files - data = MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=True) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors + with skip_if_downloading_fails(): + data = MedNISTDataset( + root_dir=testing_dir, transform=transform, section="test", download=True, copy_cache=False + ) _test_dataset(data) # testing from - data = MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=False) - data.get_num_classes() + data = MedNISTDataset(root_dir=Path(testing_dir), transform=transform, section="test", download=False) + self.assertEqual(data.get_num_classes(), 6) _test_dataset(data) data = MedNISTDataset(root_dir=testing_dir, section="test", download=False) self.assertTupleEqual(data[0]["image"].shape, (64, 64)) # test same dataset length with different random seed data = MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=False, seed=42) _test_dataset(data) + self.assertEqual(data[0]["class_name"], "AbdomenCT") + self.assertEqual(data[0]["label"].cpu().item(), 0) shutil.rmtree(os.path.join(testing_dir, "MedNIST")) try: - data = MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=False) + MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=False) except RuntimeError as e: print(str(e)) self.assertTrue(str(e).startswith("Cannot find dataset directory")) diff --git a/tests/test_milmodel.py b/tests/test_milmodel.py new file mode 100644 index 0000000000..ad04e96c60 --- /dev/null +++ b/tests/test_milmodel.py @@ -0,0 +1,91 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import MILModel +from monai.utils.module import optional_import +from tests.utils import test_script_save + +models, _ = optional_import("torchvision.models") + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +TEST_CASE_MILMODEL = [] +for num_classes in [1, 5]: + for mil_mode in ["mean", "max", "att", "att_trans", "att_trans_pyramid"]: + test_case = [ + {"num_classes": num_classes, "mil_mode": mil_mode, "pretrained": False}, + (1, 2, 3, 512, 512), + (1, num_classes), + ] + TEST_CASE_MILMODEL.append(test_case) + + +for trans_blocks in [1, 3]: + test_case = [ + {"num_classes": 5, "pretrained": False, "trans_blocks": trans_blocks, "trans_dropout": 0.5}, + (1, 2, 3, 512, 512), + (1, 5), + ] + TEST_CASE_MILMODEL.append(test_case) + +# torchvision backbone +TEST_CASE_MILMODEL.append( + [{"num_classes": 5, "backbone": "resnet18", "pretrained": False}, (2, 2, 3, 512, 512), (2, 5)] +) +TEST_CASE_MILMODEL.append([{"num_classes": 5, "backbone": "resnet18", "pretrained": True}, (2, 2, 3, 512, 512), (2, 5)]) + +# custom backbone +backbone = models.densenet121(pretrained=False) +backbone_nfeatures = backbone.classifier.in_features +backbone.classifier = torch.nn.Identity() +TEST_CASE_MILMODEL.append( + [ + {"num_classes": 5, "backbone": backbone, "backbone_num_features": backbone_nfeatures, "pretrained": False}, + (2, 2, 3, 512, 512), + (2, 5), + ] +) + + +class TestMilModel(unittest.TestCase): + @parameterized.expand(TEST_CASE_MILMODEL) + def test_shape(self, input_param, input_shape, expected_shape): + net = MILModel(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape, dtype=torch.float).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_args(self): + with self.assertRaises(ValueError): + MILModel( + num_classes=5, + pretrained=False, + backbone="resnet50", + backbone_num_features=2048, + mil_mode="att_trans_pyramid", + ) + + def test_script(self): + input_param, input_shape, expected_shape = TEST_CASE_MILMODEL[0] + net = MILModel(**input_param) + test_data = torch.randn(input_shape, dtype=torch.float) + test_script_save(net, test_data) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_mlp.py b/tests/test_mlp.py index 7a93f81ec3..6fec5b6854 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,11 +24,7 @@ for mlp_dim in [512, 1028, 2048, 3072]: test_case = [ - { - "hidden_size": hidden_size, - "mlp_dim": mlp_dim, - "dropout_rate": dropout_rate, - }, + {"hidden_size": hidden_size, "mlp_dim": mlp_dim, "dropout_rate": dropout_rate}, (2, 512, hidden_size), (2, 512, hidden_size), ] diff --git a/tests/test_mmar_download.py b/tests/test_mmar_download.py index 6952e62c3c..98ad373625 100644 --- a/tests/test_mmar_download.py +++ b/tests/test_mmar_download.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,24 +12,21 @@ import os import tempfile import unittest -from urllib.error import ContentTooShortError, HTTPError +from pathlib import Path import numpy as np import torch from parameterized import parameterized -from monai.apps import RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar +from monai.apps import download_mmar, load_from_mmar from monai.apps.mmars import MODEL_DESC from monai.apps.mmars.mmars import _get_val -from tests.utils import SkipIfAtLeastPyTorchVersion, SkipIfBeforePyTorchVersion, skip_if_quick +from tests.utils import skip_if_downloading_fails, skip_if_quick -TEST_CASES = [["clara_pt_prostate_mri_segmentation_1"], ["clara_pt_covid19_ct_lesion_segmentation_1"]] +TEST_CASES = [["clara_pt_prostate_mri_segmentation"], ["clara_pt_covid19_ct_lesion_segmentation"]] TEST_EXTRACT_CASES = [ ( - { - "item": "clara_pt_prostate_mri_segmentation_1", - "map_location": "cuda" if torch.cuda.is_available() else "cpu", - }, + {"item": "clara_pt_prostate_mri_segmentation", "map_location": "cuda" if torch.cuda.is_available() else "cpu"}, "UNet", np.array( [ @@ -41,7 +38,7 @@ ), ( { - "item": "clara_pt_covid19_ct_lesion_segmentation_1", + "item": "clara_pt_covid19_ct_lesion_segmentation", "map_location": "cuda" if torch.cuda.is_available() else "cpu", }, "SegResNet", @@ -67,8 +64,9 @@ ), ( { - "item": "clara_pt_fed_learning_brain_tumor_mri_segmentation_1", + "item": "clara_pt_fed_learning_brain_tumor_mri_segmentation", "map_location": "cuda" if torch.cuda.is_available() else "cpu", + "model_file": os.path.join("models", "server", "best_FL_global_model.pt"), }, "SegResNet", np.array( @@ -81,7 +79,7 @@ ), ( { - "item": "clara_pt_pathology_metastasis_detection_1", + "item": "clara_pt_pathology_metastasis_detection", "map_location": "cuda" if torch.cuda.is_available() else "cpu", }, "TorchVisionFullyConvModel", @@ -103,49 +101,30 @@ class TestMMMARDownload(unittest.TestCase): @parameterized.expand(TEST_CASES) @skip_if_quick - @SkipIfBeforePyTorchVersion((1, 6)) def test_download(self, idx): - try: - # test model specification - cand = get_model_spec(idx) - self.assertEqual(cand[RemoteMMARKeys.ID], idx) - download_mmar(idx) + with skip_if_downloading_fails(): + with self.assertLogs(level="INFO", logger="monai.apps"): + download_mmar(idx) download_mmar(idx, progress=False) # repeated to check caching with tempfile.TemporaryDirectory() as tmp_dir: download_mmar(idx, mmar_dir=tmp_dir, progress=False) - download_mmar(idx, mmar_dir=tmp_dir, progress=False, version=1) # repeated to check caching + download_mmar(idx, mmar_dir=Path(tmp_dir), progress=False, version=1) # repeated to check caching self.assertTrue(os.path.exists(os.path.join(tmp_dir, idx))) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, HTTPError): - self.assertTrue("500" in str(e)) # http error has the code 500 - return # skipping this test due the network connection errors @parameterized.expand(TEST_EXTRACT_CASES) @skip_if_quick - @SkipIfBeforePyTorchVersion((1, 6)) def test_load_ckpt(self, input_args, expected_name, expected_val): - try: + with skip_if_downloading_fails(): output = load_from_mmar(**input_args) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, HTTPError): - self.assertTrue("500" in str(e)) # http error has the code 500 - return self.assertEqual(output.__class__.__name__, expected_name) x = next(output.parameters()) # verify the first element np.testing.assert_allclose(x[0][0].detach().cpu().numpy(), expected_val, rtol=1e-3, atol=1e-3) def test_unique(self): # model ids are unique - keys = sorted([m["id"] for m in MODEL_DESC]) + keys = sorted(m["id"] for m in MODEL_DESC) self.assertTrue(keys == sorted(set(keys))) - @SkipIfAtLeastPyTorchVersion((1, 6)) - def test_no_default(self): - with self.assertRaises(ValueError): - download_mmar(0) - def test_search(self): self.assertEqual(_get_val({"a": 1, "b": 2}, key="b"), 2) self.assertEqual(_get_val({"a": {"c": {"c": 4}}, "b": {"c": 2}}, key="b"), {"c": 2}) diff --git a/tests/test_module_list.py b/tests/test_module_list.py index 3aefaf5e0c..83c6979f30 100644 --- a/tests/test_module_list.py +++ b/tests/test_module_list.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,7 +10,9 @@ # limitations under the License. import glob +import inspect import os +import pathlib import unittest import monai @@ -33,6 +35,34 @@ def test_public_api(self): mod.append(code_folder) self.assertEqual(sorted(monai.__all__), sorted(mod)) + def test_transform_api(self): + """monai subclasses of MapTransforms must have alias names ending with 'd', 'D', 'Dict'""" + to_exclude = {"MapTransform"} # except for these transforms + to_exclude_docs = {"Decollate", "Ensemble", "Invert", "SaveClassification", "RandTorchVision"} + to_exclude_docs.update({"DeleteItems", "SelectItems", "CopyItems", "ConcatItems"}) + xforms = { + name: obj + for name, obj in monai.transforms.__dict__.items() + if inspect.isclass(obj) and issubclass(obj, monai.transforms.MapTransform) + } + names = sorted(x for x in xforms if x not in to_exclude) + remained = set(names) + doc_file = os.path.join(pathlib.Path(__file__).parent.parent, "docs", "source", "transforms.rst") + contents = pathlib.Path(doc_file).read_text() if os.path.exists(doc_file) else None + for n in names: + if not n.endswith("d"): + continue + with self.subTest(n=n): + basename = n[:-1] # Transformd basename is Transform + for docname in (f"{basename}", f"{basename}d"): + if docname in to_exclude_docs: + continue + if (contents is not None) and f"`{docname}`" not in f"{contents}": + self.assertTrue(False, f"please add `{docname}` to docs/source/transforms.rst") + for postfix in ("D", "d", "Dict"): + remained.remove(f"{basename}{postfix}") + self.assertFalse(remained) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_multi_scale.py b/tests/test_multi_scale.py index 01a760db72..963824f25e 100644 --- a/tests/test_multi_scale.py +++ b/tests/test_multi_scale.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_net_adapter.py b/tests/test_net_adapter.py index b2d55129a7..39201fb600 100644 --- a/tests/test_net_adapter.py +++ b/tests/test_net_adapter.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,23 +19,11 @@ device = "cuda" if torch.cuda.is_available() else "cpu" -TEST_CASE_0 = [ - {"num_classes": 1, "use_conv": True, "dim": 2}, - (2, 3, 224, 224), - (2, 1, 8, 1), -] +TEST_CASE_0 = [{"num_classes": 1, "use_conv": True, "dim": 2}, (2, 3, 224, 224), (2, 1, 8, 1)] -TEST_CASE_1 = [ - {"num_classes": 1, "use_conv": True, "dim": 3, "pool": None}, - (2, 3, 32, 32, 32), - (2, 1, 1, 1, 1), -] +TEST_CASE_1 = [{"num_classes": 1, "use_conv": True, "dim": 3, "pool": None}, (2, 3, 32, 32, 32), (2, 1, 1, 1, 1)] -TEST_CASE_2 = [ - {"num_classes": 5, "use_conv": True, "dim": 3, "pool": None}, - (2, 3, 32, 32, 32), - (2, 5, 1, 1, 1), -] +TEST_CASE_2 = [{"num_classes": 5, "use_conv": True, "dim": 3, "pool": None}, (2, 3, 32, 32, 32), (2, 5, 1, 1, 1)] TEST_CASE_3 = [ {"num_classes": 5, "use_conv": True, "pool": ("avg", {"kernel_size": 4, "stride": 1}), "dim": 3}, @@ -53,7 +41,9 @@ class TestNetAdapter(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_shape(self, input_param, input_shape, expected_shape): - model = resnet18(spatial_dims=input_param["dim"]) + spatial_dims = input_param["dim"] + stride = (1, 2, 2)[:spatial_dims] + model = resnet18(spatial_dims=spatial_dims, conv1_t_stride=stride) input_param["model"] = model net = NetAdapter(**input_param).to(device) with eval_mode(net): diff --git a/tests/test_network_consistency.py b/tests/test_network_consistency.py index 9698a40116..419e1202d0 100644 --- a/tests/test_network_consistency.py +++ b/tests/test_network_consistency.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,8 +20,9 @@ from parameterized.parameterized import parameterized import monai.networks.nets as nets +from monai.utils import set_determinism -extra_test_data_dir = os.environ.get("MONAI_EXTRA_TEST_DATA", None) +extra_test_data_dir = os.environ.get("MONAI_EXTRA_TEST_DATA") TESTS = [] if extra_test_data_dir is not None: @@ -33,6 +34,12 @@ class TestNetworkConsistency(unittest.TestCase): + def setUp(self): + set_determinism(0) + + def tearDown(self): + set_determinism(None) + @skipIf( len(TESTS) == 0, "To run these tests, clone https://github.com/Project-MONAI/MONAI-extra-test-data and set MONAI_EXTRA_TEST_DATA", @@ -53,8 +60,8 @@ def test_network_consistency(self, net_name, data_path, json_path): json_file.close() # Create model - model = nets.__dict__[net_name](**model_params) - model.load_state_dict(loaded_data["model"]) + model = getattr(nets, net_name)(**model_params) + model.load_state_dict(loaded_data["model"], strict=False) model.eval() in_data = loaded_data["in_data"] diff --git a/tests/test_nifti_endianness.py b/tests/test_nifti_endianness.py index bf0f27b9ca..7f179d3bde 100644 --- a/tests/test_nifti_endianness.py +++ b/tests/test_nifti_endianness.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,6 +23,7 @@ from monai.data.image_reader import PILReader from monai.transforms import LoadImage, LoadImaged from monai.transforms.io.array import switch_endianness +from monai.utils.enums import PostFix from monai.utils.module import optional_import if TYPE_CHECKING: @@ -60,8 +61,8 @@ def test_endianness(self, endianness, use_array, image_only): check_ds = Dataset(data, tr) check_loader = DataLoader(check_ds, batch_size=1) ret = next(iter(check_loader)) - if isinstance(ret, dict) and "image_meta_dict" in ret: - np.testing.assert_allclose(ret["image_meta_dict"]["spatial_shape"], [[100, 100]]) + if isinstance(ret, dict) and PostFix.meta("image") in ret: + np.testing.assert_allclose(ret[PostFix.meta("image")]["spatial_shape"], [[100, 100]]) def test_switch(self): # verify data types for data in (np.zeros((2, 1)), ("test",), [24, 42], {"foo": "bar"}, True, 42): diff --git a/tests/test_nifti_header_revise.py b/tests/test_nifti_header_revise.py index 8d9a1d4f3a..7f917cb0e9 100644 --- a/tests/test_nifti_header_revise.py +++ b/tests/test_nifti_header_revise.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_nifti_rw.py b/tests/test_nifti_rw.py index f16d80659c..2c0a8dc9a3 100644 --- a/tests/test_nifti_rw.py +++ b/tests/test_nifti_rw.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,56 +17,68 @@ import numpy as np from parameterized import parameterized -from monai.data import write_nifti +from monai.data import NibabelWriter from monai.transforms import LoadImage, Orientation, Spacing -from tests.utils import make_nifti_image - -TEST_IMAGE = np.arange(24).reshape((2, 4, 3)) -TEST_AFFINE = np.array( - [[-5.3, 0.0, 0.0, 102.01], [0.0, 0.52, 2.17, -7.50], [-0.0, 1.98, -0.26, -23.12], [0.0, 0.0, 0.0, 1.0]] -) - -TEST_CASES = [ - [ - TEST_IMAGE, - TEST_AFFINE, - dict(reader="NibabelReader", image_only=False, as_closest_canonical=True), - np.arange(24).reshape((2, 4, 3)), - ], - [ - TEST_IMAGE, - TEST_AFFINE, - dict(reader="NibabelReader", image_only=True, as_closest_canonical=True), - np.array( +from tests.utils import TEST_NDARRAYS, assert_allclose, make_nifti_image + +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + TEST_IMAGE = p(np.arange(24).reshape((2, 4, 3))) + TEST_AFFINE = q( + np.array( + [[-5.3, 0.0, 0.0, 102.01], [0.0, 0.52, 2.17, -7.50], [-0.0, 1.98, -0.26, -23.12], [0.0, 0.0, 0.0, 1.0]] + ) + ) + TESTS.append( + [ + TEST_IMAGE, + TEST_AFFINE, + dict(reader="NibabelReader", image_only=False, as_closest_canonical=True), + np.arange(24).reshape((2, 4, 3)), + ] + ) + TESTS.append( [ - [[12.0, 15.0, 18.0, 21.0], [13.0, 16.0, 19.0, 22.0], [14.0, 17.0, 20.0, 23.0]], - [[0.0, 3.0, 6.0, 9.0], [1.0, 4.0, 7.0, 10.0], [2.0, 5.0, 8.0, 11.0]], + TEST_IMAGE, + TEST_AFFINE, + dict(reader="NibabelReader", image_only=True, as_closest_canonical=True), + np.array( + [ + [[12.0, 15.0, 18.0, 21.0], [13.0, 16.0, 19.0, 22.0], [14.0, 17.0, 20.0, 23.0]], + [[0.0, 3.0, 6.0, 9.0], [1.0, 4.0, 7.0, 10.0], [2.0, 5.0, 8.0, 11.0]], + ] + ), ] - ), - ], - [ - TEST_IMAGE, - TEST_AFFINE, - dict(reader="NibabelReader", image_only=True, as_closest_canonical=False), - np.arange(24).reshape((2, 4, 3)), - ], - [ - TEST_IMAGE, - TEST_AFFINE, - dict(reader="NibabelReader", image_only=False, as_closest_canonical=False), - np.arange(24).reshape((2, 4, 3)), - ], - [ - TEST_IMAGE, - None, - dict(reader="NibabelReader", image_only=False, as_closest_canonical=False), - np.arange(24).reshape((2, 4, 3)), - ], -] + ) + TESTS.append( + [ + TEST_IMAGE, + TEST_AFFINE, + dict(reader="NibabelReader", image_only=True, as_closest_canonical=False), + np.arange(24).reshape((2, 4, 3)), + ] + ) + TESTS.append( + [ + TEST_IMAGE, + TEST_AFFINE, + dict(reader="NibabelReader", image_only=False, as_closest_canonical=False), + np.arange(24).reshape((2, 4, 3)), + ] + ) + TESTS.append( + [ + TEST_IMAGE, + None, + dict(reader="NibabelReader", image_only=False, as_closest_canonical=False), + np.arange(24).reshape((2, 4, 3)), + ] + ) class TestNiftiLoadRead(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_orientation(self, array, affine, reader_param, expected): test_image = make_nifti_image(array, affine) @@ -82,10 +94,13 @@ def test_orientation(self, array, affine, reader_param, expected): os.remove(test_image) # write test cases + writer_obj = NibabelWriter() + writer_obj.set_data_array(data_array, channel_dim=None) if header is not None: - write_nifti(data_array, test_image, header["affine"], header.get("original_affine", None)) + writer_obj.set_metadata(header) elif affine is not None: - write_nifti(data_array, test_image, affine) + writer_obj.set_metadata({"affine": affine}) + writer_obj.write(test_image, verbose=True) saved = nib.load(test_image) saved_affine = saved.affine saved_data = saved.get_fdata() @@ -93,8 +108,8 @@ def test_orientation(self, array, affine, reader_param, expected): os.remove(test_image) if affine is not None: - np.testing.assert_allclose(saved_affine, affine) - np.testing.assert_allclose(saved_data, expected) + assert_allclose(saved_affine, affine, type_test=False) + assert_allclose(saved_data, expected, type_test=False) def test_consistency(self): np.set_printoptions(suppress=True, precision=3) @@ -104,105 +119,134 @@ def test_consistency(self): data, _, new_affine = Orientation("ILP")(data, new_affine) if os.path.exists(test_image): os.remove(test_image) - write_nifti(data[0], test_image, new_affine, original_affine, mode="nearest", padding_mode="border") + writer_obj = NibabelWriter() + writer_obj.set_data_array(data[0], channel_dim=None) + writer_obj.set_metadata( + meta_dict={"affine": new_affine, "original_affine": original_affine}, mode="nearest", padding_mode="border" + ) + writer_obj.write(test_image, verbose=True) saved = nib.load(test_image) saved_data = saved.get_fdata() np.testing.assert_allclose(saved_data, np.arange(64).reshape(1, 8, 8), atol=1e-7) if os.path.exists(test_image): os.remove(test_image) - write_nifti( - data[0], - test_image, - new_affine, - original_affine, + writer_obj.set_data_array(data[0], channel_dim=None) + writer_obj.set_metadata( + meta_dict={"affine": new_affine, "original_affine": original_affine, "spatial_shape": (1, 8, 8)}, mode="nearest", padding_mode="border", - output_spatial_shape=(1, 8, 8), ) + writer_obj.write(test_image, verbose=True) saved = nib.load(test_image) saved_data = saved.get_fdata() np.testing.assert_allclose(saved_data, np.arange(64).reshape(1, 8, 8), atol=1e-7) if os.path.exists(test_image): os.remove(test_image) - # test the case that only correct orientation but don't resample - write_nifti(data[0], test_image, new_affine, original_affine, resample=False) + # test the case no resample + writer_obj.set_data_array(data[0], channel_dim=None) + writer_obj.set_metadata(meta_dict={"affine": new_affine, "original_affine": original_affine}, resample=False) + writer_obj.write(test_image, verbose=True) saved = nib.load(test_image) - # compute expected affine - start_ornt = nib.orientations.io_orientation(new_affine) - target_ornt = nib.orientations.io_orientation(original_affine) - ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt) - data_shape = data[0].shape - expected_affine = new_affine @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape) - np.testing.assert_allclose(saved.affine, expected_affine) + np.testing.assert_allclose(saved.affine, new_affine) if os.path.exists(test_image): os.remove(test_image) def test_write_2d(self): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.nii.gz") - img = np.arange(6).reshape((2, 3)) - write_nifti(img, image_name, affine=np.diag([1]), target_affine=np.diag([1.4])) - out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[0, 1, 2], [3.0, 4, 5]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) - - image_name = os.path.join(out_dir, "test1.nii.gz") - img = np.arange(5).reshape((1, 5)) - write_nifti(img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 1, 3, 5])) - out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[0, 2, 4]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 1, 1])) + for p in TEST_NDARRAYS: + img = p(np.arange(6).reshape((2, 3))) + writer_obj = NibabelWriter() + writer_obj.set_data_array(img, channel_dim=None) + writer_obj.set_metadata({"affine": np.diag([1, 1, 1]), "original_affine": np.diag([1.4, 1, 1])}) + writer_obj.write(image_name, verbose=True) + out = nib.load(image_name) + np.testing.assert_allclose(out.get_fdata(), [[0, 1, 2], [3.0, 4, 5]]) + np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) + + image_name = os.path.join(out_dir, "test1.nii.gz") + img = np.arange(5).reshape((1, 5)) + writer_obj.set_data_array(img, channel_dim=None) + writer_obj.set_metadata( + {"affine": np.diag([1, 1, 1, 3, 3]), "original_affine": np.diag([1.4, 2.0, 1, 3, 5])} + ) + writer_obj.write(image_name, verbose=True) + out = nib.load(image_name) + np.testing.assert_allclose(out.get_fdata(), [[0, 2, 4]]) + np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 1, 1])) def test_write_3d(self): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.nii.gz") - img = np.arange(6).reshape((1, 2, 3)) - write_nifti(img, image_name, affine=np.diag([1]), target_affine=np.diag([1.4])) - out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[[0, 1, 2], [3, 4, 5]]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) - - image_name = os.path.join(out_dir, "test1.nii.gz") - img = np.arange(5).reshape((1, 1, 5)) - write_nifti(img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 2, 3, 5])) - out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[[0, 2, 4]]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) + for p in TEST_NDARRAYS: + img = p(np.arange(6).reshape((1, 2, 3))) + writer_obj = NibabelWriter() + writer_obj.set_data_array(img, channel_dim=None) + writer_obj.set_metadata({"affine": np.diag([1, 1, 1, 1]), "original_affine": np.diag([1.4, 1, 1, 1])}) + writer_obj.write(image_name, verbose=True) + out = nib.load(image_name) + np.testing.assert_allclose(out.get_fdata(), [[[0, 1, 2], [3, 4, 5]]]) + np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) + + image_name = os.path.join(out_dir, "test1.nii.gz") + img = p(np.arange(5).reshape((1, 1, 5))) + writer_obj.set_data_array(img, channel_dim=None) + writer_obj.set_metadata( + {"affine": np.diag([1, 1, 1, 3, 3]), "original_affine": np.diag([1.4, 2.0, 2, 3, 5])} + ) + writer_obj.write(image_name, verbose=True) + out = nib.load(image_name) + np.testing.assert_allclose(out.get_fdata(), [[[0, 2, 4]]]) + np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) def test_write_4d(self): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.nii.gz") - img = np.arange(6).reshape((1, 1, 3, 2)) - write_nifti(img, image_name, affine=np.diag([1.4, 1]), target_affine=np.diag([1, 1.4, 1])) - out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[[[0, 1], [2, 3], [4, 5]]]]) - np.testing.assert_allclose(out.affine, np.diag([1, 1.4, 1, 1])) - - image_name = os.path.join(out_dir, "test1.nii.gz") - img = np.arange(5).reshape((1, 1, 5, 1)) - write_nifti(img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 2, 3, 5])) - out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[[[0], [2], [4]]]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) + for p in TEST_NDARRAYS: + img = p(np.arange(6).reshape((1, 1, 3, 2))) + writer_obj = NibabelWriter() + writer_obj.set_data_array(img, channel_dim=-1) + writer_obj.set_metadata({"affine": np.diag([1.4, 1, 1, 1]), "original_affine": np.diag([1, 1.4, 1, 1])}) + writer_obj.write(image_name, verbose=True) + out = nib.load(image_name) + np.testing.assert_allclose(out.get_fdata(), [[[[0, 1], [2, 3], [4, 5]]]]) + np.testing.assert_allclose(out.affine, np.diag([1, 1.4, 1, 1])) + + image_name = os.path.join(out_dir, "test1.nii.gz") + img = p(np.arange(5).reshape((1, 1, 5, 1))) + writer_obj.set_data_array(img, channel_dim=-1, squeeze_end_dims=False) + writer_obj.set_metadata( + {"affine": np.diag([1, 1, 1, 3, 3]), "original_affine": np.diag([1.4, 2.0, 2, 3, 5])} + ) + writer_obj.write(image_name, verbose=True) + out = nib.load(image_name) + np.testing.assert_allclose(out.get_fdata(), [[[[0], [2], [4]]]]) + np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) def test_write_5d(self): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.nii.gz") - img = np.arange(12).reshape((1, 1, 3, 2, 2)) - write_nifti(img, image_name, affine=np.diag([1]), target_affine=np.diag([1.4])) - out = nib.load(image_name) - np.testing.assert_allclose( - out.get_fdata(), - np.array([[[[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]], [[8.0, 9.0], [10.0, 11.0]]]]]), - ) - np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) - - image_name = os.path.join(out_dir, "test1.nii.gz") - img = np.arange(10).reshape((1, 1, 5, 1, 2)) - write_nifti(img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 2, 3, 5])) - out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), np.array([[[[[0.0, 1.0]], [[4.0, 5.0]], [[8.0, 9.0]]]]])) - np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) + for p in TEST_NDARRAYS: + img = p(np.arange(12).reshape((1, 1, 3, 2, 2))) + writer_obj = NibabelWriter() + writer_obj.set_data_array(img, channel_dim=-1, squeeze_end_dims=False, spatial_ndim=None) + writer_obj.set_metadata({"affine": np.diag([1, 1, 1, 1]), "original_affine": np.diag([1.4, 1, 1, 1])}) + writer_obj.write(image_name, verbose=True) + out = nib.load(image_name) + np.testing.assert_allclose( + out.get_fdata(), + np.array([[[[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]], [[8.0, 9.0], [10.0, 11.0]]]]]), + ) + np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) + + image_name = os.path.join(out_dir, "test1.nii.gz") + img = p(np.arange(10).reshape((1, 1, 5, 1, 2))) + writer_obj.set_data_array(img, channel_dim=-1, squeeze_end_dims=False, spatial_ndim=None) + writer_obj.set_metadata({"affine": np.diag([1, 1, 1, 3]), "original_affine": np.diag([1.4, 2.0, 2, 3])}) + writer_obj.write(image_name, verbose=True) + out = nib.load(image_name) + np.testing.assert_allclose(out.get_fdata(), np.array([[[[[0.0, 2.0]], [[4.0, 5.0]], [[7.0, 9.0]]]]])) + np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) if __name__ == "__main__": diff --git a/tests/test_nifti_saver.py b/tests/test_nifti_saver.py index c07084172f..6855a59041 100644 --- a/tests/test_nifti_saver.py +++ b/tests/test_nifti_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -36,12 +36,7 @@ def test_saved_content(self): def test_saved_resize_content(self): with tempfile.TemporaryDirectory() as tempdir: - saver = NiftiSaver( - output_dir=tempdir, - output_postfix="seg", - output_ext=".nii.gz", - dtype=np.float32, - ) + saver = NiftiSaver(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32) meta_data = { "filename_or_obj": ["testfile" + str(i) + ".nii" for i in range(8)], @@ -56,12 +51,7 @@ def test_saved_resize_content(self): def test_saved_3d_resize_content(self): with tempfile.TemporaryDirectory() as tempdir: - saver = NiftiSaver( - output_dir=tempdir, - output_postfix="seg", - output_ext=".nii.gz", - dtype=np.float32, - ) + saver = NiftiSaver(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32) meta_data = { "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], @@ -74,6 +64,25 @@ def test_saved_3d_resize_content(self): filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + def test_saved_3d_no_resize_content(self): + with tempfile.TemporaryDirectory() as tempdir: + + saver = NiftiSaver( + output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32, resample=False + ) + + meta_data = { + "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], + "spatial_shape": [(10, 10, 2)] * 8, + "affine": [np.diag(np.ones(4)) * 5] * 8, + "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, + } + saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join(tempdir, "testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") + img, _ = LoadImage("nibabelreader")(filepath) + self.assertEqual(img.shape, (1, 2, 2, 8)) + def test_squeeze_end_dims(self): with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/test_normalize_intensity.py b/tests/test_normalize_intensity.py index 2755eb4c25..5bcee1263b 100644 --- a/tests/test_normalize_intensity.py +++ b/tests/test_normalize_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -31,51 +31,51 @@ "divisor": u(np.array([0.5, 0.5, 0.5, 0.5])), "nonzero": True, }, - np.array([0.0, 3.0, 0.0, 4.0]), - np.array([0.0, -1.0, 0.0, 1.0]), + p(np.array([0.0, 3.0, 0.0, 4.0])), + p(np.array([0.0, -1.0, 0.0, 1.0])), ] ) - TESTS.append([p, {"nonzero": True}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])]) - TESTS.append([p, {"nonzero": False}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])]) - TESTS.append([p, {"nonzero": False}, np.array([1, 1, 1, 1]), np.array([0.0, 0.0, 0.0, 0.0])]) + TESTS.append([p, {"nonzero": True}, p(np.array([0.0, 0.0, 0.0, 0.0])), p(np.array([0.0, 0.0, 0.0, 0.0]))]) + TESTS.append([p, {"nonzero": False}, p(np.array([0.0, 0.0, 0.0, 0.0])), p(np.array([0.0, 0.0, 0.0, 0.0]))]) + TESTS.append([p, {"nonzero": False}, p(np.array([1, 1, 1, 1])), p(np.array([0.0, 0.0, 0.0, 0.0]))]) TESTS.append( [ p, - {"nonzero": False, "channel_wise": True, "subtrahend": [1, 2, 3]}, - np.ones((3, 2, 2)), - np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-2.0, -2.0], [-2.0, -2.0]]]), + {"nonzero": False, "channel_wise": True, "subtrahend": [1, 2, 3], "dtype": np.float32}, + p(np.ones((3, 2, 2))), + p(np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-2.0, -2.0], [-2.0, -2.0]]])), ] ) TESTS.append( [ p, - {"nonzero": True, "channel_wise": True, "subtrahend": [1, 2, 3], "divisor": [0, 0, 2]}, - np.ones((3, 2, 2)), - np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-1.0, -1.0], [-1.0, -1.0]]]), + {"nonzero": True, "channel_wise": True, "subtrahend": [1, 2, 3], "divisor": [0, 0, 2], "dtype": "float32"}, + p(np.ones((3, 2, 2))), + p(np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-1.0, -1.0], [-1.0, -1.0]]])), ] ) TESTS.append( [ p, - {"nonzero": True, "channel_wise": False, "subtrahend": 2, "divisor": 0}, - np.ones((3, 2, 2)), - np.ones((3, 2, 2)) * -1.0, + {"nonzero": True, "channel_wise": False, "subtrahend": 2, "divisor": 0, "dtype": torch.float32}, + p(np.ones((3, 2, 2))), + p(np.ones((3, 2, 2)) * -1.0), ] ) TESTS.append( [ p, {"nonzero": True, "channel_wise": False, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": 0}, - np.ones((3, 2, 2)), - np.ones((3, 2, 2)) * 0.5, + p(np.ones((3, 2, 2))), + p(np.ones((3, 2, 2)) * 0.5), ] ) TESTS.append( [ p, {"nonzero": True, "channel_wise": True, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": [0, 1, 0]}, - np.ones((3, 2, 2)), - np.ones((3, 2, 2)) * 0.5, + p(np.ones((3, 2, 2))), + p(np.ones((3, 2, 2)) * 0.5), ] ) @@ -91,17 +91,14 @@ def test_default(self, im_type): self.assertEqual(im.device, normalized.device) self.assertTrue(normalized.dtype in (np.float32, torch.float32)) expected = (self.imt - np.mean(self.imt)) / np.std(self.imt) - assert_allclose(expected, normalized, rtol=1e-3) + assert_allclose(normalized, expected, type_test=False, rtol=1e-3) @parameterized.expand(TESTS) def test_nonzero(self, in_type, input_param, input_data, expected_data): normalizer = NormalizeIntensity(**input_param) im = in_type(input_data) normalized = normalizer(im) - self.assertEqual(type(im), type(normalized)) - if isinstance(normalized, torch.Tensor): - self.assertEqual(im.device, normalized.device) - assert_allclose(expected_data, normalized) + assert_allclose(normalized, in_type(expected_data)) @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_channel_wise(self, im_type): @@ -109,10 +106,7 @@ def test_channel_wise(self, im_type): input_data = im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]])) expected = np.array([[0.0, -1.0, 0.0, 1.0], [0.0, -1.0, 0.0, 1.0]]) normalized = normalizer(input_data) - self.assertEqual(type(input_data), type(normalized)) - if isinstance(normalized, torch.Tensor): - self.assertEqual(input_data.device, normalized.device) - assert_allclose(expected, normalized) + assert_allclose(normalized, im_type(expected)) @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_value_errors(self, im_type): diff --git a/tests/test_normalize_intensityd.py b/tests/test_normalize_intensityd.py index e2cec5407a..12a39b1b5b 100644 --- a/tests/test_normalize_intensityd.py +++ b/tests/test_normalize_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,7 +25,7 @@ [ {"keys": ["img"], "nonzero": True}, {"img": p(np.array([0.0, 3.0, 0.0, 4.0]))}, - np.array([0.0, -1.0, 0.0, 1.0]), + p(np.array([0.0, -1.0, 0.0, 1.0])), ] ) TESTS.append( @@ -37,14 +37,14 @@ "nonzero": True, }, {"img": p(np.array([0.0, 3.0, 0.0, 4.0]))}, - np.array([0.0, -1.0, 0.0, 1.0]), + p(np.array([0.0, -1.0, 0.0, 1.0])), ] ) TESTS.append( [ {"keys": ["img"], "nonzero": True}, {"img": p(np.array([0.0, 0.0, 0.0, 0.0]))}, - np.array([0.0, 0.0, 0.0, 0.0]), + p(np.array([0.0, 0.0, 0.0, 0.0])), ] ) @@ -60,7 +60,7 @@ def test_image_normalize_intensityd(self, im_type): self.assertEqual(type(im), type(normalized)) if isinstance(normalized, torch.Tensor): self.assertEqual(im.device, normalized.device) - assert_allclose(normalized, expected, rtol=1e-3) + assert_allclose(normalized, im_type(expected), rtol=1e-3) @parameterized.expand(TESTS) def test_nonzero(self, input_param, input_data, expected_data): @@ -82,7 +82,7 @@ def test_channel_wise(self, im_type): if isinstance(normalized, torch.Tensor): self.assertEqual(input_data[key].device, normalized.device) expected = np.array([[0.0, -1.0, 0.0, 1.0], [0.0, -1.0, 0.0, 1.0]]) - assert_allclose(normalized, expected) + assert_allclose(normalized, im_type(expected)) if __name__ == "__main__": diff --git a/tests/test_npzdictitemdataset.py b/tests/test_npzdictitemdataset.py index 2e86ef29d0..e24a2cfc1f 100644 --- a/tests/test_npzdictitemdataset.py +++ b/tests/test_npzdictitemdataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_numpy_reader.py b/tests/test_numpy_reader.py index a57a036905..c2f3679e33 100644 --- a/tests/test_numpy_reader.py +++ b/tests/test_numpy_reader.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,12 +10,16 @@ # limitations under the License. import os +import sys import tempfile import unittest import numpy as np +import torch -from monai.data import NumpyReader +from monai.data import DataLoader, Dataset, NumpyReader +from monai.transforms import LoadImaged +from monai.utils.enums import PostFix class TestNumpyReader(unittest.TestCase): @@ -27,8 +31,8 @@ def test_npy(self): reader = NumpyReader() result = reader.get_data(reader.read(filepath)) - self.assertTupleEqual(result[1]["spatial_shape"], test_data.shape) - self.assertTupleEqual(result[0].shape, test_data.shape) + np.testing.assert_allclose(result[1]["spatial_shape"], test_data.shape) + np.testing.assert_allclose(result[0].shape, test_data.shape) np.testing.assert_allclose(result[0], test_data) def test_npz1(self): @@ -39,8 +43,8 @@ def test_npz1(self): reader = NumpyReader() result = reader.get_data(reader.read(filepath)) - self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape) - self.assertTupleEqual(result[0].shape, test_data1.shape) + np.testing.assert_allclose(result[1]["spatial_shape"], test_data1.shape) + np.testing.assert_allclose(result[0].shape, test_data1.shape) np.testing.assert_allclose(result[0], test_data1) def test_npz2(self): @@ -52,8 +56,8 @@ def test_npz2(self): reader = NumpyReader() result = reader.get_data(reader.read(filepath)) - self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape) - self.assertTupleEqual(result[0].shape, (2, 3, 4, 4)) + np.testing.assert_allclose(result[1]["spatial_shape"], test_data1.shape) + np.testing.assert_allclose(result[0].shape, (2, 3, 4, 4)) np.testing.assert_allclose(result[0], np.stack([test_data1, test_data2])) def test_npz3(self): @@ -65,8 +69,8 @@ def test_npz3(self): reader = NumpyReader(npz_keys=["test1", "test2"]) result = reader.get_data(reader.read(filepath)) - self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape) - self.assertTupleEqual(result[0].shape, (2, 3, 4, 4)) + np.testing.assert_allclose(result[1]["spatial_shape"], test_data1.shape) + np.testing.assert_allclose(result[0].shape, (2, 3, 4, 4)) np.testing.assert_allclose(result[0], np.stack([test_data1, test_data2])) def test_npy_pickle(self): @@ -77,7 +81,7 @@ def test_npy_pickle(self): reader = NumpyReader() result = reader.get_data(reader.read(filepath))[0].item() - self.assertTupleEqual(result["test"].shape, test_data["test"].shape) + np.testing.assert_allclose(result["test"].shape, test_data["test"].shape) np.testing.assert_allclose(result["test"], test_data["test"]) def test_kwargs(self): @@ -88,7 +92,39 @@ def test_kwargs(self): reader = NumpyReader(mmap_mode="r") result = reader.get_data(reader.read(filepath, mmap_mode=None))[0].item() - self.assertTupleEqual(result["test"].shape, test_data["test"].shape) + np.testing.assert_allclose(result["test"].shape, test_data["test"].shape) + + def test_dataloader(self): + test_data = np.random.randint(0, 256, size=[3, 4, 5]) + datalist = [] + with tempfile.TemporaryDirectory() as tempdir: + for i in range(4): + filepath = os.path.join(tempdir, f"test_data{i}.npz") + np.savez(filepath, test_data) + datalist.append({"image": filepath}) + + num_workers = 2 if sys.platform == "linux" else 0 + loader = DataLoader( + Dataset(data=datalist, transform=LoadImaged(keys="image", reader=NumpyReader())), + batch_size=2, + num_workers=num_workers, + ) + for d in loader: + for s in d[PostFix.meta("image")]["spatial_shape"]: + torch.testing.assert_allclose(s, torch.as_tensor([3, 4, 5])) + for c in d["image"]: + torch.testing.assert_allclose(c, test_data) + + def test_channel_dim(self): + test_data = np.random.randint(0, 256, size=[3, 4, 5, 2]) + with tempfile.TemporaryDirectory() as tempdir: + filepath = os.path.join(tempdir, "test_data.npy") + np.save(filepath, test_data) + + reader = NumpyReader(channel_dim=-1) + result = reader.get_data(reader.read(filepath)) + np.testing.assert_allclose(result[1]["spatial_shape"], test_data.shape[:-1]) + self.assertEqual(result[1]["original_channel_dim"], -1) if __name__ == "__main__": diff --git a/tests/test_nvtx_decorator.py b/tests/test_nvtx_decorator.py index e2a9ad67b8..e81c72efcf 100644 --- a/tests/test_nvtx_decorator.py +++ b/tests/test_nvtx_decorator.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,52 +17,44 @@ from monai.transforms import ( Compose, + CuCIM, Flip, FlipD, RandAdjustContrast, + RandCuCIM, RandFlip, Randomizable, Rotate90, + ToCupy, + TorchVision, ToTensor, ToTensorD, ) from monai.utils import Range, optional_import +from tests.utils import HAS_CUPY _, has_nvtx = optional_import("torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?") +_, has_tvt = optional_import("torchvision.transforms") +_, has_cut = optional_import("cucim.core.operations.expose.transform") -TEST_CASE_ARRAY_0 = [ - np.random.randn(3, 3), -] -TEST_CASE_ARRAY_1 = [ - np.random.randn(3, 10, 10), -] +TEST_CASE_ARRAY_0 = [np.random.randn(3, 3)] +TEST_CASE_ARRAY_1 = [np.random.randn(3, 10, 10)] -TEST_CASE_DICT_0 = [ - {"image": np.random.randn(3, 3)}, -] -TEST_CASE_DICT_1 = [ - {"image": np.random.randn(3, 10, 10)}, -] +TEST_CASE_DICT_0 = [{"image": np.random.randn(3, 3)}] +TEST_CASE_DICT_1 = [{"image": np.random.randn(3, 10, 10)}] -TEST_CASE_TORCH_0 = [ - torch.randn(3, 3), -] -TEST_CASE_TORCH_1 = [ - torch.randn(3, 10, 10), -] +TEST_CASE_TORCH_0 = [torch.randn(3, 3)] +TEST_CASE_TORCH_1 = [torch.randn(3, 10, 10)] +TEST_CASE_WRAPPER = [np.random.randn(3, 10, 10)] + +@unittest.skipUnless(has_nvtx, "CUDA is required for NVTX Range!") class TestNVTXRangeDecorator(unittest.TestCase): @parameterized.expand([TEST_CASE_ARRAY_0, TEST_CASE_ARRAY_1]) - @unittest.skipUnless(has_nvtx, "CUDA is required for NVTX Range!") def test_tranform_array(self, input): - transforms = Compose( - [ - Range("random flip")(Flip()), - Range()(ToTensor()), - ] - ) + transforms = Compose([Range("random flip")(Flip()), Range()(ToTensor())]) # Apply transforms output = transforms(input) @@ -82,18 +74,12 @@ def test_tranform_array(self, input): self.assertIsInstance(output2, torch.Tensor) self.assertIsInstance(output3, torch.Tensor) np.testing.assert_equal(output.numpy(), output1.numpy()) - np.testing.assert_equal(output.numpy(), output1.numpy()) + np.testing.assert_equal(output.numpy(), output2.numpy()) np.testing.assert_equal(output.numpy(), output3.numpy()) @parameterized.expand([TEST_CASE_DICT_0, TEST_CASE_DICT_1]) - @unittest.skipUnless(has_nvtx, "CUDA is required for NVTX Range!") def test_tranform_dict(self, input): - transforms = Compose( - [ - Range("random flip dict")(FlipD(keys="image")), - Range()(ToTensorD("image")), - ] - ) + transforms = Compose([Range("random flip dict")(FlipD(keys="image")), Range()(ToTensorD("image"))]) # Apply transforms output = transforms(input)["image"] @@ -116,8 +102,32 @@ def test_tranform_dict(self, input): np.testing.assert_equal(output.numpy(), output2.numpy()) np.testing.assert_equal(output.numpy(), output3.numpy()) + @parameterized.expand([TEST_CASE_WRAPPER]) + @unittest.skipUnless(HAS_CUPY, "Requires CuPy.") + @unittest.skipUnless(has_cut, "Requires cuCIM transforms.") + @unittest.skipUnless(has_tvt, "Requires torchvision transforms.") + def test_wrapper_tranforms(self, input): + transform_list = [ + ToTensor(), + TorchVision(name="RandomHorizontalFlip", p=1.0), + ToCupy(), + CuCIM(name="image_flip", spatial_axis=-1), + RandCuCIM(name="rand_image_rotate_90", prob=1.0, max_k=1, spatial_axis=(-2, -1)), + ] + + transforms = Compose(transform_list) + transforms_range = Compose([Range()(t) for t in transform_list]) + + # Apply transforms + output = transforms(input) + + # Apply transforms with Range + output_r = transforms_range(input) + + # Check the outputs + np.testing.assert_equal(output.get(), output_r.get()) + @parameterized.expand([TEST_CASE_ARRAY_1]) - @unittest.skipUnless(has_nvtx, "CUDA is required for NVTX Range!") def test_tranform_randomized(self, input): # Compose deterministic and randomized transforms transforms = Compose( @@ -158,13 +168,9 @@ def test_tranform_randomized(self, input): break @parameterized.expand([TEST_CASE_TORCH_0, TEST_CASE_TORCH_1]) - @unittest.skipUnless(has_nvtx, "CUDA is required for NVTX Range!") def test_network(self, input): # Create a network - model = torch.nn.Sequential( - torch.nn.ReLU(), - torch.nn.Sigmoid(), - ) + model = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Sigmoid()) # Forward output = model(input) @@ -189,7 +195,6 @@ def test_network(self, input): np.testing.assert_equal(output.numpy(), output3.numpy()) @parameterized.expand([TEST_CASE_TORCH_0, TEST_CASE_TORCH_1]) - @unittest.skipUnless(has_nvtx, "CUDA is required for NVTX Range!") def test_loss(self, input): # Create a network and loss model = torch.nn.Sigmoid() @@ -219,7 +224,6 @@ def test_loss(self, input): np.testing.assert_equal(output.numpy(), output2.numpy()) np.testing.assert_equal(output.numpy(), output3.numpy()) - @unittest.skipUnless(has_nvtx, "CUDA is required for NVTX Range!") def test_context_manager(self): model = torch.nn.Sigmoid() loss = torch.nn.BCELoss() diff --git a/tests/test_nvtx_transform.py b/tests/test_nvtx_transform.py index 6bcfe00078..01a069ed8a 100644 --- a/tests/test_nvtx_transform.py +++ b/tests/test_nvtx_transform.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -35,29 +35,14 @@ _, has_nvtx = optional_import("torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?") -TEST_CASE_ARRAY_0 = [ - np.random.randn(3, 3), -] -TEST_CASE_ARRAY_1 = [ - np.random.randn(3, 10, 10), -] -TEST_CASE_DICT_0 = [ - {"image": np.random.randn(3, 3)}, -] -TEST_CASE_DICT_1 = [ - {"image": np.random.randn(3, 10, 10)}, -] +TEST_CASE_ARRAY_0 = [np.random.randn(3, 3)] +TEST_CASE_ARRAY_1 = [np.random.randn(3, 10, 10)] +TEST_CASE_DICT_0 = [{"image": np.random.randn(3, 3)}] +TEST_CASE_DICT_1 = [{"image": np.random.randn(3, 10, 10)}] class TestNVTXTransforms(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_ARRAY_0, - TEST_CASE_ARRAY_1, - TEST_CASE_DICT_0, - TEST_CASE_DICT_1, - ] - ) + @parameterized.expand([TEST_CASE_ARRAY_0, TEST_CASE_ARRAY_1, TEST_CASE_DICT_0, TEST_CASE_DICT_1]) @unittest.skipUnless(has_nvtx, "CUDA is required for NVTX!") def test_nvtx_transfroms_alone(self, input): transforms = Compose( diff --git a/tests/test_occlusion_sensitivity.py b/tests/test_occlusion_sensitivity.py index d58359a598..f258dfc557 100644 --- a/tests/test_occlusion_sensitivity.py +++ b/tests/test_occlusion_sensitivity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,60 +21,48 @@ out_channels_2d = 4 out_channels_3d = 3 model_2d = DenseNet121(spatial_dims=2, in_channels=1, out_channels=out_channels_2d).to(device) +model_2d_2c = DenseNet121(spatial_dims=2, in_channels=2, out_channels=out_channels_2d).to(device) model_3d = DenseNet( spatial_dims=3, in_channels=1, out_channels=out_channels_3d, init_features=2, growth_rate=2, block_config=(6,) ).to(device) model_2d.eval() +model_2d_2c.eval() model_3d.eval() # 2D w/ bounding box TEST_CASE_0 = [ - { - "nn_module": model_2d, - }, - { - "x": torch.rand(1, 1, 48, 64).to(device), - "b_box": [-1, -1, 2, 40, 1, 62], - }, + {"nn_module": model_2d}, + {"x": torch.rand(1, 1, 48, 64).to(device), "b_box": [-1, -1, 2, 40, 1, 62]}, (1, 1, 39, 62, out_channels_2d), (1, 1, 39, 62), ] # 3D w/ bounding box and stride TEST_CASE_1 = [ {"nn_module": model_3d, "n_batch": 10, "stride": (2, 1, 2), "mask_size": (16, 15, 14)}, - { - "x": torch.rand(1, 1, 6, 6, 6).to(device), - "b_box": [-1, -1, 2, 3, -1, -1, -1, -1], - }, + {"x": torch.rand(1, 1, 6, 6, 6).to(device), "b_box": [-1, -1, 2, 3, -1, -1, -1, -1]}, (1, 1, 2, 6, 6, out_channels_3d), (1, 1, 2, 6, 6), ] TEST_CASE_FAIL_0 = [ # 2D should fail, since 3 stride values given - { - "nn_module": model_2d, - "n_batch": 10, - "stride": (2, 2, 2), - }, - { - "x": torch.rand(1, 1, 48, 64).to(device), - "b_box": [-1, -1, 2, 3, -1, -1], - }, + {"nn_module": model_2d, "n_batch": 10, "stride": (2, 2, 2)}, + {"x": torch.rand(1, 1, 48, 64).to(device), "b_box": [-1, -1, 2, 3, -1, -1]}, ] TEST_CASE_FAIL_1 = [ # 2D should fail, since stride is not a factor of image size - { - "nn_module": model_2d, - "stride": 3, - }, - { - "x": torch.rand(1, 1, 48, 64).to(device), - }, + {"nn_module": model_2d, "stride": 3}, + {"x": torch.rand(1, 1, 48, 64).to(device)}, +] +TEST_MULTI_CHANNEL = [ + {"nn_module": model_2d_2c, "per_channel": False}, + {"x": torch.rand(1, 2, 48, 64).to(device)}, + (1, 1, 48, 64, out_channels_2d), + (1, 1, 48, 64), ] class TestComputeOcclusionSensitivity(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_MULTI_CHANNEL]) def test_shape(self, init_data, call_data, map_expected_shape, most_prob_expected_shape): occ_sens = OcclusionSensitivity(**init_data) m, most_prob = occ_sens(**call_data) diff --git a/tests/test_one_of.py b/tests/test_one_of.py index d45d0f3f61..29d13d7d0c 100644 --- a/tests/test_one_of.py +++ b/tests/test_one_of.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,12 +12,21 @@ import unittest from copy import deepcopy +import numpy as np from parameterized import parameterized -from monai.transforms import InvertibleTransform, OneOf, Transform +from monai.transforms import ( + InvertibleTransform, + OneOf, + RandScaleIntensityd, + RandShiftIntensityd, + Resized, + TraceableTransform, + Transform, +) from monai.transforms.compose import Compose from monai.transforms.transform import MapTransform -from monai.utils.enums import InverseKeys +from monai.utils.enums import TraceKeys class X(Transform): @@ -93,9 +102,7 @@ def __init__(self, keys): self.inv_fn = lambda x: x - 100 -TESTS = [ - ((X(), Y(), X()), (1, 2, 1), (0.25, 0.5, 0.25)), -] +TESTS = [((X(), Y(), X()), (1, 2, 1), (0.25, 0.5, 0.25))] KEYS = ["x", "y"] TEST_INVERSES = [ @@ -141,30 +148,52 @@ def _match(a, b): _match(p, f) @parameterized.expand(TEST_INVERSES) - def test_inverse(self, transform, should_be_ok): + def test_inverse(self, transform, invertible): data = {k: (i + 1) * 10.0 for i, k in enumerate(KEYS)} fwd_data = transform(data) - if not should_be_ok: - with self.assertRaises(RuntimeError): - transform.inverse(fwd_data) - return - - for k in KEYS: - t = fwd_data[k + InverseKeys.KEY_SUFFIX][-1] - # make sure the OneOf index was stored - self.assertEqual(t[InverseKeys.CLASS_NAME], OneOf.__name__) - # make sure index exists and is in bounds - self.assertTrue(0 <= t[InverseKeys.EXTRA_INFO]["index"] < len(transform)) + + if invertible: + for k in KEYS: + t = fwd_data[TraceableTransform.trace_key(k)][-1] + # make sure the OneOf index was stored + self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__) + # make sure index exists and is in bounds + self.assertTrue(0 <= t[TraceKeys.EXTRA_INFO]["index"] < len(transform)) # call the inverse fwd_inv_data = transform.inverse(fwd_data) - for k in KEYS: - # check transform was removed - self.assertTrue(len(fwd_inv_data[k + InverseKeys.KEY_SUFFIX]) < len(fwd_data[k + InverseKeys.KEY_SUFFIX])) - # check data is same as original (and different from forward) - self.assertEqual(fwd_inv_data[k], data[k]) - self.assertNotEqual(fwd_inv_data[k], fwd_data[k]) + if invertible: + for k in KEYS: + # check transform was removed + self.assertTrue( + len(fwd_inv_data[TraceableTransform.trace_key(k)]) < len(fwd_data[TraceableTransform.trace_key(k)]) + ) + # check data is same as original (and different from forward) + self.assertEqual(fwd_inv_data[k], data[k]) + self.assertNotEqual(fwd_inv_data[k], fwd_data[k]) + else: + # if not invertible, should not change the data + self.assertDictEqual(fwd_data, fwd_inv_data) + + def test_inverse_compose(self): + transform = Compose( + [ + Resized(keys="img", spatial_size=[100, 100, 100]), + OneOf( + [ + RandScaleIntensityd(keys="img", factors=0.5, prob=1.0), + RandShiftIntensityd(keys="img", offsets=0.5, prob=1.0), + ] + ), + ] + ) + transform.set_random_state(seed=0) + result = transform({"img": np.ones((1, 101, 102, 103))}) + + result = transform.inverse(result) + # invert to the original spatial shape + self.assertTupleEqual(result["img"].shape, (1, 101, 102, 103)) def test_one_of(self): p = OneOf((A(), B(), C()), (1, 2, 1)) diff --git a/tests/test_openslide_reader.py b/tests/test_openslide_reader.py deleted file mode 100644 index c0b395fd02..0000000000 --- a/tests/test_openslide_reader.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import unittest -from unittest import skipUnless - -import numpy as np -from numpy.testing import assert_array_equal -from parameterized import parameterized - -from monai.apps.utils import download_url -from monai.data.image_reader import WSIReader -from monai.utils import optional_import - -_, has_osl = optional_import("openslide") - - -FILE_URL = "http://openslide.cs.cmu.edu/download/openslide-testdata/Generic-TIFF/CMU-1.tiff" -FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + os.path.basename(FILE_URL)) - -HEIGHT = 32914 -WIDTH = 46000 - -TEST_CASE_0 = [FILE_PATH, (3, HEIGHT, WIDTH)] - -TEST_CASE_1 = [ - FILE_PATH, - {"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0}, - np.array([[[246], [246]], [[246], [246]], [[246], [246]]]), -] - -TEST_CASE_2 = [ - FILE_PATH, - {"location": (0, 0), "size": (2, 1), "level": 2}, - np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), -] - -TEST_CASE_3 = [ - FILE_PATH, - { - "location": (0, 0), - "size": (8, 8), - "level": 2, - "grid_shape": (2, 1), - "patch_size": 2, - }, - np.array( - [ - [[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[239, 239], [239, 239]]], - [[[242, 242], [242, 243]], [[242, 242], [242, 243]], [[242, 242], [242, 243]]], - ] - ), -] - -TEST_CASE_4 = [ - FILE_PATH, - { - "location": (0, 0), - "size": (8, 8), - "level": 2, - "grid_shape": (2, 1), - "patch_size": 1, - }, - np.array([[[[239]], [[239]], [[239]]], [[[243]], [[243]], [[243]]]]), -] - - -class TestOpenSlideReader(unittest.TestCase): - @skipUnless(has_osl, "Requires OpenSlide") - def setUp(self): - download_url(FILE_URL, FILE_PATH, "5a3cfd4fd725c50578ddb80b517b759f") - - @parameterized.expand([TEST_CASE_0]) - def test_read_whole_image(self, file_path, expected_shape): - reader = WSIReader("OpenSlide") - img_obj = reader.read(file_path) - img = reader.get_data(img_obj)[0] - self.assertTupleEqual(img.shape, expected_shape) - - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_read_region(self, file_path, patch_info, expected_img): - reader = WSIReader("OpenSlide") - img_obj = reader.read(file_path) - img = reader.get_data(img_obj, **patch_info)[0] - self.assertTupleEqual(img.shape, expected_img.shape) - self.assertIsNone(assert_array_equal(img, expected_img)) - - @parameterized.expand([TEST_CASE_3, TEST_CASE_4]) - def test_read_patches(self, file_path, patch_info, expected_img): - reader = WSIReader("OpenSlide") - img_obj = reader.read(file_path) - img = reader.get_data(img_obj, **patch_info)[0] - self.assertTupleEqual(img.shape, expected_img.shape) - self.assertIsNone(assert_array_equal(img, expected_img)) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_optim_novograd.py b/tests/test_optim_novograd.py index c76501cd6f..0cf4c35cb6 100644 --- a/tests/test_optim_novograd.py +++ b/tests/test_optim_novograd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -38,37 +38,19 @@ def build_test_cases(data): return test_cases -TEST_CASES_ALL = build_test_cases( # normal parameters - [ - torch.randn(10, 5), - torch.randn(10), - torch.randn(5), - ] -) +TEST_CASES_ALL = build_test_cases([torch.randn(10, 5), torch.randn(10), torch.randn(5)]) # normal parameters TEST_CASES_ALL += build_test_cases( # non-contiguous parameters - [ - torch.randn(10, 5, 2)[..., 0], - torch.randn(10, 2)[..., 0], - torch.randn(5), - ] + [torch.randn(10, 5, 2)[..., 0], torch.randn(10, 2)[..., 0], torch.randn(5)] ) if torch.cuda.is_available(): TEST_CASES_ALL += build_test_cases( # gpu parameters - [ - torch.randn(10, 5).cuda(), - torch.randn(10).cuda(), - torch.randn(5).cuda(), - ] + [torch.randn(10, 5).cuda(), torch.randn(10).cuda(), torch.randn(5).cuda()] ) if torch.cuda.device_count() > 1: TEST_CASES_ALL += build_test_cases( # multi-gpu parameters - [ - torch.randn(10, 5).cuda(0), - torch.randn(10).cuda(1), - torch.randn(5).cuda(0), - ] + [torch.randn(10, 5).cuda(0), torch.randn(10).cuda(1), torch.randn(5).cuda(0)] ) diff --git a/tests/test_optional_import.py b/tests/test_optional_import.py index 05584f9f9c..b87ebf8909 100644 --- a/tests/test_optional_import.py +++ b/tests/test_optional_import.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_ori_ras_lps.py b/tests/test_ori_ras_lps.py new file mode 100644 index 0000000000..4ed223bf5b --- /dev/null +++ b/tests/test_ori_ras_lps.py @@ -0,0 +1,46 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.data.utils import orientation_ras_lps +from tests.utils import TEST_NDARRAYS, assert_allclose + +TEST_CASES_AFFINE = [] +for p in TEST_NDARRAYS: + case_1d = p([[1.0, 0.0], [1.0, 1.0]]), p([[-1.0, 0.0], [1.0, 1.0]]) + TEST_CASES_AFFINE.append(case_1d) + case_2d_1 = p([[1.0, 0.0, 1.0], [1.0, 1.0, 1.0]]), p([[-1.0, 0.0, -1.0], [1.0, 1.0, 1.0]]) + TEST_CASES_AFFINE.append(case_2d_1) + case_2d_2 = p([[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]), p( + [[-1.0, 0.0, -1.0], [0.0, -1.0, -1.0], [1.0, 1.0, 1.0]] + ) + TEST_CASES_AFFINE.append(case_2d_2) + case_3d = p([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 2.0], [1.0, 1.0, 1.0, 3.0]]), p( + [[-1.0, 0.0, -1.0, -1.0], [0.0, -1.0, -1.0, -2.0], [1.0, 1.0, 1.0, 3.0]] + ) + TEST_CASES_AFFINE.append(case_3d) + case_4d = p(np.ones((5, 5))), p([[-1] * 5, [-1] * 5, [1] * 5, [1] * 5, [1] * 5]) + TEST_CASES_AFFINE.append(case_4d) + + +class TestITKWriter(unittest.TestCase): + @parameterized.expand(TEST_CASES_AFFINE) + def test_ras_to_lps(self, param, expected): + assert_allclose(orientation_ras_lps(param), expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_orientation.py b/tests/test_orientation.py index aa7f33a469..2b749dabad 100644 --- a/tests/test_orientation.py +++ b/tests/test_orientation.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,90 +16,149 @@ from parameterized import parameterized from monai.transforms import Orientation, create_rotate, create_translate +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + p, + {"axcodes": "RAS"}, + np.arange(12).reshape((2, 1, 2, 3)), + {"affine": np.eye(4)}, + np.arange(12).reshape((2, 1, 2, 3)), + "RAS", + ] + ) + TESTS.append( + [ + p, + {"axcodes": "ALS"}, + np.arange(12).reshape((2, 1, 2, 3)), + {"affine": np.diag([-1, -1, 1, 1])}, + np.array([[[[3, 4, 5]], [[0, 1, 2]]], [[[9, 10, 11]], [[6, 7, 8]]]]), + "ALS", + ] + ) + TESTS.append( + [ + p, + {"axcodes": "RAS"}, + np.arange(12).reshape((2, 1, 2, 3)), + {"affine": np.diag([-1, -1, 1, 1])}, + np.array([[[[3, 4, 5], [0, 1, 2]]], [[[9, 10, 11], [6, 7, 8]]]]), + "RAS", + ] + ) + TESTS.append( + [ + p, + {"axcodes": "AL"}, + np.arange(6).reshape((2, 1, 3)), + {"affine": np.eye(3)}, + np.array([[[0], [1], [2]], [[3], [4], [5]]]), + "AL", + ] + ) + TESTS.append( + [ + p, + {"axcodes": "L"}, + np.arange(6).reshape((2, 3)), + {"affine": np.eye(2)}, + np.array([[2, 1, 0], [5, 4, 3]]), + "L", + ] + ) + TESTS.append( + [ + p, + {"axcodes": "L"}, + np.arange(6).reshape((2, 3)), + {"affine": np.eye(2)}, + np.array([[2, 1, 0], [5, 4, 3]]), + "L", + ] + ) + TESTS.append( + [ + p, + {"axcodes": "L"}, + np.arange(6).reshape((2, 3)), + {"affine": np.diag([-1, 1])}, + np.arange(6).reshape((2, 3)), + "L", + ] + ) + TESTS.append( + [ + p, + {"axcodes": "LPS"}, + np.arange(12).reshape((2, 1, 2, 3)), + { + "affine": create_translate(3, (10, 20, 30)) + @ create_rotate(3, (np.pi / 2, np.pi / 2, np.pi / 4)) + @ np.diag([-1, 1, 1, 1]) + }, + np.array([[[[2, 5]], [[1, 4]], [[0, 3]]], [[[8, 11]], [[7, 10]], [[6, 9]]]]), + "LPS", + ] + ) + TESTS.append( + [ + p, + {"as_closest_canonical": True}, + np.arange(12).reshape((2, 1, 2, 3)), + { + "affine": create_translate(3, (10, 20, 30)) + @ create_rotate(3, (np.pi / 2, np.pi / 2, np.pi / 4)) + @ np.diag([-1, 1, 1, 1]) + }, + np.array([[[[0, 3]], [[1, 4]], [[2, 5]]], [[[6, 9]], [[7, 10]], [[8, 11]]]]), + "RAS", + ] + ) + TESTS.append( + [ + p, + {"as_closest_canonical": True}, + np.arange(6).reshape((1, 2, 3)), + {"affine": create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])}, + np.array([[[3, 0], [4, 1], [5, 2]]]), + "RA", + ] + ) + TESTS.append( + [ + p, + {"axcodes": "LP"}, + np.arange(6).reshape((1, 2, 3)), + {"affine": create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])}, + np.array([[[2, 5], [1, 4], [0, 3]]]), + "LP", + ] + ) + TESTS.append( + [ + p, + {"axcodes": "LPID", "labels": tuple(zip("LPIC", "RASD"))}, + np.zeros((1, 2, 3, 4, 5)), + {"affine": np.diag([-1, -0.2, -1, 1, 1])}, + np.zeros((1, 2, 3, 4, 5)), + "LPID", + ] + ) + TESTS.append( + [ + p, + {"as_closest_canonical": True, "labels": tuple(zip("LPIC", "RASD"))}, + np.zeros((1, 2, 3, 4, 5)), + {"affine": np.diag([-1, -0.2, -1, 1, 1])}, + np.zeros((1, 2, 3, 4, 5)), + "RASD", + ] + ) -TEST_CASES = [ - [ - {"axcodes": "RAS"}, - np.arange(12).reshape((2, 1, 2, 3)), - {"affine": np.eye(4)}, - np.arange(12).reshape((2, 1, 2, 3)), - "RAS", - ], - [ - {"axcodes": "ALS"}, - np.arange(12).reshape((2, 1, 2, 3)), - {"affine": np.diag([-1, -1, 1, 1])}, - np.array([[[[3, 4, 5]], [[0, 1, 2]]], [[[9, 10, 11]], [[6, 7, 8]]]]), - "ALS", - ], - [ - {"axcodes": "RAS"}, - np.arange(12).reshape((2, 1, 2, 3)), - {"affine": np.diag([-1, -1, 1, 1])}, - np.array([[[[3, 4, 5], [0, 1, 2]]], [[[9, 10, 11], [6, 7, 8]]]]), - "RAS", - ], - [ - {"axcodes": "AL"}, - np.arange(6).reshape((2, 1, 3)), - {"affine": np.eye(3)}, - np.array([[[0], [1], [2]], [[3], [4], [5]]]), - "AL", - ], - [{"axcodes": "L"}, np.arange(6).reshape((2, 3)), {"affine": np.eye(2)}, np.array([[2, 1, 0], [5, 4, 3]]), "L"], - [{"axcodes": "L"}, np.arange(6).reshape((2, 3)), {"affine": np.eye(2)}, np.array([[2, 1, 0], [5, 4, 3]]), "L"], - [{"axcodes": "L"}, np.arange(6).reshape((2, 3)), {"affine": np.diag([-1, 1])}, np.arange(6).reshape((2, 3)), "L"], - [ - {"axcodes": "LPS"}, - np.arange(12).reshape((2, 1, 2, 3)), - { - "affine": create_translate(3, (10, 20, 30)) - @ create_rotate(3, (np.pi / 2, np.pi / 2, np.pi / 4)) - @ np.diag([-1, 1, 1, 1]) - }, - np.array([[[[2, 5]], [[1, 4]], [[0, 3]]], [[[8, 11]], [[7, 10]], [[6, 9]]]]), - "LPS", - ], - [ - {"as_closest_canonical": True}, - np.arange(12).reshape((2, 1, 2, 3)), - { - "affine": create_translate(3, (10, 20, 30)) - @ create_rotate(3, (np.pi / 2, np.pi / 2, np.pi / 4)) - @ np.diag([-1, 1, 1, 1]) - }, - np.array([[[[0, 3]], [[1, 4]], [[2, 5]]], [[[6, 9]], [[7, 10]], [[8, 11]]]]), - "RAS", - ], - [ - {"as_closest_canonical": True}, - np.arange(6).reshape((1, 2, 3)), - {"affine": create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])}, - np.array([[[3, 0], [4, 1], [5, 2]]]), - "RA", - ], - [ - {"axcodes": "LP"}, - np.arange(6).reshape((1, 2, 3)), - {"affine": create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])}, - np.array([[[2, 5], [1, 4], [0, 3]]]), - "LP", - ], - [ - {"axcodes": "LPID", "labels": tuple(zip("LPIC", "RASD"))}, - np.zeros((1, 2, 3, 4, 5)), - {"affine": np.diag([-1, -0.2, -1, 1, 1])}, - np.zeros((1, 2, 3, 4, 5)), - "LPID", - ], - [ - {"as_closest_canonical": True, "labels": tuple(zip("LPIC", "RASD"))}, - np.zeros((1, 2, 3, 4, 5)), - {"affine": np.diag([-1, -0.2, -1, 1, 1])}, - np.zeros((1, 2, 3, 4, 5)), - "RASD", - ], -] ILL_CASES = [ # no axcodes or as_cloest_canonical @@ -110,14 +169,15 @@ class TestOrientationCase(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_ornt(self, init_param, img, data_param, expected_data, expected_code): + @parameterized.expand(TESTS) + def test_ornt(self, in_type, init_param, img, data_param, expected_data, expected_code): + img = in_type(img) ornt = Orientation(**init_param) res = ornt(img, **data_param) if not isinstance(res, tuple): - np.testing.assert_allclose(res, expected_data) + assert_allclose(res, in_type(expected_data)) return - np.testing.assert_allclose(res[0], expected_data) + assert_allclose(res[0], in_type(expected_data)) original_affine = data_param["affine"] np.testing.assert_allclose(original_affine, res[1]) new_code = nib.orientations.aff2axcodes(res[2], labels=ornt.labels) diff --git a/tests/test_orientationd.py b/tests/test_orientationd.py index 452172ce9b..a4b953c8b5 100644 --- a/tests/test_orientationd.py +++ b/tests/test_orientationd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,77 +15,80 @@ import numpy as np from monai.transforms import Orientationd +from monai.utils.enums import PostFix +from tests.utils import TEST_NDARRAYS class TestOrientationdCase(unittest.TestCase): def test_orntd(self): - data = {"seg": np.ones((2, 1, 2, 3)), "seg_meta_dict": {"affine": np.eye(4)}} + data = {"seg": np.ones((2, 1, 2, 3)), PostFix.meta("seg"): {"affine": np.eye(4)}} ornt = Orientationd(keys="seg", axcodes="RAS") res = ornt(data) np.testing.assert_allclose(res["seg"].shape, (2, 1, 2, 3)) - code = nib.aff2axcodes(res["seg_meta_dict"]["affine"], ornt.ornt_transform.labels) + code = nib.aff2axcodes(res[PostFix.meta("seg")]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("R", "A", "S")) def test_orntd_3d(self): - data = { - "seg": np.ones((2, 1, 2, 3)), - "img": np.ones((2, 1, 2, 3)), - "seg_meta_dict": {"affine": np.eye(4)}, - "img_meta_dict": {"affine": np.eye(4)}, - } - ornt = Orientationd(keys=("img", "seg"), axcodes="PLI") - res = ornt(data) - np.testing.assert_allclose(res["img"].shape, (2, 2, 1, 3)) - np.testing.assert_allclose(res["seg"].shape, (2, 2, 1, 3)) - code = nib.aff2axcodes(res["seg_meta_dict"]["affine"], ornt.ornt_transform.labels) - self.assertEqual(code, ("P", "L", "I")) - code = nib.aff2axcodes(res["img_meta_dict"]["affine"], ornt.ornt_transform.labels) - self.assertEqual(code, ("P", "L", "I")) + for p in TEST_NDARRAYS: + data = { + "seg": p(np.ones((2, 1, 2, 3))), + "img": p(np.ones((2, 1, 2, 3))), + PostFix.meta("seg"): {"affine": np.eye(4)}, + PostFix.meta("img"): {"affine": np.eye(4)}, + } + ornt = Orientationd(keys=("img", "seg"), axcodes="PLI") + res = ornt(data) + np.testing.assert_allclose(res["img"].shape, (2, 2, 1, 3)) + np.testing.assert_allclose(res["seg"].shape, (2, 2, 1, 3)) + code = nib.aff2axcodes(res[PostFix.meta("seg")]["affine"], ornt.ornt_transform.labels) + self.assertEqual(code, ("P", "L", "I")) + code = nib.aff2axcodes(res[PostFix.meta("img")]["affine"], ornt.ornt_transform.labels) + self.assertEqual(code, ("P", "L", "I")) def test_orntd_2d(self): data = { "seg": np.ones((2, 1, 3)), "img": np.ones((2, 1, 3)), - "seg_meta_dict": {"affine": np.eye(4)}, - "img_meta_dict": {"affine": np.eye(4)}, + PostFix.meta("seg"): {"affine": np.eye(4)}, + PostFix.meta("img"): {"affine": np.eye(4)}, } ornt = Orientationd(keys=("img", "seg"), axcodes="PLI") res = ornt(data) np.testing.assert_allclose(res["img"].shape, (2, 3, 1)) - code = nib.aff2axcodes(res["seg_meta_dict"]["affine"], ornt.ornt_transform.labels) + code = nib.aff2axcodes(res[PostFix.meta("seg")]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("P", "L", "S")) - code = nib.aff2axcodes(res["img_meta_dict"]["affine"], ornt.ornt_transform.labels) + code = nib.aff2axcodes(res[PostFix.meta("img")]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("P", "L", "S")) def test_orntd_1d(self): data = { "seg": np.ones((2, 3)), "img": np.ones((2, 3)), - "seg_meta_dict": {"affine": np.eye(4)}, - "img_meta_dict": {"affine": np.eye(4)}, + PostFix.meta("seg"): {"affine": np.eye(4)}, + PostFix.meta("img"): {"affine": np.eye(4)}, } ornt = Orientationd(keys=("img", "seg"), axcodes="L") res = ornt(data) np.testing.assert_allclose(res["img"].shape, (2, 3)) - code = nib.aff2axcodes(res["seg_meta_dict"]["affine"], ornt.ornt_transform.labels) + code = nib.aff2axcodes(res[PostFix.meta("seg")]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("L", "A", "S")) - code = nib.aff2axcodes(res["img_meta_dict"]["affine"], ornt.ornt_transform.labels) + code = nib.aff2axcodes(res[PostFix.meta("img")]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("L", "A", "S")) def test_orntd_canonical(self): data = { "seg": np.ones((2, 1, 2, 3)), "img": np.ones((2, 1, 2, 3)), - "seg_meta_dict": {"affine": np.eye(4)}, - "img_meta_dict": {"affine": np.eye(4)}, + PostFix.meta("seg"): {"affine": np.eye(4)}, + PostFix.meta("img"): {"affine": np.eye(4)}, } ornt = Orientationd(keys=("img", "seg"), as_closest_canonical=True) res = ornt(data) np.testing.assert_allclose(res["img"].shape, (2, 1, 2, 3)) np.testing.assert_allclose(res["seg"].shape, (2, 1, 2, 3)) - code = nib.aff2axcodes(res["seg_meta_dict"]["affine"], ornt.ornt_transform.labels) + code = nib.aff2axcodes(res[PostFix.meta("seg")]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("R", "A", "S")) - code = nib.aff2axcodes(res["img_meta_dict"]["affine"], ornt.ornt_transform.labels) + code = nib.aff2axcodes(res[PostFix.meta("img")]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("R", "A", "S")) def test_orntd_no_metadata(self): @@ -93,7 +96,7 @@ def test_orntd_no_metadata(self): ornt = Orientationd(keys="seg", axcodes="RAS") res = ornt(data) np.testing.assert_allclose(res["seg"].shape, (2, 1, 2, 3)) - code = nib.aff2axcodes(res["seg_meta_dict"]["affine"], ornt.ornt_transform.labels) + code = nib.aff2axcodes(res[PostFix.meta("seg")]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("R", "A", "S")) diff --git a/tests/test_p3d_block.py b/tests/test_p3d_block.py new file mode 100644 index 0000000000..62b7098dcd --- /dev/null +++ b/tests/test_p3d_block.py @@ -0,0 +1,75 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.blocks.dints_block import P3DActiConvNormBlock + +TEST_CASES_3D = [ + [ + {"in_channel": 32, "out_channel": 16, "kernel_size": 3, "padding": 0, "mode": 0}, + (7, 32, 16, 32, 8), + (7, 16, 14, 30, 6), + ], + [ + {"in_channel": 32, "out_channel": 16, "kernel_size": 3, "padding": 1, "mode": 0}, # check padding + (7, 32, 16, 32, 8), + (7, 16, 16, 32, 8), + ], + [ + {"in_channel": 32, "out_channel": 16, "kernel_size": 3, "padding": 0, "mode": 1}, + (7, 32, 16, 32, 8), + (7, 16, 14, 30, 6), + ], + [ + { + "in_channel": 32, + "out_channel": 16, + "kernel_size": 3, + "padding": 0, + "mode": 2, + "act_name": ("leakyrelu", {"inplace": True, "negative_slope": 0.2}), + }, + (7, 32, 16, 32, 8), + (7, 16, 14, 30, 6), + ], + [ + { + "in_channel": 32, + "out_channel": 16, + "kernel_size": 4, + "padding": 0, + "mode": 0, + "norm_name": ("INSTANCE", {"affine": True}), + }, + (7, 32, 16, 32, 8), + (7, 16, 13, 29, 5), + ], +] + + +class TestP3D(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D) + def test_3d(self, input_param, input_shape, expected_shape): + net = P3DActiConvNormBlock(**input_param) + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill(self): + with self.assertRaises(ValueError): + P3DActiConvNormBlock(in_channel=32, out_channel=16, kernel_size=3, padding=0, mode=3) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py index a8c544558f..530e5f86a3 100644 --- a/tests/test_pad_collation.py +++ b/tests/test_pad_collation.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,6 +20,7 @@ from monai.data import CacheDataset, DataLoader from monai.data.utils import decollate_batch, pad_list_data_collate from monai.transforms import ( + Compose, PadListDataCollate, RandRotate, RandRotate90, @@ -29,24 +30,26 @@ RandSpatialCropd, RandZoom, RandZoomd, + ToTensor, + ToTensord, ) from monai.utils import set_determinism TESTS: List[Tuple] = [] for pad_collate in [ - lambda x: pad_list_data_collate(batch=x, method="end", mode="constant", constant_values=1), - PadListDataCollate(method="end", mode="constant", constant_values=1), + lambda x: pad_list_data_collate(batch=x, method="end", mode="constant"), + PadListDataCollate(method="end", mode="constant"), ]: TESTS.append((dict, pad_collate, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) - TESTS.append((dict, pad_collate, RandRotated("image", prob=1, range_x=np.pi, keep_size=False))) + TESTS.append((dict, pad_collate, RandRotated("image", prob=1, range_x=np.pi, keep_size=False, dtype=np.float64))) TESTS.append((dict, pad_collate, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) - TESTS.append((dict, pad_collate, RandRotate90d("image", prob=1, max_k=2))) + TESTS.append((dict, pad_collate, Compose([RandRotate90d("image", prob=1, max_k=2), ToTensord("image")]))) TESTS.append((list, pad_collate, RandSpatialCrop(roi_size=[8, 7], random_size=True))) - TESTS.append((list, pad_collate, RandRotate(prob=1, range_x=np.pi, keep_size=False))) + TESTS.append((list, pad_collate, RandRotate(prob=1, range_x=np.pi, keep_size=False, dtype=np.float64))) TESTS.append((list, pad_collate, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) - TESTS.append((list, pad_collate, RandRotate90(prob=1, max_k=2))) + TESTS.append((list, pad_collate, Compose([RandRotate90(prob=1, max_k=2), ToTensor()]))) class _Dataset(torch.utils.data.Dataset): diff --git a/tests/test_parallel_execution.py b/tests/test_parallel_execution.py index c4115d21ef..6186e73f68 100644 --- a/tests/test_parallel_execution.py +++ b/tests/test_parallel_execution.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_parallel_execution_dist.py b/tests/test_parallel_execution_dist.py new file mode 100644 index 0000000000..f067b71d14 --- /dev/null +++ b/tests/test_parallel_execution_dist.py @@ -0,0 +1,45 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +import torch.distributed as dist + +from monai.engines import create_multigpu_supervised_trainer +from tests.utils import DistCall, DistTestCase, skip_if_no_cuda + + +def fake_loss(y_pred, y): + return (y_pred[0] + y).sum() + + +def fake_data_stream(): + while True: + yield torch.rand((10, 1, 64, 64)), torch.rand((10, 1, 64, 64)) + + +class DistributedTestParallelExecution(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) + @skip_if_no_cuda + def test_distributed(self): + device = torch.device(f"cuda:{dist.get_rank()}") + net = torch.nn.Conv2d(1, 1, 3, padding=1).to(device) + opt = torch.optim.Adam(net.parameters(), 1e-3) + + trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, [device], distributed=True) + trainer.run(fake_data_stream(), 2, 2) + # assert the trainer output is loss value + self.assertTrue(isinstance(trainer.state.output, float)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_partition_dataset.py b/tests/test_partition_dataset.py index a954bfae91..687cf8df34 100644 --- a/tests/test_partition_dataset.py +++ b/tests/test_partition_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -117,16 +117,7 @@ class TestPartitionDataset(unittest.TestCase): @parameterized.expand( - [ - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - TEST_CASE_7, - TEST_CASE_8, - ] + [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8] ) def test_value(self, input_param, result): self.assertListEqual(partition_dataset(**input_param), result) diff --git a/tests/test_partition_dataset_classes.py b/tests/test_partition_dataset_classes.py index 3aef47107a..4ed283bdd7 100644 --- a/tests/test_partition_dataset_classes.py +++ b/tests/test_partition_dataset_classes.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_patch_dataset.py b/tests/test_patch_dataset.py index 4f6e9a25fd..a46c117b75 100644 --- a/tests/test_patch_dataset.py +++ b/tests/test_patch_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -29,7 +29,7 @@ def test_shape(self): test_dataset = ["vwxyz", "hello", "world"] n_per_image = len(test_dataset[0]) - result = PatchDataset(dataset=test_dataset, patch_func=identity, samples_per_image=n_per_image) + result = PatchDataset(data=test_dataset, patch_func=identity, samples_per_image=n_per_image) output = [] n_workers = 0 if sys.platform == "win32" else 2 @@ -50,7 +50,7 @@ def test_loading_array(self): patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0) image_ds = Dataset(images, transform=patch_intensity) # patch level - ds = PatchDataset(dataset=image_ds, patch_func=sampler, samples_per_image=n_samples, transform=patch_intensity) + ds = PatchDataset(data=image_ds, patch_func=sampler, samples_per_image=n_samples, transform=patch_intensity) np.testing.assert_equal(len(ds), n_samples * len(images)) # use the patch dataset, length: len(images) x samplers_per_image @@ -59,7 +59,7 @@ def test_loading_array(self): np.testing.assert_allclose( item[0], np.array( - [[[1.779992, 2.779992, 3.779992], [5.779992, 6.779992, 7.779992], [9.779992, 10.779992, 11.779992]]] + [[[1.338681, 2.338681, 3.338681], [5.338681, 6.338681, 7.338681], [9.338681, 10.338681, 11.338681]]] ), rtol=1e-5, ) @@ -71,9 +71,9 @@ def test_loading_array(self): np.array( [ [ - [5.025618, 6.025618, 7.025618], - [9.025618, 10.025618, 11.025618], - [13.025618, 14.025618, 15.025618], + [4.957847, 5.957847, 6.957847], + [8.957847, 9.957847, 10.957847], + [12.957847, 13.957847, 14.957847], ] ] ), diff --git a/tests/test_patch_wsi_dataset.py b/tests/test_patch_wsi_dataset.py index f775f28376..c351ce5f79 100644 --- a/tests/test_patch_wsi_dataset.py +++ b/tests/test_patch_wsi_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,28 +18,27 @@ from parameterized import parameterized from monai.apps.pathology.data import PatchWSIDataset -from monai.apps.utils import download_url from monai.utils import optional_import +from tests.utils import download_url_or_skip_test, testing_data_config -_, has_cim = optional_import("cucim") +_cucim, has_cim = optional_import("cucim") +has_cim = has_cim and hasattr(_cucim, "CuImage") _, has_osl = optional_import("openslide") -FILE_URL = "http://openslide.cs.cmu.edu/download/openslide-testdata/Generic-TIFF/CMU-1.tiff" -FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + os.path.basename(FILE_URL)) +FILE_KEY = "wsi_img" +FILE_URL = testing_data_config("images", FILE_KEY, "url") +base_name, extension = os.path.basename(f"{FILE_URL}"), ".tiff" +FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + base_name + extension) TEST_CASE_0 = [ { - "data": [ - {"image": FILE_PATH, "location": [0, 0], "label": [1]}, - ], + "data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}], "region_size": (1, 1), "grid_shape": (1, 1), "patch_size": 1, "image_reader_name": "cuCIM", }, - [ - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}, - ], + [{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}], ] TEST_CASE_1 = [ @@ -60,47 +59,35 @@ TEST_CASE_2 = [ { - "data": [ - {"image": FILE_PATH, "location": [0, 0], "label": [1]}, - ], + "data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}], "region_size": 1, "grid_shape": 1, "patch_size": 1, "image_reader_name": "cuCIM", }, - [ - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}, - ], + [{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}], ] TEST_CASE_3 = [ { - "data": [ - {"image": FILE_PATH, "location": [0, 0], "label": [[[0, 1], [1, 0]]]}, - ], + "data": [{"image": FILE_PATH, "location": [0, 0], "label": [[[0, 1], [1, 0]]]}], "region_size": 1, "grid_shape": 1, "patch_size": 1, "image_reader_name": "cuCIM", }, - [ - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0, 1], [1, 0]]])}, - ], + [{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0, 1], [1, 0]]])}], ] TEST_CASE_OPENSLIDE_0 = [ { - "data": [ - {"image": FILE_PATH, "location": [0, 0], "label": [1]}, - ], + "data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}], "region_size": (1, 1), "grid_shape": (1, 1), "patch_size": 1, "image_reader_name": "OpenSlide", }, - [ - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}, - ], + [{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}], ] TEST_CASE_OPENSLIDE_1 = [ @@ -122,16 +109,11 @@ class TestPatchWSIDataset(unittest.TestCase): def setUp(self): - download_url(FILE_URL, FILE_PATH, "5a3cfd4fd725c50578ddb80b517b759f") - - @parameterized.expand( - [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - ] - ) + hash_type = testing_data_config("images", FILE_KEY, "hash_type") + hash_val = testing_data_config("images", FILE_KEY, "hash_val") + download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val) + + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) @skipUnless(has_cim, "Requires CuCIM") def test_read_patches_cucim(self, input_parameters, expected): dataset = PatchWSIDataset(**input_parameters) @@ -142,12 +124,7 @@ def test_read_patches_cucim(self, input_parameters, expected): self.assertIsNone(assert_array_equal(samples[i]["label"], expected[i]["label"])) self.assertIsNone(assert_array_equal(samples[i]["image"], expected[i]["image"])) - @parameterized.expand( - [ - TEST_CASE_OPENSLIDE_0, - TEST_CASE_OPENSLIDE_1, - ] - ) + @parameterized.expand([TEST_CASE_OPENSLIDE_0, TEST_CASE_OPENSLIDE_1]) @skipUnless(has_osl, "Requires OpenSlide") def test_read_patches_openslide(self, input_parameters, expected): dataset = PatchWSIDataset(**input_parameters) diff --git a/tests/test_patchembedding.py b/tests/test_patchembedding.py index 6c9ac78a99..4af2b47ba5 100644 --- a/tests/test_patchembedding.py +++ b/tests/test_patchembedding.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_pathology_he_stain.py b/tests/test_pathology_he_stain.py index 1d74f485e9..7b884315fc 100644 --- a/tests/test_pathology_he_stain.py +++ b/tests/test_pathology_he_stain.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -73,12 +73,7 @@ class TestExtractHEStains(unittest.TestCase): @parameterized.expand( - [ - NEGATIVE_VALUE_TEST_CASE, - INVALID_VALUE_TEST_CASE, - EXTRACT_STAINS_TEST_CASE_0, - EXTRACT_STAINS_TEST_CASE_1, - ] + [NEGATIVE_VALUE_TEST_CASE, INVALID_VALUE_TEST_CASE, EXTRACT_STAINS_TEST_CASE_0, EXTRACT_STAINS_TEST_CASE_1] ) def test_transparent_image(self, image): """ @@ -112,13 +107,7 @@ def test_identical_result_vectors(self, image): result = ExtractHEStains()(image) np.testing.assert_array_equal(result[:, 0], result[:, 1]) - @parameterized.expand( - [ - EXTRACT_STAINS_TEST_CASE_00, - EXTRACT_STAINS_TEST_CASE_4, - EXTRACT_STAINS_TEST_CASE_5, - ] - ) + @parameterized.expand([EXTRACT_STAINS_TEST_CASE_00, EXTRACT_STAINS_TEST_CASE_4, EXTRACT_STAINS_TEST_CASE_5]) def test_result_value(self, image, expected_data): """ Test that an input image returns an expected stain matrix. @@ -156,12 +145,7 @@ def test_result_value(self, image, expected_data): class TestNormalizeHEStains(unittest.TestCase): @parameterized.expand( - [ - NEGATIVE_VALUE_TEST_CASE, - INVALID_VALUE_TEST_CASE, - NORMALIZE_STAINS_TEST_CASE_0, - NORMALIZE_STAINS_TEST_CASE_1, - ] + [NEGATIVE_VALUE_TEST_CASE, INVALID_VALUE_TEST_CASE, NORMALIZE_STAINS_TEST_CASE_0, NORMALIZE_STAINS_TEST_CASE_1] ) def test_transparent_image(self, image): """ diff --git a/tests/test_pathology_he_stain_dict.py b/tests/test_pathology_he_stain_dict.py index 8d51579cb2..7d6c3ffb75 100644 --- a/tests/test_pathology_he_stain_dict.py +++ b/tests/test_pathology_he_stain_dict.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -100,13 +100,7 @@ def test_identical_result_vectors(self, image): result = ExtractHEStainsD([key])({key: image}) np.testing.assert_array_equal(result[key][:, 0], result[key][:, 1]) - @parameterized.expand( - [ - EXTRACT_STAINS_TEST_CASE_00, - EXTRACT_STAINS_TEST_CASE_4, - EXTRACT_STAINS_TEST_CASE_5, - ] - ) + @parameterized.expand([EXTRACT_STAINS_TEST_CASE_00, EXTRACT_STAINS_TEST_CASE_4, EXTRACT_STAINS_TEST_CASE_5]) def test_result_value(self, image, expected_data): """ Test that an input image returns an expected stain matrix. diff --git a/tests/test_pathology_prob_nms.py b/tests/test_pathology_prob_nms.py index 223b136ea7..3399e33afa 100644 --- a/tests/test_pathology_prob_nms.py +++ b/tests/test_pathology_prob_nms.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -41,12 +41,7 @@ class TestPathologyProbNMS(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASES_2D, - TEST_CASES_3D, - ] - ) + @parameterized.expand([TEST_CASES_2D, TEST_CASES_3D]) def test_output(self, class_args, call_args, probs_map, expected): nms = PathologyProbNMS(**class_args) output = nms(probs_map, **call_args) diff --git a/tests/test_persistentdataset.py b/tests/test_persistentdataset.py index 8446f566ef..17575c79f7 100644 --- a/tests/test_persistentdataset.py +++ b/tests/test_persistentdataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,6 +10,7 @@ # limitations under the License. import os +import pickle import tempfile import unittest @@ -56,7 +57,13 @@ def test_cache(self): items = [[list(range(i))] for i in range(5)] with tempfile.TemporaryDirectory() as tempdir: - ds = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir) + ds = PersistentDataset( + data=items, + transform=_InplaceXform(), + cache_dir=tempdir, + pickle_module="pickle", + pickle_protocol=pickle.HIGHEST_PROTOCOL, + ) self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) ds1 = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir) self.assertEqual(list(ds1), list(ds)) diff --git a/tests/test_persistentdataset_dist.py b/tests/test_persistentdataset_dist.py index d45bba03e5..20dcb2c264 100644 --- a/tests/test_persistentdataset_dist.py +++ b/tests/test_persistentdataset_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_phl_cpu.py b/tests/test_phl_cpu.py index 31e28bd39d..d479f554b4 100644 --- a/tests/test_phl_cpu.py +++ b/tests/test_phl_cpu.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -42,7 +42,7 @@ # Batch 0 [ # Channel 0 - [1, 0.2, 0.5, 0, 1], + [1, 0.2, 0.5, 0, 1] ], # Batch 1 [ @@ -79,15 +79,15 @@ [0, 0, 0, 0, 1], # Channel 2 [0, 0, 1, 0, 0], - ], + ] ], # Features [ # Batch 0 [ # Channel 0 - [1, 0.2, 0.5, 0.2, 1], - ], + [1, 0.2, 0.5, 0.2, 1] + ] ], # Expected [ @@ -99,7 +99,7 @@ [0.229572, 0.182884, 0.202637, 0.182884, 0.229572], # Channel 2 [0.201235, 0.208194, 0.205409, 0.208194, 0.201235], - ], + ] ], ], [ @@ -113,7 +113,7 @@ [ # Channel 0 [[9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 6, 6, 6], [9, 9, 6, 6, 6]] - ], + ] ], # Features [ @@ -125,7 +125,7 @@ [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]], # Channel 2 [[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3], [4, 4, 4, 4, 4]], - ], + ] ], # Expected [ @@ -139,7 +139,7 @@ [7.613517, 7.359183, 5.846500, 5.638952, 5.350098], [7.598255, 7.458446, 5.912375, 5.583625, 5.233126], ] - ], + ] ], ], [ @@ -164,7 +164,7 @@ # Frame 4 [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], ] - ], + ] ], # Features [ @@ -183,7 +183,7 @@ # Frame 4 [[0, 0, 5, 5, 5], [0, 0, 5, 5, 5], [0, 0, 5, 5, 5], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], ] - ], + ] ], # Expected [ @@ -232,7 +232,7 @@ [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], ], ] - ], + ] ], ], ] diff --git a/tests/test_phl_cuda.py b/tests/test_phl_cuda.py index 8f7fc6fc3d..d49a60ecd9 100644 --- a/tests/test_phl_cuda.py +++ b/tests/test_phl_cuda.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -42,7 +42,7 @@ # Batch 0 [ # Channel 0 - [1, 0.2, 0.5, 0, 1], + [1, 0.2, 0.5, 0, 1] ], # Batch 1 [ @@ -79,15 +79,15 @@ [0, 0, 0, 0, 1], # Channel 2 [0, 0, 1, 0, 0], - ], + ] ], # Features [ # Batch 0 [ # Channel 0 - [1, 0.2, 0.5, 0.2, 1], - ], + [1, 0.2, 0.5, 0.2, 1] + ] ], # Expected [ @@ -99,7 +99,7 @@ [0.229572, 0.182884, 0.202637, 0.182884, 0.229572], # Channel 2 [0.201235, 0.208194, 0.205409, 0.208194, 0.201235], - ], + ] ], ], [ @@ -113,7 +113,7 @@ [ # Channel 0 [[9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 6, 6, 6], [9, 9, 6, 6, 6]] - ], + ] ], # Features [ @@ -125,7 +125,7 @@ [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]], # Channel 2 [[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3], [4, 4, 4, 4, 4]], - ], + ] ], # Expected [ @@ -139,7 +139,7 @@ [7.712976, 7.429060, 5.789552, 5.594258, 5.371737], [7.701185, 7.492719, 5.860026, 5.538241, 5.281656], ] - ], + ] ], ], ] diff --git a/tests/test_pil_reader.py b/tests/test_pil_reader.py index 0a076b581e..0f7792a56c 100644 --- a/tests/test_pil_reader.py +++ b/tests/test_pil_reader.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_plot_2d_or_3d_image.py b/tests/test_plot_2d_or_3d_image.py index 645658e311..2e4adb93e3 100644 --- a/tests/test_plot_2d_or_3d_image.py +++ b/tests/test_plot_2d_or_3d_image.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,7 +17,11 @@ from parameterized import parameterized from torch.utils.tensorboard import SummaryWriter +from monai.utils import optional_import from monai.visualize import plot_2d_or_3d_image +from tests.utils import SkipIfNoModule + +SummaryWriterX, _ = optional_import("tensorboardX", name="SummaryWriter") TEST_CASE_1 = [(1, 1, 10, 10)] @@ -32,10 +36,30 @@ class TestPlot2dOr3dImage(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) - def test_tb_image_shape(self, shape): + def test_tb_image(self, shape): with tempfile.TemporaryDirectory() as tempdir: writer = SummaryWriter(log_dir=tempdir) - plot_2d_or_3d_image(torch.zeros(shape), 0, writer) + plot_2d_or_3d_image(torch.zeros(shape), 0, writer, max_channels=3, frame_dim=-1) + writer.flush() + writer.close() + self.assertTrue(len(glob.glob(tempdir)) > 0) + + @SkipIfNoModule("tensorboardX") + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + def test_tbx_image(self, shape): + with tempfile.TemporaryDirectory() as tempdir: + writer = SummaryWriterX(log_dir=tempdir) + plot_2d_or_3d_image(torch.zeros(shape), 0, writer, max_channels=2) + writer.flush() + writer.close() + self.assertTrue(len(glob.glob(tempdir)) > 0) + + @SkipIfNoModule("tensorboardX") + @parameterized.expand([TEST_CASE_5]) + def test_tbx_video(self, shape): + with tempfile.TemporaryDirectory() as tempdir: + writer = SummaryWriterX(log_dir=tempdir) + plot_2d_or_3d_image(torch.rand(shape), 0, writer, max_channels=3) writer.flush() writer.close() self.assertTrue(len(glob.glob(tempdir)) > 0) diff --git a/tests/test_png_rw.py b/tests/test_png_rw.py index 265b31b83b..47b5571ac0 100644 --- a/tests/test_png_rw.py +++ b/tests/test_png_rw.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,7 +16,7 @@ import numpy as np from PIL import Image -from monai.data import write_png +from monai.data.image_writer import PILWriter class TestPngWrite(unittest.TestCase): @@ -25,7 +25,9 @@ def test_write_gray(self): image_name = os.path.join(out_dir, "test.png") img = np.random.rand(2, 3) img_save_val = (255 * img).astype(np.uint8) - write_png(img, image_name, scale=255) + writer_obj = PILWriter(output_dtype=np.uint8) + writer_obj.set_data_array(img, channel_dim=None) + writer_obj.write(image_name, format="PNG") out = np.asarray(Image.open(image_name)) out = np.moveaxis(out, 0, 1) np.testing.assert_allclose(out, img_save_val) @@ -35,7 +37,9 @@ def test_write_gray_1height(self): image_name = os.path.join(out_dir, "test.png") img = np.random.rand(1, 3) img_save_val = (65535 * img).astype(np.uint16) - write_png(img, image_name, scale=65535) + writer_obj = PILWriter(output_dtype=np.uint16, scale=65535) + writer_obj.set_data_array(img, channel_dim=None) + writer_obj.write(image_name, format="PNG") out = np.asarray(Image.open(image_name)) out = np.moveaxis(out, 0, 1) np.testing.assert_allclose(out, img_save_val) @@ -45,17 +49,22 @@ def test_write_gray_1channel(self): image_name = os.path.join(out_dir, "test.png") img = np.random.rand(2, 3, 1) img_save_val = (255 * img).astype(np.uint8).squeeze(2) - write_png(img, image_name, scale=255) + writer_obj = PILWriter(output_dtype=np.uint8, scale=255) + writer_obj.set_data_array(img, channel_dim=None) + writer_obj.write(image_name, format="PNG") out = np.asarray(Image.open(image_name)) out = np.moveaxis(out, 0, 1) np.testing.assert_allclose(out, img_save_val) def test_write_rgb(self): + """testing default kwargs and obj_kwargs""" with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.png") img = np.random.rand(2, 3, 3) img_save_val = (255 * img).astype(np.uint8) - write_png(img, image_name, scale=255) + writer_obj = PILWriter(output_dtype=np.uint8) + writer_obj.set_data_array(img, channel_dim=-1) + writer_obj.write(image_name) out = np.asarray(Image.open(image_name)) out = np.moveaxis(out, 0, 1) np.testing.assert_allclose(out, img_save_val) @@ -65,7 +74,9 @@ def test_write_2channels(self): image_name = os.path.join(out_dir, "test.png") img = np.random.rand(2, 3, 2) img_save_val = (255 * img).astype(np.uint8) - write_png(img, image_name, scale=255) + writer_obj = PILWriter(output_dtype=np.uint8) + writer_obj.set_data_array(img, channel_dim=-1) + writer_obj.write(image_name, format="PNG") out = np.asarray(Image.open(image_name)) out = np.moveaxis(out, 0, 1) np.testing.assert_allclose(out, img_save_val) @@ -74,7 +85,10 @@ def test_write_output_shape(self): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.png") img = np.random.rand(2, 2, 3) - write_png(img, image_name, (4, 4), scale=255) + writer_obj = PILWriter(output_dtype=np.uint8) + writer_obj.set_data_array(img, channel_dim=-1) + writer_obj.set_metadata({"spatial_shape": (4, 4)}, scale=255) + writer_obj.write(image_name, format="PNG") out = np.asarray(Image.open(image_name)) np.testing.assert_allclose(out.shape, (4, 4, 3)) diff --git a/tests/test_png_saver.py b/tests/test_png_saver.py index f8ea1df54b..d832718643 100644 --- a/tests/test_png_saver.py +++ b/tests/test_png_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -60,11 +60,7 @@ def test_saved_specified_root(self): with tempfile.TemporaryDirectory() as tempdir: saver = PNGSaver( - output_dir=tempdir, - output_postfix="seg", - output_ext=".png", - scale=255, - data_root_dir="test", + output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255, data_root_dir="test" ) meta_data = { diff --git a/tests/test_polyval.py b/tests/test_polyval.py index 4ff05bc817..db3bcaca53 100644 --- a/tests/test_polyval.py +++ b/tests/test_polyval.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_prepare_batch_default.py b/tests/test_prepare_batch_default.py new file mode 100644 index 0000000000..96051b5e82 --- /dev/null +++ b/tests/test_prepare_batch_default.py @@ -0,0 +1,68 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from monai.engines import PrepareBatchDefault, SupervisedEvaluator +from tests.utils import assert_allclose + + +class TestNet(torch.nn.Module): + def forward(self, x: torch.Tensor): + return x + + +class TestPrepareBatchDefault(unittest.TestCase): + def test_content(self): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataloader = [ + { + "image": torch.tensor([1, 2]), + "label": torch.tensor([3, 4]), + "extra1": torch.tensor([5, 6]), + "extra2": 16, + "extra3": "test", + } + ] + # set up engine + evaluator = SupervisedEvaluator( + device=device, + val_data_loader=dataloader, + epoch_length=1, + network=TestNet(), + non_blocking=False, + prepare_batch=PrepareBatchDefault(), + decollate=False, + mode="eval", + ) + evaluator.run() + output = evaluator.state.output + assert_allclose(output["image"], torch.tensor([1, 2], device=device)) + assert_allclose(output["label"], torch.tensor([3, 4], device=device)) + + def test_empty_data(self): + dataloader = [] + evaluator = SupervisedEvaluator( + val_data_loader=dataloader, + device=torch.device("cpu"), + epoch_length=0, + network=TestNet(), + non_blocking=False, + prepare_batch=PrepareBatchDefault(), + decollate=False, + ) + evaluator.run() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_prepare_batch_default_dist.py b/tests/test_prepare_batch_default_dist.py new file mode 100644 index 0000000000..95d01d2a16 --- /dev/null +++ b/tests/test_prepare_batch_default_dist.py @@ -0,0 +1,73 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import torch +import torch.distributed as dist +from parameterized import parameterized + +from monai.engines import PrepareBatchDefault, SupervisedEvaluator +from tests.utils import DistCall, DistTestCase, assert_allclose + +TEST_CASE_1 = [ + [ + # data for rank 0, has 1 iteration + [{"image": torch.tensor([1, 1]), "label": torch.tensor([1, 0])}], + # data for rank 1, has 2 iterations + [ + {"image": torch.tensor([1, 0]), "label": torch.tensor([1, 0])}, + {"image": torch.tensor([1]), "label": torch.tensor([0])}, + ], + ] +] + +TEST_CASE_2 = [ + [ + # data for rank 0 + [{"image": torch.tensor([0, 1, 1, 0, 1]), "label": torch.tensor([1, 1, 0, 0, 1])}], + # data for rank 1 + [], + ] +] + + +class TestNet(torch.nn.Module): + def forward(self, x: torch.Tensor): + return x + + +class DistributedPrepareBatchDefault(DistTestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @DistCall(nnodes=1, nproc_per_node=2, node_rank=0) + def test_compute(self, dataloaders): + device = torch.device(f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu") + dataloader = dataloaders[dist.get_rank()] + # set up engine + evaluator = SupervisedEvaluator( + device=device, + val_data_loader=dataloader, + epoch_length=len(dataloader), + network=TestNet(), + non_blocking=False, + prepare_batch=PrepareBatchDefault(), + decollate=False, + ) + evaluator.run() + output = evaluator.state.output + if len(dataloader) > 0: + assert_allclose(output["image"], dataloader[-1]["image"].to(device=device)) + assert_allclose(output["label"], dataloader[-1]["label"].to(device=device)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_prepare_batch_extra_input.py b/tests/test_prepare_batch_extra_input.py new file mode 100644 index 0000000000..79c9a13679 --- /dev/null +++ b/tests/test_prepare_batch_extra_input.py @@ -0,0 +1,76 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.engines import PrepareBatchExtraInput, SupervisedEvaluator +from tests.utils import assert_allclose + +TEST_CASE_0 = [ + {"extra_keys": "extra1"}, + {"x": torch.tensor([1, 2]), "t1": torch.tensor([5, 6]), "t2": None, "t3": None}, +] + +TEST_CASE_1 = [ + {"extra_keys": ["extra1", "extra3"]}, + {"x": torch.tensor([1, 2]), "t1": torch.tensor([5, 6]), "t2": "test", "t3": None}, +] + +TEST_CASE_2 = [ + {"extra_keys": {"t1": "extra2", "t2": "extra3", "t3": "extra1"}}, + {"x": torch.tensor([1, 2]), "t1": 16, "t2": "test", "t3": torch.tensor([5, 6])}, +] + + +class TestNet(torch.nn.Module): + def forward(self, x: torch.Tensor, t1=None, t2=None, t3=None): + return {"x": x, "t1": t1, "t2": t2, "t3": t3} + + +class TestPrepareBatchExtraInput(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) + def test_content(self, input_args, expected_value): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataloader = [ + { + "image": torch.tensor([1, 2]), + "label": torch.tensor([3, 4]), + "extra1": torch.tensor([5, 6]), + "extra2": 16, + "extra3": "test", + } + ] + # set up engine + evaluator = SupervisedEvaluator( + device=device, + val_data_loader=dataloader, + epoch_length=1, + network=TestNet(), + non_blocking=True, + prepare_batch=PrepareBatchExtraInput(**input_args), + decollate=False, + ) + evaluator.run() + output = evaluator.state.output + assert_allclose(output["image"], torch.tensor([1, 2], device=device)) + assert_allclose(output["label"], torch.tensor([3, 4], device=device)) + for k, v in output["pred"].items(): + if isinstance(v, torch.Tensor): + assert_allclose(v, expected_value[k].to(device)) + else: + self.assertEqual(v, expected_value[k]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_print_info.py b/tests/test_print_info.py index 64f0b66949..591316884c 100644 --- a/tests/test_print_info.py +++ b/tests/test_print_info.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_print_transform_backends.py b/tests/test_print_transform_backends.py index 4164687f01..2db00fea39 100644 --- a/tests/test_print_transform_backends.py +++ b/tests/test_print_transform_backends.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_probnms.py b/tests/test_probnms.py index e51d1017d8..aab312c1db 100644 --- a/tests/test_probnms.py +++ b/tests/test_probnms.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,87 +16,54 @@ from parameterized import parameterized from monai.transforms.post.array import ProbNMS +from tests.utils import TEST_NDARRAYS, assert_allclose -probs_map_1 = np.random.rand(100, 100).clip(0, 0.5) -TEST_CASES_2D_1 = [{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": 10}, probs_map_1, []] +TESTS = [] +for p in TEST_NDARRAYS: + probs_map_1 = p(np.random.rand(100, 100).clip(0, 0.5)) + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": 10}, probs_map_1, []]) -probs_map_2 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_2[33, 33] = 0.7 -probs_map_2[66, 66] = 0.9 -expected_2 = [[0.9, 66, 66], [0.7, 33, 33]] -TEST_CASES_2D_2 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": [10, 10]}, - probs_map_2, - expected_2, -] + probs_map_2 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_2[33, 33] = 0.7 + probs_map_2[66, 66] = 0.9 + expected_2 = [[0.9, 66, 66], [0.7, 33, 33]] + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": [10, 10]}, probs_map_2, expected_2]) -probs_map_3 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_3[56, 58] = 0.7 -probs_map_3[60, 66] = 0.8 -probs_map_3[66, 66] = 0.9 -expected_3 = [[0.9, 66, 66], [0.8, 60, 66]] -TEST_CASES_2D_3 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": (10, 20)}, - probs_map_3, - expected_3, -] + probs_map_3 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_3[56, 58] = 0.7 + probs_map_3[60, 66] = 0.8 + probs_map_3[66, 66] = 0.9 + expected_3 = [[0.9, 66, 66], [0.8, 60, 66]] + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": (10, 20)}, probs_map_3, expected_3]) -probs_map_4 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_4[33, 33] = 0.7 -probs_map_4[66, 66] = 0.9 -expected_4 = [[0.9, 66, 66]] -TEST_CASES_2D_4 = [ - {"spatial_dims": 2, "prob_threshold": 0.8, "box_size": 10}, - probs_map_4, - expected_4, -] + probs_map_4 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_4[33, 33] = 0.7 + probs_map_4[66, 66] = 0.9 + expected_4 = [[0.9, 66, 66]] + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.8, "box_size": 10}, probs_map_4, expected_4]) -probs_map_5 = np.random.rand(100, 100).clip(0, 0.5) -TEST_CASES_2D_5 = [{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, probs_map_5, []] + probs_map_5 = p(np.random.rand(100, 100).clip(0, 0.5)) + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, probs_map_5, []]) -probs_map_6 = torch.as_tensor(np.random.rand(100, 100).clip(0, 0.5)) -TEST_CASES_2D_6 = [{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, probs_map_6, []] + probs_map_6 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_6[33, 33] = 0.7 + probs_map_6[66, 66] = 0.9 + expected_6 = [[0.9, 66, 66], [0.7, 33, 33]] + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, probs_map_6, expected_6]) -probs_map_7 = torch.as_tensor(np.random.rand(100, 100).clip(0, 0.5)) -probs_map_7[33, 33] = 0.7 -probs_map_7[66, 66] = 0.9 -if torch.cuda.is_available(): - probs_map_7 = probs_map_7.cuda() -expected_7 = [[0.9, 66, 66], [0.7, 33, 33]] -TEST_CASES_2D_7 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, - probs_map_7, - expected_7, -] - -probs_map_3d = torch.rand([50, 50, 50]).uniform_(0, 0.5) -probs_map_3d[25, 25, 25] = 0.7 -probs_map_3d[45, 45, 45] = 0.9 -expected_3d = [[0.9, 45, 45, 45], [0.7, 25, 25, 25]] -TEST_CASES_3D = [ - {"spatial_dims": 3, "prob_threshold": 0.5, "box_size": (10, 10, 10)}, - probs_map_3d, - expected_3d, -] + probs_map_3d = p(torch.rand([50, 50, 50]).uniform_(0, 0.5)) + probs_map_3d[25, 25, 25] = 0.7 + probs_map_3d[45, 45, 45] = 0.9 + expected_3d = [[0.9, 45, 45, 45], [0.7, 25, 25, 25]] + TESTS.append([{"spatial_dims": 3, "prob_threshold": 0.5, "box_size": (10, 10, 10)}, probs_map_3d, expected_3d]) class TestProbNMS(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASES_2D_1, - TEST_CASES_2D_2, - TEST_CASES_2D_3, - TEST_CASES_2D_4, - TEST_CASES_2D_5, - TEST_CASES_2D_6, - TEST_CASES_2D_7, - TEST_CASES_3D, - ] - ) + @parameterized.expand(TESTS) def test_output(self, class_args, probs_map, expected): nms = ProbNMS(**class_args) output = nms(probs_map) - np.testing.assert_allclose(output, expected) + assert_allclose(output, expected) if __name__ == "__main__": diff --git a/tests/test_probnmsd.py b/tests/test_probnmsd.py index 5b75d4310f..bb2315487b 100644 --- a/tests/test_probnmsd.py +++ b/tests/test_probnmsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,89 +10,63 @@ # limitations under the License. import unittest +from typing import Any, List import numpy as np import torch from parameterized import parameterized from monai.transforms.post.dictionary import ProbNMSD +from tests.utils import TEST_NDARRAYS -probs_map_1 = np.random.rand(100, 100).clip(0, 0.5) -TEST_CASES_2D_1 = [{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": 10}, {"prob_map": probs_map_1}, []] +TESTS: List[Any] = [] +for p in TEST_NDARRAYS: + probs_map_1 = p(np.random.rand(100, 100).clip(0, 0.5)) + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": 10}, {"prob_map": probs_map_1}, []]) -probs_map_2 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_2[33, 33] = 0.7 -probs_map_2[66, 66] = 0.9 -expected_2 = [[0.9, 66, 66], [0.7, 33, 33]] -TEST_CASES_2D_2 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": [10, 10]}, - {"prob_map": probs_map_2}, - expected_2, -] - -probs_map_3 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_3[56, 58] = 0.7 -probs_map_3[60, 66] = 0.8 -probs_map_3[66, 66] = 0.9 -expected_3 = [[0.9, 66, 66], [0.8, 60, 66]] -TEST_CASES_2D_3 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": (10, 20)}, - {"prob_map": probs_map_3}, - expected_3, -] + probs_map_2 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_2[33, 33] = 0.7 + probs_map_2[66, 66] = 0.9 + expected_2 = [[0.9, 66, 66], [0.7, 33, 33]] + TESTS.append( + [{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": [10, 10]}, {"prob_map": probs_map_2}, expected_2] + ) -probs_map_4 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_4[33, 33] = 0.7 -probs_map_4[66, 66] = 0.9 -expected_4 = [[0.9, 66, 66]] -TEST_CASES_2D_4 = [ - {"spatial_dims": 2, "prob_threshold": 0.8, "box_size": 10}, - {"prob_map": probs_map_4}, - expected_4, -] + probs_map_3 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_3[56, 58] = 0.7 + probs_map_3[60, 66] = 0.8 + probs_map_3[66, 66] = 0.9 + expected_3 = [[0.9, 66, 66], [0.8, 60, 66]] + TESTS.append( + [{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": (10, 20)}, {"prob_map": probs_map_3}, expected_3] + ) -probs_map_5 = np.random.rand(100, 100).clip(0, 0.5) -TEST_CASES_2D_5 = [{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, {"prob_map": probs_map_5}, []] + probs_map_4 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_4[33, 33] = 0.7 + probs_map_4[66, 66] = 0.9 + expected_4 = [[0.9, 66, 66]] + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.8, "box_size": 10}, {"prob_map": probs_map_4}, expected_4]) -probs_map_6 = torch.as_tensor(np.random.rand(100, 100).clip(0, 0.5)) -TEST_CASES_2D_6 = [{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, {"prob_map": probs_map_6}, []] + probs_map_5 = p(np.random.rand(100, 100).clip(0, 0.5)) + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, {"prob_map": probs_map_5}, []]) -probs_map_7 = torch.as_tensor(np.random.rand(100, 100).clip(0, 0.5)) -probs_map_7[33, 33] = 0.7 -probs_map_7[66, 66] = 0.9 -if torch.cuda.is_available(): - probs_map_7 = probs_map_7.cuda() -expected_7 = [[0.9, 66, 66], [0.7, 33, 33]] -TEST_CASES_2D_7 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, - {"prob_map": probs_map_7}, - expected_7, -] + probs_map_6 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_6[33, 33] = 0.7 + probs_map_6[66, 66] = 0.9 + expected_6 = [[0.9, 66, 66], [0.7, 33, 33]] + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, {"prob_map": probs_map_6}, expected_6]) -probs_map_3d = torch.rand([50, 50, 50]).uniform_(0, 0.5) -probs_map_3d[25, 25, 25] = 0.7 -probs_map_3d[45, 45, 45] = 0.9 -expected_3d = [[0.9, 45, 45, 45], [0.7, 25, 25, 25]] -TEST_CASES_3D = [ - {"spatial_dims": 3, "prob_threshold": 0.5, "box_size": (10, 10, 10)}, - {"prob_map": probs_map_3d}, - expected_3d, -] + probs_map_3d = p(torch.rand([50, 50, 50]).uniform_(0, 0.5)) + probs_map_3d[25, 25, 25] = 0.7 + probs_map_3d[45, 45, 45] = 0.9 + expected_3d = [[0.9, 45, 45, 45], [0.7, 25, 25, 25]] + TESTS.append( + [{"spatial_dims": 3, "prob_threshold": 0.5, "box_size": (10, 10, 10)}, {"prob_map": probs_map_3d}, expected_3d] + ) class TestProbNMS(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASES_2D_1, - TEST_CASES_2D_2, - TEST_CASES_2D_3, - TEST_CASES_2D_4, - TEST_CASES_2D_5, - TEST_CASES_2D_6, - TEST_CASES_2D_7, - TEST_CASES_3D, - ] - ) + @parameterized.expand(TESTS) def test_output(self, class_args, probs_map, expected): nms = ProbNMSD(keys="prob_map", **class_args) output = nms(probs_map) diff --git a/tests/test_pytorch_version_after.py b/tests/test_pytorch_version_after.py new file mode 100644 index 0000000000..68abb9571f --- /dev/null +++ b/tests/test_pytorch_version_after.py @@ -0,0 +1,47 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from parameterized import parameterized + +from monai.utils import pytorch_after + +TEST_CASES = ( + (1, 5, 9, "1.6.0"), + (1, 6, 0, "1.6.0"), + (1, 6, 1, "1.6.0", False), + (1, 7, 0, "1.6.0", False), + (2, 6, 0, "1.6.0", False), + (0, 6, 0, "1.6.0a0+3fd9dcf"), + (1, 5, 9, "1.6.0a0+3fd9dcf"), + (1, 6, 0, "1.6.0a0+3fd9dcf", False), + (1, 6, 1, "1.6.0a0+3fd9dcf", False), + (2, 6, 0, "1.6.0a0+3fd9dcf", False), + (1, 6, 0, "1.6.0-rc0+3fd9dcf", False), # defaults to prerelease + (1, 6, 0, "1.6.0rc0", False), + (1, 6, 0, "1.6", True), + (1, 6, 0, "1", False), + (1, 6, 0, "1.6.0+cpu", True), + (1, 6, 1, "1.6.0+cpu", False), +) + + +class TestPytorchVersionCompare(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_compare(self, a, b, p, current, expected=True): + """Test pytorch_after with a and b""" + self.assertEqual(pytorch_after(a, b, p, current), expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_query_memory.py b/tests/test_query_memory.py index 22c29598fc..cdb44d3eb1 100644 --- a/tests/test_query_memory.py +++ b/tests/test_query_memory.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_rand_adjust_contrast.py b/tests/test_rand_adjust_contrast.py index d7d750957d..eaeff70d51 100644 --- a/tests/test_rand_adjust_contrast.py +++ b/tests/test_rand_adjust_contrast.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import RandAdjustContrast -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose TEST_CASE_1 = [(0.5, 4.5)] @@ -26,14 +26,16 @@ class TestRandAdjustContrast(NumpyImageTestCase2D): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_correct_results(self, gamma): adjuster = RandAdjustContrast(prob=1.0, gamma=gamma) - result = adjuster(self.imt) - epsilon = 1e-7 - img_min = self.imt.min() - img_range = self.imt.max() - img_min - expected = ( - np.power(((self.imt - img_min) / float(img_range + epsilon)), adjuster.gamma_value) * img_range + img_min - ) - np.testing.assert_allclose(expected, result, rtol=1e-05) + for p in TEST_NDARRAYS: + result = adjuster(p(self.imt)) + epsilon = 1e-7 + img_min = self.imt.min() + img_range = self.imt.max() - img_min + expected = ( + np.power(((self.imt - img_min) / float(img_range + epsilon)), adjuster.gamma_value) * img_range + + img_min + ) + assert_allclose(expected, result, rtol=1e-05, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_adjust_contrastd.py b/tests/test_rand_adjust_contrastd.py index e4b61293bb..e5f1f6099a 100644 --- a/tests/test_rand_adjust_contrastd.py +++ b/tests/test_rand_adjust_contrastd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import RandAdjustContrastd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose TEST_CASE_1 = [(0.5, 4.5)] @@ -26,14 +26,16 @@ class TestRandAdjustContrastd(NumpyImageTestCase2D): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_correct_results(self, gamma): adjuster = RandAdjustContrastd("img", prob=1.0, gamma=gamma) - result = adjuster({"img": self.imt}) - epsilon = 1e-7 - img_min = self.imt.min() - img_range = self.imt.max() - img_min - expected = ( - np.power(((self.imt - img_min) / float(img_range + epsilon)), adjuster.gamma_value) * img_range + img_min - ) - np.testing.assert_allclose(expected, result["img"], rtol=1e-05) + for p in TEST_NDARRAYS: + result = adjuster({"img": p(self.imt)}) + epsilon = 1e-7 + img_min = self.imt.min() + img_range = self.imt.max() - img_min + expected = ( + np.power(((self.imt - img_min) / float(img_range + epsilon)), adjuster.adjuster.gamma_value) * img_range + + img_min + ) + assert_allclose(expected, result["img"], rtol=1e-05, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_affine.py b/tests/test_rand_affine.py index 1e1a23bc09..8de408ab84 100644 --- a/tests/test_rand_affine.py +++ b/tests/test_rand_affine.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,114 +16,130 @@ from parameterized import parameterized from monai.transforms import RandAffine +from monai.utils.type_conversion import convert_data_type +from tests.utils import TEST_NDARRAYS, assert_allclose, is_tf32_env -TEST_CASES = [ - [ - dict(as_tensor_output=False, device=None), - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - dict(as_tensor_output=False, device=None, spatial_size=-1), - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - dict(as_tensor_output=False, device=None), - {"img": torch.arange(27).reshape((3, 3, 3)), "spatial_size": (2, 2)}, - np.array([[[2.0, 3.0], [5.0, 6.0]], [[11.0, 12.0], [14.0, 15.0]], [[20.0, 21.0], [23.0, 24.0]]]), - ], - [ - dict(as_tensor_output=True, device=None), - {"img": torch.ones((1, 3, 3, 3)), "spatial_size": (2, 2, 2)}, - torch.ones((1, 2, 2, 2)), - ], - [ - dict(as_tensor_output=True, device=None, spatial_size=(2, 2, 2), cache_grid=True), - {"img": torch.ones((1, 3, 3, 3))}, - torch.ones((1, 2, 2, 2)), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - padding_mode="zeros", - spatial_size=(2, 2, 2), - device=None, - ), - {"img": torch.ones((1, 3, 3, 3)), "mode": "bilinear"}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - padding_mode="zeros", - spatial_size=(2, 2, 2), - cache_grid=True, - device=None, - ), - {"img": torch.ones((1, 3, 3, 3)), "mode": "bilinear"}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=True, - device=None, - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "spatial_size": (3, 3)}, - torch.tensor([[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - spatial_size=(3, 3), - cache_grid=True, - as_tensor_output=True, - device=None, - ), - {"img": torch.arange(64).reshape((1, 8, 8))}, - torch.tensor([[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]), - ], -] +_rtol = 1e-3 if is_tf32_env() else 1e-4 + +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [dict(device=device), {"img": p(torch.arange(27).reshape((3, 3, 3)))}, p(np.arange(27).reshape((3, 3, 3)))] + ) + TESTS.append( + [ + dict(device=device, spatial_size=-1), + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), + ] + ) + TESTS.append( + [ + dict(device=device), + {"img": p(torch.arange(27).reshape((3, 3, 3))), "spatial_size": (2, 2)}, + p(np.array([[[2.0, 3.0], [5.0, 6.0]], [[11.0, 12.0], [14.0, 15.0]], [[20.0, 21.0], [23.0, 24.0]]])), + ] + ) + TESTS.append( + [ + dict(device=device), + {"img": p(torch.ones((1, 3, 3, 3))), "spatial_size": (2, 2, 2)}, + p(torch.ones((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2, 2), cache_grid=True), + {"img": p(torch.ones((1, 3, 3, 3)))}, + p(torch.ones((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + padding_mode="zeros", + spatial_size=(2, 2, 2), + device=device, + ), + {"img": p(torch.ones((1, 3, 3, 3))), "mode": "bilinear"}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + padding_mode="zeros", + spatial_size=(2, 2, 2), + cache_grid=True, + device=device, + ), + {"img": p(torch.ones((1, 3, 3, 3))), "mode": "bilinear"}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "spatial_size": (3, 3)}, + p( + torch.tensor( + [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]] + ) + ), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + cache_grid=True, + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8)))}, + p( + torch.tensor( + [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]] + ) + ), + ] + ) -ARR_NUMPY = np.arange(9 * 10).reshape(1, 9, 10) -ARR_TORCH = torch.Tensor(ARR_NUMPY) TEST_CASES_SKIPPED_CONSISTENCY = [] -for im in (ARR_NUMPY, ARR_TORCH): - for as_tensor_output in (True, False): - for in_dtype_is_int in (True, False): - TEST_CASES_SKIPPED_CONSISTENCY.append((im, as_tensor_output, in_dtype_is_int)) +for p in TEST_NDARRAYS: + for in_dtype in (np.int32, np.float32): + TEST_CASES_SKIPPED_CONSISTENCY.append((p(np.arange(9 * 10).reshape(1, 9, 10)), in_dtype)) class TestRandAffine(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_affine(self, input_param, input_data, expected_val): g = RandAffine(**input_param) g.set_random_state(123) result = g(**input_data) if input_param.get("cache_grid", False): self.assertTrue(g._cached_grid is not None) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=_rtol, atol=1e-4) def test_ill_cache(self): with self.assertWarns(UserWarning): @@ -132,15 +148,11 @@ def test_ill_cache(self): RandAffine(cache_grid=True, spatial_size=(1, 1, -1)) @parameterized.expand(TEST_CASES_SKIPPED_CONSISTENCY) - def test_skipped_transform_consistency(self, im, as_tensor_output, in_dtype_is_int): - t1 = RandAffine(prob=0, as_tensor_output=as_tensor_output) - t2 = RandAffine(prob=1, spatial_size=(10, 11), as_tensor_output=as_tensor_output) + def test_skipped_transform_consistency(self, im, in_dtype): + t1 = RandAffine(prob=0) + t2 = RandAffine(prob=1, spatial_size=(10, 11)) - # change dtype to int32 or float32 - if in_dtype_is_int: - im = im.astype("int32") if isinstance(im, np.ndarray) else im.int() - else: - im = im.astype("float32") if isinstance(im, np.ndarray) else im.float() + im, *_ = convert_data_type(im, dtype=in_dtype) out1 = t1(im) out2 = t2(im) diff --git a/tests/test_rand_affine_grid.py b/tests/test_rand_affine_grid.py index 605d0a30ba..60ac40f468 100644 --- a/tests/test_rand_affine_grid.py +++ b/tests/test_rand_affine_grid.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,182 +16,194 @@ from parameterized import parameterized from monai.transforms import RandAffineGrid +from tests.utils import TEST_NDARRAYS, assert_allclose, is_tf32_env -TEST_CASES = [ - [{"as_tensor_output": False, "device": None}, {"grid": torch.ones((3, 3, 3))}, np.ones((3, 3, 3))], - [ - {"rotate_range": (1, 2), "translate_range": (3, 3, 3)}, - {"grid": torch.arange(0, 27).reshape((3, 3, 3))}, - torch.tensor( - np.array( - [ - [ - [-32.81998, -33.910976, -35.001972], - [-36.092968, -37.183964, -38.27496], - [-39.36596, -40.456955, -41.54795], - ], - [[2.1380205, 3.1015975, 4.0651755], [5.028752, 5.9923296, 6.955907], [7.919484, 8.883063, 9.84664]], - [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0], [24.0, 25.0, 26.0]], - ] - ) - ), - ], - [ - {"translate_range": (3, 3, 3), "as_tensor_output": False, "device": torch.device("cpu:0")}, - {"spatial_size": (3, 3, 3)}, - np.array( +_rtol = 1e-1 if is_tf32_env else 1e-4 + +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append([{"device": device}, {"grid": p(torch.ones((3, 3, 3)))}, p(np.ones((3, 3, 3)))]) + TESTS.append( [ - [ - [ - [0.17881513, 0.17881513, 0.17881513], - [0.17881513, 0.17881513, 0.17881513], - [0.17881513, 0.17881513, 0.17881513], - ], - [ - [1.1788151, 1.1788151, 1.1788151], - [1.1788151, 1.1788151, 1.1788151], - [1.1788151, 1.1788151, 1.1788151], - ], - [ - [2.1788151, 2.1788151, 2.1788151], - [2.1788151, 2.1788151, 2.1788151], - [2.1788151, 2.1788151, 2.1788151], - ], - ], - [ - [ - [-2.283164, -2.283164, -2.283164], - [-1.283164, -1.283164, -1.283164], - [-0.28316402, -0.28316402, -0.28316402], - ], - [ - [-2.283164, -2.283164, -2.283164], - [-1.283164, -1.283164, -1.283164], - [-0.28316402, -0.28316402, -0.28316402], - ], - [ - [-2.283164, -2.283164, -2.283164], - [-1.283164, -1.283164, -1.283164], - [-0.28316402, -0.28316402, -0.28316402], - ], - ], - [ - [ - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - ], - [ - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - ], - [ - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - ], - ], - [ - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - ], - ] - ), - ], - [ - {"rotate_range": (1.0, 1.0, 1.0), "shear_range": (0.1,), "scale_range": (1.2,)}, - {"grid": torch.arange(0, 108).reshape((4, 3, 3, 3))}, - torch.tensor( - np.array( - [ - [ - [ - [-9.4201e00, -8.1672e00, -6.9143e00], - [-5.6614e00, -4.4085e00, -3.1556e00], - [-1.9027e00, -6.4980e-01, 6.0310e-01], - ], - [ - [1.8560e00, 3.1089e00, 4.3618e00], - [5.6147e00, 6.8676e00, 8.1205e00], - [9.3734e00, 1.0626e01, 1.1879e01], - ], + {"rotate_range": (1, 2), "translate_range": (3, 3, 3)}, + {"grid": p(torch.arange(0, 27).reshape((3, 3, 3)))}, + p( + np.array( [ - [1.3132e01, 1.4385e01, 1.5638e01], - [1.6891e01, 1.8144e01, 1.9397e01], - [2.0650e01, 2.1902e01, 2.3155e01], - ], - ], - [ - [ - [9.9383e-02, -4.8845e-01, -1.0763e00], - [-1.6641e00, -2.2519e00, -2.8398e00], - [-3.4276e00, -4.0154e00, -4.6032e00], - ], - [ - [-5.1911e00, -5.7789e00, -6.3667e00], - [-6.9546e00, -7.5424e00, -8.1302e00], - [-8.7180e00, -9.3059e00, -9.8937e00], - ], - [ - [-1.0482e01, -1.1069e01, -1.1657e01], - [-1.2245e01, -1.2833e01, -1.3421e01], - [-1.4009e01, -1.4596e01, -1.5184e01], - ], - ], + [ + [-32.81998, -33.910976, -35.001972], + [-36.092968, -37.183964, -38.27496], + [-39.36596, -40.456955, -41.54795], + ], + [ + [2.1380205, 3.1015975, 4.0651755], + [5.028752, 5.9923296, 6.955907], + [7.919484, 8.883063, 9.84664], + ], + [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0], [24.0, 25.0, 26.0]], + ] + ) + ), + ] + ) + TESTS.append( + [ + {"translate_range": (3, 3, 3), "device": device}, + {"spatial_size": (3, 3, 3)}, + np.array( [ [ - [5.9635e01, 6.1199e01, 6.2764e01], - [6.4328e01, 6.5892e01, 6.7456e01], - [6.9021e01, 7.0585e01, 7.2149e01], - ], - [ - [7.3714e01, 7.5278e01, 7.6842e01], - [7.8407e01, 7.9971e01, 8.1535e01], - [8.3099e01, 8.4664e01, 8.6228e01], + [ + [0.17881513, 0.17881513, 0.17881513], + [0.17881513, 0.17881513, 0.17881513], + [0.17881513, 0.17881513, 0.17881513], + ], + [ + [1.1788151, 1.1788151, 1.1788151], + [1.1788151, 1.1788151, 1.1788151], + [1.1788151, 1.1788151, 1.1788151], + ], + [ + [2.1788151, 2.1788151, 2.1788151], + [2.1788151, 2.1788151, 2.1788151], + [2.1788151, 2.1788151, 2.1788151], + ], ], [ - [8.7792e01, 8.9357e01, 9.0921e01], - [9.2485e01, 9.4049e01, 9.5614e01], - [9.7178e01, 9.8742e01, 1.0031e02], + [ + [-2.283164, -2.283164, -2.283164], + [-1.283164, -1.283164, -1.283164], + [-0.28316402, -0.28316402, -0.28316402], + ], + [ + [-2.283164, -2.283164, -2.283164], + [-1.283164, -1.283164, -1.283164], + [-0.28316402, -0.28316402, -0.28316402], + ], + [ + [-2.283164, -2.283164, -2.283164], + [-1.283164, -1.283164, -1.283164], + [-0.28316402, -0.28316402, -0.28316402], + ], ], - ], - [ [ - [8.1000e01, 8.2000e01, 8.3000e01], - [8.4000e01, 8.5000e01, 8.6000e01], - [8.7000e01, 8.8000e01, 8.9000e01], + [ + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + ], + [ + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + ], + [ + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + ], ], [ - [9.0000e01, 9.1000e01, 9.2000e01], - [9.3000e01, 9.4000e01, 9.5000e01], - [9.6000e01, 9.7000e01, 9.8000e01], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], ], + ] + ), + ] + ) + TESTS.append( + [ + {"device": device, "rotate_range": (1.0, 1.0, 1.0), "shear_range": (0.1,), "scale_range": (1.2,)}, + {"grid": p(torch.arange(0, 108).reshape((4, 3, 3, 3)))}, + p( + np.array( [ - [9.9000e01, 1.0000e02, 1.0100e02], - [1.0200e02, 1.0300e02, 1.0400e02], - [1.0500e02, 1.0600e02, 1.0700e02], - ], - ], - ] - ) - ), - ], -] + [ + [ + [-9.4201e00, -8.1672e00, -6.9143e00], + [-5.6614e00, -4.4085e00, -3.1556e00], + [-1.9027e00, -6.4980e-01, 6.0310e-01], + ], + [ + [1.8560e00, 3.1089e00, 4.3618e00], + [5.6147e00, 6.8676e00, 8.1205e00], + [9.3734e00, 1.0626e01, 1.1879e01], + ], + [ + [1.3132e01, 1.4385e01, 1.5638e01], + [1.6891e01, 1.8144e01, 1.9397e01], + [2.0650e01, 2.1902e01, 2.3155e01], + ], + ], + [ + [ + [9.9383e-02, -4.8845e-01, -1.0763e00], + [-1.6641e00, -2.2519e00, -2.8398e00], + [-3.4276e00, -4.0154e00, -4.6032e00], + ], + [ + [-5.1911e00, -5.7789e00, -6.3667e00], + [-6.9546e00, -7.5424e00, -8.1302e00], + [-8.7180e00, -9.3059e00, -9.8937e00], + ], + [ + [-1.0482e01, -1.1069e01, -1.1657e01], + [-1.2245e01, -1.2833e01, -1.3421e01], + [-1.4009e01, -1.4596e01, -1.5184e01], + ], + ], + [ + [ + [5.9635e01, 6.1199e01, 6.2764e01], + [6.4328e01, 6.5892e01, 6.7456e01], + [6.9021e01, 7.0585e01, 7.2149e01], + ], + [ + [7.3714e01, 7.5278e01, 7.6842e01], + [7.8407e01, 7.9971e01, 8.1535e01], + [8.3099e01, 8.4664e01, 8.6228e01], + ], + [ + [8.7792e01, 8.9357e01, 9.0921e01], + [9.2485e01, 9.4049e01, 9.5614e01], + [9.7178e01, 9.8742e01, 1.0031e02], + ], + ], + [ + [ + [8.1000e01, 8.2000e01, 8.3000e01], + [8.4000e01, 8.5000e01, 8.6000e01], + [8.7000e01, 8.8000e01, 8.9000e01], + ], + [ + [9.0000e01, 9.1000e01, 9.2000e01], + [9.3000e01, 9.4000e01, 9.5000e01], + [9.6000e01, 9.7000e01, 9.8000e01], + ], + [ + [9.9000e01, 1.0000e02, 1.0100e02], + [1.0200e02, 1.0300e02, 1.0400e02], + [1.0500e02, 1.0600e02, 1.0700e02], + ], + ], + ] + ) + ), + ] + ) class TestRandAffineGrid(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_affine_grid(self, input_param, input_data, expected_val): g = RandAffineGrid(**input_param) g.set_random_state(123) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + if "device" in input_data: + self.assertEqual(result.device, input_data[device]) + assert_allclose(result, expected_val, type_test=False, rtol=_rtol, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_affined.py b/tests/test_rand_affined.py index d2f8a60665..882b5554e6 100644 --- a/tests/test_rand_affined.py +++ b/tests/test_rand_affined.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,179 +17,190 @@ from monai.transforms import RandAffined from monai.utils import GridSampleMode +from tests.utils import TEST_NDARRAYS, assert_allclose, is_tf32_env -TEST_CASES = [ - [ - dict(as_tensor_output=False, device=None, spatial_size=None, keys=("img", "seg")), - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - dict(as_tensor_output=False, device=None, spatial_size=(2, 2), keys=("img", "seg")), - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.ones((3, 2, 2)), - ], - [ - dict(as_tensor_output=False, device=None, spatial_size=(2, 2), cache_grid=True, keys=("img", "seg")), - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.ones((3, 2, 2)), - ], - [ - dict(as_tensor_output=True, device=None, spatial_size=(2, 2, 2), keys=("img", "seg")), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - torch.ones((1, 2, 2, 2)), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - spatial_size=(2, 2, 2), - padding_mode="zeros", - device=None, - keys=("img", "seg"), - mode="bilinear", - ), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=False, - spatial_size=(2, 2, 2), - padding_mode="zeros", - device=None, - cache_grid=True, - keys=("img", "seg"), - mode="bilinear", - ), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - np.array([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=True, - spatial_size=(3, 3), - keys=("img", "seg"), - device=None, - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - torch.tensor([[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]), - ], - [ - dict( - prob=0.9, - mode=("bilinear", "nearest"), - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=False, - spatial_size=(3, 3), - keys=("img", "seg"), - device=torch.device("cpu:0"), - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - { - "img": np.array( - [ - [ - [18.736153, 15.581954, 12.4277525], - [27.398798, 24.244598, 21.090399], - [36.061443, 32.90724, 29.753046], - ] - ] - ), - "seg": np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]]), - }, - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - spatial_size=(2, 2, 2), - padding_mode="zeros", - device=None, - keys=("img", "seg"), - mode=GridSampleMode.BILINEAR, - ), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=False, - spatial_size=(3, 3), - keys=("img", "seg"), - device=torch.device("cpu:0"), - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - { - "img": np.array( - [ - [ - [18.736153, 15.581954, 12.4277525], - [27.398798, 24.244598, 21.090399], - [36.061443, 32.90724, 29.753046], - ] - ] - ), - "seg": np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]]), - }, - ], - [ - dict( - prob=0.9, - mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=False, - spatial_size=(3, 3), - cache_grid=True, - keys=("img", "seg"), - device=torch.device("cpu:0"), - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - { - "img": np.array( - [ - [ - [18.736153, 15.581954, 12.4277525], - [27.398798, 24.244598, 21.090399], - [36.061443, 32.90724, 29.753046], - ] - ] - ), - "seg": np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]]), - }, - ], -] +_rtol = 1e-3 if is_tf32_env() else 1e-4 + +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + dict(device=device, spatial_size=None, keys=("img", "seg")), + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2), keys=("img", "seg")), + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p(np.ones((3, 2, 2))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2), cache_grid=True, keys=("img", "seg")), + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p(np.ones((3, 2, 2))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2, 2), keys=("img", "seg")), + {"img": p(torch.ones((1, 3, 3, 3))), "seg": p(torch.ones((1, 3, 3, 3)))}, + p(torch.ones((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + spatial_size=(2, 2, 2), + padding_mode="zeros", + device=device, + keys=("img", "seg"), + mode="bilinear", + ), + {"img": p(torch.ones((1, 3, 3, 3))), "seg": p(torch.ones((1, 3, 3, 3)))}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + p( + torch.tensor( + [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]] + ) + ), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + mode=("bilinear", "nearest"), + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + { + "img": p( + np.array( + [ + [ + [18.736153, 15.581954, 12.4277525], + [27.398798, 24.244598, 21.090399], + [36.061443, 32.90724, 29.753046], + ] + ] + ) + ), + "seg": p(np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), + }, + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + spatial_size=(2, 2, 2), + padding_mode="zeros", + device=device, + keys=("img", "seg"), + mode=GridSampleMode.BILINEAR, + ), + {"img": p(torch.ones((1, 3, 3, 3))), "seg": p(torch.ones((1, 3, 3, 3)))}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + { + "img": p( + np.array( + [ + [ + [18.736153, 15.581954, 12.4277525], + [27.398798, 24.244598, 21.090399], + [36.061443, 32.90724, 29.753046], + ] + ] + ) + ), + "seg": p(np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), + }, + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + cache_grid=True, + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + { + "img": p( + np.array( + [ + [ + [18.736153, 15.581954, 12.4277525], + [27.398798, 24.244598, 21.090399], + [36.061443, 32.90724, 29.753046], + ] + ] + ) + ), + "seg": p(np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), + }, + ] + ) class TestRandAffined(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_affined(self, input_param, input_data, expected_val): g = RandAffined(**input_param).set_random_state(123) res = g(input_data) @@ -200,28 +211,20 @@ def test_rand_affined(self, input_param, input_data, expected_val): if "_transforms" in key: continue expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected, rtol=_rtol, atol=1e-3) + + g.set_random_state(4) + res = g(input_data) + # affine should be tensor because the resampler only supports pytorch backend + self.assertTrue(isinstance(res["img_transforms"][0]["extra_info"]["affine"], torch.Tensor)) def test_ill_cache(self): with self.assertWarns(UserWarning): # spatial size is None - RandAffined( - as_tensor_output=False, device=None, spatial_size=None, prob=1.0, cache_grid=True, keys=("img", "seg") - ) + RandAffined(device=device, spatial_size=None, prob=1.0, cache_grid=True, keys=("img", "seg")) with self.assertWarns(UserWarning): # spatial size is dynamic - RandAffined( - as_tensor_output=False, - device=None, - spatial_size=(2, -1), - prob=1.0, - cache_grid=True, - keys=("img", "seg"), - ) + RandAffined(device=device, spatial_size=(2, -1), prob=1.0, cache_grid=True, keys=("img", "seg")) if __name__ == "__main__": diff --git a/tests/test_rand_axis_flip.py b/tests/test_rand_axis_flip.py index c05c3a1e0d..b7c504557f 100644 --- a/tests/test_rand_axis_flip.py +++ b/tests/test_rand_axis_flip.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,10 +22,8 @@ def test_correct_results(self): for p in TEST_NDARRAYS: flip = RandAxisFlip(prob=1.0) result = flip(p(self.imt[0])) - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, flip._axis)) - assert_allclose(np.stack(expected), result) + expected = [np.flip(channel, flip._axis) for channel in self.imt[0]] + assert_allclose(result, p(np.stack(expected))) if __name__ == "__main__": diff --git a/tests/test_rand_axis_flipd.py b/tests/test_rand_axis_flipd.py index 7bef0baa63..ff97d5dc1e 100644 --- a/tests/test_rand_axis_flipd.py +++ b/tests/test_rand_axis_flipd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,10 +23,8 @@ def test_correct_results(self): flip = RandAxisFlipd(keys="img", prob=1.0) result = flip({"img": p(self.imt[0])})["img"] - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, flip._axis)) - assert_allclose(np.stack(expected), result) + expected = [np.flip(channel, flip.flipper._axis) for channel in self.imt[0]] + assert_allclose(result, p(np.stack(expected))) if __name__ == "__main__": diff --git a/tests/test_random_bias_field.py b/tests/test_rand_bias_field.py similarity index 65% rename from tests/test_random_bias_field.py rename to tests/test_rand_bias_field.py index 5aeeb79874..b3aa8e9174 100644 --- a/tests/test_random_bias_field.py +++ b/tests/test_rand_bias_field.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,29 +12,34 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandBiasField -TEST_CASES_2D = [{}, (3, 32, 32)] -TEST_CASES_3D = [{}, (3, 32, 32, 32)] -TEST_CASES_2D_ZERO_RANGE = [{"coeff_range": (0.0, 0.0)}, (2, 3, 3)] -TEST_CASES_2D_ONES = [{"coeff_range": (1.0, 1.0)}, np.asarray([[[7.389056, 0.1353353], [7.389056, 22026.46]]])] +TEST_CASES_2D = [{"prob": 1.0}, (3, 32, 32)] +TEST_CASES_3D = [{"prob": 1.0}, (3, 32, 32, 32)] +TEST_CASES_2D_ZERO_RANGE = [{"prob": 1.0, "coeff_range": (0.0, 0.0)}, (2, 3, 3)] +TEST_CASES_2D_ONES = [ + {"prob": 1.0, "coeff_range": (1.0, 1.0)}, + np.asarray([[[7.389056, 0.1353353], [7.389056, 22026.46]]]), +] class TestRandBiasField(unittest.TestCase): @parameterized.expand([TEST_CASES_2D, TEST_CASES_3D]) def test_output_shape(self, class_args, img_shape): - for degree in [1, 2, 3]: - bias_field = RandBiasField(degree=degree, **class_args) - img = np.random.rand(*img_shape) - output = bias_field(img) - np.testing.assert_equal(output.shape, img_shape) - np.testing.assert_equal(output.dtype, bias_field.dtype) - - img_zero = np.zeros([*img_shape]) - output_zero = bias_field(img_zero) - np.testing.assert_equal(output_zero, img_zero) + for fn in (np.random, torch): + for degree in [1, 2, 3]: + bias_field = RandBiasField(degree=degree, **class_args) + img = fn.rand(*img_shape) + output = bias_field(img) + np.testing.assert_equal(output.shape, img_shape) + self.assertTrue(output.dtype in (np.float32, torch.float32)) + + img_zero = np.zeros([*img_shape]) + output_zero = bias_field(img_zero) + np.testing.assert_equal(output_zero, img_zero) @parameterized.expand([TEST_CASES_2D_ZERO_RANGE]) def test_zero_range(self, class_args, img_shape): diff --git a/tests/test_random_bias_fieldd.py b/tests/test_rand_bias_fieldd.py similarity index 89% rename from tests/test_random_bias_fieldd.py rename to tests/test_rand_bias_fieldd.py index aa2e206de9..da08cfe053 100644 --- a/tests/test_random_bias_fieldd.py +++ b/tests/test_rand_bias_fieldd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,11 +16,11 @@ from monai.transforms import RandBiasFieldd -TEST_CASES_2D = [{}, (3, 32, 32)] -TEST_CASES_3D = [{}, (3, 32, 32, 32)] -TEST_CASES_2D_ZERO_RANGE = [{"coeff_range": (0.0, 0.0)}, (3, 32, 32)] +TEST_CASES_2D = [{"prob": 1.0}, (3, 32, 32)] +TEST_CASES_3D = [{"prob": 1.0}, (3, 32, 32, 32)] +TEST_CASES_2D_ZERO_RANGE = [{"prob": 1.0, "coeff_range": (0.0, 0.0)}, (3, 32, 32)] TEST_CASES_2D_ONES = [ - {"coeff_range": (1.0, 1.0)}, + {"prob": 1.0, "coeff_range": (1.0, 1.0)}, np.asarray([[[7.3890562e00, 1.3533528e-01], [7.3890562e00, 2.2026465e04]]]), ] diff --git a/tests/test_rand_coarse_dropout.py b/tests/test_rand_coarse_dropout.py index 830832c2a5..a05d323277 100644 --- a/tests/test_rand_coarse_dropout.py +++ b/tests/test_rand_coarse_dropout.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandCoarseDropout @@ -52,12 +53,20 @@ np.random.randint(0, 2, size=[3, 3, 3, 4]), ] +TEST_CASE_7 = [ + {"holes": 2, "spatial_size": [2, 2, 2], "dropout_holes": False, "fill_value": (3, 6), "prob": 1.0}, + torch.randint(0, 2, size=[3, 3, 3, 4]), +] + class TestRandCoarseDropout(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand( + [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7] + ) def test_value(self, input_param, input_data): dropout = RandCoarseDropout(**input_param) result = dropout(input_data) + self.assertEqual(type(result), type(input_data)) holes = input_param.get("holes") max_holes = input_param.get("max_holes") spatial_size = fall_back_tuple(input_param.get("spatial_size"), input_data.shape[1:]) diff --git a/tests/test_rand_coarse_dropoutd.py b/tests/test_rand_coarse_dropoutd.py index fc898a9fca..e54db130a5 100644 --- a/tests/test_rand_coarse_dropoutd.py +++ b/tests/test_rand_coarse_dropoutd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -28,14 +28,7 @@ ] TEST_CASE_2 = [ - { - "keys": "img", - "holes": 2, - "spatial_size": [2, 2, 2], - "fill_value": 5, - "max_spatial_size": [4, 4, 3], - "prob": 1.0, - }, + {"keys": "img", "holes": 2, "spatial_size": [2, 2, 2], "fill_value": 5, "max_spatial_size": [4, 4, 3], "prob": 1.0}, {"img": np.random.randint(0, 2, size=[3, 3, 3, 4])}, ] diff --git a/tests/test_rand_coarse_shuffle.py b/tests/test_rand_coarse_shuffle.py new file mode 100644 index 0000000000..fb7311e5a3 --- /dev/null +++ b/tests/test_rand_coarse_shuffle.py @@ -0,0 +1,62 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import RandCoarseShuffle + +TEST_CASES = [ + [ + {"holes": 5, "spatial_size": 1, "max_spatial_size": -1, "prob": 0.0}, + {"img": np.arange(8).reshape((1, 2, 2, 2))}, + np.arange(8).reshape((1, 2, 2, 2)), + ], + [ + {"holes": 10, "spatial_size": 1, "max_spatial_size": -1, "prob": 1.0}, + {"img": np.arange(27).reshape((1, 3, 3, 3))}, + np.asarray( + [ + [ + [[8, 19, 26], [24, 6, 15], [0, 13, 25]], + [[17, 3, 5], [10, 1, 12], [22, 4, 11]], + [[21, 20, 23], [14, 2, 16], [18, 9, 7]], + ] + ] + ), + ], + [ + {"holes": 2, "spatial_size": 1, "max_spatial_size": -1, "prob": 1.0}, + {"img": np.arange(16).reshape((2, 2, 2, 2))}, + np.asarray([[[[6, 1], [4, 3]], [[0, 2], [7, 5]]], [[[14, 10], [9, 8]], [[12, 15], [13, 11]]]]), + ], + [ + {"holes": 2, "spatial_size": 1, "max_spatial_size": -1, "prob": 1.0}, + {"img": torch.arange(16).reshape((2, 2, 2, 2))}, + torch.as_tensor([[[[6, 1], [4, 3]], [[0, 2], [7, 5]]], [[[14, 10], [9, 8]], [[12, 15], [13, 11]]]]), + ], +] + + +class TestRandCoarseShuffle(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shuffle(self, input_param, input_data, expected_val): + g = RandCoarseShuffle(**input_param) + g.set_random_state(seed=12) + result = g(**input_data) + np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_coarse_shuffled.py b/tests/test_rand_coarse_shuffled.py new file mode 100644 index 0000000000..fa9c17286d --- /dev/null +++ b/tests/test_rand_coarse_shuffled.py @@ -0,0 +1,56 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandCoarseShuffled + +TEST_CASES = [ + [ + {"keys": "img", "holes": 5, "spatial_size": 1, "max_spatial_size": -1, "prob": 0.0}, + {"img": np.arange(8).reshape((1, 2, 2, 2))}, + np.arange(8).reshape((1, 2, 2, 2)), + ], + [ + {"keys": "img", "holes": 10, "spatial_size": 1, "max_spatial_size": -1, "prob": 1.0}, + {"img": np.arange(27).reshape((1, 3, 3, 3))}, + np.asarray( + [ + [ + [[8, 19, 26], [24, 6, 15], [0, 13, 25]], + [[17, 3, 5], [10, 1, 12], [22, 4, 11]], + [[21, 20, 23], [14, 2, 16], [18, 9, 7]], + ] + ] + ), + ], + [ + {"keys": "img", "holes": 2, "spatial_size": 1, "max_spatial_size": -1, "prob": 1.0}, + {"img": np.arange(16).reshape((2, 2, 2, 2))}, + np.asarray([[[[6, 1], [4, 3]], [[0, 2], [7, 5]]], [[[14, 10], [9, 8]], [[12, 15], [13, 11]]]]), + ], +] + + +class TestRandCoarseShuffled(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shuffle(self, input_param, input_data, expected_val): + g = RandCoarseShuffled(**input_param) + g.set_random_state(seed=12) + result = g(input_data) + np.testing.assert_allclose(result["img"], expected_val, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_crop_by_label_classes.py b/tests/test_rand_crop_by_label_classes.py index b21f971042..11d73df74e 100644 --- a/tests/test_rand_crop_by_label_classes.py +++ b/tests/test_rand_crop_by_label_classes.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,68 +15,121 @@ from parameterized import parameterized from monai.transforms import ClassesToIndices, RandCropByLabelClasses +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ +TESTS_INDICES, TESTS_SHAPE = [], [] +for p in TEST_NDARRAYS: # One-Hot label - { - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "num_classes": None, - "spatial_size": [2, 2, -1], - "ratios": [1, 1, 1], - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - list, - (3, 2, 2, 3), -] + TESTS_INDICES.append( + [ + { + "label": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "num_classes": None, + "spatial_size": [2, 2, -1], + "ratios": [1, 1, 1], + "num_samples": 2, + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image_threshold": 0, + }, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + list, + (3, 2, 2, 3), + ] + ) -TEST_CASE_1 = [ - # Argmax label - { - "label": np.random.randint(0, 2, size=[1, 3, 3, 3]), - "num_classes": 2, - "spatial_size": [2, 2, 2], - "ratios": [1, 1], - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - list, - (3, 2, 2, 2), -] + TESTS_INDICES.append( + [ + # Argmax label + { + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + "num_classes": 2, + "spatial_size": [2, 2, 2], + "ratios": [1, 1], + "num_samples": 2, + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image_threshold": 0, + }, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + list, + (3, 2, 2, 2), + ] + ) -TEST_CASE_2 = [ - # provide label at runtime - { - "label": None, - "num_classes": 2, - "spatial_size": [2, 2, 2], - "ratios": [1, 1], - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - { - "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[1, 3, 3, 3]), - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - }, - list, - (3, 2, 2, 2), -] + TESTS_SHAPE.append( + [ + # provide label at runtime + { + "label": None, + "num_classes": 2, + "spatial_size": [2, 2, 2], + "ratios": [1, 1], + "num_samples": 2, + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image_threshold": 0, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + }, + list, + (3, 2, 2, 2), + ] + ) + TESTS_SHAPE.append( + [ + # provide label at runtime + { + "label": None, + "num_classes": 2, + "spatial_size": [4, 4, 2], + "ratios": [1, 1], + "num_samples": 2, + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image_threshold": 0, + "allow_smaller": True, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + }, + list, + (3, 3, 3, 2), + ] + ) + TESTS_SHAPE.append( + [ + # provide label at runtime + { + "label": None, + "num_classes": 2, + "spatial_size": [4, 4, 4], + "ratios": [1, 1], + "num_samples": 2, + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image_threshold": 0, + "allow_smaller": True, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + }, + list, + (3, 3, 3, 3), + ] + ) class TestRandCropByLabelClasses(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS_INDICES + TESTS_SHAPE) def test_type_shape(self, input_param, input_data, expected_type, expected_shape): result = RandCropByLabelClasses(**input_param)(**input_data) self.assertIsInstance(result, expected_type) self.assertTupleEqual(result[0].shape, expected_shape) - @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) + @parameterized.expand(TESTS_INDICES) def test_indices(self, input_param, input_data, expected_type, expected_shape): input_param["indices"] = ClassesToIndices(num_classes=input_param["num_classes"])(input_param["label"]) result = RandCropByLabelClasses(**input_param)(**input_data) diff --git a/tests/test_rand_crop_by_label_classesd.py b/tests/test_rand_crop_by_label_classesd.py index 829096953b..92780458e0 100644 --- a/tests/test_rand_crop_by_label_classesd.py +++ b/tests/test_rand_crop_by_label_classesd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,52 +15,107 @@ from parameterized import parameterized from monai.transforms import ClassesToIndicesd, RandCropByLabelClassesd +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ - # One-Hot label - { - "keys": "img", - "label_key": "label", - "num_classes": None, - "spatial_size": [2, 2, -1], - "ratios": [1, 1, 1], - "num_samples": 2, - "image_key": "image", - "image_threshold": 0, - }, - { - "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - }, - list, - (3, 2, 2, 3), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + # One-Hot label + { + "keys": "img", + "label_key": "label", + "num_classes": None, + "spatial_size": [2, 2, -1], + "ratios": [1, 1, 1], + "num_samples": 2, + "image_key": "image", + "image_threshold": 0, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + }, + list, + (3, 2, 2, 3), + ] + ) -TEST_CASE_1 = [ - # Argmax label - { - "keys": "img", - "label_key": "label", - "num_classes": 2, - "spatial_size": [2, 2, 2], - "ratios": [1, 1], - "num_samples": 2, - "image_key": "image", - "image_threshold": 0, - }, - { - "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[1, 3, 3, 3]), - }, - list, - (3, 2, 2, 2), -] + TESTS.append( + [ + # Argmax label + { + "keys": "img", + "label_key": "label", + "num_classes": 2, + "spatial_size": [2, 2, 2], + "ratios": [1, 1], + "num_samples": 2, + "image_key": "image", + "image_threshold": 0, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + }, + list, + (3, 2, 2, 2), + ] + ) + + TESTS.append( + [ + # Argmax label + { + "keys": "img", + "label_key": "label", + "num_classes": 2, + "spatial_size": [4, 4, 2], + "ratios": [1, 1], + "num_samples": 2, + "image_key": "image", + "image_threshold": 0, + "allow_smaller": True, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + }, + list, + (3, 3, 3, 2), + ] + ) + + TESTS.append( + [ + # Argmax label + { + "keys": "img", + "label_key": "label", + "num_classes": 2, + "spatial_size": [4, 4, 4], + "ratios": [1, 1], + "num_samples": 2, + "image_key": "image", + "image_threshold": 0, + "allow_smaller": True, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + }, + list, + (3, 3, 3, 3), + ] + ) class TestRandCropByLabelClassesd(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) + @parameterized.expand(TESTS) def test_type_shape(self, input_param, input_data, expected_type, expected_shape): result = RandCropByLabelClassesd(**input_param)(input_data) self.assertIsInstance(result, expected_type) diff --git a/tests/test_rand_crop_by_pos_neg_label.py b/tests/test_rand_crop_by_pos_neg_label.py index e0f669ab3f..f8b8a77a45 100644 --- a/tests/test_rand_crop_by_pos_neg_label.py +++ b/tests/test_rand_crop_by_pos_neg_label.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,68 +10,123 @@ # limitations under the License. import unittest +from copy import deepcopy import numpy as np from parameterized import parameterized from monai.transforms import RandCropByPosNegLabel +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ - { - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "spatial_size": [2, 2, -1], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - list, - (3, 2, 2, 3), -] +TESTS = [] +TESTS.append( + [ + { + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "spatial_size": [2, 2, -1], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_threshold": 0, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, + (3, 2, 2, 3), + ] +) +TESTS.append( + [ + { + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "spatial_size": [2, 2, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_threshold": 0, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, + (3, 2, 2, 2), + ] +) +TESTS.append( + [ + { + "label": None, + "spatial_size": [2, 2, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_threshold": 0, + }, + { + "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + }, + (3, 2, 2, 2), + ] +) +TESTS.append( + [ + { + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "spatial_size": [4, 4, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "allow_smaller": True, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, + (3, 3, 3, 2), + ] +) +TESTS.append( + [ + { + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "spatial_size": [4, 4, 4], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "allow_smaller": True, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, + (3, 3, 3, 3), + ] +) -TEST_CASE_1 = [ - { - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "spatial_size": [2, 2, 2], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - list, - (3, 2, 2, 2), -] -TEST_CASE_2 = [ - { - "label": None, - "spatial_size": [2, 2, 2], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - { - "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - }, - list, - (3, 2, 2, 2), -] +class TestRandCropByPosNegLabel(unittest.TestCase): + @staticmethod + def convert_data_type(im_type, d, keys=("img", "image", "label")): + out = deepcopy(d) + for k, v in out.items(): + if k in keys and isinstance(v, np.ndarray): + out[k] = im_type(v) + return out + @parameterized.expand(TESTS) + def test_type_shape(self, input_param, input_data, expected_shape): + results = [] + for p in TEST_NDARRAYS: + input_param_mod = self.convert_data_type(p, input_param) + input_data_mod = self.convert_data_type(p, input_data) + cropper = RandCropByPosNegLabel(**input_param_mod) + cropper.set_random_state(0) + result = cropper(**input_data_mod) -class TestRandCropByPosNegLabel(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) - def test_type_shape(self, input_param, input_data, expected_type, expected_shape): - result = RandCropByPosNegLabel(**input_param)(**input_data) - self.assertIsInstance(result, expected_type) - self.assertTupleEqual(result[0].shape, expected_shape) + self.assertIsInstance(result, list) + self.assertTupleEqual(result[0].shape, expected_shape) + + # check for same results across numpy, torch.Tensor and torch.cuda.Tensor + result = np.asarray([i if isinstance(i, np.ndarray) else i.cpu().numpy() for i in result]) + results.append(np.asarray(result)) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) if __name__ == "__main__": diff --git a/tests/test_rand_crop_by_pos_neg_labeld.py b/tests/test_rand_crop_by_pos_neg_labeld.py index 17a3e117bb..df85b29b00 100644 --- a/tests/test_rand_crop_by_pos_neg_labeld.py +++ b/tests/test_rand_crop_by_pos_neg_labeld.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,90 +10,142 @@ # limitations under the License. import unittest +from copy import deepcopy import numpy as np from parameterized import parameterized from monai.transforms import RandCropByPosNegLabeld +from monai.utils.enums import PostFix +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ - { - "keys": ["image", "extra", "label"], - "label_key": "label", - "spatial_size": [-1, 2, 2], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image_key": None, - "image_threshold": 0, - }, - { - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "extra": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, - }, - list, - (3, 3, 2, 2), +TESTS = [ + [ + { + "keys": ["image", "extra", "label"], + "label_key": "label", + "spatial_size": [-1, 2, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image_key": None, + "image_threshold": 0, + }, + { + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "extra": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + PostFix.meta("image"): {"affine": np.eye(3), "shape": "CHWD"}, + }, + (3, 3, 2, 2), + ], + [ + { + "keys": ["image", "extra", "label"], + "label_key": "label", + "spatial_size": [2, 2, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image_key": None, + "image_threshold": 0, + }, + { + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "extra": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + PostFix.meta("label"): {"affine": np.eye(3), "shape": "CHWD"}, + }, + (3, 2, 2, 2), + ], + [ + { + "keys": ["image", "extra", "label"], + "label_key": "label", + "spatial_size": [2, 2, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image_key": None, + "image_threshold": 0, + }, + { + "image": np.zeros([3, 3, 3, 3]) - 1, + "extra": np.zeros([3, 3, 3, 3]), + "label": np.ones([3, 3, 3, 3]), + PostFix.meta("extra"): {"affine": np.eye(3), "shape": "CHWD"}, + }, + (3, 2, 2, 2), + ], + [ + { + "keys": ["image", "extra", "label"], + "label_key": "label", + "spatial_size": [4, 4, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image_key": None, + "image_threshold": 0, + "allow_smaller": True, + }, + { + "image": np.zeros([3, 3, 3, 3]) - 1, + "extra": np.zeros([3, 3, 3, 3]), + "label": np.ones([3, 3, 3, 3]), + PostFix.meta("extra"): {"affine": np.eye(3), "shape": "CHWD"}, + }, + (3, 3, 3, 2), + ], + [ + { + "keys": ["image", "extra", "label"], + "label_key": "label", + "spatial_size": [4, 4, 4], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image_key": None, + "image_threshold": 0, + "allow_smaller": True, + }, + { + "image": np.zeros([3, 3, 3, 3]) - 1, + "extra": np.zeros([3, 3, 3, 3]), + "label": np.ones([3, 3, 3, 3]), + PostFix.meta("extra"): {"affine": np.eye(3), "shape": "CHWD"}, + }, + (3, 3, 3, 3), + ], ] -TEST_CASE_1 = [ - { - "keys": ["image", "extra", "label"], - "label_key": "label", - "spatial_size": [2, 2, 2], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image_key": None, - "image_threshold": 0, - }, - { - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "extra": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, - }, - list, - (3, 2, 2, 2), -] -TEST_CASE_2 = [ - { - "keys": ["image", "extra", "label"], - "label_key": "label", - "spatial_size": [2, 2, 2], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image_key": None, - "image_threshold": 0, - }, - { - "image": np.zeros([3, 3, 3, 3]) - 1, - "extra": np.zeros([3, 3, 3, 3]), - "label": np.ones([3, 3, 3, 3]), - "extra_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, - }, - list, - (3, 2, 2, 2), -] +class TestRandCropByPosNegLabeld(unittest.TestCase): + @staticmethod + def convert_data_type(im_type, d, keys=("img", "image", "label")): + out = deepcopy(d) + for k, v in out.items(): + if k in keys and isinstance(v, np.ndarray): + out[k] = im_type(v) + return out + @parameterized.expand(TESTS) + def test_type_shape(self, input_param, input_data, expected_shape): + for p in TEST_NDARRAYS: + input_param_mod = self.convert_data_type(p, input_param) + input_data_mod = self.convert_data_type(p, input_data) + cropper = RandCropByPosNegLabeld(**input_param_mod) + cropper.set_random_state(0) + result = cropper(input_data_mod) -class TestRandCropByPosNegLabeld(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) - def test_type_shape(self, input_param, input_data, expected_type, expected_shape): - result = RandCropByPosNegLabeld(**input_param)(input_data) - self.assertIsInstance(result, expected_type) - self.assertTupleEqual(result[0]["image"].shape, expected_shape) - self.assertTupleEqual(result[0]["extra"].shape, expected_shape) - self.assertTupleEqual(result[0]["label"].shape, expected_shape) - _len = len(tuple(input_data.keys())) - self.assertTupleEqual(tuple(result[0].keys())[:_len], tuple(input_data.keys())) - for i, item in enumerate(result): - self.assertEqual(item["image_meta_dict"]["patch_index"], i) - self.assertEqual(item["label_meta_dict"]["patch_index"], i) - self.assertEqual(item["extra_meta_dict"]["patch_index"], i) + self.assertIsInstance(result, list) + + _len = len(tuple(input_data.keys())) + self.assertTupleEqual(tuple(result[0].keys())[:_len], tuple(input_data.keys())) + for k in ("image", "extra", "label"): + self.assertTupleEqual(result[0][k].shape, expected_shape) + for i, item in enumerate(result): + self.assertEqual(item[PostFix.meta(k)]["patch_index"], i) if __name__ == "__main__": diff --git a/tests/test_rand_cucim_dict_transform.py b/tests/test_rand_cucim_dict_transform.py new file mode 100644 index 0000000000..cd41b7f49a --- /dev/null +++ b/tests/test_rand_cucim_dict_transform.py @@ -0,0 +1,185 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandCuCIMd +from monai.utils import optional_import, set_determinism +from tests.utils import HAS_CUPY, skip_if_no_cuda + +_, has_cut = optional_import("cucim.core.operations.expose.transform") +cp, _ = optional_import("cupy") + +set_determinism(seed=0) + +TEST_CASE_COLOR_JITTER_1 = [ + {"name": "color_jitter", "brightness": 0.0, "contrast": 0.0, "saturation": 0.0, "hue": 0.0}, + np.array([[[0, 1], [2, 3]], [[0, 10], [20, 30]], [[0, 50], [100, 150]]], dtype=np.uint8), + np.array([[[0, 1], [2, 3]], [[0, 10], [20, 30]], [[0, 50], [100, 150]]], dtype=np.uint8), +] + +TEST_CASE_FLIP_1 = [ + {"name": "image_flip", "spatial_axis": -1}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]]], dtype=np.float32), +] + +TEST_CASE_RAND_ROTATE_1 = [ + {"name": "rand_image_rotate_90", "prob": 1.0, "max_k": 1, "spatial_axis": (-2, -1)}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]]], dtype=np.float32), +] + + +TEST_CASE_RAND_ROTATE_2 = [ + {"name": "rand_image_rotate_90", "prob": 0.0, "max_k": 1, "spatial_axis": (-2, -1)}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), +] + +TEST_CASE_SCALE_INTENSITY_1 = [ + {"name": "scale_intensity_range", "a_min": 0.0, "a_max": 4.0, "b_min": 0.0, "b_max": 1.0, "clip": False}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]]], dtype=np.float32), +] + +TEST_CASE_ZOOM_1 = [ + {"name": "zoom", "zoom_factor": (0.5, 0.5)}, + np.mgrid[:3, 1:4].astype(dtype=np.float32), + np.concatenate([np.ones((1, 3, 3), dtype=np.float32) * 1.0, np.ones((1, 3, 3), dtype=np.float32) * 2.0]), +] + +TEST_CASE_RAND_ZOOM_1 = [ + {"name": "rand_zoom", "prob": 1.0, "min_zoom": 0.5, "max_zoom": 0.5}, + np.mgrid[:3, 1:4].astype(dtype=np.float32), + np.concatenate([np.ones((1, 3, 3), dtype=np.float32) * 1.0, np.ones((1, 3, 3), dtype=np.float32) * 2.0]), +] + +TEST_CASE_RAND_ZOOM_2 = [ + {"name": "rand_zoom", "prob": 0.0, "min_zoom": 0.5, "max_zoom": 0.5}, + np.mgrid[:3, 1:4].astype(dtype=np.float32), + np.mgrid[:3, 1:4].astype(dtype=np.float32), +] + + +@skip_if_no_cuda +@unittest.skipUnless(HAS_CUPY, "CuPy is required.") +@unittest.skipUnless(has_cut, "cuCIM transforms are required.") +class TestRandCuCIMDict(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_FLIP_1, + TEST_CASE_RAND_ROTATE_1, + TEST_CASE_RAND_ROTATE_2, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + TEST_CASE_RAND_ZOOM_1, + TEST_CASE_RAND_ZOOM_2, + ] + ) + def test_tramsforms_numpy_single(self, params, input, expected): + input = {"image": input} + # apply_prob=1.0 + output = RandCuCIMd(keys="image", apply_prob=1.0, **params)(input)["image"] + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, expected) + # apply_prob=0.0 + output = RandCuCIMd(keys="image", apply_prob=0.0, **params)(input)["image"] + self.assertTrue(output.dtype == input["image"].dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, input["image"]) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_FLIP_1, + TEST_CASE_RAND_ROTATE_1, + TEST_CASE_RAND_ROTATE_2, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + TEST_CASE_RAND_ZOOM_1, + TEST_CASE_RAND_ZOOM_2, + ] + ) + def test_tramsforms_numpy_batch(self, params, input, expected): + input = {"image": input[cp.newaxis, ...]} + expected = expected[cp.newaxis, ...] + # apply_prob=1.0 + output = RandCuCIMd(keys="image", apply_prob=1.0, **params)(input)["image"] + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, expected) + # apply_prob=0.0 + output = RandCuCIMd(keys="image", apply_prob=0.0, **params)(input)["image"] + self.assertTrue(output.dtype == input["image"].dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, input["image"]) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_FLIP_1, + TEST_CASE_RAND_ROTATE_1, + TEST_CASE_RAND_ROTATE_2, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + TEST_CASE_RAND_ZOOM_1, + TEST_CASE_RAND_ZOOM_2, + ] + ) + def test_tramsforms_cupy_single(self, params, input, expected): + input = {"image": cp.asarray(input)} + expected = cp.asarray(expected) + # apply_prob=1.0 + output = RandCuCIMd(keys="image", apply_prob=1.0, **params)(input)["image"] + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, expected) + # apply_prob=0.0 + output = RandCuCIMd(keys="image", apply_prob=0.0, **params)(input)["image"] + self.assertTrue(output.dtype == input["image"].dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, input["image"]) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_FLIP_1, + TEST_CASE_RAND_ROTATE_1, + TEST_CASE_RAND_ROTATE_2, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + TEST_CASE_RAND_ZOOM_1, + TEST_CASE_RAND_ZOOM_2, + ] + ) + def test_tramsforms_cupy_batch(self, params, input, expected): + input = {"image": cp.asarray(input)[cp.newaxis, ...]} + expected = cp.asarray(expected)[cp.newaxis, ...] + # apply_prob=1.0 + output = RandCuCIMd(keys="image", **params)(input)["image"] + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, expected) + # apply_prob=0.0 + output = RandCuCIMd(keys="image", apply_prob=0.0, **params)(input)["image"] + self.assertTrue(output.dtype == input["image"].dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, input["image"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_cucim_transform.py b/tests/test_rand_cucim_transform.py new file mode 100644 index 0000000000..0950329833 --- /dev/null +++ b/tests/test_rand_cucim_transform.py @@ -0,0 +1,184 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandCuCIM +from monai.utils import optional_import, set_determinism +from tests.utils import HAS_CUPY, skip_if_no_cuda + +_, has_cut = optional_import("cucim.core.operations.expose.transform") +cp, _ = optional_import("cupy") + +set_determinism(seed=0) + +TEST_CASE_COLOR_JITTER_1 = [ + {"name": "color_jitter", "brightness": 0.0, "contrast": 0.0, "saturation": 0.0, "hue": 0.0}, + np.array([[[0, 1], [2, 3]], [[0, 10], [20, 30]], [[0, 50], [100, 150]]], dtype=np.uint8), + np.array([[[0, 1], [2, 3]], [[0, 10], [20, 30]], [[0, 50], [100, 150]]], dtype=np.uint8), +] + +TEST_CASE_FLIP_1 = [ + {"name": "image_flip", "spatial_axis": -1}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]]], dtype=np.float32), +] + +TEST_CASE_RAND_ROTATE_1 = [ + {"name": "rand_image_rotate_90", "prob": 1.0, "max_k": 1, "spatial_axis": (-2, -1)}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]]], dtype=np.float32), +] + + +TEST_CASE_RAND_ROTATE_2 = [ + {"name": "rand_image_rotate_90", "prob": 0.0, "max_k": 1, "spatial_axis": (-2, -1)}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), +] + +TEST_CASE_SCALE_INTENSITY_1 = [ + {"name": "scale_intensity_range", "a_min": 0.0, "a_max": 4.0, "b_min": 0.0, "b_max": 1.0, "clip": False}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]]], dtype=np.float32), +] + +TEST_CASE_ZOOM_1 = [ + {"name": "zoom", "zoom_factor": (0.5, 0.5)}, + np.mgrid[:3, 1:4].astype(dtype=np.float32), + np.concatenate([np.ones((1, 3, 3), dtype=np.float32) * 1.0, np.ones((1, 3, 3), dtype=np.float32) * 2.0]), +] + +TEST_CASE_RAND_ZOOM_1 = [ + {"name": "rand_zoom", "prob": 1.0, "min_zoom": 0.5, "max_zoom": 0.5}, + np.mgrid[:3, 1:4].astype(dtype=np.float32), + np.concatenate([np.ones((1, 3, 3), dtype=np.float32) * 1.0, np.ones((1, 3, 3), dtype=np.float32) * 2.0]), +] + +TEST_CASE_RAND_ZOOM_2 = [ + {"name": "rand_zoom", "prob": 0.0, "min_zoom": 0.5, "max_zoom": 0.5}, + np.mgrid[:3, 1:4].astype(dtype=np.float32), + np.mgrid[:3, 1:4].astype(dtype=np.float32), +] + + +@skip_if_no_cuda +@unittest.skipUnless(HAS_CUPY, "CuPy is required.") +@unittest.skipUnless(has_cut, "cuCIM transforms are required.") +class TestRandCuCIM(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_FLIP_1, + TEST_CASE_RAND_ROTATE_1, + TEST_CASE_RAND_ROTATE_2, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + TEST_CASE_RAND_ZOOM_1, + TEST_CASE_RAND_ZOOM_2, + ] + ) + def test_tramsforms_numpy_single(self, params, input, expected): + # apply_prob=1.0 + output = RandCuCIM(apply_prob=1.0, **params)(input) + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, expected) + # apply_prob=0.0 + output = RandCuCIM(apply_prob=0.0, **params)(input) + self.assertTrue(output.dtype == input.dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, input) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_FLIP_1, + TEST_CASE_RAND_ROTATE_1, + TEST_CASE_RAND_ROTATE_2, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + TEST_CASE_RAND_ZOOM_1, + TEST_CASE_RAND_ZOOM_2, + ] + ) + def test_tramsforms_numpy_batch(self, params, input, expected): + input = input[cp.newaxis, ...] + expected = expected[cp.newaxis, ...] + # apply_prob=1.0 + output = RandCuCIM(apply_prob=1.0, **params)(input) + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, expected) + # apply_prob=0.0 + output = RandCuCIM(apply_prob=0.0, **params)(input) + self.assertTrue(output.dtype == input.dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, input) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_FLIP_1, + TEST_CASE_RAND_ROTATE_1, + TEST_CASE_RAND_ROTATE_2, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + TEST_CASE_RAND_ZOOM_1, + TEST_CASE_RAND_ZOOM_2, + ] + ) + def test_tramsforms_cupy_single(self, params, input, expected): + input = cp.asarray(input) + expected = cp.asarray(expected) + # apply_prob=1.0 + output = RandCuCIM(apply_prob=1.0, **params)(input) + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, expected) + # apply_prob=0.0 + output = RandCuCIM(apply_prob=0.0, **params)(input) + self.assertTrue(output.dtype == input.dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, input) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_FLIP_1, + TEST_CASE_RAND_ROTATE_1, + TEST_CASE_RAND_ROTATE_2, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + TEST_CASE_RAND_ZOOM_1, + TEST_CASE_RAND_ZOOM_2, + ] + ) + def test_tramsforms_cupy_batch(self, params, input, expected): + input = cp.asarray(input)[cp.newaxis, ...] + expected = cp.asarray(expected)[cp.newaxis, ...] + # apply_prob=1.0 + output = RandCuCIM(**params)(input) + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, expected) + # apply_prob=0.0 + output = RandCuCIM(apply_prob=0.0, **params)(input) + self.assertTrue(output.dtype == input.dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, input) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_deform_grid.py b/tests/test_rand_deform_grid.py index 7c12c263d2..8a2c8bf6eb 100644 --- a/tests/test_rand_deform_grid.py +++ b/tests/test_rand_deform_grid.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,10 +12,10 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import RandDeformGrid +from tests.utils import assert_allclose TEST_CASES = [ [ @@ -129,11 +129,7 @@ def test_rand_deform_grid(self, input_param, input_data, expected_val): g = RandDeformGrid(**input_param) g.set_random_state(123) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, type_test=False, rtol=1e-3, atol=1e-3) if __name__ == "__main__": diff --git a/tests/test_rand_elastic_2d.py b/tests/test_rand_elastic_2d.py index fbfb7d5761..bc23a6c5cb 100644 --- a/tests/test_rand_elastic_2d.py +++ b/tests/test_rand_elastic_2d.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,90 +16,103 @@ from parameterized import parameterized from monai.transforms import Rand2DElastic +from tests.utils import TEST_NDARRAYS, assert_allclose, is_tf32_env -TEST_CASES = [ - [ - {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "as_tensor_output": False, "device": None}, - {"img": torch.ones((3, 3, 3)), "spatial_size": (2, 2)}, - np.ones((3, 2, 2)), - ], - [ - {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "as_tensor_output": False, "device": None}, - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - { - "spacing": (0.3, 0.3), - "magnitude_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "device": None, - "padding_mode": "zeros", - }, - {"img": torch.ones((3, 3, 3)), "spatial_size": (2, 2), "mode": "bilinear"}, - np.array( +_rtol = 5e-3 if is_tf32_env() else 1e-4 + +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "device": device}, + {"img": p(torch.ones((3, 3, 3))), "spatial_size": (2, 2)}, + p(np.ones((3, 2, 2))), + ] + ) + TESTS.append( + [ + {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "device": device}, + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), + ] + ) + TESTS.append( [ - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], + { + "spacing": (0.3, 0.3), + "magnitude_range": (1.0, 2.0), + "prob": 0.9, + "device": device, + "padding_mode": "zeros", + }, + {"img": p(torch.ones((3, 3, 3))), "spatial_size": (2, 2), "mode": "bilinear"}, + p( + np.array( + [ + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + ] + ) + ), ] - ), - ], - [ - { - "spacing": (1.0, 1.0), - "magnitude_range": (1.0, 1.0), - "scale_range": [1.2, 2.2], - "prob": 0.9, - "padding_mode": "border", - "as_tensor_output": True, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3))}, - torch.tensor( + ) + TESTS.append( [ - [[3.0793, 2.6141], [4.0568, 5.9978]], - [[12.0793, 11.6141], [13.0568, 14.9978]], - [[21.0793, 20.6141], [22.0568, 23.9978]], + { + "spacing": (1.0, 1.0), + "magnitude_range": (1.0, 1.0), + "scale_range": [1.2, 2.2], + "prob": 0.9, + "padding_mode": "border", + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + torch.tensor( + [ + [[3.0793, 2.6141], [4.0568, 5.9978]], + [[12.0793, 11.6141], [13.0568, 14.9978]], + [[21.0793, 20.6141], [22.0568, 23.9978]], + ] + ) + ), ] - ), - ], - [ - { - "spacing": (0.3, 0.3), - "magnitude_range": (0.1, 0.2), - "translate_range": [-0.01, 0.01], - "scale_range": [0.01, 0.02], - "prob": 0.9, - "as_tensor_output": False, - "device": "cuda" if torch.cuda.is_available() else "cpu", - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.array( + ) + TESTS.append( [ - [[1.3584113, 1.9251312], [5.626623, 6.642721]], - [[10.358411, 10.925131], [14.626623, 15.642721]], - [[19.358412, 19.92513], [23.626623, 24.642721]], + { + "spacing": (0.3, 0.3), + "magnitude_range": (0.1, 0.2), + "translate_range": [-0.01, 0.01], + "scale_range": [0.01, 0.02], + "prob": 0.9, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + np.array( + [ + [[1.3584113, 1.9251312], [5.626623, 6.642721]], + [[10.358411, 10.925131], [14.626623, 15.642721]], + [[19.358412, 19.92513], [23.626623, 24.642721]], + ] + ) + ), ] - ), - ], -] + ) class TestRand2DElastic(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_2d_elastic(self, input_param, input_data, expected_val): g = Rand2DElastic(**input_param) g.set_random_state(123) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=_rtol, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_elastic_3d.py b/tests/test_rand_elastic_3d.py index c63282d571..39ce779cb0 100644 --- a/tests/test_rand_elastic_3d.py +++ b/tests/test_rand_elastic_3d.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,69 +16,79 @@ from parameterized import parameterized from monai.transforms import Rand3DElastic +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - { - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": -1, - }, - {"img": torch.arange(72).reshape((2, 3, 3, 4))}, - np.arange(72).reshape((2, 3, 3, 4)), - ], - [ - { - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - }, - {"img": torch.ones((2, 3, 3, 3)), "spatial_size": (2, 2, 2)}, - np.ones((2, 2, 2, 2)), - ], - [ - { - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "device": None, - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "spatial_size": (2, 2, 2)}, - np.array([[[[6.4939356, 7.50289], [9.518351, 10.522849]], [[15.512375, 16.523542], [18.531467, 19.53646]]]]), - ], - [ - { - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "rotate_range": [1, 1, 1], - "as_tensor_output": False, - "device": "cuda" if torch.cuda.is_available() else "cpu", - "spatial_size": (2, 2, 2), - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "mode": "bilinear"}, - np.array([[[[5.0069294, 9.463932], [9.287769, 13.739735]], [[12.319424, 16.777205], [16.594296, 21.045748]]]]), - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + { + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": -1, + }, + {"img": p(torch.arange(72).reshape((2, 3, 3, 4)))}, + p(np.arange(72).reshape((2, 3, 3, 4))), + ] + ) + TESTS.append( + [ + {"magnitude_range": (0.3, 2.3), "sigma_range": (1.0, 20.0), "prob": 0.0, "device": device}, + {"img": p(torch.ones((2, 3, 3, 3))), "spatial_size": (2, 2, 2)}, + p(np.ones((2, 2, 2, 2))), + ] + ) + TESTS.append( + [ + {"magnitude_range": (0.3, 0.3), "sigma_range": (1.0, 2.0), "prob": 0.9, "device": device}, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "spatial_size": (2, 2, 2)}, + p( + np.array( + [ + [ + [[6.4939356, 7.50289], [9.518351, 10.522849]], + [[15.512375, 16.523542], [18.531467, 19.53646]], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "rotate_range": [1, 1, 1], + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "mode": "bilinear"}, + p( + np.array( + [ + [ + [[5.0069294, 9.463932], [9.287769, 13.739735]], + [[12.319424, 16.777205], [16.594296, 21.045748]], + ] + ] + ) + ), + ] + ) class TestRand3DElastic(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_3d_elastic(self, input_param, input_data, expected_val): g = Rand3DElastic(**input_param) g.set_random_state(123) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=1e-1, atol=1e-1) if __name__ == "__main__": diff --git a/tests/test_rand_elasticd_2d.py b/tests/test_rand_elasticd_2d.py index f8eb026088..ead39e5731 100644 --- a/tests/test_rand_elasticd_2d.py +++ b/tests/test_rand_elasticd_2d.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,127 +16,149 @@ from parameterized import parameterized from monai.transforms import Rand2DElasticd +from tests.utils import TEST_NDARRAYS, assert_allclose, is_tf32_env -TEST_CASES = [ - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (1.0, 2.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.ones((3, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (0.3, 0.3), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": -1, - }, - {"img": torch.arange(4).reshape((1, 2, 2)), "seg": torch.arange(4).reshape((1, 2, 2))}, - np.arange(4).reshape((1, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "padding_mode": "zeros", - "device": None, - "spatial_size": (2, 2), - "mode": "bilinear", - }, - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.array( +_rtol = 5e-3 if is_tf32_env() else 1e-4 + +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (1.0, 2.0), + "prob": 0.0, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p(np.ones((3, 2, 2))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (0.3, 0.3), + "prob": 0.0, + "device": device, + "spatial_size": -1, + }, + {"img": p(torch.arange(4).reshape((1, 2, 2))), "seg": p(torch.arange(4).reshape((1, 2, 2)))}, + p(np.arange(4).reshape((1, 2, 2))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (1.0, 2.0), + "prob": 0.9, + "padding_mode": "zeros", + "device": device, + "spatial_size": (2, 2), + "mode": "bilinear", + }, + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p( + np.array( + [ + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + ] + ) + ), + ] + ) + TESTS.append( [ - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], + { + "keys": ("img", "seg"), + "spacing": (1.0, 1.0), + "magnitude_range": (1.0, 1.0), + "scale_range": [1.2, 2.2], + "prob": 0.9, + "padding_mode": "border", + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + torch.tensor( + [ + [[3.0793, 2.6141], [4.0568, 5.9978]], + [[12.0793, 11.6141], [13.0568, 14.9978]], + [[21.0793, 20.6141], [22.0568, 23.9978]], + ] + ) + ), ] - ), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (1.0, 1.0), - "magnitude_range": (1.0, 1.0), - "scale_range": [1.2, 2.2], - "prob": 0.9, - "padding_mode": "border", - "as_tensor_output": True, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - torch.tensor( + ) + TESTS.append( [ - [[3.0793, 2.6141], [4.0568, 5.9978]], - [[12.0793, 11.6141], [13.0568, 14.9978]], - [[21.0793, 20.6141], [22.0568, 23.9978]], + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (0.1, 0.2), + "translate_range": [-0.01, 0.01], + "scale_range": [0.01, 0.02], + "prob": 0.9, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + np.array( + [ + [[1.3584113, 1.9251312], [5.626623, 6.642721]], + [[10.358411, 10.925131], [14.626623, 15.642721]], + [[19.358412, 19.92513], [23.626623, 24.642721]], + ] + ) + ), ] - ), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (0.1, 0.2), - "translate_range": [-0.01, 0.01], - "scale_range": [0.01, 0.02], - "prob": 0.9, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - np.array( + ) + TESTS.append( [ - [[1.3584113, 1.9251312], [5.626623, 6.642721]], - [[10.358411, 10.925131], [14.626623, 15.642721]], - [[19.358412, 19.92513], [23.626623, 24.642721]], + { + "keys": ("img", "seg"), + "mode": ("bilinear", "nearest"), + "spacing": (0.3, 0.3), + "magnitude_range": (0.1, 0.2), + "translate_range": [-0.01, 0.01], + "scale_range": [0.01, 0.02], + "prob": 0.9, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + { + "img": p( + torch.tensor( + [ + [[1.3584, 1.9251], [5.6266, 6.6427]], + [[10.3584, 10.9251], [14.6266, 15.6427]], + [[19.3584, 19.9251], [23.6266, 24.6427]], + ] + ) + ), + "seg": p( + torch.tensor( + [[[0.0, 2.0], [6.0, 8.0]], [[9.0, 11.0], [15.0, 17.0]], [[18.0, 20.0], [24.0, 26.0]]] + ) + ), + }, ] - ), - ], - [ - { - "keys": ("img", "seg"), - "mode": ("bilinear", "nearest"), - "spacing": (0.3, 0.3), - "magnitude_range": (0.1, 0.2), - "translate_range": [-0.01, 0.01], - "scale_range": [0.01, 0.02], - "prob": 0.9, - "as_tensor_output": True, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - { - "img": torch.tensor( - [ - [[1.3584, 1.9251], [5.6266, 6.6427]], - [[10.3584, 10.9251], [14.6266, 15.6427]], - [[19.3584, 19.9251], [23.6266, 24.6427]], - ] - ), - "seg": torch.tensor([[[0.0, 2.0], [6.0, 8.0]], [[9.0, 11.0], [15.0, 17.0]], [[18.0, 20.0], [24.0, 26.0]]]), - }, - ], -] + ) class TestRand2DElasticd(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_2d_elasticd(self, input_param, input_data, expected_val): g = Rand2DElasticd(**input_param) g.set_random_state(123) @@ -144,11 +166,7 @@ def test_rand_2d_elasticd(self, input_param, input_data, expected_val): for key in res: result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected, rtol=_rtol, atol=5e-3) if __name__ == "__main__": diff --git a/tests/test_rand_elasticd_3d.py b/tests/test_rand_elasticd_3d.py index 47ab814882..c78ed1f42e 100644 --- a/tests/test_rand_elasticd_3d.py +++ b/tests/test_rand_elasticd_3d.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,98 +16,128 @@ from parameterized import parameterized from monai.transforms import Rand3DElasticd +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2, 2), - }, - {"img": torch.ones((2, 3, 3, 3)), "seg": torch.ones((2, 3, 3, 3))}, - np.ones((2, 2, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, -1, -1), - }, - {"img": torch.ones((2, 3, 3, 3)), "seg": torch.ones((2, 3, 3, 3))}, - np.ones((2, 2, 3, 3)), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": -1, - }, - {"img": torch.arange(8).reshape((1, 2, 2, 2)), "seg": torch.arange(8).reshape((1, 2, 2, 2))}, - np.arange(8).reshape((1, 2, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2, 2), - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "seg": torch.arange(27).reshape((1, 3, 3, 3))}, - np.array([[[[6.4939356, 7.50289], [9.518351, 10.522849]], [[15.512375, 16.523542], [18.531467, 19.53646]]]]), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "rotate_range": [1, 1, 1], - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2, 2), - "mode": "bilinear", - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "seg": torch.arange(27).reshape((1, 3, 3, 3))}, - np.array([[[[5.0069294, 9.463932], [9.287769, 13.739735]], [[12.319424, 16.777205], [16.594296, 21.045748]]]]), - ], - [ - { - "keys": ("img", "seg"), - "mode": ("bilinear", "nearest"), - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "rotate_range": [1, 1, 1], - "as_tensor_output": True, - "device": torch.device("cpu:0"), - "spatial_size": (2, 2, 2), - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "seg": torch.arange(27).reshape((1, 3, 3, 3))}, - { - "img": torch.tensor([[[[5.0069, 9.4639], [9.2878, 13.7397]], [[12.3194, 16.7772], [16.5943, 21.0457]]]]), - "seg": torch.tensor([[[[4.0, 14.0], [7.0, 14.0]], [[9.0, 19.0], [12.0, 22.0]]]]), - }, - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.ones((2, 3, 3, 3))), "seg": p(torch.ones((2, 3, 3, 3)))}, + p(np.ones((2, 2, 2, 2))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": (2, -1, -1), + }, + {"img": p(torch.ones((2, 3, 3, 3))), "seg": p(torch.ones((2, 3, 3, 3)))}, + p(np.ones((2, 2, 3, 3))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": -1, + }, + {"img": p(torch.arange(8).reshape((1, 2, 2, 2))), "seg": p(torch.arange(8).reshape((1, 2, 2, 2)))}, + p(np.arange(8).reshape((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "seg": p(torch.arange(27).reshape((1, 3, 3, 3)))}, + p( + np.array( + [ + [ + [[6.4939356, 7.50289], [9.518351, 10.522849]], + [[15.512375, 16.523542], [18.531467, 19.53646]], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "rotate_range": [1, 1, 1], + "device": device, + "spatial_size": (2, 2, 2), + "mode": "bilinear", + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "seg": p(torch.arange(27).reshape((1, 3, 3, 3)))}, + p( + np.array( + [ + [ + [[5.0069294, 9.463932], [9.287769, 13.739735]], + [[12.319424, 16.777205], [16.594296, 21.045748]], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "mode": ("bilinear", "nearest"), + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "rotate_range": [1, 1, 1], + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "seg": p(torch.arange(27).reshape((1, 3, 3, 3)))}, + { + "img": p( + torch.tensor( + [[[[5.0069, 9.4639], [9.2878, 13.7397]], [[12.3194, 16.7772], [16.5943, 21.0457]]]] + ) + ), + "seg": p(torch.tensor([[[[4.0, 14.0], [7.0, 14.0]], [[9.0, 19.0], [12.0, 22.0]]]])), + }, + ] + ) class TestRand3DElasticd(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_3d_elasticd(self, input_param, input_data, expected_val): g = Rand3DElasticd(**input_param) g.set_random_state(123) @@ -115,11 +145,7 @@ def test_rand_3d_elasticd(self, input_param, input_data, expected_val): for key in res: result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected, rtol=1e-2, atol=1e-2) if __name__ == "__main__": diff --git a/tests/test_rand_flip.py b/tests/test_rand_flip.py index b3c514cb1f..b9e9a8c4d6 100644 --- a/tests/test_rand_flip.py +++ b/tests/test_rand_flip.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,12 +34,10 @@ def test_correct_results(self, _, spatial_axis): for p in TEST_NDARRAYS: im = p(self.imt[0]) flip = RandFlip(prob=1.0, spatial_axis=spatial_axis) - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, spatial_axis)) + expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) result = flip(im) - assert_allclose(expected, result) + assert_allclose(result, p(expected)) if __name__ == "__main__": diff --git a/tests/test_rand_flipd.py b/tests/test_rand_flipd.py index 8972024fd8..9a92661c59 100644 --- a/tests/test_rand_flipd.py +++ b/tests/test_rand_flipd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -26,11 +26,9 @@ def test_correct_results(self, _, spatial_axis): for p in TEST_NDARRAYS: flip = RandFlipd(keys="img", prob=1.0, spatial_axis=spatial_axis) result = flip({"img": p(self.imt[0])})["img"] - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, spatial_axis)) + expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(expected, result) + assert_allclose(result, p(expected)) if __name__ == "__main__": diff --git a/tests/test_rand_gaussian_noise.py b/tests/test_rand_gaussian_noise.py index d376add460..1f2adfb9e7 100644 --- a/tests/test_rand_gaussian_noise.py +++ b/tests/test_rand_gaussian_noise.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_rand_gaussian_noised.py b/tests/test_rand_gaussian_noised.py index 4b0d2a311a..be1df0f2e6 100644 --- a/tests/test_rand_gaussian_noised.py +++ b/tests/test_rand_gaussian_noised.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -29,14 +29,16 @@ class TestRandGaussianNoised(NumpyImageTestCase2D): @parameterized.expand(TESTS) def test_correct_results(self, _, im_type, keys, mean, std): - gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std) + gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std, dtype=np.float64) gaussian_fn.set_random_state(seed) im = im_type(self.imt) noised = gaussian_fn({k: im for k in keys}) np.random.seed(seed) + # simulate the randomize() of transform np.random.random() + noise = np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape) for k in keys: - expected = self.imt + np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape) + expected = self.imt + noise self.assertEqual(type(im), type(noised[k])) if isinstance(noised[k], torch.Tensor): noised[k] = noised[k].cpu() diff --git a/tests/test_rand_gaussian_sharpen.py b/tests/test_rand_gaussian_sharpen.py index 909f96f56b..06563a35b6 100644 --- a/tests/test_rand_gaussian_sharpen.py +++ b/tests/test_rand_gaussian_sharpen.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,88 +11,127 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import RandGaussianSharpen +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"prob": 1.0}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( +TESTS = [] + +for p in TEST_NDARRAYS: + TESTS.append( [ - [[5.2919216, 5.5854445, 5.29192], [11.3982, 12.62332, 11.398202], [14.870525, 17.323769, 14.870527]], - [[20.413757, 22.767355, 20.413757], [28.495504, 31.558315, 28.495499], [29.99236, 34.505676, 29.992361]], + {"prob": 1.0}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [5.2919216, 5.5854445, 5.29192], + [11.3982, 12.62332, 11.398202], + [14.870525, 17.323769, 14.870527], + ], + [ + [20.413757, 22.767355, 20.413757], + [28.495504, 31.558315, 28.495499], + [29.99236, 34.505676, 29.992361], + ], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - { - "sigma1_x": (0.5, 0.75), - "sigma1_y": (0.5, 0.75), - "sigma1_z": (0.5, 0.75), - "sigma2_x": 0.4, - "sigma2_y": 0.4, - "sigma2_z": 0.4, - "prob": 1.0, - }, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[4.1071496, 3.597953, 4.1071477], [10.062014, 9.825114, 10.0620165], [14.698058, 15.818766, 14.698058]], - [[18.211048, 18.16049, 18.211048], [25.155039, 24.56279, 25.155039], [28.801964, 30.381308, 28.801964]], + { + "sigma1_x": (0.5, 0.75), + "sigma1_y": (0.5, 0.75), + "sigma1_z": (0.5, 0.75), + "sigma2_x": 0.4, + "sigma2_y": 0.4, + "sigma2_z": 0.4, + "prob": 1.0, + }, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [4.1071496, 3.597953, 4.1071477], + [10.062014, 9.825114, 10.0620165], + [14.698058, 15.818766, 14.698058], + ], + [ + [18.211048, 18.16049, 18.211048], + [25.155039, 24.56279, 25.155039], + [28.801964, 30.381308, 28.801964], + ], + ] + ), ] - ), -] + ) -TEST_CASE_3 = [ - { - "sigma1_x": (0.5, 0.75), - "sigma1_y": (0.5, 0.75), - "sigma1_z": (0.5, 0.75), - "sigma2_x": (0.5, 0.75), - "sigma2_y": (0.5, 0.75), - "sigma2_z": (0.5, 0.75), - "prob": 1.0, - }, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[4.81077, 4.4237204, 4.81077], [12.061236, 12.298177, 12.061236], [17.362553, 19.201174, 17.362553]], - [[21.440754, 22.142393, 21.440754], [30.15308, 30.745445, 30.153086], [33.99255, 36.919838, 33.99255]], + { + "sigma1_x": (0.5, 0.75), + "sigma1_y": (0.5, 0.75), + "sigma1_z": (0.5, 0.75), + "sigma2_x": (0.5, 0.75), + "sigma2_y": (0.5, 0.75), + "sigma2_z": (0.5, 0.75), + "prob": 1.0, + }, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [4.81077, 4.4237204, 4.81077], + [12.061236, 12.298177, 12.061236], + [17.362553, 19.201174, 17.362553], + ], + [ + [21.440754, 22.142393, 21.440754], + [30.15308, 30.745445, 30.153086], + [33.99255, 36.919838, 33.99255], + ], + ] + ), ] - ), -] + ) -TEST_CASE_4 = [ - { - "sigma1_x": (0.5, 0.75), - "sigma1_y": (0.5, 0.75), - "sigma1_z": (0.5, 0.75), - "sigma2_x": (0.5, 0.75), - "sigma2_y": (0.5, 0.75), - "sigma2_z": (0.5, 0.75), - "approx": "scalespace", - "prob": 1.0, - }, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[4.430213, 3.2278745, 4.4302144], [10.325399, 8.507457, 10.325399], [17.494898, 16.5609, 17.494894]], - [[20.87405, 18.06946, 20.87405], [25.813503, 21.268656, 25.8135], [33.93874, 31.402481, 33.938725]], + { + "sigma1_x": (0.5, 0.75), + "sigma1_y": (0.5, 0.75), + "sigma1_z": (0.5, 0.75), + "sigma2_x": (0.5, 0.75), + "sigma2_y": (0.5, 0.75), + "sigma2_z": (0.5, 0.75), + "approx": "scalespace", + "prob": 1.0, + }, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [4.430213, 3.2278745, 4.4302144], + [10.325399, 8.507457, 10.325399], + [17.494898, 16.5609, 17.494894], + ], + [[20.87405, 18.06946, 20.87405], [25.813503, 21.268656, 25.8135], [33.93874, 31.402481, 33.938725]], + ] + ), ] - ), -] + ) class TestRandGaussianSharpen(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): converter = RandGaussianSharpen(**argments) converter.set_random_state(seed=0) result = converter(image) - np.testing.assert_allclose(result, expected_data, rtol=1e-4) + assert_allclose(result, expected_data, atol=0, rtol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_gaussian_sharpend.py b/tests/test_rand_gaussian_sharpend.py index 9ba29ee71b..ecffa547e0 100644 --- a/tests/test_rand_gaussian_sharpend.py +++ b/tests/test_rand_gaussian_sharpend.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,87 +15,126 @@ from parameterized import parameterized from monai.transforms import RandGaussianSharpend +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"keys": "img", "prob": 1.0}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( [ - [[5.2919216, 5.5854445, 5.29192], [11.3982, 12.62332, 11.398202], [14.870525, 17.323769, 14.870527]], - [[20.413757, 22.767355, 20.413757], [28.495504, 31.558315, 28.495499], [29.99236, 34.505676, 29.992361]], + {"keys": "img", "prob": 1.0}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [5.2919216, 5.5854445, 5.29192], + [11.3982, 12.62332, 11.398202], + [14.870525, 17.323769, 14.870527], + ], + [ + [20.413757, 22.767355, 20.413757], + [28.495504, 31.558315, 28.495499], + [29.99236, 34.505676, 29.992361], + ], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - { - "keys": "img", - "sigma1_x": (0.5, 0.75), - "sigma1_y": (0.5, 0.75), - "sigma1_z": (0.5, 0.75), - "sigma2_x": 0.4, - "sigma2_y": 0.4, - "sigma2_z": 0.4, - "prob": 1.0, - }, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( + TESTS.append( [ - [[4.1071496, 3.597953, 4.1071477], [10.062014, 9.825114, 10.0620165], [14.698058, 15.818766, 14.698058]], - [[18.211048, 18.16049, 18.211048], [25.155039, 24.56279, 25.155039], [28.801964, 30.381308, 28.801964]], + { + "keys": "img", + "sigma1_x": (0.5, 0.75), + "sigma1_y": (0.5, 0.75), + "sigma1_z": (0.5, 0.75), + "sigma2_x": 0.4, + "sigma2_y": 0.4, + "sigma2_z": 0.4, + "prob": 1.0, + }, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [4.1071496, 3.597953, 4.1071477], + [10.062014, 9.825114, 10.0620165], + [14.698058, 15.818766, 14.698058], + ], + [ + [18.211048, 18.16049, 18.211048], + [25.155039, 24.56279, 25.155039], + [28.801964, 30.381308, 28.801964], + ], + ] + ), ] - ), -] + ) -TEST_CASE_3 = [ - { - "keys": "img", - "sigma1_x": (0.5, 0.75), - "sigma1_y": (0.5, 0.75), - "sigma1_z": (0.5, 0.75), - "sigma2_x": (0.5, 0.75), - "sigma2_y": (0.5, 0.75), - "sigma2_z": (0.5, 0.75), - "prob": 1.0, - }, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( + TESTS.append( [ - [[4.81077, 4.4237204, 4.81077], [12.061236, 12.298177, 12.061236], [17.362553, 19.201174, 17.362553]], - [[21.440754, 22.142393, 21.440754], [30.15308, 30.745445, 30.153086], [33.99255, 36.919838, 33.99255]], + { + "keys": "img", + "sigma1_x": (0.5, 0.75), + "sigma1_y": (0.5, 0.75), + "sigma1_z": (0.5, 0.75), + "sigma2_x": (0.5, 0.75), + "sigma2_y": (0.5, 0.75), + "sigma2_z": (0.5, 0.75), + "prob": 1.0, + }, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [4.81077, 4.4237204, 4.81077], + [12.061236, 12.298177, 12.061236], + [17.362553, 19.201174, 17.362553], + ], + [ + [21.440754, 22.142393, 21.440754], + [30.15308, 30.745445, 30.153086], + [33.99255, 36.919838, 33.99255], + ], + ] + ), ] - ), -] + ) -TEST_CASE_4 = [ - { - "keys": "img", - "sigma1_x": (0.5, 0.75), - "sigma1_y": (0.5, 0.75), - "sigma1_z": (0.5, 0.75), - "sigma2_x": (0.5, 0.75), - "sigma2_y": (0.5, 0.75), - "sigma2_z": (0.5, 0.75), - "approx": "scalespace", - "prob": 1.0, - }, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( + TESTS.append( [ - [[4.430213, 3.2278745, 4.4302144], [10.325399, 8.507457, 10.325399], [17.494898, 16.5609, 17.494894]], - [[20.87405, 18.06946, 20.87405], [25.813503, 21.268656, 25.8135], [33.93874, 31.402481, 33.938725]], + { + "keys": "img", + "sigma1_x": (0.5, 0.75), + "sigma1_y": (0.5, 0.75), + "sigma1_z": (0.5, 0.75), + "sigma2_x": (0.5, 0.75), + "sigma2_y": (0.5, 0.75), + "sigma2_z": (0.5, 0.75), + "approx": "scalespace", + "prob": 1.0, + }, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [4.430213, 3.2278745, 4.4302144], + [10.325399, 8.507457, 10.325399], + [17.494898, 16.5609, 17.494894], + ], + [[20.87405, 18.06946, 20.87405], [25.813503, 21.268656, 25.8135], [33.93874, 31.402481, 33.938725]], + ] + ), ] - ), -] + ) class TestRandGaussianSharpend(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): converter = RandGaussianSharpend(**argments) converter.set_random_state(seed=0) result = converter(image) - np.testing.assert_allclose(result["img"], expected_data, rtol=1e-4) + assert_allclose(result["img"], expected_data, rtol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_gaussian_smooth.py b/tests/test_rand_gaussian_smooth.py index 889ed7d6d5..d51618be95 100644 --- a/tests/test_rand_gaussian_smooth.py +++ b/tests/test_rand_gaussian_smooth.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,48 +15,81 @@ from parameterized import parameterized from monai.transforms import RandGaussianSmooth +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"sigma_x": (0.5, 1.5), "prob": 1.0}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( [ - [[0.71806467, 0.9074683, 0.71806467], [1.0718315, 1.3545481, 1.0718315], [1.0337002, 1.306359, 1.0337002]], - [[2.0318885, 2.5678391, 2.0318885], [2.6795788, 3.3863702, 2.6795788], [2.3475242, 2.9667296, 2.3475242]], + {"sigma_x": (0.5, 1.5), "prob": 1.0}, + p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])), + np.array( + [ + [ + [0.71806467, 0.9074683, 0.71806467], + [1.0718315, 1.3545481, 1.0718315], + [1.0337002, 1.306359, 1.0337002], + ], + [ + [2.0318885, 2.5678391, 2.0318885], + [2.6795788, 3.3863702, 2.6795788], + [2.3475242, 2.9667296, 2.3475242], + ], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - {"sigma_x": (0.5, 1.5), "sigma_y": (0.5, 1.0), "prob": 1.0}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[0.7686928, 0.9848021, 0.7686928], [1.1474025, 1.4699818, 1.1474024], [1.1065826, 1.4176859, 1.1065826]], - [[2.1751494, 2.7866683, 2.1751497], [2.8685062, 3.6749542, 2.8685062], [2.5130394, 3.219552, 2.5130394]], + {"sigma_x": (0.5, 1.5), "sigma_y": (0.5, 1.0), "prob": 1.0}, + p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])), + np.array( + [ + [ + [0.7686928, 0.9848021, 0.7686928], + [1.1474025, 1.4699818, 1.1474024], + [1.1065826, 1.4176859, 1.1065826], + ], + [ + [2.1751494, 2.7866683, 2.1751497], + [2.8685062, 3.6749542, 2.8685062], + [2.5130394, 3.219552, 2.5130394], + ], + ] + ), ] - ), -] + ) -TEST_CASE_3 = [ - {"sigma_x": (0.5, 1.5), "sigma_y": (0.5, 1.0), "approx": "scalespace", "prob": 1.0}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[0.8128456, 0.96736777, 0.8128456], [1.2742369, 1.5164697, 1.2742369], [1.2800367, 1.5233722, 1.2800368]], - [[2.3825073, 2.8354228, 2.3825073], [3.1855922, 3.7911744, 3.1855922], [2.8496985, 3.391427, 2.8496985]], + {"sigma_x": (0.5, 1.5), "sigma_y": (0.5, 1.0), "approx": "scalespace", "prob": 1.0}, + p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])), + np.array( + [ + [ + [0.8128456, 0.96736777, 0.8128456], + [1.2742369, 1.5164697, 1.2742369], + [1.2800367, 1.5233722, 1.2800368], + ], + [ + [2.3825073, 2.8354228, 2.3825073], + [3.1855922, 3.7911744, 3.1855922], + [2.8496985, 3.391427, 2.8496985], + ], + ] + ), ] - ), -] + ) class TestRandGaussianSmooth(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): converter = RandGaussianSmooth(**argments) converter.set_random_state(seed=0) result = converter(image) - np.testing.assert_allclose(result, expected_data, rtol=1e-4) + assert_allclose(result, expected_data, rtol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_gaussian_smoothd.py b/tests/test_rand_gaussian_smoothd.py index 2eedc9071c..e0ef0a8bb5 100644 --- a/tests/test_rand_gaussian_smoothd.py +++ b/tests/test_rand_gaussian_smoothd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,48 +15,81 @@ from parameterized import parameterized from monai.transforms import RandGaussianSmoothd +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"keys": "img", "sigma_x": (0.5, 1.5), "prob": 1.0}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( [ - [[0.71806467, 0.9074683, 0.71806467], [1.0718315, 1.3545481, 1.0718315], [1.0337002, 1.306359, 1.0337002]], - [[2.0318885, 2.5678391, 2.0318885], [2.6795788, 3.3863702, 2.6795788], [2.3475242, 2.9667296, 2.3475242]], + {"keys": "img", "sigma_x": (0.5, 1.5), "prob": 1.0}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [0.71806467, 0.9074683, 0.71806467], + [1.0718315, 1.3545481, 1.0718315], + [1.0337002, 1.306359, 1.0337002], + ], + [ + [2.0318885, 2.5678391, 2.0318885], + [2.6795788, 3.3863702, 2.6795788], + [2.3475242, 2.9667296, 2.3475242], + ], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - {"keys": "img", "sigma_x": (0.5, 1.5), "sigma_y": (0.5, 1.0), "prob": 1.0}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( + TESTS.append( [ - [[0.7686928, 0.9848021, 0.7686928], [1.1474025, 1.4699818, 1.1474024], [1.1065826, 1.4176859, 1.1065826]], - [[2.1751494, 2.7866683, 2.1751497], [2.8685062, 3.6749542, 2.8685062], [2.5130394, 3.219552, 2.5130394]], + {"keys": "img", "sigma_x": (0.5, 1.5), "sigma_y": (0.5, 1.0), "prob": 1.0}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [0.7686928, 0.9848021, 0.7686928], + [1.1474025, 1.4699818, 1.1474024], + [1.1065826, 1.4176859, 1.1065826], + ], + [ + [2.1751494, 2.7866683, 2.1751497], + [2.8685062, 3.6749542, 2.8685062], + [2.5130394, 3.219552, 2.5130394], + ], + ] + ), ] - ), -] + ) -TEST_CASE_3 = [ - {"keys": "img", "sigma_x": (0.5, 1.5), "sigma_y": (0.5, 1.0), "approx": "scalespace", "prob": 1.0}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( + TESTS.append( [ - [[0.8128456, 0.96736777, 0.8128456], [1.2742369, 1.5164697, 1.2742369], [1.2800367, 1.5233722, 1.2800368]], - [[2.3825073, 2.8354228, 2.3825073], [3.1855922, 3.7911744, 3.1855922], [2.8496985, 3.391427, 2.8496985]], + {"keys": "img", "sigma_x": (0.5, 1.5), "sigma_y": (0.5, 1.0), "approx": "scalespace", "prob": 1.0}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [0.8128456, 0.96736777, 0.8128456], + [1.2742369, 1.5164697, 1.2742369], + [1.2800367, 1.5233722, 1.2800368], + ], + [ + [2.3825073, 2.8354228, 2.3825073], + [3.1855922, 3.7911744, 3.1855922], + [2.8496985, 3.391427, 2.8496985], + ], + ] + ), ] - ), -] + ) class TestRandGaussianSmoothd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): converter = RandGaussianSmoothd(**argments) converter.set_random_state(seed=0) result = converter(image) - np.testing.assert_allclose(result["img"], expected_data, rtol=1e-4) + assert_allclose(result["img"], expected_data, rtol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_gibbs_noise.py b/tests/test_rand_gibbs_noise.py index a0701d09c3..fe928038da 100644 --- a/tests/test_rand_gibbs_noise.py +++ b/tests/test_rand_gibbs_noise.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,17 +19,17 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import RandGibbsNoise from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS + +_, has_torch_fft = optional_import("torch.fft", name="fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, input_type)) -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestRandGibbsNoise(unittest.TestCase): def setUp(self): set_determinism(0) @@ -39,50 +39,50 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d im = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] - return torch.Tensor(im) if as_tensor_input else im + return input_type(im) @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_0_prob(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [0.5, 1.0] - t = RandGibbsNoise(0.0, alpha, as_tensor_output) + t = RandGibbsNoise(0.0, alpha) out = t(im) - np.testing.assert_allclose(im, out) + torch.testing.assert_allclose(im, out, rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_same_result(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [0.5, 0.8] - t = RandGibbsNoise(1.0, alpha, as_tensor_output) + t = RandGibbsNoise(1.0, alpha) t.set_random_state(42) out1 = t(deepcopy(im)) t.set_random_state(42) out2 = t(deepcopy(im)) - np.testing.assert_allclose(out1, out2) - self.assertIsInstance(out1, torch.Tensor if as_tensor_output else np.ndarray) + torch.testing.assert_allclose(out1, out2, rtol=1e-7, atol=0) + self.assertIsInstance(out1, type(im)) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_identity(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [0.0, 0.0] t = RandGibbsNoise(1.0, alpha) out = t(deepcopy(im)) - np.testing.assert_allclose(im, out, atol=1e-2) + torch.testing.assert_allclose(im, out, atol=1e-2, rtol=1e-7) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_alpha_1(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [1.0, 1.0] t = RandGibbsNoise(1.0, alpha) out = t(deepcopy(im)) - np.testing.assert_allclose(0 * im, out) + torch.testing.assert_allclose(0 * im, out, rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_alpha(self, im_shape, _, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_alpha(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [0.5, 0.51] t = RandGibbsNoise(1.0, alpha) _ = t(deepcopy(im)) diff --git a/tests/test_rand_gibbs_noised.py b/tests/test_rand_gibbs_noised.py index b778bffdda..8c5e045b90 100644 --- a/tests/test_rand_gibbs_noised.py +++ b/tests/test_rand_gibbs_noised.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,19 +19,19 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import RandGibbsNoised from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS + +_, has_torch_fft = optional_import("torch.fft", name="fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, input_type)) KEYS = ["im", "label"] -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestRandGibbsNoised(unittest.TestCase): def setUp(self): set_determinism(0) @@ -41,70 +41,76 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims - return dict(zip(KEYS, ims)) + return {k: input_type(v) for k, v in zip(KEYS, ims)} @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_0_prob(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = [0.5, 1.0] - t = RandGibbsNoised(KEYS, 0.0, alpha, as_tensor_output) + t = RandGibbsNoised(KEYS, 0.0, alpha) out = t(data) for k in KEYS: - np.testing.assert_allclose(data[k], out[k]) + torch.testing.assert_allclose(data[k], out[k], rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_same_result(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = [0.5, 0.8] - t = RandGibbsNoised(KEYS, 1.0, alpha, as_tensor_output) + t = RandGibbsNoised(KEYS, 1.0, alpha) t.set_random_state(42) out1 = t(deepcopy(data)) t.set_random_state(42) out2 = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(out1[k], out2[k]) - self.assertIsInstance(out1[k], torch.Tensor if as_tensor_output else np.ndarray) + torch.testing.assert_allclose(out1[k], out2[k], rtol=1e-7, atol=0) + self.assertIsInstance(out1[k], type(data[k])) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_identity(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = [0.0, 0.0] t = RandGibbsNoised(KEYS, 1.0, alpha) out = t(deepcopy(data)) for k in KEYS: + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k], data[k] = out[k].cpu(), data[k].cpu() np.testing.assert_allclose(data[k], out[k], atol=1e-2) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_alpha_1(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = [1.0, 1.0] t = RandGibbsNoised(KEYS, 1.0, alpha) out = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(0 * data[k], out[k]) + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k], data[k] = out[k].cpu(), data[k].cpu() + np.testing.assert_allclose(0.0 * data[k], out[k], atol=1e-2) @parameterized.expand(TEST_CASES) - def test_dict_matches(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_dict_matches(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) # use same image for both dictionary entries to check same trans is applied to them data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} alpha = [0.5, 1.0] t = RandGibbsNoised(KEYS, 1.0, alpha) out = t(deepcopy(data)) - np.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]]) + torch.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]], rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_alpha(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_alpha(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = [0.5, 0.51] t = RandGibbsNoised(KEYS, 1.0, alpha) _ = t(deepcopy(data)) - self.assertGreaterEqual(t.sampled_alpha, 0.5) - self.assertLessEqual(t.sampled_alpha, 0.51) + self.assertTrue(0.5 <= t.rand_gibbs_noise.sampled_alpha <= 0.51) if __name__ == "__main__": diff --git a/tests/test_rand_grid_distortion.py b/tests/test_rand_grid_distortion.py new file mode 100644 index 0000000000..80f19df0db --- /dev/null +++ b/tests/test_rand_grid_distortion.py @@ -0,0 +1,94 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandGridDistortion +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + seed = 0 + TESTS.append( + [ + dict(num_cells=2, prob=1.0, distort_limit=0.5, mode="nearest", padding_mode="zeros"), + seed, + p(np.indices([6, 6]).astype(np.float32)), + p( + np.array( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [2.0, 2.0, 2.0, 2.0, 2.0, 0.0], + [4.0, 4.0, 4.0, 4.0, 4.0, 0.0], + [4.0, 4.0, 4.0, 4.0, 4.0, 0.0], + [5.0, 5.0, 5.0, 5.0, 5.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 1.0, 3.0, 3.0, 4.0, 0.0], + [0.0, 1.0, 3.0, 3.0, 4.0, 0.0], + [0.0, 1.0, 3.0, 3.0, 4.0, 0.0], + [0.0, 1.0, 3.0, 3.0, 4.0, 0.0], + [0.0, 1.0, 3.0, 3.0, 4.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + ] + ).astype(np.float32) + ), + ] + ) + seed = 1 + TESTS.append( + [ + dict(num_cells=(2, 2), prob=1.0, distort_limit=0.1, mode="bilinear", padding_mode="reflection"), + seed, + p(np.indices([6, 6]).astype(np.float32)), + p( + np.array( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.5660975, 1.5660975, 1.5660975, 1.5660975, 1.5660974, 1.5660975], + [3.132195, 3.132195, 3.132195, 3.132195, 3.132195, 3.132195], + [3.132195, 3.132195, 3.132195, 3.132195, 3.132195, 3.132195], + [4.482229, 4.482229, 4.482229, 4.482229, 4.482229, 4.482229], + [4.167737, 4.167737, 4.167737, 4.167737, 4.167737, 4.167737], + ], + [ + [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 4.4565434], + [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 4.4565434], + [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 4.4565434], + [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 4.4565434], + [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 4.4565434], + [0.0, 1.3940266, 2.7880538, 2.7880538, 4.1657557, 4.456543], + ], + ] + ).astype(np.float32) + ), + ] + ) + + +class TestRandGridDistortion(unittest.TestCase): + @parameterized.expand(TESTS) + def test_rand_grid_distortion(self, input_param, seed, input_data, expected_val): + g = RandGridDistortion(**input_param) + g.set_random_state(seed=seed) + result = g(input_data) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_grid_distortiond.py b/tests/test_rand_grid_distortiond.py new file mode 100644 index 0000000000..323848dc0b --- /dev/null +++ b/tests/test_rand_grid_distortiond.py @@ -0,0 +1,88 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandGridDistortiond +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +num_cells = 2 +seed = 0 +for p in TEST_NDARRAYS: + img = np.indices([6, 6]).astype(np.float32) + TESTS.append( + [ + dict( + keys=["img", "mask"], + num_cells=num_cells, + prob=1.0, + distort_limit=(-0.1, 0.1), + mode=["bilinear", "nearest"], + padding_mode="zeros", + ), + seed, + {"img": p(img), "mask": p(np.ones_like(img[:1]))}, + p( + np.array( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.5645568, 1.5645568, 1.5645568, 1.5645568, 1.5645568, 0.0], + [3.1291137, 3.1291137, 3.1291137, 3.1291137, 3.1291137, 0.0], + [3.1291137, 3.1291137, 3.1291137, 3.1291137, 3.1291137, 0.0], + [4.6599426, 4.6599426, 4.6599426, 4.6599426, 4.6599426, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0], + [0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0], + [0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0], + [0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0], + [0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + ] + ).astype(np.float32) + ), + p( + np.array( + [ + [ + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] + ] + ) + ), + ] + ) + + +class TestRandGridDistortiond(unittest.TestCase): + @parameterized.expand(TESTS) + def test_rand_grid_distortiond(self, input_param, seed, input_data, expected_val_img, expected_val_mask): + g = RandGridDistortiond(**input_param) + g.set_random_state(seed=seed) + result = g(input_data) + assert_allclose(result["img"], expected_val_img, rtol=1e-4, atol=1e-4) + assert_allclose(result["mask"], expected_val_mask, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_histogram_shift.py b/tests/test_rand_histogram_shift.py index b258cc5a7e..c66f7859c6 100644 --- a/tests/test_rand_histogram_shift.py +++ b/tests/test_rand_histogram_shift.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,33 +15,40 @@ from parameterized import parameterized from monai.transforms import RandHistogramShift - -TEST_CASES = [ - [ - {"num_control_points": 5, "prob": 0.0}, - {"img": np.arange(8).reshape((1, 2, 2, 2))}, - np.arange(8).reshape((1, 2, 2, 2)), - ], - [ - {"num_control_points": 5, "prob": 0.9}, - {"img": np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)}, - np.array([[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]]), - ], - [ - {"num_control_points": (5, 20), "prob": 0.9}, - {"img": np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)}, - np.array([[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]]), - ], -] +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"num_control_points": 5, "prob": 0.0}, + {"img": p(np.arange(8).reshape((1, 2, 2, 2)))}, + np.arange(8).reshape((1, 2, 2, 2)), + ] + ) + TESTS.append( + [ + {"num_control_points": 5, "prob": 0.9}, + {"img": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32))}, + np.array([[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]]), + ] + ) + TESTS.append( + [ + {"num_control_points": (5, 20), "prob": 0.9}, + {"img": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32))}, + np.array([[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]]), + ] + ) class TestRandHistogramShift(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_histogram_shift(self, input_param, input_data, expected_val): g = RandHistogramShift(**input_param) g.set_random_state(123) result = g(**input_data) - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_histogram_shiftd.py b/tests/test_rand_histogram_shiftd.py index 806e4f5cf2..fe8ddf9ffd 100644 --- a/tests/test_rand_histogram_shiftd.py +++ b/tests/test_rand_histogram_shiftd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,47 +14,60 @@ import numpy as np from parameterized import parameterized -from monai.transforms import RandHistogramShiftD - -TEST_CASES = [ - [ - {"keys": ("img",), "num_control_points": 5, "prob": 0.0}, - {"img": np.arange(8).reshape((1, 2, 2, 2)), "seg": np.ones(8).reshape((1, 2, 2, 2))}, - {"img": np.arange(8).reshape((1, 2, 2, 2)), "seg": np.ones(8).reshape((1, 2, 2, 2))}, - ], - [ - {"keys": ("img",), "num_control_points": 5, "prob": 0.9}, - {"img": np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32), "seg": np.ones(8).reshape((1, 2, 2, 2))}, - { - "img": np.array( - [[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]] - ), - "seg": np.ones(8).reshape((1, 2, 2, 2)), - }, - ], - [ - {"keys": ("img",), "num_control_points": (5, 20), "prob": 0.9}, - {"img": np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32), "seg": np.ones(8).reshape((1, 2, 2, 2))}, - { - "img": np.array( - [[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]] - ), - "seg": np.ones(8).reshape((1, 2, 2, 2)), - }, - ], -] +from monai.transforms.intensity.dictionary import RandHistogramShiftd +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ("img",), "num_control_points": 5, "prob": 0.0}, + {"img": p(np.arange(8).reshape((1, 2, 2, 2))), "seg": p(np.ones(8).reshape((1, 2, 2, 2)))}, + {"img": np.arange(8).reshape((1, 2, 2, 2)), "seg": np.ones(8).reshape((1, 2, 2, 2))}, + ] + ) + TESTS.append( + [ + {"keys": ("img",), "num_control_points": 5, "prob": 0.9}, + { + "img": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)), + "seg": p(np.ones(8).reshape((1, 2, 2, 2))), + }, + { + "img": np.array( + [[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]] + ), + "seg": np.ones(8).reshape((1, 2, 2, 2)), + }, + ] + ) + TESTS.append( + [ + {"keys": ("img",), "num_control_points": (5, 20), "prob": 0.9}, + { + "img": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)), + "seg": p(np.ones(8).reshape((1, 2, 2, 2))), + }, + { + "img": np.array( + [[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]] + ), + "seg": np.ones(8).reshape((1, 2, 2, 2)), + }, + ] + ) class TestRandHistogramShiftD(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_histogram_shiftd(self, input_param, input_data, expected_val): - g = RandHistogramShiftD(**input_param) + g = RandHistogramShiftd(**input_param) g.set_random_state(123) res = g(input_data) for key in res: result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected, rtol=1e-4, atol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_k_space_spike_noise.py b/tests/test_rand_k_space_spike_noise.py index 71f7e36d9b..8027194555 100644 --- a/tests/test_rand_k_space_spike_noise.py +++ b/tests/test_rand_k_space_spike_noise.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,18 +19,15 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import KSpaceSpikeNoise, RandKSpaceSpikeNoise from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from tests.utils import TEST_NDARRAYS -TEST_CASES = [] +TESTS = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - for channel_wise in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input, channel_wise)) + for p in TEST_NDARRAYS: + for channel_wise in (True, False): + TESTS.append((shape, p, channel_wise)) -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestRandKSpaceSpikeNoise(unittest.TestCase): def setUp(self): set_determinism(0) @@ -40,50 +37,68 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, im_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d im = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] - return torch.Tensor(im) if as_tensor_input else im + return im_type(im) - @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input, channel_wise): - im = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_0_prob(self, im_shape, im_type, channel_wise): + im = self.get_data(im_shape, im_type) intensity_range = [14, 15] - t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise, as_tensor_output) + t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise) out = t(im) + self.assertEqual(type(im), type(out)) + if isinstance(out, torch.Tensor): + self.assertEqual(out.device, im.device) + im, out = im.cpu(), out.cpu() np.testing.assert_allclose(im, out) - @parameterized.expand(TEST_CASES) - def test_1_prob(self, im_shape, as_tensor_output, as_tensor_input, channel_wise): - im = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_1_prob(self, im_shape, im_type, channel_wise): + im = self.get_data(im_shape, im_type) intensity_range = [14, 14] - t = RandKSpaceSpikeNoise(1.0, intensity_range, channel_wise, as_tensor_output) + t = RandKSpaceSpikeNoise(1.0, intensity_range, channel_wise) out = t(im) - base_t = KSpaceSpikeNoise(t.sampled_locs, [14], as_tensor_output) + base_t = KSpaceSpikeNoise(t.sampled_locs, [14]) out = out - base_t(im) + self.assertEqual(type(im), type(out)) + if isinstance(out, torch.Tensor): + self.assertEqual(out.device, im.device) + im, out = im.cpu(), out.cpu() np.testing.assert_allclose(out, im * 0) - @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input, channel_wise): - im = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_same_result(self, im_shape, im_type, channel_wise): + im = self.get_data(im_shape, im_type) intensity_range = [14, 15] - t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise, as_tensor_output) + t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise) t.set_random_state(42) out1 = t(deepcopy(im)) t.set_random_state(42) out2 = t(deepcopy(im)) + self.assertEqual(type(im), type(out1)) + if isinstance(out1, torch.Tensor): + self.assertEqual(out1.device, im.device) + out1, out2 = out1.cpu(), out2.cpu() np.testing.assert_allclose(out1, out2) - self.assertIsInstance(out1, torch.Tensor if as_tensor_output else np.ndarray) - @parameterized.expand(TEST_CASES) - def test_intensity(self, im_shape, _, as_tensor_input, channel_wise): - im = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_intensity(self, im_shape, im_type, channel_wise): + im = self.get_data(im_shape, im_type) intensity_range = [14, 14.1] t = RandKSpaceSpikeNoise(1.0, intensity_range, channel_wise) _ = t(deepcopy(im)) self.assertGreaterEqual(t.sampled_k_intensity[0], 14) self.assertLessEqual(t.sampled_k_intensity[0], 14.1) + @parameterized.expand(TESTS) + def test_default_intensity(self, im_shape, im_type, channel_wise): + im = self.get_data(im_shape, im_type) + t = RandKSpaceSpikeNoise(1.0, intensity_range=None, channel_wise=channel_wise) + out = t(deepcopy(im)) + self.assertEqual(out.shape, im.shape) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_k_space_spike_noised.py b/tests/test_rand_k_space_spike_noised.py index 1056ebf163..7a6a73b215 100644 --- a/tests/test_rand_k_space_spike_noised.py +++ b/tests/test_rand_k_space_spike_noised.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,19 +19,16 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import RandKSpaceSpikeNoised from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from tests.utils import TEST_NDARRAYS -TEST_CASES = [] +TESTS = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for p in TEST_NDARRAYS: + TESTS.append((shape, p)) KEYS = ["image", "label"] -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestKSpaceSpikeNoised(unittest.TestCase): def setUp(self): set_determinism(0) @@ -41,107 +38,53 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, im_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) - ims = [im[None] for im in ims] - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims - return dict(zip(KEYS, ims)) - - @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): - - data = self.get_data(im_shape, as_tensor_input) - - intensity_ranges = {"image": (13, 15), "label": (13, 15)} - t = RandKSpaceSpikeNoised( - KEYS, - global_prob=1.0, - prob=1.0, - intensity_ranges=intensity_ranges, - channel_wise=True, - as_tensor_output=as_tensor_output, - ) - t.set_rand_state(42) + ims = [im_type(im[None]) for im in ims] + return {k: v for k, v in zip(KEYS, ims)} + + @parameterized.expand(TESTS) + def test_same_result(self, im_shape, im_type): + + data = self.get_data(im_shape, im_type) + + t = RandKSpaceSpikeNoised(KEYS, prob=1.0, intensity_range=(13, 15), channel_wise=True) + t.set_random_state(42) out1 = t(deepcopy(data)) - t.set_rand_state(42) + t.set_random_state(42) out2 = t(deepcopy(data)) for k in KEYS: + self.assertEqual(type(out1[k]), type(data[k])) + if isinstance(out1[k], torch.Tensor): + self.assertEqual(out1[k].device, data[k].device) + out1[k] = out1[k].cpu() + out2[k] = out2[k].cpu() np.testing.assert_allclose(out1[k], out2[k], atol=1e-10) - self.assertIsInstance(out1[k], torch.Tensor if as_tensor_output else np.ndarray) - - @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) - intensity_ranges = {"image": (13, 15), "label": (13, 15)} - t1 = RandKSpaceSpikeNoised( - KEYS, - global_prob=0.0, - prob=1.0, - intensity_ranges=intensity_ranges, - channel_wise=True, - as_tensor_output=as_tensor_output, - ) - - t2 = RandKSpaceSpikeNoised( - KEYS, - global_prob=0.0, - prob=1.0, - intensity_ranges=intensity_ranges, - channel_wise=True, - as_tensor_output=as_tensor_output, - ) + + @parameterized.expand(TESTS) + def test_0_prob(self, im_shape, im_type): + data = self.get_data(im_shape, im_type) + + t1 = RandKSpaceSpikeNoised(KEYS, prob=0.0, intensity_range=(13, 15), channel_wise=True) + + t2 = RandKSpaceSpikeNoised(KEYS, prob=0.0, intensity_range=(13, 15), channel_wise=True) out1 = t1(data) out2 = t2(data) for k in KEYS: + self.assertEqual(type(out1[k]), type(data[k])) + if isinstance(out1[k], torch.Tensor): + self.assertEqual(out1[k].device, data[k].device) + out1[k] = out1[k].cpu() + out2[k] = out2[k].cpu() + data[k] = data[k].cpu() + np.testing.assert_allclose(data[k], out1[k]) np.testing.assert_allclose(data[k], out2[k]) - @parameterized.expand(TEST_CASES) - def test_intensity(self, im_shape, as_tensor_output, as_tensor_input): - - data = self.get_data(im_shape, as_tensor_input) - intensity_ranges = {"image": (13, 13.1), "label": (13, 13.1)} - t = RandKSpaceSpikeNoised( - KEYS, - global_prob=1.0, - prob=1.0, - intensity_ranges=intensity_ranges, - channel_wise=True, - as_tensor_output=True, - ) - - _ = t(data) - self.assertGreaterEqual(t.transforms["image"].sampled_k_intensity[0], 13) - self.assertLessEqual(t.transforms["image"].sampled_k_intensity[0], 13.1) - self.assertGreaterEqual(t.transforms["label"].sampled_k_intensity[0], 13) - self.assertLessEqual(t.transforms["label"].sampled_k_intensity[0], 13.1) - - @parameterized.expand(TEST_CASES) - def test_same_transformation(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) - # use same image for both dictionary entries to check same trans is applied to them - data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} - - intensity_ranges = {"image": (13, 15), "label": (13, 15)} - # use common_sampling = True to ask for the same transformation - t = RandKSpaceSpikeNoised( - KEYS, - global_prob=1.0, - prob=1.0, - intensity_ranges=intensity_ranges, - channel_wise=True, - common_sampling=True, - as_tensor_output=True, - ) - - out = t(deepcopy(data)) - - np.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]]) - if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_lambda.py b/tests/test_rand_lambda.py index bf537883cf..043f44aec4 100644 --- a/tests/test_rand_lambda.py +++ b/tests/test_rand_lambda.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_rand_lambdad.py b/tests/test_rand_lambdad.py index 0a127839b8..854fef8879 100644 --- a/tests/test_rand_lambdad.py +++ b/tests/test_rand_lambdad.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_rand_local_patch_shuffle.py b/tests/test_rand_local_patch_shuffle.py deleted file mode 100644 index 8e2eefb5d1..0000000000 --- a/tests/test_rand_local_patch_shuffle.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from parameterized import parameterized - -from monai.transforms import LocalPatchShuffling - -TEST_CASES = [ - [ - {"number_blocks": 10, "blocksize_ratio": 1, "prob": 0.0}, - {"img": np.arange(8).reshape((1, 2, 2, 2))}, - np.arange(8).reshape((1, 2, 2, 2)), - ], - [ - {"number_blocks": 10, "blocksize_ratio": 1, "prob": 1.0}, - {"img": np.arange(27).reshape((1, 3, 3, 3))}, - [ - [ - [[9, 1, 2], [3, 4, 5], [6, 7, 8]], - [[0, 10, 11], [12, 4, 14], [15, 16, 17]], - [[18, 19, 20], [21, 22, 23], [24, 25, 26]], - ] - ], - ], -] - - -class TestLocalPatchShuffle(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_local_patch_shuffle(self, input_param, input_data, expected_val): - g = LocalPatchShuffling(**input_param) - g.set_random_state(seed=12) - result = g(**input_data) - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_rand_rician_noise.py b/tests/test_rand_rician_noise.py index 7ec5fc4dc4..8e2ea1ee3a 100644 --- a/tests/test_rand_rician_noise.py +++ b/tests/test_rand_rician_noise.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_rand_rician_noised.py b/tests/test_rand_rician_noised.py index 010bbcb310..05707059bc 100644 --- a/tests/test_rand_rician_noised.py +++ b/tests/test_rand_rician_noised.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -29,11 +29,12 @@ class TestRandRicianNoisedNumpy(NumpyImageTestCase2D): @parameterized.expand(TESTS) def test_correct_results(self, _, in_type, keys, mean, std): - rician_fn = RandRicianNoised(keys=keys, global_prob=1.0, prob=1.0, mean=mean, std=std) + rician_fn = RandRicianNoised(keys=keys, prob=1.0, mean=mean, std=std, dtype=np.float64) rician_fn.set_random_state(seed) noised = rician_fn({k: in_type(self.imt) for k in keys}) np.random.seed(seed) for k in keys: + # simulate the `randomize` function of transform np.random.random() _std = np.random.uniform(0, std) expected = np.sqrt( diff --git a/tests/test_rand_rotate.py b/tests/test_rand_rotate.py index 0ff8508a0f..7a85fce23b 100644 --- a/tests/test_rand_rotate.py +++ b/tests/test_rand_rotate.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,25 +10,60 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import RandRotate -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, np.pi / 2, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4, True, "nearest", "border", False)) + TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", True)) + TEST_CASES_2D.append((p, (-np.pi / 4, 0), False, "nearest", "zeros", True)) -class TestRandRotate2D(NumpyImageTestCase2D): - @parameterized.expand( - [ - (np.pi / 2, True, "bilinear", "border", False), - (np.pi / 4, True, "nearest", "border", False), - (np.pi, False, "nearest", "zeros", True), - ((-np.pi / 4, 0), False, "nearest", "zeros", True), - ] +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append( + (p, np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)) + ) + TEST_CASES_3D.append( + ( + p, + np.pi / 4, + (-np.pi / 9, np.pi / 4.5), + (np.pi / 9, np.pi / 6), + False, + "nearest", + "border", + True, + (1, 89, 105, 104), + ) + ) + TEST_CASES_3D.append( + ( + p, + 0.0, + (2 * np.pi, 2.06 * np.pi), + (-np.pi / 180, np.pi / 180), + True, + "nearest", + "zeros", + True, + (1, 48, 64, 80), + ) ) - def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_corners): + TEST_CASES_3D.append((p, (-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90))) + + +class TestRandRotate2D(NumpyImageTestCase2D): + @parameterized.expand(TEST_CASES_2D) + def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners): rotate_fn = RandRotate( range_x=degrees, prob=1.0, @@ -36,9 +71,10 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor mode=mode, padding_mode=padding_mode, align_corners=align_corners, + dtype=np.float64, ) rotate_fn.set_random_state(243) - rotated = rotate_fn(self.imt[0]) + rotated = rotate_fn(im_type(self.imt[0])) _order = 0 if mode == "nearest" else 1 if mode == "border": @@ -52,38 +88,14 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) expected = np.stack(expected).astype(np.float32) + rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated good = np.sum(np.isclose(expected, rotated[0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") class TestRandRotate3D(NumpyImageTestCase3D): - @parameterized.expand( - [ - (np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)), - ( - np.pi / 4, - (-np.pi / 9, np.pi / 4.5), - (np.pi / 9, np.pi / 6), - False, - "nearest", - "border", - True, - (1, 89, 105, 104), - ), - ( - 0.0, - (2 * np.pi, 2.06 * np.pi), - (-np.pi / 180, np.pi / 180), - True, - "nearest", - "zeros", - True, - (1, 48, 64, 80), - ), - ((-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90)), - ] - ) - def test_correct_results(self, x, y, z, keep_size, mode, padding_mode, align_corners, expected): + @parameterized.expand(TEST_CASES_3D) + def test_correct_results(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected): rotate_fn = RandRotate( range_x=x, range_y=y, @@ -93,10 +105,11 @@ def test_correct_results(self, x, y, z, keep_size, mode, padding_mode, align_cor mode=mode, padding_mode=padding_mode, align_corners=align_corners, + dtype=np.float64, ) rotate_fn.set_random_state(243) - rotated = rotate_fn(self.imt[0]) - np.testing.assert_allclose(rotated.shape, expected) + rotated = rotate_fn(im_type(self.imt[0])) + torch.testing.assert_allclose(rotated.shape, expected, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_rand_rotate90.py b/tests/test_rand_rotate90.py index 50a1b28e53..b845944062 100644 --- a/tests/test_rand_rotate90.py +++ b/tests/test_rand_rotate90.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,49 +14,45 @@ import numpy as np from monai.transforms import RandRotate90 -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestRandRotate90(NumpyImageTestCase2D): def test_default(self): rotate = RandRotate90() - rotate.set_random_state(123) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(123) + rotated = rotate(p(self.imt[0])) + expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) def test_k(self): rotate = RandRotate90(max_k=2) - rotate.set_random_state(234) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(123) + rotated = rotate(p(self.imt[0])) + expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) def test_spatial_axes(self): rotate = RandRotate90(spatial_axes=(0, 1)) - rotate.set_random_state(234) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(123) + rotated = rotate(p(self.imt[0])) + expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) def test_prob_k_spatial_axes(self): rotate = RandRotate90(prob=1.0, max_k=2, spatial_axes=(0, 1)) - rotate.set_random_state(234) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(234) + rotated = rotate(p(self.imt[0])) + expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) if __name__ == "__main__": diff --git a/tests/test_rand_rotate90d.py b/tests/test_rand_rotate90d.py index a487b695f5..ded18e430a 100644 --- a/tests/test_rand_rotate90d.py +++ b/tests/test_rand_rotate90d.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,53 +14,49 @@ import numpy as np from monai.transforms import RandRotate90d -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestRandRotate90d(NumpyImageTestCase2D): def test_default(self): key = None rotate = RandRotate90d(keys=key) - rotate.set_random_state(123) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(123) + rotated = rotate({key: p(self.imt[0])}) + expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated[key], p(expected)) def test_k(self): key = "test" rotate = RandRotate90d(keys=key, max_k=2) - rotate.set_random_state(234) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(234) + rotated = rotate({key: p(self.imt[0])}) + expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated[key], p(expected)) def test_spatial_axes(self): key = "test" rotate = RandRotate90d(keys=key, spatial_axes=(0, 1)) - rotate.set_random_state(234) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(234) + rotated = rotate({key: p(self.imt[0])}) + expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated[key], p(expected)) def test_prob_k_spatial_axes(self): key = "test" rotate = RandRotate90d(keys=key, prob=1.0, max_k=2, spatial_axes=(0, 1)) - rotate.set_random_state(234) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(234) + rotated = rotate({key: p(self.imt[0])}) + expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated[key], p(expected)) def test_no_key(self): key = "unknown" diff --git a/tests/test_rand_rotated.py b/tests/test_rand_rotated.py index 47b4b7107e..464b37d925 100644 --- a/tests/test_rand_rotated.py +++ b/tests/test_rand_rotated.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,26 +10,104 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import RandRotated from monai.utils import GridSampleMode, GridSamplePadMode -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, np.pi / 2, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4, True, "nearest", "border", False)) + TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", True)) + TEST_CASES_2D.append((p, (-np.pi / 4, 0), False, "nearest", "zeros", True)) -class TestRandRotated2D(NumpyImageTestCase2D): - @parameterized.expand( - [ - (np.pi / 2, True, "bilinear", "border", False), - (np.pi / 4, True, "nearest", "border", False), - (np.pi, False, "nearest", "zeros", True), - ((-np.pi / 4, 0), False, "nearest", "zeros", True), - ] + +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append( + (p, np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)) ) - def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_corners): + TEST_CASES_3D.append( + ( + p, + np.pi / 2, + -np.pi / 6, + (0.0, np.pi), + False, + GridSampleMode.NEAREST, + GridSamplePadMode.BORDER, + False, + (1, 87, 104, 109), + ) + ) + TEST_CASES_3D.append( + ( + p, + np.pi / 4, + (-np.pi / 9, np.pi / 4.5), + (np.pi / 9, np.pi / 6), + False, + "nearest", + "border", + True, + (1, 89, 105, 104), + ) + ) + TEST_CASES_3D.append( + ( + p, + np.pi / 4, + (-np.pi / 9, np.pi / 4.5), + (np.pi / 9, np.pi / 6), + False, + GridSampleMode.NEAREST, + GridSamplePadMode.BORDER, + True, + (1, 89, 105, 104), + ) + ) + TEST_CASES_3D.append( + ( + p, + 0.0, + (2 * np.pi, 2.06 * np.pi), + (-np.pi / 180, np.pi / 180), + True, + "nearest", + "zeros", + True, + (1, 48, 64, 80), + ) + ) + TEST_CASES_3D.append( + ( + p, + 0.0, + (2 * np.pi, 2.06 * np.pi), + (-np.pi / 180, np.pi / 180), + True, + GridSampleMode.NEAREST, + GridSamplePadMode.ZEROS, + True, + (1, 48, 64, 80), + ) + ) + TEST_CASES_3D.append((p, (-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90))) + TEST_CASES_3D.append( + (p, (-np.pi / 4, 0), 0, 0, False, GridSampleMode.NEAREST, GridSamplePadMode.ZEROS, False, (1, 48, 77, 90)) + ) + + +class TestRandRotated2D(NumpyImageTestCase2D): + @parameterized.expand(TEST_CASES_2D) + def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners): rotate_fn = RandRotated( "img", range_x=degrees, @@ -38,9 +116,10 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor mode=mode, padding_mode=padding_mode, align_corners=align_corners, + dtype=np.float64, ) rotate_fn.set_random_state(243) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) _order = 0 if mode == "nearest" else 1 if padding_mode == "border": @@ -49,74 +128,20 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor _mode = "reflect" else: _mode = "constant" - angle = rotate_fn.x + angle = rotate_fn.rand_rotate.x expected = scipy.ndimage.rotate( self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v expected = np.stack(expected).astype(np.float32) good = np.sum(np.isclose(expected, rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") class TestRandRotated3D(NumpyImageTestCase3D): - @parameterized.expand( - [ - (np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)), - ( - np.pi / 2, - -np.pi / 6, - (0.0, np.pi), - False, - GridSampleMode.NEAREST, - GridSamplePadMode.BORDER, - False, - (1, 87, 104, 109), - ), - ( - np.pi / 4, - (-np.pi / 9, np.pi / 4.5), - (np.pi / 9, np.pi / 6), - False, - "nearest", - "border", - True, - (1, 89, 105, 104), - ), - ( - np.pi / 4, - (-np.pi / 9, np.pi / 4.5), - (np.pi / 9, np.pi / 6), - False, - GridSampleMode.NEAREST, - GridSamplePadMode.BORDER, - True, - (1, 89, 105, 104), - ), - ( - 0.0, - (2 * np.pi, 2.06 * np.pi), - (-np.pi / 180, np.pi / 180), - True, - "nearest", - "zeros", - True, - (1, 48, 64, 80), - ), - ( - 0.0, - (2 * np.pi, 2.06 * np.pi), - (-np.pi / 180, np.pi / 180), - True, - GridSampleMode.NEAREST, - GridSamplePadMode.ZEROS, - True, - (1, 48, 64, 80), - ), - ((-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90)), - ((-np.pi / 4, 0), 0, 0, False, GridSampleMode.NEAREST, GridSamplePadMode.ZEROS, False, (1, 48, 77, 90)), - ] - ) - def test_correct_shapes(self, x, y, z, keep_size, mode, padding_mode, align_corners, expected): + @parameterized.expand(TEST_CASES_3D) + def test_correct_shapes(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected): rotate_fn = RandRotated( "img", range_x=x, @@ -127,9 +152,10 @@ def test_correct_shapes(self, x, y, z, keep_size, mode, padding_mode, align_corn mode=mode, padding_mode=padding_mode, align_corners=align_corners, + dtype=np.float64, ) rotate_fn.set_random_state(243) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) np.testing.assert_allclose(rotated["img"].shape, expected) diff --git a/tests/test_rand_scale_crop.py b/tests/test_rand_scale_crop.py index db5487ebff..5d6312002f 100644 --- a/tests/test_rand_scale_crop.py +++ b/tests/test_rand_scale_crop.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.transforms import RandScaleCrop +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASE_1 = [ {"roi_scale": [1.0, 1.0, -1.0], "random_center": True}, @@ -55,22 +56,25 @@ class TestRandScaleCrop(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, input_param, input_data, expected_shape): - result = RandScaleCrop(**input_param)(input_data) - self.assertTupleEqual(result.shape, expected_shape) + for p in TEST_NDARRAYS: + result = RandScaleCrop(**input_param)(p(input_data)) + self.assertTupleEqual(result.shape, expected_shape) @parameterized.expand([TEST_CASE_3]) def test_value(self, input_param, input_data): - cropper = RandScaleCrop(**input_param) - result = cropper(input_data) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] - np.testing.assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) + for p in TEST_NDARRAYS: + cropper = RandScaleCrop(**input_param) + result = cropper(p(input_data)) + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] + assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False) @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_random_shape(self, input_param, input_data, expected_shape): - cropper = RandScaleCrop(**input_param) - cropper.set_random_state(seed=123) - result = cropper(input_data) - self.assertTupleEqual(result.shape, expected_shape) + for p in TEST_NDARRAYS: + cropper = RandScaleCrop(**input_param) + cropper.set_random_state(seed=123) + result = cropper(p(input_data)) + self.assertTupleEqual(result.shape, expected_shape) if __name__ == "__main__": diff --git a/tests/test_rand_scale_cropd.py b/tests/test_rand_scale_cropd.py index 265c6c467d..5e833fef98 100644 --- a/tests/test_rand_scale_cropd.py +++ b/tests/test_rand_scale_cropd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.transforms import RandScaleCropd +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASE_1 = [ {"keys": "img", "roi_scale": [1.0, 1.0, -1.0], "random_center": True}, @@ -23,7 +24,8 @@ ] TEST_CASE_2 = [ - {"keys": "img", "roi_scale": [1.0, 1.0, 1.0], "random_center": False}, + # test `allow_missing_keys` with key "label" + {"keys": ["label", "img"], "roi_scale": [1.0, 1.0, 1.0], "random_center": False, "allow_missing_keys": True}, {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, (3, 3, 3, 3), ] @@ -66,10 +68,14 @@ def test_shape(self, input_param, input_data, expected_shape): @parameterized.expand([TEST_CASE_3]) def test_value(self, input_param, input_data): - cropper = RandScaleCropd(**input_param) - result = cropper(input_data) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] - np.testing.assert_allclose(result["img"], input_data["img"][:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) + for p in TEST_NDARRAYS: + cropper = RandScaleCropd(**input_param) + input_data["img"] = p(input_data["img"]) + result = cropper(input_data) + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] + assert_allclose( + result["img"], input_data["img"][:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False + ) @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_random_shape(self, input_param, input_data, expected_shape): diff --git a/tests/test_rand_scale_intensity.py b/tests/test_rand_scale_intensity.py index 750d88bfad..5aa5c7b964 100644 --- a/tests/test_rand_scale_intensity.py +++ b/tests/test_rand_scale_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,8 +24,10 @@ def test_value(self): scaler.set_random_state(seed=0) result = scaler(p(self.imt)) np.random.seed(0) + # simulate the randomize() of transform + np.random.random() expected = p((self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32)) - assert_allclose(result, expected, rtol=1e-7, atol=0) + assert_allclose(result, p(expected), rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_rand_scale_intensityd.py b/tests/test_rand_scale_intensityd.py index a8d2e63f65..655bd88ee0 100644 --- a/tests/test_rand_scale_intensityd.py +++ b/tests/test_rand_scale_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,14 +19,16 @@ class TestRandScaleIntensityd(NumpyImageTestCase2D): def test_value(self): + key = "img" for p in TEST_NDARRAYS: - key = "img" scaler = RandScaleIntensityd(keys=[key], factors=0.5, prob=1.0) scaler.set_random_state(seed=0) result = scaler({key: p(self.imt)}) np.random.seed(0) + # simulate the randomize function of transform + np.random.random() expected = (self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32) - assert_allclose(result[key], expected) + assert_allclose(result[key], p(expected)) if __name__ == "__main__": diff --git a/tests/test_rand_shift_intensity.py b/tests/test_rand_shift_intensity.py index 4c4dd87dfe..b4f32a385a 100644 --- a/tests/test_rand_shift_intensity.py +++ b/tests/test_rand_shift_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,6 +23,8 @@ def test_value(self): shifter.set_random_state(seed=0) result = shifter(self.imt, factor=1.0) np.random.seed(0) + # simulate the randomize() of transform + np.random.random() expected = self.imt + np.random.uniform(low=-1.0, high=1.0) np.testing.assert_allclose(result, expected) diff --git a/tests/test_rand_shift_intensityd.py b/tests/test_rand_shift_intensityd.py index 6766236146..4d05149e3c 100644 --- a/tests/test_rand_shift_intensityd.py +++ b/tests/test_rand_shift_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,28 +14,33 @@ import numpy as np from monai.transforms import IntensityStatsd, RandShiftIntensityd +from monai.utils.enums import PostFix from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestRandShiftIntensityd(NumpyImageTestCase2D): def test_value(self): + key = "img" for p in TEST_NDARRAYS: - key = "img" shifter = RandShiftIntensityd(keys=[key], offsets=1.0, prob=1.0) shifter.set_random_state(seed=0) result = shifter({key: p(self.imt)}) np.random.seed(0) + # simulate the randomize() of transform + np.random.random() expected = self.imt + np.random.uniform(low=-1.0, high=1.0) - assert_allclose(result[key], expected) + assert_allclose(result[key], p(expected)) def test_factor(self): key = "img" stats = IntensityStatsd(keys=key, ops="max", key_prefix="orig") shifter = RandShiftIntensityd(keys=[key], offsets=1.0, factor_key=["orig_max"], prob=1.0) - data = {key: self.imt, key + "_meta_dict": {"affine": None}} + data = {key: self.imt, PostFix.meta(key): {"affine": None}} shifter.set_random_state(seed=0) result = shifter(stats(data)) np.random.seed(0) + # simulate the randomize() of transform + np.random.random() expected = self.imt + np.random.uniform(low=-1.0, high=1.0) * np.nanmax(self.imt) np.testing.assert_allclose(result[key], expected) diff --git a/tests/test_rand_spatial_crop.py b/tests/test_rand_spatial_crop.py index 01e057e589..8f4bb0fffa 100644 --- a/tests/test_rand_spatial_crop.py +++ b/tests/test_rand_spatial_crop.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.transforms import RandSpatialCrop +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASE_0 = [ {"roi_size": [3, 3, -1], "random_center": True}, @@ -56,10 +57,11 @@ def test_shape(self, input_param, input_data, expected_shape): @parameterized.expand([TEST_CASE_3]) def test_value(self, input_param, input_data): - cropper = RandSpatialCrop(**input_param) - result = cropper(input_data) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] - np.testing.assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) + for p in TEST_NDARRAYS: + cropper = RandSpatialCrop(**input_param) + result = cropper(p(input_data)) + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] + assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False) @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) def test_random_shape(self, input_param, input_data, expected_shape): diff --git a/tests/test_rand_spatial_crop_samples.py b/tests/test_rand_spatial_crop_samples.py index 0ade9bbbba..18fdf38773 100644 --- a/tests/test_rand_spatial_crop_samples.py +++ b/tests/test_rand_spatial_crop_samples.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.transforms import RandSpatialCropSamples +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASE_1 = [ {"roi_size": [3, 3, 3], "num_samples": 4, "random_center": True, "random_size": False}, @@ -70,14 +71,15 @@ class TestRandSpatialCropSamples(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, input_param, input_data, expected_shape, expected_last_item): - xform = RandSpatialCropSamples(**input_param) - xform.set_random_state(1234) - result = xform(input_data) + for p in TEST_NDARRAYS: + xform = RandSpatialCropSamples(**input_param) + xform.set_random_state(1234) + result = xform(p(input_data)) - np.testing.assert_equal(len(result), input_param["num_samples"]) - for item, expected in zip(result, expected_shape): - self.assertTupleEqual(item.shape, expected) - np.testing.assert_allclose(result[-1], expected_last_item) + np.testing.assert_equal(len(result), input_param["num_samples"]) + for item, expected in zip(result, expected_shape): + self.assertTupleEqual(item.shape, expected) + assert_allclose(result[-1], expected_last_item, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_spatial_crop_samplesd.py b/tests/test_rand_spatial_crop_samplesd.py index 3f5eee7b27..0891068488 100644 --- a/tests/test_rand_spatial_crop_samplesd.py +++ b/tests/test_rand_spatial_crop_samplesd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,6 +15,8 @@ from parameterized import parameterized from monai.transforms import Compose, RandSpatialCropSamplesd, ToTensord +from monai.utils.enums import PostFix +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASE_1 = [ {"keys": ["img", "seg"], "num_samples": 4, "roi_size": [2, 2, 2], "random_center": True}, @@ -38,31 +40,48 @@ }, ] -TEST_CASE_2 = [ - {"keys": ["img", "seg"], "num_samples": 8, "roi_size": [2, 2, 3], "random_center": False}, - {"img": np.arange(81).reshape(3, 3, 3, 3), "seg": np.arange(81, 0, -1).reshape(3, 3, 3, 3)}, - [(3, 3, 3, 3), (3, 2, 3, 3), (3, 2, 2, 3), (3, 2, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3), (3, 2, 2, 3), (3, 3, 2, 3)], - { - "img": np.array( +TEST_CASE_2 = [] +for p in TEST_NDARRAYS: + TEST_CASE_2.append( + [ + {"keys": ["img", "seg"], "num_samples": 8, "roi_size": [2, 2, 3], "random_center": False}, + {"img": p(np.arange(81).reshape(3, 3, 3, 3)), "seg": p(np.arange(81, 0, -1).reshape(3, 3, 3, 3))}, [ - [[[0, 1, 2], [3, 4, 5]], [[9, 10, 11], [12, 13, 14]], [[18, 19, 20], [21, 22, 23]]], - [[[27, 28, 29], [30, 31, 32]], [[36, 37, 38], [39, 40, 41]], [[45, 46, 47], [48, 49, 50]]], - [[[54, 55, 56], [57, 58, 59]], [[63, 64, 65], [66, 67, 68]], [[72, 73, 74], [75, 76, 77]]], - ] - ), - "seg": np.array( - [ - [[[81, 80, 79], [78, 77, 76]], [[72, 71, 70], [69, 68, 67]], [[63, 62, 61], [60, 59, 58]]], - [[[54, 53, 52], [51, 50, 49]], [[45, 44, 43], [42, 41, 40]], [[36, 35, 34], [33, 32, 31]]], - [[[27, 26, 25], [24, 23, 22]], [[18, 17, 16], [15, 14, 13]], [[9, 8, 7], [6, 5, 4]]], - ] - ), - }, -] + (3, 3, 3, 3), + (3, 2, 3, 3), + (3, 2, 2, 3), + (3, 2, 3, 3), + (3, 3, 3, 3), + (3, 3, 3, 3), + (3, 2, 2, 3), + (3, 3, 2, 3), + ], + { + "img": p( + np.array( + [ + [[[0, 1, 2], [3, 4, 5]], [[9, 10, 11], [12, 13, 14]], [[18, 19, 20], [21, 22, 23]]], + [[[27, 28, 29], [30, 31, 32]], [[36, 37, 38], [39, 40, 41]], [[45, 46, 47], [48, 49, 50]]], + [[[54, 55, 56], [57, 58, 59]], [[63, 64, 65], [66, 67, 68]], [[72, 73, 74], [75, 76, 77]]], + ] + ) + ), + "seg": p( + np.array( + [ + [[[81, 80, 79], [78, 77, 76]], [[72, 71, 70], [69, 68, 67]], [[63, 62, 61], [60, 59, 58]]], + [[[54, 53, 52], [51, 50, 49]], [[45, 44, 43], [42, 41, 40]], [[36, 35, 34], [33, 32, 31]]], + [[[27, 26, 25], [24, 23, 22]], [[18, 17, 16], [15, 14, 13]], [[9, 8, 7], [6, 5, 4]]], + ] + ) + ), + }, + ] + ) class TestRandSpatialCropSamplesd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_1, *TEST_CASE_2]) def test_shape(self, input_param, input_data, expected_shape, expected_last): xform = RandSpatialCropSamplesd(**input_param) xform.set_random_state(1234) @@ -71,20 +90,16 @@ def test_shape(self, input_param, input_data, expected_shape, expected_last): self.assertTupleEqual(item["img"].shape, expected) self.assertTupleEqual(item["seg"].shape, expected) for i, item in enumerate(result): - self.assertEqual(item["img_meta_dict"]["patch_index"], i) - self.assertEqual(item["seg_meta_dict"]["patch_index"], i) - np.testing.assert_allclose(item["img"], expected_last["img"]) - np.testing.assert_allclose(item["seg"], expected_last["seg"]) + self.assertEqual(item[PostFix.meta("img")]["patch_index"], i) + self.assertEqual(item[PostFix.meta("seg")]["patch_index"], i) + assert_allclose(item["img"], expected_last["img"], type_test=True) + assert_allclose(item["seg"], expected_last["seg"], type_test=True) def test_deep_copy(self): data = {"img": np.ones((1, 10, 11, 12))} num_samples = 3 sampler = RandSpatialCropSamplesd( - keys=["img"], - roi_size=(3, 3, 3), - num_samples=num_samples, - random_center=True, - random_size=False, + keys=["img"], roi_size=(3, 3, 3), num_samples=num_samples, random_center=True, random_size=False ) transform = Compose([ToTensord(keys="img"), sampler]) samples = transform(data) diff --git a/tests/test_rand_spatial_cropd.py b/tests/test_rand_spatial_cropd.py index 610c1974aa..9e6e86eea2 100644 --- a/tests/test_rand_spatial_cropd.py +++ b/tests/test_rand_spatial_cropd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.transforms import RandSpatialCropd +from tests.utils import TEST_NDARRAYS TEST_CASE_0 = [ {"keys": "img", "roi_size": [3, 3, -1], "random_center": True}, @@ -67,10 +68,12 @@ def test_value(self, input_param, input_data): @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) def test_random_shape(self, input_param, input_data, expected_shape): - cropper = RandSpatialCropd(**input_param) - cropper.set_random_state(seed=123) - result = cropper(input_data) - self.assertTupleEqual(result["img"].shape, expected_shape) + for p in TEST_NDARRAYS: + cropper = RandSpatialCropd(**input_param) + cropper.set_random_state(seed=123) + input_data["img"] = p(input_data["img"]) + result = cropper(input_data) + self.assertTupleEqual(result["img"].shape, expected_shape) if __name__ == "__main__": diff --git a/tests/test_rand_std_shift_intensity.py b/tests/test_rand_std_shift_intensity.py index 0c6382555e..fdf386fee4 100644 --- a/tests/test_rand_std_shift_intensity.py +++ b/tests/test_rand_std_shift_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,6 +22,8 @@ class TestRandStdShiftIntensity(NumpyImageTestCase2D): def test_value(self): for p in TEST_NDARRAYS: np.random.seed(0) + # simulate the randomize() of transform + np.random.random() factor = np.random.uniform(low=-1.0, high=1.0) offset = factor * np.std(self.imt) expected = p(self.imt + offset) diff --git a/tests/test_rand_std_shift_intensityd.py b/tests/test_rand_std_shift_intensityd.py index 0ab017a42d..e98d1e3ad3 100644 --- a/tests/test_rand_std_shift_intensityd.py +++ b/tests/test_rand_std_shift_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,6 +23,8 @@ def test_value(self): for p in TEST_NDARRAYS: key = "img" np.random.seed(0) + # simulate the randomize() of transform + np.random.random() factor = np.random.uniform(low=-1.0, high=1.0) expected = self.imt + factor * np.std(self.imt) shifter = RandStdShiftIntensityd(keys=[key], factors=1.0, prob=1.0) diff --git a/tests/test_rand_weighted_crop.py b/tests/test_rand_weighted_crop.py index 39a9439122..dae7f05016 100644 --- a/tests/test_rand_weighted_crop.py +++ b/tests/test_rand_weighted_crop.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,127 +12,159 @@ import unittest import numpy as np +import torch +from parameterized.parameterized import parameterized from monai.transforms.croppad.array import RandWeightedCrop -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose -class TestRandWeightedCrop2D(NumpyImageTestCase2D): - def test_rand_weighted_crop_small_roi(self): - img = self.seg1[0] - n_samples = 3 - crop = RandWeightedCrop((10, 12), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 40, 31] = 1 - weight[0, 80, 21] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 10, 12)) - np.testing.assert_allclose(np.asarray(crop.centers), [[80, 21], [30, 17], [40, 31]]) +def get_data(ndim): + im_gen = NumpyImageTestCase2D() if ndim == 2 else NumpyImageTestCase3D() + im_gen.setUp() + return im_gen.imt[0], im_gen.seg1[0], im_gen.segn[0] + + +IMT_2D, SEG1_2D, SEGN_2D = get_data(ndim=2) +IMT_3D, SEG1_3D, SEGN_3D = get_data(ndim=3) + - def test_rand_weighted_crop_default_roi(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCrop((10, -1), n_samples) - weight = np.zeros_like(img) +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + im = SEG1_2D + weight = np.zeros_like(im) weight[0, 30, 17] = 1.1 weight[0, 40, 31] = 1 weight[0, 80, 21] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 10, 64)) - np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32], [105, 32], [20, 32]]) - - def test_rand_weighted_crop_large_roi(self): - img = self.segn[0] - n_samples = 3 - crop = RandWeightedCrop((10000, 400), n_samples) - weight = np.zeros_like(img) + TESTS.append( + [ + "small roi 2d", + dict(spatial_size=(10, 12), num_samples=3), + p(im), + q(weight), + (1, 10, 12), + [[80, 21], [30, 17], [40, 31]], + ] + ) + im = IMT_2D + TESTS.append( + [ + "default roi 2d", + dict(spatial_size=(10, -1), num_samples=3), + p(im), + q(weight), + (1, 10, 64), + [[14, 32], [105, 32], [20, 32]], + ] + ) + im = SEGN_2D + weight = np.zeros_like(im) weight[0, 30, 17] = 1.1 weight[0, 10, 1] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 128, 64)) - np.testing.assert_allclose(np.asarray(crop.centers), [[64, 32], [64, 32], [64, 32]]) - for res in result: - np.testing.assert_allclose(res, self.segn[0]) - - def test_rand_weighted_crop_bad_w(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCrop((20, 40), n_samples) - weight = np.zeros_like(img) + TESTS.append( + [ + "large roi 2d", + dict(spatial_size=(10000, 400), num_samples=3), + p(im), + q(weight), + (1, 128, 64), + [[64, 32], [64, 32], [64, 32]], + ] + ) + im = IMT_2D + weight = np.zeros_like(im) weight[0, 30, 17] = np.inf weight[0, 10, 1] = -np.inf weight[0, 10, 20] = -np.nan - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 20, 40)) - np.testing.assert_allclose(np.asarray(crop.centers), [[63, 37], [31, 43], [66, 20]]) - - -class TestRandWeightedCrop(NumpyImageTestCase3D): - def test_rand_weighted_crop_small_roi(self): - img = self.seg1[0] - n_samples = 3 - crop = RandWeightedCrop((8, 10, 12), n_samples) - weight = np.zeros_like(img) + TESTS.append( + [ + "bad w 2d", + dict(spatial_size=(20, 40), num_samples=3), + p(im), + q(weight), + (1, 20, 40), + [[63, 37], [31, 43], [66, 20]], + ] + ) + im = SEG1_3D + weight = np.zeros_like(im) weight[0, 5, 30, 17] = 1.1 weight[0, 8, 40, 31] = 1 weight[0, 11, 23, 21] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 8, 10, 12)) - np.testing.assert_allclose(np.asarray(crop.centers), [[11, 23, 21], [5, 30, 17], [8, 40, 31]]) - - def test_rand_weighted_crop_default_roi(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCrop((10, -1, -1), n_samples) - weight = np.zeros_like(img) + TESTS.append( + [ + "small roi 3d", + dict(spatial_size=(8, 10, 12), num_samples=3), + p(im), + q(weight), + (1, 8, 10, 12), + [[11, 23, 21], [5, 30, 17], [8, 40, 31]], + ] + ) + im = IMT_3D + weight = np.zeros_like(im) weight[0, 7, 17] = 1.1 weight[0, 13, 31] = 1.1 weight[0, 24, 21] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32, 40], [41, 32, 40], [20, 32, 40]]) - - def test_rand_weighted_crop_large_roi(self): - img = self.segn[0] - n_samples = 3 - crop = RandWeightedCrop((10000, 400, 80), n_samples) - weight = np.zeros_like(img) + TESTS.append( + [ + "default roi 3d", + dict(spatial_size=(10, -1, -1), num_samples=3), + p(im), + q(weight), + (1, 10, 64, 80), + [[14, 32, 40], [41, 32, 40], [20, 32, 40]], + ] + ) + im = SEGN_3D + weight = np.zeros_like(im) weight[0, 30, 17, 20] = 1.1 weight[0, 10, 1, 17] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]]) - for res in result: - np.testing.assert_allclose(res, self.segn[0]) - - def test_rand_weighted_crop_bad_w(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCrop((48, 64, 80), n_samples) - weight = np.zeros_like(img) + TESTS.append( + [ + "large roi 3d", + dict(spatial_size=(10000, 400, 80), num_samples=3), + p(im), + q(weight), + (1, 48, 64, 80), + [[24, 32, 40], [24, 32, 40], [24, 32, 40]], + ] + ) + im = IMT_3D + weight = np.zeros_like(im) weight[0, 30, 17] = np.inf weight[0, 10, 1] = -np.inf weight[0, 10, 20] = -np.nan + TESTS.append( + [ + "bad w 3d", + dict(spatial_size=(48, 64, 80), num_samples=3), + p(im), + q(weight), + (1, 48, 64, 80), + [[24, 32, 40], [24, 32, 40], [24, 32, 40]], + ] + ) + + +class TestRandWeightedCrop(unittest.TestCase): + @parameterized.expand(TESTS) + def test_rand_weighted_crop(self, _, input_params, img, weight, expected_shape, expected_vals): + crop = RandWeightedCrop(**input_params) crop.set_random_state(10) result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]]) + self.assertTrue(len(result) == input_params["num_samples"]) + assert_allclose(result[0].shape, expected_shape) + for c, e in zip(crop.centers, expected_vals): + assert_allclose(c, e, type_test=False) + # if desired ROI is larger than image, check image is unchanged + if all(s >= i for i, s in zip(img.shape[1:], input_params["spatial_size"])): + for res in result: + self.assertEqual(type(img), type(res)) + if isinstance(img, torch.Tensor): + self.assertEqual(res.device, img.device) + assert_allclose(res, img) if __name__ == "__main__": diff --git a/tests/test_rand_weighted_cropd.py b/tests/test_rand_weighted_cropd.py index 367ce3beb9..a357398f1c 100644 --- a/tests/test_rand_weighted_cropd.py +++ b/tests/test_rand_weighted_cropd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,148 +14,178 @@ import numpy as np from monai.transforms.croppad.dictionary import RandWeightedCropd -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from monai.utils.enums import PostFix +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose class TestRandWeightedCrop(NumpyImageTestCase2D): def test_rand_weighted_crop_small_roi(self): - img = self.seg1[0] - n_samples = 3 - crop = RandWeightedCropd("img", "w", (10, 12), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 40, 31] = 1 - weight[0, 80, 21] = 1 - crop.set_random_state(10) - d = {"img": img, "w": weight} - result = crop(d) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 12)) - np.testing.assert_allclose(np.asarray(crop.centers), [[80, 21], [30, 17], [40, 31]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.seg1[0] + n_samples = 3 + crop = RandWeightedCropd("img", "w", (10, 12), n_samples) + weight = np.zeros_like(img) + weight[0, 30, 17] = 1.1 + weight[0, 40, 31] = 1 + weight[0, 80, 21] = 1 + crop.set_random_state(10) + d = {"img": p(img), "w": q(weight)} + result = crop(d) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 12)) + for c, e in zip(crop.centers, [[80, 21], [30, 17], [40, 31]]): + assert_allclose(c, e, type_test=False) def test_rand_weighted_crop_default_roi(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd("im", "weight", (10, -1), n_samples, "coords") - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 40, 31] = 1 - weight[0, 80, 21] = 1 - crop.set_random_state(10) - data = {"im": img, "weight": weight, "others": np.nan} - result = crop(data) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["im"].shape, (1, 10, 64)) - np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32], [105, 32], [20, 32]]) - np.testing.assert_allclose(result[1]["coords"], [105, 32]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.imt[0] + n_samples = 3 + crop = RandWeightedCropd("im", "weight", (10, -1), n_samples, "coords") + weight = np.zeros_like(img) + weight[0, 30, 17] = 1.1 + weight[0, 40, 31] = 1 + weight[0, 80, 21] = 1 + crop.set_random_state(10) + data = {"im": p(img), "weight": q(weight), "others": np.nan} + result = crop(data) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["im"].shape, (1, 10, 64)) + for c, e in zip(crop.centers, [[14, 32], [105, 32], [20, 32]]): + assert_allclose(c, e, type_test=False) + assert_allclose(result[1]["coords"], [105, 32], type_test=False) def test_rand_weighted_crop_large_roi(self): - img = self.segn[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "weight", (10000, 400), n_samples, "location") - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 10, 1] = 1 - crop.set_random_state(10) - data = {"img": img, "seg": self.imt[0], "weight": weight} - result = crop(data) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 128, 64)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 128, 64)) - np.testing.assert_allclose(np.asarray(crop.centers), [[64, 32], [64, 32], [64, 32]]) - np.testing.assert_allclose(result[1]["location"], [64, 32]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.segn[0] + n_samples = 3 + crop = RandWeightedCropd(("img", "seg"), "weight", (10000, 400), n_samples, "location") + weight = np.zeros_like(img) + weight[0, 30, 17] = 1.1 + weight[0, 10, 1] = 1 + crop.set_random_state(10) + data = {"img": p(img), "seg": p(self.imt[0]), "weight": q(weight)} + result = crop(data) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 128, 64)) + np.testing.assert_allclose(result[0]["seg"].shape, (1, 128, 64)) + for c, e in zip(crop.centers, [[64, 32], [64, 32], [64, 32]]): + assert_allclose(c, e, type_test=False) + assert_allclose(result[1]["location"], [64, 32], type_test=False) def test_rand_weighted_crop_bad_w(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (20, 40), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = np.inf - weight[0, 10, 1] = -np.inf - weight[0, 10, 20] = -np.nan - crop.set_random_state(10) - result = crop({"img": img, "seg": self.segn[0], "w": weight}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 20, 40)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 20, 40)) - np.testing.assert_allclose(np.asarray(crop.centers), [[63, 37], [31, 43], [66, 20]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.imt[0] + n_samples = 3 + crop = RandWeightedCropd(("img", "seg"), "w", (20, 40), n_samples) + weight = np.zeros_like(img) + weight[0, 30, 17] = np.inf + weight[0, 10, 1] = -np.inf + weight[0, 10, 20] = -np.nan + crop.set_random_state(10) + result = crop({"img": p(img), "seg": p(self.segn[0]), "w": q(weight)}) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 20, 40)) + np.testing.assert_allclose(result[0]["seg"].shape, (1, 20, 40)) + for c, e in zip(crop.centers, [[63, 37], [31, 43], [66, 20]]): + assert_allclose(c, e, type_test=False) class TestRandWeightedCrop3D(NumpyImageTestCase3D): def test_rand_weighted_crop_small_roi(self): - img = self.seg1[0] - n_samples = 3 - crop = RandWeightedCropd("img", "w", (8, 10, 12), n_samples) - weight = np.zeros_like(img) - weight[0, 5, 30, 17] = 1.1 - weight[0, 8, 40, 31] = 1 - weight[0, 11, 23, 21] = 1 - crop.set_random_state(10) - result = crop({"img": img, "w": weight}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 8, 10, 12)) - np.testing.assert_allclose(np.asarray(crop.centers), [[11, 23, 21], [5, 30, 17], [8, 40, 31]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.seg1[0] + n_samples = 3 + crop = RandWeightedCropd("img", "w", (8, 10, 12), n_samples) + weight = np.zeros_like(img) + weight[0, 5, 30, 17] = 1.1 + weight[0, 8, 40, 31] = 1 + weight[0, 11, 23, 21] = 1 + crop.set_random_state(10) + result = crop({"img": p(img), "w": q(weight)}) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 8, 10, 12)) + for c, e in zip(crop.centers, [[11, 23, 21], [5, 30, 17], [8, 40, 31]]): + assert_allclose(c, e, type_test=False) def test_rand_weighted_crop_default_roi(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) - weight = np.zeros_like(img) - weight[0, 7, 17] = 1.1 - weight[0, 13, 31] = 1.1 - weight[0, 24, 21] = 1 - crop.set_random_state(10) - result = crop({"img": img, "seg": self.segn[0], "w": weight}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32, 40], [41, 32, 40], [20, 32, 40]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.imt[0] + n_samples = 3 + crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) + weight = np.zeros_like(img) + weight[0, 7, 17] = 1.1 + weight[0, 13, 31] = 1.1 + weight[0, 24, 21] = 1 + crop.set_random_state(10) + result = crop({"img": p(img), "seg": p(self.segn[0]), "w": q(weight)}) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 64, 80)) + np.testing.assert_allclose(result[0]["seg"].shape, (1, 10, 64, 80)) + for c, e in zip(crop.centers, [[14, 32, 40], [41, 32, 40], [20, 32, 40]]): + assert_allclose(c, e, type_test=False) def test_rand_weighted_crop_large_roi(self): - img = self.segn[0] - n_samples = 3 - crop = RandWeightedCropd("img", "w", (10000, 400, 80), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17, 20] = 1.1 - weight[0, 10, 1, 17] = 1 - crop.set_random_state(10) - result = crop({"img": img, "w": weight}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.segn[0] + n_samples = 3 + crop = RandWeightedCropd("img", "w", (10000, 400, 80), n_samples) + weight = np.zeros_like(img) + weight[0, 30, 17, 20] = 1.1 + weight[0, 10, 1, 17] = 1 + crop.set_random_state(10) + result = crop({"img": p(img), "w": q(weight)}) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) + for c, e in zip(crop.centers, [[24, 32, 40], [24, 32, 40], [24, 32, 40]]): + assert_allclose(c, e, type_test=False) def test_rand_weighted_crop_bad_w(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (48, 64, 80), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = np.inf - weight[0, 10, 1] = -np.inf - weight[0, 10, 20] = -np.nan - crop.set_random_state(10) - result = crop({"img": img, "seg": self.segn[0], "w": weight}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.imt[0] + n_samples = 3 + crop = RandWeightedCropd(("img", "seg"), "w", (48, 64, 80), n_samples) + weight = np.zeros_like(img) + weight[0, 30, 17] = np.inf + weight[0, 10, 1] = -np.inf + weight[0, 10, 20] = -np.nan + crop.set_random_state(10) + result = crop({"img": p(img), "seg": p(self.segn[0]), "w": q(weight)}) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) + np.testing.assert_allclose(result[0]["seg"].shape, (1, 48, 64, 80)) + for c, e in zip(crop.centers, [[24, 32, 40], [24, 32, 40], [24, 32, 40]]): + assert_allclose(c, e, type_test=False) def test_rand_weighted_crop_patch_index(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) - weight = np.zeros_like(img) - weight[0, 7, 17] = 1.1 - weight[0, 13, 31] = 1.1 - weight[0, 24, 21] = 1 - crop.set_random_state(10) - result = crop({"img": img, "seg": self.segn[0], "w": weight, "img_meta_dict": {"affine": None}}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32, 40], [41, 32, 40], [20, 32, 40]]) - for i in range(n_samples): - np.testing.assert_allclose(result[i]["img"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(result[i]["seg"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(result[i]["img_meta_dict"]["patch_index"], i) - np.testing.assert_allclose(result[i]["seg_meta_dict"]["patch_index"], i) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.imt[0] + n_samples = 3 + crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) + weight = np.zeros_like(img) + weight[0, 7, 17] = 1.1 + weight[0, 13, 31] = 1.1 + weight[0, 24, 21] = 1 + crop.set_random_state(10) + result = crop( + {"img": p(img), "seg": p(self.segn[0]), "w": q(weight), PostFix.meta("img"): {"affine": None}} + ) + self.assertTrue(len(result) == n_samples) + for c, e in zip(crop.centers, [[14, 32, 40], [41, 32, 40], [20, 32, 40]]): + assert_allclose(c, e, type_test=False) + for i in range(n_samples): + np.testing.assert_allclose(result[i]["img"].shape, (1, 10, 64, 80)) + np.testing.assert_allclose(result[i]["seg"].shape, (1, 10, 64, 80)) + np.testing.assert_allclose(result[i][PostFix.meta("img")]["patch_index"], i) + np.testing.assert_allclose(result[i][PostFix.meta("seg")]["patch_index"], i) if __name__ == "__main__": diff --git a/tests/test_rand_zoom.py b/tests/test_rand_zoom.py index c21bc8b9e9..35472024ef 100644 --- a/tests/test_rand_zoom.py +++ b/tests/test_rand_zoom.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,7 +17,7 @@ from monai.transforms import RandZoom from monai.utils import GridSampleMode, InterpolateMode -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose VALID_CASES = [(0.8, 1.2, "nearest", False), (0.8, 1.2, InterpolateMode.NEAREST, False)] @@ -25,36 +25,28 @@ class TestRandZoom(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) def test_correct_results(self, min_zoom, max_zoom, mode, keep_size): - random_zoom = RandZoom( - prob=1.0, - min_zoom=min_zoom, - max_zoom=max_zoom, - mode=mode, - keep_size=keep_size, - ) - random_zoom.set_random_state(1234) - zoomed = random_zoom(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False)) - expected = np.stack(expected).astype(np.float32) - np.testing.assert_allclose(zoomed, expected, atol=1.0) + for p in TEST_NDARRAYS: + random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode, keep_size=keep_size) + random_zoom.set_random_state(1234) + zoomed = random_zoom(p(self.imt[0])) + expected = [ + zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False) + for channel in self.imt[0] + ] + + expected = np.stack(expected).astype(np.float32) + assert_allclose(zoomed, p(expected), atol=1.0) def test_keep_size(self): - random_zoom = RandZoom( - prob=1.0, - min_zoom=0.6, - max_zoom=0.7, - keep_size=True, - padding_mode="constant", - constant_values=2, - ) - zoomed = random_zoom(self.imt[0]) - self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) - zoomed = random_zoom(self.imt[0]) - self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) - zoomed = random_zoom(self.imt[0]) - self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) + for p in TEST_NDARRAYS: + im = p(self.imt[0]) + random_zoom = RandZoom(prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True) + zoomed = random_zoom(im) + self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) + zoomed = random_zoom(im) + self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) + zoomed = random_zoom(im) + self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) @parameterized.expand( [ @@ -64,23 +56,19 @@ def test_keep_size(self): ] ) def test_invalid_inputs(self, _, min_zoom, max_zoom, mode, raises): - with self.assertRaises(raises): - random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode) - random_zoom(self.imt[0]) + for p in TEST_NDARRAYS: + with self.assertRaises(raises): + random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode) + random_zoom(p(self.imt[0])) def test_auto_expand_3d(self): - random_zoom = RandZoom( - prob=1.0, - min_zoom=[0.8, 0.7], - max_zoom=[1.2, 1.3], - mode="nearest", - keep_size=False, - ) - random_zoom.set_random_state(1234) - test_data = np.random.randint(0, 2, size=[2, 2, 3, 4]) - zoomed = random_zoom(test_data) - np.testing.assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2) - np.testing.assert_allclose(zoomed.shape, (2, 2, 3, 3)) + for p in TEST_NDARRAYS: + random_zoom = RandZoom(prob=1.0, min_zoom=[0.8, 0.7], max_zoom=[1.2, 1.3], mode="nearest", keep_size=False) + random_zoom.set_random_state(1234) + test_data = p(np.random.randint(0, 2, size=[2, 2, 3, 4])) + zoomed = random_zoom(test_data) + assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2) + assert_allclose(zoomed.shape, (2, 2, 3, 3)) if __name__ == "__main__": diff --git a/tests/test_rand_zoomd.py b/tests/test_rand_zoomd.py index 4ccb1aad64..a22f2f36f1 100644 --- a/tests/test_rand_zoomd.py +++ b/tests/test_rand_zoomd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,7 +16,7 @@ from scipy.ndimage import zoom as zoom_scipy from monai.transforms import RandZoomd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose VALID_CASES = [(0.8, 1.2, "nearest", None, False)] @@ -34,52 +34,47 @@ def test_correct_results(self, min_zoom, max_zoom, mode, align_corners, keep_siz align_corners=align_corners, keep_size=keep_size, ) - random_zoom.set_random_state(1234) + for p in TEST_NDARRAYS: + random_zoom.set_random_state(1234) - zoomed = random_zoom({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False)) - expected = np.stack(expected).astype(np.float32) - np.testing.assert_allclose(expected, zoomed[key], atol=1.0) + zoomed = random_zoom({key: p(self.imt[0])}) + expected = [ + zoom_scipy(channel, zoom=random_zoom.rand_zoom._zoom, mode="nearest", order=0, prefilter=False) + for channel in self.imt[0] + ] + + expected = np.stack(expected).astype(np.float32) + assert_allclose(zoomed[key], p(expected), atol=1.0) def test_keep_size(self): key = "img" random_zoom = RandZoomd( - keys=key, - prob=1.0, - min_zoom=0.6, - max_zoom=0.7, - keep_size=True, - padding_mode="constant", - constant_values=2, + keys=key, prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True, padding_mode="constant", constant_values=2 ) - zoomed = random_zoom({key: self.imt[0]}) - self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:])) + for p in TEST_NDARRAYS: + zoomed = random_zoom({key: p(self.imt[0])}) + np.testing.assert_array_equal(zoomed[key].shape, self.imt.shape[1:]) @parameterized.expand( [("no_min_zoom", None, 1.1, "bilinear", TypeError), ("invalid_order", 0.9, 1.1, "s", ValueError)] ) def test_invalid_inputs(self, _, min_zoom, max_zoom, mode, raises): key = "img" - with self.assertRaises(raises): - random_zoom = RandZoomd(key, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode) - random_zoom({key: self.imt[0]}) + for p in TEST_NDARRAYS: + with self.assertRaises(raises): + random_zoom = RandZoomd(key, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode) + random_zoom({key: p(self.imt[0])}) def test_auto_expand_3d(self): random_zoom = RandZoomd( - keys="img", - prob=1.0, - min_zoom=[0.8, 0.7], - max_zoom=[1.2, 1.3], - mode="nearest", - keep_size=False, + keys="img", prob=1.0, min_zoom=[0.8, 0.7], max_zoom=[1.2, 1.3], mode="nearest", keep_size=False ) - random_zoom.set_random_state(1234) - test_data = {"img": np.random.randint(0, 2, size=[2, 2, 3, 4])} - zoomed = random_zoom(test_data) - np.testing.assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2) - np.testing.assert_allclose(zoomed["img"].shape, (2, 2, 3, 3)) + for p in TEST_NDARRAYS: + random_zoom.set_random_state(1234) + test_data = {"img": p(np.random.randint(0, 2, size=[2, 2, 3, 4]))} + zoomed = random_zoom(test_data) + assert_allclose(random_zoom.rand_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2) + assert_allclose(zoomed["img"].shape, (2, 2, 3, 3)) if __name__ == "__main__": diff --git a/tests/test_randomizable.py b/tests/test_randomizable.py index 9972bded0f..7445287a12 100644 --- a/tests/test_randomizable.py +++ b/tests/test_randomizable.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_randtorchvisiond.py b/tests/test_randtorchvisiond.py index d0485ce405..2e96d723ee 100644 --- a/tests/test_randtorchvisiond.py +++ b/tests/test_randtorchvisiond.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -29,19 +29,10 @@ {"img": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])}, torch.tensor( [ - [ - [0.1090, 0.6193], - [0.6193, 0.9164], - ], - [ - [0.1090, 0.6193], - [0.6193, 0.9164], - ], - [ - [0.1090, 0.6193], - [0.6193, 0.9164], - ], - ], + [[0.1090, 0.6193], [0.6193, 0.9164]], + [[0.1090, 0.6193], [0.6193, 0.9164]], + [[0.1090, 0.6193], [0.6193, 0.9164]], + ] ), ] @@ -50,24 +41,9 @@ {"img": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])}, torch.tensor( [ - [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - ], + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], ] ), ] diff --git a/tests/test_reference_resolver.py b/tests/test_reference_resolver.py new file mode 100644 index 0000000000..e6b01c05f4 --- /dev/null +++ b/tests/test_reference_resolver.py @@ -0,0 +1,110 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +import monai +from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem +from monai.bundle.reference_resolver import ReferenceResolver +from monai.data import DataLoader +from monai.transforms import LoadImaged, RandTorchVisiond +from monai.utils import min_version, optional_import + +_, has_tv = optional_import("torchvision", "0.8.0", min_version) + +# test instance with no dependencies +TEST_CASE_1 = [ + { + # all the recursively parsed config items + "transform#1": {"_target_": "LoadImaged", "keys": ["image"]}, + "transform#1#_target_": "LoadImaged", + "transform#1#keys": ["image"], + "transform#1#keys#0": "image", + }, + "transform#1", + LoadImaged, +] +# test depends on other component and executable code +TEST_CASE_2 = [ + { + # some the recursively parsed config items + "dataloader": {"_target_": "DataLoader", "dataset": "@dataset", "collate_fn": "$monai.data.list_data_collate"}, + "dataset": {"_target_": "Dataset", "data": [1, 2]}, + "dataloader#_target_": "DataLoader", + "dataloader#dataset": "@dataset", + "dataloader#collate_fn": "$monai.data.list_data_collate", + "dataset#_target_": "Dataset", + "dataset#data": [1, 2], + "dataset#data#0": 1, + "dataset#data#1": 2, + }, + "dataloader", + DataLoader, +] +# test config has key `name` +TEST_CASE_3 = [ + { + # all the recursively parsed config items + "transform#1": {"_target_": "RandTorchVisiond", "keys": "image", "name": "ColorJitter", "brightness": 0.25}, + "transform#1#_target_": "RandTorchVisiond", + "transform#1#keys": "image", + "transform#1#name": "ColorJitter", + "transform#1#brightness": 0.25, + }, + "transform#1", + RandTorchVisiond, +] + + +class TestReferenceResolver(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2] + ([TEST_CASE_3] if has_tv else [])) + def test_resolve(self, configs, expected_id, output_type): + locator = ComponentLocator() + resolver = ReferenceResolver() + # add items to resolver + for k, v in configs.items(): + if ConfigComponent.is_instantiable(v): + resolver.add_item(ConfigComponent(config=v, id=k, locator=locator)) + elif ConfigExpression.is_expression(v): + resolver.add_item(ConfigExpression(config=v, id=k, globals={"monai": monai, "torch": torch})) + else: + resolver.add_item(ConfigItem(config=v, id=k)) + + result = resolver.get_resolved_content(expected_id) # the root id is `expected_id` here + self.assertTrue(isinstance(result, output_type)) + + # test lazy instantiation + item = resolver.get_item(expected_id, resolve=True) + config = item.get_config() + config["_disabled_"] = False + item.update_config(config=config) + if isinstance(item, ConfigComponent): + result = item.instantiate() + else: + result = item.get_config() + self.assertTrue(isinstance(result, output_type)) + + def test_circular_references(self): + locator = ComponentLocator() + resolver = ReferenceResolver() + configs = {"A": "@B", "B": "@C", "C": "@A"} + for k, v in configs.items(): + resolver.add_item(ConfigComponent(config=v, id=k, locator=locator)) + for k in ["A", "B", "C"]: + with self.assertRaises(ValueError): + resolver.get_resolved_content(k) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py index b864a64647..822a056879 100644 --- a/tests/test_reg_loss_integration.py +++ b/tests/test_reg_loss_integration.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,25 +17,15 @@ from parameterized import parameterized from monai.losses import BendingEnergyLoss, GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss +from tests.utils import SkipIfBeforePyTorchVersion TEST_CASES = [ - [BendingEnergyLoss, {}, ["pred"]], - [ - LocalNormalizedCrossCorrelationLoss, - {"kernel_size": 7, "kernel_type": "rectangular"}, - ["pred", "target"], - ], - [ - LocalNormalizedCrossCorrelationLoss, - {"kernel_size": 5, "kernel_type": "triangular"}, - ["pred", "target"], - ], - [ - LocalNormalizedCrossCorrelationLoss, - {"kernel_size": 3, "kernel_type": "gaussian"}, - ["pred", "target"], - ], + [BendingEnergyLoss, {}, ["pred"], 3], + [LocalNormalizedCrossCorrelationLoss, {"kernel_size": 7, "kernel_type": "rectangular"}, ["pred", "target"]], + [LocalNormalizedCrossCorrelationLoss, {"kernel_size": 5, "kernel_type": "triangular"}, ["pred", "target"]], + [LocalNormalizedCrossCorrelationLoss, {"kernel_size": 3, "kernel_type": "gaussian"}, ["pred", "target"]], [GlobalMutualInformationLoss, {"num_bins": 10}, ["pred", "target"]], + [GlobalMutualInformationLoss, {"kernel_type": "b-spline", "num_bins": 10}, ["pred", "target"]], ] @@ -51,7 +41,8 @@ def tearDown(self): torch.backends.cudnn.benchmark = True @parameterized.expand(TEST_CASES) - def test_convergence(self, loss_type, loss_args, forward_args): + @SkipIfBeforePyTorchVersion((1, 9)) + def test_convergence(self, loss_type, loss_args, forward_args, pred_channels=1): """ The goal of this test is to assess if the gradient of the loss function is correct by testing if we can train a one layer neural network @@ -69,11 +60,11 @@ def test_convergence(self, loss_type, loss_args, forward_args): # define a one layer model class OnelayerNet(nn.Module): def __init__(self): - super(OnelayerNet, self).__init__() + super().__init__() self.layer = nn.Sequential( nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1), nn.ReLU(), - nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1), + nn.Conv3d(in_channels=1, out_channels=pred_channels, kernel_size=3, padding=1), ) def forward(self, x): diff --git a/tests/test_regunet.py b/tests/test_regunet.py index 4dd968a1cf..e37ca49538 100644 --- a/tests/test_regunet.py +++ b/tests/test_regunet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_regunet_block.py b/tests/test_regunet_block.py index 9b96875432..3be02ea377 100644 --- a/tests/test_regunet_block.py +++ b/tests/test_regunet_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_remove_repeated_channel.py b/tests/test_remove_repeated_channel.py index ebbe6c730c..39b42cc4b0 100644 --- a/tests/test_remove_repeated_channel.py +++ b/tests/test_remove_repeated_channel.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_remove_repeated_channeld.py b/tests/test_remove_repeated_channeld.py index 9d4812791e..9db66a6aa0 100644 --- a/tests/test_remove_repeated_channeld.py +++ b/tests/test_remove_repeated_channeld.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_repeat_channel.py b/tests/test_repeat_channel.py index e246dd1212..3d74b6479c 100644 --- a/tests/test_repeat_channel.py +++ b/tests/test_repeat_channel.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_repeat_channeld.py b/tests/test_repeat_channeld.py index 3b73962bb9..a348f3eea9 100644 --- a/tests/test_repeat_channeld.py +++ b/tests/test_repeat_channeld.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_require_pkg.py b/tests/test_require_pkg.py new file mode 100644 index 0000000000..dbe63fff3f --- /dev/null +++ b/tests/test_require_pkg.py @@ -0,0 +1,77 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from monai.utils import OptionalImportError, min_version, require_pkg + + +class TestRequirePkg(unittest.TestCase): + def test_class(self): + @require_pkg(pkg_name="torch", version="1.4", version_checker=min_version) + class TestClass: + pass + + TestClass() + + def test_function(self): + @require_pkg(pkg_name="torch", version="1.4", version_checker=min_version) + def test_func(x): + return x + + test_func(x=None) + + def test_warning(self): + @require_pkg(pkg_name="test123", raise_error=False) + def test_func(x): + return x + + test_func(x=None) + + def test_class_exception(self): + with self.assertRaises(OptionalImportError): + + @require_pkg(pkg_name="test123") + class TestClass: + pass + + TestClass() + + def test_class_version_exception(self): + with self.assertRaises(OptionalImportError): + + @require_pkg(pkg_name="torch", version="10000", version_checker=min_version) + class TestClass: + pass + + TestClass() + + def test_func_exception(self): + with self.assertRaises(OptionalImportError): + + @require_pkg(pkg_name="test123") + def test_func(x): + return x + + test_func(x=None) + + def test_func_versions_exception(self): + with self.assertRaises(OptionalImportError): + + @require_pkg(pkg_name="torch", version="10000", version_checker=min_version) + def test_func(x): + return x + + test_func(x=None) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_resample_datalist.py b/tests/test_resample_datalist.py new file mode 100644 index 0000000000..fa120b261e --- /dev/null +++ b/tests/test_resample_datalist.py @@ -0,0 +1,40 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.data import resample_datalist + +TEST_CASE_1 = [ + {"data": [1, 2, 3, 4, 5], "factor": 2.5, "random_pick": True, "seed": 123}, + [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 2, 4, 5], +] + +TEST_CASE_2 = [ + {"data": [1, 2, 3, 4, 5], "factor": 2.5, "random_pick": False, "seed": 0}, + [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3], +] + +TEST_CASE_3 = [{"data": [1, 2, 3, 4, 5], "factor": 0.6, "random_pick": True, "seed": 123}, [2, 4, 5]] + + +class TestResampleDatalist(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_value_shape(self, input_param, expected): + result = resample_datalist(**input_param) + np.testing.assert_allclose(result, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_resample_to_match.py b/tests/test_resample_to_match.py new file mode 100644 index 0000000000..e1f6a28998 --- /dev/null +++ b/tests/test_resample_to_match.py @@ -0,0 +1,61 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import itertools +import os +import tempfile +import unittest + +import nibabel as nib +import numpy as np +from parameterized import parameterized + +from monai.data.image_reader import ITKReader, NibabelReader +from monai.data.image_writer import ITKWriter +from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, ResampleToMatch, SaveImaged +from tests.utils import assert_allclose, download_url_or_skip_test, testing_data_config + +TEST_CASES = ["itkreader", "nibabelreader"] + + +class TestResampleToMatch(unittest.TestCase): + def setUp(self): + self.fnames = [] + for key in ("0000_t2_tse_tra_4", "0000_ep2d_diff_tra_7"): + fname = os.path.join(os.path.dirname(__file__), "testing_data", f"test_{key}.nii.gz") + url = testing_data_config("images", key, "url") + hash_type = testing_data_config("images", key, "hash_type") + hash_val = testing_data_config("images", key, "hash_val") + download_url_or_skip_test(url=url, filepath=fname, hash_type=hash_type, hash_val=hash_val) + self.fnames.append(fname) + + @parameterized.expand(itertools.product([NibabelReader, ITKReader], ["monai.data.NibabelWriter", ITKWriter])) + def test_correct(self, reader, writer): + with tempfile.TemporaryDirectory() as temp_dir: + loader = Compose([LoadImaged(("im1", "im2"), reader=reader), EnsureChannelFirstd(("im1", "im2"))]) + data = loader({"im1": self.fnames[0], "im2": self.fnames[1]}) + + im_mod, meta = ResampleToMatch()(data["im2"], data["im2_meta_dict"], data["im1_meta_dict"]) + current_dims = copy.deepcopy(meta.get("dim")) + saver = SaveImaged("im3", output_dir=temp_dir, output_postfix="", separate_folder=False, writer=writer) + meta["filename_or_obj"] = "file3.nii.gz" + saver({"im3": im_mod, "im3_meta_dict": meta}) + + saved = nib.load(os.path.join(temp_dir, meta["filename_or_obj"])) + assert_allclose(data["im1"].shape[1:], saved.shape) + assert_allclose(saved.header["dim"][:4], np.array([3, 384, 384, 19])) + if current_dims is not None: + assert_allclose(saved.header["dim"], current_dims) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_resample_to_matchd.py b/tests/test_resample_to_matchd.py new file mode 100644 index 0000000000..d9dbeee133 --- /dev/null +++ b/tests/test_resample_to_matchd.py @@ -0,0 +1,79 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +from monai.transforms import ( + Compose, + CopyItemsd, + EnsureChannelFirstd, + Invertd, + Lambda, + LoadImaged, + ResampleToMatchd, + SaveImaged, +) +from tests.utils import assert_allclose, download_url_or_skip_test, testing_data_config + + +def update_fname(d): + d["im3_meta_dict"]["filename_or_obj"] = "file3.nii.gz" + return d + + +class TestResampleToMatchd(unittest.TestCase): + def setUp(self): + self.fnames = [] + for key in ("0000_t2_tse_tra_4", "0000_ep2d_diff_tra_7"): + fname = os.path.join(os.path.dirname(__file__), "testing_data", f"test_{key}.nii.gz") + url = testing_data_config("images", key, "url") + hash_type = testing_data_config("images", key, "hash_type") + hash_val = testing_data_config("images", key, "hash_val") + download_url_or_skip_test(url=url, filepath=fname, hash_type=hash_type, hash_val=hash_val) + self.fnames.append(fname) + + def test_correct(self): + with tempfile.TemporaryDirectory() as temp_dir: + transforms = Compose( + [ + LoadImaged(("im1", "im2")), + EnsureChannelFirstd(("im1", "im2")), + CopyItemsd(("im2", "im2_meta_dict"), names=("im3", "im3_meta_dict")), + ResampleToMatchd("im3", "im1_meta_dict"), + Lambda(update_fname), + SaveImaged("im3", output_dir=temp_dir, output_postfix="", separate_folder=False), + ] + ) + data = transforms({"im1": self.fnames[0], "im2": self.fnames[1]}) + # check that output sizes match + assert_allclose(data["im1"].shape, data["im3"].shape) + # and that the meta data has been updated accordingly + assert_allclose(data["im3"].shape[1:], data["im3_meta_dict"]["spatial_shape"], type_test=False) + assert_allclose(data["im3_meta_dict"]["affine"], data["im1_meta_dict"]["affine"]) + # check we're different from the original + self.assertTrue(any(i != j for i, j in zip(data["im3"].shape, data["im2"].shape))) + self.assertTrue( + any( + i != j + for i, j in zip( + data["im3_meta_dict"]["affine"].flatten(), data["im2_meta_dict"]["affine"].flatten() + ) + ) + ) + # test the inverse + data = Invertd("im3", transforms, "im3")(data) + assert_allclose(data["im2"].shape, data["im3"].shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_resampler.py b/tests/test_resampler.py index 2be94acebd..7dfb86a7a9 100644 --- a/tests/test_resampler.py +++ b/tests/test_resampler.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,69 +17,146 @@ from monai.transforms import Resample from monai.transforms.utils import create_grid +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"grid": create_grid((2, 2)), "img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 1.0], [2.0, 3.0]]]), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4)), "img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(padding_mode="border", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4)), "img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3, 3.0], [2.0, 2.0, 3.0, 3.0]]]), - ], - [ - dict(padding_mode="reflection", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4)), "img": np.arange(4).reshape((1, 2, 2)), "mode": "nearest"}, - np.array([[[3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0], [3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0]]]), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4, 4)), "img": np.arange(8).reshape((1, 2, 2, 2)), "mode": "bilinear"}, - np.array( - [ +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 4.0, 5.0, 0.0], [0.0, 6.0, 7.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + dict(padding_mode="zeros", device=device), + {"grid": p(create_grid((2, 2))), "img": q(np.arange(4).reshape((1, 2, 2)))}, + q(np.array([[[0.0, 1.0], [2.0, 3.0]]])), ] - ] - ), - ], - [ - dict(padding_mode="border", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4, 4)), "img": np.arange(8).reshape((1, 2, 2, 2)), "mode": "bilinear"}, - np.array( - [ + ) + TESTS.append( [ - [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3.0, 3.0], [2.0, 2.0, 3.0, 3.0]], - [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3.0, 3.0], [2.0, 2.0, 3.0, 3.0]], - [[4.0, 4.0, 5.0, 5.0], [4.0, 4.0, 5.0, 5.0], [6.0, 6.0, 7.0, 7.0], [6.0, 6.0, 7.0, 7.0]], - [[4.0, 4.0, 5.0, 5.0], [4.0, 4.0, 5.0, 5.0], [6.0, 6.0, 7.0, 7.0], [6.0, 6.0, 7.0, 7.0]], + dict(padding_mode="zeros", device=device), + {"grid": p(create_grid((4, 4))), "img": q(np.arange(4).reshape((1, 2, 2)))}, + q( + np.array( + [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]] + ) + ), ] - ] - ), - ], -] + ) + TESTS.append( + [ + dict(padding_mode="border", device=device), + {"grid": p(create_grid((4, 4))), "img": q(np.arange(4).reshape((1, 2, 2)))}, + q( + np.array( + [[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3, 3.0], [2.0, 2.0, 3.0, 3.0]]] + ) + ), + ] + ) + TESTS.append( + [ + dict(padding_mode="reflection", device=device), + {"grid": p(create_grid((4, 4))), "img": q(np.arange(4).reshape((1, 2, 2))), "mode": "nearest"}, + q( + np.array( + [[[3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0], [3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0]]] + ) + ), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + { + "grid": p(create_grid((4, 4, 4))), + "img": q(np.arange(8).reshape((1, 2, 2, 2))), + "mode": "bilinear", + }, + q( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 2.0, 3.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 4.0, 5.0, 0.0], + [0.0, 6.0, 7.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + dict(padding_mode="border", device=device), + { + "grid": p(create_grid((4, 4, 4))), + "img": q(np.arange(8).reshape((1, 2, 2, 2))), + "mode": "bilinear", + }, + q( + np.array( + [ + [ + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + [2.0, 2.0, 3.0, 3.0], + [2.0, 2.0, 3.0, 3.0], + ], + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + [2.0, 2.0, 3.0, 3.0], + [2.0, 2.0, 3.0, 3.0], + ], + [ + [4.0, 4.0, 5.0, 5.0], + [4.0, 4.0, 5.0, 5.0], + [6.0, 6.0, 7.0, 7.0], + [6.0, 6.0, 7.0, 7.0], + ], + [ + [4.0, 4.0, 5.0, 5.0], + [4.0, 4.0, 5.0, 5.0], + [6.0, 6.0, 7.0, 7.0], + [6.0, 6.0, 7.0, 7.0], + ], + ] + ] + ) + ), + ] + ) class TestResample(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_resample(self, input_param, input_data, expected_val): g = Resample(**input_param) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + if "device" in input_data: + self.assertEqual(result.device, input_data["device"]) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_resize.py b/tests/test_resize.py index e5ec5dd1a9..06246b2358 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.transforms import Resize -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose TEST_CASE_0 = [{"spatial_size": 15}, (6, 10, 15)] @@ -45,16 +45,17 @@ def test_correct_results(self, spatial_size, mode): _order = 1 if spatial_size == (32, -1): spatial_size = (32, 64) - expected = [] - for channel in self.imt[0]: - expected.append( - skimage.transform.resize( - channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=False - ) + expected = [ + skimage.transform.resize( + channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=False ) + for channel in self.imt[0] + ] + expected = np.stack(expected).astype(np.float32) - out = resize(self.imt[0]) - np.testing.assert_allclose(out, expected, atol=0.9) + for p in TEST_NDARRAYS: + out = resize(p(self.imt[0])) + assert_allclose(out, expected, type_test=False, atol=0.9) @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) def test_longest_shape(self, input_param, expected_shape): diff --git a/tests/test_resize_with_pad_or_crop.py b/tests/test_resize_with_pad_or_crop.py index 46f1fc86cc..f81e1d4b08 100644 --- a/tests/test_resize_with_pad_or_crop.py +++ b/tests/test_resize_with_pad_or_crop.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,47 +12,38 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import ResizeWithPadOrCrop +from tests.utils import TEST_NDARRAYS TEST_CASES = [ - [ - {"spatial_size": [15, 8, 8], "mode": "constant"}, - (3, 8, 8, 4), - (3, 15, 8, 8), - ], + [{"spatial_size": [15, 8, 8], "mode": "constant"}, (3, 8, 8, 4), (3, 15, 8, 8)], [ {"spatial_size": [15, 4, 8], "mode": "constant", "method": "end", "constant_values": 1}, (3, 8, 8, 4), (3, 15, 4, 8), ], - [ - {"spatial_size": [15, 4, -1], "mode": "constant"}, - (3, 8, 8, 4), - (3, 15, 4, 4), - ], - [ - {"spatial_size": [15, 4, -1], "mode": "reflect"}, - (3, 8, 8, 4), - (3, 15, 4, 4), - ], - [ - {"spatial_size": [-1, -1, -1], "mode": "reflect"}, - (3, 8, 8, 4), - (3, 8, 8, 4), - ], + [{"spatial_size": [15, 4, -1], "mode": "constant"}, (3, 8, 8, 4), (3, 15, 4, 4)], + [{"spatial_size": [15, 4, -1], "mode": "reflect"}, (3, 8, 8, 4), (3, 15, 4, 4)], + [{"spatial_size": [-1, -1, -1], "mode": "reflect"}, (3, 8, 8, 4), (3, 8, 8, 4)], ] class TestResizeWithPadOrCrop(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_pad_shape(self, input_param, input_shape, expected_shape): - paddcroper = ResizeWithPadOrCrop(**input_param) - result = paddcroper(np.zeros(input_shape)) - np.testing.assert_allclose(result.shape, expected_shape) - result = paddcroper(np.zeros(input_shape), mode="constant") - np.testing.assert_allclose(result.shape, expected_shape) + for p in TEST_NDARRAYS: + if isinstance(p(0), torch.Tensor) and ( + "constant_values" in input_param or input_param["mode"] == "reflect" + ): + continue + paddcroper = ResizeWithPadOrCrop(**input_param) + result = paddcroper(p(np.zeros(input_shape))) + np.testing.assert_allclose(result.shape, expected_shape) + result = paddcroper(p(np.zeros(input_shape)), mode="constant") + np.testing.assert_allclose(result.shape, expected_shape) if __name__ == "__main__": diff --git a/tests/test_resize_with_pad_or_cropd.py b/tests/test_resize_with_pad_or_cropd.py index 32a62a9e16..28993a2bf4 100644 --- a/tests/test_resize_with_pad_or_cropd.py +++ b/tests/test_resize_with_pad_or_cropd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,45 +12,37 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import ResizeWithPadOrCropd +from tests.utils import TEST_NDARRAYS TEST_CASES = [ - [ - {"keys": "img", "spatial_size": [15, 8, 8], "mode": "constant"}, - {"img": np.zeros((3, 8, 8, 4))}, - (3, 15, 8, 8), - ], + [{"keys": "img", "spatial_size": [15, 8, 8], "mode": "constant"}, {"img": np.zeros((3, 8, 8, 4))}, (3, 15, 8, 8)], [ {"keys": "img", "spatial_size": [15, 4, 8], "mode": "constant", "method": "end", "constant_values": 1}, {"img": np.zeros((3, 8, 8, 4))}, (3, 15, 4, 8), ], - [ - {"keys": "img", "spatial_size": [15, 4, -1], "mode": "constant"}, - {"img": np.zeros((3, 8, 8, 4))}, - (3, 15, 4, 4), - ], - [ - {"keys": "img", "spatial_size": [15, 4, -1], "mode": "reflect"}, - {"img": np.zeros((3, 8, 8, 4))}, - (3, 15, 4, 4), - ], - [ - {"keys": "img", "spatial_size": [-1, -1, -1], "mode": "reflect"}, - {"img": np.zeros((3, 8, 8, 4))}, - (3, 8, 8, 4), - ], + [{"keys": "img", "spatial_size": [15, 4, -1], "mode": "constant"}, {"img": np.zeros((3, 8, 8, 4))}, (3, 15, 4, 4)], + [{"keys": "img", "spatial_size": [15, 4, -1], "mode": "reflect"}, {"img": np.zeros((3, 8, 8, 4))}, (3, 15, 4, 4)], + [{"keys": "img", "spatial_size": [-1, -1, -1], "mode": "reflect"}, {"img": np.zeros((3, 8, 8, 4))}, (3, 8, 8, 4)], ] class TestResizeWithPadOrCropd(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_pad_shape(self, input_param, input_data, expected_val): - paddcroper = ResizeWithPadOrCropd(**input_param) - result = paddcroper(input_data) - np.testing.assert_allclose(result["img"].shape, expected_val) + for p in TEST_NDARRAYS: + if isinstance(p(0), torch.Tensor) and ( + "constant_values" in input_param or input_param["mode"] == "reflect" + ): + continue + paddcroper = ResizeWithPadOrCropd(**input_param) + input_data["img"] = p(input_data["img"]) + result = paddcroper(input_data) + np.testing.assert_allclose(result["img"].shape, expected_val) if __name__ == "__main__": diff --git a/tests/test_resized.py b/tests/test_resized.py index 930faf00eb..d7374ea930 100644 --- a/tests/test_resized.py +++ b/tests/test_resized.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.transforms import Resized -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose TEST_CASE_0 = [{"keys": "img", "spatial_size": 15}, (6, 10, 15)] @@ -48,16 +48,17 @@ def test_correct_results(self, spatial_size, mode): _order = 1 if spatial_size == (32, -1): spatial_size = (32, 64) - expected = [] - for channel in self.imt[0]: - expected.append( - skimage.transform.resize( - channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=False - ) + expected = [ + skimage.transform.resize( + channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=False ) + for channel in self.imt[0] + ] + expected = np.stack(expected).astype(np.float32) - out = resize({"img": self.imt[0]})["img"] - np.testing.assert_allclose(out, expected, atol=0.9) + for p in TEST_NDARRAYS: + out = resize({"img": p(self.imt[0])})["img"] + assert_allclose(out, expected, type_test=False, atol=0.9) @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_longest_shape(self, input_param, expected_shape): diff --git a/tests/test_resnet.py b/tests/test_resnet.py index c4ba5c2e16..688f7827b1 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -31,25 +31,72 @@ device = "cuda" if torch.cuda.is_available() else "cpu" TEST_CASE_1 = [ # 3D, batch 3, 2 input channel - {"pretrained": False, "spatial_dims": 3, "n_input_channels": 2, "num_classes": 3}, + { + "pretrained": False, + "spatial_dims": 3, + "n_input_channels": 2, + "num_classes": 3, + "conv1_t_size": 7, + "conv1_t_stride": (2, 2, 2), + }, (3, 2, 32, 64, 48), (3, 3), ] TEST_CASE_2 = [ # 2D, batch 2, 1 input channel - {"pretrained": False, "spatial_dims": 2, "n_input_channels": 1, "num_classes": 3}, + { + "pretrained": False, + "spatial_dims": 2, + "n_input_channels": 1, + "num_classes": 3, + "conv1_t_size": [7, 7], + "conv1_t_stride": [2, 2], + }, + (2, 1, 32, 64), + (2, 3), +] + +TEST_CASE_2_A = [ # 2D, batch 2, 1 input channel, shortcut type A + { + "pretrained": False, + "spatial_dims": 2, + "n_input_channels": 1, + "num_classes": 3, + "shortcut_type": "A", + "conv1_t_size": (7, 7), + "conv1_t_stride": 2, + }, (2, 1, 32, 64), (2, 3), ] TEST_CASE_3 = [ # 1D, batch 1, 2 input channels - {"pretrained": False, "spatial_dims": 1, "n_input_channels": 2, "num_classes": 3}, + { + "pretrained": False, + "spatial_dims": 1, + "n_input_channels": 2, + "num_classes": 3, + "conv1_t_size": [3], + "conv1_t_stride": 1, + }, (1, 2, 32), (1, 3), ] +TEST_CASE_3_A = [ # 1D, batch 1, 2 input channels + {"pretrained": False, "spatial_dims": 1, "n_input_channels": 2, "num_classes": 3, "shortcut_type": "A"}, + (1, 2, 32), + (1, 3), +] + +TEST_CASE_4 = [ # 2D, batch 2, 1 input channel + {"pretrained": False, "spatial_dims": 2, "n_input_channels": 1, "num_classes": 3, "feed_forward": False}, + (2, 1, 32, 64), + ((2, 512), (2, 2048)), +] + TEST_CASES = [] -for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]: +for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]: for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]: TEST_CASES.append([model, *case]) @@ -64,7 +111,10 @@ def test_resnet_shape(self, model, input_param, input_shape, expected_shape): net = model(**input_param).to(device) with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) - self.assertEqual(result.shape, expected_shape) + if input_param.get("feed_forward", True): + self.assertEqual(result.shape, expected_shape) + else: + self.assertTrue(result.shape in expected_shape) @parameterized.expand(TEST_SCRIPT_CASES) def test_script(self, model, input_param, input_shape, expected_shape): diff --git a/tests/test_rotate.py b/tests/test_rotate.py index 436c952d4b..01842f6d73 100644 --- a/tests/test_rotate.py +++ b/tests/test_rotate.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,42 +10,44 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import Rotate -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D - -TEST_CASES_2D = [ - (np.pi / 6, False, "bilinear", "border", False), - (np.pi / 4, True, "bilinear", "border", False), - (-np.pi / 4.5, True, "nearest", "reflection", False), - (np.pi, False, "nearest", "zeros", False), - (-np.pi / 2, False, "bilinear", "zeros", True), -] - -TEST_CASES_3D = [ - (-np.pi / 2, True, "nearest", "border", False), - (np.pi / 4, True, "bilinear", "border", False), - (-np.pi / 4.5, True, "nearest", "reflection", False), - (np.pi, False, "nearest", "zeros", False), - (-np.pi / 2, False, "bilinear", "zeros", False), -] - -TEST_CASES_SHAPE_3D = [ - ([-np.pi / 2, 1.0, 2.0], "nearest", "border", False), - ([np.pi / 4, 0, 0], "bilinear", "border", False), - ([-np.pi / 4.5, -20, 20], "nearest", "reflection", False), -] +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D + +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, np.pi / 6, False, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, -np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", False)) + TEST_CASES_2D.append((p, -np.pi / 2, False, "bilinear", "zeros", True)) + +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append((p, -np.pi / 2, True, "nearest", "border", False)) + TEST_CASES_3D.append((p, np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_3D.append((p, -np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_3D.append((p, np.pi, False, "nearest", "zeros", False)) + TEST_CASES_3D.append((p, -np.pi / 2, False, "bilinear", "zeros", False)) + +TEST_CASES_SHAPE_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_SHAPE_3D.append((p, [-np.pi / 2, 1.0, 2.0], "nearest", "border", False)) + TEST_CASES_SHAPE_3D.append((p, [np.pi / 4, 0, 0], "bilinear", "border", False)) + TEST_CASES_SHAPE_3D.append((p, [-np.pi / 4.5, -20, 20], "nearest", "reflection", False)) class TestRotate2D(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES_2D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): - rotate_fn = Rotate(angle, keep_size, mode, padding_mode, align_corners) - rotated = rotate_fn(self.imt[0]) + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): + rotate_fn = Rotate(angle, keep_size, mode, padding_mode, align_corners, dtype=np.float64) + rotated = rotate_fn(im_type(self.imt[0])) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated.shape) _order = 0 if mode == "nearest" else 1 @@ -60,25 +62,20 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne for channel in self.imt[0]: expected.append( scipy.ndimage.rotate( - channel, - -np.rad2deg(angle), - (0, 1), - not keep_size, - order=_order, - mode=_mode, - prefilter=False, + channel, -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) ) expected = np.stack(expected).astype(np.float32) + rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated good = np.sum(np.isclose(expected, rotated, atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") class TestRotate3D(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): - rotate_fn = Rotate([angle, 0, 0], keep_size, mode, padding_mode, align_corners) - rotated = rotate_fn(self.imt[0]) + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): + rotate_fn = Rotate([angle, 0, 0], keep_size, mode, padding_mode, align_corners, dtype=np.float64) + rotated = rotate_fn(im_type(self.imt[0])) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated.shape) _order = 0 if mode == "nearest" else 1 @@ -93,33 +90,29 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne for channel in self.imt[0]: expected.append( scipy.ndimage.rotate( - channel, - -np.rad2deg(angle), - (1, 2), - not keep_size, - order=_order, - mode=_mode, - prefilter=False, + channel, -np.rad2deg(angle), (1, 2), not keep_size, order=_order, mode=_mode, prefilter=False ) ) expected = np.stack(expected).astype(np.float32) + rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated n_good = np.sum(np.isclose(expected, rotated, atol=1e-3)) self.assertLessEqual(expected.size - n_good, 5, "diff at most 5 pixels") @parameterized.expand(TEST_CASES_SHAPE_3D) - def test_correct_shape(self, angle, mode, padding_mode, align_corners): - rotate_fn = Rotate(angle, True, align_corners=align_corners) - rotated = rotate_fn(self.imt[0], mode=mode, padding_mode=padding_mode) + def test_correct_shape(self, im_type, angle, mode, padding_mode, align_corners): + rotate_fn = Rotate(angle, True, align_corners=align_corners, dtype=np.float64) + rotated = rotate_fn(im_type(self.imt[0]), mode=mode, padding_mode=padding_mode) np.testing.assert_allclose(self.imt[0].shape, rotated.shape) def test_ill_case(self): - rotate_fn = Rotate(10, True) - with self.assertRaises(ValueError): # wrong shape - rotate_fn(self.imt) - - rotate_fn = Rotate(10, keep_size=False) - with self.assertRaises(ValueError): # wrong mode - rotate_fn(self.imt[0], mode="trilinear") + for p in TEST_NDARRAYS: + rotate_fn = Rotate(10, True) + with self.assertRaises(ValueError): # wrong shape + rotate_fn(p(self.imt)) + + rotate_fn = Rotate(10, keep_size=False) + with self.assertRaises(ValueError): # wrong mode + rotate_fn(p(self.imt[0]), mode="trilinear") if __name__ == "__main__": diff --git a/tests/test_rotate90.py b/tests/test_rotate90.py index 4ab39d5cf6..9865120688 100644 --- a/tests/test_rotate90.py +++ b/tests/test_rotate90.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,45 +14,41 @@ import numpy as np from monai.transforms import Rotate90 -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestRotate90(NumpyImageTestCase2D): def test_rotate90_default(self): rotate = Rotate90() - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotated = rotate(p(self.imt[0])) + expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) def test_k(self): rotate = Rotate90(k=2) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 2, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotated = rotate(p(self.imt[0])) + expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) def test_spatial_axes(self): rotate = Rotate90(spatial_axes=(0, -1)) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, -1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotated = rotate(p(self.imt[0])) + expected = [np.rot90(channel, 1, (0, -1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) def test_prob_k_spatial_axes(self): rotate = Rotate90(k=2, spatial_axes=(0, 1)) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 2, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotated = rotate(p(self.imt[0])) + expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) if __name__ == "__main__": diff --git a/tests/test_rotate90d.py b/tests/test_rotate90d.py index 3d71ead82a..ef4bad9419 100644 --- a/tests/test_rotate90d.py +++ b/tests/test_rotate90d.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,49 +14,45 @@ import numpy as np from monai.transforms import Rotate90d -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestRotate90d(NumpyImageTestCase2D): def test_rotate90_default(self): key = "test" rotate = Rotate90d(keys=key) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotated = rotate({key: p(self.imt[0])}) + expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated[key], p(expected)) def test_k(self): key = None rotate = Rotate90d(keys=key, k=2) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 2, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotated = rotate({key: p(self.imt[0])}) + expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated[key], p(expected)) def test_spatial_axes(self): key = "test" rotate = Rotate90d(keys=key, spatial_axes=(0, 1)) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotated = rotate({key: p(self.imt[0])}) + expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated[key], p(expected)) def test_prob_k_spatial_axes(self): key = "test" rotate = Rotate90d(keys=key, k=2, spatial_axes=(0, 1)) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 2, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotated = rotate({key: p(self.imt[0])}) + expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated[key], p(expected)) def test_no_key(self): key = "unknown" diff --git a/tests/test_rotated.py b/tests/test_rotated.py index 2ea421101b..43b5a68f61 100644 --- a/tests/test_rotated.py +++ b/tests/test_rotated.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,36 +10,40 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import Rotated -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D -TEST_CASES_2D = [ - (-np.pi / 6, False, "bilinear", "border", False), - (-np.pi / 4, True, "bilinear", "border", False), - (np.pi / 4.5, True, "nearest", "reflection", False), - (-np.pi, False, "nearest", "zeros", False), - (np.pi / 2, False, "bilinear", "zeros", True), -] +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, -np.pi / 6, False, "bilinear", "border", False)) + TEST_CASES_2D.append((p, -np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_2D.append((p, -np.pi, False, "nearest", "zeros", False)) + TEST_CASES_2D.append((p, np.pi / 2, False, "bilinear", "zeros", True)) -TEST_CASES_3D = [ - (-np.pi / 6, False, "bilinear", "border", False), - (-np.pi / 4, True, "bilinear", "border", False), - (np.pi / 4.5, True, "nearest", "reflection", False), - (-np.pi, False, "nearest", "zeros", False), - (np.pi / 2, False, "bilinear", "zeros", True), -] +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append((p, -np.pi / 6, False, "bilinear", "border", False)) + TEST_CASES_3D.append((p, -np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_3D.append((p, np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_3D.append((p, -np.pi, False, "nearest", "zeros", False)) + TEST_CASES_3D.append((p, np.pi / 2, False, "bilinear", "zeros", True)) class TestRotated2D(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES_2D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): - rotate_fn = Rotated(("img", "seg"), angle, keep_size, (mode, "nearest"), padding_mode, align_corners) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): + rotate_fn = Rotated( + ("img", "seg"), angle, keep_size, (mode, "nearest"), padding_mode, align_corners, dtype=np.float64 + ) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) _order = 0 if mode == "nearest" else 1 @@ -52,6 +56,8 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne expected = scipy.ndimage.rotate( self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v good = np.sum(np.isclose(expected, rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") @@ -64,9 +70,11 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne class TestRotated3D(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): - rotate_fn = Rotated(("img", "seg"), [0, angle, 0], keep_size, (mode, "nearest"), padding_mode, align_corners) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): + rotate_fn = Rotated( + ("img", "seg"), [0, angle, 0], keep_size, (mode, "nearest"), padding_mode, align_corners, dtype=np.float64 + ) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) _order = 0 if mode == "nearest" else 1 @@ -79,6 +87,8 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne expected = scipy.ndimage.rotate( self.imt[0, 0], np.rad2deg(angle), (0, 2), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v good = np.sum(np.isclose(expected.astype(np.float32), rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 voxels.") @@ -86,14 +96,16 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne self.segn[0, 0], np.rad2deg(angle), (0, 2), not keep_size, order=0, mode=_mode, prefilter=False ) expected = np.stack(expected).astype(int) - self.assertLessEqual(np.count_nonzero(expected != rotated["seg"][0]), 130) + self.assertLessEqual(np.count_nonzero(expected != rotated["seg"][0]), 160) class TestRotated3DXY(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): - rotate_fn = Rotated(("img", "seg"), [0, 0, angle], keep_size, (mode, "nearest"), padding_mode, align_corners) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): + rotate_fn = Rotated( + ("img", "seg"), [0, 0, angle], keep_size, (mode, "nearest"), padding_mode, align_corners, dtype=np.float64 + ) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) _order = 0 if mode == "nearest" else 1 @@ -106,6 +118,8 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne expected = scipy.ndimage.rotate( self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v good = np.sum(np.isclose(expected, rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 voxels") @@ -113,7 +127,7 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne self.segn[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=0, mode=_mode, prefilter=False ) expected = np.stack(expected).astype(int) - self.assertLessEqual(np.count_nonzero(expected != rotated["seg"][0]), 130) + self.assertLessEqual(np.count_nonzero(expected != rotated["seg"][0]), 160) if __name__ == "__main__": diff --git a/tests/test_saliency_inferer.py b/tests/test_saliency_inferer.py index 416b7170ae..c97bcb7811 100644 --- a/tests/test_saliency_inferer.py +++ b/tests/test_saliency_inferer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_sample_slices.py b/tests/test_sample_slices.py new file mode 100644 index 0000000000..117d39b486 --- /dev/null +++ b/tests/test_sample_slices.py @@ -0,0 +1,41 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.utils import sample_slices +from tests.utils import TEST_NDARRAYS, assert_allclose + +# test data[:, [1, ], ...] +TEST_CASE_1 = [torch.tensor([[[0, 2], [1, 0]]]), 1, True, (1,), torch.tensor([[[1, 0]]])] +# test data[:, [0, 2], ...] +TEST_CASE_2 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, True, (0, 2), torch.tensor([[[0, 2], [4, 5]]])] +# test data[:, [0: 2], ...] +TEST_CASE_3 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, False, (0, 2), torch.tensor([[[0, 2], [1, 0]]])] +# test data[:, [1: ], ...] +TEST_CASE_4 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, False, (1, None), torch.tensor([[[1, 0], [4, 5]]])] +# test data[:, [0: 3: 2], ...] +TEST_CASE_5 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, False, (0, 3, 2), torch.tensor([[[0, 2], [4, 5]]])] + + +class TestSampleSlices(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + def test_shape(self, input_data, dim, as_indices, vals, expected_result): + for p in TEST_NDARRAYS: + result = sample_slices(p(input_data), dim, as_indices, *vals) + assert_allclose(p(expected_result), result) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_distributed_sampler.py b/tests/test_sampler_dist.py similarity index 97% rename from tests/test_distributed_sampler.py rename to tests/test_sampler_dist.py index 0a439874bd..8b140f3ff8 100644 --- a/tests/test_distributed_sampler.py +++ b/tests/test_sampler_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_save_classificationd.py b/tests/test_save_classificationd.py index 67dc0320a6..10c65c2044 100644 --- a/tests/test_save_classificationd.py +++ b/tests/test_save_classificationd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,12 +13,14 @@ import os import tempfile import unittest +from pathlib import Path import numpy as np import torch from monai.data import CSVSaver, decollate_batch from monai.transforms import Compose, CopyItemsd, SaveClassificationd +from monai.utils.enums import PostFix class TestSaveClassificationd(unittest.TestCase): @@ -27,34 +29,37 @@ def test_saved_content(self): data = [ { "pred": torch.zeros(8), - "image_meta_dict": {"filename_or_obj": ["testfile" + str(i) for i in range(8)]}, + PostFix.meta("image"): {"filename_or_obj": ["testfile" + str(i) for i in range(8)]}, }, { "pred": torch.zeros(8), - "image_meta_dict": {"filename_or_obj": ["testfile" + str(i) for i in range(8, 16)]}, + PostFix.meta("image"): {"filename_or_obj": ["testfile" + str(i) for i in range(8, 16)]}, }, { "pred": torch.zeros(8), - "image_meta_dict": {"filename_or_obj": ["testfile" + str(i) for i in range(16, 24)]}, + PostFix.meta("image"): {"filename_or_obj": ["testfile" + str(i) for i in range(16, 24)]}, }, ] - saver = CSVSaver(output_dir=tempdir, filename="predictions2.csv", overwrite=False, flush=False) + saver = CSVSaver( + output_dir=Path(tempdir), filename="predictions2.csv", overwrite=False, flush=False, delimiter="\t" + ) # set up test transforms post_trans = Compose( [ - CopyItemsd(keys="image_meta_dict", times=1, names="pred_meta_dict"), + CopyItemsd(keys=PostFix.meta("image"), times=1, names=PostFix.meta("pred")), # 1st saver saves data into CSV file SaveClassificationd( keys="pred", saver=None, meta_keys=None, - output_dir=tempdir, + output_dir=Path(tempdir), filename="predictions1.csv", + delimiter="\t", overwrite=True, ), # 2rd saver only saves data into the cache, manually finalize later - SaveClassificationd(keys="pred", saver=saver, meta_key_postfix="meta_dict"), + SaveClassificationd(keys="pred", saver=saver, meta_key_postfix=PostFix.meta()), ] ) # simulate inference 2 iterations @@ -71,9 +76,10 @@ def test_saved_content(self): trans2 = SaveClassificationd( keys="pred", saver=None, - meta_keys="image_meta_dict", # specify meta key, so no need to copy anymore + meta_keys=PostFix.meta("image"), # specify meta key, so no need to copy anymore output_dir=tempdir, filename="predictions1.csv", + delimiter="\t", overwrite=False, ) d = decollate_batch(data[2]) @@ -83,8 +89,8 @@ def test_saved_content(self): def _test_file(filename, count): filepath = os.path.join(tempdir, filename) self.assertTrue(os.path.exists(filepath)) - with open(filepath, "r") as f: - reader = csv.reader(f) + with open(filepath) as f: + reader = csv.reader(f, delimiter="\t") i = 0 for row in reader: self.assertEqual(row[0], "testfile" + str(i)) diff --git a/tests/test_save_image.py b/tests/test_save_image.py index f7c8e07f06..a1297c1e61 100644 --- a/tests/test_save_image.py +++ b/tests/test_save_image.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,36 +13,35 @@ import tempfile import unittest +import numpy as np import torch from parameterized import parameterized from monai.transforms import SaveImage -TEST_CASE_1 = [ - torch.randint(0, 255, (1, 2, 3, 4)), - {"filename_or_obj": "testfile0.nii.gz"}, - ".nii.gz", - False, -] +TEST_CASE_1 = [torch.randint(0, 255, (1, 2, 3, 4)), {"filename_or_obj": "testfile0.nii.gz"}, ".nii.gz", False] + +TEST_CASE_2 = [torch.randint(0, 255, (1, 2, 3, 4)), None, ".nii.gz", False] + +TEST_CASE_3 = [torch.randint(0, 255, (1, 2, 3, 4)), {"filename_or_obj": "testfile0.nrrd"}, ".nrrd", False] -TEST_CASE_2 = [ - torch.randint(0, 255, (1, 2, 3, 4)), - None, - ".nii.gz", +TEST_CASE_4 = [ + np.random.randint(0, 255, (3, 2, 4, 5), dtype=np.uint8), + {"filename_or_obj": "testfile0.dcm"}, + ".dcm", False, ] class TestSaveImage(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_saved_content(self, test_data, meta_data, output_ext, resample): with tempfile.TemporaryDirectory() as tempdir: trans = SaveImage( output_dir=tempdir, output_ext=output_ext, resample=resample, - # test saving into the same folder - separate_folder=False, + separate_folder=False, # test saving into the same folder ) trans(test_data, meta_data) diff --git a/tests/test_save_imaged.py b/tests/test_save_imaged.py index 35bbea9628..a6988683e5 100644 --- a/tests/test_save_imaged.py +++ b/tests/test_save_imaged.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,34 +17,42 @@ from parameterized import parameterized from monai.transforms import SaveImaged +from monai.utils.enums import PostFix TEST_CASE_1 = [ + {"img": torch.randint(0, 255, (1, 2, 3, 4)), PostFix.meta("img"): {"filename_or_obj": "testfile0.nii.gz"}}, + ".nii.gz", + False, +] + +TEST_CASE_2 = [ { "img": torch.randint(0, 255, (1, 2, 3, 4)), - "img_meta_dict": {"filename_or_obj": "testfile0.nii.gz"}, + PostFix.meta("img"): {"filename_or_obj": "testfile0.nii.gz"}, + "patch_index": 6, }, ".nii.gz", False, ] -TEST_CASE_2 = [ +TEST_CASE_3 = [ { "img": torch.randint(0, 255, (1, 2, 3, 4)), - "img_meta_dict": {"filename_or_obj": "testfile0.nii.gz"}, + PostFix.meta("img"): {"filename_or_obj": "testfile0.nrrd"}, "patch_index": 6, }, - ".nii.gz", + ".nrrd", False, ] class TestSaveImaged(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_saved_content(self, test_data, output_ext, resample): with tempfile.TemporaryDirectory() as tempdir: trans = SaveImaged( keys=["img", "pred"], - meta_keys="img_meta_dict", + meta_keys=PostFix.meta("img"), output_dir=tempdir, output_ext=output_ext, resample=resample, @@ -52,7 +60,7 @@ def test_saved_content(self, test_data, output_ext, resample): ) trans(test_data) - patch_index = test_data["img_meta_dict"].get("patch_index", None) + patch_index = test_data[PostFix.meta("img")].get("patch_index", None) patch_index = f"_{patch_index}" if patch_index is not None else "" filepath = os.path.join("testfile0", "testfile0" + "_trans" + patch_index + output_ext) self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) diff --git a/tests/test_save_state.py b/tests/test_save_state.py new file mode 100644 index 0000000000..c48b12ebdc --- /dev/null +++ b/tests/test_save_state.py @@ -0,0 +1,70 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import torch +import torch.optim as optim +from parameterized import parameterized + +from monai.networks import save_state + +TEST_CASE_1 = [torch.nn.PReLU(), ["weight"]] + +TEST_CASE_2 = [{"net": torch.nn.PReLU()}, ["net"]] + +TEST_CASE_3 = [{"net": torch.nn.PReLU(), "opt": optim.SGD(torch.nn.PReLU().parameters(), lr=0.02)}, ["net", "opt"]] + +TEST_CASE_4 = [torch.nn.DataParallel(torch.nn.PReLU()), ["weight"]] + +TEST_CASE_5 = [{"net": torch.nn.DataParallel(torch.nn.PReLU())}, ["net"]] + +TEST_CASE_6 = [torch.nn.PReLU(), ["weight"], True, True, None, {"pickle_protocol": 2}] + +TEST_CASE_7 = [torch.nn.PReLU().state_dict(), ["weight"]] + +TEST_CASE_8 = [torch.nn.PReLU(), ["weight"], False] + +TEST_CASE_9 = [torch.nn.PReLU(), ["weight"], True, False] + +TEST_CASE_10 = [torch.nn.PReLU(), ["weight"], True, True, torch.save] + + +class TestSaveState(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + TEST_CASE_10, + ] + ) + def test_file(self, src, expected_keys, create_dir=True, atomic=True, func=None, kwargs=None): + with tempfile.TemporaryDirectory() as tempdir: + path = os.path.join(tempdir, "test_ckpt.pt") + if kwargs is None: + kwargs = {} + save_state(src=src, path=path, create_dir=create_dir, atomic=atomic, func=func, **kwargs) + ckpt = dict(torch.load(path)) + for k in ckpt.keys(): + self.assertIn(k, expected_keys) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_savitzky_golay_filter.py b/tests/test_savitzky_golay_filter.py index c9bcd9687e..0e54276533 100644 --- a/tests/test_savitzky_golay_filter.py +++ b/tests/test_savitzky_golay_filter.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,7 +25,7 @@ torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value torch.Tensor([1 / 3]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1 # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) - 1e-15, # absolute tolerance + 1e-6, # absolute tolerance ] TEST_CASE_1D = [ @@ -35,21 +35,21 @@ .unsqueeze(0) .unsqueeze(0), # Expected output: zero padded, so linear interpolation # over length-3 windows will result in output of [2/3, 1, 2/3]. - 1e-15, # absolute tolerance + 1e-6, # absolute tolerance ] TEST_CASE_2D_AXIS_2 = [ {"window_length": 3, "order": 1}, # along default axis (2, first spatial dim) torch.ones((3, 2)).unsqueeze(0).unsqueeze(0), torch.Tensor([[2 / 3, 2 / 3], [1.0, 1.0], [2 / 3, 2 / 3]]).unsqueeze(0).unsqueeze(0), - 1e-15, # absolute tolerance + 1e-6, # absolute tolerance ] TEST_CASE_2D_AXIS_3 = [ {"window_length": 3, "order": 1, "axis": 3}, # along axis 3 (second spatial dim) torch.ones((2, 3)).unsqueeze(0).unsqueeze(0), torch.Tensor([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]).unsqueeze(0).unsqueeze(0), - 1e-15, # absolute tolerance + 1e-6, # absolute tolerance ] # Replicated-padding trivial tests @@ -59,7 +59,7 @@ torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1 # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) - 1e-15, # absolute tolerance + 1e-6, # absolute tolerance ] TEST_CASE_1D_REP = [ @@ -67,21 +67,21 @@ torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Expected output: zero padded, so linear interpolation # over length-3 windows will result in output of [2/3, 1, 2/3]. - 1e-15, # absolute tolerance + 1e-6, # absolute tolerance ] TEST_CASE_2D_AXIS_2_REP = [ {"window_length": 3, "order": 1, "mode": "replicate"}, # along default axis (2, first spatial dim) torch.ones((3, 2)).unsqueeze(0).unsqueeze(0), torch.Tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]).unsqueeze(0).unsqueeze(0), - 1e-15, # absolute tolerance + 1e-6, # absolute tolerance ] TEST_CASE_2D_AXIS_3_REP = [ {"window_length": 3, "order": 1, "axis": 3, "mode": "replicate"}, # along axis 3 (second spatial dim) torch.ones((2, 3)).unsqueeze(0).unsqueeze(0), torch.Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]).unsqueeze(0).unsqueeze(0), - 1e-15, # absolute tolerance + 1e-6, # absolute tolerance ] # Sine smoothing @@ -99,57 +99,40 @@ class TestSavitzkyGolayCPU(unittest.TestCase): @parameterized.expand( - [ - TEST_CASE_SINGLE_VALUE, - TEST_CASE_1D, - TEST_CASE_2D_AXIS_2, - TEST_CASE_2D_AXIS_3, - TEST_CASE_SINE_SMOOTH, - ] + [TEST_CASE_SINGLE_VALUE, TEST_CASE_1D, TEST_CASE_2D_AXIS_2, TEST_CASE_2D_AXIS_3, TEST_CASE_SINE_SMOOTH] ) - def test_value(self, arguments, image, expected_data, atol): + def test_value(self, arguments, image, expected_data, atol, rtol=1e-5): result = SavitzkyGolayFilter(**arguments)(image) - np.testing.assert_allclose(result, expected_data, atol=atol) + np.testing.assert_allclose(result, expected_data, atol=atol, rtol=rtol) class TestSavitzkyGolayCPUREP(unittest.TestCase): @parameterized.expand( [TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP] ) - def test_value(self, arguments, image, expected_data, atol): + def test_value(self, arguments, image, expected_data, atol, rtol=1e-5): result = SavitzkyGolayFilter(**arguments)(image) - np.testing.assert_allclose(result, expected_data, atol=atol) + np.testing.assert_allclose(result, expected_data, atol=atol, rtol=rtol) @skip_if_no_cuda class TestSavitzkyGolayGPU(unittest.TestCase): @parameterized.expand( - [ - TEST_CASE_SINGLE_VALUE, - TEST_CASE_1D, - TEST_CASE_2D_AXIS_2, - TEST_CASE_2D_AXIS_3, - TEST_CASE_SINE_SMOOTH, - ] + [TEST_CASE_SINGLE_VALUE, TEST_CASE_1D, TEST_CASE_2D_AXIS_2, TEST_CASE_2D_AXIS_3, TEST_CASE_SINE_SMOOTH] ) - def test_value(self, arguments, image, expected_data, atol): + def test_value(self, arguments, image, expected_data, atol, rtol=1e-5): result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda")) - np.testing.assert_allclose(result.cpu(), expected_data, atol=atol) + np.testing.assert_allclose(result.cpu(), expected_data, atol=atol, rtol=rtol) @skip_if_no_cuda class TestSavitzkyGolayGPUREP(unittest.TestCase): @parameterized.expand( - [ - TEST_CASE_SINGLE_VALUE_REP, - TEST_CASE_1D_REP, - TEST_CASE_2D_AXIS_2_REP, - TEST_CASE_2D_AXIS_3_REP, - ] + [TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP] ) - def test_value(self, arguments, image, expected_data, atol): + def test_value(self, arguments, image, expected_data, atol, rtol=1e-5): result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda")) - np.testing.assert_allclose(result.cpu(), expected_data, atol=atol) + np.testing.assert_allclose(result.cpu(), expected_data, atol=atol, rtol=rtol) if __name__ == "__main__": diff --git a/tests/test_savitzky_golay_smooth.py b/tests/test_savitzky_golay_smooth.py index 45d0ea3e4d..ac42cf806e 100644 --- a/tests/test_savitzky_golay_smooth.py +++ b/tests/test_savitzky_golay_smooth.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,14 +25,14 @@ np.expand_dims(np.array([1.0]), 0), # Input data: Single value np.expand_dims(np.array([1 / 3]), 0), # Expected output: With a window length of 3 and polyorder 1 # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) - 1e-15, # absolute tolerance + 1e-5, # absolute tolerance ] TEST_CASE_2D_AXIS_2 = [ {"window_length": 3, "order": 1, "axis": 2}, # along axis 2 (second spatial dim) np.expand_dims(np.ones((2, 3)), 0), np.expand_dims(np.array([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]), 0), - 1e-15, # absolute tolerance + 1e-5, # absolute tolerance ] # Replicated-padding trivial tests @@ -42,7 +42,7 @@ np.expand_dims(np.array([1.0]), 0), # Input data: Single value np.expand_dims(np.array([1.0]), 0), # Expected output: With a window length of 3 and polyorder 1 # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) - 1e-15, # absolute tolerance + 1e-5, # absolute tolerance ] # Sine smoothing @@ -59,19 +59,13 @@ class TestSavitzkyGolaySmooth(unittest.TestCase): - @parameterized.expand([TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH]) + @parameterized.expand( + [TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH, TEST_CASE_SINGLE_VALUE_REP] + ) def test_value(self, arguments, image, expected_data, atol): for p in TEST_NDARRAYS: - result = SavitzkyGolaySmooth(**arguments)(p(image)) - torch.testing.assert_allclose(result, p(expected_data.astype(np.float32)), rtol=1e-7, atol=atol) - - -class TestSavitzkyGolaySmoothREP(unittest.TestCase): - @parameterized.expand([TEST_CASE_SINGLE_VALUE_REP]) - def test_value(self, arguments, image, expected_data, atol): - for p in TEST_NDARRAYS: - result = SavitzkyGolaySmooth(**arguments)(p(image)) - torch.testing.assert_allclose(result, p(expected_data.astype(np.float32)), rtol=1e-7, atol=atol) + result = SavitzkyGolaySmooth(**arguments)(p(image.astype(np.float32))) + torch.testing.assert_allclose(result, p(expected_data.astype(np.float32)), rtol=1e-4, atol=atol) if __name__ == "__main__": diff --git a/tests/test_savitzky_golay_smoothd.py b/tests/test_savitzky_golay_smoothd.py new file mode 100644 index 0000000000..6f0b33f533 --- /dev/null +++ b/tests/test_savitzky_golay_smoothd.py @@ -0,0 +1,72 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import SavitzkyGolaySmoothd +from tests.utils import TEST_NDARRAYS + +# Zero-padding trivial tests + +TEST_CASE_SINGLE_VALUE = [ + {"keys": "img", "window_length": 3, "order": 1}, + np.expand_dims(np.array([1.0]), 0), # Input data: Single value + np.expand_dims(np.array([1 / 3]), 0), # Expected output: With a window length of 3 and polyorder 1 + # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) + 1e-5, # absolute tolerance +] + +TEST_CASE_2D_AXIS_2 = [ + {"keys": "img", "window_length": 3, "order": 1, "axis": 2}, # along axis 2 (second spatial dim) + np.expand_dims(np.ones((2, 3)), 0), + np.expand_dims(np.array([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]), 0), + 1e-5, # absolute tolerance +] + +# Replicated-padding trivial tests + +TEST_CASE_SINGLE_VALUE_REP = [ + {"keys": "img", "window_length": 3, "order": 1, "mode": "replicate"}, + np.expand_dims(np.array([1.0]), 0), # Input data: Single value + np.expand_dims(np.array([1.0]), 0), # Expected output: With a window length of 3 and polyorder 1 + # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) + 1e-5, # absolute tolerance +] + +# Sine smoothing + +TEST_CASE_SINE_SMOOTH = [ + {"keys": "img", "window_length": 3, "order": 1}, + # Sine wave with period equal to savgol window length (windowed to reduce edge effects). + np.expand_dims(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100), 0), + # Should be smoothed out to zeros + np.expand_dims(np.zeros(100), 0), + # tolerance chosen by examining output of SciPy.signal.savgol_filter() when provided the above input + 2e-2, # absolute tolerance +] + + +class TestSavitzkyGolaySmoothd(unittest.TestCase): + @parameterized.expand( + [TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH, TEST_CASE_SINGLE_VALUE_REP] + ) + def test_value(self, arguments, image, expected_data, atol): + for p in TEST_NDARRAYS: + result = SavitzkyGolaySmoothd(**arguments)({"img": p(image.astype(np.float32))})["img"] + torch.testing.assert_allclose(result, p(expected_data.astype(np.float32)), rtol=1e-4, atol=atol) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_scale_intensity.py b/tests/test_scale_intensity.py index c2485af616..bd1adac4f4 100644 --- a/tests/test_scale_intensity.py +++ b/tests/test_scale_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import numpy as np from monai.transforms import ScaleIntensity +from monai.transforms.utils import rescale_array from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose @@ -26,14 +27,45 @@ def test_range_scale(self): maxa = self.imt.max() norm = (self.imt - mina) / (maxa - mina) expected = p((norm * (2.0 - 1.0)) + 1.0) - assert_allclose(result, expected, rtol=1e-7, atol=0) + assert_allclose(result, expected, type_test=False, rtol=1e-7, atol=0) def test_factor_scale(self): for p in TEST_NDARRAYS: scaler = ScaleIntensity(minv=None, maxv=None, factor=0.1) result = scaler(p(self.imt)) expected = p((self.imt * (1 + 0.1)).astype(np.float32)) - assert_allclose(result, expected, rtol=1e-7, atol=0) + assert_allclose(result, p(expected), rtol=1e-7, atol=0) + + def test_max_none(self): + for p in TEST_NDARRAYS: + scaler = ScaleIntensity(minv=0.0, maxv=None, factor=0.1) + result = scaler(p(self.imt)) + expected = rescale_array(p(self.imt), minv=0.0, maxv=None) + assert_allclose(result, expected, rtol=1e-3, atol=1e-3) + + def test_int(self): + """integers should be handled by converting them to floats first.""" + for p in TEST_NDARRAYS: + scaler = ScaleIntensity(minv=1.0, maxv=2.0) + result = scaler(p(self.imt.astype(int))) + _imt = self.imt.astype(int).astype(np.float32) + mina = _imt.min() + maxa = _imt.max() + norm = (_imt - mina) / (maxa - mina) + expected = p((norm * (2.0 - 1.0)) + 1.0) + assert_allclose(result, expected, type_test=False, rtol=1e-7, atol=0) + + def test_channel_wise(self): + for p in TEST_NDARRAYS: + scaler = ScaleIntensity(minv=1.0, maxv=2.0, channel_wise=True) + data = p(np.tile(self.imt, (3, 1, 1, 1))) + result = scaler(data) + mina = self.imt.min() + maxa = self.imt.max() + for i, c in enumerate(data): + norm = (c - mina) / (maxa - mina) + expected = p((norm * (2.0 - 1.0)) + 1.0) + assert_allclose(result[i], expected, type_test=False, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_scale_intensity_range.py b/tests/test_scale_intensity_range.py index cba07d9157..faddf9001b 100644 --- a/tests/test_scale_intensity_range.py +++ b/tests/test_scale_intensity_range.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,16 +14,25 @@ import numpy as np from monai.transforms import ScaleIntensityRange -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class IntensityScaleIntensityRange(NumpyImageTestCase2D): def test_image_scale_intensity_range(self): - scaler = ScaleIntensityRange(a_min=20, a_max=108, b_min=50, b_max=80) - scaled = scaler(self.imt) - expected = (self.imt - 20) / 88 - expected = expected * 30 + 50 - self.assertTrue(np.allclose(scaled, expected)) + scaler = ScaleIntensityRange(a_min=20, a_max=108, b_min=50, b_max=80, dtype=np.uint8) + for p in TEST_NDARRAYS: + scaled = scaler(p(self.imt)) + self.assertTrue(scaled.dtype, np.uint8) + expected = (((self.imt - 20) / 88) * 30 + 50).astype(np.uint8) + assert_allclose(scaled, p(expected)) + + def test_image_scale_intensity_range_none_clip(self): + scaler = ScaleIntensityRange(a_min=20, a_max=108, b_min=None, b_max=80, clip=True, dtype=np.uint8) + for p in TEST_NDARRAYS: + scaled = scaler(p(self.imt)) + self.assertTrue(scaled.dtype, np.uint8) + expected = (np.clip((self.imt - 20) / 88, None, 80)).astype(np.uint8) + assert_allclose(scaled, p(expected)) if __name__ == "__main__": diff --git a/tests/test_scale_intensity_range_percentiles.py b/tests/test_scale_intensity_range_percentiles.py index 015162c8de..f8656dd929 100644 --- a/tests/test_scale_intensity_range_percentiles.py +++ b/tests/test_scale_intensity_range_percentiles.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,12 +14,12 @@ import numpy as np from monai.transforms.intensity.array import ScaleIntensityRangePercentiles -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestScaleIntensityRangePercentiles(NumpyImageTestCase2D): def test_scaling(self): - img = self.imt + img = self.imt[0] lower = 10 upper = 99 b_min = 0 @@ -27,13 +27,14 @@ def test_scaling(self): a_min = np.percentile(img, lower) a_max = np.percentile(img, upper) - expected = (img - a_min) / (a_max - a_min) - expected = (expected * (b_max - b_min)) + b_min - scaler = ScaleIntensityRangePercentiles(lower=lower, upper=upper, b_min=b_min, b_max=b_max) - self.assertTrue(np.allclose(expected, scaler(img))) + expected = (((img - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min).astype(np.uint8) + scaler = ScaleIntensityRangePercentiles(lower=lower, upper=upper, b_min=b_min, b_max=b_max, dtype=np.uint8) + for p in TEST_NDARRAYS: + result = scaler(p(img)) + assert_allclose(result, p(expected), rtol=1e-4) def test_relative_scaling(self): - img = self.imt + img = self.imt[0] lower = 10 upper = 99 b_min = 100 @@ -47,7 +48,16 @@ def test_relative_scaling(self): expected_img = (img - expected_a_min) / (expected_a_max - expected_a_min) expected_img = (expected_img * (expected_b_max - expected_b_min)) + expected_b_min - self.assertTrue(np.allclose(expected_img, scaler(img))) + for p in TEST_NDARRAYS: + result = scaler(p(img)) + assert_allclose(result, p(expected_img), rtol=1e-3) + + scaler = ScaleIntensityRangePercentiles( + lower=lower, upper=upper, b_min=b_min, b_max=b_max, relative=True, clip=True + ) + for p in TEST_NDARRAYS: + result = scaler(p(img)) + assert_allclose(result, p(np.clip(expected_img, expected_b_min, expected_b_max)), rtol=1e-4) def test_invalid_instantiation(self): self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=-10, upper=99, b_min=0, b_max=255) @@ -55,6 +65,26 @@ def test_invalid_instantiation(self): self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=30, upper=-20, b_min=0, b_max=255) self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=30, upper=900, b_min=0, b_max=255) + def test_channel_wise(self): + img = np.tile(self.imt, (3, 1, 1, 1)) + lower = 10 + upper = 99 + b_min = 0 + b_max = 255 + scaler = ScaleIntensityRangePercentiles( + lower=lower, upper=upper, b_min=b_min, b_max=b_max, channel_wise=True, dtype=np.uint8 + ) + expected = [] + for c in img: + a_min = np.percentile(c, lower) + a_max = np.percentile(c, upper) + expected.append((((c - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min).astype(np.uint8)) + expected = np.stack(expected) + + for p in TEST_NDARRAYS: + result = scaler(p(img)) + assert_allclose(result, p(expected), rtol=1e-4) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_scale_intensity_range_percentilesd.py b/tests/test_scale_intensity_range_percentilesd.py index 9d0fe8284a..5441832a77 100644 --- a/tests/test_scale_intensity_range_percentilesd.py +++ b/tests/test_scale_intensity_range_percentilesd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,14 +14,12 @@ import numpy as np from monai.transforms.intensity.dictionary import ScaleIntensityRangePercentilesd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestScaleIntensityRangePercentilesd(NumpyImageTestCase2D): def test_scaling(self): img = self.imt - data = {} - data["img"] = img lower = 10 upper = 99 b_min = 0 @@ -29,12 +27,14 @@ def test_scaling(self): a_min = np.percentile(img, lower) a_max = np.percentile(img, upper) - expected = (img - a_min) / (a_max - a_min) - expected = (expected * (b_max - b_min)) + b_min - - scaler = ScaleIntensityRangePercentilesd(keys=data.keys(), lower=lower, upper=upper, b_min=b_min, b_max=b_max) + expected = (((img - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min).astype(np.uint8) - self.assertTrue(np.allclose(expected, scaler(data)["img"])) + for p in TEST_NDARRAYS: + data = {"img": p(img)} + scaler = ScaleIntensityRangePercentilesd( + keys=data.keys(), lower=lower, upper=upper, b_min=b_min, b_max=b_max, dtype=np.uint8 + ) + assert_allclose(p(expected), scaler(data)["img"], rtol=1e-4) def test_relative_scaling(self): img = self.imt @@ -55,7 +55,7 @@ def test_relative_scaling(self): expected_img = (img - expected_a_min) / (expected_a_max - expected_a_min) expected_img = (expected_img * (expected_b_max - expected_b_min)) + expected_b_min - self.assertTrue(np.allclose(expected_img, scaler(data)["img"])) + np.testing.assert_allclose(expected_img, scaler(data)["img"]) def test_invalid_instantiation(self): self.assertRaises( @@ -70,6 +70,29 @@ def test_invalid_instantiation(self): self.assertRaises( ValueError, ScaleIntensityRangePercentilesd, keys=["img"], lower=30, upper=1000, b_min=0, b_max=255 ) + with self.assertRaises(ValueError): + s = ScaleIntensityRangePercentilesd(keys=["img"], lower=30, upper=90, b_min=None, b_max=20, relative=True) + s(self.imt) + + def test_channel_wise(self): + img = np.tile(self.imt, (3, 1, 1, 1)) + lower = 10 + upper = 99 + b_min = 0 + b_max = 255 + scaler = ScaleIntensityRangePercentilesd( + keys="img", lower=lower, upper=upper, b_min=b_min, b_max=b_max, channel_wise=True, dtype=np.uint8 + ) + expected = [] + for c in img: + a_min = np.percentile(c, lower) + a_max = np.percentile(c, upper) + expected.append((((c - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min).astype(np.uint8)) + expected = np.stack(expected) + + for p in TEST_NDARRAYS: + data = {"img": p(img)} + assert_allclose(scaler(data)["img"], p(expected), rtol=1e-4) if __name__ == "__main__": diff --git a/tests/test_scale_intensity_ranged.py b/tests/test_scale_intensity_ranged.py index a8cac414e8..ffbd3e44c4 100644 --- a/tests/test_scale_intensity_ranged.py +++ b/tests/test_scale_intensity_ranged.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,20 +11,27 @@ import unittest -import numpy as np - from monai.transforms import ScaleIntensityRanged -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class IntensityScaleIntensityRanged(NumpyImageTestCase2D): def test_image_scale_intensity_ranged(self): key = "img" scaler = ScaleIntensityRanged(keys=key, a_min=20, a_max=108, b_min=50, b_max=80) - scaled = scaler({key: self.imt}) - expected = (self.imt - 20) / 88 - expected = expected * 30 + 50 - self.assertTrue(np.allclose(scaled[key], expected)) + for p in TEST_NDARRAYS: + scaled = scaler({key: p(self.imt)}) + expected = (self.imt - 20) / 88 + expected = expected * 30 + 50 + assert_allclose(scaled[key], p(expected)) + + def test_image_scale_intensity_ranged_none(self): + key = "img" + scaler = ScaleIntensityRanged(keys=key, a_min=20, a_max=108, b_min=None, b_max=None) + for p in TEST_NDARRAYS: + scaled = scaler({key: p(self.imt)}) + expected = (self.imt - 20) / 88 + assert_allclose(scaled[key], p(expected)) if __name__ == "__main__": diff --git a/tests/test_scale_intensityd.py b/tests/test_scale_intensityd.py index 6e13dbc272..42f1527490 100644 --- a/tests/test_scale_intensityd.py +++ b/tests/test_scale_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,23 +19,36 @@ class TestScaleIntensityd(NumpyImageTestCase2D): def test_range_scale(self): + key = "img" for p in TEST_NDARRAYS: - key = "img" scaler = ScaleIntensityd(keys=[key], minv=1.0, maxv=2.0) result = scaler({key: p(self.imt)}) mina = np.min(self.imt) maxa = np.max(self.imt) norm = (self.imt - mina) / (maxa - mina) expected = (norm * (2.0 - 1.0)) + 1.0 - assert_allclose(result[key], expected) + assert_allclose(result[key], p(expected)) def test_factor_scale(self): + key = "img" for p in TEST_NDARRAYS: - key = "img" scaler = ScaleIntensityd(keys=[key], minv=None, maxv=None, factor=0.1) result = scaler({key: p(self.imt)}) expected = (self.imt * (1 + 0.1)).astype(np.float32) - assert_allclose(result[key], expected) + assert_allclose(result[key], p(expected)) + + def test_channel_wise(self): + key = "img" + for p in TEST_NDARRAYS: + scaler = ScaleIntensityd(keys=[key], minv=1.0, maxv=2.0, channel_wise=True) + data = p(self.imt) + result = scaler({key: data}) + mina = self.imt.min() + maxa = self.imt.max() + for i, c in enumerate(data): + norm = (c - mina) / (maxa - mina) + expected = p((norm * (2.0 - 1.0)) + 1.0) + assert_allclose(result[key][i], expected, type_test=False, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_se_block.py b/tests/test_se_block.py index 1f515a7fb4..88983a7746 100644 --- a/tests/test_se_block.py +++ b/tests/test_se_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_se_blocks.py b/tests/test_se_blocks.py index e9aed7d9d9..400ee85e7f 100644 --- a/tests/test_se_blocks.py +++ b/tests/test_se_blocks.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_seg_loss_integration.py b/tests/test_seg_loss_integration.py index d2f991f160..f4a0f25267 100644 --- a/tests/test_seg_loss_integration.py +++ b/tests/test_seg_loss_integration.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -91,7 +91,7 @@ def test_convergence(self, loss_type, loss_args, forward_args): # define a one layer model class OnelayerNet(nn.Module): def __init__(self): - super(OnelayerNet, self).__init__() + super().__init__() self.layer_1 = nn.Linear(num_voxels, 200) self.acti = nn.ReLU() self.layer_2 = nn.Linear(200, num_voxels * num_classes) diff --git a/tests/test_segresnet.py b/tests/test_segresnet.py index ea6ca5b5dd..b7c37f87b9 100644 --- a/tests/test_segresnet.py +++ b/tests/test_segresnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -70,6 +70,7 @@ "init_filters": init_filters, "out_channels": out_channels, "upsample_mode": upsample_mode, + "act": ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), "input_image_size": ([16] * spatial_dims), "vae_estimate_std": vae_estimate_std, }, @@ -88,7 +89,7 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) def test_ill_arg(self): - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): SegResNet(spatial_dims=4) def test_script(self): diff --git a/tests/test_segresnet_block.py b/tests/test_segresnet_block.py index eb8cc9676b..9bb435ac1d 100644 --- a/tests/test_segresnet_block.py +++ b/tests/test_segresnet_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_select_cross_validation_folds.py b/tests/test_select_cross_validation_folds.py index 6dbd004e71..7693baca80 100644 --- a/tests/test_select_cross_validation_folds.py +++ b/tests/test_select_cross_validation_folds.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_select_itemsd.py b/tests/test_select_itemsd.py index bf63864eb0..ba75a27cff 100644 --- a/tests/test_select_itemsd.py +++ b/tests/test_select_itemsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 3d561aac2f..407bee341c 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -28,11 +28,7 @@ for num_heads in [4, 6, 8, 12]: test_case = [ - { - "hidden_size": hidden_size, - "num_heads": num_heads, - "dropout_rate": dropout_rate, - }, + {"hidden_size": hidden_size, "num_heads": num_heads, "dropout_rate": dropout_rate}, (2, 512, hidden_size), (2, 512, hidden_size), ] diff --git a/tests/test_senet.py b/tests/test_senet.py index 1c6222d6a0..57ca49237d 100644 --- a/tests/test_senet.py +++ b/tests/test_senet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest from typing import TYPE_CHECKING from unittest import skipUnless @@ -16,10 +17,11 @@ import torch from parameterized import parameterized +import monai.networks.nets.senet as se_mod from monai.networks import eval_mode from monai.networks.nets import SENet154, SEResNet50, SEResNet101, SEResNet152, SEResNext50, SEResNext101 from monai.utils import optional_import -from tests.utils import test_pretrained_networks, test_script_save +from tests.utils import test_is_quick, test_pretrained_networks, test_script_save, testing_data_config if TYPE_CHECKING: import pretrainedmodels @@ -31,6 +33,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu" + NET_ARGS = {"spatial_dims": 3, "in_channels": 2, "num_classes": 2} TEST_CASE_1 = [SENet154, NET_ARGS] TEST_CASE_2 = [SEResNet50, NET_ARGS] @@ -60,6 +63,43 @@ def test_script(self, net, net_args): class TestPretrainedSENET(unittest.TestCase): + def setUp(self): + self.original_urls = se_mod.SE_NET_MODELS.copy() + if test_is_quick(): + testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") + testing_data_urls = { + "senet154": { + "url": testing_data_config("models", "senet154-c7b49a05", "url"), + "filename": "senet154-c7b49a05.pth", + }, + "se_resnet50": { + "url": testing_data_config("models", "se_resnet50-ce0d4300", "url"), + "filename": "se_resnet50-ce0d4300.pth", + }, + "se_resnet101": { + "url": testing_data_config("models", "se_resnet101-7e38fcc6", "url"), + "filename": "se_resnet101-7e38fcc6.pth", + }, + "se_resnet152": { + "url": testing_data_config("models", "se_resnet152-d17c99b7", "url"), + "filename": "se_resnet152-d17c99b7.pth", + }, + "se_resnext50_32x4d": { + "url": testing_data_config("models", "se_resnext50_32x4d-a260b3a4", "url"), + "filename": "se_resnext50_32x4d-a260b3a4.pth", + }, + "se_resnext101_32x4d": { + "url": testing_data_config("models", "se_resnext101_32x4d-3b2fe3d8", "url"), + "filename": "se_resnext101_32x4d-3b2fe3d8.pth", + }, + } + for item in testing_data_urls: + testing_data_urls[item]["filename"] = os.path.join(testing_dir, testing_data_urls[item]["filename"]) + se_mod.SE_NET_MODELS = testing_data_urls + + def tearDown(self): + se_mod.SE_NET_MODELS = self.original_urls.copy() + @parameterized.expand([TEST_CASE_PRETRAINED_1]) def test_senet_shape(self, model, input_param): net = test_pretrained_networks(model, input_param, device) diff --git a/tests/test_separable_filter.py b/tests/test_separable_filter.py new file mode 100644 index 0000000000..e152ad2c2b --- /dev/null +++ b/tests/test_separable_filter.py @@ -0,0 +1,85 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from monai.networks.layers import separable_filtering + + +class SeparableFilterTestCase(unittest.TestCase): + def test_1d(self): + a = torch.tensor([[list(range(10))]], dtype=torch.float) + out = separable_filtering(a, torch.tensor([-1, 0, 1])) + expected = np.array([[[1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, -8.0]]]) + np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4) + if torch.cuda.is_available(): + out = separable_filtering(a.cuda(), torch.tensor([-1, 0, 1]).cuda()) + np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4) + + def test_2d(self): + a = torch.tensor([[[list(range(7)), list(range(7, 0, -1)), list(range(7))]]], dtype=torch.float) + expected = np.array( + [ + [28.0, 28.0, 28.0, 28.0, 28.0, 28.0], + [30.0, 34.0, 38.0, 42.0, 46.0, 50.0], + [28.0, 28.0, 28.0, 28.0, 28.0, 28.0], + ] + ) + expected = expected[None][None] + out = separable_filtering(a, [torch.tensor([1, 1, 1]), torch.tensor([2, 2])]) + np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4) + if torch.cuda.is_available(): + out = separable_filtering(a.cuda(), [torch.tensor([1, 1, 1]).cuda(), torch.tensor([2, 2]).cuda()]) + np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4) + + def test_3d(self): + a = torch.tensor( + [[list(range(7)), list(range(7)), list(range(7))], [list(range(7)), list(range(7)), list(range(7))]], + dtype=torch.float, + ) + a = a[None][None] + a = a.expand(2, 3, -1, -1, -1) + expected = np.array( + [ + [ + [4.0, 12.0, 24.0, 36.0, 48.0, 60.0, 44.0], + [6.0, 18.0, 36.0, 54.0, 72.0, 90.0, 66.0], + [4.0, 12.0, 24.0, 36.0, 48.0, 60.0, 44.0], + ], + [ + [4.0, 12.0, 24.0, 36.0, 48.0, 60.0, 44.0], + [6.0, 18.0, 36.0, 54.0, 72.0, 90.0, 66.0], + [4.0, 12.0, 24.0, 36.0, 48.0, 60.0, 44.0], + ], + ] + ) + expected = expected + # testing shapes + k = torch.tensor([1, 1, 1]) + for kernel in (k, [k] * 3): + out = separable_filtering(a, kernel) + np.testing.assert_allclose(out.cpu().numpy()[1][2], expected, rtol=1e-4) + if torch.cuda.is_available(): + out = separable_filtering( + a.cuda(), kernel.cuda() if isinstance(kernel, torch.Tensor) else [k.cuda() for k in kernel] + ) + np.testing.assert_allclose(out.cpu().numpy()[0][1], expected, rtol=1e-4) + + def test_wrong_args(self): + with self.assertRaisesRegex(TypeError, ""): + separable_filtering(((1, 1, 1, 2, 3, 2)), torch.ones((2,))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_set_determinism.py b/tests/test_set_determinism.py index 537aa36676..7d6c54909d 100644 --- a/tests/test_set_determinism.py +++ b/tests/test_set_determinism.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -40,6 +40,8 @@ def test_values(self): self.assertEqual(seed, get_seed()) a = np.random.randint(seed) b = torch.randint(seed, (1,)) + # tset when global flag support is disabled + torch.backends.disable_global_flags() set_determinism(seed=seed) c = np.random.randint(seed) d = torch.randint(seed, (1,)) diff --git a/tests/test_set_visible_devices.py b/tests/test_set_visible_devices.py index b6da879f4b..75cbd6fb0d 100644 --- a/tests/test_set_visible_devices.py +++ b/tests/test_set_visible_devices.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_shift_intensity.py b/tests/test_shift_intensity.py index b73c18b6a5..ecded268ab 100644 --- a/tests/test_shift_intensity.py +++ b/tests/test_shift_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_shift_intensityd.py b/tests/test_shift_intensityd.py index 0396857781..e28b7f54e4 100644 --- a/tests/test_shift_intensityd.py +++ b/tests/test_shift_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import numpy as np from monai.transforms import IntensityStatsd, ShiftIntensityd +from monai.utils.enums import PostFix from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose @@ -24,13 +25,13 @@ def test_value(self): shifter = ShiftIntensityd(keys=[key], offset=1.0) result = shifter({key: p(self.imt)}) expected = self.imt + 1.0 - assert_allclose(result[key], expected) + assert_allclose(result[key], p(expected)) def test_factor(self): key = "img" stats = IntensityStatsd(keys=key, ops="max", key_prefix="orig") shifter = ShiftIntensityd(keys=[key], offset=1.0, factor_key=["orig_max"]) - data = {key: self.imt, key + "_meta_dict": {"affine": None}} + data = {key: self.imt, PostFix.meta(key): {"affine": None}} result = shifter(stats(data)) expected = self.imt + 1.0 * np.nanmax(self.imt) diff --git a/tests/test_simple_aspp.py b/tests/test_simple_aspp.py index fbc8cb37d1..9c952bd791 100644 --- a/tests/test_simple_aspp.py +++ b/tests/test_simple_aspp.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_simulatedelay.py b/tests/test_simulatedelay.py index 3a4686218e..3a0507dae7 100644 --- a/tests/test_simulatedelay.py +++ b/tests/test_simulatedelay.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,7 +24,7 @@ class TestSimulateDelay(NumpyImageTestCase2D): def test_value(self, delay_test_time: float): resize = SimulateDelay(delay_time=delay_test_time) start: float = time.time() - result = resize(self.imt[0]) + _ = resize(self.imt[0]) stop: float = time.time() measured_approximate: float = stop - start np.testing.assert_allclose(delay_test_time, measured_approximate, rtol=0.5) diff --git a/tests/test_simulatedelayd.py b/tests/test_simulatedelayd.py index 58bd3eb6b8..cbabb68e0f 100644 --- a/tests/test_simulatedelayd.py +++ b/tests/test_simulatedelayd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_skip_connection.py b/tests/test_skip_connection.py index 2118842ed0..f523891084 100644 --- a/tests/test_skip_connection.py +++ b/tests/test_skip_connection.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,11 +24,7 @@ result_shape = (input_shape[0] * 2, *input_shape[1:]) else: result_shape = input_shape - test_case = [ - {"dim": 0, "mode": type_1}, - input_shape, - result_shape, - ] + test_case = [{"dim": 0, "mode": type_1}, input_shape, result_shape] TEST_CASES_3D.append(test_case) diff --git a/tests/test_slice_inferer.py b/tests/test_slice_inferer.py new file mode 100644 index 0000000000..3c52082d85 --- /dev/null +++ b/tests/test_slice_inferer.py @@ -0,0 +1,51 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.inferers import SliceInferer +from monai.networks.nets import UNet + +TEST_CASES = ["0", "1", "2"] + + +class TestSliceInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, spatial_dim): + spatial_dim = int(spatial_dim) + + model = UNet( + spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8, 16), strides=(2, 2), num_res_units=2 + ) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + + # Initialize a dummy 3D tensor volume with shape (N,C,D,H,W) + input_volume = torch.ones(1, 1, 64, 256, 256, device=device) + + # Remove spatial dim to slide across from the roi_size + roi_size = list(input_volume.shape[2:]) + roi_size.pop(spatial_dim) + + # Initialize and run inferer + inferer = SliceInferer(roi_size=roi_size, spatial_dim=spatial_dim, sw_batch_size=1, cval=-1) + result = inferer(input_volume, model) + + self.assertTupleEqual(result.shape, input_volume.shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py index a22e5990bf..5b6995c1ea 100644 --- a/tests/test_sliding_window_inference.py +++ b/tests/test_sliding_window_inference.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,8 +16,11 @@ from parameterized import parameterized from monai.inferers import SlidingWindowInferer, sliding_window_inference +from monai.utils import optional_import from tests.utils import skip_if_no_cuda +_, has_tqdm = optional_import("tqdm") + TEST_CASES = [ [(2, 3, 16), (4,), 3, 0.25, "constant", torch.device("cpu:0")], # 1D small roi [(2, 3, 16, 15, 7, 9), 4, 3, 0.25, "constant", torch.device("cpu:0")], # 4D small roi @@ -33,14 +36,7 @@ [(1, 3, 16, 7), (80, 50), 7, 0.5, "gaussian", torch.device("cpu:0")], # 2D large overlap, gaussian [(1, 3, 16, 15, 7), (4, 10, 7), 3, 0.25, "gaussian", torch.device("cpu:0")], # 3D small roi, gaussian [(3, 3, 16, 15, 7), (4, 10, 7), 3, 0.25, "gaussian", torch.device("cpu:0")], # 3D small roi, gaussian - [ - (1, 3, 16, 15, 7), - (4, 10, 7), - 3, - 0.25, - "gaussian", - torch.device("cuda:0"), - ], # test inference on gpu if availabe + [(1, 3, 16, 15, 7), (4, 10, 7), 3, 0.25, "gaussian", torch.device("cuda:0")], # test inference on gpu if availabe [(1, 3, 16, 15, 7), (4, 1, 7), 3, 0.25, "constant", torch.device("cpu:0")], # 3D small roi [(5, 3, 16, 15, 7), (4, 1, 7), 3, 0.25, "constant", torch.device("cpu:0")], # 3D small roi ] @@ -152,6 +148,7 @@ def compute(self, data): cval=-1, mode="gaussian", sigma_scale=1.0, + progress=has_tqdm, ) expected = np.array( [ @@ -229,15 +226,16 @@ def compute(data, test1, test2): 0.0, device, device, + has_tqdm, t1, test2=t2, ) expected = np.ones((1, 1, 3, 3)) + 2.0 np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4) - result = SlidingWindowInferer(roi_shape, sw_batch_size, overlap=0.5, mode="constant", cval=-1)( - inputs, compute, t1, test2=t2 - ) + result = SlidingWindowInferer( + roi_shape, sw_batch_size, overlap=0.5, mode="constant", cval=-1, progress=has_tqdm + )(inputs, compute, t1, test2=t2) np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4) diff --git a/tests/test_smartcache_patch_wsi_dataset.py b/tests/test_smartcache_patch_wsi_dataset.py index c484e5fc69..e2150edce5 100644 --- a/tests/test_smartcache_patch_wsi_dataset.py +++ b/tests/test_smartcache_patch_wsi_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,13 +18,16 @@ from parameterized import parameterized from monai.apps.pathology.data import SmartCachePatchWSIDataset -from monai.apps.utils import download_url from monai.utils import optional_import +from tests.utils import download_url_or_skip_test, testing_data_config -_, has_cim = optional_import("cucim") +_cucim, has_cim = optional_import("cucim") +has_cim = has_cim and hasattr(_cucim, "CuImage") -FILE_URL = "http://openslide.cs.cmu.edu/download/openslide-testdata/Generic-TIFF/CMU-1.tiff" -FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + os.path.basename(FILE_URL)) +FILE_KEY = "wsi_img" +FILE_URL = testing_data_config("images", FILE_KEY, "url") +base_name, extension = os.path.basename(f"{FILE_URL}"), ".tiff" +FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + base_name + extension) TEST_CASE_0 = [ { @@ -43,6 +46,7 @@ "cache_num": 2, "num_init_workers": 1, "num_replace_workers": 1, + "copy_cache": False, }, [ {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0]]])}, @@ -131,15 +135,11 @@ class TestSmartCachePatchWSIDataset(unittest.TestCase): def setUp(self): - download_url(FILE_URL, FILE_PATH, "5a3cfd4fd725c50578ddb80b517b759f") + hash_type = testing_data_config("images", FILE_KEY, "hash_type") + hash_val = testing_data_config("images", FILE_KEY, "hash_val") + download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val) - @parameterized.expand( - [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - ] - ) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) @skipUnless(has_cim, "Requires CuCIM") def test_read_patches(self, input_parameters, expected): dataset = SmartCachePatchWSIDataset(**input_parameters) diff --git a/tests/test_smartcachedataset.py b/tests/test_smartcachedataset.py index e2675f4d8c..e7d51be63a 100644 --- a/tests/test_smartcachedataset.py +++ b/tests/test_smartcachedataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -174,13 +174,7 @@ def test_datalist(self): data_list = [np.array([i]) for i in range(5)] data_list_backup = copy.copy(data_list) - SmartCacheDataset( - data=data_list, - transform=None, - cache_rate=0.5, - replace_rate=0.4, - shuffle=True, - ) + SmartCacheDataset(data=data_list, transform=None, cache_rate=0.5, replace_rate=0.4, shuffle=True) np.testing.assert_allclose(data_list, data_list_backup) diff --git a/tests/test_smooth_field.py b/tests/test_smooth_field.py new file mode 100644 index 0000000000..5849b96167 --- /dev/null +++ b/tests/test_smooth_field.py @@ -0,0 +1,143 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from itertools import product + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import RandSmoothDeformd, RandSmoothFieldAdjustContrastd, RandSmoothFieldAdjustIntensityd +from tests.utils import TEST_NDARRAYS, assert_allclose, is_tf32_env + +_rtol = 5e-3 if is_tf32_env() else 1e-4 + +INPUT_SHAPES = ((1, 8, 8), (2, 8, 8), (1, 8, 8, 8)) + +TESTS_CONTRAST = [] +TESTS_INTENSITY = [] +TESTS_DEFORM = [] + +KEY = "test" + +for arr_type, shape in product(TEST_NDARRAYS, INPUT_SHAPES): + in_arr = arr_type(np.ones(shape, np.float32)) + exp_arr = arr_type(np.ones(shape, np.float32)) + rand_size = (4,) * (len(shape) - 1) + + device = torch.device("cpu") + + if isinstance(in_arr, torch.Tensor) and in_arr.get_device() >= 0: + device = torch.device(in_arr.get_device()) + + TESTS_CONTRAST.append( + ( + {"keys": (KEY,), "spatial_size": shape[1:], "rand_size": rand_size, "prob": 1.0, "device": device}, + {KEY: in_arr}, + {KEY: exp_arr}, + ) + ) + + TESTS_INTENSITY.append( + ( + { + "keys": (KEY,), + "spatial_size": shape[1:], + "rand_size": rand_size, + "prob": 1.0, + "device": device, + "gamma": (0.9, 1), + }, + {KEY: in_arr}, + {KEY: exp_arr}, + ) + ) + + TESTS_DEFORM.append( + ( + { + "keys": (KEY,), + "spatial_size": shape[1:], + "rand_size": rand_size, + "prob": 1.0, + "device": device, + "def_range": 0.1, + }, + {KEY: in_arr}, + {KEY: exp_arr}, + ) + ) + + +class TestSmoothField(unittest.TestCase): + @parameterized.expand(TESTS_CONTRAST) + def test_rand_smooth_field_adjust_contrastd(self, input_param, input_data, expected_val): + g = RandSmoothFieldAdjustContrastd(**input_param) + g.set_random_state(123) + + res = g(input_data) + for key, result in res.items(): + expected = expected_val[key] + assert_allclose(result, expected, rtol=_rtol, atol=1e-1) + + def test_rand_smooth_field_adjust_contrastd_pad(self): + input_param, input_data, expected_val = TESTS_CONTRAST[0] + + g = RandSmoothFieldAdjustContrastd(pad=1, **input_param) + g.set_random_state(123) + + res = g(input_data) + for key, result in res.items(): + expected = expected_val[key] + assert_allclose(result, expected, rtol=_rtol, atol=1e-1) + + @parameterized.expand(TESTS_INTENSITY) + def test_rand_smooth_field_adjust_intensityd(self, input_param, input_data, expected_val): + g = RandSmoothFieldAdjustIntensityd(**input_param) + g.set_random_state(123) + + res = g(input_data) + for key, result in res.items(): + expected = expected_val[key] + assert_allclose(result, expected, rtol=_rtol, atol=1e-1) + + def test_rand_smooth_field_adjust_intensityd_pad(self): + input_param, input_data, expected_val = TESTS_INTENSITY[0] + + g = RandSmoothFieldAdjustIntensityd(pad=1, **input_param) + g.set_random_state(123) + + res = g(input_data) + for key, result in res.items(): + expected = expected_val[key] + assert_allclose(result, expected, rtol=_rtol, atol=1e-1) + + @parameterized.expand(TESTS_DEFORM) + def test_rand_smooth_deformd(self, input_param, input_data, expected_val): + g = RandSmoothDeformd(**input_param) + g.set_random_state(123) + + res = g(input_data) + for key, result in res.items(): + expected = expected_val[key] + assert_allclose(result, expected, rtol=_rtol, atol=1e-1) + + def test_rand_smooth_deformd_pad(self): + input_param, input_data, expected_val = TESTS_DEFORM[0] + + g = RandSmoothDeformd(pad=1, **input_param) + g.set_random_state(123) + + res = g(input_data) + for key, result in res.items(): + expected = expected_val[key] + assert_allclose(result, expected, rtol=_rtol, atol=1e-1) diff --git a/tests/test_spacing.py b/tests/test_spacing.py index 6be6730c5a..80df981b73 100644 --- a/tests/test_spacing.py +++ b/tests/test_spacing.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,155 +12,214 @@ import unittest import numpy as np +import torch from parameterized import parameterized +from monai.data.utils import affine_to_spacing from monai.transforms import Spacing from monai.utils import ensure_tuple, fall_back_tuple +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - {"pixdim": (1.0, 1.5), "padding_mode": "zeros", "dtype": float}, - np.arange(4).reshape((1, 2, 2)) + 1.0, # data - {"affine": np.eye(4)}, - np.array([[[1.0, 1.0], [3.0, 2.0]]]), - ], - [ - {"pixdim": 1.0, "padding_mode": "zeros", "dtype": float}, - np.ones((1, 2, 1, 2)), # data - {"affine": np.eye(4)}, - np.array([[[[1.0, 1.0]], [[1.0, 1.0]]]]), - ], - [ - {"pixdim": (1.0, 1.0, 1.0), "padding_mode": "zeros", "dtype": float}, - np.ones((1, 2, 1, 2)), # data - {"affine": np.eye(4)}, - np.array([[[[1.0, 1.0]], [[1.0, 1.0]]]]), - ], - [ - {"pixdim": (1.0, 0.2, 1.5), "diagonal": False, "padding_mode": "zeros", "align_corners": True}, - np.ones((1, 2, 1, 2)), # data - {"affine": np.array([[2, 1, 0, 4], [-1, -3, 0, 5], [0, 0, 2.0, 5], [0, 0, 0, 1]])}, - np.array([[[[0.95527864, 0.95527864]], [[1.0, 1.0]], [[1.0, 1.0]]]]), - ], - [ - {"pixdim": (3.0, 1.0), "padding_mode": "zeros"}, - np.arange(24).reshape((2, 3, 4)), # data - {"affine": np.diag([-3.0, 0.2, 1.5, 1])}, - np.array([[[0, 0], [4, 0], [8, 0]], [[12, 0], [16, 0], [20, 0]]]), - ], - [ - {"pixdim": (3.0, 1.0), "padding_mode": "zeros"}, - np.arange(24).reshape((2, 3, 4)), # data - {}, - np.array([[[0, 1, 2, 3], [0, 0, 0, 0]], [[12, 13, 14, 15], [0, 0, 0, 0]]]), - ], - [ - {"pixdim": (1.0, 1.0)}, - np.arange(24).reshape((2, 3, 4)), # data - {}, - np.array( - [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] - ), - ], - [ - {"pixdim": (4.0, 5.0, 6.0)}, - np.arange(24).reshape((1, 2, 3, 4)), # data - {"affine": np.array([[-4, 0, 0, 4], [0, 5, 0, -5], [0, 0, 6, -6], [0, 0, 0, 1]])}, - np.arange(24).reshape((1, 2, 3, 4)), # data - ], - [ - {"pixdim": (4.0, 5.0, 6.0), "diagonal": True}, - np.arange(24).reshape((1, 2, 3, 4)), # data - {"affine": np.array([[-4, 0, 0, 4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]])}, - np.array( - [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] - ), - ], - [ - {"pixdim": (4.0, 5.0, 6.0), "padding_mode": "border", "diagonal": True}, - np.arange(24).reshape((1, 2, 3, 4)), # data - {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]])}, - np.array( - [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] - ), - ], - [ - {"pixdim": (4.0, 5.0, 6.0), "padding_mode": "border", "diagonal": True}, - np.arange(24).reshape((1, 2, 3, 4)), # data - {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "nearest"}, - np.array( - [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] - ), - ], - [ - {"pixdim": (1.9, 4.0), "padding_mode": "zeros", "diagonal": True}, - np.arange(24).reshape((1, 4, 6)), # data - {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "nearest"}, - np.array( - [ +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + p, + {"pixdim": (1.0, 1.5), "padding_mode": "zeros", "dtype": float}, + np.arange(4).reshape((1, 2, 2)) + 1.0, # data + {"affine": np.eye(4)}, + np.array([[[1.0, 1.0], [3.0, 2.0]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": 1.0, "padding_mode": "zeros", "dtype": float}, + np.ones((1, 2, 1, 2)), # data + {"affine": np.eye(4)}, + np.array([[[[1.0, 1.0]], [[1.0, 1.0]]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (1.0, 1.0, 1.0), "padding_mode": "zeros", "dtype": float}, + np.ones((1, 2, 1, 2)), # data + {"affine": np.eye(4)}, + np.array([[[[1.0, 1.0]], [[1.0, 1.0]]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (1.0, 0.2, 1.5), "diagonal": False, "padding_mode": "zeros", "align_corners": True}, + np.ones((1, 2, 1, 2)), # data + {"affine": np.array([[2, 1, 0, 4], [-1, -3, 0, 5], [0, 0, 2.0, 5], [0, 0, 0, 1]])}, + np.array([[[[0.95527864, 0.95527864]], [[1.0, 1.0]], [[1.0, 1.0]]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (3.0, 1.0), "padding_mode": "zeros"}, + np.arange(24).reshape((2, 3, 4)), # data + {"affine": np.diag([-3.0, 0.2, 1.5, 1])}, + np.array([[[0, 0], [4, 0], [8, 0]], [[12, 0], [16, 0], [20, 0]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (3.0, 1.0), "padding_mode": "zeros"}, + np.arange(24).reshape((2, 3, 4)), # data + {}, + np.array([[[0, 1, 2, 3], [0, 0, 0, 0]], [[12, 13, 14, 15], [0, 0, 0, 0]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (1.0, 1.0), "align_corners": True}, + np.arange(24).reshape((2, 3, 4)), # data + {}, + np.array( + [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (4.0, 5.0, 6.0)}, + np.arange(24).reshape((1, 2, 3, 4)), # data + {"affine": np.array([[-4, 0, 0, 4], [0, 5, 0, -5], [0, 0, 6, -6], [0, 0, 0, 1]])}, + np.arange(24).reshape((1, 2, 3, 4)), # data + ] + ) + TESTS.append( + [ + p, + {"pixdim": (4.0, 5.0, 6.0), "diagonal": True}, + np.arange(24).reshape((1, 2, 3, 4)), # data + {"affine": np.array([[-4, 0, 0, 4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]])}, + np.array( + [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (4.0, 5.0, 6.0), "padding_mode": "border", "diagonal": True}, + np.arange(24).reshape((1, 2, 3, 4)), # data + {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]])}, + np.array( + [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (4.0, 5.0, 6.0), "padding_mode": "border", "diagonal": True}, + np.arange(24).reshape((1, 2, 3, 4)), # data + {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "nearest"}, + np.array( + [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (1.9, 4.0), "padding_mode": "zeros", "diagonal": True}, + np.arange(24).reshape((1, 4, 6)), # data + {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "nearest"}, + np.array( [ - [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0], - [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0], - [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0], - [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0], - [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0], - [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0], - [0.0, 1.0, 2.0, 2.0, 3.0, 4.0, 5.0], + [ + [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0], + [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0], + [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0], + [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0], + [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0], + [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0], + [0.0, 1.0, 2.0, 2.0, 3.0, 4.0, 5.0], + ] ] - ] - ), - ], - [ - {"pixdim": (5.0, 3.0), "padding_mode": "border", "diagonal": True, "dtype": np.float32}, - np.arange(24).reshape((1, 4, 6)), # data - {"affine": np.array([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "bilinear"}, - np.array( - [ + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (5.0, 3.0), "padding_mode": "border", "diagonal": True, "dtype": np.float32}, + np.arange(24).reshape((1, 4, 6)), # data + {"affine": np.array([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "bilinear"}, + np.array( [ - [18.0, 18.6, 19.2, 19.8, 20.400002, 21.0, 21.6, 22.2, 22.8], - [10.5, 11.1, 11.700001, 12.299999, 12.900001, 13.5, 14.1, 14.700001, 15.3], - [3.0, 3.6000001, 4.2000003, 4.8, 5.4000006, 6.0, 6.6000004, 7.200001, 7.8], + [ + [18.0, 18.6, 19.2, 19.8, 20.400002, 21.0, 21.6, 22.2, 22.8], + [10.5, 11.1, 11.700001, 12.299999, 12.900001, 13.5, 14.1, 14.700001, 15.3], + [3.0, 3.6000001, 4.2000003, 4.8, 5.4000006, 6.0, 6.6000004, 7.200001, 7.8], + ] ] - ] - ), - ], - [ - {"pixdim": (5.0, 3.0), "padding_mode": "zeros", "diagonal": True, "dtype": np.float32}, - np.arange(24).reshape((1, 4, 6)), # data - {"affine": np.array([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "bilinear"}, - np.array( - [ + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (5.0, 3.0), "padding_mode": "zeros", "diagonal": True, "dtype": np.float32}, + np.arange(24).reshape((1, 4, 6)), # data + {"affine": np.array([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "bilinear"}, + np.array( [ - [18.0000, 18.6000, 19.2000, 19.8000, 20.4000, 21.0000, 21.6000, 22.2000, 22.8000], - [10.5000, 11.1000, 11.7000, 12.3000, 12.9000, 13.5000, 14.1000, 14.7000, 15.3000], - [3.0000, 3.6000, 4.2000, 4.8000, 5.4000, 6.0000, 6.6000, 7.2000, 7.8000], + [ + [18.0000, 18.6000, 19.2000, 19.8000, 20.4000, 21.0000, 21.6000, 22.2000, 22.8000], + [10.5000, 11.1000, 11.7000, 12.3000, 12.9000, 13.5000, 14.1000, 14.7000, 15.3000], + [3.0000, 3.6000, 4.2000, 4.8000, 5.4000, 6.0000, 6.6000, 7.2000, 7.8000], + ] ] - ] - ), - ], - [ - {"pixdim": [-1, -1, 0.5], "padding_mode": "zeros", "dtype": float}, - np.ones((1, 2, 1, 2)), # data - {"affine": np.eye(4)}, - np.array([[[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]]]), - ], -] + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": [-1, -1, 0.5], "padding_mode": "zeros", "dtype": float}, + np.ones((1, 2, 1, 2)), # data + {"affine": np.eye(4)}, + np.array([[[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]]]), + ] + ) + TESTS.append( # 5D input + [ + p, + {"pixdim": [-1, -1, 0.5], "padding_mode": "zeros", "dtype": float, "align_corners": True}, + np.ones((1, 2, 2, 2, 1)), # data + {"affine": np.eye(4)}, + np.ones((1, 2, 2, 3, 1)), + ] + ) class TestSpacingCase(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_spacing(self, init_param, img, data_param, expected_output): - res = Spacing(**init_param)(img, **data_param) - if not isinstance(res, tuple): - np.testing.assert_allclose(res, expected_output, atol=1e-6) - return - np.testing.assert_allclose(res[0], expected_output, atol=1e-6) - sr = len(res[0].shape) - 1 + @parameterized.expand(TESTS) + def test_spacing(self, in_type, init_param, img, data_param, expected_output): + _img = in_type(img) + output_data, _, new_affine = Spacing(**init_param)(_img, **data_param) + if isinstance(_img, torch.Tensor): + self.assertEqual(_img.device, output_data.device) + output_data = output_data.cpu() + + np.testing.assert_allclose(output_data, expected_output, atol=1e-1, rtol=1e-1) + sr = min(len(output_data.shape) - 1, 3) if isinstance(init_param["pixdim"], float): init_param["pixdim"] = [init_param["pixdim"]] * sr init_pixdim = ensure_tuple(init_param["pixdim"]) init_pixdim = init_param["pixdim"][:sr] - norm = np.sqrt(np.sum(np.square(res[2]), axis=0))[:sr] + norm = affine_to_spacing(new_affine, sr) np.testing.assert_allclose(fall_back_tuple(init_pixdim, norm), norm) diff --git a/tests/test_spacingd.py b/tests/test_spacingd.py index 61a4a4c38b..060d908699 100644 --- a/tests/test_spacingd.py +++ b/tests/test_spacingd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,82 +10,89 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np +import torch +from parameterized import parameterized from monai.transforms import Spacingd +from monai.utils.enums import PostFix +from tests.utils import TEST_NDARRAYS, assert_allclose - -class TestSpacingDCase(unittest.TestCase): - def test_spacingd_3d(self): - data = {"image": np.ones((2, 10, 15, 20)), "image_meta_dict": {"affine": np.eye(4)}} - spacing = Spacingd(keys="image", pixdim=(1, 2, 1.4)) - res = spacing(data) - self.assertEqual(("image", "image_meta_dict", "image_transforms"), tuple(sorted(res))) - np.testing.assert_allclose(res["image"].shape, (2, 10, 8, 15)) - np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag([1, 2, 1.4, 1.0])) - - def test_spacingd_2d(self): - data = {"image": np.ones((2, 10, 20)), "image_meta_dict": {"affine": np.eye(3)}} - spacing = Spacingd(keys="image", pixdim=(1, 2)) - res = spacing(data) - self.assertEqual(("image", "image_meta_dict", "image_transforms"), tuple(sorted(res))) - np.testing.assert_allclose(res["image"].shape, (2, 10, 10)) - np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 2, 1))) - - def test_spacingd_2d_no_metadata(self): - data = {"image": np.ones((2, 10, 20))} - spacing = Spacingd(keys="image", pixdim=(1, 2)) - res = spacing(data) - self.assertEqual(("image", "image_meta_dict", "image_transforms"), tuple(sorted(res))) - np.testing.assert_allclose(res["image"].shape, (2, 10, 10)) - np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 2, 1))) - - def test_interp_all(self): - data = { - "image": np.arange(20).reshape((2, 1, 10)), - "seg": np.ones((2, 1, 10)), - "image_meta_dict": {"affine": np.eye(4)}, - "seg_meta_dict": {"affine": np.eye(4)}, - } - spacing = Spacingd( - keys=("image", "seg"), - mode="nearest", - pixdim=( - 1, - 0.2, - ), +TESTS: List[Tuple] = [] +for p in TEST_NDARRAYS: + TESTS.append( + ( + "spacing 3d", + {"image": p(np.ones((2, 10, 15, 20))), PostFix.meta("image"): {"affine": p(np.eye(4))}}, + dict(keys="image", pixdim=(1, 2, 1.4)), + ("image", PostFix.meta("image"), "image_transforms"), + (2, 10, 8, 15), + p(np.diag([1, 2, 1.4, 1.0])), ) - res = spacing(data) - self.assertEqual( - ("image", "image_meta_dict", "image_transforms", "seg", "seg_meta_dict", "seg_transforms"), - tuple(sorted(res)), + ) + TESTS.append( + ( + "spacing 2d", + {"image": np.ones((2, 10, 20)), PostFix.meta("image"): {"affine": np.eye(3)}}, + dict(keys="image", pixdim=(1, 2)), + ("image", PostFix.meta("image"), "image_transforms"), + (2, 10, 10), + np.diag((1, 2, 1)), ) - np.testing.assert_allclose(res["image"].shape, (2, 1, 46)) - np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 0.2, 1, 1))) - - def test_interp_sep(self): - data = { - "image": np.ones((2, 1, 10)), - "seg": np.ones((2, 1, 10)), - "image_meta_dict": {"affine": np.eye(4)}, - "seg_meta_dict": {"affine": np.eye(4)}, - } - spacing = Spacingd( - keys=("image", "seg"), - mode=("bilinear", "nearest"), - pixdim=( - 1, - 0.2, - ), + ) + TESTS.append( + ( + "spacing 2d no metadata", + {"image": np.ones((2, 10, 20))}, + dict(keys="image", pixdim=(1, 2)), + ("image", PostFix.meta("image"), "image_transforms"), + (2, 10, 10), + np.diag((1, 2, 1)), + ) + ) + TESTS.append( + ( + "interp all", + { + "image": np.arange(20).reshape((2, 1, 10)), + "seg": np.ones((2, 1, 10)), + PostFix.meta("image"): {"affine": np.eye(4)}, + PostFix.meta("seg"): {"affine": np.eye(4)}, + }, + dict(keys=("image", "seg"), mode="nearest", pixdim=(1, 0.2)), + ("image", PostFix.meta("image"), "image_transforms", "seg", PostFix.meta("seg"), "seg_transforms"), + (2, 1, 46), + np.diag((1, 0.2, 1, 1)), ) - res = spacing(data) - self.assertEqual( - ("image", "image_meta_dict", "image_transforms", "seg", "seg_meta_dict", "seg_transforms"), - tuple(sorted(res)), + ) + TESTS.append( + ( + "interp sep", + { + "image": np.ones((2, 1, 10)), + "seg": np.ones((2, 1, 10)), + PostFix.meta("image"): {"affine": np.eye(4)}, + PostFix.meta("seg"): {"affine": np.eye(4)}, + }, + dict(keys=("image", "seg"), mode=("bilinear", "nearest"), pixdim=(1, 0.2)), + ("image", PostFix.meta("image"), "image_transforms", "seg", PostFix.meta("seg"), "seg_transforms"), + (2, 1, 46), + np.diag((1, 0.2, 1, 1)), ) - np.testing.assert_allclose(res["image"].shape, (2, 1, 46)) - np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 0.2, 1, 1))) + ) + + +class TestSpacingDCase(unittest.TestCase): + @parameterized.expand(TESTS) + def test_spacingd(self, _, data, kw_args, expected_keys, expected_shape, expected_affine): + res = Spacingd(**kw_args)(data) + if isinstance(data["image"], torch.Tensor): + self.assertEqual(data["image"].device, res["image"].device) + self.assertEqual(expected_keys, tuple(sorted(res))) + np.testing.assert_allclose(res["image"].shape, expected_shape) + assert_allclose(res[PostFix.meta("image")]["affine"], expected_affine) if __name__ == "__main__": diff --git a/tests/test_spatial_crop.py b/tests/test_spatial_crop.py index c76915f0a3..bf1eb11491 100644 --- a/tests/test_spatial_crop.py +++ b/tests/test_spatial_crop.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,54 +16,41 @@ from parameterized import parameterized from monai.transforms import SpatialCrop +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - {"roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, - (3, 3, 3, 3), - (3, 2, 2, 2), - ], +TESTS = [ + [{"roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2)], + [{"roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, (3, 1, 1, 1), (3, 1, 1, 1)], [{"roi_start": [0, 0, 0], "roi_end": [2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2)], [{"roi_start": [0, 0], "roi_end": [2, 2]}, (3, 3, 3, 3), (3, 2, 2, 3)], - [ - {"roi_start": [0, 0, 0, 0, 0], "roi_end": [2, 2, 2, 2, 2]}, - (3, 3, 3, 3), - (3, 2, 2, 2), - ], - [ - {"roi_start": [0, 0, 0, 0, 0], "roi_end": [8, 8, 8, 2, 2]}, - (3, 3, 3, 3), - (3, 3, 3, 3), - ], - [ - {"roi_start": [1, 0, 0], "roi_end": [1, 8, 8]}, - (3, 3, 3, 3), - (3, 0, 3, 3), - ], - [ - {"roi_slices": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]}, - (3, 3, 3, 3), - (3, 1, 2, 2), - ], + [{"roi_start": [0, 0, 0, 0, 0], "roi_end": [2, 2, 2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2)], + [{"roi_start": [0, 0, 0, 0, 0], "roi_end": [8, 8, 8, 2, 2]}, (3, 3, 3, 3), (3, 3, 3, 3)], + [{"roi_start": [1, 0, 0], "roi_end": [1, 8, 8]}, (3, 3, 3, 3), (3, 0, 3, 3)], + [{"roi_slices": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]}, (3, 3, 3, 3), (3, 1, 2, 2)], ] -TEST_ERRORS = [ - [{"roi_slices": [slice(s, e, 2) for s, e in zip([-1, -2, 0], [None, None, 2])]}], -] +TEST_ERRORS = [[{"roi_slices": [slice(s, e, 2) for s, e in zip([-1, -2, 0], [None, None, 2])]}]] class TestSpatialCrop(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_shape, expected_shape): input_data = np.random.randint(0, 2, size=input_shape) - result = SpatialCrop(**input_param)(input_data) - self.assertTupleEqual(result.shape, expected_shape) - - @parameterized.expand(TEST_CASES) - def test_tensor_shape(self, input_param, input_shape, expected_shape): - input_data = torch.randint(0, 2, size=input_shape, device="cuda" if torch.cuda.is_available() else "cpu") - result = SpatialCrop(**input_param)(input_data) - self.assertTupleEqual(result.shape, expected_shape) + results = [] + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS + (None,): + input_param_mod = { + k: q(v) if k != "roi_slices" and q is not None else v for k, v in input_param.items() + } + im = p(input_data) + result = SpatialCrop(**input_param_mod)(im) + self.assertEqual(type(im), type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, im.device) + self.assertTupleEqual(result.shape, expected_shape) + results.append(result) + if len(results) > 1: + assert_allclose(results[0], results[-1], type_test=False) @parameterized.expand(TEST_ERRORS) def test_error(self, input_param): diff --git a/tests/test_spatial_cropd.py b/tests/test_spatial_cropd.py index 797c25d34b..5b16f460fd 100644 --- a/tests/test_spatial_cropd.py +++ b/tests/test_spatial_cropd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,38 +15,49 @@ from parameterized import parameterized from monai.transforms import SpatialCropd +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - {"keys": ["img"], "roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 2, 2, 2), - ], - [ - {"keys": ["img"], "roi_start": [0, 0, 0], "roi_end": [2, 2, 2]}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 2, 2, 2), - ], - [ - {"keys": ["img"], "roi_start": [0, 0], "roi_end": [2, 2]}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 2, 2, 3), - ], - [ - {"keys": ["img"], "roi_start": [0, 0, 0, 0, 0], "roi_end": [2, 2, 2, 2, 2]}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 2, 2, 2), - ], - [ - {"keys": ["img"], "roi_slices": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 1, 2, 2), - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["img"], "roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + (3, 2, 2, 2), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "roi_start": [0, 0, 0], "roi_end": [2, 2, 2]}, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + (3, 2, 2, 2), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "roi_start": [0, 0], "roi_end": [2, 2]}, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + (3, 2, 2, 3), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "roi_start": [0, 0, 0, 0, 0], "roi_end": [2, 2, 2, 2, 2]}, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + (3, 2, 2, 2), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "roi_slices": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]}, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + (3, 1, 2, 2), + ] + ) class TestSpatialCropd(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape): result = SpatialCropd(**input_param)(input_data) self.assertTupleEqual(result["img"].shape, expected_shape) diff --git a/tests/test_spatial_pad.py b/tests/test_spatial_pad.py index 86d010bbad..4cdeb6d64e 100644 --- a/tests/test_spatial_pad.py +++ b/tests/test_spatial_pad.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,44 +17,39 @@ from parameterized import parameterized from monai.transforms import SpatialPad -from monai.utils.enums import NumpyPadMode +from monai.utils.enums import NumpyPadMode, PytorchPadMode from monai.utils.misc import set_determinism from tests.utils import TEST_NDARRAYS TESTS = [] -# Numpy modes -MODES: List = [ +MODES = [] + +# Test modes +NP_MODES: List = [ "constant", "edge", - "linear_ramp", - "maximum", - "mean", - "median", - "minimum", - "reflect", - "symmetric", + # `reflect` mode is not supported in some PyTorch versions, skip the test + # "reflect", "wrap", - "empty", ] -MODES += [NumpyPadMode(i) for i in MODES] +MODES += NP_MODES +MODES += [NumpyPadMode(i) for i in NP_MODES] + +PT_MODES: list = [ + "constant", + "replicate", + "circular", + # `reflect` mode is not supported in some PyTorch versions, skip the test + # "reflect", +] +MODES += PT_MODES +MODES += [PytorchPadMode(i) for i in PT_MODES] for mode in MODES: - TESTS.append( - [ - {"spatial_size": [50, 50], "method": "end", "mode": mode}, - (1, 2, 2), - (1, 50, 50), - ] - ) - - TESTS.append( - [ - {"spatial_size": [15, 4, -1], "method": "symmetric", "mode": mode}, - (3, 8, 8, 4), - (3, 15, 8, 4), - ] - ) + TESTS.append([{"spatial_size": [3, 4], "method": "end", "mode": mode}, (1, 2, 3), (1, 3, 4)]) + + TESTS.append([{"spatial_size": [15, 4, -1], "method": "symmetric", "mode": mode}, (3, 8, 8, 4), (3, 15, 8, 4)]) class TestSpatialPad(unittest.TestCase): @@ -86,14 +81,19 @@ def test_pad_shape(self, input_param, input_shape, expected_shape): torch.testing.assert_allclose(results[0], results[-1], atol=0, rtol=1e-5) def test_pad_kwargs(self): - padder = SpatialPad( - spatial_size=[15, 8], method="end", mode="constant", constant_values=((0, 0), (1, 1), (2, 2)) - ) for p in TEST_NDARRAYS: - result = padder(p(np.zeros((3, 8, 4)))) - if isinstance(result, torch.Tensor): - result = result.cpu().numpy() - torch.testing.assert_allclose(result[:, 8:, :4], np.ones((3, 7, 4)), rtol=1e-7, atol=0) + input_data = p(np.zeros((3, 8, 4))) + if isinstance(input_data, torch.Tensor): + result = ( + SpatialPad(spatial_size=[15, 8], method="end", mode="constant", value=2)(img=input_data) + .cpu() + .numpy() + ) + else: + result = SpatialPad( + spatial_size=[15, 8], method="end", mode="constant", constant_values=((0, 0), (1, 1), (2, 2)) + )(img=input_data) + torch.testing.assert_allclose(result[:, 8:, :4], np.ones((3, 7, 4)), rtol=1e-7, atol=0) torch.testing.assert_allclose(result[:, :, 4:], np.ones((3, 15, 4)) + 1, rtol=1e-7, atol=0) diff --git a/tests/test_spatial_padd.py b/tests/test_spatial_padd.py index 8400bb82cc..762a1145f5 100644 --- a/tests/test_spatial_padd.py +++ b/tests/test_spatial_padd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_spatial_resample.py b/tests/test_spatial_resample.py new file mode 100644 index 0000000000..9ee84de85b --- /dev/null +++ b/tests/test_spatial_resample.py @@ -0,0 +1,146 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.config import USE_COMPILED +from monai.transforms import SpatialResample +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] + +for ind, dst in enumerate( + [ + np.asarray([[1.0, 0.0, 0.0], [0.0, -1.0, 1.0], [0.0, 0.0, 1.0]]), # flip the second + np.asarray([[-1.0, 0.0, 1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), # flip the first + ] +): + for p in TEST_NDARRAYS: + for p_src in TEST_NDARRAYS: + for align in (False, True): + for interp_mode in ("nearest", "bilinear"): + TESTS.append( + [ + {}, # default no params + np.arange(4).reshape((1, 2, 2)) + 1.0, # data + { + "src_affine": p_src(np.eye(3)), + "dst_affine": p(dst), + "dtype": np.float32, + "align_corners": align, + "mode": interp_mode, + "padding_mode": "zeros", + }, + np.array([[[2.0, 1.0], [4.0, 3.0]]]) if ind == 0 else np.array([[[3.0, 4.0], [1.0, 2.0]]]), + ] + ) + +for ind, dst in enumerate( + [ + np.asarray([[1.0, 0.0, 0.0, 0.0], [0.0, -1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]), + np.asarray([[-1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]), + ] +): + for p_src in TEST_NDARRAYS: + for align in (True, False): + if align and USE_COMPILED: + interp = ("nearest", "bilinear", 0, 1) + else: + interp = ("nearest", "bilinear") # type: ignore + for interp_mode in interp: # type: ignore + for padding_mode in ("zeros", "border", "reflection"): + TESTS.append( + [ + {}, # default no params + np.arange(12).reshape((1, 2, 2, 3)) + 1.0, # data + { + "src_affine": p_src(np.eye(4)), + "dst_affine": p_src(dst), + "dtype": np.float64, + "align_corners": align, + "mode": interp_mode, + "padding_mode": padding_mode, + }, + np.array([[[[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]], [[10.0, 11.0, 12.0], [7.0, 8.0, 9.0]]]]) + if ind == 0 + else np.array( + [[[[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]] + ), + ] + ) + + +class TestSpatialResample(unittest.TestCase): + @parameterized.expand(itertools.product(TEST_NDARRAYS, TESTS)) + def test_flips(self, p_type, args): + init_param, img, data_param, expected_output = args + _img = p_type(img) + _expected_output = p_type(expected_output) + output_data, output_dst = SpatialResample(**init_param)(img=_img, **data_param) + assert_allclose(output_data, _expected_output, rtol=1e-2, atol=1e-2) + expected_dst = ( + data_param.get("dst_affine") if data_param.get("dst_affine") is not None else data_param.get("src_affine") + ) + assert_allclose(output_dst, expected_dst, type_test=False, rtol=1e-2, atol=1e-2) + + @parameterized.expand(itertools.product([True, False], TEST_NDARRAYS)) + def test_4d_5d(self, is_5d, p_type): + new_shape = (1, 2, 2, 3, 1, 1) if is_5d else (1, 2, 2, 3, 1) + img = np.arange(12).reshape(new_shape) + img = np.tile(img, (1, 1, 1, 1, 2, 2) if is_5d else (1, 1, 1, 1, 2)) + _img = p_type(img) + dst = np.asarray([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, -1.0, 1.5], [0.0, 0.0, 0.0, 1.0]]) + output_data, output_dst = SpatialResample(dtype=np.float32)( + img=_img, src_affine=p_type(np.eye(4)), dst_affine=dst + ) + expected_data = ( + np.asarray( + [ + [ + [[[0.0, 0.0], [0.0, 1.0]], [[0.5, 0.0], [1.5, 1.0]], [[1.0, 2.0], [2.0, 2.0]]], + [[[3.0, 3.0], [3.0, 4.0]], [[3.5, 3.0], [4.5, 4.0]], [[4.0, 5.0], [5.0, 5.0]]], + ], + [ + [[[6.0, 6.0], [6.0, 7.0]], [[6.5, 6.0], [7.5, 7.0]], [[7.0, 8.0], [8.0, 8.0]]], + [[[9.0, 9.0], [9.0, 10.0]], [[9.5, 9.0], [10.5, 10.0]], [[10.0, 11.0], [11.0, 11.0]]], + ], + ], + dtype=np.float32, + ) + if is_5d + else np.asarray( + [ + [[[0.5, 0.0], [0.0, 2.0], [1.5, 1.0]], [[3.5, 3.0], [3.0, 5.0], [4.5, 4.0]]], + [[[6.5, 6.0], [6.0, 8.0], [7.5, 7.0]], [[9.5, 9.0], [9.0, 11.0], [10.5, 10.0]]], + ], + dtype=np.float32, + ) + ) + assert_allclose(output_data, p_type(expected_data[None]), rtol=1e-2, atol=1e-2) + assert_allclose(output_dst, dst, type_test=False, rtol=1e-2, atol=1e-2) + + def test_ill_affine(self): + img = np.arange(12).reshape(1, 2, 2, 3) + ill_affine = np.asarray( + [[1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, -1.0, 1.5], [0.0, 0.0, 0.0, 1.0]] + ) + with self.assertRaises(ValueError): + SpatialResample()(img=img, src_affine=np.eye(4), dst_affine=ill_affine) + with self.assertRaises(ValueError): + SpatialResample()(img=img, src_affine=ill_affine, dst_affine=np.eye(3)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_spatial_resampled.py b/tests/test_spatial_resampled.py new file mode 100644 index 0000000000..73f83791d9 --- /dev/null +++ b/tests/test_spatial_resampled.py @@ -0,0 +1,113 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.config import USE_COMPILED +from monai.transforms import SpatialResampleD +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] + +for ind, dst in enumerate( + [ + np.asarray([[1.0, 0.0, 0.0], [0.0, -1.0, 1.0], [0.0, 0.0, 1.0]]), # flip the second + np.asarray([[-1.0, 0.0, 1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), # flip the first + ] +): + for p in TEST_NDARRAYS: + for p_src in TEST_NDARRAYS: + for align in (False, True): + for interp_mode in ("nearest", "bilinear"): + TESTS.append( + [ + {}, # default no params + np.arange(4).reshape((1, 2, 2)) + 1.0, # data + { + "src": p_src(np.eye(3)), + "dst": p(dst), + "dtype": np.float32, + "align_corners": align, + "mode": interp_mode, + "padding_mode": "zeros", + }, + np.array([[[2.0, 1.0], [4.0, 3.0]]]) if ind == 0 else np.array([[[3.0, 4.0], [1.0, 2.0]]]), + ] + ) + +for ind, dst in enumerate( + [ + np.asarray([[1.0, 0.0, 0.0, 0.0], [0.0, -1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]), + np.asarray([[-1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]), + ] +): + for p_src in TEST_NDARRAYS: + for align in (True, False): + if align and USE_COMPILED: + interp = ("nearest", "bilinear", 0, 1) + else: + interp = ("nearest", "bilinear") # type: ignore + for interp_mode in interp: # type: ignore + for padding_mode in ("zeros", "border", "reflection"): + TESTS.append( + [ + {}, # default no params + np.arange(12).reshape((1, 2, 2, 3)) + 1.0, # data + { + "src": p_src(np.eye(4)), + "dst": p_src(dst), + "dtype": np.float64, + "align_corners": align, + "mode": interp_mode, + "padding_mode": padding_mode, + }, + np.array([[[[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]], [[10.0, 11.0, 12.0], [7.0, 8.0, 9.0]]]]) + if ind == 0 + else np.array( + [[[[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]] + ), + ] + ) + + +class TestSpatialResample(unittest.TestCase): + @parameterized.expand(itertools.product(TEST_NDARRAYS, TESTS)) + def test_flips_inverse(self, p_type, args): + _, img, data_param, expected_output = args + _img = p_type(img) + _expected_output = p_type(expected_output) + input_dict = {"img": _img, "img_meta_dict": {"src": data_param.get("src"), "dst": data_param.get("dst")}} + xform = SpatialResampleD( + keys="img", + meta_src_keys="src", + meta_dst_keys="dst", + mode=data_param.get("mode"), + padding_mode=data_param.get("padding_mode"), + align_corners=data_param.get("align_corners"), + ) + output_data = xform(input_dict) + assert_allclose(output_data["img"], _expected_output, rtol=1e-2, atol=1e-2) + assert_allclose( + output_data["img_meta_dict"]["src"], data_param.get("dst"), type_test=False, rtol=1e-2, atol=1e-2 + ) + + inverted = xform.inverse(output_data) + self.assertEqual(inverted["img_transforms"], []) # no further invert after inverting + assert_allclose(inverted["img_meta_dict"]["src"], data_param.get("src"), type_test=False, rtol=1e-2, atol=1e-2) + assert_allclose(inverted["img"], _img, rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_split_channel.py b/tests/test_split_channel.py index 38315a102c..75216227e4 100644 --- a/tests/test_split_channel.py +++ b/tests/test_split_channel.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_split_channeld.py b/tests/test_split_channeld.py index f1df24364d..7a34855676 100644 --- a/tests/test_split_channeld.py +++ b/tests/test_split_channeld.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -51,13 +51,7 @@ ] ) - TESTS.append( - [ - {"keys": "pred", "channel_dim": 1}, - {"pred": p(np.random.randint(2, size=(3, 2, 4)))}, - (3, 1, 4), - ] - ) + TESTS.append([{"keys": "pred", "channel_dim": 1}, {"pred": p(np.random.randint(2, size=(3, 2, 4)))}, (3, 1, 4)]) class TestSplitChanneld(unittest.TestCase): diff --git a/tests/test_split_on_grid.py b/tests/test_split_on_grid.py index a187835e7b..b1d4cd93c5 100644 --- a/tests/test_split_on_grid.py +++ b/tests/test_split_on_grid.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,11 +11,11 @@ import unittest -import numpy as np import torch from parameterized import parameterized from monai.apps.pathology.transforms import SplitOnGrid +from tests.utils import TEST_NDARRAYS, assert_allclose A11 = torch.randn(3, 2, 2) A12 = torch.randn(3, 2, 2) @@ -26,105 +26,58 @@ A2 = torch.cat([A21, A22], 2) A = torch.cat([A1, A2], 1) -TEST_CASE_0 = [ - {"grid_size": (2, 2)}, - A, - torch.stack([A11, A12, A21, A22]), +TEST_CASE_0 = [{"grid_size": (2, 2)}, A, torch.stack([A11, A12, A21, A22])] +TEST_CASE_1 = [{"grid_size": (2, 1)}, A, torch.stack([A1, A2])] +TEST_CASE_2 = [{"grid_size": (1, 2)}, A1, torch.stack([A11, A12])] +TEST_CASE_3 = [{"grid_size": (1, 2)}, A2, torch.stack([A21, A22])] +TEST_CASE_4 = [{"grid_size": (1, 1), "patch_size": (2, 2)}, A, torch.stack([A11])] +TEST_CASE_5 = [{"grid_size": 1, "patch_size": 4}, A, torch.stack([A])] +TEST_CASE_6 = [{"grid_size": 2, "patch_size": 2}, A, torch.stack([A11, A12, A21, A22])] +TEST_CASE_7 = [{"grid_size": 1}, A, torch.stack([A])] +TEST_CASE_8 = [ + {"grid_size": (2, 2), "patch_size": 2}, + torch.arange(12).reshape(1, 3, 4).to(torch.float32), + torch.Tensor([[[[0, 1], [4, 5]]], [[[2, 3], [6, 7]]], [[[4, 5], [8, 9]]], [[[6, 7], [10, 11]]]]).to(torch.float32), ] -TEST_CASE_1 = [ - {"grid_size": (2, 1)}, - A, - torch.stack([A1, A2]), -] - -TEST_CASE_2 = [ - {"grid_size": (1, 2)}, - A1, - torch.stack([A11, A12]), -] - -TEST_CASE_3 = [ - {"grid_size": (1, 2)}, - A2, - torch.stack([A21, A22]), -] - -TEST_CASE_4 = [ - {"grid_size": (1, 1), "patch_size": (2, 2)}, - A, - torch.stack([A11]), -] - -TEST_CASE_5 = [ - {"grid_size": 1, "patch_size": 4}, - A, - torch.stack([A]), -] - -TEST_CASE_6 = [ - {"grid_size": 2, "patch_size": 2}, - A, - torch.stack([A11, A12, A21, A22]), -] - -TEST_CASE_7 = [ - {"grid_size": 1}, - A, - torch.stack([A]), -] - -TEST_CASE_MC_0 = [ - {"grid_size": (2, 2)}, - [A, A], - [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])], -] - - -TEST_CASE_MC_1 = [ - {"grid_size": (2, 1)}, - [A] * 5, - [torch.stack([A1, A2])] * 5, -] - - -TEST_CASE_MC_2 = [ - {"grid_size": (1, 2)}, - [A1, A2], - [torch.stack([A11, A12]), torch.stack([A21, A22])], -] +TEST_SINGLE = [] +for p in TEST_NDARRAYS: + TEST_SINGLE.append([p, *TEST_CASE_0]) + TEST_SINGLE.append([p, *TEST_CASE_1]) + TEST_SINGLE.append([p, *TEST_CASE_2]) + TEST_SINGLE.append([p, *TEST_CASE_3]) + TEST_SINGLE.append([p, *TEST_CASE_4]) + TEST_SINGLE.append([p, *TEST_CASE_5]) + TEST_SINGLE.append([p, *TEST_CASE_6]) + TEST_SINGLE.append([p, *TEST_CASE_7]) + TEST_SINGLE.append([p, *TEST_CASE_8]) + +TEST_CASE_MC_0 = [{"grid_size": (2, 2)}, [A, A], [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])]] +TEST_CASE_MC_1 = [{"grid_size": (2, 1)}, [A] * 5, [torch.stack([A1, A2])] * 5] +TEST_CASE_MC_2 = [{"grid_size": (1, 2)}, [A1, A2], [torch.stack([A11, A12]), torch.stack([A21, A22])]] + +TEST_MULTIPLE = [] +for p in TEST_NDARRAYS: + TEST_MULTIPLE.append([p, *TEST_CASE_MC_0]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_1]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_2]) class TestSplitOnGrid(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - TEST_CASE_7, - ] - ) - def test_split_pathce_single_call(self, input_parameters, img, expected): + @parameterized.expand(TEST_SINGLE) + def test_split_patch_single_call(self, in_type, input_parameters, image, expected): + input_image = in_type(image) splitter = SplitOnGrid(**input_parameters) - output = splitter(img) - np.testing.assert_equal(output.numpy(), expected.numpy()) + output = splitter(input_image) + assert_allclose(output, expected, type_test=False) - @parameterized.expand( - [ - TEST_CASE_MC_0, - TEST_CASE_MC_1, - TEST_CASE_MC_2, - ] - ) - def test_split_pathce_multiple_call(self, input_parameters, img_list, expected_list): + @parameterized.expand(TEST_MULTIPLE) + def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list): splitter = SplitOnGrid(**input_parameters) - for img, expected in zip(img_list, expected_list): - output = splitter(img) - np.testing.assert_equal(output.numpy(), expected.numpy()) + for image, expected in zip(img_list, expected_list): + input_image = in_type(image) + output = splitter(input_image) + assert_allclose(output, expected, type_test=False) if __name__ == "__main__": diff --git a/tests/test_split_on_grid_dict.py b/tests/test_split_on_grid_dict.py index 96ec095423..778a38da34 100644 --- a/tests/test_split_on_grid_dict.py +++ b/tests/test_split_on_grid_dict.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,11 +11,11 @@ import unittest -import numpy as np import torch from parameterized import parameterized from monai.apps.pathology.transforms import SplitOnGridDict +from tests.utils import TEST_NDARRAYS, assert_allclose A11 = torch.randn(3, 2, 2) A12 = torch.randn(3, 2, 2) @@ -26,105 +26,74 @@ A2 = torch.cat([A21, A22], 2) A = torch.cat([A1, A2], 1) -TEST_CASE_0 = [ - {"keys": "image", "grid_size": (2, 2)}, - {"image": A}, - torch.stack([A11, A12, A21, A22]), -] - -TEST_CASE_1 = [ - {"keys": "image", "grid_size": (2, 1)}, - {"image": A}, - torch.stack([A1, A2]), -] - -TEST_CASE_2 = [ - {"keys": "image", "grid_size": (1, 2)}, - {"image": A1}, - torch.stack([A11, A12]), -] - -TEST_CASE_3 = [ - {"keys": "image", "grid_size": (1, 2)}, - {"image": A2}, - torch.stack([A21, A22]), -] - -TEST_CASE_4 = [ - {"keys": "image", "grid_size": (1, 1), "patch_size": (2, 2)}, - {"image": A}, - torch.stack([A11]), -] - -TEST_CASE_5 = [ - {"keys": "image", "grid_size": 1, "patch_size": 4}, - {"image": A}, - torch.stack([A]), +TEST_CASE_0 = [{"keys": "image", "grid_size": (2, 2)}, {"image": A}, torch.stack([A11, A12, A21, A22])] +TEST_CASE_1 = [{"keys": "image", "grid_size": (2, 1)}, {"image": A}, torch.stack([A1, A2])] +TEST_CASE_2 = [{"keys": "image", "grid_size": (1, 2)}, {"image": A1}, torch.stack([A11, A12])] +TEST_CASE_3 = [{"keys": "image", "grid_size": (1, 2)}, {"image": A2}, torch.stack([A21, A22])] +TEST_CASE_4 = [{"keys": "image", "grid_size": (1, 1), "patch_size": (2, 2)}, {"image": A}, torch.stack([A11])] +TEST_CASE_5 = [{"keys": "image", "grid_size": 1, "patch_size": 4}, {"image": A}, torch.stack([A])] +TEST_CASE_6 = [{"keys": "image", "grid_size": 2, "patch_size": 2}, {"image": A}, torch.stack([A11, A12, A21, A22])] +TEST_CASE_7 = [{"keys": "image", "grid_size": 1}, {"image": A}, torch.stack([A])] +TEST_CASE_8 = [ + {"keys": "image", "grid_size": (2, 2), "patch_size": 2}, + {"image": torch.arange(12).reshape(1, 3, 4).to(torch.float32)}, + torch.Tensor([[[[0, 1], [4, 5]]], [[[2, 3], [6, 7]]], [[[4, 5], [8, 9]]], [[[6, 7], [10, 11]]]]).to(torch.float32), ] -TEST_CASE_6 = [ - {"keys": "image", "grid_size": 2, "patch_size": 2}, - {"image": A}, - torch.stack([A11, A12, A21, A22]), -] - -TEST_CASE_7 = [ - {"keys": "image", "grid_size": 1}, - {"image": A}, - torch.stack([A]), -] +TEST_SINGLE = [] +for p in TEST_NDARRAYS: + TEST_SINGLE.append([p, *TEST_CASE_0]) + TEST_SINGLE.append([p, *TEST_CASE_1]) + TEST_SINGLE.append([p, *TEST_CASE_2]) + TEST_SINGLE.append([p, *TEST_CASE_3]) + TEST_SINGLE.append([p, *TEST_CASE_4]) + TEST_SINGLE.append([p, *TEST_CASE_5]) + TEST_SINGLE.append([p, *TEST_CASE_6]) + TEST_SINGLE.append([p, *TEST_CASE_7]) + TEST_SINGLE.append([p, *TEST_CASE_8]) TEST_CASE_MC_0 = [ {"keys": "image", "grid_size": (2, 2)}, [{"image": A}, {"image": A}], [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])], ] - - TEST_CASE_MC_1 = [ {"keys": "image", "grid_size": (2, 1)}, - [{"image": A}] * 5, - [torch.stack([A1, A2])] * 5, + [{"image": A}, {"image": A}, {"image": A}], + [torch.stack([A1, A2])] * 3, ] - - TEST_CASE_MC_2 = [ {"keys": "image", "grid_size": (1, 2)}, [{"image": A1}, {"image": A2}], [torch.stack([A11, A12]), torch.stack([A21, A22])], ] +TEST_MULTIPLE = [] +for p in TEST_NDARRAYS: + TEST_MULTIPLE.append([p, *TEST_CASE_MC_0]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_1]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_2]) + class TestSplitOnGridDict(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - TEST_CASE_7, - ] - ) - def test_split_pathce_single_call(self, input_parameters, img_dict, expected): + @parameterized.expand(TEST_SINGLE) + def test_split_patch_single_call(self, in_type, input_parameters, img_dict, expected): + input_dict = {} + for k, v in img_dict.items(): + input_dict[k] = in_type(v) splitter = SplitOnGridDict(**input_parameters) - output = splitter(img_dict)[input_parameters["keys"]] - np.testing.assert_equal(output.numpy(), expected.numpy()) + output = splitter(input_dict)[input_parameters["keys"]] + assert_allclose(output, expected, type_test=False) - @parameterized.expand( - [ - TEST_CASE_MC_0, - TEST_CASE_MC_1, - TEST_CASE_MC_2, - ] - ) - def test_split_pathce_multiple_call(self, input_parameters, img_list, expected_list): + @parameterized.expand(TEST_MULTIPLE) + def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list): splitter = SplitOnGridDict(**input_parameters) for img_dict, expected in zip(img_list, expected_list): - output = splitter(img_dict)[input_parameters["keys"]] - np.testing.assert_equal(output.numpy(), expected.numpy()) + input_dict = {} + for k, v in img_dict.items(): + input_dict[k] = in_type(v) + output = splitter(input_dict)[input_parameters["keys"]] + assert_allclose(output, expected, type_test=False) if __name__ == "__main__": diff --git a/tests/test_squeezedim.py b/tests/test_squeezedim.py index 15ff7e94d6..8403efe836 100644 --- a/tests/test_squeezedim.py +++ b/tests/test_squeezedim.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_squeezedimd.py b/tests/test_squeezedimd.py index 35e7cd5d74..6baf4696a5 100644 --- a/tests/test_squeezedimd.py +++ b/tests/test_squeezedimd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_state_cacher.py b/tests/test_state_cacher.py index 139e7b8374..e4164be272 100644 --- a/tests/test_state_cacher.py +++ b/tests/test_state_cacher.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pickle import unittest from os.path import exists, join +from pathlib import Path from tempfile import gettempdir import torch @@ -20,20 +22,15 @@ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" -TEST_CASE_0 = [ - torch.Tensor([1]).to(DEVICE), - {"in_memory": True}, -] +TEST_CASE_0 = [torch.Tensor([1]).to(DEVICE), {"in_memory": True}] TEST_CASE_1 = [ torch.Tensor([1]).to(DEVICE), - {"in_memory": False, "cache_dir": gettempdir()}, -] -TEST_CASE_2 = [ - torch.Tensor([1]).to(DEVICE), - {"in_memory": False, "allow_overwrite": False}, + {"in_memory": False, "cache_dir": gettempdir(), "pickle_module": None, "pickle_protocol": pickle.HIGHEST_PROTOCOL}, ] +TEST_CASE_2 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "allow_overwrite": False}] +TEST_CASE_3 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "cache_dir": Path(gettempdir())}] -TEST_CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2] +TEST_CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3] class TestStateCacher(unittest.TestCase): @@ -44,7 +41,7 @@ def test_state_cacher(self, data_obj, params): state_cacher = StateCacher(**params) # store it - state_cacher.store(key, data_obj) + state_cacher.store(key, data_obj, pickle_module=pickle) # create clone then modify original data_obj_orig = data_obj.clone() data_obj += 1 diff --git a/tests/test_std_shift_intensity.py b/tests/test_std_shift_intensity.py index 5c16e14c45..55750161ec 100644 --- a/tests/test_std_shift_intensity.py +++ b/tests/test_std_shift_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_std_shift_intensityd.py b/tests/test_std_shift_intensityd.py index 4eb256f1e5..595da5cbc2 100644 --- a/tests/test_std_shift_intensityd.py +++ b/tests/test_std_shift_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_subpixel_upsample.py b/tests/test_subpixel_upsample.py index 07e110d7a7..0216f164c3 100644 --- a/tests/test_subpixel_upsample.py +++ b/tests/test_subpixel_upsample.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,37 +24,30 @@ for dim in range(1, 4): for factor in range(1, 3): test_case = [ - {"dimensions": dim, "in_channels": inch, "scale_factor": factor}, + {"spatial_dims": dim, "in_channels": inch, "scale_factor": factor}, (2, inch, *([8] * dim)), (2, inch, *([8 * factor] * dim)), ] TEST_CASE_SUBPIXEL.append(test_case) TEST_CASE_SUBPIXEL_2D_EXTRA = [ - {"dimensions": 2, "in_channels": 2, "scale_factor": 3}, + {"spatial_dims": 2, "in_channels": 2, "scale_factor": 3}, (2, 2, 8, 4), # different size for H and W (2, 2, 24, 12), ] TEST_CASE_SUBPIXEL_3D_EXTRA = [ - {"dimensions": 3, "in_channels": 1, "scale_factor": 2}, + {"spatial_dims": 3, "in_channels": 1, "scale_factor": 2}, (2, 1, 16, 8, 4), # different size for H, W and D (2, 1, 32, 16, 8), ] conv_block = nn.Sequential( - Conv[Conv.CONV, 3](1, 4, kernel_size=1), - Conv[Conv.CONV, 3]( - 4, - 8, - kernel_size=3, - stride=1, - padding=1, - ), + Conv[Conv.CONV, 3](1, 4, kernel_size=1), Conv[Conv.CONV, 3](4, 8, kernel_size=3, stride=1, padding=1) ) TEST_CASE_SUBPIXEL_CONV_BLOCK_EXTRA = [ - {"dimensions": 3, "in_channels": 1, "scale_factor": 2, "conv_block": conv_block}, + {"spatial_dims": 3, "in_channels": 1, "scale_factor": 2, "conv_block": conv_block}, (2, 1, 16, 8, 4), # different size for H, W and D (2, 1, 32, 16, 8), ] diff --git a/tests/test_surface_distance.py b/tests/test_surface_distance.py index e5d2145a1f..edfe9e8663 100644 --- a/tests/test_surface_distance.py +++ b/tests/test_surface_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,9 +20,7 @@ def create_spherical_seg_3d( - radius: float = 20.0, - centre: Tuple[int, int, int] = (49, 49, 49), - im_shape: Tuple[int, int, int] = (99, 99, 99), + radius: float = 20.0, centre: Tuple[int, int, int] = (49, 49, 49), im_shape: Tuple[int, int, int] = (99, 99, 99) ) -> np.ndarray: """ Return a 3D image with a sphere inside. Voxel values will be @@ -49,10 +47,7 @@ def create_spherical_seg_3d( TEST_CASES = [ - [ - [create_spherical_seg_3d(), create_spherical_seg_3d()], - [0, 0], - ], + [[create_spherical_seg_3d(), create_spherical_seg_3d()], [0, 0]], [ [ create_spherical_seg_3d(radius=20, centre=(20, 20, 20)), @@ -66,14 +61,14 @@ def create_spherical_seg_3d( create_spherical_seg_3d(radius=33, centre=(19, 33, 22)), create_spherical_seg_3d(radius=33, centre=(20, 33, 22)), ], - [0.35021200688332677, 0.3483278807706289], + [0.350217, 0.3483278807706289], ], [ [ create_spherical_seg_3d(radius=20, centre=(20, 33, 22)), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), ], - [13.975673696300824, 12.040033513150455], + [15.117741, 12.040033513150455], ], [ [ @@ -81,7 +76,7 @@ def create_spherical_seg_3d( create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), "chessboard", ], - [10.792254295459173, 9.605067064083457], + [11.492719, 9.605067064083457], ], [ [ @@ -89,23 +84,10 @@ def create_spherical_seg_3d( create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), "taxicab", ], - [17.32691760951026, 12.432687531048186], - ], - [ - [ - np.zeros([99, 99, 99]), - create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), - ], - [np.inf, np.inf], - ], - [ - [ - create_spherical_seg_3d(), - np.zeros([99, 99, 99]), - "taxicab", - ], - [np.inf, np.inf], + [20.214613, 12.432687531048186], ], + [[np.zeros([99, 99, 99]), create_spherical_seg_3d(radius=40, centre=(20, 33, 22))], [np.inf, np.inf]], + [[create_spherical_seg_3d(), np.zeros([99, 99, 99]), "taxicab"], [np.inf, np.inf]], ] TEST_CASES_NANS = [ @@ -114,8 +96,8 @@ def create_spherical_seg_3d( # both pred and gt do not have foreground, metric and not_nans should be 0 np.zeros([99, 99, 99]), np.zeros([99, 99, 99]), - ], - ], + ] + ] ] @@ -139,7 +121,7 @@ def test_value(self, input_data, expected_value): sur_metric(batch_seg_1, batch_seg_2) result = sur_metric.aggregate() expected_value_curr = expected_value[ct] - np.testing.assert_allclose(expected_value_curr, result, rtol=1e-7) + np.testing.assert_allclose(expected_value_curr, result, rtol=1e-5) ct += 1 @parameterized.expand(TEST_CASES_NANS) @@ -153,8 +135,8 @@ def test_nans(self, input_data): batch_seg_2 = [seg_2.unsqueeze(0)] sur_metric(batch_seg_1, batch_seg_2) result, not_nans = sur_metric.aggregate() - np.testing.assert_allclose(0, result, rtol=1e-7) - np.testing.assert_allclose(0, not_nans, rtol=1e-7) + np.testing.assert_allclose(0, result, rtol=1e-5) + np.testing.assert_allclose(0, not_nans, rtol=1e-5) if __name__ == "__main__": diff --git a/tests/test_synthetic.py b/tests/test_synthetic.py index 97ab12a588..fadf6255ff 100644 --- a/tests/test_synthetic.py +++ b/tests/test_synthetic.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,29 +18,10 @@ from monai.utils import set_determinism TEST_CASES = [ + [2, {"width": 64, "height": 64, "rad_max": 10, "rad_min": 4}, 0.1479004, 0.739502, (64, 64), 5], [ 2, - { - "width": 64, - "height": 64, - "rad_max": 10, - "rad_min": 4, - }, - 0.1479004, - 0.739502, - (64, 64), - 5, - ], - [ - 2, - { - "width": 32, - "height": 28, - "num_objs": 3, - "rad_max": 5, - "rad_min": 1, - "noise_max": 0.2, - }, + {"width": 32, "height": 28, "num_objs": 3, "rad_max": 5, "rad_min": 1, "noise_max": 0.2}, 0.1709315, 0.4040179, (32, 28), @@ -48,15 +29,7 @@ ], [ 3, - { - "width": 64, - "height": 64, - "depth": 45, - "num_seg_classes": 3, - "channel_dim": -1, - "rad_max": 10, - "rad_min": 4, - }, + {"width": 64, "height": 64, "depth": 45, "num_seg_classes": 3, "channel_dim": -1, "rad_max": 10, "rad_min": 4}, 0.025132, 0.0753961, (64, 64, 45, 1), diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index a07d59703d..21186adc3c 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,10 +21,20 @@ from monai.data.utils import pad_list_data_collate from monai.losses import DiceLoss from monai.networks.nets import UNet -from monai.transforms import Activations, AddChanneld, AsDiscrete, Compose, CropForegroundd, DivisiblePadd, RandAffined +from monai.transforms import ( + Activations, + AddChanneld, + AsDiscrete, + Compose, + CropForegroundd, + DivisiblePadd, + RandAffined, + RandScaleIntensityd, +) from monai.transforms.croppad.dictionary import SpatialPadd -from monai.transforms.spatial.dictionary import Rand2DElasticd, RandFlipd, Spacingd +from monai.transforms.spatial.dictionary import RandFlipd, Spacingd from monai.utils import optional_import, set_determinism +from monai.utils.enums import PostFix from tests.utils import TEST_NDARRAYS if TYPE_CHECKING: @@ -46,14 +56,14 @@ def get_data(num_examples, input_size, data_type=np.asarray, include_label=True) create_test_image_2d, *input_size, rad_max=7, num_seg_classes=1, num_objs=1 ) data = [] - for _ in range(num_examples): + for i in range(num_examples): im, label = custom_create_test_image_2d() d = {} - d["image"] = data_type(im) - d["image_meta_dict"] = {"affine": np.eye(4)} + d["image"] = data_type(im[:, i:]) + d[PostFix.meta("image")] = {"affine": np.eye(4)} if include_label: - d["label"] = data_type(label) - d["label_meta_dict"] = {"affine": np.eye(4)} + d["label"] = data_type(label[:, i:]) + d[PostFix.meta("label")] = {"affine": np.eye(4)} data.append(d) return data[0] if num_examples == 1 else data @@ -64,12 +74,13 @@ def tearDown(self) -> None: set_determinism(None) def test_test_time_augmentation(self): - input_size = (20, 20) - device = "cuda" if torch.cuda.is_available() else "cpu" + input_size = (20, 40) # test different input data shape to pad list collate keys = ["image", "label"] num_training_ims = 10 + train_data = self.get_data(num_training_ims, input_size) test_data = self.get_data(1, input_size) + device = "cuda" if torch.cuda.is_available() else "cpu" transforms = Compose( [ @@ -113,34 +124,46 @@ def test_test_time_augmentation(self): epoch_loss /= len(train_loader) - post_trans = Compose( - [ - Activations(sigmoid=True), - AsDiscrete(threshold_values=True), - ] + post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) + + tt_aug = TestTimeAugmentation( + transform=transforms, + batch_size=5, + num_workers=0, + inferrer_fn=model, + device=device, + to_tensor=True, + output_device="cpu", + post_func=post_trans, ) - - def inferrer_fn(x): - return post_trans(model(x)) - - tt_aug = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=inferrer_fn, device=device) mode, mean, std, vvc = tt_aug(test_data) self.assertEqual(mode.shape, (1,) + input_size) self.assertEqual(mean.shape, (1,) + input_size) self.assertTrue(all(np.unique(mode) == (0, 1))) - self.assertEqual((mean.min(), mean.max()), (0.0, 1.0)) + self.assertGreaterEqual(mean.min(), 0.0) + self.assertLessEqual(mean.max(), 1.0) self.assertEqual(std.shape, (1,) + input_size) self.assertIsInstance(vvc, float) - def test_fail_non_random(self): + def test_warn_non_random(self): transforms = Compose([AddChanneld("im"), SpatialPadd("im", 1)]) - with self.assertRaises(RuntimeError): + with self.assertWarns(UserWarning): TestTimeAugmentation(transforms, None, None, None) - def test_fail_random_but_not_invertible(self): - transforms = Compose([AddChanneld("im"), Rand2DElasticd("im", None, None)]) - with self.assertRaises(RuntimeError): - TestTimeAugmentation(transforms, None, None, None) + def test_warn_random_but_has_no_invertible(self): + transforms = Compose( + [AddChanneld("image"), RandFlipd("image", prob=1.0), RandScaleIntensityd("image", 0.1, prob=1.0)] + ) + with self.assertWarns(UserWarning): + tta = TestTimeAugmentation(transforms, 5, 0, orig_key="image") + tta(self.get_data(1, (20, 20), data_type=np.float32)) + + def test_warn_random_but_all_not_invertible(self): + """test with no invertible stack""" + transforms = Compose([AddChanneld("image"), RandScaleIntensityd("image", 0.1, prob=1.0)]) + with self.assertWarns(UserWarning): + tta = TestTimeAugmentation(transforms, 1, 0, orig_key="image") + tta(self.get_data(1, (20, 20), data_type=np.float32)) def test_single_transform(self): for p in TEST_NDARRAYS: @@ -155,7 +178,7 @@ def test_image_no_label(self): @unittest.skipUnless(has_nib, "Requires nibabel") def test_requires_meta_dict(self): - transforms = Compose([RandFlipd("image"), Spacingd("image", pixdim=1.0)]) + transforms = Compose([AddChanneld("image"), RandFlipd("image"), Spacingd("image", pixdim=1.1)]) tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x, orig_key="image") tta(self.get_data(1, (20, 20), include_label=False)) diff --git a/tests/test_thread_buffer.py b/tests/test_thread_buffer.py index 507b6909be..04511220f8 100644 --- a/tests/test_thread_buffer.py +++ b/tests/test_thread_buffer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -78,6 +78,22 @@ def test_time(self): f"Buffered time {buffered_time} should be less than unbuffered time {unbuffered_time}", ) + def test_dataloader_repeats(self): + dataset = Dataset(data=self.datalist, transform=self.transform) + dataloader = ThreadDataLoader(dataset=dataset, batch_size=2, num_workers=0, repeats=2) + + previous_batch = None + + for d in dataloader: + self.assertEqual(d["image"][0], "spleen_19.nii.gz") + self.assertEqual(d["image"][1], "spleen_31.nii.gz") + + if previous_batch is None: + previous_batch = d + else: + self.assertTrue(previous_batch is d, "Batch object was not repeated") + previous_batch = None + if __name__ == "__main__": unittest.main() diff --git a/tests/test_threadcontainer.py b/tests/test_threadcontainer.py index 543dab4d0c..2419b390fd 100644 --- a/tests/test_threadcontainer.py +++ b/tests/test_threadcontainer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -79,7 +79,7 @@ def test_plot(self): # a third non-image key is added to test that this is correctly ignored when plotting data = {CommonKeys.IMAGE: img, CommonKeys.LABEL: img, "Not Image Data": ["This isn't an image"]} - loader = DataLoader([data] * 10) + loader = DataLoader([data] * 20, batch_size=2) trainer = SupervisedTrainer( device=torch.device("cpu"), @@ -102,7 +102,7 @@ def test_plot(self): with tempfile.TemporaryDirectory() as tempdir: tempimg = f"{tempdir}/threadcontainer_plot_test.png" fig.savefig(tempimg) - comp = compare_images(f"{testing_dir}/threadcontainer_plot_test.png", tempimg, 1e-2) + comp = compare_images(f"{testing_dir}/threadcontainer_plot_test.png", tempimg, 5e-2) self.assertIsNone(comp, comp) # None indicates test passed diff --git a/tests/test_threshold_intensity.py b/tests/test_threshold_intensity.py index a6d3895709..01321f1b0b 100644 --- a/tests/test_threshold_intensity.py +++ b/tests/test_threshold_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,20 +15,21 @@ from parameterized import parameterized from monai.transforms import ThresholdIntensity +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [{"threshold": 5, "above": True, "cval": 0}, (0, 0, 0, 0, 0, 0, 6, 7, 8, 9)] - -TEST_CASE_2 = [{"threshold": 5, "above": False, "cval": 0}, (0, 1, 2, 3, 4, 0, 0, 0, 0, 0)] - -TEST_CASE_3 = [{"threshold": 5, "above": True, "cval": 5}, (5, 5, 5, 5, 5, 5, 6, 7, 8, 9)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"threshold": 5, "above": True, "cval": 0}, (0, 0, 0, 0, 0, 0, 6, 7, 8, 9)]) + TESTS.append([p, {"threshold": 5, "above": False, "cval": 0}, (0, 1, 2, 3, 4, 0, 0, 0, 0, 0)]) + TESTS.append([p, {"threshold": 5, "above": True, "cval": 5}, (5, 5, 5, 5, 5, 5, 6, 7, 8, 9)]) class TestThresholdIntensity(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_value(self, input_param, expected_value): - test_data = np.arange(10) + @parameterized.expand(TESTS) + def test_value(self, in_type, input_param, expected_value): + test_data = in_type(np.arange(10)) result = ThresholdIntensity(**input_param)(test_data) - np.testing.assert_allclose(result, expected_value) + assert_allclose(result, in_type(expected_value)) if __name__ == "__main__": diff --git a/tests/test_threshold_intensityd.py b/tests/test_threshold_intensityd.py index efcfcfe604..e0610ebb5b 100644 --- a/tests/test_threshold_intensityd.py +++ b/tests/test_threshold_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,31 +15,41 @@ from parameterized import parameterized from monai.transforms import ThresholdIntensityd - -TEST_CASE_1 = [ - {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 0}, - (0, 0, 0, 0, 0, 0, 6, 7, 8, 9), -] - -TEST_CASE_2 = [ - {"keys": ["image", "label", "extra"], "threshold": 5, "above": False, "cval": 0}, - (0, 1, 2, 3, 4, 0, 0, 0, 0, 0), -] - -TEST_CASE_3 = [ - {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 5}, - (5, 5, 5, 5, 5, 5, 6, 7, 8, 9), -] +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + p, + {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 0}, + (0, 0, 0, 0, 0, 0, 6, 7, 8, 9), + ] + ) + TESTS.append( + [ + p, + {"keys": ["image", "label", "extra"], "threshold": 5, "above": False, "cval": 0}, + (0, 1, 2, 3, 4, 0, 0, 0, 0, 0), + ] + ) + TESTS.append( + [ + p, + {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 5}, + (5, 5, 5, 5, 5, 5, 6, 7, 8, 9), + ] + ) class TestThresholdIntensityd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_value(self, input_param, expected_value): - test_data = {"image": np.arange(10), "label": np.arange(10), "extra": np.arange(10)} + @parameterized.expand(TESTS) + def test_value(self, in_type, input_param, expected_value): + test_data = {"image": in_type(np.arange(10)), "label": in_type(np.arange(10)), "extra": in_type(np.arange(10))} result = ThresholdIntensityd(**input_param)(test_data) - np.testing.assert_allclose(result["image"], expected_value) - np.testing.assert_allclose(result["label"], expected_value) - np.testing.assert_allclose(result["extra"], expected_value) + assert_allclose(result["image"], in_type(expected_value)) + assert_allclose(result["label"], in_type(expected_value)) + assert_allclose(result["extra"], in_type(expected_value)) if __name__ == "__main__": diff --git a/tests/test_tile_on_grid.py b/tests/test_tile_on_grid.py new file mode 100644 index 0000000000..09434de5e0 --- /dev/null +++ b/tests/test_tile_on_grid.py @@ -0,0 +1,141 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Optional + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms import TileOnGrid +from tests.utils import TEST_NDARRAYS, assert_allclose + +TEST_CASES = [] +for tile_count in [16, 64]: + for tile_size in [8, 32]: + for filter_mode in ["min", "max", "random"]: + for background_val in [255, 0]: + TEST_CASES.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": False, + "background_val": background_val, + } + ] + ) + +for tile_size in [8, 16]: + for step in [4, 8]: + TEST_CASES.append([{"tile_count": 16, "step": step, "tile_size": tile_size}]) + +TESTS = [] +for p in TEST_NDARRAYS: + for tc in TEST_CASES: + TESTS.append([p, *tc]) + +TEST_CASES2 = [] +for tile_count in [16, 64]: + for tile_size in [8, 32]: + for filter_mode in ["min", "max", "random"]: + for background_val in [255, 0]: + TEST_CASES2.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": True, + "background_val": background_val, + } + ] + ) + +TESTS2 = [] +for p in TEST_NDARRAYS: + for tc in TEST_CASES2: + TESTS2.append([p, *tc]) + + +def make_image( + tile_count: int, + tile_size: int, + step: int = 0, + random_offset: bool = False, + filter_mode: Optional[str] = None, + seed=123, + **kwargs, +): + + tile_count = int(np.sqrt(tile_count)) + pad = 0 + if random_offset: + pad = 3 + + if step == 0: + step = tile_size + + image = np.random.randint( + 200, + size=[3, (tile_count - 1) * step + tile_size + pad, (tile_count - 1) * step + tile_size + pad], + dtype=np.uint8, + ) + imlarge = image + + random_state = np.random.RandomState(seed) + + if random_offset: + pad_h = image.shape[1] % tile_size + pad_w = image.shape[2] % tile_size + offset = (random_state.randint(pad_h) if pad_h > 0 else 0, random_state.randint(pad_w) if pad_w > 0 else 0) + image = image[:, offset[0] :, offset[1] :] + + tiles_list = [] + for x in range(tile_count): + for y in range(tile_count): + tiles_list.append(image[:, x * step : x * step + tile_size, y * step : y * step + tile_size]) + + tiles = np.stack(tiles_list, axis=0) + + if (filter_mode == "min" or filter_mode == "max") and len(tiles) > tile_count**2: + tiles = tiles[np.argsort(tiles.sum(axis=(1, 2, 3)))] + + return imlarge, tiles + + +class TestTileOnGrid(unittest.TestCase): + @parameterized.expand(TESTS) + def test_tile_patch_single_call(self, in_type, input_parameters): + + img, tiles = make_image(**input_parameters) + input_img = in_type(img) + + tiler = TileOnGrid(**input_parameters) + output = tiler(input_img) + assert_allclose(output, tiles, type_test=False) + + @parameterized.expand(TESTS2) + def test_tile_patch_random_call(self, in_type, input_parameters): + + img, tiles = make_image(**input_parameters, seed=123) + input_img = in_type(img) + + tiler = TileOnGrid(**input_parameters) + tiler.set_random_state(seed=123) + + output = tiler(input_img) + assert_allclose(output, tiles, type_test=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_tile_on_grid_dict.py b/tests/test_tile_on_grid_dict.py new file mode 100644 index 0000000000..c6f35fe738 --- /dev/null +++ b/tests/test_tile_on_grid_dict.py @@ -0,0 +1,176 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Optional + +import numpy as np +import torch +from parameterized import parameterized + +from monai.apps.pathology.transforms import TileOnGridDict +from tests.utils import TEST_NDARRAYS, assert_allclose + +TEST_CASES = [] +for tile_count in [16, 64]: + for tile_size in [8, 32]: + for filter_mode in ["min", "max", "random"]: + for background_val in [255, 0]: + for return_list_of_dicts in [False, True]: + TEST_CASES.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": False, + "background_val": background_val, + "return_list_of_dicts": return_list_of_dicts, + } + ] + ) + +for tile_size in [8, 16]: + for step in [4, 8]: + TEST_CASES.append([{"tile_count": 16, "step": step, "tile_size": tile_size}]) + +TESTS = [] +for p in TEST_NDARRAYS: + for tc in TEST_CASES: + TESTS.append([p, *tc]) + +TEST_CASES2 = [] +for tile_count in [16, 64]: + for tile_size in [8, 32]: + for filter_mode in ["min", "max", "random"]: + for background_val in [255, 0]: + for return_list_of_dicts in [False, True]: + TEST_CASES2.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": True, + "background_val": background_val, + "return_list_of_dicts": return_list_of_dicts, + } + ] + ) + +TESTS2 = [] +for p in TEST_NDARRAYS: + for tc in TEST_CASES2: + TESTS2.append([p, *tc]) + +for tile_size in [8, 16]: + for step in [4, 8]: + TEST_CASES.append([{"tile_count": 16, "step": step, "tile_size": tile_size}]) + + +def make_image( + tile_count: int, + tile_size: int, + step: int = 0, + random_offset: bool = False, + filter_mode: Optional[str] = None, + seed=123, + **kwargs, +): + + tile_count = int(np.sqrt(tile_count)) + pad = 0 + if random_offset: + pad = 3 + + if step == 0: + step = tile_size + + image = np.random.randint( + 200, + size=[3, (tile_count - 1) * step + tile_size + pad, (tile_count - 1) * step + tile_size + pad], + dtype=np.uint8, + ) + imlarge = image + + random_state = np.random.RandomState(seed) + + if random_offset: + pad_h = image.shape[1] % tile_size + pad_w = image.shape[2] % tile_size + offset = (random_state.randint(pad_h) if pad_h > 0 else 0, random_state.randint(pad_w) if pad_w > 0 else 0) + image = image[:, offset[0] :, offset[1] :] + + tiles_list = [] + for x in range(tile_count): + for y in range(tile_count): + tiles_list.append(image[:, x * step : x * step + tile_size, y * step : y * step + tile_size]) + + tiles = np.stack(tiles_list, axis=0) + + if (filter_mode == "min" or filter_mode == "max") and len(tiles) > tile_count**2: + tiles = tiles[np.argsort(tiles.sum(axis=(1, 2, 3)))] + + return imlarge, tiles + + +class TestTileOnGridDict(unittest.TestCase): + @parameterized.expand(TESTS) + def test_tile_patch_single_call(self, in_type, input_parameters): + + key = "image" + input_parameters["keys"] = key + + img, tiles = make_image(**input_parameters) + input_img = in_type(img) + + splitter = TileOnGridDict(**input_parameters) + + output = splitter({key: input_img}) + + if input_parameters.get("return_list_of_dicts", False): + if isinstance(input_img, torch.Tensor): + output = torch.stack([ix[key] for ix in output], axis=0) + else: + output = np.stack([ix[key] for ix in output], axis=0) + else: + output = output[key] + + assert_allclose(output, tiles, type_test=False) + + @parameterized.expand(TESTS2) + def test_tile_patch_random_call(self, in_type, input_parameters): + + key = "image" + input_parameters["keys"] = key + + random_state = np.random.RandomState(123) + seed = random_state.randint(10000) + img, tiles = make_image(**input_parameters, seed=seed) + input_img = in_type(img) + + splitter = TileOnGridDict(**input_parameters) + splitter.set_random_state(seed=123) + + output = splitter({key: input_img}) + + if input_parameters.get("return_list_of_dicts", False): + if isinstance(input_img, torch.Tensor): + output = torch.stack([ix[key] for ix in output], axis=0) + else: + output = np.stack([ix[key] for ix in output], axis=0) + else: + output = output[key] + assert_allclose(output, tiles, type_test=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_timedcall.py b/tests/test_timedcall_dist.py similarity index 95% rename from tests/test_timedcall.py rename to tests/test_timedcall_dist.py index de10abb8f7..a2b3ae585a 100644 --- a/tests/test_timedcall.py +++ b/tests/test_timedcall_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,7 +17,7 @@ from tests.utils import TimedCall -@TimedCall(seconds=10 if sys.platform == "linux" else 60, force_quit=False) +@TimedCall(seconds=20 if sys.platform == "linux" else 60, force_quit=False) def case_1_seconds(arg=None): time.sleep(1) return "good" if not arg else arg diff --git a/tests/test_to_contiguous.py b/tests/test_to_contiguous.py new file mode 100644 index 0000000000..a9c2a78278 --- /dev/null +++ b/tests/test_to_contiguous.py @@ -0,0 +1,44 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from monai.transforms import convert_to_contiguous +from tests.utils import assert_allclose + + +class TestToContiguous(unittest.TestCase): + def test_decollation_dict(self): + tochange = np.moveaxis(np.zeros((2, 3, 4)), 0, -1) + test_dict = {"test_key": [[1]], 0: np.array(0), 1: np.array([0]), "nested": {"nested": [tochange]}} + output = convert_to_contiguous(test_dict) + self.assertEqual(output["test_key"], [[1]]) + assert_allclose(output[0], np.array(0)) + assert_allclose(output[1], np.array([0])) + self.assertTrue(output["nested"]["nested"][0].flags.c_contiguous) + + def test_decollation_seq(self): + tochange = torch.zeros(2, 3, 4).transpose(0, 1) + test_dict = [[[1]], np.array(0), np.array([0]), torch.tensor(1.0), [[tochange]], "test_string"] + output = convert_to_contiguous(test_dict) + self.assertEqual(output[0], [[1]]) + assert_allclose(output[1], np.array(0)) + assert_allclose(output[2], np.array([0])) + assert_allclose(output[3], torch.tensor(1.0)) + self.assertTrue(output[4][0][0].is_contiguous()) + self.assertEqual(output[5], "test_string") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_to_cupy.py b/tests/test_to_cupy.py index 8b00e12539..36edf24f3f 100644 --- a/tests/test_to_cupy.py +++ b/tests/test_to_cupy.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,60 +17,92 @@ from monai.transforms import ToCupy from monai.utils import optional_import -from tests.utils import skip_if_no_cuda +from tests.utils import HAS_CUPY, skip_if_no_cuda -cp, has_cp = optional_import("cupy") +cp, _ = optional_import("cupy") +@skipUnless(HAS_CUPY, "CuPy is required.") class TestToCupy(unittest.TestCase): - @skipUnless(has_cp, "CuPy is required.") def test_cupy_input(self): - test_data = cp.array([[1, 2], [3, 4]]) + test_data = cp.array([[1, 2], [3, 4]], dtype=cp.float32) test_data = cp.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) result = ToCupy()(test_data) + self.assertTrue(result.dtype == cp.float32) + self.assertTrue(isinstance(result, cp.ndarray)) + self.assertTrue(result.flags["C_CONTIGUOUS"]) + cp.testing.assert_allclose(result, test_data) + + def test_cupy_input_dtype(self): + test_data = cp.array([[1, 2], [3, 4]], dtype=cp.float32) + test_data = cp.rot90(test_data) + self.assertFalse(test_data.flags["C_CONTIGUOUS"]) + result = ToCupy(cp.uint8)(test_data) + self.assertTrue(result.dtype == cp.uint8) self.assertTrue(isinstance(result, cp.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data) - @skipUnless(has_cp, "CuPy is required.") def test_numpy_input(self): - test_data = np.array([[1, 2], [3, 4]]) + test_data = np.array([[1, 2], [3, 4]], dtype=np.float32) test_data = np.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) result = ToCupy()(test_data) + self.assertTrue(result.dtype == cp.float32) + self.assertTrue(isinstance(result, cp.ndarray)) + self.assertTrue(result.flags["C_CONTIGUOUS"]) + cp.testing.assert_allclose(result, test_data) + + def test_numpy_input_dtype(self): + test_data = np.array([[1, 2], [3, 4]], dtype=np.float32) + test_data = np.rot90(test_data) + self.assertFalse(test_data.flags["C_CONTIGUOUS"]) + result = ToCupy(np.uint8)(test_data) + self.assertTrue(result.dtype == cp.uint8) self.assertTrue(isinstance(result, cp.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data) - @skipUnless(has_cp, "CuPy is required.") def test_tensor_input(self): - test_data = torch.tensor([[1, 2], [3, 4]]) + test_data = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) test_data = test_data.rot90() self.assertFalse(test_data.is_contiguous()) result = ToCupy()(test_data) + self.assertTrue(result.dtype == cp.float32) self.assertTrue(isinstance(result, cp.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - cp.testing.assert_allclose(result, test_data.numpy()) + cp.testing.assert_allclose(result, test_data) - @skipUnless(has_cp, "CuPy is required.") @skip_if_no_cuda def test_tensor_cuda_input(self): - test_data = torch.tensor([[1, 2], [3, 4]]).cuda() + test_data = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32).cuda() test_data = test_data.rot90() self.assertFalse(test_data.is_contiguous()) result = ToCupy()(test_data) + self.assertTrue(result.dtype == cp.float32) + self.assertTrue(isinstance(result, cp.ndarray)) + self.assertTrue(result.flags["C_CONTIGUOUS"]) + cp.testing.assert_allclose(result, test_data) + + @skip_if_no_cuda + def test_tensor_cuda_input_dtype(self): + test_data = torch.tensor([[1, 2], [3, 4]], dtype=torch.uint8).cuda() + test_data = test_data.rot90() + self.assertFalse(test_data.is_contiguous()) + + result = ToCupy(dtype="float32")(test_data) + self.assertTrue(result.dtype == cp.float32) self.assertTrue(isinstance(result, cp.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - cp.testing.assert_allclose(result, test_data.cpu().numpy()) + cp.testing.assert_allclose(result, test_data) - @skipUnless(has_cp, "CuPy is required.") def test_list_tuple(self): test_data = [[1, 2], [3, 4]] - result = ToCupy()(test_data) + result = ToCupy(wrap_sequence=True)(test_data) cp.testing.assert_allclose(result, cp.asarray(test_data)) test_data = ((1, 2), (3, 4)) - result = ToCupy()(test_data) + result = ToCupy(wrap_sequence=True)(test_data) cp.testing.assert_allclose(result, cp.asarray(test_data)) diff --git a/tests/test_to_cupyd.py b/tests/test_to_cupyd.py index 6f40bafe1c..3e778ae269 100644 --- a/tests/test_to_cupyd.py +++ b/tests/test_to_cupyd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,13 +17,13 @@ from monai.transforms import ToCupyd from monai.utils import optional_import -from tests.utils import skip_if_no_cuda +from tests.utils import HAS_CUPY, skip_if_no_cuda -cp, has_cp = optional_import("cupy") +cp, _ = optional_import("cupy") +@skipUnless(HAS_CUPY, "CuPy is required.") class TestToCupyd(unittest.TestCase): - @skipUnless(has_cp, "CuPy is required.") def test_cupy_input(self): test_data = cp.array([[1, 2], [3, 4]]) test_data = cp.rot90(test_data) @@ -33,7 +33,6 @@ def test_cupy_input(self): self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data) - @skipUnless(has_cp, "CuPy is required.") def test_numpy_input(self): test_data = np.array([[1, 2], [3, 4]]) test_data = np.rot90(test_data) @@ -43,7 +42,6 @@ def test_numpy_input(self): self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data) - @skipUnless(has_cp, "CuPy is required.") def test_tensor_input(self): test_data = torch.tensor([[1, 2], [3, 4]]) test_data = test_data.rot90() @@ -53,7 +51,6 @@ def test_tensor_input(self): self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data.numpy()) - @skipUnless(has_cp, "CuPy is required.") @skip_if_no_cuda def test_tensor_cuda_input(self): test_data = torch.tensor([[1, 2], [3, 4]]).cuda() @@ -64,13 +61,12 @@ def test_tensor_cuda_input(self): self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data.cpu().numpy()) - @skipUnless(has_cp, "CuPy is required.") def test_list_tuple(self): test_data = [[1, 2], [3, 4]] - result = ToCupyd(keys="img")({"img": test_data})["img"] + result = ToCupyd(keys="img", wrap_sequence=True)({"img": test_data})["img"] cp.testing.assert_allclose(result, cp.asarray(test_data)) test_data = ((1, 2), (3, 4)) - result = ToCupyd(keys="img")({"img": test_data})["img"] + result = ToCupyd(keys="img", wrap_sequence=True)({"img": test_data})["img"] cp.testing.assert_allclose(result, cp.asarray(test_data)) diff --git a/tests/test_to_device.py b/tests/test_to_device.py index 9855a353f0..70f1ea8828 100644 --- a/tests/test_to_device.py +++ b/tests/test_to_device.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_to_deviced.py b/tests/test_to_deviced.py index 0d5d1d1cdc..7d075ad365 100644 --- a/tests/test_to_deviced.py +++ b/tests/test_to_deviced.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,9 +24,7 @@ def test_value(self): device = "cuda:0" data = [{"img": torch.tensor(i)} for i in range(4)] dataset = CacheDataset( - data=data, - transform=ToDeviced(keys="img", device=device, non_blocking=True), - cache_rate=1.0, + data=data, transform=ToDeviced(keys="img", device=device, non_blocking=True), cache_rate=1.0 ) dataloader = ThreadDataLoader(dataset=dataset, num_workers=0, batch_size=1) for i, d in enumerate(dataloader): diff --git a/tests/test_to_numpy.py b/tests/test_to_numpy.py index b48727c01d..e1f135a289 100644 --- a/tests/test_to_numpy.py +++ b/tests/test_to_numpy.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,13 +17,13 @@ from monai.transforms import ToNumpy from monai.utils import optional_import -from tests.utils import assert_allclose, skip_if_no_cuda +from tests.utils import HAS_CUPY, assert_allclose, skip_if_no_cuda -cp, has_cp = optional_import("cupy") +cp, _ = optional_import("cupy") class TestToNumpy(unittest.TestCase): - @skipUnless(has_cp, "CuPy is required.") + @skipUnless(HAS_CUPY, "CuPy is required.") def test_cupy_input(self): test_data = cp.array([[1, 2], [3, 4]]) test_data = cp.rot90(test_data) @@ -31,25 +31,26 @@ def test_cupy_input(self): result = ToNumpy()(test_data) self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data.get()) + assert_allclose(result, test_data.get(), type_test=False) def test_numpy_input(self): test_data = np.array([[1, 2], [3, 4]]) test_data = np.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) - result = ToNumpy()(test_data) + result = ToNumpy(dtype="float32")(test_data) self.assertTrue(isinstance(result, np.ndarray)) + self.assertTrue(result.dtype == np.float32) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) def test_tensor_input(self): test_data = torch.tensor([[1, 2], [3, 4]]) test_data = test_data.rot90() self.assertFalse(test_data.is_contiguous()) - result = ToNumpy()(test_data) + result = ToNumpy(dtype=torch.uint8)(test_data) self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) @skip_if_no_cuda def test_tensor_cuda_input(self): @@ -59,21 +60,22 @@ def test_tensor_cuda_input(self): result = ToNumpy()(test_data) self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) def test_list_tuple(self): test_data = [[1, 2], [3, 4]] result = ToNumpy()(test_data) - assert_allclose(result, np.asarray(test_data)) + assert_allclose(result, np.asarray(test_data), type_test=False) test_data = ((1, 2), (3, 4)) - result = ToNumpy()(test_data) - assert_allclose(result, np.asarray(test_data)) + result = ToNumpy(wrap_sequence=False)(test_data) + self.assertTrue(type(result), tuple) + assert_allclose(result, ((np.asarray(1), np.asarray(2)), (np.asarray(3), np.asarray(4)))) def test_single_value(self): for test_data in [5, np.array(5), torch.tensor(5)]: - result = ToNumpy()(test_data) + result = ToNumpy(dtype=np.uint8)(test_data) self.assertTrue(isinstance(result, np.ndarray)) - assert_allclose(result, np.asarray(test_data)) + assert_allclose(result, np.asarray(test_data), type_test=False) self.assertEqual(result.ndim, 0) diff --git a/tests/test_to_numpyd.py b/tests/test_to_numpyd.py index 5acaef39c7..ba7cf798ef 100644 --- a/tests/test_to_numpyd.py +++ b/tests/test_to_numpyd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,13 +17,13 @@ from monai.transforms import ToNumpyd from monai.utils import optional_import -from tests.utils import assert_allclose, skip_if_no_cuda +from tests.utils import HAS_CUPY, assert_allclose, skip_if_no_cuda -cp, has_cp = optional_import("cupy") +cp, _ = optional_import("cupy") class TestToNumpyd(unittest.TestCase): - @skipUnless(has_cp, "CuPy is required.") + @skipUnless(HAS_CUPY, "CuPy is required.") def test_cupy_input(self): test_data = cp.array([[1, 2], [3, 4]]) test_data = cp.rot90(test_data) @@ -31,7 +31,7 @@ def test_cupy_input(self): result = ToNumpyd(keys="img")({"img": test_data})["img"] self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data.get()) + assert_allclose(result, test_data.get(), type_test=False) def test_numpy_input(self): test_data = np.array([[1, 2], [3, 4]]) @@ -40,7 +40,7 @@ def test_numpy_input(self): result = ToNumpyd(keys="img")({"img": test_data})["img"] self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) def test_tensor_input(self): test_data = torch.tensor([[1, 2], [3, 4]]) @@ -49,7 +49,7 @@ def test_tensor_input(self): result = ToNumpyd(keys="img")({"img": test_data})["img"] self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) @skip_if_no_cuda def test_tensor_cuda_input(self): @@ -59,7 +59,7 @@ def test_tensor_cuda_input(self): result = ToNumpyd(keys="img")({"img": test_data})["img"] self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) if __name__ == "__main__": diff --git a/tests/test_to_onehot.py b/tests/test_to_onehot.py index c3e373955d..c08672bfb2 100644 --- a/tests/test_to_onehot.py +++ b/tests/test_to_onehot.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_to_pil.py b/tests/test_to_pil.py index 5690645dd8..0a1351028c 100644 --- a/tests/test_to_pil.py +++ b/tests/test_to_pil.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -43,7 +43,7 @@ class TestToPIL(unittest.TestCase): def test_value(self, test_data): result = ToPIL()(test_data) self.assertTrue(isinstance(result, PILImageImage)) - assert_allclose(np.array(result), test_data) + assert_allclose(np.array(result), test_data, type_test=False) if __name__ == "__main__": diff --git a/tests/test_to_pild.py b/tests/test_to_pild.py index 3a15b1e507..d00ecf13d4 100644 --- a/tests/test_to_pild.py +++ b/tests/test_to_pild.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -30,9 +30,7 @@ PILImageImage, _ = optional_import("PIL.Image", name="Image") im = [[1.0, 2.0], [3.0, 4.0]] -TESTS = [] -for p in TEST_NDARRAYS: - TESTS.append([{"keys": "image"}, {"image": p(im)}]) +TESTS = [[{"keys": "image"}, {"image": p(im)}] for p in TEST_NDARRAYS] if has_pil: TESTS.append([{"keys": "image"}, {"image": pil_image_fromarray(np.array(im))}]) @@ -43,7 +41,7 @@ class TestToPIL(unittest.TestCase): def test_values(self, input_param, test_data): result = ToPILd(**input_param)(test_data)[input_param["keys"]] self.assertTrue(isinstance(result, PILImageImage)) - assert_allclose(np.array(result), test_data[input_param["keys"]]) + assert_allclose(np.array(result), test_data[input_param["keys"]], type_test=False) if __name__ == "__main__": diff --git a/tests/test_to_tensor.py b/tests/test_to_tensor.py index 6ac06983f6..bfc61cdb19 100644 --- a/tests/test_to_tensor.py +++ b/tests/test_to_tensor.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,10 +11,13 @@ import unittest +import torch from parameterized import parameterized from monai.transforms import ToTensor -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import HAS_CUPY, TEST_NDARRAYS, assert_allclose, optional_import + +cp, _ = optional_import("cupy") im = [[1, 2], [3, 4]] @@ -32,16 +35,26 @@ class TestToTensor(unittest.TestCase): @parameterized.expand(TESTS) def test_array_input(self, test_data, expected_shape): - result = ToTensor()(test_data) - assert_allclose(result, test_data) + result = ToTensor(dtype=torch.float32, device="cpu", wrap_sequence=True)(test_data) + self.assertTrue(isinstance(result, torch.Tensor)) + assert_allclose(result, test_data, type_test=False) self.assertTupleEqual(result.shape, expected_shape) @parameterized.expand(TESTS_SINGLE) def test_single_input(self, test_data): result = ToTensor()(test_data) - assert_allclose(result, test_data) + self.assertTrue(isinstance(result, torch.Tensor)) + assert_allclose(result, test_data, type_test=False) self.assertEqual(result.ndim, 0) + @unittest.skipUnless(HAS_CUPY, "CuPy is required.") + def test_cupy(self): + test_data = [[1, 2], [3, 4]] + cupy_array = cp.ascontiguousarray(cp.asarray(test_data)) + result = ToTensor()(cupy_array) + self.assertTrue(isinstance(result, torch.Tensor)) + assert_allclose(result, test_data, type_test=False) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_torchscript_utils.py b/tests/test_torchscript_utils.py new file mode 100644 index 0000000000..b26d41345a --- /dev/null +++ b/tests/test_torchscript_utils.py @@ -0,0 +1,112 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import torch + +from monai.config import get_config_values +from monai.data import load_net_with_metadata, save_net_with_metadata +from monai.utils import JITMetadataKeys +from monai.utils.module import pytorch_after + + +class TestModule(torch.nn.Module): + def forward(self, x): + return x + 10 + + +class TestTorchscript(unittest.TestCase): + def test_save_net_with_metadata(self): + """Save a network without metadata to a file.""" + m = torch.jit.script(TestModule()) + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test") + + self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) + + def test_save_net_with_metadata_ext(self): + """Save a network without metadata to a file.""" + m = torch.jit.script(TestModule()) + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test.zip") + + self.assertTrue(os.path.isfile(f"{tempdir}/test.zip")) + + def test_save_net_with_metadata_with_extra(self): + """Save a network with simple metadata to a file.""" + m = torch.jit.script(TestModule()) + + test_metadata = {"foo": [1, 2], "bar": "string"} + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata) + + self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) + + def test_load_net_with_metadata(self): + """Save then load a network with no metadata or other extra files.""" + m = torch.jit.script(TestModule()) + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test") + _, meta, extra_files = load_net_with_metadata(f"{tempdir}/test.pt") + + del meta[JITMetadataKeys.TIMESTAMP.value] # no way of knowing precisely what this value would be + + self.assertEqual(meta, get_config_values()) + self.assertEqual(extra_files, {}) + + def test_load_net_with_metadata_with_extra(self): + """Save then load a network with basic metadata.""" + m = torch.jit.script(TestModule()) + + test_metadata = {"foo": [1, 2], "bar": "string"} + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata) + _, meta, extra_files = load_net_with_metadata(f"{tempdir}/test.pt") + + del meta[JITMetadataKeys.TIMESTAMP.value] # no way of knowing precisely what this value would be + + test_compare = get_config_values() + test_compare.update(test_metadata) + + self.assertEqual(meta, test_compare) + self.assertEqual(extra_files, {}) + + def test_save_load_more_extra_files(self): + """Save then load extra file data from a torchscript file.""" + m = torch.jit.script(TestModule()) + + test_metadata = {"foo": [1, 2], "bar": "string"} + + more_extra_files = {"test.txt": b"This is test data"} + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata, more_extra_files=more_extra_files) + + self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) + + _, _, loaded_extra_files = load_net_with_metadata(f"{tempdir}/test.pt", more_extra_files=("test.txt",)) + + if pytorch_after(1, 7): + self.assertEqual(more_extra_files["test.txt"], loaded_extra_files["test.txt"]) + else: + self.assertEqual(more_extra_files["test.txt"].decode(), loaded_extra_files["test.txt"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_torchvision.py b/tests/test_torchvision.py index 0846b7f6b6..e0844eb4b9 100644 --- a/tests/test_torchvision.py +++ b/tests/test_torchvision.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,75 +11,54 @@ import unittest -import torch from parameterized import parameterized from monai.transforms import TorchVision from monai.utils import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion +from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose -TEST_CASE_1 = [ - {"name": "ColorJitter"}, - torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), - torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), -] - -TEST_CASE_2 = [ - {"name": "ColorJitter", "brightness": 0.5, "contrast": 0.5, "saturation": [0.1, 0.8], "hue": 0.5}, - torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), - torch.tensor( - [ - [ - [0.1090, 0.6193], - [0.6193, 0.9164], - ], - [ - [0.1090, 0.6193], - [0.6193, 0.9164], - ], - [ - [0.1090, 0.6193], - [0.6193, 0.9164], - ], - ], - ), -] - -TEST_CASE_3 = [ - {"name": "Pad", "padding": [1, 1, 1, 1]}, - torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), - torch.tensor( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.extend( [ [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 0.0, 0.0, 0.0], + {"name": "ColorJitter"}, + p([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), + p([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), ], [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 0.0, 0.0, 0.0], + {"name": "ColorJitter", "brightness": 0.5, "contrast": 0.5, "saturation": [0.1, 0.8], "hue": 0.5}, + p([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), + p( + [ + [[0.1090, 0.6193], [0.6193, 0.9164]], + [[0.1090, 0.6193], [0.6193, 0.9164]], + [[0.1090, 0.6193], [0.6193, 0.9164]], + ] + ), ], [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 0.0, 0.0, 0.0], + {"name": "Pad", "padding": [1, 1, 1, 1]}, + p([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), + p( + [ + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + ] + ), ], ] - ), -] + ) @SkipIfBeforePyTorchVersion((1, 7)) class TestTorchVision(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, input_param, input_data, expected_value): set_determinism(seed=0) result = TorchVision(**input_param)(input_data) - torch.testing.assert_allclose(result, expected_value) + assert_allclose(result, expected_value, rtol=1e-3) if __name__ == "__main__": diff --git a/tests/test_torchvision_fc_model.py b/tests/test_torchvision_fc_model.py index d6d3ea69c9..98b300eeac 100644 --- a/tests/test_torchvision_fc_model.py +++ b/tests/test_torchvision_fc_model.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -115,17 +115,7 @@ class TestTorchVisionFCModel(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - ] - ) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) @skipUnless(has_tv, "Requires TorchVision.") def test_without_pretrained(self, input_param, input_shape, expected_shape): net = TorchVisionFCModel(**input_param).to(device) diff --git a/tests/test_torchvision_fully_conv_model.py b/tests/test_torchvision_fully_conv_model.py index af2c1458d3..34a61ce9fa 100644 --- a/tests/test_torchvision_fully_conv_model.py +++ b/tests/test_torchvision_fully_conv_model.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,23 +23,11 @@ device = "cuda" if torch.cuda.is_available() else "cpu" -TEST_CASE_0 = [ - {"model_name": "resnet18", "num_classes": 1, "pretrained": False}, - (2, 3, 224, 224), - (2, 1, 1, 1), -] +TEST_CASE_0 = [{"model_name": "resnet18", "num_classes": 1, "pretrained": False}, (2, 3, 224, 224), (2, 1, 1, 1)] -TEST_CASE_1 = [ - {"model_name": "resnet18", "num_classes": 1, "pretrained": False}, - (2, 3, 256, 256), - (2, 1, 2, 2), -] +TEST_CASE_1 = [{"model_name": "resnet18", "num_classes": 1, "pretrained": False}, (2, 3, 256, 256), (2, 1, 2, 2)] -TEST_CASE_2 = [ - {"model_name": "resnet101", "num_classes": 5, "pretrained": False}, - (2, 3, 256, 256), - (2, 5, 2, 2), -] +TEST_CASE_2 = [{"model_name": "resnet101", "num_classes": 5, "pretrained": False}, (2, 3, 256, 256), (2, 5, 2, 2)] TEST_CASE_3 = [ {"model_name": "resnet101", "num_classes": 5, "pool_size": 6, "pretrained": False}, @@ -70,14 +58,7 @@ class TestTorchVisionFullyConvModel(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - ] - ) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) @skipUnless(has_tv, "Requires TorchVision.") def test_without_pretrained(self, input_param, input_shape, expected_shape): net = TorchVisionFullyConvModel(**input_param).to(device) @@ -85,13 +66,7 @@ def test_without_pretrained(self, input_param, input_shape, expected_shape): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) - @parameterized.expand( - [ - TEST_CASE_PRETRAINED_0, - TEST_CASE_PRETRAINED_1, - TEST_CASE_PRETRAINED_2, - ] - ) + @parameterized.expand([TEST_CASE_PRETRAINED_0, TEST_CASE_PRETRAINED_1, TEST_CASE_PRETRAINED_2]) @skipUnless(has_tv, "Requires TorchVision.") def test_with_pretrained(self, input_param, input_shape, expected_shape, expected_value): net = TorchVisionFullyConvModel(**input_param).to(device) diff --git a/tests/test_torchvisiond.py b/tests/test_torchvisiond.py index 4f42bc95f7..4c62c6e41a 100644 --- a/tests/test_torchvisiond.py +++ b/tests/test_torchvisiond.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -29,19 +29,10 @@ {"img": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])}, torch.tensor( [ - [ - [0.1090, 0.6193], - [0.6193, 0.9164], - ], - [ - [0.1090, 0.6193], - [0.6193, 0.9164], - ], - [ - [0.1090, 0.6193], - [0.6193, 0.9164], - ], - ], + [[0.1090, 0.6193], [0.6193, 0.9164]], + [[0.1090, 0.6193], [0.6193, 0.9164]], + [[0.1090, 0.6193], [0.6193, 0.9164]], + ] ), ] @@ -50,24 +41,9 @@ {"img": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])}, torch.tensor( [ - [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - ], + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], ] ), ] diff --git a/tests/test_traceable_transform.py b/tests/test_traceable_transform.py new file mode 100644 index 0000000000..bc6aad3a62 --- /dev/null +++ b/tests/test_traceable_transform.py @@ -0,0 +1,53 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from monai.transforms.inverse import TraceableTransform + + +class _TraceTest(TraceableTransform): + def __call__(self, data): + self.push_transform(data) + return data + + def pop(self, data): + self.pop_transform(data) + return data + + +class TestTraceable(unittest.TestCase): + def test_default(self): + expected_key = "_transforms" + a = _TraceTest() + self.assertEqual(a.trace_key(), expected_key) + + data = {"image": "test"} + data = a(data) # adds to the stack + self.assertTrue(isinstance(data[expected_key], list)) + self.assertEqual(data[expected_key][0]["class"], "_TraceTest") + + data = a(data) # adds to the stack + self.assertEqual(len(data[expected_key]), 2) + self.assertEqual(data[expected_key][-1]["class"], "_TraceTest") + + with self.assertRaises(IndexError): + a.pop({"test": "test"}) # no stack in the data + data = a.pop(data) + data = a.pop(data) + self.assertEqual(data[expected_key], []) + + with self.assertRaises(IndexError): # no more items + a.pop(data) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_train_mode.py b/tests/test_train_mode.py index 1acb443041..231e3854f0 100644 --- a/tests/test_train_mode.py +++ b/tests/test_train_mode.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_transchex.py b/tests/test_transchex.py new file mode 100644 index 0000000000..462ce64fd6 --- /dev/null +++ b/tests/test_transchex.py @@ -0,0 +1,82 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.transchex import Transchex +from tests.utils import skip_if_quick + +TEST_CASE_TRANSCHEX = [] +for drop_out in [0.4]: + for in_channels in [3]: + for img_size in [224]: + for patch_size in [16, 32]: + for num_language_layers in [2]: + for num_vision_layers in [4]: + for num_mixed_layers in [3]: + for num_classes in [8]: + test_case = [ + { + "in_channels": in_channels, + "img_size": (img_size,) * 2, + "patch_size": (patch_size,) * 2, + "num_vision_layers": num_vision_layers, + "num_mixed_layers": num_mixed_layers, + "num_language_layers": num_language_layers, + "num_classes": num_classes, + "drop_out": drop_out, + }, + (2, num_classes), # type: ignore + ] + TEST_CASE_TRANSCHEX.append(test_case) + + +@skip_if_quick +class TestTranschex(unittest.TestCase): + @parameterized.expand(TEST_CASE_TRANSCHEX) + def test_shape(self, input_param, expected_shape): + net = Transchex(**input_param) + with eval_mode(net): + result = net(torch.randint(2, (2, 512)), torch.randint(2, (2, 512)), torch.randn((2, 3, 224, 224))) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + Transchex( + in_channels=3, + img_size=(128, 128), + patch_size=(16, 16), + num_language_layers=2, + num_mixed_layers=4, + num_vision_layers=2, + num_classes=2, + drop_out=5.0, + ) + + with self.assertRaises(ValueError): + Transchex( + in_channels=1, + img_size=(97, 97), + patch_size=(16, 16), + num_language_layers=6, + num_mixed_layers=6, + num_vision_layers=8, + num_classes=8, + drop_out=0.4, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_transformerblock.py b/tests/test_transformerblock.py index 616e3e7ec9..d6131d010c 100644 --- a/tests/test_transformerblock.py +++ b/tests/test_transformerblock.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_transpose.py b/tests/test_transpose.py index 10882c9dd8..94a5b49c3a 100644 --- a/tests/test_transpose.py +++ b/tests/test_transpose.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,18 +20,8 @@ TESTS = [] for p in TEST_NDARRAYS: - TESTS.append( - [ - p(np.arange(5 * 4).reshape(5, 4)), - None, - ] - ) - TESTS.append( - [ - p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), - [2, 0, 1], - ] - ) + TESTS.append([p(np.arange(5 * 4).reshape(5, 4)), None]) + TESTS.append([p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), [2, 0, 1]]) class TestTranspose(unittest.TestCase): @@ -42,7 +32,7 @@ def test_transpose(self, im, indices): if isinstance(im, torch.Tensor): im = im.cpu().numpy() out2 = np.transpose(im, indices) - assert_allclose(out1, out2) + assert_allclose(out1, out2, type_test=False) if __name__ == "__main__": diff --git a/tests/test_transposed.py b/tests/test_transposed.py index 88ecd0c872..14e62eb9da 100644 --- a/tests/test_transposed.py +++ b/tests/test_transposed.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,30 +21,10 @@ TESTS = [] for p in TEST_NDARRAYS: - TESTS.append( - [ - p(np.arange(5 * 4).reshape(5, 4)), - [1, 0], - ] - ) - TESTS.append( - [ - p(np.arange(5 * 4).reshape(5, 4)), - None, - ] - ) - TESTS.append( - [ - p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), - [2, 0, 1], - ] - ) - TESTS.append( - [ - p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), - None, - ] - ) + TESTS.append([p(np.arange(5 * 4).reshape(5, 4)), [1, 0]]) + TESTS.append([p(np.arange(5 * 4).reshape(5, 4)), None]) + TESTS.append([p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), [2, 0, 1]]) + TESTS.append([p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), None]) class TestTranspose(unittest.TestCase): @@ -57,13 +37,13 @@ def test_transpose(self, im, indices): if isinstance(im, torch.Tensor): im = im.cpu().numpy() out_gt = np.transpose(im, indices) - assert_allclose(out_im1, out_gt) - assert_allclose(out_im2, out_gt) + assert_allclose(out_im1, out_gt, type_test=False) + assert_allclose(out_im2, out_gt, type_test=False) # test inverse fwd_inv_data = tr.inverse(out_data) for i, j in zip(data.values(), fwd_inv_data.values()): - assert_allclose(i, j) + assert_allclose(i, j, type_test=False) if __name__ == "__main__": diff --git a/tests/test_tversky_loss.py b/tests/test_tversky_loss.py index 0bc2ca2e70..2bb2409360 100644 --- a/tests/test_tversky_loss.py +++ b/tests/test_tversky_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,10 +21,7 @@ TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.307576, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) @@ -107,18 +104,12 @@ ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "alpha": 0.3, "beta": 0.7, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.3589, ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "alpha": 0.7, "beta": 0.3, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.247366, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) diff --git a/tests/test_unet.py b/tests/test_unet.py index 4091c4e9d7..5f126fed97 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,7 +23,7 @@ TEST_CASE_0 = [ # single channel 2D, batch 16, no residual { - "dimensions": 2, + "spatial_dims": 2, "in_channels": 1, "out_channels": 3, "channels": (16, 32, 64), @@ -36,7 +36,7 @@ TEST_CASE_1 = [ # single channel 2D, batch 16 { - "dimensions": 2, + "spatial_dims": 2, "in_channels": 1, "out_channels": 3, "channels": (16, 32, 64), @@ -49,7 +49,7 @@ TEST_CASE_2 = [ # single channel 3D, batch 16 { - "dimensions": 3, + "spatial_dims": 3, "in_channels": 1, "out_channels": 3, "channels": (16, 32, 64), @@ -62,7 +62,7 @@ TEST_CASE_3 = [ # 4-channel 3D, batch 16 { - "dimensions": 3, + "spatial_dims": 3, "in_channels": 4, "out_channels": 3, "channels": (16, 32, 64), @@ -75,7 +75,7 @@ TEST_CASE_4 = [ # 4-channel 3D, batch 16, batch normalization { - "dimensions": 3, + "spatial_dims": 3, "in_channels": 4, "out_channels": 3, "channels": (16, 32, 64), @@ -89,7 +89,7 @@ TEST_CASE_5 = [ # 4-channel 3D, batch 16, LeakyReLU activation { - "dimensions": 3, + "spatial_dims": 3, "in_channels": 4, "out_channels": 3, "channels": (16, 32, 64), @@ -103,7 +103,7 @@ TEST_CASE_6 = [ # 4-channel 3D, batch 16, LeakyReLU activation explicit { - "dimensions": 3, + "spatial_dims": 3, "in_channels": 4, "out_channels": 3, "channels": (16, 32, 64), @@ -120,7 +120,7 @@ ILL_CASES = [ [ { # len(channels) < 2 - "dimensions": 2, + "spatial_dims": 2, "in_channels": 1, "out_channels": 3, "channels": (16,), @@ -130,7 +130,7 @@ ], [ { # len(strides) < len(channels) - 1 - "dimensions": 2, + "spatial_dims": 2, "in_channels": 1, "out_channels": 3, "channels": (8, 8, 8), @@ -139,8 +139,8 @@ } ], [ - { # len(kernel_size) = 3, dimensions = 2 - "dimensions": 2, + { # len(kernel_size) = 3, spatial_dims = 2 + "spatial_dims": 2, "in_channels": 1, "out_channels": 3, "channels": (8, 8, 8), @@ -149,8 +149,8 @@ } ], [ - { # len(up_kernel_size) = 2, dimensions = 3 - "dimensions": 3, + { # len(up_kernel_size) = 2, spatial_dims = 3 + "spatial_dims": 3, "in_channels": 1, "out_channels": 3, "channels": (8, 8, 8), @@ -170,13 +170,15 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) def test_script(self): - net = UNet(dimensions=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0) + net = UNet( + spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0 + ) test_data = torch.randn(16, 1, 32, 32) test_script_save(net, test_data) def test_script_without_running_stats(self): net = UNet( - dimensions=2, + spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), @@ -188,13 +190,7 @@ def test_script_without_running_stats(self): test_script_save(net, test_data) def test_ill_input_shape(self): - net = UNet( - dimensions=2, - in_channels=1, - out_channels=3, - channels=(16, 32, 64), - strides=(2, 2), - ) + net = UNet(spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2)) with eval_mode(net): with self.assertRaisesRegex(RuntimeError, "Sizes of tensors must match"): net.forward(torch.randn(2, 1, 16, 5)) @@ -202,7 +198,7 @@ def test_ill_input_shape(self): @parameterized.expand(ILL_CASES) def test_ill_input_hyper_params(self, input_param): with self.assertRaises(ValueError): - net = UNet(**input_param) + _ = UNet(**input_param) if __name__ == "__main__": diff --git a/tests/test_unetr.py b/tests/test_unetr.py index d19ed2ca59..40619de9dc 100644 --- a/tests/test_unetr.py +++ b/tests/test_unetr.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,6 +16,7 @@ from monai.networks import eval_mode from monai.networks.nets.unetr import UNETR +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save TEST_CASE_UNETR = [] for dropout_rate in [0.4]: @@ -52,7 +53,7 @@ TEST_CASE_UNETR.append(test_case) -class TestPatchEmbeddingBlock(unittest.TestCase): +class TestUNETR(unittest.TestCase): @parameterized.expand(TEST_CASE_UNETR) def test_shape(self, input_param, input_shape, expected_shape): net = UNETR(**input_param) @@ -117,6 +118,17 @@ def test_ill_arg(self): dropout_rate=0.2, ) + @parameterized.expand(TEST_CASE_UNETR) + @SkipIfBeforePyTorchVersion((1, 9)) + def test_script(self, input_param, input_shape, _): + net = UNETR(**(input_param)) + net.eval() + with torch.no_grad(): + torch.jit.script(net) + + test_data = torch.randn(input_shape) + test_script_save(net, test_data) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_unetr_block.py b/tests/test_unetr_block.py index 7546918a2c..c0f14c829d 100644 --- a/tests/test_unetr_block.py +++ b/tests/test_unetr_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_upsample_block.py b/tests/test_upsample_block.py index 7b8ada399c..aa4141aabc 100644 --- a/tests/test_upsample_block.py +++ b/tests/test_upsample_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,11 +20,7 @@ TEST_CASES = [ [{"dimensions": 2, "in_channels": 4}, (7, 4, 32, 48), (7, 4, 64, 96)], # 4-channel 2D, batch 7 - [ - {"dimensions": 1, "in_channels": 4, "out_channels": 3}, - (16, 4, 63), - (16, 3, 126), - ], # 4-channel 1D, batch 16 + [{"dimensions": 1, "in_channels": 4, "out_channels": 3}, (16, 4, 63), (16, 3, 126)], # 4-channel 1D, batch 16 [ {"dimensions": 1, "in_channels": 4, "out_channels": 8, "mode": "deconv", "align_corners": False}, (16, 4, 20), @@ -78,14 +74,7 @@ expected_shape = (16, 5, 4 * s, 5 * s, 6 * s) for t in UpsampleMode: test_case = [ - { - "dimensions": 3, - "in_channels": 3, - "out_channels": 5, - "mode": t, - "scale_factor": s, - "align_corners": True, - }, + {"dimensions": 3, "in_channels": 3, "out_channels": 5, "mode": t, "scale_factor": s, "align_corners": True}, (16, 3, 4, 5, 6), ] test_case.append(expected_shape) diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py new file mode 100644 index 0000000000..b13378debe --- /dev/null +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -0,0 +1,71 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms.utils_pytorch_numpy_unification import mode, percentile +from monai.utils import set_determinism +from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose + +TEST_MODE = [] +for p in TEST_NDARRAYS: + TEST_MODE.append([p(np.array([1, 2, 3, 4, 4, 5])), p(4), False]) + TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4.1), False]) + TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4), True]) + + +class TestPytorchNumpyUnification(unittest.TestCase): + def setUp(self) -> None: + set_determinism(0) + + def test_percentile(self): + for size in (1, 100): + q = np.random.randint(0, 100, size=size) + results = [] + for p in TEST_NDARRAYS: + arr = p(np.arange(100 * 101).reshape(1, 100, 101).astype(np.float32)) + results.append(percentile(arr, q)) + # pre torch 1.7, no `quantile`. Our own method doesn't interpolate, + # so we can only be accurate to 0.5 + atol = 0.5 if not hasattr(torch, "quantile") else 1e-4 + assert_allclose(results[0], results[-1], type_test=False, atol=atol) + + def test_fails(self): + for p in TEST_NDARRAYS: + for q in (-1, 101): + arr = p(np.arange(100 * 101).reshape(1, 100, 101).astype(np.float32)) + with self.assertRaises(ValueError): + percentile(arr, q) + + @SkipIfBeforePyTorchVersion((1, 7)) + def test_dim(self): + q = np.random.randint(0, 100, size=50) + results = [] + for p in TEST_NDARRAYS: + arr = p(np.arange(6).reshape(1, 2, 3).astype(np.float32)) + results.append(percentile(arr, q, dim=1)) + # pre torch 1.7, no `quantile`. Our own method doesn't interpolate, + # so we can only be accurate to 0.5 + atol = 0.5 if not hasattr(torch, "quantile") else 1e-4 + assert_allclose(results[0], results[-1], type_test=False, atol=atol) + + @parameterized.expand(TEST_MODE) + def test_mode(self, array, expected, to_long): + res = mode(array, to_long=to_long) + assert_allclose(res, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_varautoencoder.py b/tests/test_varautoencoder.py index 7a4a546d87..95fea8afcb 100644 --- a/tests/test_varautoencoder.py +++ b/tests/test_varautoencoder.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,7 +23,7 @@ TEST_CASE_0 = [ # single channel 2D, batch 4, no residual { - "dimensions": 2, + "spatial_dims": 2, "in_shape": (1, 128, 128), "out_channels": 1, "latent_size": 2, @@ -37,7 +37,7 @@ TEST_CASE_1 = [ # single channel 2D, batch 4 { - "dimensions": 2, + "spatial_dims": 2, "in_shape": (1, 128, 128), "out_channels": 1, "latent_size": 2, @@ -50,7 +50,7 @@ TEST_CASE_2 = [ # 3-channel 2D, batch 4, LeakyReLU activation { - "dimensions": 2, + "spatial_dims": 2, "in_shape": (3, 128, 128), "out_channels": 3, "latent_size": 2, @@ -64,7 +64,7 @@ TEST_CASE_3 = [ # 4-channel 3D, batch 4 { - "dimensions": 3, + "spatial_dims": 3, "in_shape": (4, 128, 128, 128), "out_channels": 3, "latent_size": 2, @@ -88,7 +88,7 @@ def test_shape(self, input_param, input_shape, expected_shape): def test_script(self): net = VarAutoEncoder( - dimensions=2, in_shape=(1, 32, 32), out_channels=1, latent_size=2, channels=(4, 8), strides=(2, 2) + spatial_dims=2, in_shape=(1, 32, 32), out_channels=1, latent_size=2, channels=(4, 8), strides=(2, 2) ) test_data = torch.randn(2, 1, 32, 32) test_script_save(net, test_data) diff --git a/tests/test_version_leq.py b/tests/test_version_leq.py index a1913069d3..86fccca9fb 100644 --- a/tests/test_version_leq.py +++ b/tests/test_version_leq.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -67,6 +67,9 @@ def _pairwise(iterable): ("0post1", "0.4post1"), ("2.1.0-rc1", "2.1.0"), ("2.1dev", "2.1a0"), + (1.6, "1.6.0"), + ("1.6.0", 1.6), + (1.6, 1.7), ) + tuple(_pairwise(reversed(torture.split()))) diff --git a/tests/test_vis_cam.py b/tests/test_vis_cam.py index 47c116cd5d..2137926424 100644 --- a/tests/test_vis_cam.py +++ b/tests/test_vis_cam.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_vis_gradcam.py b/tests/test_vis_gradcam.py index eebf32d70b..acca06d405 100644 --- a/tests/test_vis_gradcam.py +++ b/tests/test_vis_gradcam.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -40,23 +40,13 @@ ] # 2D TEST_CASE_2 = [ - { - "model": "senet2d", - "shape": (2, 3, 64, 64), - "feature_shape": (2, 1, 2, 2), - "target_layers": "layer4", - }, + {"model": "senet2d", "shape": (2, 3, 64, 64), "feature_shape": (2, 1, 2, 2), "target_layers": "layer4"}, (2, 1, 64, 64), ] # 3D TEST_CASE_3 = [ - { - "model": "senet3d", - "shape": (2, 3, 8, 8, 48), - "feature_shape": (2, 1, 1, 1, 2), - "target_layers": "layer4", - }, + {"model": "senet3d", "shape": (2, 3, 8, 8, 48), "feature_shape": (2, 1, 1, 1, 2), "target_layers": "layer4"}, (2, 1, 8, 8, 48), ] @@ -88,6 +78,16 @@ def test_shape(self, input_data, expected_shape): result2 = cam(x=image, layer_idx=-1, class_idx=model(image).max(1)[-1].cpu()) torch.testing.assert_allclose(result, result2) + def test_ill(self): + model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) + for name, x in model.named_parameters(): + if "features" in name: + x.requires_grad = False + cam = GradCAM(nn_module=model, target_layers="class_layers.relu") + image = torch.rand((2, 1, 48, 64)) + with self.assertRaises(RuntimeError): + cam(x=image) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_vis_gradcampp.py b/tests/test_vis_gradcampp.py index 92a4b2ac7b..a261b6055b 100644 --- a/tests/test_vis_gradcampp.py +++ b/tests/test_vis_gradcampp.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -39,23 +39,13 @@ ] # 2D TEST_CASE_2 = [ - { - "model": "senet2d", - "shape": (2, 3, 64, 64), - "feature_shape": (2, 1, 2, 2), - "target_layers": "layer4", - }, + {"model": "senet2d", "shape": (2, 3, 64, 64), "feature_shape": (2, 1, 2, 2), "target_layers": "layer4"}, (2, 1, 64, 64), ] # 3D TEST_CASE_3 = [ - { - "model": "senet3d", - "shape": (2, 3, 8, 8, 48), - "feature_shape": (2, 1, 1, 1, 2), - "target_layers": "layer4", - }, + {"model": "senet3d", "shape": (2, 3, 8, 8, 48), "feature_shape": (2, 1, 1, 1, 2), "target_layers": "layer4"}, (2, 1, 8, 8, 48), ] diff --git a/tests/test_vit.py b/tests/test_vit.py index cdf0888222..3ef847626d 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,6 +16,7 @@ from monai.networks import eval_mode from monai.networks.nets.vit import ViT +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save TEST_CASE_Vit = [] for dropout_rate in [0.6]: @@ -27,7 +28,7 @@ for mlp_dim in [3072]: for num_layers in [4]: for num_classes in [8]: - for pos_embed in ["conv"]: + for pos_embed in ["conv", "perceptron"]: for classification in [False, True]: for nd in (2, 3): test_case = [ @@ -54,7 +55,7 @@ TEST_CASE_Vit.append(test_case) -class TestPatchEmbeddingBlock(unittest.TestCase): +class TestViT(unittest.TestCase): @parameterized.expand(TEST_CASE_Vit) def test_shape(self, input_param, input_shape, expected_shape): net = ViT(**input_param) @@ -133,6 +134,17 @@ def test_ill_arg(self): dropout_rate=0.3, ) + @parameterized.expand(TEST_CASE_Vit) + @SkipIfBeforePyTorchVersion((1, 9)) + def test_script(self, input_param, input_shape, _): + net = ViT(**(input_param)) + net.eval() + with torch.no_grad(): + torch.jit.script(net) + + test_data = torch.randn(input_shape) + test_script_save(net, test_data) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_vitautoenc.py b/tests/test_vitautoenc.py new file mode 100644 index 0000000000..8320fef02d --- /dev/null +++ b/tests/test_vitautoenc.py @@ -0,0 +1,140 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.vitautoenc import ViTAutoEnc + +TEST_CASE_Vitautoenc = [] +for in_channels in [1, 4]: + for img_size in [64, 96, 128]: + for patch_size in [16]: + for pos_embed in ["conv", "perceptron"]: + for nd in [2, 3]: + test_case = [ + { + "in_channels": in_channels, + "img_size": (img_size,) * nd, + "patch_size": (patch_size,) * nd, + "hidden_size": 768, + "mlp_dim": 3072, + "num_layers": 4, + "num_heads": 12, + "pos_embed": pos_embed, + "dropout_rate": 0.6, + "spatial_dims": nd, + }, + (2, in_channels, *([img_size] * nd)), + (2, 1, *([img_size] * nd)), + ] + + TEST_CASE_Vitautoenc.append(test_case) + +TEST_CASE_Vitautoenc.append( + [ + { + "in_channels": 1, + "img_size": (512, 512, 32), + "patch_size": (16, 16, 16), + "hidden_size": 768, + "mlp_dim": 3072, + "num_layers": 4, + "num_heads": 12, + "pos_embed": "conv", + "dropout_rate": 0.6, + "spatial_dims": 3, + }, + (2, 1, 512, 512, 32), + (2, 1, 512, 512, 32), + ] +) + + +class TestPatchEmbeddingBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_Vitautoenc) + def test_shape(self, input_param, input_shape, expected_shape): + net = ViTAutoEnc(**input_param) + with eval_mode(net): + result, _ = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + ViTAutoEnc( + in_channels=1, + img_size=(128, 128, 128), + patch_size=(16, 16, 16), + hidden_size=128, + mlp_dim=3072, + num_layers=12, + num_heads=12, + pos_embed="conv", + dropout_rate=5.0, + ) + + with self.assertRaises(ValueError): + ViTAutoEnc( + in_channels=1, + img_size=(32, 32, 32), + patch_size=(64, 64, 64), + hidden_size=512, + mlp_dim=3072, + num_layers=12, + num_heads=8, + pos_embed="perceptron", + dropout_rate=0.3, + ) + + with self.assertRaises(ValueError): + ViTAutoEnc( + in_channels=1, + img_size=(96, 96, 96), + patch_size=(8, 8, 8), + hidden_size=512, + mlp_dim=3072, + num_layers=12, + num_heads=14, + pos_embed="conv", + dropout_rate=0.3, + ) + + with self.assertRaises(ValueError): + ViTAutoEnc( + in_channels=1, + img_size=(97, 97, 97), + patch_size=(4, 4, 4), + hidden_size=768, + mlp_dim=3072, + num_layers=12, + num_heads=8, + pos_embed="perceptron", + dropout_rate=0.3, + ) + + with self.assertRaises(ValueError): + ViTAutoEnc( + in_channels=4, + img_size=(96, 96, 96), + patch_size=(16, 16, 16), + hidden_size=768, + mlp_dim=3072, + num_layers=12, + num_heads=12, + pos_embed="perc", + dropout_rate=0.3, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vnet.py b/tests/test_vnet.py index 4eba5396b2..add0396bd8 100644 --- a/tests/test_vnet.py +++ b/tests/test_vnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_vote_ensemble.py b/tests/test_vote_ensemble.py index 74c19d5f48..79868d4706 100644 --- a/tests/test_vote_ensemble.py +++ b/tests/test_vote_ensemble.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,66 +15,67 @@ from parameterized import parameterized from monai.transforms import VoteEnsemble +from tests.utils import TEST_NDARRAYS, assert_allclose -# shape: [2, 1, 1] -TEST_CASE_1 = [ - {"num_classes": None}, - [torch.tensor([[[1]], [[0]]]), torch.tensor([[[1]], [[0]]]), torch.tensor([[[0]], [[1]]])], - torch.tensor([[[1.0]], [[0.0]]]), -] +TESTS = [] +for p in TEST_NDARRAYS: + # shape: [2, 1, 1] + TESTS.append( + [ + {"num_classes": None}, + [p(torch.tensor([[[1]], [[0]]])), p(torch.tensor([[[1]], [[0]]])), p(torch.tensor([[[0]], [[1]]]))], + p(torch.tensor([[[1.0]], [[0.0]]])), + ] + ) -# shape: [1, 2, 1, 1] -TEST_CASE_2 = [ - {"num_classes": None}, - torch.stack([torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])]), - torch.tensor([[[[1.0]], [[0.0]]]]), -] + # shape: [1, 2, 1, 1] + TESTS.append( + [ + {"num_classes": None}, + p( + torch.stack( + [torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])] + ) + ), + p(torch.tensor([[[[1.0]], [[0.0]]]])), + ] + ) -# shape: [1, 2, 1] -TEST_CASE_3 = [ - {"num_classes": 3}, - [torch.tensor([[[0], [2]]]), torch.tensor([[[0], [2]]]), torch.tensor([[[1], [1]]])], - torch.tensor([[[0], [2]]]), -] + # shape: [1, 2, 1] + TESTS.append( + [ + {"num_classes": 3}, + [p(torch.tensor([[[0], [2]]])), p(torch.tensor([[[0], [2]]])), p(torch.tensor([[[1], [1]]]))], + p(torch.tensor([[[0], [2]]])), + ] + ) -# shape: [1, 2, 1] -TEST_CASE_4 = [ - {"num_classes": 5}, - [torch.tensor([[[0], [2]]]), torch.tensor([[[0], [2]]]), torch.tensor([[[1], [1]]])], - torch.tensor([[[0], [2]]]), -] + # shape: [1, 2, 1] + TESTS.append( + [ + {"num_classes": 5}, + [p(torch.tensor([[[0], [2]]])), p(torch.tensor([[[0], [2]]])), p(torch.tensor([[[1], [1]]]))], + p(torch.tensor([[[0], [2]]])), + ] + ) -# shape: [1] -TEST_CASE_5 = [ - {"num_classes": 3}, - [torch.tensor([2]), torch.tensor([2]), torch.tensor([1])], - torch.tensor([2]), -] + # shape: [1] + TESTS.append( + [{"num_classes": 3}, [p(torch.tensor([2])), p(torch.tensor([2])), p(torch.tensor([1]))], p(torch.tensor([2]))] + ) -# shape: 1 -TEST_CASE_6 = [ - {"num_classes": 3}, - [torch.tensor(2), torch.tensor(2), torch.tensor(1)], - torch.tensor(2), -] + # shape: 1 + TESTS.append([{"num_classes": 3}, [p(torch.tensor(2)), p(torch.tensor(2)), p(torch.tensor(1))], p(torch.tensor(2))]) class TestVoteEnsemble(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand(TESTS) def test_value(self, input_param, img, expected_value): result = VoteEnsemble(**input_param)(img) - torch.testing.assert_allclose(result, expected_value) - - def test_cuda_value(self): - img = torch.stack( - [torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])] - ) - expected_value = torch.tensor([[[[1.0]], [[0.0]]]]) - if torch.cuda.is_available(): - img = img.to(torch.device("cuda:0")) - expected_value = expected_value.to(torch.device("cuda:0")) - result = VoteEnsemble(num_classes=None)(img) - torch.testing.assert_allclose(result, expected_value) + if isinstance(img, torch.Tensor): + self.assertIsInstance(result, torch.Tensor) + self.assertEqual(result.device, img.device) + assert_allclose(result, expected_value) if __name__ == "__main__": diff --git a/tests/test_vote_ensembled.py b/tests/test_vote_ensembled.py index e94213733f..e42a57f3b7 100644 --- a/tests/test_vote_ensembled.py +++ b/tests/test_vote_ensembled.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,64 +15,79 @@ from parameterized import parameterized from monai.transforms import VoteEnsembled +from tests.utils import TEST_NDARRAYS, assert_allclose -# shape: [1, 2, 1, 1] -TEST_CASE_1 = [ - {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": None}, - { - "pred0": torch.tensor([[[[1]], [[0]]]]), - "pred1": torch.tensor([[[[1]], [[0]]]]), - "pred2": torch.tensor([[[[0]], [[1]]]]), - }, - torch.tensor([[[[1.0]], [[0.0]]]]), -] +TESTS = [] +for p in TEST_NDARRAYS: + # shape: [1, 2, 1, 1] + TESTS.append( + [ + {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": None}, + { + "pred0": p(torch.tensor([[[[1]], [[0]]]])), + "pred1": p(torch.tensor([[[[1]], [[0]]]])), + "pred2": p(torch.tensor([[[[0]], [[1]]]])), + }, + p(torch.tensor([[[[1.0]], [[0.0]]]])), + ] + ) -# shape: [1, 2, 1, 1] -TEST_CASE_2 = [ - {"keys": "output", "output_key": "output", "num_classes": None}, - { - "output": torch.stack( - [torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])] - ) - }, - torch.tensor([[[[1.0]], [[0.0]]]]), -] + # shape: [1, 2, 1, 1] + TESTS.append( + [ + {"keys": "output", "output_key": "output", "num_classes": None}, + { + "output": p( + torch.stack( + [torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])] + ) + ) + }, + p(torch.tensor([[[[1.0]], [[0.0]]]])), + ] + ) -# shape: [1, 2, 1] -TEST_CASE_3 = [ - {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 3}, - { - "pred0": torch.tensor([[[0], [2]]]), - "pred1": torch.tensor([[[0], [2]]]), - "pred2": torch.tensor([[[1], [1]]]), - }, - torch.tensor([[[0], [2]]]), -] + # shape: [1, 2, 1] + TESTS.append( + [ + {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 3}, + { + "pred0": p(torch.tensor([[[0], [2]]])), + "pred1": p(torch.tensor([[[0], [2]]])), + "pred2": p(torch.tensor([[[1], [1]]])), + }, + p(torch.tensor([[[0], [2]]])), + ] + ) -# shape: [1, 2, 1] -TEST_CASE_4 = [ - {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 5}, - { - "pred0": torch.tensor([[[0], [2]]]), - "pred1": torch.tensor([[[0], [2]]]), - "pred2": torch.tensor([[[1], [1]]]), - }, - torch.tensor([[[0], [2]]]), -] + # shape: [1, 2, 1] + TESTS.append( + [ + {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 5}, + { + "pred0": p(torch.tensor([[[0], [2]]])), + "pred1": p(torch.tensor([[[0], [2]]])), + "pred2": p(torch.tensor([[[1], [1]]])), + }, + p(torch.tensor([[[0], [2]]])), + ] + ) -# shape: [1] -TEST_CASE_5 = [ - {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 3}, - {"pred0": torch.tensor([2]), "pred1": torch.tensor([2]), "pred2": torch.tensor([1])}, - torch.tensor([2]), -] + # shape: [1] + TESTS.append( + [ + {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 3}, + {"pred0": p(torch.tensor([2])), "pred1": p(torch.tensor([2])), "pred2": p(torch.tensor([1]))}, + p(torch.tensor([2])), + ] + ) class TestVoteEnsembled(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_value(self, input_param, img, expected_value): result = VoteEnsembled(**input_param)(img) - torch.testing.assert_allclose(result["output"], expected_value) + assert_allclose(result["output"], expected_value) def test_cuda_value(self): img = torch.stack( @@ -83,7 +98,7 @@ def test_cuda_value(self): img = img.to(torch.device("cuda:0")) expected_value = expected_value.to(torch.device("cuda:0")) result = VoteEnsembled(keys="output", num_classes=None)({"output": img}) - torch.testing.assert_allclose(result["output"], expected_value) + assert_allclose(result["output"], expected_value) if __name__ == "__main__": diff --git a/tests/test_warp.py b/tests/test_warp.py index c6c79a369a..c039b57211 100644 --- a/tests/test_warp.py +++ b/tests/test_warp.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -8,7 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os import unittest import numpy as np @@ -18,8 +18,9 @@ from monai.config.deviceconfig import USE_COMPILED from monai.networks.blocks.warp import Warp +from monai.transforms import LoadImaged from monai.utils import GridSampleMode, GridSamplePadMode -from tests.utils import SkipIfBeforePyTorchVersion +from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule, download_url_or_skip_test, testing_data_config LOW_POWER_TEST_CASES = [ # run with BUILD_MONAI=1 to test csrc/resample, BUILD_MONAI=0 to test native grid_sample [ @@ -96,6 +97,25 @@ class TestWarp(unittest.TestCase): + def setUp(self): + config = testing_data_config("images", "Prostate_T2W_AX_1") + download_url_or_skip_test( + url=config["url"], + filepath=FILE_PATH, + hash_val=config.get("hash_val"), + hash_type=config.get("hash_type", "sha256"), + ) + + @SkipIfNoModule("itk") + def test_itk_benchmark(self): + img, ddf = load_img_and_sample_ddf() + monai_result = monai_warp(img, ddf) + itk_result = itk_warp(img, ddf) + relative_diff = np.mean( + np.divide(monai_result - itk_result, itk_result, out=np.zeros_like(itk_result), where=(itk_result != 0)) + ) + self.assertTrue(relative_diff < 0.01) + @parameterized.expand(TEST_CASES, skip_on_empty=True) def test_resample(self, input_param, input_data, expected_val): warp_layer = Warp(**input_param) @@ -127,5 +147,92 @@ def test_grad(self): gradcheck(warp_layer, (input_image, ddf), atol=1e-2, eps=1e-2) +FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + "mri.nii") + + +def load_img_and_sample_ddf(): + # load image + img = LoadImaged(keys="img")({"img": FILE_PATH})["img"] + # W, H, D -> D, H, W + img = img.transpose((2, 1, 0)) + + # randomly sample ddf such that maximum displacement in each direction equals to one-tenth of the image dimension in + # that direction. + ddf = np.random.random((3, *img.shape)).astype(np.float32) # (3, D, H, W) + ddf[0] = ddf[0] * img.shape[0] * 0.1 + ddf[1] = ddf[1] * img.shape[1] * 0.1 + ddf[2] = ddf[2] * img.shape[2] * 0.1 + return img, ddf + + +def itk_warp(img, ddf): + """ + warping with python itk + Args: + img: numpy array of shape (D, H, W) + ddf: numpy array of shape (3, D, H, W) + + Returns: + warped_img: numpy arrap of shape (D, H, W) + """ + import itk + + # 3, D, H, W -> D, H, W, 3 + ddf = ddf.transpose((1, 2, 3, 0)) + # x, y, z -> z, x, y + ddf = ddf[..., ::-1] + + dimension = 3 + + # initialise image + pixel_type = itk.F # float32 + image_type = itk.Image[pixel_type, dimension] + itk_img = itk.PyBuffer[image_type].GetImageFromArray(img.astype(np.float32), is_vector=None) + + # initialise displacement field + vector_component_type = itk.F + vector_pixel_type = itk.Vector[vector_component_type, dimension] + displacement_field_type = itk.Image[vector_pixel_type, dimension] + displacement_field = itk.PyBuffer[displacement_field_type].GetImageFromArray(ddf.astype(np.float32), is_vector=True) + + # initialise warp_filter + warp_filter = itk.WarpImageFilter[image_type, image_type, displacement_field_type].New() + interpolator = itk.LinearInterpolateImageFunction[image_type, itk.D].New() + warp_filter.SetInterpolator(interpolator) + warp_filter.SetOutputSpacing(itk_img.GetSpacing()) + warp_filter.SetOutputOrigin(itk_img.GetOrigin()) + warp_filter.SetOutputDirection(itk_img.GetDirection()) + + # warp + warp_filter.SetDisplacementField(displacement_field) + warp_filter.SetInput(itk_img) + warped_img = warp_filter.GetOutput() + warped_img = np.asarray(warped_img) + + return warped_img + + +def monai_warp(img, ddf): + """ + warp with MONAI + Args: + img: numpy array of shape (D, H, W) + ddf: numpy array of shape (3, D, H, W) + + Returns: + warped_img: numpy arrap of shape (D, H, W) + """ + warp_layer = Warp(padding_mode="zeros") + # turn to tensor and add channel dim + monai_img = torch.tensor(img).unsqueeze(0) + ddf = torch.tensor(ddf) + # img -> batch -> img + warped_img = warp_layer(monai_img.unsqueeze(0), ddf.unsqueeze(0)).squeeze(0) + # remove channel dim + warped_img = np.asarray(warped_img.squeeze(0)) + + return warped_img + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_distributed_weighted_random_sampler.py b/tests/test_weighted_random_sampler_dist.py similarity index 91% rename from tests/test_distributed_weighted_random_sampler.py rename to tests/test_weighted_random_sampler_dist.py index b8e088fdcf..13404a8acb 100644 --- a/tests/test_distributed_weighted_random_sampler.py +++ b/tests/test_weighted_random_sampler_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,10 +25,7 @@ def test_sampling(self): data = [1, 2, 3, 4, 5] weights = [1, 2, 3, 4, 5] sampler = DistributedWeightedRandomSampler( - weights=weights, - dataset=data, - shuffle=False, - generator=torch.Generator().manual_seed(0), + weights=weights, dataset=data, shuffle=False, generator=torch.Generator().manual_seed(0) ) samples = np.array([data[i] for i in list(sampler)]) diff --git a/tests/test_with_allow_missing_keys.py b/tests/test_with_allow_missing_keys.py index 68c5ad30c4..36d5c0c843 100644 --- a/tests/test_with_allow_missing_keys.py +++ b/tests/test_with_allow_missing_keys.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_write_metrics_reports.py b/tests/test_write_metrics_reports.py index f736db9961..101e1137b6 100644 --- a/tests/test_write_metrics_reports.py +++ b/tests/test_write_metrics_reports.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,6 +13,7 @@ import os import tempfile import unittest +from pathlib import Path import torch @@ -23,7 +24,7 @@ class TestWriteMetricsReports(unittest.TestCase): def test_content(self): with tempfile.TemporaryDirectory() as tempdir: write_metrics_reports( - save_dir=tempdir, + save_dir=Path(tempdir), images=["filepath1", "filepath2"], metrics={"metric1": 1, "metric2": 2}, metric_details={"metric3": torch.tensor([[1, 2], [2, 3]]), "metric4": torch.tensor([[5, 6], [7, 8]])}, diff --git a/tests/test_wsireader.py b/tests/test_wsireader.py new file mode 100644 index 0000000000..6ee02143b8 --- /dev/null +++ b/tests/test_wsireader.py @@ -0,0 +1,239 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +from unittest import skipUnless + +import numpy as np +import torch +from numpy.testing import assert_array_equal +from parameterized import parameterized + +from monai.data import DataLoader, Dataset +from monai.data.image_reader import WSIReader +from monai.transforms import Compose, LoadImaged, ToTensord +from monai.utils import first, optional_import +from monai.utils.enums import PostFix +from tests.utils import download_url_or_skip_test, testing_data_config + +cucim, has_cucim = optional_import("cucim") +has_cucim = has_cucim and hasattr(cucim, "CuImage") +openslide, has_osl = optional_import("openslide") +imwrite, has_tiff = optional_import("tifffile", name="imwrite") +_, has_codec = optional_import("imagecodecs") +has_tiff = has_tiff and has_codec + +FILE_KEY = "wsi_img" +FILE_URL = testing_data_config("images", FILE_KEY, "url") +base_name, extension = os.path.basename(f"{FILE_URL}"), ".tiff" +FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + base_name + extension) + +HEIGHT = 32914 +WIDTH = 46000 + +TEST_CASE_0 = [FILE_PATH, 2, (3, HEIGHT // 4, WIDTH // 4)] + +TEST_CASE_TRANSFORM_0 = [FILE_PATH, 4, (HEIGHT // 16, WIDTH // 16), (1, 3, HEIGHT // 16, WIDTH // 16)] + +TEST_CASE_1 = [ + FILE_PATH, + {"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0}, + np.array([[[246], [246]], [[246], [246]], [[246], [246]]]), +] + +TEST_CASE_2 = [ + FILE_PATH, + {"location": (0, 0), "size": (2, 1), "level": 2}, + np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), +] + +TEST_CASE_3 = [ + FILE_PATH, + {"location": (0, 0), "size": (8, 8), "level": 2, "grid_shape": (2, 1), "patch_size": 2}, + np.array( + [ + [[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[239, 239], [239, 239]]], + [[[242, 242], [242, 243]], [[242, 242], [242, 243]], [[242, 242], [242, 243]]], + ] + ), +] + +TEST_CASE_4 = [ + FILE_PATH, + {"location": (0, 0), "size": (8, 8), "level": 2, "grid_shape": (2, 1), "patch_size": 1}, + np.array([[[[239]], [[239]], [[239]]], [[[243]], [[243]], [[243]]]]), +] + +TEST_CASE_5 = [ + FILE_PATH, + {"location": (HEIGHT - 2, WIDTH - 2), "level": 0, "grid_shape": (1, 1)}, + np.array([[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[237, 237], [237, 237]]]), +] + + +TEST_CASE_RGB_0 = [np.ones((3, 2, 2), dtype=np.uint8)] # CHW + +TEST_CASE_RGB_1 = [np.ones((3, 100, 100), dtype=np.uint8)] # CHW + +TEST_CASE_ERROR_GRAY = [np.ones((16, 16), dtype=np.uint8)] # no color channel +TEST_CASE_ERROR_3D = [np.ones((16, 16, 16, 3), dtype=np.uint8)] # 3D + color + + +def save_rgba_tiff(array: np.ndarray, filename: str, mode: str): + """ + Save numpy array into a TIFF RGB/RGBA file + + Args: + array: numpy ndarray with the shape of CxHxW and C==3 representing a RGB image + filename: the filename to be used for the tiff file. '_RGB.tiff' or '_RGBA.tiff' will be appended to this filename. + mode: RGB or RGBA + """ + if mode == "RGBA": + array = np.concatenate([array, 255 * np.ones_like(array[0])[np.newaxis]]).astype(np.uint8) + + img_rgb = array.transpose(1, 2, 0) + imwrite(filename, img_rgb, shape=img_rgb.shape, tile=(16, 16)) + + return filename + + +def save_gray_tiff(array: np.ndarray, filename: str): + """ + Save numpy array into a TIFF file + + Args: + array: numpy ndarray with any shape + filename: the filename to be used for the tiff file. + """ + img_gray = array + imwrite(filename, img_gray, shape=img_gray.shape, photometric="rgb") + + return filename + + +@skipUnless(has_cucim or has_osl or has_tiff, "Requires cucim, openslide, or tifffile!") +def setUpModule(): # noqa: N802 + hash_type = testing_data_config("images", FILE_KEY, "hash_type") + hash_val = testing_data_config("images", FILE_KEY, "hash_val") + download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val) + + +class WSIReaderTests: + class Tests(unittest.TestCase): + backend = None + + @parameterized.expand([TEST_CASE_0]) + def test_read_whole_image(self, file_path, level, expected_shape): + reader = WSIReader(self.backend, level=level) + with reader.read(file_path) as img_obj: + img = reader.get_data(img_obj)[0] + self.assertTupleEqual(img.shape, expected_shape) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_5]) + def test_read_region(self, file_path, patch_info, expected_img): + kwargs = {"name": None, "offset": None} if self.backend == "tifffile" else {} + reader = WSIReader(self.backend, **kwargs) + with reader.read(file_path, **kwargs) as img_obj: + if self.backend == "tifffile": + with self.assertRaises(ValueError): + reader.get_data(img_obj, **patch_info)[0] + else: + # Read twice to check multiple calls + img = reader.get_data(img_obj, **patch_info)[0] + img2 = reader.get_data(img_obj, **patch_info)[0] + self.assertTupleEqual(img.shape, img2.shape) + self.assertIsNone(assert_array_equal(img, img2)) + self.assertTupleEqual(img.shape, expected_img.shape) + self.assertIsNone(assert_array_equal(img, expected_img)) + + @parameterized.expand([TEST_CASE_3, TEST_CASE_4]) + def test_read_patches(self, file_path, patch_info, expected_img): + reader = WSIReader(self.backend) + with reader.read(file_path) as img_obj: + if self.backend == "tifffile": + with self.assertRaises(ValueError): + reader.get_data(img_obj, **patch_info)[0] + else: + img = reader.get_data(img_obj, **patch_info)[0] + self.assertTupleEqual(img.shape, expected_img.shape) + self.assertIsNone(assert_array_equal(img, expected_img)) + + @parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1]) + @skipUnless(has_tiff, "Requires tifffile.") + def test_read_rgba(self, img_expected): + # skip for OpenSlide since not working with images without tiles + if self.backend == "openslide": + return + image = {} + reader = WSIReader(self.backend) + for mode in ["RGB", "RGBA"]: + file_path = save_rgba_tiff( + img_expected, + os.path.join(os.path.dirname(__file__), "testing_data", f"temp_tiff_image_{mode}.tiff"), + mode=mode, + ) + with reader.read(file_path) as img_obj: + image[mode], _ = reader.get_data(img_obj) + + self.assertIsNone(assert_array_equal(image["RGB"], img_expected)) + self.assertIsNone(assert_array_equal(image["RGBA"], img_expected)) + + @parameterized.expand([TEST_CASE_ERROR_GRAY, TEST_CASE_ERROR_3D]) + @skipUnless(has_tiff, "Requires tifffile.") + def test_read_malformats(self, img_expected): + reader = WSIReader(self.backend) + file_path = save_gray_tiff( + img_expected, os.path.join(os.path.dirname(__file__), "testing_data", "temp_tiff_image_gray.tiff") + ) + with self.assertRaises((RuntimeError, ValueError, openslide.OpenSlideError if has_osl else ValueError)): + with reader.read(file_path) as img_obj: + reader.get_data(img_obj) + + @parameterized.expand([TEST_CASE_TRANSFORM_0]) + def test_with_dataloader(self, file_path, level, expected_spatial_shape, expected_shape): + train_transform = Compose( + [ + LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level), + ToTensord(keys=["image"]), + ] + ) + dataset = Dataset([{"image": file_path}], transform=train_transform) + data_loader = DataLoader(dataset) + data: dict = first(data_loader) + for s in data[PostFix.meta("image")]["spatial_shape"]: + torch.testing.assert_allclose(s, expected_spatial_shape) + self.assertTupleEqual(data["image"].shape, expected_shape) + + +@skipUnless(has_cucim, "Requires cucim") +class TestCuCIM(WSIReaderTests.Tests): + @classmethod + def setUpClass(cls): + cls.backend = "cucim" + + +@skipUnless(has_osl, "Requires OpenSlide") +class TestOpenSlide(WSIReaderTests.Tests): + @classmethod + def setUpClass(cls): + cls.backend = "openslide" + + +@skipUnless(has_tiff, "Requires TiffFile") +class TestTiffFile(WSIReaderTests.Tests): + @classmethod + def setUpClass(cls): + cls.backend = "tifffile" + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_zipdataset.py b/tests/test_zipdataset.py index 710ca71fc2..f381e0a453 100644 --- a/tests/test_zipdataset.py +++ b/tests/test_zipdataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_zoom.py b/tests/test_zoom.py index e6710ede29..1a7694072e 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,11 +12,12 @@ import unittest import numpy as np +import torch from parameterized import parameterized from scipy.ndimage import zoom as zoom_scipy from monai.transforms import Zoom -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose VALID_CASES = [(1.5, "nearest"), (1.5, "nearest"), (0.8, "bilinear"), (0.8, "area")] @@ -26,38 +27,42 @@ class TestZoom(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) def test_correct_results(self, zoom, mode): - zoom_fn = Zoom(zoom=zoom, mode=mode, keep_size=False) - zoomed = zoom_fn(self.imt[0]) - _order = 0 - if mode.endswith("linear"): - _order = 1 - expected = [] - for channel in self.imt[0]: - expected.append(zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False)) - expected = np.stack(expected).astype(np.float32) - np.testing.assert_allclose(zoomed, expected, atol=1.0) + for p in TEST_NDARRAYS: + zoom_fn = Zoom(zoom=zoom, mode=mode, keep_size=False) + zoomed = zoom_fn(p(self.imt[0])) + _order = 0 + if mode.endswith("linear"): + _order = 1 + expected = [] + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False)) + expected = np.stack(expected).astype(np.float32) + assert_allclose(zoomed, p(expected), atol=1.0) def test_keep_size(self): - zoom_fn = Zoom(zoom=[0.6, 0.6], keep_size=True, align_corners=True, padding_mode="constant", constant_values=2) - zoomed = zoom_fn(self.imt[0], mode="bilinear") - np.testing.assert_allclose(zoomed.shape, self.imt.shape[1:]) + for p in TEST_NDARRAYS: + zoom_fn = Zoom(zoom=[0.6, 0.6], keep_size=True, align_corners=True) + zoomed = zoom_fn(p(self.imt[0]), mode="bilinear") + assert_allclose(zoomed.shape, self.imt.shape[1:]) - zoom_fn = Zoom(zoom=[1.3, 1.3], keep_size=True) - zoomed = zoom_fn(self.imt[0]) - np.testing.assert_allclose(zoomed.shape, self.imt.shape[1:]) + zoom_fn = Zoom(zoom=[1.3, 1.3], keep_size=True) + zoomed = zoom_fn(p(self.imt[0])) + assert_allclose(zoomed.shape, self.imt.shape[1:]) @parameterized.expand(INVALID_CASES) def test_invalid_inputs(self, zoom, mode, raises): - with self.assertRaises(raises): - zoom_fn = Zoom(zoom=zoom, mode=mode) - zoom_fn(self.imt[0]) + for p in TEST_NDARRAYS: + with self.assertRaises(raises): + zoom_fn = Zoom(zoom=zoom, mode=mode) + zoom_fn(p(self.imt[0])) def test_padding_mode(self): - zoom_fn = Zoom(zoom=0.5, mode="nearest", padding_mode="constant", keep_size=True) - test_data = np.array([[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]]) - zoomed = zoom_fn(test_data) - expected = np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]) - np.testing.assert_allclose(zoomed, expected) + for p in TEST_NDARRAYS: + zoom_fn = Zoom(zoom=0.5, mode="nearest", padding_mode="constant", keep_size=True) + test_data = p([[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]]) + zoomed = zoom_fn(test_data) + expected = p([[[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]) + torch.testing.assert_allclose(zoomed, expected) if __name__ == "__main__": diff --git a/tests/test_zoom_affine.py b/tests/test_zoom_affine.py index 49c3c0dcac..3c4bcd302c 100644 --- a/tests/test_zoom_affine.py +++ b/tests/test_zoom_affine.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_zoomd.py b/tests/test_zoomd.py index 1a1a905d80..87a5cec22b 100644 --- a/tests/test_zoomd.py +++ b/tests/test_zoomd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,7 +16,7 @@ from scipy.ndimage import zoom as zoom_scipy from monai.transforms import Zoomd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose VALID_CASES = [(1.5, "nearest", False), (0.3, "bilinear", False), (0.8, "bilinear", False)] @@ -27,38 +27,37 @@ class TestZoomd(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) def test_correct_results(self, zoom, mode, keep_size): key = "img" - zoom_fn = Zoomd( - key, - zoom=zoom, - mode=mode, - keep_size=keep_size, - ) - zoomed = zoom_fn({key: self.imt[0]}) - _order = 0 - if mode.endswith("linear"): - _order = 1 - expected = [] - for channel in self.imt[0]: - expected.append(zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False)) - expected = np.stack(expected).astype(np.float32) - np.testing.assert_allclose(expected, zoomed[key], atol=1.0) + zoom_fn = Zoomd(key, zoom=zoom, mode=mode, keep_size=keep_size) + for p in TEST_NDARRAYS: + zoomed = zoom_fn({key: p(self.imt[0])}) + _order = 0 + if mode.endswith("linear"): + _order = 1 + expected = [ + zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False) for channel in self.imt[0] + ] + + expected = np.stack(expected).astype(np.float32) + assert_allclose(zoomed[key], p(expected), atol=1.0) def test_keep_size(self): key = "img" zoom_fn = Zoomd(key, zoom=0.6, keep_size=True, padding_mode="constant", constant_values=2) - zoomed = zoom_fn({key: self.imt[0]}) - self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:])) + for p in TEST_NDARRAYS: + zoomed = zoom_fn({key: p(self.imt[0])}) + np.testing.assert_array_equal(zoomed[key].shape, self.imt.shape[1:]) - zoom_fn = Zoomd(key, zoom=1.3, keep_size=True) - zoomed = zoom_fn({key: self.imt[0]}) - self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:])) + zoom_fn = Zoomd(key, zoom=1.3, keep_size=True) + zoomed = zoom_fn({key: self.imt[0]}) + self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:])) @parameterized.expand(INVALID_CASES) def test_invalid_inputs(self, _, zoom, mode, raises): key = "img" - with self.assertRaises(raises): - zoom_fn = Zoomd(key, zoom=zoom, mode=mode) - zoom_fn({key: self.imt[0]}) + for p in TEST_NDARRAYS: + with self.assertRaises(raises): + zoom_fn = Zoomd(key, zoom=zoom, mode=mode) + zoom_fn({key: p(self.imt[0])}) if __name__ == "__main__": diff --git a/tests/testing_data/cpp_resample_answers.py b/tests/testing_data/cpp_resample_answers.py index 51ac6ccda9..93f596619e 100644 --- a/tests/testing_data/cpp_resample_answers.py +++ b/tests/testing_data/cpp_resample_answers.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,7 +23,7 @@ def _read_testing_data_answers(fname: Optional[str] = None, delimiter=",") -> Li pwd = os.path.dirname(os.path.abspath(__file__)) filename = os.path.join(pwd, fname) if not os.path.isfile(filename): - warnings.warn("test data {} not found.".format(filename)) + warnings.warn(f"test data {filename} not found.") return answers with open(filename) as f: res_reader = csv.reader(f, delimiter=delimiter) diff --git a/tests/testing_data/data_config.json b/tests/testing_data/data_config.json new file mode 100644 index 0000000000..8bcdcd244e --- /dev/null +++ b/tests/testing_data/data_config.json @@ -0,0 +1,78 @@ +{ + "images": { + "wsi_img": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/CMU-1.tiff", + "hash_type": "sha256", + "hash_val": "73a7e89bc15576587c3d68e55d9bf92f09690280166240b48ff4b48230b13bcd" + }, + "favicon": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/favicon.ico.zip", + "hash_type": "sha256", + "hash_val": "3a3635c8d8adb81feebc5926b4106e8eb643a24a4be2a69a9d35f9a578acadb5" + }, + "icon": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/icon.tar.gz", + "hash_type": "sha256", + "hash_val": "90f24cd8f20f3932624da95190ce384302261acf0ea15b358f7832e3b6becac0" + }, + "mednist": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz", + "hash_type": "sha256", + "hash_val": "f2f4881ff8799a170b10a403495f0ce0ad7486491901cde67a647e6627e7f916" + }, + "Prostate_T2W_AX_1": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/Prostate_T2W_AX_1.nii", + "hash_type": "sha256", + "hash_val": "a14231f539c0f365a5f83f2a046969a9b9870e56ffd126fd8e7242364d25938a" + }, + "0000_t2_tse_tra_4": { + "url": "https://github.com/rcuocolo/PROSTATEx_masks/raw/master/Files/lesions/Images/T2/ProstateX-0000_t2_tse_tra_4.nii.gz", + "hash_type": "md5", + "hash_val": "adb3f1c4db66a6481c3e4a2a3033c7d5" + }, + "0000_ep2d_diff_tra_7": { + "url": "https://github.com/rcuocolo/PROSTATEx_masks/raw/master/Files/lesions/Images/ADC/ProstateX-0000_ep2d_diff_tra_7.nii.gz", + "hash_type": "md5", + "hash_val": "f12a11ad0ebb0b1876e9e010564745d2" + } + }, + "models": { + "senet154-c7b49a05": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/senet154-c7b49a05.pth", + "hash_type": "sha256", + "hash_val": "c7b49a056b98b0bed65b0237c27acdead655e599669215573d357ad337460413" + }, + "se_resnet101-7e38fcc6": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/se_resnet101-7e38fcc6.pth", + "hash_type": "sha256", + "hash_val": "7e38fcc64eff3225a3ea4e6081efeb6087e8d5a61c204d94edc2ed1aab0b9d70" + }, + "se_resnet152-d17c99b7": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/se_resnet152-d17c99b7.pth", + "hash_type": "sha256", + "hash_val": "d17c99b703dcca2d2507ddfb68f72625a2f7e23ee64396eb992f1b2cf7e6bdc1" + }, + "se_resnet50-ce0d4300": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/se_resnet50-ce0d4300.pth", + "hash_type": "sha256", + "hash_val": "ce0d430017d3f4aa6b5658c72209f3bfffb060207fd26a2ef0b203ce592eba01" + }, + "se_resnext101_32x4d-3b2fe3d8": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/se_resnext101_32x4d-3b2fe3d8.pth", + "hash_type": "sha256", + "hash_val": "3b2fe3d8acb8de7d5976c4baf518f24a0237509272a69366e816682d3e57b989" + }, + "se_resnext50_32x4d-a260b3a4": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/se_resnext50_32x4d-a260b3a4.pth", + "hash_type": "sha256", + "hash_val": "a260b3a40f82dfe37c58d26a612bcf7bef0d27c6fed096226b0e4e9fb364168e" + } + }, + "configs": { + "test_meta_file": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_202203171008.json", + "hash_type": "md5", + "hash_val": "e3a7e23d1113a1f3e6c69f09b6f9ce2c" + } + } +} diff --git a/tests/testing_data/inference.json b/tests/testing_data/inference.json new file mode 100644 index 0000000000..cc9ddef866 --- /dev/null +++ b/tests/testing_data/inference.json @@ -0,0 +1,114 @@ +{ + "dataset_dir": "/workspace/data/Task09_Spleen", + "import_glob": "$import glob", + "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", + "set_seed": "$monai.utils.set_determinism(0)", + "print_test_name": "$print('json_test')", + "print_glob_file": "$print(glob.__file__)", + "network_def": { + "_target_": "UNet", + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 2, + "channels": [ + 2, + 2, + 4, + 8, + 4 + ], + "strides": [ + 2, + 2, + 2, + 2 + ], + "num_res_units": 2, + "norm": "batch" + }, + "network": "need override", + "preprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "LoadImaged", + "keys": "image" + }, + { + "_target_": "EnsureChannelFirstd", + "keys": "image" + }, + { + "_target_": "ScaleIntensityd", + "keys": "image" + }, + { + "_target_": "RandRotated", + "_disabled_": true, + "keys": "image" + }, + { + "_target_": "EnsureTyped", + "keys": "image" + } + ] + }, + "dataset": { + "_target_": "need override", + "data": "@_meta_#datalist", + "transform": "@preprocessing" + }, + "dataloader": { + "_target_": "DataLoader", + "dataset": "@dataset", + "batch_size": 1, + "shuffle": false, + "num_workers": 4 + }, + "inferer": { + "_target_": "SlidingWindowInferer", + "roi_size": [ + 64, + 64, + 32 + ], + "sw_batch_size": 4, + "overlap": 0.25 + }, + "postprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "Activationsd", + "keys": "pred", + "softmax": true + }, + { + "_target_": "AsDiscreted", + "keys": "pred", + "argmax": true + }, + { + "_target_": "SaveImaged", + "keys": "pred", + "meta_keys": "image_meta_dict", + "output_dir": "@_meta_#output_dir" + } + ] + }, + "evaluator": { + "_target_": "SupervisedEvaluator", + "_requires_": [ + "@set_seed", + "@print_test_name", + "@print_glob_file", + "$print('test_in_line_json')" + ], + "device": "@device", + "val_data_loader": "@dataloader", + "network": "@network", + "inferer": "@inferer", + "postprocessing": "@postprocessing", + "amp": false + } +} diff --git a/tests/testing_data/inference.yaml b/tests/testing_data/inference.yaml new file mode 100644 index 0000000000..4973d4473f --- /dev/null +++ b/tests/testing_data/inference.yaml @@ -0,0 +1,81 @@ +--- +dataset_dir: "/workspace/data/Task09_Spleen" +device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')" +set_seed: "$monai.utils.set_determinism(0)" +print_test_name: "$print('yaml_test')" +network_def: + _target_: UNet + spatial_dims: 3 + in_channels: 1 + out_channels: 2 + channels: + - 2 + - 2 + - 4 + - 8 + - 4 + strides: + - 2 + - 2 + - 2 + - 2 + num_res_units: 2 + norm: batch +network: need override +preprocessing: + _target_: Compose + transforms: + - _target_: LoadImaged + keys: image + - _target_: EnsureChannelFirstd + keys: image + - _target_: ScaleIntensityd + keys: image + - _target_: RandRotated + _disabled_: true + keys: image + - _target_: EnsureTyped + keys: image +dataset: + _target_: need override + data: "@_meta_#datalist" + transform: "@preprocessing" +dataloader: + _target_: DataLoader + dataset: "@dataset" + batch_size: 1 + shuffle: false + num_workers: 4 +inferer: + _target_: SlidingWindowInferer + roi_size: + - 64 + - 64 + - 32 + sw_batch_size: 4 + overlap: 0.25 +postprocessing: + _target_: Compose + transforms: + - _target_: Activationsd + keys: pred + softmax: true + - _target_: AsDiscreted + keys: pred + argmax: true + - _target_: SaveImaged + keys: pred + meta_keys: image_meta_dict + output_dir: "@_meta_#output_dir" +evaluator: + _target_: SupervisedEvaluator + _requires_: + - "$print('test_in_line_yaml')" + - "@set_seed" + - "@print_test_name" + device: "@device" + val_data_loader: "@dataloader" + network: "@network" + inferer: "@inferer" + postprocessing: "@postprocessing" + amp: false diff --git a/tests/testing_data/integration_answers.py b/tests/testing_data/integration_answers.py index ccb4293a40..99765a2b33 100644 --- a/tests/testing_data/integration_answers.py +++ b/tests/testing_data/integration_answers.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -418,7 +418,7 @@ 0.17794132232666016, 0.18584394454956055, 0.03577899932861328, - ], + ] }, "integration_segmentation_3d": { # for the mixed readers "losses": [ @@ -433,6 +433,153 @@ "infer_metric": 0.9326590299606323, }, }, + { # test answers for PyTorch 21.10 + "integration_classification_2d": { + "losses": [0.7806222991199251, 0.16259610306495315, 0.07529311385124353, 0.04640352608529246], + "best_metric": 0.9999369155431564, + "infer_prop": [1030, 898, 981, 1033, 960, 1046], + }, + "integration_segmentation_3d": { + "losses": [ + 0.5462362408638001, + 0.4913381844758987, + 0.4526856362819672, + 0.43404580652713776, + 0.42532919645309447, + 0.4160102754831314, + ], + "best_metric": 0.9357608556747437, + "infer_metric": 0.9359462857246399, + "output_sums": [ + 0.14133183650702907, + 0.15129517085134564, + 0.15039408698301698, + 0.1388800895551786, + 0.18765019147239637, + 0.16847158867677473, + 0.14567945622102715, + 0.16728557092807228, + 0.15601444057659314, + 0.17816339678760573, + 0.1616256801482474, + 0.16733042976922818, + 0.14342795433701588, + 0.1122946416901734, + 0.16105778942392063, + 0.20017543167070598, + 0.17512204704647916, + 0.09592956823274325, + 0.19316383411238341, + 0.2022308530579937, + 0.19527218778022315, + 0.2075871950564991, + 0.16083565516485876, + 0.13111518931029637, + 0.1473909261474288, + 0.14161210629657228, + 0.23102446985179093, + 0.15980667305916593, + 0.14760356792082058, + 0.1018092235719272, + 0.11792260857122504, + 0.1285278390386459, + 0.11275165891441473, + 0.15101653432548032, + 0.16236351926994622, + 0.1932631773335222, + 0.2221395787381994, + 0.18003549292918666, + 0.18940543270178078, + 0.07430261166443994, + ], + }, + "integration_workflows": { + "output_sums": [ + 0.14211511611938477, + 0.1516571044921875, + 0.1381092071533203, + 0.13403034210205078, + 0.18480682373046875, + 0.16382598876953125, + 0.14140796661376953, + 0.1665945053100586, + 0.15700864791870117, + 0.17697620391845703, + 0.16163396835327148, + 0.16488313674926758, + 0.1442713737487793, + 0.11060476303100586, + 0.16111087799072266, + 0.19617986679077148, + 0.1744403839111328, + 0.052786827087402344, + 0.19046974182128906, + 0.19913578033447266, + 0.19527721405029297, + 0.2032318115234375, + 0.16050148010253906, + 0.13228464126586914, + 0.1512293815612793, + 0.1372208595275879, + 0.22692251205444336, + 0.16164922714233398, + 0.14729642868041992, + 0.10398292541503906, + 0.1195836067199707, + 0.13096046447753906, + 0.11221647262573242, + 0.1521167755126953, + 0.1599421501159668, + 0.1898345947265625, + 0.21675777435302734, + 0.1777491569519043, + 0.18526840209960938, + 0.035144805908203125, + ], + "output_sums_2": [ + 0.14200592041015625, + 0.15146303176879883, + 0.13796186447143555, + 0.1339101791381836, + 0.18489742279052734, + 0.1637406349182129, + 0.14113903045654297, + 0.16657161712646484, + 0.15676355361938477, + 0.17683839797973633, + 0.1614980697631836, + 0.16493558883666992, + 0.14408016204833984, + 0.11035394668579102, + 0.1610560417175293, + 0.1962742805480957, + 0.17439842224121094, + 0.05285835266113281, + 0.19057941436767578, + 0.19914865493774414, + 0.19533538818359375, + 0.20333576202392578, + 0.16032838821411133, + 0.13197898864746094, + 0.1510462760925293, + 0.13703680038452148, + 0.2270984649658203, + 0.16144943237304688, + 0.1472611427307129, + 0.10393238067626953, + 0.11940813064575195, + 0.1307811737060547, + 0.11203241348266602, + 0.15186500549316406, + 0.15992307662963867, + 0.18991422653198242, + 0.21689796447753906, + 0.1777033805847168, + 0.18547868728637695, + 0.035192012786865234, + ], + }, + }, ] @@ -440,6 +587,8 @@ def test_integration_value(test_name, key, data, rtol=1e-2): for (idx, expected) in enumerate(EXPECTED_ANSWERS): if test_name not in expected: continue + if key not in expected[test_name]: + continue value = expected[test_name][key] if np.allclose(data, value, rtol=rtol): print(f"matched {idx} result of {test_name}, {key}, {rtol}.") diff --git a/tests/testing_data/matshow3d_patch_test.png b/tests/testing_data/matshow3d_patch_test.png new file mode 100644 index 0000000000..a4d89e3446 Binary files /dev/null and b/tests/testing_data/matshow3d_patch_test.png differ diff --git a/tests/testing_data/matshow3d_rgb_test.png b/tests/testing_data/matshow3d_rgb_test.png new file mode 100644 index 0000000000..7c8e224c0e Binary files /dev/null and b/tests/testing_data/matshow3d_rgb_test.png differ diff --git a/tests/testing_data/matshow3d_test.png b/tests/testing_data/matshow3d_test.png new file mode 100644 index 0000000000..d720a0c407 Binary files /dev/null and b/tests/testing_data/matshow3d_test.png differ diff --git a/tests/testing_data/metadata.json b/tests/testing_data/metadata.json new file mode 100644 index 0000000000..42a55b114c --- /dev/null +++ b/tests/testing_data/metadata.json @@ -0,0 +1,76 @@ +{ + "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_202203171008.json", + "version": "0.1.0", + "changelog": { + "0.1.0": "complete the model package", + "0.0.1": "initialize the model package structure" + }, + "monai_version": "0.8.0", + "pytorch_version": "1.10.0", + "numpy_version": "1.21.2", + "optional_packages_version": { + "nibabel": "3.2.1" + }, + "task": "Decathlon spleen segmentation", + "description": "A pre-trained model for volumetric (3D) segmentation of the spleen from CT image", + "authorship": "MONAI team", + "copyright": "Copyright (c) MONAI Consortium", + "data_source": "Task09_Spleen.tar from http://medicaldecathlon.com/", + "data_type": "dicom", + "image_classes": "single channel data, intensity scaled to [0, 1]", + "label_classes": "single channel data, 1 is spleen, 0 is everything else", + "pred_classes": "2 channels OneHot data, channel 1 is spleen, channel 0 is background", + "eval_metrics": { + "mean_dice": 0.96 + }, + "intended_use": "This is an example, not to be used for diagnostic purposes", + "references": [ + "Xia, Yingda, et al. '3D Semi-Supervised Learning with Uncertainty-Aware Multi-View Co-Training. arXiv preprint arXiv:1811.12506 (2018). https://arxiv.org/abs/1811.12506.", + "Kerfoot E., Clough J., Oksuz I., Lee J., King A.P., Schnabel J.A. (2019) Left-Ventricle Quantification Using Residual U-Net. In: Pop M. et al. (eds) Statistical Atlases and Computational Models of the Heart. Atrial Segmentation and LV Quantification Challenges. STACOM 2018. Lecture Notes in Computer Science, vol 11395. Springer, Cham. https://doi.org/10.1007/978-3-030-12029-0_40" + ], + "network_data_format": { + "inputs": { + "image": { + "type": "image", + "format": "magnitude", + "num_channels": 1, + "spatial_shape": [ + 160, + 160, + 160 + ], + "dtype": "float32", + "value_range": [ + 0, + 1 + ], + "is_patch_data": false, + "channel_def": { + "0": "image" + } + } + }, + "outputs": { + "pred": { + "type": "image", + "format": "segmentation", + "num_channels": 2, + "spatial_shape": [ + 160, + 160, + 160 + ], + "dtype": "float32", + "value_range": [ + 0, + 1 + ], + "is_patch_data": false, + "channel_def": { + "0": "background", + "1": "spleen" + } + } + } + } +} diff --git a/tests/utils.py b/tests/utils.py index 1375cd2d72..3065f9b3df 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,6 +13,8 @@ import datetime import functools import importlib +import json +import operator import os import queue import sys @@ -21,27 +23,41 @@ import traceback import unittest import warnings -from functools import partial -from io import BytesIO +from contextlib import contextmanager +from functools import partial, reduce from subprocess import PIPE, Popen from typing import Callable, Optional, Tuple -from urllib.error import ContentTooShortError, HTTPError, URLError +from urllib.error import ContentTooShortError, HTTPError import numpy as np import torch import torch.distributed as dist +from monai.apps.utils import download_url from monai.config import NdarrayTensor from monai.config.deviceconfig import USE_COMPILED from monai.config.type_definitions import NdarrayOrTensor from monai.data import create_test_image_2d, create_test_image_3d -from monai.utils import ensure_tuple, optional_import, set_determinism -from monai.utils.misc import is_module_ver_at_least -from monai.utils.module import version_leq +from monai.networks import convert_to_torchscript +from monai.utils import optional_import +from monai.utils.module import pytorch_after, version_leq +from monai.utils.type_conversion import convert_data_type nib, _ = optional_import("nibabel") quick_test_var = "QUICKTEST" +_tf32_enabled = None +_test_data_config: dict = {} + + +def testing_data_config(*keys): + """get _test_data_config[keys0][keys1]...[keysN]""" + if not _test_data_config: + with open(os.path.join(os.path.dirname(__file__), "testing_data", "data_config.json")) as c: + _config = json.load(c) + for k, v in _config.items(): + _test_data_config[k] = v + return reduce(operator.getitem, keys, _test_data_config) def clone(data: NdarrayTensor) -> NdarrayTensor: @@ -57,31 +73,95 @@ def clone(data: NdarrayTensor) -> NdarrayTensor: return copy.deepcopy(data) -def assert_allclose(a: NdarrayOrTensor, b: NdarrayOrTensor, *args, **kwargs): +def assert_allclose( + actual: NdarrayOrTensor, + desired: NdarrayOrTensor, + type_test: bool = True, + device_test: bool = False, + *args, + **kwargs, +): """ - Assert that all values of two data objects are close. + Assert that types and all values of two data objects are close. Args: - a (NdarrayOrTensor): Pytorch Tensor or numpy array for comparison - b (NdarrayOrTensor): Pytorch Tensor or numpy array to compare against + actual: Pytorch Tensor or numpy array for comparison. + desired: Pytorch Tensor or numpy array to compare against. + type_test: whether to test that `actual` and `desired` are both numpy arrays or torch tensors. + device_test: whether to test the device property. + args: extra arguments to pass on to `np.testing.assert_allclose`. + kwargs: extra arguments to pass on to `np.testing.assert_allclose`. + + """ - a = a.cpu() if isinstance(a, torch.Tensor) else a - b = b.cpu() if isinstance(b, torch.Tensor) else b - np.testing.assert_allclose(a, b, *args, **kwargs) + if type_test: + # check both actual and desired are of the same type + np.testing.assert_equal(isinstance(actual, np.ndarray), isinstance(desired, np.ndarray), "numpy type") + np.testing.assert_equal(isinstance(actual, torch.Tensor), isinstance(desired, torch.Tensor), "torch type") + if isinstance(desired, torch.Tensor) or isinstance(actual, torch.Tensor): + if device_test: + np.testing.assert_equal(str(actual.device), str(desired.device), "torch device check") # type: ignore + actual = actual.cpu().numpy() if isinstance(actual, torch.Tensor) else actual + desired = desired.cpu().numpy() if isinstance(desired, torch.Tensor) else desired + np.testing.assert_allclose(actual, desired, *args, **kwargs) -def test_pretrained_networks(network, input_param, device): + +@contextmanager +def skip_if_downloading_fails(): try: - net = network(**input_param).to(device) - except (URLError, HTTPError, ContentTooShortError) as e: - raise unittest.SkipTest(e) - return net + yield + except (ContentTooShortError, HTTPError, ConnectionError) as e: + raise unittest.SkipTest(f"error while downloading: {e}") from e + except RuntimeError as rt_e: + if "unexpected EOF" in str(rt_e): + raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e # incomplete download + if "network issue" in str(rt_e): + raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e + if "gdown dependency" in str(rt_e): # no gdown installed + raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e + if "md5 check" in str(rt_e): + raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e + raise rt_e + + +def test_pretrained_networks(network, input_param, device): + with skip_if_downloading_fails(): + return network(**input_param).to(device) def test_is_quick(): return os.environ.get(quick_test_var, "").lower() == "true" +def is_tf32_env(): + """ + The environment variable NVIDIA_TF32_OVERRIDE=0 will override any defaults + or programmatic configuration of NVIDIA libraries, and consequently, + cuBLAS will not accelerate FP32 computations with TF32 tensor cores. + """ + global _tf32_enabled + if _tf32_enabled is None: + _tf32_enabled = False + if ( + torch.cuda.is_available() + and not version_leq(f"{torch.version.cuda}", "10.100") + and os.environ.get("NVIDIA_TF32_OVERRIDE", "1") != "0" + and torch.cuda.device_count() > 0 # at least 11.0 + ): + try: + # with TF32 enabled, the speed is ~8x faster, but the precision has ~2 digits less in the result + g_gpu = torch.Generator(device="cuda") + g_gpu.manual_seed(2147483647) + a_full = torch.randn(1024, 1024, dtype=torch.double, device="cuda", generator=g_gpu) + b_full = torch.randn(1024, 1024, dtype=torch.double, device="cuda", generator=g_gpu) + _tf32_enabled = (a_full.float() @ b_full.float() - a_full @ b_full).abs().max().item() > 0.001 # 0.1713 + except BaseException: + pass + print(f"tf32 enabled: {_tf32_enabled}") + return _tf32_enabled + + def skip_if_quick(obj): """ Skip the unit tests if environment variable `quick_test_var=true`. @@ -143,7 +223,7 @@ class SkipIfBeforePyTorchVersion: def __init__(self, pytorch_version_tuple): self.min_version = pytorch_version_tuple - self.version_too_old = not is_module_ver_at_least(torch, pytorch_version_tuple) + self.version_too_old = not pytorch_after(*pytorch_version_tuple) def __call__(self, obj): return unittest.skipIf( @@ -157,8 +237,7 @@ class SkipIfAtLeastPyTorchVersion: def __init__(self, pytorch_version_tuple): self.max_version = pytorch_version_tuple - test_ver = ".".join(map(str, self.max_version)) - self.version_too_new = version_leq(test_ver, torch.__version__) + self.version_too_new = pytorch_after(*pytorch_version_tuple) def __call__(self, obj): return unittest.skipIf( @@ -166,19 +245,69 @@ def __call__(self, obj): )(obj) -def make_nifti_image(array, affine=None): +def is_main_test_process(): + ps = torch.multiprocessing.current_process() + if not ps or not hasattr(ps, "name"): + return False + return ps.name.startswith("Main") + + +def has_cupy(): + """ + Returns True if the user has installed a version of cupy. + """ + cp, has_cp = optional_import("cupy") + if not is_main_test_process(): + return has_cp # skip the check if we are running in subprocess + if not has_cp: + return False + try: # test cupy installation with a basic example + x = cp.arange(6, dtype="f").reshape(2, 3) + y = cp.arange(3, dtype="f") + kernel = cp.ElementwiseKernel( + "float32 x, float32 y", "float32 z", """ if (x - 2 > y) { z = x * y; } else { z = x + y; } """, "my_kernel" + ) + flag = kernel(x, y)[0, 0] == 0 + del x, y, kernel + cp.get_default_memory_pool().free_all_blocks() + return flag + except Exception: + return False + + +HAS_CUPY = has_cupy() + + +def make_nifti_image(array: NdarrayOrTensor, affine=None, dir=None, fname=None, suffix=".nii.gz", verbose=False): """ Create a temporary nifti image on the disk and return the image name. User is responsible for deleting the temporary file when done with it. """ + if isinstance(array, torch.Tensor): + array, *_ = convert_data_type(array, np.ndarray) + if isinstance(affine, torch.Tensor): + affine, *_ = convert_data_type(affine, np.ndarray) if affine is None: affine = np.eye(4) test_image = nib.Nifti1Image(array, affine) - temp_f, image_name = tempfile.mkstemp(suffix=".nii.gz") - nib.save(test_image, image_name) - os.close(temp_f) - return image_name + # if dir not given, create random. Else, make sure it exists. + if dir is None: + dir = tempfile.mkdtemp() + else: + os.makedirs(dir, exist_ok=True) + + # If fname not given, get random one. Else, concat dir, fname and suffix. + if fname is None: + temp_f, fname = tempfile.mkstemp(suffix=suffix, dir=dir) + os.close(temp_f) + else: + fname = os.path.join(dir, fname + suffix) + + nib.save(test_image, fname) + if verbose: + print(f"File written: {fname}.") + return fname def make_rand_affine(ndim: int = 3, random_state: Optional[np.random.RandomState] = None): @@ -298,8 +427,7 @@ def run_process(self, func, local_rank, args, kwargs, results): os.environ["RANK"] = str(self.nproc_per_node * self.node_rank + local_rank) if torch.cuda.is_available(): - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - torch.cuda.set_device(int(local_rank)) + torch.cuda.set_device(int(local_rank)) # using device ids from CUDA_VISIBILE_DEVICES dist.init_process_group( backend=self.backend, @@ -354,6 +482,7 @@ def _wrapper(*args, **kwargs): for p in processes: p.join() assert results.get(), "Distributed call failed." + _del_original_func(obj) return _wrapper @@ -435,6 +564,7 @@ def _wrapper(*args, **kwargs): finally: p.join() + _del_original_func(obj) res = None try: res = results.get(block=False) @@ -460,6 +590,15 @@ def _cache_original_func(obj) -> None: _original_funcs[obj.__name__] = obj +def _del_original_func(obj): + """pop the original function from cache.""" + global _original_funcs + _original_funcs.pop(obj.__name__, None) + if torch.cuda.is_available(): # clean up the cached function + torch.cuda.synchronize() + torch.cuda.empty_cache() + + def _call_original_func(name, module, *args, **kwargs): if name not in _original_funcs: _original_module = importlib.import_module(module) # reimport, refresh _original_funcs @@ -524,59 +663,37 @@ def setUp(self): self.segn = torch.tensor(self.segn) -def test_script_save(net, *inputs, eval_nets=True, device=None, rtol=1e-4): +def test_script_save(net, *inputs, device=None, rtol=1e-4, atol=0.0): """ Test the ability to save `net` as a Torchscript object, reload it, and apply inference. The value `inputs` is - forward-passed through the original and loaded copy of the network and their results returned. Both `net` and its - reloaded copy are set to evaluation mode if `eval_nets` is True. The forward pass for both is done without - gradient accumulation. + forward-passed through the original and loaded copy of the network and their results returned. + The forward pass for both is done without gradient accumulation. The test will be performed with CUDA if available, else CPU. """ - if True: - device = "cpu" - else: - # TODO: It would be nice to be able to use GPU if - # available, but this currently causes CI failures. - if not device: - device = "cuda" if torch.cuda.is_available() else "cpu" - - # Convert to device - inputs = [i.to(device) for i in inputs] - - scripted = torch.jit.script(net.cpu()) - buffer = scripted.save_to_buffer() - reloaded_net = torch.jit.load(BytesIO(buffer)).to(device) - net.to(device) - - if eval_nets: - net.eval() - reloaded_net.eval() - - with torch.no_grad(): - set_determinism(seed=0) - result1 = net(*inputs) - result2 = reloaded_net(*inputs) - set_determinism(seed=None) - - # convert results to tuples if needed to allow iterating over pairs of outputs - result1 = ensure_tuple(result1) - result2 = ensure_tuple(result2) - - for i, (r1, r2) in enumerate(zip(result1, result2)): - if None not in (r1, r2): # might be None - np.testing.assert_allclose( - r1.detach().cpu().numpy(), - r2.detach().cpu().numpy(), - rtol=rtol, - atol=0, - err_msg=f"failed on comparison number: {i}", - ) + # TODO: would be nice to use GPU if available, but it currently causes CI failures. + device = "cpu" + with tempfile.TemporaryDirectory() as tempdir: + convert_to_torchscript( + model=net, + filename_or_obj=os.path.join(tempdir, "model.ts"), + verify=True, + inputs=inputs, + device=device, + rtol=rtol, + atol=atol, + ) + + +def download_url_or_skip_test(*args, **kwargs): + """``download_url`` and skip the tests if any downloading error occurs.""" + with skip_if_downloading_fails(): + download_url(*args, **kwargs) def query_memory(n=2): """ - Find best n idle devices and return a string of device ids. + Find best n idle devices and return a string of device ids using the `nvidia-smi` command. """ bash_string = "nvidia-smi --query-gpu=power.draw,temperature.gpu,memory.used --format=csv,noheader,nounits" @@ -587,7 +704,7 @@ def query_memory(n=2): free_memory = np.asarray(free_memory, dtype=float).T free_memory[1] += free_memory[0] # combine 0/1 column measures ids = np.lexsort(free_memory)[:n] - except (FileNotFoundError, TypeError, IndexError): + except (TypeError, IndexError, OSError): ids = range(n) if isinstance(n, int) else [] return ",".join(f"{int(x)}" for x in ids)