diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index f7024f1a08..e1eeb92c6b 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -12,6 +12,6 @@ A few sentences describing the changes proposed in this pull request. - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. -- [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests`. +- [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. diff --git a/.github/workflows/cron-mmar.yml b/.github/workflows/cron-mmar.yml new file mode 100644 index 0000000000..f61ba59368 --- /dev/null +++ b/.github/workflows/cron-mmar.yml @@ -0,0 +1,42 @@ +name: cron-mmar + +on: + schedule: + - cron: "0 2 * * *" # at 02:00 UTC + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +concurrency: + # automatically cancel the previously triggered workflows when there's a newer version + group: mmar-tests-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + cron-load: + if: github.repository == 'Project-MONAI/MONAI' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: cache weekly timestamp + id: pip-cache + run: echo "::set-output name=datew::$(date '+%Y-%V')" + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip wheel + python -m pip install -r requirements-dev.txt + - name: Loading MMARs + run: | + # clean up temporary files + $(pwd)/runtests.sh --build --clean + # run tests + python -m tests.ngc_mmar_loading diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml index a36cfbcdb9..f50363b1e0 100644 --- a/.github/workflows/cron.yml +++ b/.github/workflows/cron.yml @@ -15,7 +15,7 @@ jobs: runs-on: [self-hosted, linux, x64, common] strategy: matrix: - pytorch-version: [1.5.1, 1.6.0, 1.7.1, 1.8.1, latest] + pytorch-version: [1.6.0, 1.7.1, 1.8.1, 1.9.1, latest] steps: - uses: actions/checkout@v2 - name: Install the dependencies @@ -25,14 +25,14 @@ jobs: python -m pip uninstall -y torch torchvision if [ ${{ matrix.pytorch-version }} == "latest" ]; then python -m pip install torch torchvision - elif [ ${{ matrix.pytorch-version }} == "1.5.1" ]; then - python -m pip install torch==1.5.1 torchvision==0.6.1 elif [ ${{ matrix.pytorch-version }} == "1.6.0" ]; then python -m pip install torch==1.6.0 torchvision==0.7.0 elif [ ${{ matrix.pytorch-version }} == "1.7.1" ]; then python -m pip install torch==1.7.1 torchvision==0.8.2 elif [ ${{ matrix.pytorch-version }} == "1.8.1" ]; then python -m pip install torch==1.8.1 torchvision==0.9.1 + elif [ ${{ matrix.pytorch-version }} == "1.9.1" ]; then + python -m pip install torch==1.9.1 torchvision==0.10.1 fi python -m pip install -r requirements-dev.txt python -m pip list @@ -48,8 +48,8 @@ jobs: python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null & python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))" python -c 'import torch; print(torch.rand(5, 3, device=torch.device("cuda:0")))' - BUILD_MONAI=1 ./runtests.sh --coverage --unittests # unit tests with coverage report - BUILD_MONAI=1 ./runtests.sh --coverage --net # integration tests with coverage report + BUILD_MONAI=1 ./runtests.sh --build --coverage --unittests --disttests # unit tests with coverage report + BUILD_MONAI=1 ./runtests.sh --build --coverage --net # integration tests with coverage report coverage xml if pgrep python; then pkill python; fi - name: Upload coverage @@ -62,7 +62,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:21.02", "pytorch:21.08"] # 21.02 for backward comp. + container: ["pytorch:21.02", "pytorch:21.12"] # 21.02 for backward comp. container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -91,8 +91,8 @@ jobs: python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null & python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))" python -c 'import torch; print(torch.rand(5, 3, device=torch.device("cuda:0")))' - BUILD_MONAI=1 ./runtests.sh --coverage --unittests # unit tests with coverage report - BUILD_MONAI=1 ./runtests.sh --coverage --net # integration tests with coverage report + BUILD_MONAI=1 ./runtests.sh --build --coverage --unittests --disttests # unit tests with coverage report + BUILD_MONAI=1 ./runtests.sh --build --coverage --net # integration tests with coverage report coverage xml if pgrep python; then pkill python; fi - name: Upload coverage @@ -106,7 +106,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:21.02", "pytorch:21.08"] # 21.02 for backward comp. + container: ["pytorch:21.02", "pytorch:21.12"] # 21.02 for backward comp. container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -173,7 +173,7 @@ jobs: cron-docker: if: github.repository == 'Project-MONAI/MONAI' container: - image: localhost:5000/local_monai:dockerhub # use currently latest, locally available dockerhub image + image: docker://projectmonai/monai:latest # this might be slow and has the pull count limitations options: "--gpus all" runs-on: [self-hosted, linux, x64, common] steps: @@ -190,8 +190,8 @@ jobs: python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))" python -c 'import torch; print(torch.rand(5,3, device=torch.device("cuda:0")))' ngc --version - BUILD_MONAI=1 ./runtests.sh --coverage --pytype --unittests # unit tests with pytype checks, coverage report - BUILD_MONAI=1 ./runtests.sh --coverage --net # integration tests with coverage report + BUILD_MONAI=1 ./runtests.sh --build --coverage --pytype --unittests --disttests # unit tests with pytype checks, coverage report + BUILD_MONAI=1 ./runtests.sh --build --coverage --net # integration tests with coverage report coverage xml if pgrep python; then pkill python; fi - name: Upload coverage @@ -204,7 +204,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' needs: cron-gpu # so that monai itself is verified first container: - image: nvcr.io/nvidia/pytorch:21.08-py3 # testing with the latest pytorch base image + image: nvcr.io/nvidia/pytorch:21.09-py3 # testing with the latest pytorch base image options: "--gpus all --ipc=host" runs-on: [self-hosted, linux, x64, common] steps: @@ -215,7 +215,7 @@ jobs: which python python -m pip install --upgrade pip wheel python -m pip install -r requirements-dev.txt - BUILD_MONAI=0 python setup.py develop # install monai + BUILD_MONAI=1 python setup.py develop # install monai nvidia-smi export CUDA_VISIBLE_DEVICES=$(python -m tests.utils) echo $CUDA_VISIBLE_DEVICES @@ -234,5 +234,7 @@ jobs: trap 'if pgrep python; then pkill python; fi;' ERR python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null & cd /opt/tutorials + python -c 'import monai; monai.config.print_debug_info()' $(pwd)/runner.sh + python -c 'import monai; monai.config.print_debug_info()' if pgrep python; then pkill python; fi diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 3104224e2b..7140cd7dd8 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -13,20 +13,21 @@ on: workflow_dispatch: jobs: - versioning: + versioning_dev: # compute versioning file from python setup.py # upload as artifact - # (also used in release.yml) if: github.repository == 'Project-MONAI/MONAI' - container: - image: localhost:5000/local_monai:latest - runs-on: [self-hosted, linux, x64, build_only] + runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 # full history so that we can git describe with: ref: dev fetch-depth: 0 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 - shell: bash run: | git describe @@ -43,13 +44,11 @@ jobs: ls -al rm -rf {*,.[^.]*} - local_docker: - # builds two versions: local_monai:latest and local_monai:dockerhub - # latest: used for local tests - # dockerhub: release, no flake package + docker_build_dev: + # builds projectmonai/monai:latest if: github.repository == 'Project-MONAI/MONAI' - needs: versioning - runs-on: [self-hosted, linux, x64, build_only] + needs: versioning_dev + runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 with: @@ -58,67 +57,47 @@ jobs: uses: actions/download-artifact@v2 with: name: _version.py + - name: Install Latest Docker + run: | + curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add - + sudo add-apt-repository "deb [arch=amd64] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" + sudo apt-get update + sudo apt-get install docker-ce - name: docker_build shell: bash run: | # get tag info for versioning cat _version.py mv _version.py monai/ - # build and run original docker image for local registry - docker build -t localhost:5000/local_monai:latest -f Dockerfile . - docker push localhost:5000/local_monai:latest - # build once more w/ tag "latest": remove flake package as it is not needed on hub.docker.com + + # build "latest": remove flake package as it is not needed on hub.docker.com sed -i '/flake/d' requirements-dev.txt docker build -t projectmonai/monai:latest -f Dockerfile . - # also push as tag "dockerhub" to local registry - docker image tag projectmonai/monai:latest localhost:5000/local_monai:dockerhub - docker push localhost:5000/local_monai:dockerhub + # distribute as always w/ tag "latest" to hub.docker.com echo "${{ secrets.DOCKER_PW }}" | docker login -u projectmonai --password-stdin + docker push projectmonai/monai:latest docker logout docker image prune -f - docker_test_latest: - if: github.repository == 'Project-MONAI/MONAI' - needs: local_docker - container: - image: localhost:5000/local_monai:latest - runs-on: [self-hosted, linux, x64, common] - steps: - - name: Import - run: | - export CUDA_VISIBLE_DEVICES=$(python -m tests.utils) - echo $CUDA_VISIBLE_DEVICES - trap 'if pgrep python; then pkill python; fi;' ERR - python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null & - python -c 'import monai; monai.config.print_config()' - cd /opt/monai - ls -al - ngc --version - python -m tests.min_tests - if pgrep python; then pkill python; fi - env: - QUICKTEST: True - docker_test_dockerhub: if: github.repository == 'Project-MONAI/MONAI' - needs: local_docker + needs: docker_build_dev container: - image: localhost:5000/local_monai:dockerhub - runs-on: [self-hosted, linux, x64, common] + image: docker://projectmonai/monai:latest + options: "--shm-size=4g --ipc=host" + runs-on: ubuntu-latest steps: - name: Import run: | export CUDA_VISIBLE_DEVICES=$(python -m tests.utils) echo $CUDA_VISIBLE_DEVICES - trap 'if pgrep python; then pkill python; fi;' ERR - python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null & - python -c 'import monai; monai.config.print_config()' + python -c 'import monai; monai.config.print_debug_info()' cd /opt/monai ls -al ngc --version - python -m tests.min_tests - if pgrep python; then pkill python; fi + ./runtests.sh --min + shell: bash env: QUICKTEST: True diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index ed025e98fe..4b93632723 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -34,7 +34,7 @@ jobs: which python python -m pip install --upgrade pip wheel python -m pip uninstall -y torch torchvision - python -m pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html + python -m pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 -f https://download.pytorch.org/whl/torch_stable.html python -m pip install -r requirements-dev.txt - name: Run integration tests run: | @@ -46,8 +46,8 @@ jobs: python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null & python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))" python -c 'import torch; print(torch.rand(5,3, device=torch.device("cuda:0")))' - BUILD_MONAI=1 ./runtests.sh --net - BUILD_MONAI=1 ./runtests.sh --unittests + BUILD_MONAI=1 ./runtests.sh --build --net + BUILD_MONAI=1 ./runtests.sh --build --unittests --disttests if pgrep python; then pkill python; fi shell: bash - name: Add reaction diff --git a/.github/workflows/pythonapp-gpu.yml b/.github/workflows/pythonapp-gpu.yml index 999567ae16..848eaedcd6 100644 --- a/.github/workflows/pythonapp-gpu.yml +++ b/.github/workflows/pythonapp-gpu.yml @@ -19,35 +19,37 @@ jobs: strategy: matrix: environment: - - "PT16+CUDA110" - "PT17+CUDA102" - - "PT17+CUDA110" - "PT18+CUDA102" + - "PT18+CUDA112" - "PT19+CUDA114" - - "PT19+CUDA102" + - "PT110+CUDA115" + - "PT110+CUDA102" include: - - environment: PT16+CUDA110 - # we explicitly set pytorch to -h to avoid pip install error - pytorch: "-h" - base: "nvcr.io/nvidia/pytorch:20.07-py3" + # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes - environment: PT17+CUDA102 pytorch: "torch==1.7.1 torchvision==0.8.2" base: "nvcr.io/nvidia/cuda:10.2-devel-ubuntu18.04" - - environment: PT17+CUDA110 - # we explicitly set pytorch to -h to avoid pip install error - pytorch: "-h" - base: "nvcr.io/nvidia/pytorch:20.09-py3" - environment: PT18+CUDA102 pytorch: "torch==1.8.1 torchvision==0.9.1" base: "nvcr.io/nvidia/cuda:10.2-devel-ubuntu18.04" + - environment: PT18+CUDA112 + # we explicitly set pytorch to -h to avoid pip install error + # 21.03: 1.9.0a0+df837d0 + pytorch: "-h" + base: "nvcr.io/nvidia/pytorch:21.03-py3" - environment: PT19+CUDA114 # we explicitly set pytorch to -h to avoid pip install error - # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes - # 21.08: 1.10.0a0+3fd9dcf + # 21.10: 1.10.0a0+0aef44c + pytorch: "-h" + base: "nvcr.io/nvidia/pytorch:21.10-py3" + - environment: PT110+CUDA115 + # we explicitly set pytorch to -h to avoid pip install error + # 21.12: 1.11.0a0+b6df043 pytorch: "-h" - base: "nvcr.io/nvidia/pytorch:21.08-py3" - - environment: PT19+CUDA102 - pytorch: "torch==1.9.0 torchvision==0.10.0" + base: "nvcr.io/nvidia/pytorch:21.12-py3" + - environment: PT110+CUDA102 + pytorch: "torch==1.10.1 torchvision==0.11.2" base: "nvcr.io/nvidia/cuda:10.2-devel-ubuntu18.04" container: image: ${{ matrix.base }} @@ -59,9 +61,9 @@ jobs: run: | if [ ${{ matrix.environment }} = "PT17+CUDA102" ] || \ [ ${{ matrix.environment }} = "PT18+CUDA102" ] || \ - [ ${{ matrix.environment }} = "PT19+CUDA102" ] + [ ${{ matrix.environment }} = "PT110+CUDA102" ] then - PYVER=3.6 PYSFX=3 DISTUTILS=python3-distutils && \ + PYVER=3.7 PYSFX=3 DISTUTILS=python3-distutils && \ apt-get update && apt-get install -y --no-install-recommends \ curl \ pkg-config \ @@ -100,6 +102,8 @@ jobs: run: | which python python -m pip install --upgrade pip wheel + # fixes preinstalled ruamel_yaml error from the docker image + rm -rf $(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")/ruamel* python -m pip install ${{ matrix.pytorch }} python -m pip install -r requirements-dev.txt python -m pip list @@ -120,8 +124,8 @@ jobs: python -c 'import torch; print(torch.rand(5, 3, device=torch.device("cuda:0")))' python -c "import monai; monai.config.print_config()" # build for the current self-hosted CI Tesla V100 - BUILD_MONAI=1 TORCH_CUDA_ARCH_LIST="7.0" ./runtests.sh --quick --unittests - if [ ${{ matrix.environment }} = "PT19+CUDA102" ]; then + BUILD_MONAI=1 TORCH_CUDA_ARCH_LIST="7.0" ./runtests.sh --build --quick --unittests --disttests + if [ ${{ matrix.environment }} = "PT110+CUDA102" ]; then # test the clang-format tool downloading once coverage run -m tests.clang_format_utils fi diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml new file mode 100644 index 0000000000..b7816542ea --- /dev/null +++ b/.github/workflows/pythonapp-min.yml @@ -0,0 +1,170 @@ +name: build-min + +on: + # quick tests for pull requests and the releasing branches + push: + branches: + - dev + - main + - releasing/* + pull_request: + +concurrency: + # automatically cancel the previously triggered workflows when there's a newer version + group: build-min-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + # caching of these jobs: + # - docker-py3-pip- (shared) + # - ubuntu py37 pip- + # - os-latest-pip- (shared) + min-dep-os: # min dependencies installed tests for different OS + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [windows-latest, macOS-latest, ubuntu-latest] + timeout-minutes: 40 + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Prepare pip wheel + run: | + which python + python -m pip install --upgrade pip wheel + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + echo "::set-output name=dir::$(pip cache dir)" + shell: bash + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ matrix.os }}-latest-pip-${{ steps.pip-cache.outputs.datew }} + - if: runner.os == 'windows' + name: Install torch cpu from pytorch.org (Windows only) + run: | + python -m pip install torch==1.10.1+cpu -f https://download.pytorch.org/whl/torch_stable.html + - name: Install the dependencies + run: | + # min. requirements + python -m pip install torch==1.10.1 + python -m pip install -r requirements-min.txt + python -m pip list + BUILD_MONAI=0 python setup.py develop # no compile of extensions + shell: bash + - name: Run quick tests (CPU ${{ runner.os }}) + run: | + python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' + python -c "import monai; monai.config.print_config()" + ./runtests.sh --min + shell: bash + env: + QUICKTEST: True + + min-dep-py3: # min dependencies installed tests for different python + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [3.7, 3.8, 3.9] + timeout-minutes: 40 + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Prepare pip wheel + run: | + which python + python -m pip install --user --upgrade pip setuptools wheel + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + echo "::set-output name=dir::$(pip cache dir)" + shell: bash + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }} + - name: Install the dependencies + run: | + # min. requirements + python -m pip install torch==1.10.1 + python -m pip install -r requirements-min.txt + python -m pip list + BUILD_MONAI=0 python setup.py develop # no compile of extensions + shell: bash + - name: Run quick tests (CPU ${{ runner.os }}) + run: | + python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' + python -c "import monai; monai.config.print_config()" + ./runtests.sh --min + env: + QUICKTEST: True + + min-dep-pytorch: # min dependencies installed tests for different pytorch + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + pytorch-version: [1.6.0, 1.7.1, 1.8.1, 1.9.1, latest] + timeout-minutes: 40 + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Prepare pip wheel + run: | + which python + python -m pip install --user --upgrade pip setuptools wheel + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + echo "::set-output name=dir::$(pip cache dir)" + shell: bash + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }} + - name: Install the dependencies + run: | + # min. requirements + if [ ${{ matrix.pytorch-version }} == "latest" ]; then + python -m pip install torch + elif [ ${{ matrix.pytorch-version }} == "1.6.0" ]; then + python -m pip install torch==1.6.0 + elif [ ${{ matrix.pytorch-version }} == "1.7.1" ]; then + python -m pip install torch==1.7.1 + elif [ ${{ matrix.pytorch-version }} == "1.8.1" ]; then + python -m pip install torch==1.8.1 + elif [ ${{ matrix.pytorch-version }} == "1.9.1" ]; then + python -m pip install torch==1.9.1 + fi + python -m pip install -r requirements-min.txt + python -m pip list + BUILD_MONAI=0 python setup.py develop # no compile of extensions + shell: bash + - name: Run quick tests (pytorch ${{ matrix.pytorch-version }}) + run: | + python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' + python -c "import monai; monai.config.print_config()" + ./runtests.sh --min + env: + QUICKTEST: True diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 3f18263e9e..7237c8d54d 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -44,9 +44,9 @@ jobs: - name: Lint and type check run: | # clean up temporary files - $(pwd)/runtests.sh --clean + $(pwd)/runtests.sh --build --clean # Git hub actions have 2 cores, so parallize pytype - $(pwd)/runtests.sh --codeformat -j 2 + $(pwd)/runtests.sh --build --codeformat -j 2 quick-py3: # full dependencies installed tests for different OS runs-on: ${{ matrix.os }} @@ -87,10 +87,10 @@ jobs: - if: runner.os == 'windows' name: Install torch cpu from pytorch.org (Windows only) run: | - python -m pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + python -m pip install torch==1.10.1+cpu torchvision==0.11.2+cpu -f https://download.pytorch.org/whl/torch_stable.html - name: Install the dependencies run: | - python -m pip install torch==1.9.0 torchvision==0.10.0 + python -m pip install torch==1.10.1 torchvision==0.11.2 cat "requirements-dev.txt" python -m pip install -r requirements-dev.txt python -m pip list @@ -106,100 +106,6 @@ jobs: env: QUICKTEST: True - min-dep-os: # min dependencies installed tests for different OS - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [windows-latest, macOS-latest, ubuntu-latest] - timeout-minutes: 40 - steps: - - uses: actions/checkout@v2 - - name: Set up Python 3.8 - uses: actions/setup-python@v2 - with: - python-version: 3.8 - - name: Prepare pip wheel - run: | - which python - python -m pip install --upgrade pip wheel - - name: cache weekly timestamp - id: pip-cache - run: | - echo "::set-output name=datew::$(date '+%Y-%V')" - echo "::set-output name=dir::$(pip cache dir)" - shell: bash - - name: cache for pip - uses: actions/cache@v2 - id: cache - with: - path: ${{ steps.pip-cache.outputs.dir }} - key: ${{ matrix.os }}-latest-pip-${{ steps.pip-cache.outputs.datew }} - - if: runner.os == 'windows' - name: Install torch cpu from pytorch.org (Windows only) - run: | - python -m pip install torch==1.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html - - name: Install the dependencies - run: | - # min. requirements - python -m pip install torch==1.9.0 - python -m pip install -r requirements-min.txt - python -m pip list - BUILD_MONAI=0 python setup.py develop # no compile of extensions - shell: bash - - name: Run quick tests (CPU ${{ runner.os }}) - run: | - python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' - python -c "import monai; monai.config.print_config()" - ./runtests.sh --min - env: - QUICKTEST: True - - min-dep-py3: # min dependencies installed tests for different python - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: [3.6, 3.7, 3.8, 3.9] - timeout-minutes: 40 - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Prepare pip wheel - run: | - which python - python -m pip install --user --upgrade pip setuptools wheel - - name: cache weekly timestamp - id: pip-cache - run: | - echo "::set-output name=datew::$(date '+%Y-%V')" - echo "::set-output name=dir::$(pip cache dir)" - shell: bash - - name: cache for pip - uses: actions/cache@v2 - id: cache - with: - path: ${{ steps.pip-cache.outputs.dir }} - key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }} - - name: Install the dependencies - run: | - # min. requirements - python -m pip install torch==1.9.0 - python -m pip install -r requirements-min.txt - python -m pip list - BUILD_MONAI=0 python setup.py develop # no compile of extensions - shell: bash - - name: Run quick tests (CPU ${{ runner.os }}) - run: | - python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' - python -c "import monai; monai.config.print_config()" - ./runtests.sh --min - env: - QUICKTEST: True - packaging: runs-on: ubuntu-latest env: @@ -231,7 +137,7 @@ jobs: # install the latest pytorch for testing # however, "pip install monai*.tar.gz" will build cpp/cuda with an isolated # fresh torch installation according to pyproject.toml - python -m pip install torch>=1.5 torchvision + python -m pip install torch>=1.6 torchvision - name: Check packages run: | pip uninstall monai diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index bfdc639788..9cef1ff090 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.6, 3.7, 3.8, 3.9] + python-version: [3.7, 3.8, 3.9] steps: - uses: actions/checkout@v2 with: @@ -87,18 +87,18 @@ jobs: versioning: # compute versioning file from python setup.py # upload as artifact - # (also used in docker.yml) if: github.repository == 'Project-MONAI/MONAI' needs: packaging - container: - image: localhost:5000/local_monai:latest - runs-on: [self-hosted, linux, x64, build_only] + runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 # full history so that we can git describe with: - ref: main fetch-depth: 0 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 - shell: bash run: | git describe @@ -118,11 +118,9 @@ jobs: release_tag_docker: if: github.repository == 'Project-MONAI/MONAI' needs: versioning - runs-on: [self-hosted, linux, x64, build_only] + runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - with: - ref: main - name: Download version uses: actions/download-artifact@v2 with: @@ -136,6 +134,13 @@ jobs: run: | echo "$RELEASE_VERSION" cat _version.py + - if: startsWith(github.ref, 'refs/tags/') + name: Install latest docker + run: | + curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add - + sudo add-apt-repository "deb [arch=amd64] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" + sudo apt-get update + sudo apt-get install docker-ce - if: startsWith(github.ref, 'refs/tags/') name: build with the tag env: @@ -144,6 +149,17 @@ jobs: run: | # get tag info for versioning mv _version.py monai/ + # version checks + target=" \"version\": \"$RELEASE_VERSION\"" + local=`grep "\"version\"" monai/_version.py` + echo "$target" + echo "$local" + if [[ "$local" == "$target" ]]; then + echo "matched version string" + else + echo "unmatched version string, please check the tagging branch." + exit 1 + fi # remove flake package as it is not needed on hub.docker.com sed -i '/flake/d' requirements-dev.txt docker build -t projectmonai/monai:"$RELEASE_VERSION" -f Dockerfile . diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml index d0dc3a9f10..ede2c87b92 100644 --- a/.github/workflows/setupapp.yml +++ b/.github/workflows/setupapp.yml @@ -43,7 +43,8 @@ jobs: which python python -m pip install --upgrade pip wheel python -m pip uninstall -y torch torchvision - python -m pip install torch==1.9.0 torchvision==0.10.0 + rm -rf $(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")/ruamel* + python -m pip install torch==1.10.1 torchvision==0.11.2 python -m pip install -r requirements-dev.txt - name: Run unit tests report coverage run: | @@ -58,8 +59,8 @@ jobs: python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null & python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))" python -c 'import torch; print(torch.rand(5, 3, device=torch.device("cuda:0")))' - BUILD_MONAI=1 ./runtests.sh --coverage --unittests # unit tests with coverage report - BUILD_MONAI=1 ./runtests.sh --coverage --net # integration tests with coverage report + BUILD_MONAI=1 ./runtests.sh --build --coverage --unittests --disttests # unit tests with coverage report + BUILD_MONAI=1 ./runtests.sh --build --coverage --net # integration tests with coverage report coverage xml if pgrep python; then pkill python; fi shell: bash @@ -73,7 +74,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.6, 3.7, 3.8, 3.9] + python-version: [3.7, 3.8, 3.9] steps: - uses: actions/checkout@v2 with: @@ -97,13 +98,13 @@ jobs: - name: Install the dependencies run: | python -m pip install --upgrade pip wheel - python -m pip install torch==1.9.0 torchvision==0.10.0 + python -m pip install torch==1.10.1 torchvision==0.11.2 python -m pip install -r requirements-dev.txt - name: Run quick tests CPU ubuntu run: | python -m pip list python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' - BUILD_MONAI=1 ./runtests.sh --quick --unittests + BUILD_MONAI=1 ./runtests.sh --build --quick --unittests --disttests coverage xml - name: Upload coverage uses: codecov/codecov-action@v1 diff --git a/.github/workflows/weekly-preview.yml b/.github/workflows/weekly-preview.yml index df0b5dd759..0d899b1f30 100644 --- a/.github/workflows/weekly-preview.yml +++ b/.github/workflows/weekly-preview.yml @@ -33,7 +33,7 @@ jobs: export YEAR_WEEK=$(date +'%y%U') echo "Year week for tag is ${YEAR_WEEK}" if ! [[ $YEAR_WEEK =~ ^[0-9]{4}$ ]] ; then echo "Wrong 'year week' format. Should be 4 digits."; exit 1 ; fi - git tag "0.7.dev${YEAR_WEEK}" + git tag "0.9.dev${YEAR_WEEK}" git log -1 git tag --list python setup.py sdist bdist_wheel diff --git a/.gitignore b/.gitignore index 7444d7f2f9..13155c3088 100644 --- a/.gitignore +++ b/.gitignore @@ -135,3 +135,4 @@ tests/testing_data/*.tiff # VSCode .vscode/ +*.zip diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c36c96186c..980c0c0e06 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,18 +22,35 @@ repos: args: ['--maxkb=1024'] - id: detect-private-key - #- repo: https://github.com/asottile/pyupgrade - # rev: v2.23.2 - # hooks: - # - id: pyupgrade - # args: [--py36-plus] - # name: Upgrade code + - repo: https://github.com/asottile/pyupgrade + rev: v2.29.0 + hooks: + - id: pyupgrade + args: [--py37-plus] + name: Upgrade code + exclude: | + (?x)^( + versioneer.py| + monai/_version.py + )$ - #- repo: https://github.com/asottile/yesqa - # rev: v1.2.3 - # hooks: - # - id: yesqa - # name: Unused noqa + - repo: https://github.com/asottile/yesqa + rev: v1.2.3 + hooks: + - id: yesqa + name: Unused noqa + additional_dependencies: + - flake8>=3.8.1 + - flake8-bugbear + - flake8-comprehensions + - flake8-executable + - flake8-pyi + - pep8-naming + exclude: | + (?x)^( + monai/__init__.py| + docs/source/conf.py + )$ #- repo: https://github.com/PyCQA/isort # rev: 5.9.3 diff --git a/CHANGELOG.md b/CHANGELOG.md index bdbd23e7dd..53ce406849 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,97 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). ## [Unreleased] -* renamed model's `n_classes` to `num_classes` + +## [0.8.0] - 2021-11-25 +### Added +* Overview of [new features in v0.8](docs/source/whatsnew_0_8.md) +* Network modules for differentiable neural network topology search (DiNTS) +* Multiple Instance Learning transforms and models for digital pathology WSI analysis +* Vision transformers for self-supervised representation learning +* Contrastive loss for self-supervised learning +* Finalized major improvements of 200+ components in `monai.transforms` to support input and backend in PyTorch and NumPy +* Initial registration module benchmarking with `GlobalMutualInformationLoss` as an example +* `monai.transforms` documentation with visual examples and the utility functions +* Event handler for `MLfLow` integration +* Enhanced data visualization functions including `blend_images` and `matshow3d` +* `RandGridDistortion` and `SmoothField` in `monai.transforms` +* Support of randomized shuffle buffer in iterable datasets +* Performance review and enhancements for data type casting +* Cumulative averaging API with distributed environment support +* Module utility functions including `require_pkg` and `pytorch_after` +* Various usability enhancements such as `allow_smaller` when sampling ROI and `wrap_sequence` when casting object types +* `tifffile` support in `WSIReader` +* Regression tests for the fast training workflows +* Various tutorials and demos including educational contents at [MONAI Bootcamp 2021](https://github.com/Project-MONAI/MONAIBootcamp2021) +### Changed +* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:21.10-py3` from `nvcr.io/nvidia/pytorch:21.08-py3` +* Decoupled `TraceKeys` and `TraceableTransform` APIs from `InvertibleTransform` +* Skipping affine-based resampling when `resample=False` in `NiftiSaver` +* Deprecated `threshold_values: bool` and `num_classes: int` in `AsDiscrete` +* Enhanced `apply_filter` for spatially 1D, 2D and 3D inputs with non-separable kernels +* Logging with `logging` in downloading and model archives in `monai.apps` +* API documentation site now defaults to `stable` instead of `latest` +* `skip-magic-trailing-comma` in coding style enforcements +* Pre-merge CI pipelines now include unit tests with Nvidia Ampere architecture +### Removed +* Support for PyTorch 1.5 +* The deprecated `DynUnetV1` and the related network blocks +* GitHub self-hosted CI/CD pipelines for package releases +### Fixed +* Support of path-like objects as file path inputs in most modules +* Issue of `decollate_batch` for dictionary of empty lists +* Typos in documentation and code examples in various modules +* Issue of no available keys when `allow_missing_keys=True` for the `MapTransform` +* Issue of redundant computation when normalization factors are 0.0 and 1.0 in `ScaleIntensity` +* Incorrect reports of registered readers in `ImageReader` +* Wrong numbering of iterations in `StatsHandler` +* Naming conflicts in network modules and aliases +* Incorrect output shape when `reduction="none"` in `FocalLoss` +* Various usability issues reported by users + +## [0.7.0] - 2021-09-24 +### Added +* Overview of [new features in v0.7](docs/source/whatsnew_0_7.md) +* Initial phase of major usability improvements in `monai.transforms` to support input and backend in PyTorch and NumPy +* Performance enhancements, with [profiling and tuning guides](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_model_training_guide.md) for typical use cases +* Reproducing [training modules and workflows](https://github.com/Project-MONAI/tutorials/tree/master/kaggle/RANZCR/4th_place_solution) of state-of-the-art Kaggle competition solutions +* 24 new transforms, including + * `OneOf` meta transform + * DeepEdit guidance signal transforms for interactive segmentation + * Transforms for self-supervised pre-training + * Integration of [NVIDIA Tools Extension](https://developer.nvidia.com/blog/nvidia-tools-extension-api-nvtx-annotation-tool-for-profiling-code-in-python-and-c-c/) (NVTX) + * Integration of [cuCIM](https://github.com/rapidsai/cucim) + * Stain normalization and contextual grid for digital pathology +* `Transchex` network for vision-language transformers for chest X-ray analysis +* `DatasetSummary` utility in `monai.data` +* `WarmupCosineSchedule` +* Deprecation warnings and documentation support for better backwards compatibility +* Padding with additional `kwargs` and different backend API +* Additional options such as `dropout` and `norm` in various networks and their submodules + +### Changed +* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:21.08-py3` from `nvcr.io/nvidia/pytorch:21.06-py3` +* Deprecated input argument `n_classes`, in favor of `num_classes` +* Deprecated input argument `dimensions` and `ndims`, in favor of `spatial_dims` +* Updated the Sphinx-based documentation theme for better readability +* `NdarrayTensor` type is replaced by `NdarrayOrTensor` for simpler annotations +* Self-attention-based network blocks now support both 2D and 3D inputs + +### Removed +* The deprecated `TransformInverter`, in favor of `monai.transforms.InvertD` +* GitHub self-hosted CI/CD pipelines for nightly and post-merge tests +* `monai.handlers.utils.evenly_divisible_all_gather` +* `monai.handlers.utils.string_list_all_gather` + +### Fixed +* A Multi-thread cache writing issue in `LMDBDataset` +* Output shape convention inconsistencies of the image readers +* Output directory and file name flexibility issue for `NiftiSaver`, `PNGSaver` +* Requirement of the `label` field in test-time augmentation +* Input argument flexibility issues for `ThreadDataLoader` +* Decoupled `Dice` and `CrossEntropy` intermediate results in `DiceCELoss` +* Improved documentation, code examples, and warning messages in various modules +* Various usability issues reported by users ## [0.6.0] - 2021-07-08 ### Added @@ -25,6 +115,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. * Fully compatible with PyTorch 1.9 * `--disttests` and `--min` options for `runtests.sh` * Initial support of pre-merge tests with Nvidia Blossom system + ### Changed * Base Docker image upgraded to `nvcr.io/nvidia/pytorch:21.06-py3` from `nvcr.io/nvidia/pytorch:21.04-py3` @@ -34,11 +125,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. * Unified the terms: `post_transform` is renamed to `postprocessing`, `pre_transform` is renamed to `preprocessing` * Unified the postprocessing transforms and event handlers to accept the "channel-first" data format * `evenly_divisible_all_gather` and `string_list_all_gather` moved to `monai.utils.dist` + ### Removed * Support of 'batched' input for postprocessing transforms and event handlers * `TorchVisionFullyConvModel` * `set_visible_devices` utility function * `SegmentationSaver` and `TransformsInverter` handlers + ### Fixed * Issue of handling big-endian image headers * Multi-thread issue for non-random transforms in the cache-based datasets @@ -269,9 +362,11 @@ the postprocessing steps should be used before calling the metrics methods * Optionally depend on PyTorch-Ignite v0.4.2 instead of v0.3.0 * Optionally depend on torchvision, ITK * Enhanced CI tests with 8 new testing environments + ### Removed * `MONAI/examples` folder (relocated into [`Project-MONAI/tutorials`](https://github.com/Project-MONAI/tutorials)) * `MONAI/research` folder (relocated to [`Project-MONAI/research-contributions`](https://github.com/Project-MONAI/research-contributions)) + ### Fixed * `dense_patch_slices` incorrect indexing * Data type issue in `GeneralizedWassersteinDiceLoss` @@ -302,6 +397,7 @@ the postprocessing steps should be used before calling the metrics methods * Cross-platform CI tests supporting multiple Python versions * Optional import mechanism * Experimental features for third-party transforms integration + ### Changed > For more details please visit [the project wiki](https://github.com/Project-MONAI/MONAI/wiki/Notable-changes-between-0.1.0-and-0.2.0) * Core modules now require numpy >= 1.17 @@ -311,9 +407,11 @@ the postprocessing steps should be used before calling the metrics methods * Base Docker image upgraded to `nvcr.io/nvidia/pytorch:20.03-py3` from `nvcr.io/nvidia/pytorch:19.10-py3` * Enhanced local testing tools * Documentation website domain changed to https://docs.monai.io + ### Removed * Support of Python < 3.6 * Automatic installation of optional dependencies including pytorch-ignite, nibabel, tensorboard, pillow, scipy, scikit-image + ### Fixed * Various issues in type and argument names consistency * Various issues in docstring and documentation site @@ -336,7 +434,9 @@ the postprocessing steps should be used before calling the metrics methods [highlights]: https://github.com/Project-MONAI/MONAI/blob/master/docs/source/highlights.md -[Unreleased]: https://github.com/Project-MONAI/MONAI/compare/0.6.0...HEAD +[Unreleased]: https://github.com/Project-MONAI/MONAI/compare/0.8.0...HEAD +[0.8.0]: https://github.com/Project-MONAI/MONAI/compare/0.7.0...0.8.0 +[0.7.0]: https://github.com/Project-MONAI/MONAI/compare/0.6.0...0.7.0 [0.6.0]: https://github.com/Project-MONAI/MONAI/compare/0.5.3...0.6.0 [0.5.3]: https://github.com/Project-MONAI/MONAI/compare/0.5.0...0.5.3 [0.5.0]: https://github.com/Project-MONAI/MONAI/compare/0.4.0...0.5.0 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0dce26582a..ec6bf35ef9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -65,7 +65,7 @@ python -m pip install -U -r requirements-dev.txt License information: all source code files should start with this paragraph: ``` -# Copyright MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -113,7 +113,7 @@ or (for new features that would not break existing functionality): ``` It is recommended that the new test `test_[module_name].py` is constructed by using only -python 3.6+ build-in functions, `torch`, `numpy`, `coverage` (for reporting code coverages) and `parameterized` (for organising test cases) packages. +python 3.7+ build-in functions, `torch`, `numpy`, `coverage` (for reporting code coverages) and `parameterized` (for organising test cases) packages. If it requires any other external packages, please make sure: - the packages are listed in [`requirements-dev.txt`](requirements-dev.txt) - the new test `test_[module_name].py` is added to the `exclude_cases` in [`./tests/min_tests.py`](./tests/min_tests.py) so that @@ -228,7 +228,7 @@ Notably, for example, ``import monai.transforms.Spacing`` is the equivalent of ``monai.transforms.spatial.array.Spacing`` if ``class Spacing`` defined in file `monai/transforms/spatial/array.py` is decorated with ``@export("monai.transforms")``. -For string definition, [f-string](https://www.python.org/dev/peps/pep-0498/) is recommended to use over `%-print` and `format-print` from python 3.6. So please try to use `f-string` if you need to define any string object. +For string definition, [f-string](https://www.python.org/dev/peps/pep-0498/) is recommended to use over `%-print` and `format-print`. So please try to use `f-string` if you need to define any string object. #### Backwards compatibility MONAI is currently under active development, and with major version zero (following the [Semantic Versioning](https://semver.org/)). @@ -289,9 +289,9 @@ When major features are ready for a milestone, to prepare for a new release: repository's artifacts (e.g. the file at https://github.com/Project-MONAI/MONAI/actions/runs/66570977). - Check the release test at [TestPyPI](https://test.pypi.org/project/monai/), download the artifacts when the CI finishes. - Optionally run [the cron testing jobs](https://github.com/Project-MONAI/MONAI/blob/dev/.github/workflows/cron.yml) on `releasing/[version number]`. +- Rebase `releasing/[version number]` to `main`, make sure all the test pipelines succeed. - Once the release candidate is verified, tag and push a milestone, for example, `git push origin 0.1.0`. The tag must be with the latest commit of `releasing/[version number]`. -- Rebase `releasing/[version number]` to `main`, make sure all the test pipelines succeed. - Upload the packages to [PyPI](https://pypi.org/project/monai/). This could be done manually by ``twine upload dist/*``, given the artifacts are unzipped to the folder ``dist/``. - Merge `releasing/[version number]` to `dev`, this step must make sure that the tagging commit unchanged on `dev`. diff --git a/Dockerfile b/Dockerfile index 77fe1f828f..43dc103e0f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,7 +11,7 @@ # To build with a different base image # please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag. -ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:21.08-py3 +ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:21.12-py3 FROM ${PYTORCH_IMAGE} LABEL maintainer="monai.contact@gmail.com" diff --git a/README.md b/README.md index e9facef64d..bb217b7247 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ [![codecov](https://codecov.io/gh/Project-MONAI/MONAI/branch/dev/graph/badge.svg)](https://codecov.io/gh/Project-MONAI/MONAI) [![PyPI version](https://badge.fury.io/py/monai.svg)](https://badge.fury.io/py/monai) -MONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](LICENSE) framework for deep learning in healthcare imaging, part of [PyTorch Ecosystem](https://pytorch.org/ecosystem/). +MONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](https://github.com/Project-MONAI/MONAI/blob/dev/LICENSE) framework for deep learning in healthcare imaging, part of [PyTorch Ecosystem](https://pytorch.org/ecosystem/). Its ambitions are: - developing a community of academic, industrial and clinical researchers collaborating on a common foundation; - creating state-of-the-art, end-to-end training workflows for healthcare imaging; @@ -19,7 +19,7 @@ Its ambitions are: ## Features > _The codebase is currently under active development._ -> _Please see [the technical highlights](https://docs.monai.io/en/latest/highlights.html) and [What's New in 0.6](https://docs.monai.io/en/latest/whatsnew_0_6.html) of the current milestone release._ +> _Please see [the technical highlights](https://docs.monai.io/en/latest/highlights.html) and [What's New](https://docs.monai.io/en/latest/whatsnew.html) of the current milestone release._ - flexible pre-processing for multi-dimensional medical imaging data; - compositional & portable APIs for ease of integration in existing workflows; @@ -47,7 +47,7 @@ Examples and notebook tutorials are located at [Project-MONAI/tutorials](https:/ Technical documentation is available at [docs.monai.io](https://docs.monai.io). ## Contributing -For guidance on making a contribution to MONAI, see the [contributing guidelines](CONTRIBUTING.md). +For guidance on making a contribution to MONAI, see the [contributing guidelines](https://github.com/Project-MONAI/MONAI/blob/dev/CONTRIBUTING.md). ## Community Join the conversation on Twitter [@ProjectMONAI](https://twitter.com/ProjectMONAI) or join our [Slack channel](https://forms.gle/QTxJq3hFictp31UM9). diff --git a/docs/images/blend.png b/docs/images/blend.png new file mode 100644 index 0000000000..fdf8a21385 Binary files /dev/null and b/docs/images/blend.png differ diff --git a/docs/images/brats_distributed.png b/docs/images/brats_distributed.png new file mode 100644 index 0000000000..90877336ee Binary files /dev/null and b/docs/images/brats_distributed.png differ diff --git a/docs/images/dints-overview.png b/docs/images/dints-overview.png new file mode 100644 index 0000000000..e5f592a8e5 Binary files /dev/null and b/docs/images/dints-overview.png differ diff --git a/docs/images/distributed_training.png b/docs/images/distributed_training.png deleted file mode 100644 index 378855aab3..0000000000 Binary files a/docs/images/distributed_training.png and /dev/null differ diff --git a/docs/images/fast_training.png b/docs/images/fast_training.png index d0584b9dac..34e47bcb21 100644 Binary files a/docs/images/fast_training.png and b/docs/images/fast_training.png differ diff --git a/docs/images/matshow3d.png b/docs/images/matshow3d.png new file mode 100644 index 0000000000..f71e69a99f Binary files /dev/null and b/docs/images/matshow3d.png differ diff --git a/docs/images/mil-patches.jpg b/docs/images/mil-patches.jpg new file mode 100644 index 0000000000..fd904943be Binary files /dev/null and b/docs/images/mil-patches.jpg differ diff --git a/docs/images/nsight_comparison.png b/docs/images/nsight_comparison.png new file mode 100644 index 0000000000..9b91826513 Binary files /dev/null and b/docs/images/nsight_comparison.png differ diff --git a/docs/images/rand_gaussian_noise.png b/docs/images/rand_gaussian_noise.png new file mode 100644 index 0000000000..a824ea8cc6 Binary files /dev/null and b/docs/images/rand_gaussian_noise.png differ diff --git a/docs/images/ssl_overview.png b/docs/images/ssl_overview.png new file mode 100644 index 0000000000..68fa1af576 Binary files /dev/null and b/docs/images/ssl_overview.png differ diff --git a/docs/images/threaddataloader.png b/docs/images/threaddataloader.png new file mode 100644 index 0000000000..565df8d0d4 Binary files /dev/null and b/docs/images/threaddataloader.png differ diff --git a/docs/requirements.txt b/docs/requirements.txt index 00dd4d2c1e..b56fd970d2 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,6 @@ -f https://download.pytorch.org/whl/cpu/torch-1.6.0%2Bcpu-cp37-cp37m-linux_x86_64.whl -torch>=1.5 -pytorch-ignite==0.4.5 +torch>=1.6 +pytorch-ignite==0.4.7 numpy>=1.17 itk>=5.2 nibabel @@ -20,3 +20,8 @@ sphinxcontrib-serializinghtml sphinx-autodoc-typehints==1.11.1 pandas einops +transformers +mlflow +tensorboardX +imagecodecs; platform_system == "Linux" +tifffile; platform_system == "Linux" diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 11d60767ec..f4f7aff2d2 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -114,7 +114,11 @@ Clara MMARs .. automodule:: monai.apps.pathology.transforms.spatial.array .. autoclass:: SplitOnGrid :members: +.. autoclass:: TileOnGrid + :members: .. automodule:: monai.apps.pathology.transforms.spatial.dictionary .. autoclass:: SplitOnGridd :members: +.. autoclass:: TileOnGridd + :members: diff --git a/docs/source/conf.py b/docs/source/conf.py index 324be8a0fd..47a45a78e8 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -22,7 +22,7 @@ # -- Project information ----------------------------------------------------- project = "MONAI" -copyright = "2020 - 2021 MONAI Consortium" +copyright = "MONAI Consortium" author = "MONAI Contributors" # The full version, including alpha/beta/rc tags @@ -107,16 +107,8 @@ def generate_apidocs(*args): html_theme_options = { "external_links": [{"url": "https://github.com/Project-MONAI/tutorials", "name": "Tutorials"}], "icon_links": [ - { - "name": "GitHub", - "url": "https://github.com/project-monai/monai", - "icon": "fab fa-github-square", - }, - { - "name": "Twitter", - "url": "https://twitter.com/projectmonai", - "icon": "fab fa-twitter-square", - }, + {"name": "GitHub", "url": "https://github.com/project-monai/monai", "icon": "fab fa-github-square"}, + {"name": "Twitter", "url": "https://twitter.com/projectmonai", "icon": "fab fa-twitter-square"}, ], "collapse_navigation": True, "navigation_depth": 3, diff --git a/docs/source/data.rst b/docs/source/data.rst index 6e7f5e2773..7f95354e8c 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -21,6 +21,18 @@ Generic Interfaces :members: :special-members: __next__ +`DatasetFunc` +~~~~~~~~~~~~~ +.. autoclass:: DatasetFunc + :members: + :special-members: __next__ + +`ShuffleBuffer` +~~~~~~~~~~~~~~~ +.. autoclass:: ShuffleBuffer + :members: + :special-members: __next__ + `CSVIterableDataset` ~~~~~~~~~~~~~~~~~~~~ .. autoclass:: CSVIterableDataset @@ -194,6 +206,9 @@ DatasetSummary Decathlon Datalist ~~~~~~~~~~~~~~~~~~ .. autofunction:: monai.data.load_decathlon_datalist +.. autofunction:: monai.data.load_decathlon_properties +.. autofunction:: monai.data.check_missing_files +.. autofunction:: monai.data.create_cross_validation_datalist DataLoader @@ -205,6 +220,9 @@ ThreadBuffer ~~~~~~~~~~~~ .. autoclass:: monai.data.ThreadBuffer +ThreadDataLoader +~~~~~~~~~~~~~~~~ +.. autoclass:: monai.data.ThreadDataLoader TestTimeAugmentation ~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/engines.rst b/docs/source/engines.rst index cc0ec3c659..0cd40afb78 100644 --- a/docs/source/engines.rst +++ b/docs/source/engines.rst @@ -11,20 +11,21 @@ Multi-GPU data parallel .. automodule:: monai.engines.multi_gpu_supervised_trainer :members: - Workflows --------- -.. automodule:: monai.engines.workflow -.. currentmodule:: monai.engines.workflow +.. currentmodule:: monai.engines + +`BaseWorkflow` +~~~~~~~~~~~~~~ +.. autoclass:: BaseWorkflow + :members: `Workflow` ~~~~~~~~~~ .. autoclass:: Workflow :members: -.. currentmodule:: monai.engines - `Trainer` ~~~~~~~~~ .. autoclass:: Trainer @@ -54,3 +55,8 @@ Workflows ~~~~~~~~~~~~~~~~~~~ .. autoclass:: EnsembleEvaluator :members: + +Utilities +--------- +.. automodule:: monai.engines.utils + :members: diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 5caccc6b4b..d32b6d88e3 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -150,11 +150,6 @@ GarbageCollector handler .. autoclass:: GarbageCollector :members: -Transform inverter ------------------- -.. autoclass:: TransformInverter - :members: - Post processing --------------- .. autoclass:: PostProcessing @@ -165,6 +160,11 @@ Decollate batch .. autoclass:: DecollateBatch :members: +MLFlow handler +-------------- +.. autoclass:: MLFlowHandler + :members: + NVTX Handlers ------------- .. automodule:: monai.handlers.nvtx_handlers diff --git a/docs/source/highlights.md b/docs/source/highlights.md index 141c0846d1..0e2c9500a4 100644 --- a/docs/source/highlights.md +++ b/docs/source/highlights.md @@ -16,7 +16,7 @@ The overall architecture and modules are shown in the following figure: The rest of this page provides more details for each module. * [Data I/O, processing and augmentation](#medical-image-data-i-o-processing-and-augmentation) -* [Datasets](#datasets) +* [Datasets and DataLoader](#datasets-and-dataloader) * [Loss functions](#losses) * [Optimizers](#optimizers) * [Network architectures](#network-architectures) @@ -25,7 +25,7 @@ The rest of this page provides more details for each module. * [Result writing](#result-writing) * [Workflows](#workflows) * [Research](#research) -* [GPU acceleration](#gpu-acceleration) +* [Performance optimization and GPU acceleration](#performance-optimization-and-gpu-acceleration) * [Applications](#applications) ## Medical image data I/O, processing and augmentation @@ -56,8 +56,15 @@ transformations. These currently include, for example: [2D transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/transforms_demo_2d.ipynb) shows the detailed usage of several MONAI medical image specific transforms. ![2d transform examples](../images/medical_transforms.png) -### 3. Fused spatial transforms and GPU acceleration -As medical image volumes are usually large (in multi-dimensional arrays), pre-processing performance affects the overall pipeline speed. MONAI provides affine transforms to execute fused spatial operations, supports GPU acceleration via native PyTorch for high performance. + +### 3. Transforms support both NumPy array and PyTorch Tensor (CPU or GPU accelerated) +From MONAI v0.7 we introduced PyTorch `Tensor` based computation in transforms, many transforms already support both `NumPy array` and `Tensor` as input types and computational backends. To get the supported backends of every transform, please execute: `python monai/transforms/utils.py`. + +To accelerate the transforms, a common approach is to leverage GPU parallel-computation. Users can first convert input data into GPU Tensor by `ToTensor` or `EnsureType` transform, then the following transforms can execute on GPU based on PyTorch `Tensor` APIs. +GPU transform tutorial is available at [Spleen fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb). + +### 4. Fused spatial transforms +As medical image volumes are usually large (in multi-dimensional arrays), pre-processing performance affects the overall pipeline speed. MONAI provides affine transforms to execute fused spatial operations. For example: ```py @@ -67,20 +74,21 @@ affine = Affine( scale_params=(1.2, 1.2), translate_params=(200, 40), padding_mode='zeros', - device=torch.device('cuda:0') ) # convert the image using bilinear interpolation new_img = affine(image, spatial_size=(300, 400), mode='bilinear') ``` Experiments and test results are available at [Fused transforms test](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/transform_speed.ipynb). -Currently, all the geometric image transforms (Spacing, Zoom, Rotate, Resize, etc.) are designed based on the PyTorch native interfaces. [Geometric transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/3d_image_transforms.ipynb) indicates the usage of affine transforms with 3D medical images. +Currently, all the geometric image transforms (Spacing, Zoom, Rotate, Resize, etc.) are designed based on the PyTorch native interfaces. So all of them support GPU acceleration via `GPU Tensor` operations for high performance. + +[Geometric transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/3d_image_transforms.ipynb) indicates the usage of affine transforms with 3D medical images. ![3d transform examples](../images/affine.png) -### 4. Randomly crop out batch images based on positive/negative ratio +### 5. Randomly crop out batch images based on positive/negative ratio Medical image data volume may be too large to fit into GPU memory. A widely-used approach is to randomly draw small size data samples during training and run a “sliding window” routine for inference. MONAI currently provides general random sampling strategies including class-balanced fixed ratio sampling which may help stabilize the patch-based training process. A typical example is in [Spleen 3D segmentation tutorial](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d.ipynb), which achieves the class-balanced sampling with `RandCropByPosNegLabel` transform. -### 5. Deterministic training for reproducibility +### 6. Deterministic training for reproducibility Deterministic training support is necessary and important for deep learning research, especially in the medical field. Users can easily set the random seed to all the random transforms in MONAI locally and will not affect other non-deterministic modules in the user's program. For example: @@ -99,16 +107,16 @@ Users can also enable/disable deterministic at the beginning of training program monai.utils.set_determinism(seed=0, additional_settings=None) ``` -### 6. Multiple transform chains +### 7. Multiple transform chains To apply different transforms on the same data and concatenate the results, MONAI provides `CopyItems` transform to make copies of specified items in the data dictionary and `ConcatItems` transform to combine specified items on the expected dimension, and also provides `DeleteItems` transform to delete unnecessary items to save memory. Typical usage is to scale the intensity of the same image into different ranges and concatenate the results together. ![multiple transform chains](../images/multi_transform_chains.png) -### 7. Debug transforms with DataStats +### 8. Debug transforms with DataStats When transforms are combined with the "compose" function, it's not easy to track the output of a specific transform. To help debug errors in the composed transforms, MONAI provides utility transforms such as `DataStats` to print out intermediate data properties such as `data shape`, `value range`, `data value`, `Additional information`, etc. It's a self-contained transform and can be integrated into any transform chain. -### 8. Post-processing transforms for model output +### 9. Post-processing transforms for model output MONAI also provides post-processing transforms for handling the model outputs. Currently, the transforms include: - Adding an activation layer (Sigmoid, Softmax, etc.). - Converting to discrete values (Argmax, One-Hot, Threshold value, etc), as below figure (b). @@ -119,12 +127,19 @@ MONAI also provides post-processing transforms for handling the model outputs. C After decollating the batch data of model output and applying the post-processing transforms, it's easier to compute metrics, save model output into files or visualize data in the TensorBoard. [Postprocessing transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/postprocessing_transforms.ipynb) shows an example with several main transforms for post-processing. ![post-processing transforms](../images/postprocessing_transforms.png) -### 9. Integrate third-party transforms +### 10. Integrate third-party transforms The design of MONAI transforms emphasis code readability and usability. It works for array data or dictionary-based data. MONAI also provides `Adaptor` tools to accommodate different data format for 3rd party transforms. To convert the data shapes or types, utility transforms such as `ToTensor`, `ToNumpy`, `SqueezeDim` are also provided. So it's easy to enhance the transform chain by seamlessly integrating transforms from external packages, including: `ITK`, `BatchGenerator`, `TorchIO` and `Rising`. For more details, please check out the tutorial: [integrate 3rd party transforms into MONAI program](https://github.com/Project-MONAI/tutorials/blob/master/modules/integrate_3rd_party_transforms.ipynb). -### 10. IO factory for medical image formats +In digital pathology training, due to the immense burden of loading images, the CPU is preoccupied by loading images and cannot catch up with preparing the data. This causes the pipeline to become IO bound and results in under-utilization of GPU. To overcome this bottleneck, [cuCIM](https://github.com/rapidsai/cucim) has implemented an optimized version of several common transforms that we are using in digital pathology pipeline. These transforms are natively being run on GPU and act on CuPy arrays. MONAI provides `CuCIM` and `RandCuCIM` adapters to integrate the `cuCIM` library. For instance: +```py +RandCuCIM(name="color_jitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04) +CuCIM(name="scale_intensity_range", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0) +``` +It has shown a significant speed up in pathology training metastasis detection model. + +### 11. IO factory for medical image formats Many popular image formats exist in the medical domain, and they are quite different with rich metadata information. To easily handle different medical image formats in the same pipeline, [MONAI provides `LoadImage` transform](https://github.com/Project-MONAI/tutorials/blob/master/modules/load_medical_images.ipynb), which can automatically choose image readers based on the supported suffixes and in the following priority order: - User-specified reader at runtime when calling this loader. - Registered readers from the latest to the first in the list. @@ -134,13 +149,13 @@ The `ImageReader` API is quite straightforward, users can easily extend it for t With these pre-defined image readers, MONAI can load images in formats: `NIfTI`, `DICOM`, `PNG`, `JPG`, `BMP`, `NPY/NPZ`, etc. -### 11. Save transform data into NIfTI or PNG files +### 12. Save transform data into NIfTI or PNG files To convert images into files or debug the transform chain, MONAI provides `SaveImage` transform. Users can inject this transform into the transform chain to save the results. -### 12. Automatically ensure `channel-first` data shape +### 13. Automatically ensure `channel-first` data shape Medical images have different shape formats. They can be `channel-last`, `channel-first` or even `no-channel`. We may, for example, want to load several `no-channel` images and stack them as `channel-first` data. To improve the user experience, MONAI provided an `EnsureChannelFirst` transform to automatically detect data shape according to the meta information and convert it to the `channel-first` format consistently. -### 13. Invert spatial transforms and test-time augmentations +### 14. Invert spatial transforms and test-time augmentations It is often desirable to invert the previously applied spatial transforms (resize, flip, rotate, zoom, crop, pad, etc.) within the deep learning workflows, for example, to resume to the original imaging space after processing the image data in a normalized data space. Many spatial transforms are enhanced with an `inverse` operation since in v0.5. The [model inference tutorial](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_inference_dict.py) shows a basic example. If the pipeline includes random transformations, users may want to observe the effect that these transformations have on the output. The typical approach is that we pass the same input through the transforms multiple times with different random realizations. Then use the inverse transforms to move all the results to a common space, and calculate the metrics. MONAI provided `TestTimeAugmentation` for this feature, which by default will calculate the `mode`, `mean`, `standard deviation` and `volume variation coefficient`. @@ -148,20 +163,31 @@ If the pipeline includes random transformations, users may want to observe the e [Invert transforms and TTA tutorials](https://github.com/Project-MONAI/tutorials/blob/master/modules/inverse_transforms_and_test_time_augmentations.ipynb) introduce details about the API with usage examples. (1) The last column is the inverted data of model output: + ![invert transform](../images/invert_transforms.png) (2) The TTA results of `mode`, `mean` and `standard deviation`: + ![test time augmentation](../images/tta.png) -## Datasets +### 15. Visualization of transform examples +To help clearly introduce the transform functionalities, MONAI provides visualization examples in the [API document](https://docs.monai.io/en/latest/transforms.html) for almost every transform, including spatial transforms, intensity transforms, crop / pad transforms, etc. + +For example: + +![rand gaussian noise](../images/rand_gaussian_noise.png) + +## Datasets and DataLoader ### 1. Cache IO and transforms data to accelerate training Users often need to train the model with many (potentially thousands of) epochs over the data to achieve the desired model quality. A native PyTorch implementation may repeatedly load data and run the same preprocessing steps for every epoch during training, which can be time-consuming and unnecessary, especially when the medical image volumes are large. MONAI provides a multi-thread `CacheDataset` and `LMDBDataset` to accelerate these transformation steps during training by storing the intermediate outcomes before the first randomized transform in the transform chain. Enabling this feature could potentially give 10x training speedups in the [Datasets experiment](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/dataset_type_performance.ipynb). + ![digital pathology](../images/cache_dataset.png) ### 2. Cache intermediate outcomes into persistent storage The `PersistentDataset` is similar to the CacheDataset, where the intermediate cache values are persisted to disk storage or LMDB for rapid retrieval between experimental runs (as is the case when tuning hyperparameters), or when the entire data set size exceeds available memory. The `PersistentDataset` could achieve similar performance when comparing to `CacheDataset` in [Datasets experiment](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/dataset_type_performance.ipynb). + ![cachedataset speed](../images/datasets_speed.png) ### 3. SmartCache mechanism for big datasets @@ -212,6 +238,7 @@ To quickly get started with popular training data in the medical domain, MONAI p MONAI always welcome new contributions of public datasets, please refer to existing Datasets and leverage the download and extracting APIs, etc. [Public datasets tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/public_datasets.ipynb) indicates how to quickly set up training workflows with `MedNISTDataset` and `DecathlonDataset` and how to create a new `Dataset` for public data. The common workflow of predefined datasets: + ![pre-defined dataset](../images/dataset_progress.png) ### 7. Partition dataset for cross validation @@ -221,6 +248,13 @@ The `partition_dataset` utility in MONAI can perform different types of partitio CSV tables are often used in additional to image data to incorporate adjunct information, such as patient demographics, lab results, image acquisition parameters and other non-image data, MONAI provides `CSVDataset` to load CSV files and `CSVIterableDataset` to load large CSV files with scalable data access. In addition to the regular preprocessing transform while loading, it also supports multiple CSV files loading, joining tables, rows and columns selection and grouping. [CSVDatasets tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/csv_datasets.ipynb) shows detailed usage examples. +### 9. `ThreadDataLoader` vs. `DataLoader` +If the transforms are light-weighted, especially when we cache all the data in RAM, the multiprocessing of PyTorch `DataLoader` may cause unnecessary IPC time and cause the drop of GPU utilization after every epoch. MONAI provides `ThreadDataLoader` which executes the transforms in a separate thread: + +![threaddataloader](../images/threaddataloader.png) + +a `ThreadDataLoader` example is available at [Spleen fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb). + ## Losses There are domain-specific loss functions in the medical imaging research which are not typically used in generic computer vision tasks. As an important module of MONAI, these loss functions are implemented in PyTorch, such as `DiceLoss`, `GeneralizedDiceLoss`, `MaskedDiceLoss`, `TverskyLoss`, `FocalLoss`, `DiceCELoss`, and `DiceFocalLoss`, etc. @@ -228,6 +262,7 @@ There are domain-specific loss functions in the medical imaging research which a MONAI provides several advanced features in optimizers to help accelerate the training or fine-tuning progress. For example, `Novograd` optimizer can be used to converge faster than the traditional optimizers. And users can easily define different learning rates for the model layers based [on the `generate_param_groups` utility API](https://github.com/Project-MONAI/tutorials/blob/master/modules/layer_wise_learning_rate.ipynb). Another important feature is `LearningRateFinder`. The learning rate range test increases the learning rate in a pre-training run between two boundaries in a linear or exponential manner. It provides valuable information on how well the network can be trained over a range of learning rates and what the optimal learning rates are. [LearningRateFinder tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/learning_rate.ipynb) indicates the API usage examples. + ![learning rate finder plot](../images/lr_finder.png) ## Network architectures @@ -249,7 +284,7 @@ add_module('conv1', conv_type(in_channels, out_channels, kernel_size=1, bias=Fal ``` ### 2. Implementation of generic 2D/3D networks -And there are several 1D/2D/3D-compatible implementations of intermediate blocks and generic networks, such as UNet, DynUNet, DenseNet, GAN, AHNet, VNet, SENet(and SEResNet, SEResNeXt), SegResNet, EfficientNet, Attention-based networks. All the networks can support PyTorch serialization pipeline based on `torch.jit.script`. +And there are several 1D/2D/3D-compatible implementations of intermediate blocks and generic networks, such as UNet, DynUNet, DenseNet, GAN, AHNet, VNet, SENet(and SEResNet, SEResNeXt), SegResNet, EfficientNet, Attention-based transformer networks, Multi-instance learning networks, DiNTS for AutoML, etc. All the networks can support PyTorch serialization pipeline based on `torch.jit.script`. ### 3. Network adapter to finetune final layers Instead of training from scratch, we often leverage the existing models, and finetune the final layers of a network for new learning tasks. MONAI provides a `NetAdapter` to easily replace the last layer of a model by a convolutional layer or a fully-connected layer. A typical usage example is to adapt [Torchvision models trained with ImageNet](https://pytorch.org/vision/stable/models.html) for other learning tasks. @@ -282,10 +317,22 @@ With a `Cumulative` base class, intermediate metric outcomes can be automaticall ### 3. Metrics report generation During evaluation, users usually save the metrics of every input image, then analyze the bad cases to improve the deep learning pipeline. To save detailed information of metrics, MONAI provided a handler `MetricsSaver`, which can save the final metric values, raw metric of every model output channel of every input image, metrics summary report of operations: `mean`, `median`, `max`, `min`, `percentile`, `std`, etc. The `MeanDice` reports of validation with prostate dataset are as below: + ![metrics report example](../images/metrics_report.png) ## Visualization -Beyond the simple point and curve plotting, MONAI provides intuitive interfaces to visualize multidimensional data as GIF animations in TensorBoard. This could provide a quick qualitative assessment of the model by visualizing, for example, the volumetric inputs, segmentation maps, and intermediate feature maps. A runnable example with visualization is available at [UNet training example](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_training_dict.py). +Beyond the simple point and curve plotting, MONAI provides intuitive interfaces to visualize multidimensional data as GIF animations in TensorBoard. This could provide a quick qualitative assessment of the model by visualizing, for example, the volumetric inputs, segmentation maps, and intermediate feature maps. A runnable example with visualization is available at [UNet training example](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_training_dict.py). To work with ignite program, MONAI also provides several ignite handlers to visualize training curve and metrics with `TensorBoard` or `MLFlow`, more details is available in [TensorBoard and MLFlow handlers example](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb). + +To easily visualize a 3D image as frames of 2D images, MONAI provides the utility `matshow3d` based on `matplotlib` library. It can plot frames of image for the specified dimension, showing a spleen 3D image as example: +`matshow3d(volume=image, figsize=(100, 100), every_n=10, frame_dim=-1 show=True, cmap="gray")` + +![matshow3d example](../images/matshow3d.png) + +MONAI also provides the `blend_images` utility to blend the `image` and `label` to a RGB color image to better visualize the segmentation regions with the specified `cmap` mode and weights, etc. Showing a spleen segmentation `image` and the corresponding `label` as example: + +![blend example](../images/blend.png) + +For more details of `TensorBoard utility`, `matshow3d` and `blend_images`, please check the [visualziation tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/transform_visualization.ipynb). And to visualize the class activation mapping for a trained classification model, MONAI provides CAM, GradCAM, GradCAM++ APIs for both 2D and 3D models: @@ -364,12 +411,33 @@ G. Wang, X. Liu, C. Li, Z. Xu, J. Ruan, H. Zhu, T. Meng, K. Li, N. Huang, S. Zha [A reimplementation](https://monai.io/research/lamp-automated-model-parallelism) of the LAMP system originally proposed by: Wentao Zhu, Can Zhao, Wenqi Li, Holger Roth, Ziyue Xu, and Daguang Xu (2020) "LAMP: Large Deep Nets with Automated Model Parallelism for Image Segmentation." MICCAI 2020 (Early Accept, paper link: https://arxiv.org/abs/2006.12575) + ![LAMP UNet](../images/unet-pipe.png) -## GPU acceleration +### 3. DiNTS: Differentiable Neural Network Topology Search for 3D Medical Image Segmentation +MONAI integrated the `DiNTS` module to support more flexible topologies and joint two-level search. It provides a topology guaranteed discretization algorithm and a discretization aware topology loss for the search stage to minimize the discretization gap, and a cost usage aware search method which can search 3D networks with different GPU memory requirements. For more details, please check the [DiNTS tutorial](https://monai.io/research/dints.html). + +![DiNTS](../images/dints-overview.png) + +### 4. Accounting for Dependencies in Deep Learning Based Multiple Instance Learning for Whole Slide Imaging +For [classification of digital pathology whole slide images (WSI)](https://arxiv.org/abs/2111.01556), MONAI introduces new transforms and network modules for multiple instance learning. These include self-attention transformer blocks for explicitly accounting of the dependencies between instances (image patches) during training. For more details, please check out the [multiple instance learning tutorial](https://github.com/Project-MONAI/tutorials/tree/master/pathology/multiple_instance_learning). ![multi-instance](../images/mil-patches.jpg) + +### 5. Self-supervised representation learning +MONAI starts to explore self-supervised representation learning in this milestone release. The Vision Transformer has been extended to learn from self-supervised reconstruction tasks with various data augmentation and a regularized contrastive loss. The weights of the pre-trained backbone could be used to enhance the performance of the novel downstream deep learning tasks. + +The [tutorial](https://github.com/Project-MONAI/tutorials/tree/master/self_supervised_pretraining) shows how to generate a good set of pre-trained weights using unlabeled data with self-supervised tasks, then use the pre-trained weights to perform fine-tuning on a fully supervised volumetric segmentation task using a transformer based `UNETR`. + +![self-supervised](../images/ssl_overview.png) + +## Performance optimization and GPU acceleration +Typically, model training is a time-consuming step during deep learning development, especially in medical imaging applications. Volumetric medical images are usually large (as multi-dimensional arrays) and the model training process can be complex. Even with powerful hardware (e.g. CPU/GPU with large RAM), it is not easy to fully leverage them to achieve high performance. MONAI provides [a fast training guide to achieve the best performance](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_model_training_guide.md). + NVIDIA GPUs have been widely applied in many areas of deep learning training and evaluation, and the CUDA parallel computation shows obvious acceleration when comparing to traditional computation methods. To fully leverage GPU features, many popular mechanisms raised, like automatic mixed precision (AMP), distributed data parallel, etc. MONAI can support these features and provides rich examples. -### 1. Auto mixed precision(AMP) +### 1. Profiling the pipelines +First of all, MONAI provides several methods based on `DLProf`, `Nsight`, `NVTX` and `NVML` for users to analyze their programs to identify the performance bottleneck. The analyses include operation-based GPU activity and overall GPU activity during model training. They will greatly help users manage computing bottlenecks and provide insights for the area to be improved for better computing efficiency. The detailed example is shown in the [performance profiling tutorial](https://github.com/Project-MONAI/tutorials/blob/master/performance_profiling/radiology/profiling_train_base_nvtx.md). + +### 2. Auto mixed precision(AMP) In 2017, NVIDIA researchers developed a methodology for mixed-precision training, which combined single-precision (FP32) with half-precision (e.g. FP16) format when training a network, and it achieved the same accuracy as FP32 training using the same hyperparameters. For the PyTorch 1.6 release, developers at NVIDIA and Facebook moved mixed precision functionality into PyTorch core as the AMP package, `torch.cuda.amp`. @@ -379,16 +447,16 @@ MONAI workflows can easily set `amp=True/False` in `SupervisedTrainer` or `Super We also executed the same test program on NVIDIA A100 GPU with the same software environment, obtained faster results: ![amp a100 results](../images/amp_training_a100.png) More details is available at [AMP training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/automatic_mixed_precision.ipynb). -We also tried to combine AMP with `CacheDataset` and `Novograd` optimizer to achieve the fast training in MONAI, able to obtain approximately 12x speedup compared with a Pytorch native implementation when the training converges at a validation mean dice of 0.93. Benchmark for reference: +We also tried to combine `AMP` with `CacheDataset`, `GPU cache`, `GPU transforms`, `ThreadDataLoader`, `DiceCE` loss function and `Novograd` optimizer to achieve the fast training in MONAI, able to obtain approximately `200x` speedup compared with a Pytorch native implementation when the training converges at a validation mean dice of `0.95`. Benchmark for reference: ![fast training results](../images/fast_training.png) More details is available at [Fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb). -### 2. Distributed data parallel -Distributed data parallel is an important feature of PyTorch to connect multiple GPU devices on single or multiple nodes to train or evaluate models. The distributed data parallel APIs of MONAI are compatible with native PyTorch distributed module, pytorch-ignite distributed module, Horovod, XLA, and the SLURM platform. MONAI provides demos for reference: train/evaluate with PyTorch DDP, train/evaluate with Horovod, train/evaluate with Ignite DDP, partition dataset and train with SmartCacheDataset, as well as a real world training example based on Decathlon challenge Task01 - Brain Tumor segmentation. The demo contains distributed caching, training, and validation. We obtained performance benchmarks for reference (based on PyTorch 1.6, CUDA 11, NVIDIA V100 GPUs): +### 3. Distributed data parallel +Distributed data parallel is an important feature of PyTorch to connect multiple GPU devices on single or multiple nodes to train or evaluate models. The distributed data parallel APIs of MONAI are compatible with native PyTorch distributed module, pytorch-ignite distributed module, Horovod, XLA, and the SLURM platform. MONAI provides demos for reference: train/evaluate with PyTorch DDP, train/evaluate with Horovod, train/evaluate with Ignite DDP, partition dataset and train with SmartCacheDataset, as well as a real world training example based on Decathlon challenge Task01 - Brain Tumor segmentation. The [tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/distributed_training/brats_training_ddp.py) contains distributed caching, training, and validation. We obtained performance benchmarks for reference (based on PyTorch 1.9.1, CUDA 11.4, NVIDIA V100 GPUs. The `optimization` means that with more GPU resources, we can split the data and cache into GPU memory and execute GPU transforms directly): -![distributed training results](../images/distributed_training.png) +![distributed training results](../images/brats_distributed.png) -### 3. C++/CUDA optimized modules +### 4. C++/CUDA optimized modules To further accelerate the domain-specific routines in the workflows, MONAI C++/CUDA implementation are introduced as extensions of the PyTorch native implementations. MONAI provides the modules using [the two ways of building C++ extensions from PyTorch](https://pytorch.org/tutorials/advanced/cpp_extension.html#custom-c-and-cuda-extensions): - via `setuptools`, for modules including `Resampler`, `Conditional random field (CRF)`, `Fast bilateral filtering using the permutohedral lattice`. @@ -396,6 +464,26 @@ MONAI provides the modules using [the two ways of building C++ extensions from P The following figure shows results of MONAI's Gaussian mixture models applied to tissue and surgical tools segmentation: ![Gaussian mixture models as a postprocessing step](../images/gmm_feature_set_comparison_s.png) +### 5. Cache IO and transforms data to GPU memory +Even with `CacheDataset`, we usually need to copy the same data to GPU memory for GPU random transforms or network computation in every epoch. An efficient approach is to cache the data to GPU memory directly, then every epoch can start from GPU computation immediately. + +For example: +```py +train_transforms = [ + LoadImaged(...), + AddChanneld(...), + Spacingd(...), + Orientationd(...), + ScaleIntensityRanged(...), + EnsureTyped(..., data_type="tensor"), + ToDeviced(..., device="cuda:0"), + RandCropByPosNegLabeld(...), +] +dataset = CacheDataset(..., transform=train_trans) +``` +Here we convert to PyTorch `Tensor` with `EnsureTyped` transform and move data to GPU with `ToDeviced` transform. `CacheDataset` caches the transform results until `ToDeviced`, so it is in GPU memory. Then in every epoch, the program fetches cached data from GPU memory and only executes the random transform `RandCropByPosNegLabeld` on GPU directly. +GPU caching example is available at [Spleen fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb). + ## Applications The research area of medical image deep learning is expanding fast. To apply the latest achievements into applications, MONAI contains many application components to build end-to-end solutions or prototypes for other similar use cases. @@ -417,3 +505,8 @@ Starting from v0.5.0, MONAI provides experimental features for building learning The following figure shows the registration of CT images acquired at different time points for a single patient using MONAI: ![3d registration](../images/3d_paired.png) + +### 4. Reproducing the state-of-the-art Kaggle competition solutions +[A reimplementation](https://github.com/Project-MONAI/tutorials/tree/master/kaggle/RANZCR/4th_place_solution) of the 4th place solution of RANZCR CLiP - Catheter and Line Position Challenge in Kaggle: https://www.kaggle.com/c/ranzcr-clip-catheter-line-classification + +The original solution is produced by Team Watercooled, and the authors are Dieter (https://www.kaggle.com/christofhenkel) and Psi (https://www.kaggle.com/philippsinger). diff --git a/docs/source/installation.md b/docs/source/installation.md index 08ab109142..2635a6c386 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -14,7 +14,7 @@ --- -MONAI's core functionality is written in Python 3 (>= 3.6) and only requires [Numpy](https://numpy.org/) and [Pytorch](https://pytorch.org/). +MONAI's core functionality is written in Python 3 (>= 3.7) and only requires [Numpy](https://numpy.org/) and [Pytorch](https://pytorch.org/). The package is currently distributed via Github as the primary source code repository, and the Python package index (PyPI). The pre-built Docker images are made available on DockerHub. @@ -174,9 +174,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops] +[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, matplotlib, tensorboardX, tifffile, imagecodecs] ``` which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas` and `einops`, respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `matplotlib`, `tensorboardX`, +`tifffile`, `imagecodecs`, respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/docs/source/losses.rst b/docs/source/losses.rst index fc7c302ea3..dfd8ce2ddb 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -63,6 +63,11 @@ Segmentation Losses .. autoclass:: TverskyLoss :members: +`ContrastiveLoss` +~~~~~~~~~~~~~~~~~ +.. autoclass:: ContrastiveLoss + :members: + Registration Losses ------------------- diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index c1ed25831d..332571d345 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -8,6 +8,8 @@ Metrics `FROC` ------ +.. autofunction:: compute_fp_tp_probs +.. autofunction:: compute_froc_curve_data .. autofunction:: compute_froc_score `Metric` @@ -47,6 +49,7 @@ Metrics `Confusion matrix` ------------------ .. autofunction:: get_confusion_matrix +.. autofunction:: compute_confusion_matrix_metric .. autoclass:: ConfusionMatrixMetric :members: @@ -54,6 +57,7 @@ Metrics `Hausdorff distance` -------------------- .. autofunction:: compute_hausdorff_distance +.. autofunction:: compute_percent_hausdorff_distance .. autoclass:: HausdorffDistanceMetric :members: @@ -84,3 +88,13 @@ Metrics ---------------------------- .. autoclass:: PSNRMetric :members: + +`Cumulative average` +-------------------- +.. autoclass:: CumulativeAverage + :members: + +Utilities +--------- +.. automodule:: monai.metrics.utils + :members: diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 54c2756535..720a3723dc 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -21,7 +21,7 @@ Blocks :members: `CRF` -~~~~~~~~~~~~~ +~~~~~ .. autoclass:: CRF :members: @@ -73,6 +73,8 @@ Blocks :members: .. autoclass:: UnetUpBlock :members: +.. autoclass:: UnetOutBlock + :members: `SegResnet Block` ~~~~~~~~~~~~~~~~~ @@ -188,6 +190,26 @@ Blocks .. autoclass:: PatchEmbeddingBlock :members: +`FactorizedIncreaseBlock` +~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: FactorizedIncreaseBlock + :members: + +`FactorizedReduceBlock` +~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: FactorizedReduceBlock + :members: + +`P3DActiConvNormBlock` +~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: P3DActiConvNormBlock + :members: + +`ActiConvNormBlock` +~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ActiConvNormBlock + :members: + `Warp` ~~~~~~ .. autoclass:: Warp @@ -234,6 +256,11 @@ Layers .. automodule:: monai.networks.layers.Conv :members: +`Pad` +~~~~~ +.. automodule:: monai.networks.layers.Pad + :members: + `Pool` ~~~~~~ .. automodule:: monai.networks.layers.Pool @@ -256,6 +283,19 @@ Layers .. autoclass:: Flatten :members: +`Reshape` +~~~~~~~~~ +.. autoclass:: Reshape + :members: + +`separable_filtering` +~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: separable_filtering + +`apply_filter` +~~~~~~~~~~~~~~ +.. autofunction:: apply_filter + `GaussianFilter` ~~~~~~~~~~~~~~~~ .. autoclass:: GaussianFilter @@ -267,7 +307,7 @@ Layers :members: `PHLFilter` -~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~ .. autoclass:: PHLFilter `GaussianMixtureModel` @@ -353,6 +393,11 @@ Nets .. autoclass:: EfficientNet :members: +`BlockArgs` +~~~~~~~~~~~ +.. autoclass:: BlockArgs + :members: + `EfficientNetBN` ~~~~~~~~~~~~~~~~ .. autoclass:: EfficientNetBN @@ -373,6 +418,11 @@ Nets .. autoclass:: SegResNetVAE :members: +`ResNet` +~~~~~~~~ +.. autoclass:: ResNet + :members: + `SENet` ~~~~~~~ .. autoclass:: SENet @@ -470,11 +520,21 @@ Nets .. autoclass:: ViT :members: +`ViTAutoEnc` +~~~~~~~~~~~~ +.. autoclass:: ViTAutoEnc + :members: + `FullyConnectedNet` ~~~~~~~~~~~~~~~~~~~ .. autoclass:: FullyConnectedNet :members: +`VarFullyConnectedNet` +~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: VarFullyConnectedNet + :members: + `Generator` ~~~~~~~~~~~ .. autoclass:: Generator @@ -500,6 +560,11 @@ Nets .. autoclass:: Critic :members: +`Transchex` +~~~~~~~~~~~~~~~~ +.. autoclass:: Transchex + :members: + `NetAdapter` ~~~~~~~~~~~~ .. autoclass:: NetAdapter @@ -515,6 +580,31 @@ Nets .. autoclass:: TorchVisionFullyConvModel :members: +`MILModel` +~~~~~~~~~~ +.. autoclass:: MILModel + :members: + +`DiNTS` +~~~~~~~ +.. autoclass:: DiNTS + :members: + +`TopologyConstruction for DiNTS` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: TopologyConstruction + :members: + +`TopologyInstance for DiNTS` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: TopologyInstance + :members: + +`TopologySearch for DiNTS` +~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: TopologySearch + :members: + Utilities --------- .. automodule:: monai.networks.utils diff --git a/docs/source/optimizers.rst b/docs/source/optimizers.rst index c766ac3cf9..67cbdc0951 100644 --- a/docs/source/optimizers.rst +++ b/docs/source/optimizers.rst @@ -6,6 +6,11 @@ Optimizers ========== .. currentmodule:: monai.optimizers +`LearningRateFinder` +-------------------- +.. autoclass:: LearningRateFinder + :members: + `Novograd` ---------- .. autoclass:: Novograd @@ -14,3 +19,18 @@ Optimizers `Generate parameter groups` --------------------------- .. autofunction:: generate_param_groups + +`ExponentialLR` +--------------- +.. autoclass:: ExponentialLR + :members: + +`LinearLR` +---------- +.. autoclass:: LinearLR + :members: + +`WarmupCosineSchedule` +---------------------- +.. autoclass:: WarmupCosineSchedule + :members: diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index b8f57e0dbe..49ed4c9e6c 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -43,6 +43,11 @@ Generic Interfaces .. autoclass:: InvertibleTransform :members: +`TraceableTransform` +^^^^^^^^^^^^^^^^^^^^ +.. autoclass:: TraceableTransform + :members: + `BatchInverseTransform` ^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: BatchInverseTransform @@ -53,80 +58,121 @@ Generic Interfaces .. autoclass:: Decollated :members: +`OneOf` +^^^^^^^ +.. autoclass:: OneOf + :members: + Vanilla Transforms ------------------ Crop and Pad ^^^^^^^^^^^^ +`PadListDataCollate` +"""""""""""""""""""" +.. autoclass:: PadListDataCollate + :members: + :special-members: __call__ + +`Pad` +""""" +.. autoclass:: Pad + :members: + :special-members: __call__ + `SpatialPad` """""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/SpatialPad.png + :alt: example of SpatialPad .. autoclass:: SpatialPad :members: :special-members: __call__ `BorderPad` """"""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/BorderPad.png + :alt: example of BorderPad .. autoclass:: BorderPad :members: :special-members: __call__ `DivisiblePad` """""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/DivisiblePad.png + :alt: example of DivisiblePad .. autoclass:: DivisiblePad :members: :special-members: __call__ `SpatialCrop` """"""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/SpatialCrop.png + :alt: example of SpatialCrop .. autoclass:: SpatialCrop :members: :special-members: __call__ `CenterSpatialCrop` """"""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/CenterSpatialCrop.png + :alt: example of CenterSpatialCrop .. autoclass:: CenterSpatialCrop :members: :special-members: __call__ `RandSpatialCrop` """"""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSpatialCrop.png + :alt: example of RandSpatialCrop .. autoclass:: RandSpatialCrop :members: :special-members: __call__ `RandSpatialCropSamples` """""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSpatialCropSamples.png + :alt: example of RandSpatialCropSamples .. autoclass:: RandSpatialCropSamples :members: :special-members: __call__ `CropForeground` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/CropForeground.png + :alt: example of CropForeground .. autoclass:: CropForeground :members: :special-members: __call__ `RandWeightedCrop` """""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandWeightedCrop.png + :alt: example of RandWeightedCrop .. autoclass:: RandWeightedCrop :members: :special-members: __call__ `RandCropByPosNegLabel` """"""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandCropByPosNegLabel.png + :alt: example of RandCropByPosNegLabel .. autoclass:: RandCropByPosNegLabel :members: :special-members: __call__ `RandCropByLabelClasses` """""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandCropByLabelClasses.png + :alt: example of RandCropByLabelClasses .. autoclass:: RandCropByLabelClasses :members: :special-members: __call__ `ResizeWithPadOrCrop` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ResizeWithPadOrCrop.png + :alt: example of ResizeWithPadOrCrop .. autoclass:: ResizeWithPadOrCrop :members: :special-members: __call__ @@ -139,12 +185,16 @@ Crop and Pad `RandScaleCrop` """"""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandScaleCrop.png + :alt: example of RandScaleCrop .. autoclass:: RandScaleCrop :members: :special-members: __call__ `CenterScaleCrop` """"""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/CenterScaleCrop.png + :alt: example of CenterScaleCrop .. autoclass:: CenterScaleCrop :members: :special-members: __call__ @@ -154,126 +204,168 @@ Intensity `RandGaussianNoise` """"""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGaussianNoise.png + :alt: example of RandGaussianNoise .. autoclass:: RandGaussianNoise :members: :special-members: __call__ `ShiftIntensity` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ShiftIntensity.png + :alt: example of ShiftIntensity .. autoclass:: ShiftIntensity :members: :special-members: __call__ `RandShiftIntensity` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandShiftIntensity.png + :alt: example of RandShiftIntensity .. autoclass:: RandShiftIntensity :members: :special-members: __call__ `StdShiftIntensity` """"""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/StdShiftIntensity.png + :alt: example of StdShiftIntensity .. autoclass:: StdShiftIntensity :members: :special-members: __call__ `RandStdShiftIntensity` """"""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandStdShiftIntensity.png + :alt: example of RandStdShiftIntensity .. autoclass:: RandStdShiftIntensity :members: :special-members: __call__ `RandBiasField` """"""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandBiasField.png + :alt: example of RandBiasField .. autoclass:: RandBiasField :members: :special-members: __call__ `ScaleIntensity` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ScaleIntensity.png + :alt: example of ScaleIntensity .. autoclass:: ScaleIntensity :members: :special-members: __call__ `RandScaleIntensity` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandScaleIntensity.png + :alt: example of RandScaleIntensity .. autoclass:: RandScaleIntensity :members: :special-members: __call__ `NormalizeIntensity` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/NormalizeIntensity.png + :alt: example of NormalizeIntensity .. autoclass:: NormalizeIntensity :members: :special-members: __call__ `ThresholdIntensity` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ThresholdIntensity.png + :alt: example of ThresholdIntensity .. autoclass:: ThresholdIntensity :members: :special-members: __call__ `ScaleIntensityRange` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ScaleIntensityRange.png + :alt: example of ScaleIntensityRange .. autoclass:: ScaleIntensityRange :members: :special-members: __call__ `ScaleIntensityRangePercentiles` """""""""""""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ScaleIntensityRangePercentiles.png + :alt: example of ScaleIntensityRangePercentiles .. autoclass:: ScaleIntensityRangePercentiles :members: :special-members: __call__ `AdjustContrast` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/AdjustContrast.png + :alt: example of AdjustContrast .. autoclass:: AdjustContrast :members: :special-members: __call__ `RandAdjustContrast` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandAdjustContrast.png + :alt: example of RandAdjustContrast .. autoclass:: RandAdjustContrast :members: :special-members: __call__ `MaskIntensity` """"""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/MaskIntensity.png + :alt: example of MaskIntensity .. autoclass:: MaskIntensity :members: :special-members: __call__ `SavitzkyGolaySmooth` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/SavitzkyGolaySmooth.png + :alt: example of SavitzkyGolaySmooth .. autoclass:: SavitzkyGolaySmooth :members: :special-members: __call__ `GaussianSmooth` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/GaussianSmooth.png + :alt: example of GaussianSmooth .. autoclass:: GaussianSmooth :members: :special-members: __call__ `RandGaussianSmooth` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGaussianSmooth.png + :alt: example of RandGaussianSmooth .. autoclass:: RandGaussianSmooth :members: :special-members: __call__ `GaussianSharpen` """"""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/GaussianSharpen.png + :alt: example of GaussianSharpen .. autoclass:: GaussianSharpen :members: :special-members: __call__ `RandGaussianSharpen` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGaussianSharpen.png + :alt: example of RandGaussianSharpen .. autoclass:: RandGaussianSharpen :members: :special-members: __call__ `RandHistogramShift` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandHistogramShift.png + :alt: example of RandHistogramShift .. autoclass:: RandHistogramShift :members: :special-members: __call__ @@ -286,43 +378,71 @@ Intensity `GibbsNoise` """""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/GibbsNoise.png + :alt: example of GibbsNoise .. autoclass:: GibbsNoise :members: :special-members: __call__ `RandGibbsNoise` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGibbsNoise.png + :alt: example of RandGibbsNoise .. autoclass:: RandGibbsNoise :members: :special-members: __call__ `KSpaceSpikeNoise` """""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/KSpaceSpikeNoise.png + :alt: example of KSpaceSpikeNoise .. autoclass:: KSpaceSpikeNoise :members: :special-members: __call__ `RandKSpaceSpikeNoise` """""""""""""""""""""" - .. autoclass:: RandKSpaceSpikeNoise - :members: - :special-members: __call__ +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandKSpaceSpikeNoise.png + :alt: example of RandKSpaceSpikeNoise +.. autoclass:: RandKSpaceSpikeNoise + :members: + :special-members: __call__ + +`RandRicianNoise` +""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandRicianNoise.png + :alt: example of RandRicianNoise +.. autoclass:: RandRicianNoise + :members: + :special-members: __call__ + +`RandCoarseTransform` +""""""""""""""""""""" +.. autoclass:: RandCoarseTransform + :members: + :special-members: __call__ `RandCoarseDropout` """"""""""""""""""" - .. autoclass:: RandCoarseDropout - :members: - :special-members: __call__ +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandCoarseDropout.png + :alt: example of RandCoarseDropout +.. autoclass:: RandCoarseDropout + :members: + :special-members: __call__ + +`RandCoarseShuffle` +""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandCoarseShuffle.png + :alt: example of RandCoarseShuffle +.. autoclass:: RandCoarseShuffle + :members: + :special-members: __call__ `HistogramNormalize` """""""""""""""""""" - .. autoclass:: HistogramNormalize - :members: - :special-members: __call__ - -`LocalPatchShuffling` -""""""""""""""""""""" -.. autoclass:: LocalPatchShuffling +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/HistogramNormalize.png + :alt: example of HistogramNormalize +.. autoclass:: HistogramNormalize :members: :special-members: __call__ @@ -381,18 +501,24 @@ Post-processing `AsDiscrete` """""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/AsDiscrete.png + :alt: example of AsDiscrete .. autoclass:: AsDiscrete :members: :special-members: __call__ `KeepLargestConnectedComponent` """"""""""""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/KeepLargestConnectedComponent.png + :alt: example of KeepLargestConnectedComponent .. autoclass:: KeepLargestConnectedComponent :members: :special-members: __call__ `LabelFilter` """"""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/LabelFilter.png + :alt: example of LabelFilter .. autoclass:: LabelFilter :members: :special-members: __call__ @@ -405,6 +531,8 @@ Post-processing `LabelToContour` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/LabelToContour.png + :alt: example of LabelToContour .. autoclass:: LabelToContour :members: :special-members: __call__ @@ -431,42 +559,56 @@ Spatial `Spacing` """"""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Spacing.png + :alt: example of Spacing .. autoclass:: Spacing :members: :special-members: __call__ `Orientation` """"""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Orientation.png + :alt: example of Orientation .. autoclass:: Orientation :members: :special-members: __call__ `RandRotate` """""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandRotate.png + :alt: example of RandRotate .. autoclass:: RandRotate :members: :special-members: __call__ `RandFlip` """""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandFlip.png + :alt: example of RandFlip .. autoclass:: RandFlip :members: :special-members: __call__ `RandAxisFlip` """""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandAxisFlip.png + :alt: example of RandAxisFlip .. autoclass:: RandAxisFlip :members: :special-members: __call__ `RandZoom` """""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandZoom.png + :alt: example of RandZoom .. autoclass:: RandZoom :members: :special-members: __call__ `Affine` """""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Affine.png + :alt: example of Affine .. autoclass:: Affine :members: :special-members: __call__ @@ -479,6 +621,8 @@ Spatial `RandAffine` """""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandAffine.png + :alt: example of RandAffine .. autoclass:: RandAffine :members: :special-members: __call__ @@ -501,57 +645,110 @@ Spatial :members: :special-members: __call__ +`GridDistortion` +"""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/GridDistortion.png + :alt: example of GridDistortion +.. autoclass:: GridDistortion + :members: + :special-members: __call__ + +`RandGridDistortion` +"""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGridDistortion.png + :alt: example of RandGridDistortion +.. autoclass:: RandGridDistortion + :members: + :special-members: __call__ + `Rand2DElastic` """"""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Rand2DElastic.png + :alt: example of Rand2DElastic .. autoclass:: Rand2DElastic :members: :special-members: __call__ `Rand3DElastic` """"""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Rand3DElastic.png + :alt: example of Rand3DElastic .. autoclass:: Rand3DElastic :members: :special-members: __call__ `Rotate90` """""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Rotate90.png + :alt: example of Rotate90 .. autoclass:: Rotate90 :members: :special-members: __call__ `RandRotate90` """""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandRotate90.png + :alt: example of RandRotate90 .. autoclass:: RandRotate90 :members: :special-members: __call__ `Flip` """""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Flip.png + :alt: example of Flip .. autoclass:: Flip :members: :special-members: __call__ `Resize` """""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Resize.png + :alt: example of Resize .. autoclass:: Resize :members: :special-members: __call__ `Rotate` """""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Rotate.png + :alt: example of Rotate .. autoclass:: Rotate :members: :special-members: __call__ `Zoom` """""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Zoom.png + :alt: example of Zoom .. autoclass:: Zoom :members: :special-members: __call__ -`AddCoordinateChannels` -""""""""""""""""""""""" -.. autoclass:: AddCoordinateChannels +Smooth Field +^^^^^^^^^^^^ + +`RandSmoothFieldAdjustContrast` +""""""""""""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSmoothFieldAdjustContrast.png + :alt: example of RandSmoothFieldAdjustContrast +.. autoclass:: RandSmoothFieldAdjustContrast + :members: + :special-members: __call__ + +`RandSmoothFieldAdjustIntensity` +"""""""""""""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSmoothFieldAdjustIntensity.png + :alt: example of RandSmoothFieldAdjustIntensity +.. autoclass:: RandSmoothFieldAdjustIntensity + :members: + :special-members: __call__ + +`RandSmoothDeform` +"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSmoothDeform.png + :alt: example of RandSmoothDeform +.. autoclass:: RandSmoothDeform :members: :special-members: __call__ @@ -661,6 +858,12 @@ Utility :members: :special-members: __call__ +`RemoveRepeatedChannel` +""""""""""""""""""""""" +.. autoclass:: RemoveRepeatedChannel + :members: + :special-members: __call__ + `LabelToMask` """"""""""""" .. autoclass:: LabelToMask @@ -711,16 +914,34 @@ Utility `IntensityStats` """""""""""""""" - .. autoclass:: IntensityStats - :members: - :special-members: __call__ +.. autoclass:: IntensityStats + :members: + :special-members: __call__ `ToDevice` """""""""" - .. autoclass:: ToDevice +.. autoclass:: ToDevice :members: :special-members: __call__ +`CuCIM` +""""""" +.. autoclass:: CuCIM + :members: + :special-members: __call__ + +`RandCuCIM` +""""""""""" +.. autoclass:: RandCuCIM + :members: + :special-members: __call__ + +`AddCoordinateChannels` +""""""""""""""""""""""" +.. autoclass:: AddCoordinateChannels + :members: + :special-members: __call__ + Dictionary Transforms --------------------- @@ -730,72 +951,96 @@ Crop and Pad (Dict) `SpatialPadd` """"""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/SpatialPadd.png + :alt: example of SpatialPadd .. autoclass:: SpatialPadd :members: :special-members: __call__ `BorderPadd` """""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/BorderPadd.png + :alt: example of BorderPadd .. autoclass:: BorderPadd :members: :special-members: __call__ `DivisiblePadd` """"""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/DivisiblePadd.png + :alt: example of DivisiblePadd .. autoclass:: DivisiblePadd :members: :special-members: __call__ `SpatialCropd` """""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/SpatialCropd.png + :alt: example of SpatialCropd .. autoclass:: SpatialCropd :members: :special-members: __call__ `CenterSpatialCropd` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/CenterSpatialCropd.png + :alt: example of CenterSpatialCropd .. autoclass:: CenterSpatialCropd :members: :special-members: __call__ `RandSpatialCropd` """""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSpatialCropd.png + :alt: example of RandSpatialCropd .. autoclass:: RandSpatialCropd :members: :special-members: __call__ `RandSpatialCropSamplesd` """"""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSpatialCropSamplesd.png + :alt: example of RandSpatialCropSamplesd .. autoclass:: RandSpatialCropSamplesd :members: :special-members: __call__ `CropForegroundd` """"""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/CropForegroundd.png + :alt: example of CropForegroundd .. autoclass:: CropForegroundd :members: :special-members: __call__ `RandWeightedCropd` """"""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandWeightedCropd.png + :alt: example of RandWeightedCropd .. autoclass:: RandWeightedCropd :members: :special-members: __call__ `RandCropByPosNegLabeld` """""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandCropByPosNegLabeld.png + :alt: example of RandCropByPosNegLabeld .. autoclass:: RandCropByPosNegLabeld :members: :special-members: __call__ `RandCropByLabelClassesd` """"""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandCropByLabelClassesd.png + :alt: example of RandCropByLabelClassesd .. autoclass:: RandCropByLabelClassesd :members: :special-members: __call__ `ResizeWithPadOrCropd` """""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ResizeWithPadOrCropd.png + :alt: example of ResizeWithPadOrCropd .. autoclass:: ResizeWithPadOrCropd :members: :special-members: __call__ @@ -808,12 +1053,16 @@ Crop and Pad (Dict) `RandScaleCropd` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandScaleCropd.png + :alt: example of RandScaleCropd .. autoclass:: RandScaleCropd :members: :special-members: __call__ `CenterScaleCropd` """""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/CenterScaleCropd.png + :alt: example of CenterScaleCropd .. autoclass:: CenterScaleCropd :members: :special-members: __call__ @@ -823,159 +1072,235 @@ Intensity (Dict) `RandGaussianNoised` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGaussianNoised.png + :alt: example of RandGaussianNoised .. autoclass:: RandGaussianNoised :members: :special-members: __call__ `ShiftIntensityd` """"""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ShiftIntensityd.png + :alt: example of ShiftIntensityd .. autoclass:: ShiftIntensityd :members: :special-members: __call__ `RandShiftIntensityd` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandShiftIntensityd.png + :alt: example of RandShiftIntensityd .. autoclass:: RandShiftIntensityd :members: :special-members: __call__ `StdShiftIntensityd` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/StdShiftIntensityd.png + :alt: example of StdShiftIntensityd .. autoclass:: StdShiftIntensityd :members: :special-members: __call__ `RandStdShiftIntensityd` """""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandStdShiftIntensityd.png + :alt: example of RandStdShiftIntensityd .. autoclass:: RandStdShiftIntensityd :members: :special-members: __call__ `RandBiasFieldd` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandBiasFieldd.png + :alt: example of RandBiasFieldd .. autoclass:: RandBiasFieldd :members: :special-members: __call__ `ScaleIntensityd` """"""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ScaleIntensityd.png + :alt: example of ScaleIntensityd .. autoclass:: ScaleIntensityd :members: :special-members: __call__ `RandScaleIntensityd` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandScaleIntensityd.png + :alt: example of RandScaleIntensityd .. autoclass:: RandScaleIntensityd :members: :special-members: __call__ `NormalizeIntensityd` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/NormalizeIntensityd.png + :alt: example of NormalizeIntensityd .. autoclass:: NormalizeIntensityd :members: :special-members: __call__ `ThresholdIntensityd` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ThresholdIntensityd.png + :alt: example of ThresholdIntensityd .. autoclass:: ThresholdIntensityd :members: :special-members: __call__ `ScaleIntensityRanged` """""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ScaleIntensityRanged.png + :alt: example of ScaleIntensityRanged .. autoclass:: ScaleIntensityRanged :members: :special-members: __call__ `GibbsNoised` """""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/GibbsNoised.png + :alt: example of GibbsNoised .. autoclass:: GibbsNoised :members: :special-members: __call__ `RandGibbsNoised` """""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGibbsNoised.png + :alt: example of RandGibbsNoised .. autoclass:: RandGibbsNoised :members: :special-members: __call__ `KSpaceSpikeNoised` """""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/KSpaceSpikeNoised.png + :alt: example of KSpaceSpikeNoised .. autoclass:: KSpaceSpikeNoised :members: :special-members: __call__ `RandKSpaceSpikeNoised` """"""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandKSpaceSpikeNoised.png + :alt: example of RandKSpaceSpikeNoised .. autoclass:: RandKSpaceSpikeNoised :members: :special-members: __call__ +`RandRicianNoised` +"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandRicianNoised.png + :alt: example of RandRicianNoised +.. autoclass:: RandRicianNoised + :members: + :special-members: __call__ + `ScaleIntensityRangePercentilesd` """"""""""""""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ScaleIntensityRangePercentilesd.png + :alt: example of ScaleIntensityRangePercentilesd .. autoclass:: ScaleIntensityRangePercentilesd :members: :special-members: __call__ `AdjustContrastd` """"""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/AdjustContrastd.png + :alt: example of AdjustContrastd .. autoclass:: AdjustContrastd :members: :special-members: __call__ `RandAdjustContrastd` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandAdjustContrastd.png + :alt: example of RandAdjustContrastd .. autoclass:: RandAdjustContrastd :members: :special-members: __call__ `MaskIntensityd` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/MaskIntensityd.png + :alt: example of MaskIntensityd .. autoclass:: MaskIntensityd :members: :special-members: __call__ +`SavitzkyGolaySmoothd` +"""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/SavitzkyGolaySmoothd.png + :alt: example of SavitzkyGolaySmoothd +.. autoclass:: SavitzkyGolaySmoothd + :members: + :special-members: __call__ + `GaussianSmoothd` """"""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/GaussianSmoothd.png + :alt: example of GaussianSmoothd .. autoclass:: GaussianSmoothd :members: :special-members: __call__ `RandGaussianSmoothd` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGaussianSmoothd.png + :alt: example of RandGaussianSmoothd .. autoclass:: RandGaussianSmoothd :members: :special-members: __call__ `GaussianSharpend` """""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/GaussianSharpend.png + :alt: example of GaussianSharpend .. autoclass:: GaussianSharpend :members: :special-members: __call__ `RandGaussianSharpend` """""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGaussianSharpend.png + :alt: example of RandGaussianSharpend .. autoclass:: RandGaussianSharpend :members: :special-members: __call__ `RandHistogramShiftd` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandHistogramShiftd.png + :alt: example of RandHistogramShiftd .. autoclass:: RandHistogramShiftd :members: :special-members: __call__ `RandCoarseDropoutd` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandCoarseDropoutd.png + :alt: example of RandCoarseDropoutd .. autoclass:: RandCoarseDropoutd :members: :special-members: __call__ +`RandCoarseShuffled` +"""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandCoarseShuffled.png + :alt: example of RandCoarseShuffled +.. autoclass:: RandCoarseShuffled + :members: + :special-members: __call__ + `HistogramNormalized` """"""""""""""""""""" - .. autoclass:: HistogramNormalized - :members: - :special-members: __call__ +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/HistogramNormalized.png + :alt: example of HistogramNormalized +.. autoclass:: HistogramNormalized + :members: + :special-members: __call__ IO (Dict) @@ -1004,18 +1329,24 @@ Post-processing (Dict) `AsDiscreted` """"""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/AsDiscreted.png + :alt: example of AsDiscreted .. autoclass:: AsDiscreted :members: :special-members: __call__ `KeepLargestConnectedComponentd` """""""""""""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/KeepLargestConnectedComponentd.png + :alt: example of KeepLargestConnectedComponentd .. autoclass:: KeepLargestConnectedComponentd :members: :special-members: __call__ `LabelFilterd` """""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/LabelFilterd.png + :alt: example of LabelFilterd .. autoclass:: LabelFilterd :members: :special-members: __call__ @@ -1028,6 +1359,8 @@ Post-processing (Dict) `LabelToContourd` """"""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/LabelToContourd.png + :alt: example of LabelToContourd .. autoclass:: LabelToContourd :members: :special-members: __call__ @@ -1067,103 +1400,172 @@ Spatial (Dict) `Spacingd` """""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Spacingd.png + :alt: example of Spacingd .. autoclass:: Spacingd :members: :special-members: __call__ `Orientationd` """""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Orientationd.png + :alt: example of Orientationd .. autoclass:: Orientationd :members: :special-members: __call__ `Flipd` """"""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Flipd.png + :alt: example of Flipd .. autoclass:: Flipd :members: :special-members: __call__ `RandFlipd` """"""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandFlipd.png + :alt: example of RandFlipd .. autoclass:: RandFlipd :members: :special-members: __call__ `RandAxisFlipd` """"""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandAxisFlipd.png + :alt: example of RandAxisFlipd .. autoclass:: RandAxisFlipd :members: :special-members: __call__ `Rotated` """"""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Rotated.png + :alt: example of Rotated .. autoclass:: Rotated :members: :special-members: __call__ `RandRotated` """"""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandRotated.png + :alt: example of RandRotated .. autoclass:: RandRotated :members: :special-members: __call__ `Zoomd` """"""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Zoomd.png + :alt: example of Zoomd .. autoclass:: Zoomd :members: :special-members: __call__ `RandZoomd` """"""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandZoomd.png + :alt: example of RandZoomd .. autoclass:: RandZoomd :members: :special-members: __call__ `RandRotate90d` """"""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandRotate90d.png + :alt: example of RandRotate90d .. autoclass:: RandRotate90d :members: :special-members: __call__ `Rotate90d` """"""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Rotate90d.png + :alt: example of Rotate90d .. autoclass:: Rotate90d :members: :special-members: __call__ `Resized` """"""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Resized.png + :alt: example of Resized .. autoclass:: Resized :members: :special-members: __call__ `Affined` """"""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Affined.png + :alt: example of Affined .. autoclass:: Affined :members: :special-members: __call__ `RandAffined` """"""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandAffined.png + :alt: example of RandAffined .. autoclass:: RandAffined :members: :special-members: __call__ `Rand2DElasticd` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Rand2DElasticd.png + :alt: example of Rand2DElasticd .. autoclass:: Rand2DElasticd :members: :special-members: __call__ `Rand3DElasticd` """""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Rand3DElasticd.png + :alt: example of Rand3DElasticd .. autoclass:: Rand3DElasticd :members: :special-members: __call__ -`AddCoordinateChannelsd` -"""""""""""""""""""""""" -.. autoclass:: AddCoordinateChannelsd +`GridDistortiond` +""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/GridDistortiond.png + :alt: example of GridDistortiond +.. autoclass:: GridDistortiond + :members: + :special-members: __call__ + +`RandGridDistortiond` +""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGridDistortiond.png + :alt: example of RandGridDistortiond +.. autoclass:: RandGridDistortiond + :members: + :special-members: __call__ + +Smooth Field (Dict) +^^^^^^^^^^^^^^^^^^^ + +`RandSmoothFieldAdjustContrastd` +"""""""""""""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSmoothFieldAdjustContrastd.png + :alt: example of RandSmoothFieldAdjustContrastd +.. autoclass:: RandSmoothFieldAdjustContrastd + :members: + :special-members: __call__ + +`RandSmoothFieldAdjustIntensityd` +""""""""""""""""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSmoothFieldAdjustIntensityd.png + :alt: example of RandSmoothFieldAdjustIntensityd +.. autoclass:: RandSmoothFieldAdjustIntensityd + :members: + :special-members: __call__ + +`RandSmoothDeformd` +""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSmoothDeformd.png + :alt: example of RandSmoothDeformd +.. autoclass:: RandSmoothDeformd :members: :special-members: __call__ @@ -1230,6 +1632,12 @@ Utility (Dict) :members: :special-members: __call__ +`ToPIL` +""""""" +.. autoclass:: ToPIL + :members: + :special-members: __call__ + `ToCupyd` """"""""" .. autoclass:: ToCupyd @@ -1352,15 +1760,37 @@ Utility (Dict) `ToDeviced` """"""""""" - .. autoclass:: ToDeviced - :members: - :special-members: __call__ +.. autoclass:: ToDeviced + :members: + :special-members: __call__ +`CuCIMd` +"""""""" +.. autoclass:: CuCIMd + :members: + :special-members: __call__ + +`RandCuCIMd` +"""""""""""" +.. autoclass:: RandCuCIMd + :members: + :special-members: __call__ + +`AddCoordinateChannelsd` +"""""""""""""""""""""""" +.. autoclass:: AddCoordinateChannelsd + :members: + :special-members: __call__ Transform Adaptors ------------------ .. automodule:: monai.transforms.adaptors +`FunctionSignature` +^^^^^^^^^^^^^^^^^^^ +.. autoclass:: FunctionSignature + :members: + `adaptor` ^^^^^^^^^ .. autofunction:: monai.transforms.adaptors.adaptor @@ -1377,3 +1807,6 @@ Utilities --------- .. automodule:: monai.transforms.utils :members: + +.. automodule:: monai.transforms.utils_pytorch_numpy_unification + :members: diff --git a/docs/source/utils.rst b/docs/source/utils.rst index a9aea7932b..881519936b 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -43,7 +43,7 @@ Profiling Deprecated ---------- -.. automodule:: monai.utils.deprecated +.. automodule:: monai.utils.deprecate_utils :members: @@ -51,3 +51,28 @@ Type conversion --------------- .. automodule:: monai.utils.type_conversion :members: + +Decorators +---------- +.. automodule:: monai.utils.decorators + :members: + +Distributed Data Parallel +------------------------- +.. automodule:: monai.utils.dist + :members: + +Enums +----- +.. automodule:: monai.utils.enums + :members: + +Jupyter Utilities +----------------- +.. automodule:: monai.utils.jupyter_utils + :members: + +State Cacher +------------ +.. automodule:: monai.utils.state_cacher + :members: diff --git a/docs/source/visualize.rst b/docs/source/visualize.rst index 850fd51770..3779feec88 100644 --- a/docs/source/visualize.rst +++ b/docs/source/visualize.rst @@ -24,3 +24,8 @@ Occlusion sensitivity .. automodule:: monai.visualize.occlusion_sensitivity :members: + +Utilities +--------- +.. automodule:: monai.visualize.utils + :members: diff --git a/docs/source/whatsnew.rst b/docs/source/whatsnew.rst index daed871e14..1f27d651db 100644 --- a/docs/source/whatsnew.rst +++ b/docs/source/whatsnew.rst @@ -6,5 +6,7 @@ What's New .. toctree:: :maxdepth: 1 + whatsnew_0_8.md + whatsnew_0_7.md whatsnew_0_6.md whatsnew_0_5.md diff --git a/docs/source/whatsnew_0_6.md b/docs/source/whatsnew_0_6.md index bdc419df37..8df0503142 100644 --- a/docs/source/whatsnew_0_6.md +++ b/docs/source/whatsnew_0_6.md @@ -1,4 +1,4 @@ -# What's new in 0.6 🎉🎉 +# What's new in 0.6 - Decollating mini-batches as an essential post-processing step - Pythonic APIs to load the pretrained models from Clara Train MMARs diff --git a/docs/source/whatsnew_0_7.md b/docs/source/whatsnew_0_7.md new file mode 100644 index 0000000000..6df64948b0 --- /dev/null +++ b/docs/source/whatsnew_0_7.md @@ -0,0 +1,63 @@ +# What's new in 0.7 + +- Performance enhancements with profiling and tuning guides +- Major usability improvements in `monai.transforms` +- Reimplementing state-of-the-art Kaggle solutions +- Vision-language multimodal transformer architectures + +## Performance enhancements with profiling and tuning guides + +Model training is often a time-consuming step during deep learning development, +especially for medical imaging applications. Even with powerful hardware (e.g. +CPU/GPU with large RAM), the workflows often require careful profiling and +tuning to achieve high performance. MONAI has been focusing on performance +enhancements, and in this version, a fast model training guide is provided +to help build highly performant workflows, with a comprehensive overview of +the profiling tools and practical strategies: +https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_model_training_guide.md. + +The following figure shows the use of [Nvidia Nsight™ Systems](https://developer.nvidia.com/nsight-systems) for system-wide +performance analysis during a performance enhancement study. +![nsight_vis](../images/nsight_comparison.png) + +With the performance profiling and enhancements, several typical use cases were studied to +improve the training efficiency. The following figure shows that fast +training using MONAI can be `200` times faster than a regular baseline ([learn +more](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb)), and it's `20` times faster than the MONAI v0.6 fast training solution. +![fast_training](../images/fast_training.png) + +## Major usability improvements in `monai.transforms` for NumPy/PyTorch inputs and backends + + MONAI starts to roll out major usability enhancements for the + `monai.transforms` module. Many transforms are now supporting both NumPy and + PyTorch, as input types and computational backends. To get the supported backends of every transform, please execute: `python monai/transforms/utils.py`. + +One benefit of these enhancements is that the users can now better leverage the +GPUs for preprocessing. By transferring the input data onto GPU using +`ToTensor` or `EnsureType`, and applying the GPU-based transforms to the data, +[the tutorial of spleen +segmentation](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb) +shows the great potential of using the flexible modules for fast and efficient +training. + +## Reimplementing state-of-the-art Kaggle solutions + +With this release, we actively evaluate and enhance the quality and flexibility +of the MONAI core modules, using the public Kaggle challenge as a testbed. [A +reimplementation](https://github.com/Project-MONAI/tutorials/tree/master/kaggle/RANZCR/4th_place_solution) +of a state-of-the-art solution at [Kaggle RANZCR CLiP - Catheter and Line +Position +Challenge](https://www.kaggle.com/c/ranzcr-clip-catheter-line-classification) +is made available in this version. + +## Vision-language multimodal transformers + +In this release, MONAI adds support for training multimodal (vision + language) +transformers that can handle both image and textual data. MONAI introduces the +`TransCheX` model which consists of vision, language, and mixed-modality +transformer layers for processing chest X-ray and their corresponding +radiological reports within a unified framework. In addition to `TransCheX`, +users have the flexibility to alter the architecture by varying the number of +vision, language and mixed-modality layers and customizing the classification +head. In addition, the model can be initialized from pre-trained BERT language +models for fine-tuning. diff --git a/docs/source/whatsnew_0_8.md b/docs/source/whatsnew_0_8.md new file mode 100644 index 0000000000..bbaf01f5de --- /dev/null +++ b/docs/source/whatsnew_0_8.md @@ -0,0 +1,56 @@ +# What's new in 0.8 🎉🎉 + +- Differentiable neural network topology search +- Multiple instance learning for digital pathology WSI analysis +- Self-supervised representation learning +- Major usability improvements in `monai.transforms` + +## Differentiable neural network topology search +MONAI integrates `DiNTS`: [Differentiable Neural Network Topology Search for 3D +Medical Image Segmentation](https://arxiv.org/abs/2103.15954). The neural +architecture search module supports flexible multi-path topology search with +high search efficiency and budgeted memory usage. + +It provides a topology guaranteed discretization algorithm and a +discretization-aware topology loss for the search stage to minimize the +discretization gap. The module is memory usage aware and is able to search 3D +networks with different GPU memory requirements. For more details, please check out the +[DiNTS tutorial](https://monai.io/research/dints.html). + +![DiNTS](../images/dints-overview.png) + +## Multiple instance learning for digital pathology WSI analysis +For [classification of digital pathology whole slide images +(WSI)](https://arxiv.org/abs/2111.01556), MONAI introduces new transforms and +network modules for multiple instance learning. These include self-attention +transformer blocks for explicitly accounting of the dependencies between instances +(image patches) during training. For more details, +please check out the [multiple instance learning tutorial](https://github.com/Project-MONAI/tutorials/tree/master/pathology/multiple_instance_learning). + +![multi-instance](../images/mil-patches.jpg) + +## Self-supervised representation learning +MONAI starts to explore self-supervised representation learning in this +milestone release. The Vision Transformer has been extended to learn from self-supervised +reconstruction tasks with various data augmentation and a regularized +contrastive loss. The weights of the pre-trained backbone could be used to +enhance the performance of the novel downstream deep learning tasks. + +The [tutorial](https://github.com/Project-MONAI/tutorials/tree/master/self_supervised_pretraining) +shows how to generate a good set of pre-trained weights using unlabeled data +with self-supervised tasks, then use the pre-trained weights to perform +fine-tuning on a fully supervised volumetric segmentation task using a transformer based `UNETR`. + +![self-supervised](../images/ssl_overview.png) + +## Major usability improvements in `monai.transforms` +`monai.transforms` are now more flexible and easy to use in version 0.8. +- Input type handling and backend APIs are improved to support both + NumPy and PyTorch where possible. +- Visual examples are added to the documentation to illustrate the effects of + various image processing. +- New visualization utilities are provided and enhanced for quick qualitative + assessments of the model by visualizing, for example, the volumetric image + inputs, segmentation maps, and intermediate feature maps. + The visualization tutorial is available for + [TensorBoard utility, `matshow3d` and `blend_images`](https://github.com/Project-MONAI/tutorials/blob/master/modules/transform_visualization.ipynb). diff --git a/monai/__init__.py b/monai/__init__.py index 2c7c920162..68a232b46d 100644 --- a/monai/__init__.py +++ b/monai/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,22 +15,24 @@ from ._version import get_versions PY_REQUIRED_MAJOR = 3 -PY_REQUIRED_MINOR = 6 +PY_REQUIRED_MINOR = 7 version_dict = get_versions() __version__: str = version_dict.get("version", "0+unknown") __revision_id__: str = version_dict.get("full-revisionid") del get_versions, version_dict -__copyright__ = "(c) 2020 - 2021 MONAI Consortium" +__copyright__ = "(c) MONAI Consortium" __basedir__ = os.path.dirname(__file__) -if not (sys.version_info.major == PY_REQUIRED_MAJOR and sys.version_info.minor >= PY_REQUIRED_MINOR): - raise RuntimeError( - "MONAI requires Python {}.{} or higher. But the current Python is: {}".format( - PY_REQUIRED_MAJOR, PY_REQUIRED_MINOR, sys.version - ), +if sys.version_info.major != PY_REQUIRED_MAJOR or sys.version_info.minor < PY_REQUIRED_MINOR: + import warnings + + warnings.warn( + f"MONAI requires Python {PY_REQUIRED_MAJOR}.{PY_REQUIRED_MINOR} or higher. " + f"But the current Python is: {sys.version}", + category=RuntimeWarning, ) from .utils.module import load_submodules # noqa: E402 diff --git a/monai/_extensions/__init__.py b/monai/_extensions/__init__.py index 3718894b7c..fd32d71840 100644 --- a/monai/_extensions/__init__.py +++ b/monai/_extensions/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/_extensions/gmm/gmm.cpp b/monai/_extensions/gmm/gmm.cpp index ecb85e252a..686fddb721 100644 --- a/monai/_extensions/gmm/gmm.cpp +++ b/monai/_extensions/gmm/gmm.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/_extensions/gmm/gmm.h b/monai/_extensions/gmm/gmm.h index 9a43351eb9..6317baa41a 100644 --- a/monai/_extensions/gmm/gmm.h +++ b/monai/_extensions/gmm/gmm.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/_extensions/gmm/gmm_cpu.cpp b/monai/_extensions/gmm/gmm_cpu.cpp index 144e66806c..c9b55490eb 100644 --- a/monai/_extensions/gmm/gmm_cpu.cpp +++ b/monai/_extensions/gmm/gmm_cpu.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -22,5 +22,5 @@ void learn_cpu(const float* input, const int* labels, float* gmm, float* scratch void apply_cpu(const float* gmm, const float* input, float* output, unsigned int batch_count, unsigned int element_count) { - throw std::invalid_argument("GMM recieved a cpu tensor but is not yet implemented for the cpu"); + throw std::invalid_argument("GMM received a cpu tensor but is not yet implemented for the cpu"); } diff --git a/monai/_extensions/gmm/gmm_cuda.cu b/monai/_extensions/gmm/gmm_cuda.cu index 36af48b06c..765ffe5b1c 100644 --- a/monai/_extensions/gmm/gmm_cuda.cu +++ b/monai/_extensions/gmm/gmm_cuda.cu @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/_extensions/gmm/gmm_cuda_linalg.cuh b/monai/_extensions/gmm/gmm_cuda_linalg.cuh index 49e68c8442..9d54d80d3b 100644 --- a/monai/_extensions/gmm/gmm_cuda_linalg.cuh +++ b/monai/_extensions/gmm/gmm_cuda_linalg.cuh @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/_extensions/loader.py b/monai/_extensions/loader.py index 5f77480ecc..2b57302fb0 100644 --- a/monai/_extensions/loader.py +++ b/monai/_extensions/loader.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,7 +34,7 @@ def timeout(time, message): except KeyboardInterrupt as e: if timer is not None and timer.is_alive(): raise e # interrupt from user? - raise TimeoutError(message) + raise TimeoutError(message) from e finally: if timer is not None: try: @@ -84,11 +84,7 @@ def load_module( # This will either run the build or return the existing .so object. name = module_name + platform_str.replace(".", "_") module = load( - name=name, - sources=source, - extra_cflags=define_args, - extra_cuda_cflags=define_args, - verbose=verbose_build, + name=name, sources=source, extra_cflags=define_args, extra_cuda_cflags=define_args, verbose=verbose_build ) return module diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index ef4352cabd..893f7877d2 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,4 +11,4 @@ from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar -from .utils import check_hash, download_and_extract, download_url, extractall +from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index c766914026..8b03393fd6 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,13 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import sys +from pathlib import Path from typing import Callable, Dict, List, Optional, Sequence, Union import numpy as np from monai.apps.utils import download_and_extract +from monai.config.type_definitions import PathLike from monai.data import ( CacheDataset, load_decathlon_datalist, @@ -50,6 +51,12 @@ class MedNISTDataset(Randomizable, CacheDataset): will take the minimum of (cache_num, data_length x cache_rate, data_length). num_workers: the number of worker threads to use. if 0 a single thread will be used. Default is 0. + progress: whether to display a progress bar when downloading dataset and computing the transform cache content. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cached content + (for example, randomly crop from the cached image and deepcopy the crop region) + or if every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. Raises: ValueError: When ``root_dir`` is not a directory. @@ -64,7 +71,7 @@ class MedNISTDataset(Randomizable, CacheDataset): def __init__( self, - root_dir: str, + root_dir: PathLike, section: str, transform: Union[Sequence[Callable], Callable] = (), download: bool = False, @@ -74,20 +81,30 @@ def __init__( cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_workers: int = 0, + progress: bool = True, + copy_cache: bool = True, ) -> None: - if not os.path.isdir(root_dir): + root_dir = Path(root_dir) + if not root_dir.is_dir(): raise ValueError("Root directory root_dir must be a directory.") self.section = section self.val_frac = val_frac self.test_frac = test_frac self.set_random_state(seed=seed) - tarfile_name = os.path.join(root_dir, self.compressed_file_name) - dataset_dir = os.path.join(root_dir, self.dataset_folder_name) + tarfile_name = root_dir / self.compressed_file_name + dataset_dir = root_dir / self.dataset_folder_name self.num_class = 0 if download: - download_and_extract(self.resource, tarfile_name, root_dir, self.md5) + download_and_extract( + url=self.resource, + filepath=tarfile_name, + output_dir=root_dir, + hash_val=self.md5, + hash_type="md5", + progress=progress, + ) - if not os.path.exists(dataset_dir): + if not dataset_dir.is_dir(): raise RuntimeError( f"Cannot find dataset directory: {dataset_dir}, please use download=True to download it." ) @@ -95,31 +112,33 @@ def __init__( if transform == (): transform = LoadImaged("image") CacheDataset.__init__( - self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers + self, + data=data, + transform=transform, + cache_num=cache_num, + cache_rate=cache_rate, + num_workers=num_workers, + progress=progress, + copy_cache=copy_cache, ) - def randomize(self, data: List[int]) -> None: + def randomize(self, data: np.ndarray) -> None: self.R.shuffle(data) def get_num_classes(self) -> int: """Get number of classes.""" return self.num_class - def _generate_data_list(self, dataset_dir: str) -> List[Dict]: + def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]: """ Raises: ValueError: When ``section`` is not one of ["training", "validation", "test"]. """ - class_names = sorted((x for x in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, x)))) + dataset_dir = Path(dataset_dir) + class_names = sorted(f"{x.name}" for x in dataset_dir.iterdir() if x.is_dir()) # folder name as the class name self.num_class = len(class_names) - image_files = [ - [ - os.path.join(dataset_dir, class_names[i], x) - for x in os.listdir(os.path.join(dataset_dir, class_names[i])) - ] - for i in range(self.num_class) - ] + image_files = [[f"{x}" for x in (dataset_dir / class_names[i]).iterdir()] for i in range(self.num_class)] num_each = [len(image_files[i]) for i in range(self.num_class)] image_files_list = [] image_class = [] @@ -145,13 +164,9 @@ def _generate_data_list(self, dataset_dir: str) -> List[Dict]: raise ValueError( f'Unsupported section: {self.section}, available options are ["training", "validation", "test"].' ) - + # the types of label and class name should be compatible with the pytorch dataloader return [ - { - "image": image_files_list[i], - "label": image_class[i], - "class_name": class_name[i], - } + {"image": image_files_list[i], "label": image_class[i], "class_name": class_name[i]} for i in section_indices ] @@ -184,6 +199,12 @@ class DecathlonDataset(Randomizable, CacheDataset): will take the minimum of (cache_num, data_length x cache_rate, data_length). num_workers: the number of worker threads to use. if 0 a single thread will be used. Default is 0. + progress: whether to display a progress bar when downloading dataset and computing the transform cache content. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cached content + (for example, randomly crop from the cached image and deepcopy the crop region) + or if every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. Raises: ValueError: When ``root_dir`` is not a directory. @@ -238,7 +259,7 @@ class DecathlonDataset(Randomizable, CacheDataset): def __init__( self, - root_dir: str, + root_dir: PathLike, task: str, section: str, transform: Union[Sequence[Callable], Callable] = (), @@ -248,20 +269,30 @@ def __init__( cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_workers: int = 0, + progress: bool = True, + copy_cache: bool = True, ) -> None: - if not os.path.isdir(root_dir): + root_dir = Path(root_dir) + if not root_dir.is_dir(): raise ValueError("Root directory root_dir must be a directory.") self.section = section self.val_frac = val_frac self.set_random_state(seed=seed) if task not in self.resource: raise ValueError(f"Unsupported task: {task}, available options are: {list(self.resource.keys())}.") - dataset_dir = os.path.join(root_dir, task) + dataset_dir = root_dir / task tarfile_name = f"{dataset_dir}.tar" if download: - download_and_extract(self.resource[task], tarfile_name, root_dir, self.md5[task]) + download_and_extract( + url=self.resource[task], + filepath=tarfile_name, + output_dir=root_dir, + hash_val=self.md5[task], + hash_type="md5", + progress=progress, + ) - if not os.path.exists(dataset_dir): + if not dataset_dir.exists(): raise RuntimeError( f"Cannot find dataset directory: {dataset_dir}, please use download=True to download it." ) @@ -279,11 +310,18 @@ def __init__( "numTraining", "numTest", ] - self._properties = load_decathlon_properties(os.path.join(dataset_dir, "dataset.json"), property_keys) + self._properties = load_decathlon_properties(dataset_dir / "dataset.json", property_keys) if transform == (): transform = LoadImaged(["image", "label"]) CacheDataset.__init__( - self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers + self, + data=data, + transform=transform, + cache_num=cache_num, + cache_rate=cache_rate, + num_workers=num_workers, + progress=progress, + copy_cache=copy_cache, ) def get_indices(self) -> np.ndarray: @@ -293,7 +331,7 @@ def get_indices(self) -> np.ndarray: """ return self.indices - def randomize(self, data: List[int]) -> None: + def randomize(self, data: np.ndarray) -> None: self.R.shuffle(data) def get_properties(self, keys: Optional[Union[Sequence[str], str]] = None): @@ -308,9 +346,11 @@ def get_properties(self, keys: Optional[Union[Sequence[str], str]] = None): return {key: self._properties[key] for key in ensure_tuple(keys)} return {} - def _generate_data_list(self, dataset_dir: str) -> List[Dict]: + def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]: + # the types of the item in data list should be compatible with the dataloader + dataset_dir = Path(dataset_dir) section = "training" if self.section in ["training", "validation"] else "test" - datalist = load_decathlon_datalist(os.path.join(dataset_dir, "dataset.json"), True, section) + datalist = load_decathlon_datalist(dataset_dir / "dataset.json", True, section) return self._split_datalist(datalist) def _split_datalist(self, datalist: List[Dict]) -> List[Dict]: @@ -349,14 +389,15 @@ class CrossValidation: root_dir="./", task="Task09_Spleen", section="training", + transform=train_transform, download=True, ) dataset_fold0_train = cvdataset.get_dataset(folds=[1, 2, 3, 4]) - dataset_fold0_val = cvdataset.get_dataset(folds=0) + dataset_fold0_val = cvdataset.get_dataset(folds=0, transform=val_transform, download=False) # execute training for fold 0 ... - dataset_fold1_train = cvdataset.get_dataset(folds=[1]) - dataset_fold1_val = cvdataset.get_dataset(folds=[0, 2, 3, 4]) + dataset_fold1_train = cvdataset.get_dataset(folds=[0, 2, 3, 4]) + dataset_fold1_val = cvdataset.get_dataset(folds=1, transform=val_transform, download=False) # execute training for fold 1 ... ... @@ -366,13 +407,7 @@ class CrossValidation: """ - def __init__( - self, - dataset_cls, - nfolds: int = 5, - seed: int = 0, - **dataset_params, - ) -> None: + def __init__(self, dataset_cls, nfolds: int = 5, seed: int = 0, **dataset_params) -> None: if not hasattr(dataset_cls, "_split_datalist"): raise ValueError("dataset class must have _split_datalist API.") self.dataset_cls = dataset_cls @@ -380,20 +415,24 @@ def __init__( self.seed = seed self.dataset_params = dataset_params - def get_dataset(self, folds: Union[Sequence[int], int]): + def get_dataset(self, folds: Union[Sequence[int], int], **dataset_params): """ Generate dataset based on the specified fold indice in the cross validation group. Args: folds: index of folds for training or validation, if a list of values, concatenate the data. + dataset_params: other additional parameters for the dataset_cls base class, will override + the same parameters in `self.dataset_params`. """ nfolds = self.nfolds seed = self.seed + dataset_params_ = dict(self.dataset_params) + dataset_params_.update(dataset_params) class _NsplitsDataset(self.dataset_cls): # type: ignore def _split_datalist(self, datalist: List[Dict]) -> List[Dict]: data = partition_dataset(data=datalist, num_partitions=nfolds, shuffle=True, seed=seed) return select_cross_validation_folds(partitions=data, folds=folds) - return _NsplitsDataset(**self.dataset_params) + return _NsplitsDataset(**dataset_params_) diff --git a/monai/apps/deepedit/__init__.py b/monai/apps/deepedit/__init__.py index 14ae193634..1e97f89407 100644 --- a/monai/apps/deepedit/__init__.py +++ b/monai/apps/deepedit/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/apps/deepedit/transforms.py b/monai/apps/deepedit/transforms.py index 845e7bd1d0..abd9e01473 100644 --- a/monai/apps/deepedit/transforms.py +++ b/monai/apps/deepedit/transforms.py @@ -1,3 +1,14 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import logging from typing import Dict, Hashable, Mapping, Tuple @@ -15,12 +26,7 @@ class DiscardAddGuidanced(MapTransform): - def __init__( - self, - keys: KeysCollection, - probability: float = 1.0, - allow_missing_keys: bool = False, - ): + def __init__(self, keys: KeysCollection, probability: float = 1.0, allow_missing_keys: bool = False): """ Discard positive and negative points randomly or Add the two channels for inference time @@ -54,11 +60,7 @@ class ResizeGuidanceCustomd(Transform): Resize the guidance based on cropped vs resized image. """ - def __init__( - self, - guidance: str, - ref_image: str, - ) -> None: + def __init__(self, guidance: str, ref_image: str) -> None: self.guidance = guidance self.ref_image = ref_image @@ -69,8 +71,8 @@ def __call__(self, data): factor = np.divide(current_shape, d["image_meta_dict"]["dim"][1:4]) pos_clicks, neg_clicks = d["foreground"], d["background"] - pos = np.multiply(pos_clicks, factor).astype(int).tolist() if len(pos_clicks) else [] - neg = np.multiply(neg_clicks, factor).astype(int).tolist() if len(neg_clicks) else [] + pos = np.multiply(pos_clicks, factor).astype(int, copy=False).tolist() if len(pos_clicks) else [] + neg = np.multiply(neg_clicks, factor).astype(int, copy=False).tolist() if len(neg_clicks) else [] d[self.guidance] = [pos, neg] return d @@ -156,7 +158,7 @@ def _apply(self, guidance, discrepancy): guidance[0].append([-1] * len(neg)) guidance[1].append(neg) - return json.dumps(np.asarray(guidance).astype(int).tolist()) + return json.dumps(np.asarray(guidance, dtype=int).tolist()) def __call__(self, data): d = dict(data) diff --git a/monai/apps/deepgrow/__init__.py b/monai/apps/deepgrow/__init__.py index 14ae193634..1e97f89407 100644 --- a/monai/apps/deepgrow/__init__.py +++ b/monai/apps/deepgrow/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/apps/deepgrow/dataset.py b/monai/apps/deepgrow/dataset.py index acaeba0bc3..763377763a 100644 --- a/monai/apps/deepgrow/dataset.py +++ b/monai/apps/deepgrow/dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -97,7 +97,7 @@ def create_dataset( image = os.path.abspath(image) label = os.path.abspath(label) if label else None - logging.info("Image: {}; Label: {}".format(image, label if label else None)) + logging.info(f"Image: {image}; Label: {label if label else None}") data = transforms({image_key: image, label_key: label}) if dimension == 2: data = _save_data_2d( @@ -154,7 +154,7 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): if vol_label is not None and np.sum(label) == 0: continue - image_file_prefix = "vol_idx_{:0>4d}_slice_{:0>3d}".format(vol_idx, sid) + image_file_prefix = f"vol_idx_{vol_idx:0>4d}_slice_{sid:0>3d}" image_file = os.path.join(dataset_dir, "images", image_file_prefix) image_file += ".npy" @@ -165,9 +165,7 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): # Test Data if vol_label is None: data_list.append( - { - "image": image_file.replace(dataset_dir + os.pathsep, "") if relative_path else image_file, - } + {"image": image_file.replace(dataset_dir + os.pathsep, "") if relative_path else image_file} ) continue @@ -177,7 +175,7 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): unique_labels_count = max(unique_labels_count, len(unique_labels)) for idx in unique_labels: - label_file_prefix = "{}_region_{:0>2d}".format(image_file_prefix, int(idx)) + label_file_prefix = f"{image_file_prefix}_region_{int(idx):0>2d}" label_file = os.path.join(dataset_dir, "labels", label_file_prefix) label_file += ".npy" @@ -226,7 +224,7 @@ def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): label_count = 0 unique_labels_count = 0 - image_file_prefix = "vol_idx_{:0>4d}".format(vol_idx) + image_file_prefix = f"vol_idx_{vol_idx:0>4d}" image_file = os.path.join(dataset_dir, "images", image_file_prefix) image_file += ".npy" @@ -236,11 +234,7 @@ def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): # Test Data if vol_label is None: - data_list.append( - { - "image": image_file.replace(dataset_dir + os.pathsep, "") if relative_path else image_file, - } - ) + data_list.append({"image": image_file.replace(dataset_dir + os.pathsep, "") if relative_path else image_file}) else: # For all Labels unique_labels = np.unique(vol_label.flatten()) @@ -248,7 +242,7 @@ def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): unique_labels_count = max(unique_labels_count, len(unique_labels)) for idx in unique_labels: - label_file_prefix = "{}_region_{:0>2d}".format(image_file_prefix, int(idx)) + label_file_prefix = f"{image_file_prefix}_region_{int(idx):0>2d}" label_file = os.path.join(dataset_dir, "labels", label_file_prefix) label_file += ".npy" diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py index 81e82c958d..73bc8e7e0b 100644 --- a/monai/apps/deepgrow/interaction.py +++ b/monai/apps/deepgrow/interaction.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,6 +22,7 @@ class Interaction: """ Ignite process_function used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation. + For more details please refer to: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. This implementation is based on: Sakinis et al., Interactive segmentation of medical images through diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index db450792b0..d70eb31c99 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -from typing import Callable, Dict, Optional, Sequence, Union +from typing import Callable, Dict, Hashable, List, Optional, Sequence, Union import numpy as np import torch @@ -18,8 +18,8 @@ from monai.networks.layers import GaussianFilter from monai.transforms import Resize, SpatialCrop from monai.transforms.transform import MapTransform, Randomizable, Transform -from monai.transforms.utils import generate_spatial_bounding_box -from monai.utils import InterpolateMode, ensure_tuple, ensure_tuple_rep, min_version, optional_import +from monai.transforms.utils import generate_spatial_bounding_box, is_positive +from monai.utils import InterpolateMode, deprecated_arg, ensure_tuple, ensure_tuple_rep, min_version, optional_import measure, _ = optional_import("skimage.measure", "0.14.2", min_version) distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") @@ -145,7 +145,7 @@ def _apply(self, label, sid): def __call__(self, data): d = dict(data) self.randomize(data) - d[self.guidance] = json.dumps(self._apply(d[self.label], self.sid).astype(int).tolist()) + d[self.guidance] = json.dumps(self._apply(d[self.label], self.sid).astype(int, copy=False).tolist()) return d @@ -163,13 +163,7 @@ class AddGuidanceSignald(Transform): """ - def __init__( - self, - image: str = "image", - guidance: str = "guidance", - sigma: int = 2, - number_intensity_ch: int = 1, - ): + def __init__(self, image: str = "image", guidance: str = "guidance", sigma: int = 2, number_intensity_ch: int = 1): self.image = image self.guidance = guidance self.sigma = sigma @@ -276,12 +270,7 @@ class AddRandomGuidanced(Randomizable, Transform): """ - def __init__( - self, - guidance: str = "guidance", - discrepancy: str = "discrepancy", - probability: str = "probability", - ): + def __init__(self, guidance: str = "guidance", discrepancy: str = "discrepancy", probability: str = "probability"): self.guidance = guidance self.discrepancy = discrepancy self.probability = probability @@ -334,7 +323,7 @@ def _apply(self, guidance, discrepancy): guidance[0].append([-1] * len(neg)) guidance[1].append(neg) - return json.dumps(np.asarray(guidance).astype(int).tolist()) + return json.dumps(np.asarray(guidance, dtype=int).tolist()) def __call__(self, data): d = dict(data) @@ -398,7 +387,7 @@ def __init__( keys: KeysCollection, source_key: str, spatial_size: Union[Sequence[int], np.ndarray], - select_fn: Callable = lambda x: x > 0, + select_fn: Callable = is_positive, channel_indices: Optional[IndexSelection] = None, margin: int = 0, meta_keys: Optional[KeysCollection] = None, @@ -431,8 +420,8 @@ def __call__(self, data): d[self.source_key], self.select_fn, self.channel_indices, self.margin ) - center = list(np.mean([box_start, box_end], axis=0).astype(int)) - current_size = list(np.subtract(box_end, box_start).astype(int)) + center = list(np.mean([box_start, box_end], axis=0).astype(int, copy=False)) + current_size = list(np.subtract(box_end, box_start).astype(int, copy=False)) if np.all(np.less(current_size, self.spatial_size)): cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size) @@ -476,7 +465,7 @@ class AddGuidanceFromPointsd(Transform): background: key that represents user background (-ve) clicks. axis: axis that represents slices in 3D volume. (axis to Depth) depth_first: if depth (slices) is positioned at first dimension. - dimensions: dimensions based on model used for deepgrow (2D vs 3D). + spatial_dims: dimensions based on model used for deepgrow (2D vs 3D). slice_key: key that represents applicable slice to add guidance. meta_keys: explicitly indicate the key of the meta data dictionary of `ref_image`. for example, for data with key `image`, the metadata by default is in `image_meta_dict`. @@ -486,8 +475,13 @@ class AddGuidanceFromPointsd(Transform): to the key data, default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. + + .. deprecated:: 0.6.0 + ``dimensions`` is deprecated, use ``spatial_dims`` instead. + """ + @deprecated_arg(name="dimensions", since="0.6", msg_suffix="Please use `spatial_dims` instead.") def __init__( self, ref_image, @@ -496,10 +490,11 @@ def __init__( background: str = "background", axis: int = 0, depth_first: bool = True, - dimensions: int = 2, + spatial_dims: int = 2, slice_key: str = "slice", meta_keys: Optional[str] = None, meta_key_postfix: str = "meta_dict", + dimensions: Optional[int] = None, ): self.ref_image = ref_image self.guidance = guidance @@ -507,7 +502,7 @@ def __init__( self.background = background self.axis = axis self.depth_first = depth_first - self.dimensions = dimensions + self.dimensions = spatial_dims if dimensions is None else dimensions self.slice = slice_key self.meta_keys = meta_keys self.meta_key_postfix = meta_key_postfix @@ -533,9 +528,9 @@ def _apply(self, pos_clicks, neg_clicks, factor, slice_num): guidance = [pos, neg, slice_idx] else: if len(pos_clicks): - pos = np.multiply(pos_clicks, factor).astype(int).tolist() + pos = np.multiply(pos_clicks, factor).astype(int, copy=False).tolist() if len(neg_clicks): - neg = np.multiply(neg_clicks, factor).astype(int).tolist() + neg = np.multiply(neg_clicks, factor).astype(int, copy=False).tolist() guidance = [pos, neg] return guidance @@ -560,7 +555,7 @@ def __call__(self, data): fg_bg_clicks = [] for key in [self.foreground, self.background]: clicks = d[key] - clicks = list(np.array(clicks).astype(int)) + clicks = list(np.array(clicks, dtype=int)) if self.depth_first: for i in range(len(clicks)): clicks[i] = list(np.roll(clicks[i], 1)) @@ -649,13 +644,17 @@ def bounding_box(self, points, img_shape): def __call__(self, data): d: Dict = dict(data) + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + guidance = d[self.guidance] - original_spatial_shape = d[self.keys[0]].shape[1:] + original_spatial_shape = d[first_key].shape[1:] # type: ignore box_start, box_end = self.bounding_box(np.array(guidance[0] + guidance[1]), original_spatial_shape) - center = list(np.mean([box_start, box_end], axis=0).astype(int)) + center = list(np.mean([box_start, box_end], axis=0).astype(int, copy=False)) spatial_size = self.spatial_size - box_size = list(np.subtract(box_end, box_start).astype(int)) + box_size = list(np.subtract(box_end, box_start).astype(int, copy=False)) spatial_size = spatial_size[-len(box_size) :] if len(spatial_size) < len(box_size): @@ -739,8 +738,8 @@ def __call__(self, data): factor = np.divide(current_shape, cropped_shape) pos_clicks, neg_clicks = guidance[0], guidance[1] - pos = np.multiply(pos_clicks, factor).astype(int).tolist() if len(pos_clicks) else [] - neg = np.multiply(neg_clicks, factor).astype(int).tolist() if len(neg_clicks) else [] + pos = np.multiply(pos_clicks, factor).astype(int, copy=False).tolist() if len(pos_clicks) else [] + neg = np.multiply(neg_clicks, factor).astype(int, copy=False).tolist() if len(neg_clicks) else [] d[self.guidance] = [pos, neg] return d diff --git a/monai/apps/mmars/__init__.py b/monai/apps/mmars/__init__.py index 396be2e87d..8f1448bb06 100644 --- a/monai/apps/mmars/__init__.py +++ b/monai/apps/mmars/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/apps/mmars/mmars.py b/monai/apps/mmars/mmars.py index e7ff28ce44..801b826bd1 100644 --- a/monai/apps/mmars/mmars.py +++ b/monai/apps/mmars/mmars.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,14 +17,15 @@ """ import json -import os import warnings -from typing import Mapping, Union +from pathlib import Path +from typing import Mapping, Optional, Union import torch import monai.networks.nets as monai_nets -from monai.apps.utils import download_and_extract +from monai.apps.utils import download_and_extract, logger +from monai.config.type_definitions import PathLike from monai.utils.module import optional_import from .model_desc import MODEL_DESC @@ -42,7 +43,7 @@ def get_model_spec(idx: Union[int, str]): for cand in MODEL_DESC: if str(cand[Keys.ID]).strip().lower() == key: return cand - print(f"Available specs are: {MODEL_DESC}.") + logger.info(f"Available specs are: {MODEL_DESC}.") raise ValueError(f"Unknown MODEL_DESC request: {idx}") @@ -98,7 +99,9 @@ def _get_ngc_doc_url(model_name: str, model_prefix=""): return f"https://ngc.nvidia.com/catalog/models/{model_prefix}{model_name}" -def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False, version: int = -1): +def download_mmar( + item, mmar_dir: Optional[PathLike] = None, progress: bool = True, api: bool = False, version: int = -1 +): """ Download and extract Medical Model Archive (MMAR) from Nvidia Clara Train. @@ -128,10 +131,10 @@ def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False, if not mmar_dir: get_dir, has_home = optional_import("torch.hub", name="get_dir") if has_home: - mmar_dir = os.path.join(get_dir(), "mmars") + mmar_dir = Path(get_dir()) / "mmars" else: raise ValueError("mmar_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?") - + mmar_dir = Path(mmar_dir) if api: model_dict = _get_all_ngc_models(item) if len(model_dict) == 0: @@ -140,10 +143,10 @@ def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False, for k, v in model_dict.items(): ver = v["latest"] if version == -1 else str(version) download_url = _get_ngc_url(k, ver) - model_dir = os.path.join(mmar_dir, v["name"]) + model_dir = mmar_dir / v["name"] download_and_extract( url=download_url, - filepath=os.path.join(mmar_dir, f'{v["name"]}_{ver}.zip'), + filepath=mmar_dir / f'{v["name"]}_{ver}.zip', output_dir=model_dir, hash_val=None, hash_type="md5", @@ -161,11 +164,11 @@ def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False, if version > 0: ver = str(version) model_fullname = f"{item[Keys.NAME]}_{ver}" - model_dir = os.path.join(mmar_dir, model_fullname) + model_dir = mmar_dir / model_fullname model_url = item.get(Keys.URL) or _get_ngc_url(item[Keys.NAME], version=ver, model_prefix="nvidia/med/") download_and_extract( url=model_url, - filepath=os.path.join(mmar_dir, f"{model_fullname}.{item[Keys.FILE_TYPE]}"), + filepath=mmar_dir / f"{model_fullname}.{item[Keys.FILE_TYPE]}", output_dir=model_dir, hash_val=item[Keys.HASH_VAL], hash_type=item[Keys.HASH_TYPE], @@ -178,7 +181,7 @@ def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False, def load_from_mmar( item, - mmar_dir=None, + mmar_dir: Optional[PathLike] = None, progress: bool = True, version: int = -1, map_location=None, @@ -212,11 +215,11 @@ def load_from_mmar( if not isinstance(item, Mapping): item = get_model_spec(item) model_dir = download_mmar(item=item, mmar_dir=mmar_dir, progress=progress, version=version) - model_file = os.path.join(model_dir, item[Keys.MODEL_FILE]) - print(f'\n*** "{item[Keys.ID]}" available at {model_dir}.') + model_file = model_dir / item[Keys.MODEL_FILE] + logger.info(f'\n*** "{item[Keys.ID]}" available at {model_dir}.') # loading with `torch.jit.load` - if f"{model_file}".endswith(".ts"): + if model_file.name.endswith(".ts"): if not pretrained: warnings.warn("Loading a ScriptModule, 'pretrained' option ignored.") if weights_only: @@ -232,7 +235,7 @@ def load_from_mmar( model_config = _get_val(dict(model_dict).get("train_conf", {}), key=model_key, default={}) if not model_config: # 2. search json CONFIG_FILE for model config spec. - json_path = os.path.join(model_dir, item.get(Keys.CONFIG_FILE, "config_train.json")) + json_path = model_dir / item.get(Keys.CONFIG_FILE, "config_train.json") with open(json_path) as f: conf_dict = json.load(f) conf_dict = dict(conf_dict) @@ -264,18 +267,18 @@ def load_from_mmar( else: raise ValueError(f"Could not load model config {model_config}.") - print(f"*** Model: {model_cls}") + logger.info(f"*** Model: {model_cls}") model_kwargs = model_config.get("args", None) if model_kwargs: model_inst = model_cls(**model_kwargs) - print(f"*** Model params: {model_kwargs}") + logger.info(f"*** Model params: {model_kwargs}") else: model_inst = model_cls() if pretrained: model_inst.load_state_dict(model_dict.get(model_key, model_dict)) - print("\n---") + logger.info("\n---") doc_url = item.get(Keys.DOC) or _get_ngc_doc_url(item[Keys.NAME], model_prefix="nvidia:med:") - print(f"For more information, please visit {doc_url}\n") + logger.info(f"For more information, please visit {doc_url}\n") return model_inst diff --git a/monai/apps/mmars/model_desc.py b/monai/apps/mmars/model_desc.py index fca6f60da5..ae0e9cae30 100644 --- a/monai/apps/mmars/model_desc.py +++ b/monai/apps/mmars/model_desc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/apps/pathology/__init__.py b/monai/apps/pathology/__init__.py index 80f32403ea..81742caf65 100644 --- a/monai/apps/pathology/__init__.py +++ b/monai/apps/pathology/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/apps/pathology/data/__init__.py b/monai/apps/pathology/data/__init__.py index 64556b6f6e..e1b2ef7bd2 100644 --- a/monai/apps/pathology/data/__init__.py +++ b/monai/apps/pathology/data/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/apps/pathology/data/datasets.py b/monai/apps/pathology/data/datasets.py index 3694ca4144..bfab7c49da 100644 --- a/monai/apps/pathology/data/datasets.py +++ b/monai/apps/pathology/data/datasets.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -64,7 +64,7 @@ def __init__( self.patch_size = ensure_tuple_rep(patch_size, 2) self.image_path_list = list({x["image"] for x in self.data}) - self.image_reader_name = image_reader_name + self.image_reader_name = image_reader_name.lower() self.image_reader = WSIReader(image_reader_name) self.wsi_object_dict = None if self.image_reader_name != "openslide": @@ -122,6 +122,10 @@ class SmartCachePatchWSIDataset(SmartCacheDataset): num_replace_workers: the number of worker threads to prepare the replacement cache for every epoch. If num_replace_workers is None then the number returned by os.cpu_count() is used. progress: whether to display a progress bar when caching for the first epoch. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cache content + or every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. """ @@ -139,6 +143,7 @@ def __init__( num_init_workers: Optional[int] = None, num_replace_workers: Optional[int] = None, progress: bool = True, + copy_cache: bool = True, ): patch_wsi_dataset = PatchWSIDataset( data=data, @@ -157,6 +162,7 @@ def __init__( num_replace_workers=num_replace_workers, progress=progress, shuffle=False, + copy_cache=copy_cache, ) @@ -190,7 +196,7 @@ def __init__( self.patch_size = ensure_tuple_rep(patch_size, 2) # set up whole slide image reader - self.image_reader_name = image_reader_name + self.image_reader_name = image_reader_name.lower() self.image_reader = WSIReader(image_reader_name) # process data and create a list of dictionaries containing all required data and metadata @@ -293,11 +299,7 @@ def _load_a_patch(self, index): location_on_image = sample["image_locations"][patch_num] location_on_mask = sample["mask_locations"][patch_num] - image, _ = self.image_reader.get_data( - img=sample["image"], - location=location_on_image, - size=self.patch_size, - ) + image, _ = self.image_reader.get_data(img=sample["image"], location=location_on_image, size=self.patch_size) processed_sample = {"image": image, "name": sample["name"], "mask_location": location_on_mask} return processed_sample diff --git a/monai/apps/pathology/handlers/__init__.py b/monai/apps/pathology/handlers/__init__.py index 3a788ffa26..0638950bd8 100644 --- a/monai/apps/pathology/handlers/__init__.py +++ b/monai/apps/pathology/handlers/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/apps/pathology/handlers/prob_map_producer.py b/monai/apps/pathology/handlers/prob_map_producer.py index 7ac4a0e45b..0615b44b8f 100644 --- a/monai/apps/pathology/handlers/prob_map_producer.py +++ b/monai/apps/pathology/handlers/prob_map_producer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -62,9 +62,10 @@ def attach(self, engine: Engine) -> None: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - self.num_images = len(engine.data_loader.dataset.data) + data_loader = engine.data_loader # type: ignore + self.num_images = len(data_loader.dataset.data) - for sample in engine.data_loader.dataset.data: + for sample in data_loader.dataset.data: name = sample["name"] self.prob_map[name] = np.zeros(sample["mask_shape"], dtype=self.dtype) self.counter[name] = len(sample["mask_locations"]) @@ -84,6 +85,8 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ + if not isinstance(engine.state.batch, dict) or not isinstance(engine.state.output, dict): + raise ValueError("engine.state.batch and engine.state.output must be dictionaries.") names = engine.state.batch["name"] locs = engine.state.batch["mask_location"] pred = engine.state.output["pred"] diff --git a/monai/apps/pathology/metrics/__init__.py b/monai/apps/pathology/metrics/__init__.py index ad62df524a..f19811dcaf 100644 --- a/monai/apps/pathology/metrics/__init__.py +++ b/monai/apps/pathology/metrics/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/apps/pathology/metrics/lesion_froc.py b/monai/apps/pathology/metrics/lesion_froc.py index 2140de0080..6073bd0cda 100644 --- a/monai/apps/pathology/metrics/lesion_froc.py +++ b/monai/apps/pathology/metrics/lesion_froc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -78,11 +78,7 @@ def __init__( self.itc_diameter = itc_diameter self.eval_thresholds = eval_thresholds self.image_reader = WSIReader(image_reader_name) - self.nms = PathologyProbNMS( - sigma=nms_sigma, - prob_threshold=nms_prob_threshold, - box_size=nms_box_size, - ) + self.nms = PathologyProbNMS(sigma=nms_sigma, prob_threshold=nms_prob_threshold, box_size=nms_box_size) def prepare_inference_result(self, sample: Dict): """ @@ -151,12 +147,7 @@ def compute_fp_tp(self): total_tp_probs.extend(tp_probs) total_num_targets += num_targets - return ( - np.array(total_fp_probs), - np.array(total_tp_probs), - total_num_targets, - num_images, - ) + return np.array(total_fp_probs), np.array(total_tp_probs), total_num_targets, num_images def evaluate(self): """ @@ -168,17 +159,12 @@ def evaluate(self): # compute FROC curve given the evaluation of all images fps_per_image, total_sensitivity = compute_froc_curve_data( - fp_probs=fp_probs, - tp_probs=tp_probs, - num_targets=num_targets, - num_images=num_images, + fp_probs=fp_probs, tp_probs=tp_probs, num_targets=num_targets, num_images=num_images ) # compute FROC score give specific evaluation threshold froc_score = compute_froc_score( - fps_per_image=fps_per_image, - total_sensitivity=total_sensitivity, - eval_thresholds=self.eval_thresholds, + fps_per_image=fps_per_image, total_sensitivity=total_sensitivity, eval_thresholds=self.eval_thresholds ) return froc_score diff --git a/monai/apps/pathology/transforms/__init__.py b/monai/apps/pathology/transforms/__init__.py index 1be96b8e34..290c0ba6a8 100644 --- a/monai/apps/pathology/transforms/__init__.py +++ b/monai/apps/pathology/transforms/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,8 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .spatial.array import SplitOnGrid -from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict +from .spatial.array import SplitOnGrid, TileOnGrid +from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict from .stain.array import ExtractHEStains, NormalizeHEStains from .stain.dictionary import ( ExtractHEStainsd, diff --git a/monai/apps/pathology/transforms/spatial/__init__.py b/monai/apps/pathology/transforms/spatial/__init__.py index 07ba222ab0..eed111d2b6 100644 --- a/monai/apps/pathology/transforms/spatial/__init__.py +++ b/monai/apps/pathology/transforms/spatial/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,5 +9,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .array import SplitOnGrid -from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict +from .array import SplitOnGrid, TileOnGrid +from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict diff --git a/monai/apps/pathology/transforms/spatial/array.py b/monai/apps/pathology/transforms/spatial/array.py index 53e0c63715..fe1383c08d 100644 --- a/monai/apps/pathology/transforms/spatial/array.py +++ b/monai/apps/pathology/transforms/spatial/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,13 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Optional, Sequence, Tuple, Union +import numpy as np import torch +from numpy.lib.stride_tricks import as_strided -from monai.transforms.transform import Transform +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.transform import Randomizable, Transform +from monai.utils import convert_data_type, convert_to_dst_type +from monai.utils.enums import TransformBackends -__all__ = ["SplitOnGrid"] +__all__ = ["SplitOnGrid", "TileOnGrid"] class SplitOnGrid(Transform): @@ -24,19 +29,19 @@ class SplitOnGrid(Transform): This transform works only with torch.Tensor inputs. Args: - grid_shape: a tuple or an integer define the shape of the grid upon which to extract patches. + grid_size: a tuple or an integer define the shape of the grid upon which to extract patches. If it's an integer, the value will be repeated for each dimension. Default is 2x2 patch_size: a tuple or an integer that defines the output patch sizes. If it's an integer, the value will be repeated for each dimension. - The default is (0, 0), where the patch size will be infered from the grid shape. + The default is (0, 0), where the patch size will be inferred from the grid shape. - Note: the shape of the input image is infered based on the first image used. + Note: the shape of the input image is inferred based on the first image used. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( - self, - grid_size: Union[int, Tuple[int, int]] = (2, 2), - patch_size: Optional[Union[int, Tuple[int, int]]] = None, + self, grid_size: Union[int, Tuple[int, int]] = (2, 2), patch_size: Optional[Union[int, Tuple[int, int]]] = None ): # Grid size if isinstance(grid_size, int): @@ -50,17 +55,42 @@ def __init__( else: self.patch_size = patch_size - def __call__(self, image: torch.Tensor) -> torch.Tensor: + def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: if self.grid_size == (1, 1) and self.patch_size is None: - return torch.stack([image]) + if isinstance(image, torch.Tensor): + return torch.stack([image]) + elif isinstance(image, np.ndarray): + return np.stack([image]) # type: ignore + else: + raise ValueError(f"Input type [{type(image)}] is not supported.") + patch_size, steps = self.get_params(image.shape[1:]) - patches = ( - image.unfold(1, patch_size[0], steps[0]) - .unfold(2, patch_size[1], steps[1]) - .flatten(1, 2) - .transpose(0, 1) - .contiguous() - ) + patches: NdarrayOrTensor + if isinstance(image, torch.Tensor): + patches = ( + image.unfold(1, patch_size[0], steps[0]) + .unfold(2, patch_size[1], steps[1]) + .flatten(1, 2) + .transpose(0, 1) + .contiguous() + ) + elif isinstance(image, np.ndarray): + x_step, y_step = steps + c_stride, x_stride, y_stride = image.strides + n_channels = image.shape[0] + patches = as_strided( + image, + shape=(*self.grid_size, n_channels, patch_size[0], patch_size[1]), + strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), + writeable=False, + ) + # flatten the first two dimensions + patches = patches.reshape(np.prod(patches.shape[:2]), *patches.shape[2:]) + # make it a contiguous array + patches = np.ascontiguousarray(patches) + else: + raise ValueError(f"Input type [{type(image)}] is not supported.") + return patches def get_params(self, image_size): @@ -75,3 +105,159 @@ def get_params(self, image_size): ) return patch_size, steps + + +class TileOnGrid(Randomizable, Transform): + """ + Tile the 2D image into patches on a grid and maintain a subset of it. + This transform works only with np.ndarray inputs for 2D images. + + Args: + tile_count: number of tiles to extract, if None extracts all non-background tiles + Defaults to ``None``. + tile_size: size of the square tile + Defaults to ``256``. + step: step size + Defaults to ``None`` (same as tile_size) + random_offset: Randomize position of the grid, instead of starting from the top-left corner + Defaults to ``False``. + pad_full: pad image to the size evenly divisible by tile_size + Defaults to ``False``. + background_val: the background constant (e.g. 255 for white background) + Defaults to ``255``. + filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is more than tile_size, + then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for random) subset + Defaults to ``min`` (which assumes background is high value) + + """ + + backend = [TransformBackends.NUMPY] + + def __init__( + self, + tile_count: Optional[int] = None, + tile_size: int = 256, + step: Optional[int] = None, + random_offset: bool = False, + pad_full: bool = False, + background_val: int = 255, + filter_mode: str = "min", + ): + self.tile_count = tile_count + self.tile_size = tile_size + self.random_offset = random_offset + self.pad_full = pad_full + self.background_val = background_val + self.filter_mode = filter_mode + + if step is None: + # non-overlapping grid + self.step = self.tile_size + else: + self.step = step + + self.offset = (0, 0) + self.random_idxs = np.array((0,)) + + if self.filter_mode not in ["min", "max", "random"]: + raise ValueError("Unsupported filter_mode, must be [min, max or random]: " + str(self.filter_mode)) + + def randomize(self, img_size: Sequence[int]) -> None: + + c, h, w = img_size + + self.offset = (0, 0) + if self.random_offset: + pad_h = h % self.tile_size + pad_w = w % self.tile_size + self.offset = (self.R.randint(pad_h) if pad_h > 0 else 0, self.R.randint(pad_w) if pad_w > 0 else 0) + h = h - self.offset[0] + w = w - self.offset[1] + + if self.pad_full: + pad_h = (self.tile_size - h % self.tile_size) % self.tile_size + pad_w = (self.tile_size - w % self.tile_size) % self.tile_size + h = h + pad_h + w = w + pad_w + + h_n = (h - self.tile_size + self.step) // self.step + w_n = (w - self.tile_size + self.step) // self.step + tile_total = h_n * w_n + + if self.tile_count is not None and tile_total > self.tile_count: + self.random_idxs = self.R.choice(range(tile_total), self.tile_count, replace=False) + else: + self.random_idxs = np.array((0,)) + + def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: + img_np: np.ndarray + img_np, *_ = convert_data_type(image, np.ndarray) # type: ignore + + # add random offset + self.randomize(img_size=img_np.shape) + + if self.random_offset and (self.offset[0] > 0 or self.offset[1] > 0): + img_np = img_np[:, self.offset[0] :, self.offset[1] :] + + # pad to full size, divisible by tile_size + if self.pad_full: + c, h, w = img_np.shape + pad_h = (self.tile_size - h % self.tile_size) % self.tile_size + pad_w = (self.tile_size - w % self.tile_size) % self.tile_size + img_np = np.pad( + img_np, + [[0, 0], [pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2]], + constant_values=self.background_val, + ) + + # extact tiles + x_step, y_step = self.step, self.step + h_tile, w_tile = self.tile_size, self.tile_size + c_image, h_image, w_image = img_np.shape + c_stride, x_stride, y_stride = img_np.strides + llw = as_strided( + img_np, + shape=((h_image - h_tile) // x_step + 1, (w_image - w_tile) // y_step + 1, c_image, h_tile, w_tile), + strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), + writeable=False, + ) + img_np = llw.reshape(-1, c_image, h_tile, w_tile) + + # if keeping all patches + if self.tile_count is None: + # retain only patches with significant foreground content to speed up inference + # FYI, this returns a variable number of tiles, so the batch_size must be 1 (per gpu), e.g during inference + thresh = 0.999 * 3 * self.background_val * self.tile_size * self.tile_size + if self.filter_mode == "min": + # default, keep non-background tiles (small values) + idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) < thresh) + img_np = img_np[idxs.reshape(-1)] + elif self.filter_mode == "max": + idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) >= thresh) + img_np = img_np[idxs.reshape(-1)] + + else: + if len(img_np) > self.tile_count: + + if self.filter_mode == "min": + # default, keep non-background tiles (smallest values) + idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[: self.tile_count] + img_np = img_np[idxs] + elif self.filter_mode == "max": + idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[-self.tile_count :] + img_np = img_np[idxs] + else: + # random subset (more appropriate for WSIs without distinct background) + if self.random_idxs is not None: + img_np = img_np[self.random_idxs] + + elif len(img_np) < self.tile_count: + img_np = np.pad( + img_np, + [[0, self.tile_count - len(img_np)], [0, 0], [0, 0], [0, 0]], + constant_values=self.background_val, + ) + + image, *_ = convert_to_dst_type(src=img_np, dst=image, dtype=image.dtype) + + return image diff --git a/monai/apps/pathology/transforms/spatial/dictionary.py b/monai/apps/pathology/transforms/spatial/dictionary.py index 10b01a39de..d5c34a0840 100644 --- a/monai/apps/pathology/transforms/spatial/dictionary.py +++ b/monai/apps/pathology/transforms/spatial/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,16 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Hashable, Mapping, Optional, Tuple, Union - -import torch +import copy +from typing import Any, Dict, Hashable, List, Mapping, Optional, Tuple, Union from monai.config import KeysCollection -from monai.transforms.transform import MapTransform +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.transform import MapTransform, Randomizable -from .array import SplitOnGrid +from .array import SplitOnGrid, TileOnGrid -__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict"] +__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict", "TileOnGridd", "TileOnGridD", "TileOnGridDict"] class SplitOnGridd(MapTransform): @@ -27,15 +27,17 @@ class SplitOnGridd(MapTransform): This transform works only with torch.Tensor inputs. Args: - grid_shape: a tuple or an integer define the shape of the grid upon which to extract patches. + grid_size: a tuple or an integer define the shape of the grid upon which to extract patches. If it's an integer, the value will be repeated for each dimension. Default is 2x2 patch_size: a tuple or an integer that defines the output patch sizes. If it's an integer, the value will be repeated for each dimension. - The default is (0, 0), where the patch size will be infered from the grid shape. + The default is (0, 0), where the patch size will be inferred from the grid shape. - Note: the shape of the input image is infered based on the first image used. + Note: the shape of the input image is inferred based on the first image used. """ + backend = SplitOnGrid.backend + def __init__( self, keys: KeysCollection, @@ -46,11 +48,89 @@ def __init__( super().__init__(keys, allow_missing_keys) self.splitter = SplitOnGrid(grid_size=grid_size, patch_size=patch_size) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.splitter(d[key]) return d +class TileOnGridd(Randomizable, MapTransform): + """ + Tile the 2D image into patches on a grid and maintain a subset of it. + This transform works only with np.ndarray inputs for 2D images. + + Args: + tile_count: number of tiles to extract, if None extracts all non-background tiles + Defaults to ``None``. + tile_size: size of the square tile + Defaults to ``256``. + step: step size + Defaults to ``None`` (same as tile_size) + random_offset: Randomize position of the grid, instead of starting from the top-left corner + Defaults to ``False``. + pad_full: pad image to the size evenly divisible by tile_size + Defaults to ``False``. + background_val: the background constant (e.g. 255 for white background) + Defaults to ``255``. + filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is more than tile_size, + then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for random) subset + Defaults to ``min`` (which assumes background is high value) + + """ + + backend = SplitOnGrid.backend + + def __init__( + self, + keys: KeysCollection, + tile_count: Optional[int] = None, + tile_size: int = 256, + step: Optional[int] = None, + random_offset: bool = False, + pad_full: bool = False, + background_val: int = 255, + filter_mode: str = "min", + allow_missing_keys: bool = False, + return_list_of_dicts: bool = False, + ): + super().__init__(keys, allow_missing_keys) + + self.return_list_of_dicts = return_list_of_dicts + self.seed = None + + self.splitter = TileOnGrid( + tile_count=tile_count, + tile_size=tile_size, + step=step, + random_offset=random_offset, + pad_full=pad_full, + background_val=background_val, + filter_mode=filter_mode, + ) + + def randomize(self, data: Any = None) -> None: + self.seed = self.R.randint(10000) # type: ignore + + def __call__( + self, data: Mapping[Hashable, NdarrayOrTensor] + ) -> Union[Dict[Hashable, NdarrayOrTensor], List[Dict[Hashable, NdarrayOrTensor]]]: + + self.randomize() + + d = dict(data) + for key in self.key_iterator(d): + self.splitter.set_random_state(seed=self.seed) # same random seed for all keys + d[key] = self.splitter(d[key]) + + if self.return_list_of_dicts: + d_list = [] + for i in range(len(d[self.keys[0]])): + d_list.append({k: d[k][i] if k in self.keys else copy.deepcopy(d[k]) for k in d.keys()}) + d = d_list # type: ignore + + return d + + SplitOnGridDict = SplitOnGridD = SplitOnGridd +TileOnGridDict = TileOnGridD = TileOnGridd diff --git a/monai/apps/pathology/transforms/stain/__init__.py b/monai/apps/pathology/transforms/stain/__init__.py index 824f40a579..dfa235de55 100644 --- a/monai/apps/pathology/transforms/stain/__init__.py +++ b/monai/apps/pathology/transforms/stain/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/apps/pathology/transforms/stain/array.py b/monai/apps/pathology/transforms/stain/array.py index ccddc6b243..3b3a293451 100644 --- a/monai/apps/pathology/transforms/stain/array.py +++ b/monai/apps/pathology/transforms/stain/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -52,12 +52,12 @@ def _deconvolution_extract_stain(self, image: np.ndarray) -> np.ndarray: """Perform Stain Deconvolution and return stain matrix for the image. Args: - img: uint8 RGB image to perform stain deconvolution on + image: uint8 RGB image to perform stain deconvolution on Return: he: H&E absorbance matrix for the image (first column is H, second column is E, rows are RGB values) """ - # check image type and vlues + # check image type and values if not isinstance(image, np.ndarray): raise TypeError("Image must be of type numpy.ndarray.") if image.min() < 0: @@ -67,7 +67,7 @@ def _deconvolution_extract_stain(self, image: np.ndarray) -> np.ndarray: # reshape image and calculate absorbance image = image.reshape((-1, 3)) - image = image.astype(np.float32) + 1.0 + image = image.astype(np.float32, copy=False) + 1.0 absorbance = -np.log(image.clip(max=self.tli) / self.tli) # remove transparent pixels @@ -76,7 +76,7 @@ def _deconvolution_extract_stain(self, image: np.ndarray) -> np.ndarray: raise ValueError("All pixels of the input image are below the absorbance threshold.") # compute eigenvectors - _, eigvecs = np.linalg.eigh(np.cov(absorbance_hat.T).astype(np.float32)) + _, eigvecs = np.linalg.eigh(np.cov(absorbance_hat.T).astype(np.float32, copy=False)) # project on the plane spanned by the eigenvectors corresponding to the two largest eigenvalues t_hat = absorbance_hat.dot(eigvecs[:, 1:3]) @@ -162,7 +162,7 @@ def __call__(self, image: np.ndarray) -> np.ndarray: Return: image_norm: stain normalized image/patch """ - # check image type and vlues + # check image type and values if not isinstance(image, np.ndarray): raise TypeError("Image must be of type numpy.ndarray.") if image.min() < 0: @@ -186,7 +186,7 @@ def __call__(self, image: np.ndarray) -> np.ndarray: conc = np.linalg.lstsq(he, y, rcond=None)[0] # normalize stain concentrations - max_conc = np.array([np.percentile(conc[0, :], 99), np.percentile(conc[1, :], 99)], dtype=np.float32) + max_conc = np.asarray([np.percentile(conc[0, :], 99), np.percentile(conc[1, :], 99)], dtype=np.float32) tmp = np.divide(max_conc, self.max_cref, dtype=np.float32) image_c = np.divide(conc, tmp[:, np.newaxis], dtype=np.float32) diff --git a/monai/apps/pathology/transforms/stain/dictionary.py b/monai/apps/pathology/transforms/stain/dictionary.py index 976af1e7c7..eb8eba43f8 100644 --- a/monai/apps/pathology/transforms/stain/dictionary.py +++ b/monai/apps/pathology/transforms/stain/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/apps/pathology/utils.py b/monai/apps/pathology/utils.py index 54d49f5717..5a57364a11 100644 --- a/monai/apps/pathology/utils.py +++ b/monai/apps/pathology/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -62,11 +62,7 @@ class PathologyProbNMS(ProbNMS): Pathology. """ - def __call__( - self, - probs_map: Union[np.ndarray, torch.Tensor], - resolution_level: int = 0, - ): + def __call__(self, probs_map: Union[np.ndarray, torch.Tensor], resolution_level: int = 0): """ probs_map: the input probabilities map, it must have shape (H[, W, ...]). resolution_level: the level at which the probabilities map is made. diff --git a/monai/apps/utils.py b/monai/apps/utils.py index 36fac955fe..77138fc45b 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,17 +10,21 @@ # limitations under the License. import hashlib +import logging import os import shutil +import sys import tarfile import tempfile import warnings import zipfile +from pathlib import Path from typing import TYPE_CHECKING, Optional from urllib.error import ContentTooShortError, HTTPError, URLError from urllib.request import urlretrieve -from monai.utils import min_version, optional_import +from monai.config.type_definitions import PathLike +from monai.utils import look_up_option, min_version, optional_import gdown, has_gdown = optional_import("gdown", "3.6") @@ -31,18 +35,47 @@ else: tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm") -__all__ = [ - "check_hash", - "download_url", - "extractall", - "download_and_extract", -] +__all__ = ["check_hash", "download_url", "extractall", "download_and_extract", "get_logger", "SUPPORTED_HASH_TYPES"] +DEFAULT_FMT = "%(asctime)s - %(levelname)s - %(message)s" +SUPPORTED_HASH_TYPES = {"md5": hashlib.md5, "sha1": hashlib.sha1, "sha256": hashlib.sha256, "sha512": hashlib.sha512} -def _basename(p): + +def get_logger( + module_name: str = "monai.apps", + fmt: str = DEFAULT_FMT, + datefmt: Optional[str] = None, + logger_handler: Optional[logging.Handler] = None, +): + """ + Get a `module_name` logger with the specified format and date format. + By default, the logger will print to `stdout` at the INFO level. + If `module_name` is `None`, return the root logger. + `fmt` and `datafmt` are passed to a `logging.Formatter` object + (https://docs.python.org/3/library/logging.html#formatter-objects). + `logger_handler` can be used to add an additional handler. + """ + logger = logging.getLogger(module_name) + logger.propagate = False + logger.setLevel(logging.INFO) + handler = logging.StreamHandler(sys.stdout) + formatter = logging.Formatter(fmt=fmt, datefmt=datefmt) + handler.setFormatter(formatter) + logger.addHandler(handler) + if logger_handler is not None: + logger.addHandler(logger_handler) + return logger + + +# apps module-level default logger +logger = get_logger("monai.apps") +__all__.append("logger") + + +def _basename(p: PathLike) -> str: """get the last part of the path (removing the trailing slash if it exists)""" sep = os.path.sep + (os.path.altsep or "") + "/ " - return os.path.basename(p.rstrip(sep)) + return Path(f"{p}".rstrip(sep)).name def _download_with_progress(url, filepath, progress: bool = True): @@ -69,66 +102,59 @@ def update_to(self, b: int = 1, bsize: int = 1, tsize: Optional[int] = None): self.total = tsize self.update(b * bsize - self.n) # will also set self.n = b * bsize - with TqdmUpTo( - unit="B", - unit_scale=True, - unit_divisor=1024, - miniters=1, - desc=_basename(filepath), - ) as t: + with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=_basename(filepath)) as t: urlretrieve(url, filepath, reporthook=t.update_to) else: if not has_tqdm and progress: warnings.warn("tqdm is not installed, will not show the downloading progress bar.") urlretrieve(url, filepath) - except (URLError, HTTPError, ContentTooShortError, IOError) as e: - print(f"Download failed from {url} to {filepath}.") + except (URLError, HTTPError, ContentTooShortError, OSError) as e: + logger.error(f"Download failed from {url} to {filepath}.") raise e -def check_hash(filepath: str, val: Optional[str] = None, hash_type: str = "md5") -> bool: +def check_hash(filepath: PathLike, val: Optional[str] = None, hash_type: str = "md5") -> bool: """ Verify hash signature of specified file. Args: filepath: path of source file to verify hash value. val: expected hash value of the file. - hash_type: 'md5' or 'sha1', defaults to 'md5'. + hash_type: type of hash algorithm to use, default is `"md5"`. + The supported hash types are `"md5"`, `"sha1"`, `"sha256"`, `"sha512"`. + See also: :py:data:`monai.apps.utils.SUPPORTED_HASH_TYPES`. """ if val is None: - print(f"Expected {hash_type} is None, skip {hash_type} check for file {filepath}.") + logger.info(f"Expected {hash_type} is None, skip {hash_type} check for file {filepath}.") return True - if hash_type.lower() == "md5": - actual_hash = hashlib.md5() - elif hash_type.lower() == "sha1": - actual_hash = hashlib.sha1() - else: - raise NotImplementedError(f"Unknown 'hash_type' {hash_type}.") + actual_hash_func = look_up_option(hash_type.lower(), SUPPORTED_HASH_TYPES) + actual_hash = actual_hash_func() try: with open(filepath, "rb") as f: for chunk in iter(lambda: f.read(1024 * 1024), b""): actual_hash.update(chunk) except Exception as e: - print(f"Exception in check_hash: {e}") + logger.error(f"Exception in check_hash: {e}") return False if val != actual_hash.hexdigest(): - print(f"check_hash failed {actual_hash.hexdigest()}.") + logger.error(f"check_hash failed {actual_hash.hexdigest()}.") return False - print(f"Verified '{_basename(filepath)}', {hash_type}: {val}.") + logger.info(f"Verified '{_basename(filepath)}', {hash_type}: {val}.") return True def download_url( - url: str, filepath: str = "", hash_val: Optional[str] = None, hash_type: str = "md5", progress: bool = True + url: str, filepath: PathLike = "", hash_val: Optional[str] = None, hash_type: str = "md5", progress: bool = True ) -> None: """ Download file from specified URL link, support process bar and hash check. Args: url: source URL link to download file. - filepath: target filepath to save the downloaded file. If undefined, `os.path.basename(url)` will be used. + filepath: target filepath to save the downloaded file (including the filename). + If undefined, `os.path.basename(url)` will be used. hash_val: expected hash value to validate the downloaded file. if None, skip hash validation. hash_type: 'md5' or 'sha1', defaults to 'md5'. @@ -146,33 +172,34 @@ def download_url( """ if not filepath: - filepath = os.path.abspath(os.path.join(".", _basename(url))) - print(f"Default downloading to '{filepath}'") - if os.path.exists(filepath): + filepath = Path(".", _basename(url)).resolve() + logger.info(f"Default downloading to '{filepath}'") + filepath = Path(filepath) + if filepath.exists(): if not check_hash(filepath, hash_val, hash_type): raise RuntimeError( f"{hash_type} check of existing file failed: filepath={filepath}, expected {hash_type}={hash_val}." ) - print(f"File exists: {filepath}, skipped downloading.") + logger.info(f"File exists: {filepath}, skipped downloading.") return with tempfile.TemporaryDirectory() as tmp_dir: - tmp_name = os.path.join(tmp_dir, f"{_basename(filepath)}") + tmp_name = Path(tmp_dir, _basename(filepath)) if url.startswith("https://drive.google.com"): if not has_gdown: raise RuntimeError("To download files from Google Drive, please install the gdown dependency.") - gdown.download(url, tmp_name, quiet=not progress) + gdown.download(url, f"{tmp_name}", quiet=not progress) else: _download_with_progress(url, tmp_name, progress=progress) - if not os.path.exists(tmp_name): + if not tmp_name.exists(): raise RuntimeError( f"Download of file from {url} to {filepath} failed due to network issue or denied permission." ) - file_dir = os.path.dirname(filepath) + file_dir = filepath.parent if file_dir: os.makedirs(file_dir, exist_ok=True) - shutil.move(tmp_name, filepath) # copy the downloaded to a user-specified cache. - print(f"Downloaded: {filepath}") + shutil.move(f"{tmp_name}", f"{filepath}") # copy the downloaded to a user-specified cache. + logger.info(f"Downloaded: {filepath}") if not check_hash(filepath, hash_val, hash_type): raise RuntimeError( f"{hash_type} check of downloaded file failed: URL={url}, " @@ -181,8 +208,8 @@ def download_url( def extractall( - filepath: str, - output_dir: str = ".", + filepath: PathLike, + output_dir: PathLike = ".", hash_val: Optional[str] = None, hash_type: str = "md5", file_type: str = "", @@ -211,24 +238,25 @@ def extractall( """ if has_base: # the extracted files will be in this folder - cache_dir = os.path.join(output_dir, _basename(filepath).split(".")[0]) + cache_dir = Path(output_dir, _basename(filepath).split(".")[0]) else: - cache_dir = output_dir - if os.path.exists(cache_dir) and len(os.listdir(cache_dir)) > 0: - print(f"Non-empty folder exists in {cache_dir}, skipped extracting.") + cache_dir = Path(output_dir) + if cache_dir.exists() and next(cache_dir.iterdir(), None) is not None: + logger.info(f"Non-empty folder exists in {cache_dir}, skipped extracting.") return + filepath = Path(filepath) if hash_val and not check_hash(filepath, hash_val, hash_type): raise RuntimeError( f"{hash_type} check of compressed file failed: " f"filepath={filepath}, expected {hash_type}={hash_val}." ) - print(f"Writing into directory: {output_dir}.") + logger.info(f"Writing into directory: {output_dir}.") _file_type = file_type.lower().strip() - if filepath.endswith("zip") or _file_type == "zip": + if filepath.name.endswith("zip") or _file_type == "zip": zip_file = zipfile.ZipFile(filepath) zip_file.extractall(output_dir) zip_file.close() return - if filepath.endswith("tar") or filepath.endswith("tar.gz") or "tar" in _file_type: + if filepath.name.endswith("tar") or filepath.name.endswith("tar.gz") or "tar" in _file_type: tar_file = tarfile.open(filepath) tar_file.extractall(output_dir) tar_file.close() @@ -240,8 +268,8 @@ def extractall( def download_and_extract( url: str, - filepath: str = "", - output_dir: str = ".", + filepath: PathLike = "", + output_dir: PathLike = ".", hash_val: Optional[str] = None, hash_type: str = "md5", file_type: str = "", @@ -268,6 +296,6 @@ def download_and_extract( progress: whether to display progress bar. """ with tempfile.TemporaryDirectory() as tmp_dir: - filename = filepath or os.path.join(tmp_dir, f"{_basename(url)}") + filename = filepath or Path(tmp_dir, _basename(url)).resolve() download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress) extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base) diff --git a/monai/config/__init__.py b/monai/config/__init__.py index c929cb2362..bf1b66fe92 100644 --- a/monai/config/__init__.py +++ b/monai/config/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,11 +12,21 @@ from .deviceconfig import ( USE_COMPILED, IgniteInfo, + get_config_values, get_gpu_info, + get_optional_config_values, get_system_info, print_config, print_debug_info, print_gpu_info, print_system_info, ) -from .type_definitions import DtypeLike, IndexSelection, KeysCollection, NdarrayOrTensor, NdarrayTensor, TensorOrList +from .type_definitions import ( + DtypeLike, + IndexSelection, + KeysCollection, + NdarrayOrTensor, + NdarrayTensor, + PathLike, + TensorOrList, +) diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index 273431fc72..91b944bde5 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -73,6 +73,8 @@ def get_optional_config_values(): output["psutil"] = psutil_version output["pandas"] = get_package_version("pandas") output["einops"] = get_package_version("einops") + output["transformers"] = get_package_version("transformers") + output["mlflow"] = get_package_version("mlflow") return output @@ -88,6 +90,7 @@ def print_config(file=sys.stdout): print(f"{k} version: {v}", file=file, flush=True) print(f"MONAI flags: HAS_EXT = {HAS_EXT}, USE_COMPILED = {USE_COMPILED}") print(f"MONAI rev id: {monai.__revision_id__}") + print(f"MONAI __file__: {monai.__file__}") print("\nOptional dependencies:", file=file, flush=True) for k, v in get_optional_config_values().items(): @@ -121,7 +124,7 @@ def get_system_info() -> OrderedDict: elif output["System"] == "Darwin": _dict_append(output, "Mac version", lambda: platform.mac_ver()[0]) else: - with open("/etc/os-release", "r") as rel_f: + with open("/etc/os-release") as rel_f: linux_ver = re.search(r'PRETTY_NAME="(.*)"', rel_f.read()) if linux_ver: _dict_append(output, "Linux version", lambda: linux_ver.group(1)) @@ -198,8 +201,7 @@ def get_gpu_info() -> OrderedDict: if num_gpus > 0: _dict_append(output, "Current device", torch.cuda.current_device) - if hasattr(torch.cuda, "get_arch_list"): # get_arch_list is new in torch 1.7.1 - _dict_append(output, "Library compiled for CUDA architectures", torch.cuda.get_arch_list) + _dict_append(output, "Library compiled for CUDA architectures", torch.cuda.get_arch_list) for gpu in range(num_gpus): gpu_info = torch.cuda.get_device_properties(gpu) diff --git a/monai/config/type_definitions.py b/monai/config/type_definitions.py index 91ac74961b..686befb2eb 100644 --- a/monai/config/type_definitions.py +++ b/monai/config/type_definitions.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Collection, Hashable, Iterable, Sequence, TypeVar, Union import numpy as np @@ -29,7 +30,15 @@ # may be implemented). Consistent use of the concept and recorded documentation of # the rationale and convention behind it lowers the learning curve for new # developers. For readability, short names are preferred. -__all__ = ["KeysCollection", "IndexSelection", "DtypeLike", "NdarrayTensor", "NdarrayOrTensor", "TensorOrList"] +__all__ = [ + "KeysCollection", + "IndexSelection", + "DtypeLike", + "NdarrayTensor", + "NdarrayOrTensor", + "TensorOrList", + "PathLike", +] #: KeysCollection @@ -51,8 +60,8 @@ # container must be iterable. IndexSelection = Union[Iterable[int], int] -#: Type of datatypes: Adapted from https://github.com/numpy/numpy/blob/master/numpy/typing/_dtype_like.py -DtypeLike = Union[np.dtype, type, None] +#: Type of datatypes: Adapted from https://github.com/numpy/numpy/blob/v1.21.4/numpy/typing/_dtype_like.py#L121 +DtypeLike = Union[np.dtype, type, str, None] #: NdarrayTensor # @@ -66,3 +75,6 @@ #: TensorOrList: The TensorOrList type is used for defining `batch-first Tensor` or `list of channel-first Tensor`. TensorOrList = Union[torch.Tensor, Sequence[torch.Tensor]] + +#: PathLike: The PathLike type is used for defining a file path. +PathLike = Union[str, os.PathLike] diff --git a/monai/csrc/ext.cpp b/monai/csrc/ext.cpp index b4bb0f2c04..a2fa8bfc56 100644 --- a/monai/csrc/ext.cpp +++ b/monai/csrc/ext.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/filtering/bilateral/bilateral.cpp b/monai/csrc/filtering/bilateral/bilateral.cpp index 2720d312e2..183e13bb23 100644 --- a/monai/csrc/filtering/bilateral/bilateral.cpp +++ b/monai/csrc/filtering/bilateral/bilateral.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/filtering/bilateral/bilateral.h b/monai/csrc/filtering/bilateral/bilateral.h index c7a68d7457..288684666a 100644 --- a/monai/csrc/filtering/bilateral/bilateral.h +++ b/monai/csrc/filtering/bilateral/bilateral.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp b/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp index 2e6c7dbe20..51573ebbc0 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp b/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp index 847a452396..3e11f2ef57 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu index f73ae19ac9..a24e6ed092 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu index 719d1643d3..20df9419fa 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/filtering/filtering.h b/monai/csrc/filtering/filtering.h index 3dcdfc473b..3e680010ed 100644 --- a/monai/csrc/filtering/filtering.h +++ b/monai/csrc/filtering/filtering.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/filtering/permutohedral/hash_table.cuh b/monai/csrc/filtering/permutohedral/hash_table.cuh index 1acff5f276..5507034dea 100644 --- a/monai/csrc/filtering/permutohedral/hash_table.cuh +++ b/monai/csrc/filtering/permutohedral/hash_table.cuh @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/filtering/permutohedral/permutohedral.cpp b/monai/csrc/filtering/permutohedral/permutohedral.cpp index d8fd3eaaeb..6ffe966f2c 100644 --- a/monai/csrc/filtering/permutohedral/permutohedral.cpp +++ b/monai/csrc/filtering/permutohedral/permutohedral.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/filtering/permutohedral/permutohedral.h b/monai/csrc/filtering/permutohedral/permutohedral.h index 1c9d1a031e..cdc0f693d0 100644 --- a/monai/csrc/filtering/permutohedral/permutohedral.h +++ b/monai/csrc/filtering/permutohedral/permutohedral.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp b/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp index 8c0dc8e546..bc4b2cabc7 100644 --- a/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp +++ b/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu b/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu index d1d78eb940..fa06590fa9 100644 --- a/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu +++ b/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/lltm/lltm.h b/monai/csrc/lltm/lltm.h index 33e17416f8..398ea92bb9 100644 --- a/monai/csrc/lltm/lltm.h +++ b/monai/csrc/lltm/lltm.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/lltm/lltm_cpu.cpp b/monai/csrc/lltm/lltm_cpu.cpp index 295c592d00..7cd2251f9f 100644 --- a/monai/csrc/lltm/lltm_cpu.cpp +++ b/monai/csrc/lltm/lltm_cpu.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/lltm/lltm_cuda.cu b/monai/csrc/lltm/lltm_cuda.cu index 4633348477..1293bec7e9 100644 --- a/monai/csrc/lltm/lltm_cuda.cu +++ b/monai/csrc/lltm/lltm_cuda.cu @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/resample/bounds_common.h b/monai/csrc/resample/bounds_common.h index 4997c7d968..2a7778b76b 100644 --- a/monai/csrc/resample/bounds_common.h +++ b/monai/csrc/resample/bounds_common.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/resample/interpolation_common.h b/monai/csrc/resample/interpolation_common.h index 35899298bf..8e2edbc8b7 100644 --- a/monai/csrc/resample/interpolation_common.h +++ b/monai/csrc/resample/interpolation_common.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/resample/pushpull.h b/monai/csrc/resample/pushpull.h index 1c20cc0114..b056bb77c2 100644 --- a/monai/csrc/resample/pushpull.h +++ b/monai/csrc/resample/pushpull.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/resample/pushpull_cpu.cpp b/monai/csrc/resample/pushpull_cpu.cpp index dd10dd76ee..d83557c6c3 100644 --- a/monai/csrc/resample/pushpull_cpu.cpp +++ b/monai/csrc/resample/pushpull_cpu.cpp @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/resample/pushpull_cuda.cu b/monai/csrc/resample/pushpull_cuda.cu index 38d34ffe98..4a2d6c27ef 100644 --- a/monai/csrc/resample/pushpull_cuda.cu +++ b/monai/csrc/resample/pushpull_cuda.cu @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/utils/common_utils.h b/monai/csrc/utils/common_utils.h index 4d09377e65..aabea3c99f 100644 --- a/monai/csrc/utils/common_utils.h +++ b/monai/csrc/utils/common_utils.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/utils/meta_macros.h b/monai/csrc/utils/meta_macros.h index 980b253bbe..0f15b623e3 100644 --- a/monai/csrc/utils/meta_macros.h +++ b/monai/csrc/utils/meta_macros.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/utils/resample_utils.h b/monai/csrc/utils/resample_utils.h index bbdf258b4c..77df65e924 100644 --- a/monai/csrc/utils/resample_utils.h +++ b/monai/csrc/utils/resample_utils.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/csrc/utils/tensor_description.h b/monai/csrc/utils/tensor_description.h index dadd26c5f5..c604003b9d 100644 --- a/monai/csrc/utils/tensor_description.h +++ b/monai/csrc/utils/tensor_description.h @@ -1,5 +1,5 @@ /* -Copyright 2020 - 2021 MONAI Consortium +Copyright (c) MONAI Consortium Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/monai/data/__init__.py b/monai/data/__init__.py index fca170335b..ae0bb86c35 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,6 +17,7 @@ CacheNTransDataset, CSVDataset, Dataset, + DatasetFunc, LMDBDataset, NPZDictItemDataset, PersistentDataset, @@ -24,11 +25,16 @@ ZipDataset, ) from .dataset_summary import DatasetSummary -from .decathlon_datalist import load_decathlon_datalist, load_decathlon_properties +from .decathlon_datalist import ( + check_missing_files, + create_cross_validation_datalist, + load_decathlon_datalist, + load_decathlon_properties, +) from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter from .image_dataset import ImageDataset from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader -from .iterable_dataset import CSVIterableDataset, IterableDataset +from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti from .png_saver import PNGSaver @@ -37,6 +43,7 @@ from .synthetic import create_test_image_2d, create_test_image_3d from .test_time_augmentation import TestTimeAugmentation from .thread_buffer import ThreadBuffer, ThreadDataLoader +from .torchscript_utils import load_net_with_metadata, save_net_with_metadata from .utils import ( compute_importance_map, compute_shape_offset, @@ -57,7 +64,7 @@ partition_dataset_classes, pickle_hashing, rectify_header_sform_qform, - rep_scalar_to_batch, + resample_datalist, select_cross_validation_folds, set_rnd, sorted_dict, diff --git a/monai/data/csv_saver.py b/monai/data/csv_saver.py index 62f407bfd5..d4ca0e7576 100644 --- a/monai/data/csv_saver.py +++ b/monai/data/csv_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,11 +12,13 @@ import os import warnings from collections import OrderedDict +from pathlib import Path from typing import Dict, Optional, Union import numpy as np import torch +from monai.config.type_definitions import PathLike from monai.utils import ImageMetaKey as Key @@ -33,7 +35,7 @@ class CSVSaver: def __init__( self, - output_dir: str = "./", + output_dir: PathLike = "./", filename: str = "predictions.csv", overwrite: bool = True, flush: bool = False, @@ -48,12 +50,12 @@ def __init__( default to False. """ - self.output_dir = output_dir + self.output_dir = Path(output_dir) self._cache_dict: OrderedDict = OrderedDict() if not (isinstance(filename, str) and filename[-4:] == ".csv"): warnings.warn("CSV filename is not a string ends with '.csv'.") - self._filepath = os.path.join(output_dir, filename) - if os.path.exists(self._filepath) and overwrite: + self._filepath = self.output_dir / filename + if self._filepath.exists() and overwrite: os.remove(self._filepath) self.flush = flush @@ -64,8 +66,8 @@ def finalize(self) -> None: Writes the cached dict to a csv """ - if not os.path.exists(self.output_dir): - os.makedirs(self.output_dir) + if not self.output_dir.exists(): + self.output_dir.mkdir(parents=True, exist_ok=True) with open(self._filepath, "a") as f: for k, v in self._cache_dict.items(): f.write(k) diff --git a/monai/data/dataloader.py b/monai/data/dataloader.py index 2c9174e9f4..3117a27c02 100644 --- a/monai/data/dataloader.py +++ b/monai/data/dataloader.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -81,8 +81,4 @@ def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs) -> None: if "worker_init_fn" not in kwargs: kwargs.update({"worker_init_fn": worker_init_fn}) - super().__init__( # type: ignore[call-overload] - dataset=dataset, - num_workers=num_workers, - **kwargs, - ) + super().__init__(dataset=dataset, num_workers=num_workers, **kwargs) # type: ignore[call-overload] diff --git a/monai/data/dataset.py b/monai/data/dataset.py index c970e83d0d..cbb534f04a 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -26,12 +26,14 @@ import numpy as np import torch +from torch.serialization import DEFAULT_PROTOCOL from torch.utils.data import Dataset as _TorchDataset from torch.utils.data import Subset -from monai.data.utils import convert_tables_to_dicts, first, pickle_hashing +from monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing from monai.transforms import Compose, Randomizable, ThreadUnsafe, Transform, apply_transform -from monai.utils import MAX_SEED, ensure_tuple, get_seed, min_version, optional_import +from monai.utils import MAX_SEED, deprecated_arg, get_seed, look_up_option, min_version, optional_import +from monai.utils.misc import first if TYPE_CHECKING: from tqdm import tqdm @@ -95,6 +97,56 @@ def __getitem__(self, index: Union[int, slice, Sequence[int]]): return self._transform(index) +class DatasetFunc(Dataset): + """ + Execute function on the input dataset and leverage the output to act as a new Dataset. + It can be used to load / fetch the basic dataset items, like the list of `image, label` paths. + Or chain together to execute more complicated logic, like `partition_dataset`, `resample_datalist`, etc. + The `data` arg of `Dataset` will be applied to the first arg of callable `func`. + Usage example:: + + data_list = DatasetFunc( + data="path to file", + func=monai.data.load_decathlon_datalist, + data_list_key="validation", + base_dir="path to base dir", + ) + # partition dataset for every rank + data_partition = DatasetFunc( + data=data_list, + func=lambda **kwargs: monai.data.partition_dataset(**kwargs)[torch.distributed.get_rank()], + num_partitions=torch.distributed.get_world_size(), + ) + dataset = Dataset(data=data_partition, transform=transforms) + + Args: + data: input data for the func to process, will apply to `func` as the first arg. + func: callable function to generate dataset items. + kwargs: other arguments for the `func` except for the first arg. + + """ + + def __init__(self, data: Any, func: Callable, **kwargs) -> None: + super().__init__(data=None, transform=None) # type:ignore + self.src = data + self.func = func + self.kwargs = kwargs + self.reset() + + def reset(self, data: Optional[Any] = None, func: Optional[Callable] = None, **kwargs): + """ + Reset the dataset items with specified `func`. + + Args: + data: if not None, execute `func` on it, default to `self.src`. + func: if not None, execute the `func` with specified `kwargs`, default to `self.func`. + kwargs: other arguments for the `func` except for the first arg. + + """ + src = self.src if data is None else data + self.data = self.func(src, **self.kwargs) if func is None else func(src, **kwargs) + + class PersistentDataset(Dataset): """ Persistent storage of pre-computed values to efficiently manage larger than memory dictionary format data, @@ -151,6 +203,8 @@ def __init__( transform: Union[Sequence[Callable], Callable], cache_dir: Optional[Union[Path, str]], hash_func: Callable[..., bytes] = pickle_hashing, + pickle_module: str = "pickle", + pickle_protocol: int = DEFAULT_PROTOCOL, ) -> None: """ Args: @@ -166,6 +220,18 @@ def __init__( If `cache_dir` is `None`, there is effectively no caching. hash_func: a callable to compute hash from data items to be cached. defaults to `monai.data.utils.pickle_hashing`. + pickle_module: string representing the module used for pickling metadata and objects, + default to `"pickle"`. due to the pickle limitation in multi-processing of Dataloader, + we can't use `pickle` as arg directly, so here we use a string name instead. + if want to use other pickle module at runtime, just register like: + >>> from monai.data import utils + >>> utils.SUPPORTED_PICKLE_MOD["test"] = other_pickle + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, + and ``monai.data.utils.SUPPORTED_PICKLE_MOD``. + pickle_protocol: can be specified to override the default protocol, default to `2`. + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. """ if not isinstance(transform, Compose): @@ -173,6 +239,8 @@ def __init__( super().__init__(data=data, transform=transform) self.cache_dir = Path(cache_dir) if cache_dir is not None else None self.hash_func = hash_func + self.pickle_module = pickle_module + self.pickle_protocol = pickle_protocol if self.cache_dir is not None: if not self.cache_dir.exists(): self.cache_dir.mkdir(parents=True, exist_ok=True) @@ -273,7 +341,12 @@ def _cachecheck(self, item_transformed): # which may leave partially written cache files in an incomplete state with tempfile.TemporaryDirectory() as tmpdirname: temp_hash_file = Path(tmpdirname) / hashfile.name - torch.save(_item_transformed, temp_hash_file) + torch.save( + obj=_item_transformed, + f=temp_hash_file, + pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), + pickle_protocol=self.pickle_protocol, + ) if temp_hash_file.is_file() and not hashfile.is_file(): # On Unix, if target exists and is a file, it will be replaced silently if the user has permission. # for more details: https://docs.python.org/3/library/shutil.html#shutil.move. @@ -301,6 +374,8 @@ def __init__( cache_n_trans: int, cache_dir: Optional[Union[Path, str]], hash_func: Callable[..., bytes] = pickle_hashing, + pickle_module: str = "pickle", + pickle_protocol: int = DEFAULT_PROTOCOL, ) -> None: """ Args: @@ -317,9 +392,28 @@ def __init__( If `cache_dir` is `None`, there is effectively no caching. hash_func: a callable to compute hash from data items to be cached. defaults to `monai.data.utils.pickle_hashing`. - - """ - super().__init__(data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func) + pickle_module: string representing the module used for pickling metadata and objects, + default to `"pickle"`. due to the pickle limitation in multi-processing of Dataloader, + we can't use `pickle` as arg directly, so here we use a string name instead. + if want to use other pickle module at runtime, just register like: + >>> from monai.data import utils + >>> utils.SUPPORTED_PICKLE_MOD["test"] = other_pickle + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, + and ``monai.data.utils.SUPPORTED_PICKLE_MOD``. + pickle_protocol: can be specified to override the default protocol, default to `2`. + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. + + """ + super().__init__( + data=data, + transform=transform, + cache_dir=cache_dir, + hash_func=hash_func, + pickle_module=pickle_module, + pickle_protocol=pickle_protocol, + ) self.cache_n_trans = cache_n_trans def _pre_transform(self, item_transformed): @@ -406,12 +500,13 @@ def __init__( lmdb_kwargs: additional keyword arguments to the lmdb environment. for more details please visit: https://lmdb.readthedocs.io/en/release/#environment-class """ - super().__init__(data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func) + super().__init__( + data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func, pickle_protocol=pickle_protocol + ) self.progress = progress if not self.cache_dir: raise ValueError("cache_dir must be specified.") self.db_file = self.cache_dir / f"{db_name}.lmdb" - self.pickle_protocol = pickle_protocol self.lmdb_kwargs = lmdb_kwargs or {} if not self.lmdb_kwargs.get("map_size", 0): self.lmdb_kwargs["map_size"] = 1024 ** 4 # default map_size @@ -575,6 +670,7 @@ def __init__( cache_rate: float = 1.0, num_workers: Optional[int] = None, progress: bool = True, + copy_cache: bool = True, ) -> None: """ Args: @@ -587,11 +683,17 @@ def __init__( num_workers: the number of worker processes to use. If num_workers is None then the number returned by os.cpu_count() is used. progress: whether to display a progress bar. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cached content + (for example, randomly crop from the cached image and deepcopy the crop region) + or if every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. """ if not isinstance(transform, Compose): transform = Compose(transform) super().__init__(data=data, transform=transform) self.progress = progress + self.copy_cache = copy_cache self.cache_num = min(int(cache_num), int(len(data) * cache_rate), len(data)) self.num_workers = num_workers if self.num_workers is not None: @@ -656,7 +758,8 @@ def _transform(self, index: int): # only need to deep copy data on first non-deterministic transform if not start_run: start_run = True - data = deepcopy(data) + if self.copy_cache: + data = deepcopy(data) data = apply_transform(_transform, data) return data @@ -722,6 +825,10 @@ class SmartCacheDataset(Randomizable, CacheDataset): shuffle: whether to shuffle the whole data list before preparing the cache content for first epoch. it will not modify the original input data sequence in-place. seed: random seed if shuffle is `True`, default to `0`. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cache content + or every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. """ def __init__( @@ -736,6 +843,7 @@ def __init__( progress: bool = True, shuffle: bool = True, seed: int = 0, + copy_cache: bool = True, ) -> None: if shuffle: self.set_random_state(seed=seed) @@ -743,7 +851,7 @@ def __init__( self.randomize(data) self.shuffle = shuffle - super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress) + super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress, copy_cache) if self._cache is None: self._cache = self._fill_cache() if self.cache_num >= len(data): @@ -977,7 +1085,7 @@ def __init__(self, datasets: Sequence, transform: Optional[Callable] = None) -> super().__init__(list(datasets), transform=transform) def __len__(self) -> int: - return min((len(dataset) for dataset in self.data)) + return min(len(dataset) for dataset in self.data) def _transform(self, index: int): def to_list(x): @@ -1164,8 +1272,9 @@ class CSVDataset(Dataset): ] Args: - filename: the filename of expected CSV file to load. if providing a list - of filenames, it will load all the files and join tables. + src: if provided the filename of CSV file, it can be a str, URL, path object or file-like object to load. + also support to provide pandas `DataFrame` directly, will skip loading from filename. + if provided a list of filenames or pandas `DataFrame`, it will join the tables. row_indices: indices of the expected rows to load. it should be a list, every item can be a int number or a range `[start, end)` for the indices. for example: `row_indices=[[0, 100], 200, 201, 202, 300]`. if None, @@ -1191,11 +1300,15 @@ class CSVDataset(Dataset): transform: transform to apply on the loaded items of a dictionary data. kwargs: additional arguments for `pandas.merge()` API to join tables. + .. deprecated:: 0.8.0 + ``filename`` is deprecated, use ``src`` instead. + """ + @deprecated_arg(name="filename", new_name="src", since="0.8", msg_suffix="please use `src` instead.") def __init__( self, - filename: Union[str, Sequence[str]], + src: Optional[Union[str, Sequence[str]]] = None, # also can be `DataFrame` or sequense of `DataFrame` row_indices: Optional[Sequence[Union[int, str]]] = None, col_names: Optional[Sequence[str]] = None, col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None, @@ -1203,14 +1316,20 @@ def __init__( transform: Optional[Callable] = None, **kwargs, ): - files = ensure_tuple(filename) - dfs = [pd.read_csv(f) for f in files] + srcs = (src,) if not isinstance(src, (tuple, list)) else src + dfs: List = [] + for i in srcs: + if isinstance(i, str): + dfs.append(pd.read_csv(i)) + elif isinstance(i, pd.DataFrame): + dfs.append(i) + else: + raise ValueError("`src` must be file path or pandas `DataFrame`.") + + # in case treating deprecated arg `filename` as kwargs, remove it from `kwargs` + kwargs.pop("filename", None) + data = convert_tables_to_dicts( - dfs=dfs, - row_indices=row_indices, - col_names=col_names, - col_types=col_types, - col_groups=col_groups, - **kwargs, + dfs=dfs, row_indices=row_indices, col_names=col_names, col_types=col_types, col_groups=col_groups, **kwargs ) super().__init__(data=data, transform=transform) diff --git a/monai/data/dataset_summary.py b/monai/data/dataset_summary.py index a8598eb6c8..dd8a94143b 100644 --- a/monai/data/dataset_summary.py +++ b/monai/data/dataset_summary.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,8 +15,11 @@ import numpy as np import torch +from monai.config import KeysCollection from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset +from monai.transforms import concatenate +from monai.utils import convert_data_type class DatasetSummary: @@ -38,6 +41,7 @@ def __init__( dataset: Dataset, image_key: Optional[str] = "image", label_key: Optional[str] = "label", + meta_key: Optional[KeysCollection] = None, meta_key_postfix: str = "meta_dict", num_workers: int = 0, **kwargs, @@ -47,11 +51,16 @@ def __init__( dataset: dataset from which to load the data. image_key: key name of images (default: ``image``). label_key: key name of labels (default: ``label``). + meta_key: explicitly indicate the key of the corresponding meta data dictionary. + for example, for data with key `image`, the metadata by default is in `image_meta_dict`. + the meta data is a dictionary object which contains: filename, affine, original_shape, etc. + if None, will try to construct meta_keys by `{image_key}_{meta_key_postfix}`. meta_key_postfix: use `{image_key}_{meta_key_postfix}` to fetch the meta data from dict, the meta data is a dictionary object (default: ``meta_dict``). num_workers: how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process (default: ``0``). - kwargs: other parameters (except batch_size) for DataLoader (this class forces to use ``batch_size=1``). + kwargs: other parameters (except `batch_size` and `num_workers`) for DataLoader, + this class forces to use ``batch_size=1``. """ @@ -59,18 +68,17 @@ def __init__( self.image_key = image_key self.label_key = label_key - if image_key: - self.meta_key = "{}_{}".format(image_key, meta_key_postfix) + self.meta_key = meta_key or f"{image_key}_{meta_key_postfix}" self.all_meta_data: List = [] def collect_meta_data(self): """ This function is used to collect the meta data for all images of the dataset. """ - if not self.meta_key: - raise ValueError("To collect meta data for the dataset, `meta_key` should exist.") for data in self.data_loader: + if self.meta_key not in data: + raise ValueError(f"To collect meta data for the dataset, key `{self.meta_key}` must exist in `data`.") self.all_meta_data.append(data[self.meta_key]) def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: int = 3, percentile: float = 10.0): @@ -78,8 +86,8 @@ def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: Calculate the target spacing according to all spacings. If the target spacing is very anisotropic, decrease the spacing value of the maximum axis according to percentile. - So far, this function only supports NIFTI images which store spacings in headers with key "pixdim". After loading - with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`. + So far, this function only supports NIFTI images which store spacings in headers with key "pixdim". + After loading with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`. Args: spacing_key: key of spacing in meta data (default: ``pixdim``). @@ -92,8 +100,8 @@ def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: self.collect_meta_data() if spacing_key not in self.all_meta_data[0]: raise ValueError("The provided spacing_key is not in self.all_meta_data.") - - all_spacings = torch.cat([data[spacing_key][:, 1:4] for data in self.all_meta_data], dim=0).numpy() + all_spacings = concatenate(to_cat=[data[spacing_key][:, 1:4] for data in self.all_meta_data], axis=0) + all_spacings, *_ = convert_data_type(data=all_spacings, output_type=np.ndarray, wrap_sequence=True) target_spacing = np.median(all_spacings, axis=0) if max(target_spacing) / min(target_spacing) >= anisotropic_threshold: @@ -126,6 +134,8 @@ def calculate_statistics(self, foreground_threshold: int = 0): image, label = data[self.image_key], data[self.label_key] else: image, label = data + image, *_ = convert_data_type(data=image, output_type=torch.Tensor) + label, *_ = convert_data_type(data=label, output_type=torch.Tensor) voxel_max.append(image.max().item()) voxel_min.append(image.min().item()) @@ -169,6 +179,8 @@ def calculate_percentiles( image, label = data[self.image_key], data[self.label_key] else: image, label = data + image, *_ = convert_data_type(data=image, output_type=torch.Tensor) + label, *_ = convert_data_type(data=label, output_type=torch.Tensor) intensities = image[torch.where(label > foreground_threshold)].tolist() if sampling_flag: diff --git a/monai/data/decathlon_datalist.py b/monai/data/decathlon_datalist.py index 663b68a08e..d2a9c3d220 100644 --- a/monai/data/decathlon_datalist.py +++ b/monai/data/decathlon_datalist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,18 +11,22 @@ import json import os +import warnings +from pathlib import Path from typing import Dict, List, Optional, Sequence, Union, overload +from monai.config import KeysCollection, PathLike +from monai.data.utils import partition_dataset, select_cross_validation_folds from monai.utils import ensure_tuple @overload -def _compute_path(base_dir: str, element: str, check_path: bool = False) -> str: +def _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = False) -> str: ... @overload -def _compute_path(base_dir: str, element: List[str], check_path: bool = False) -> List[str]: +def _compute_path(base_dir: PathLike, element: List[PathLike], check_path: bool = False) -> List[str]: ... @@ -39,24 +43,24 @@ def _compute_path(base_dir, element, check_path=False): """ - def _join_path(base_dir: str, item: str): + def _join_path(base_dir: PathLike, item: PathLike): result = os.path.normpath(os.path.join(base_dir, item)) if check_path and not os.path.exists(result): # if not an existing path, don't join with base dir - return item - return result + return f"{item}" + return f"{result}" - if isinstance(element, str): + if isinstance(element, (str, os.PathLike)): return _join_path(base_dir, element) if isinstance(element, list): for e in element: - if not isinstance(e, str): + if not isinstance(e, (str, os.PathLike)): return element return [_join_path(base_dir, e) for e in element] return element -def _append_paths(base_dir: str, is_segmentation: bool, items: List[Dict]) -> List[Dict]: +def _append_paths(base_dir: PathLike, is_segmentation: bool, items: List[Dict]) -> List[Dict]: """ Args: base_dir: the base directory of the dataset. @@ -80,10 +84,10 @@ def _append_paths(base_dir: str, is_segmentation: bool, items: List[Dict]) -> Li def load_decathlon_datalist( - data_list_file_path: str, + data_list_file_path: PathLike, is_segmentation: bool = True, data_list_key: str = "training", - base_dir: Optional[str] = None, + base_dir: Optional[PathLike] = None, ) -> List[Dict]: """Load image/label paths of decathlon challenge from JSON file @@ -110,7 +114,8 @@ def load_decathlon_datalist( ] """ - if not os.path.isfile(data_list_file_path): + data_list_file_path = Path(data_list_file_path) + if not data_list_file_path.is_file(): raise ValueError(f"Data list file {data_list_file_path} does not exist.") with open(data_list_file_path) as json_file: json_data = json.load(json_file) @@ -121,15 +126,12 @@ def load_decathlon_datalist( expected_data = [{"image": i} for i in expected_data] if base_dir is None: - base_dir = os.path.dirname(data_list_file_path) + base_dir = data_list_file_path.parent return _append_paths(base_dir, is_segmentation, expected_data) -def load_decathlon_properties( - data_property_file_path: str, - property_keys: Union[Sequence[str], str], -) -> Dict: +def load_decathlon_properties(data_property_file_path: PathLike, property_keys: Union[Sequence[str], str]) -> Dict: """Load the properties from the JSON file contains data property with specified `property_keys`. Args: @@ -140,7 +142,8 @@ def load_decathlon_properties( `modality`, `labels`, `numTraining`, `numTest`, etc. """ - if not os.path.isfile(data_property_file_path): + data_property_file_path = Path(data_property_file_path) + if not data_property_file_path.is_file(): raise ValueError(f"Data property file {data_property_file_path} does not exist.") with open(data_property_file_path) as json_file: json_data = json.load(json_file) @@ -151,3 +154,97 @@ def load_decathlon_properties( raise KeyError(f"key {key} is not in the data property file.") properties[key] = json_data[key] return properties + + +def check_missing_files( + datalist: List[Dict], keys: KeysCollection, root_dir: Optional[PathLike] = None, allow_missing_keys: bool = False +): + """Checks whether some files in the Decathlon datalist are missing. + It would be helpful to check missing files before a heavy training run. + + Args: + datalist: a list of data items, every item is a dictionary. + usually generated by `load_decathlon_datalist` API. + keys: expected keys to check in the datalist. + root_dir: if not None, provides the root dir for the relative file paths in `datalist`. + allow_missing_keys: whether allow missing keys in the datalist items. + if False, raise exception if missing. default to False. + + Returns: + A list of missing filenames. + + """ + missing_files = [] + for item in datalist: + for k in ensure_tuple(keys): + if k not in item: + if not allow_missing_keys: + raise ValueError(f"key `{k}` is missing in the datalist item: {item}") + continue + + for f in ensure_tuple(item[k]): + if not isinstance(f, (str, os.PathLike)): + raise ValueError(f"filepath of key `{k}` must be a string or a list of strings, but got: {f}.") + f = Path(f) + if isinstance(root_dir, (str, os.PathLike)): + f = Path(root_dir).joinpath(f) + if not f.exists(): + missing_files.append(f) + + return missing_files + + +def create_cross_validation_datalist( + datalist: List[Dict], + nfolds: int, + train_folds: Union[Sequence[int], int], + val_folds: Union[Sequence[int], int], + train_key: str = "training", + val_key: str = "validation", + filename: Optional[Union[Path, str]] = None, + shuffle: bool = True, + seed: int = 0, + check_missing: bool = False, + keys: Optional[KeysCollection] = None, + root_dir: Optional[str] = None, + allow_missing_keys: bool = False, + raise_error: bool = True, +): + """ + Utility to create new Decathlon style datalist based on cross validation partition. + + Args: + datalist: loaded list of dictionaries for all the items to partition. + nfolds: number of the kfold split. + train_folds: indices of folds for training part. + val_folds: indices of folds for validation part. + train_key: the key of train part in the new datalist, defaults to "training". + val_key: the key of validation part in the new datalist, defaults to "validation". + filename: if not None and ends with ".json", save the new datalist into JSON file. + shuffle: whether to shuffle the datalist before partition, defaults to `True`. + seed: if `shuffle` is True, set the random seed, defaults to `0`. + check_missing: whether to check all the files specified by `keys` are existing. + keys: if not None and check_missing_files is True, the expected keys to check in the datalist. + root_dir: if not None, provides the root dir for the relative file paths in `datalist`. + allow_missing_keys: if check_missing_files is `True`, whether allow missing keys in the datalist items. + if False, raise exception if missing. default to False. + raise_error: when found missing files, if `True`, raise exception and stop, if `False`, print warning. + + """ + if check_missing and keys is not None: + files = check_missing_files(datalist, keys, root_dir, allow_missing_keys) + if files: + msg = f"some files of the datalist are missing: {files}" + if raise_error: + raise ValueError(msg) + warnings.warn(msg) + + data = partition_dataset(data=datalist, num_partitions=nfolds, shuffle=shuffle, seed=seed) + train_list = select_cross_validation_folds(partitions=data, folds=train_folds) + val_list = select_cross_validation_folds(partitions=data, folds=val_folds) + ret = {train_key: train_list, val_key: val_list} + if isinstance(filename, (str, Path)): + with open(filename, "w") as f: + json.dump(ret, f, indent=4) + + return ret diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 5b2a4d7abd..7a0f79d00e 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -141,7 +141,7 @@ def __iter__(self): try: iter_end = len(self.dataset) # TODO: support iterable self.dataset except TypeError: - raise NotImplementedError("image dataset must implement `len()`.") + raise NotImplementedError("image dataset must implement `len()`.") from None if worker_info is not None: # split workload diff --git a/monai/data/image_dataset.py b/monai/data/image_dataset.py index 874b9dc004..0ab71cd444 100644 --- a/monai/data/image_dataset.py +++ b/monai/data/image_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index cd1486d6d3..501518dc04 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,37 +9,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import warnings from abc import ABC, abstractmethod +from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern -from monai.config import DtypeLike, KeysCollection +from monai.config import DtypeLike, KeysCollection, PathLike from monai.data.utils import correct_nifti_header_if_necessary from monai.transforms.utility.array import EnsureChannelFirst -from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import +from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import, require_pkg from .utils import is_supported_format if TYPE_CHECKING: - import cucim import itk # type: ignore import nibabel as nib - import openslide from nibabel.nifti1 import Nifti1Image from PIL import Image as PILImage - has_itk = has_nib = has_pil = has_cim = has_osl = True + has_itk = has_nib = has_pil = True else: itk, has_itk = optional_import("itk", allow_namespace_pkg=True) nib, has_nib = optional_import("nibabel") Nifti1Image, _ = optional_import("nibabel.nifti1", name="Nifti1Image") PILImage, has_pil = optional_import("PIL.Image") - cucim, has_cim = optional_import("cucim") - openslide, has_osl = optional_import("openslide") + +OpenSlide, _ = optional_import("openslide", name="OpenSlide") +CuImage, _ = optional_import("cucim", name="CuImage") +TiffFile, _ = optional_import("tifffile", name="TiffFile") __all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "WSIReader"] @@ -64,7 +64,7 @@ class ImageReader(ABC): """ @abstractmethod - def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ Verify whether the specified `filename` is supported by the current reader. This method should return True if the reader is able to read the format suggested by the @@ -78,7 +78,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") @abstractmethod - def read(self, data: Union[Sequence[str], str], **kwargs) -> Union[Sequence[Any], Any]: + def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs) -> Union[Sequence[Any], Any]: """ Read image data from specified file or files. Note that it returns a data object or a sequence of data objects. @@ -136,6 +136,7 @@ def _stack_images(image_list: List, meta_dict: Dict): return np.stack(image_list, axis=0) +@require_pkg(pkg_name="itk") class ITKReader(ImageReader): """ Load medical images based on ITK library. @@ -154,18 +155,25 @@ class ITKReader(ImageReader): series_name: the name of the DICOM series if there are multiple ones. used when loading DICOM series. + reverse_indexing: whether to use a reversed spatial indexing convention for the returned data array. + If ``False``, the spatial indexing follows the numpy convention; + otherwise, the spatial indexing convention is reversed to be compatible with ITK. Default is ``False``. + This option does not affect the metadata. kwargs: additional args for `itk.imread` API. more details about available args: https://github.com/InsightSoftwareConsortium/ITK/blob/master/Wrapping/Generators/Python/itk/support/extras.py """ - def __init__(self, channel_dim: Optional[int] = None, series_name: str = "", **kwargs): + def __init__( + self, channel_dim: Optional[int] = None, series_name: str = "", reverse_indexing: bool = False, **kwargs + ): super().__init__() self.kwargs = kwargs self.channel_dim = channel_dim self.series_name = series_name + self.reverse_indexing = reverse_indexing - def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ Verify whether the specified file or files format is supported by ITK reader. @@ -176,7 +184,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: """ return has_itk - def read(self, data: Union[Sequence[str], str], **kwargs): + def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): """ Read image data from specified file or files, it can read a list of `no-channel` images and stack them together as multi-channels data in `get_data()`. @@ -192,11 +200,12 @@ def read(self, data: Union[Sequence[str], str], **kwargs): """ img_ = [] - filenames: Sequence[str] = ensure_tuple(data) + filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: - if os.path.isdir(name): + name = f"{name}" + if Path(name).is_dir(): # read DICOM series # https://itk.org/ITKExamples/src/IO/GDCM/ReadDICOMSeriesAndWrite3DImage names_generator = itk.GDCMSeriesFileNames.New() @@ -306,39 +315,50 @@ def _get_array_data(self, img): Following PyTorch conventions, the returned array data has contiguous channels, e.g. for an RGB image, all red channel image pixels are contiguous in memory. - The first axis of the returned array is the channel axis. + The last axis of the returned array is the channel axis. + + See also: + + - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Modules/Bridge/NumPy/wrapping/PyBuffer.i.in Args: img: an ITK image object loaded from an image file. """ - channels = img.GetNumberOfComponentsPerPixel() - np_data = itk.array_view_from_image(img).T - if channels == 1: - return np_data - if channels != np_data.shape[0]: - warnings.warn("itk_img.GetNumberOfComponentsPerPixel != numpy data channels") - return np.moveaxis(np_data, 0, -1) # channel last is compatible with `write_nifti` + np_img = itk.array_view_from_image(img, keep_axes=False) + if img.GetNumberOfComponentsPerPixel() == 1: # handling spatial images + return np_img if self.reverse_indexing else np_img.T + # handling multi-channel images + return np_img if self.reverse_indexing else np.moveaxis(np_img.T, 0, -1) +@require_pkg(pkg_name="nibabel") class NibabelReader(ImageReader): """ Load NIfTI format images based on Nibabel library. Args: as_closest_canonical: if True, load the image as closest to canonical axis format. + squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3) kwargs: additional args for `nibabel.load` API. more details about available args: https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py """ - def __init__(self, as_closest_canonical: bool = False, dtype: DtypeLike = np.float32, **kwargs): + def __init__( + self, + as_closest_canonical: bool = False, + squeeze_non_spatial_dims: bool = False, + dtype: DtypeLike = np.float32, + **kwargs, + ): super().__init__() self.as_closest_canonical = as_closest_canonical + self.squeeze_non_spatial_dims = squeeze_non_spatial_dims self.dtype = dtype self.kwargs = kwargs - def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ Verify whether the specified file or files format is supported by Nibabel reader. @@ -350,7 +370,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: suffixes: Sequence[str] = ["nii", "nii.gz"] return has_nib and is_supported_format(filename, suffixes) - def read(self, data: Union[Sequence[str], str], **kwargs): + def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): """ Read image data from specified file or files, it can read a list of `no-channel` images and stack them together as multi-channels data in `get_data()`. @@ -365,7 +385,7 @@ def read(self, data: Union[Sequence[str], str], **kwargs): """ img_: List[Nifti1Image] = [] - filenames: Sequence[str] = ensure_tuple(data) + filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -399,6 +419,10 @@ def get_data(self, img): header["affine"] = self._get_affine(i) header["spatial_shape"] = self._get_spatial_shape(i) data = self._get_array_data(i) + if self.squeeze_non_spatial_dims: + for d in range(len(data.shape), len(header["spatial_shape"]), -1): + if data.shape[d - 1] == 1: + data = data.squeeze(axis=d - 1) img_array.append(data) header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else -1 _copy_compatible_dict(header, compatible_meta) @@ -475,19 +499,21 @@ class NumpyReader(ImageReader): Args: npz_keys: if loading npz file, only load the specified keys, if None, load all the items. stack the loaded items together to construct a new first dimension. + channel_dim: if not None, explicitly specify the channel dim, otherwise, treat the array as no channel. kwargs: additional args for `numpy.load` API except `allow_pickle`. more details about available args: https://numpy.org/doc/stable/reference/generated/numpy.load.html """ - def __init__(self, npz_keys: Optional[KeysCollection] = None, **kwargs): + def __init__(self, npz_keys: Optional[KeysCollection] = None, channel_dim: Optional[int] = None, **kwargs): super().__init__() if npz_keys is not None: npz_keys = ensure_tuple(npz_keys) self.npz_keys = npz_keys + self.channel_dim = channel_dim self.kwargs = kwargs - def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ Verify whether the specified file or files format is supported by Numpy reader. @@ -498,7 +524,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: suffixes: Sequence[str] = ["npz", "npy"] return is_supported_format(filename, suffixes) - def read(self, data: Union[Sequence[str], str], **kwargs): + def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): """ Read image data from specified file or files, it can read a list of `no-channel` data files and stack them together as multi-channels data in `get_data()`. @@ -513,12 +539,12 @@ def read(self, data: Union[Sequence[str], str], **kwargs): """ img_: List[Nifti1Image] = [] - filenames: Sequence[str] = ensure_tuple(data) + filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: img = np.load(name, allow_pickle=True, **kwargs_) - if name.endswith(".npz"): + if Path(name).name.endswith(".npz"): # load expected items from NPZ file npz_keys = [f"arr_{i}" for i in range(len(img))] if self.npz_keys is None else self.npz_keys for k in npz_keys: @@ -548,14 +574,19 @@ def get_data(self, img): for i in ensure_tuple(img): header = {} if isinstance(i, np.ndarray): - # can not detect the channel dim of numpy array, use all the dims as spatial_shape - header["spatial_shape"] = i.shape + # if `channel_dim` is None, can not detect the channel dim, use all the dims as spatial_shape + spatial_shape = np.asarray(i.shape) + if isinstance(self.channel_dim, int): + spatial_shape = np.delete(spatial_shape, self.channel_dim) + header["spatial_shape"] = spatial_shape img_array.append(i) + header["original_channel_dim"] = self.channel_dim if isinstance(self.channel_dim, int) else "no_channel" _copy_compatible_dict(header, compatible_meta) return _stack_images(img_array, compatible_meta), compatible_meta +@require_pkg(pkg_name="PIL") class PILReader(ImageReader): """ Load common 2D image format (supports PNG, JPG, BMP) file or files from provided path. @@ -572,7 +603,7 @@ def __init__(self, converter: Optional[Callable] = None, **kwargs): self.converter = converter self.kwargs = kwargs - def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ Verify whether the specified file or files format is supported by PIL reader. @@ -583,7 +614,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: suffixes: Sequence[str] = ["png", "jpg", "jpeg", "bmp"] return has_pil and is_supported_format(filename, suffixes) - def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs): + def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): """ Read image data from specified file or files, it can read a list of `no-channel` images and stack them together as multi-channels data in `get_data()`. @@ -598,7 +629,7 @@ def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs): """ img_: List[PILImage.Image] = [] - filenames: Sequence[str] = ensure_tuple(data) + filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -616,6 +647,8 @@ def get_data(self, img): It computes `spatial_shape` and stores it in meta dict. When loading a list of files, they are stacked together at a new dimension as the first dimension, and the meta data of the first image is used to represent the output meta data. + Note that it will switch axis 0 and 1 after loading the array because the `HW` definition in PIL + is different from other common medical packages. Args: img: a PIL Image object loaded from a file or a list of PIL Image objects. @@ -641,12 +674,7 @@ def _get_meta_dict(self, img) -> Dict: img: a PIL Image object loaded from an image file. """ - return { - "format": img.format, - "mode": img.mode, - "width": img.width, - "height": img.height, - } + return {"format": img.format, "mode": img.mode, "width": img.width, "height": img.height} def _get_spatial_shape(self, img): """ @@ -659,26 +687,38 @@ def _get_spatial_shape(self, img): class WSIReader(ImageReader): """ - Read whole slide imaging and extract patches. + Read whole slide images and extract patches. Args: - reader_lib: backend library to load the images, available options: "OpenSlide" or "cuCIM". - + backend: backend library to load the images, available options: "cuCIM", "OpenSlide" and "TiffFile". + level: the whole slide image level at which the image is extracted. (default=0) + This is overridden if the level argument is provided in `get_data`. + + Note: + While "cuCIM" and "OpenSlide" backends both can load patches from large whole slide images + without loading the entire image into memory, "TiffFile" backend needs to load the entire image into memory + before extracting any patch; thus, memory consideration is needed when using "TiffFile" backend for + patch extraction. """ - def __init__(self, reader_lib: str = "OpenSlide"): + def __init__(self, backend: str = "OpenSlide", level: int = 0): super().__init__() - self.reader_lib = reader_lib.lower() - if self.reader_lib == "openslide": - if has_osl: - self.wsi_reader = openslide.OpenSlide - elif self.reader_lib == "cucim": - if has_cim: - self.wsi_reader = cucim.CuImage - else: - raise ValueError('`reader_lib` should be either "cuCIM" or "OpenSlide"') - - def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: + self.backend = backend.lower() + func = require_pkg(self.backend)(self._set_reader) + self.wsi_reader = func(self.backend) + self.level = level + + @staticmethod + def _set_reader(backend: str): + if backend == "openslide": + return OpenSlide + if backend == "cucim": + return CuImage + if backend == "tifffile": + return TiffFile + raise ValueError("`backend` should be 'cuCIM', 'OpenSlide' or 'TiffFile'.") + + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ Verify whether the specified file or files format is supported by WSI reader. @@ -688,26 +728,23 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: """ return is_supported_format(filename, ["tif", "tiff"]) - def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs): + def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): """ - Read image data from specified file or files. - Note that the returned object is CuImage or list of CuImage objects. + Read image data from given file or list of files. Args: data: file name or a list of file names to read. - """ - if (self.reader_lib == "openslide") and (not has_osl): - raise ImportError("No module named 'openslide'") - if (self.reader_lib == "cucim") and (not has_cim): - raise ImportError("No module named 'cucim'") + Returns: + image object or list of image objects + """ img_: List = [] - filenames: Sequence[str] = ensure_tuple(data) + filenames: Sequence[PathLike] = ensure_tuple(data) for name in filenames: img = self.wsi_reader(name) - if self.reader_lib == "openslide": + if self.backend == "openslide": img.shape = (img.dimensions[1], img.dimensions[0], 3) img_.append(img) @@ -718,7 +755,7 @@ def get_data( img, location: Tuple[int, int] = (0, 0), size: Optional[Tuple[int, int]] = None, - level: int = 0, + level: Optional[int] = None, dtype: DtypeLike = np.uint8, grid_shape: Tuple[int, int] = (1, 1), patch_size: Optional[Union[int, Tuple[int, int]]] = None, @@ -737,33 +774,71 @@ def get_data( grid_shape: (row, columns) tuple define a grid to extract patches on that patch_size: (height, width) the size of extracted patches at the given level """ + # Verify inputs + if level is None: + level = self._check_level(img, level) - if self.reader_lib == "openslide" and size is None: - # the maximum size is set to WxH - size = ( - img.shape[0] // (2 ** level) - location[0], - img.shape[1] // (2 ** level) - location[1], - ) - + # Extract a region or the entire image region = self._extract_region(img, location=location, size=size, level=level, dtype=dtype) + # Add necessary metadata metadata: Dict = {} - metadata["spatial_shape"] = size + metadata["spatial_shape"] = np.asarray(region.shape[:-1]) metadata["original_channel_dim"] = -1 + + # Make it channel first region = EnsureChannelFirst()(region, metadata) + + # Split into patches if patch_size is None: patches = region else: tuple_patch_size = ensure_tuple_rep(patch_size, 2) patches = self._extract_patches( - region, - patch_size=tuple_patch_size, # type: ignore - grid_shape=grid_shape, - dtype=dtype, + region, patch_size=tuple_patch_size, grid_shape=grid_shape, dtype=dtype # type: ignore ) return patches, metadata + def _check_level(self, img, level): + level = self.level + + level_count = 0 + if self.backend == "openslide": + level_count = img.level_count + elif self.backend == "cucim": + level_count = img.resolutions["level_count"] + elif self.backend == "tifffile": + level_count = len(img.pages) + + if level > level_count - 1: + raise ValueError(f"The maximum level of this image is {level_count - 1} while level={level} is requested)!") + + return level + + def _get_image_size(self, img, size, level, location): + """ + Calculate the maximum region size for the given level and starting location (if size is None). + Note that region size in OpenSlide and cuCIM are WxH (but the final image output would be HxW) + """ + if size is not None: + return size[::-1] + + max_size = [] + downsampling_factor = [] + if self.backend == "openslide": + downsampling_factor = img.level_downsamples[level] + max_size = img.level_dimensions[level] + elif self.backend == "cucim": + downsampling_factor = img.resolutions["level_downsamples"][level] + max_size = img.resolutions["level_dimensions"][level] + + # subtract the top left corner of the patch (at given level) from maximum size + location_at_level = (round(location[1] / downsampling_factor), round(location[0] / downsampling_factor)) + size = [max_size[i] - location_at_level[i] for i in range(len(max_size))] + + return size + def _extract_region( self, img_obj, @@ -772,35 +847,42 @@ def _extract_region( level: int = 0, dtype: DtypeLike = np.uint8, ): - # reverse the order of dimensions for size and location to be compatible with image shape - location = location[::-1] - if size is None: - region = img_obj.read_region(location=location, level=level) + if self.backend == "tifffile": + # Read the entire image + if size is not None: + raise ValueError( + f"TiffFile backend reads the entire image only, so size '{size}'' should not be provided!", + "For more flexibility or extracting regions, please use cuCIM or OpenSlide backend.", + ) + if location != (0, 0): + raise ValueError( + f"TiffFile backend reads the entire image only, so location '{location}' should not be provided!", + "For more flexibility and extracting regions, please use cuCIM or OpenSlide backend.", + ) + region = img_obj.asarray(level=level) else: - size = size[::-1] - region = img_obj.read_region(location=location, size=size, level=level) + # Get region size to be extracted + region_size = self._get_image_size(img_obj, size, level, location) + # reverse the order of location's dimensions to become WxH (for cuCIM and OpenSlide) + region_location = location[::-1] + # Extract a region (or the entire image) + region = img_obj.read_region(location=region_location, size=region_size, level=level) region = self.convert_to_rgb_array(region, dtype) return region - def convert_to_rgb_array( - self, - raw_region, - dtype: DtypeLike = np.uint8, - ): + def convert_to_rgb_array(self, raw_region, dtype: DtypeLike = np.uint8): """Convert to RGB mode and numpy array""" - if self.reader_lib == "openslide": + if self.backend == "openslide": # convert to RGB raw_region = raw_region.convert("RGB") - # convert to numpy - raw_region = np.asarray(raw_region, dtype=dtype) - else: - num_channels = len(raw_region.channel_names) - # convert to numpy - raw_region = np.asarray(raw_region, dtype=dtype) - # remove alpha channel if exist (RGBA) - if num_channels > 3: - raw_region = raw_region[:, :, :3] + + # convert to numpy (if not already in numpy) + raw_region = np.asarray(raw_region, dtype=dtype) + + # remove alpha channel if exist (RGBA) + if raw_region.shape[-1] > 3: + raw_region = raw_region[..., :3] return raw_region diff --git a/monai/data/iterable_dataset.py b/monai/data/iterable_dataset.py index c4fc252586..19efc925fc 100644 --- a/monai/data/iterable_dataset.py +++ b/monai/data/iterable_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,15 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union +import numpy as np from torch.utils.data import IterableDataset as _TorchIterableDataset from torch.utils.data import get_worker_info from monai.data.utils import convert_tables_to_dicts from monai.transforms import apply_transform -from monai.utils import ensure_tuple, optional_import +from monai.transforms.transform import Randomizable +from monai.utils import deprecated_arg, optional_import pd, _ = optional_import("pandas") @@ -25,11 +26,15 @@ class IterableDataset(_TorchIterableDataset): """ A generic dataset for iterable data source and an optional callable data transform - when fetching a data sample. + when fetching a data sample. Inherit from PyTorch IterableDataset: + https://pytorch.org/docs/stable/data.html?highlight=iterabledataset#torch.utils.data.IterableDataset. For example, typical input data can be web data stream which can support multi-process access. - Note that when used with `DataLoader` and `num_workers > 0`, each worker process will have a - different copy of the dataset object, need to guarantee process-safe from data source or DataLoader. + To accelerate the loading process, it can support multi-processing based on PyTorch DataLoader workers, + every process executes transforms on part of every loaded data. + Note that the order of output data may not match data source in multi-processing mode. + And each worker process will have a different copy of the dataset object, need to guarantee + process-safe from data source or DataLoader. """ @@ -44,19 +49,91 @@ def __init__(self, data: Iterable, transform: Optional[Callable] = None) -> None self.source = None def __iter__(self): + info = get_worker_info() + num_workers = info.num_workers if info is not None else 1 + id = info.id if info is not None else 0 + self.source = iter(self.data) - for data in self.source: - if self.transform is not None: - data = apply_transform(self.transform, data) - yield data + for i, item in enumerate(self.source): + if i % num_workers == id: + if self.transform is not None: + item = apply_transform(self.transform, item) + yield item + + +class ShuffleBuffer(Randomizable, IterableDataset): + """ + Extend the IterableDataset with a buffer and randomly pop items. + + Args: + data: input data source to load and transform to generate dataset for model. + transform: a callable data transform on input data. + buffer_size: size of the buffer to store items and randomly pop, default to 512. + seed: random seed to initialize the random state of all workers, set `seed += 1` in + every iter() call, refer to the PyTorch idea: + https://github.com/pytorch/pytorch/blob/v1.10.0/torch/utils/data/distributed.py#L98. + + """ + + def __init__(self, data, transform=None, buffer_size: int = 512, seed: int = 0) -> None: + super().__init__(data=data, transform=transform) + self.size = buffer_size + self.seed = seed + self._idx = 0 + + def __iter__(self): + """ + Fetch data from the source, if buffer is not full, fill into buffer, otherwise, + randomly pop items from the buffer. + After loading all the data from source, randomly pop items from the buffer. + + """ + self.seed += 1 + super().set_random_state(seed=self.seed) # make all workers in sync + buffer = [] + source = self.data + + def _pop_item(): + self.randomize(len(buffer)) + # switch random index data and the last index data + ret, buffer[self._idx] = buffer[self._idx], buffer[-1] + buffer.pop() + return ret + + def _get_item(): + for item in source: + if len(buffer) >= self.size: + yield _pop_item() + buffer.append(item) + + while buffer: + yield _pop_item() + + self.data = _get_item() + return super().__iter__() + + def randomize(self, size: int) -> None: + self._idx = self.R.randint(size) + + def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None): + raise NotImplementedError(f"`set_random_state` is not available in {self.__class__.__name__}.") class CSVIterableDataset(IterableDataset): """ Iterable dataset to load CSV files and generate dictionary data. - It can be helpful when loading extremely big CSV files that can't read into memory directly. + It is particularly useful when data come from a stream, inherits from PyTorch IterableDataset: + https://pytorch.org/docs/stable/data.html?highlight=iterabledataset#torch.utils.data.IterableDataset. + + It also can be helpful when loading extremely big CSV files that can't read into memory directly, + just treat the big CSV file as stream input, call `reset()` of `CSVIterableDataset` for every epoch. + Note that as a stream input, it can't get the length of dataset. + + To effectively shuffle the data in the big dataset, users can set a big buffer to continuously store + the loaded data, then randomly pick data from the buffer for following tasks. + To accelerate the loading process, it can support multi-processing based on PyTorch DataLoader workers, - every process executes transforms on part of every loaded chunk. + every process executes transforms on part of every loaded data. Note: the order of output data may not match data source in multi-processing mode. It can load data from multiple CSV files and join the tables with additional `kwargs` arg. @@ -70,10 +147,12 @@ class CSVIterableDataset(IterableDataset): ] Args: - filename: the filename of expected CSV file to load. if providing a list - of filenames, it will load all the files and join tables. + src: if provided the filename of CSV file, it can be a str, URL, path object or file-like object to load. + also support to provide iter for stream input directly, will skip loading from filename. + if provided a list of filenames or iters, it will join the tables. chunksize: rows of a chunk when loading iterable data from CSV files, default to 1000. more details: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html. + buffer_size: size of the buffer to store the loaded chunks, if None, set to `2 x chunksize`. col_names: names of the expected columns to load. if None, load all the columns. col_types: `type` and `default value` to convert the loaded columns, if None, use original data. it should be a dictionary, every item maps to an expected column, the `key` is the column @@ -93,50 +172,98 @@ class CSVIterableDataset(IterableDataset): be the new column name, the `value` is the names of columns to combine. for example: `col_groups={"ehr": [f"ehr_{i}" for i in range(10)], "meta": ["meta_1", "meta_2"]}` transform: transform to apply on the loaded items of a dictionary data. + shuffle: whether to shuffle all the data in the buffer every time a new chunk loaded. + seed: random seed to initialize the random state for all the workers if `shuffle` is True, + set `seed += 1` in every iter() call, refer to the PyTorch idea: + https://github.com/pytorch/pytorch/blob/v1.10.0/torch/utils/data/distributed.py#L98. kwargs: additional arguments for `pandas.merge()` API to join tables. + .. deprecated:: 0.8.0 + ``filename`` is deprecated, use ``src`` instead. + """ + @deprecated_arg(name="filename", new_name="src", since="0.8", msg_suffix="please use `src` instead.") def __init__( self, - filename: Union[str, Sequence[str]], + src: Union[Union[str, Sequence[str]], Union[Iterable, Sequence[Iterable]]], chunksize: int = 1000, + buffer_size: Optional[int] = None, col_names: Optional[Sequence[str]] = None, col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None, col_groups: Optional[Dict[str, Sequence[str]]] = None, transform: Optional[Callable] = None, + shuffle: bool = False, + seed: int = 0, **kwargs, ): - self.files = ensure_tuple(filename) + self.src = src self.chunksize = chunksize - self.iters = self.reset() + self.buffer_size = 2 * chunksize if buffer_size is None else buffer_size self.col_names = col_names self.col_types = col_types self.col_groups = col_groups + self.shuffle = shuffle + self.seed = seed + # in case treating deprecated arg `filename` as kwargs, remove it from `kwargs` + kwargs.pop("filename", None) self.kwargs = kwargs + + self.iters: List[Iterable] = self.reset() super().__init__(data=None, transform=transform) # type: ignore - def reset(self, filename: Optional[Union[str, Sequence[str]]] = None): - if filename is not None: - # update files if necessary - self.files = ensure_tuple(filename) - self.iters = [pd.read_csv(f, chunksize=self.chunksize) for f in self.files] + @deprecated_arg(name="filename", new_name="src", since="0.8", msg_suffix="please use `src` instead.") + def reset(self, src: Optional[Union[Union[str, Sequence[str]], Union[Iterable, Sequence[Iterable]]]] = None): + """ + Reset the pandas `TextFileReader` iterable object to read data. For more details, please check: + https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html?#iteration. + + Args: + src: if not None and provided the filename of CSV file, it can be a str, URL, path object + or file-like object to load. also support to provide iter for stream input directly, + will skip loading from filename. if provided a list of filenames or iters, it will join the tables. + default to `self.src`. + + """ + src = self.src if src is None else src + srcs = (src,) if not isinstance(src, (tuple, list)) else src + self.iters = [] + for i in srcs: + if isinstance(i, str): + self.iters.append(pd.read_csv(i, chunksize=self.chunksize)) + elif isinstance(i, Iterable): + self.iters.append(i) + else: + raise ValueError("`src` must be file path or iterable object.") return self.iters - def __iter__(self): + def close(self): + """ + Close the pandas `TextFileReader` iterable objects. + If the input src is file path, TextFileReader was created internally, need to close it. + If the input src is iterable object, depends on users requirements whether to close it in this function. + For more details, please check: + https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html?#iteration. + + """ + for i in self.iters: + i.close() + + def _flattened(self): for chunks in zip(*self.iters): - self.data = convert_tables_to_dicts( + yield from convert_tables_to_dicts( dfs=chunks, col_names=self.col_names, col_types=self.col_types, col_groups=self.col_groups, **self.kwargs, ) - info = get_worker_info() - if info is not None: - length = len(self.data) - per_worker = int(math.ceil(length / float(info.num_workers))) - start = info.id * per_worker - self.data = self.data[start : min(start + per_worker, length)] - - return super().__iter__() + + def __iter__(self): + if self.shuffle: + self.seed += 1 + buffer = ShuffleBuffer( + data=self._flattened(), transform=self.transform, buffer_size=self.buffer_size, seed=self.seed + ) + yield from buffer + yield from IterableDataset(data=self._flattened(), transform=self.transform) diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index b7067def73..f31926cb6c 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,13 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pathlib import Path from typing import Dict, Optional, Union import numpy as np import torch -from monai.config import DtypeLike +from monai.config import DtypeLike, PathLike from monai.data.nifti_writer import write_nifti from monai.data.utils import create_file_basename from monai.utils import GridSampleMode, GridSamplePadMode @@ -37,7 +36,7 @@ class NiftiSaver: def __init__( self, - output_dir: Union[Path, str] = "./", + output_dir: PathLike = "./", output_postfix: str = "seg", output_ext: str = ".nii.gz", resample: bool = True, @@ -47,7 +46,7 @@ def __init__( dtype: DtypeLike = np.float64, output_dtype: DtypeLike = np.float32, squeeze_end_dims: bool = True, - data_root_dir: str = "", + data_root_dir: PathLike = "", separate_folder: bool = True, print_log: bool = True, ) -> None: @@ -56,7 +55,8 @@ def __init__( output_dir: output image directory. output_postfix: a string appended to all output file names. output_ext: output file extension name. - resample: whether to resample before saving the data array. + resample: whether to convert the data array to it's original coordinate system + based on `original_affine` in the `meta_data`. mode: {``"bilinear"``, ``"nearest"``} This option is used when ``resample = True``. Interpolation mode to calculate output values. Defaults to ``"bilinear"``. @@ -107,7 +107,7 @@ def __init__( def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: """ - Save data into a Nifti file. + Save data into a NIfTI file. The meta_data could optionally have the following keys: - ``'filename_or_obj'`` -- for output file name creation, corresponding to filename or object. @@ -116,7 +116,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] - ``'spatial_shape'`` -- for data output shape. - ``'patch_index'`` -- if the data is a patch of big image, append the patch index to filename. - When meta_data is specified, the saver will try to resample batch data from the space + When meta_data is specified and `resample=True`, the saver will try to resample batch data from the space defined by "affine" to the space defined by "original_affine". If meta_data is None, use the default index (starting from 0) as the filename. @@ -131,7 +131,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] """ filename = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) self._data_index += 1 - original_affine = meta_data.get("original_affine", None) if meta_data else None + original_affine = meta_data.get("original_affine", None) if meta_data and self.resample else None affine = meta_data.get("affine", None) if meta_data else None spatial_shape = meta_data.get("spatial_shape", None) if meta_data else None patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None @@ -151,7 +151,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] # change data shape to be (channel, h, w, d) while len(data.shape) < 4: data = np.expand_dims(data, -1) - # change data to "channel last" format and write to nifti format file + # change data to "channel last" format and write to NIfTI format file data = np.moveaxis(np.asarray(data), 0, -1) # if desired, remove trailing singleton dimensions @@ -164,7 +164,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] file_name=path, affine=affine, target_affine=original_affine, - resample=self.resample, + resample=True, output_spatial_shape=spatial_shape, mode=self.mode, padding_mode=self.padding_mode, @@ -178,7 +178,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: """ - Save a batch of data into Nifti format files. + Save a batch of data into NIfTI format files. Spatially it supports up to three dimensions, that is, H, HW, HWD for 1D, 2D, 3D respectively (with resampling supports for 2D and 3D only). diff --git a/monai/data/nifti_writer.py b/monai/data/nifti_writer.py index c56d4c1e8d..35044977e0 100644 --- a/monai/data/nifti_writer.py +++ b/monai/data/nifti_writer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,17 +15,19 @@ import torch from monai.config import DtypeLike +from monai.config.type_definitions import NdarrayOrTensor from monai.data.utils import compute_shape_offset, to_affine_nd from monai.networks.layers import AffineTransform from monai.utils import GridSampleMode, GridSamplePadMode, optional_import +from monai.utils.type_conversion import convert_data_type nib, _ = optional_import("nibabel") def write_nifti( - data: np.ndarray, + data: NdarrayOrTensor, file_name: str, - affine: Optional[np.ndarray] = None, + affine: Optional[NdarrayOrTensor] = None, target_affine: Optional[np.ndarray] = None, resample: bool = True, output_spatial_shape: Union[Sequence[int], np.ndarray, None] = None, @@ -96,13 +98,17 @@ def write_nifti( If None, use the data type of input data. output_dtype: data type for saving data. Defaults to ``np.float32``. """ + if isinstance(data, torch.Tensor): + data, *_ = convert_data_type(data, np.ndarray) + if isinstance(affine, torch.Tensor): + affine, *_ = convert_data_type(affine, np.ndarray) if not isinstance(data, np.ndarray): - raise AssertionError("input data must be numpy array.") + raise AssertionError("input data must be numpy array or torch tensor.") dtype = dtype or data.dtype sr = min(data.ndim, 3) if affine is None: affine = np.eye(4, dtype=np.float64) - affine = to_affine_nd(sr, affine) + affine = to_affine_nd(sr, affine) # type: ignore if target_affine is None: target_affine = affine @@ -110,7 +116,7 @@ def write_nifti( if np.allclose(affine, target_affine, atol=1e-3): # no affine changes, save (data, affine) - results_img = nib.Nifti1Image(data.astype(output_dtype), to_affine_nd(3, target_affine)) + results_img = nib.Nifti1Image(data.astype(output_dtype, copy=False), to_affine_nd(3, target_affine)) nib.save(results_img, file_name) return @@ -122,7 +128,7 @@ def write_nifti( data = nib.orientations.apply_orientation(data, ornt_transform) _affine = affine @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape) if np.allclose(_affine, target_affine, atol=1e-3) or not resample: - results_img = nib.Nifti1Image(data.astype(output_dtype), to_affine_nd(3, _affine)) + results_img = nib.Nifti1Image(data.astype(output_dtype, copy=False), to_affine_nd(3, _affine)) # type: ignore nib.save(results_img, file_name) return @@ -138,11 +144,11 @@ def write_nifti( while len(output_spatial_shape_) < 3: output_spatial_shape_ = output_spatial_shape_ + [1] spatial_shape, channel_shape = data.shape[:3], data.shape[3:] - data_np = data.reshape(list(spatial_shape) + [-1]) + data_np: np.ndarray = data.reshape(list(spatial_shape) + [-1]) # type: ignore data_np = np.moveaxis(data_np, -1, 0) # channel first for pytorch data_torch = affine_xform( - torch.as_tensor(np.ascontiguousarray(data_np).astype(dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(transform).astype(dtype)), + torch.as_tensor(np.ascontiguousarray(data_np, dtype=dtype)).unsqueeze(0), + torch.as_tensor(np.ascontiguousarray(transform, dtype=dtype)), spatial_size=output_spatial_shape_[:3], ) data_np = data_torch.squeeze(0).detach().cpu().numpy() @@ -152,12 +158,12 @@ def write_nifti( while len(output_spatial_shape_) < len(data.shape): output_spatial_shape_ = output_spatial_shape_ + [1] data_torch = affine_xform( - torch.as_tensor(np.ascontiguousarray(data).astype(dtype)[None, None]), - torch.as_tensor(np.ascontiguousarray(transform).astype(dtype)), + torch.as_tensor(np.ascontiguousarray(data, dtype=dtype)[None, None]), + torch.as_tensor(np.ascontiguousarray(transform, dtype=dtype)), spatial_size=output_spatial_shape_[: len(data.shape)], ) data_np = data_torch.squeeze(0).squeeze(0).detach().cpu().numpy() - results_img = nib.Nifti1Image(data_np.astype(output_dtype), to_affine_nd(3, target_affine)) + results_img = nib.Nifti1Image(data_np.astype(output_dtype, copy=False), to_affine_nd(3, target_affine)) nib.save(results_img, file_name) return diff --git a/monai/data/png_saver.py b/monai/data/png_saver.py index e6fb641cca..2e31597837 100644 --- a/monai/data/png_saver.py +++ b/monai/data/png_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,12 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pathlib import Path from typing import Dict, Optional, Union import numpy as np import torch +from monai.config.type_definitions import PathLike from monai.data.png_writer import write_png from monai.data.utils import create_file_basename from monai.utils import ImageMetaKey as Key @@ -34,13 +34,13 @@ class PNGSaver: def __init__( self, - output_dir: Union[Path, str] = "./", + output_dir: PathLike = "./", output_postfix: str = "seg", output_ext: str = ".png", resample: bool = True, mode: Union[InterpolateMode, str] = InterpolateMode.NEAREST, scale: Optional[int] = None, - data_root_dir: str = "", + data_root_dir: PathLike = "", separate_folder: bool = True, print_log: bool = True, ) -> None: @@ -134,11 +134,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] raise ValueError(f"Unsupported number of channels: {data.shape[0]}, available options are [1, 3, 4]") write_png( - np.asarray(data), - file_name=path, - output_spatial_shape=spatial_shape, - mode=self.mode, - scale=self.scale, + np.asarray(data), file_name=path, output_spatial_shape=spatial_shape, mode=self.mode, scale=self.scale ) if self.print_log: diff --git a/monai/data/png_writer.py b/monai/data/png_writer.py index 2baec3b872..6f3b2ef86e 100644 --- a/monai/data/png_writer.py +++ b/monai/data/png_writer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -48,7 +48,7 @@ def write_png( """ if not isinstance(data, np.ndarray): - raise AssertionError("input data must be numpy array.") + raise ValueError("input data must be numpy array.") if len(data.shape) == 3 and data.shape[2] == 1: # PIL Image can't save image with 1 channel data = data.squeeze(2) if output_spatial_shape is not None: @@ -59,26 +59,26 @@ def write_png( _min, _max = np.min(data), np.max(data) if len(data.shape) == 3: data = np.moveaxis(data, -1, 0) # to channel first - data = xform(data) + data = xform(data) # type: ignore data = np.moveaxis(data, 0, -1) else: # (H, W) data = np.expand_dims(data, 0) # make a channel - data = xform(data)[0] # first channel + data = xform(data)[0] # type: ignore if mode != InterpolateMode.NEAREST: data = np.clip(data, _min, _max) # type: ignore if scale is not None: data = np.clip(data, 0.0, 1.0) # type: ignore # png writer only can scale data in range [0, 1] if scale == np.iinfo(np.uint8).max: - data = (scale * data).astype(np.uint8) + data = (scale * data).astype(np.uint8, copy=False) elif scale == np.iinfo(np.uint16).max: - data = (scale * data).astype(np.uint16) + data = (scale * data).astype(np.uint16, copy=False) else: raise ValueError(f"Unsupported scale: {scale}, available options are [255, 65535]") # PNG data must be int number if data.dtype not in (np.uint8, np.uint16): # type: ignore - data = data.astype(np.uint8) + data = data.astype(np.uint8, copy=False) data = np.moveaxis(data, 0, 1) img = Image.fromarray(data) diff --git a/monai/data/samplers.py b/monai/data/samplers.py index f69c6091ca..40eed03187 100644 --- a/monai/data/samplers.py +++ b/monai/data/samplers.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/data/synthetic.py b/monai/data/synthetic.py index 6eec9fd277..46d555cf11 100644 --- a/monai/data/synthetic.py +++ b/monai/data/synthetic.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -73,7 +73,7 @@ def create_test_image_2d( else: image[circle] = rs.random() * 0.5 + 0.5 - labels = np.ceil(image).astype(np.int32) + labels = np.ceil(image).astype(np.int32, copy=False) norm = rs.uniform(0, num_seg_classes * noise_max, size=image.shape) noisyimage: np.ndarray = rescale_array(np.maximum(image, norm)) # type: ignore @@ -148,7 +148,7 @@ def create_test_image_3d( else: image[circle] = rs.random() * 0.5 + 0.5 - labels = np.ceil(image).astype(np.int32) + labels = np.ceil(image).astype(np.int32, copy=False) norm = rs.uniform(0, num_seg_classes * noise_max, size=image.shape) noisyimage: np.ndarray = rescale_array(np.maximum(image, norm)) # type: ignore diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 33239ea924..c06d567b54 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from copy import deepcopy from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union @@ -23,8 +24,9 @@ from monai.transforms.inverse_batch_transform import BatchInverseTransform from monai.transforms.transform import Randomizable from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode -from monai.utils.enums import CommonKeys, InverseKeys +from monai.utils.enums import CommonKeys, TraceKeys from monai.utils.module import optional_import +from monai.utils.type_conversion import convert_data_type if TYPE_CHECKING: from tqdm import tqdm @@ -36,6 +38,10 @@ __all__ = ["TestTimeAugmentation"] +def _identity(x): + return x + + class TestTimeAugmentation: """ Class for performing test time augmentations. This will pass the same image through the network multiple times. @@ -81,7 +87,7 @@ class TestTimeAugmentation: .. code-block:: python transform = RandAffined(keys, ...) - post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) + post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) tt_aug = TestTimeAugmentation( transform, batch_size=5, num_workers=0, inferrer_fn=lambda x: post_trans(model(x)), device=device @@ -93,8 +99,8 @@ def __init__( self, transform: InvertibleTransform, batch_size: int, - num_workers: int, - inferrer_fn: Callable, + num_workers: int = 0, + inferrer_fn: Callable = _identity, device: Union[str, torch.device] = "cpu", image_key=CommonKeys.IMAGE, orig_key=CommonKeys.LABEL, @@ -133,8 +139,8 @@ def _check_transforms(self): # check that whenever randoms is True, invertibles is also true for r, i in zip(randoms, invertibles): if r and not i: - raise RuntimeError( - f"All applied random transform(s) must be invertible. Problematic transform: {type(r).__name__}" + warnings.warn( + f"Not all applied random transform(s) are invertible. Problematic transform: {type(r).__name__}" ) def __call__( @@ -159,11 +165,11 @@ def __call__( raise ValueError("num_examples should be multiple of batch size.") # generate batch of data of size == batch_size, dataset and dataloader - data_in = [d] * num_examples + data_in = [deepcopy(d) for _ in range(num_examples)] ds = Dataset(data_in, self.transform) - dl = DataLoader(ds, self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate) + dl = DataLoader(ds, num_workers=self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate) - transform_key = self.orig_key + InverseKeys.KEY_SUFFIX + transform_key = InvertibleTransform.trace_key(self.orig_key) # create inverter inverter = BatchInverseTransform(self.transform, dl, collate_fn=list_data_collate) @@ -180,13 +186,13 @@ def __call__( batch_output = batch_output.detach().cpu() if isinstance(batch_output, np.ndarray): batch_output = torch.Tensor(batch_output) - - transform_info = batch_data[transform_key] + transform_info = batch_data.get(transform_key, None) + if transform_info is None: + # no invertible transforms, adding dummy info for identity invertible + transform_info = [[TraceKeys.NONE] for _ in range(self.batch_size)] if self.nearest_interp: transform_info = convert_inverse_interp_mode( - trans_info=deepcopy(transform_info), - mode="nearest", - align_corners=None, + trans_info=deepcopy(transform_info), mode="nearest", align_corners=None ) # create a dictionary containing the inferred batch and their transforms @@ -210,7 +216,8 @@ def __call__( return output # calculate metrics - mode = np.array(torch.mode(torch.Tensor(output.astype(np.int64)), dim=0).values) + output_t, *_ = convert_data_type(output, output_type=torch.Tensor, dtype=np.int64) + mode: np.ndarray = np.asarray(torch.mode(output_t, dim=0).values) # type: ignore mean: np.ndarray = np.mean(output, axis=0) # type: ignore std: np.ndarray = np.std(output, axis=0) # type: ignore vvc: float = (np.std(output) / np.mean(output)).item() diff --git a/monai/data/thread_buffer.py b/monai/data/thread_buffer.py index da5847465e..cdd7c05f31 100644 --- a/monai/data/thread_buffer.py +++ b/monai/data/thread_buffer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -83,12 +83,24 @@ class ThreadDataLoader(DataLoader): iterate over data from the loader as expected however the data is generated on a separate thread. Use this class where a `DataLoader` instance is required and not just an iterable object. + The default behaviour with `repeats` set to 1 is to yield each batch as it is generated, however with a higher + value the generated batch is yielded that many times while underlying dataset asynchronously generates the next. + Typically not all relevant information is learned from a batch in a single iteration so training multiple times + on the same batch will still produce good training with minimal short-term overfitting while allowing a slow batch + generation process more time to produce a result. + + See: + * Fischetti et al. "Faster SGD training by minibatch persistency." ArXiv (2018) https://arxiv.org/abs/1806.07353 + * Dami et al., "Faster Neural Network Training with Data Echoing" ArXiv (2020) https://arxiv.org/abs/1907.05550 + * Ramezani et al. "GCN meets GPU: Decoupling "When to Sample" from "How to Sample"." NeurIPS (2020). + https://proceedings.neurips.cc/paper/2020/file/d714d2c5a796d5814c565d78dd16188d-Paper.pdf + Args: dataset: input dataset. buffer_size: number of items to buffer from the data source. buffer_timeout: time to wait for an item from the buffer, or to wait while the buffer is full when adding items. - num_workers: number of the multi-prcessing workers in PyTorch DataLoader. - + num_workers: number of the multi-processing workers in PyTorch DataLoader. + repeats: number of times to yield the same batch """ def __init__( @@ -97,12 +109,17 @@ def __init__( buffer_size: int = 1, buffer_timeout: float = 0.01, num_workers: int = 0, + repeats: int = 1, **kwargs, ): super().__init__(dataset, num_workers, **kwargs) self.buffer_size = buffer_size self.buffer_timeout = buffer_timeout + self.repeats = repeats def __iter__(self): buffer = ThreadBuffer(src=super().__iter__(), buffer_size=self.buffer_size, timeout=self.buffer_timeout) - yield from buffer + + for batch in buffer: + for _ in range(self.repeats): + yield batch diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py new file mode 100644 index 0000000000..585db14712 --- /dev/null +++ b/monai/data/torchscript_utils.py @@ -0,0 +1,149 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import json +import os +from typing import IO, Any, Mapping, Optional, Sequence, Tuple, Union + +import torch + +from monai.config import get_config_values +from monai.utils import JITMetadataKeys +from monai.utils.module import pytorch_after + +METADATA_FILENAME = "metadata.json" + + +def save_net_with_metadata( + jit_obj: torch.nn.Module, + filename_prefix_or_stream: Union[str, IO[Any]], + include_config_vals: bool = True, + append_timestamp: bool = False, + meta_values: Optional[Mapping[str, Any]] = None, + more_extra_files: Optional[Mapping[str, bytes]] = None, +) -> None: + """ + Save the JIT object (script or trace produced object) `jit_obj` to the given file or stream with metadata + included as a JSON file. The Torchscript format is a zip file which can contain extra file data which is used + here as a mechanism for storing metadata about the network being saved. The data in `meta_values` should be + compatible with conversion to JSON using the standard library function `dumps`. The intent is this metadata will + include information about the network applicable to some use case, such as describing the input and output format, + a network name and version, a plain language description of what the network does, and other relevant scientific + information. Clients can use this information to determine automatically how to use the network, and users can + read what the network does and keep track of versions. + + Examples:: + + net = torch.jit.script(monai.networks.nets.UNet(2, 1, 1, [8, 16], [2])) + + meta = { + "name": "Test UNet", + "used_for": "demonstration purposes", + "input_dims": 2, + "output_dims": 2 + } + + # save the Torchscript bundle with the above dictionary stored as an extra file + save_net_with_metadata(m, "test", meta_values=meta) + + # load the network back, `loaded_meta` has same data as `meta` plus version information + loaded_net, loaded_meta, _ = load_net_with_metadata("test.pt") + + + Args: + jit_obj: object to save, should be generated by `script` or `trace`. + filename_prefix_or_stream: filename or file-like stream object, if filename has no extension it becomes `.pt`. + include_config_vals: if True, MONAI, Pytorch, and Numpy versions are included in metadata. + append_timestamp: if True, a timestamp for "now" is appended to the file's name before the extension. + meta_values: metadata values to store with the object, not limited just to keys in `JITMetadataKeys`. + more_extra_files: other extra file data items to include in bundle, see `_extra_files` of `torch.jit.save`. + """ + + now = datetime.datetime.now() + metadict = {} + + if include_config_vals: + metadict.update(get_config_values()) + metadict[JITMetadataKeys.TIMESTAMP.value] = now.astimezone().isoformat() + + if meta_values is not None: + metadict.update(meta_values) + + json_data = json.dumps(metadict) + + # Pytorch>1.6 can use dictionaries directly, otherwise need to use special map object + if pytorch_after(1, 7): + extra_files = {METADATA_FILENAME: json_data.encode()} + + if more_extra_files is not None: + extra_files.update(more_extra_files) + else: + extra_files = torch._C.ExtraFilesMap() # type:ignore[attr-defined] + extra_files[METADATA_FILENAME] = json_data.encode() + + if more_extra_files is not None: + for k, v in more_extra_files.items(): + extra_files[k] = v + + if isinstance(filename_prefix_or_stream, str): + filename_no_ext, ext = os.path.splitext(filename_prefix_or_stream) + if ext == "": + ext = ".pt" + + if append_timestamp: + filename_prefix_or_stream = now.strftime(f"{filename_no_ext}_%Y%m%d%H%M%S{ext}") + else: + filename_prefix_or_stream = filename_no_ext + ext + + torch.jit.save(jit_obj, filename_prefix_or_stream, extra_files) + + +def load_net_with_metadata( + filename_prefix_or_stream: Union[str, IO[Any]], + map_location: Optional[torch.device] = None, + more_extra_files: Sequence[str] = (), +) -> Tuple[torch.nn.Module, dict, dict]: + """ + Load the module object from the given Torchscript filename or stream, and convert the stored JSON metadata + back to a dict object. This will produce an empty dict if the metadata file is not present. + + Args: + filename_prefix_or_stream: filename or file-like stream object. + map_location: network map location as in `torch.jit.load`. + more_extra_files: other extra file data names to load from bundle, see `_extra_files` of `torch.jit.load`. + Returns: + Triple containing loaded object, metadata dict, and extra files dict containing other file data if present + """ + # Pytorch>1.6 can use dictionaries directly, otherwise need to use special map object + if pytorch_after(1, 7): + extra_files = {f: "" for f in more_extra_files} + extra_files[METADATA_FILENAME] = "" + else: + extra_files = torch._C.ExtraFilesMap() # type:ignore[attr-defined] + extra_files[METADATA_FILENAME] = "" + + for f in more_extra_files: + extra_files[f] = "" + + jit_obj = torch.jit.load(filename_prefix_or_stream, map_location, extra_files) + + extra_files = dict(extra_files.items()) # compatibility with ExtraFilesMap + + if METADATA_FILENAME in extra_files: + json_data = extra_files[METADATA_FILENAME] + del extra_files[METADATA_FILENAME] + else: + json_data = "{}" + + json_data_dict = json.loads(json_data) + + return jit_obj, json_data_dict, extra_files diff --git a/monai/data/utils.py b/monai/data/utils.py index aab23217dc..0cd3c1594a 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,14 +18,15 @@ from collections import defaultdict from copy import deepcopy from functools import reduce -from itertools import product, starmap -from pathlib import Path, PurePath +from itertools import product, starmap, zip_longest +from pathlib import PurePath from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch from torch.utils.data._utils.collate import default_collate +from monai.config.type_definitions import PathLike from monai.networks.layers.simplelayers import GaussianFilter from monai.utils import ( MAX_SEED, @@ -66,17 +67,21 @@ "is_supported_format", "partition_dataset", "partition_dataset_classes", + "resample_datalist", "select_cross_validation_folds", "json_hashing", "pickle_hashing", "sorted_dict", "decollate_batch", - "rep_scalar_to_batch", "pad_list_data_collate", "no_collation", "convert_tables_to_dicts", + "SUPPORTED_PICKLE_MOD", ] +# module to be used by `torch.save` +SUPPORTED_PICKLE_MOD = {"pickle": pickle} + def get_random_patch( dims: Sequence[int], patch_size: Sequence[int], rand_state: Optional[np.random.RandomState] = None @@ -134,9 +139,7 @@ def iter_patch_slices( def dense_patch_slices( - image_size: Sequence[int], - patch_size: Sequence[int], - scan_interval: Sequence[int], + image_size: Sequence[int], patch_size: Sequence[int], scan_interval: Sequence[int] ) -> List[Tuple[slice, ...]]: """ Enumerate all slices defining ND patches of size `patch_size` from an `image_size` input image. @@ -283,7 +286,7 @@ def list_data_collate(batch: Sequence): + "`DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its " + "documentation)." ) - raise RuntimeError(re_str) + raise RuntimeError(re_str) from re except TypeError as re: re_str = str(re) if "numpy" in re_str and "Tensor" in re_str: @@ -294,10 +297,35 @@ def list_data_collate(batch: Sequence): + "creating your `DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem " + "(check its documentation)." ) - raise TypeError(re_str) + raise TypeError(re_str) from re + + +def _non_zipping_check(batch_data, detach, pad, fill_value): + """ + Utility function based on `decollate_batch`, to identify the largest batch size from the collated data. + returns batch_size, the list of non-iterable items, and the dictionary or list with their items decollated. + + See `decollate_batch` for more details. + """ + if isinstance(batch_data, Mapping): + _deco = {key: decollate_batch(batch_data[key], detach, pad=pad, fill_value=fill_value) for key in batch_data} + elif isinstance(batch_data, Iterable): + _deco = [decollate_batch(b, detach, pad=pad, fill_value=fill_value) for b in batch_data] + else: + raise NotImplementedError(f"Unable to de-collate: {batch_data}, type: {type(batch_data)}.") + batch_size, non_iterable = 0, [] + for k, v in _deco.items() if isinstance(_deco, Mapping) else enumerate(_deco): + if not isinstance(v, Iterable) or isinstance(v, (str, bytes)) or (isinstance(v, torch.Tensor) and v.ndim == 0): + # Not running the usual list decollate here: + # don't decollate ['test', 'test'] into [['t', 't'], ['e', 'e'], ['s', 's'], ['t', 't']] + # torch.tensor(0) is iterable but iter(torch.tensor(0)) raises TypeError: iteration over a 0-d tensor + non_iterable.append(k) + elif hasattr(v, "__len__"): + batch_size = max(batch_size, len(v)) + return batch_size, non_iterable, _deco -def decollate_batch(batch, detach: bool = True): +def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): """De-collate a batch of data (for example, as produced by a `DataLoader`). Returns a list of structures with the original tensor's 0-th dimension sliced into elements using `torch.unbind`. @@ -335,10 +363,23 @@ def decollate_batch(batch, detach: bool = True): print(out[0]) >>> tensor([[[4.3549e-01...43e-01]]]) + batch_data = { + "image": [1, 2, 3], "meta": [4, 5], # undetermined batch size + } + out = decollate_batch(batch_data, pad=True, fill_value=0) + print(out) + >>> [{'image': 1, 'meta': 4}, {'image': 2, 'meta': 5}, {'image': 3, 'meta': 0}] + out = decollate_batch(batch_data, pad=False) + print(out) + >>> [{'image': 1, 'meta': 4}, {'image': 2, 'meta': 5}] + Args: batch: data to be de-collated. detach: whether to detach the tensors. Scalars tensors will be detached into number types instead of torch tensors. + pad: when the items in a batch indicate different batch size, whether to pad all the sequences to the longest. + If False, the batch size will be the length of the shortest sequence. + fill_value: when `pad` is True, the `fillvalue` to use when padding, defaults to `None`. """ if batch is None: return batch @@ -353,68 +394,20 @@ def decollate_batch(batch, detach: bool = True): if out_list[0].ndim == 0 and detach: return [t.item() for t in out_list] return list(out_list) - if isinstance(batch, Mapping): - _dict_list = {key: decollate_batch(batch[key], detach) for key in batch} - return [dict(zip(_dict_list, item)) for item in zip(*_dict_list.values())] - if isinstance(batch, Iterable): - item_0 = first(batch) - if ( - not isinstance(item_0, Iterable) - or isinstance(item_0, (str, bytes)) - or (isinstance(item_0, torch.Tensor) and item_0.ndim == 0) - ): - # Not running the usual list decollate here: - # don't decollate ['test', 'test'] into [['t', 't'], ['e', 'e'], ['s', 's'], ['t', 't']] - # torch.tensor(0) is iterable but iter(torch.tensor(0)) raises TypeError: iteration over a 0-d tensor - return [decollate_batch(b, detach) for b in batch] - return [list(item) for item in zip(*(decollate_batch(b, detach) for b in batch))] - raise NotImplementedError(f"Unable to de-collate: {batch}, type: {type(batch)}.") - -def rep_scalar_to_batch(batch_data: Union[List, Dict]) -> Union[List, Dict]: - """ - Utility tp replicate the scalar items of a list or dictionary to ensure all the items have batch dimension. - It leverages `decollate_batch(detach=False)` to filter out the scalar items. - - """ - - def _detect_batch_size(batch_data: Sequence): - """ - Detect the batch size from a list of data, some items in the list have batch dim, some not. - - """ - for v in batch_data: - if isinstance(v, torch.Tensor) and v.ndim > 0: - return v.shape[0] - for v in batch_data: - if issequenceiterable(v): - warnings.warn("batch_data doesn't contain batched Tensor data, use the length of first sequence data.") - return len(v) - raise RuntimeError("failed to automatically detect the batch size.") - - if isinstance(batch_data, dict): - batch_size = _detect_batch_size(list(batch_data.values())) - dict_batch = {} - for k, v in batch_data.items(): - if decollate_batch(v, detach=False) == v and not isinstance(v, list): - # if decollating a list, the result may be the same list, so should skip this case - dict_batch[k] = [deepcopy(decollate_batch(v, detach=True)) for _ in range(batch_size)] - else: - dict_batch[k] = v - - return dict_batch - if isinstance(batch_data, list): - batch_size = _detect_batch_size(batch_data) - list_batch = [] - for b in batch_data: - if decollate_batch(b, detach=False) == b and not isinstance(b, list): - list_batch.append([deepcopy(decollate_batch(b, detach=True)) for _ in range(batch_size)]) - else: - list_batch.append(b) - - return list_batch - # if not dict or list, just return the original data - return batch_data + b, non_iterable, deco = _non_zipping_check(batch, detach, pad, fill_value) + if b <= 0: # all non-iterable, single item "batch"? {"image": 1, "label": 1} + return deco + if pad: # duplicate non-iterable items to the longest batch + for k in non_iterable: + deco[k] = [deepcopy(deco[k]) for _ in range(b)] + if isinstance(deco, Mapping): + _gen = zip_longest(*deco.values(), fillvalue=fill_value) if pad else zip(*deco.values()) + return [dict(zip(deco, item)) for item in _gen] + if isinstance(deco, Iterable): + _gen = zip_longest(*deco, fillvalue=fill_value) if pad else zip(*deco) + return [list(item) for item in _gen] + raise NotImplementedError(f"Unable to de-collate: {batch}, type: {type(batch)}.") def pad_list_data_collate( @@ -467,9 +460,10 @@ def worker_init_fn(worker_id: int) -> None: def set_rnd(obj, seed: int) -> int: """ - Set seed or random state for all randomisable properties of obj. + Set seed or random state for all randomizable properties of obj. Args: + obj: object to set seed or random state for. seed: set the random state with an integer seed. """ if not hasattr(obj, "__dict__"): @@ -545,7 +539,7 @@ def rectify_header_sform_qform(img_nii): return img_nii -def zoom_affine(affine: np.ndarray, scale: Sequence[float], diagonal: bool = True): +def zoom_affine(affine: np.ndarray, scale: Union[np.ndarray, Sequence[float]], diagonal: bool = True): """ To make column norm of `affine` the same as `scale`. If diagonal is False, returns an affine that combines orthogonal rotation and the new scale. @@ -630,7 +624,7 @@ def compute_shape_offset( # different orientation, the min is the origin corners = corners[:-1] / corners[-1] offset = np.min(corners, 1) - return out_shape.astype(int), offset + return out_shape.astype(int, copy=False), offset def to_affine_nd(r: Union[np.ndarray, int], affine: np.ndarray) -> np.ndarray: @@ -678,18 +672,19 @@ def to_affine_nd(r: Union[np.ndarray, int], affine: np.ndarray) -> np.ndarray: def create_file_basename( postfix: str, - input_file_name: str, - folder_path: Union[Path, str], - data_root_dir: str = "", + input_file_name: PathLike, + folder_path: PathLike, + data_root_dir: PathLike = "", separate_folder: bool = True, - patch_index: Optional[int] = None, + patch_index=None, + makedirs: bool = True, ) -> str: """ Utility function to create the path to the output file based on the input filename (file name extension is not added by this function). When `data_root_dir` is not specified, the output file name is: - `folder_path/input_file_name (no ext.) /input_file_name (no ext.)[_postfix]` + `folder_path/input_file_name (no ext.) /input_file_name (no ext.)[_postfix][_patch_index]` otherwise the relative path with respect to `data_root_dir` will be inserted, for example: input_file_name: /foo/bar/test1/image.png, @@ -710,6 +705,7 @@ def create_file_basename( `image.nii`, postfix is `seg` and folder_path is `output`, if `True`, save as: `output/image/image_seg.nii`, if `False`, save as `output/image_seg.nii`. default to `True`. patch_index: if not None, append the patch index to filename. + makedirs: whether to create the folder if it does not exist. """ # get the filename and directory @@ -728,8 +724,10 @@ def create_file_basename( if separate_folder: output = os.path.join(output, filename) - # create target folder if no existing - os.makedirs(output, exist_ok=True) + + if makedirs: + # create target folder if no existing + os.makedirs(output, exist_ok=True) # add the sub-folder plus the postfix name to become the file basename in the output path output = os.path.join(output, (filename + "_" + postfix) if len(postfix) > 0 else filename) @@ -737,7 +735,7 @@ def create_file_basename( if patch_index is not None: output += f"_{patch_index}" - return os.path.abspath(output) + return os.path.normpath(output) def compute_importance_map( @@ -795,7 +793,7 @@ def compute_importance_map( return importance_map -def is_supported_format(filename: Union[Sequence[str], str], suffixes: Sequence[str]) -> bool: +def is_supported_format(filename: Union[Sequence[PathLike], PathLike], suffixes: Sequence[str]) -> bool: """ Verify whether the specified file or files format match supported suffixes. If supported suffixes is None, skip the verification and return True. @@ -806,7 +804,7 @@ def is_supported_format(filename: Union[Sequence[str], str], suffixes: Sequence[ suffixes: all the supported image suffixes of current reader, must be a list of lower case suffixes. """ - filenames: Sequence[str] = ensure_tuple(filename) + filenames: Sequence[PathLike] = ensure_tuple(filename) for name in filenames: tokens: Sequence[str] = PurePath(name).suffixes if len(tokens) == 0 or all("." + s.lower() not in "".join(tokens) for s in suffixes): @@ -993,6 +991,31 @@ def partition_dataset_classes( return datasets +def resample_datalist(data: Sequence, factor: float, random_pick: bool = False, seed: int = 0): + """ + Utility function to resample the loaded datalist for training, for example: + If factor < 1.0, randomly pick part of the datalist and set to Dataset, useful to quickly test the program. + If factor > 1.0, repeat the datalist to enhance the Dataset. + + Args: + data: original datalist to scale. + factor: scale factor for the datalist, for example, factor=4.5, repeat the datalist 4 times and plus + 50% of the original datalist. + random_pick: whether to randomly pick data if scale factor has decimal part. + seed: random seed to randomly pick data. + + """ + scale, repeats = math.modf(factor) + ret: List = list() + + for _ in range(int(repeats)): + ret.extend(list(deepcopy(data))) + if scale > 1e-6: + ret.extend(partition_dataset(data=data, ratios=[scale, 1 - scale], shuffle=random_pick, seed=seed)[0]) + + return ret + + def select_cross_validation_folds(partitions: Sequence[Iterable], folds: Union[Sequence[int], int]) -> List: """ Select cross validation data based on data partitions and specified fold index. @@ -1029,7 +1052,7 @@ def json_hashing(item) -> bytes: """ # TODO: Find way to hash transforms content as part of the cache cache_key = hashlib.md5(json.dumps(item, sort_keys=True).encode("utf-8")).hexdigest() - return f"{cache_key}".encode("utf-8") + return f"{cache_key}".encode() def pickle_hashing(item, protocol=pickle.HIGHEST_PROTOCOL) -> bytes: @@ -1044,7 +1067,7 @@ def pickle_hashing(item, protocol=pickle.HIGHEST_PROTOCOL) -> bytes: """ cache_key = hashlib.md5(pickle.dumps(sorted_dict(item), protocol=protocol)).hexdigest() - return f"{cache_key}".encode("utf-8") + return f"{cache_key}".encode() def sorted_dict(item, key=None, reverse=False): @@ -1117,7 +1140,7 @@ def convert_tables_to_dicts( # convert data types types = {k: v["type"] for k, v in col_types.items() if v is not None and "type" in v} if types: - data_ = data_.astype(dtype=types) + data_ = data_.astype(dtype=types, copy=False) data: List[Dict] = data_.to_dict(orient="records") # group columns to generate new column diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index 89ebc8b47c..88f094c732 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,9 +15,13 @@ from .utils import ( GanKeys, IterationEvents, + PrepareBatch, + PrepareBatchDefault, + PrepareBatchExtraInput, default_make_latent, default_metric_cmp_fn, default_prepare_batch, engine_apply_transform, get_devices_spec, ) +from .workflow import BaseWorkflow, Workflow diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 1c37da71d4..b329462e24 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import torch from torch.utils.data import DataLoader @@ -45,9 +45,13 @@ class Evaluator(Workflow): epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. - prepare_batch: function to parse image and label for current iteration. + prepare_batch: function to parse expected data (usually `image`, `label` and other network args) + from `engine.state.batch` for every iteration, for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. iteration_update: the callable function for every iteration, expect to accept `engine` - and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. key_val_metric: compute metric when every iteration completed, and save average value to @@ -80,7 +84,7 @@ def __init__( epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable] = None, + iteration_update: Optional[Callable[[Engine, Any], Any]] = None, postprocessing: Optional[Transform] = None, key_val_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, @@ -147,9 +151,13 @@ class SupervisedEvaluator(Evaluator): epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. - prepare_batch: function to parse image and label for current iteration. + prepare_batch: function to parse expected data (usually `image`, `label` and other network args) + from `engine.state.batch` for every iteration, for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. iteration_update: the callable function for every iteration, expect to accept `engine` - and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. inferer: inference method that execute model forward on input data, like: SlidingWindow, etc. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. @@ -184,7 +192,7 @@ def __init__( epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable] = None, + iteration_update: Optional[Callable[[Engine, Any], Any]] = None, inferer: Optional[Inferer] = None, postprocessing: Optional[Transform] = None, key_val_metric: Optional[Dict[str, Metric]] = None, @@ -219,7 +227,7 @@ def __init__( self.network = network self.inferer = SimpleInferer() if inferer is None else inferer - def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: @@ -237,7 +245,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) + batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore if len(batch) == 2: inputs, targets = batch args: Tuple = () @@ -246,15 +254,15 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict inputs, targets, args, kwargs = batch # put iteration outputs into engine.state - engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore # execute forward computation with self.mode(self.network): if self.amp: with torch.cuda.amp.autocast(): - engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) + engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) # type: ignore else: - engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) + engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) # type: ignore engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) @@ -270,14 +278,18 @@ class EnsembleEvaluator(Evaluator): device: an object representing the device on which to run. val_data_loader: Ignite engine use data_loader to run, must be Iterable, typically be torch.DataLoader. epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`. - network: networks to evaluate in order in the evaluator, should be regular PyTorch `torch.nn.Module`. + networks: networks to evaluate in order in the evaluator, should be regular PyTorch `torch.nn.Module`. pred_keys: the keys to store every prediction data. the length must exactly match the number of networks. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. - prepare_batch: function to parse image and label for current iteration. + prepare_batch: function to parse expected data (usually `image`, `label` and other network args) + from `engine.state.batch` for every iteration, for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. iteration_update: the callable function for every iteration, expect to accept `engine` - and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. inferer: inference method that execute model forward on input data, like: SlidingWindow, etc. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. @@ -313,7 +325,7 @@ def __init__( epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable] = None, + iteration_update: Optional[Callable[[Engine, Any], Any]] = None, inferer: Optional[Inferer] = None, postprocessing: Optional[Transform] = None, key_val_metric: Optional[Dict[str, Metric]] = None, @@ -349,7 +361,7 @@ def __init__( self.pred_keys = ensure_tuple(pred_keys) self.inferer = SimpleInferer() if inferer is None else inferer - def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: @@ -370,7 +382,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) + batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore if len(batch) == 2: inputs, targets = batch args: Tuple = () @@ -379,17 +391,21 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict inputs, targets, args, kwargs = batch # put iteration outputs into engine.state - engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore for idx, network in enumerate(self.networks): with self.mode(network): if self.amp: with torch.cuda.amp.autocast(): + if isinstance(engine.state.output, dict): + engine.state.output.update( + {self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)} + ) + else: + if isinstance(engine.state.output, dict): engine.state.output.update( {self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)} ) - else: - engine.state.output.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}) engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) diff --git a/monai/engines/multi_gpu_supervised_trainer.py b/monai/engines/multi_gpu_supervised_trainer.py index 3671dbcfd1..7c59b670b7 100644 --- a/monai/engines/multi_gpu_supervised_trainer.py +++ b/monai/engines/multi_gpu_supervised_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,10 +34,7 @@ Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric") -__all__ = [ - "create_multigpu_supervised_trainer", - "create_multigpu_supervised_evaluator", -] +__all__ = ["create_multigpu_supervised_trainer", "create_multigpu_supervised_evaluator"] def _default_transform(_x: torch.Tensor, _y: torch.Tensor, _y_pred: torch.Tensor, loss: torch.Tensor) -> float: @@ -59,7 +56,7 @@ def create_multigpu_supervised_trainer( prepare_batch: Callable = _prepare_batch, output_transform: Callable = _default_transform, distributed: bool = False, -) -> Engine: +): """ Derived from `create_supervised_trainer` in Ignite. @@ -107,7 +104,7 @@ def create_multigpu_supervised_evaluator( prepare_batch: Callable = _prepare_batch, output_transform: Callable = _default_eval_transform, distributed: bool = False, -) -> Engine: +): """ Derived from `create_supervised_evaluator` in Ignite. diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 44e265be1f..774e535e7f 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import torch from torch.optim.optimizer import Optimizer @@ -26,7 +26,7 @@ from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform -from monai.utils import PT_BEFORE_1_7, min_version, optional_import +from monai.utils import min_version, optional_import, pytorch_after from monai.utils.enums import CommonKeys as Keys if TYPE_CHECKING: @@ -75,9 +75,13 @@ class SupervisedTrainer(Trainer): epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. - prepare_batch: function to parse image and label for current iteration. + prepare_batch: function to parse expected data (usually `image`, `label` and other network args) + from `engine.state.batch` for every iteration, for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. iteration_update: the callable function for every iteration, expect to accept `engine` - and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. inferer: inference method that execute model forward on input data, like: SlidingWindow, etc. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. @@ -115,7 +119,7 @@ def __init__( epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable] = None, + iteration_update: Optional[Callable[[Engine, Any], Any]] = None, inferer: Optional[Inferer] = None, postprocessing: Optional[Transform] = None, key_train_metric: Optional[Dict[str, Metric]] = None, @@ -172,7 +176,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) + batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore if len(batch) == 2: inputs, targets = batch args: Tuple = () @@ -180,7 +184,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): else: inputs, targets, args, kwargs = batch # put iteration outputs into engine.state - engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore def _compute_pred_loss(): engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) @@ -190,7 +194,7 @@ def _compute_pred_loss(): self.network.train() # `set_to_none` only work from PyTorch 1.7.0 - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): self.optimizer.zero_grad() else: self.optimizer.zero_grad(set_to_none=self.optim_set_to_none) @@ -198,13 +202,13 @@ def _compute_pred_loss(): if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): _compute_pred_loss() - self.scaler.scale(engine.state.output[Keys.LOSS]).backward() + self.scaler.scale(engine.state.output[Keys.LOSS]).backward() # type: ignore engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.scaler.step(self.optimizer) self.scaler.update() else: _compute_pred_loss() - engine.state.output[Keys.LOSS].backward() + engine.state.output[Keys.LOSS].backward() # type: ignore engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.optimizer.step() engine.fire_event(IterationEvents.MODEL_COMPLETED) @@ -241,12 +245,16 @@ class GanTrainer(Trainer): non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. d_prepare_batch: callback function to prepare batchdata for D inferer. - Defaults to return ``GanKeys.REALS`` in batchdata dict. + Defaults to return ``GanKeys.REALS`` in batchdata dict. for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. g_prepare_batch: callback function to create batch of latent input for G inferer. - Defaults to return random latents. + Defaults to return random latents. for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. g_update_latents: Calculate G loss with new latent codes. Defaults to ``True``. iteration_update: the callable function for every iteration, expect to accept `engine` - and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. key_train_metric: compute metric when every iteration completed, and save average value to @@ -286,7 +294,7 @@ def __init__( d_prepare_batch: Callable = default_prepare_batch, g_prepare_batch: Callable = default_make_latent, g_update_latents: bool = True, - iteration_update: Optional[Callable] = None, + iteration_update: Optional[Callable[[Engine, Any], Any]] = None, postprocessing: Optional[Transform] = None, key_train_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, @@ -345,18 +353,21 @@ def _iteration( if batchdata is None: raise ValueError("must provide batch data for current iteration.") - d_input = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) + d_input = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore batch_size = self.data_loader.batch_size # type: ignore - g_input = self.g_prepare_batch(batch_size, self.latent_shape, engine.state.device, engine.non_blocking) + g_input = self.g_prepare_batch( + num_latents=batch_size, + latent_size=self.latent_shape, + device=engine.state.device, # type: ignore + non_blocking=engine.non_blocking, # type: ignore + ) g_output = self.g_inferer(g_input, self.g_network) # Train Discriminator - d_total_loss = torch.zeros( - 1, - ) + d_total_loss = torch.zeros(1) for _ in range(self.d_train_steps): # `set_to_none` only work from PyTorch 1.7.0 - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): self.d_optimizer.zero_grad() else: self.d_optimizer.zero_grad(set_to_none=self.optim_set_to_none) @@ -367,9 +378,14 @@ def _iteration( # Train Generator if self.g_update_latents: - g_input = self.g_prepare_batch(batch_size, self.latent_shape, engine.state.device, engine.non_blocking) + g_input = self.g_prepare_batch( + num_latents=batch_size, + latent_size=self.latent_shape, + device=engine.state.device, # type: ignore + non_blocking=engine.non_blocking, # type: ignore + ) g_output = self.g_inferer(g_input, self.g_network) - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): self.g_optimizer.zero_grad() else: self.g_optimizer.zero_grad(set_to_none=self.optim_set_to_none) diff --git a/monai/engines/utils.py b/monai/engines/utils.py index c94cc16916..726dfc8e98 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,13 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch from monai.config import IgniteInfo from monai.transforms import apply_transform -from monai.utils import min_version, optional_import +from monai.utils import ensure_tuple, min_version, optional_import from monai.utils.enums import CommonKeys if TYPE_CHECKING: @@ -28,6 +29,9 @@ "GanKeys", "get_devices_spec", "default_prepare_batch", + "PrepareBatch", + "PrepareBatchDefault", + "PrepareBatchExtraInput", "default_make_latent", "engine_apply_transform", "default_metric_cmp_fn", @@ -100,9 +104,7 @@ def get_devices_spec(devices: Optional[Sequence[torch.device]] = None) -> List[t def default_prepare_batch( - batchdata: Dict[str, torch.Tensor], - device: Optional[Union[str, torch.device]] = None, - non_blocking: bool = False, + batchdata: Dict[str, torch.Tensor], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: """ Default function to prepare the data for current iteration. @@ -125,11 +127,81 @@ def default_prepare_batch( return batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking), None +class PrepareBatch(ABC): + """ + Interface of customized prepare_batch in the trainer or evaluator workflows. + It takes the data of current batch, target device and non_blocking flag as input. + + """ + + @abstractmethod + def __call__( + self, + batchdata: Dict[str, torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + non_blocking: bool = False, + ): + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + +class PrepareBatchDefault(PrepareBatch): + """ + Default prepare batch method to return `image` and `label` only, + it's to be consistent with `default_prepare_batch` API. + """ + + def __call__( + self, + batchdata: Dict[str, torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + non_blocking: bool = False, + ): + return default_prepare_batch(batchdata, device, non_blocking) + + +class PrepareBatchExtraInput(PrepareBatch): + """ + Customized prepare_batch for trainer or evaluator that support extra input data for network. + Extra items are specified by the `extra_keys` parameter. + + Args: + extra_keys: if a string or list provided, every item is the key of extra data in current batch, + and will pass the extra data to the `network(*args)` in order. + If a dictionary is provided, every `{k, v}` pair is the key of extra data in current batch, + `k` is the param name in network, `v` is the key of extra data in current batch, + and will pass the `{k1: batch[v1], k2: batch[v2], ...}` as kwargs to the network. + + """ + + def __init__(self, extra_keys: Union[str, Sequence[str], Dict[str, str]]) -> None: + self.extra_keys = extra_keys + + def __call__( + self, + batchdata: Dict[str, torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + non_blocking: bool = False, + ): + image, label = default_prepare_batch(batchdata, device, non_blocking) + args = list() + kwargs = dict() + + def _get_data(key: str): + data = batchdata[key] + return data.to(device=device, non_blocking=non_blocking) if isinstance(data, torch.Tensor) else data + + if isinstance(self.extra_keys, (str, list, tuple)): + for k in ensure_tuple(self.extra_keys): + args.append(_get_data(k)) + elif isinstance(self.extra_keys, dict): + for k, v in self.extra_keys.items(): + kwargs.update({k: _get_data(v)}) + + return image, label, tuple(args), kwargs + + def default_make_latent( - num_latents: int, - latent_size: int, - device: Optional[Union[str, torch.device]] = None, - non_blocking: bool = False, + num_latents: int, latent_size: int, device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False ) -> torch.Tensor: return torch.randn(num_latents, latent_size).to(device=device, non_blocking=non_blocking) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index ffb8ce05b3..4222db0593 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,7 +10,8 @@ # limitations under the License. import warnings -from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Union +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Union import torch import torch.distributed as dist @@ -19,8 +20,8 @@ from monai.config import IgniteInfo from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch -from monai.transforms import Decollated, Transform -from monai.utils import ensure_tuple, min_version, optional_import +from monai.transforms import Decollated +from monai.utils import ensure_tuple, is_scalar, min_version, optional_import from .utils import engine_apply_transform @@ -37,6 +38,18 @@ EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum") +class BaseWorkflow(ABC): + """ + Base class for any MONAI style workflow. + `run()` is designed to execute the train, evaluation or inference logic. + + """ + + @abstractmethod + def run(self, *args, **kwargs): + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optional_import """ Workflow defines the core work process inheriting from Ignite engine. @@ -54,9 +67,13 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona epoch_length: number of iterations for one epoch, default to `len(data_loader)`. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. - prepare_batch: function to parse image and label for every iteration. + prepare_batch: function to parse expected data (usually `image`, `label` and other network args) + from `engine.state.batch` for every iteration, for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. iteration_update: the callable function for every iteration, expect to accept `engine` - and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. key_metric: compute metric when every iteration completed, and save average value to @@ -94,7 +111,7 @@ def __init__( epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable] = None, + iteration_update: Optional[Callable[[Engine, Any], Any]] = None, postprocessing: Optional[Callable] = None, key_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, @@ -152,15 +169,15 @@ def set_sampler_epoch(engine: Engine): self.scaler: Optional[torch.cuda.amp.GradScaler] = None if event_names is None: - event_names = [IterationEvents] + event_names = [IterationEvents] # type: ignore else: if not isinstance(event_names, list): raise ValueError("event_names must be a list or string or EventEnum.") - event_names += [IterationEvents] + event_names += [IterationEvents] # type: ignore for name in event_names: if isinstance(name, str): self.register_events(name, event_to_attr=event_to_attr) - elif issubclass(name, EventEnum): + elif issubclass(name, EventEnum): # type: ignore self.register_events(*name, event_to_attr=event_to_attr) else: raise ValueError("event_names must be a list or string or EventEnum.") @@ -169,8 +186,8 @@ def set_sampler_epoch(engine: Engine): self._register_decollate() if postprocessing is not None: - if not decollate and isinstance(postprocessing, Transform): - warnings.warn("MONAI transforms expect `channel-first` data, `decollate=False` may not work here.") + # tips: if `decollate=False` and `postprocessing` is MONAI transforms, it may not work well + # because all the MONAI transforms expect `channel-first` data self._register_postprocessing(postprocessing) if key_metric is not None: self._register_metrics(key_metric, additional_metrics) @@ -187,8 +204,10 @@ def _register_decollate(self): def _decollate_data(engine: Engine) -> None: # replicate the scalar values to make sure all the items have batch dimension, then decollate transform = Decollated(keys=None, detach=True) - engine.state.batch = transform(engine.state.batch) - engine.state.output = transform(engine.state.output) + if isinstance(engine.state.batch, (list, dict)): + engine.state.batch = transform(engine.state.batch) + if isinstance(engine.state.output, (list, dict)): + engine.state.output = transform(engine.state.output) def _register_postprocessing(self, posttrans: Callable): """ @@ -200,9 +219,7 @@ def _register_postprocessing(self, posttrans: Callable): def _run_postprocessing(engine: Engine) -> None: if not isinstance(engine.state.batch, list) or not isinstance(engine.state.output, list): engine.state.batch, engine.state.output = engine_apply_transform( - batch=engine.state.batch, - output=engine.state.output, - transform=posttrans, + batch=engine.state.batch, output=engine.state.output, transform=posttrans ) else: for i, (b, o) in enumerate(zip(engine.state.batch, engine.state.output)): @@ -216,7 +233,7 @@ def _register_metrics(self, k_metric: Dict, add_metrics: Optional[Dict] = None): if not isinstance(k_metric, dict): raise TypeError(f"key_metric must be None or a dict but is {type(k_metric).__name__}.") self.state.key_metric_name = list(k_metric.keys())[0] - metrics = k_metric + metrics = dict(k_metric) if add_metrics is not None and len(add_metrics) > 0: if not isinstance(add_metrics, dict): raise TypeError(f"additional metrics must be None or a dict but is {type(add_metrics).__name__}.") @@ -226,12 +243,20 @@ def _register_metrics(self, k_metric: Dict, add_metrics: Optional[Dict] = None): @self.on(Events.EPOCH_COMPLETED) def _compare_metrics(engine: Engine) -> None: - if engine.state.key_metric_name is not None: - current_val_metric = engine.state.metrics[engine.state.key_metric_name] - if self.metric_cmp_fn(current_val_metric, engine.state.best_metric): - self.logger.info(f"Got new best metric of {engine.state.key_metric_name}: {current_val_metric}") - engine.state.best_metric = current_val_metric - engine.state.best_metric_epoch = engine.state.epoch + key_metric_name = engine.state.key_metric_name # type: ignore + if key_metric_name is not None: + current_val_metric = engine.state.metrics[key_metric_name] + if not is_scalar(current_val_metric): + warnings.warn( + "key metric is not a scalar value, skip the metric comparison with the current best metric." + "please set other metrics as the key metric, or change the `reduction` mode to 'mean'." + ) + return + + if self.metric_cmp_fn(current_val_metric, engine.state.best_metric): # type: ignore + self.logger.info(f"Got new best metric of {key_metric_name}: {current_val_metric}") + engine.state.best_metric = current_val_metric # type: ignore + engine.state.best_metric_epoch = engine.state.epoch # type: ignore def _register_handlers(self, handlers: Sequence): """ diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index c9eecc6d46..03aaa37412 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,6 +22,7 @@ from .mean_dice import MeanDice from .metric_logger import MetricLogger, MetricLoggerKeys from .metrics_saver import MetricsSaver +from .mlflow_handler import MLFlowHandler from .nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler from .parameter_scheduler import ParamSchedulerHandler from .postprocessing import PostProcessing @@ -32,13 +33,5 @@ from .stats_handler import StatsHandler from .surface_distance import SurfaceDistance from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler -from .transform_inverter import TransformInverter -from .utils import ( - evenly_divisible_all_gather, - from_engine, - stopping_fn_from_loss, - stopping_fn_from_metric, - string_list_all_gather, - write_metrics_reports, -) +from .utils import from_engine, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports from .validation_handler import ValidationHandler diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index f1f60abf63..91cfca354a 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -126,7 +126,7 @@ def __call__(self, engine: Engine) -> None: # save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint prior_max_epochs = engine.state.max_epochs Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint, strict=self.strict) - if engine.state.epoch > prior_max_epochs: + if prior_max_epochs is not None and engine.state.epoch > prior_max_epochs: raise ValueError( f"Epoch count ({engine.state.epoch}) in checkpoint is larger than " f"the `engine.state.max_epochs` ({prior_max_epochs}) of engine. To further train from checkpoint, " diff --git a/monai/handlers/checkpoint_saver.py b/monai/handlers/checkpoint_saver.py index f365ff73c4..f7abca4aa0 100644 --- a/monai/handlers/checkpoint_saver.py +++ b/monai/handlers/checkpoint_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,10 +11,10 @@ import logging import warnings -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, Mapping, Optional from monai.config import IgniteInfo -from monai.utils import min_version, optional_import +from monai.utils import is_scalar, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") Checkpoint, _ = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint") @@ -126,7 +126,7 @@ def __init__(self, dirname: str, filename: Optional[str] = None): super().__init__(dirname=dirname, require_empty=False, atomic=False) self.filename = filename - def __call__(self, checkpoint: Dict, filename: str, metadata: Optional[Dict] = None) -> None: + def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mapping] = None) -> None: if self.filename is not None: filename = self.filename super().__call__(checkpoint=checkpoint, filename=filename, metadata=metadata) @@ -154,14 +154,20 @@ def _final_func(engine: Engine): def _score_func(engine: Engine): if isinstance(key_metric_name, str): metric_name = key_metric_name - elif hasattr(engine.state, "key_metric_name") and isinstance(engine.state.key_metric_name, str): - metric_name = engine.state.key_metric_name + elif hasattr(engine.state, "key_metric_name"): + metric_name = engine.state.key_metric_name # type: ignore else: raise ValueError( f"Incompatible values: save_key_metric=True and key_metric_name={key_metric_name}." ) - - return (-1 if key_metric_negative_sign else 1) * engine.state.metrics[metric_name] + metric = engine.state.metrics[metric_name] + if not is_scalar(metric): + warnings.warn( + "key metric is not a scalar value, skip metric comparison and don't save a model." + "please use other metrics as key metric, or change the `reduction` mode to 'mean'." + ) + return -1 + return (-1 if key_metric_negative_sign else 1) * metric if key_metric_filename is not None and key_metric_n_saved > 1: raise ValueError("if using fixed filename to save the best metric model, we should only save 1 model.") diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 815be87754..bb37e36826 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -55,9 +55,15 @@ def __init__( batch_transform: a callable that is used to extract the `meta_data` dictionary of the input images from `ignite.engine.state.batch`. the purpose is to get the input filenames from the `meta_data` and store with classification results together. + `engine.state` and `batch_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. output_transform: a callable that is used to extract the model prediction data from `ignite.engine.state.output`. the first dimension of its output will be treated as the batch dimension. each item in the batch will be saved individually. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. name: identifier of logging.logger to use, defaulting to `engine.logger`. save_rank: only the handler on specified rank will save to CSV file in multi-gpus validation, default to 0. @@ -92,7 +98,14 @@ def attach(self, engine: Engine) -> None: if not engine.has_event_handler(self._finalize, Events.EPOCH_COMPLETED): engine.add_event_handler(Events.EPOCH_COMPLETED, self._finalize) - def _started(self, engine: Engine) -> None: + def _started(self, _engine: Engine) -> None: + """ + Initialize internal buffers. + + Args: + _engine: Ignite Engine, unused argument. + + """ self._outputs = [] self._filenames = [] @@ -114,12 +127,12 @@ def __call__(self, engine: Engine) -> None: o = o.detach() self._outputs.append(o) - def _finalize(self, engine: Engine) -> None: + def _finalize(self, _engine: Engine) -> None: """ All gather classification results from ranks and save to CSV file. Args: - engine: Ignite Engine, it can be a trainer, validator or evaluator. + _engine: Ignite Engine, unused argument. """ ws = idist.get_world_size() if self.save_rank >= ws: diff --git a/monai/handlers/confusion_matrix.py b/monai/handlers/confusion_matrix.py index 368aacc6cb..e3fc4bfbf1 100644 --- a/monai/handlers/confusion_matrix.py +++ b/monai/handlers/confusion_matrix.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,11 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable +from typing import Callable, Union from monai.handlers.ignite_metric import IgniteMetric from monai.metrics import ConfusionMatrixMetric -from monai.metrics.utils import MetricReduction +from monai.utils.enums import MetricReduction class ConfusionMatrix(IgniteMetric): @@ -25,6 +25,8 @@ def __init__( self, include_background: bool = True, metric_name: str = "hit_rate", + compute_sample: bool = False, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: @@ -40,11 +42,17 @@ def __init__( ``"informedness"``, ``"markedness"``] Some of the metrics have multiple aliases (as shown in the wikipedia page aforementioned), and you can also input those names instead. + compute_sample: when reducing, if ``True``, each sample's metric will be computed based on each confusion matrix first. + if ``False``, compute reduction on the confusion matrices first, defaults to ``False``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. - for example: if `ignite.engine.state.output` is `{"pred": xxx, "label": xxx, "other": xxx}`, - output_transform can be `lambda x: (x["pred"], x["label"])`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. save_details: whether to save metric computation details per image, for example: TP/TN/FP/FN of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. @@ -54,12 +62,8 @@ def __init__( metric_fn = ConfusionMatrixMetric( include_background=include_background, metric_name=metric_name, - compute_sample=False, - reduction=MetricReduction.MEAN, + compute_sample=compute_sample, + reduction=reduction, ) self.metric_name = metric_name - super().__init__( - metric_fn=metric_fn, - output_transform=output_transform, - save_details=save_details, - ) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) diff --git a/monai/handlers/decollate_batch.py b/monai/handlers/decollate_batch.py index 4e99fc6f04..a0d0ef3ad2 100644 --- a/monai/handlers/decollate_batch.py +++ b/monai/handlers/decollate_batch.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -88,7 +88,7 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - if self.batch_transform is not None: + if self.batch_transform is not None and isinstance(engine.state.batch, (list, dict)): engine.state.batch = self.batch_transform(engine.state.batch) - if self.output_transform is not None: + if self.output_transform is not None and isinstance(engine.state.output, (list, dict)): engine.state.output = self.output_transform(engine.state.output) diff --git a/monai/handlers/earlystop_handler.py b/monai/handlers/earlystop_handler.py index e194b50d59..8d57526676 100644 --- a/monai/handlers/earlystop_handler.py +++ b/monai/handlers/earlystop_handler.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/handlers/garbage_collector.py b/monai/handlers/garbage_collector.py index fffca2a740..74ccac8a72 100644 --- a/monai/handlers/garbage_collector.py +++ b/monai/handlers/garbage_collector.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -42,6 +42,7 @@ class GarbageCollector: """ def __init__(self, trigger_event: str = "epoch", log_level: int = 10): + self.trigger_event: Events if isinstance(trigger_event, Events): self.trigger_event = trigger_event elif trigger_event.lower() == "epoch": diff --git a/monai/handlers/hausdorff_distance.py b/monai/handlers/hausdorff_distance.py index a25ef04383..739c9e9935 100644 --- a/monai/handlers/hausdorff_distance.py +++ b/monai/handlers/hausdorff_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional +from typing import Callable, Optional, Union from monai.handlers.ignite_metric import IgniteMetric from monai.metrics import HausdorffDistanceMetric @@ -27,6 +27,7 @@ def __init__( distance_metric: str = "euclidean", percentile: Optional[float] = None, directed: bool = False, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: @@ -41,11 +42,15 @@ def __init__( percentile of the Hausdorff Distance rather than the maximum result will be achieved. Defaults to ``None``. directed: whether to calculate directed Hausdorff distance. Defaults to ``False``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. - for example: if `ignite.engine.state.output` is `{"pred": xxx, "label": xxx, "other": xxx}`, - output_transform can be `lambda x: (x["pred"], x["label"])`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. save_details: whether to save metric computation details per image, for example: hausdorff distance of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. @@ -55,10 +60,6 @@ def __init__( distance_metric=distance_metric, percentile=percentile, directed=directed, - reduction=MetricReduction.MEAN, - ) - super().__init__( - metric_fn=metric_fn, - output_transform=output_transform, - save_details=save_details, + reduction=reduction, ) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py index cbf84e4626..f28923af68 100644 --- a/monai/handlers/ignite_metric.py +++ b/monai/handlers/ignite_metric.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -41,18 +41,16 @@ class IgniteMetric(Metric): # type: ignore[valid-type, misc] # due to optional_ output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. - for example: if `ignite.engine.state.output` is `{"pred": xxx, "label": xxx, "other": xxx}`, - output_transform can be `lambda x: (x["pred"], x["label"])`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. save_details: whether to save metric computation details per image, for example: mean_dice of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. """ def __init__( - self, - metric_fn: CumulativeIterationMetric, - output_transform: Callable = lambda x: x, - save_details: bool = True, + self, metric_fn: CumulativeIterationMetric, output_transform: Callable = lambda x: x, save_details: bool = True ) -> None: self._is_reduced: bool = False self.metric_fn = metric_fn @@ -101,9 +99,13 @@ def compute(self) -> Any: if self.save_details: if self._engine is None or self._name is None: raise RuntimeError("please call the attach() function to connect expected engine first.") - self._engine.state.metric_details[self._name] = self.metric_fn.get_buffer() + self._engine.state.metric_details[self._name] = self.metric_fn.get_buffer() # type: ignore - return result.item() if isinstance(result, torch.Tensor) else result + if isinstance(result, torch.Tensor): + result = result.squeeze() + if result.ndim == 0: + result = result.item() + return result def attach(self, engine: Engine, name: str) -> None: """ @@ -120,4 +122,4 @@ def attach(self, engine: Engine, name: str) -> None: self._engine = engine self._name = name if self.save_details and not hasattr(engine.state, "metric_details"): - engine.state.metric_details = {} + engine.state.metric_details = {} # type: ignore diff --git a/monai/handlers/lr_schedule_handler.py b/monai/handlers/lr_schedule_handler.py index 3e57ac7bbd..db186bd73d 100644 --- a/monai/handlers/lr_schedule_handler.py +++ b/monai/handlers/lr_schedule_handler.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/handlers/mean_dice.py b/monai/handlers/mean_dice.py index ba5805fc19..c5609c6746 100644 --- a/monai/handlers/mean_dice.py +++ b/monai/handlers/mean_dice.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable +from typing import Callable, Union from monai.handlers.ignite_metric import IgniteMetric from monai.metrics import DiceMetric @@ -24,6 +24,7 @@ class MeanDice(IgniteMetric): def __init__( self, include_background: bool = True, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: @@ -32,20 +33,20 @@ def __init__( Args: include_background: whether to include dice computation on the first channel of the predicted output. Defaults to True. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. - for example: if `ignite.engine.state.output` is `{"pred": xxx, "label": xxx, "other": xxx}`, - output_transform can be `lambda x: (x["pred"], x["label"])`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. save_details: whether to save metric computation details per image, for example: mean dice of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. See also: :py:meth:`monai.metrics.meandice.compute_meandice` """ - metric_fn = DiceMetric(include_background=include_background, reduction=MetricReduction.MEAN) - super().__init__( - metric_fn=metric_fn, - output_transform=output_transform, - save_details=save_details, - ) + metric_fn = DiceMetric(include_background=include_background, reduction=reduction) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) diff --git a/monai/handlers/metric_logger.py b/monai/handlers/metric_logger.py index 64553955b7..350d1978de 100644 --- a/monai/handlers/metric_logger.py +++ b/monai/handlers/metric_logger.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -57,6 +57,9 @@ class MetricLogger: Args: loss_transform: Converts the `output` value from the trainer's state into a loss value + `engine.state` and `loss_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. metric_transform: Converts the metric value coming from the trainer/evaluator's state into a storable value evaluator: Optional evaluator to consume metric results from at the end of its evaluation run """ diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py index 97b080b244..1ec26eece7 100644 --- a/monai/handlers/metrics_saver.py +++ b/monai/handlers/metrics_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -50,6 +50,9 @@ class MetricsSaver: batch_transform: a callable that is used to extract the `meta_data` dictionary of the input images from `ignite.engine.state.batch` if saving metric details. the purpose is to get the input filenames from the `meta_data` and store with metric details together. + `engine.state` and `batch_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. summary_ops: expected computation operations to generate the summary report. it can be: None, "*" or list of strings, default to None. None - don't generate summary report for every expected metric_details. @@ -102,7 +105,14 @@ def attach(self, engine: Engine) -> None: engine.add_event_handler(Events.ITERATION_COMPLETED, self._get_filenames) engine.add_event_handler(Events.EPOCH_COMPLETED, self) - def _started(self, engine: Engine) -> None: + def _started(self, _engine: Engine) -> None: + """ + Initialize internal buffers. + + Args: + _engine: Ignite Engine, unused argument. + + """ self._filenames = [] def _get_filenames(self, engine: Engine) -> None: @@ -132,10 +142,12 @@ def __call__(self, engine: Engine) -> None: if self.metrics is not None and len(engine.state.metrics) > 0: _metrics = {k: v for k, v in engine.state.metrics.items() if k in self.metrics or "*" in self.metrics} _metric_details = {} - if self.metric_details is not None and len(engine.state.metric_details) > 0: - for k, v in engine.state.metric_details.items(): - if k in self.metric_details or "*" in self.metric_details: - _metric_details[k] = v + if hasattr(engine.state, "metric_details"): + details = engine.state.metric_details # type: ignore + if self.metric_details is not None and len(details) > 0: + for k, v in details.items(): + if k in self.metric_details or "*" in self.metric_details: + _metric_details[k] = v write_metrics_reports( save_dir=self.save_dir, diff --git a/monai/handlers/mlflow_handler.py b/monai/handlers/mlflow_handler.py new file mode 100644 index 0000000000..9475b2be35 --- /dev/null +++ b/monai/handlers/mlflow_handler.py @@ -0,0 +1,193 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence + +import torch + +from monai.config import IgniteInfo +from monai.utils import min_version, optional_import + +Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") +mlflow, _ = optional_import("mlflow") + +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + +DEFAULT_TAG = "Loss" + + +class MLFlowHandler: + """ + MLFlowHandler defines a set of Ignite Event-handlers for the MLFlow tracking logics. + It can be used for any Ignite Engine(trainer, validator and evaluator). + And it can track both epoch level and iteration level logging, then MLFlow can store + the data and visualize. + The expected data source is Ignite ``engine.state.output`` and ``engine.state.metrics``. + + Default behaviors: + - When EPOCH_COMPLETED, track each dictionary item in + ``engine.state.metrics`` in MLFlow. + - When ITERATION_COMPLETED, track expected item in + ``self.output_transform(engine.state.output)`` in MLFlow, default to `Loss`. + + Usage example is available in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb. + + Args: + tracking_uri: connects to a tracking URI. can also set the `MLFLOW_TRACKING_URI` environment + variable to have MLflow find a URI from there. in both cases, the URI can either be + a HTTP/HTTPS URI for a remote server, a database connection string, or a local path + to log data to a directory. The URI defaults to path `mlruns`. + for more details: https://mlflow.org/docs/latest/python_api/mlflow.html#mlflow.set_tracking_uri. + epoch_logger: customized callable logger for epoch level logging with MLFlow. + Must accept parameter "engine", use default logger if None. + iteration_logger: customized callable logger for iteration level logging with MLFlow. + Must accept parameter "engine", use default logger if None. + output_transform: a callable that is used to transform the + ``ignite.engine.state.output`` into a scalar to track, or a dictionary of {key: scalar}. + By default this value logging happens when every iteration completed. + The default behavior is to track loss from output[0] as output is a decollated list + and we replicated loss value for every item of the decollated list. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. + global_epoch_transform: a callable that is used to customize global epoch number. + For example, in evaluation, the evaluator engine might want to track synced epoch number + with the trainer engine. + state_attributes: expected attributes from `engine.state`, if provided, will extract them + when epoch completed. + tag_name: when iteration output is a scalar, `tag_name` is used to track, defaults to `'Loss'`. + + For more details of MLFlow usage, please refer to: https://mlflow.org/docs/latest/index.html. + + """ + + def __init__( + self, + tracking_uri: Optional[str] = None, + epoch_logger: Optional[Callable[[Engine], Any]] = None, + iteration_logger: Optional[Callable[[Engine], Any]] = None, + output_transform: Callable = lambda x: x[0], + global_epoch_transform: Callable = lambda x: x, + state_attributes: Optional[Sequence[str]] = None, + tag_name: str = DEFAULT_TAG, + ) -> None: + if tracking_uri is not None: + mlflow.set_tracking_uri(tracking_uri) + + self.epoch_logger = epoch_logger + self.iteration_logger = iteration_logger + self.output_transform = output_transform + self.global_epoch_transform = global_epoch_transform + self.state_attributes = state_attributes + self.tag_name = tag_name + + def attach(self, engine: Engine) -> None: + """ + Register a set of Ignite Event-Handlers to a specified Ignite engine. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if not engine.has_event_handler(self.start, Events.STARTED): + engine.add_event_handler(Events.STARTED, self.start) + if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): + engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) + if not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): + engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed) + + def start(self) -> None: + """ + Check MLFlow status and start if not active. + + """ + if mlflow.active_run() is None: + mlflow.start_run() + + def close(self) -> None: + """ + Stop current running logger of MLFlow. + + """ + mlflow.end_run() + + def epoch_completed(self, engine: Engine) -> None: + """ + Handler for train or validation/evaluation epoch completed Event. + Track epoch level log, default values are from Ignite `engine.state.metrics` dict. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if self.epoch_logger is not None: + self.epoch_logger(engine) + else: + self._default_epoch_log(engine) + + def iteration_completed(self, engine: Engine) -> None: + """ + Handler for train or validation/evaluation iteration completed Event. + Track iteration level log. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if self.iteration_logger is not None: + self.iteration_logger(engine) + else: + self._default_iteration_log(engine) + + def _default_epoch_log(self, engine: Engine) -> None: + """ + Execute epoch level log operation. + Default to track the values from Ignite `engine.state.metrics` dict and + track the values of specified attributes of `engine.state`. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + + """ + log_dict = engine.state.metrics + if not log_dict: + return + + current_epoch = self.global_epoch_transform(engine.state.epoch) + mlflow.log_metrics(log_dict, step=current_epoch) + + if self.state_attributes is not None: + attrs = {attr: getattr(engine.state, attr, None) for attr in self.state_attributes} + mlflow.log_metrics(attrs, step=current_epoch) + + def _default_iteration_log(self, engine: Engine) -> None: + """ + Execute iteration log operation based on Ignite `engine.state.output` data. + Log the values from `self.output_transform(engine.state.output)`. + Since `engine.state.output` is a decollated list and we replicated the loss value for every item + of the decollated list, the default behavior is to track the loss from `output[0]`. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + + """ + loss = self.output_transform(engine.state.output) + if loss is None: + return + + if not isinstance(loss, dict): + loss = {self.tag_name: loss.item() if isinstance(loss, torch.Tensor) else loss} + + mlflow.log_metrics(loss, step=engine.state.iteration) diff --git a/monai/handlers/nvtx_handlers.py b/monai/handlers/nvtx_handlers.py index aba7a7ec0e..327c156f63 100644 --- a/monai/handlers/nvtx_handlers.py +++ b/monai/handlers/nvtx_handlers.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -50,9 +50,7 @@ class RangeHandler: """ def __init__( - self, - events: Union[str, Tuple[Union[str, Events], Union[str, Events]]], - msg: Optional[str] = None, + self, events: Union[str, Tuple[Union[str, Events], Union[str, Events]]], msg: Optional[str] = None ) -> None: self.events = self.resolve_events(events) if msg is None: @@ -73,10 +71,7 @@ def resolve_events(self, events: Union[str, Tuple]) -> Tuple[Events, Events]: if len(events) == 1: return self.create_paired_events(events[0]) if len(events) == 2: - return ( - self.get_event(events[0]), - self.get_event(events[1]), - ) + return self.get_event(events[0]), self.get_event(events[1]) raise ValueError(f"Exactly two Ignite events should be provided [received {len(events)}].") def create_paired_events(self, event: str) -> Tuple[Events, Events]: @@ -84,22 +79,11 @@ def create_paired_events(self, event: str) -> Tuple[Events, Events]: Create pair of Ignite events from a event prefix name """ event = event.upper() - event_prefix = { - "": "", - "ENGINE": "", - "EPOCH": "EPOCH_", - "ITERATION": "ITERATION_", - "BATCH": "GET_BATCH_", - } - return ( - self.get_event(event_prefix[event] + "STARTED"), - self.get_event(event_prefix[event] + "COMPLETED"), - ) + event_prefix = {"": "", "ENGINE": "", "EPOCH": "EPOCH_", "ITERATION": "ITERATION_", "BATCH": "GET_BATCH_"} + return self.get_event(event_prefix[event] + "STARTED"), self.get_event(event_prefix[event] + "COMPLETED") def get_event(self, event: Union[str, Events]) -> Events: - if isinstance(event, str): - event = event.upper() - return Events[event] + return Events[event.upper()] if isinstance(event, str) else event def attach(self, engine: Engine) -> None: """ @@ -126,10 +110,8 @@ class RangePushHandler: msg: ASCII message to associate with range """ - def __init__(self, event: Events, msg: Optional[str] = None) -> None: - if isinstance(event, str): - event = event.upper() - self.event = Events[event] + def __init__(self, event: Union[str, Events], msg: Optional[str] = None) -> None: + self.event = Events[event.upper()] if isinstance(event, str) else event if msg is None: msg = self.event.name self.msg = msg @@ -156,10 +138,8 @@ class RangePopHandler: msg: ASCII message to associate with range """ - def __init__(self, event: Events) -> None: - if isinstance(event, str): - event = event.upper() - self.event = Events[event] + def __init__(self, event: Union[str, Events]) -> None: + self.event = Events[event.upper()] if isinstance(event, str) else event def attach(self, engine: Engine) -> None: """ @@ -181,10 +161,8 @@ class MarkHandler: msg: ASCII message to associate with range """ - def __init__(self, event: Events, msg: Optional[str] = None) -> None: - if isinstance(event, str): - event = event.upper() - self.event = Events[event] + def __init__(self, event: Union[str, Events], msg: Optional[str] = None) -> None: + self.event = Events[event.upper()] if isinstance(event, str) else event if msg is None: msg = self.event.name self.msg = msg diff --git a/monai/handlers/parameter_scheduler.py b/monai/handlers/parameter_scheduler.py index b6eb35562f..c0e18edcd0 100644 --- a/monai/handlers/parameter_scheduler.py +++ b/monai/handlers/parameter_scheduler.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/handlers/postprocessing.py b/monai/handlers/postprocessing.py index 05c6bd414d..4a89c86f47 100644 --- a/monai/handlers/postprocessing.py +++ b/monai/handlers/postprocessing.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -63,9 +63,7 @@ def __call__(self, engine: Engine) -> None: """ if not isinstance(engine.state.batch, list) or not isinstance(engine.state.output, list): engine.state.batch, engine.state.output = engine_apply_transform( - batch=engine.state.batch, - output=engine.state.output, - transform=self.transform, + batch=engine.state.batch, output=engine.state.output, transform=self.transform ) else: for i, (b, o) in enumerate(zip(engine.state.batch, engine.state.output)): diff --git a/monai/handlers/regression_metrics.py b/monai/handlers/regression_metrics.py index f203439f40..bf4ac3af1d 100644 --- a/monai/handlers/regression_metrics.py +++ b/monai/handlers/regression_metrics.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,29 +23,30 @@ class MeanSquaredError(IgniteMetric): def __init__( self, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: """ Args: + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. - for example: if `ignite.engine.state.output` is `{"pred": xxx, "label": xxx, "other": xxx}`, - output_transform can be `lambda x: (x["pred"], x["label"])`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. save_details: whether to save metric computation details per image, for example: mean squared error of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. See also: :py:class:`monai.metrics.MSEMetric` """ - metric_fn = MSEMetric(reduction=MetricReduction.MEAN) - super().__init__( - metric_fn=metric_fn, - output_transform=output_transform, - save_details=save_details, - ) + metric_fn = MSEMetric(reduction=reduction) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) class MeanAbsoluteError(IgniteMetric): @@ -55,25 +56,30 @@ class MeanAbsoluteError(IgniteMetric): def __init__( self, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: """ Args: - output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. - save_details: whether to save metric computation details per image, for example: mean absolute error of every image. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. + output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then + construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or + lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. + save_details: whether to save metric computation details per image, for example: mean squared error of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. See also: :py:class:`monai.metrics.MAEMetric` """ - metric_fn = MAEMetric(reduction=MetricReduction.MEAN) - super().__init__( - metric_fn=metric_fn, - output_transform=output_transform, - save_details=save_details, - ) + metric_fn = MAEMetric(reduction=reduction) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) class RootMeanSquaredError(IgniteMetric): @@ -83,25 +89,30 @@ class RootMeanSquaredError(IgniteMetric): def __init__( self, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: """ Args: - output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. - save_details: whether to save metric computation details per image, for example: root mean squared error of every image. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. + output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then + construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or + lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. + save_details: whether to save metric computation details per image, for example: mean squared error of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. See also: :py:class:`monai.metrics.RMSEMetric` """ - metric_fn = RMSEMetric(reduction=MetricReduction.MEAN) - super().__init__( - metric_fn=metric_fn, - output_transform=output_transform, - save_details=save_details, - ) + metric_fn = RMSEMetric(reduction=reduction) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) class PeakSignalToNoiseRatio(IgniteMetric): @@ -112,6 +123,7 @@ class PeakSignalToNoiseRatio(IgniteMetric): def __init__( self, max_val: Union[int, float], + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: @@ -120,17 +132,21 @@ def __init__( Args: max_val: The dynamic range of the images/volumes (i.e., the difference between the maximum and the minimum allowed values e.g. 255 for a uint8 image). - output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. - save_details: whether to save metric computation details per image, for example: PSNR of every image. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. + output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then + construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or + lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. + save_details: whether to save metric computation details per image, for example: mean squared error of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, See also: :py:class:`monai.metrics.PSNRMetric` """ - metric_fn = PSNRMetric(max_val=max_val, reduction=MetricReduction.MEAN) - super().__init__( - metric_fn=metric_fn, - output_transform=output_transform, - save_details=save_details, - ) + metric_fn = PSNRMetric(max_val=max_val, reduction=reduction) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) diff --git a/monai/handlers/roc_auc.py b/monai/handlers/roc_auc.py index 98c8c8f8bc..31b046064e 100644 --- a/monai/handlers/roc_auc.py +++ b/monai/handlers/roc_auc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -36,8 +36,9 @@ class ROCAUC(IgniteMetric): # type: ignore[valid-type, misc] # due to optional output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. - for example: if `ignite.engine.state.output` is `{"pred": xxx, "label": xxx, "other": xxx}`, - output_transform can be `lambda x: (x["pred"], x["label"])`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. Note: ROCAUC expects y to be comprised of 0's and 1's. @@ -45,14 +46,6 @@ class ROCAUC(IgniteMetric): # type: ignore[valid-type, misc] # due to optional """ - def __init__( - self, - average: Union[Average, str] = Average.MACRO, - output_transform: Callable = lambda x: x, - ) -> None: + def __init__(self, average: Union[Average, str] = Average.MACRO, output_transform: Callable = lambda x: x) -> None: metric_fn = ROCAUCMetric(average=Average(average)) - super().__init__( - metric_fn=metric_fn, - output_transform=output_transform, - save_details=False, - ) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False) diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index 535f58945b..79ebfd3a22 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -26,7 +26,7 @@ Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") -@deprecated(since="0.6.0", removed="0.7.0", msg_suffix="Please consider using `SaveImage[d]` transform instead.") +@deprecated(since="0.6.0", removed="0.9.0", msg_suffix="Please consider using `SaveImage[d]` transform instead.") class SegmentationSaver: """ Event handler triggered on completing every iteration to save the segmentation predictions into files. @@ -113,9 +113,15 @@ def __init__( batch_transform: a callable that is used to extract the `meta_data` dictionary of the input images from `ignite.engine.state.batch`. the purpose is to extract necessary information from the meta data: filename, affine, original_shape, etc. + `engine.state` and `batch_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. output_transform: a callable that is used to extract the model prediction data from `ignite.engine.state.output`. the first dimension of its output will be treated as the batch dimension. each item in the batch will be saved individually. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. name: identifier of logging.logger to use, defaulting to `engine.logger`. """ diff --git a/monai/handlers/smartcache_handler.py b/monai/handlers/smartcache_handler.py index e3adcbf4a0..56fee78b1d 100644 --- a/monai/handlers/smartcache_handler.py +++ b/monai/handlers/smartcache_handler.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index d5756074fc..e90d0ebd10 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,7 +11,7 @@ import logging import warnings -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence import torch @@ -31,7 +31,7 @@ class StatsHandler: """ StatsHandler defines a set of Ignite Event-handlers for all the log printing logics. - It's can be used for any Ignite Engine(trainer, validator and evaluator). + It can be used for any Ignite Engine(trainer, validator and evaluator). And it can support logging for epoch level and iteration level with pre-defined loggers. Default behaviors: @@ -39,6 +39,9 @@ class StatsHandler: - When ITERATION_COMPLETED, logs ``self.output_transform(engine.state.output)`` using ``self.logger``. + Usage example is available in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb. + """ def __init__( @@ -47,6 +50,7 @@ def __init__( iteration_print_logger: Optional[Callable[[Engine], Any]] = None, output_transform: Callable = lambda x: x[0], global_epoch_transform: Callable = lambda x: x, + state_attributes: Optional[Sequence[str]] = None, name: Optional[str] = None, tag_name: str = DEFAULT_TAG, key_var_format: str = DEFAULT_KEY_VAL_FORMAT, @@ -65,9 +69,14 @@ def __init__( By default this value logging happens when every iteration completed. The default behavior is to print loss from output[0] as output is a decollated list and we replicated loss value for every item of the decollated list. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. global_epoch_transform: a callable that is used to customize global epoch number. For example, in evaluation, the evaluator engine might want to print synced epoch number with the trainer engine. + state_attributes: expected attributes from `engine.state`, if provided, will extract them + when epoch completed. name: identifier of logging.logger to use, defaulting to ``engine.logger``. tag_name: when iteration output is a scalar, tag_name is used to print tag_name: scalar_value to logger. Defaults to ``'Loss'``. @@ -80,6 +89,7 @@ def __init__( self.iteration_print_logger = iteration_print_logger self.output_transform = output_transform self.global_epoch_transform = global_epoch_transform + self.state_attributes = state_attributes self.logger = logging.getLogger(name) self._name = name @@ -108,7 +118,7 @@ def attach(self, engine: Engine) -> None: def epoch_completed(self, engine: Engine) -> None: """ Handler for train or validation/evaluation epoch completed Event. - Print epoch level log, default values are from Ignite state.metrics dict. + Print epoch level log, default values are from Ignite `engine.state.metrics` dict. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. @@ -122,7 +132,7 @@ def epoch_completed(self, engine: Engine) -> None: def iteration_completed(self, engine: Engine) -> None: """ Handler for train or validation/evaluation iteration completed Event. - Print iteration level log, default values are from Ignite state.logs dict. + Print iteration level log, default values are from Ignite `engine.state.output`. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. @@ -133,14 +143,14 @@ def iteration_completed(self, engine: Engine) -> None: else: self._default_iteration_print(engine) - def exception_raised(self, engine: Engine, e: Exception) -> None: + def exception_raised(self, _engine: Engine, e: Exception) -> None: """ Handler for train or validation/evaluation exception raised Event. Print the exception information and traceback. This callback may be skipped because the logic with Ignite can only trigger the first attached handler for `EXCEPTION_RAISED` event. Args: - engine: Ignite Engine, it can be a trainer, validator or evaluator. + _engine: Ignite Engine, unused argument. e: the exception caught in Ignite during engine.run(). """ @@ -149,39 +159,46 @@ def exception_raised(self, engine: Engine, e: Exception) -> None: def _default_epoch_print(self, engine: Engine) -> None: """ - Execute epoch level log operation based on Ignite engine.state data. - print the values from Ignite state.metrics dict. + Execute epoch level log operation. + Default to print the values from Ignite `engine.state.metrics` dict and + print the values of specified attributes of `engine.state`. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - prints_dict = engine.state.metrics - if not prints_dict: - return current_epoch = self.global_epoch_transform(engine.state.epoch) - out_str = f"Epoch[{current_epoch}] Metrics -- " - for name in sorted(prints_dict): - value = prints_dict[name] - out_str += self.key_var_format.format(name, value) - self.logger.info(out_str) + prints_dict = engine.state.metrics + if prints_dict is not None and len(prints_dict) > 0: + out_str = f"Epoch[{current_epoch}] Metrics -- " + for name in sorted(prints_dict): + value = prints_dict[name] + out_str += self.key_var_format.format(name, value) if is_scalar(value) else f"{name}: {str(value)}" + self.logger.info(out_str) if ( hasattr(engine.state, "key_metric_name") and hasattr(engine.state, "best_metric") and hasattr(engine.state, "best_metric_epoch") ): - out_str = f"Key metric: {engine.state.key_metric_name} " - out_str += f"best value: {engine.state.best_metric} at epoch: {engine.state.best_metric_epoch}" - self.logger.info(out_str) + out_str = f"Key metric: {engine.state.key_metric_name} " # type: ignore + out_str += f"best value: {engine.state.best_metric} " # type: ignore + out_str += f"at epoch: {engine.state.best_metric_epoch}" # type: ignore + self.logger.info(out_str) + + if self.state_attributes is not None and len(self.state_attributes) > 0: + out_str = "State values: " + for attr in self.state_attributes: + out_str += f"{attr}: {getattr(engine.state, attr, None)} " + self.logger.info(out_str) def _default_iteration_print(self, engine: Engine) -> None: """ - Execute iteration log operation based on Ignite engine.state data. - Print the values from Ignite state.logs dict. - The default behavior is to print loss from output[0] as output is a decollated list and we replicated loss - value for every item of the decollated list. + Execute iteration log operation based on Ignite `engine.state.output` data. + Print the values from `self.output_transform(engine.state.output)`. + Since `engine.state.output` is a decollated list and we replicated the loss value for every item + of the decollated list, the default behavior is to print the loss from `output[0]`. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. @@ -220,7 +237,9 @@ def _default_iteration_print(self, engine: Engine) -> None: return # no value to print num_iterations = engine.state.epoch_length - current_iteration = (engine.state.iteration - 1) % num_iterations + 1 + current_iteration = engine.state.iteration + if num_iterations is not None: + current_iteration = (current_iteration - 1) % num_iterations + 1 current_epoch = engine.state.epoch num_epochs = engine.state.max_epochs diff --git a/monai/handlers/surface_distance.py b/monai/handlers/surface_distance.py index 4fc5b5a60a..77f0debfe9 100644 --- a/monai/handlers/surface_distance.py +++ b/monai/handlers/surface_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable +from typing import Callable, Union from monai.handlers.ignite_metric import IgniteMetric from monai.metrics import SurfaceDistanceMetric @@ -26,6 +26,7 @@ def __init__( include_background: bool = False, symmetric: bool = False, distance_metric: str = "euclidean", + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: @@ -38,11 +39,15 @@ def __init__( `seg_pred` and `seg_gt`. Defaults to ``False``. distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] the metric used to compute surface distance. Defaults to ``"euclidean"``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. - for example: if `ignite.engine.state.output` is `{"pred": xxx, "label": xxx, "other": xxx}`, - output_transform can be `lambda x: (x["pred"], x["label"])`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. save_details: whether to save metric computation details per image, for example: surface dice of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. @@ -51,10 +56,6 @@ def __init__( include_background=include_background, symmetric=symmetric, distance_metric=distance_metric, - reduction=MetricReduction.MEAN, - ) - super().__init__( - metric_fn=metric_fn, - output_transform=output_transform, - save_details=save_details, + reduction=reduction, ) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index a3a0bf76b8..dcf60973b0 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,7 +10,7 @@ # limitations under the License. import warnings -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence import numpy as np import torch @@ -20,6 +20,7 @@ from monai.visualize import plot_2d_or_3d_image Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") + if TYPE_CHECKING: from ignite.engine import Engine from torch.utils.tensorboard import SummaryWriter @@ -35,8 +36,8 @@ class TensorBoardHandler: Base class for the handlers to write data into TensorBoard. Args: - summary_writer: user can specify TensorBoard SummaryWriter, - default to create a new writer. + summary_writer: user can specify TensorBoard or TensorBoardX SummaryWriter, + default to create a new TensorBoard writer. log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`. """ @@ -64,7 +65,7 @@ def close(self): class TensorBoardStatsHandler(TensorBoardHandler): """ TensorBoardStatsHandler defines a set of Ignite Event-handlers for all the TensorBoard logics. - It's can be used for any Ignite Engine(trainer, validator and evaluator). + It can be used for any Ignite Engine(trainer, validator and evaluator). And it can support both epoch level and iteration level with pre-defined TensorBoard event writer. The expected data source is Ignite ``engine.state.output`` and ``engine.state.metrics``. @@ -73,6 +74,10 @@ class TensorBoardStatsHandler(TensorBoardHandler): ``engine.state.metrics`` to TensorBoard. - When ITERATION_COMPLETED, write each dictionary item in ``self.output_transform(engine.state.output)`` to TensorBoard. + + Usage example is available in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb. + """ def __init__( @@ -85,12 +90,13 @@ def __init__( iteration_interval: int = 1, output_transform: Callable = lambda x: x[0], global_epoch_transform: Callable = lambda x: x, + state_attributes: Optional[Sequence[str]] = None, tag_name: str = DEFAULT_TAG, ) -> None: """ Args: - summary_writer: user can specify TensorBoard SummaryWriter, - default to create a new writer. + summary_writer: user can specify TensorBoard or TensorBoardX SummaryWriter, + default to create a new TensorBoard writer. log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`. epoch_event_writer: customized callable TensorBoard writer for epoch level. Must accept parameter "engine" and "summary_writer", use default event writer if None. @@ -104,9 +110,14 @@ def __init__( By default this value plotting happens when every iteration completed. The default behavior is to print loss from output[0] as output is a decollated list and we replicated loss value for every item of the decollated list. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. global_epoch_transform: a callable that is used to customize global epoch number. For example, in evaluation, the evaluator engine might want to use trainer engines epoch number when plotting epoch vs metric curves. + state_attributes: expected attributes from `engine.state`, if provided, will extract them + when epoch completed. tag_name: when iteration output is a scalar, tag_name is used to plot, defaults to ``'Loss'``. """ super().__init__(summary_writer=summary_writer, log_dir=log_dir) @@ -116,6 +127,7 @@ def __init__( self.iteration_interval = iteration_interval self.output_transform = output_transform self.global_epoch_transform = global_epoch_transform + self.state_attributes = state_attributes self.tag_name = tag_name def attach(self, engine: Engine) -> None: @@ -136,7 +148,7 @@ def attach(self, engine: Engine) -> None: def epoch_completed(self, engine: Engine) -> None: """ Handler for train or validation/evaluation epoch completed Event. - Write epoch level events, default values are from Ignite state.metrics dict. + Write epoch level events, default values are from Ignite `engine.state.metrics` dict. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. @@ -150,7 +162,7 @@ def epoch_completed(self, engine: Engine) -> None: def iteration_completed(self, engine: Engine) -> None: """ Handler for train or validation/evaluation iteration completed Event. - Write iteration level events, default values are from Ignite state.logs dict. + Write iteration level events, default values are from Ignite `engine.state.output`. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. @@ -161,31 +173,53 @@ def iteration_completed(self, engine: Engine) -> None: else: self._default_iteration_writer(engine, self._writer) + def _write_scalar(self, _engine: Engine, writer: SummaryWriter, tag: str, value: Any, step: int) -> None: + """ + Write scale value into TensorBoard. + Default to call `SummaryWriter.add_scalar()`. + + Args: + _engine: Ignite Engine, unused argument. + writer: TensorBoard or TensorBoardX writer, passed or created in TensorBoardHandler. + tag: tag name in the TensorBoard. + value: value of the scalar data for current step. + step: index of current step. + + """ + writer.add_scalar(tag, value, step) + def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter) -> None: """ - Execute epoch level event write operation based on Ignite engine.state data. - Default is to write the values from Ignite state.metrics dict. + Execute epoch level event write operation. + Default to write the values from Ignite `engine.state.metrics` dict and + write the values of specified attributes of `engine.state`. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. - writer: TensorBoard writer, created in TensorBoardHandler. + writer: TensorBoard or TensorBoardX writer, passed or created in TensorBoardHandler. """ current_epoch = self.global_epoch_transform(engine.state.epoch) summary_dict = engine.state.metrics for name, value in summary_dict.items(): - writer.add_scalar(name, value, current_epoch) + if is_scalar(value): + self._write_scalar(engine, writer, name, value, current_epoch) + + if self.state_attributes is not None: + for attr in self.state_attributes: + self._write_scalar(engine, writer, attr, getattr(engine.state, attr, None), current_epoch) writer.flush() def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> None: """ - Execute iteration level event write operation based on Ignite engine.state data. - The default behavior is to print loss from output[0] as output is a decollated list and we replicated loss - value for every item of the decollated list. + Execute iteration level event write operation based on Ignite `engine.state.output` data. + Extract the values from `self.output_transform(engine.state.output)`. + Since `engine.state.output` is a decollated list and we replicated the loss value for every item + of the decollated list, the default behavior is to track the loss from `output[0]`. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. - writer: TensorBoard writer, created in TensorBoardHandler. + writer: TensorBoard or TensorBoardX writer, passed or created in TensorBoardHandler. """ loss = self.output_transform(engine.state.output) @@ -202,12 +236,20 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> No " {}:{}".format(name, type(value)) ) continue # not plot multi dimensional output - writer.add_scalar( - name, value.item() if isinstance(value, torch.Tensor) else value, engine.state.iteration + self._write_scalar( + _engine=engine, + writer=writer, + tag=name, + value=value.item() if isinstance(value, torch.Tensor) else value, + step=engine.state.iteration, ) elif is_scalar(loss): # not printing multi dimensional output - writer.add_scalar( - self.tag_name, loss.item() if isinstance(loss, torch.Tensor) else loss, engine.state.iteration + self._write_scalar( + _engine=engine, + writer=writer, + tag=self.tag_name, + value=loss.item() if isinstance(loss, torch.Tensor) else loss, + step=engine.state.iteration, ) else: warnings.warn( @@ -225,6 +267,7 @@ class TensorBoardImageHandler(TensorBoardHandler): 2D output (shape in Batch, channel, H, W) will be shown as simple image using the first element in the batch, for 3D to ND output (shape in Batch, channel, H, W, D) input, each of ``self.max_channels`` number of images' last three dimensions will be shown as animated GIF along the last axis (typically Depth). + And if writer is from TensorBoardX, data has 3 channels and `max_channels=3`, will plot as RGB video. It can be used for any Ignite Engine (trainer, validator and evaluator). User can easily add it to engine for any expected Event, for example: ``EPOCH_COMPLETED``, @@ -239,6 +282,9 @@ class TensorBoardImageHandler(TensorBoardHandler): - Expects ``output_transform(engine.state.output)`` to return a torch tensor in format (y_pred[N, channel, ...], loss). + Usage example is available in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb. + """ def __init__( @@ -252,12 +298,13 @@ def __init__( global_iter_transform: Callable = lambda x: x, index: int = 0, max_channels: int = 1, + frame_dim: int = -3, max_frames: int = 64, ) -> None: """ Args: - summary_writer: user can specify TensorBoard SummaryWriter, - default to create a new writer. + summary_writer: user can specify TensorBoard or TensorBoardX SummaryWriter, + default to create a new TensorBoard writer. log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`. interval: plot content from engine.state every N epochs or every N iterations, default is 1. epoch_level: plot content from engine.state every N epochs or N iterations. `True` is epoch level, @@ -266,13 +313,21 @@ def __init__( then construct `(image, label)` pair. for example: if `ignite.engine.state.batch` is `{"image": xxx, "label": xxx, "other": xxx}`, `batch_transform` can be `lambda x: (x["image"], x["label"])`. will use the result to plot image from `result[0][index]` and plot label from `result[1][index]`. + `engine.state` and `batch_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. output_transform: a callable that is used to extract the `predictions` data from `ignite.engine.state.output`, will use the result to plot output from `result[index]`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. global_iter_transform: a callable that is used to customize global step number for TensorBoard. For example, in evaluation, the evaluator engine needs to know current epoch from trainer. index: plot which element in a data batch, default is the first element. max_channels: number of channels to plot. - max_frames: number of frames for 2D-t plot. + frame_dim: if plotting 3D image as GIF, specify the dimension used as frames, + expect input data shape as `NCHWD`, default to `-3` (the first spatial dim) + max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`. """ super().__init__(summary_writer=summary_writer, log_dir=log_dir) self.interval = interval @@ -281,6 +336,7 @@ def __init__( self.output_transform = output_transform self.global_iter_transform = global_iter_transform self.index = index + self.frame_dim = frame_dim self.max_frames = max_frames self.max_channels = max_channels @@ -320,13 +376,14 @@ def __call__(self, engine: Engine) -> None: ) plot_2d_or_3d_image( # add batch dim and plot the first item - show_images[None], - step, - self._writer, - 0, - self.max_channels, - self.max_frames, - "input_0", + data=show_images[None], + step=step, + writer=self._writer, + index=0, + max_channels=self.max_channels, + frame_dim=self.frame_dim, + max_frames=self.max_frames, + tag="input_0", ) show_labels = self.batch_transform(engine.state.batch)[1][self.index] @@ -338,7 +395,16 @@ def __call__(self, engine: Engine) -> None: "batch_transform(engine.state.batch)[1] must be None or one of " f"(numpy.ndarray, torch.Tensor) but is {type(show_labels).__name__}." ) - plot_2d_or_3d_image(show_labels[None], step, self._writer, 0, self.max_channels, self.max_frames, "input_1") + plot_2d_or_3d_image( + data=show_labels[None], + step=step, + writer=self._writer, + index=0, + max_channels=self.max_channels, + frame_dim=self.frame_dim, + max_frames=self.max_frames, + tag="input_1", + ) show_outputs = self.output_transform(engine.state.output)[self.index] if isinstance(show_outputs, torch.Tensor): @@ -349,6 +415,15 @@ def __call__(self, engine: Engine) -> None: "output_transform(engine.state.output) must be None or one of " f"(numpy.ndarray, torch.Tensor) but is {type(show_outputs).__name__}." ) - plot_2d_or_3d_image(show_outputs[None], step, self._writer, 0, self.max_channels, self.max_frames, "output") + plot_2d_or_3d_image( + data=show_outputs[None], + step=step, + writer=self._writer, + index=0, + max_channels=self.max_channels, + frame_dim=self.frame_dim, + max_frames=self.max_frames, + tag="output", + ) self._writer.flush() diff --git a/monai/handlers/transform_inverter.py b/monai/handlers/transform_inverter.py deleted file mode 100644 index 83b5f56396..0000000000 --- a/monai/handlers/transform_inverter.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union - -import torch - -from monai.config import IgniteInfo, KeysCollection -from monai.engines.utils import CommonKeys, IterationEvents -from monai.transforms import Invertd, InvertibleTransform -from monai.utils import deprecated, ensure_tuple, ensure_tuple_rep, min_version, optional_import - -Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") -if TYPE_CHECKING: - from ignite.engine import Engine -else: - Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") - - -@deprecated(since="0.6.0", removed="0.7.0", msg_suffix="Please consider using `Invertd` transform instead.") -class TransformInverter: - """ - Ignite handler to automatically invert `transforms`. - It takes `engine.state.output` as the input data and uses the transforms information from `engine.state.batch`. - Expect both `engine.state.output` and `engine.state.batch` to be list of dictionaries data. - The inverted data is in-place saved back to `engine.state.output` with key: "{output_key}". - And the inverted meta dict will be stored in `engine.state.batch` - with key: "{meta_keys}" or "{key}_{meta_key_postfix}". - - .. deprecated:: 0.6.0 - Use :class:`monai.transforms.Invertd` instead. - - """ - - def __init__( - self, - transform: InvertibleTransform, - output_keys: KeysCollection = CommonKeys.PRED, - batch_keys: KeysCollection = CommonKeys.IMAGE, - meta_keys: Optional[KeysCollection] = None, - batch_meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", - nearest_interp: Union[bool, Sequence[bool]] = True, - to_tensor: Union[bool, Sequence[bool]] = True, - device: Union[Union[str, torch.device], Sequence[Union[str, torch.device]]] = "cpu", - post_func: Union[Callable, Sequence[Callable]] = lambda x: x, - num_workers: Optional[int] = 0, - ) -> None: - """ - Args: - transform: a callable data transform on input data. - output_keys: the key of expected data in `ignite.engine.output`, invert transforms on it. - it also can be a list of keys, will invert transform for each of them. - Default to "pred". it's in-place operation. - batch_keys: the key of input data in `ignite.engine.batch`. will get the applied transforms - for this input data, then invert them for the expected data with `output_keys`. - It can also be a list of keys, each matches to the `output_keys` data. default to "image". - meta_keys: explicitly indicate the key for the inverted meta data dictionary. - the meta data is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `{key}_{meta_key_postfix}`. - batch_meta_keys: the key of the meta data of input data in `ignite.engine.batch`, - will get the `affine`, `data_shape`, etc. - the meta data is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `{orig_key}_{meta_key_postfix}`. - meta data will also be inverted and stored in `meta_keys`. - meta_key_postfix: if `orig_meta_keys` is None, use `{orig_key}_{meta_key_postfix}` to to fetch the - meta data from dict, if `meta_keys` is None, use `{key}_{meta_key_postfix}`. - default is `meta_dict`, the meta data is a dictionary object. - For example, to handle orig_key `image`, read/write `affine` matrices from the - metadata `image_meta_dict` dictionary's `affine` field. - the inverted meta dict will be stored with key: "{key}_{meta_key_postfix}". - nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms, - default to `True`. If `False`, use the same interpolation mode as the original transform. - it also can be a list of bool, each matches to the `output_keys` data. - to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`. - it also can be a list of bool, each matches to the `output_keys` data. - device: if converted to Tensor, move the inverted results to target device before `post_func`, - default to "cpu", it also can be a list of string or `torch.device`, - each matches to the `output_keys` data. - post_func: post processing for the inverted data, should be a callable function. - it also can be a list of callable, each matches to the `output_keys` data. - - """ - self.inverter = Invertd( - keys=output_keys, - transform=transform, - orig_keys=batch_keys, - meta_keys=meta_keys, - orig_meta_keys=batch_meta_keys, - meta_key_postfix=meta_key_postfix, - nearest_interp=nearest_interp, - to_tensor=to_tensor, - device=device, - post_func=post_func, - ) - self.output_keys = ensure_tuple(output_keys) - self.meta_keys = ensure_tuple_rep(None, len(self.output_keys)) if meta_keys is None else ensure_tuple(meta_keys) - if len(self.output_keys) != len(self.meta_keys): - raise ValueError("meta_keys should have the same length as output_keys.") - self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.output_keys)) - - def attach(self, engine: Engine) -> None: - """ - Args: - engine: Ignite Engine, it can be a trainer, validator or evaluator. - """ - engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self) - - def __call__(self, engine: Engine) -> None: - """ - Args: - engine: Ignite Engine, it can be a trainer, validator or evaluator. - """ - if not isinstance(engine.state.batch, list) or not isinstance(engine.state.output, list): - warnings.warn("inverter requires `engine.state.batch` and `engine.state.output` to be lists.") - else: - for i, (b, o) in enumerate(zip(engine.state.batch, engine.state.output)): - # combine `batch` and `output` to temporarily act as 1 dict for postprocessing - data = dict(b) - data.update(o) - ret = self.inverter(data) - - for output_key, meta_key, meta_key_postfix in zip( - self.output_keys, self.meta_keys, self.meta_key_postfix - ): - # save the inverted data into state.output - engine.state.output[i][output_key] = ret.get(output_key) - # save the inverted meta dict into state.batch - meta_key = meta_key or f"{output_key}_{meta_key_postfix}" - if meta_key in ret: - # FIXME: we save inverted meta dict into `batch` to be compatible with `SegmentationSaver` - # will deprecate both handlers soon - engine.state.batch[i][meta_key] = ret.get(meta_key) diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 13f23c582a..23947d3054 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,13 +11,13 @@ import os from collections import OrderedDict -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Dict, Optional, Sequence, Union import numpy as np import torch -from monai.config import IgniteInfo, KeysCollection -from monai.utils import deprecated, ensure_tuple, get_torch_version_tuple, look_up_option, min_version, optional_import +from monai.config import IgniteInfo, KeysCollection, PathLike +from monai.utils import ensure_tuple, look_up_option, min_version, optional_import idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") if TYPE_CHECKING: @@ -25,14 +25,7 @@ else: Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") -__all__ = [ - "stopping_fn_from_metric", - "stopping_fn_from_loss", - "evenly_divisible_all_gather", - "string_list_all_gather", - "write_metrics_reports", - "from_engine", -] +__all__ = ["stopping_fn_from_metric", "stopping_fn_from_loss", "write_metrics_reports", "from_engine"] def stopping_fn_from_metric(metric_name: str): @@ -52,85 +45,13 @@ def stopping_fn_from_loss(): """ def stopping_fn(engine: Engine): - return -engine.state.output + return -engine.state.output # type:ignore return stopping_fn -@deprecated(since="0.6.0", removed="0.7.0", msg_suffix="The API had been moved to monai.utils module.") -def evenly_divisible_all_gather(data: torch.Tensor) -> torch.Tensor: - """ - Utility function for distributed data parallel to pad at first dim to make it evenly divisible and all_gather. - - Args: - data: source tensor to pad and execute all_gather in distributed data parallel. - - Note: - The input data on different ranks must have exactly same `dtype`. - - .. versionchanged:: 0.6.0 - The API had been moved to `monai.utils`. - - """ - if not isinstance(data, torch.Tensor): - raise ValueError("input data must be PyTorch Tensor.") - - if idist.get_world_size() <= 1: - return data - - # make sure the data is evenly-divisible on multi-GPUs - length = data.shape[0] - all_lens = idist.all_gather(length) - max_len = max(all_lens) - if length < max_len: - size = [max_len - length] + list(data.shape[1:]) - data = torch.cat([data, data.new_full(size, 0)], dim=0) - # all gather across all processes - data = idist.all_gather(data) - # delete the padding NaN items - return torch.cat([data[i * max_len : i * max_len + l, ...] for i, l in enumerate(all_lens)], dim=0) - - -@deprecated(since="0.6.0", removed="0.7.0", msg_suffix="The API had been moved to monai.utils module.") -def string_list_all_gather(strings: List[str]) -> List[str]: - """ - Utility function for distributed data parallel to all gather a list of strings. - Note that if the item in `strings` is longer than 1024 chars, it will be truncated to 1024: - https://pytorch.org/ignite/v0.4.5/distributed.html#ignite.distributed.utils.all_gather. - - Args: - strings: a list of strings to all gather. - - .. versionchanged:: 0.6.0 - The API had been moved to `monai.utils`. - - """ - world_size = idist.get_world_size() - if world_size <= 1: - return strings - - result: List[List[str]] = [[] for _ in range(world_size)] - # get length of strings - length = len(strings) - all_lens = idist.all_gather(length) - max_len = max(all_lens) - # pad the item to make sure the same length - if length < max_len: - strings += ["" for _ in range(max_len - length)] - - if get_torch_version_tuple() <= (1, 6): - raise RuntimeError("string all_gather can not be supported in PyTorch < 1.7.0.") - - for s in strings: - gathered = idist.all_gather(s) - for i, g in enumerate(gathered): - if len(g) > 0: - result[i].append(g) - return [i for k in result for i in k] - - def write_metrics_reports( - save_dir: str, + save_dir: PathLike, images: Optional[Sequence[str]], metrics: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]], metric_details: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]], @@ -204,12 +125,12 @@ class mean median max 5percentile 95percentile notnans if summary_ops is not None: supported_ops = OrderedDict( { - "mean": lambda x: np.nanmean(x), - "median": lambda x: np.nanmedian(x), - "max": lambda x: np.nanmax(x), - "min": lambda x: np.nanmin(x), + "mean": np.nanmean, + "median": np.nanmedian, + "max": np.nanmax, + "min": np.nanmin, "90percentile": lambda x: np.nanpercentile(x[0], x[1]), - "std": lambda x: np.nanstd(x), + "std": np.nanstd, "notnans": lambda x: (~np.isnan(x)).sum(), } ) @@ -223,7 +144,7 @@ def _compute_op(op: str, d: np.ndarray): return c_op(d) threshold = int(op.split("percentile")[0]) - return supported_ops["90percentile"]((d, threshold)) + return supported_ops["90percentile"]((d, threshold)) # type: ignore with open(os.path.join(save_dir, f"{k}_summary.csv"), "w") as f: f.write(f"class{deli}{deli.join(ops)}\n") diff --git a/monai/handlers/validation_handler.py b/monai/handlers/validation_handler.py index 6214461a4f..171c901fbb 100644 --- a/monai/handlers/validation_handler.py +++ b/monai/handlers/validation_handler.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/inferers/__init__.py b/monai/inferers/__init__.py index 030344728d..20d829297f 100644 --- a/monai/inferers/__init__.py +++ b/monai/inferers/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index ecb2c2c178..c7b70e06ca 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -42,13 +42,7 @@ class Inferer(ABC): """ @abstractmethod - def __call__( - self, - inputs: torch.Tensor, - network: Callable[..., torch.Tensor], - *args: Any, - **kwargs: Any, - ): + def __call__(self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any): """ Run inference on `inputs` with the `network` model. @@ -75,13 +69,7 @@ class SimpleInferer(Inferer): def __init__(self) -> None: Inferer.__init__(self) - def __call__( - self, - inputs: torch.Tensor, - network: Callable[..., torch.Tensor], - *args: Any, - **kwargs: Any, - ): + def __call__(self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any): """Unified callable function API of Inferers. Args: @@ -161,11 +149,7 @@ def __init__( self.device = device def __call__( - self, - inputs: torch.Tensor, - network: Callable[..., torch.Tensor], - *args: Any, - **kwargs: Any, + self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any ) -> torch.Tensor: """ @@ -217,13 +201,7 @@ def __init__(self, cam_name: str, target_layers: str, class_idx: Optional[int] = self.args = args self.kwargs = kwargs - def __call__( # type: ignore - self, - inputs: torch.Tensor, - network: nn.Module, - *args: Any, - **kwargs: Any, - ): + def __call__(self, inputs: torch.Tensor, network: nn.Module, *args: Any, **kwargs: Any): # type: ignore """Unified callable function API of Inferers. Args: diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 0ca53529c7..b27e13eef6 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 1221cd3041..1922996fb6 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .contrastive import ContrastiveLoss from .deform import BendingEnergyLoss from .dice import ( Dice, @@ -18,7 +19,6 @@ GeneralizedDiceLoss, GeneralizedWassersteinDiceLoss, MaskedDiceLoss, - dice, dice_ce, dice_focal, generalized_dice, diff --git a/monai/losses/contrastive.py b/monai/losses/contrastive.py new file mode 100644 index 0000000000..aef596a492 --- /dev/null +++ b/monai/losses/contrastive.py @@ -0,0 +1,88 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch.nn import functional as F +from torch.nn.modules.loss import _Loss + +from monai.utils import deprecated_arg + + +class ContrastiveLoss(_Loss): + + """ + Compute the Contrastive loss defined in: + + Chen, Ting, et al. "A simple framework for contrastive learning of visual representations." International + conference on machine learning. PMLR, 2020. (http://proceedings.mlr.press/v119/chen20j.html) + + Adapted from: + https://github.com/Sara-Ahmed/SiT/blob/1aacd6adcd39b71efc903d16b4e9095b97dda76f/losses.py#L5 + + """ + + @deprecated_arg(name="reduction", since="0.8", msg_suffix="`reduction` is no longer supported.") + def __init__(self, temperature: float = 0.5, batch_size: int = 1, reduction="sum") -> None: + """ + Args: + temperature: Can be scaled between 0 and 1 for learning from negative samples, ideally set to 0.5. + batch_size: The number of samples. + + Raises: + ValueError: When an input of dimension length > 2 is passed + ValueError: When input and target are of different shapes + + .. deprecated:: 0.8.0 + + `reduction` is no longer supported. + + """ + super().__init__() + + self.batch_size = batch_size + self.temperature = temperature + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be B[F]. + target: the shape should be B[F]. + """ + if len(target.shape) > 2 or len(input.shape) > 2: + raise ValueError( + f"Either target or input has dimensions greater than 2 where target " + f"shape is ({target.shape}) and input shape is ({input.shape})" + ) + + if target.shape != input.shape: + raise ValueError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") + + temperature_tensor = torch.tensor(self.temperature).to(input.device) + + norm_i = F.normalize(input, dim=1) + norm_j = F.normalize(target, dim=1) + + negatives_mask = ~torch.eye(self.batch_size * 2, self.batch_size * 2, dtype=torch.bool) + negatives_mask = torch.tensor(negatives_mask, dtype=torch.float) + negatives_mask = torch.clone(torch.as_tensor(negatives_mask)).to(input.device) + + repr = torch.cat([norm_i, norm_j], dim=0) + sim_matrix = F.cosine_similarity(repr.unsqueeze(1), repr.unsqueeze(0), dim=2) + sim_ij = torch.diag(sim_matrix, self.batch_size) + sim_ji = torch.diag(sim_matrix, -self.batch_size) + + positives = torch.cat([sim_ij, sim_ji], dim=0) + nominator = torch.exp(positives / temperature_tensor) + denominator = negatives_mask * torch.exp(sim_matrix / temperature_tensor) + + loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1)) + + return torch.sum(loss_partial) / (2 * self.batch_size) diff --git a/monai/losses/deform.py b/monai/losses/deform.py index d96fa1440a..0f5e263a53 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -52,12 +52,12 @@ class BendingEnergyLoss(_Loss): DeepReg (https://github.com/DeepRegNet/DeepReg) """ - def __init__( - self, - reduction: Union[LossReduction, str] = LossReduction.MEAN, - ) -> None: + def __init__(self, normalize: bool = False, reduction: Union[LossReduction, str] = LossReduction.MEAN) -> None: """ Args: + normalize: + Whether to divide out spatial sizes in order to make the computation roughly + invariant to image scale (i.e. vector field sampling resolution). Defaults to False. reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. @@ -65,7 +65,8 @@ def __init__( - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. """ - super(BendingEnergyLoss, self).__init__(reduction=LossReduction(reduction).value) + super().__init__(reduction=LossReduction(reduction).value) + self.normalize = normalize def forward(self, pred: torch.Tensor) -> torch.Tensor: """ @@ -77,20 +78,35 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: """ if pred.ndim not in [3, 4, 5]: - raise ValueError(f"expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") + raise ValueError(f"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") for i in range(pred.ndim - 2): if pred.shape[-i - 1] <= 4: - raise ValueError("all spatial dimensions must > 4, got pred of shape {pred.shape}") + raise ValueError(f"All spatial dimensions must be > 4, got spatial dimensions {pred.shape[2:]}") + if pred.shape[1] != pred.ndim - 2: + raise ValueError( + f"Number of vector components, {pred.shape[1]}, does not match number of spatial dimensions, {pred.ndim-2}" + ) # first order gradient first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)] + # spatial dimensions in a shape suited for broadcasting below + if self.normalize: + spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,)) + energy = torch.tensor(0) for dim_1, g in enumerate(first_order_gradient): dim_1 += 2 - energy = spatial_gradient(g, dim_1) ** 2 + energy + if self.normalize: + g *= pred.shape[dim_1] / spatial_dims + energy = energy + (spatial_gradient(g, dim_1) * pred.shape[dim_1]) ** 2 + else: + energy = energy + spatial_gradient(g, dim_1) ** 2 for dim_2 in range(dim_1 + 1, pred.ndim): - energy = 2 * spatial_gradient(g, dim_2) ** 2 + energy + if self.normalize: + energy = energy + 2 * (spatial_gradient(g, dim_2) * pred.shape[dim_2]) ** 2 + else: + energy = energy + 2 * spatial_gradient(g, dim_2) ** 2 if self.reduction == LossReduction.MEAN.value: energy = torch.mean(energy) # the batch and channel average diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 325c5300ea..7f2037ca54 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -27,16 +27,17 @@ class DiceLoss(_Loss): """ Compute average Dice loss between two tensors. It can support both multi-classes and multi-labels tasks. - Input logits `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]). - Axis N of `input` is expected to have logit predictions for each class rather than being image channels, - while the same axis of `target` can be 1 or N (one-hot format). The `smooth_nr` and `smooth_dr` parameters are - values added to the intersection and union components of the inter-over-union calculation to smooth results - respectively, these values should be small. The `include_background` class attribute can be set to False for - an instance of DiceLoss to exclude the first category (channel index 0) which is by convention assumed to be - background. If the non-background segmentations are small compared to the total image size they can get - overwhelmed by the signal from the background so excluding it in such cases helps convergence. + The data `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]). - Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks forVolumetric Medical Image Segmentation, 3DV, 2016. + Note that axis N of `input` is expected to be logits or probabilities for each class, if passing logits as input, + must set `sigmoid=True` or `softmax=True`, or specifying `other_act`. And the same axis of `target` + can be 1 or N (one-hot format). + + The `smooth_nr` and `smooth_dr` parameters are values added to the intersection and union components of + the inter-over-union calculation to smooth results respectively, these values should be small. + + The original paper: Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks forVolumetric + Medical Image Segmentation, 3DV, 2016. """ @@ -57,6 +58,8 @@ def __init__( """ Args: include_background: if False, channel index 0 (background category) is excluded from the calculation. + if the non-background segmentations are small compared to the total image size they can get overwhelmed + by the signal from the background so excluding it in such cases helps convergence. to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. sigmoid: if True, apply a sigmoid function to the prediction. softmax: if True, apply a softmax function to the prediction. @@ -111,6 +114,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: have different shapes. ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + Example: + >>> from monai.losses.dice import * # NOQA + >>> import torch + >>> from monai.losses.dice import DiceLoss + >>> B, C, H, W = 7, 5, 3, 2 + >>> input = torch.rand(B, C, H, W) + >>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long() + >>> target = one_hot(target_idx[:, None, ...], num_classes=C) + >>> self = DiceLoss(reduction='none') + >>> loss = self(input, target) + >>> assert np.broadcast_shapes(loss.shape, input.shape) == input.shape """ if self.sigmoid: input = torch.sigmoid(input) @@ -168,7 +182,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: f = torch.mean(f) # the batch and channel average elif self.reduction == LossReduction.SUM.value: f = torch.sum(f) # sum over the batch and channel dims - elif self.reduction != LossReduction.NONE.value: + elif self.reduction == LossReduction.NONE.value: + # If we are not computing voxelwise loss components at least + # make sure a none reduction maintains a broadcastable shape + broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2) + f = f.view(broadcast_shape) + else: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return f @@ -234,7 +253,6 @@ def __init__( other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute other activation layers, Defaults to ``None``. for example: `other_act = torch.tanh`. - squared_pred: use squared versions of targets and predictions in the denominator or not. w_type: {``"square"``, ``"simple"``, ``"uniform"``} Type of function to transform ground truth volume to a weight factor. Defaults to ``"square"``. reduction: {``"none"``, ``"mean"``, ``"sum"``} @@ -335,15 +353,21 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: b[infs] = 0.0 b[infs] = torch.max(b) - f: torch.Tensor = 1.0 - (2.0 * (intersection * w).sum(0 if self.batch else 1) + self.smooth_nr) / ( - (denominator * w).sum(0 if self.batch else 1) + self.smooth_dr - ) + final_reduce_dim = 0 if self.batch else 1 + numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr + denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr + f: torch.Tensor = 1.0 - (numer / denom) if self.reduction == LossReduction.MEAN.value: f = torch.mean(f) # the batch and channel average elif self.reduction == LossReduction.SUM.value: f = torch.sum(f) # sum over the batch and channel dims - elif self.reduction != LossReduction.NONE.value: + elif self.reduction == LossReduction.NONE.value: + # If we are not computing voxelwise loss components at least + # make sure a none reduction maintains a broadcastable shape + broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2) + f = f.view(broadcast_shape) + else: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return f @@ -419,7 +443,7 @@ def __init__( wass_loss(pred_score, grnd) # 0 """ - super(GeneralizedWassersteinDiceLoss, self).__init__(reduction=LossReduction(reduction).value) + super().__init__(reduction=LossReduction(reduction).value) if dist_matrix.shape[0] != dist_matrix.shape[1]: raise ValueError(f"dist_matrix must be C x C, got {dist_matrix.shape[0]} x {dist_matrix.shape[1]}.") @@ -478,7 +502,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: wass_dice_loss = torch.mean(wass_dice_loss) # the batch and channel average elif self.reduction == LossReduction.SUM.value: wass_dice_loss = torch.sum(wass_dice_loss) # sum over the batch and channel dims - elif self.reduction != LossReduction.NONE.value: + elif self.reduction == LossReduction.NONE.value: + # If we are not computing voxelwise loss components at least + # make sure a none reduction maintains a broadcastable shape + broadcast_shape = input.shape[0:2] + (1,) * (len(input.shape) - 2) + wass_dice_loss = wass_dice_loss.view(broadcast_shape) + else: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return wass_dice_loss @@ -536,10 +565,7 @@ def _compute_generalized_true_positive( flat_target_extended = torch.unsqueeze(flat_target, dim=1) alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1) - return torch.sum( - alpha_extended * (1.0 - wasserstein_distance_map), - dim=[1, 2], - ) + return torch.sum(alpha_extended * (1.0 - wasserstein_distance_map), dim=[1, 2]) def _compute_denominator( self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor @@ -556,10 +582,7 @@ def _compute_denominator( flat_target_extended = torch.unsqueeze(flat_target, dim=1) alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1) - return torch.sum( - alpha_extended * (2.0 - wasserstein_distance_map), - dim=[1, 2], - ) + return torch.sum(alpha_extended * (2.0 - wasserstein_distance_map), dim=[1, 2]) def _compute_alpha_generalized_true_positives(self, flat_target: torch.Tensor) -> torch.Tensor: """ @@ -657,10 +680,7 @@ def __init__( smooth_dr=smooth_dr, batch=batch, ) - self.cross_entropy = nn.CrossEntropyLoss( - weight=ce_weight, - reduction=reduction, - ) + self.cross_entropy = nn.CrossEntropyLoss(weight=ce_weight, reduction=reduction) if lambda_dice < 0.0: raise ValueError("lambda_dice should be no less than 0.0.") if lambda_ce < 0.0: @@ -815,11 +835,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: dice_loss = self.dice(input, target) focal_loss = self.focal(input, target) total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss - return total_loss -dice = Dice = DiceLoss +Dice = DiceLoss dice_ce = DiceCELoss dice_focal = DiceFocalLoss generalized_dice = GeneralizedDiceLoss diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index b4b3698e5b..bf31682748 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,11 +22,44 @@ class FocalLoss(_Loss): """ + FocalLoss is an extension of BCEWithLogitsLoss that down-weights loss from + high confidence correct predictions. + Reimplementation of the Focal Loss (with a build-in sigmoid activation) described in: - "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017 - "AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy", Zhu et al., Medical Physics 2018 + + Example: + >>> import torch + >>> from monai.losses import FocalLoss + >>> from torch.nn import BCEWithLogitsLoss + >>> shape = B, N, *DIMS = 2, 3, 5, 7, 11 + >>> input = torch.rand(*shape) + >>> target = torch.rand(*shape) + >>> # Demonstrate equivalence to BCE when gamma=0 + >>> fl_g0_criterion = FocalLoss(reduction='none', gamma=0) + >>> fl_g0_loss = fl_g0_criterion(input, target) + >>> bce_criterion = BCEWithLogitsLoss(reduction='none') + >>> bce_loss = bce_criterion(input, target) + >>> assert torch.allclose(fl_g0_loss, bce_loss) + >>> # Demonstrate "focus" by setting gamma > 0. + >>> fl_g2_criterion = FocalLoss(reduction='none', gamma=2) + >>> fl_g2_loss = fl_g2_criterion(input, target) + >>> # Mark easy and hard cases + >>> is_easy = (target > 0.7) & (input > 0.7) + >>> is_hard = (target > 0.7) & (input < 0.3) + >>> easy_loss_g0 = fl_g0_loss[is_easy].mean() + >>> hard_loss_g0 = fl_g0_loss[is_hard].mean() + >>> easy_loss_g2 = fl_g2_loss[is_easy].mean() + >>> hard_loss_g2 = fl_g2_loss[is_hard].mean() + >>> # Gamma > 0 causes the loss function to "focus" on the hard + >>> # cases. IE, easy cases are downweighted, so hard cases + >>> # receive a higher proportion of the loss. + >>> hard_to_easy_ratio_g2 = hard_loss_g2 / easy_loss_g2 + >>> hard_to_easy_ratio_g0 = hard_loss_g0 / easy_loss_g0 + >>> assert hard_to_easy_ratio_g2 > hard_to_easy_ratio_g0 """ def __init__( @@ -56,18 +89,14 @@ def __init__( - ``"sum"``: the output will be summed. Example: - .. code-block:: python - - import torch - from monai.losses import FocalLoss - - pred = torch.tensor([[1, 0], [0, 1], [1, 0]], dtype=torch.float32) - grnd = torch.tensor([[0], [1], [0]], dtype=torch.int64) - fl = FocalLoss(to_onehot_y=True) - fl(pred, grnd) - + >>> import torch + >>> from monai.losses import FocalLoss + >>> pred = torch.tensor([[1, 0], [0, 1], [1, 0]], dtype=torch.float32) + >>> grnd = torch.tensor([[0], [1], [0]], dtype=torch.int64) + >>> fl = FocalLoss(to_onehot_y=True) + >>> fl(pred, grnd) """ - super(FocalLoss, self).__init__(reduction=LossReduction(reduction).value) + super().__init__(reduction=LossReduction(reduction).value) self.include_background = include_background self.to_onehot_y = to_onehot_y self.gamma = gamma @@ -147,12 +176,25 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # Compute the loss mini-batch. # (1-p_t)^gamma * log(p_t) with reduced chance of overflow p = F.logsigmoid(-i * (t * 2.0 - 1.0)) - loss = torch.mean((p * self.gamma).exp() * ce, dim=-1) + flat_loss: torch.Tensor = (p * self.gamma).exp() * ce + + # Previously there was a mean over the last dimension, which did not + # return a compatible BCE loss. To maintain backwards compatible + # behavior we have a flag that performs this extra step, disable or + # parameterize if necessary. (Or justify why the mean should be there) + average_spatial_dims = True if self.reduction == LossReduction.SUM.value: - return loss.sum() - if self.reduction == LossReduction.NONE.value: - return loss - if self.reduction == LossReduction.MEAN.value: - return loss.mean() - raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + if average_spatial_dims: + flat_loss = flat_loss.mean(dim=-1) + loss = flat_loss.sum() + elif self.reduction == LossReduction.MEAN.value: + if average_spatial_dims: + flat_loss = flat_loss.mean(dim=-1) + loss = flat_loss.mean() + elif self.reduction == LossReduction.NONE.value: + spacetime_dims = input.shape[2:] + loss = flat_loss.reshape([b, n] + list(spacetime_dims)) + else: + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + return loss diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index eed5808aa3..b527522cd7 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -8,14 +8,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch from torch.nn import functional as F from torch.nn.modules.loss import _Loss from monai.networks.layers import gaussian_1d, separable_filtering -from monai.utils import LossReduction +from monai.utils import LossReduction, deprecated_arg +from monai.utils.module import look_up_option def make_rectangular_kernel(kernel_size: int) -> torch.Tensor: @@ -59,18 +60,20 @@ class LocalNormalizedCrossCorrelationLoss(_Loss): DeepReg (https://github.com/DeepRegNet/DeepReg) """ + @deprecated_arg(name="ndim", since="0.6", msg_suffix="Please use `spatial_dims` instead.") def __init__( self, - ndim: int = 3, + spatial_dims: int = 3, kernel_size: int = 3, kernel_type: str = "rectangular", reduction: Union[LossReduction, str] = LossReduction.MEAN, smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, + ndim: Optional[int] = None, ) -> None: """ Args: - ndim: number of spatial ndimensions, {``1``, ``2``, ``3``}. Defaults to 3. + spatial_dims: number of spatial dimensions, {``1``, ``2``, ``3``}. Defaults to 3. kernel_size: kernel spatial size, must be odd. kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``. reduction: {``"none"``, ``"mean"``, ``"sum"``} @@ -81,22 +84,24 @@ def __init__( - ``"sum"``: the output will be summed. smooth_nr: a small constant added to the numerator to avoid nan. smooth_dr: a small constant added to the denominator to avoid nan. + + .. deprecated:: 0.6.0 + ``ndim`` is deprecated, use ``spatial_dims``. """ - super(LocalNormalizedCrossCorrelationLoss, self).__init__(reduction=LossReduction(reduction).value) + super().__init__(reduction=LossReduction(reduction).value) - self.ndim = ndim - if self.ndim not in [1, 2, 3]: + if ndim is not None: + spatial_dims = ndim + self.ndim = spatial_dims + if self.ndim not in {1, 2, 3}: raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported") self.kernel_size = kernel_size if self.kernel_size % 2 == 0: raise ValueError(f"kernel_size must be odd, got {self.kernel_size}") - if kernel_type not in kernel_dict.keys(): - raise ValueError( - f'Unsupported kernel_type: {kernel_type}, available options are ["rectangular", "triangular", "gaussian"].' - ) - self.kernel = kernel_dict[kernel_type](self.kernel_size) + _kernel = look_up_option(kernel_type, kernel_dict) + self.kernel = _kernel(self.kernel_size) self.kernel_vol = self.get_kernel_vol() self.smooth_nr = float(smooth_nr) @@ -170,6 +175,7 @@ class GlobalMutualInformationLoss(_Loss): def __init__( self, + kernel_type: str = "gaussian", num_bins: int = 23, sigma_ratio: float = 0.5, reduction: Union[LossReduction, str] = LossReduction.MEAN, @@ -178,6 +184,19 @@ def __init__( ) -> None: """ Args: + kernel_type: {``"gaussian"``, ``"b-spline"``} + ``"gaussian"``: adapted from DeepReg + Reference: https://dspace.mit.edu/handle/1721.1/123142, Section 3.1, equation 3.1-3.5, Algorithm 1. + ``"b-spline"``: based on the method of Mattes et al [1,2] and adapted from ITK + References: + [1] "Nonrigid multimodality image registration" + D. Mattes, D. R. Haynor, H. Vesselle, T. Lewellen and W. Eubank + Medical Imaging 2001: Image Processing, 2001, pp. 1609-1620. + [2] "PET-CT Image Registration in the Chest Using Free-form Deformations" + D. Mattes, D. R. Haynor, H. Vesselle, T. Lewellen and W. Eubank + IEEE Transactions in Medical Imaging. Vol.22, No.1, + January 2003. pp.120-128. + num_bins: number of bins for intensity sigma_ratio: a hyper param for gaussian function reduction: {``"none"``, ``"mean"``, ``"sum"``} @@ -189,25 +208,99 @@ def __init__( smooth_nr: a small constant added to the numerator to avoid nan. smooth_dr: a small constant added to the denominator to avoid nan. """ - super(GlobalMutualInformationLoss, self).__init__(reduction=LossReduction(reduction).value) + super().__init__(reduction=LossReduction(reduction).value) if num_bins <= 0: raise ValueError("num_bins must > 0, got {num_bins}") bin_centers = torch.linspace(0.0, 1.0, num_bins) # (num_bins,) sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio - self.preterm = 1 / (2 * sigma ** 2) - self.bin_centers = bin_centers[None, None, ...] + self.kernel_type = look_up_option(kernel_type, ["gaussian", "b-spline"]) + self.num_bins = num_bins + self.kernel_type = kernel_type + if self.kernel_type == "gaussian": + self.preterm = 1 / (2 * sigma ** 2) + self.bin_centers = bin_centers[None, None, ...] self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) - def parzen_windowing(self, pred: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def parzen_windowing( + self, pred: torch.Tensor, target: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if self.kernel_type == "gaussian": + pred_weight, pred_probability = self.parzen_windowing_gaussian(pred) + target_weight, target_probability = self.parzen_windowing_gaussian(target) + elif self.kernel_type == "b-spline": + # a third order BSpline kernel is used for the pred image intensity PDF. + pred_weight, pred_probability = self.parzen_windowing_b_spline(pred, order=3) + # a zero order (box car) BSpline kernel is used for the target image intensity PDF. + target_weight, target_probability = self.parzen_windowing_b_spline(target, order=0) + else: + raise ValueError + return pred_weight, pred_probability, target_weight, target_probability + + def parzen_windowing_b_spline(self, img: torch.Tensor, order: int) -> Tuple[torch.Tensor, torch.Tensor]: """ + Parzen windowing with b-spline kernel (adapted from ITK) + Args: - pred: the shape should be B[NDHW]. + img: the shape should be B[NDHW]. + order: int. + """ + + # Compute binsize for the histograms. + # + # The binsize for the image intensities needs to be adjusted so that + # we can avoid dealing with boundary conditions using the cubic + # spline as the Parzen window. We do this by increasing the size + # of the bins so that the joint histogram becomes "padded" at the + # borders. Because we are changing the binsize, + # we also need to shift the minimum by the padded amount in order to + # avoid minimum values filling in our padded region. + # + # Note that there can still be non-zero bin values in the padded region, + # it's just that these bins will never be a central bin for the Parzen + # window. + _max, _min = torch.max(img), torch.min(img) + padding = 2 + bin_size = (_max - _min) / (self.num_bins - 2 * padding) + norm_min = torch.div(_min, bin_size) - padding + + # assign bin/window index to each voxel + window_term = torch.div(img, bin_size) - norm_min # B[NDHW] + # make sure the extreme values are in valid (non-padded) bins + window_term = torch.clamp(window_term, padding, self.num_bins - padding - 1) # B[NDHW] + window_term = window_term.reshape(window_term.shape[0], -1, 1) # (batch, num_sample, 1) + bins = torch.arange(self.num_bins, device=window_term.device).reshape(1, 1, -1) # (1, 1, num_bins) + sample_bin_matrix = torch.abs(bins - window_term) # (batch, num_sample, num_bins) + + # b-spleen kernel + # (4 - 6 * abs ** 2 + 3 * abs ** 3) / 6 when 0 <= abs < 1 + # (2 - abs) ** 3 / 6 when 1 <= abs < 2 + weight = torch.zeros_like(sample_bin_matrix, dtype=torch.float) # (batch, num_sample, num_bins) + if order == 0: + weight = weight + (sample_bin_matrix < 0.5) + (sample_bin_matrix == 0.5) * 0.5 + elif order == 3: + weight = ( + weight + (4 - 6 * sample_bin_matrix ** 2 + 3 * sample_bin_matrix ** 3) * (sample_bin_matrix < 1) / 6 + ) + weight = weight + (2 - sample_bin_matrix) ** 3 * (sample_bin_matrix >= 1) * (sample_bin_matrix < 2) / 6 + else: + raise ValueError(f"Do not support b-spline {order}-order parzen windowing") + + weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bins) + probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bins) + return weight, probability + + def parzen_windowing_gaussian(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Parzen windowing with gaussian kernel (adapted from DeepReg implementation) + Note: the input is expected to range between 0 and 1 + Args: + img: the shape should be B[NDHW]. """ - pred = torch.clamp(pred, 0, 1) - pred = pred.reshape(pred.shape[0], -1, 1) # (batch, num_sample, 1) + img = torch.clamp(img, 0, 1) + img = img.reshape(img.shape[0], -1, 1) # (batch, num_sample, 1) weight = torch.exp( - -self.preterm.to(pred) * (pred - self.bin_centers.to(pred)) ** 2 + -self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2 ) # (batch, num_sample, num_bin) weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bin) probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bin) @@ -223,11 +316,10 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ if target.shape != pred.shape: raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})") - wa, pa = self.parzen_windowing(pred) # (batch, num_sample, num_bin), (batch, 1, num_bin) - wb, pb = self.parzen_windowing(target) # (batch, num_sample, num_bin), (batch, 1, num_bin) - pab = torch.bmm(wa.permute(0, 2, 1), wb).div(wa.shape[1]) # (batch, num_bins, num_bins) + wa, pa, wb, pb = self.parzen_windowing(pred, target) # (batch, num_sample, num_bin), (batch, 1, num_bin) - papb = torch.bmm(pa.permute(0, 2, 1), pb) # (batch, num_bins, num_bins) + pab = torch.bmm(wa.permute(0, 2, 1), wb.to(wa)).div(wa.shape[1]) # (batch, num_bins, num_bins) + papb = torch.bmm(pa.permute(0, 2, 1), pb.to(pa)) # (batch, num_bins, num_bins) mi = torch.sum( pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2) ) # (batch) diff --git a/monai/losses/multi_scale.py b/monai/losses/multi_scale.py index 6f9326420b..5e80af30bc 100644 --- a/monai/losses/multi_scale.py +++ b/monai/losses/multi_scale.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,12 +21,7 @@ def make_gaussian_kernel(sigma: int) -> torch.Tensor: if sigma <= 0: raise ValueError(f"expecting positive sigma, got sigma={sigma}") - return gaussian_1d( - sigma=torch.tensor(sigma), - truncated=3, - approx="sampled", - normalize=False, - ) + return gaussian_1d(sigma=torch.tensor(sigma), truncated=3, approx="sampled", normalize=False) def make_cauchy_kernel(sigma: int) -> torch.Tensor: @@ -39,10 +34,7 @@ def make_cauchy_kernel(sigma: int) -> torch.Tensor: return k -kernel_fn_dict = { - "gaussian": make_gaussian_kernel, - "cauchy": make_cauchy_kernel, -} +kernel_fn_dict = {"gaussian": make_gaussian_kernel, "cauchy": make_cauchy_kernel} class MultiScaleLoss(_Loss): @@ -67,7 +59,7 @@ def __init__( scales: list of scalars or None, if None, do not apply any scaling. kernel: gaussian or cauchy. """ - super(MultiScaleLoss, self).__init__(reduction=LossReduction(reduction).value) + super().__init__(reduction=LossReduction(reduction).value) if kernel not in kernel_fn_dict.keys(): raise ValueError(f"got unsupported kernel type: {kernel}", "only support gaussian and cauchy") self.kernel_fn = kernel_fn_dict[kernel] diff --git a/monai/losses/spatial_mask.py b/monai/losses/spatial_mask.py index 387300e507..aa232f882e 100644 --- a/monai/losses/spatial_mask.py +++ b/monai/losses/spatial_mask.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index 1cc0e1d8d7..ee6d7d933b 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index c2197bdf2a..d18c20f7b2 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,6 +10,7 @@ # limitations under the License. from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix +from .cumulative_average import CumulativeAverage from .froc import compute_fp_tp_probs, compute_froc_curve_data, compute_froc_score from .hausdorff_distance import HausdorffDistanceMetric, compute_hausdorff_distance, compute_percent_hausdorff_distance from .meandice import DiceMetric, compute_meandice diff --git a/monai/metrics/confusion_matrix.py b/monai/metrics/confusion_matrix.py index 9568cf6028..dd1f81a4b4 100644 --- a/monai/metrics/confusion_matrix.py +++ b/monai/metrics/confusion_matrix.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -47,8 +47,9 @@ class ConfusionMatrixMetric(CumulativeIterationMetric): returned with the same order as input names when calling the class. compute_sample: when reducing, if ``True``, each sample's metric will be computed based on each confusion matrix first. if ``False``, compute reduction on the confusion matrices first, defaults to ``False``. - reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``} + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns [(metric, not_nans), ...]. If False, aggregate() returns [metric, ...]. Here `not_nans` count the number of not nans for True Positive, False Positive, True Negative and False Negative. @@ -99,11 +100,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor warnings.warn("As for classification task, compute_sample should be False.") self.compute_sample = False - return get_confusion_matrix( - y_pred=y_pred, - y=y, - include_background=self.include_background, - ) + return get_confusion_matrix(y_pred=y_pred, y=y, include_background=self.include_background) def aggregate(self): # type: ignore """ @@ -129,11 +126,7 @@ def aggregate(self): # type: ignore return results -def get_confusion_matrix( - y_pred: torch.Tensor, - y: torch.Tensor, - include_background: bool = True, -): +def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True): """ Compute confusion matrix. A tensor with the shape [BC4] will be returned. Where, the third dimension represents the number of true positive, false positive, true negative and false negative values for @@ -153,10 +146,7 @@ def get_confusion_matrix( """ if not include_background: - y_pred, y = ignore_background( - y_pred=y_pred, - y=y, - ) + y_pred, y = ignore_background(y_pred=y_pred, y=y) y = y.float() y_pred = y_pred.float() diff --git a/monai/metrics/cumulative_average.py b/monai/metrics/cumulative_average.py new file mode 100644 index 0000000000..090d65a44c --- /dev/null +++ b/monai/metrics/cumulative_average.py @@ -0,0 +1,72 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from monai.transforms import isnan +from monai.utils import convert_data_type + +from .metric import Cumulative + + +class CumulativeAverage(Cumulative): + """ + Cumulatively record data value and aggregate for the average value. + It supports single class or multi-class data, for example, + value can be 0.44 (a loss value) or [0.3, 0.4] (metrics of two classes). + It also supports distributed data parallel, sync data when aggregating. + For example, recording loss values and compute the overall average value in every 5 iterations: + + .. code-block:: python + + average = CumulativeAverage() + for i, d in enumerate(dataloader): + loss = ... + average.append(loss) + if i % 5 == 0: + print(f"cumulative average of loss: {average.aggregate()}") + average.reset() + + """ + + def __init__(self) -> None: + super().__init__() + self.sum = None + self.not_nans = None + + def reset(self): + """ + Reset all the running status, including buffers, sum, not nans count, etc. + + """ + super().reset() + self.sum = None + self.not_nans = None + + def aggregate(self): # type: ignore + """ + Sync data from all the ranks and compute the average value with previous sum value. + + """ + data = self.get_buffer() + + # compute SUM across the batch dimension + nans = isnan(data) + not_nans = convert_data_type((~nans), dtype=torch.float32)[0].sum(0) + data[nans] = 0 + f = data.sum(0) + + # clear the buffer for next update + super().reset() + self.sum = f if self.sum is None else (self.sum + f) + self.not_nans = not_nans if self.not_nans is None else (self.not_nans + not_nans) + + return self.sum / self.not_nans diff --git a/monai/metrics/froc.py b/monai/metrics/froc.py index faebbbf7a6..1c3e64579d 100644 --- a/monai/metrics/froc.py +++ b/monai/metrics/froc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -96,7 +96,7 @@ def compute_froc_curve_data( num_images: the number of images under evaluation. """ - if type(fp_probs) is not type(tp_probs): + if not isinstance(fp_probs, type(tp_probs)): raise AssertionError("fp and tp probs should have same type.") if isinstance(fp_probs, torch.Tensor): fp_probs = fp_probs.detach().cpu().numpy() @@ -116,9 +116,7 @@ def compute_froc_curve_data( def compute_froc_score( - fps_per_image: np.ndarray, - total_sensitivity: np.ndarray, - eval_thresholds: Tuple = (0.25, 0.5, 1, 2, 4, 8), + fps_per_image: np.ndarray, total_sensitivity: np.ndarray, eval_thresholds: Tuple = (0.25, 0.5, 1, 2, 4, 8) ): """ This function is modified from the official evaluation code of diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 12f3b49d32..082311aa67 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -42,9 +42,9 @@ class HausdorffDistanceMetric(CumulativeIterationMetric): percentile of the Hausdorff Distance rather than the maximum result will be achieved. Defaults to ``None``. directed: whether to calculate directed Hausdorff distance. Defaults to ``False``. - reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``} - Define the mode to reduce computation result. Defaults to ``"mean"``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. @@ -141,10 +141,7 @@ def compute_hausdorff_distance( """ if not include_background: - y_pred, y = ignore_background( - y_pred=y_pred, - y=y, - ) + y_pred, y = ignore_background(y_pred=y_pred, y=y) if isinstance(y, torch.Tensor): y = y.float() if isinstance(y_pred, torch.Tensor): @@ -172,10 +169,7 @@ def compute_hausdorff_distance( def compute_percent_hausdorff_distance( - edges_pred: np.ndarray, - edges_gt: np.ndarray, - distance_metric: str = "euclidean", - percentile: Optional[float] = None, + edges_pred: np.ndarray, edges_gt: np.ndarray, distance_metric: str = "euclidean", percentile: Optional[float] = None ): """ This function is used to compute the directed Hausdorff distance. diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 226c106f7e..1e6065b59c 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -35,9 +35,9 @@ class DiceMetric(CumulativeIterationMetric): Args: include_background: whether to skip Dice computation on the first channel of the predicted output. Defaults to ``True``. - reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``} - Define the mode to reduce computation result. Defaults to ``"mean"``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. @@ -77,11 +77,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor if dims < 3: raise ValueError("y_pred should have at least three dimensions.") # compute dice (BxC) for each channel for each batch - return compute_meandice( - y_pred=y_pred, - y=y, - include_background=self.include_background, - ) + return compute_meandice(y_pred=y_pred, y=y, include_background=self.include_background) def aggregate(self): # type: ignore """ @@ -97,11 +93,7 @@ def aggregate(self): # type: ignore return (f, not_nans) if self.get_not_nans else f -def compute_meandice( - y_pred: torch.Tensor, - y: torch.Tensor, - include_background: bool = True, -) -> torch.Tensor: +def compute_meandice(y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True) -> torch.Tensor: """Computes Dice score metric from full size Tensor and collects average. Args: @@ -122,10 +114,7 @@ def compute_meandice( """ if not include_background: - y_pred, y = ignore_background( - y_pred=y_pred, - y=y, - ) + y_pred, y = ignore_background(y_pred=y_pred, y=y) y = y.float() y_pred = y_pred.float() @@ -142,8 +131,4 @@ def compute_meandice( y_pred_o = torch.sum(y_pred, dim=reduce_axis) denominator = y_o + y_pred_o - return torch.where( - y_o > 0, - (2.0 * intersection) / denominator, - torch.tensor(float("nan"), device=y_o.device), - ) + return torch.where(y_o > 0, (2.0 * intersection) / denominator, torch.tensor(float("nan"), device=y_o.device)) diff --git a/monai/metrics/metric.py b/monai/metrics/metric.py index bb4aa7c343..60dcd0b52d 100644 --- a/monai/metrics/metric.py +++ b/monai/metrics/metric.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,121 +15,168 @@ import torch from monai.config import TensorOrList -from monai.utils import evenly_divisible_all_gather +from monai.utils import convert_data_type, evenly_divisible_all_gather + +__all__ = ["Metric", "IterationMetric", "Cumulative", "CumulativeIterationMetric"] class Metric(ABC): """ - Base class of all Metrics interface. - `__call__` is designed to execute metric computation. + Base class for metric computation for evaluating the performance of a model. + `__call__` is designed to execute the computation. """ @abstractmethod - def __call__(self, *args: Any, **kwds: Any): + def __call__(self, *args: Any, **kwargs: Any): """ - API to execute the metric computation. - + This method should take raw model outputs as inputs, and return values that measure the models' quality. """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") class IterationMetric(Metric): """ - Base class of Metrics interface for computation on a batch of tensors, usually the data of 1 iteration. - `__call__` is supposed to compute independent logic for several samples of `y_pred` and `y`(optional). - Usually, subclass only needs to implement the `_compute_tensor` function for computation process. - The input data shape should be `list of channel-first tensors` or a `batch-first tensor`. + Base class for metrics computation at the iteration level, that is, on a min-batch of samples + usually using the model outcome of one iteration. + + `__call__` is designed to handle `y_pred` and `y` (optional) in torch tensors or a list/tuple of tensors. + Subclasses typically implement the `_compute_tensor` function for the actual tensor computation logic. """ def __call__(self, y_pred: TensorOrList, y: Optional[TensorOrList] = None): # type: ignore """ - Execute basic computation for model prediction and ground truth. - It can support both `list of channel-first Tensor` and `batch-first Tensor`. - And users can execute on every batch of data, then accumulate the results, or - accumulate the original `y_pred` and `y`, then execute on the accumulated data. + Execute basic computation for model prediction `y_pred` and ground truth `y` (optional). + It supports inputs of a list of "channel-first" Tensor and a "batch-first" Tensor. Args: - y_pred: the model prediction data to compute, must be a list of `channel-first` Tensor + y_pred: the raw model prediction data at one iteration, must be a list of `channel-first` Tensor or a `batch-first` Tensor. y: the ground truth to compute, must be a list of `channel-first` Tensor or a `batch-first` Tensor. + Returns: + The computed metric values at the iteration level. + The output shape could be a `batch-first` tensor or a list of `batch-first` tensors. + When it's a list of tensors, each item in the list can represent a specific type of metric. + """ ret: TensorOrList + # handling a list of channel-first data if isinstance(y_pred, (list, tuple)) or isinstance(y, (list, tuple)): - # if y_pred or y is a list of channel-first data, add batch dim and compute metric - ret = self._compute_list(y_pred, y) - elif isinstance(y_pred, torch.Tensor): - y_ = y.detach() if y is not None and isinstance(y, torch.Tensor) else None - ret = self._compute_tensor(y_pred.detach(), y_) - else: - raise ValueError("y_pred or y must be a list of `channel-first` Tensors or a `batch-first` Tensor.") - - return ret + return self._compute_list(y_pred, y) + # handling a single batch-first data + if isinstance(y_pred, torch.Tensor): + y_ = y.detach() if isinstance(y, torch.Tensor) else None + return self._compute_tensor(y_pred.detach(), y_) + raise ValueError("y_pred or y must be a list/tuple of `channel-first` Tensors or a `batch-first` Tensor.") def _compute_list(self, y_pred: TensorOrList, y: Optional[TensorOrList] = None): """ - Excute the computation for the y_pred and y items of a iteration, the data is in the list shape. - Will concat the results to guarantee the output shape of ret is BCHW[D], otherwise it's list of batch-first, - which is against our principle that data in metrics should be BCHW[D] or list of channel-first. - Note: subclass may enhance the operation with multi-threads to accelerate. + Execute the metric computation for `y_pred` and `y` in a list of "channel-first" tensors. + + The return value is a "batch-first" tensor, or a list of "batch-first" tensors. + When it's a list of tensors, each item in the list can represent a specific type of metric values. + + For example, `self._compute_tensor` may be implemented as returning a list of `batch_size` items, + where each item is a tuple of three values `tp`, `fp`, `fn` for true positives, false positives, + and false negatives respectively. This function will return a list of three items, + (`tp_batched`, `fp_batched`, `fn_batched`), where each item is a `batch_size`-length tensor. + Note: subclass may enhance the operation to have multi-thread support. """ - ret: TensorOrList if y is not None: ret = [self._compute_tensor(p.detach().unsqueeze(0), y_.detach().unsqueeze(0)) for p, y_ in zip(y_pred, y)] else: ret = [self._compute_tensor(p_.detach().unsqueeze(0), None) for p_ in y_pred] - # concat the list of results - if isinstance(ret[0], torch.Tensor): - ret = torch.cat(ret, dim=0) - elif isinstance(ret[0], (list, tuple)) and all(isinstance(i, torch.Tensor) for i in ret[0]): - # if _compute_tensor() returned not only 1 Tensor, concat them separately - ret = [torch.cat([k[i] for k in ret], dim=0) for i in range(len(ret[0]))] + # concat the list of results (e.g. a batch of evaluation scores) + if isinstance(ret[0], torch.Tensor): + return torch.cat(ret, dim=0) + # the result is a list of sequence of tensors (e.g. a batch of multi-class results) + if isinstance(ret[0], (list, tuple)) and all(isinstance(i, torch.Tensor) for i in ret[0]): + return [torch.cat(batch_i, dim=0) for batch_i in zip(*ret)] return ret @abstractmethod def _compute_tensor(self, y_pred: torch.Tensor, y: Optional[torch.Tensor] = None): """ - computation logic for the y_pred and y of a iteration, the data should be `batch-first` Tensors. - Every subclass metric should implement its own computation logic according to its algorithm. - + Computation logic for `y_pred` and `y` of an iteration, the data should be "batch-first" Tensors. + A subclass should implement its own computation logic. + The return value is usually a "batch_first" tensor, or a list of "batch_first" tensors. """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") -class Cumulative(ABC): +class Cumulative: """ Utility class for the typical cumulative computation process based on PyTorch Tensors. - It cumulates tensors in the buffer, then sync across distributed ranks and aggregate. + It provides interfaces to accumulate values in the local buffers, synchronize buffers across distributed nodes, + and aggregate the buffered values. + + In multi-processing, PyTorch programs usually distribute data to multiple nodes. Each node runs with a subset + of the data, adds values to its local buffers. Calling `get_buffer` could gather all the results and + `aggregate` can further handle the results to generate the final outcomes. - To speed up computation with multi-processing, PyTorch programs usually split data to distributed ranks - by `DistributedSampler` before an epoch, every rank then computes only based on its own data part and - `add` to the buffers in its process. Eventually, sync the values of all ranks to compute the final results. + Users can implement their own `aggregate` method to handle the results, + using `get_buffer` to get the buffered contents. Note: the data list should have the same length every time calling `add()` in a round, it will automatically create buffers according to the length of data list. - Typically, this class is expected to execute the steps referring to below examples:: + Typically, this class is expected to execute the following steps: + + .. code-block:: python + + from monai.metrics import Cumulative + + c = Cumulative() + c.append(1) # adds a value + c.extend([2, 3]) # adds a batch of values + c.extend([4, 5, 6]) # adds a batch of values + print(c.get_buffer()) # tensor([1, 2, 3, 4, 5, 6]) + print(len(c)) # 6 + c.reset() + print(len(c)) # 0 + + The following is an example of maintaining two internal buffers: + + .. code-block:: python + + from monai.metrics import Cumulative + + c = Cumulative() + c.append(1, 2) # adds a value to two buffers respectively + c.extend([3, 4], [5, 6]) # adds batches of values + print(c.get_buffer()) # [tensor([1, 3, 4]), tensor([2, 5, 6])] + print(len(c)) - cum = Cumulative() - cum.add(x, y) - cum.add(a, b) - cum.add(c, d) - cum.aggregate() - result = cum.get_buffer() - cum.reset() + The following is an example of extending with variable length data: + + .. code-block:: python + + import torch + from monai.metrics import Cumulative + + c = Cumulative() + c.extend(torch.zeros((8, 2)), torch.zeros((6, 2))) # adds batches + c.append(torch.zeros((2, ))) # adds a value + print(c.get_buffer()) # [torch.zeros((9, 2)), torch.zeros((6, 2))] + print(len(c)) """ def __init__(self): - self.buffer_num: int = 0 + """ + Initialize the internal buffers. + `self._buffers` are local buffers, they are not usually used directly. + `self._sync_buffers` are the buffers with all the results across all the nodes. + """ self._buffers: Optional[List[List[torch.Tensor]]] = None self._synced_tensors: Optional[List[Optional[torch.Tensor]]] = None self._synced: bool = False + self.reset() def reset(self): """ @@ -140,33 +187,57 @@ def reset(self): self._synced_tensors = None self._synced = False - def add(self, *data: torch.Tensor): + def extend(self, *data) -> None: + """ + Extend the local buffers with new ("batch-first") data. + A buffer will be allocated for each `data` item. + Compared with `self.append`, this method adds a "batch" of data to the local buffers. + + Args: + data: each item can be a "batch-first" tensor or a list of "channel-first" tensors. + they will be concatenated at the 0-th dimension when `get_buffer()` is called. + """ + if self._buffers is None: + self._buffers = [[] for _ in data] + for b, d in zip(self._buffers, data): + # converting to pytorch tensors so that we can use the distributed API + d_t: torch.Tensor + d_t, *_ = convert_data_type(d, output_type=torch.Tensor, wrap_sequence=True) # type: ignore + try: + b.extend([x[0] for x in torch.split(d_t, 1, dim=0)]) + except (AttributeError, IndexError, RuntimeError) as e: + raise TypeError( + f"{e}. `data` should be a batch-first tensor or" + f" a list of channel-first tensors, got {type(d_t)}" + ) from e + self._synced = False + + def append(self, *data) -> None: """ - Add samples to the cumulative buffers. + Add samples to the local cumulative buffers. + A buffer will be allocated for each `data` item. + Compared with `self.extend`, this method adds a single sample (instead + of a "batch") to the local buffers. Args: - data: list of input tensor, make sure the input data order is always the same in a round. - every item of data will be added to the corresponding buffer. + data: each item will be converted into a torch tensor. + they will be stacked at the 0-th dim with a new dimension when `get_buffer()` is called. """ - data_len = len(data) if self._buffers is None: - self._buffers = [[] for _ in range(data_len)] - elif len(self._buffers) != data_len: - raise ValueError(f"data length: {data_len} doesn't match buffers length: {len(self._buffers)}.") - if self._synced_tensors is None: - self._synced_tensors = [None for _ in range(data_len)] - - for i, d in enumerate(data): - if not isinstance(d, torch.Tensor): - raise ValueError(f"the data to cumulate in a buffer must be PyTorch Tensor, but got: {type(d)}.") - self._buffers[i].append(d) + self._buffers = [[] for _ in data] + for b, d in zip(self._buffers, data): + # converting to pytorch tensors so that we can use the distributed API + d_t: torch.Tensor + d_t, *_ = convert_data_type(d, output_type=torch.Tensor, wrap_sequence=True) # type: ignore + b.append(d_t) self._synced = False @abstractmethod - def aggregate(self, *args: Any, **kwds: Any): + def aggregate(self, *args: Any, **kwargs: Any): """ - Aggregate final results based on the buffers. + Aggregate final results based on the gathered buffers. + This method is expected to use `get_buffer` to gather the local buffer contents. """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") @@ -174,28 +245,67 @@ def aggregate(self, *args: Any, **kwds: Any): def _sync(self): """ All gather the buffers across distributed ranks for aggregating. - Every buffer will be concatenated as a PyTorch Tensor. + Each buffer will be concatenated as a PyTorch Tensor. """ - self._synced_tensors = [evenly_divisible_all_gather(torch.cat(b, dim=0), concat=True) for b in self._buffers] + if self._synced or self._buffers is None: + return + try: + self._synced_tensors = [ + evenly_divisible_all_gather(torch.stack(b, dim=0), concat=True) for b in self._buffers + ] + except (RuntimeError, TypeError, ValueError) as e: + raise TypeError(f"{e}. unable to sync buffer contents: {self._buffers}.") from e self._synced = True + def __len__(self): + """ + Return the length of the largest buffer. + Note that the method will trigger synchronization of the local buffers. + """ + self._sync() + if not self._synced_tensors: + return 0 + return max(len(x) for x in self._synced_tensors) + def get_buffer(self): """ - Get the synced buffers list. + Get the synchronized list of buffers. A typical usage is to generate the metrics report based on the raw metric details. + Each buffer is a PyTorch Tensor. """ - if not self._synced: - self._sync() + self._sync() return self._synced_tensors[0] if len(self._synced_tensors) == 1 else self._synced_tensors class CumulativeIterationMetric(Cumulative, IterationMetric): """ - Base class of cumulative metric which computes on batch data of every iteration and aggregate. - Typically, it computes some intermediate results for every iteration, cumulates in buffers, - then syncs across all the distributed ranks and aggregates for the final result when epoch completed. + Base class of cumulative metric which collects metrics on each mini-batch data at the iteration level. + + Typically, it computes some intermediate results for each iteration, adds them to the buffers, + then the buffer contents could be gathered and aggregated for the final result when epoch completed. + + For example, `MeanDice` inherits this class and the usage is as follows: + + .. code-block:: python + + dice_metric = DiceMetric(include_background=True, reduction="mean") + + for val_data in val_loader: + val_outputs = model(val_data["img"]) + val_outputs = [postprocessing_transform(i) for i in decollate_batch(val_outputs)] + # compute metric for current iteration + dice_metric(y_pred=val_outputs, y=val_data["seg"]) # callable to add metric to the buffer + + # aggregate the final mean dice result + metric = dice_metric.aggregate().item() + + # reset the status for next computation round + dice_metric.reset() + + And to load `predictions` and `labels` from files, then compute metrics with multi-processing, please refer to: + https://github.com/Project-MONAI/tutorials/blob/master/modules/compute_metric.py. """ @@ -212,11 +322,13 @@ def __call__(self, y_pred: TensorOrList, y: Optional[TensorOrList] = None): # t y: the ground truth to compute, must be a list of `channel-first` Tensor or a `batch-first` Tensor. + Returns: + The computed metric values at the iteration level. """ ret = super().__call__(y_pred=y_pred, y=y) if isinstance(ret, (tuple, list)): - self.add(*ret) + self.extend(*ret) else: - self.add(ret) + self.extend(ret) return ret diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py index a2a2f0853d..bd63134d6c 100644 --- a/monai/metrics/regression.py +++ b/monai/metrics/regression.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -30,18 +30,16 @@ class RegressionMetric(CumulativeIterationMetric): `y_preds` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]). Args: - reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``} - Define the mode to reduce computation result. Defaults to ``"mean"``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. """ def __init__( - self, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, - get_not_nans: bool = False, + self, reduction: Union[MetricReduction, str] = MetricReduction.MEAN, get_not_nans: bool = False ) -> None: super().__init__() self.reduction = reduction @@ -57,9 +55,7 @@ def aggregate(self): # type: ignore def _check_shape(self, y_pred: torch.Tensor, y: torch.Tensor) -> None: if y_pred.shape != y.shape: - raise ValueError( - "y_pred and y shapes dont match, received y_pred: [{}] and y: [{}]".format(y_pred.shape, y.shape) - ) + raise ValueError(f"y_pred and y shapes dont match, received y_pred: [{y_pred.shape}] and y: [{y.shape}]") # also check if there is atleast one non-batch dimension i.e. num_dims >= 2 if len(y_pred.shape) < 2: @@ -88,17 +84,15 @@ class MSEMetric(RegressionMetric): Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model. Args: - reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``} - Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). """ def __init__( - self, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, - get_not_nans: bool = False, + self, reduction: Union[MetricReduction, str] = MetricReduction.MEAN, get_not_nans: bool = False ) -> None: super().__init__(reduction=reduction, get_not_nans=get_not_nans) self.sq_func = partial(torch.pow, exponent=2.0) @@ -122,17 +116,15 @@ class MAEMetric(RegressionMetric): Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model. Args: - reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``} - Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). """ def __init__( - self, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, - get_not_nans: bool = False, + self, reduction: Union[MetricReduction, str] = MetricReduction.MEAN, get_not_nans: bool = False ) -> None: super().__init__(reduction=reduction, get_not_nans=get_not_nans) self.abs_func = torch.abs @@ -157,17 +149,15 @@ class RMSEMetric(RegressionMetric): Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model. Args: - reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``} - Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). """ def __init__( - self, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, - get_not_nans: bool = False, + self, reduction: Union[MetricReduction, str] = MetricReduction.MEAN, get_not_nans: bool = False ) -> None: super().__init__(reduction=reduction, get_not_nans=get_not_nans) self.sq_func = partial(torch.pow, exponent=2.0) @@ -198,9 +188,9 @@ class PSNRMetric(RegressionMetric): Args: max_val: The dynamic range of the images/volumes (i.e., the difference between the maximum and the minimum allowed values e.g. 255 for a uint8 image). - reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``} - Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). """ diff --git a/monai/metrics/rocauc.py b/monai/metrics/rocauc.py index c2679cc2ea..341d4cba2f 100644 --- a/monai/metrics/rocauc.py +++ b/monai/metrics/rocauc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -93,20 +93,18 @@ def _calculate(y_pred: torch.Tensor, y: torch.Tensor) -> float: return auc / (nneg * (n - nneg)) -def compute_roc_auc( - y_pred: torch.Tensor, - y: torch.Tensor, - average: Union[Average, str] = Average.MACRO, -): +def compute_roc_auc(y_pred: torch.Tensor, y: torch.Tensor, average: Union[Average, str] = Average.MACRO): """Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to: `sklearn.metrics.roc_auc_score `_. Args: y_pred: input data to compute, typical classification model output. - it must be One-Hot format and first dim is batch, example shape: [16] or [16, 2]. - y: ground truth to compute ROC AUC metric, the first dim is batch. - example shape: [16, 1] will be converted into [16, 2] (where `2` is inferred from `y_pred`). + the first dim must be batch, if multi-classes, it must be in One-Hot format. + for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data. + y: ground truth to compute ROC AUC metric, the first dim must be batch. + if multi-classes, it must be in One-Hot format. + for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data. average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``} Type of averaging performed if not binary classification. Defaults to ``"macro"``. diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 6039f1b55e..04eed97a5d 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -37,9 +37,9 @@ class SurfaceDistanceMetric(CumulativeIterationMetric): `seg_pred` and `seg_gt`. Defaults to ``False``. distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] the metric used to compute surface distance. Defaults to ``"euclidean"``. - reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``} - Define the mode to reduce computation result. Defaults to ``"mean"``. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. @@ -134,10 +134,7 @@ def compute_average_surface_distance( """ if not include_background: - y_pred, y = ignore_background( - y_pred=y_pred, - y=y, - ) + y_pred, y = ignore_background(y_pred=y_pred, y=y) if isinstance(y, torch.Tensor): y = y.float() diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 84de834f74..ccb6d93862 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,12 +25,10 @@ __all__ = ["ignore_background", "do_metric_reduction", "get_mask_edges", "get_surface_distance"] -def ignore_background( - y_pred: Union[np.ndarray, torch.Tensor], - y: Union[np.ndarray, torch.Tensor], -): +def ignore_background(y_pred: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor]): """ This function is used to remove background (the first channel) for `y_pred` and `y`. + Args: y_pred: predictions. As for classification tasks, `y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks, @@ -38,24 +36,24 @@ def ignore_background( y: ground truth, the first dim is batch. """ + y = y[:, 1:] if y.shape[1] > 1 else y y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred return y_pred, y -def do_metric_reduction( - f: torch.Tensor, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, -): +def do_metric_reduction(f: torch.Tensor, reduction: Union[MetricReduction, str] = MetricReduction.MEAN): """ - This function is to do the metric reduction for calculated metrics of each example's each class. + This function is to do the metric reduction for calculated `not-nan` metrics of each sample's each class. The function also returns `not_nans`, which counts the number of not nans for the metric. Args: f: a tensor that contains the calculated metric scores per batch and per class. The first two dims should be batch and class. - reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``}, if "none", return the input f tensor and not_nans. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. + if "none", return the input f tensor and not_nans. Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. Raises: @@ -156,7 +154,7 @@ def get_mask_edges( if crop: if not np.any(seg_pred | seg_gt): - return (np.zeros_like(seg_pred), np.zeros_like(seg_gt)) + return np.zeros_like(seg_pred), np.zeros_like(seg_gt) seg_pred, seg_gt = np.expand_dims(seg_pred, 0), np.expand_dims(seg_gt, 0) box_start, box_end = generate_spatial_bounding_box(np.asarray(seg_pred | seg_gt)) @@ -167,14 +165,10 @@ def get_mask_edges( edges_pred = binary_erosion(seg_pred) ^ seg_pred edges_gt = binary_erosion(seg_gt) ^ seg_gt - return (edges_pred, edges_gt) + return edges_pred, edges_gt -def get_surface_distance( - seg_pred: np.ndarray, - seg_gt: np.ndarray, - distance_metric: str = "euclidean", -) -> np.ndarray: +def get_surface_distance(seg_pred: np.ndarray, seg_gt: np.ndarray, distance_metric: str = "euclidean") -> np.ndarray: """ This function is used to compute the surface distances from `seg_pred` to `seg_gt`. diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index 3c347dad22..4e607dd298 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,6 +10,7 @@ # limitations under the License. from .utils import ( + convert_to_torchscript, copy_model_state, eval_mode, icnr_init, diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index db723f622d..0fdc944760 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ from .aspp import SimpleASPP from .convolutions import Convolution, ResidualUnit from .crf import CRF +from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock from .downsample import MaxAvgPool from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding from .fcn import FCN, GCN, MCFCN, Refine diff --git a/monai/networks/blocks/acti_norm.py b/monai/networks/blocks/acti_norm.py index 593ca6baa7..d07d78f1ad 100644 --- a/monai/networks/blocks/acti_norm.py +++ b/monai/networks/blocks/acti_norm.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index a380f8e757..1526b37056 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,7 +19,6 @@ def monai_mish(x, inplace: bool = False): return torch.nn.functional.mish(x, inplace=inplace) - else: def monai_mish(x, inplace: bool = False): @@ -31,7 +30,6 @@ def monai_mish(x, inplace: bool = False): def monai_swish(x, inplace: bool = False): return torch.nn.functional.silu(x, inplace=inplace) - else: def monai_swish(x, inplace: bool = False): @@ -48,8 +46,7 @@ class Swish(nn.Module): Shape: - - Input: :math:`(N, *)` where `*` means, any number of additional - dimensions + - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input @@ -123,7 +120,7 @@ class MemoryEfficientSwish(nn.Module): """ def __init__(self, inplace: bool = False): - super(MemoryEfficientSwish, self).__init__() + super().__init__() # inplace only works when using torch.nn.functional.silu self.inplace = inplace @@ -143,8 +140,7 @@ class Mish(nn.Module): this class will utilize `torch.nn.functional.mish` to do the calculation if meets the version. Shape: - - Input: :math:`(N, *)` where `*` means, any number of additional - dimensions + - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input @@ -158,7 +154,7 @@ class Mish(nn.Module): """ def __init__(self, inplace: bool = False): - super(Mish, self).__init__() + super().__init__() # inplace only works when using torch.nn.functional.mish self.inplace = inplace diff --git a/monai/networks/blocks/aspp.py b/monai/networks/blocks/aspp.py index f8bf8a5ba6..8d43530fa7 100644 --- a/monai/networks/blocks/aspp.py +++ b/monai/networks/blocks/aspp.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -86,7 +86,7 @@ def __init__( out_channels = conv_out_channels * len(pads) # final conv. output channels self.conv_k1 = Convolution( - dimensions=spatial_dims, + spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=1, diff --git a/monai/networks/blocks/convolutions.py b/monai/networks/blocks/convolutions.py index 39ce60e3f8..37530668a3 100644 --- a/monai/networks/blocks/convolutions.py +++ b/monai/networks/blocks/convolutions.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,6 +18,7 @@ from monai.networks.blocks import ADN from monai.networks.layers.convutils import same_padding, stride_minus_kernel_padding from monai.networks.layers.factories import Conv +from monai.utils.deprecate_utils import deprecated_arg class Convolution(nn.Sequential): @@ -59,7 +60,7 @@ class Convolution(nn.Sequential): ) Args: - dimensions: number of spatial dimensions. + spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. strides: convolution stride. Defaults to 1. @@ -69,13 +70,13 @@ class Convolution(nn.Sequential): act: activation type and arguments. Defaults to PReLU. norm: feature normalization type and arguments. Defaults to instance norm. dropout: dropout ratio. Defaults to no dropout. - dropout_dim: determine the dimensions of dropout. Defaults to 1. + dropout_dim: determine the spatial dimensions of dropout. Defaults to 1. - When dropout_dim = 1, randomly zeroes some of the elements for each channel. - When dropout_dim = 2, Randomly zeroes out entire channels (a channel is a 2D feature map). - When dropout_dim = 3, Randomly zeroes out entire channels (a channel is a 3D feature map). - The value of dropout_dim should be no no larger than the value of `dimensions`. + The value of dropout_dim should be no no larger than the value of `spatial_dims`. dilation: dilation rate. Defaults to 1. groups: controls the connections between inputs and outputs. Defaults to 1. bias: whether to have a bias term. Defaults to True. @@ -86,6 +87,9 @@ class Convolution(nn.Sequential): output_padding: controls the additional size added to one side of the output shape. Defaults to None. + .. deprecated:: 0.6.0 + ``dimensions`` is deprecated, use ``spatial_dims`` instead. + See also: :py:class:`monai.networks.layers.Conv` @@ -93,9 +97,12 @@ class Convolution(nn.Sequential): """ + @deprecated_arg( + name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." + ) def __init__( self, - dimensions: int, + spatial_dims: int, in_channels: int, out_channels: int, strides: Union[Sequence[int], int] = 1, @@ -112,15 +119,16 @@ def __init__( is_transposed: bool = False, padding: Optional[Union[Sequence[int], int]] = None, output_padding: Optional[Union[Sequence[int], int]] = None, + dimensions: Optional[int] = None, ) -> None: super().__init__() - self.dimensions = dimensions + self.dimensions = spatial_dims if dimensions is None else dimensions self.in_channels = in_channels self.out_channels = out_channels self.is_transposed = is_transposed if padding is None: padding = same_padding(kernel_size, dilation) - conv_type = Conv[Conv.CONVTRANS if is_transposed else Conv.CONV, dimensions] + conv_type = Conv[Conv.CONVTRANS if is_transposed else Conv.CONV, self.dimensions] conv: nn.Module if is_transposed: @@ -159,7 +167,7 @@ def __init__( in_channels=out_channels, act=act, norm=norm, - norm_dim=dimensions, + norm_dim=self.dimensions, dropout=dropout, dropout_dim=dropout_dim, ), @@ -177,7 +185,7 @@ class ResidualUnit(nn.Module): from monai.networks.blocks import ResidualUnit convs = ResidualUnit( - dimensions=3, + spatial_dims=3, in_channels=1, out_channels=1, adn_ordering="AN", @@ -209,7 +217,7 @@ class ResidualUnit(nn.Module): ) Args: - dimensions: number of spatial dimensions. + spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. strides: convolution stride. Defaults to 1. @@ -234,15 +242,19 @@ class ResidualUnit(nn.Module): padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each dimension. Defaults to None. + .. deprecated:: 0.6.0 + ``dimensions`` is deprecated, use ``spatial_dims`` instead. + See also: :py:class:`monai.networks.blocks.Convolution` """ + @deprecated_arg(name="dimensions", since="0.6", msg_suffix="Please use `spatial_dims` instead.") def __init__( self, - dimensions: int, + spatial_dims: int, in_channels: int, out_channels: int, strides: Union[Sequence[int], int] = 1, @@ -257,9 +269,10 @@ def __init__( bias: bool = True, last_conv_only: bool = False, padding: Optional[Union[Sequence[int], int]] = None, + dimensions: Optional[int] = None, ) -> None: super().__init__() - self.dimensions = dimensions + self.dimensions = spatial_dims if dimensions is None else dimensions self.in_channels = in_channels self.out_channels = out_channels self.conv = nn.Sequential() @@ -273,7 +286,7 @@ def __init__( for su in range(subunits): conv_only = last_conv_only and su == (subunits - 1) unit = Convolution( - dimensions, + self.dimensions, schannels, out_channels, strides=sstrides, @@ -304,7 +317,7 @@ def __init__( rkernel_size = 1 rpadding = 0 - conv_type = Conv[Conv.CONV, dimensions] + conv_type = Conv[Conv.CONV, self.dimensions] self.residual = conv_type(in_channels, out_channels, rkernel_size, strides, rpadding, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/monai/networks/blocks/crf.py b/monai/networks/blocks/crf.py index 49ff5bcd04..ccaef17679 100644 --- a/monai/networks/blocks/crf.py +++ b/monai/networks/blocks/crf.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -57,7 +57,7 @@ def __init__( compatibility_matrix: a matrix describing class compatibility, should be NxN where N is the number of classes. """ - super(CRF, self).__init__() + super().__init__() self.iterations = iterations self.bilateral_weight = bilateral_weight self.gaussian_weight = gaussian_weight diff --git a/monai/networks/blocks/dints_block.py b/monai/networks/blocks/dints_block.py new file mode 100644 index 0000000000..f76e125fe0 --- /dev/null +++ b/monai/networks/blocks/dints_block.py @@ -0,0 +1,272 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Tuple, Union + +import torch + +from monai.networks.layers.factories import Conv +from monai.networks.layers.utils import get_act_layer, get_norm_layer + +__all__ = ["FactorizedIncreaseBlock", "FactorizedReduceBlock", "P3DActiConvNormBlock", "ActiConvNormBlock"] + + +class FactorizedIncreaseBlock(torch.nn.Sequential): + """ + Up-sampling the features by two using linear interpolation and convolutions. + """ + + def __init__( + self, + in_channel: int, + out_channel: int, + spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = "INSTANCE", + ): + """ + Args: + in_channel: number of input channels + out_channel: number of output channels + spatial_dims: number of spatial dimensions + act_name: activation layer type and arguments. + norm_name: feature normalization type and arguments. + """ + super().__init__() + self._in_channel = in_channel + self._out_channel = out_channel + self._spatial_dims = spatial_dims + if self._spatial_dims not in (2, 3): + raise ValueError("spatial_dims must be 2 or 3.") + + conv_type = Conv[Conv.CONV, self._spatial_dims] + mode = "trilinear" if self._spatial_dims == 3 else "bilinear" + self.add_module("up", torch.nn.Upsample(scale_factor=2, mode=mode, align_corners=True)) + self.add_module("acti", get_act_layer(name=act_name)) + self.add_module( + "conv", + conv_type( + in_channels=self._in_channel, + out_channels=self._out_channel, + kernel_size=1, + stride=1, + padding=0, + groups=1, + bias=False, + dilation=1, + ), + ) + self.add_module( + "norm", get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel) + ) + + +class FactorizedReduceBlock(torch.nn.Module): + """ + Down-sampling the feature by 2 using stride. + The length along each spatial dimension must be a multiple of 2. + """ + + def __init__( + self, + in_channel: int, + out_channel: int, + spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = "INSTANCE", + ): + """ + Args: + in_channel: number of input channels + out_channel: number of output channels. + spatial_dims: number of spatial dimensions. + act_name: activation layer type and arguments. + norm_name: feature normalization type and arguments. + """ + super().__init__() + self._in_channel = in_channel + self._out_channel = out_channel + self._spatial_dims = spatial_dims + if self._spatial_dims not in (2, 3): + raise ValueError("spatial_dims must be 2 or 3.") + + conv_type = Conv[Conv.CONV, self._spatial_dims] + + self.act = get_act_layer(name=act_name) + self.conv_1 = conv_type( + in_channels=self._in_channel, + out_channels=self._out_channel // 2, + kernel_size=1, + stride=2, + padding=0, + groups=1, + bias=False, + dilation=1, + ) + self.conv_2 = conv_type( + in_channels=self._in_channel, + out_channels=self._out_channel - self._out_channel // 2, + kernel_size=1, + stride=2, + padding=0, + groups=1, + bias=False, + dilation=1, + ) + self.norm = get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + The length along each spatial dimension must be a multiple of 2. + """ + x = self.act(x) + if self._spatial_dims == 3: + out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:, 1:])], dim=1) + else: + out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1) + out = self.norm(out) + return out + + +class P3DActiConvNormBlock(torch.nn.Sequential): + """ + -- (act) -- (conv) -- (norm) -- + """ + + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: int, + padding: int, + mode: int = 0, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = "INSTANCE", + ): + """ + Args: + in_channel: number of input channels. + out_channel: number of output channels. + kernel_size: kernel size to be expanded to 3D. + padding: padding size to be expanded to 3D. + mode: mode for the anisotropic kernels: + + - 0: ``(k, k, 1)``, ``(1, 1, k)``, + - 1: ``(k, 1, k)``, ``(1, k, 1)``, + - 2: ``(1, k, k)``. ``(k, 1, 1)``. + + act_name: activation layer type and arguments. + norm_name: feature normalization type and arguments. + """ + super().__init__() + self._in_channel = in_channel + self._out_channel = out_channel + self._p3dmode = int(mode) + + conv_type = Conv[Conv.CONV, 3] + + if self._p3dmode == 0: # (k, k, 1), (1, 1, k) + kernel_size0 = (kernel_size, kernel_size, 1) + kernel_size1 = (1, 1, kernel_size) + padding0 = (padding, padding, 0) + padding1 = (0, 0, padding) + elif self._p3dmode == 1: # (k, 1, k), (1, k, 1) + kernel_size0 = (kernel_size, 1, kernel_size) + kernel_size1 = (1, kernel_size, 1) + padding0 = (padding, 0, padding) + padding1 = (0, padding, 0) + elif self._p3dmode == 2: # (1, k, k), (k, 1, 1) + kernel_size0 = (1, kernel_size, kernel_size) + kernel_size1 = (kernel_size, 1, 1) + padding0 = (0, padding, padding) + padding1 = (padding, 0, 0) + else: + raise ValueError("`mode` must be 0, 1, or 2.") + + self.add_module("acti", get_act_layer(name=act_name)) + self.add_module( + "conv", + conv_type( + in_channels=self._in_channel, + out_channels=self._in_channel, + kernel_size=kernel_size0, + stride=1, + padding=padding0, + groups=1, + bias=False, + dilation=1, + ), + ) + self.add_module( + "conv_1", + conv_type( + in_channels=self._in_channel, + out_channels=self._out_channel, + kernel_size=kernel_size1, + stride=1, + padding=padding1, + groups=1, + bias=False, + dilation=1, + ), + ) + self.add_module("norm", get_norm_layer(name=norm_name, spatial_dims=3, channels=self._out_channel)) + + +class ActiConvNormBlock(torch.nn.Sequential): + """ + -- (Acti) -- (Conv) -- (Norm) -- + """ + + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: int = 3, + padding: int = 1, + spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = "INSTANCE", + ): + """ + Args: + in_channel: number of input channels. + out_channel: number of output channels. + kernel_size: kernel size of the convolution. + padding: padding size of the convolution. + spatial_dims: number of spatial dimensions. + act_name: activation layer type and arguments. + norm_name: feature normalization type and arguments. + """ + super().__init__() + self._in_channel = in_channel + self._out_channel = out_channel + self._spatial_dims = spatial_dims + + conv_type = Conv[Conv.CONV, self._spatial_dims] + self.add_module("acti", get_act_layer(name=act_name)) + self.add_module( + "conv", + conv_type( + in_channels=self._in_channel, + out_channels=self._out_channel, + kernel_size=kernel_size, + stride=1, + padding=padding, + groups=1, + bias=False, + dilation=1, + ), + ) + self.add_module( + "norm", get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel) + ) diff --git a/monai/networks/blocks/downsample.py b/monai/networks/blocks/downsample.py index 9bee4c596e..9b0d5dd4b9 100644 --- a/monai/networks/blocks/downsample.py +++ b/monai/networks/blocks/downsample.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/blocks/dynunet_block.py b/monai/networks/blocks/dynunet_block.py index bb654d841c..8b22cb16a9 100644 --- a/monai/networks/blocks/dynunet_block.py +++ b/monai/networks/blocks/dynunet_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -33,6 +33,8 @@ class UnetResBlock(nn.Module): kernel_size: convolution kernel size. stride: convolution stride. norm_name: feature normalization type and arguments. + act_name: activation layer type and arguments. + dropout: dropout probability. """ @@ -44,33 +46,26 @@ def __init__( kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], norm_name: Union[Tuple, str], + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), + dropout: Optional[Union[Tuple, str, float]] = None, ): - super(UnetResBlock, self).__init__() + super().__init__() self.conv1 = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=kernel_size, stride=stride, + dropout=dropout, conv_only=True, ) self.conv2 = get_conv_layer( - spatial_dims, - out_channels, - out_channels, - kernel_size=kernel_size, - stride=1, - conv_only=True, + spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True ) self.conv3 = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=1, - stride=stride, - conv_only=True, + spatial_dims, in_channels, out_channels, kernel_size=1, stride=stride, dropout=dropout, conv_only=True ) - self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01})) + self.lrelu = get_act_layer(name=act_name) self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) self.norm3 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) @@ -107,6 +102,8 @@ class UnetBasicBlock(nn.Module): kernel_size: convolution kernel size. stride: convolution stride. norm_name: feature normalization type and arguments. + act_name: activation layer type and arguments. + dropout: dropout probability. """ @@ -118,25 +115,23 @@ def __init__( kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], norm_name: Union[Tuple, str], + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), + dropout: Optional[Union[Tuple, str, float]] = None, ): - super(UnetBasicBlock, self).__init__() + super().__init__() self.conv1 = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=kernel_size, stride=stride, + dropout=dropout, conv_only=True, ) self.conv2 = get_conv_layer( - spatial_dims, - out_channels, - out_channels, - kernel_size=kernel_size, - stride=1, - conv_only=True, + spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True ) - self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01})) + self.lrelu = get_act_layer(name=act_name) self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) @@ -164,6 +159,9 @@ class UnetUpBlock(nn.Module): stride: convolution stride. upsample_kernel_size: convolution kernel size for transposed convolution layers. norm_name: feature normalization type and arguments. + act_name: activation layer type and arguments. + dropout: dropout probability. + trans_bias: transposed convolution bias. """ @@ -176,8 +174,11 @@ def __init__( stride: Union[Sequence[int], int], upsample_kernel_size: Union[Sequence[int], int], norm_name: Union[Tuple, str], + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), + dropout: Optional[Union[Tuple, str, float]] = None, + trans_bias: bool = False, ): - super(UnetUpBlock, self).__init__() + super().__init__() upsample_stride = upsample_kernel_size self.transp_conv = get_conv_layer( spatial_dims, @@ -185,6 +186,8 @@ def __init__( out_channels, kernel_size=upsample_kernel_size, stride=upsample_stride, + dropout=dropout, + bias=trans_bias, conv_only=True, is_transposed=True, ) @@ -194,7 +197,9 @@ def __init__( out_channels, kernel_size=kernel_size, stride=1, + dropout=dropout, norm_name=norm_name, + act_name=act_name, ) def forward(self, inp, skip): @@ -206,10 +211,12 @@ def forward(self, inp, skip): class UnetOutBlock(nn.Module): - def __init__(self, spatial_dims: int, in_channels: int, out_channels: int): - super(UnetOutBlock, self).__init__() + def __init__( + self, spatial_dims: int, in_channels: int, out_channels: int, dropout: Optional[Union[Tuple, str, float]] = None + ): + super().__init__() self.conv = get_conv_layer( - spatial_dims, in_channels, out_channels, kernel_size=1, stride=1, bias=True, conv_only=True + spatial_dims, in_channels, out_channels, kernel_size=1, stride=1, dropout=dropout, bias=True, conv_only=True ) def forward(self, inp): @@ -224,6 +231,7 @@ def get_conv_layer( stride: Union[Sequence[int], int] = 1, act: Optional[Union[Tuple, str]] = Act.PRELU, norm: Union[Tuple, str] = Norm.INSTANCE, + dropout: Optional[Union[Tuple, str, float]] = None, bias: bool = False, conv_only: bool = True, is_transposed: bool = False, @@ -240,6 +248,7 @@ def get_conv_layer( kernel_size=kernel_size, act=act, norm=norm, + dropout=dropout, bias=bias, conv_only=conv_only, is_transposed=is_transposed, @@ -249,8 +258,7 @@ def get_conv_layer( def get_padding( - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], + kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int] ) -> Union[Tuple[int, ...], int]: kernel_size_np = np.atleast_1d(kernel_size) @@ -264,9 +272,7 @@ def get_padding( def get_output_padding( - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - padding: Union[Sequence[int], int], + kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], padding: Union[Sequence[int], int] ) -> Union[Tuple[int, ...], int]: kernel_size_np = np.atleast_1d(kernel_size) stride_np = np.atleast_1d(stride) diff --git a/monai/networks/blocks/dynunet_block_v1.py b/monai/networks/blocks/dynunet_block_v1.py deleted file mode 100644 index d5d9bbf3dc..0000000000 --- a/monai/networks/blocks/dynunet_block_v1.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Sequence, Union - -import numpy as np -import torch.nn as nn - -from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, UnetUpBlock, get_conv_layer -from monai.networks.layers.factories import Norm -from monai.networks.layers.utils import get_act_layer - - -class _UnetResBlockV1(UnetResBlock): - """ - UnetResBlock for backward compatibility purpose. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - norm_name: str, - ): - nn.Module.__init__(self) - self.conv1 = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - conv_only=True, - ) - self.conv2 = get_conv_layer( - spatial_dims, - out_channels, - out_channels, - kernel_size=kernel_size, - stride=1, - conv_only=True, - ) - self.conv3 = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=1, - stride=stride, - conv_only=True, - ) - self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01})) - self.norm1 = _get_norm_layer(spatial_dims, out_channels, norm_name) - self.norm2 = _get_norm_layer(spatial_dims, out_channels, norm_name) - self.norm3 = _get_norm_layer(spatial_dims, out_channels, norm_name) - self.downsample = in_channels != out_channels - stride_np = np.atleast_1d(stride) - if not np.all(stride_np == 1): - self.downsample = True - - -class _UnetBasicBlockV1(UnetBasicBlock): - """ - UnetBasicBlock for backward compatibility purpose. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - norm_name: str, - ): - nn.Module.__init__(self) - self.conv1 = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - conv_only=True, - ) - self.conv2 = get_conv_layer( - spatial_dims, - out_channels, - out_channels, - kernel_size=kernel_size, - stride=1, - conv_only=True, - ) - self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01})) - self.norm1 = _get_norm_layer(spatial_dims, out_channels, norm_name) - self.norm2 = _get_norm_layer(spatial_dims, out_channels, norm_name) - - -class _UnetUpBlockV1(UnetUpBlock): - """ - UnetUpBlock for backward compatibility purpose. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - upsample_kernel_size: Union[Sequence[int], int], - norm_name: str, - ): - nn.Module.__init__(self) - upsample_stride = upsample_kernel_size - self.transp_conv = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=upsample_kernel_size, - stride=upsample_stride, - conv_only=True, - is_transposed=True, - ) - self.conv_block = _UnetBasicBlockV1( - spatial_dims, - out_channels + out_channels, - out_channels, - kernel_size=kernel_size, - stride=1, - norm_name=norm_name, - ) - - -def _get_norm_layer(spatial_dims: int, out_channels: int, norm_name: str, num_groups: int = 16): - if norm_name not in ["batch", "instance", "group"]: - raise ValueError(f"Unsupported normalization mode: {norm_name}") - if norm_name != "group": - return Norm[norm_name, spatial_dims](out_channels, affine=True) - if out_channels % num_groups != 0: - raise AssertionError("out_channels should be divisible by num_groups.") - return Norm[norm_name, spatial_dims](num_groups=num_groups, num_channels=out_channels, affine=True) diff --git a/monai/networks/blocks/fcn.py b/monai/networks/blocks/fcn.py index d84e506774..5833d4a262 100644 --- a/monai/networks/blocks/fcn.py +++ b/monai/networks/blocks/fcn.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -36,7 +36,7 @@ def __init__(self, inplanes: int, planes: int, ks: int = 7): planes: number of output channels. ks: kernel size for one dimension. Defaults to 7. """ - super(GCN, self).__init__() + super().__init__() conv2d_type: Type[nn.Conv2d] = Conv[Conv.CONV, 2] self.conv_l1 = conv2d_type(in_channels=inplanes, out_channels=planes, kernel_size=(ks, 1), padding=(ks // 2, 0)) @@ -67,7 +67,7 @@ def __init__(self, planes: int): Args: planes: number of input channels. """ - super(Refine, self).__init__() + super().__init__() relu_type: Type[nn.ReLU] = Act[Act.RELU] conv2d_type: Type[nn.Conv2d] = Conv[Conv.CONV, 2] @@ -116,7 +116,7 @@ class FCN(nn.Module): def __init__( self, out_channels: int = 1, upsample_mode: str = "bilinear", pretrained: bool = True, progress: bool = True ): - super(FCN, self).__init__() + super().__init__() conv2d_type: Type[nn.Conv2d] = Conv[Conv.CONV, 2] @@ -154,12 +154,7 @@ def __init__( self.transformer = self.conv2d_type(in_channels=256, out_channels=64, kernel_size=1) if self.upsample_mode == "transpose": - self.up_conv = UpSample( - dimensions=2, - in_channels=self.out_channels, - scale_factor=2, - mode="deconv", - ) + self.up_conv = UpSample(spatial_dims=2, in_channels=self.out_channels, scale_factor=2, mode="deconv") def forward(self, x: torch.Tensor): """ @@ -195,14 +190,7 @@ def forward(self, x: torch.Tensor): fs2 = self.refine7(F.interpolate(fs1, fm2.size()[2:], mode=self.upsample_mode, align_corners=True) + gcfm3) fs3 = self.refine8(F.interpolate(fs2, pool_x.size()[2:], mode=self.upsample_mode, align_corners=True) + gcfm4) fs4 = self.refine9(F.interpolate(fs3, conv_x.size()[2:], mode=self.upsample_mode, align_corners=True) + gcfm5) - return self.refine10( - F.interpolate( - fs4, - org_input.size()[2:], - mode=self.upsample_mode, - align_corners=True, - ) - ) + return self.refine10(F.interpolate(fs4, org_input.size()[2:], mode=self.upsample_mode, align_corners=True)) class MCFCN(FCN): @@ -231,12 +219,12 @@ def __init__( pretrained: bool = True, progress: bool = True, ): - super(MCFCN, self).__init__( + super().__init__( out_channels=out_channels, upsample_mode=upsample_mode, pretrained=pretrained, progress=progress ) self.init_proj = Convolution( - dimensions=2, + spatial_dims=2, in_channels=in_channels, out_channels=3, kernel_size=1, @@ -251,4 +239,4 @@ def forward(self, x: torch.Tensor): x: in shape (batch, in_channels, spatial_1, spatial_2). """ x = self.init_proj(x) - return super(MCFCN, self).forward(x) + return super().forward(x) diff --git a/monai/networks/blocks/localnet_block.py b/monai/networks/blocks/localnet_block.py index 3997d42436..41b76c7d4c 100644 --- a/monai/networks/blocks/localnet_block.py +++ b/monai/networks/blocks/localnet_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -29,7 +29,7 @@ def get_conv_block( norm: Optional[Union[Tuple, str]] = "BATCH", ) -> nn.Module: padding = same_padding(kernel_size) - return Convolution( + mod: nn.Module = Convolution( spatial_dims, in_channels, out_channels, @@ -40,33 +40,22 @@ def get_conv_block( conv_only=False, padding=padding, ) + return mod def get_conv_layer( - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int] = 3, + spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int] = 3 ) -> nn.Module: padding = same_padding(kernel_size) - return Convolution( - spatial_dims, - in_channels, - out_channels, - kernel_size=kernel_size, - bias=False, - conv_only=True, - padding=padding, + mod: nn.Module = Convolution( + spatial_dims, in_channels, out_channels, kernel_size=kernel_size, bias=False, conv_only=True, padding=padding ) + return mod -def get_deconv_block( - spatial_dims: int, - in_channels: int, - out_channels: int, -) -> nn.Module: - return Convolution( - dimensions=spatial_dims, +def get_deconv_block(spatial_dims: int, in_channels: int, out_channels: int) -> nn.Module: + mod: nn.Module = Convolution( + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, strides=2, @@ -77,26 +66,20 @@ def get_deconv_block( padding=1, output_padding=1, ) + return mod class ResidualBlock(nn.Module): def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], + self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int] ) -> None: - super(ResidualBlock, self).__init__() + super().__init__() if in_channels != out_channels: raise ValueError( f"expecting in_channels == out_channels, " f"got in_channels={in_channels}, out_channels={out_channels}" ) self.conv_block = get_conv_block( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size ) self.conv = get_conv_layer( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size @@ -110,22 +93,13 @@ def forward(self, x) -> torch.Tensor: class LocalNetResidualBlock(nn.Module): - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - ) -> None: - super(LocalNetResidualBlock, self).__init__() + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> None: + super().__init__() if in_channels != out_channels: raise ValueError( f"expecting in_channels == out_channels, " f"got in_channels={in_channels}, out_channels={out_channels}" ) - self.conv_layer = get_conv_layer( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - ) + self.conv_layer = get_conv_layer(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels) self.norm = Norm[Norm.BATCH, spatial_dims](out_channels) self.relu = nn.ReLU() @@ -147,11 +121,7 @@ class LocalNetDownSampleBlock(nn.Module): """ def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], + self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int] ) -> None: """ Args: @@ -162,16 +132,14 @@ def __init__( Raises: NotImplementedError: when ``kernel_size`` is even """ - super(LocalNetDownSampleBlock, self).__init__() + super().__init__() self.conv_block = get_conv_block( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size ) self.residual_block = ResidualBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size ) - self.max_pool = Pool[Pool.MAX, spatial_dims]( - kernel_size=2, - ) + self.max_pool = Pool[Pool.MAX, spatial_dims](kernel_size=2) def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -208,12 +176,7 @@ class LocalNetUpSampleBlock(nn.Module): DeepReg (https://github.com/DeepRegNet/DeepReg) """ - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - ) -> None: + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> None: """ Args: spatial_dims: number of spatial dimensions. @@ -222,21 +185,13 @@ def __init__( Raises: ValueError: when ``in_channels != 2 * out_channels`` """ - super(LocalNetUpSampleBlock, self).__init__() + super().__init__() self.deconv_block = get_deconv_block( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - ) - self.conv_block = get_conv_block( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels ) + self.conv_block = get_conv_block(spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels) self.residual_block = LocalNetResidualBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, + spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels ) if in_channels / out_channels != 2: raise ValueError( @@ -306,7 +261,7 @@ def __init__( act: activation type and arguments. Defaults to ReLU. kernel_initializer: kernel initializer. Defaults to None. """ - super(LocalNetFeatureExtractorBlock, self).__init__() + super().__init__() self.conv_block = get_conv_block( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, act=act, norm=None ) diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py index 11b5e6fc15..a1728365cf 100644 --- a/monai/networks/blocks/mlp.py +++ b/monai/networks/blocks/mlp.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,12 +18,7 @@ class MLPBlock(nn.Module): An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " """ - def __init__( - self, - hidden_size: int, - mlp_dim: int, - dropout_rate: float = 0.0, - ) -> None: + def __init__(self, hidden_size: int, mlp_dim: int, dropout_rate: float = 0.0) -> None: """ Args: hidden_size: dimension of hidden layer. diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index c1fcfa9af7..063a1fded1 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -62,7 +62,7 @@ def __init__( """ - super(PatchEmbeddingBlock, self).__init__() + super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") @@ -94,11 +94,9 @@ def __init__( to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)" axes_len = {f"p{i+1}": p for i, p in enumerate(patch_size)} self.patch_embeddings = nn.Sequential( - Rearrange(f"{from_chars} -> {to_chars}", **axes_len), - nn.Linear(self.patch_dim, hidden_size), + Rearrange(f"{from_chars} -> {to_chars}", **axes_len), nn.Linear(self.patch_dim, hidden_size) ) self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) - self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) self.dropout = nn.Dropout(dropout_rate) self.trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0) self.apply(self._init_weights) diff --git a/monai/networks/blocks/regunet_block.py b/monai/networks/blocks/regunet_block.py index d2cd3518b9..78e2598b4b 100644 --- a/monai/networks/blocks/regunet_block.py +++ b/monai/networks/blocks/regunet_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -32,7 +32,7 @@ def get_conv_block( ) -> nn.Module: if padding is None: padding = same_padding(kernel_size) - conv_block = Convolution( + conv_block: nn.Module = Convolution( spatial_dims, in_channels, out_channels, @@ -59,21 +59,13 @@ def get_conv_block( def get_conv_layer( - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int] = 3, + spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int] = 3 ) -> nn.Module: padding = same_padding(kernel_size) - return Convolution( - spatial_dims, - in_channels, - out_channels, - kernel_size=kernel_size, - bias=False, - conv_only=True, - padding=padding, + mod: nn.Module = Convolution( + spatial_dims, in_channels, out_channels, kernel_size=kernel_size, bias=False, conv_only=True, padding=padding ) + return mod class RegistrationResidualConvBlock(nn.Module): @@ -83,12 +75,7 @@ class RegistrationResidualConvBlock(nn.Module): """ def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - num_layers: int = 2, - kernel_size: int = 3, + self, spatial_dims: int, in_channels: int, out_channels: int, num_layers: int = 2, kernel_size: int = 3 ): """ @@ -99,7 +86,7 @@ def __init__( num_layers: number of layers inside the block kernel_size: kernel_size """ - super(RegistrationResidualConvBlock, self).__init__() + super().__init__() self.num_layers = num_layers self.layers = nn.ModuleList( [ @@ -145,19 +132,14 @@ class RegistrationDownSampleBlock(nn.Module): DeepReg (https://github.com/DeepRegNet/DeepReg) """ - def __init__( - self, - spatial_dims: int, - channels: int, - pooling: bool, - ) -> None: + def __init__(self, spatial_dims: int, channels: int, pooling: bool) -> None: """ Args: spatial_dims: number of spatial dimensions. channels: channels pooling: use MaxPool if True, strided conv if False """ - super(RegistrationDownSampleBlock, self).__init__() + super().__init__() if pooling: self.layer = Pool[Pool.MAX, spatial_dims](kernel_size=2) else: @@ -188,13 +170,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out -def get_deconv_block( - spatial_dims: int, - in_channels: int, - out_channels: int, -) -> nn.Module: - return Convolution( - dimensions=spatial_dims, +def get_deconv_block(spatial_dims: int, in_channels: int, out_channels: int) -> nn.Module: + mod: nn.Module = Convolution( + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, strides=2, @@ -205,6 +183,7 @@ def get_deconv_block( padding=1, output_padding=1, ) + return mod class RegistrationExtractionBlock(nn.Module): @@ -233,7 +212,7 @@ def __init__( kernel_initializer: kernel initializer activation: kernel activation function """ - super(RegistrationExtractionBlock, self).__init__() + super().__init__() self.extract_levels = extract_levels self.max_level = max(extract_levels) self.layers = nn.ModuleList( @@ -261,10 +240,7 @@ def forward(self, x: List[torch.Tensor], image_size: List[int]) -> torch.Tensor: Tensor of shape (batch, `out_channels`, size1, size2, size3), where (size1, size2, size3) = ``image_size`` """ feature_list = [ - F.interpolate( - layer(x[self.max_level - level]), - size=image_size, - ) + F.interpolate(layer(x[self.max_level - level]), size=image_size) for layer, level in zip(self.layers, self.extract_levels) ] out: torch.Tensor = torch.mean(torch.stack(feature_list, dim=0), dim=0) diff --git a/monai/networks/blocks/segresnet_block.py b/monai/networks/blocks/segresnet_block.py index d8f6d7b268..ded270ab52 100644 --- a/monai/networks/blocks/segresnet_block.py +++ b/monai/networks/blocks/segresnet_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,8 +15,7 @@ from monai.networks.blocks.convolutions import Convolution from monai.networks.blocks.upsample import UpSample -from monai.networks.layers.factories import Act -from monai.networks.layers.utils import get_norm_layer +from monai.networks.layers.utils import get_act_layer, get_norm_layer from monai.utils import InterpolateMode, UpsampleMode @@ -25,13 +24,7 @@ def get_conv_layer( ): return Convolution( - spatial_dims, - in_channels, - out_channels, - strides=stride, - kernel_size=kernel_size, - bias=bias, - conv_only=True, + spatial_dims, in_channels, out_channels, strides=stride, kernel_size=kernel_size, bias=bias, conv_only=True ) @@ -39,7 +32,7 @@ def get_upsample_layer( spatial_dims: int, in_channels: int, upsample_mode: Union[UpsampleMode, str] = "nontrainable", scale_factor: int = 2 ): return UpSample( - dimensions=spatial_dims, + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, scale_factor=scale_factor, @@ -62,6 +55,7 @@ def __init__( in_channels: int, norm: Union[Tuple, str], kernel_size: int = 3, + act: Union[Tuple, str] = ("RELU", {"inplace": True}), ) -> None: """ Args: @@ -69,6 +63,7 @@ def __init__( in_channels: number of input channels. norm: feature normalization type and arguments. kernel_size: convolution kernel size, the value should be an odd number. Defaults to 3. + act: activation type and arguments. Defaults to ``RELU``. """ super().__init__() @@ -78,7 +73,7 @@ def __init__( self.norm1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels) self.norm2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels) - self.relu = Act[Act.RELU](inplace=True) + self.act = get_act_layer(act) self.conv1 = get_conv_layer(spatial_dims, in_channels=in_channels, out_channels=in_channels) self.conv2 = get_conv_layer(spatial_dims, in_channels=in_channels, out_channels=in_channels) @@ -87,11 +82,11 @@ def forward(self, x): identity = x x = self.norm1(x) - x = self.relu(x) + x = self.act(x) x = self.conv1(x) x = self.norm2(x) - x = self.relu(x) + x = self.act(x) x = self.conv2(x) x += identity diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 9dc45cccc8..4a86cd84bc 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,12 +23,7 @@ class SABlock(nn.Module): An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " """ - def __init__( - self, - hidden_size: int, - num_heads: int, - dropout_rate: float = 0.0, - ) -> None: + def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0) -> None: """ Args: hidden_size: dimension of hidden layer. @@ -37,7 +32,7 @@ def __init__( """ - super(SABlock, self).__init__() + super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") diff --git a/monai/networks/blocks/squeeze_and_excitation.py b/monai/networks/blocks/squeeze_and_excitation.py index 4db6dc30f7..a9ac57aa4f 100644 --- a/monai/networks/blocks/squeeze_and_excitation.py +++ b/monai/networks/blocks/squeeze_and_excitation.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -50,7 +50,7 @@ def __init__( :py:class:`monai.networks.layers.Act` """ - super(ChannelSELayer, self).__init__() + super().__init__() self.add_residual = add_residual @@ -181,21 +181,21 @@ def __init__( :py:class:`monai.networks.blocks.ChannelSELayer` """ - super(SEBlock, self).__init__() + super().__init__() if not conv_param_1: conv_param_1 = {"kernel_size": 1, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})} self.conv1 = Convolution( - dimensions=spatial_dims, in_channels=in_channels, out_channels=n_chns_1, **conv_param_1 + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=n_chns_1, **conv_param_1 ) if not conv_param_2: conv_param_2 = {"kernel_size": 3, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})} - self.conv2 = Convolution(dimensions=spatial_dims, in_channels=n_chns_1, out_channels=n_chns_2, **conv_param_2) + self.conv2 = Convolution(spatial_dims=spatial_dims, in_channels=n_chns_1, out_channels=n_chns_2, **conv_param_2) if not conv_param_3: conv_param_3 = {"kernel_size": 1, "norm": Norm.BATCH, "act": None} - self.conv3 = Convolution(dimensions=spatial_dims, in_channels=n_chns_2, out_channels=n_chns_3, **conv_param_3) + self.conv3 = Convolution(spatial_dims=spatial_dims, in_channels=n_chns_2, out_channels=n_chns_3, **conv_param_3) self.se_layer = ChannelSELayer( spatial_dims=spatial_dims, in_channels=n_chns_3, r=r, acti_type_1=acti_type_1, acti_type_2=acti_type_2 @@ -264,7 +264,7 @@ def __init__( } conv_param_3 = {"strides": 1, "kernel_size": 1, "act": None, "norm": Norm.BATCH, "bias": False} - super(SEBottleneck, self).__init__( + super().__init__( spatial_dims=spatial_dims, in_channels=inplanes, n_chns_1=planes * 2, @@ -315,7 +315,7 @@ def __init__( } conv_param_3 = {"strides": 1, "kernel_size": 1, "act": None, "norm": Norm.BATCH, "bias": False} - super(SEResNetBottleneck, self).__init__( + super().__init__( spatial_dims=spatial_dims, in_channels=inplanes, n_chns_1=planes, diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index c7a948ed76..616d84e067 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,13 +21,7 @@ class TransformerBlock(nn.Module): An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " """ - def __init__( - self, - hidden_size: int, - mlp_dim: int, - num_heads: int, - dropout_rate: float = 0.0, - ) -> None: + def __init__(self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0) -> None: """ Args: hidden_size: dimension of hidden layer. diff --git a/monai/networks/blocks/unetr_block.py b/monai/networks/blocks/unetr_block.py index a0852d05e0..a9d871a644 100644 --- a/monai/networks/blocks/unetr_block.py +++ b/monai/networks/blocks/unetr_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -46,7 +46,7 @@ def __init__( """ - super(UnetrUpBlock, self).__init__() + super().__init__() upsample_stride = upsample_kernel_size self.transp_conv = get_conv_layer( spatial_dims, diff --git a/monai/networks/blocks/upsample.py b/monai/networks/blocks/upsample.py index 5320611ce6..6a8498fe6f 100644 --- a/monai/networks/blocks/upsample.py +++ b/monai/networks/blocks/upsample.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,7 +16,7 @@ from monai.networks.layers.factories import Conv, Pad, Pool from monai.networks.utils import icnr_init, pixelshuffle -from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option +from monai.utils import InterpolateMode, UpsampleMode, deprecated_arg, ensure_tuple_rep, look_up_option __all__ = ["Upsample", "UpSample", "SubpixelUpsample", "Subpixelupsample", "SubpixelUpSample"] @@ -34,9 +34,12 @@ class UpSample(nn.Sequential): (often used to map the number of features from `in_channels` to `out_channels`). """ + @deprecated_arg( + name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." + ) def __init__( self, - dimensions: int, + spatial_dims: int, in_channels: Optional[int] = None, out_channels: Optional[int] = None, scale_factor: Union[Sequence[float], float] = 2, @@ -47,10 +50,11 @@ def __init__( align_corners: Optional[bool] = True, bias: bool = True, apply_pad_pool: bool = True, + dimensions: Optional[int] = None, ) -> None: """ Args: - dimensions: number of spatial dimensions of the input image. + spatial_dims: number of spatial dimensions of the input image. in_channels: number of channels of the input image. out_channels: number of channels of the output image. Defaults to `in_channels`. scale_factor: multiplier for spatial size. Has to match input size if it is a tuple. Defaults to 2. @@ -75,16 +79,21 @@ def __init__( apply_pad_pool: if True the upsampled tensor is padded then average pooling is applied with a kernel the size of `scale_factor` with a stride of 1. See also: :py:class:`monai.networks.blocks.SubpixelUpsample`. Only used in the "pixelshuffle" mode. + + .. deprecated:: 0.6.0 + ``dimensions`` is deprecated, use ``spatial_dims`` instead. """ super().__init__() - scale_factor_ = ensure_tuple_rep(scale_factor, dimensions) + if dimensions is not None: + spatial_dims = dimensions + scale_factor_ = ensure_tuple_rep(scale_factor, spatial_dims) up_mode = look_up_option(mode, UpsampleMode) if up_mode == UpsampleMode.DECONV: if not in_channels: raise ValueError(f"in_channels needs to be specified in the '{mode}' mode.") self.add_module( "deconv", - Conv[Conv.CONVTRANS, dimensions]( + Conv[Conv.CONVTRANS, spatial_dims]( in_channels=in_channels, out_channels=out_channels or in_channels, kernel_size=scale_factor_, @@ -98,7 +107,7 @@ def __init__( raise ValueError(f"in_channels needs to be specified in the '{mode}' mode.") self.add_module( "preconv", - Conv[Conv.CONV, dimensions]( + Conv[Conv.CONV, spatial_dims]( in_channels=in_channels, out_channels=out_channels or in_channels, kernel_size=1, bias=bias ), ) @@ -112,7 +121,7 @@ def __init__( interp_mode = InterpolateMode(interp_mode) linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR] if interp_mode in linear_mode: # choose mode based on dimensions - interp_mode = linear_mode[dimensions - 1] + interp_mode = linear_mode[spatial_dims - 1] self.add_module( "upsample_non_trainable", nn.Upsample( @@ -126,7 +135,7 @@ def __init__( self.add_module( "pixelshuffle", SubpixelUpsample( - dimensions=dimensions, + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, scale_factor=scale_factor_[0], # isotropic @@ -164,19 +173,23 @@ class SubpixelUpsample(nn.Module): """ + @deprecated_arg( + name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." + ) def __init__( self, - dimensions: int, + spatial_dims: int, in_channels: Optional[int], out_channels: Optional[int] = None, scale_factor: int = 2, conv_block: Optional[Union[nn.Module, str]] = "default", apply_pad_pool: bool = True, bias: bool = True, + dimensions: Optional[int] = None, ) -> None: """ Args: - dimensions: number of spatial dimensions of the input image. + spatial_dims: number of spatial dimensions of the input image. in_channels: number of channels of the input image. out_channels: optional number of channels of the output image. scale_factor: multiplier for spatial size. Defaults to 2. @@ -190,21 +203,24 @@ def __init__( size of `scale_factor` with a stride of 1. This implements the nearest neighbour resize convolution component of subpixel convolutions described in Aitken et al. bias: whether to have a bias term in the default conv_block. Defaults to True. + + .. deprecated:: 0.6.0 + ``dimensions`` is deprecated, use ``spatial_dims`` instead. """ super().__init__() if scale_factor <= 0: raise ValueError(f"The `scale_factor` multiplier must be an integer greater than 0, got {scale_factor}.") - self.dimensions = dimensions + self.dimensions = spatial_dims if dimensions is None else dimensions self.scale_factor = scale_factor if conv_block == "default": out_channels = out_channels or in_channels if not out_channels: raise ValueError("in_channels need to be specified.") - conv_out_channels = out_channels * (scale_factor ** dimensions) - self.conv_block = Conv[Conv.CONV, dimensions]( + conv_out_channels = out_channels * (scale_factor ** self.dimensions) + self.conv_block = Conv[Conv.CONV, self.dimensions]( in_channels=in_channels, out_channels=conv_out_channels, kernel_size=3, stride=1, padding=1, bias=bias ) diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index d916c026ff..2cb349f8f0 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -30,11 +30,7 @@ class Warp(nn.Module): Warp an image with given dense displacement field (DDF). """ - def __init__( - self, - mode=GridSampleMode.BILINEAR.value, - padding_mode=GridSamplePadMode.BORDER.value, - ): + def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePadMode.BORDER.value): """ For pytorch native APIs, the possible values are: @@ -50,7 +46,7 @@ def __init__( See also: :py:class:`monai.networks.layers.grid_pull` """ - super(Warp, self).__init__() + super().__init__() # resolves _interp_mode for different methods if USE_COMPILED: @@ -123,13 +119,7 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor): ) # using csrc resampling - return grid_pull( - image, - grid, - bound=self._padding_mode, - extrapolate=True, - interpolation=self._interp_mode, - ) + return grid_pull(image, grid, bound=self._padding_mode, extrapolate=True, interpolation=self._interp_mode) class DVF2DDF(nn.Module): @@ -143,12 +133,9 @@ class DVF2DDF(nn.Module): """ def __init__( - self, - num_steps: int = 7, - mode=GridSampleMode.BILINEAR.value, - padding_mode=GridSamplePadMode.ZEROS.value, + self, num_steps: int = 7, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePadMode.ZEROS.value ): - super(DVF2DDF, self).__init__() + super().__init__() if num_steps <= 0: raise ValueError(f"expecting positive num_steps, got {num_steps}") self.num_steps = num_steps diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index b2defc703d..5115c00af3 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,6 +22,7 @@ Reshape, SavitzkyGolayFilter, SkipConnection, + apply_filter, separable_filtering, ) from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push diff --git a/monai/networks/layers/convutils.py b/monai/networks/layers/convutils.py index 994ca05b85..5efb6e792f 100644 --- a/monai/networks/layers/convutils.py +++ b/monai/networks/layers/convutils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -44,8 +44,7 @@ def same_padding( def stride_minus_kernel_padding( - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], + kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int] ) -> Union[Tuple[int, ...], int]: kernel_size_np = np.atleast_1d(kernel_size) stride_np = np.atleast_1d(stride) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index d4de08fc50..6379f49449 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index 3b2214d59a..bbf925eba9 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/layers/gmm.py b/monai/networks/layers/gmm.py index 3091f95458..eb9a3f91e4 100644 --- a/monai/networks/layers/gmm.py +++ b/monai/networks/layers/gmm.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index 52f19aab29..b17eae6c5b 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,6 +10,7 @@ # limitations under the License. import math +from copy import deepcopy from typing import List, Sequence, Union import torch @@ -20,29 +21,30 @@ from monai.networks.layers.convutils import gaussian_1d from monai.networks.layers.factories import Conv from monai.utils import ( - PT_BEFORE_1_7, ChannelMatching, InvalidPyTorchVersionError, SkipMode, - ensure_tuple_rep, look_up_option, optional_import, + pytorch_after, ) +from monai.utils.misc import issequenceiterable _C, _ = optional_import("monai._C") -if not PT_BEFORE_1_7: +if pytorch_after(1, 7): fft, _ = optional_import("torch.fft") __all__ = [ - "SkipConnection", + "ChannelPad", "Flatten", "GaussianFilter", + "HilbertTransform", "LLTM", "Reshape", - "separable_filtering", "SavitzkyGolayFilter", - "HilbertTransform", - "ChannelPad", + "SkipConnection", + "apply_filter", + "separable_filtering", ] @@ -210,25 +212,97 @@ def separable_filtering(x: torch.Tensor, kernels: List[torch.Tensor], mode: str Args: x: the input image. must have shape (batch, channels, H[, W, ...]). kernels: kernel along each spatial dimension. - could be a single kernel (duplicated for all dimension), or `spatial_dims` number of kernels. + could be a single kernel (duplicated for all spatial dimensions), or + a list of `spatial_dims` number of kernels. mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` - or ``'circular'``. Default: ``'zeros'``. Modes other than ``'zeros'`` require PyTorch version >= 1.5.1. See - torch.nn.Conv1d() for more information. + or ``'circular'``. Default: ``'zeros'``. See ``torch.nn.Conv1d()`` for more information. Raises: TypeError: When ``x`` is not a ``torch.Tensor``. + + Examples: + + .. code-block:: python + + >>> import torch + >>> from monai.networks.layers import separable_filtering + >>> img = torch.randn(2, 4, 32, 32) # batch_size 2, channels 4, 32x32 2D images + # applying a [-1, 0, 1] filter along each of the spatial dimensions. + # the output shape is the same as the input shape. + >>> out = separable_filtering(img, torch.tensor((-1., 0., 1.))) + # applying `[-1, 0, 1]`, `[1, 0, -1]` filters along two spatial dimensions respectively. + # the output shape is the same as the input shape. + >>> out = separable_filtering(img, [torch.tensor((-1., 0., 1.)), torch.tensor((1., 0., -1.))]) + """ if not isinstance(x, torch.Tensor): raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.") spatial_dims = len(x.shape) - 2 - _kernels = [s.float() for s in kernels] + if isinstance(kernels, torch.Tensor): + kernels = [kernels] * spatial_dims + _kernels = [s.to(x) for s in kernels] _paddings = [(k.shape[0] - 1) // 2 for k in _kernels] n_chs = x.shape[1] pad_mode = "constant" if mode == "zeros" else mode - return _separable_filtering_conv(x, kernels, pad_mode, spatial_dims - 1, spatial_dims, _paddings, n_chs) + return _separable_filtering_conv(x, _kernels, pad_mode, spatial_dims - 1, spatial_dims, _paddings, n_chs) + + +def apply_filter(x: torch.Tensor, kernel: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Filtering `x` with `kernel` independently for each batch and channel respectively. + + Args: + x: the input image, must have shape (batch, channels, H[, W, D]). + kernel: `kernel` must at least have the spatial shape (H_k[, W_k, D_k]). + `kernel` shape must be broadcastable to the `batch` and `channels` dimensions of `x`. + kwargs: keyword arguments passed to `conv*d()` functions. + + Returns: + The filtered `x`. + + Examples: + + .. code-block:: python + + >>> import torch + >>> from monai.networks.layers import apply_filter + >>> img = torch.rand(2, 5, 10, 10) # batch_size 2, channels 5, 10x10 2D images + >>> out = apply_filter(img, torch.rand(3, 3)) # spatial kernel + >>> out = apply_filter(img, torch.rand(5, 3, 3)) # channel-wise kernels + >>> out = apply_filter(img, torch.rand(2, 5, 3, 3)) # batch-, channel-wise kernels + + """ + if not isinstance(x, torch.Tensor): + raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.") + batch, chns, *spatials = x.shape + n_spatial = len(spatials) + if n_spatial > 3: + raise NotImplementedError(f"Only spatial dimensions up to 3 are supported but got {n_spatial}.") + k_size = len(kernel.shape) + if k_size < n_spatial or k_size > n_spatial + 2: + raise ValueError( + f"kernel must have {n_spatial} ~ {n_spatial + 2} dimensions to match the input shape {x.shape}." + ) + kernel = kernel.to(x) + # broadcast kernel size to (batch chns, spatial_kernel_size) + kernel = kernel.expand(batch, chns, *kernel.shape[(k_size - n_spatial) :]) + kernel = kernel.reshape(-1, 1, *kernel.shape[2:]) # group=1 + x = x.view(1, kernel.shape[0], *spatials) + conv = [F.conv1d, F.conv2d, F.conv3d][n_spatial - 1] + if "padding" not in kwargs: + if pytorch_after(1, 10): + kwargs["padding"] = "same" + else: + # even-sized kernels are not supported + kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]] + + if "stride" not in kwargs: + kwargs["stride"] = 1 + output = conv(x, kernel, groups=kernel.shape[0], bias=None, **kwargs) + return output.view(batch, chns, *output.shape[2:]) class SavitzkyGolayFilter(nn.Module): @@ -307,12 +381,12 @@ class HilbertTransform(nn.Module): Args: axis: Axis along which to apply Hilbert transform. Default 2 (first spatial dimension). - N: Number of Fourier components (i.e. FFT size). Default: ``x.shape[axis]``. + n: Number of Fourier components (i.e. FFT size). Default: ``x.shape[axis]``. """ def __init__(self, axis: int = 2, n: Union[int, None] = None) -> None: - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__) super().__init__() @@ -393,13 +467,18 @@ def __init__( (for example `parameters()` iterator could be used to get the parameters); otherwise this module will fix the kernels using `sigma` as the std. """ + if issequenceiterable(sigma): + if len(sigma) != spatial_dims: # type: ignore + raise ValueError + else: + sigma = [deepcopy(sigma) for _ in range(spatial_dims)] # type: ignore super().__init__() self.sigma = [ torch.nn.Parameter( torch.as_tensor(s, dtype=torch.float, device=s.device if isinstance(s, torch.Tensor) else None), requires_grad=requires_grad, ) - for s in ensure_tuple_rep(sigma, int(spatial_dims)) + for s in sigma # type: ignore ] self.truncated = truncated self.approx = approx @@ -449,7 +528,7 @@ class LLTM(nn.Module): """ def __init__(self, input_features: int, state_size: int): - super(LLTM, self).__init__() + super().__init__() self.input_features = input_features self.state_size = state_size self.weights = nn.Parameter(torch.empty(3 * state_size, input_features + state_size)) diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py index 511c24fcb0..01e45b2e67 100644 --- a/monai/networks/layers/spatial_transforms.py +++ b/monai/networks/layers/spatial_transforms.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -46,7 +46,9 @@ def backward(ctx, grad): return None, grads[0], None, None, None -def grid_pull(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", bound="zero", extrapolate: bool = True): +def grid_pull( + input: torch.Tensor, grid: torch.Tensor, interpolation="linear", bound="zero", extrapolate: bool = True +) -> torch.Tensor: """ Sample an image with respect to a deformation field. @@ -112,8 +114,9 @@ def grid_pull(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", b _C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i) for i in ensure_tuple(interpolation) ] - - return _GridPull.apply(input, grid, interpolation, bound, extrapolate) + out: torch.Tensor + out = _GridPull.apply(input, grid, interpolation, bound, extrapolate) # type: ignore + return out class _GridPush(torch.autograd.Function): diff --git a/monai/networks/layers/utils.py b/monai/networks/layers/utils.py index 380a77552c..42fac58716 100644 --- a/monai/networks/layers/utils.py +++ b/monai/networks/layers/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index ad1ca2418b..22fcef4903 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .ahnet import AHnet, Ahnet, AHNet, ahnet +from .ahnet import AHnet, Ahnet, AHNet from .autoencoder import AutoEncoder from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet from .classifier import Classifier, Critic, Discriminator @@ -24,13 +24,13 @@ Densenet201, DenseNet264, Densenet264, - densenet, densenet121, densenet169, densenet201, densenet264, ) -from .dynunet import DynUNet, DynUnet, Dynunet, dynunet +from .dints import DiNTS, TopologyConstruction, TopologyInstance, TopologySearch +from .dynunet import DynUNet, DynUnet, Dynunet from .efficientnet import ( BlockArgs, EfficientNet, @@ -42,6 +42,7 @@ from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet from .generator import Generator from .highresnet import HighResBlock, HighResNet +from .milmodel import MILModel from .netadapter import NetAdapter from .regressor import Regressor from .regunet import GlobalNet, LocalNet, RegUNet @@ -71,7 +72,6 @@ SEResNeXt101, SEresnext101, Seresnext101, - senet, senet154, seresnet50, seresnet101, @@ -80,8 +80,10 @@ seresnext101, ) from .torchvision_fc import TorchVisionFCModel, TorchVisionFullyConvModel -from .unet import UNet, Unet, unet +from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex +from .unet import UNet, Unet from .unetr import UNETR from .varautoencoder import VarAutoEncoder from .vit import ViT +from .vitautoenc import ViTAutoEnc from .vnet import VNet diff --git a/monai/networks/nets/ahnet.py b/monai/networks/nets/ahnet.py index 5ca6813efe..b481374aa1 100644 --- a/monai/networks/nets/ahnet.py +++ b/monai/networks/nets/ahnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,7 +19,7 @@ from monai.networks.blocks.fcn import FCN from monai.networks.layers.factories import Act, Conv, Norm, Pool -__all__ = ["AHnet", "Ahnet", "ahnet", "AHNet"] +__all__ = ["AHnet", "Ahnet", "AHNet"] class Bottleneck3x3x1(nn.Module): @@ -35,7 +35,7 @@ def __init__( downsample: Optional[nn.Sequential] = None, ) -> None: - super(Bottleneck3x3x1, self).__init__() + super().__init__() conv_type = Conv[Conv.CONV, spatial_dims] norm_type: Type[Union[nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] @@ -87,7 +87,7 @@ def forward(self, x): class Projection(nn.Sequential): def __init__(self, spatial_dims: int, num_input_features: int, num_output_features: int): - super(Projection, self).__init__() + super().__init__() conv_type = Conv[Conv.CONV, spatial_dims] norm_type: Type[Union[nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] @@ -108,7 +108,7 @@ def __init__( growth_rate: int, dropout_prob: float, ): - super(DenseBlock, self).__init__() + super().__init__() for i in range(num_layers): layer = Pseudo3DLayer( spatial_dims, num_input_features + i * growth_rate, growth_rate, bn_size, dropout_prob @@ -120,7 +120,7 @@ class UpTransition(nn.Sequential): def __init__( self, spatial_dims: int, num_input_features: int, num_output_features: int, upsample_mode: str = "transpose" ): - super(UpTransition, self).__init__() + super().__init__() conv_type = Conv[Conv.CONV, spatial_dims] norm_type: Type[Union[nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] @@ -145,7 +145,7 @@ class Final(nn.Sequential): def __init__( self, spatial_dims: int, num_input_features: int, num_output_features: int, upsample_mode: str = "transpose" ): - super(Final, self).__init__() + super().__init__() conv_type = Conv[Conv.CONV, spatial_dims] norm_type: Type[Union[nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] @@ -178,7 +178,7 @@ def __init__( class Pseudo3DLayer(nn.Module): def __init__(self, spatial_dims: int, num_input_features: int, growth_rate: int, bn_size: int, dropout_prob: float): - super(Pseudo3DLayer, self).__init__() + super().__init__() # 1x1x1 conv_type = Conv[Conv.CONV, spatial_dims] @@ -244,7 +244,7 @@ def forward(self, x): class PSP(nn.Module): def __init__(self, spatial_dims: int, psp_block_num: int, in_ch: int, upsample_mode: str = "transpose"): - super(PSP, self).__init__() + super().__init__() self.up_modules = nn.ModuleList() conv_type = Conv[Conv.CONV, spatial_dims] pool_type: Type[Union[nn.MaxPool2d, nn.MaxPool3d]] = Pool[Pool.MAX, spatial_dims] @@ -256,13 +256,7 @@ def __init__(self, spatial_dims: int, psp_block_num: int, in_ch: int, upsample_m size = (2 ** (i + 3), 2 ** (i + 3), 1)[-spatial_dims:] self.pool_modules.append(pool_type(kernel_size=size, stride=size)) self.project_modules.append( - conv_type( - in_ch, - 1, - kernel_size=(1, 1, 1)[-spatial_dims:], - stride=1, - padding=(1, 1, 0)[-spatial_dims:], - ) + conv_type(in_ch, 1, kernel_size=(1, 1, 1)[-spatial_dims:], stride=1, padding=(1, 1, 0)[-spatial_dims:]) ) self.spatial_dims = spatial_dims @@ -274,15 +268,7 @@ def __init__(self, spatial_dims: int, psp_block_num: int, in_ch: int, upsample_m for i in range(psp_block_num): size = (2 ** (i + 3), 2 ** (i + 3), 1)[-spatial_dims:] pad_size = (2 ** (i + 3), 2 ** (i + 3), 0)[-spatial_dims:] - self.up_modules.append( - conv_trans_type( - 1, - 1, - kernel_size=size, - stride=size, - padding=pad_size, - ) - ) + self.up_modules.append(conv_trans_type(1, 1, kernel_size=size, stride=size, padding=pad_size)) def forward(self, x): outputs = [] @@ -356,7 +342,7 @@ def __init__( progress: bool = True, ): self.inplanes = 64 - super(AHNet, self).__init__() + super().__init__() conv_type = Conv[Conv.CONV, spatial_dims] conv_trans_type = Conv[Conv.CONVTRANS, spatial_dims] @@ -451,13 +437,7 @@ def __init__( net2d = FCN(pretrained=True, progress=progress) self.copy_from(net2d) - def _make_layer( - self, - block: Type[Bottleneck3x3x1], - planes: int, - blocks: int, - stride: int = 1, - ) -> nn.Sequential: + def _make_layer(self, block: Type[Bottleneck3x3x1], planes: int, blocks: int, stride: int = 1) -> nn.Sequential: downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( @@ -559,4 +539,4 @@ def copy_bn_param(module2d, module3d): p3d.data[:] = p2d.data[:] # Two parameter gamma and beta -AHnet = Ahnet = ahnet = AHNet +AHnet = Ahnet = AHNet diff --git a/monai/networks/nets/autoencoder.py b/monai/networks/nets/autoencoder.py index d0a54b8148..75edde70eb 100644 --- a/monai/networks/nets/autoencoder.py +++ b/monai/networks/nets/autoencoder.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,14 +16,84 @@ from monai.networks.blocks import Convolution, ResidualUnit from monai.networks.layers.factories import Act, Norm +from monai.utils import deprecated_arg __all__ = ["AutoEncoder"] class AutoEncoder(nn.Module): + """ + Simple definition of an autoencoder and base class for the architecture implementing + :py:class:`monai.networks.nets.VarAutoEncoder`. The network is composed of an encode sequence of blocks, followed + by an intermediary sequence of blocks, and finally a decode sequence of blocks. The encode and decode blocks are + default :py:class:`monai.networks.blocks.Convolution` instances with the encode blocks having the given stride + and the decode blocks having transpose convolutions with the same stride. If `num_res_units` is given residual + blocks are used instead. + + By default the intermediary sequence is empty but if `inter_channels` is given to specify the output channels of + blocks then this will be become a sequence of Convolution blocks or of residual blocks if `num_inter_units` is + given. The optional parameter `inter_dilations` can be used to specify the dilation values of the convolutions in + these blocks, this allows a network to use dilated kernels in this middle section. Since the intermediary section + isn't meant to change the size of the output the strides for all these kernels is 1. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + channels: sequence of channels. Top block first. The length of `channels` should be no less than 2. + strides: sequence of convolution strides. The length of `stride` should equal to `len(channels) - 1`. + kernel_size: convolution kernel size, the value(s) should be odd. If sequence, + its length should equal to dimensions. Defaults to 3. + up_kernel_size: upsampling convolution kernel size, the value(s) should be odd. If sequence, + its length should equal to dimensions. Defaults to 3. + num_res_units: number of residual units. Defaults to 0. + inter_channels: sequence of channels defining the blocks in the intermediate layer between encode and decode. + inter_dilations: defines the dilation value for each block of the intermediate layer. Defaults to 1. + num_inter_units: number of residual units for each block of the intermediate layer. Defaults to 0. + act: activation type and arguments. Defaults to PReLU. + norm: feature normalization type and arguments. Defaults to instance norm. + dropout: dropout ratio. Defaults to no dropout. + bias: whether to have a bias term in convolution blocks. Defaults to True. + According to `Performance Tuning Guide `_, + if a conv layer is directly followed by a batch norm layer, bias should be False. + + .. deprecated:: 0.6.0 + ``dimensions`` is deprecated, use ``spatial_dims`` instead. + + Examples:: + + from monai.networks.nets import AutoEncoder + + # 3 layers each down/up sampling their inputs by a factor 2 with no intermediate layer + net = AutoEncoder( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(2, 4, 8), + strides=(2, 2, 2) + ) + + # 1 layer downsampling by 2, followed by a sequence of residual units with 2 convolutions defined by + # progressively increasing dilations, then final upsample layer + net = AutoEncoder( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(4,), + strides=(2,), + inter_channels=(8, 8, 8), + inter_dilations=(1, 2, 4), + num_inter_units=2 + ) + + """ + + @deprecated_arg( + name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." + ) def __init__( self, - dimensions: int, + spatial_dims: int, in_channels: int, out_channels: int, channels: Sequence[int], @@ -38,10 +108,11 @@ def __init__( norm: Union[Tuple, str] = Norm.INSTANCE, dropout: Optional[Union[Tuple, str, float]] = None, bias: bool = True, + dimensions: Optional[int] = None, ) -> None: super().__init__() - self.dimensions = dimensions + self.dimensions = spatial_dims if dimensions is None else dimensions self.in_channels = in_channels self.out_channels = out_channels self.channels = list(channels) @@ -71,6 +142,9 @@ def __init__( def _get_encode_module( self, in_channels: int, channels: Sequence[int], strides: Sequence[int] ) -> Tuple[nn.Sequential, int]: + """ + Returns the encode part of the network by building up a sequence of layers returned by `_get_encode_layer`. + """ encode = nn.Sequential() layer_channels = in_channels @@ -82,6 +156,10 @@ def _get_encode_module( return encode, layer_channels def _get_intermediate_module(self, in_channels: int, num_inter_units: int) -> Tuple[nn.Module, int]: + """ + Returns the intermediate block of the network which accepts input from the encoder and whose output goes + to the decoder. + """ # Define some types intermediate: nn.Module unit: nn.Module @@ -95,7 +173,7 @@ def _get_intermediate_module(self, in_channels: int, num_inter_units: int) -> Tu for i, (dc, di) in enumerate(zip(self.inter_channels, self.inter_dilations)): if self.num_inter_units > 0: unit = ResidualUnit( - dimensions=self.dimensions, + spatial_dims=self.dimensions, in_channels=layer_channels, out_channels=dc, strides=1, @@ -109,7 +187,7 @@ def _get_intermediate_module(self, in_channels: int, num_inter_units: int) -> Tu ) else: unit = Convolution( - dimensions=self.dimensions, + spatial_dims=self.dimensions, in_channels=layer_channels, out_channels=dc, strides=1, @@ -129,6 +207,9 @@ def _get_intermediate_module(self, in_channels: int, num_inter_units: int) -> Tu def _get_decode_module( self, in_channels: int, channels: Sequence[int], strides: Sequence[int] ) -> Tuple[nn.Sequential, int]: + """ + Returns the decode part of the network by building up a sequence of layers returned by `_get_decode_layer`. + """ decode = nn.Sequential() layer_channels = in_channels @@ -140,10 +221,13 @@ def _get_decode_module( return decode, layer_channels def _get_encode_layer(self, in_channels: int, out_channels: int, strides: int, is_last: bool) -> nn.Module: - + """ + Returns a single layer of the encoder part of the network. + """ + mod: nn.Module if self.num_res_units > 0: - return ResidualUnit( - dimensions=self.dimensions, + mod = ResidualUnit( + spatial_dims=self.dimensions, in_channels=in_channels, out_channels=out_channels, strides=strides, @@ -155,8 +239,8 @@ def _get_encode_layer(self, in_channels: int, out_channels: int, strides: int, i bias=self.bias, last_conv_only=is_last, ) - return Convolution( - dimensions=self.dimensions, + mod = Convolution( + spatial_dims=self.dimensions, in_channels=in_channels, out_channels=out_channels, strides=strides, @@ -167,13 +251,16 @@ def _get_encode_layer(self, in_channels: int, out_channels: int, strides: int, i bias=self.bias, conv_only=is_last, ) + return mod def _get_decode_layer(self, in_channels: int, out_channels: int, strides: int, is_last: bool) -> nn.Sequential: - + """ + Returns a single layer of the decoder part of the network. + """ decode = nn.Sequential() conv = Convolution( - dimensions=self.dimensions, + spatial_dims=self.dimensions, in_channels=in_channels, out_channels=out_channels, strides=strides, @@ -190,7 +277,7 @@ def _get_decode_layer(self, in_channels: int, out_channels: int, strides: int, i if self.num_res_units > 0: ru = ResidualUnit( - dimensions=self.dimensions, + spatial_dims=self.dimensions, in_channels=out_channels, out_channels=out_channels, strides=1, diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py index 63205f45ee..1e46846576 100644 --- a/monai/networks/nets/basic_unet.py +++ b/monai/networks/nets/basic_unet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,27 +16,29 @@ from monai.networks.blocks import Convolution, UpSample from monai.networks.layers.factories import Conv, Pool -from monai.utils import ensure_tuple_rep +from monai.utils import deprecated_arg, ensure_tuple_rep -__all__ = ["BasicUNet", "BasicUnet", "Basicunet"] +__all__ = ["BasicUnet", "Basicunet", "basicunet", "BasicUNet"] class TwoConv(nn.Sequential): """two convolutions.""" + @deprecated_arg(name="dim", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead.") def __init__( self, - dim: int, + spatial_dims: int, in_chns: int, out_chns: int, act: Union[str, tuple], norm: Union[str, tuple], bias: bool, dropout: Union[float, tuple] = 0.0, + dim: Optional[int] = None, ): """ Args: - dim: number of spatial dimensions. + spatial_dims: number of spatial dimensions. in_chns: number of input channels. out_chns: number of output channels. act: activation type and arguments. @@ -44,11 +46,17 @@ def __init__( bias: whether to have a bias term in convolution blocks. dropout: dropout ratio. Defaults to no dropout. + .. deprecated:: 0.6.0 + ``dim`` is deprecated, use ``spatial_dims`` instead. """ super().__init__() - conv_0 = Convolution(dim, in_chns, out_chns, act=act, norm=norm, dropout=dropout, bias=bias, padding=1) - conv_1 = Convolution(dim, out_chns, out_chns, act=act, norm=norm, dropout=dropout, bias=bias, padding=1) + if dim is not None: + spatial_dims = dim + conv_0 = Convolution(spatial_dims, in_chns, out_chns, act=act, norm=norm, dropout=dropout, bias=bias, padding=1) + conv_1 = Convolution( + spatial_dims, out_chns, out_chns, act=act, norm=norm, dropout=dropout, bias=bias, padding=1 + ) self.add_module("conv_0", conv_0) self.add_module("conv_1", conv_1) @@ -56,19 +64,21 @@ def __init__( class Down(nn.Sequential): """maxpooling downsampling and two convolutions.""" + @deprecated_arg(name="dim", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead.") def __init__( self, - dim: int, + spatial_dims: int, in_chns: int, out_chns: int, act: Union[str, tuple], norm: Union[str, tuple], bias: bool, dropout: Union[float, tuple] = 0.0, + dim: Optional[int] = None, ): """ Args: - dim: number of spatial dimensions. + spatial_dims: number of spatial dimensions. in_chns: number of input channels. out_chns: number of output channels. act: activation type and arguments. @@ -76,11 +86,14 @@ def __init__( bias: whether to have a bias term in convolution blocks. dropout: dropout ratio. Defaults to no dropout. + .. deprecated:: 0.6.0 + ``dim`` is deprecated, use ``spatial_dims`` instead. """ super().__init__() - - max_pooling = Pool["MAX", dim](kernel_size=2) - convs = TwoConv(dim, in_chns, out_chns, act, norm, bias, dropout) + if dim is not None: + spatial_dims = dim + max_pooling = Pool["MAX", spatial_dims](kernel_size=2) + convs = TwoConv(spatial_dims, in_chns, out_chns, act, norm, bias, dropout) self.add_module("max_pooling", max_pooling) self.add_module("convs", convs) @@ -88,9 +101,10 @@ def __init__( class UpCat(nn.Module): """upsampling, concatenation with the encoder feature map, two convolutions""" + @deprecated_arg(name="dim", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead.") def __init__( self, - dim: int, + spatial_dims: int, in_chns: int, cat_chns: int, out_chns: int, @@ -103,10 +117,11 @@ def __init__( interp_mode: str = "linear", align_corners: Optional[bool] = True, halves: bool = True, + dim: Optional[int] = None, ): """ Args: - dim: number of spatial dimensions. + spatial_dims: number of spatial dimensions. in_chns: number of input channels to be upsampled. cat_chns: number of channels from the decoder. out_chns: number of output channels. @@ -124,14 +139,19 @@ def __init__( Only used in the "nontrainable" mode. halves: whether to halve the number of channels during upsampling. This parameter does not work on ``nontrainable`` mode if ``pre_conv`` is `None`. + + .. deprecated:: 0.6.0 + ``dim`` is deprecated, use ``spatial_dims`` instead. """ super().__init__() + if dim is not None: + spatial_dims = dim if upsample == "nontrainable" and pre_conv is None: up_chns = in_chns else: up_chns = in_chns // 2 if halves else in_chns self.upsample = UpSample( - dim, + spatial_dims, in_chns, up_chns, 2, @@ -140,7 +160,7 @@ def __init__( interp_mode=interp_mode, align_corners=align_corners, ) - self.convs = TwoConv(dim, cat_chns + up_chns, out_chns, act, norm, bias, dropout) + self.convs = TwoConv(spatial_dims, cat_chns + up_chns, out_chns, act, norm, bias, dropout) def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]): """ @@ -167,9 +187,12 @@ def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]): class BasicUNet(nn.Module): + @deprecated_arg( + name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." + ) def __init__( self, - dimensions: int = 3, + spatial_dims: int = 3, in_channels: int = 1, out_channels: int = 2, features: Sequence[int] = (32, 32, 64, 128, 256, 32), @@ -178,6 +201,7 @@ def __init__( bias: bool = True, dropout: Union[float, tuple] = 0.0, upsample: str = "deconv", + dimensions: Optional[int] = None, ): """ A UNet implementation with 1D/2D/3D supports. @@ -189,7 +213,7 @@ def __init__( http://dx.doi.org/10.1038/s41592-018-0261-2 Args: - dimensions: number of spatial dimensions. Defaults to 3 for spatial 3D inputs. + spatial_dims: number of spatial dimensions. Defaults to 3 for spatial 3D inputs. in_channels: number of input channels. Defaults to 1. out_channels: number of output channels. Defaults to 2. features: six integers as numbers of features. @@ -207,16 +231,19 @@ def __init__( upsample: upsampling mode, available options are ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. + .. deprecated:: 0.6.0 + ``dimensions`` is deprecated, use ``spatial_dims`` instead. + Examples:: # for spatial 2D - >>> net = BasicUNet(dimensions=2, features=(64, 128, 256, 512, 1024, 128)) + >>> net = BasicUNet(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128)) # for spatial 2D, with group norm - >>> net = BasicUNet(dimensions=2, features=(64, 128, 256, 512, 1024, 128), norm=("group", {"num_groups": 4})) + >>> net = BasicUNet(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128), norm=("group", {"num_groups": 4})) # for spatial 3D - >>> net = BasicUNet(dimensions=3, features=(32, 32, 64, 128, 256, 32)) + >>> net = BasicUNet(spatial_dims=3, features=(32, 32, 64, 128, 256, 32)) See Also @@ -225,22 +252,24 @@ def __init__( """ super().__init__() + if dimensions is not None: + spatial_dims = dimensions fea = ensure_tuple_rep(features, 6) print(f"BasicUNet features: {fea}.") - self.conv_0 = TwoConv(dimensions, in_channels, features[0], act, norm, bias, dropout) - self.down_1 = Down(dimensions, fea[0], fea[1], act, norm, bias, dropout) - self.down_2 = Down(dimensions, fea[1], fea[2], act, norm, bias, dropout) - self.down_3 = Down(dimensions, fea[2], fea[3], act, norm, bias, dropout) - self.down_4 = Down(dimensions, fea[3], fea[4], act, norm, bias, dropout) + self.conv_0 = TwoConv(spatial_dims, in_channels, features[0], act, norm, bias, dropout) + self.down_1 = Down(spatial_dims, fea[0], fea[1], act, norm, bias, dropout) + self.down_2 = Down(spatial_dims, fea[1], fea[2], act, norm, bias, dropout) + self.down_3 = Down(spatial_dims, fea[2], fea[3], act, norm, bias, dropout) + self.down_4 = Down(spatial_dims, fea[3], fea[4], act, norm, bias, dropout) - self.upcat_4 = UpCat(dimensions, fea[4], fea[3], fea[3], act, norm, bias, dropout, upsample) - self.upcat_3 = UpCat(dimensions, fea[3], fea[2], fea[2], act, norm, bias, dropout, upsample) - self.upcat_2 = UpCat(dimensions, fea[2], fea[1], fea[1], act, norm, bias, dropout, upsample) - self.upcat_1 = UpCat(dimensions, fea[1], fea[0], fea[5], act, norm, bias, dropout, upsample, halves=False) + self.upcat_4 = UpCat(spatial_dims, fea[4], fea[3], fea[3], act, norm, bias, dropout, upsample) + self.upcat_3 = UpCat(spatial_dims, fea[3], fea[2], fea[2], act, norm, bias, dropout, upsample) + self.upcat_2 = UpCat(spatial_dims, fea[2], fea[1], fea[1], act, norm, bias, dropout, upsample) + self.upcat_1 = UpCat(spatial_dims, fea[1], fea[0], fea[5], act, norm, bias, dropout, upsample, halves=False) - self.final_conv = Conv["conv", dimensions](fea[5], out_channels, kernel_size=1) + self.final_conv = Conv["conv", spatial_dims](fea[5], out_channels, kernel_size=1) def forward(self, x: torch.Tensor): """ diff --git a/monai/networks/nets/classifier.py b/monai/networks/nets/classifier.py index 92fee4f566..7f4e43eedb 100644 --- a/monai/networks/nets/classifier.py +++ b/monai/networks/nets/classifier.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,6 +25,19 @@ class Classifier(Regressor): Defines a classification network from Regressor by specifying the output shape as a single dimensional tensor with size equal to the number of classes to predict. The final activation function can also be specified, eg. softmax or sigmoid. + + Args: + in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension) + classes: integer stating the dimension of the final output tensor + channels: tuple of integers stating the output channels of each convolutional layer + strides: tuple of integers stating the stride (downscale factor) of each convolutional layer + kernel_size: integer or tuple of integers stating size of convolutional kernels + num_res_units: integer stating number of convolutions in residual units, 0 means no residual units + act: name or type defining activation layers + norm: name or type defining normalization layers + dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout + bias: boolean stating if convolution layers should have a bias component + last_act: name defining the last activation layer """ def __init__( @@ -41,20 +54,6 @@ def __init__( bias: bool = True, last_act: Optional[str] = None, ) -> None: - """ - Args: - in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension) - classes: integer stating the dimension of the final output tensor - channels: tuple of integers stating the output channels of each convolutional layer - strides: tuple of integers stating the stride (downscale factor) of each convolutional layer - kernel_size: integer or tuple of integers stating size of convolutional kernels - num_res_units: integer stating number of convolutions in residual units, 0 means no residual units - act: name or type defining activation layers - norm: name or type defining normalization layers - dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout - bias: boolean stating if convolution layers should have a bias component - last_act: name defining the last activation layer - """ super().__init__(in_shape, (classes,), channels, strides, kernel_size, num_res_units, act, norm, dropout, bias) if last_act is not None: @@ -68,6 +67,18 @@ class Discriminator(Classifier): """ Defines a discriminator network from Classifier with a single output value and sigmoid activation by default. This is meant for use with GANs or other applications requiring a generic discriminator network. + + Args: + in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension) + channels: tuple of integers stating the output channels of each convolutional layer + strides: tuple of integers stating the stride (downscale factor) of each convolutional layer + kernel_size: integer or tuple of integers stating size of convolutional kernels + num_res_units: integer stating number of convolutions in residual units, 0 means no residual units + act: name or type defining activation layers + norm: name or type defining normalization layers + dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout + bias: boolean stating if convolution layers should have a bias component + last_act: name defining the last activation layer """ def __init__( @@ -83,19 +94,6 @@ def __init__( bias: bool = True, last_act=Act.SIGMOID, ) -> None: - """ - Args: - in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension) - channels: tuple of integers stating the output channels of each convolutional layer - strides: tuple of integers stating the stride (downscale factor) of each convolutional layer - kernel_size: integer or tuple of integers stating size of convolutional kernels - num_res_units: integer stating number of convolutions in residual units, 0 means no residual units - act: name or type defining activation layers - norm: name or type defining normalization layers - dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout - bias: boolean stating if convolution layers should have a bias component - last_act: name defining the last activation layer - """ super().__init__(in_shape, 1, channels, strides, kernel_size, num_res_units, act, norm, dropout, bias, last_act) @@ -104,6 +102,17 @@ class Critic(Classifier): Defines a critic network from Classifier with a single output value and no final activation. The final layer is `nn.Flatten` instead of `nn.Linear`, the final result is computed as the mean over the first dimension. This is meant to be used with Wasserstein GANs. + + Args: + in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension) + channels: tuple of integers stating the output channels of each convolutional layer + strides: tuple of integers stating the stride (downscale factor) of each convolutional layer + kernel_size: integer or tuple of integers stating size of convolutional kernels + num_res_units: integer stating number of convolutions in residual units, 0 means no residual units + act: name or type defining activation layers + norm: name or type defining normalization layers + dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout + bias: boolean stating if convolution layers should have a bias component """ def __init__( @@ -118,18 +127,6 @@ def __init__( dropout: Optional[float] = 0.25, bias: bool = True, ) -> None: - """ - Args: - in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension) - channels: tuple of integers stating the output channels of each convolutional layer - strides: tuple of integers stating the stride (downscale factor) of each convolutional layer - kernel_size: integer or tuple of integers stating size of convolutional kernels - num_res_units: integer stating number of convolutions in residual units, 0 means no residual units - act: name or type defining activation layers - norm: name or type defining normalization layers - dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout - bias: boolean stating if convolution layers should have a bias component - """ super().__init__(in_shape, 1, channels, strides, kernel_size, num_res_units, act, norm, dropout, bias, None) def _get_final_layer(self, in_shape: Sequence[int]): diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py index e9f3b6d33e..8fb2c269ab 100644 --- a/monai/networks/nets/densenet.py +++ b/monai/networks/nets/densenet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,7 +23,6 @@ __all__ = [ "DenseNet", - "densenet", "Densenet", "DenseNet121", "densenet121", @@ -62,7 +61,7 @@ def __init__( act: activation type and arguments. Defaults to relu. norm: feature normalization type and arguments. Defaults to batch norm. """ - super(_DenseLayer, self).__init__() + super().__init__() out_channels = bn_size * growth_rate conv_type: Callable = Conv[Conv.CONV, spatial_dims] @@ -110,7 +109,7 @@ def __init__( act: activation type and arguments. Defaults to relu. norm: feature normalization type and arguments. Defaults to batch norm. """ - super(_DenseBlock, self).__init__() + super().__init__() for i in range(layers): layer = _DenseLayer(spatial_dims, in_channels, growth_rate, bn_size, dropout_prob, act=act, norm=norm) in_channels += growth_rate @@ -134,7 +133,7 @@ def __init__( act: activation type and arguments. Defaults to relu. norm: feature normalization type and arguments. Defaults to batch norm. """ - super(_Transition, self).__init__() + super().__init__() conv_type: Callable = Conv[Conv.CONV, spatial_dims] pool_type: Callable = Pool[Pool.AVG, spatial_dims] @@ -178,7 +177,7 @@ def __init__( dropout_prob: float = 0.0, ) -> None: - super(DenseNet, self).__init__() + super().__init__() conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] pool_type: Type[Union[nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d]] = Pool[Pool.MAX, spatial_dims] @@ -299,14 +298,13 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super(DenseNet121, self).__init__( - init_features=init_features, - growth_rate=growth_rate, - block_config=block_config, - **kwargs, - ) + super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs) if pretrained: - # it only worked when `spatial_dims` is 2 + if kwargs["spatial_dims"] > 2: + raise NotImplementedError( + "Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not" + "provide pretrained models for more than two spatial dimensions." + ) _load_state_dict(self, "densenet121", progress) @@ -322,14 +320,13 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super(DenseNet169, self).__init__( - init_features=init_features, - growth_rate=growth_rate, - block_config=block_config, - **kwargs, - ) + super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs) if pretrained: - # it only worked when `spatial_dims` is 2 + if kwargs["spatial_dims"] > 2: + raise NotImplementedError( + "Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not" + "provide pretrained models for more than two spatial dimensions." + ) _load_state_dict(self, "densenet169", progress) @@ -345,14 +342,13 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super(DenseNet201, self).__init__( - init_features=init_features, - growth_rate=growth_rate, - block_config=block_config, - **kwargs, - ) + super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs) if pretrained: - # it only worked when `spatial_dims` is 2 + if kwargs["spatial_dims"] > 2: + raise NotImplementedError( + "Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not" + "provide pretrained models for more than two spatial dimensions." + ) _load_state_dict(self, "densenet201", progress) @@ -363,22 +359,17 @@ def __init__( self, init_features: int = 64, growth_rate: int = 32, - block_config: Sequence[int] = (6, 12, 48, 32), + block_config: Sequence[int] = (6, 12, 64, 48), pretrained: bool = False, progress: bool = True, **kwargs, ) -> None: - super(DenseNet264, self).__init__( - init_features=init_features, - growth_rate=growth_rate, - block_config=block_config, - **kwargs, - ) + super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs) if pretrained: raise NotImplementedError("Currently PyTorch Hub does not provide densenet264 pretrained models.") -Densenet = densenet = DenseNet +Densenet = DenseNet Densenet121 = densenet121 = DenseNet121 Densenet169 = densenet169 = DenseNet169 Densenet201 = densenet201 = DenseNet201 diff --git a/monai/networks/nets/dints.py b/monai/networks/nets/dints.py new file mode 100644 index 0000000000..c024d6e0f1 --- /dev/null +++ b/monai/networks/nets/dints.py @@ -0,0 +1,948 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks.dints_block import ( + ActiConvNormBlock, + FactorizedIncreaseBlock, + FactorizedReduceBlock, + P3DActiConvNormBlock, +) +from monai.networks.layers.factories import Conv +from monai.networks.layers.utils import get_act_layer, get_norm_layer +from monai.utils import optional_import + +# solving shortest path problem +csr_matrix, _ = optional_import("scipy.sparse", name="csr_matrix") +dijkstra, _ = optional_import("scipy.sparse.csgraph", name="dijkstra") + +__all__ = ["DiNTS", "TopologyConstruction", "TopologyInstance", "TopologySearch"] + + +@torch.jit.interface +class CellInterface(torch.nn.Module): + """interface for torchscriptable Cell""" + + def forward(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + pass + + +@torch.jit.interface +class StemInterface(torch.nn.Module): + """interface for torchscriptable Stem""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pass + + +class StemTS(StemInterface): + """wrapper for torchscriptable Stem""" + + def __init__(self, *mod): + super().__init__() + self.mod = torch.nn.Sequential(*mod) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mod(x) # type: ignore + + +def _dfs(node, paths): + """use depth first search to find all path activation combination""" + if node == paths: + return [[0], [1]] + child = _dfs(node + 1, paths) + return [[0] + _ for _ in child] + [[1] + _ for _ in child] + + +class _IdentityWithRAMCost(nn.Identity): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.ram_cost = 0 + + +class _CloseWithRAMCost(nn.Module): + def __init__(self): + super().__init__() + self.ram_cost = 0 + + def forward(self, x): + return torch.tensor(0.0, requires_grad=False).to(x) + + +class _ActiConvNormBlockWithRAMCost(ActiConvNormBlock): + """The class wraps monai layers with ram estimation. The ram_cost = total_ram/output_size is estimated. + Here is the estimation: + feature_size = output_size/out_channel + total_ram = ram_cost * output_size + total_ram = in_channel * feature_size (activation map) + + in_channel * feature_size (convolution map) + + out_channel * feature_size (normalization) + = (2*in_channel + out_channel) * output_size/out_channel + ram_cost = total_ram/output_size = 2 * in_channel/out_channel + 1 + """ + + def __init__(self, in_channel: int, out_channel: int, kernel_size: int, padding: int, spatial_dims: int = 3): + super().__init__(in_channel, out_channel, kernel_size, padding, spatial_dims) + self.ram_cost = 1 + in_channel / out_channel * 2 + + +class _P3DActiConvNormBlockWithRAMCost(P3DActiConvNormBlock): + def __init__(self, in_channel: int, out_channel: int, kernel_size: int, padding: int, p3dmode: int = 0): + super().__init__(in_channel, out_channel, kernel_size, padding, p3dmode) + # 1 in_channel (activation) + 1 in_channel (convolution) + + # 1 out_channel (convolution) + 1 out_channel (normalization) + self.ram_cost = 2 + 2 * in_channel / out_channel + + +class _FactorizedIncreaseBlockWithRAMCost(FactorizedIncreaseBlock): + def __init__(self, in_channel: int, out_channel: int, spatial_dims: int = 3): + super().__init__(in_channel, out_channel, spatial_dims) + # s0 is upsampled 2x from s1, representing feature sizes at two resolutions. + # 2 * in_channel * s0 (upsample + activation) + 2 * out_channel * s0 (conv + normalization) + # s0 = output_size/out_channel + self.ram_cost = 2 * in_channel / out_channel + 2 + + +class _FactorizedReduceBlockWithRAMCost(FactorizedReduceBlock): + def __init__(self, in_channel: int, out_channel: int, spatial_dims: int = 3): + super().__init__(in_channel, out_channel, spatial_dims) + # s0 is upsampled 2x from s1, representing feature sizes at two resolutions. + # in_channel * s0 (activation) + 3 * out_channel * s1 (convolution, concatenation, normalization) + # s0 = s1 * 2^(spatial_dims) = output_size / out_channel * 2^(spatial_dims) + self.ram_cost = in_channel / out_channel * 2 ** self._spatial_dims + 3 + + +class MixedOp(nn.Module): + """ + The weighted averaging of cell operations. + Args: + c: number of output channels. + ops: a dictionary of operations. See also: ``Cell.OPS2D`` or ``Cell.OPS3D``. + arch_code_c: binary cell operation code. It represents the operation results added to the output. + """ + + def __init__(self, c: int, ops: dict, arch_code_c=None): + super().__init__() + if arch_code_c is None: + arch_code_c = np.ones(len(ops)) + self.ops = nn.ModuleList() + for arch_c, op_name in zip(arch_code_c, ops): + self.ops.append(_CloseWithRAMCost() if arch_c == 0 else ops[op_name](c)) + + def forward(self, x: torch.Tensor, weight: torch.Tensor): + """ + Args: + x: input tensor. + weight: learnable architecture weights for cell operations. arch_code_c are derived from it. + Return: + out: weighted average of the operation results. + """ + out = 0.0 + weight = weight.to(x) + for idx, _op in enumerate(self.ops): + out = out + _op(x) * weight[idx] + return out + + +class Cell(CellInterface): + """ + The basic class for cell operation search, which contains a preprocessing operation and a mixed cell operation. + Each cell is defined on a `path` in the topology search space. + Args: + c_prev: number of input channels + c: number of output channels + rate: resolution change rate. It represents the preprocessing operation before the mixed cell operation. + ``-1`` for 2x downsample, ``1`` for 2x upsample, ``0`` for no change of resolution. + arch_code_c: cell operation code + """ + + DIRECTIONS = 3 + # Possible output paths for `Cell`. + # + # - UpSample + # / + # +--+/ + # | |--- Identity or AlignChannels + # +--+\ + # \ + # - Downsample + + # Define connection operation set, parameterized by the number of channels + ConnOPS = { + "up": _FactorizedIncreaseBlockWithRAMCost, + "down": _FactorizedReduceBlockWithRAMCost, + "identity": _IdentityWithRAMCost, + "align_channels": _ActiConvNormBlockWithRAMCost, + } + + # Define 2D operation set, parameterized by the number of channels + OPS2D = { + "skip_connect": lambda _c: _IdentityWithRAMCost(), + "conv_3x3": lambda c: _ActiConvNormBlockWithRAMCost(c, c, 3, padding=1, spatial_dims=2), + } + + # Define 3D operation set, parameterized by the number of channels + OPS3D = { + "skip_connect": lambda _c: _IdentityWithRAMCost(), + "conv_3x3x3": lambda c: _ActiConvNormBlockWithRAMCost(c, c, 3, padding=1, spatial_dims=3), + "conv_3x3x1": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=0), + "conv_3x1x3": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=1), + "conv_1x3x3": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=2), + } + + def __init__(self, c_prev: int, c: int, rate: int, arch_code_c=None, spatial_dims: int = 3): + super().__init__() + self._spatial_dims = spatial_dims + if rate == -1: # downsample + self.preprocess = self.ConnOPS["down"](c_prev, c, spatial_dims=self._spatial_dims) + elif rate == 1: # upsample + self.preprocess = self.ConnOPS["up"](c_prev, c, spatial_dims=self._spatial_dims) + else: + if c_prev == c: + self.preprocess = self.ConnOPS["identity"]() + else: + self.preprocess = self.ConnOPS["align_channels"](c_prev, c, 1, 0, spatial_dims=self._spatial_dims) + + self.OPS = {} + if self._spatial_dims == 2: + self.OPS = self.OPS2D + elif self._spatial_dims == 3: + self.OPS = self.OPS3D + else: + raise NotImplementedError(f"Spatial dimensions {self._spatial_dims} is not supported.") + + self.op = MixedOp(c, self.OPS, arch_code_c) + + def forward(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """ + Args: + x: input tensor + weight: weights for different operations. + """ + x = self.preprocess(x) + x = self.op(x, weight) + return x + + +class DiNTS(nn.Module): + """ + Reimplementation of DiNTS based on + "DiNTS: Differentiable Neural Network Topology Search for 3D Medical Image Segmentation + ". + + The model contains a pre-defined multi-resolution stem block (defined in this class) and a + DiNTS space (defined in :py:class:`monai.networks.nets.TopologyInstance` and + :py:class:`monai.networks.nets.TopologySearch`). + + The stem block is for: 1) input downsample and 2) output upsample to original size. + The model downsamples the input image by 2 (if ``use_downsample=True``). + The downsampled image is downsampled by [1, 2, 4, 8] times (``num_depths=4``) and used as input to the + DiNTS search space (``TopologySearch``) or the DiNTS instance (``TopologyInstance``). + + - ``TopologyInstance`` is the final searched model. The initialization requires the searched architecture codes. + - ``TopologySearch`` is a multi-path topology and cell operation search space. + The architecture codes will be initialized as one. + - ``TopologyConstruction`` is the parent class which constructs the instance and search space. + + To meet the requirements of the structure, the input size for each spatial dimension should be: + divisible by 2 ** (num_depths + 1). + + Args: + dints_space: DiNTS search space. The value should be instance of `TopologyInstance` or `TopologySearch`. + in_channels: number of input image channels. + num_classes: number of output segmentation classes. + act_name: activation name, default to 'RELU'. + norm_name: normalization used in convolution blocks. Default to `InstanceNorm`. + spatial_dims: spatial 2D or 3D inputs. + use_downsample: use downsample in the stem. + If ``False``, the search space will be in resolution [1, 1/2, 1/4, 1/8], + if ``True``, the search space will be in resolution [1/2, 1/4, 1/8, 1/16]. + node_a: node activation numpy matrix. Its shape is `(num_depths, num_blocks + 1)`. + +1 for multi-resolution inputs. + In model searching stage, ``node_a`` can be None. In deployment stage, ``node_a`` cannot be None. + """ + + def __init__( + self, + dints_space, + in_channels: int, + num_classes: int, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = "INSTANCE", + spatial_dims: int = 3, + use_downsample: bool = True, + node_a=None, + ): + super().__init__() + + self.dints_space = dints_space + self.filter_nums = dints_space.filter_nums + self.num_blocks = dints_space.num_blocks + self.num_depths = dints_space.num_depths + if spatial_dims not in (2, 3): + raise NotImplementedError(f"Spatial dimensions {spatial_dims} is not supported.") + self._spatial_dims = spatial_dims + if node_a is None: + self.node_a = torch.ones((self.num_blocks + 1, self.num_depths)) + else: + self.node_a = node_a + + # define stem operations for every block + conv_type = Conv[Conv.CONV, spatial_dims] + self.stem_down = nn.ModuleDict() + self.stem_up = nn.ModuleDict() + self.stem_finals = nn.Sequential( + ActiConvNormBlock( + self.filter_nums[0], + self.filter_nums[0], + act_name=act_name, + norm_name=norm_name, + spatial_dims=spatial_dims, + ), + conv_type( + in_channels=self.filter_nums[0], + out_channels=num_classes, + kernel_size=1, + stride=1, + padding=0, + groups=1, + bias=True, + dilation=1, + ), + ) + mode = "trilinear" if self._spatial_dims == 3 else "bilinear" + for res_idx in range(self.num_depths): + # define downsample stems before DiNTS search + if use_downsample: + self.stem_down[str(res_idx)] = StemTS( + nn.Upsample(scale_factor=1 / (2 ** res_idx), mode=mode, align_corners=True), + conv_type( + in_channels=in_channels, + out_channels=self.filter_nums[res_idx], + kernel_size=3, + stride=1, + padding=1, + groups=1, + bias=False, + dilation=1, + ), + get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx]), + get_act_layer(name=act_name), + conv_type( + in_channels=self.filter_nums[res_idx], + out_channels=self.filter_nums[res_idx + 1], + kernel_size=3, + stride=2, + padding=1, + groups=1, + bias=False, + dilation=1, + ), + get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx + 1]), + ) + self.stem_up[str(res_idx)] = StemTS( + get_act_layer(name=act_name), + conv_type( + in_channels=self.filter_nums[res_idx + 1], + out_channels=self.filter_nums[res_idx], + kernel_size=3, + stride=1, + padding=1, + groups=1, + bias=False, + dilation=1, + ), + get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx]), + nn.Upsample(scale_factor=2, mode=mode, align_corners=True), + ) + + else: + self.stem_down[str(res_idx)] = StemTS( + nn.Upsample(scale_factor=1 / (2 ** res_idx), mode=mode, align_corners=True), + conv_type( + in_channels=in_channels, + out_channels=self.filter_nums[res_idx], + kernel_size=3, + stride=1, + padding=1, + groups=1, + bias=False, + dilation=1, + ), + get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx]), + ) + self.stem_up[str(res_idx)] = StemTS( + get_act_layer(name=act_name), + conv_type( + in_channels=self.filter_nums[res_idx], + out_channels=self.filter_nums[max(res_idx - 1, 0)], + kernel_size=3, + stride=1, + padding=1, + groups=1, + bias=False, + dilation=1, + ), + get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx - 1]), + nn.Upsample(scale_factor=2 ** (res_idx != 0), mode=mode, align_corners=True), + ) + + def weight_parameters(self): + return [param for name, param in self.named_parameters()] + + def forward(self, x: torch.Tensor): + """ + Prediction based on dynamic arch_code. + + Args: + x: input tensor. + """ + inputs = [] + for d in range(self.num_depths): + # allow multi-resolution input + _mod_w: StemInterface = self.stem_down[str(d)] + x_out = _mod_w.forward(x) + if self.node_a[0][d]: + inputs.append(x_out) + else: + inputs.append(torch.zeros_like(x_out)) + + outputs = self.dints_space(inputs) + + blk_idx = self.num_blocks - 1 + start = False + _temp: torch.Tensor = torch.empty(0) + for res_idx in range(self.num_depths - 1, -1, -1): + _mod_up: StemInterface = self.stem_up[str(res_idx)] + if start: + _temp = _mod_up.forward(outputs[res_idx] + _temp) + elif self.node_a[blk_idx + 1][res_idx]: + start = True + _temp = _mod_up.forward(outputs[res_idx]) + prediction = self.stem_finals(_temp) + return prediction + + +class TopologyConstruction(nn.Module): + """ + The base class for `TopologyInstance` and `TopologySearch`. + + Args: + arch_code: `[arch_code_a, arch_code_c]`, numpy arrays. The architecture codes defining the model. + For example, for a ``num_depths=4, num_blocks=12`` search space: + + - `arch_code_a` is a 12x10 (10 paths) binary matrix representing if a path is activated. + - `arch_code_c` is a 12x10x5 (5 operations) binary matrix representing if a cell operation is used. + - `arch_code` in ``__init__()`` is used for creating the network and remove unused network blocks. If None, + + all paths and cells operations will be used, and must be in the searching stage (is_search=True). + channel_mul: adjust intermediate channel number, default is 1. + cell: operation of each node. + num_blocks: number of blocks (depth in the horizontal direction) of the DiNTS search space. + num_depths: number of image resolutions of the DiNTS search space: 1, 1/2, 1/4 ... in each dimension. + use_downsample: use downsample in the stem. If False, the search space will be in resolution [1, 1/2, 1/4, 1/8], + if True, the search space will be in resolution [1/2, 1/4, 1/8, 1/16]. + device: `'cpu'`, `'cuda'`, or device ID. + + + Predefined variables: + `filter_nums`: default to 32. Double the number of channels after downsample. + topology related variables: + + - `arch_code2in`: path activation to its incoming node index (resolution). For depth = 4, + arch_code2in = [0, 1, 0, 1, 2, 1, 2, 3, 2, 3]. The first path outputs from node 0 (top resolution), + the second path outputs from node 1 (second resolution in the search space), + the third path outputs from node 0, etc. + - `arch_code2ops`: path activation to operations of upsample 1, keep 0, downsample -1. For depth = 4, + arch_code2ops = [0, 1, -1, 0, 1, -1, 0, 1, -1, 0]. The first path does not change + resolution, the second path perform upsample, the third perform downsample, etc. + - `arch_code2out`: path activation to its output node index. + For depth = 4, arch_code2out = [0, 0, 1, 1, 1, 2, 2, 2, 3, 3], + the first and second paths connects to node 0 (top resolution), the 3,4,5 paths connects to node 1, etc. + """ + + def __init__( + self, + arch_code: Optional[list] = None, + channel_mul: float = 1.0, + cell=Cell, + num_blocks: int = 6, + num_depths: int = 3, + spatial_dims: int = 3, + use_downsample: bool = True, + device: str = "cpu", + ): + + super().__init__() + + self.filter_nums = [int(n_feat * channel_mul) for n_feat in (32, 64, 128, 256, 512)] + self.num_blocks = num_blocks + self.num_depths = num_depths + self._spatial_dims = spatial_dims + self.use_downsample = use_downsample + self.device = device + self.num_cell_ops = 0 + if self._spatial_dims == 2: + self.num_cell_ops = len(cell.OPS2D) + elif self._spatial_dims == 3: + self.num_cell_ops = len(cell.OPS3D) + + # Calculate predefined parameters for topology search and decoding + arch_code2in, arch_code2out = [], [] + for i in range(Cell.DIRECTIONS * self.num_depths - 2): + arch_code2in.append((i + 1) // Cell.DIRECTIONS - 1 + (i + 1) % Cell.DIRECTIONS) + arch_code2ops = ([-1, 0, 1] * self.num_depths)[1:-1] + for m in range(self.num_depths): + arch_code2out.extend([m, m, m]) + arch_code2out = arch_code2out[1:-1] + self.arch_code2in = arch_code2in + self.arch_code2ops = arch_code2ops + self.arch_code2out = arch_code2out + + # define NAS search space + if arch_code is None: + arch_code_a = torch.ones((self.num_blocks, len(self.arch_code2out))).to(self.device) + arch_code_c = torch.ones((self.num_blocks, len(self.arch_code2out), self.num_cell_ops)).to(self.device) + else: + arch_code_a = torch.from_numpy(arch_code[0]).to(self.device) + arch_code_c = F.one_hot(torch.from_numpy(arch_code[1]).to(torch.int64), self.num_cell_ops).to(self.device) + + self.arch_code_a = arch_code_a + self.arch_code_c = arch_code_c + # define cell operation on each path + self.cell_tree = nn.ModuleDict() + for blk_idx in range(self.num_blocks): + for res_idx in range(len(self.arch_code2out)): + if self.arch_code_a[blk_idx, res_idx] == 1: + self.cell_tree[str((blk_idx, res_idx))] = cell( + self.filter_nums[self.arch_code2in[res_idx] + int(use_downsample)], + self.filter_nums[self.arch_code2out[res_idx] + int(use_downsample)], + self.arch_code2ops[res_idx], + self.arch_code_c[blk_idx, res_idx], + self._spatial_dims, + ) + + def forward(self, x): + """This function to be implemented by the architecture instances or search spaces.""" + pass + + +class TopologyInstance(TopologyConstruction): + """ + Instance of the final searched architecture. Only used in re-training/inference stage. + """ + + def __init__( + self, + arch_code=None, + channel_mul: float = 1.0, + cell=Cell, + num_blocks: int = 6, + num_depths: int = 3, + spatial_dims: int = 3, + use_downsample: bool = True, + device: str = "cpu", + ): + """ + Initialize DiNTS topology search space of neural architectures. + """ + if arch_code is None: + warnings.warn("arch_code not provided when not searching.") + + super().__init__( + arch_code=arch_code, + channel_mul=channel_mul, + cell=cell, + num_blocks=num_blocks, + num_depths=num_depths, + spatial_dims=spatial_dims, + use_downsample=use_downsample, + device=device, + ) + + def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: + """ + Args: + x: input tensor. + """ + # generate path activation probability + inputs, outputs = x, [torch.tensor(0.0).to(x[0])] * self.num_depths + for blk_idx in range(self.num_blocks): + outputs = [torch.tensor(0.0).to(x[0])] * self.num_depths + for res_idx, activation in enumerate(self.arch_code_a[blk_idx].data): + if activation: + mod: CellInterface = self.cell_tree[str((blk_idx, res_idx))] + _out = mod.forward( + x=inputs[self.arch_code2in[res_idx]], weight=torch.ones_like(self.arch_code_c[blk_idx, res_idx]) + ) + outputs[self.arch_code2out[res_idx]] = outputs[self.arch_code2out[res_idx]] + _out + inputs = outputs + + return inputs + + +class TopologySearch(TopologyConstruction): + """ + DiNTS topology search space of neural architectures. + + Examples: + + .. code-block:: python + + from monai.networks.nets.dints import TopologySearch + + topology_search_space = TopologySearch( + channel_mul=0.5, num_blocks=8, num_depths=4, use_downsample=True, spatial_dims=3) + topology_search_space.get_ram_cost_usage(in_size=(2, 16, 80, 80, 80), full=True) + multi_res_images = [ + torch.randn(2, 16, 80, 80, 80), + torch.randn(2, 32, 40, 40, 40), + torch.randn(2, 64, 20, 20, 20), + torch.randn(2, 128, 10, 10, 10)] + prediction = topology_search_space(image) + for x in prediction: print(x.shape) + # torch.Size([2, 16, 80, 80, 80]) + # torch.Size([2, 32, 40, 40, 40]) + # torch.Size([2, 64, 20, 20, 20]) + # torch.Size([2, 128, 10, 10, 10]) + + Class method overview: + + - ``get_prob_a()``: convert learnable architecture weights to path activation probabilities. + - ``get_ram_cost_usage()``: get estimated ram cost. + - ``get_topology_entropy()``: get topology entropy loss in searching stage. + - ``decode()``: get final binarized architecture code. + - ``gen_mtx()``: generate variables needed for topology search. + + Predefined variables: + - `tidx`: index used to convert path activation matrix T = (depth,depth) in transfer_mtx to + path activation arch_code (1,3*depth-2), for depth = 4, tidx = [0, 1, 4, 5, 6, 9, 10, 11, 14, 15], + A tidx (10 binary values) represents the path activation. + - `transfer_mtx`: feasible path activation matrix (denoted as T) given a node activation pattern. + It is used to convert path activation pattern (1, paths) to node activation (1, nodes) + - `node_act_list`: all node activation [2^num_depths-1, depth]. For depth = 4, there are 15 node activation + patterns, each of length 4. For example, [1,1,0,0] means nodes 0, 1 are activated (with input paths). + - `all_connect`: All possible path activations. For depth = 4, + all_connection has 1024 vectors of length 10 (10 paths). + The return value will exclude path activation of all 0. + """ + + def __init__( + self, + channel_mul: float = 1.0, + cell=Cell, + arch_code: Optional[list] = None, + num_blocks: int = 6, + num_depths: int = 3, + spatial_dims: int = 3, + use_downsample: bool = True, + device: str = "cpu", + ): + """ + Initialize DiNTS topology search space of neural architectures. + """ + super().__init__( + arch_code=arch_code, + channel_mul=channel_mul, + cell=cell, + num_blocks=num_blocks, + num_depths=num_depths, + spatial_dims=spatial_dims, + use_downsample=use_downsample, + device=device, + ) + + tidx = [] + _d = Cell.DIRECTIONS + for i in range(_d * self.num_depths - 2): + tidx.append((i + 1) // _d * self.num_depths + (i + 1) // _d - 1 + (i + 1) % _d) + self.tidx = tidx + transfer_mtx, node_act_list, child_list = self.gen_mtx(num_depths) + + self.node_act_list = np.asarray(node_act_list) + self.node_act_dict = {str(self.node_act_list[i]): i for i in range(len(self.node_act_list))} + self.transfer_mtx = transfer_mtx + self.child_list = np.asarray(child_list) + + self.ram_cost = np.zeros((self.num_blocks, len(self.arch_code2out), self.num_cell_ops)) + for blk_idx in range(self.num_blocks): + for res_idx in range(len(self.arch_code2out)): + if self.arch_code_a[blk_idx, res_idx] == 1: + self.ram_cost[blk_idx, res_idx] = np.array( + [ + op.ram_cost + self.cell_tree[str((blk_idx, res_idx))].preprocess.ram_cost + for op in self.cell_tree[str((blk_idx, res_idx))].op.ops[: self.num_cell_ops] + ] + ) + + # define cell and macro architecture probabilities + self.log_alpha_c = nn.Parameter( + torch.zeros(self.num_blocks, len(self.arch_code2out), self.num_cell_ops) + .normal_(1, 0.01) + .to(self.device) + .requires_grad_() + ) + self.log_alpha_a = nn.Parameter( + torch.zeros(self.num_blocks, len(self.arch_code2out)).normal_(0, 0.01).to(self.device).requires_grad_() + ) + self._arch_param_names = ["log_alpha_a", "log_alpha_c"] + + def gen_mtx(self, depth: int): + """ + Generate elements needed in decoding and topology. + + - `transfer_mtx`: feasible path activation matrix (denoted as T) given a node activation pattern. + It is used to convert path activation pattern (1, paths) to node activation (1, nodes) + - `node_act_list`: all node activation [2^num_depths-1, depth]. For depth = 4, there are 15 node activation + patterns, each of length 4. For example, [1,1,0,0] means nodes 0, 1 are activated (with input paths). + - `all_connect`: All possible path activations. For depth = 4, + all_connection has 1024 vectors of length 10 (10 paths). + The return value will exclude path activation of all 0. + """ + # total paths in a block, each node has three output paths, + # except the two nodes at the top and the bottom scales + paths = Cell.DIRECTIONS * depth - 2 + + # for 10 paths, all_connect has 1024 possible path activations. [1 0 0 0 0 0 0 0 0 0] means the top + # path is activated. + all_connect = _dfs(0, paths - 1) + + # Save all possible connections in mtx (might be redundant and infeasible) + mtx = [] + for m in all_connect: + # convert path activation [1,paths] to path activation matrix [depth, depth] + ma = np.zeros((depth, depth)) + for i in range(paths): + ma[(i + 1) // Cell.DIRECTIONS, (i + 1) // Cell.DIRECTIONS - 1 + (i + 1) % Cell.DIRECTIONS] = m[i] + mtx.append(ma) + + # define all possible node activation + node_act_list = _dfs(0, depth - 1)[1:] + transfer_mtx = {} + for arch_code in node_act_list: + # make sure each activated node has an active connection, inactivated node has no connection + arch_code_mtx = [_ for _ in mtx if ((np.sum(_, 0) > 0).astype(int) == np.array(arch_code)).all()] + transfer_mtx[str(np.array(arch_code))] = arch_code_mtx + + return transfer_mtx, node_act_list, all_connect[1:] + + def weight_parameters(self): + return [param for name, param in self.named_parameters() if name not in self._arch_param_names] + + def get_prob_a(self, child: bool = False): + """ + Get final path and child model probabilities from architecture weights `log_alpha_a`. + This is used in forward pass, getting training loss, and final decoding. + + Args: + child: return child probability (used in decoding) + Return: + arch_code_prob_a: the path activation probability of size: + `[number of blocks, number of paths in each block]`. + For 12 blocks, 4 depths search space, the size is [12,10] + probs_a: The probability of all child models (size 1023x10). Each child model is a path activation pattern + (1D vector of length 10 for 10 paths). In total 1023 child models (2^10 -1) + """ + _arch_code_prob_a = torch.sigmoid(self.log_alpha_a) + # remove the case where all path are zero, and re-normalize. + norm = 1 - (1 - _arch_code_prob_a).prod(-1) + arch_code_prob_a = _arch_code_prob_a / norm.unsqueeze(1) + if child: + path_activation = torch.from_numpy(self.child_list).to(self.device) + probs_a = [ + ( + path_activation * _arch_code_prob_a[blk_idx] + + (1 - path_activation) * (1 - _arch_code_prob_a[blk_idx]) + ).prod(-1) + / norm[blk_idx] + for blk_idx in range(self.num_blocks) + ] + probs_a = torch.stack(probs_a) # type: ignore + return probs_a, arch_code_prob_a + return None, arch_code_prob_a + + def get_ram_cost_usage(self, in_size, full: bool = False): + """ + Get estimated output tensor size to approximate RAM consumption. + + Args: + in_size: input image shape (4D/5D, ``[BCHW[D]]``) at the highest resolution level. + full: full ram cost usage with all probability of 1. + """ + # convert input image size to feature map size at each level + batch_size = in_size[0] + image_size = np.array(in_size[-self._spatial_dims :]) + sizes = [] + for res_idx in range(self.num_depths): + sizes.append(batch_size * self.filter_nums[res_idx] * (image_size // (2 ** res_idx)).prod()) + sizes = torch.tensor(sizes).to(torch.float32).to(self.device) / (2 ** (int(self.use_downsample))) + probs_a, arch_code_prob_a = self.get_prob_a(child=False) + cell_prob = F.softmax(self.log_alpha_c, dim=-1) + if full: + arch_code_prob_a = arch_code_prob_a.detach() + arch_code_prob_a.fill_(1) + ram_cost = torch.from_numpy(self.ram_cost).to(torch.float32).to(self.device) + usage = 0.0 + for blk_idx in range(self.num_blocks): + # node activation for input + # cell operation + for path_idx in range(len(self.arch_code2out)): + usage += ( + arch_code_prob_a[blk_idx, path_idx] + * (1 + (ram_cost[blk_idx, path_idx] * cell_prob[blk_idx, path_idx]).sum()) + * sizes[self.arch_code2out[path_idx]] + ) + return usage * 32 / 8 / 1024 ** 2 + + def get_topology_entropy(self, probs): + """ + Get topology entropy loss at searching stage. + + Args: + probs: path activation probabilities + """ + if hasattr(self, "node2in"): + node2in = self.node2in + node2out = self.node2out + else: + # node activation index to feasible input child_idx + node2in = [[] for _ in range(len(self.node_act_list))] + # node activation index to feasible output child_idx + node2out = [[] for _ in range(len(self.node_act_list))] + for child_idx in range(len(self.child_list)): + _node_in, _node_out = np.zeros(self.num_depths), np.zeros(self.num_depths) + for res_idx in range(len(self.arch_code2out)): + _node_out[self.arch_code2out[res_idx]] += self.child_list[child_idx][res_idx] + _node_in[self.arch_code2in[res_idx]] += self.child_list[child_idx][res_idx] + _node_in = (_node_in >= 1).astype(int) + _node_out = (_node_out >= 1).astype(int) + node2in[self.node_act_dict[str(_node_out)]].append(child_idx) + node2out[self.node_act_dict[str(_node_in)]].append(child_idx) + self.node2in = node2in + self.node2out = node2out + # calculate entropy + ent = 0 + for blk_idx in range(self.num_blocks - 1): + blk_ent = 0 + # node activation probability + for node_idx in range(len(self.node_act_list)): + _node_p = probs[blk_idx, node2in[node_idx]].sum() + _out_probs = probs[blk_idx + 1, node2out[node_idx]].sum() + blk_ent += -(_node_p * torch.log(_out_probs + 1e-5) + (1 - _node_p) * torch.log(1 - _out_probs + 1e-5)) + ent += blk_ent + return ent + + def decode(self): + """ + Decode network log_alpha_a/log_alpha_c using dijkstra shortest path algorithm. + + `[node_a, arch_code_a, arch_code_c, arch_code_a_max]` is decoded when using ``self.decode()``. + + For example, for a ``num_depths=4``, ``num_blocks=12`` search space: + + - ``node_a`` is a 4x13 binary matrix representing if a feature node is activated + (13 because of multi-resolution inputs). + - ``arch_code_a`` is a 12x10 (10 paths) binary matrix representing if a path is activated. + - ``arch_code_c`` is a 12x10x5 (5 operations) binary matrix representing if a cell operation is used. + + Return: + arch_code with maximum probability + """ + probs, arch_code_prob_a = self.get_prob_a(child=True) + arch_code_a_max = self.child_list[torch.argmax(probs, -1).data.cpu().numpy()] + arch_code_c = torch.argmax(F.softmax(self.log_alpha_c, -1), -1).data.cpu().numpy() + probs = probs.data.cpu().numpy() + + # define adjacency matrix + amtx = np.zeros( + (1 + len(self.child_list) * self.num_blocks + 1, 1 + len(self.child_list) * self.num_blocks + 1) + ) + + # build a path activation to child index searching dictionary + path2child = {str(self.child_list[i]): i for i in range(len(self.child_list))} + + # build a submodel to submodel index + sub_amtx = np.zeros((len(self.child_list), len(self.child_list))) + for child_idx in range(len(self.child_list)): + _node_act = np.zeros(self.num_depths).astype(int) + for path_idx in range(len(self.child_list[child_idx])): + _node_act[self.arch_code2out[path_idx]] += self.child_list[child_idx][path_idx] + _node_act = (_node_act >= 1).astype(int) + for mtx in self.transfer_mtx[str(_node_act)]: + connect_child_idx = path2child[str(mtx.flatten()[self.tidx].astype(int))] + sub_amtx[child_idx, connect_child_idx] = 1 + + # fill in source to first block, add 1e-5/1e-3 to avoid log0 and negative edge weights + amtx[0, 1 : 1 + len(self.child_list)] = -np.log(probs[0] + 1e-5) + 0.001 + + # fill in the rest blocks + for blk_idx in range(1, self.num_blocks): + amtx[ + 1 + (blk_idx - 1) * len(self.child_list) : 1 + blk_idx * len(self.child_list), + 1 + blk_idx * len(self.child_list) : 1 + (blk_idx + 1) * len(self.child_list), + ] = sub_amtx * np.tile(-np.log(probs[blk_idx] + 1e-5) + 0.001, (len(self.child_list), 1)) + + # fill in the last to the sink + amtx[1 + (self.num_blocks - 1) * len(self.child_list) : 1 + self.num_blocks * len(self.child_list), -1] = 0.001 + + graph = csr_matrix(amtx) + dist_matrix, predecessors, sources = dijkstra( + csgraph=graph, directed=True, indices=0, min_only=True, return_predecessors=True + ) + index, a_idx = -1, -1 + arch_code_a = np.zeros((self.num_blocks, len(self.arch_code2out))) + node_a = np.zeros((self.num_blocks + 1, self.num_depths)) + + # decoding to paths + while True: + index = predecessors[index] + if index == 0: + break + child_idx = (index - 1) % len(self.child_list) + arch_code_a[a_idx, :] = self.child_list[child_idx] + for res_idx in range(len(self.arch_code2out)): + node_a[a_idx, self.arch_code2out[res_idx]] += arch_code_a[a_idx, res_idx] + a_idx -= 1 + for res_idx in range(len(self.arch_code2out)): + node_a[a_idx, self.arch_code2in[res_idx]] += arch_code_a[0, res_idx] + node_a = (node_a >= 1).astype(int) + return node_a, arch_code_a, arch_code_c, arch_code_a_max + + def forward(self, x): + """ + Prediction based on dynamic arch_code. + + Args: + x: a list of `num_depths` input tensors as a multi-resolution input. + tensor is of shape `BCHW[D]` where `C` must match `self.filter_nums`. + """ + # generate path activation probability + probs_a, arch_code_prob_a = self.get_prob_a(child=False) + inputs = x + for blk_idx in range(self.num_blocks): + outputs = [0.0] * self.num_depths + for res_idx, activation in enumerate(self.arch_code_a[blk_idx].data.cpu().numpy()): + if activation: + _w = F.softmax(self.log_alpha_c[blk_idx, res_idx], dim=-1) + outputs[self.arch_code2out[res_idx]] += ( + self.cell_tree[str((blk_idx, res_idx))](inputs[self.arch_code2in[res_idx]], weight=_w) + * arch_code_prob_a[blk_idx, res_idx] + ) + inputs = outputs + + return inputs diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index 4af70b22c7..337a99acd8 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,7 +18,7 @@ from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock -__all__ = ["DynUNet", "DynUnet", "Dynunet", "dynunet"] +__all__ = ["DynUNet", "DynUnet", "Dynunet"] class DynUNetSkipLayer(nn.Module): @@ -31,13 +31,13 @@ class DynUNetSkipLayer(nn.Module): forward passes of the network. """ - heads: List[torch.Tensor] + heads: Optional[List[torch.Tensor]] - def __init__(self, index, heads, downsample, upsample, super_head, next_layer): + def __init__(self, index, downsample, upsample, next_layer, heads=None, super_head=None): super().__init__() self.downsample = downsample - self.upsample = upsample self.next_layer = next_layer + self.upsample = upsample self.super_head = super_head self.heads = heads self.index = index @@ -46,8 +46,8 @@ def forward(self, x): downout = self.downsample(x) nextout = self.next_layer(downout) upout = self.upsample(nextout, downout) - - self.heads[self.index] = self.super_head(upout) + if self.super_head is not None and self.heads is not None and self.index > 0: + self.heads[self.index - 1] = self.super_head(upout) return upout @@ -57,6 +57,7 @@ class DynUNet(nn.Module): This reimplementation of a dynamic UNet (DynUNet) is based on: `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. + `Optimized U-Net for Brain Tumor Segmentation `_. This model is more flexible compared with ``monai.networks.nets.UNet`` in three places: @@ -74,10 +75,15 @@ class DynUNet(nn.Module): To meet the requirements of the structure, the input size for each spatial dimension should be divisible by `2 * the product of all strides in the corresponding dimension`. The output size for each spatial dimension - equals to the input size of the correponding dimension divided by the stride in strides[0]. + equals to the input size of the corresponding dimension divided by the stride in strides[0]. For example, if `strides=((1, 2, 4), 2, 1, 1)`, the minimal spatial size of the input is `(8, 16, 32)`, and the spatial size of the output is `(8, 8, 8)`. + For backwards compatibility with old weights, please set `strict=False` when calling `load_state_dict`. + + Usage example with medical segmentation decathlon dataset is available at: + https://github.com/Project-MONAI/tutorials/tree/master/modules/dynunet_pipeline. + Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. @@ -86,25 +92,32 @@ class DynUNet(nn.Module): strides: convolution strides for each blocks. upsample_kernel_size: convolution kernel size for transposed convolution layers. The values should equal to strides[1:]. + filters: number of output channels for each blocks. Different from nnU-Net, in this implementation we add + this argument to make the network more flexible. As shown in the third reference, one way to determine + this argument is like: + ``[64, 96, 128, 192, 256, 384, 512, 768, 1024][: len(strides)]``. + The above way is used in the network that wins task 1 in the BraTS21 Challenge. + If not specified, the way which nnUNet used will be employed. Defaults to ``None``. + dropout: dropout ratio. Defaults to no dropout. norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``. + act_name: activation layer type and arguments. Defaults to ``leakyrelu``. deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. - If ``True``, in training mode, the forward function will output not only the last feature - map, but also the previous feature maps that come from the intermediate up sample layers. + If ``True``, in training mode, the forward function will output not only the final feature map + (from `output_block`), but also the feature maps that come from the intermediate up sample layers. In order to unify the return type (the restriction of TorchScript), all intermediate - feature maps are interpolated into the same size as the last feature map and stacked together + feature maps are interpolated into the same size as the final feature map and stacked together (with a new dimension in the first axis)into one single tensor. - For instance, if there are three feature maps with shapes: (1, 2, 32, 24), (1, 2, 16, 12) and - (1, 2, 8, 6). The last two will be interpolated into (1, 2, 32, 24), and the stacked tensor - will has the shape (1, 3, 2, 8, 6). + For instance, if there are two intermediate feature maps with shapes: (1, 2, 16, 12) and + (1, 2, 8, 6), and the final feature map has the shape (1, 2, 32, 24), then all intermediate feature maps + will be interpolated into (1, 2, 32, 24), and the stacked tensor will has the shape (1, 3, 2, 32, 24). When calculating the loss, you can use torch.unbind to get all feature maps can compute the loss one by one with the ground truth, then do a weighted average for all losses to achieve the final loss. - (To be added: a corresponding tutorial link) - deep_supr_num: number of feature maps that will output during deep supervision head. The value should be larger than 0 and less than the number of up sample layers. Defaults to 1. res_block: whether to use residual connection based convolution blocks during the network. Defaults to ``False``. + trans_bias: whether to set the bias parameter in transposed convolution layers. Defaults to ``False``. """ def __init__( @@ -115,12 +128,16 @@ def __init__( kernel_size: Sequence[Union[Sequence[int], int]], strides: Sequence[Union[Sequence[int], int]], upsample_kernel_size: Sequence[Union[Sequence[int], int]], + filters: Optional[Sequence[int]] = None, + dropout: Optional[Union[Tuple, str, float]] = None, norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), deep_supervision: bool = False, deep_supr_num: int = 1, res_block: bool = False, + trans_bias: bool = False, ): - super(DynUNet, self).__init__() + super().__init__() self.spatial_dims = spatial_dims self.in_channels = in_channels self.out_channels = out_channels @@ -128,24 +145,32 @@ def __init__( self.strides = strides self.upsample_kernel_size = upsample_kernel_size self.norm_name = norm_name + self.act_name = act_name + self.dropout = dropout self.conv_block = UnetResBlock if res_block else UnetBasicBlock - self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))] + self.trans_bias = trans_bias + if filters is not None: + self.filters = filters + self.check_filters() + else: + self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))] self.input_block = self.get_input_block() self.downsamples = self.get_downsamples() self.bottleneck = self.get_bottleneck() self.upsamples = self.get_upsamples() self.output_block = self.get_output_block(0) self.deep_supervision = deep_supervision - self.deep_supervision_heads = self.get_deep_supervision_heads() self.deep_supr_num = deep_supr_num + # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on + self.heads: List[torch.Tensor] = [torch.rand(1)] * self.deep_supr_num + if self.deep_supervision: + self.deep_supervision_heads = self.get_deep_supervision_heads() + self.check_deep_supr_num() + self.apply(self.initialize_weights) self.check_kernel_stride() - self.check_deep_supr_num() - - # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on - self.heads: List[torch.Tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1) - def create_skips(index, downsamples, upsamples, superheads, bottleneck): + def create_skips(index, downsamples, upsamples, bottleneck, superheads=None): """ Construct the UNet topology as a sequence of skip layers terminating with the bottleneck layer. This is done recursively from the top down since a recursive nn.Module subclass is being used to be compatible @@ -155,64 +180,90 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck): """ if len(downsamples) != len(upsamples): - raise AssertionError(f"{len(downsamples)} != {len(upsamples)}") - if (len(downsamples) - len(superheads)) not in (1, 0): - raise AssertionError(f"{len(downsamples)}-(0,1) != {len(superheads)}") + raise ValueError(f"{len(downsamples)} != {len(upsamples)}") if len(downsamples) == 0: # bottom of the network, pass the bottleneck block return bottleneck + + if superheads is None: + next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], bottleneck) + return DynUNetSkipLayer(index, downsample=downsamples[0], upsample=upsamples[0], next_layer=next_layer) + + super_head_flag = False if index == 0: # don't associate a supervision head with self.input_block - current_head, rest_heads = nn.Identity(), superheads - elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one - current_head, rest_heads = nn.Identity(), superheads[1:] + rest_heads = superheads else: - current_head, rest_heads = superheads[0], superheads[1:] + if len(superheads) > 0: + super_head_flag = True + rest_heads = superheads[1:] + else: + rest_heads = nn.ModuleList() # create the next layer down, this will stop at the bottleneck layer - next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck) - - return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], current_head, next_layer) - - self.skip_layers = create_skips( - 0, - [self.input_block] + list(self.downsamples), - self.upsamples[::-1], - self.deep_supervision_heads, - self.bottleneck, - ) + next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], bottleneck, superheads=rest_heads) + if super_head_flag: + return DynUNetSkipLayer( + index, + downsample=downsamples[0], + upsample=upsamples[0], + next_layer=next_layer, + heads=self.heads, + super_head=superheads[0], + ) + + return DynUNetSkipLayer(index, downsample=downsamples[0], upsample=upsamples[0], next_layer=next_layer) + + if not self.deep_supervision: + self.skip_layers = create_skips( + 0, [self.input_block] + list(self.downsamples), self.upsamples[::-1], self.bottleneck + ) + else: + self.skip_layers = create_skips( + 0, + [self.input_block] + list(self.downsamples), + self.upsamples[::-1], + self.bottleneck, + superheads=self.deep_supervision_heads, + ) def check_kernel_stride(self): kernels, strides = self.kernel_size, self.strides error_msg = "length of kernel_size and strides should be the same, and no less than 3." - if not (len(kernels) == len(strides) and len(kernels) >= 3): - raise AssertionError(error_msg) + if len(kernels) != len(strides) or len(kernels) < 3: + raise ValueError(error_msg) for idx, k_i in enumerate(kernels): kernel, stride = k_i, strides[idx] if not isinstance(kernel, int): - error_msg = "length of kernel_size in block {} should be the same as spatial_dims.".format(idx) + error_msg = f"length of kernel_size in block {idx} should be the same as spatial_dims." if len(kernel) != self.spatial_dims: - raise AssertionError(error_msg) + raise ValueError(error_msg) if not isinstance(stride, int): - error_msg = "length of stride in block {} should be the same as spatial_dims.".format(idx) + error_msg = f"length of stride in block {idx} should be the same as spatial_dims." if len(stride) != self.spatial_dims: - raise AssertionError(error_msg) + raise ValueError(error_msg) def check_deep_supr_num(self): deep_supr_num, strides = self.deep_supr_num, self.strides num_up_layers = len(strides) - 1 if deep_supr_num >= num_up_layers: - raise AssertionError("deep_supr_num should be less than the number of up sample layers.") + raise ValueError("deep_supr_num should be less than the number of up sample layers.") if deep_supr_num < 1: - raise AssertionError("deep_supr_num should be larger than 0.") + raise ValueError("deep_supr_num should be larger than 0.") + + def check_filters(self): + filters = self.filters + if len(filters) < len(self.strides): + raise ValueError("length of filters should be no less than the length of strides.") + else: + self.filters = filters[: len(self.strides)] def forward(self, x): out = self.skip_layers(x) out = self.output_block(out) if self.training and self.deep_supervision: out_all = [out] - feature_maps = self.heads[1 : self.deep_supr_num + 1] - for feature_map in feature_maps: + for feature_map in self.heads: out_all.append(interpolate(feature_map, out.shape[2:])) return torch.stack(out_all, dim=1) return out @@ -225,6 +276,8 @@ def get_input_block(self): self.kernel_size[0], self.strides[0], self.norm_name, + self.act_name, + dropout=self.dropout, ) def get_bottleneck(self): @@ -235,14 +288,12 @@ def get_bottleneck(self): self.kernel_size[-1], self.strides[-1], self.norm_name, + self.act_name, + dropout=self.dropout, ) def get_output_block(self, idx: int): - return UnetOutBlock( - self.spatial_dims, - self.filters[idx], - self.out_channels, - ) + return UnetOutBlock(self.spatial_dims, self.filters[idx], self.out_channels, dropout=self.dropout) def get_downsamples(self): inp, out = self.filters[:-2], self.filters[1:-1] @@ -253,7 +304,9 @@ def get_upsamples(self): inp, out = self.filters[1:][::-1], self.filters[:-1][::-1] strides, kernel_size = self.strides[1:][::-1], self.kernel_size[1:][::-1] upsample_kernel_size = self.upsample_kernel_size[::-1] - return self.get_module_list(inp, out, kernel_size, strides, UnetUpBlock, upsample_kernel_size) + return self.get_module_list( + inp, out, kernel_size, strides, UnetUpBlock, upsample_kernel_size, trans_bias=self.trans_bias + ) def get_module_list( self, @@ -263,6 +316,7 @@ def get_module_list( strides: Sequence[Union[Sequence[int], int]], conv_block: nn.Module, upsample_kernel_size: Optional[Sequence[Union[Sequence[int], int]]] = None, + trans_bias: bool = False, ): layers = [] if upsample_kernel_size is not None: @@ -276,7 +330,10 @@ def get_module_list( "kernel_size": kernel, "stride": stride, "norm_name": self.norm_name, + "act_name": self.act_name, + "dropout": self.dropout, "upsample_kernel_size": up_kernel, + "trans_bias": trans_bias, } layer = conv_block(**params) layers.append(layer) @@ -289,13 +346,15 @@ def get_module_list( "kernel_size": kernel, "stride": stride, "norm_name": self.norm_name, + "act_name": self.act_name, + "dropout": self.dropout, } layer = conv_block(**params) layers.append(layer) return nn.ModuleList(layers) def get_deep_supervision_heads(self): - return nn.ModuleList([self.get_output_block(i + 1) for i in range(len(self.upsamples) - 1)]) + return nn.ModuleList([self.get_output_block(i + 1) for i in range(self.deep_supr_num)]) @staticmethod def initialize_weights(module): @@ -305,4 +364,4 @@ def initialize_weights(module): module.bias = nn.init.constant_(module.bias, 0) -DynUnet = Dynunet = dynunet = DynUNet +DynUnet = Dynunet = DynUNet diff --git a/monai/networks/nets/dynunet_v1.py b/monai/networks/nets/dynunet_v1.py deleted file mode 100644 index feb05d1762..0000000000 --- a/monai/networks/nets/dynunet_v1.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import List, Sequence, Union - -import torch -import torch.nn as nn - -from monai.networks.blocks.dynunet_block_v1 import _UnetBasicBlockV1, _UnetResBlockV1, _UnetUpBlockV1 -from monai.networks.nets.dynunet import DynUNet, DynUNetSkipLayer -from monai.utils import deprecated - -__all__ = ["DynUNetV1", "DynUnetV1", "DynunetV1"] - - -@deprecated( - since="0.6.0", - removed="0.7.0", - msg_suffix="This module is for backward compatibility purpose only. Please use `DynUNet` instead.", -) -class DynUNetV1(DynUNet): - """ - This a deprecated reimplementation of a dynamic UNet (DynUNet), please use `monai.networks.nets.DynUNet` instead. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - kernel_size: convolution kernel size. - strides: convolution strides for each blocks. - upsample_kernel_size: convolution kernel size for transposed convolution layers. - norm_name: [``"batch"``, ``"instance"``, ``"group"``]. Defaults to "instance". - deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. - deep_supr_num: number of feature maps that will output during deep supervision head. Defaults to 1. - res_block: whether to use residual connection based convolution blocks during the network. - Defaults to ``False``. - - .. deprecated:: 0.6.0 - Use :class:`monai.networks.nets.DynUNet` instead. - - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Sequence[Union[Sequence[int], int]], - strides: Sequence[Union[Sequence[int], int]], - upsample_kernel_size: Sequence[Union[Sequence[int], int]], - norm_name: str = "instance", - deep_supervision: bool = False, - deep_supr_num: int = 1, - res_block: bool = False, - ): - nn.Module.__init__(self) - self.spatial_dims = spatial_dims - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.strides = strides - self.upsample_kernel_size = upsample_kernel_size - self.norm_name = norm_name - self.conv_block = _UnetResBlockV1 if res_block else _UnetBasicBlockV1 # type: ignore - self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))] - self.input_block = self.get_input_block() - self.downsamples = self.get_downsamples() - self.bottleneck = self.get_bottleneck() - self.upsamples = self.get_upsamples() - self.output_block = self.get_output_block(0) - self.deep_supervision = deep_supervision - self.deep_supervision_heads = self.get_deep_supervision_heads() - self.deep_supr_num = deep_supr_num - self.apply(self.initialize_weights) - self.check_kernel_stride() - self.check_deep_supr_num() - - # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on - self.heads: List[torch.Tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1) - - def create_skips(index, downsamples, upsamples, superheads, bottleneck): - """ - Construct the UNet topology as a sequence of skip layers terminating with the bottleneck layer. This is - done recursively from the top down since a recursive nn.Module subclass is being used to be compatible - with Torchscript. Initially the length of `downsamples` will be one more than that of `superheads` - since the `input_block` is passed to this function as the first item in `downsamples`, however this - shouldn't be associated with a supervision head. - """ - - if len(downsamples) != len(upsamples): - raise AssertionError(f"{len(downsamples)} != {len(upsamples)}") - if (len(downsamples) - len(superheads)) not in (1, 0): - raise AssertionError(f"{len(downsamples)}-(0,1) != {len(superheads)}") - - if len(downsamples) == 0: # bottom of the network, pass the bottleneck block - return bottleneck - if index == 0: # don't associate a supervision head with self.input_block - current_head, rest_heads = nn.Identity(), superheads - elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one - current_head, rest_heads = nn.Identity(), superheads[1:] - else: - current_head, rest_heads = superheads[0], superheads[1:] - - # create the next layer down, this will stop at the bottleneck layer - next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck) - - return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], current_head, next_layer) - - self.skip_layers = create_skips( - 0, - [self.input_block] + list(self.downsamples), - self.upsamples[::-1], - self.deep_supervision_heads, - self.bottleneck, - ) - - def get_upsamples(self): - inp, out = self.filters[1:][::-1], self.filters[:-1][::-1] - strides, kernel_size = self.strides[1:][::-1], self.kernel_size[1:][::-1] - upsample_kernel_size = self.upsample_kernel_size[::-1] - return self.get_module_list(inp, out, kernel_size, strides, _UnetUpBlockV1, upsample_kernel_size) - - @staticmethod - def initialize_weights(module): - name = module.__class__.__name__.lower() - if "conv3d" in name or "conv2d" in name: - nn.init.kaiming_normal_(module.weight, a=0.01) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - elif "norm" in name: - nn.init.normal_(module.weight, 1.0, 0.02) - nn.init.zeros_(module.bias) - - -DynUnetV1 = DynunetV1 = DynUNetV1 diff --git a/monai/networks/nets/efficientnet.py b/monai/networks/nets/efficientnet.py index 453916758a..fa5efbc4ef 100644 --- a/monai/networks/nets/efficientnet.py +++ b/monai/networks/nets/efficientnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,7 +23,14 @@ from monai.networks.layers.utils import get_norm_layer from monai.utils.module import look_up_option -__all__ = ["EfficientNet", "EfficientNetBN", "get_efficientnet_image_size", "drop_connect"] +__all__ = [ + "EfficientNet", + "EfficientNetBN", + "get_efficientnet_image_size", + "drop_connect", + "EfficientNetBNFeatures", + "BlockArgs", +] efficientnet_params = { # model_name: (width_mult, depth_mult, image_size, dropout_rate, dropconnect_rate) @@ -82,7 +89,7 @@ def __init__( Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. - out_classes: number of output channels. + out_channels: number of output channels. kernel_size: size of the kernel for conv ops. stride: stride to use for conv ops. image_size: input image resolution. @@ -369,10 +376,7 @@ def __init__( ) idx += 1 # increment blocks index counter - self._blocks.add_module( - str(stack_idx), - sub_stack, - ) + self._blocks.add_module(str(stack_idx), sub_stack) # sanity check to see if len(self._blocks) equal expected num_blocks if idx != num_blocks: @@ -534,7 +538,7 @@ def __init__( weight_coeff, depth_coeff, image_size, dropout_rate, dropconnect_rate = efficientnet_params[model_name] # create model and initialize random weights - super(EfficientNetBN, self).__init__( + super().__init__( blocks_args_str=blocks_args_str, spatial_dims=spatial_dims, in_channels=in_channels, @@ -594,7 +598,7 @@ def __init__( weight_coeff, depth_coeff, image_size, dropout_rate, dropconnect_rate = efficientnet_params[model_name] # create model and initialize random weights - super(EfficientNetBNFeatures, self).__init__( + super().__init__( blocks_args_str=blocks_args_str, spatial_dims=spatial_dims, in_channels=in_channels, @@ -669,7 +673,7 @@ def drop_connect(inputs: torch.Tensor, p: float, training: bool) -> torch.Tensor e.g. 1D activations [B, C, H], 2D activations [B, C, H, W] and 3D activations [B, C, H, W, D] Args: - input: input tensor with [B, C, dim_1, dim_2, ..., dim_N] where N=spatial_dims. + inputs: input tensor with [B, C, dim_1, dim_2, ..., dim_N] where N=spatial_dims. p: probability to use for dropping connections. training: whether in training or evaluation mode. @@ -677,7 +681,7 @@ def drop_connect(inputs: torch.Tensor, p: float, training: bool) -> torch.Tensor output: output tensor after applying drop connection. """ if p < 0.0 or p > 1.0: - raise ValueError("p must be in range of [0, 1], found {}".format(p)) + raise ValueError(f"p must be in range of [0, 1], found {p}") # eval mode: drop_connect is switched off - so return input without modifying if not training: @@ -708,7 +712,7 @@ def _load_state_dict(model: nn.Module, arch: str, progress: bool, adv_prop: bool arch = arch.split("efficientnet-")[-1] + "-ap" model_url = look_up_option(arch, url_map, None) if model_url is None: - print("pretrained weights of {} is not provided".format(arch)) + print(f"pretrained weights of {arch} is not provided") else: # load state dict from url model_url = url_map[arch] @@ -852,7 +856,7 @@ def _calculate_output_image_size(input_image_size: List[int], stride: Union[int, if isinstance(stride, tuple): all_strides_equal = all(stride[0] == s for s in stride) if not all_strides_equal: - raise ValueError("unequal strides are not possible, got {}".format(stride)) + raise ValueError(f"unequal strides are not possible, got {stride}") stride = stride[0] diff --git a/monai/networks/nets/fullyconnectednet.py b/monai/networks/nets/fullyconnectednet.py index b906bab015..810c07431b 100644 --- a/monai/networks/nets/fullyconnectednet.py +++ b/monai/networks/nets/fullyconnectednet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -30,9 +30,24 @@ def _get_adn_layer( class FullyConnectedNet(nn.Sequential): """ - Plain full-connected layer neural network + Simple full-connected layer neural network composed of a sequence of linear layers with PReLU activation and + dropout. The network accepts input with `in_channels` channels, has output with `out_channels` channels, and + hidden layer output channels given in `hidden_channels`. If `bias` is True then linear units have a bias term. + + Args: + in_channels: number of input channels. + out_channels: number of output channels. + hidden_channels: number of output channels for each hidden layer. + dropout: dropout ratio. Defaults to no dropout. + act: activation type and arguments. Defaults to PReLU. + bias: whether to have a bias term in linear units. Defaults to True. + adn_ordering: order of operations in :py:class:`monai.networks.blocks.ADN`. + + Examples:: + + # accepts 4 values and infers 3 values as output, has 3 hidden layers with 10, 20, 10 values as output + net = FullyConnectedNet(4, 3, [10, 20, 10], dropout=0.2) - The network uses dropout and, by default, PReLU activation """ def __init__( @@ -53,8 +68,11 @@ def __init__( self.in_channels = in_channels self.out_channels = out_channels self.hidden_channels = list(hidden_channels) + self.act = act + self.dropout = dropout + self.adn_ordering = adn_ordering + self.add_module("flatten", nn.Flatten()) - self.adn_layer = _get_adn_layer(act, dropout, adn_ordering) prev_channels = self.in_channels for i, c in enumerate(hidden_channels): @@ -64,13 +82,34 @@ def __init__( self.add_module("output", nn.Linear(prev_channels, out_channels, bias)) def _get_layer(self, in_channels: int, out_channels: int, bias: bool) -> nn.Sequential: - seq = nn.Sequential(nn.Linear(in_channels, out_channels, bias)) - seq.add_module("ADN", self.adn_layer) + seq = nn.Sequential( + nn.Linear(in_channels, out_channels, bias), _get_adn_layer(self.act, self.dropout, self.adn_ordering) + ) return seq class VarFullyConnectedNet(nn.Module): - """Variational fully-connected network.""" + """ + Variational fully-connected network. This is composed of an encode layer, reparameterization layer, and then a + decode layer. + + Args: + in_channels: number of input channels. + out_channels: number of output channels. + latent_size: number of latent variables to use. + encode_channels: number of output channels for each hidden layer of the encode half. + decode_channels: number of output channels for each hidden layer of the decode half. + dropout: dropout ratio. Defaults to no dropout. + act: activation type and arguments. Defaults to PReLU. + bias: whether to have a bias term in linear units. Defaults to True. + adn_ordering: order of operations in :py:class:`monai.networks.blocks.ADN`. + + Examples:: + + # accepts inputs with 4 values, uses a latent space of 2 variables, and produces outputs of 3 values + net = VarFullyConnectedNet(4, 3, 2, [5, 10], [10, 5]) + + """ def __init__( self, diff --git a/monai/networks/nets/generator.py b/monai/networks/nets/generator.py index 1f24944a63..a69cae4d7b 100644 --- a/monai/networks/nets/generator.py +++ b/monai/networks/nets/generator.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,13 +25,35 @@ class Generator(nn.Module): """ Defines a simple generator network accepting a latent vector and through a sequence of convolution layers constructs an output tensor of greater size and high dimensionality. The method `_get_layer` is used to - create each of these layers, override this method to define layers beyond the default Convolution or - ResidualUnit layers. + create each of these layers, override this method to define layers beyond the default + :py:class:`monai.networks.blocks.Convolution` or :py:class:`monai.networks.blocks.ResidualUnit` layers. + + The layers are constructed using the values in the `channels` and `strides` arguments, the number being defined by + the length of these (which must match). Input is first passed through a :py:class:`torch.nn.Linear` layer to + convert the input vector to an image tensor with dimensions `start_shape`. This passes through the convolution + layers and is progressively upsampled if the `strides` values are greater than 1 using transpose convolutions. The + size of the final output is defined by the `start_shape` dimension and the amount of upsampling done through + strides. In the default definition the size of the output's spatial dimensions will be that of `start_shape` + multiplied by the product of `strides`, thus the example network below upsamples an starting size of (64, 8, 8) + to (1, 64, 64) since its `strides` are (2, 2, 2). + + Args: + latent_shape: tuple of integers stating the dimension of the input latent vector (minus batch dimension) + start_shape: tuple of integers stating the dimension of the tensor to pass to convolution subnetwork + channels: tuple of integers stating the output channels of each convolutional layer + strides: tuple of integers stating the stride (upscale factor) of each convolutional layer + kernel_size: integer or tuple of integers stating size of convolutional kernels + num_res_units: integer stating number of convolutions in residual units, 0 means no residual units + act: name or type defining activation layers + norm: name or type defining normalization layers + dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout + bias: boolean stating if convolution layers should have a bias component + + Examples:: + + # 3 layers, latent input vector of shape (42, 24), output volume of shape (1, 64, 64) + net = Generator((42, 24), (64, 8, 8), (32, 16, 1), (2, 2, 2)) - For example, a generator accepting a latent vector if shape (42,24) and producing an output volume of - shape (1,64,64) can be constructed as: - - gen = Generator((42, 24), (64, 8, 8), (32, 16, 1), (2, 2, 2)) """ def __init__( @@ -47,26 +69,6 @@ def __init__( dropout: Optional[float] = None, bias: bool = True, ) -> None: - """ - Construct the generator network with the number of layers defined by `channels` and `strides`. In the - forward pass a `nn.Linear` layer relates the input latent vector to a tensor of dimensions `start_shape`, - this is then fed forward through the sequence of convolutional layers. The number of layers is defined by - the length of `channels` and `strides` which must match, each layer having the number of output channels - given in `channels` and an upsample factor given in `strides` (ie. a transpose convolution with that stride - size). - - Args: - latent_shape: tuple of integers stating the dimension of the input latent vector (minus batch dimension) - start_shape: tuple of integers stating the dimension of the tensor to pass to convolution subnetwork - channels: tuple of integers stating the output channels of each convolutional layer - strides: tuple of integers stating the stride (upscale factor) of each convolutional layer - kernel_size: integer or tuple of integers stating size of convolutional kernels - num_res_units: integer stating number of convolutions in residual units, 0 means no residual units - act: name or type defining activation layers - norm: name or type defining normalization layers - dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout - bias: boolean stating if convolution layers should have a bias component - """ super().__init__() self.in_channels, *self.start_shape = ensure_tuple(start_shape) @@ -112,7 +114,7 @@ def _get_layer( strides=strides, is_transposed=True, conv_only=is_last or self.num_res_units > 0, - dimensions=self.dimensions, + spatial_dims=self.dimensions, out_channels=out_channels, kernel_size=self.kernel_size, act=self.act, @@ -126,7 +128,7 @@ def _get_layer( in_channels=out_channels, subunits=self.num_res_units, last_conv_only=is_last, - dimensions=self.dimensions, + spatial_dims=self.dimensions, out_channels=out_channels, kernel_size=self.kernel_size, act=self.act, diff --git a/monai/networks/nets/highresnet.py b/monai/networks/nets/highresnet.py index f644a7835a..95c0c758af 100644 --- a/monai/networks/nets/highresnet.py +++ b/monai/networks/nets/highresnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -70,7 +70,7 @@ def __init__( ValueError: When ``channel_matching=pad`` and ``in_channels > out_channels``. Incompatible values. """ - super(HighResBlock, self).__init__() + super().__init__() self.chn_pad = ChannelPad( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, mode=channel_matching ) @@ -84,7 +84,7 @@ def __init__( ) layers.append( Convolution( - dimensions=spatial_dims, + spatial_dims=spatial_dims, in_channels=_in_chns, out_channels=_out_chns, kernel_size=kernel_size, @@ -146,7 +146,7 @@ def __init__( channel_matching: Union[ChannelMatching, str] = ChannelMatching.PAD, ) -> None: - super(HighResNet, self).__init__() + super().__init__() blocks = nn.ModuleList() # initial conv layer @@ -154,7 +154,7 @@ def __init__( _in_chns, _out_chns = in_channels, params["n_features"] blocks.append( Convolution( - dimensions=spatial_dims, + spatial_dims=spatial_dims, in_channels=_in_chns, out_channels=_out_chns, kernel_size=params["kernel_size"], @@ -190,7 +190,7 @@ def __init__( _in_chns, _out_chns = _out_chns, params["n_features"] blocks.append( Convolution( - dimensions=spatial_dims, + spatial_dims=spatial_dims, in_channels=_in_chns, out_channels=_out_chns, kernel_size=params["kernel_size"], @@ -206,7 +206,7 @@ def __init__( _in_chns = _out_chns blocks.append( Convolution( - dimensions=spatial_dims, + spatial_dims=spatial_dims, in_channels=_in_chns, out_channels=out_channels, kernel_size=params["kernel_size"], diff --git a/monai/networks/nets/milmodel.py b/monai/networks/nets/milmodel.py new file mode 100644 index 0000000000..2f4afaffbe --- /dev/null +++ b/monai/networks/nets/milmodel.py @@ -0,0 +1,244 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Union, cast + +import torch +import torch.nn as nn + +from monai.utils.module import optional_import + +models, _ = optional_import("torchvision.models") + + +class MILModel(nn.Module): + """ + Multiple Instance Learning (MIL) model, with a backbone classification model. + Currently, it only works for 2D images, a typical use case is for classification of the + digital pathology whole slide images. The expected shape of input data is `[B, N, C, H, W]`, + where `B` is the batch_size of PyTorch Dataloader and `N` is the number of instances + extracted from every original image in the batch. A tutorial example is available at: + https://github.com/Project-MONAI/tutorials/tree/master/pathology/multiple_instance_learning. + + Args: + num_classes: number of output classes. + mil_mode: MIL algorithm, available values (Defaults to ``"att"``): + + - ``"mean"`` - average features from all instances, equivalent to pure CNN (non MIL). + - ``"max"`` - retain only the instance with the max probability for loss calculation. + - ``"att"`` - attention based MIL https://arxiv.org/abs/1802.04712. + - ``"att_trans"`` - transformer MIL https://arxiv.org/abs/2111.01556. + - ``"att_trans_pyramid"`` - transformer pyramid MIL https://arxiv.org/abs/2111.01556. + + pretrained: init backbone with pretrained weights, defaults to ``True``. + backbone: Backbone classifier CNN (either ``None``, a ``nn.Module`` that returns features, + or a string name of a torchvision model). + Defaults to ``None``, in which case ResNet50 is used. + backbone_num_features: Number of output features of the backbone CNN + Defaults to ``None`` (necessary only when using a custom backbone) + trans_blocks: number of the blocks in `TransformEncoder` layer. + trans_dropout: dropout rate in `TransformEncoder` layer. + + """ + + def __init__( + self, + num_classes: int, + mil_mode: str = "att", + pretrained: bool = True, + backbone: Optional[Union[str, nn.Module]] = None, + backbone_num_features: Optional[int] = None, + trans_blocks: int = 4, + trans_dropout: float = 0.0, + ) -> None: + + super().__init__() + + if num_classes <= 0: + raise ValueError("Number of classes must be positive: " + str(num_classes)) + + if mil_mode.lower() not in ["mean", "max", "att", "att_trans", "att_trans_pyramid"]: + raise ValueError("Unsupported mil_mode: " + str(mil_mode)) + + self.mil_mode = mil_mode.lower() + self.attention = nn.Sequential() + self.transformer = None # type: Optional[nn.Module] + + if backbone is None: + + net = models.resnet50(pretrained=pretrained) + nfc = net.fc.in_features # save the number of final features + net.fc = torch.nn.Identity() # remove final linear layer + + self.extra_outputs = {} # type: Dict[str, torch.Tensor] + + if mil_mode == "att_trans_pyramid": + # register hooks to capture outputs of intermediate layers + def forward_hook(layer_name): + def hook(module, input, output): + self.extra_outputs[layer_name] = output + + return hook + + net.layer1.register_forward_hook(forward_hook("layer1")) + net.layer2.register_forward_hook(forward_hook("layer2")) + net.layer3.register_forward_hook(forward_hook("layer3")) + net.layer4.register_forward_hook(forward_hook("layer4")) + + elif isinstance(backbone, str): + + # assume torchvision model string is provided + torch_model = getattr(models, backbone, None) + if torch_model is None: + raise ValueError("Unknown torch vision model" + str(backbone)) + net = torch_model(pretrained=pretrained) + + if getattr(net, "fc", None) is not None: + nfc = net.fc.in_features # save the number of final features + net.fc = torch.nn.Identity() # remove final linear layer + else: + raise ValueError( + "Unable to detect FC layer for the torchvision model " + str(backbone), + ". Please initialize the backbone model manually.", + ) + + elif isinstance(backbone, nn.Module): + # use a custom backbone + net = backbone + nfc = backbone_num_features + + if backbone_num_features is None: + raise ValueError("Number of endencoder features must be provided for a custom backbone model") + + else: + raise ValueError("Unsupported backbone") + + if backbone is not None and mil_mode not in ["mean", "max", "att", "att_trans"]: + raise ValueError("Custom backbone is not supported for the mode:" + str(mil_mode)) + + if self.mil_mode in ["mean", "max"]: + pass + elif self.mil_mode == "att": + self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) + + elif self.mil_mode == "att_trans": + transformer = nn.TransformerEncoderLayer(d_model=nfc, nhead=8, dropout=trans_dropout) + self.transformer = nn.TransformerEncoder(transformer, num_layers=trans_blocks) + self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) + + elif self.mil_mode == "att_trans_pyramid": + + transformer_list = nn.ModuleList( + [ + nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), num_layers=trans_blocks + ), + nn.Sequential( + nn.Linear(768, 256), + nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), + num_layers=trans_blocks, + ), + ), + nn.Sequential( + nn.Linear(1280, 256), + nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), + num_layers=trans_blocks, + ), + ), + nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=2304, nhead=8, dropout=trans_dropout), + num_layers=trans_blocks, + ), + ] + ) + self.transformer = transformer_list + nfc = nfc + 256 + self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) + + else: + raise ValueError("Unsupported mil_mode: " + str(mil_mode)) + + self.myfc = nn.Linear(nfc, num_classes) + self.net = net + + def calc_head(self, x: torch.Tensor) -> torch.Tensor: + + sh = x.shape + + if self.mil_mode == "mean": + x = self.myfc(x) + x = torch.mean(x, dim=1) + + elif self.mil_mode == "max": + x = self.myfc(x) + x, _ = torch.max(x, dim=1) + + elif self.mil_mode == "att": + + a = self.attention(x) + a = torch.softmax(a, dim=1) + x = torch.sum(x * a, dim=1) + + x = self.myfc(x) + + elif self.mil_mode == "att_trans" and self.transformer is not None: + + x = x.permute(1, 0, 2) + x = self.transformer(x) + x = x.permute(1, 0, 2) + + a = self.attention(x) + a = torch.softmax(a, dim=1) + x = torch.sum(x * a, dim=1) + + x = self.myfc(x) + + elif self.mil_mode == "att_trans_pyramid" and self.transformer is not None: + + l1 = torch.mean(self.extra_outputs["layer1"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) + l2 = torch.mean(self.extra_outputs["layer2"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) + l3 = torch.mean(self.extra_outputs["layer3"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) + l4 = torch.mean(self.extra_outputs["layer4"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) + + transformer_list = cast(nn.ModuleList, self.transformer) + + x = transformer_list[0](l1) + x = transformer_list[1](torch.cat((x, l2), dim=2)) + x = transformer_list[2](torch.cat((x, l3), dim=2)) + x = transformer_list[3](torch.cat((x, l4), dim=2)) + + x = x.permute(1, 0, 2) + + a = self.attention(x) + a = torch.softmax(a, dim=1) + x = torch.sum(x * a, dim=1) + + x = self.myfc(x) + + else: + raise ValueError("Wrong model mode" + str(self.mil_mode)) + + return x + + def forward(self, x: torch.Tensor, no_head: bool = False) -> torch.Tensor: + + sh = x.shape + x = x.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4]) + + x = self.net(x) + x = x.reshape(sh[0], sh[1], -1) + + if not no_head: + x = self.calc_head(x) + + return x diff --git a/monai/networks/nets/netadapter.py b/monai/networks/nets/netadapter.py index 80288f7945..425c1d5820 100644 --- a/monai/networks/nets/netadapter.py +++ b/monai/networks/nets/netadapter.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,11 +24,12 @@ class NetAdapter(torch.nn.Module): then replace the model's last two layers with an optional `pooling` and a `conv` or `linear` layer. Args: - model: a PyTorch model, support both 2D and 3D models. typically, it can be a pretrained model in Torchvision, - like: ``resnet18``, ``resnet34m``, ``resnet50``, ``resnet101``, ``resnet152``, etc. + model: a PyTorch model, which can be both 2D and 3D models. typically, it can be a pretrained model + in Torchvision, like: ``resnet18``, ``resnet34``, ``resnet50``, ``resnet101``, ``resnet152``, etc. more details: https://pytorch.org/vision/stable/models.html. num_classes: number of classes for the last classification layer. Default to 1. - dim: number of spatial dimensions, default to 2. + dim: number of supported spatial dimensions in the specified model, depends on the model implementation. + default to 2 as most Torchvision models are for 2D image processing. in_channels: number of the input channels of last layer. if None, get it from `in_features` of last layer. use_conv: whether use convolutional layer to replace the last layer, default to False. pool: parameters for the pooling layer, it should be a tuple, the first item is name of the pooling layer, @@ -37,6 +38,9 @@ class NetAdapter(torch.nn.Module): bias: the bias value when replacing the last layer. if False, the layer will not learn an additive bias, default to True. + .. deprecated:: 0.6.0 + ``n_classes`` is deprecated, use ``num_classes`` instead. + """ @deprecated_arg("n_classes", since="0.6") @@ -67,32 +71,23 @@ def __init__( in_channels_ = in_channels if pool is None: - self.pool = None # remove the last layer self.features = torch.nn.Sequential(*layers[:-1]) + self.pool = None else: - self.pool = get_pool_layer(name=pool, spatial_dims=dim) # remove the last 2 layers self.features = torch.nn.Sequential(*layers[:-2]) + self.pool = get_pool_layer(name=pool, spatial_dims=dim) self.fc: Union[torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv3d] if use_conv: # add 1x1 conv (it behaves like a FC layer) - self.fc = Conv[Conv.CONV, dim]( - in_channels=in_channels_, - out_channels=num_classes, - kernel_size=1, - bias=bias, - ) + self.fc = Conv[Conv.CONV, dim](in_channels=in_channels_, out_channels=num_classes, kernel_size=1, bias=bias) else: # remove the last Linear layer (fully connected) self.features = torch.nn.Sequential(*layers[:-1]) # replace the out_features of FC layer - self.fc = torch.nn.Linear( - in_features=in_channels_, - out_features=num_classes, - bias=bias, - ) + self.fc = torch.nn.Linear(in_features=in_channels_, out_features=num_classes, bias=bias) self.use_conv = use_conv def forward(self, x): diff --git a/monai/networks/nets/regressor.py b/monai/networks/nets/regressor.py index 25acb9bfa5..0a1e6258a9 100644 --- a/monai/networks/nets/regressor.py +++ b/monai/networks/nets/regressor.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -29,6 +29,30 @@ class Regressor(nn.Module): This defines a network for relating large-sized input tensors to small output tensors, ie. regressing large values to a prediction. An output of a single dimension can be used as value regression or multi-label classification prediction, an output of a single value can be used as a discriminator or critic prediction. + + The network is constructed as a sequence of layers, either :py:class:`monai.networks.blocks.Convolution` or + :py:class:`monai.networks.blocks.ResidualUnit`, with a final fully-connected layer resizing the output from the + blocks to the final size. Each block is defined with a stride value typically used to downsample the input using + strided convolutions. In this way each block progressively condenses information from the input into a deep + representation the final fully-connected layer relates to a final result. + + Args: + in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension) + out_shape: tuple of integers stating the dimension of the final output tensor (minus batch dimension) + channels: tuple of integers stating the output channels of each convolutional layer + strides: tuple of integers stating the stride (downscale factor) of each convolutional layer + kernel_size: integer or tuple of integers stating size of convolutional kernels + num_res_units: integer stating number of convolutions in residual units, 0 means no residual units + act: name or type defining activation layers + norm: name or type defining normalization layers + dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout + bias: boolean stating if convolution layers should have a bias component + + Examples:: + + # infers a 2-value result (eg. a 2D cartesian coordinate) from a 64x64 image + net = Regressor((1, 64, 64), (2,), (2, 4, 8), (2, 2, 2)) + """ def __init__( @@ -44,23 +68,6 @@ def __init__( dropout: Optional[float] = None, bias: bool = True, ) -> None: - """ - Construct the regressor network with the number of layers defined by `channels` and `strides`. Inputs are - first passed through the convolutional layers in the forward pass, the output from this is then pass - through a fully connected layer to relate them to the final output tensor. - - Args: - in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension) - out_shape: tuple of integers stating the dimension of the final output tensor - channels: tuple of integers stating the output channels of each convolutional layer - strides: tuple of integers stating the stride (downscale factor) of each convolutional layer - kernel_size: integer or tuple of integers stating size of convolutional kernels - num_res_units: integer stating number of convolutions in residual units, 0 means no residual units - act: name or type defining activation layers - norm: name or type defining normalization layers - dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout - bias: boolean stating if convolution layers should have a bias component - """ super().__init__() self.in_channels, *self.in_shape = ensure_tuple(in_shape) @@ -107,7 +114,7 @@ def _get_layer( layer = ResidualUnit( subunits=self.num_res_units, last_conv_only=is_last, - dimensions=self.dimensions, + spatial_dims=self.dimensions, in_channels=in_channels, out_channels=out_channels, strides=strides, @@ -120,7 +127,7 @@ def _get_layer( else: layer = Convolution( conv_only=is_last, - dimensions=self.dimensions, + spatial_dims=self.dimensions, in_channels=in_channels, out_channels=out_channels, strides=strides, diff --git a/monai/networks/nets/regunet.py b/monai/networks/nets/regunet.py index 4cf747f650..174431cc3c 100644 --- a/monai/networks/nets/regunet.py +++ b/monai/networks/nets/regunet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -67,7 +67,7 @@ def __init__( concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition encode_kernel_sizes: kernel size for down-sampling """ - super(RegUNet, self).__init__() + super().__init__() if not extract_levels: extract_levels = (depth,) if max(extract_levels) != depth: @@ -106,9 +106,7 @@ def __init__( # build layers self.build_layers() - def build_layers( - self, - ): + def build_layers(self): self.build_encode_layers() self.build_decode_layers() @@ -125,23 +123,13 @@ def build_encode_layers(self): ] ) self.encode_pools = nn.ModuleList( - [ - self.build_down_sampling_block( - channels=self.num_channels[d], - ) - for d in range(self.depth) - ] + [self.build_down_sampling_block(channels=self.num_channels[d]) for d in range(self.depth)] ) self.bottom_block = self.build_bottom_block( in_channels=self.num_channels[-2], out_channels=self.num_channels[-1] ) - def build_conv_block( - self, - in_channels, - out_channels, - kernel_size, - ): + def build_conv_block(self, in_channels, out_channels, kernel_size): return nn.Sequential( get_conv_block( spatial_dims=self.spatial_dims, @@ -157,10 +145,7 @@ def build_conv_block( ), ) - def build_down_sampling_block( - self, - channels: int, - ): + def build_down_sampling_block(self, channels: int): return RegistrationDownSampleBlock(spatial_dims=self.spatial_dims, channels=channels, pooling=self.pooling) def build_bottom_block(self, in_channels: int, out_channels: int): @@ -203,11 +188,7 @@ def build_decode_layers(self): # extraction self.output_block = self.build_output_block() - def build_up_sampling_block( - self, - in_channels: int, - out_channels: int, - ) -> nn.Module: + def build_up_sampling_block(self, in_channels: int, out_channels: int) -> nn.Module: return get_deconv_block(spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels) def build_output_block(self) -> nn.Module: @@ -255,14 +236,8 @@ def forward(self, x): class AffineHead(nn.Module): - def __init__( - self, - spatial_dims: int, - image_size: List[int], - decode_size: List[int], - in_channels: int, - ): - super(AffineHead, self).__init__() + def __init__(self, spatial_dims: int, image_size: List[int], decode_size: List[int], in_channels: int): + super().__init__() self.spatial_dims = spatial_dims if spatial_dims == 2: in_features = in_channels * decode_size[0] * decode_size[1] @@ -365,13 +340,8 @@ def build_output_block(self): class AdditiveUpSampleBlock(nn.Module): - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - ): - super(AdditiveUpSampleBlock, self).__init__() + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int): + super().__init__() self.deconv = get_deconv_block(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -435,17 +405,10 @@ def __init__( def build_bottom_block(self, in_channels: int, out_channels: int): kernel_size = self.encode_kernel_sizes[self.depth] return get_conv_block( - spatial_dims=self.spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, + spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size ) - def build_up_sampling_block( - self, - in_channels: int, - out_channels: int, - ) -> nn.Module: + def build_up_sampling_block(self, in_channels: int, out_channels: int) -> nn.Module: if self._use_additive_upsampling: return AdditiveUpSampleBlock( spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index a5e6b7ab81..a263c8e8b3 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,9 +14,10 @@ import torch import torch.nn as nn -import torch.nn.functional as F from monai.networks.layers.factories import Conv, Norm, Pool +from monai.networks.layers.utils import get_pool_layer +from monai.utils.module import look_up_option __all__ = ["ResNet", "resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"] @@ -28,14 +29,14 @@ def get_inplanes(): def get_avgpool(): - return [(0), (1), (1, 1), (1, 1, 1)] + return [0, 1, (1, 1), (1, 1, 1)] def get_conv1(conv1_t_size: int, conv1_t_stride: int): return ( - [(0), (conv1_t_size), (conv1_t_size, 7), (conv1_t_size, 7, 7)], - [(0), (conv1_t_stride), (conv1_t_stride, 2), (conv1_t_stride, 2, 2)], - [(0), (conv1_t_size // 2), (conv1_t_size // 2, 3), (conv1_t_size // 2, 3, 3)], + [0, conv1_t_size, (conv1_t_size, 7), (conv1_t_size, 7, 7)], + [0, conv1_t_stride, (conv1_t_stride, 2), (conv1_t_stride, 2, 2)], + [0, (conv1_t_size // 2), (conv1_t_size // 2, 3), (conv1_t_size // 2, 3, 3)], ) @@ -58,7 +59,7 @@ def __init__( stride: stride to use for first conv layer. downsample: which downsample layer to use. """ - super(ResNetBlock, self).__init__() + super().__init__() conv_type: Callable = Conv[Conv.CONV, spatial_dims] norm_type: Callable = Norm[Norm.BATCH, spatial_dims] @@ -110,7 +111,7 @@ def __init__( downsample: which downsample layer to use. """ - super(ResNetBottleneck, self).__init__() + super().__init__() conv_type: Callable = Conv[Conv.CONV, spatial_dims] norm_type: Callable = Norm[Norm.BATCH, spatial_dims] @@ -153,6 +154,7 @@ class ResNet(nn.Module): ResNet based on: `Deep Residual Learning for Image Recognition `_ and `Can Spatiotemporal 3D CNNs Retrace the History of 2D CNNs and ImageNet? `_. Adapted from ``_. + Args: block: which ResNet block to use, either Basic or Bottleneck. layers: how many layers to use. @@ -162,9 +164,16 @@ class ResNet(nn.Module): conv1_t_size: size of first convolution layer, determines kernel and padding. conv1_t_stride: stride of first convolution layer. no_max_pool: bool argument to determine if to use maxpool layer. - shortcut_type: which downsample block to use. + shortcut_type: which downsample block to use. Options are 'A', 'B', default to 'B'. + - 'A': using `self._downsample_basic_block`. + - 'B': kernel_size 1 conv + norm. widen_factor: widen output for each layer. - num_classes: number of output (classifications) + num_classes: number of output (classifications). + feed_forward: whether to add the FC layer for the output, default to `True`. + + .. deprecated:: 0.6.0 + ``n_classes`` is deprecated, use ``num_classes`` instead. + """ @deprecated_arg("n_classes", since="0.6") @@ -185,7 +194,7 @@ def __init__( n_classes: Optional[int] = None, ) -> None: - super(ResNet, self).__init__() + super().__init__() # in case the new num_classes is default but you still call deprecated n_classes if n_classes is not None and num_classes == 400: num_classes = n_classes @@ -198,7 +207,7 @@ def __init__( ] block_avgpool = get_avgpool() - conv1_kernel, conv1_stride, con1_padding = get_conv1(conv1_t_size, conv1_t_stride) + conv1_kernel, conv1_stride, conv1_padding = get_conv1(conv1_t_size, conv1_t_stride) block_inplanes = [int(x * widen_factor) for x in block_inplanes] self.in_planes = block_inplanes[0] @@ -209,7 +218,7 @@ def __init__( self.in_planes, kernel_size=conv1_kernel[spatial_dims], stride=conv1_stride[spatial_dims], - padding=con1_padding[spatial_dims], + padding=conv1_padding[spatial_dims], bias=False, ) self.bn1 = norm_type(self.in_planes) @@ -220,9 +229,7 @@ def __init__( self.layer3 = self._make_layer(block, block_inplanes[2], layers[2], spatial_dims, shortcut_type, stride=2) self.layer4 = self._make_layer(block, block_inplanes[3], layers[3], spatial_dims, shortcut_type, stride=2) self.avgpool = avgp_type(block_avgpool[spatial_dims]) - - if feed_forward: - self.fc = nn.Linear(block_inplanes[3] * block.expansion, num_classes) + self.fc = nn.Linear(block_inplanes[3] * block.expansion, num_classes) if feed_forward else None for m in self.modules(): if isinstance(m, conv_type): @@ -234,14 +241,9 @@ def __init__( nn.init.constant_(torch.as_tensor(m.bias), 0) def _downsample_basic_block(self, x: torch.Tensor, planes: int, stride: int, spatial_dims: int = 3) -> torch.Tensor: - assert spatial_dims == 3 - out: torch.Tensor = F.avg_pool3d(x, kernel_size=1, stride=stride) - zero_pads = torch.zeros(out.size(0), planes - out.size(1), out.size(2), out.size(3), out.size(4)) - if isinstance(out.data, torch.FloatTensor): - zero_pads = zero_pads.cuda() - + out: torch.Tensor = get_pool_layer(("avg", {"kernel_size": 1, "stride": stride}), spatial_dims=spatial_dims)(x) + zero_pads = torch.zeros(out.size(0), planes - out.size(1), *out.shape[2:], dtype=out.dtype, device=out.device) out = torch.cat([out.data, zero_pads], dim=1) - return out def _make_layer( @@ -259,9 +261,12 @@ def _make_layer( downsample: Union[nn.Module, partial, None] = None if stride != 1 or self.in_planes != planes * block.expansion: - if shortcut_type == "A": + if look_up_option(shortcut_type, {"A", "B"}) == "A": downsample = partial( - self._downsample_basic_block, planes=planes * block.expansion, kernel_size=1, stride=stride + self._downsample_basic_block, + planes=planes * block.expansion, + stride=stride, + spatial_dims=spatial_dims, ) else: downsample = nn.Sequential( @@ -269,12 +274,12 @@ def _make_layer( norm_type(planes * block.expansion), ) - layers = [] - layers.append( + layers = [ block( in_planes=self.in_planes, planes=planes, spatial_dims=spatial_dims, stride=stride, downsample=downsample ) - ) + ] + self.in_planes = planes * block.expansion for _i in range(1, blocks): layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims)) @@ -296,7 +301,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.avgpool(x) x = x.view(x.size(0), -1) - x = self.fc(x) + if self.fc is not None: + x = self.fc(x) return x diff --git a/monai/networks/nets/segresnet.py b/monai/networks/nets/segresnet.py index 8be562aadd..d2c45dd3a3 100644 --- a/monai/networks/nets/segresnet.py +++ b/monai/networks/nets/segresnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple, Union +from typing import List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -73,7 +73,7 @@ def __init__( super().__init__() if spatial_dims not in (2, 3): - raise AssertionError("spatial_dims can only be 2 or 3.") + raise ValueError("`spatial_dims` can only be 2 or 3.") self.spatial_dims = spatial_dims self.init_filters = init_filters @@ -81,7 +81,8 @@ def __init__( self.blocks_down = blocks_down self.blocks_up = blocks_up self.dropout_prob = dropout_prob - self.act = get_act_layer(act) + self.act = act # input options + self.act_mod = get_act_layer(act) if norm_name: if norm_name.lower() != "group": raise ValueError(f"Deprecating option 'norm_name={norm_name}', please use 'norm' instead.") @@ -99,12 +100,7 @@ def __init__( def _make_down_layers(self): down_layers = nn.ModuleList() - blocks_down, spatial_dims, filters, norm = ( - self.blocks_down, - self.spatial_dims, - self.init_filters, - self.norm, - ) + blocks_down, spatial_dims, filters, norm = (self.blocks_down, self.spatial_dims, self.init_filters, self.norm) for i in range(len(blocks_down)): layer_in_channels = filters * 2 ** i pre_conv = ( @@ -114,7 +110,7 @@ def _make_down_layers(self): ) down_layer = nn.Sequential( pre_conv, - *[ResBlock(spatial_dims, layer_in_channels, norm=norm) for _ in range(blocks_down[i])], + *[ResBlock(spatial_dims, layer_in_channels, norm=norm, act=self.act) for _ in range(blocks_down[i])], ) down_layers.append(down_layer) return down_layers @@ -133,7 +129,10 @@ def _make_up_layers(self): sample_in_channels = filters * 2 ** (n_up - i) up_layers.append( nn.Sequential( - *[ResBlock(spatial_dims, sample_in_channels // 2, norm=norm) for _ in range(blocks_up[i])] + *[ + ResBlock(spatial_dims, sample_in_channels // 2, norm=norm, act=self.act) + for _ in range(blocks_up[i]) + ] ) ) up_samples.append( @@ -149,11 +148,11 @@ def _make_up_layers(self): def _make_final_conv(self, out_channels: int): return nn.Sequential( get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=self.init_filters), - self.act, + self.act_mod, get_conv_layer(self.spatial_dims, self.init_filters, out_channels, kernel_size=1, bias=True), ) - def forward(self, x): + def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: x = self.convInit(x) if self.dropout_prob is not None: x = self.dropout(x) @@ -164,14 +163,23 @@ def forward(self, x): x = down(x) down_x.append(x) - down_x.reverse() + return x, down_x + def decode(self, x: torch.Tensor, down_x: List[torch.Tensor]) -> torch.Tensor: for i, (up, upl) in enumerate(zip(self.up_samples, self.up_layers)): x = up(x) + down_x[i + 1] x = upl(x) if self.use_conv_final: x = self.conv_final(x) + + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, down_x = self.encode(x) + down_x.reverse() + + x = self.decode(x, down_x) return x @@ -226,12 +234,13 @@ def __init__( blocks_up: tuple = (1, 1, 1), upsample_mode: Union[UpsampleMode, str] = UpsampleMode.NONTRAINABLE, ): - super(SegResNetVAE, self).__init__( + super().__init__( spatial_dims=spatial_dims, init_filters=init_filters, in_channels=in_channels, out_channels=out_channels, dropout_prob=dropout_prob, + act=act, norm=norm, use_conv_final=use_conv_final, blocks_down=blocks_down, @@ -258,10 +267,10 @@ def _prepare_vae_modules(self): self.vae_down = nn.Sequential( get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=v_filters), - self.act, + self.act_mod, get_conv_layer(self.spatial_dims, v_filters, self.smallest_filters, stride=2, bias=True), get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=self.smallest_filters), - self.act, + self.act_mod, ) self.vae_fc1 = nn.Linear(total_elements, self.vae_nz) self.vae_fc2 = nn.Linear(total_elements, self.vae_nz) @@ -271,7 +280,7 @@ def _prepare_vae_modules(self): get_conv_layer(self.spatial_dims, self.smallest_filters, v_filters, kernel_size=1), get_upsample_layer(self.spatial_dims, v_filters, upsample_mode=self.upsample_mode), get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=v_filters), - self.act, + self.act_mod, ) def _get_vae_loss(self, net_input: torch.Tensor, vae_input: torch.Tensor): @@ -300,7 +309,7 @@ def _get_vae_loss(self, net_input: torch.Tensor, vae_input: torch.Tensor): x_vae = z_mean + z_sigma * z_mean_rand x_vae = self.vae_fc3(x_vae) - x_vae = self.act(x_vae) + x_vae = self.act_mod(x_vae) x_vae = x_vae.view([-1, self.smallest_filters] + self.fc_insize) x_vae = self.vae_fc_up_sample(x_vae) @@ -315,25 +324,11 @@ def _get_vae_loss(self, net_input: torch.Tensor, vae_input: torch.Tensor): def forward(self, x): net_input = x - x = self.convInit(x) - if self.dropout_prob is not None: - x = self.dropout(x) - - down_x = [] - for down in self.down_layers: - x = down(x) - down_x.append(x) - + x, down_x = self.encode(x) down_x.reverse() vae_input = x - - for i, (up, upl) in enumerate(zip(self.up_samples, self.up_layers)): - x = up(x) + down_x[i + 1] - x = upl(x) - - if self.use_conv_final: - x = self.conv_final(x) + x = self.decode(x, down_x) if self.training: vae_loss = self._get_vae_loss(net_input, vae_input) diff --git a/monai/networks/nets/senet.py b/monai/networks/nets/senet.py index 9b7035c259..a85d32ba5a 100644 --- a/monai/networks/nets/senet.py +++ b/monai/networks/nets/senet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,12 +17,32 @@ import torch.nn as nn from torch.hub import load_state_dict_from_url +from monai.apps.utils import download_url from monai.networks.blocks.convolutions import Convolution from monai.networks.blocks.squeeze_and_excitation import SEBottleneck, SEResNetBottleneck, SEResNeXtBottleneck from monai.networks.layers.factories import Act, Conv, Dropout, Norm, Pool from monai.utils.module import look_up_option -__all__ = ["SENet", "SENet154", "SEResNet50", "SEResNet101", "SEResNet152", "SEResNeXt50", "SEResNext101"] +__all__ = [ + "SENet", + "SENet154", + "SEResNet50", + "SEResNet101", + "SEResNet152", + "SEResNeXt50", + "SEResNext101", + "SE_NET_MODELS", +] + + +SE_NET_MODELS = { + "senet154": "http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth", + "se_resnet50": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth", + "se_resnet101": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth", + "se_resnet152": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth", + "se_resnext50_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth", + "se_resnext101_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth", +} class SENet(nn.Module): @@ -87,7 +107,7 @@ def __init__( num_classes: int = 1000, ) -> None: - super(SENet, self).__init__() + super().__init__() relu_type: Type[nn.ReLU] = Act[Act.RELU] conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] @@ -192,7 +212,7 @@ def _make_layer( downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = Convolution( - dimensions=self.spatial_dims, + spatial_dims=self.spatial_dims, in_channels=self.inplanes, out_channels=planes * block.expansion, strides=stride, @@ -254,15 +274,7 @@ def _load_state_dict(model: nn.Module, arch: str, progress: bool): """ This function is used to load pretrained models. """ - model_urls = { - "senet154": "http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth", - "se_resnet50": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth", - "se_resnet101": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth", - "se_resnet152": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth", - "se_resnext50_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth", - "se_resnext101_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth", - } - model_url = look_up_option(arch, model_urls, None) + model_url = look_up_option(arch, SE_NET_MODELS, None) if model_url is None: raise ValueError( "only 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', " @@ -276,7 +288,11 @@ def _load_state_dict(model: nn.Module, arch: str, progress: bool): pattern_down_conv = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.0.)(\w*)$") pattern_down_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.1.)(\w*)$") - state_dict = load_state_dict_from_url(model_url, progress=progress) + if isinstance(model_url, dict): + download_url(model_url["url"], filepath=model_url["filename"]) + state_dict = torch.load(model_url["filename"], map_location=None) + else: + state_dict = load_state_dict_from_url(model_url, progress=progress) for key in list(state_dict.keys()): new_key = None if pattern_conv.match(key): @@ -317,13 +333,7 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super(SENet154, self).__init__( - block=SEBottleneck, - layers=layers, - groups=groups, - reduction=reduction, - **kwargs, - ) + super().__init__(block=SEBottleneck, layers=layers, groups=groups, reduction=reduction, **kwargs) if pretrained: # it only worked when `spatial_dims` is 2 _load_state_dict(self, "senet154", progress) @@ -345,7 +355,7 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super(SEResNet50, self).__init__( + super().__init__( block=SEResNetBottleneck, layers=layers, groups=groups, @@ -378,7 +388,7 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super(SEResNet101, self).__init__( + super().__init__( block=SEResNetBottleneck, layers=layers, groups=groups, @@ -410,7 +420,7 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super(SEResNet152, self).__init__( + super().__init__( block=SEResNetBottleneck, layers=layers, groups=groups, @@ -443,7 +453,7 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super(SEResNext50, self).__init__( + super().__init__( block=SEResNeXtBottleneck, layers=layers, groups=groups, @@ -477,7 +487,7 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super(SEResNext101, self).__init__( + super().__init__( block=SEResNeXtBottleneck, layers=layers, groups=groups, @@ -493,7 +503,7 @@ def __init__( _load_state_dict(self, "se_resnext101_32x4d", progress) -SEnet = Senet = senet = SENet +SEnet = Senet = SENet SEnet154 = Senet154 = senet154 = SENet154 SEresnet50 = Seresnet50 = seresnet50 = SEResNet50 SEresnet101 = Seresnet101 = seresnet101 = SEResNet101 diff --git a/monai/networks/nets/torchvision_fc.py b/monai/networks/nets/torchvision_fc.py index 1619f877e7..e93019d050 100644 --- a/monai/networks/nets/torchvision_fc.py +++ b/monai/networks/nets/torchvision_fc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -26,11 +26,12 @@ class TorchVisionFCModel(NetAdapter): Args: model_name: name of any torchvision model with fully connected layer at the end. - ``resnet18`` (default), ``resnet34m``, ``resnet50``, ``resnet101``, ``resnet152``, + ``resnet18`` (default), ``resnet34``, ``resnet50``, ``resnet101``, ``resnet152``, ``resnext50_32x4d``, ``resnext101_32x8d``, ``wide_resnet50_2``, ``wide_resnet101_2``. model details: https://pytorch.org/vision/stable/models.html. num_classes: number of classes for the last classification layer. Default to 1. - dim: number of spatial dimensions, default to 2. + dim: number of supported spatial dimensions in the specified model, depends on the model implementation. + default to 2 as most Torchvision models are for 2D image processing. in_channels: number of the input channels of last layer. if None, get it from `in_features` of last layer. use_conv: whether use convolutional layer to replace the last layer, default to False. pool: parameters for the pooling layer, it should be a tuple, the first item is name of the pooling layer, @@ -73,14 +74,14 @@ def __init__( ) -@deprecated(since="0.6.0", removed="0.7.0", msg_suffix="Please consider using `TorchVisionFCModel` instead.") +@deprecated(since="0.6.0", removed="0.9.0", msg_suffix="Please consider using `TorchVisionFCModel` instead.") class TorchVisionFullyConvModel(TorchVisionFCModel): """ Customize TorchVision models to replace fully connected layer by convolutional layer. Args: model_name: name of any torchvision with adaptive avg pooling and fully connected layer at the end. - ``resnet18`` (default), ``resnet34m``, ``resnet50``, ``resnet101``, ``resnet152``, + ``resnet18`` (default), ``resnet34``, ``resnet50``, ``resnet101``, ``resnet152``, ``resnext50_32x4d``, ``resnext101_32x8d``, ``wide_resnet50_2``, ``wide_resnet101_2``. num_classes: number of classes for the last classification layer. Default to 1. pool_size: the kernel size for `AvgPool2d` to replace `AdaptiveAvgPool2d`. Default to (7, 7). diff --git a/monai/networks/nets/transchex.py b/monai/networks/nets/transchex.py new file mode 100644 index 0000000000..7b3afe3e5e --- /dev/null +++ b/monai/networks/nets/transchex.py @@ -0,0 +1,378 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +import shutil +import tarfile +import tempfile +from typing import Sequence, Tuple, Union + +import torch +from torch import nn + +from monai.utils import optional_import + +transformers = optional_import("transformers") +load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert") +cached_path = optional_import("transformers.file_utils", name="cached_path")[0] +BertEmbeddings = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings")[0] +BertLayer = optional_import("transformers.models.bert.modeling_bert", name="BertLayer")[0] + +__all__ = ["BertPreTrainedModel", "BertAttention", "BertOutput", "BertMixedLayer", "Pooler", "MultiModal", "Transchex"] + + +class BertPreTrainedModel(nn.Module): + """Module to load BERT pre-trained weights. + Based on: + LXMERT + https://github.com/airsplay/lxmert + BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__(self, *inputs, **kwargs) -> None: + super().__init__() + + def init_bert_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, torch.nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + @classmethod + def from_pretrained( + cls, + num_language_layers, + num_vision_layers, + num_mixed_layers, + bert_config, + state_dict=None, + cache_dir=None, + from_tf=False, + *inputs, + **kwargs, + ): + archive_file = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz" + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) + tempdir = None + if os.path.isdir(resolved_archive_file) or from_tf: + serialization_dir = resolved_archive_file + else: + tempdir = tempfile.mkdtemp() + with tarfile.open(resolved_archive_file, "r:gz") as archive: + archive.extractall(tempdir) + serialization_dir = tempdir + model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs) + if state_dict is None and not from_tf: + weights_path = os.path.join(serialization_dir, "pytorch_model.bin") + state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None) + if tempdir: + shutil.rmtree(tempdir) + if from_tf: + weights_path = os.path.join(serialization_dir, "model.ckpt") + return load_tf_weights_in_bert(model, weights_path) + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if "gamma" in key: + new_key = key.replace("gamma", "weight") + if "beta" in key: + new_key = key.replace("beta", "bias") + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs + ) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + start_prefix = "" + if not hasattr(model, "bert") and any(s.startswith("bert.") for s in state_dict.keys()): + start_prefix = "bert." + load(model, prefix=start_prefix) + return model + + +class BertAttention(nn.Module): + """BERT attention layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__(self, config) -> None: + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, context): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(context) + mixed_value_layer = self.value(context) + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.dropout(nn.Softmax(dim=-1)(attention_scores)) + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertOutput(nn.Module): + """BERT output layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__(self, config) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertMixedLayer(nn.Module): + """BERT cross attention layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__(self, config) -> None: + super().__init__() + self.att = BertAttention(config) + self.output = BertOutput(config) + + def forward(self, x, y): + output = self.att(x, y) + return self.output(output, x) + + +class Pooler(nn.Module): + """BERT pooler layer. + Based on: BERT (pytorch-transformer) + https://github.com/huggingface/transformers + """ + + def __init__(self, hidden_size) -> None: + super().__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class MultiModal(BertPreTrainedModel): + """ + Multimodal Transformers From Pretrained BERT Weights" + """ + + def __init__( + self, num_language_layers: int, num_vision_layers: int, num_mixed_layers: int, bert_config: dict # type: ignore + ) -> None: + """ + Args: + num_language_layers: number of language transformer layers. + num_vision_layers: number of vision transformer layers. + bert_config: configuration for bert language transformer encoder. + + """ + super().__init__() + self.config = type("obj", (object,), bert_config) + self.embeddings = BertEmbeddings(self.config) + self.language_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_language_layers)]) + self.vision_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_vision_layers)]) + self.mixed_encoder = nn.ModuleList([BertMixedLayer(self.config) for _ in range(num_mixed_layers)]) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, vision_feats=None, attention_mask=None): + language_features = self.embeddings(input_ids, token_type_ids) + for layer in self.vision_encoder: + hidden_state_vision = layer(vision_feats, None)[0] + for layer in self.language_encoder: + hidden_state_language = layer(language_features, attention_mask)[0] + for layer in self.mixed_encoder: + hidden_state_mixed = layer(hidden_state_language, hidden_state_vision) + return hidden_state_mixed + + +class Transchex(torch.nn.Module): + """ + TransChex based on: "Hatamizadeh et al.,TransCheX: Self-Supervised Pretraining of Vision-Language + Transformers for Chest X-ray Analysis" + """ + + def __init__( + self, + in_channels: int, + img_size: Union[Sequence[int], int], # type: ignore + patch_size: Union[int, Tuple[int, int]], # type: ignore + num_classes: int, + num_language_layers: int, + num_vision_layers: int, + num_mixed_layers: int, + hidden_size: int = 768, + drop_out: float = 0.0, + attention_probs_dropout_prob: float = 0.1, + gradient_checkpointing: bool = False, + hidden_act: str = "gelu", + hidden_dropout_prob: float = 0.1, + initializer_range: float = 0.02, + intermediate_size: int = 3072, + layer_norm_eps: float = 1e-12, + max_position_embeddings: int = 512, + model_type: str = "bert", + num_attention_heads: int = 12, + num_hidden_layers: int = 12, + pad_token_id: int = 0, + position_embedding_type: str = "absolute", + transformers_version: str = "4.10.2", + type_vocab_size: int = 2, + use_cache: bool = True, + vocab_size: int = 30522, + chunk_size_feed_forward: int = 0, + is_decoder: bool = False, + add_cross_attention: bool = False, + ) -> None: + """ + Args: + in_channels: dimension of input channels. + img_size: dimension of input image. + patch_size: dimension of patch size. + num_classes: number of classes if classification is used. + num_language_layers: number of language transformer layers. + num_vision_layers: number of vision transformer layers. + num_mixed_layers: number of mixed transformer layers. + drop_out: faction of the input units to drop. + + The other parameters are part of the `bert_config` to `MultiModal.from_pretrained`. + + Examples: + + .. code-block:: python + + # for 3-channel with image size of (224,224), patch size of (32,32), 3 classes, 2 language layers, + # 2 vision layers, 2 mixed modality layers and dropout of 0.2 in the classification head + net = Transchex(in_channels=3, + img_size=(224, 224), + num_classes=3, + num_language_layers=2, + num_vision_layers=2, + num_mixed_layers=2, + drop_out=0.2) + + """ + super().__init__() + bert_config = { + "attention_probs_dropout_prob": attention_probs_dropout_prob, + "classifier_dropout": None, + "gradient_checkpointing": gradient_checkpointing, + "hidden_act": hidden_act, + "hidden_dropout_prob": hidden_dropout_prob, + "hidden_size": hidden_size, + "initializer_range": initializer_range, + "intermediate_size": intermediate_size, + "layer_norm_eps": layer_norm_eps, + "max_position_embeddings": max_position_embeddings, + "model_type": model_type, + "num_attention_heads": num_attention_heads, + "num_hidden_layers": num_hidden_layers, + "pad_token_id": pad_token_id, + "position_embedding_type": position_embedding_type, + "transformers_version": transformers_version, + "type_vocab_size": type_vocab_size, + "use_cache": use_cache, + "vocab_size": vocab_size, + "chunk_size_feed_forward": chunk_size_feed_forward, + "is_decoder": is_decoder, + "add_cross_attention": add_cross_attention, + } + if not (0 <= drop_out <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + + if (img_size[0] % patch_size[0] != 0) or (img_size[1] % patch_size[1] != 0): # type: ignore + raise ValueError("img_size should be divisible by patch_size.") + + self.multimodal = MultiModal.from_pretrained( + num_language_layers=num_language_layers, + num_vision_layers=num_vision_layers, + num_mixed_layers=num_mixed_layers, + bert_config=bert_config, + ) + + self.patch_size = patch_size + self.num_patches = (img_size[0] // self.patch_size[0]) * (img_size[1] // self.patch_size[1]) # type: ignore + self.vision_proj = nn.Conv2d( + in_channels=in_channels, + out_channels=hidden_size, + kernel_size=self.patch_size, # type: ignore + stride=self.patch_size, # type: ignore + ) + self.norm_vision_pos = nn.LayerNorm(hidden_size) + self.pos_embed_vis = nn.Parameter(torch.zeros(1, self.num_patches, hidden_size)) + self.pooler = Pooler(hidden_size=hidden_size) + self.drop = torch.nn.Dropout(drop_out) + self.cls_head = torch.nn.Linear(hidden_size, num_classes) + + def forward(self, input_ids, token_type_ids=None, vision_feats=None): + attention_mask = torch.ones_like(input_ids).unsqueeze(1).unsqueeze(2) + attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) + attention_mask = (1.0 - attention_mask) * -10000.0 + vision_feats = self.vision_proj(vision_feats).flatten(2).transpose(1, 2) + vision_feats = self.norm_vision_pos(vision_feats) + vision_feats = vision_feats + self.pos_embed_vis + hidden_state_mixed = self.multimodal( + input_ids=input_ids, token_type_ids=token_type_ids, vision_feats=vision_feats, attention_mask=attention_mask + ) + pooled_features = self.pooler(hidden_state_mixed) + logits = self.cls_head(self.drop(pooled_features)) + return logits diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 70cc816fe9..21259936e7 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,7 +10,7 @@ # limitations under the License. import warnings -from typing import Sequence, Tuple, Union +from typing import Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -18,17 +18,99 @@ from monai.networks.blocks.convolutions import Convolution, ResidualUnit from monai.networks.layers.factories import Act, Norm from monai.networks.layers.simplelayers import SkipConnection -from monai.utils import alias, export +from monai.utils import alias, deprecated_arg, export -__all__ = ["UNet", "Unet", "unet"] +__all__ = ["UNet", "Unet"] @export("monai.networks.nets") @alias("Unet") class UNet(nn.Module): + """ + Enhanced version of UNet which has residual units implemented with the ResidualUnit class. + The residual part uses a convolution to change the input dimensions to match the output dimensions + if this is necessary but will use nn.Identity if not. + Refer to: https://link.springer.com/chapter/10.1007/978-3-030-12029-0_40. + + Each layer of the network has a encode and decode path with a skip connection between them. Data in the encode path + is downsampled using strided convolutions (if `strides` is given values greater than 1) and in the decode path + upsampled using strided transpose convolutions. These down or up sampling operations occur at the beginning of each + block rather than afterwards as is typical in UNet implementations. + + To further explain this consider the first example network given below. This network has 3 layers with strides + of 2 for each of the middle layers (the last layer is the bottom connection which does not down/up sample). Input + data to this network is immediately reduced in the spatial dimensions by a factor of 2 by the first convolution of + the residual unit defining the first layer of the encode part. The last layer of the decode part will upsample its + input (data from the previous layer concatenated with data from the skip connection) in the first convolution. this + ensures the final output of the network has the same shape as the input. + + Padding values for the convolutions are chosen to ensure output sizes are even divisors/multiples of the input + sizes if the `strides` value for a layer is a factor of the input sizes. A typical case is to use `strides` values + of 2 and inputs that are multiples of powers of 2. An input can thus be downsampled evenly however many times its + dimensions can be divided by 2, so for the example network inputs would have to have dimensions that are multiples + of 4. In the second example network given below the input to the bottom layer will have shape (1, 64, 15, 15) for + an input of shape (1, 1, 240, 240) demonstrating the input being reduced in size spatially by 2**4. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + channels: sequence of channels. Top block first. The length of `channels` should be no less than 2. + strides: sequence of convolution strides. The length of `stride` should equal to `len(channels) - 1`. + kernel_size: convolution kernel size, the value(s) should be odd. If sequence, + its length should equal to dimensions. Defaults to 3. + up_kernel_size: upsampling convolution kernel size, the value(s) should be odd. If sequence, + its length should equal to dimensions. Defaults to 3. + num_res_units: number of residual units. Defaults to 0. + act: activation type and arguments. Defaults to PReLU. + norm: feature normalization type and arguments. Defaults to instance norm. + dropout: dropout ratio. Defaults to no dropout. + bias: whether to have a bias term in convolution blocks. Defaults to True. + According to `Performance Tuning Guide `_, + if a conv layer is directly followed by a batch norm layer, bias should be False. + + Examples:: + + from monai.networks.nets import UNet + + # 3 layer network with down/upsampling by a factor of 2 at each layer with 2-convolution residual units + net = UNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(4, 8, 16), + strides=(2, 2), + num_res_units=2 + ) + + # 5 layer network with simple convolution/normalization/dropout/activation blocks defining the layers + net=UNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(4, 8, 16, 32, 64), + strides=(2, 2, 2, 2), + ) + + .. deprecated:: 0.6.0 + ``dimensions`` is deprecated, use ``spatial_dims`` instead. + + Note: The acceptable spatial size of input data depends on the parameters of the network, + to set appropriate spatial size, please check the tutorial for more details: + https://github.com/Project-MONAI/tutorials/blob/master/modules/UNet_input_size_constrains.ipynb. + Typically, when using a stride of 2 in down / up sampling, the output dimensions are either half of the + input when downsampling, or twice when upsampling. In this case with N numbers of layers in the network, + the inputs must have spatial dimensions that are all multiples of 2^N. + Usually, applying `resize`, `pad` or `crop` transforms can help adjust the spatial size of input data. + + """ + + @deprecated_arg( + name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." + ) def __init__( self, - dimensions: int, + spatial_dims: int, in_channels: int, out_channels: int, channels: Sequence[int], @@ -40,40 +122,9 @@ def __init__( norm: Union[Tuple, str] = Norm.INSTANCE, dropout: float = 0.0, bias: bool = True, + dimensions: Optional[int] = None, ) -> None: - """ - Enhanced version of UNet which has residual units implemented with the ResidualUnit class. - The residual part uses a convolution to change the input dimensions to match the output dimensions - if this is necessary but will use nn.Identity if not. - Refer to: https://link.springer.com/chapter/10.1007/978-3-030-12029-0_40. - - Args: - dimensions: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - channels: sequence of channels. Top block first. The length of `channels` should be no less than 2. - strides: sequence of convolution strides. The length of `stride` should equal to `len(channels) - 1`. - kernel_size: convolution kernel size, the value(s) should be odd. If sequence, - its length should equal to dimensions. Defaults to 3. - up_kernel_size: upsampling convolution kernel size, the value(s) should be odd. If sequence, - its length should equal to dimensions. Defaults to 3. - num_res_units: number of residual units. Defaults to 0. - act: activation type and arguments. Defaults to PReLU. - norm: feature normalization type and arguments. Defaults to instance norm. - dropout: dropout ratio. Defaults to no dropout. - bias: whether to have a bias term in convolution blocks. Defaults to True. - According to `Performance Tuning Guide `_, - if a conv layer is directly followed by a batch norm layer, bias should be False. - - Note: The acceptable spatial size of input data depends on the parameters of the network, - to set appropriate spatial size, please check the tutorial for more details: - https://github.com/Project-MONAI/tutorials/blob/master/modules/UNet_input_size_constrains.ipynb. - Typically, when using a stride of 2 in down / up sampling, the output dimensions are either half of the - input when downsampling, or twice when upsampling. In this case with N numbers of layers in the network, - the inputs must have spatial dimensions that are all multiples of 2^N. - Usually, applying `resize`, `pad` or `crop` transforms can help adjust the spatial size of input data. - """ super().__init__() if len(channels) < 2: @@ -83,14 +134,16 @@ def __init__( raise ValueError("the length of `strides` should equal to `len(channels) - 1`.") if delta > 0: warnings.warn(f"`len(strides) > len(channels) - 1`, the last {delta} values of strides will not be used.") + if dimensions is not None: + spatial_dims = dimensions if isinstance(kernel_size, Sequence): - if len(kernel_size) != dimensions: + if len(kernel_size) != spatial_dims: raise ValueError("the length of `kernel_size` should equal to `dimensions`.") if isinstance(up_kernel_size, Sequence): - if len(up_kernel_size) != dimensions: + if len(up_kernel_size) != spatial_dims: raise ValueError("the length of `up_kernel_size` should equal to `dimensions`.") - self.dimensions = dimensions + self.dimensions = spatial_dims self.in_channels = in_channels self.out_channels = out_channels self.channels = channels @@ -145,8 +198,10 @@ def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_ strides: convolution stride. is_top: True if this is the top block. """ + mod: nn.Module if self.num_res_units > 0: - return ResidualUnit( + + mod = ResidualUnit( self.dimensions, in_channels, out_channels, @@ -158,7 +213,8 @@ def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_ dropout=self.dropout, bias=self.bias, ) - return Convolution( + return mod + mod = Convolution( self.dimensions, in_channels, out_channels, @@ -169,6 +225,7 @@ def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_ dropout=self.dropout, bias=self.bias, ) + return mod def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module: """ @@ -225,4 +282,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -Unet = unet = UNet +Unet = UNet diff --git a/monai/networks/nets/unetr.py b/monai/networks/nets/unetr.py index 9990cb6643..b22f0584a2 100644 --- a/monai/networks/nets/unetr.py +++ b/monai/networks/nets/unetr.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -70,7 +70,7 @@ def __init__( """ - super(UNETR, self).__init__() + super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") diff --git a/monai/networks/nets/varautoencoder.py b/monai/networks/nets/varautoencoder.py index 7f54890992..7386883124 100644 --- a/monai/networks/nets/varautoencoder.py +++ b/monai/networks/nets/varautoencoder.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,16 +19,65 @@ from monai.networks.layers.convutils import calculate_out_shape, same_padding from monai.networks.layers.factories import Act, Norm from monai.networks.nets import AutoEncoder +from monai.utils import deprecated_arg __all__ = ["VarAutoEncoder"] class VarAutoEncoder(AutoEncoder): - """Variational Autoencoder based on the paper - https://arxiv.org/abs/1312.6114""" + """ + Variational Autoencoder based on the paper - https://arxiv.org/abs/1312.6114 + + Args: + spatial_dims: number of spatial dimensions. + in_shape: shape of input data starting with channel dimension. + out_channels: number of output channels. + latent_size: size of the latent variable. + channels: sequence of channels. Top block first. The length of `channels` should be no less than 2. + strides: sequence of convolution strides. The length of `stride` should equal to `len(channels) - 1`. + kernel_size: convolution kernel size, the value(s) should be odd. If sequence, + its length should equal to dimensions. Defaults to 3. + up_kernel_size: upsampling convolution kernel size, the value(s) should be odd. If sequence, + its length should equal to dimensions. Defaults to 3. + num_res_units: number of residual units. Defaults to 0. + inter_channels: sequence of channels defining the blocks in the intermediate layer between encode and decode. + inter_dilations: defines the dilation value for each block of the intermediate layer. Defaults to 1. + num_inter_units: number of residual units for each block of the intermediate layer. Defaults to 0. + act: activation type and arguments. Defaults to PReLU. + norm: feature normalization type and arguments. Defaults to instance norm. + dropout: dropout ratio. Defaults to no dropout. + bias: whether to have a bias term in convolution blocks. Defaults to True. + According to `Performance Tuning Guide `_, + if a conv layer is directly followed by a batch norm layer, bias should be False. + + .. deprecated:: 0.6.0 + ``dimensions`` is deprecated, use ``spatial_dims`` instead. + + Examples:: + + from monai.networks.nets import VarAutoEncoder + + # 3 layer network accepting images with dimensions (1, 32, 32) and using a latent vector with 2 values + model = VarAutoEncoder( + dimensions=2, + in_shape=(32, 32), # image spatial shape + out_channels=1, + latent_size=2, + channels=(16, 32, 64), + strides=(1, 2, 2), + ) + + see also: + - Variational autoencoder network with MedNIST Dataset + https://github.com/Project-MONAI/tutorials/blob/master/modules/varautoencoder_mednist.ipynb + """ + @deprecated_arg( + name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." + ) def __init__( self, - dimensions: int, + spatial_dims: int, in_shape: Sequence[int], out_channels: int, latent_size: int, @@ -44,15 +93,18 @@ def __init__( norm: Union[Tuple, str] = Norm.INSTANCE, dropout: Optional[Union[Tuple, str, float]] = None, bias: bool = True, + dimensions: Optional[int] = None, ) -> None: self.in_channels, *self.in_shape = in_shape self.latent_size = latent_size self.final_size = np.asarray(self.in_shape, dtype=int) + if dimensions is not None: + spatial_dims = dimensions super().__init__( - dimensions, + spatial_dims, self.in_channels, out_channels, channels, diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index 3a5d94cc37..62e92603ab 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,6 +18,8 @@ from monai.networks.blocks.patchembedding import PatchEmbeddingBlock from monai.networks.blocks.transformerblock import TransformerBlock +__all__ = ["ViT"] + class ViT(nn.Module): """ @@ -68,7 +70,7 @@ def __init__( """ - super(ViT, self).__init__() + super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py new file mode 100644 index 0000000000..3ec89488ff --- /dev/null +++ b/monai/networks/nets/vitautoenc.py @@ -0,0 +1,119 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import Sequence, Union + +import torch +import torch.nn as nn + +from monai.networks.blocks.patchembedding import PatchEmbeddingBlock +from monai.networks.blocks.transformerblock import TransformerBlock +from monai.networks.layers import Conv + +__all__ = ["ViTAutoEnc"] + + +class ViTAutoEnc(nn.Module): + """ + Vision Transformer (ViT), based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + + Modified to also give same dimension outputs as the input size of the image + """ + + def __init__( + self, + in_channels: int, + img_size: Union[Sequence[int], int], + patch_size: Union[Sequence[int], int], + out_channels: int = 1, + deconv_chns: int = 16, + hidden_size: int = 768, + mlp_dim: int = 3072, + num_layers: int = 12, + num_heads: int = 12, + pos_embed: str = "conv", + dropout_rate: float = 0.0, + spatial_dims: int = 3, + ) -> None: + """ + Args: + in_channels: dimension of input channels or the number of channels for input + img_size: dimension of input image. + patch_size: dimension of patch size. + hidden_size: dimension of hidden layer. + out_channels: number of output channels. + deconv_chns: number of channels for the deconvolution layers. + mlp_dim: dimension of feedforward layer. + num_layers: number of transformer blocks. + num_heads: number of attention heads. + pos_embed: position embedding layer type. + dropout_rate: faction of the input units to drop. + spatial_dims: number of spatial dimensions. + + Examples:: + + # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone + # It will provide an output of same size as that of the input + >>> net = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(96,96,96), pos_embed='conv') + + # for 3-channel with image size of (128,128,128), output will be same size as of input + >>> net = ViTAutoEnc(in_channels=3, patch_size=(16,16,16), img_size=(128,128,128), pos_embed='conv') + + """ + + super().__init__() + + self.spatial_dims = spatial_dims + + self.patch_embedding = PatchEmbeddingBlock( + in_channels=in_channels, + img_size=img_size, + patch_size=patch_size, + hidden_size=hidden_size, + num_heads=num_heads, + pos_embed=pos_embed, + dropout_rate=dropout_rate, + spatial_dims=self.spatial_dims, + ) + self.blocks = nn.ModuleList( + [TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)] + ) + self.norm = nn.LayerNorm(hidden_size) + + new_patch_size = [4] * self.spatial_dims + conv_trans = Conv[Conv.CONVTRANS, self.spatial_dims] + # self.conv3d_transpose* is to be compatible with existing 3d model weights. + self.conv3d_transpose = conv_trans(hidden_size, deconv_chns, kernel_size=new_patch_size, stride=new_patch_size) + self.conv3d_transpose_1 = conv_trans( + in_channels=deconv_chns, out_channels=out_channels, kernel_size=new_patch_size, stride=new_patch_size + ) + + def forward(self, x): + """ + Args: + x: input tensor must have isotropic spatial dimensions, + such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``. + """ + x = self.patch_embedding(x) + hidden_states_out = [] + for blk in self.blocks: + x = blk(x) + hidden_states_out.append(x) + x = self.norm(x) + x = x.transpose(1, 2) + d = [round(math.pow(x.shape[2], 1 / self.spatial_dims))] * self.spatial_dims + x = torch.reshape(x, [x.shape[0], x.shape[1], *d]) + x = self.conv3d_transpose(x) + x = self.conv3d_transpose_1(x) + return x, hidden_states_out diff --git a/monai/networks/nets/vnet.py b/monai/networks/nets/vnet.py index 72f3290a89..ede5f688aa 100644 --- a/monai/networks/nets/vnet.py +++ b/monai/networks/nets/vnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -30,11 +30,11 @@ def get_acti_layer(act: Union[Tuple[str, Dict], str], nchan: int = 0): class LUConv(nn.Module): def __init__(self, spatial_dims: int, nchan: int, act: Union[Tuple[str, Dict], str], bias: bool = False): - super(LUConv, self).__init__() + super().__init__() self.act_function = get_acti_layer(act, nchan) self.conv_block = Convolution( - dimensions=spatial_dims, + spatial_dims=spatial_dims, in_channels=nchan, out_channels=nchan, kernel_size=5, @@ -65,7 +65,7 @@ def __init__( act: Union[Tuple[str, Dict], str], bias: bool = False, ): - super(InputTransition, self).__init__() + super().__init__() if 16 % in_channels != 0: raise ValueError(f"16 should be divisible by in_channels, got in_channels={in_channels}.") @@ -74,7 +74,7 @@ def __init__( self.in_channels = in_channels self.act_function = get_acti_layer(act, 16) self.conv_block = Convolution( - dimensions=spatial_dims, + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=16, kernel_size=5, @@ -102,7 +102,7 @@ def __init__( dropout_dim: int = 3, bias: bool = False, ): - super(DownTransition, self).__init__() + super().__init__() conv_type: Type[Union[nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] norm_type: Type[Union[nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] @@ -138,7 +138,7 @@ def __init__( dropout_prob: Optional[float] = None, dropout_dim: int = 3, ): - super(UpTransition, self).__init__() + super().__init__() conv_trans_type: Type[Union[nn.ConvTranspose2d, nn.ConvTranspose3d]] = Conv[Conv.CONVTRANS, spatial_dims] norm_type: Type[Union[nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] @@ -174,13 +174,13 @@ def __init__( act: Union[Tuple[str, Dict], str], bias: bool = False, ): - super(OutputTransition, self).__init__() + super().__init__() conv_type: Type[Union[nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] self.act_function1 = get_acti_layer(act, out_channels) self.conv_block = Convolution( - dimensions=spatial_dims, + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=5, diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 9d20d2a83b..f7fd2e2956 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,11 +15,15 @@ import warnings from collections import OrderedDict from contextlib import contextmanager -from typing import Any, Callable, Mapping, Optional, Sequence, Union +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union import torch import torch.nn as nn +from monai.utils.deprecate_utils import deprecated_arg +from monai.utils.misc import ensure_tuple, set_determinism +from monai.utils.module import pytorch_after + __all__ = [ "one_hot", "slice_channels", @@ -32,6 +36,7 @@ "eval_mode", "train_mode", "copy_model_state", + "convert_to_torchscript", ] @@ -225,9 +230,14 @@ def icnr_init(conv, upsample_factor, init=nn.init.kaiming_normal_): conv.weight.data.copy_(kernel) -def pixelshuffle(x: torch.Tensor, dimensions: int, scale_factor: int) -> torch.Tensor: +@deprecated_arg( + name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." +) +def pixelshuffle( + x: torch.Tensor, spatial_dims: int, scale_factor: int, dimensions: Optional[int] = None +) -> torch.Tensor: """ - Apply pixel shuffle to the tensor `x` with spatial dimensions `dimensions` and scaling factor `scale_factor`. + Apply pixel shuffle to the tensor `x` with spatial dimensions `spatial_dims` and scaling factor `scale_factor`. See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution Using a nEfficient Sub-Pixel Convolutional Neural Network." @@ -236,17 +246,21 @@ def pixelshuffle(x: torch.Tensor, dimensions: int, scale_factor: int) -> torch.T Args: x: Input tensor - dimensions: number of spatial dimensions, typically 2 or 3 for 2D or 3D + spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D scale_factor: factor to rescale the spatial dimensions by, must be >=1 + .. deprecated:: 0.6.0 + ``dimensions`` is deprecated, use ``spatial_dims`` instead. + Returns: Reshuffled version of `x`. Raises: - ValueError: When input channels of `x` are not divisible by (scale_factor ** dimensions) + ValueError: When input channels of `x` are not divisible by (scale_factor ** spatial_dims) """ - - dim, factor = dimensions, scale_factor + if dimensions is not None: + spatial_dims = dimensions + dim, factor = spatial_dims, scale_factor input_size = list(x.size()) batch_size, channels = input_size[:2] scale_divisor = factor ** dim @@ -413,3 +427,70 @@ def copy_model_state( if inplace and isinstance(dst, torch.nn.Module): dst.load_state_dict(dst_dict) return dst_dict, updated_keys, unchanged_keys + + +def convert_to_torchscript( + model: nn.Module, + filename_or_obj: Optional[Any] = None, + extra_files: Optional[Dict] = None, + verify: bool = False, + inputs: Optional[Sequence[Any]] = None, + device: Optional[torch.device] = None, + rtol: float = 1e-4, + atol: float = 0.0, + **kwargs, +): + """ + Utility to convert a model into TorchScript model and save to file, + with optional input / output data verification. + + Args: + model: source PyTorch model to save. + filename_or_obj: if not None, specify a file-like object (has to implement write and flush) + or a string containing a file path name to save the TorchScript model. + extra_files: map from filename to contents which will be stored as part of the save model file. + works for PyTorch 1.7 or later. + for more details: https://pytorch.org/docs/stable/generated/torch.jit.save.html. + verify: whether to verify the input and output of TorchScript model. + if `filename_or_obj` is not None, load the saved TorchScript model and verify. + inputs: input test data to verify model, should be a sequence of data, every item maps to a argument + of `model()` function. + device: target device to verify the model, if None, use CUDA if available. + rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model. + atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model. + kwargs: other arguments except `obj` for `torch.jit.script()` to convert model, for more details: + https://pytorch.org/docs/master/generated/torch.jit.script.html. + + """ + model.eval() + with torch.no_grad(): + script_module = torch.jit.script(model, **kwargs) + if filename_or_obj is not None: + if not pytorch_after(1, 7): + torch.jit.save(m=script_module, f=filename_or_obj) + else: + torch.jit.save(m=script_module, f=filename_or_obj, _extra_files=extra_files) + + if verify: + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if inputs is None: + raise ValueError("missing input data for verification.") + + inputs = [i.to(device) if isinstance(i, torch.Tensor) else i for i in inputs] + ts_model = torch.jit.load(filename_or_obj) if filename_or_obj is not None else script_module + ts_model.eval().to(device) + model = model.to(device) + + with torch.no_grad(): + set_determinism(seed=0) + torch_out = ensure_tuple(model(*inputs)) + set_determinism(seed=0) + torchscript_out = ensure_tuple(ts_model(*inputs)) + set_determinism(seed=None) + # compare TorchScript and PyTorch results + for r1, r2 in zip(torch_out, torchscript_out): + if isinstance(r1, torch.Tensor) or isinstance(r2, torch.Tensor): + torch.testing.assert_allclose(r1, r2, rtol=rtol, atol=atol) + + return script_module diff --git a/monai/optimizers/__init__.py b/monai/optimizers/__init__.py index e53aa8d468..8ce5d3f925 100644 --- a/monai/optimizers/__init__.py +++ b/monai/optimizers/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,5 +10,6 @@ # limitations under the License. from .lr_finder import LearningRateFinder +from .lr_scheduler import ExponentialLR, LinearLR, WarmupCosineSchedule from .novograd import Novograd from .utils import generate_param_groups diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index 49d4427b3d..ce092d33ab 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pickle import warnings from functools import partial from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union @@ -17,6 +18,7 @@ import torch import torch.nn as nn from torch.optim import Optimizer +from torch.serialization import DEFAULT_PROTOCOL from torch.utils.data import DataLoader from monai.networks.utils import eval_mode @@ -120,7 +122,7 @@ def __iter__(self): def __next__(self): self.run_counter += 1 - return super(ValDataLoaderIter, self).__next__() + return super().__next__() def default_image_extractor(x: Any) -> torch.Tensor: @@ -144,30 +146,30 @@ class LearningRateFinder: and what is the optimal learning rate. Example (fastai approach): - >>> lr_finder = LearningRateFinder(net, optimizer, criterion) - >>> lr_finder.range_test(data_loader, end_lr=100, num_iter=100) - >>> lr_finder.get_steepest_gradient() - >>> lr_finder.plot() # to inspect the loss-learning rate graph + >>> lr_finder = LearningRateFinder(net, optimizer, criterion) + >>> lr_finder.range_test(data_loader, end_lr=100, num_iter=100) + >>> lr_finder.get_steepest_gradient() + >>> lr_finder.plot() # to inspect the loss-learning rate graph Example (Leslie Smith's approach): - >>> lr_finder = LearningRateFinder(net, optimizer, criterion) - >>> lr_finder.range_test(train_loader, val_loader=val_loader, end_lr=1, num_iter=100, step_mode="linear") + >>> lr_finder = LearningRateFinder(net, optimizer, criterion) + >>> lr_finder.range_test(train_loader, val_loader=val_loader, end_lr=1, num_iter=100, step_mode="linear") Gradient accumulation is supported; example: - >>> train_data = ... # prepared dataset - >>> desired_bs, real_bs = 32, 4 # batch size - >>> accumulation_steps = desired_bs // real_bs # required steps for accumulation - >>> data_loader = torch.utils.data.DataLoader(train_data, batch_size=real_bs, shuffle=True) - >>> acc_lr_finder = LearningRateFinder(net, optimizer, criterion) - >>> acc_lr_finder.range_test(data_loader, end_lr=10, num_iter=100, accumulation_steps=accumulation_steps) + >>> train_data = ... # prepared dataset + >>> desired_bs, real_bs = 32, 4 # batch size + >>> accumulation_steps = desired_bs // real_bs # required steps for accumulation + >>> data_loader = torch.utils.data.DataLoader(train_data, batch_size=real_bs, shuffle=True) + >>> acc_lr_finder = LearningRateFinder(net, optimizer, criterion) + >>> acc_lr_finder.range_test(data_loader, end_lr=10, num_iter=100, accumulation_steps=accumulation_steps) By default, image will be extracted from data loader with x["image"] and x[0], depending on whether batch data is a dictionary or not (and similar behaviour for extracting the label). If your data loader returns something other than this, pass a callable function to extract it, e.g.: - >>> image_extractor = lambda x: x["input"] - >>> label_extractor = lambda x: x[100] - >>> lr_finder = LearningRateFinder(net, optimizer, criterion) - >>> lr_finder.range_test(train_loader, val_loader, image_extractor, label_extractor) + >>> image_extractor = lambda x: x["input"] + >>> label_extractor = lambda x: x[100] + >>> lr_finder = LearningRateFinder(net, optimizer, criterion) + >>> lr_finder.range_test(train_loader, val_loader, image_extractor, label_extractor) References: Modified from: https://github.com/davidtvs/pytorch-lr-finder. @@ -183,6 +185,8 @@ def __init__( memory_cache: bool = True, cache_dir: Optional[str] = None, amp: bool = False, + pickle_module=pickle, + pickle_protocol: int = DEFAULT_PROTOCOL, verbose: bool = True, ) -> None: """Constructor. @@ -202,6 +206,12 @@ def __init__( specified, system-wide temporary directory is used. Notice that this parameter will be ignored if `memory_cache` is True. amp: use Automatic Mixed Precision + pickle_module: module used for pickling metadata and objects, default to `pickle`. + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. + pickle_protocol: can be specified to override the default protocol, default to `2`. + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. verbose: verbose output Returns: None @@ -221,7 +231,9 @@ def __init__( # Save the original state of the model and optimizer so they can be restored if # needed self.model_device = next(self.model.parameters()).device - self.state_cacher = StateCacher(memory_cache, cache_dir=cache_dir) + self.state_cacher = StateCacher( + in_memory=memory_cache, cache_dir=cache_dir, pickle_module=pickle_module, pickle_protocol=pickle_protocol + ) self.state_cacher.store("model", self.model.state_dict()) self.state_cacher.store("optimizer", self.optimizer.state_dict()) @@ -328,11 +340,7 @@ def range_test( print(f"Computing optimal learning rate, iteration {iteration + 1}/{num_iter}") # Train on batch and retrieve loss - loss = self._train_batch( - train_iter, - accumulation_steps, - non_blocking_transfer=non_blocking_transfer, - ) + loss = self._train_batch(train_iter, accumulation_steps, non_blocking_transfer=non_blocking_transfer) if val_loader: loss = self._validate(val_iter, non_blocking_transfer=non_blocking_transfer) @@ -429,11 +437,7 @@ def _validate(self, val_iter: ValDataLoaderIter, non_blocking_transfer: bool = T return running_loss / len(val_iter.dataset) - def get_lrs_and_losses( - self, - skip_start: int = 0, - skip_end: int = 0, - ) -> Tuple[list, list]: + def get_lrs_and_losses(self, skip_start: int = 0, skip_end: int = 0) -> Tuple[list, list]: """Get learning rates and their corresponding losses Args: @@ -454,9 +458,7 @@ def get_lrs_and_losses( return lrs, losses def get_steepest_gradient( - self, - skip_start: int = 0, - skip_end: int = 0, + self, skip_start: int = 0, skip_end: int = 0 ) -> Union[Tuple[float, float], Tuple[None, None]]: """Get learning rate which has steepest gradient and its corresponding loss @@ -476,14 +478,7 @@ def get_steepest_gradient( print("Failed to compute the gradients, there might not be enough points.") return None, None - def plot( - self, - skip_start: int = 0, - skip_end: int = 0, - log_lr: bool = True, - ax=None, - steepest_lr: bool = True, - ): + def plot(self, skip_start: int = 0, skip_end: int = 0, log_lr: bool = True, ax=None, steepest_lr: bool = True): """Plots the learning rate range test. Args: diff --git a/monai/optimizers/lr_scheduler.py b/monai/optimizers/lr_scheduler.py index 9416b583f7..83412c61ea 100644 --- a/monai/optimizers/lr_scheduler.py +++ b/monai/optimizers/lr_scheduler.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -33,7 +33,7 @@ def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoc """ self.end_lr = end_lr self.num_iter = num_iter - super(_LRSchedulerMONAI, self).__init__(optimizer, last_epoch) + super().__init__(optimizer, last_epoch) class LinearLR(_LRSchedulerMONAI): @@ -77,7 +77,7 @@ def __init__( self.warmup_steps = warmup_steps self.t_total = t_total self.cycles = cycles - super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch) + super().__init__(optimizer, self.lr_lambda, last_epoch) def lr_lambda(self, step): if step < self.warmup_steps: diff --git a/monai/optimizers/novograd.py b/monai/optimizers/novograd.py index 62e42cc9ab..07a6aff90a 100644 --- a/monai/optimizers/novograd.py +++ b/monai/optimizers/novograd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,7 +20,7 @@ class Novograd(Optimizer): Novograd based on `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks `_. The code is adapted from the implementations in `Jasper for PyTorch - `_, + `_, and `OpenSeq2Seq `_. Args: @@ -45,28 +45,23 @@ def __init__( amsgrad: bool = False, ): if 0.0 > lr: - raise ValueError("Invalid learning rate: {}".format(lr)) + raise ValueError(f"Invalid learning rate: {lr}") if 0.0 > eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) + raise ValueError(f"Invalid epsilon value: {eps}") if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") if 0.0 > weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - grad_averaging=grad_averaging, - amsgrad=amsgrad, + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, grad_averaging=grad_averaging, amsgrad=amsgrad ) - super(Novograd, self).__init__(params, defaults) + super().__init__(params, defaults) def __setstate__(self, state): - super(Novograd, self).__setstate__(state) + super().__setstate__(state) for group in self.param_groups: group.setdefault("amsgrad", False) diff --git a/monai/optimizers/utils.py b/monai/optimizers/utils.py index c52ab07a04..1a040927d8 100644 --- a/monai/optimizers/utils.py +++ b/monai/optimizers/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -47,7 +47,7 @@ def generate_param_groups( .. code-block:: python - net = Unet(dimensions=3, in_channels=1, out_channels=3, channels=[2, 2, 2], strides=[1, 1, 1]) + net = Unet(spatial_dims=3, in_channels=1, out_channels=3, channels=[2, 2, 2], strides=[1, 1, 1]) print(net) # print out network components to select expected items print(net.named_parameters()) # print out all the named parameters to filter out expected items params = generate_param_groups( diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 2ea7e3aa63..c9cd1b4e0d 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,6 +18,7 @@ CenterSpatialCrop, CropForeground, DivisiblePad, + Pad, RandCropByLabelClasses, RandCropByPosNegLabel, RandScaleCrop, @@ -48,7 +49,7 @@ DivisiblePadd, DivisiblePadD, DivisiblePadDict, - NumpyPadModeSequence, + PadModeSequence, RandCropByLabelClassesd, RandCropByLabelClassesD, RandCropByLabelClassesDict, @@ -85,12 +86,13 @@ GibbsNoise, HistogramNormalize, KSpaceSpikeNoise, - LocalPatchShuffling, MaskIntensity, NormalizeIntensity, RandAdjustContrast, RandBiasField, RandCoarseDropout, + RandCoarseShuffle, + RandCoarseTransform, RandGaussianNoise, RandGaussianSharpen, RandGaussianSmooth, @@ -143,6 +145,9 @@ RandCoarseDropoutd, RandCoarseDropoutD, RandCoarseDropoutDict, + RandCoarseShuffled, + RandCoarseShuffleD, + RandCoarseShuffleDict, RandGaussianNoised, RandGaussianNoiseD, RandGaussianNoiseDict, @@ -173,6 +178,9 @@ RandStdShiftIntensityd, RandStdShiftIntensityD, RandStdShiftIntensityDict, + SavitzkyGolaySmoothd, + SavitzkyGolaySmoothD, + SavitzkyGolaySmoothDict, ScaleIntensityd, ScaleIntensityD, ScaleIntensityDict, @@ -192,7 +200,7 @@ ThresholdIntensityD, ThresholdIntensityDict, ) -from .inverse import InvertibleTransform +from .inverse import InvertibleTransform, TraceableTransform from .inverse_batch_transform import BatchInverseTransform, Decollated from .io.array import SUPPORTED_READERS, LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict @@ -269,11 +277,18 @@ VoteEnsembled, VoteEnsembleDict, ) +from .smooth_field.array import ( + RandSmoothDeform, + RandSmoothFieldAdjustContrast, + RandSmoothFieldAdjustIntensity, + SmoothField, +) +from .smooth_field.dictionary import RandSmoothDeformd, RandSmoothFieldAdjustContrastd, RandSmoothFieldAdjustIntensityd from .spatial.array import ( - AddCoordinateChannels, Affine, AffineGrid, Flip, + GridDistortion, Orientation, Rand2DElastic, Rand3DElastic, @@ -282,6 +297,7 @@ RandAxisFlip, RandDeformGrid, RandFlip, + RandGridDistortion, RandRotate, RandRotate90, RandZoom, @@ -293,15 +309,15 @@ Zoom, ) from .spatial.dictionary import ( - AddCoordinateChannelsd, - AddCoordinateChannelsD, - AddCoordinateChannelsDict, Affined, AffineD, AffineDict, Flipd, FlipD, FlipDict, + GridDistortiond, + GridDistortionD, + GridDistortionDict, Orientationd, OrientationD, OrientationDict, @@ -320,6 +336,9 @@ RandFlipd, RandFlipD, RandFlipDict, + RandGridDistortiond, + RandGridDistortionD, + RandGridDistortionDict, RandRotate90d, RandRotate90D, RandRotate90Dict, @@ -348,12 +367,14 @@ from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform from .utility.array import ( AddChannel, + AddCoordinateChannels, AddExtremePointsChannel, AsChannelFirst, AsChannelLast, CastToType, ClassesToIndices, ConvertToMultiChannelBasedOnBratsClasses, + CuCIM, DataStats, EnsureChannelFirst, EnsureType, @@ -363,6 +384,7 @@ LabelToMask, Lambda, MapLabelValue, + RandCuCIM, RandLambda, RemoveRepeatedChannel, RepeatChannel, @@ -381,6 +403,9 @@ AddChanneld, AddChannelD, AddChannelDict, + AddCoordinateChannelsd, + AddCoordinateChannelsD, + AddCoordinateChannelsDict, AddExtremePointsChanneld, AddExtremePointsChannelD, AddExtremePointsChannelDict, @@ -405,6 +430,9 @@ CopyItemsd, CopyItemsD, CopyItemsDict, + CuCIMd, + CuCIMD, + CuCIMDict, DataStatsd, DataStatsD, DataStatsDict, @@ -435,6 +463,9 @@ MapLabelValued, MapLabelValueD, MapLabelValueDict, + RandCuCIMd, + RandCuCIMD, + RandCuCIMDict, RandLambdad, RandLambdaD, RandLambdaDict, @@ -486,6 +517,7 @@ allow_missing_keys_mode, compute_divisible_spatial_size, convert_inverse_interp_mode, + convert_pad_mode, copypaste_arrays, create_control_grid, create_grid, @@ -518,4 +550,21 @@ weighted_patch_samples, zero_margins, ) -from .utils_pytorch_numpy_unification import in1d, moveaxis +from .utils_pytorch_numpy_unification import ( + any_np_pt, + clip, + concatenate, + cumsum, + floor_divide, + in1d, + isfinite, + isnan, + maximum, + moveaxis, + nonzero, + percentile, + ravel, + repeat, + unravel_index, + where, +) diff --git a/monai/transforms/adaptors.py b/monai/transforms/adaptors.py index 434d1f1c05..92fd11cf79 100644 --- a/monai/transforms/adaptors.py +++ b/monai/transforms/adaptors.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 4bf175769b..165d9b732f 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -28,7 +28,7 @@ apply_transform, ) from monai.utils import MAX_SEED, ensure_tuple, get_seed -from monai.utils.enums import InverseKeys +from monai.utils.enums import TraceKeys __all__ = ["Compose", "OneOf"] @@ -173,15 +173,15 @@ def inverse(self, data): class OneOf(Compose): """ - ``OneOf`` provides the ability to radomly choose one transform out of a - list of callables with predfined probabilities for each. + ``OneOf`` provides the ability to randomly choose one transform out of a + list of callables with pre-defined probabilities for each. Args: transforms: sequence of callables. weights: probabilities corresponding to each callable in transforms. Probabilities are normalized to sum to one. - OneOf inherits from Compose and uses args map_items and unpack_items in + ``OneOf`` inherits from ``Compose`` and uses args ``map_items`` and ``unpack_items`` in the same way. """ @@ -204,14 +204,13 @@ def __init__( def _normalize_probabilities(self, weights): if len(weights) == 0: return weights - else: - weights = np.array(weights) - if np.any(weights < 0): - raise AssertionError("Probabilities must be greater than or equal to zero.") - if np.all(weights == 0): - raise AssertionError("At least one probability must be greater than zero.") - weights = weights / weights.sum() - return list(weights) + weights = np.array(weights) + if np.any(weights < 0): + raise AssertionError("Probabilities must be greater than or equal to zero.") + if np.all(weights == 0): + raise AssertionError("At least one probability must be greater than zero.") + weights = weights / weights.sum() + return list(weights) def flatten(self): transforms = [] @@ -232,16 +231,15 @@ def flatten(self): def __call__(self, data): if len(self.transforms) == 0: return data - else: - index = self.R.multinomial(1, self.weights).argmax() - _transform = self.transforms[index] - data = apply_transform(_transform, data, self.map_items, self.unpack_items) - # if the data is a mapping (dictionary), append the OneOf transform to the end - if isinstance(data, Mapping): - for key in data.keys(): - if key + InverseKeys.KEY_SUFFIX in data: - self.push_transform(data, key, extra_info={"index": index}) - return data + index = self.R.multinomial(1, self.weights).argmax() + _transform = self.transforms[index] + data = apply_transform(_transform, data, self.map_items, self.unpack_items) + # if the data is a mapping (dictionary), append the OneOf transform to the end + if isinstance(data, Mapping): + for key in data.keys(): + if self.trace_key(key) in data: + self.push_transform(data, key, extra_info={"index": index}) + return data def inverse(self, data): if len(self.transforms) == 0: @@ -252,18 +250,15 @@ def inverse(self, data): # loop until we get an index and then break (since they'll all be the same) index = None for key in data.keys(): - if key + InverseKeys.KEY_SUFFIX in data: + if self.trace_key(key) in data: # get the index of the applied OneOf transform - index = self.get_most_recent_transform(data, key)[InverseKeys.EXTRA_INFO]["index"] + index = self.get_most_recent_transform(data, key)[TraceKeys.EXTRA_INFO]["index"] # and then remove the OneOf transform self.pop_transform(data, key) if index is None: - raise RuntimeError("No invertible transforms have been applied") + # no invertible transforms have been applied + return data - # if applied transform is not InvertibleTransform, throw error _transform = self.transforms[index] - if not isinstance(_transform, InvertibleTransform): - raise RuntimeError(f"Applied OneOf transform is not invertible (applied index: {index}).") - # apply the inverse - return _transform.inverse(data) + return _transform.inverse(data) if isinstance(_transform, InvertibleTransform) else data diff --git a/monai/transforms/croppad/__init__.py b/monai/transforms/croppad/__init__.py index 14ae193634..1e97f89407 100644 --- a/monai/transforms/croppad/__init__.py +++ b/monai/transforms/croppad/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 74f556cc1a..faf5306ce0 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,11 +22,12 @@ from torch.nn.functional import pad as pad_pt from monai.config import IndexSelection -from monai.config.type_definitions import NdarrayTensor +from monai.config.type_definitions import NdarrayOrTensor from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.transform import Randomizable, Transform from monai.transforms.utils import ( compute_divisible_spatial_size, + convert_pad_mode, generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, @@ -35,11 +36,21 @@ map_classes_to_indices, weighted_patch_samples, ) -from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple, look_up_option +from monai.transforms.utils_pytorch_numpy_unification import floor_divide, maximum +from monai.utils import ( + Method, + NumpyPadMode, + PytorchPadMode, + ensure_tuple, + ensure_tuple_rep, + fall_back_tuple, + look_up_option, +) from monai.utils.enums import TransformBackends -from monai.utils.type_conversion import convert_data_type +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type __all__ = [ + "Pad", "SpatialPad", "BorderPad", "DivisiblePad", @@ -61,16 +72,18 @@ class Pad(Transform): """ Perform padding for a given an amount of padding in each dimension. - If input is `torch.Tensor` and mode is `constant`, `torch.nn.functional.pad` will be used. - Otherwise, `np.pad` will be used (input converted to `np.ndarray` if necessary). - Uses np.pad so in practice, a mode needs to be provided. See numpy.lib.arraypad.pad - for additional details. + If input is `torch.Tensor`, `torch.nn.functional.pad` will be used, otherwise, `np.pad` will be used. + Args: to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...]. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -78,43 +91,44 @@ class Pad(Transform): def __init__( self, to_pad: List[Tuple[int, int]], - mode: Union[NumpyPadMode, str, None] = NumpyPadMode.CONSTANT, - **np_kwargs, + mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, + **kwargs, ) -> None: self.to_pad = to_pad - self.mode = mode or NumpyPadMode.CONSTANT - self.np_kwargs = np_kwargs + self.mode = mode + self.kwargs = kwargs @staticmethod - def _np_pad(img: np.ndarray, all_pad_width, mode, **np_kwargs) -> np.ndarray: - img_np, *_ = convert_data_type(img, np.ndarray) - return np.pad(img_np, all_pad_width, mode=mode, **np_kwargs) # type: ignore + def _np_pad(img: np.ndarray, all_pad_width, mode, **kwargs) -> np.ndarray: + return np.pad(img, all_pad_width, mode=mode, **kwargs) # type: ignore @staticmethod - def _pt_pad(img: torch.Tensor, all_pad_width, mode, **np_kwargs) -> torch.Tensor: - pt_pad_width = [val for sublist in all_pad_width for val in sublist[::-1]][::-1] - return pad_pt(img, pt_pad_width, mode=mode, **np_kwargs) + def _pt_pad(img: torch.Tensor, all_pad_width, mode, **kwargs) -> torch.Tensor: + pt_pad_width = [val for sublist in all_pad_width[1:] for val in sublist[::-1]][::-1] + # torch.pad expects `[B, C, H, W, [D]]` shape + return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **kwargs).squeeze(0) - def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] = None) -> NdarrayTensor: + def __call__( + self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None + ) -> NdarrayOrTensor: """ Args: img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - One of the listed string values or a user supplied function. Defaults to ``self.mode``. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"`` or ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to `self.mode`. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + """ if not np.asarray(self.to_pad).any(): # all zeros, skip padding return img - mode = mode or self.mode - mode = mode.value if isinstance(mode, NumpyPadMode) else mode - if isinstance(img, torch.Tensor) and mode == "constant" and not self.np_kwargs: - pad = self._pt_pad - else: - pad = self._np_pad # type: ignore - return pad(img, self.to_pad, mode, **self.np_kwargs) + mode = convert_pad_mode(dst=img, mode=mode or self.mode).value + pad = self._pt_pad if isinstance(img, torch.Tensor) else self._np_pad + return pad(img, self.to_pad, mode, **self.kwargs) # type: ignore class SpatialPad(Transform): @@ -135,12 +149,14 @@ class SpatialPad(Transform): `spatial_size=[32, 25, -1]`, the spatial size of output data will be [32, 30, 30]. method: {``"symmetric"``, ``"end"``} Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ @@ -150,13 +166,13 @@ def __init__( self, spatial_size: Union[Sequence[int], int], method: Union[Method, str] = Method.SYMMETRIC, - mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, - **np_kwargs, + mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, + **kwargs, ) -> None: self.spatial_size = spatial_size self.method: Method = look_up_option(method, Method) - self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode) - self.np_kwargs = np_kwargs + self.mode = mode + self.kwargs = kwargs def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int, int]]: spatial_size = fall_back_tuple(self.spatial_size, data_shape) @@ -168,15 +184,20 @@ def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int return pad_width return [(0, max(sp_i - data_shape[i], 0)) for i, sp_i in enumerate(spatial_size)] - def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] = None) -> NdarrayTensor: + def __call__( + self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None + ) -> NdarrayOrTensor: """ Args: img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - One of the listed string values or a user supplied function. Defaults to ``self.mode``. + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to `self.mode`. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + """ data_pad_width = self._determine_data_pad_width(img.shape[1:]) all_pad_width = [(0, 0)] + data_pad_width @@ -184,8 +205,7 @@ def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] # all zeros, skip padding return img - mode = look_up_option(mode or self.mode, NumpyPadMode) - padder = Pad(all_pad_width, mode, **self.np_kwargs) + padder = Pad(all_pad_width, mode or self.mode, **self.kwargs) return padder(img) @@ -204,13 +224,14 @@ class BorderPad(Transform): for example, image shape(CHW) is [1, 4, 4], spatial_border is [1, 2, 3, 4], pad top of H dim with 1, pad bottom of H dim with 2, pad left of W dim with 3, pad right of W dim with 4. the result shape is [1, 7, 11]. - - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ @@ -219,22 +240,26 @@ class BorderPad(Transform): def __init__( self, spatial_border: Union[Sequence[int], int], - mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, - **np_kwargs, + mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, + **kwargs, ) -> None: self.spatial_border = spatial_border - self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode) - self.np_kwargs = np_kwargs + self.mode = mode + self.kwargs = kwargs - def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] = None) -> NdarrayTensor: + def __call__( + self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None + ) -> NdarrayOrTensor: """ Args: img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - One of the listed string values or a user supplied function. Defaults to ``self.mode``. + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to `self.mode`. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html Raises: ValueError: When ``self.spatial_border`` does not contain ints. @@ -261,8 +286,7 @@ def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] ) all_pad_width = [(0, 0)] + data_pad_width - mode = look_up_option(mode or self.mode, NumpyPadMode) - padder = Pad(all_pad_width, mode, **self.np_kwargs) + padder = Pad(all_pad_width, mode or self.mode, **self.kwargs) return padder(img) @@ -276,48 +300,50 @@ class DivisiblePad(Transform): def __init__( self, k: Union[Sequence[int], int], - mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, method: Union[Method, str] = Method.SYMMETRIC, - **np_kwargs, + **kwargs, ) -> None: """ Args: k: the target k for each spatial dimension. if `k` is negative or 0, the original size is preserved. if `k` is an int, the same `k` be applied to all the input spatial dimensions. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html method: {``"symmetric"``, ``"end"``} Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. See also :py:class:`monai.transforms.SpatialPad` """ self.k = k self.mode: NumpyPadMode = NumpyPadMode(mode) self.method: Method = Method(method) - self.np_kwargs = np_kwargs + self.kwargs = kwargs - def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] = None) -> NdarrayTensor: + def __call__( + self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None + ) -> NdarrayOrTensor: """ Args: img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - One of the listed string values or a user supplied function. Defaults to ``self.mode``. + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to `self.mode`. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + """ new_size = compute_divisible_spatial_size(spatial_shape=img.shape[1:], k=self.k) - spatial_pad = SpatialPad( - spatial_size=new_size, - method=self.method, - mode=mode or self.mode, - **self.np_kwargs, - ) + spatial_pad = SpatialPad(spatial_size=new_size, method=self.method, mode=mode or self.mode, **self.kwargs) return spatial_pad(img) @@ -336,12 +362,14 @@ class SpatialCrop(Transform): - the start and end coordinates of the ROI """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, - roi_center: Union[Sequence[int], np.ndarray, None] = None, - roi_size: Union[Sequence[int], np.ndarray, None] = None, - roi_start: Union[Sequence[int], np.ndarray, None] = None, - roi_end: Union[Sequence[int], np.ndarray, None] = None, + roi_center: Union[Sequence[int], NdarrayOrTensor, None] = None, + roi_size: Union[Sequence[int], NdarrayOrTensor, None] = None, + roi_start: Union[Sequence[int], NdarrayOrTensor, None] = None, + roi_end: Union[Sequence[int], NdarrayOrTensor, None] = None, roi_slices: Optional[Sequence[slice]] = None, ) -> None: """ @@ -354,28 +382,37 @@ def __init__( use the end coordinate of image. roi_slices: list of slices for each of the spatial dimensions. """ + roi_start_torch: torch.Tensor + if roi_slices: if not all(s.step is None or s.step == 1 for s in roi_slices): raise ValueError("Only slice steps of 1/None are currently supported") self.slices = list(roi_slices) else: if roi_center is not None and roi_size is not None: - roi_center = np.asarray(roi_center, dtype=np.int16) - roi_size = np.asarray(roi_size, dtype=np.int16) - roi_start_np = np.maximum(roi_center - np.floor_divide(roi_size, 2), 0) - roi_end_np = np.maximum(roi_start_np + roi_size, roi_start_np) + roi_center, *_ = convert_data_type( + data=roi_center, output_type=torch.Tensor, dtype=torch.int16, wrap_sequence=True + ) + roi_size, *_ = convert_to_dst_type(src=roi_size, dst=roi_center, wrap_sequence=True) + _zeros = torch.zeros_like(roi_center) # type: ignore + roi_start_torch = maximum(roi_center - floor_divide(roi_size, 2), _zeros) # type: ignore + roi_end_torch = maximum(roi_start_torch + roi_size, roi_start_torch) else: if roi_start is None or roi_end is None: raise ValueError("Please specify either roi_center, roi_size or roi_start, roi_end.") - roi_start_np = np.maximum(np.asarray(roi_start, dtype=np.int16), 0) - roi_end_np = np.maximum(np.asarray(roi_end, dtype=np.int16), roi_start_np) - # Allow for 1D by converting back to np.array (since np.maximum will convert to int) - roi_start_np = roi_start_np if isinstance(roi_start_np, np.ndarray) else np.array([roi_start_np]) - roi_end_np = roi_end_np if isinstance(roi_end_np, np.ndarray) else np.array([roi_end_np]) - # convert to slices - self.slices = [slice(s, e) for s, e in zip(roi_start_np, roi_end_np)] - - def __call__(self, img: Union[np.ndarray, torch.Tensor]): + roi_start_torch, *_ = convert_data_type( # type: ignore + data=roi_start, output_type=torch.Tensor, dtype=torch.int16, wrap_sequence=True + ) + roi_start_torch = maximum(roi_start_torch, torch.zeros_like(roi_start_torch)) # type: ignore + roi_end_torch, *_ = convert_to_dst_type(src=roi_end, dst=roi_start_torch, wrap_sequence=True) + roi_end_torch = maximum(roi_end_torch, roi_start_torch) + # convert to slices (accounting for 1d) + if roi_start_torch.numel() == 1: + self.slices = [slice(int(roi_start_torch.item()), int(roi_end_torch.item()))] + else: + self.slices = [slice(int(s), int(e)) for s, e in zip(roi_start_torch.tolist(), roi_end_torch.tolist())] + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -400,10 +437,12 @@ class CenterSpatialCrop(Transform): the spatial size of output data will be [32, 40, 40]. """ + backend = SpatialCrop.backend + def __init__(self, roi_size: Union[Sequence[int], int]) -> None: self.roi_size = roi_size - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -424,10 +463,12 @@ class CenterScaleCrop(Transform): """ + backend = CenterSpatialCrop.backend + def __init__(self, roi_scale: Union[Sequence[float], float]): self.roi_scale = roi_scale - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: img_size = img.shape[1:] ndim = len(img_size) roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] @@ -459,6 +500,8 @@ class RandSpatialCrop(Randomizable, Transform): if True, the actual size is sampled from `randint(roi_size, max_roi_size + 1)`. """ + backend = CenterSpatialCrop.backend + def __init__( self, roi_size: Union[Sequence[int], int], @@ -479,19 +522,19 @@ def randomize(self, img_size: Sequence[int]) -> None: max_size = img_size if self.max_roi_size is None else fall_back_tuple(self.max_roi_size, img_size) if any(i > j for i, j in zip(self._size, max_size)): raise ValueError(f"min ROI size: {self._size} is bigger than max ROI size: {max_size}.") - self._size = tuple((self.R.randint(low=self._size[i], high=max_size[i] + 1) for i in range(len(img_size)))) + self._size = tuple(self.R.randint(low=self._size[i], high=max_size[i] + 1) for i in range(len(img_size))) if self.random_center: valid_size = get_valid_patch_size(img_size, self._size) self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ self.randomize(img.shape[1:]) if self._size is None: - raise AssertionError + raise RuntimeError("self._size not specified.") if self.random_center: return img[self._slices] cropper = CenterSpatialCrop(self._size) @@ -530,7 +573,7 @@ def __init__( self.roi_scale = roi_scale self.max_roi_scale = max_roi_scale - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -576,6 +619,8 @@ class RandSpatialCropSamples(Randomizable, Transform): """ + backend = RandSpatialCrop.backend + def __init__( self, roi_size: Union[Sequence[int], int], @@ -591,15 +636,15 @@ def __init__( def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "Randomizable": - super().set_random_state(seed=seed, state=state) - self.cropper.set_random_state(state=self.R) + ) -> "RandSpatialCropSamples": + super().set_random_state(seed, state) + self.cropper.set_random_state(seed, state) return self def randomize(self, data: Optional[Any] = None) -> None: pass - def __call__(self, img: np.ndarray) -> List[np.ndarray]: + def __call__(self, img: NdarrayOrTensor) -> List[NdarrayOrTensor]: """ Apply the transform to `img`, assuming `img` is channel-first and cropping doesn't change the channel dim. @@ -639,6 +684,8 @@ def threshold_at_one(x): """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, select_fn: Callable = is_positive, @@ -646,7 +693,7 @@ def __init__( margin: Union[Sequence[int], int] = 0, return_coords: bool = False, k_divisible: Union[Sequence[int], int] = 1, - mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = NumpyPadMode.CONSTANT, **np_kwargs, ) -> None: """ @@ -658,10 +705,12 @@ def __init__( return_coords: whether return the coordinates of spatial bounding box for foreground. k_divisible: make each spatial dimension to be divisible by k, default to 1. if `k_divisible` is an int, the same `k` be applied to all the input spatial dimensions. - mode: padding mode {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - one of the listed string values or a user supplied function. Defaults to ``"constant"``. - see also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html @@ -674,18 +723,18 @@ def __init__( self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode) self.np_kwargs = np_kwargs - def compute_bounding_box(self, img: np.ndarray): + def compute_bounding_box(self, img: NdarrayOrTensor): """ Compute the start points and end points of bounding box to crop. And adjust bounding box coords to be divisible by `k`. """ box_start, box_end = generate_spatial_bounding_box(img, self.select_fn, self.channel_indices, self.margin) - box_start_ = np.asarray(box_start, dtype=np.int16) - box_end_ = np.asarray(box_end, dtype=np.int16) + box_start_, *_ = convert_data_type(box_start, output_type=np.ndarray, dtype=np.int16, wrap_sequence=True) + box_end_, *_ = convert_data_type(box_end, output_type=np.ndarray, dtype=np.int16, wrap_sequence=True) orig_spatial_size = box_end_ - box_start_ # make the spatial size divisible by `k` - spatial_size = np.asarray(compute_divisible_spatial_size(spatial_shape=orig_spatial_size, k=self.k_divisible)) + spatial_size = np.asarray(compute_divisible_spatial_size(orig_spatial_size.tolist(), k=self.k_divisible)) # update box_start and box_end box_start_ = box_start_ - np.floor_divide(np.asarray(spatial_size) - orig_spatial_size, 2) box_end_ = box_start_ + spatial_size @@ -693,10 +742,10 @@ def compute_bounding_box(self, img: np.ndarray): def crop_pad( self, - img: np.ndarray, + img: NdarrayOrTensor, box_start: np.ndarray, box_end: np.ndarray, - mode: Optional[Union[NumpyPadMode, str]] = None, + mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, ): """ Crop and pad based on the bounding box. @@ -708,7 +757,7 @@ def crop_pad( pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) return BorderPad(spatial_border=pad, mode=mode or self.mode, **self.np_kwargs)(cropped) - def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None): + def __call__(self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, str]] = None): """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't change the channel dim. @@ -734,20 +783,25 @@ class RandWeightedCrop(Randomizable, Transform): It should be a single-channel array in shape, for example, `(1, spatial_dim_0, spatial_dim_1, ...)`. """ + backend = SpatialCrop.backend + def __init__( - self, spatial_size: Union[Sequence[int], int], num_samples: int = 1, weight_map: Optional[np.ndarray] = None + self, + spatial_size: Union[Sequence[int], int], + num_samples: int = 1, + weight_map: Optional[NdarrayOrTensor] = None, ): self.spatial_size = ensure_tuple(spatial_size) self.num_samples = int(num_samples) self.weight_map = weight_map self.centers: List[np.ndarray] = [] - def randomize(self, weight_map: np.ndarray) -> None: + def randomize(self, weight_map: NdarrayOrTensor) -> None: self.centers = weighted_patch_samples( spatial_size=self.spatial_size, w=weight_map[0], n_samples=self.num_samples, r_state=self.R ) # using only the first channel as weight map - def __call__(self, img: np.ndarray, weight_map: Optional[np.ndarray] = None) -> List[np.ndarray]: + def __call__(self, img: NdarrayOrTensor, weight_map: Optional[NdarrayOrTensor] = None) -> List[NdarrayOrTensor]: """ Args: img: input image to sample patches from. assuming `img` is a channel-first array. @@ -764,9 +818,10 @@ def __call__(self, img: np.ndarray, weight_map: Optional[np.ndarray] = None) -> raise ValueError("weight map must be provided for weighted patch sampling.") if img.shape[1:] != weight_map.shape[1:]: raise ValueError(f"image and weight map spatial shape mismatch: {img.shape[1:]} vs {weight_map.shape[1:]}.") + self.randomize(weight_map) _spatial_size = fall_back_tuple(self.spatial_size, weight_map.shape[1:]) - results = [] + results: List[NdarrayOrTensor] = [] for center in self.centers: cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size) results.append(cropper(img)) @@ -816,6 +871,9 @@ class RandCropByPosNegLabel(Randomizable, Transform): `image_threshold`, and randomly select crop centers based on them, need to provide `fg_indices` and `bg_indices` together, expect to be 1 dim array of spatial indices after flattening. a typical usage is to call `FgBgToIndices` transform first and cache the results. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will be set to + match the cropped size (i.e., no cropping in that dimension). Raises: ValueError: When ``pos`` or ``neg`` are negative. @@ -823,17 +881,20 @@ class RandCropByPosNegLabel(Randomizable, Transform): """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, spatial_size: Union[Sequence[int], int], - label: Optional[np.ndarray] = None, + label: Optional[NdarrayOrTensor] = None, pos: float = 1.0, neg: float = 1.0, num_samples: int = 1, - image: Optional[np.ndarray] = None, + image: Optional[NdarrayOrTensor] = None, image_threshold: float = 0.0, - fg_indices: Optional[np.ndarray] = None, - bg_indices: Optional[np.ndarray] = None, + fg_indices: Optional[NdarrayOrTensor] = None, + bg_indices: Optional[NdarrayOrTensor] = None, + allow_smaller: bool = False, ) -> None: self.spatial_size = ensure_tuple(spatial_size) self.label = label @@ -845,16 +906,17 @@ def __init__( self.num_samples = num_samples self.image = image self.image_threshold = image_threshold - self.centers: Optional[List[List[np.ndarray]]] = None + self.centers: Optional[List[List[int]]] = None self.fg_indices = fg_indices self.bg_indices = bg_indices + self.allow_smaller = allow_smaller def randomize( self, - label: np.ndarray, - fg_indices: Optional[np.ndarray] = None, - bg_indices: Optional[np.ndarray] = None, - image: Optional[np.ndarray] = None, + label: NdarrayOrTensor, + fg_indices: Optional[NdarrayOrTensor] = None, + bg_indices: Optional[NdarrayOrTensor] = None, + image: Optional[NdarrayOrTensor] = None, ) -> None: self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) if fg_indices is None or bg_indices is None: @@ -867,17 +929,24 @@ def randomize( fg_indices_ = fg_indices bg_indices_ = bg_indices self.centers = generate_pos_neg_label_crop_centers( - self.spatial_size, self.num_samples, self.pos_ratio, label.shape[1:], fg_indices_, bg_indices_, self.R + self.spatial_size, + self.num_samples, + self.pos_ratio, + label.shape[1:], + fg_indices_, + bg_indices_, + self.R, + self.allow_smaller, ) def __call__( self, - img: np.ndarray, - label: Optional[np.ndarray] = None, - image: Optional[np.ndarray] = None, - fg_indices: Optional[np.ndarray] = None, - bg_indices: Optional[np.ndarray] = None, - ) -> List[np.ndarray]: + img: NdarrayOrTensor, + label: Optional[NdarrayOrTensor] = None, + image: Optional[NdarrayOrTensor] = None, + fg_indices: Optional[NdarrayOrTensor] = None, + bg_indices: Optional[NdarrayOrTensor] = None, + ) -> List[NdarrayOrTensor]: """ Args: img: input data to crop samples from based on the pos/neg ratio of `label` and `image`. @@ -900,10 +969,10 @@ def __call__( image = self.image self.randomize(label, fg_indices, bg_indices, image) - results: List[np.ndarray] = [] + results: List[NdarrayOrTensor] = [] if self.centers is not None: for center in self.centers: - cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore + cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size) results.append(cropper(img)) return results @@ -965,19 +1034,25 @@ class RandCropByLabelClasses(Randomizable, Transform): `image_threshold`, and randomly select crop centers based on them, expect to be 1 dim array of spatial indices after flattening. a typical usage is to call `ClassesToIndices` transform first and cache the results for better performance. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will remain + unchanged. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, spatial_size: Union[Sequence[int], int], ratios: Optional[List[Union[float, int]]] = None, - label: Optional[np.ndarray] = None, + label: Optional[NdarrayOrTensor] = None, num_classes: Optional[int] = None, num_samples: int = 1, - image: Optional[np.ndarray] = None, + image: Optional[NdarrayOrTensor] = None, image_threshold: float = 0.0, - indices: Optional[List[np.ndarray]] = None, + indices: Optional[List[NdarrayOrTensor]] = None, + allow_smaller: bool = False, ) -> None: self.spatial_size = ensure_tuple(spatial_size) self.ratios = ratios @@ -986,17 +1061,18 @@ def __init__( self.num_samples = num_samples self.image = image self.image_threshold = image_threshold - self.centers: Optional[List[List[np.ndarray]]] = None + self.centers: Optional[List[List[int]]] = None self.indices = indices + self.allow_smaller = allow_smaller def randomize( self, - label: np.ndarray, - indices: Optional[List[np.ndarray]] = None, - image: Optional[np.ndarray] = None, + label: NdarrayOrTensor, + indices: Optional[List[NdarrayOrTensor]] = None, + image: Optional[NdarrayOrTensor] = None, ) -> None: self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) - indices_: List[np.ndarray] + indices_: Sequence[NdarrayOrTensor] if indices is None: if self.indices is not None: indices_ = self.indices @@ -1005,16 +1081,16 @@ def randomize( else: indices_ = indices self.centers = generate_label_classes_crop_centers( - self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R + self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R, self.allow_smaller ) def __call__( self, - img: np.ndarray, - label: Optional[np.ndarray] = None, - image: Optional[np.ndarray] = None, - indices: Optional[List[np.ndarray]] = None, - ) -> List[np.ndarray]: + img: NdarrayOrTensor, + label: Optional[NdarrayOrTensor] = None, + image: Optional[NdarrayOrTensor] = None, + indices: Optional[List[NdarrayOrTensor]] = None, + ) -> List[NdarrayOrTensor]: """ Args: img: input data to crop samples from based on the ratios of every class, assumes `img` is a @@ -1033,10 +1109,10 @@ def __call__( image = self.image self.randomize(label, indices, image) - results: List[np.ndarray] = [] + results: List[NdarrayOrTensor] = [] if self.centers is not None: for center in self.centers: - cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore + cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) results.append(cropper(img)) return results @@ -1063,6 +1139,8 @@ class ResizeWithPadOrCrop(Transform): """ + backend = list(set(SpatialPad.backend) & set(CenterSpatialCrop.backend)) + def __init__( self, spatial_size: Union[Sequence[int], int], @@ -1073,7 +1151,7 @@ def __init__( self.padder = SpatialPad(spatial_size=spatial_size, method=method, mode=mode, **np_kwargs) self.cropper = CenterSpatialCrop(roi_size=spatial_size) - def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, str]] = None) -> NdarrayOrTensor: """ Args: img: data to pad or crop, assuming `img` is channel-first and @@ -1084,7 +1162,7 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N If None, defaults to the ``mode`` in construction. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ - return self.padder(self.cropper(img), mode=mode) + return self.padder(self.cropper(img), mode=mode) # type: ignore class BoundingRect(Transform): @@ -1111,10 +1189,12 @@ class BoundingRect(Transform): select_fn: function to select expected foreground, default is to select values > 0. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, select_fn: Callable = is_positive) -> None: self.select_fn = select_fn - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> np.ndarray: """ See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`. """ diff --git a/monai/transforms/croppad/batch.py b/monai/transforms/croppad/batch.py index 956dff7881..6edaf4622d 100644 --- a/monai/transforms/croppad/batch.py +++ b/monai/transforms/croppad/batch.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,15 +20,11 @@ import torch from monai.data.utils import list_data_collate -from monai.transforms.compose import Compose from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad from monai.transforms.inverse import InvertibleTransform -from monai.transforms.utility.array import ToTensor -from monai.utils.enums import InverseKeys, Method, NumpyPadMode +from monai.utils.enums import Method, NumpyPadMode, TraceKeys -__all__ = [ - "PadListDataCollate", -] +__all__ = ["PadListDataCollate"] def replace_element(to_replace, batch, idx, key_or_idx): @@ -97,20 +93,12 @@ def __call__(self, batch: Any): # If all same size, skip if np.all(np.array(max_shapes).min(axis=0) == max_shape): continue - # Do we need to convert output to Tensor? - output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor) - - # Use `SpatialPadd` or `SpatialPad` to match sizes - # Default params are central padding, padding with 0's - # If input is dictionary, use the dictionary version so that the transformation is recorded + # Use `SpatialPad` to match sizes, Default params are central padding, padding with 0's padder = SpatialPad(spatial_size=max_shape, method=self.method, mode=self.mode, **self.np_kwargs) - transform = padder if not output_to_tensor else Compose([padder, ToTensor()]) - for idx, batch_i in enumerate(batch): - im = batch_i[key_or_idx] - orig_size = im.shape[1:] - padded = transform(batch_i[key_or_idx]) + orig_size = batch_i[key_or_idx].shape[1:] + padded = padder(batch_i[key_or_idx]) batch = replace_element(padded, batch, idx, key_or_idx) # If we have a dictionary of data, append to list @@ -127,11 +115,13 @@ def inverse(data: dict) -> Dict[Hashable, np.ndarray]: d = deepcopy(data) for key in d: - transform_key = str(key) + InverseKeys.KEY_SUFFIX + transform_key = InvertibleTransform.trace_key(key) if transform_key in d: transform = d[transform_key][-1] - if transform[InverseKeys.CLASS_NAME] == PadListDataCollate.__name__: - d[key] = CenterSpatialCrop(transform["orig_size"])(d[key]) + if not isinstance(transform, Dict): + continue + if transform.get(TraceKeys.CLASS_NAME) == PadListDataCollate.__name__: + d[key] = CenterSpatialCrop(transform.get("orig_size", -1))(d[key]) # fallback to image size # remove transform d[transform_key].pop() return d diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 9e33ab2db1..03c75705ab 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,7 +25,7 @@ import numpy as np from monai.config import IndexSelection, KeysCollection -from monai.config.type_definitions import NdarrayTensor +from monai.config.type_definitions import NdarrayOrTensor from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.croppad.array import ( BorderPad, @@ -33,6 +33,8 @@ CenterSpatialCrop, CropForeground, DivisiblePad, + RandCropByLabelClasses, + RandCropByPosNegLabel, ResizeWithPadOrCrop, SpatialCrop, SpatialPad, @@ -49,11 +51,11 @@ weighted_patch_samples, ) from monai.utils import ImageMetaKey as Key -from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple -from monai.utils.enums import InverseKeys +from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple +from monai.utils.enums import TraceKeys __all__ = [ - "NumpyPadModeSequence", + "PadModeSequence", "SpatialPadd", "BorderPadd", "DivisiblePadd", @@ -96,9 +98,13 @@ "ResizeWithPadOrCropDict", "BoundingRectD", "BoundingRectDict", + "RandCropByLabelClassesd", + "RandCropByLabelClassesD", + "RandCropByLabelClassesDict", ] NumpyPadModeSequence = Union[Sequence[Union[NumpyPadMode, str]], NumpyPadMode, str] +PadModeSequence = Union[Sequence[Union[NumpyPadMode, PytorchPadMode, str]], NumpyPadMode, PytorchPadMode, str] class SpatialPadd(MapTransform, InvertibleTransform): @@ -114,9 +120,9 @@ def __init__( keys: KeysCollection, spatial_size: Union[Sequence[int], int], method: Union[Method, str] = Method.SYMMETRIC, - mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, + mode: PadModeSequence = NumpyPadMode.CONSTANT, allow_missing_keys: bool = False, - **np_kwargs, + **kwargs, ) -> None: """ Args: @@ -129,33 +135,35 @@ def __init__( the spatial size of output data will be [32, 30, 30]. method: {``"symmetric"``, ``"end"``} Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padder = SpatialPad(spatial_size, method, **np_kwargs) + self.padder = SpatialPad(spatial_size, method, **kwargs) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, m in self.key_iterator(d, self.mode): self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) d[key] = self.padder(d[key], mode=m) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = transform[InverseKeys.ORIG_SIZE] + orig_size = transform[TraceKeys.ORIG_SIZE] if self.padder.method == Method.SYMMETRIC: current_size = d[key].shape[1:] roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) // 2 for r, i in zip(orig_size, current_size)] @@ -183,9 +191,9 @@ def __init__( self, keys: KeysCollection, spatial_border: Union[Sequence[int], int], - mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, + mode: PadModeSequence = NumpyPadMode.CONSTANT, allow_missing_keys: bool = False, - **np_kwargs, + **kwargs, ) -> None: """ Args: @@ -202,34 +210,36 @@ def __init__( pad bottom of H dim with 2, pad left of W dim with 3, pad right of W dim with 4. the result shape is [1, 7, 11]. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padder = BorderPad(spatial_border=spatial_border, **np_kwargs) + self.padder = BorderPad(spatial_border=spatial_border, **kwargs) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, m in self.key_iterator(d, self.mode): self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) d[key] = self.padder(d[key], mode=m) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) roi_start = np.array(self.padder.spatial_border) # Need to convert single value to [min1,min2,...] if roi_start.size == 1: @@ -237,7 +247,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # need to convert [min1,max1,min2,...] to [min1,min2,...] elif roi_start.size == 2 * orig_size.size: roi_start = roi_start[::2] - roi_end = np.array(transform[InverseKeys.ORIG_SIZE]) + roi_start + roi_end = np.array(transform[TraceKeys.ORIG_SIZE]) + roi_start inverse_transform = SpatialCrop(roi_start=roi_start, roi_end=roi_end) # Apply inverse transform @@ -260,10 +270,10 @@ def __init__( self, keys: KeysCollection, k: Union[Sequence[int], int], - mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, + mode: PadModeSequence = NumpyPadMode.CONSTANT, method: Union[Method, str] = Method.SYMMETRIC, allow_missing_keys: bool = False, - **np_kwargs, + **kwargs, ) -> None: """ Args: @@ -272,38 +282,40 @@ def __init__( k: the target k for each spatial dimension. if `k` is negative or 0, the original size is preserved. if `k` is an int, the same `k` be applied to all the input spatial dimensions. - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. method: {``"symmetric"``, ``"end"``} Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. allow_missing_keys: don't raise exception if key is missing. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. See also :py:class:`monai.transforms.SpatialPad` """ super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padder = DivisiblePad(k=k, method=method, **np_kwargs) + self.padder = DivisiblePad(k=k, method=method, **kwargs) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, m in self.key_iterator(d, self.mode): self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) d[key] = self.padder(d[key], mode=m) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) current_size = np.array(d[key].shape[1:]) roi_start = np.floor((current_size - orig_size) / 2) roi_end = orig_size + roi_start @@ -331,6 +343,8 @@ class SpatialCropd(MapTransform, InvertibleTransform): - the start and end coordinates of the ROI """ + backend = SpatialCrop.backend + def __init__( self, keys: KeysCollection, @@ -357,20 +371,20 @@ def __init__( super().__init__(keys, allow_missing_keys) self.cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): self.push_transform(d, key) d[key] = self.cropper(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) current_size = np.array(d[key].shape[1:]) # get required pad to start and end pad_to_start = np.array([s.indices(o)[0] for s, o in zip(self.cropper.slices, orig_size)]) @@ -404,13 +418,15 @@ class CenterSpatialCropd(MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = CenterSpatialCrop.backend + def __init__( self, keys: KeysCollection, roi_size: Union[Sequence[int], int], allow_missing_keys: bool = False ) -> None: super().__init__(keys, allow_missing_keys) self.cropper = CenterSpatialCrop(roi_size) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): orig_size = d[key].shape[1:] @@ -418,13 +434,13 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.push_transform(d, key, orig_size=orig_size) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) current_size = np.array(d[key].shape[1:]) pad_to_start = np.floor((orig_size - current_size) / 2).astype(int) # in each direction, if original size is even and current size is odd, += 1 @@ -454,16 +470,22 @@ class CenterScaleCropd(MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = CenterSpatialCrop.backend + def __init__( self, keys: KeysCollection, roi_scale: Union[Sequence[float], float], allow_missing_keys: bool = False ) -> None: super().__init__(keys, allow_missing_keys=allow_missing_keys) self.roi_scale = roi_scale - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + # use the spatial size of first image to scale, expect all images have the same spatial size - img_size = data[self.keys[0]].shape[1:] + img_size = d[first_key].shape[1:] # type: ignore ndim = len(img_size) roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] cropper = CenterSpatialCrop(roi_size) @@ -473,13 +495,13 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) current_size = np.array(d[key].shape[1:]) pad_to_start = np.floor((orig_size - current_size) / 2).astype(int) # in each direction, if original size is even and current size is odd, += 1 @@ -525,6 +547,8 @@ class RandSpatialCropd(Randomizable, MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = CenterSpatialCrop.backend + def __init__( self, keys: KeysCollection, @@ -553,11 +577,15 @@ def randomize(self, img_size: Sequence[int]) -> None: valid_size = get_valid_patch_size(img_size, self._size) self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize(d[self.keys[0]].shape[1:]) # image shape from the first data key + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + + self.randomize(d[first_key].shape[1:]) # type: ignore if self._size is None: - raise AssertionError + raise RuntimeError("self._size not specified.") for key in self.key_iterator(d): if self.random_center: self.push_transform(d, key, {"slices": [(i.start, i.stop) for i in self._slices[1:]]}) # type: ignore @@ -568,18 +596,18 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = cropper(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = transform[InverseKeys.ORIG_SIZE] + orig_size = transform[TraceKeys.ORIG_SIZE] random_center = self.random_center pad_to_start = np.empty((len(orig_size)), dtype=np.int32) pad_to_end = np.empty((len(orig_size)), dtype=np.int32) if random_center: - for i, _slice in enumerate(transform[InverseKeys.EXTRA_INFO]["slices"]): + for i, _slice in enumerate(transform[TraceKeys.EXTRA_INFO]["slices"]): pad_to_start[i] = _slice[0] pad_to_end[i] = orig_size[i] - _slice[1] else: @@ -615,7 +643,7 @@ class RandScaleCropd(RandSpatialCropd): roi_scale: if `random_size` is True, it specifies the minimum crop size: `roi_scale * image spatial size`. if `random_size` is False, it specifies the expected scale of image size to crop. e.g. [0.3, 0.4, 0.5]. If its components have non-positive values, will use `1.0` instead, which means the input image size. - max_roi_size: if `random_size` is True and `roi_scale` specifies the min crop region size, `max_roi_scale` + max_roi_scale: if `random_size` is True and `roi_scale` specifies the min crop region size, `max_roi_scale` can specify the max crop region size: `max_roi_scale * image spatial size`. if None, defaults to the input image size. if its components have non-positive values, will use `1.0` instead, which means the input image size. @@ -626,6 +654,8 @@ class RandScaleCropd(RandSpatialCropd): allow_missing_keys: don't raise exception if key is missing. """ + backend = RandSpatialCropd.backend + def __init__( self, keys: KeysCollection, @@ -643,12 +673,15 @@ def __init__( random_size=random_size, allow_missing_keys=allow_missing_keys, ) - MapTransform.__init__(self, keys, allow_missing_keys) self.roi_scale = roi_scale self.max_roi_scale = max_roi_scale - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - img_size = data[self.keys[0]].shape[1:] + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + first_key: Union[Hashable, List] = self.first_key(data) # type: ignore + if first_key == []: + return data # type: ignore + + img_size = data[first_key].shape[1:] # type: ignore ndim = len(img_size) self.roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] if self.max_roi_scale is not None: @@ -711,6 +744,8 @@ class RandSpatialCropSamplesd(Randomizable, MapTransform, InvertibleTransform): """ + backend = RandSpatialCropd.backend + def __init__( self, keys: KeysCollection, @@ -735,15 +770,15 @@ def __init__( def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "Randomizable": - super().set_random_state(seed=seed, state=state) - self.cropper.set_random_state(state=self.R) + ) -> "RandSpatialCropSamplesd": + super().set_random_state(seed, state) + self.cropper.set_random_state(seed, state) return self def randomize(self, data: Optional[Any] = None) -> None: pass - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: ret = [] for i in range(self.num_samples): d = dict(data) @@ -753,24 +788,24 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n cropped = self.cropper(d) # self.cropper will have added RandSpatialCropd to the list. Change to RandSpatialCropSamplesd for key in self.key_iterator(cropped): - cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.CLASS_NAME] = self.__class__.__name__ - cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.ID] = id(self) + cropped[self.trace_key(key)][-1][TraceKeys.CLASS_NAME] = self.__class__.__name__ # type: ignore + cropped[self.trace_key(key)][-1][TraceKeys.ID] = id(self) # type: ignore # add `patch_index` to the meta data for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): meta_key = meta_key or f"{key}_{meta_key_postfix}" if meta_key not in cropped: cropped[meta_key] = {} # type: ignore - cropped[meta_key][Key.PATCH_INDEX] = i + cropped[meta_key][Key.PATCH_INDEX] = i # type: ignore ret.append(cropped) return ret - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: d = deepcopy(dict(data)) # We changed the transform name from RandSpatialCropd to RandSpatialCropSamplesd # Need to revert that since we're calling RandSpatialCropd's inverse for key in self.key_iterator(d): - d[key + InverseKeys.KEY_SUFFIX][-1][InverseKeys.CLASS_NAME] = self.cropper.__class__.__name__ - d[key + InverseKeys.KEY_SUFFIX][-1][InverseKeys.ID] = id(self.cropper) + d[self.trace_key(key)][-1][TraceKeys.CLASS_NAME] = self.cropper.__class__.__name__ + d[self.trace_key(key)][-1][TraceKeys.ID] = id(self.cropper) context_manager = allow_missing_keys_mode if self.allow_missing_keys else _nullcontext with context_manager(self.cropper): return self.cropper.inverse(d) @@ -789,6 +824,8 @@ class CropForegroundd(MapTransform, InvertibleTransform): channels. And it can also add margin to every dim of the bounding box of foreground object. """ + backend = CropForeground.backend + def __init__( self, keys: KeysCollection, @@ -797,7 +834,7 @@ def __init__( channel_indices: Optional[IndexSelection] = None, margin: Union[Sequence[int], int] = 0, k_divisible: Union[Sequence[int], int] = 1, - mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, + mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = NumpyPadMode.CONSTANT, start_coord_key: str = "foreground_start_coord", end_coord_key: str = "foreground_end_coord", allow_missing_keys: bool = False, @@ -814,10 +851,12 @@ def __init__( margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. k_divisible: make each spatial dimension to be divisible by k, default to 1. if `k_divisible` is an int, the same `k` be applied to all the input spatial dimensions. - mode: padding mode {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - one of the listed string values or a user supplied function. Defaults to ``"constant"``. - see also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html it also can be a sequence of string, each element corresponds to a key in ``keys``. start_coord_key: key to record the start coordinate of spatial bounding box for foreground. end_coord_key: key to record the end coordinate of spatial bounding box for foreground. @@ -831,15 +870,11 @@ def __init__( self.start_coord_key = start_coord_key self.end_coord_key = end_coord_key self.cropper = CropForeground( - select_fn=select_fn, - channel_indices=channel_indices, - margin=margin, - k_divisible=k_divisible, - **np_kwargs, + select_fn=select_fn, channel_indices=channel_indices, margin=margin, k_divisible=k_divisible, **np_kwargs ) self.mode = ensure_tuple_rep(mode, len(self.keys)) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) box_start, box_end = self.cropper.compute_bounding_box(img=d[self.source_key]) d[self.start_coord_key] = box_start @@ -849,14 +884,14 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.cropper.crop_pad(img=d[key], box_start=box_start, box_end=box_end, mode=m) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE]) cur_size = np.asarray(d[key].shape[1:]) - extra_info = transform[InverseKeys.EXTRA_INFO] + extra_info = transform[TraceKeys.EXTRA_INFO] box_start = np.asarray(extra_info["box_start"]) box_end = np.asarray(extra_info["box_end"]) # first crop the padding part @@ -906,6 +941,8 @@ class RandWeightedCropd(Randomizable, MapTransform, InvertibleTransform): :py:class:`monai.transforms.RandWeightedCrop` """ + backend = SpatialCrop.backend + def __init__( self, keys: KeysCollection, @@ -928,18 +965,18 @@ def __init__( self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) self.centers: List[np.ndarray] = [] - def randomize(self, weight_map: np.ndarray) -> None: + def randomize(self, weight_map: NdarrayOrTensor) -> None: self.centers = weighted_patch_samples( spatial_size=self.spatial_size, w=weight_map[0], n_samples=self.num_samples, r_state=self.R ) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: d = dict(data) self.randomize(d[self.w_key]) _spatial_size = fall_back_tuple(self.spatial_size, d[self.w_key].shape[1:]) # initialize returned list with shallow copy to preserve key ordering - results: List[Dict[Hashable, np.ndarray]] = [dict(data) for _ in range(self.num_samples)] + results: List[Dict[Hashable, NdarrayOrTensor]] = [dict(data) for _ in range(self.num_samples)] # fill in the extra keys with unmodified data for i in range(self.num_samples): for key in set(data.keys()).difference(set(self.keys)): @@ -965,19 +1002,19 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n meta_key = meta_key or f"{key}_{meta_key_postfix}" if meta_key not in results[i]: results[i][meta_key] = {} # type: ignore - results[i][meta_key][Key.PATCH_INDEX] = i + results[i][meta_key][Key.PATCH_INDEX] = i # type: ignore return results - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE]) current_size = np.asarray(d[key].shape[1:]) - center = transform[InverseKeys.EXTRA_INFO]["center"] - cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) + center = transform[TraceKeys.EXTRA_INFO]["center"] + cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size) # get required pad to start and end pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)]) pad_to_end = orig_size - current_size - pad_to_start @@ -1040,6 +1077,9 @@ class RandCropByPosNegLabeld(Randomizable, MapTransform, InvertibleTransform): meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. used to add `patch_index` to the meta dict. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will be set to + match the cropped size (i.e., no cropping in that dimension). allow_missing_keys: don't raise exception if key is missing. Raises: @@ -1048,6 +1088,8 @@ class RandCropByPosNegLabeld(Randomizable, MapTransform, InvertibleTransform): """ + backend = RandCropByPosNegLabel.backend + def __init__( self, keys: KeysCollection, @@ -1062,6 +1104,7 @@ def __init__( bg_indices_key: Optional[str] = None, meta_keys: Optional[KeysCollection] = None, meta_key_postfix: str = "meta_dict", + allow_smaller: bool = False, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) @@ -1081,14 +1124,15 @@ def __init__( if len(self.keys) != len(self.meta_keys): raise ValueError("meta_keys should have the same length as keys.") self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self.centers: Optional[List[List[np.ndarray]]] = None + self.centers: Optional[List[List[int]]] = None + self.allow_smaller = allow_smaller def randomize( self, - label: np.ndarray, - fg_indices: Optional[np.ndarray] = None, - bg_indices: Optional[np.ndarray] = None, - image: Optional[np.ndarray] = None, + label: NdarrayOrTensor, + fg_indices: Optional[NdarrayOrTensor] = None, + bg_indices: Optional[NdarrayOrTensor] = None, + image: Optional[NdarrayOrTensor] = None, ) -> None: self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) if fg_indices is None or bg_indices is None: @@ -1097,10 +1141,17 @@ def randomize( fg_indices_ = fg_indices bg_indices_ = bg_indices self.centers = generate_pos_neg_label_crop_centers( - self.spatial_size, self.num_samples, self.pos_ratio, label.shape[1:], fg_indices_, bg_indices_, self.R + self.spatial_size, + self.num_samples, + self.pos_ratio, + label.shape[1:], + fg_indices_, + bg_indices_, + self.R, + self.allow_smaller, ) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: d = dict(data) label = d[self.label_key] image = d[self.image_key] if self.image_key else None @@ -1114,7 +1165,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n raise ValueError("no available ROI centers to crop.") # initialize returned list with shallow copy to preserve key ordering - results: List[Dict[Hashable, np.ndarray]] = [dict(d) for _ in range(self.num_samples)] + results: List[Dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(self.num_samples)] for i, center in enumerate(self.centers): # fill in the extra keys with unmodified data @@ -1122,7 +1173,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n results[i][key] = deepcopy(d[key]) for key in self.key_iterator(d): img = d[key] - cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore + cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) orig_size = img.shape[1:] results[i][key] = cropper(img) self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) @@ -1131,18 +1182,18 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n meta_key = meta_key or f"{key}_{meta_key_postfix}" if meta_key not in results[i]: results[i][meta_key] = {} # type: ignore - results[i][meta_key][Key.PATCH_INDEX] = i + results[i][meta_key][Key.PATCH_INDEX] = i # type: ignore return results - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE]) current_size = np.asarray(d[key].shape[1:]) - center = transform[InverseKeys.EXTRA_INFO]["center"] + center = transform[TraceKeys.EXTRA_INFO]["center"] cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore # get required pad to start and end pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)]) @@ -1230,10 +1281,15 @@ class RandCropByLabelClassesd(Randomizable, MapTransform, InvertibleTransform): meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. used to add `patch_index` to the meta dict. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will remain + unchanged. allow_missing_keys: don't raise exception if key is missing. """ + backend = RandCropByLabelClasses.backend + def __init__( self, keys: KeysCollection, @@ -1247,6 +1303,7 @@ def __init__( indices_key: Optional[str] = None, meta_keys: Optional[KeysCollection] = None, meta_key_postfix: str = "meta_dict", + allow_smaller: bool = False, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) @@ -1262,25 +1319,25 @@ def __init__( if len(self.keys) != len(self.meta_keys): raise ValueError("meta_keys should have the same length as keys.") self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self.centers: Optional[List[List[np.ndarray]]] = None + self.centers: Optional[List[List[int]]] = None + self.allow_smaller = allow_smaller def randomize( self, - label: np.ndarray, - indices: Optional[List[np.ndarray]] = None, - image: Optional[np.ndarray] = None, + label: NdarrayOrTensor, + indices: Optional[List[NdarrayOrTensor]] = None, + image: Optional[NdarrayOrTensor] = None, ) -> None: self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) - indices_: List[np.ndarray] if indices is None: indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) else: indices_ = indices self.centers = generate_label_classes_crop_centers( - self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R + self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R, self.allow_smaller ) - def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, NdarrayOrTensor]]: d = dict(data) label = d[self.label_key] image = d[self.image_key] if self.image_key else None @@ -1293,7 +1350,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarr raise ValueError("no available ROI centers to crop.") # initialize returned list with shallow copy to preserve key ordering - results: List[Dict[Hashable, np.ndarray]] = [dict(d) for _ in range(self.num_samples)] + results: List[Dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(self.num_samples)] for i, center in enumerate(self.centers): # fill in the extra keys with unmodified data @@ -1301,7 +1358,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarr results[i][key] = deepcopy(d[key]) for key in self.key_iterator(d): img = d[key] - cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore + cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) orig_size = img.shape[1:] results[i][key] = cropper(img) self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) @@ -1310,18 +1367,18 @@ def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarr meta_key = meta_key or f"{key}_{meta_key_postfix}" if meta_key not in results[i]: results[i][meta_key] = {} # type: ignore - results[i][meta_key][Key.PATCH_INDEX] = i + results[i][meta_key][Key.PATCH_INDEX] = i # type: ignore return results - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE]) current_size = np.asarray(d[key].shape[1:]) - center = transform[InverseKeys.EXTRA_INFO]["center"] + center = transform[TraceKeys.EXTRA_INFO]["center"] cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore # get required pad to start and end pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)]) @@ -1359,6 +1416,8 @@ class ResizeWithPadOrCropd(MapTransform, InvertibleTransform): """ + backend = ResizeWithPadOrCrop.backend + def __init__( self, keys: KeysCollection, @@ -1372,27 +1431,20 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, method=method, **np_kwargs) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, m in self.key_iterator(d, self.mode): orig_size = d[key].shape[1:] d[key] = self.padcropper(d[key], mode=m) - self.push_transform( - d, - key, - orig_size=orig_size, - extra_info={ - "mode": m.value if isinstance(m, Enum) else m, - }, - ) + self.push_transform(d, key, orig_size=orig_size, extra_info={"mode": m.value if isinstance(m, Enum) else m}) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) current_size = np.array(d[key].shape[1:]) # Unfortunately, we can't just use ResizeWithPadOrCrop with original size because of odd/even rounding. # Instead, we first pad any smaller dimensions, and then we crop any larger dimensions. @@ -1436,6 +1488,8 @@ class BoundingRectd(MapTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = BoundingRect.backend + def __init__( self, keys: KeysCollection, @@ -1447,7 +1501,7 @@ def __init__( self.bbox = BoundingRect(select_fn=select_fn) self.bbox_key_postfix = bbox_key_postfix - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: """ See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`. """ diff --git a/monai/transforms/intensity/__init__.py b/monai/transforms/intensity/__init__.py index 14ae193634..1e97f89407 100644 --- a/monai/transforms/intensity/__init__.py +++ b/monai/transforms/intensity/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 20d306be04..33da679df5 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,7 +13,7 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ -import copy +from abc import abstractmethod from collections.abc import Iterable from functools import partial from typing import Any, Callable, List, Optional, Sequence, Tuple, Union @@ -23,13 +23,13 @@ import torch from monai.config import DtypeLike -from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor +from monai.config.type_definitions import NdarrayOrTensor from monai.data.utils import get_random_patch, get_valid_patch_size from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter from monai.transforms.transform import RandomizableTransform, Transform from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array +from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where from monai.utils import ( - PT_BEFORE_1_7, InvalidPyTorchVersionError, convert_data_type, convert_to_dst_type, @@ -37,7 +37,9 @@ ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, + pytorch_after, ) +from monai.utils.deprecate_utils import deprecated_arg from monai.utils.enums import TransformBackends from monai.utils.type_conversion import convert_to_tensor, get_equivalent_dtype @@ -69,9 +71,10 @@ "RandGibbsNoise", "KSpaceSpikeNoise", "RandKSpaceSpikeNoise", + "RandCoarseTransform", "RandCoarseDropout", + "RandCoarseShuffle", "HistogramNormalize", - "LocalPatchShuffling", ] @@ -83,30 +86,42 @@ class RandGaussianNoise(RandomizableTransform): prob: Probability to add Gaussian noise. mean: Mean or “centre” of the distribution. std: Standard deviation (spread) of distribution. + dtype: output data type, if None, same as input image. defaults to float32. + """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, prob: float = 0.1, mean: Union[Sequence[float], float] = 0.0, std: float = 0.1) -> None: + def __init__(self, prob: float = 0.1, mean: float = 0.0, std: float = 0.1, dtype: DtypeLike = np.float32) -> None: RandomizableTransform.__init__(self, prob) self.mean = mean self.std = std - self._noise: np.ndarray + self.dtype = dtype + self.noise: Optional[np.ndarray] = None - def randomize(self, im_shape: Sequence[int]) -> None: + def randomize(self, img: NdarrayOrTensor, mean: Optional[float] = None) -> None: super().randomize(None) - self._noise = self.R.normal(self.mean, self.R.uniform(0, self.std), size=im_shape) + if not self._do_transform: + return None + rand_std = self.R.uniform(0, self.std) + noise = self.R.normal(self.mean if mean is None else mean, rand_std, size=img.shape) + # noise is float64 array, convert to the output dtype to save memory + self.noise, *_ = convert_data_type(noise, dtype=self.dtype) # type: ignore - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, mean: Optional[float] = None, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - self.randomize(img.shape) - if self._noise is None: - raise RuntimeError("randomized factor should not be None.") + if randomize: + self.randomize(img=img, mean=self.mean if mean is None else mean) + if not self._do_transform: return img - noise, *_ = convert_to_dst_type(self._noise, img) + + if self.noise is None: + raise RuntimeError("please call the `randomize()` function first.") + img, *_ = convert_data_type(img, dtype=self.dtype) + noise, *_ = convert_to_dst_type(self.noise, img) return img + noise @@ -114,10 +129,10 @@ class RandRicianNoise(RandomizableTransform): """ Add Rician noise to image. Rician noise in MRI is the result of performing a magnitude operation on complex - data with Gaussian noise of the same variance in both channels, as described in `Noise in Magnitude Magnetic Resonance Images - `_. This transform is adapted from - `DIPY`_. See also: `The rician distribution of noisy mri data - `_. + data with Gaussian noise of the same variance in both channels, as described in + `Noise in Magnitude Magnetic Resonance Images `_. + This transform is adapted from `DIPY `_. + See also: `The rician distribution of noisy mri data `_. Args: prob: Probability to add Rician noise. @@ -131,6 +146,8 @@ class RandRicianNoise(RandomizableTransform): histogram. sample_std: If True, sample the spread of the Gaussian distributions uniformly from 0 to std. + dtype: output data type, if None, same as input image. defaults to float32. + """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -143,6 +160,7 @@ def __init__( channel_wise: bool = False, relative: bool = False, sample_std: bool = True, + dtype: DtypeLike = np.float32, ) -> None: RandomizableTransform.__init__(self, prob) self.prob = prob @@ -151,15 +169,16 @@ def __init__( self.channel_wise = channel_wise self.relative = relative self.sample_std = sample_std + self.dtype = dtype self._noise1: NdarrayOrTensor self._noise2: NdarrayOrTensor - def _add_noise(self, img: NdarrayTensor, mean: float, std: float): + def _add_noise(self, img: NdarrayOrTensor, mean: float, std: float): dtype_np = get_equivalent_dtype(img.dtype, np.ndarray) im_shape = img.shape _std = self.R.uniform(0, std) if self.sample_std else std - self._noise1 = self.R.normal(mean, _std, size=im_shape).astype(dtype_np) - self._noise2 = self.R.normal(mean, _std, size=im_shape).astype(dtype_np) + self._noise1 = self.R.normal(mean, _std, size=im_shape).astype(dtype_np, copy=False) + self._noise2 = self.R.normal(mean, _std, size=im_shape).astype(dtype_np, copy=False) if isinstance(img, torch.Tensor): n1 = torch.tensor(self._noise1, device=img.device) n2 = torch.tensor(self._noise2, device=img.device) @@ -167,13 +186,17 @@ def _add_noise(self, img: NdarrayTensor, mean: float, std: float): return np.sqrt((img + self._noise1) ** 2 + self._noise2 ** 2) - def __call__(self, img: NdarrayTensor) -> NdarrayTensor: + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - super().randomize(None) + if randomize: + super().randomize(None) + if not self._do_transform: return img + + img, *_ = convert_data_type(img, dtype=self.dtype) if self.channel_wise: _mean = ensure_tuple_rep(self.mean, len(img)) _std = ensure_tuple_rep(self.std, len(img)) @@ -211,9 +234,9 @@ def __call__(self, img: NdarrayOrTensor, offset: Optional[float] = None) -> Ndar offset = self.offset if offset is None else offset out = img + offset - if isinstance(out, torch.Tensor): - return out.type(img.dtype) - return out.astype(img.dtype) # type: ignore + out, *_ = convert_data_type(data=out, dtype=img.dtype) + + return out class RandShiftIntensity(RandomizableTransform): @@ -241,10 +264,12 @@ def __init__(self, offsets: Union[Tuple[float, float], float], prob: float = 0.1 self._shfiter = ShiftIntensity(self._offset) def randomize(self, data: Optional[Any] = None) -> None: - self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) super().randomize(None) + if not self._do_transform: + return None + self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) - def __call__(self, img: NdarrayOrTensor, factor: Optional[float] = None) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, factor: Optional[float] = None, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`. @@ -254,9 +279,12 @@ def __call__(self, img: NdarrayOrTensor, factor: Optional[float] = None) -> Ndar can be some image specific value at runtime, like: max(img), etc. """ - self.randomize() + if randomize: + self.randomize() + if not self._do_transform: return img + return self._shfiter(img, self._offset if factor is None else self._offset * factor) @@ -272,7 +300,7 @@ class StdShiftIntensity(Transform): nonzero: whether only count non-zero values. channel_wise: if True, calculate on each channel separately. Please ensure that the first dimension represents the channel of the image if True. - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -305,7 +333,8 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - img, *_ = convert_data_type(img, dtype=self.dtype) + if self.dtype is not None: + img, *_ = convert_data_type(img, dtype=self.dtype) if self.channel_wise: for i, d in enumerate(img): img[i] = self._stdshift(d) # type: ignore @@ -337,7 +366,7 @@ def __init__( prob: probability of std shift. nonzero: whether only count non-zero values. channel_wise: if True, calculate on each channel separately. - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. """ RandomizableTransform.__init__(self, prob) @@ -353,20 +382,25 @@ def __init__( self.dtype = dtype def randomize(self, data: Optional[Any] = None) -> None: - self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) super().randomize(None) + if not self._do_transform: + return None + self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - self.randomize() + if randomize: + self.randomize() + if not self._do_transform: return img + shifter = StdShiftIntensity( factor=self.factor, nonzero=self.nonzero, channel_wise=self.channel_wise, dtype=self.dtype ) - return shifter(img) + return shifter(img=img) class ScaleIntensity(Transform): @@ -378,18 +412,28 @@ class ScaleIntensity(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( - self, minv: Optional[float] = 0.0, maxv: Optional[float] = 1.0, factor: Optional[float] = None + self, + minv: Optional[float] = 0.0, + maxv: Optional[float] = 1.0, + factor: Optional[float] = None, + channel_wise: bool = False, + dtype: DtypeLike = np.float32, ) -> None: """ Args: minv: minimum value of output data. maxv: maximum value of output data. factor: factor scale by ``v = v * (1 + factor)``. In order to use - this parameter, please set `minv` and `maxv` into None. + this parameter, please set both `minv` and `maxv` into None. + channel_wise: if True, scale on each channel separately. Please ensure + that the first dimension represents the channel of the image if True. + dtype: output data type, if None, same as input image. defaults to float32. """ self.minv = minv self.maxv = maxv self.factor = factor + self.channel_wise = channel_wise + self.dtype = dtype def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ @@ -399,13 +443,18 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: ValueError: When ``self.minv=None`` or ``self.maxv=None`` and ``self.factor=None``. Incompatible values. """ - if self.minv is not None and self.maxv is not None: - return rescale_array(img, self.minv, self.maxv, img.dtype) - if self.factor is not None: - out = img * (1 + self.factor) - out, *_ = convert_data_type(out, dtype=img.dtype) - return out - raise ValueError("Incompatible values: minv=None or maxv=None and factor=None.") + ret: NdarrayOrTensor + if self.minv is not None or self.maxv is not None: + if self.channel_wise: + out = [rescale_array(d, self.minv, self.maxv, dtype=self.dtype) for d in img] + ret = torch.stack(out) if isinstance(img, torch.Tensor) else np.stack(out) # type: ignore + else: + ret = rescale_array(img, self.minv, self.maxv, dtype=self.dtype) + else: + ret = (img * (1 + self.factor)) if self.factor is not None else img + + ret, *_ = convert_data_type(ret, dtype=self.dtype or img.dtype) + return ret class RandScaleIntensity(RandomizableTransform): @@ -416,12 +465,15 @@ class RandScaleIntensity(RandomizableTransform): backend = ScaleIntensity.backend - def __init__(self, factors: Union[Tuple[float, float], float], prob: float = 0.1) -> None: + def __init__( + self, factors: Union[Tuple[float, float], float], prob: float = 0.1, dtype: DtypeLike = np.float32 + ) -> None: """ Args: factors: factor range to randomly scale by ``v = v * (1 + factor)``. if single number, factor value is picked from (-factors, factors). prob: probability of scale. + dtype: output data type, if None, same as input image. defaults to float32. """ RandomizableTransform.__init__(self, prob) @@ -432,20 +484,25 @@ def __init__(self, factors: Union[Tuple[float, float], float], prob: float = 0.1 else: self.factors = (min(factors), max(factors)) self.factor = self.factors[0] + self.dtype = dtype def randomize(self, data: Optional[Any] = None) -> None: - self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) super().randomize(None) + if not self._do_transform: + return None + self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - self.randomize() + if randomize: + self.randomize() + if not self._do_transform: return img - scaler = ScaleIntensity(minv=None, maxv=None, factor=self.factor) - return scaler(img) + + return ScaleIntensity(minv=None, maxv=None, factor=self.factor, dtype=self.dtype)(img) class RandBiasField(RandomizableTransform): @@ -463,17 +520,19 @@ class RandBiasField(RandomizableTransform): degree: degree of freedom of the polynomials. The value should be no less than 1. Defaults to 3. coeff_range: range of the random coefficients. Defaults to (0.0, 0.1). - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. prob: probability to do random bias field. """ + backend = [TransformBackends.NUMPY] + def __init__( self, degree: int = 3, coeff_range: Tuple[float, float] = (0.0, 0.1), dtype: DtypeLike = np.float32, - prob: float = 1.0, + prob: float = 0.1, ) -> None: RandomizableTransform.__init__(self, prob) if degree < 1: @@ -507,18 +566,23 @@ def _generate_random_field(self, spatial_shape: Sequence[int], degree: int, coef return np.polynomial.legendre.leggrid3d(coords[0], coords[1], coords[2], coeff_mat) raise NotImplementedError("only supports 2D or 3D fields") - def randomize(self, data: np.ndarray) -> None: + def randomize(self, img_size: Sequence[int]) -> None: super().randomize(None) - n_coeff = int(np.prod([(self.degree + k) / k for k in range(1, len(data.shape[1:]) + 1)])) + if not self._do_transform: + return None + n_coeff = int(np.prod([(self.degree + k) / k for k in range(1, len(img_size) + 1)])) self._coeff = self.R.uniform(*self.coeff_range, n_coeff).tolist() - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - self.randomize(data=img) + if randomize: + self.randomize(img_size=img.shape[1:]) + if not self._do_transform: return img + num_channels, *spatial_shape = img.shape _bias_fields = np.stack( [ @@ -527,7 +591,10 @@ def __call__(self, img: np.ndarray): ], axis=0, ) - return (img * np.exp(_bias_fields)).astype(self.dtype) + img_np, *_ = convert_data_type(img, np.ndarray) + out: NdarrayOrTensor = img_np * np.exp(_bias_fields) + out, *_ = convert_to_dst_type(src=out, dst=img, dtype=self.dtype or img.dtype) + return out class NormalizeIntensity(Transform): @@ -542,9 +609,9 @@ class NormalizeIntensity(Transform): subtrahend: the amount to subtract by (usually the mean). divisor: the amount to divide by (usually the standard deviation). nonzero: whether only normalize non-zero values. - channel_wise: if using calculated mean and std, calculate on each channel separately - or calculate on the entire image directly. - dtype: output data type, defaults to float32. + channel_wise: if True, calculate on each channel separately, otherwise, calculate on + the entire image directly. default to False. + dtype: output data type, if None, same as input image. defaults to float32. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -611,6 +678,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is a channel-first array if `self.channel_wise` is True, """ + dtype = self.dtype or img.dtype if self.channel_wise: if self.subtrahend is not None and len(self.subtrahend) != len(img): raise ValueError(f"img has {len(img)} channels, but subtrahend has {len(self.subtrahend)} components.") @@ -626,7 +694,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: else: img = self._normalize(img, self.subtrahend, self.divisor) - out, *_ = convert_data_type(img, dtype=self.dtype) + out, *_ = convert_data_type(img, dtype=dtype) return out @@ -641,6 +709,8 @@ class ThresholdIntensity(Transform): cval: value to fill the remaining parts of the image, default is 0. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, threshold: float, above: bool = True, cval: float = 0.0) -> None: if not isinstance(threshold, (int, float)): raise ValueError("threshold must be a float or int number.") @@ -648,13 +718,14 @@ def __init__(self, threshold: float, above: bool = True, cval: float = 0.0) -> N self.above = above self.cval = cval - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - return np.asarray( - np.where(img > self.threshold if self.above else img < self.threshold, img, self.cval), dtype=img.dtype - ) + mask = img > self.threshold if self.above else img < self.threshold + res = where(mask, img, self.cval) + res, *_ = convert_data_type(res, dtype=img.dtype) + return res class ScaleIntensityRange(Transform): @@ -662,34 +733,55 @@ class ScaleIntensityRange(Transform): Apply specific intensity scaling to the whole numpy array. Scaling from [a_min, a_max] to [b_min, b_max] with clip option. + When `b_min` or `b_max` are `None`, `scacled_array * (b_max - b_min) + b_min` will be skipped. + If `clip=True`, when `b_min`/`b_max` is None, the clipping is not performed on the corresponding edge. + Args: a_min: intensity original range min. a_max: intensity original range max. b_min: intensity target range min. b_max: intensity target range max. clip: whether to perform clip after scaling. + dtype: output data type, if None, same as input image. defaults to float32. """ - def __init__(self, a_min: float, a_max: float, b_min: float, b_max: float, clip: bool = False) -> None: + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__( + self, + a_min: float, + a_max: float, + b_min: Optional[float] = None, + b_max: Optional[float] = None, + clip: bool = False, + dtype: DtypeLike = np.float32, + ) -> None: self.a_min = a_min self.a_max = a_max self.b_min = b_min self.b_max = b_max self.clip = clip + self.dtype = dtype - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ + dtype = self.dtype or img.dtype if self.a_max - self.a_min == 0.0: warn("Divide by zero (a_min == a_max)", Warning) + if self.b_min is None: + return img - self.a_min return img - self.a_min + self.b_min img = (img - self.a_min) / (self.a_max - self.a_min) - img = img * (self.b_max - self.b_min) + self.b_min + if (self.b_min is not None) and (self.b_max is not None): + img = img * (self.b_max - self.b_min) + self.b_min if self.clip: - img = np.asarray(np.clip(img, self.b_min, self.b_max)) - return img + img = clip(img, self.b_min, self.b_max) + ret, *_ = convert_data_type(img, dtype=dtype) + + return ret class AdjustContrast(Transform): @@ -702,19 +794,22 @@ class AdjustContrast(Transform): gamma: gamma value to adjust the contrast as function. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, gamma: float) -> None: if not isinstance(gamma, (int, float)): raise ValueError("gamma must be a float or int number.") self.gamma = gamma - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ epsilon = 1e-7 img_min = img.min() img_range = img.max() - img_min - return np.power(((img - img_min) / float(img_range + epsilon)), self.gamma) * img_range + img_min + ret: NdarrayOrTensor = ((img - img_min) / float(img_range + epsilon)) ** self.gamma * img_range + img_min + return ret class RandAdjustContrast(RandomizableTransform): @@ -729,6 +824,8 @@ class RandAdjustContrast(RandomizableTransform): If single number, value is picked from (0.5, gamma), default is (0.5, 4.5). """ + backend = AdjustContrast.backend + def __init__(self, prob: float = 0.1, gamma: Union[Sequence[float], float] = (0.5, 4.5)) -> None: RandomizableTransform.__init__(self, prob) @@ -743,34 +840,39 @@ def __init__(self, prob: float = 0.1, gamma: Union[Sequence[float], float] = (0. else: self.gamma = (min(gamma), max(gamma)) - self.gamma_value: float + self.gamma_value: Optional[float] = None def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) + if not self._do_transform: + return None self.gamma_value = self.R.uniform(low=self.gamma[0], high=self.gamma[1]) - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - self.randomize() - if self.gamma_value is None: - raise ValueError("gamma_value is not set.") + if randomize: + self.randomize() + if not self._do_transform: return img - adjuster = AdjustContrast(self.gamma_value) - return adjuster(img) + + if self.gamma_value is None: + raise RuntimeError("gamma_value is not set, please call `randomize` function first.") + return AdjustContrast(self.gamma_value)(img) class ScaleIntensityRangePercentiles(Transform): """ Apply range scaling to a numpy array based on the intensity distribution of the input. - By default this transform will scale from [lower_intensity_percentile, upper_intensity_percentile] to [b_min, b_max], where - {lower,upper}_intensity_percentile are the intensity values at the corresponding percentiles of ``img``. + By default this transform will scale from [lower_intensity_percentile, upper_intensity_percentile] to + `[b_min, b_max]`, where {lower,upper}_intensity_percentile are the intensity values at the corresponding + percentiles of ``img``. - The ``relative`` parameter can also be set to scale from [lower_intensity_percentile, upper_intensity_percentile] to the - lower and upper percentiles of the output range [b_min, b_max] + The ``relative`` parameter can also be set to scale from [lower_intensity_percentile, upper_intensity_percentile] + to the lower and upper percentiles of the output range [b_min, b_max]. For example: @@ -807,6 +909,9 @@ class ScaleIntensityRangePercentiles(Transform): [20., 60., 100., 140., 180.], [20., 60., 100., 140., 180.]]] + See Also: + + - :py:class:`monai.transforms.ScaleIntensityRange` Args: lower: lower intensity percentile. @@ -815,10 +920,23 @@ class ScaleIntensityRangePercentiles(Transform): b_max: intensity target range max. clip: whether to perform clip after scaling. relative: whether to scale to the corresponding percentiles of [b_min, b_max]. + channel_wise: if True, compute intensity percentile and normalize every channel separately. + default to False. + dtype: output data type, if None, same as input image. defaults to float32. """ + backend = ScaleIntensityRange.backend + def __init__( - self, lower: float, upper: float, b_min: float, b_max: float, clip: bool = False, relative: bool = False + self, + lower: float, + upper: float, + b_min: Optional[float], + b_max: Optional[float], + clip: bool = False, + relative: bool = False, + channel_wise: bool = False, + dtype: DtypeLike = np.float32, ) -> None: if lower < 0.0 or lower > 100.0: raise ValueError("Percentiles must be in the range [0, 100]") @@ -830,25 +948,36 @@ def __init__( self.b_max = b_max self.clip = clip self.relative = relative + self.channel_wise = channel_wise + self.dtype = dtype - def __call__(self, img: np.ndarray): - """ - Apply the transform to `img`. - """ - a_min = np.percentile(img, self.lower) - a_max = np.percentile(img, self.upper) + def _normalize(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + a_min: float = percentile(img, self.lower) # type: ignore + a_max: float = percentile(img, self.upper) # type: ignore b_min = self.b_min b_max = self.b_max if self.relative: + if (self.b_min is None) or (self.b_max is None): + raise ValueError("If it is relative, b_min and b_max should not be None.") b_min = ((self.b_max - self.b_min) * (self.lower / 100.0)) + self.b_min b_max = ((self.b_max - self.b_min) * (self.upper / 100.0)) + self.b_min - scalar = ScaleIntensityRange(a_min=a_min, a_max=a_max, b_min=b_min, b_max=b_max, clip=False) + scalar = ScaleIntensityRange( + a_min=a_min, a_max=a_max, b_min=b_min, b_max=b_max, clip=self.clip, dtype=self.dtype + ) img = scalar(img) + return img - if self.clip: - img = np.asarray(np.clip(img, self.b_min, self.b_max)) + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + """ + Apply the transform to `img`. + """ + if self.channel_wise: + for i, d in enumerate(img): + img[i] = self._normalize(img=d) # type: ignore + else: + img = self._normalize(img=img) return img @@ -871,11 +1000,13 @@ class MaskIntensity(Transform): """ - def __init__(self, mask_data: Optional[np.ndarray] = None, select_fn: Callable = is_positive) -> None: + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self, mask_data: Optional[NdarrayOrTensor] = None, select_fn: Callable = is_positive) -> None: self.mask_data = mask_data self.select_fn = select_fn - def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor, mask_data: Optional[NdarrayOrTensor] = None) -> NdarrayOrTensor: """ Args: mask_data: if mask data is single channel, apply to every channel @@ -892,14 +1023,16 @@ def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> n if mask_data is None: raise ValueError("must provide the mask_data when initializing the transform or at runtime.") - mask_data = np.asarray(self.select_fn(mask_data)) - if mask_data.shape[0] != 1 and mask_data.shape[0] != img.shape[0]: + mask_data_, *_ = convert_to_dst_type(src=mask_data, dst=img) + + mask_data_ = self.select_fn(mask_data_) + if mask_data_.shape[0] != 1 and mask_data_.shape[0] != img.shape[0]: raise ValueError( "When mask_data is not single channel, mask_data channels must match img, " - f"got img channels={img.shape[0]} mask_data channels={mask_data.shape[0]}." + f"got img channels={img.shape[0]} mask_data channels={mask_data_.shape[0]}." ) - return np.asarray(img * mask_data) + return img * mask_data_ class SavitzkyGolaySmooth(Transform): @@ -914,7 +1047,7 @@ class SavitzkyGolaySmooth(Transform): or ``'circular'``. Default: ``'zeros'``. See ``torch.nn.Conv1d()`` for more information. """ - backend = [TransformBackends.NUMPY] + backend = [TransformBackends.TORCH] def __init__(self, window_length: int, order: int, axis: int = 1, mode: str = "zeros"): @@ -927,7 +1060,7 @@ def __init__(self, window_length: int, order: int, axis: int = 1, mode: str = "z self.mode = mode self.img_t: torch.Tensor = torch.tensor(0.0) - def __call__(self, img: NdarrayOrTensor) -> torch.Tensor: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: array containing input data. Must be real and in shape [channels, spatial1, spatial2, ...]. @@ -941,7 +1074,9 @@ def __call__(self, img: NdarrayOrTensor) -> torch.Tensor: # add one to transform axis because a batch axis will be added at dimension 0 savgol_filter = SavitzkyGolayFilter(self.window_length, self.order, self.axis + 1, self.mode) # convert to Tensor and add Batch axis expected by HilbertTransform - out: torch.Tensor = savgol_filter(self.img_t.unsqueeze(0)).squeeze(0) + smoothed = savgol_filter(self.img_t.unsqueeze(0)).squeeze(0) + out, *_ = convert_to_dst_type(smoothed, dst=img) + return out @@ -952,14 +1087,16 @@ class DetectEnvelope(Transform): Args: axis: Axis along which to detect the envelope. Default 1, i.e. the first spatial dimension. - N: FFT size. Default img.shape[axis]. Input will be zero-padded or truncated to this size along dimension + n: FFT size. Default img.shape[axis]. Input will be zero-padded or truncated to this size along dimension ``axis``. """ + backend = [TransformBackends.TORCH] + def __init__(self, axis: int = 1, n: Union[int, None] = None) -> None: - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__) if axis < 0: @@ -968,7 +1105,7 @@ def __init__(self, axis: int = 1, n: Union[int, None] = None) -> None: self.axis = axis self.n = n - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor): """ Args: @@ -978,11 +1115,15 @@ def __call__(self, img: np.ndarray): np.ndarray containing envelope of data in img along the specified axis. """ + img_t: torch.Tensor + img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore # add one to transform axis because a batch axis will be added at dimension 0 hilbert_transform = HilbertTransform(self.axis + 1, self.n) # convert to Tensor and add Batch axis expected by HilbertTransform - input_data = torch.as_tensor(np.ascontiguousarray(img)).unsqueeze(0) - return np.abs(hilbert_transform(input_data).squeeze(0).numpy()) + out = hilbert_transform(img_t.unsqueeze(0)).squeeze(0).abs() + out, *_ = convert_to_dst_type(src=out, dst=img) + + return out class GaussianSmooth(Transform): @@ -999,14 +1140,25 @@ class GaussianSmooth(Transform): """ + backend = [TransformBackends.TORCH] + def __init__(self, sigma: Union[Sequence[float], float] = 1.0, approx: str = "erf") -> None: self.sigma = sigma self.approx = approx - def __call__(self, img: np.ndarray): - gaussian_filter = GaussianFilter(img.ndim - 1, self.sigma, approx=self.approx) - input_data = torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0) - return gaussian_filter(input_data).squeeze(0).detach().numpy() + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + img_t: torch.Tensor + img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) # type: ignore + sigma: Union[Sequence[torch.Tensor], torch.Tensor] + if isinstance(self.sigma, Sequence): + sigma = [torch.as_tensor(s, device=img_t.device) for s in self.sigma] + else: + sigma = torch.as_tensor(self.sigma, device=img_t.device) + gaussian_filter = GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx) + out_t: torch.Tensor = gaussian_filter(img_t.unsqueeze(0)).squeeze(0) + out, *_ = convert_data_type(out_t, type(img), device=img.device if isinstance(img, torch.Tensor) else None) + + return out class RandGaussianSmooth(RandomizableTransform): @@ -1023,6 +1175,8 @@ class RandGaussianSmooth(RandomizableTransform): """ + backend = GaussianSmooth.backend + def __init__( self, sigma_x: Tuple[float, float] = (0.25, 1.5), @@ -1043,14 +1197,19 @@ def __init__( def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) + if not self._do_transform: + return None self.x = self.R.uniform(low=self.sigma_x[0], high=self.sigma_x[1]) self.y = self.R.uniform(low=self.sigma_y[0], high=self.sigma_y[1]) self.z = self.R.uniform(low=self.sigma_z[0], high=self.sigma_z[1]) - def __call__(self, img: np.ndarray): - self.randomize() + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: + if randomize: + self.randomize() + if not self._do_transform: return img + sigma = ensure_tuple_size(tup=(self.x, self.y, self.z), dim=img.ndim - 1) return GaussianSmooth(sigma=sigma, approx=self.approx)(img) @@ -1082,6 +1241,8 @@ class GaussianSharpen(Transform): """ + backend = [TransformBackends.TORCH] + def __init__( self, sigma1: Union[Sequence[float], float] = 3.0, @@ -1094,13 +1255,19 @@ def __init__( self.alpha = alpha self.approx = approx - def __call__(self, img: np.ndarray): - gaussian_filter1 = GaussianFilter(img.ndim - 1, self.sigma1, approx=self.approx) - gaussian_filter2 = GaussianFilter(img.ndim - 1, self.sigma2, approx=self.approx) - input_data = torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0) - blurred_f = gaussian_filter1(input_data) - filter_blurred_f = gaussian_filter2(blurred_f) - return (blurred_f + self.alpha * (blurred_f - filter_blurred_f)).squeeze(0).detach().numpy() + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + img_t: torch.Tensor + img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float32) # type: ignore + + gf1, gf2 = ( + GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx).to(img_t.device) + for sigma in (self.sigma1, self.sigma2) + ) + blurred_f = gf1(img_t.unsqueeze(0)) + filter_blurred_f = gf2(blurred_f) + out_t: torch.Tensor = (blurred_f + self.alpha * (blurred_f - filter_blurred_f)).squeeze(0) + out, *_ = convert_data_type(out_t, type(img), device=img.device if isinstance(img, torch.Tensor) else None) + return out class RandGaussianSharpen(RandomizableTransform): @@ -1125,6 +1292,8 @@ class RandGaussianSharpen(RandomizableTransform): """ + backend = GaussianSharpen.backend + def __init__( self, sigma1_x: Tuple[float, float] = (0.5, 1.0), @@ -1146,9 +1315,18 @@ def __init__( self.sigma2_z = sigma2_z self.alpha = alpha self.approx = approx + self.x1: Optional[float] = None + self.y1: Optional[float] = None + self.z1: Optional[float] = None + self.x2: Optional[float] = None + self.y2: Optional[float] = None + self.z2: Optional[float] = None + self.a: Optional[float] = None def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) + if not self._do_transform: + return None self.x1 = self.R.uniform(low=self.sigma1_x[0], high=self.sigma1_x[1]) self.y1 = self.R.uniform(low=self.sigma1_y[0], high=self.sigma1_y[1]) self.z1 = self.R.uniform(low=self.sigma1_z[0], high=self.sigma1_z[1]) @@ -1160,10 +1338,15 @@ def randomize(self, data: Optional[Any] = None) -> None: self.z2 = self.R.uniform(low=sigma2_z[0], high=sigma2_z[1]) self.a = self.R.uniform(low=self.alpha[0], high=self.alpha[1]) - def __call__(self, img: np.ndarray): - self.randomize() + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: + if randomize: + self.randomize() + if not self._do_transform: return img + + if self.x2 is None or self.y2 is None or self.z2 is None or self.a is None: + raise RuntimeError("please call the `randomize()` function first.") sigma1 = ensure_tuple_size(tup=(self.x1, self.y1, self.z1), dim=img.ndim - 1) sigma2 = ensure_tuple_size(tup=(self.x2, self.y2, self.z2), dim=img.ndim - 1) return GaussianSharpen(sigma1=sigma1, sigma2=sigma2, alpha=self.a, approx=self.approx)(img) @@ -1180,6 +1363,8 @@ class RandHistogramShift(RandomizableTransform): prob: probability of histogram shift. """ + backend = [TransformBackends.NUMPY] + def __init__(self, num_control_points: Union[Tuple[int, int], int] = 10, prob: float = 0.1) -> None: RandomizableTransform.__init__(self, prob) @@ -1193,9 +1378,13 @@ def __init__(self, num_control_points: Union[Tuple[int, int], int] = 10, prob: f if min(num_control_points) <= 2: raise ValueError("num_control_points should be greater than or equal to 3") self.num_control_points = (min(num_control_points), max(num_control_points)) + self.reference_control_points: np.ndarray + self.floating_control_points: np.ndarray def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) + if not self._do_transform: + return None num_control_point = self.R.randint(self.num_control_points[0], self.num_control_points[1] + 1) self.reference_control_points = np.linspace(0, 1, num_control_point) self.floating_control_points = np.copy(self.reference_control_points) @@ -1204,79 +1393,26 @@ def randomize(self, data: Optional[Any] = None) -> None: self.floating_control_points[i - 1], self.floating_control_points[i + 1] ) - def __call__(self, img: np.ndarray) -> np.ndarray: - self.randomize() + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: + if randomize: + self.randomize() + if not self._do_transform: return img - img_min, img_max = img.min(), img.max() + + if self.reference_control_points is None or self.floating_control_points is None: + raise RuntimeError("please call the `randomize()` function first.") + img_np: np.ndarray + img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore + img_min, img_max = img_np.min(), img_np.max() reference_control_points_scaled = self.reference_control_points * (img_max - img_min) + img_min floating_control_points_scaled = self.floating_control_points * (img_max - img_min) + img_min - return np.asarray( - np.interp(img, reference_control_points_scaled, floating_control_points_scaled), dtype=img.dtype + img_np = np.asarray( + np.interp(img_np, reference_control_points_scaled, floating_control_points_scaled), dtype=img_np.dtype ) - - -class RandGibbsNoise(RandomizableTransform): - """ - Naturalistic image augmentation via Gibbs artifacts. The transform - randomly applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts - are one of the common type of type artifacts appearing in MRI scans. - - The transform is applied to all the channels in the data. - - For general information on Gibbs artifacts, please refer to: - https://pubs.rsna.org/doi/full/10.1148/rg.313105115 - https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949 - - - Args: - prob (float): probability of applying the transform. - alpha (float, Sequence(float)): Parametrizes the intensity of the Gibbs noise filter applied. Takes - values in the interval [0,1] with alpha = 0 acting as the identity mapping. - If a length-2 list is given as [a,b] then the value of alpha will be - sampled uniformly from the interval [a,b]. 0 <= a <= b <= 1. - as_tensor_output: if true return torch.Tensor, else return np.array. default: True. - """ - - def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0), as_tensor_output: bool = True) -> None: - - if len(alpha) != 2: - raise ValueError("alpha length must be 2.") - if alpha[1] > 1 or alpha[0] < 0: - raise ValueError("alpha must take values in the interval [0,1]") - if alpha[0] > alpha[1]: - raise ValueError("When alpha = [a,b] we need a < b.") - - self.alpha = alpha - self.sampled_alpha = -1.0 # stores last alpha sampled by randomize() - self.as_tensor_output = as_tensor_output - - RandomizableTransform.__init__(self, prob=prob) - - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: - - # randomize application and possibly alpha - self._randomize(None) - - if self._do_transform: - # apply transform - transform = GibbsNoise(self.sampled_alpha, self.as_tensor_output) - img = transform(img) - else: - if isinstance(img, np.ndarray) and self.as_tensor_output: - img = torch.Tensor(img) - elif isinstance(img, torch.Tensor) and not self.as_tensor_output: - img = img.detach().cpu().numpy() + img, *_ = convert_to_dst_type(img_np, dst=img) return img - def _randomize(self, _: Any) -> None: - """ - (1) Set random variable to apply the transform. - (2) Get alpha from uniform distribution. - """ - super().randomize(None) - self.sampled_alpha = self.R.uniform(self.alpha[0], self.alpha[1]) - class GibbsNoise(Transform, Fourier): """ @@ -1296,21 +1432,20 @@ class GibbsNoise(Transform, Fourier): Args: alpha: Parametrizes the intensity of the Gibbs noise filter applied. Takes values in the interval [0,1] with alpha = 0 acting as the identity mapping. - as_tensor_output: if true return torch.Tensor, else return np.array. Default: True. """ - def __init__(self, alpha: float = 0.5, as_tensor_output: bool = True) -> None: + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + @deprecated_arg(name="as_tensor_output", since="0.6") + def __init__(self, alpha: float = 0.1, as_tensor_output: bool = True) -> None: if alpha > 1 or alpha < 0: - raise ValueError("alpha must take values in the interval [0,1].") + raise ValueError("alpha must take values in the interval [0, 1].") self.alpha = alpha - self.as_tensor_output = as_tensor_output - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: n_dims = len(img.shape[1:]) - if isinstance(img, np.ndarray): - img = torch.Tensor(img) # FT k = self.shift_fourier(img, n_dims) # build and apply mask @@ -1318,13 +1453,13 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, # map back img = self.inv_shift_fourier(k, n_dims) - return img if self.as_tensor_output else img.cpu().detach().numpy() + return img - def _apply_mask(self, k: torch.Tensor) -> torch.Tensor: + def _apply_mask(self, k: NdarrayOrTensor) -> NdarrayOrTensor: """Builds and applies a mask on the spatial dimensions. Args: - k (np.ndarray): k-space version of the image. + k: k-space version of the image. Returns: masked version of the k-space image. """ @@ -1345,11 +1480,73 @@ def _apply_mask(self, k: torch.Tensor) -> torch.Tensor: # add channel dimension into mask mask = np.repeat(mask[None], k.shape[0], axis=0) + if isinstance(k, torch.Tensor): + mask, *_ = convert_data_type(mask, torch.Tensor, device=k.device) + # apply binary mask - k_masked = k * torch.tensor(mask, device=k.device) + k_masked: NdarrayOrTensor + k_masked = k * mask return k_masked +class RandGibbsNoise(RandomizableTransform): + """ + Naturalistic image augmentation via Gibbs artifacts. The transform + randomly applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts + are one of the common type of type artifacts appearing in MRI scans. + + The transform is applied to all the channels in the data. + + For general information on Gibbs artifacts, please refer to: + https://pubs.rsna.org/doi/full/10.1148/rg.313105115 + https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949 + + + Args: + prob (float): probability of applying the transform. + alpha (Sequence(float)): Parametrizes the intensity of the Gibbs noise filter applied. Takes + values in the interval [0,1] with alpha = 0 acting as the identity mapping. + If a length-2 list is given as [a,b] then the value of alpha will be + sampled uniformly from the interval [a,b]. 0 <= a <= b <= 1. + """ + + backend = GibbsNoise.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") + def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0), as_tensor_output: bool = True) -> None: + if len(alpha) != 2: + raise ValueError("alpha length must be 2.") + if alpha[1] > 1 or alpha[0] < 0: + raise ValueError("alpha must take values in the interval [0, 1]") + if alpha[0] > alpha[1]: + raise ValueError("When alpha = [a,b] we need a < b.") + + self.alpha = alpha + self.sampled_alpha = -1.0 # stores last alpha sampled by randomize() + + RandomizableTransform.__init__(self, prob=prob) + + def randomize(self, data: Any) -> None: + """ + (1) Set random variable to apply the transform. + (2) Get alpha from uniform distribution. + """ + super().randomize(None) + if not self._do_transform: + return None + self.sampled_alpha = self.R.uniform(self.alpha[0], self.alpha[1]) + + def __call__(self, img: NdarrayOrTensor, randomize: bool = True): + if randomize: + # randomize application and possibly alpha + self.randomize(None) + + if not self._do_transform: + return img + + return GibbsNoise(self.sampled_alpha)(img) + + class KSpaceSpikeNoise(Transform, Fourier): """ Apply localized spikes in `k`-space at the given locations and intensities. @@ -1377,8 +1574,6 @@ class KSpaceSpikeNoise(Transform, Fourier): receive a sequence of intensities. This value should be tested as it is data-dependent. The default values are the 2.5 the mean of the log-intensity for each channel. - as_tensor_output: if ``True`` return torch.Tensor, else return np.array. - Default: ``True``. Example: When working with 4D data, ``KSpaceSpikeNoise(loc = ((3,60,64,32), (64,60,32)), k_intensity = (13,14))`` @@ -1387,6 +1582,9 @@ class KSpaceSpikeNoise(Transform, Fourier): with `log-intensity = 14`. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, loc: Union[Tuple, Sequence[Tuple]], @@ -1395,7 +1593,6 @@ def __init__( ): self.loc = ensure_tuple(loc) - self.as_tensor_output = as_tensor_output self.k_intensity = k_intensity # assert one-to-one relationship between factors and locations @@ -1409,7 +1606,7 @@ def __init__( if isinstance(self.loc[0], Sequence) and k_intensity is not None and not isinstance(self.k_intensity, Sequence): raise ValueError("There must be one intensity_factor value for each tuple of indices in loc.") - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: image with dimensions (C, H, W) or (C, H, W, D) @@ -1421,22 +1618,21 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, raise RuntimeError("Image needs a channel direction.") if isinstance(self.loc[0], int) and len(img.shape) == 4 and len(self.loc) == 2: raise RuntimeError("Input images of dimension 4 need location tuple to be length 3 or 4") - if isinstance(self.loc[0], Sequence) and len(img.shape) == 4 and min(map(lambda x: len(x), self.loc)) == 2: + if isinstance(self.loc[0], Sequence) and len(img.shape) == 4 and min(map(len, self.loc)) == 2: raise RuntimeError("Input images of dimension 4 need location tuple to be length 3 or 4") n_dims = len(img.shape[1:]) - if isinstance(img, np.ndarray): - img = torch.Tensor(img) # FT k = self.shift_fourier(img, n_dims) - log_abs = torch.log(torch.absolute(k) + 1e-10) - phase = torch.angle(k) + lib = np if isinstance(k, np.ndarray) else torch + log_abs = lib.log(lib.abs(k) + 1e-10) # type: ignore + phase = lib.angle(k) # type: ignore k_intensity = self.k_intensity # default log intensity if k_intensity is None: - k_intensity = tuple(torch.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5) + k_intensity = tuple(lib.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5) # type: ignore # highlight if isinstance(self.loc[0], Sequence): @@ -1445,10 +1641,10 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, else: self._set_spike(log_abs, self.loc, k_intensity) # map back - k = torch.exp(log_abs) * torch.exp(1j * phase) - img = self.inv_shift_fourier(k, n_dims) + k = lib.exp(log_abs) * lib.exp(1j * phase) # type: ignore + img, *_ = convert_to_dst_type(self.inv_shift_fourier(k, n_dims), dst=img) - return img if self.as_tensor_output else img.cpu().detach().numpy() + return img def _check_indices(self, img) -> None: """Helper method to check consistency of self.loc and input image. @@ -1468,7 +1664,7 @@ def _check_indices(self, img) -> None: f"The index value at position {i} of one of the tuples in loc = {self.loc} is out of bounds for current image." ) - def _set_spike(self, k: torch.Tensor, idx: Tuple, val: Union[Sequence[float], float]): + def _set_spike(self, k: NdarrayOrTensor, idx: Tuple, val: Union[Sequence[float], float]): """ Helper function to introduce a given intensity at given location. @@ -1504,18 +1700,14 @@ class RandKSpaceSpikeNoise(RandomizableTransform, Fourier): Args: prob: probability of applying the transform, either on all channels at once, or channel-wise if ``channel_wise = True``. - intensity_range: pass a tuple - (a, b) to sample the log-intensity from the interval (a, b) - uniformly for all channels. Or pass sequence of intevals + intensity_range: pass a tuple (a, b) to sample the log-intensity from the interval (a, b) + uniformly for all channels. Or pass sequence of intervals ((a0, b0), (a1, b1), ...) to sample for each respective channel. - In the second case, the number of 2-tuples must match the number of - channels. + In the second case, the number of 2-tuples must match the number of channels. Default ranges is `(0.95x, 1.10x)` where `x` is the mean log-intensity for each channel. channel_wise: treat each channel independently. True by default. - as_tensor_output: if True return torch.Tensor, else - return np.array. default: True. Example: To apply `k`-space spikes randomly with probability 0.5, and @@ -1524,17 +1716,19 @@ class RandKSpaceSpikeNoise(RandomizableTransform, Fourier): ``RandKSpaceSpikeNoise(prob=0.5, intensity_range=(11, 12), channel_wise=True)`` """ + backend = KSpaceSpikeNoise.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, prob: float = 0.1, intensity_range: Optional[Sequence[Union[Sequence[float], float]]] = None, - channel_wise=True, + channel_wise: bool = True, as_tensor_output: bool = True, ): self.intensity_range = intensity_range self.channel_wise = channel_wise - self.as_tensor_output = as_tensor_output self.sampled_k_intensity: List = [] self.sampled_locs: List[Tuple] = [] @@ -1543,13 +1737,14 @@ def __init__( super().__init__(prob) - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: NdarrayOrTensor, randomize: bool = True): """ Apply transform to `img`. Assumes data is in channel-first form. Args: img: image with dimensions (C, H, W) or (C, H, W, D) """ + if ( self.intensity_range is not None and isinstance(self.intensity_range[0], Sequence) @@ -1562,20 +1757,16 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, self.sampled_k_intensity = [] self.sampled_locs = [] - if not isinstance(img, torch.Tensor): - img = torch.Tensor(img) - - intensity_range = self._make_sequence(img) - self._randomize(img, intensity_range) + if randomize: + intensity_range = self._make_sequence(img) + self.randomize(img, intensity_range) - # build/appy transform only if there are spike locations - if self.sampled_locs: - transform = KSpaceSpikeNoise(self.sampled_locs, self.sampled_k_intensity, self.as_tensor_output) - return transform(img) + if not self._do_transform: + return img - return img if self.as_tensor_output else img.detach().numpy() + return KSpaceSpikeNoise(self.sampled_locs, self.sampled_k_intensity)(img) - def _randomize(self, img: torch.Tensor, intensity_range: Sequence[Sequence[float]]) -> None: + def randomize(self, img: NdarrayOrTensor, intensity_range: Sequence[Sequence[float]]) -> None: # type: ignore """ Helper method to sample both the location and intensity of the spikes. When not working channel wise (channel_wise=False) it use the random @@ -1585,25 +1776,24 @@ def _randomize(self, img: torch.Tensor, intensity_range: Sequence[Sequence[float When working channel wise, the method randomly samples a location and intensity for each channel depending on ``self._do_transform``. """ - # randomizing per channel + super().randomize(None) + if not self._do_transform: + return None if self.channel_wise: + # randomizing per channel for i, chan in enumerate(img): - super().randomize(None) - if self._do_transform: - self.sampled_locs.append((i,) + tuple(self.R.randint(0, k) for k in chan.shape)) - self.sampled_k_intensity.append(self.R.uniform(intensity_range[i][0], intensity_range[i][1])) - # working with all channels together + self.sampled_locs.append((i,) + tuple(self.R.randint(0, k) for k in chan.shape)) + self.sampled_k_intensity.append(self.R.uniform(intensity_range[i][0], intensity_range[i][1])) else: - super().randomize(None) - if self._do_transform: - spatial = tuple(self.R.randint(0, k) for k in img.shape[1:]) - self.sampled_locs = [(i,) + spatial for i in range(img.shape[0])] - if isinstance(intensity_range[0], Sequence): - self.sampled_k_intensity = [self.R.uniform(p[0], p[1]) for p in intensity_range] - else: - self.sampled_k_intensity = [self.R.uniform(intensity_range[0], intensity_range[1])] * len(img) + # working with all channels together + spatial = tuple(self.R.randint(0, k) for k in img.shape[1:]) + self.sampled_locs = [(i,) + spatial for i in range(img.shape[0])] + if isinstance(intensity_range[0], Sequence): + self.sampled_k_intensity = [self.R.uniform(p[0], p[1]) for p in intensity_range] + else: + self.sampled_k_intensity = [self.R.uniform(intensity_range[0], intensity_range[1])] * len(img) - def _make_sequence(self, x: torch.Tensor) -> Sequence[Sequence[float]]: + def _make_sequence(self, x: NdarrayOrTensor) -> Sequence[Sequence[float]]: """ Formats the sequence of intensities ranges to Sequence[Sequence[float]]. """ @@ -1615,7 +1805,7 @@ def _make_sequence(self, x: torch.Tensor) -> Sequence[Sequence[float]]: return (ensure_tuple(self.intensity_range),) * x.shape[0] return ensure_tuple(self.intensity_range) - def _set_default_range(self, img: torch.Tensor) -> Sequence[Sequence[float]]: + def _set_default_range(self, img: NdarrayOrTensor) -> Sequence[Sequence[float]]: """ Sets default intensity ranges to be sampled. @@ -1625,18 +1815,17 @@ def _set_default_range(self, img: torch.Tensor) -> Sequence[Sequence[float]]: n_dims = len(img.shape[1:]) k = self.shift_fourier(img, n_dims) - log_abs = torch.log(torch.absolute(k) + 1e-10) - shifted_means = torch.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5 + mod = torch if isinstance(k, torch.Tensor) else np + log_abs = mod.log(mod.absolute(k) + 1e-10) # type: ignore + shifted_means = mod.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5 # type: ignore return tuple((i * 0.95, i * 1.1) for i in shifted_means) -class RandCoarseDropout(RandomizableTransform): +class RandCoarseTransform(RandomizableTransform): """ - Randomly coarse dropout regions in the image, then fill in the rectangular regions with specified value. - Or keep the rectangular regions and fill in the other areas with specified value. - Refer to papers: https://arxiv.org/abs/1708.04552, https://arxiv.org/pdf/1604.07379 - And other implementation: https://albumentations.ai/docs/api_reference/augmentations/transforms/ - #albumentations.augmentations.transforms.CoarseDropout. + Randomly select coarse regions in the image, then execute transform operations for the regions. + It's the base class of all kinds of region transforms. + Refer to papers: https://arxiv.org/abs/1708.04552 Args: holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to @@ -1646,12 +1835,6 @@ class RandCoarseDropout(RandomizableTransform): if some components of the `spatial_size` are non-positive values, the transform will use the corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. - dropout_holes: if `True`, dropout the regions of holes and fill value, if `False`, keep the holes and - dropout the outside and fill value. default to `True`. - fill_value: target value to fill the dropout regions, if providing a number, will use it as constant - value to fill all the regions. if providing a tuple for the `min` and `max`, will randomly select - value for every pixel / voxel from the range `[min, max)`. if None, will compute the `min` and `max` - value of input image then randomly select value to fill, default to None. max_holes: if not None, define the maximum number to randomly select the expected number of regions. max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region. if some components of the `max_spatial_size` are non-positive values, the transform will use the @@ -1661,12 +1844,12 @@ class RandCoarseDropout(RandomizableTransform): """ + backend = [TransformBackends.NUMPY] + def __init__( self, holes: int, spatial_size: Union[Sequence[int], int], - dropout_holes: bool = True, - fill_value: Optional[Union[Tuple[float, float], float]] = None, max_holes: Optional[int] = None, max_spatial_size: Optional[Union[Sequence[int], int]] = None, prob: float = 0.1, @@ -1676,17 +1859,14 @@ def __init__( raise ValueError("number of holes must be greater than 0.") self.holes = holes self.spatial_size = spatial_size - self.dropout_holes = dropout_holes - if isinstance(fill_value, (tuple, list)): - if len(fill_value) != 2: - raise ValueError("fill value should contain 2 numbers if providing the `min` and `max`.") - self.fill_value = fill_value self.max_holes = max_holes self.max_spatial_size = max_spatial_size self.hole_coords: List = [] def randomize(self, img_size: Sequence[int]) -> None: super().randomize(None) + if not self._do_transform: + return None size = fall_back_tuple(self.spatial_size, img_size) self.hole_coords = [] # clear previously computed coords num_holes = self.holes if self.max_holes is None else self.R.randint(self.holes, self.max_holes + 1) @@ -1697,28 +1877,143 @@ def randomize(self, img_size: Sequence[int]) -> None: valid_size = get_valid_patch_size(img_size, size) self.hole_coords.append((slice(None),) + get_random_patch(img_size, valid_size, self.R)) - def __call__(self, img: np.ndarray): - self.randomize(img.shape[1:]) - ret = img - if self._do_transform: - fill_value = (img.min(), img.max()) if self.fill_value is None else self.fill_value - - if self.dropout_holes: - for h in self.hole_coords: - if isinstance(fill_value, (tuple, list)): - ret[h] = self.R.uniform(fill_value[0], fill_value[1], size=img[h].shape) - else: - ret[h] = fill_value - else: + @abstractmethod + def _transform_holes(self, img: np.ndarray) -> np.ndarray: + """ + Transform the randomly selected `self.hole_coords` in input images. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: + if randomize: + self.randomize(img.shape[1:]) + + if not self._do_transform: + return img + + img_np: np.ndarray + img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore + out = self._transform_holes(img=img_np) + ret, *_ = convert_to_dst_type(src=out, dst=img) + return ret + + +class RandCoarseDropout(RandCoarseTransform): + """ + Randomly coarse dropout regions in the image, then fill in the rectangular regions with specified value. + Or keep the rectangular regions and fill in the other areas with specified value. + Refer to papers: https://arxiv.org/abs/1708.04552, https://arxiv.org/pdf/1604.07379 + And other implementation: https://albumentations.ai/docs/api_reference/augmentations/transforms/ + #albumentations.augmentations.transforms.CoarseDropout. + + Args: + holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to + randomly select the expected number of regions. + spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg + as the minimum spatial size to randomly select size for every region. + if some components of the `spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + dropout_holes: if `True`, dropout the regions of holes and fill value, if `False`, keep the holes and + dropout the outside and fill value. default to `True`. + fill_value: target value to fill the dropout regions, if providing a number, will use it as constant + value to fill all the regions. if providing a tuple for the `min` and `max`, will randomly select + value for every pixel / voxel from the range `[min, max)`. if None, will compute the `min` and `max` + value of input image then randomly select value to fill, default to None. + max_holes: if not None, define the maximum number to randomly select the expected number of regions. + max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region. + if some components of the `max_spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + prob: probability of applying the transform. + + """ + + def __init__( + self, + holes: int, + spatial_size: Union[Sequence[int], int], + dropout_holes: bool = True, + fill_value: Optional[Union[Tuple[float, float], float]] = None, + max_holes: Optional[int] = None, + max_spatial_size: Optional[Union[Sequence[int], int]] = None, + prob: float = 0.1, + ) -> None: + super().__init__( + holes=holes, spatial_size=spatial_size, max_holes=max_holes, max_spatial_size=max_spatial_size, prob=prob + ) + self.dropout_holes = dropout_holes + if isinstance(fill_value, (tuple, list)): + if len(fill_value) != 2: + raise ValueError("fill value should contain 2 numbers if providing the `min` and `max`.") + self.fill_value = fill_value + + def _transform_holes(self, img: np.ndarray): + """ + Fill the randomly selected `self.hole_coords` in input images. + Please note that we usually only use `self.R` in `randomize()` method, here is a special case. + + """ + fill_value = (img.min(), img.max()) if self.fill_value is None else self.fill_value + + if self.dropout_holes: + for h in self.hole_coords: if isinstance(fill_value, (tuple, list)): - ret = self.R.uniform(fill_value[0], fill_value[1], size=img.shape).astype(img.dtype) + img[h] = self.R.uniform(fill_value[0], fill_value[1], size=img[h].shape) else: - ret = np.full_like(img, fill_value) - for h in self.hole_coords: - ret[h] = img[h] + img[h] = fill_value + ret = img + else: + if isinstance(fill_value, (tuple, list)): + ret = self.R.uniform(fill_value[0], fill_value[1], size=img.shape).astype(img.dtype, copy=False) + else: + ret = np.full_like(img, fill_value) + for h in self.hole_coords: + ret[h] = img[h] return ret +class RandCoarseShuffle(RandCoarseTransform): + """ + Randomly select regions in the image, then shuffle the pixels within every region. + It shuffles every channel separately. + Refer to paper: + Kang, Guoliang, et al. "Patchshuffle regularization." arXiv preprint arXiv:1707.07103 (2017). + https://arxiv.org/abs/1707.07103 + + Args: + holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to + randomly select the expected number of regions. + spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg + as the minimum spatial size to randomly select size for every region. + if some components of the `spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + max_holes: if not None, define the maximum number to randomly select the expected number of regions. + max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region. + if some components of the `max_spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + prob: probability of applying the transform. + + """ + + def _transform_holes(self, img: np.ndarray): + """ + Shuffle the content of randomly selected `self.hole_coords` in input images. + Please note that we usually only use `self.R` in `randomize()` method, here is a special case. + + """ + for h in self.hole_coords: + # shuffle every channel separately + for i, c in enumerate(img[h]): + patch_channel = c.flatten() + self.R.shuffle(patch_channel) + img[h][i] = patch_channel.reshape(c.shape) + return img + + class HistogramNormalize(Transform): """ Apply the histogram normalization to input image. @@ -1732,16 +2027,18 @@ class HistogramNormalize(Transform): mask: if provided, must be ndarray of bools or 0s and 1s, and same shape as `image`. only points at which `mask==True` are used for the equalization. can also provide the mask along with img at runtime. - dtype: data type of the output, default to `float32`. + dtype: data type of the output, if None, same as input image. default to `float32`. """ + backend = [TransformBackends.NUMPY] + def __init__( self, num_bins: int = 256, min: int = 0, max: int = 255, - mask: Optional[np.ndarray] = None, + mask: Optional[NdarrayOrTensor] = None, dtype: DtypeLike = np.float32, ) -> None: self.num_bins = num_bins @@ -1750,104 +2047,15 @@ def __init__( self.mask = mask self.dtype = dtype - def __call__(self, img: np.ndarray, mask: Optional[np.ndarray] = None) -> np.ndarray: - return equalize_hist( - img=img, - mask=mask if mask is not None else self.mask, - num_bins=self.num_bins, - min=self.min, - max=self.max, - dtype=self.dtype, - ) - - -class LocalPatchShuffling(RandomizableTransform): - """ - Takes a 3D image and based on input of the local patch size, shuffles the pixels of the local patch within it. - This process is repeated a for N number of times where every time a different random block is selected for local - pixel shuffling. - - Kang, Guoliang, et al. "Patchshuffle regularization." arXiv preprint arXiv:1707.07103 (2017). - """ - - def __init__( - self, - prob: float = 1.0, - number_blocks: int = 1000, - blocksize_ratio: int = 10, - channel_wise: bool = True, - device: Optional[torch.device] = None, - image_only: bool = False, - ) -> None: - """ - Args: - prob: The chance of this transform occuring on the given volume. - number_blocks: Total number of time a random 3D block will be selected for local shuffling of pixels/voxels - contained in the block. - blocksize_ratio: This ratio can be used to estimate the local 3D block sizes that will be selected. - channel_wise: If True, treats each channel of the image separately. - device: device on which the tensor will be allocated. - image_only: if True return only the image volume, otherwise return (image, affine). - """ - RandomizableTransform.__init__(self, prob) - self.prob = prob - self.number_blocks = number_blocks - self.blocksize_ratio = blocksize_ratio - self.channel_wise = channel_wise - - def _local_patch_shuffle(self, img: Union[torch.Tensor, np.ndarray], number_blocks: int, blocksize_ratio: int): - im_shape = img.shape - img_copy = copy.deepcopy(img) - for _each_block in range(number_blocks): - - block_size_x = self.R.randint(1, im_shape[0] // blocksize_ratio) - block_size_y = self.R.randint(1, im_shape[1] // blocksize_ratio) - block_size_z = self.R.randint(1, im_shape[2] // blocksize_ratio) - - noise_x = self.R.randint(0, im_shape[0] - block_size_x) - noise_y = self.R.randint(0, im_shape[1] - block_size_y) - noise_z = self.R.randint(0, im_shape[2] - block_size_z) - - local_patch = img[ - noise_x : noise_x + block_size_x, - noise_y : noise_y + block_size_y, - noise_z : noise_z + block_size_z, - ] + def __call__(self, img: NdarrayOrTensor, mask: Optional[NdarrayOrTensor] = None) -> NdarrayOrTensor: + img_np: np.ndarray + img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore + mask = mask if mask is not None else self.mask + mask_np: Optional[np.ndarray] = None + if mask is not None: + mask_np, *_ = convert_data_type(mask, np.ndarray) # type: ignore - local_patch = local_patch.flatten() - self.R.shuffle(local_patch) - local_patch = local_patch.reshape((block_size_x, block_size_y, block_size_z)) + ret = equalize_hist(img=img_np, mask=mask_np, num_bins=self.num_bins, min=self.min, max=self.max) + out, *_ = convert_to_dst_type(src=ret, dst=img, dtype=self.dtype or img.dtype) - img_copy[ - noise_x : noise_x + block_size_x, noise_y : noise_y + block_size_y, noise_z : noise_z + block_size_z - ] = local_patch - - shuffled_image = img_copy - return shuffled_image - - def __call__( - self, - img: Union[np.ndarray, torch.Tensor], - # spatial_size: Optional[Union[Sequence[int], int]] = None, - # mode: Optional[Union[GridSampleMode, str]] = None, - # padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ): - """ - Args: - img: shape must be (num_channels, H, W[, D]), - - """ - - super().randomize(None) - if not self._do_transform: - return img - - if self.channel_wise: - # img = self._local_patch_shuffle(img=img) - for i, _d in enumerate(img): - img[i] = self._local_patch_shuffle( - img=img[i], blocksize_ratio=self.blocksize_ratio, number_blocks=self.number_blocks - ) - else: - raise AssertionError("If channel_wise is False, the image needs to be set to channel first") - return img + return out diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index bc53fb6b7b..48a259dbc5 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,13 +15,11 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ -from collections.abc import Iterable -from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np -import torch -from monai.config import DtypeLike, KeysCollection, NdarrayTensor +from monai.config import DtypeLike, KeysCollection from monai.config.type_definitions import NdarrayOrTensor from monai.transforms.intensity.array import ( AdjustContrast, @@ -32,11 +30,21 @@ KSpaceSpikeNoise, MaskIntensity, NormalizeIntensity, + RandAdjustContrast, RandBiasField, RandCoarseDropout, + RandCoarseShuffle, RandGaussianNoise, + RandGaussianSharpen, + RandGaussianSmooth, + RandGibbsNoise, + RandHistogramShift, RandKSpaceSpikeNoise, RandRicianNoise, + RandScaleIntensity, + RandShiftIntensity, + RandStdShiftIntensity, + SavitzkyGolaySmooth, ScaleIntensity, ScaleIntensityRange, ScaleIntensityRangePercentiles, @@ -44,9 +52,10 @@ StdShiftIntensity, ThresholdIntensity, ) -from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform +from monai.transforms.transform import MapTransform, RandomizableTransform from monai.transforms.utils import is_positive -from monai.utils import convert_to_dst_type, ensure_tuple, ensure_tuple_rep, ensure_tuple_size +from monai.utils import ensure_tuple, ensure_tuple_rep +from monai.utils.deprecate_utils import deprecated_arg __all__ = [ "RandGaussianNoised", @@ -65,6 +74,7 @@ "RandAdjustContrastd", "ScaleIntensityRangePercentilesd", "MaskIntensityd", + "SavitzkyGolaySmoothd", "GaussianSmoothd", "RandGaussianSmoothd", "GaussianSharpend", @@ -75,6 +85,7 @@ "RandKSpaceSpikeNoised", "RandHistogramShiftd", "RandCoarseDropoutd", + "RandCoarseShuffled", "HistogramNormalized", "RandGaussianNoiseD", "RandGaussianNoiseDict", @@ -106,6 +117,8 @@ "ScaleIntensityRangePercentilesDict", "MaskIntensityD", "MaskIntensityDict", + "SavitzkyGolaySmoothD", + "SavitzkyGolaySmoothDict", "GaussianSmoothD", "GaussianSmoothDict", "RandGaussianSmoothD", @@ -126,8 +139,12 @@ "RandRicianNoiseDict", "RandCoarseDropoutD", "RandCoarseDropoutDict", + "RandCoarseShuffleD", + "RandCoarseShuffleDict", "HistogramNormalizeD", "HistogramNormalizeDict", + "RandKSpaceSpikeNoiseD", + "RandKSpaceSpikeNoiseDict", ] @@ -143,6 +160,7 @@ class RandGaussianNoised(RandomizableTransform, MapTransform): prob: Probability to add Gaussian noise. mean: Mean or “centre” of the distribution. std: Standard deviation (spread) of distribution. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ @@ -152,34 +170,36 @@ def __init__( self, keys: KeysCollection, prob: float = 0.1, - mean: Union[Sequence[float], float] = 0.0, + mean: float = 0.0, std: float = 0.1, + dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.mean = ensure_tuple_rep(mean, len(self.keys)) - self.std = std - self._noise: List[np.ndarray] = [] + self.rand_gaussian_noise = RandGaussianNoise(mean=mean, std=std, prob=1.0, dtype=dtype) - def randomize(self, im_shape: Sequence[int]) -> None: - super().randomize(None) - self._noise.clear() - for m in self.mean: - self._noise.append(self.R.normal(m, self.R.uniform(0, self.std), size=im_shape)) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandGaussianNoised": + super().set_random_state(seed, state) + self.rand_gaussian_noise.set_random_state(seed, state) + return self - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - - image_shape = d[self.keys[0]].shape # image shape from the first data key - self.randomize(image_shape) - if len(self._noise) != len(self.keys): - raise RuntimeError("inconsistent noise items and keys.") + self.randomize(None) if not self._do_transform: return d - for key, noise in self.key_iterator(d, self._noise): - noise, *_ = convert_to_dst_type(noise, d[key]) - d[key] = d[key] + noise + + # all the keys share the same random noise + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + + self.rand_gaussian_noise.randomize(d[first_key]) # type: ignore + for key in self.key_iterator(d): + d[key] = self.rand_gaussian_noise(img=d[key], randomize=False) return d @@ -192,9 +212,7 @@ class RandRicianNoised(RandomizableTransform, MapTransform): Args: keys: Keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` - global_prob: Probability to add Rician noise to the dictionary. - prob: Probability to add Rician noise to each item in the dictionary, - once asserted that noise will be added to the dictionary at all. + prob: Probability to add Rician noise to the dictionary. mean: Mean or "centre" of the Gaussian distributions sampled to make up the Rician noise. std: Standard deviation (spread) of the Gaussian distributions sampled @@ -205,38 +223,52 @@ class RandRicianNoised(RandomizableTransform, MapTransform): histogram. sample_std: If True, sample the spread of the Gaussian distributions uniformly from 0 to std. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: Don't raise exception if key is missing. """ backend = RandRicianNoise.backend + @deprecated_arg("global_prob", since="0.7") def __init__( self, keys: KeysCollection, - global_prob: float = 0.1, - prob: float = 1.0, + prob: float = 0.1, mean: Union[Sequence[float], float] = 0.0, std: Union[Sequence[float], float] = 1.0, channel_wise: bool = False, relative: bool = False, sample_std: bool = True, + dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) - RandomizableTransform.__init__(self, global_prob) - self.rand_rician_noise = RandRicianNoise(prob, mean, std, channel_wise, relative, sample_std) + RandomizableTransform.__init__(self, prob) + self.rand_rician_noise = RandRicianNoise( + prob=1.0, + mean=mean, + std=std, + channel_wise=channel_wise, + relative=relative, + sample_std=sample_std, + dtype=dtype, + ) - def set_random_state(self, seed=None, state=None): + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandRicianNoised": super().set_random_state(seed, state) self.rand_rician_noise.set_random_state(seed, state) + return self - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - super().randomize(None) + self.randomize(None) if not self._do_transform: return d + for key in self.key_iterator(d): - d[key] = self.rand_rician_noise(d[key]) + d[key] = self.rand_rician_noise(d[key], randomize=True) return d @@ -302,7 +334,7 @@ class RandShiftIntensityd(RandomizableTransform, MapTransform): Dictionary-based version :py:class:`monai.transforms.RandShiftIntensity`. """ - backend = ShiftIntensity.backend + backend = RandShiftIntensity.backend def __init__( self, @@ -341,36 +373,34 @@ def __init__( MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - if isinstance(offsets, (int, float)): - self.offsets = (min(-offsets, offsets), max(-offsets, offsets)) - else: - if len(offsets) != 2: - raise ValueError("offsets should be a number or pair of numbers.") - self.offsets = (min(offsets), max(offsets)) - self._offset = self.offsets[0] self.factor_key = ensure_tuple_rep(factor_key, len(self.keys)) self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) if len(self.keys) != len(self.meta_keys): raise ValueError("meta_keys should have the same length as keys.") self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self.shifter = ShiftIntensity(self._offset) + self.shifter = RandShiftIntensity(offsets=offsets, prob=1.0) - def randomize(self, data: Optional[Any] = None) -> None: - self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) - super().randomize(None) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandShiftIntensityd": + super().set_random_state(seed, state) + self.shifter.set_random_state(seed, state) + return self def __call__(self, data) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize() + self.randomize(None) if not self._do_transform: return d + + # all the keys share the same random shift factor + self.shifter.randomize(None) for key, factor_key, meta_key, meta_key_postfix in self.key_iterator( d, self.factor_key, self.meta_keys, self.meta_key_postfix ): meta_key = meta_key or f"{key}_{meta_key_postfix}" factor: Optional[float] = d[meta_key].get(factor_key) if meta_key in d else None - offset = self._offset if factor is None else self._offset * factor - d[key] = self.shifter(d[key], offset=offset) + d[key] = self.shifter(d[key], factor=factor, randomize=False) return d @@ -398,7 +428,7 @@ def __init__( nonzero: whether only count non-zero values. channel_wise: if True, calculate on each channel separately. Please ensure that the first dimension represents the channel of the image if True. - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) @@ -416,7 +446,7 @@ class RandStdShiftIntensityd(RandomizableTransform, MapTransform): Dictionary-based version :py:class:`monai.transforms.RandStdShiftIntensity`. """ - backend = StdShiftIntensity.backend + backend = RandStdShiftIntensity.backend def __init__( self, @@ -437,35 +467,32 @@ def __init__( prob: probability of std shift. nonzero: whether only count non-zero values. channel_wise: if True, calculate on each channel separately. - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.shifter = RandStdShiftIntensity( + factors=factors, nonzero=nonzero, channel_wise=channel_wise, dtype=dtype, prob=1.0 + ) - if isinstance(factors, (int, float)): - self.factors = (min(-factors, factors), max(-factors, factors)) - elif len(factors) != 2: - raise ValueError("factors should be a number or pair of numbers.") - else: - self.factors = (min(factors), max(factors)) - self.factor = self.factors[0] - self.nonzero = nonzero - self.channel_wise = channel_wise - self.dtype = dtype - - def randomize(self, data: Optional[Any] = None) -> None: - self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) - super().randomize(None) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandStdShiftIntensityd": + super().set_random_state(seed, state) + self.shifter.set_random_state(seed, state) + return self def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize() + self.randomize(None) if not self._do_transform: return d - shifter = StdShiftIntensity(self.factor, self.nonzero, self.channel_wise, self.dtype) + + # all the keys share the same random shift factor + self.shifter.randomize(None) for key in self.key_iterator(d): - d[key] = shifter(d[key]) + d[key] = self.shifter(d[key], randomize=False) return d @@ -484,6 +511,8 @@ def __init__( minv: Optional[float] = 0.0, maxv: Optional[float] = 1.0, factor: Optional[float] = None, + channel_wise: bool = False, + dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: """ @@ -493,12 +522,15 @@ def __init__( minv: minimum value of output data. maxv: maximum value of output data. factor: factor scale by ``v = v * (1 + factor)``. In order to use - this parameter, please set `minv` and `maxv` into None. + this parameter, please set both `minv` and `maxv` into None. + channel_wise: if True, scale on each channel separately. Please ensure + that the first dimension represents the channel of the image if True. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) - self.scaler = ScaleIntensity(minv, maxv, factor) + self.scaler = ScaleIntensity(minv, maxv, factor, channel_wise, dtype) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -512,13 +544,14 @@ class RandScaleIntensityd(RandomizableTransform, MapTransform): Dictionary-based version :py:class:`monai.transforms.RandScaleIntensity`. """ - backend = ScaleIntensity.backend + backend = RandScaleIntensity.backend def __init__( self, keys: KeysCollection, factors: Union[Tuple[float, float], float], prob: float = 0.1, + dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: """ @@ -529,32 +562,31 @@ def __init__( if single number, factor value is picked from (-factors, factors). prob: probability of rotating. (Default 0.1, with 10% probability it returns a rotated array.) + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.scaler = RandScaleIntensity(factors=factors, dtype=dtype, prob=1.0) - if isinstance(factors, (int, float)): - self.factors = (min(-factors, factors), max(-factors, factors)) - elif len(factors) != 2: - raise ValueError("factors should be a number or pair of numbers.") - else: - self.factors = (min(factors), max(factors)) - self.factor = self.factors[0] - - def randomize(self, data: Optional[Any] = None) -> None: - self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) - super().randomize(None) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandScaleIntensityd": + super().set_random_state(seed, state) + self.scaler.set_random_state(seed, state) + return self def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize() + self.randomize(None) if not self._do_transform: return d - scaler = ScaleIntensity(minv=None, maxv=None, factor=self.factor) + + # all the keys share the same random scale factor + self.scaler.randomize(None) for key in self.key_iterator(d): - d[key] = scaler(d[key]) + d[key] = self.scaler(d[key], randomize=False) return d @@ -563,13 +595,15 @@ class RandBiasFieldd(RandomizableTransform, MapTransform): Dictionary-based version :py:class:`monai.transforms.RandBiasField`. """ + backend = RandBiasField.backend + def __init__( self, keys: KeysCollection, degree: int = 3, coeff_range: Tuple[float, float] = (0.0, 0.1), dtype: DtypeLike = np.float32, - prob: float = 1.0, + prob: float = 0.1, allow_missing_keys: bool = False, ) -> None: """ @@ -579,7 +613,7 @@ def __init__( degree: degree of freedom of the polynomials. The value should be no less than 1. Defaults to 3. coeff_range: range of the random coefficients. Defaults to (0.0, 0.1). - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. prob: probability to do random bias field. allow_missing_keys: don't raise exception if key is missing. @@ -587,18 +621,29 @@ def __init__( MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.rand_bias_field = RandBiasField(degree, coeff_range, dtype, prob) + self.rand_bias_field = RandBiasField(degree=degree, coeff_range=coeff_range, dtype=dtype, prob=1.0) - def randomize(self, data: Optional[Any] = None) -> None: - super().randomize(None) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandBiasFieldd": + super().set_random_state(seed, state) + self.rand_bias_field.set_random_state(seed, state) + return self - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize() + self.randomize(None) if not self._do_transform: return d + + # all the keys share the same random bias factor + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + + self.rand_bias_field.randomize(img_size=d[first_key].shape[1:]) # type: ignore for key in self.key_iterator(d): - d[key] = self.rand_bias_field(d[key]) + d[key] = self.rand_bias_field(d[key], randomize=False) return d @@ -614,9 +659,9 @@ class NormalizeIntensityd(MapTransform): subtrahend: the amount to subtract by (usually the mean) divisor: the amount to divide by (usually the standard deviation) nonzero: whether only normalize non-zero values. - channel_wise: if using calculated mean and std, calculate on each channel separately - or calculate on the entire image directly. - dtype: output data type, defaults to float32. + channel_wise: if True, calculate on each channel separately, otherwise, calculate on + the entire image directly. default to False. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ @@ -655,6 +700,8 @@ class ThresholdIntensityd(MapTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = ThresholdIntensity.backend + def __init__( self, keys: KeysCollection, @@ -666,7 +713,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.filter = ThresholdIntensity(threshold, above, cval) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.filter(d[key]) @@ -685,23 +732,27 @@ class ScaleIntensityRanged(MapTransform): b_min: intensity target range min. b_max: intensity target range max. clip: whether to perform clip after scaling. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ + backend = ScaleIntensityRange.backend + def __init__( self, keys: KeysCollection, a_min: float, a_max: float, - b_min: float, - b_max: float, + b_min: Optional[float] = None, + b_max: Optional[float] = None, clip: bool = False, + dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.scaler = ScaleIntensityRange(a_min, a_max, b_min, b_max, clip) + self.scaler = ScaleIntensityRange(a_min, a_max, b_min, b_max, clip, dtype) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.scaler(d[key]) @@ -722,11 +773,13 @@ class AdjustContrastd(MapTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = AdjustContrast.backend + def __init__(self, keys: KeysCollection, gamma: float, allow_missing_keys: bool = False) -> None: super().__init__(keys, allow_missing_keys) self.adjuster = AdjustContrast(gamma) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.adjuster(d[key]) @@ -749,6 +802,8 @@ class RandAdjustContrastd(RandomizableTransform, MapTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = RandAdjustContrast.backend + def __init__( self, keys: KeysCollection, @@ -758,34 +813,25 @@ def __init__( ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.adjuster = RandAdjustContrast(gamma=gamma, prob=1.0) - if isinstance(gamma, (int, float)): - if gamma <= 0.5: - raise ValueError( - "if gamma is single number, must greater than 0.5 and value is picked from (0.5, gamma)" - ) - self.gamma = (0.5, gamma) - elif len(gamma) != 2: - raise ValueError("gamma should be a number or pair of numbers.") - else: - self.gamma = (min(gamma), max(gamma)) - - self.gamma_value: Optional[float] = None - - def randomize(self, data: Optional[Any] = None) -> None: - super().randomize(None) - self.gamma_value = self.R.uniform(low=self.gamma[0], high=self.gamma[1]) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandAdjustContrastd": + super().set_random_state(seed, state) + self.adjuster.set_random_state(seed, state) + return self - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize() - if self.gamma_value is None: - raise RuntimeError("gamma_value is not set.") + self.randomize(None) if not self._do_transform: return d - adjuster = AdjustContrast(self.gamma_value) + + # all the keys share the same random gamma value + self.adjuster.randomize(None) for key in self.key_iterator(d): - d[key] = adjuster(d[key]) + d[key] = self.adjuster(d[key], randomize=False) return d @@ -802,24 +848,31 @@ class ScaleIntensityRangePercentilesd(MapTransform): b_max: intensity target range max. clip: whether to perform clip after scaling. relative: whether to scale to the corresponding percentiles of [b_min, b_max] + channel_wise: if True, compute intensity percentile and normalize every channel separately. + default to False. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ + backend = ScaleIntensityRangePercentiles.backend + def __init__( self, keys: KeysCollection, lower: float, upper: float, - b_min: float, - b_max: float, + b_min: Optional[float], + b_max: Optional[float], clip: bool = False, relative: bool = False, + channel_wise: bool = False, + dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative) + self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative, channel_wise, dtype) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.scaler(d[key]) @@ -847,10 +900,12 @@ class MaskIntensityd(MapTransform): """ + backend = MaskIntensity.backend + def __init__( self, keys: KeysCollection, - mask_data: Optional[np.ndarray] = None, + mask_data: Optional[NdarrayOrTensor] = None, mask_key: Optional[str] = None, select_fn: Callable = is_positive, allow_missing_keys: bool = False, @@ -859,13 +914,50 @@ def __init__( self.converter = MaskIntensity(mask_data=mask_data, select_fn=select_fn) self.mask_key = mask_key if mask_data is None else None - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key], d[self.mask_key]) if self.mask_key is not None else self.converter(d[key]) return d +class SavitzkyGolaySmoothd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.SavitzkyGolaySmooth`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + window_length: length of the filter window, must be a positive odd integer. + order: order of the polynomial to fit to each window, must be less than ``window_length``. + axis: optional axis along which to apply the filter kernel. Default 1 (first spatial dimension). + mode: optional padding mode, passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` + or ``'circular'``. default: ``'zeros'``. See ``torch.nn.Conv1d()`` for more information. + allow_missing_keys: don't raise exception if key is missing. + + """ + + backend = SavitzkyGolaySmooth.backend + + def __init__( + self, + keys: KeysCollection, + window_length: int, + order: int, + axis: int = 1, + mode: str = "zeros", + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.converter = SavitzkyGolaySmooth(window_length=window_length, order=order, axis=axis, mode=mode) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.converter(d[key]) + return d + + class GaussianSmoothd(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.GaussianSmooth`. @@ -882,6 +974,8 @@ class GaussianSmoothd(MapTransform): """ + backend = GaussianSmooth.backend + def __init__( self, keys: KeysCollection, @@ -892,7 +986,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = GaussianSmooth(sigma, approx=approx) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -916,6 +1010,8 @@ class RandGaussianSmoothd(RandomizableTransform, MapTransform): """ + backend = RandGaussianSmooth.backend + def __init__( self, keys: KeysCollection, @@ -928,25 +1024,27 @@ def __init__( ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.sigma_x, self.sigma_y, self.sigma_z = sigma_x, sigma_y, sigma_z - self.approx = approx - - self.x, self.y, self.z = self.sigma_x[0], self.sigma_y[0], self.sigma_z[0] + self.rand_smooth = RandGaussianSmooth( + sigma_x=sigma_x, sigma_y=sigma_y, sigma_z=sigma_z, approx=approx, prob=1.0 + ) - def randomize(self, data: Optional[Any] = None) -> None: - super().randomize(None) - self.x = self.R.uniform(low=self.sigma_x[0], high=self.sigma_x[1]) - self.y = self.R.uniform(low=self.sigma_y[0], high=self.sigma_y[1]) - self.z = self.R.uniform(low=self.sigma_z[0], high=self.sigma_z[1]) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandGaussianSmoothd": + super().set_random_state(seed, state) + self.rand_smooth.set_random_state(seed, state) + return self - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize() + self.randomize(None) if not self._do_transform: return d + + # all the keys share the same random sigma + self.rand_smooth.randomize(None) for key in self.key_iterator(d): - sigma = ensure_tuple_size(tup=(self.x, self.y, self.z), dim=d[key].ndim - 1) - d[key] = GaussianSmooth(sigma=sigma, approx=self.approx)(d[key]) + d[key] = self.rand_smooth(d[key], randomize=False) return d @@ -970,6 +1068,8 @@ class GaussianSharpend(MapTransform): """ + backend = GaussianSharpen.backend + def __init__( self, keys: KeysCollection, @@ -982,7 +1082,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = GaussianSharpen(sigma1, sigma2, alpha, approx=approx) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -1013,6 +1113,8 @@ class RandGaussianSharpend(RandomizableTransform, MapTransform): """ + backend = RandGaussianSharpen.backend + def __init__( self, keys: KeysCollection, @@ -1029,37 +1131,35 @@ def __init__( ): MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.sigma1_x = sigma1_x - self.sigma1_y = sigma1_y - self.sigma1_z = sigma1_z - self.sigma2_x = sigma2_x - self.sigma2_y = sigma2_y - self.sigma2_z = sigma2_z - self.alpha = alpha - self.approx = approx - - def randomize(self, data: Optional[Any] = None) -> None: - super().randomize(None) - self.x1 = self.R.uniform(low=self.sigma1_x[0], high=self.sigma1_x[1]) - self.y1 = self.R.uniform(low=self.sigma1_y[0], high=self.sigma1_y[1]) - self.z1 = self.R.uniform(low=self.sigma1_z[0], high=self.sigma1_z[1]) - sigma2_x = (self.sigma2_x, self.x1) if not isinstance(self.sigma2_x, Iterable) else self.sigma2_x - sigma2_y = (self.sigma2_y, self.y1) if not isinstance(self.sigma2_y, Iterable) else self.sigma2_y - sigma2_z = (self.sigma2_z, self.z1) if not isinstance(self.sigma2_z, Iterable) else self.sigma2_z - self.x2 = self.R.uniform(low=sigma2_x[0], high=sigma2_x[1]) - self.y2 = self.R.uniform(low=sigma2_y[0], high=sigma2_y[1]) - self.z2 = self.R.uniform(low=sigma2_z[0], high=sigma2_z[1]) - self.a = self.R.uniform(low=self.alpha[0], high=self.alpha[1]) + self.rand_sharpen = RandGaussianSharpen( + sigma1_x=sigma1_x, + sigma1_y=sigma1_y, + sigma1_z=sigma1_z, + sigma2_x=sigma2_x, + sigma2_y=sigma2_y, + sigma2_z=sigma2_z, + alpha=alpha, + approx=approx, + prob=1.0, + ) - def __call__(self, data): + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandGaussianSharpend": + super().set_random_state(seed, state) + self.rand_sharpen.set_random_state(seed, state) + return self + + def __call__(self, data: Dict[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize() + self.randomize(None) if not self._do_transform: return d + + # all the keys share the same random sigma1, sigma2, etc. + self.rand_sharpen.randomize(None) for key in self.key_iterator(d): - sigma1 = ensure_tuple_size(tup=(self.x1, self.y1, self.z1), dim=d[key].ndim - 1) - sigma2 = ensure_tuple_size(tup=(self.x2, self.y2, self.z2), dim=d[key].ndim - 1) - d[key] = GaussianSharpen(sigma1=sigma1, sigma2=sigma2, alpha=self.a, approx=self.approx)(d[key]) + d[key] = self.rand_sharpen(d[key], randomize=False) return d @@ -1078,6 +1178,8 @@ class RandHistogramShiftd(RandomizableTransform, MapTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = RandHistogramShift.backend + def __init__( self, keys: KeysCollection, @@ -1087,38 +1189,25 @@ def __init__( ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - if isinstance(num_control_points, int): - if num_control_points <= 2: - raise ValueError("num_control_points should be greater than or equal to 3") - self.num_control_points = (num_control_points, num_control_points) - else: - if len(num_control_points) != 2: - raise ValueError("num_control points should be a number or a pair of numbers") - if min(num_control_points) <= 2: - raise ValueError("num_control_points should be greater than or equal to 3") - self.num_control_points = (min(num_control_points), max(num_control_points)) - - def randomize(self, data: Optional[Any] = None) -> None: - super().randomize(None) - num_control_point = self.R.randint(self.num_control_points[0], self.num_control_points[1] + 1) - self.reference_control_points = np.linspace(0, 1, num_control_point) - self.floating_control_points = np.copy(self.reference_control_points) - for i in range(1, num_control_point - 1): - self.floating_control_points[i] = self.R.uniform( - self.floating_control_points[i - 1], self.floating_control_points[i + 1] - ) - - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + self.shifter = RandHistogramShift(num_control_points=num_control_points, prob=1.0) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandHistogramShiftd": + super().set_random_state(seed, state) + self.shifter.set_random_state(seed, state) + return self + + def __call__(self, data: Dict[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize() + self.randomize(None) if not self._do_transform: return d + + # all the keys share the same random shift params + self.shifter.randomize(None) for key in self.key_iterator(d): - img_min, img_max = d[key].min(), d[key].max() - reference_control_points_scaled = self.reference_control_points * (img_max - img_min) + img_min - floating_control_points_scaled = self.floating_control_points * (img_max - img_min) + img_min - dtype = d[key].dtype - d[key] = np.interp(d[key], reference_control_points_scaled, floating_control_points_scaled).astype(dtype) + d[key] = self.shifter(d[key], randomize=False) return d @@ -1144,56 +1233,43 @@ class RandGibbsNoised(RandomizableTransform, MapTransform): values in the interval [0,1] with alpha = 0 acting as the identity mapping. If a length-2 list is given as [a,b] then the value of alpha will be sampled uniformly from the interval [a,b]. - as_tensor_output: if true return torch.Tensor, else return np.array. default: True. allow_missing_keys: do not raise exception if key is missing. """ + backend = RandGibbsNoise.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0), - as_tensor_output: bool = True, allow_missing_keys: bool = False, + as_tensor_output: bool = True, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob=prob) - self.alpha = alpha - self.sampled_alpha = -1.0 # stores last alpha sampled by randomize() - self.as_tensor_output = as_tensor_output + self.rand_gibbs_noise = RandGibbsNoise(alpha=alpha, prob=1.0) - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandGibbsNoised": + super().set_random_state(seed, state) + self.rand_gibbs_noise.set_random_state(seed, state) + return self + def __call__(self, data: Dict[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self._randomize(None) - - for i, key in enumerate(self.key_iterator(d)): - if self._do_transform: - if i == 0: - transform = GibbsNoise(self.sampled_alpha, self.as_tensor_output) - d[key] = transform(d[key]) - else: - if isinstance(d[key], np.ndarray) and self.as_tensor_output: - d[key] = torch.Tensor(d[key]) - elif isinstance(d[key], torch.Tensor) and not self.as_tensor_output: - d[key] = self._to_numpy(d[key]) - return d - - def _randomize(self, _: Any) -> None: - """ - (1) Set random variable to apply the transform. - (2) Get alpha from uniform distribution. - """ - super().randomize(None) - self.sampled_alpha = self.R.uniform(self.alpha[0], self.alpha[1]) + self.randomize(None) + if not self._do_transform: + return d - def _to_numpy(self, d: Union[torch.Tensor, np.ndarray]) -> np.ndarray: - if isinstance(d, torch.Tensor): - d_numpy: np.ndarray = d.cpu().detach().numpy() - return d_numpy + # all the keys share the same random noise params + self.rand_gibbs_noise.randomize(None) + for key in self.key_iterator(d): + d[key] = self.rand_gibbs_noise(d[key], randomize=False) + return d class GibbsNoised(MapTransform): @@ -1212,20 +1288,20 @@ class GibbsNoised(MapTransform): you need to transform. alpha (float): Parametrizes the intensity of the Gibbs noise filter applied. Takes values in the interval [0,1] with alpha = 0 acting as the identity mapping. - as_tensor_output: if true return torch.Tensor, else return np.array. default: True. allow_missing_keys: do not raise exception if key is missing. """ + backend = GibbsNoise.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( - self, keys: KeysCollection, alpha: float = 0.5, as_tensor_output: bool = True, allow_missing_keys: bool = False + self, keys: KeysCollection, alpha: float = 0.5, allow_missing_keys: bool = False, as_tensor_output: bool = True ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) - self.transform = GibbsNoise(alpha, as_tensor_output) + self.transform = GibbsNoise(alpha) - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): @@ -1264,8 +1340,6 @@ class KSpaceSpikeNoised(MapTransform): receive a sequence of intensities. This value should be tested as it is data-dependent. The default values are the 2.5 the mean of the log-intensity for each channel. - as_tensor_output: if ``True`` return torch.Tensor, else return np.array. - Default: ``True``. allow_missing_keys: do not raise exception if key is missing. Example: @@ -1276,21 +1350,22 @@ class KSpaceSpikeNoised(MapTransform): with `log-intensity = 14`. """ + backend = KSpaceSpikeNoise.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, loc: Union[Tuple, Sequence[Tuple]], k_intensity: Optional[Union[Sequence[float], float]] = None, - as_tensor_output: bool = True, allow_missing_keys: bool = False, + as_tensor_output: bool = True, ) -> None: super().__init__(keys, allow_missing_keys) - self.transform = KSpaceSpikeNoise(loc, k_intensity, as_tensor_output) + self.transform = KSpaceSpikeNoise(loc, k_intensity) - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: """ Args: data: Expects image/label to have dimensions (C, H, W) or @@ -1320,110 +1395,66 @@ class RandKSpaceSpikeNoised(RandomizableTransform, MapTransform): Args: keys: "image", "label", or ["image", "label"] depending on which data you need to transform. - global_prob: probability of applying transform to the dictionary. prob: probability to add spike artifact to each item in the dictionary provided it is realized that the noise will be applied to the dictionary. - intensity_ranges: Dictionary with intensity - ranges to sample for each key. Given a dictionary value of `(a, b)` the - transform will sample the log-intensity from the interval `(a, b)` uniformly for all - channels of the respective key. If a sequence of intevals `((a0, b0), (a1, b1), ...)` - is given, then the transform will sample from each interval for each - respective channel. In the second case, the number of 2-tuples must - match the number of channels. Default ranges is `(0.95x, 1.10x)` - where `x` is the mean log-intensity for each channel. - channel_wise: treat each channel independently. True by - default. - common_sampling: If ``True`` same values for location and log-intensity - will be sampled for the image and label. - common_seed: Seed to be used in case ``common_sampling = True``. - as_tensor_output: if ``True`` return torch.Tensor, else return - np.array. Default: ``True``. + intensity_range: pass a tuple (a, b) to sample the log-intensity from the interval (a, b) + uniformly for all channels. Or pass sequence of intervals + ((a0, b0), (a1, b1), ...) to sample for each respective channel. + In the second case, the number of 2-tuples must match the number of channels. + Default ranges is `(0.95x, 1.10x)` where `x` is the mean + log-intensity for each channel. + channel_wise: treat each channel independently. True by default. allow_missing_keys: do not raise exception if key is missing. Example: To apply `k`-space spikes randomly on the image only, with probability 0.5, and log-intensity sampled from the interval [13, 15] for each channel independently, one uses - ``RandKSpaceSpikeNoised("image", prob=0.5, intensity_ranges={"image":(13,15)}, channel_wise=True)``. + ``RandKSpaceSpikeNoised("image", prob=0.5, intensity_ranges=(13, 15), channel_wise=True)``. """ + backend = RandKSpaceSpikeNoise.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") + @deprecated_arg(name="common_sampling", since="0.6") + @deprecated_arg(name="common_seed", since="0.6") + @deprecated_arg(name="global_prob", since="0.6") def __init__( self, keys: KeysCollection, global_prob: float = 1.0, prob: float = 0.1, - intensity_ranges: Optional[Mapping[Hashable, Sequence[Union[Sequence[float], float]]]] = None, + intensity_range: Optional[Sequence[Union[Sequence[float], float]]] = None, channel_wise: bool = True, common_sampling: bool = False, common_seed: int = 42, - as_tensor_output: bool = True, allow_missing_keys: bool = False, + as_tensor_output: bool = True, ): - MapTransform.__init__(self, keys, allow_missing_keys) - RandomizableTransform.__init__(self, global_prob) - - self.common_sampling = common_sampling - self.common_seed = common_seed - self.as_tensor_output = as_tensor_output - # the spikes artifact is amplitude dependent so we instantiate one per key - self.transforms = {} - if isinstance(intensity_ranges, Mapping): - for k in self.keys: - self.transforms[k] = RandKSpaceSpikeNoise( - prob, intensity_ranges[k], channel_wise, self.as_tensor_output - ) - else: - for k in self.keys: - self.transforms[k] = RandKSpaceSpikeNoise(prob, None, channel_wise, self.as_tensor_output) - - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: - """ - Args: - data: Expects image/label to have dimensions (C, H, W) or - (C, H, W, D), where C is the channel. - """ - d = dict(data) - super().randomize(None) - - # In case the same spikes are desired for both image and label. - if self.common_sampling: - for k in self.keys: - self.transforms[k].set_random_state(self.common_seed) - - for key, t in self.key_iterator(d, self.transforms): - if self._do_transform: - d[key] = self.transforms[t](d[key]) - else: - if isinstance(d[key], np.ndarray) and self.as_tensor_output: - d[key] = torch.Tensor(d[key]) - elif isinstance(d[key], torch.Tensor) and not self.as_tensor_output: - d[key] = self._to_numpy(d[key]) - return d - - def set_rand_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> None: - """ - Set the random state locally to control the randomness. - User should use this method instead of ``set_random_state``. + RandomizableTransform.__init__(self, prob=prob) + self.rand_noise = RandKSpaceSpikeNoise(prob=1.0, intensity_range=intensity_range, channel_wise=channel_wise) - Args: - seed: set the random state with an integer seed. - state: set the random state with a `np.random.RandomState` object.""" + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandKSpaceSpikeNoised": + super().set_random_state(seed, state) + self.rand_noise.set_random_state(seed, state) + return self - self.set_random_state(seed, state) - for key in self.keys: - self.transforms[key].set_random_state(seed, state) + def __call__(self, data: Dict[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + self.randomize(None) + if not self._do_transform: + return d - def _to_numpy(self, d: Union[torch.Tensor, np.ndarray]) -> np.ndarray: - if isinstance(d, torch.Tensor): - d_numpy: np.ndarray = d.cpu().detach().numpy() - return d_numpy + for key in self.key_iterator(d): + d[key] = self.rand_noise(d[key], randomize=True) + return d -class RandCoarseDropoutd(Randomizable, MapTransform): +class RandCoarseDropoutd(RandomizableTransform, MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.RandCoarseDropout`. Expect all the data specified by `keys` have same spatial shape and will randomly dropout the same regions @@ -1455,6 +1486,8 @@ class RandCoarseDropoutd(Randomizable, MapTransform): """ + backend = RandCoarseDropout.backend + def __init__( self, keys: KeysCollection, @@ -1468,6 +1501,7 @@ def __init__( allow_missing_keys: bool = False, ): MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob=prob) self.dropper = RandCoarseDropout( holes=holes, spatial_size=spatial_size, @@ -1475,19 +1509,99 @@ def __init__( fill_value=fill_value, max_holes=max_holes, max_spatial_size=max_spatial_size, - prob=prob, + prob=1.0, + ) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandCoarseDropoutd": + super().set_random_state(seed, state) + self.dropper.set_random_state(seed, state) + return self + + def __call__(self, data): + d = dict(data) + self.randomize(None) + if not self._do_transform: + return d + + # expect all the specified keys have same spatial shape and share same random holes + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + + self.dropper.randomize(d[first_key].shape[1:]) # type: ignore + for key in self.key_iterator(d): + d[key] = self.dropper(img=d[key], randomize=False) + + return d + + +class RandCoarseShuffled(RandomizableTransform, MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.RandCoarseShuffle`. + Expect all the data specified by `keys` have same spatial shape and will randomly dropout the same regions + for every key, if want to shuffle different regions for every key, please use this transform separately. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to + randomly select the expected number of regions. + spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg + as the minimum spatial size to randomly select size for every region. + if some components of the `spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + max_holes: if not None, define the maximum number to randomly select the expected number of regions. + max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region. + if some components of the `max_spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + prob: probability of applying the transform. + allow_missing_keys: don't raise exception if key is missing. + + """ + + backend = RandCoarseShuffle.backend + + def __init__( + self, + keys: KeysCollection, + holes: int, + spatial_size: Union[Sequence[int], int], + max_holes: Optional[int] = None, + max_spatial_size: Optional[Union[Sequence[int], int]] = None, + prob: float = 0.1, + allow_missing_keys: bool = False, + ): + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob=prob) + self.shuffle = RandCoarseShuffle( + holes=holes, spatial_size=spatial_size, max_holes=max_holes, max_spatial_size=max_spatial_size, prob=1.0 ) - def randomize(self, img_size: Sequence[int]) -> None: - self.dropper.randomize(img_size=img_size) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandCoarseShuffled": + super().set_random_state(seed, state) + self.shuffle.set_random_state(seed, state) + return self def __call__(self, data): d = dict(data) - # expect all the specified keys have same spatial shape - self.randomize(d[self.keys[0]].shape[1:]) - if self.dropper._do_transform: - for key in self.key_iterator(d): - d[key] = self.dropper(img=d[key]) + self.randomize(None) + if not self._do_transform: + return d + + # expect all the specified keys have same spatial shape and share same random holes + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + + self.shuffle.randomize(d[first_key].shape[1:]) # type: ignore + for key in self.key_iterator(d): + d[key] = self.shuffle(img=d[key], randomize=False) return d @@ -1507,18 +1621,20 @@ class HistogramNormalized(MapTransform): only points at which `mask==True` are used for the equalization. can also provide the mask by `mask_key` at runtime. mask_key: if mask is None, will try to get the mask with `mask_key`. - dtype: data type of the output, default to `float32`. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: do not raise exception if key is missing. """ + backend = HistogramNormalize.backend + def __init__( self, keys: KeysCollection, num_bins: int = 256, min: int = 0, max: int = 255, - mask: Optional[np.ndarray] = None, + mask: Optional[NdarrayOrTensor] = None, mask_key: Optional[str] = None, dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, @@ -1527,7 +1643,7 @@ def __init__( self.transform = HistogramNormalize(num_bins=num_bins, min=min, max=max, mask=mask, dtype=dtype) self.mask_key = mask_key if mask is None else None - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.transform(d[key], d[self.mask_key]) if self.mask_key is not None else self.transform(d[key]) @@ -1551,6 +1667,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda RandAdjustContrastD = RandAdjustContrastDict = RandAdjustContrastd ScaleIntensityRangePercentilesD = ScaleIntensityRangePercentilesDict = ScaleIntensityRangePercentilesd MaskIntensityD = MaskIntensityDict = MaskIntensityd +SavitzkyGolaySmoothD = SavitzkyGolaySmoothDict = SavitzkyGolaySmoothd GaussianSmoothD = GaussianSmoothDict = GaussianSmoothd RandGaussianSmoothD = RandGaussianSmoothDict = RandGaussianSmoothd GaussianSharpenD = GaussianSharpenDict = GaussianSharpend @@ -1562,3 +1679,4 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda RandKSpaceSpikeNoiseD = RandKSpaceSpikeNoiseDict = RandKSpaceSpikeNoised RandCoarseDropoutD = RandCoarseDropoutDict = RandCoarseDropoutd HistogramNormalizeD = HistogramNormalizeDict = HistogramNormalized +RandCoarseShuffleD = RandCoarseShuffleDict = RandCoarseShuffled diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 58f3526086..c8bfeeca05 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -8,18 +8,80 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Hashable, Optional, Tuple +import os +from typing import Hashable, Mapping, Optional, Tuple import torch -from monai.transforms.transform import RandomizableTransform, Transform -from monai.utils.enums import InverseKeys +from monai.transforms.transform import Transform +from monai.utils.enums import TraceKeys + +__all__ = ["TraceableTransform", "InvertibleTransform"] + + +class TraceableTransform(Transform): + """ + Maintains a stack of applied transforms. The stack is inserted as pairs of + `trace_key: list of transforms` to each data dictionary. + + The ``__call__`` method of this transform class must be implemented so + that the transformation information for each key is stored when + ``__call__`` is called. If the transforms were applied to keys "image" and + "label", there will be two extra keys in the dictionary: "image_transforms" + and "label_transforms" (based on `TraceKeys.KEY_SUFFIX`). Each list + contains a list of the transforms applied to that key. + + The information in ``data[key_transform]`` will be compatible with the + default collate since it only stores strings, numbers and arrays. -__all__ = ["InvertibleTransform"] + `tracing` could be enabled by `self.set_tracing` or setting + `MONAI_TRACE_TRANSFORM` when initializing the class. + """ + + tracing = False if os.environ.get("MONAI_TRACE_TRANSFORM", "1") == "0" else True + + def set_tracing(self, tracing: bool) -> None: + """Set whether to trace transforms.""" + self.tracing = tracing + + @staticmethod + def trace_key(key: Hashable = None): + """The key to store the stack of applied transforms.""" + if key is None: + return TraceKeys.KEY_SUFFIX + return str(key) + TraceKeys.KEY_SUFFIX + + def push_transform( + self, data: Mapping, key: Hashable = None, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None + ) -> None: + """PUsh to a stack of applied transforms for that key.""" + if not self.tracing: + return + info = {TraceKeys.CLASS_NAME: self.__class__.__name__, TraceKeys.ID: id(self)} + if orig_size is not None: + info[TraceKeys.ORIG_SIZE] = orig_size + elif key in data and hasattr(data[key], "shape"): + info[TraceKeys.ORIG_SIZE] = data[key].shape[1:] + if extra_info is not None: + info[TraceKeys.EXTRA_INFO] = extra_info + # If class is randomizable transform, store whether the transform was actually performed (based on `prob`) + if hasattr(self, "_do_transform"): # RandomizableTransform + info[TraceKeys.DO_TRANSFORM] = self._do_transform # type: ignore + # If this is the first, create list + if self.trace_key(key) not in data: + if not isinstance(data, dict): + data = dict(data) + data[self.trace_key(key)] = [] + data[self.trace_key(key)].append(info) + + def pop_transform(self, data: Mapping, key: Hashable = None): + """Remove the most recent applied transform.""" + if not self.tracing: + return + return data.get(self.trace_key(key), []).pop() -class InvertibleTransform(Transform): +class InvertibleTransform(TraceableTransform): """Classes for invertible transforms. This class exists so that an ``invert`` method can be implemented. This allows, for @@ -27,28 +89,21 @@ class InvertibleTransform(Transform): and after be returned to their original size before saving to file for comparison in an external viewer. - When the ``__call__`` method is called, the transformation information for each key is - stored. If the transforms were applied to keys "image" and "label", there will be two - extra keys in the dictionary: "image_transforms" and "label_transforms". Each list - contains a list of the transforms applied to that key. When the ``inverse`` method is - called, the inverse is called on each key individually, which allows for different - parameters being passed to each label (e.g., different interpolation for image and - label). + When the ``inverse`` method is called: - When the ``inverse`` method is called, the inverse transforms are applied in a last- - in-first-out order. As the inverse is applied, its entry is removed from the list - detailing the applied transformations. That is to say that during the forward pass, - the list of applied transforms grows, and then during the inverse it shrinks back - down to an empty list. + - the inverse is called on each key individually, which allows for + different parameters being passed to each label (e.g., different + interpolation for image and label). - The information in ``data[key_transform]`` will be compatible with the default collate - since it only stores strings, numbers and arrays. + - the inverse transforms are applied in a last- in-first-out order. As + the inverse is applied, its entry is removed from the list detailing + the applied transformations. That is to say that during the forward + pass, the list of applied transforms grows, and then during the + inverse it shrinks back down to an empty list. We currently check that the ``id()`` of the transform is the same in the forward and inverse directions. This is a useful check to ensure that the inverses are being - processed in the correct order. However, this may cause issues if the ``id()`` of the - object changes (such as multiprocessing on Windows). If you feel this issue affects - you, please raise a GitHub issue. + processed in the correct order. Note to developers: When converting a transform to an invertible transform, you need to: @@ -63,55 +118,25 @@ class InvertibleTransform(Transform): """ - def push_transform( - self, - data: dict, - key: Hashable, - extra_info: Optional[dict] = None, - orig_size: Optional[Tuple] = None, - ) -> None: - """Append to list of applied transforms for that key.""" - key_transform = str(key) + InverseKeys.KEY_SUFFIX - info = { - InverseKeys.CLASS_NAME: self.__class__.__name__, - InverseKeys.ID: id(self), - } - if orig_size is not None: - info[InverseKeys.ORIG_SIZE] = orig_size - elif hasattr(data[key], "shape"): - info[InverseKeys.ORIG_SIZE] = data[key].shape[1:] - if extra_info is not None: - info[InverseKeys.EXTRA_INFO] = extra_info - # If class is randomizable transform, store whether the transform was actually performed (based on `prob`) - if isinstance(self, RandomizableTransform): - info[InverseKeys.DO_TRANSFORM] = self._do_transform - # If this is the first, create list - if key_transform not in data: - data[key_transform] = [] - data[key_transform].append(info) - - def check_transforms_match(self, transform: dict) -> None: + def check_transforms_match(self, transform: Mapping) -> None: """Check transforms are of same instance.""" - if transform[InverseKeys.ID] == id(self): + xform_name = transform.get(TraceKeys.CLASS_NAME, "") + xform_id = transform.get(TraceKeys.ID, "") + if xform_id == id(self): return # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID) - if ( - torch.multiprocessing.get_start_method() in ("spawn", None) - and transform[InverseKeys.CLASS_NAME] == self.__class__.__name__ - ): + if torch.multiprocessing.get_start_method() in ("spawn", None) and xform_name == self.__class__.__name__: return - raise RuntimeError("Should inverse most recently applied invertible transform first") + raise RuntimeError(f"Error inverting the most recently applied invertible transform {xform_name} {xform_id}.") - def get_most_recent_transform(self, data: dict, key: Hashable) -> dict: + def get_most_recent_transform(self, data: Mapping, key: Hashable = None): """Get most recent transform.""" - transform = dict(data[str(key) + InverseKeys.KEY_SUFFIX][-1]) + if not self.tracing: + raise RuntimeError("Transform Tracing must be enabled to get the most recent transform.") + transform = data[self.trace_key(key)][-1] self.check_transforms_match(transform) return transform - def pop_transform(self, data: dict, key: Hashable) -> None: - """Remove most recent transform.""" - data[str(key) + InverseKeys.KEY_SUFFIX].pop() - def inverse(self, data: dict) -> dict: """ Inverse of ``__call__``. diff --git a/monai/transforms/inverse_batch_transform.py b/monai/transforms/inverse_batch_transform.py index d9c6790840..ae0317cea8 100644 --- a/monai/transforms/inverse_batch_transform.py +++ b/monai/transforms/inverse_batch_transform.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,7 +17,7 @@ from monai.config import KeysCollection from monai.data.dataloader import DataLoader -from monai.data.utils import decollate_batch, no_collation, pad_list_data_collate, rep_scalar_to_batch +from monai.data.utils import decollate_batch, no_collation, pad_list_data_collate from monai.transforms.croppad.batch import PadListDataCollate from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Transform @@ -27,12 +27,7 @@ class _BatchInverseDataset(Dataset): - def __init__( - self, - data: Sequence[Any], - transform: InvertibleTransform, - pad_collation_used: bool, - ) -> None: + def __init__(self, data: Sequence[Any], transform: InvertibleTransform, pad_collation_used: bool) -> None: self.data = data self.invertible_transform = transform self.pad_collation_used = pad_collation_used @@ -65,6 +60,8 @@ def __init__( collate_fn: Optional[Callable] = no_collation, num_workers: Optional[int] = 0, detach: bool = True, + pad_batch: bool = True, + fill_value=None, ) -> None: """ Args: @@ -78,6 +75,10 @@ def __init__( if set to `None`, use the `num_workers` of the transform data loader. detach: whether to detach the tensors. Scalars tensors will be detached into number types instead of torch tensors. + pad_batch: when the items in a batch indicate different batch size, + whether to pad all the sequences to the longest. + If False, the batch size will be the length of the shortest sequence. + fill_value: the value to fill the padded sequences when `pad_batch=True`. """ self.transform = transform @@ -85,10 +86,12 @@ def __init__( self.num_workers = loader.num_workers if num_workers is None else num_workers self.collate_fn = collate_fn self.detach = detach + self.pad_batch = pad_batch + self.fill_value = fill_value self.pad_collation_used = loader.collate_fn.__doc__ == pad_list_data_collate.__doc__ def __call__(self, data: Dict[str, Any]) -> Any: - decollated_data = decollate_batch(data, detach=self.detach) + decollated_data = decollate_batch(data, detach=self.detach, pad=self.pad_batch, fill_value=self.fill_value) inv_ds = _BatchInverseDataset(decollated_data, self.transform, self.pad_collation_used) inv_loader = DataLoader( inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn @@ -99,21 +102,25 @@ def __call__(self, data: Dict[str, Any]) -> Any: re_str = str(re) if "equal size" in re_str: re_str += "\nMONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`." - raise RuntimeError(re_str) + raise RuntimeError(re_str) from re class Decollated(MapTransform): """ - Decollate a batch of data, if input a dictionary, it can also support to only decollate specified keys. - Note that unlike most MapTransforms, it will delete other keys not specified and if keys=None, will decollate - all the data in the input. - And it replicates the scalar values to every item of the decollated list. + Decollate a batch of data. If input is a dictionary, it also supports to only decollate specified keys. + Note that unlike most MapTransforms, it will delete the other keys that are not specified. + if `keys=None`, it will decollate all the data in the input. + It replicates the scalar values to every item of the decollated list. Args: keys: keys of the corresponding items to decollate, note that it will delete other keys not specified. if None, will decollate all the keys. see also: :py:class:`monai.transforms.compose.MapTransform`. detach: whether to detach the tensors. Scalars tensors will be detached into number types instead of torch tensors. + pad_batch: when the items in a batch indicate different batch size, + whether to pad all the sequences to the longest. + If False, the batch size will be the length of the shortest sequence. + fill_value: the value to fill the padded sequences when `pad_batch=True`. allow_missing_keys: don't raise exception if key is missing. """ @@ -122,10 +129,14 @@ def __init__( self, keys: Optional[KeysCollection] = None, detach: bool = True, + pad_batch: bool = True, + fill_value=None, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) self.detach = detach + self.pad_batch = pad_batch + self.fill_value = fill_value def __call__(self, data: Union[Dict, List]): d: Union[Dict, List] @@ -139,4 +150,4 @@ def __call__(self, data: Union[Dict, List]): for key in self.key_iterator(data): d[key] = data[key] - return decollate_batch(rep_scalar_to_batch(d), detach=self.detach) + return decollate_batch(d, detach=self.detach, pad=self.pad_batch, fill_value=self.fill_value) diff --git a/monai/transforms/io/__init__.py b/monai/transforms/io/__init__.py index 14ae193634..1e97f89407 100644 --- a/monai/transforms/io/__init__.py +++ b/monai/transforms/io/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 38a2861c8b..19fafbcbf4 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,15 +23,14 @@ import numpy as np import torch -from monai.config import DtypeLike +from monai.config import DtypeLike, PathLike from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader from monai.data.nifti_saver import NiftiSaver from monai.data.png_saver import PNGSaver from monai.transforms.transform import Transform from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key -from monai.utils import InterpolateMode, ensure_tuple, optional_import -from monai.utils.module import look_up_option +from monai.utils import InterpolateMode, OptionalImportError, ensure_tuple, look_up_option, optional_import nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") @@ -126,6 +125,10 @@ def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np. for r in SUPPORTED_READERS: # set predefined readers as default try: self.register(SUPPORTED_READERS[r](*args, **kwargs)) + except OptionalImportError: + logging.getLogger(self.__class__.__name__).debug( + f"required package for reader {r} is not installed, or the version doesn't match requirement." + ) except TypeError: # the reader doesn't have the corresponding args/kwargs logging.getLogger(self.__class__.__name__).debug( f"{r} is not supported with the given parameters {args} {kwargs}." @@ -139,6 +142,10 @@ def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np. the_reader = look_up_option(_r.lower(), SUPPORTED_READERS) try: self.register(the_reader(*args, **kwargs)) + except OptionalImportError: + warnings.warn( + f"required package for reader {r} is not installed, or the version doesn't match requirement." + ) except TypeError: # the reader doesn't have the corresponding args/kwargs warnings.warn(f"{r} is not supported with the given parameters {args} {kwargs}.") self.register(the_reader()) @@ -160,7 +167,7 @@ def register(self, reader: ImageReader): warnings.warn(f"Preferably the reader should inherit ImageReader, but got {type(reader)}.") self.readers.append(reader) - def __call__(self, filename: Union[Sequence[str], str, Path, Sequence[Path]], reader: Optional[ImageReader] = None): + def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Optional[ImageReader] = None): """ Load image file and meta data from the given filename(s). If `reader` is not specified, this class automatically chooses readers based on the @@ -176,7 +183,7 @@ def __call__(self, filename: Union[Sequence[str], str, Path, Sequence[Path]], re reader: runtime reader to load image file and meta data. """ - filename = tuple(str(s) for s in ensure_tuple(filename)) # allow Path objects + filename = tuple(f"{Path(s).expanduser()}" for s in ensure_tuple(filename)) # allow Path objects img = None if reader is not None: img = reader.read(filename) # runtime specified reader @@ -197,19 +204,21 @@ def __call__(self, filename: Union[Sequence[str], str, Path, Sequence[Path]], re break if img is None or reader is None: + if isinstance(filename, tuple) and len(filename) == 1: + filename = filename[0] raise RuntimeError( - f"can not find a suitable reader for file: {filename}.\n" + f"cannot find a suitable reader for file: {filename}.\n" " Please install the reader libraries, see also the installation instructions:\n" " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" f" The current registered: {self.readers}.\n" ) img_array, meta_data = reader.get_data(img) - img_array = img_array.astype(self.dtype) + img_array = img_array.astype(self.dtype, copy=False) if self.image_only: return img_array - meta_data[Key.FILENAME_OR_OBJ] = ensure_tuple(filename)[0] + meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader # make sure all elements in metadata are little endian meta_data = switch_endianness(meta_data, "<") @@ -285,7 +294,7 @@ class SaveImage(Transform): def __init__( self, - output_dir: Union[Path, str] = "./", + output_dir: PathLike = "./", output_postfix: str = "trans", output_ext: str = ".nii.gz", resample: bool = True, @@ -295,7 +304,7 @@ def __init__( dtype: DtypeLike = np.float64, output_dtype: DtypeLike = np.float32, squeeze_end_dims: bool = True, - data_root_dir: str = "", + data_root_dir: PathLike = "", separate_folder: bool = True, print_log: bool = True, ) -> None: diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 764e20f838..cb73567afb 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -26,14 +26,7 @@ from monai.transforms.transform import MapTransform from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, ensure_tuple, ensure_tuple_rep -__all__ = [ - "LoadImaged", - "LoadImageD", - "LoadImageDict", - "SaveImaged", - "SaveImageD", - "SaveImageDict", -] +__all__ = ["LoadImaged", "LoadImageD", "LoadImageDict", "SaveImaged", "SaveImageD", "SaveImageDict"] class LoadImaged(MapTransform): diff --git a/monai/transforms/nvtx.py b/monai/transforms/nvtx.py index 6dd5c3b0a3..f00145efbc 100644 --- a/monai/transforms/nvtx.py +++ b/monai/transforms/nvtx.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/post/__init__.py b/monai/transforms/post/__init__.py index 14ae193634..1e97f89407 100644 --- a/monai/transforms/post/__init__.py +++ b/monai/transforms/post/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 631947025c..c5fe05d220 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,14 +18,15 @@ import numpy as np import torch -import torch.nn.functional as F -from monai.config import NdarrayTensor +from monai.config.type_definitions import NdarrayOrTensor from monai.networks import one_hot -from monai.networks.layers import GaussianFilter +from monai.networks.layers import GaussianFilter, apply_filter from monai.transforms.transform import Transform from monai.transforms.utils import fill_holes, get_largest_connected_component_mask -from monai.utils import deprecated_arg, ensure_tuple, look_up_option +from monai.transforms.utils_pytorch_numpy_unification import unravel_index +from monai.utils import TransformBackends, convert_data_type, deprecated_arg, ensure_tuple, look_up_option +from monai.utils.type_conversion import convert_to_dst_type __all__ = [ "Activations", @@ -57,6 +58,8 @@ class Activations(Transform): """ + backend = [TransformBackends.TORCH] + def __init__(self, sigmoid: bool = False, softmax: bool = False, other: Optional[Callable] = None) -> None: self.sigmoid = sigmoid self.softmax = softmax @@ -66,11 +69,11 @@ def __init__(self, sigmoid: bool = False, softmax: bool = False, other: Optional def __call__( self, - img: torch.Tensor, + img: NdarrayOrTensor, sigmoid: Optional[bool] = None, softmax: Optional[bool] = None, other: Optional[Callable] = None, - ) -> torch.Tensor: + ) -> NdarrayOrTensor: """ Args: sigmoid: whether to execute sigmoid function on model output before transform. @@ -92,17 +95,18 @@ def __call__( raise TypeError(f"other must be None or callable but is {type(other).__name__}.") # convert to float as activation must operate on float tensor - img = img.float() + img_t: torch.Tensor + img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) # type: ignore if sigmoid or self.sigmoid: - img = torch.sigmoid(img) + img_t = torch.sigmoid(img_t) if softmax or self.softmax: - img = torch.softmax(img, dim=0) + img_t = torch.softmax(img_t, dim=0) act_func = self.other if other is None else other if act_func is not None: - img = act_func(img) - - return img + img_t = act_func(img_t) + out, *_ = convert_to_dst_type(img_t, img) + return out class AsDiscrete(Transform): @@ -118,91 +122,138 @@ class AsDiscrete(Transform): Args: argmax: whether to execute argmax function on input data before transform. Defaults to ``False``. - to_onehot: whether to convert input data into the one-hot format. - Defaults to ``False``. - num_classes: the number of classes to convert to One-Hot format. + to_onehot: if not None, convert input data into the one-hot format with specified number of classes. + Defaults to ``None``. + threshold: if not None, threshold the float values to int number 0 or 1 with specified theashold. Defaults to ``None``. - threshold_values: whether threshold the float value to int number 0 or 1. - Defaults to ``False``. - logit_thresh: the threshold value for thresholding operation.. - Defaults to ``0.5``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. + Example: + + >>> transform = AsDiscrete(argmax=True) + >>> print(transform(np.array([[[0.0, 1.0]], [[2.0, 3.0]]]))) + # [[[1.0, 1.0]]] + + >>> transform = AsDiscrete(threshold=0.6) + >>> print(transform(np.array([[[0.0, 0.5], [0.8, 3.0]]]))) + # [[[0.0, 0.0], [1.0, 1.0]]] + + >>> transform = AsDiscrete(argmax=True, to_onehot=2, threshold=0.5) + >>> print(transform(np.array([[[0.0, 1.0]], [[2.0, 3.0]]]))) + # [[[0.0, 0.0]], [[1.0, 1.0]]] + + .. deprecated:: 0.6.0 + ``n_classes`` is deprecated, use ``to_onehot`` instead. + + .. deprecated:: 0.7.0 + ``num_classes`` is deprecated, use ``to_onehot`` instead. + ``logit_thresh`` is deprecated, use ``threshold`` instead. + ``threshold_values`` is deprecated, use ``threshold`` instead. + """ - @deprecated_arg("n_classes", since="0.6") + backend = [TransformBackends.TORCH] + + @deprecated_arg(name="n_classes", new_name="num_classes", since="0.6", msg_suffix="please use `to_onehot` instead.") + @deprecated_arg("num_classes", since="0.7", msg_suffix="please use `to_onehot` instead.") + @deprecated_arg("logit_thresh", since="0.7", msg_suffix="please use `threshold` instead.") + @deprecated_arg( + name="threshold_values", new_name="threshold", since="0.7", msg_suffix="please use `threshold` instead." + ) def __init__( self, argmax: bool = False, - to_onehot: bool = False, - num_classes: Optional[int] = None, - threshold_values: bool = False, - logit_thresh: float = 0.5, + to_onehot: Optional[int] = None, + threshold: Optional[float] = None, rounding: Optional[str] = None, - n_classes: Optional[int] = None, + n_classes: Optional[int] = None, # deprecated + num_classes: Optional[int] = None, # deprecated + logit_thresh: float = 0.5, # deprecated + threshold_values: Optional[bool] = False, # deprecated ) -> None: - # in case the new num_classes is default but you still call deprecated n_classes - if n_classes is not None and num_classes is None: - num_classes = n_classes self.argmax = argmax + if isinstance(to_onehot, bool): # for backward compatibility + warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") + to_onehot = num_classes if to_onehot else None self.to_onehot = to_onehot - self.num_classes = num_classes - self.threshold_values = threshold_values - self.logit_thresh = logit_thresh + + if isinstance(threshold, bool): # for backward compatibility + warnings.warn("`threshold_values=True/False` is deprecated, please use `threshold=value` instead.") + threshold = logit_thresh if threshold else None + self.threshold = threshold + self.rounding = rounding - @deprecated_arg("n_classes", since="0.6") + @deprecated_arg(name="n_classes", new_name="num_classes", since="0.6", msg_suffix="please use `to_onehot` instead.") + @deprecated_arg("num_classes", since="0.7", msg_suffix="please use `to_onehot` instead.") + @deprecated_arg("logit_thresh", since="0.7", msg_suffix="please use `threshold` instead.") + @deprecated_arg( + name="threshold_values", new_name="threshold", since="0.7", msg_suffix="please use `threshold` instead." + ) def __call__( self, - img: torch.Tensor, + img: NdarrayOrTensor, argmax: Optional[bool] = None, - to_onehot: Optional[bool] = None, - num_classes: Optional[int] = None, - threshold_values: Optional[bool] = None, - logit_thresh: Optional[float] = None, + to_onehot: Optional[int] = None, + threshold: Optional[float] = None, rounding: Optional[str] = None, - n_classes: Optional[int] = None, - ) -> torch.Tensor: + n_classes: Optional[int] = None, # deprecated + num_classes: Optional[int] = None, # deprecated + logit_thresh: Optional[float] = None, # deprecated + threshold_values: Optional[bool] = None, # deprecated + ) -> NdarrayOrTensor: """ Args: img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`, will automatically add it. argmax: whether to execute argmax function on input data before transform. Defaults to ``self.argmax``. - to_onehot: whether to convert input data into the one-hot format. + to_onehot: if not None, convert input data into the one-hot format with specified number of classes. Defaults to ``self.to_onehot``. - num_classes: the number of classes to convert to One-Hot format. - Defaults to ``self.num_classes``. - threshold_values: whether threshold the float value to int number 0 or 1. - Defaults to ``self.threshold_values``. - logit_thresh: the threshold value for thresholding operation.. - Defaults to ``self.logit_thresh``. + threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value. + Defaults to ``self.threshold``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. + .. deprecated:: 0.6.0 + ``n_classes`` is deprecated, use ``to_onehot`` instead. + + .. deprecated:: 0.7.0 + ``num_classes`` is deprecated, use ``to_onehot`` instead. + ``logit_thresh`` is deprecated, use ``threshold`` instead. + ``threshold_values`` is deprecated, use ``threshold`` instead. + """ - # in case the new num_classes is default but you still call deprecated n_classes - if n_classes is not None and num_classes is None: - num_classes = n_classes + if isinstance(to_onehot, bool): + warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") + to_onehot = num_classes if to_onehot else None + if isinstance(threshold, bool): + warnings.warn("`threshold_values=True/False` is deprecated, please use `threshold=value` instead.") + threshold = logit_thresh if threshold else None + + img_t: torch.Tensor + img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore if argmax or self.argmax: - img = torch.argmax(img, dim=0, keepdim=True) + img_t = torch.argmax(img_t, dim=0, keepdim=True) - if to_onehot or self.to_onehot: - _nclasses = self.num_classes if num_classes is None else num_classes - if not isinstance(_nclasses, int): - raise AssertionError("One of self.num_classes or num_classes must be an integer") - img = one_hot(img, num_classes=_nclasses, dim=0) + to_onehot = self.to_onehot if to_onehot is None else to_onehot + if to_onehot is not None: + if not isinstance(to_onehot, int): + raise AssertionError("the number of classes for One-Hot must be an integer.") + img_t = one_hot(img_t, num_classes=to_onehot, dim=0) - if threshold_values or self.threshold_values: - img = img >= (self.logit_thresh if logit_thresh is None else logit_thresh) + threshold = self.threshold if threshold is None else threshold + if threshold is not None: + img_t = img_t >= threshold rounding = self.rounding if rounding is None else rounding if rounding is not None: - rounding = look_up_option(rounding, ["torchrounding"]) - img = torch.round(img) + look_up_option(rounding, ["torchrounding"]) + img_t = torch.round(img_t) - return img.float() + img, *_ = convert_to_dst_type(img_t, img, dtype=torch.float) + return img class KeepLargestConnectedComponent(Transform): @@ -251,17 +302,21 @@ class KeepLargestConnectedComponent(Transform): """ + backend = [TransformBackends.NUMPY] + def __init__( self, applied_labels: Union[Sequence[int], int], independent: bool = True, connectivity: Optional[int] = None ) -> None: """ Args: - applied_labels: Labels for applying the connected component on. - If only one channel. The pixel whose value is not in this list will remain unchanged. - If the data is in one-hot format, this is used to determine what channels to apply. - independent: consider several labels as a whole or independent, default is `True`. - Example use case would be segment label 1 is liver and label 2 is liver tumor, in that case - you want this "independent" to be specified as False. + applied_labels: Labels for applying the connected component analysis on. + If only one channel. The pixel whose value is in this list will be analyzed. + If the data is in one-hot format, this is used to determine which channels to apply. + independent: whether to treat ``applied_labels`` as a union of foreground labels. + If ``True``, the connected component analysis will be performed on each foreground label independently + and return the intersection of the largest components. + If ``False``, the analysis will be performed on the union of foreground labels. + default is `True`. connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. Accepted values are ranging from 1 to input.ndim. If ``None``, a full connectivity of ``input.ndim`` is used. @@ -271,48 +326,37 @@ def __init__( self.independent = independent self.connectivity = connectivity - def __call__(self, img: torch.Tensor) -> torch.Tensor: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: shape must be (C, spatial_dim1[, spatial_dim2, ...]). Returns: - A PyTorch Tensor with shape (C, spatial_dim1[, spatial_dim2, ...]). + An array with shape (C, spatial_dim1[, spatial_dim2, ...]). """ - if img.shape[0] == 1: - img = torch.squeeze(img, dim=0) - - if self.independent: - for i in self.applied_labels: - foreground = (img == i).type(torch.uint8) - mask = get_largest_connected_component_mask(foreground, self.connectivity) - img[foreground != mask] = 0 - else: - foreground = torch.zeros_like(img) - for i in self.applied_labels: - foreground += (img == i).type(torch.uint8) - mask = get_largest_connected_component_mask(foreground, self.connectivity) - img[foreground != mask] = 0 - output = torch.unsqueeze(img, dim=0) - else: - # one-hot data is assumed to have binary value in each channel - if self.independent: - for i in self.applied_labels: - foreground = img[i, ...].type(torch.uint8) - mask = get_largest_connected_component_mask(foreground, self.connectivity) - img[i, ...][foreground != mask] = 0 - else: - applied_img = img[self.applied_labels, ...].type(torch.uint8) - foreground = torch.any(applied_img, dim=0) + is_onehot = img.shape[0] > 1 + if self.independent: + for i in self.applied_labels: + foreground = img[i] > 0 if is_onehot else img[0] == i mask = get_largest_connected_component_mask(foreground, self.connectivity) - background_mask = torch.unsqueeze(foreground != mask, dim=0) - background_mask = torch.repeat_interleave(background_mask, len(self.applied_labels), dim=0) - applied_img[background_mask] = 0 - img[self.applied_labels, ...] = applied_img.type(img.type()) - output = img - - return output + if is_onehot: + img[i][foreground != mask] = 0 + else: + img[0][foreground != mask] = 0 + return img + if not is_onehot: # not one-hot, union of labels + labels, *_ = convert_to_dst_type(self.applied_labels, dst=img, wrap_sequence=True) + foreground = (img[..., None] == labels).any(-1)[0] + mask = get_largest_connected_component_mask(foreground, self.connectivity) + img[0][foreground != mask] = 0 + return img + # one-hot, union of labels + foreground = (img[self.applied_labels, ...] == 1).any(0) + mask = get_largest_connected_component_mask(foreground, self.connectivity) + for i in self.applied_labels: + img[i][foreground != mask] = 0 + return img class LabelFilter: @@ -333,6 +377,8 @@ class LabelFilter: [7, 8, 9] [0, 0, 9] """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, applied_labels: Union[Iterable[int], int]) -> None: """ Initialize the LabelFilter class with the labels to filter on. @@ -342,7 +388,7 @@ def __init__(self, applied_labels: Union[Iterable[int], int]) -> None: """ self.applied_labels = ensure_tuple(applied_labels) - def __call__(self, img: NdarrayTensor) -> NdarrayTensor: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Filter the image on the `applied_labels`. @@ -355,13 +401,18 @@ def __call__(self, img: NdarrayTensor) -> NdarrayTensor: Returns: Pytorch tensor or numpy array of the same shape as the input. """ - if isinstance(img, np.ndarray): - return np.asarray(np.where(np.isin(img, self.applied_labels), img, 0)) + if not isinstance(img, (np.ndarray, torch.Tensor)): + raise NotImplementedError(f"{self.__class__} can not handle data of type {type(img)}.") + if isinstance(img, torch.Tensor): - img_arr = img.detach().cpu().numpy() - img_arr = self(img_arr) - return torch.as_tensor(img_arr, device=img.device) - raise NotImplementedError(f"{self.__class__} can not handle data of type {type(img)}.") + if hasattr(torch, "isin"): # `isin` is new in torch 1.10.0 + appl_lbls = torch.as_tensor(self.applied_labels, device=img.device) + return torch.where(torch.isin(img, appl_lbls), img, torch.tensor(0.0).to(img)) + else: + out = self(img.detach().cpu().numpy()) + out, *_ = convert_to_dst_type(out, img) + return out + return np.asarray(np.where(np.isin(img, self.applied_labels), img, 0)) class FillHoles(Transform): @@ -404,6 +455,8 @@ class FillHoles(Transform): The background label near label 2 and 3 is not fully enclosed and therefore not filled. """ + backend = [TransformBackends.NUMPY] + def __init__( self, applied_labels: Optional[Union[Iterable[int], int]] = None, connectivity: Optional[int] = None ) -> None: @@ -419,7 +472,7 @@ def __init__( self.applied_labels = ensure_tuple(applied_labels) if applied_labels else None self.connectivity = connectivity - def __call__(self, img: NdarrayTensor) -> NdarrayTensor: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Fill the holes in the provided image. @@ -435,18 +488,18 @@ def __call__(self, img: NdarrayTensor) -> NdarrayTensor: Returns: Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]]. """ - if isinstance(img, np.ndarray): - return fill_holes(img, self.applied_labels, self.connectivity) - if isinstance(img, torch.Tensor): - img_arr = img.detach().cpu().numpy() - img_arr = self(img_arr) - return torch.as_tensor(img_arr, device=img.device) - raise NotImplementedError(f"{self.__class__} can not handle data of type {type(img)}.") + if not isinstance(img, (np.ndarray, torch.Tensor)): + raise NotImplementedError(f"{self.__class__} can not handle data of type {type(img)}.") + img_np: np.ndarray + img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore + out_np: np.ndarray = fill_holes(img_np, self.applied_labels, self.connectivity) + out, *_ = convert_to_dst_type(out_np, img) + return out class LabelToContour(Transform): """ - Return the contour of binary input images that only compose of 0 and 1, with Laplace kernel + Return the contour of binary input images that only compose of 0 and 1, with Laplacian kernel set as default for edge detection. Typical usage is to plot the edge of label or segmentation output. Args: @@ -457,12 +510,14 @@ class LabelToContour(Transform): """ + backend = [TransformBackends.TORCH] + def __init__(self, kernel_type: str = "Laplace") -> None: if kernel_type != "Laplace": raise NotImplementedError('Currently only kernel_type="Laplace" is supported.') self.kernel_type = kernel_type - def __call__(self, img: torch.Tensor) -> torch.Tensor: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: torch tensor data to extract the contour, with shape: [channels, height, width[, depth]] @@ -478,25 +533,41 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: ideally the edge should be thin enough, but now it has a thickness. """ - channels = img.shape[0] - img_ = img.unsqueeze(0) - if img.ndimension() == 3: - kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32, device=img.device) - kernel = kernel.repeat(channels, 1, 1, 1) - contour_img = F.conv2d(img_, kernel, bias=None, stride=1, padding=1, dilation=1, groups=channels) - elif img.ndimension() == 4: - kernel = -1 * torch.ones(3, 3, 3, dtype=torch.float32, device=img.device) - kernel[1, 1, 1] = 26 - kernel = kernel.repeat(channels, 1, 1, 1, 1) - contour_img = F.conv3d(img_, kernel, bias=None, stride=1, padding=1, dilation=1, groups=channels) + img_: torch.Tensor = convert_data_type(img, torch.Tensor)[0] # type: ignore + spatial_dims = len(img_.shape) - 1 + img_ = img_.unsqueeze(0) # adds a batch dim + if spatial_dims == 2: + kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32) + elif spatial_dims == 3: + kernel = -1.0 * torch.ones(3, 3, 3, dtype=torch.float32) + kernel[1, 1, 1] = 26.0 else: - raise ValueError(f"Unsupported img dimension: {img.ndimension()}, available options are [4, 5].") - + raise ValueError(f"{self.__class__} can only handle 2D or 3D images.") + contour_img = apply_filter(img_, kernel) contour_img.clamp_(min=0.0, max=1.0) - return contour_img.squeeze(0) + output, *_ = convert_to_dst_type(contour_img.squeeze(0), img) + return output + +class Ensemble: + @staticmethod + def get_stacked_torch(img: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor]) -> torch.Tensor: + """Get either a sequence or single instance of np.ndarray/torch.Tensor. Return single torch.Tensor.""" + if isinstance(img, Sequence) and isinstance(img[0], np.ndarray): + img = [torch.as_tensor(i) for i in img] + elif isinstance(img, np.ndarray): + img = torch.as_tensor(img) + out: torch.Tensor = torch.stack(img) if isinstance(img, Sequence) else img # type: ignore + return out -class MeanEnsemble(Transform): + @staticmethod + def post_convert(img: torch.Tensor, orig_img: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor]) -> NdarrayOrTensor: + orig_img_ = orig_img[0] if isinstance(orig_img, Sequence) else orig_img + out, *_ = convert_to_dst_type(img, orig_img_) + return out + + +class MeanEnsemble(Ensemble, Transform): """ Execute mean ensemble on the input data. The input data can be a list or tuple of PyTorch Tensor with shape: [C[, H, W, D]], @@ -519,11 +590,13 @@ class MeanEnsemble(Transform): """ - def __init__(self, weights: Optional[Union[Sequence[float], torch.Tensor, np.ndarray]] = None) -> None: + backend = [TransformBackends.TORCH] + + def __init__(self, weights: Optional[Union[Sequence[float], NdarrayOrTensor]] = None) -> None: self.weights = torch.as_tensor(weights, dtype=torch.float) if weights is not None else None - def __call__(self, img: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Tensor: - img_ = torch.stack(img) if isinstance(img, (tuple, list)) else torch.as_tensor(img) + def __call__(self, img: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor]) -> NdarrayOrTensor: + img_ = self.get_stacked_torch(img) if self.weights is not None: self.weights = self.weights.to(img_.device) shape = tuple(self.weights.shape) @@ -533,10 +606,11 @@ def __call__(self, img: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Te img_ = img_ * weights / weights.mean(dim=0, keepdim=True) - return torch.mean(img_, dim=0) + out_pt = torch.mean(img_, dim=0) + return self.post_convert(out_pt, img) -class VoteEnsemble(Transform): +class VoteEnsemble(Ensemble, Transform): """ Execute vote ensemble on the input data. The input data can be a list or tuple of PyTorch Tensor with shape: [C[, H, W, D]], @@ -556,11 +630,14 @@ class VoteEnsemble(Transform): """ + backend = [TransformBackends.TORCH] + def __init__(self, num_classes: Optional[int] = None) -> None: self.num_classes = num_classes - def __call__(self, img: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Tensor: - img_ = torch.stack(img) if isinstance(img, (tuple, list)) else torch.as_tensor(img) + def __call__(self, img: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor]) -> NdarrayOrTensor: + img_ = self.get_stacked_torch(img) + if self.num_classes is not None: has_ch_dim = True if img_.ndimension() > 1 and img_.shape[1] > 1: @@ -575,9 +652,11 @@ def __call__(self, img: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Te if self.num_classes is not None: # if not One-Hot, use "argmax" to vote the most common class - return torch.argmax(img_, dim=0, keepdim=has_ch_dim) - # for One-Hot data, round the float number to 0 or 1 - return torch.round(img_) + out_pt = torch.argmax(img_, dim=0, keepdim=has_ch_dim) + else: + # for One-Hot data, round the float number to 0 or 1 + out_pt = torch.round(img_) + return self.post_convert(out_pt, img) class ProbNMS(Transform): @@ -611,6 +690,8 @@ class ProbNMS(Transform): """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, spatial_dims: int = 2, @@ -627,9 +708,9 @@ def __init__( self.prob_threshold = prob_threshold if isinstance(box_size, int): self.box_size = np.asarray([box_size] * spatial_dims) + elif len(box_size) != spatial_dims: + raise ValueError("the sequence length of box_size should be the same as spatial_dims.") else: - if len(box_size) != spatial_dims: - raise ValueError("the sequence length of box_size should be the same as spatial_dims.") self.box_size = np.asarray(box_size) if self.box_size.min() <= 0: raise ValueError("box_size should be larger than 0.") @@ -637,10 +718,7 @@ def __init__( self.box_lower_bd = self.box_size // 2 self.box_upper_bd = self.box_size - self.box_lower_bd - def __call__( - self, - prob_map: Union[np.ndarray, torch.Tensor], - ): + def __call__(self, prob_map: NdarrayOrTensor): """ prob_map: the input probabilities map, it must have shape (H[, W, ...]). """ @@ -649,24 +727,19 @@ def __call__( prob_map = torch.as_tensor(prob_map, dtype=torch.float) self.filter.to(prob_map) prob_map = self.filter(prob_map) - else: - if not isinstance(prob_map, torch.Tensor): - prob_map = prob_map.copy() - - if isinstance(prob_map, torch.Tensor): - prob_map = prob_map.detach().cpu().numpy() prob_map_shape = prob_map.shape outputs = [] - while np.max(prob_map) > self.prob_threshold: - max_idx = np.unravel_index(prob_map.argmax(), prob_map_shape) - prob_max = prob_map[max_idx] - max_idx_arr = np.asarray(max_idx) - outputs.append([prob_max] + list(max_idx_arr)) - - idx_min_range = (max_idx_arr - self.box_lower_bd).clip(0, None) - idx_max_range = (max_idx_arr + self.box_upper_bd).clip(None, prob_map_shape) + while prob_map.max() > self.prob_threshold: + max_idx = unravel_index(prob_map.argmax(), prob_map_shape) + prob_max = prob_map[tuple(max_idx)] + max_idx = max_idx.cpu().numpy() if isinstance(max_idx, torch.Tensor) else max_idx + prob_max = prob_max.item() if isinstance(prob_max, torch.Tensor) else prob_max + outputs.append([prob_max] + list(max_idx)) + + idx_min_range = (max_idx - self.box_lower_bd).clip(0, None) + idx_max_range = (max_idx + self.box_upper_bd).clip(None, prob_map_shape) # for each dimension, set values during index ranges to 0 slices = tuple(slice(idx_min_range[i], idx_max_range[i]) for i in range(self.spatial_dims)) prob_map[slices] = 0 diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 2fc3993e3e..9d7c9652bf 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,10 +19,9 @@ from copy import deepcopy from typing import Any, Callable, Dict, Hashable, Iterable, List, Mapping, Optional, Sequence, Union -import numpy as np import torch -from monai.config import KeysCollection, NdarrayTensor +from monai.config.type_definitions import KeysCollection, NdarrayOrTensor, PathLike from monai.data.csv_saver import CSVSaver from monai.transforms.inverse import InvertibleTransform from monai.transforms.post.array import ( @@ -40,7 +39,6 @@ from monai.transforms.utility.array import ToTensor from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode from monai.utils import deprecated_arg, ensure_tuple, ensure_tuple_rep -from monai.utils.enums import InverseKeys __all__ = [ "ActivationsD", @@ -86,6 +84,8 @@ class Activationsd(MapTransform): Add activation layers to the input data specified by `keys`. """ + backend = Activations.backend + def __init__( self, keys: KeysCollection, @@ -114,7 +114,7 @@ def __init__( self.other = ensure_tuple_rep(other, len(self.keys)) self.converter = Activations() - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, sigmoid, softmax, other in self.key_iterator(d, self.sigmoid, self.softmax, self.other): d[key] = self.converter(d[key], sigmoid, softmax, other) @@ -126,18 +126,26 @@ class AsDiscreted(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.AsDiscrete`. """ - @deprecated_arg("n_classes", since="0.6") + backend = AsDiscrete.backend + + @deprecated_arg(name="n_classes", new_name="num_classes", since="0.6", msg_suffix="please use `to_onehot` instead.") + @deprecated_arg("num_classes", since="0.7", msg_suffix="please use `to_onehot` instead.") + @deprecated_arg("logit_thresh", since="0.7", msg_suffix="please use `threshold` instead.") + @deprecated_arg( + name="threshold_values", new_name="threshold", since="0.7", msg_suffix="please use `threshold` instead." + ) def __init__( self, keys: KeysCollection, argmax: Union[Sequence[bool], bool] = False, - to_onehot: Union[Sequence[bool], bool] = False, - num_classes: Optional[Union[Sequence[int], int]] = None, - threshold_values: Union[Sequence[bool], bool] = False, - logit_thresh: Union[Sequence[float], float] = 0.5, + to_onehot: Union[Sequence[Optional[int]], Optional[int]] = None, + threshold: Union[Sequence[Optional[float]], Optional[float]] = None, rounding: Union[Sequence[Optional[str]], Optional[str]] = None, allow_missing_keys: bool = False, - n_classes: Optional[int] = None, + n_classes: Optional[Union[Sequence[int], int]] = None, # deprecated + num_classes: Optional[Union[Sequence[int], int]] = None, # deprecated + logit_thresh: Union[Sequence[float], float] = 0.5, # deprecated + threshold_values: Union[Sequence[bool], bool] = False, # deprecated ) -> None: """ Args: @@ -145,46 +153,55 @@ def __init__( See also: :py:class:`monai.transforms.compose.MapTransform` argmax: whether to execute argmax function on input data before transform. it also can be a sequence of bool, each element corresponds to a key in ``keys``. - to_onehot: whether to convert input data into the one-hot format. Defaults to False. - it also can be a sequence of bool, each element corresponds to a key in ``keys``. - num_classes: the number of classes to convert to One-Hot format. it also can be a - sequence of int, each element corresponds to a key in ``keys``. - threshold_values: whether threshold the float value to int number 0 or 1, default is False. - it also can be a sequence of bool, each element corresponds to a key in ``keys``. - logit_thresh: the threshold value for thresholding operation, default is 0.5. - it also can be a sequence of float, each element corresponds to a key in ``keys``. + to_onehot: if not None, convert input data into the one-hot format with specified number of classes. + defaults to ``None``. it also can be a sequence, each element corresponds to a key in ``keys``. + threshold: if not None, threshold the float values to int number 0 or 1 with specified theashold value. + defaults to ``None``. it also can be a sequence, each element corresponds to a key in ``keys``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. it also can be a sequence of str or None, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. + .. deprecated:: 0.6.0 + ``n_classes`` is deprecated, use ``to_onehot`` instead. + + .. deprecated:: 0.7.0 + ``num_classes`` is deprecated, use ``to_onehot`` instead. + ``logit_thresh`` is deprecated, use ``threshold`` instead. + ``threshold_values`` is deprecated, use ``threshold`` instead. + """ - # in case the new num_classes is default but you still call deprecated n_classes - if n_classes is not None and num_classes is None: - num_classes = n_classes super().__init__(keys, allow_missing_keys) self.argmax = ensure_tuple_rep(argmax, len(self.keys)) - self.to_onehot = ensure_tuple_rep(to_onehot, len(self.keys)) - self.num_classes = ensure_tuple_rep(num_classes, len(self.keys)) - self.threshold_values = ensure_tuple_rep(threshold_values, len(self.keys)) - self.logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys)) + to_onehot_ = ensure_tuple_rep(to_onehot, len(self.keys)) + num_classes = ensure_tuple_rep(num_classes, len(self.keys)) + self.to_onehot = [] + for flag, val in zip(to_onehot_, num_classes): + if isinstance(flag, bool): + warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") + self.to_onehot.append(val if flag else None) + else: + self.to_onehot.append(flag) + + threshold_ = ensure_tuple_rep(threshold, len(self.keys)) + logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys)) + self.threshold = [] + for flag, val in zip(threshold_, logit_thresh): + if isinstance(flag, bool): + warnings.warn("`threshold_values=True/False` is deprecated, please use `threshold=value` instead.") + self.threshold.append(val if flag else None) + else: + self.threshold.append(flag) + self.rounding = ensure_tuple_rep(rounding, len(self.keys)) self.converter = AsDiscrete() - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - for key, argmax, to_onehot, num_classes, threshold_values, logit_thresh, rounding in self.key_iterator( - d, self.argmax, self.to_onehot, self.num_classes, self.threshold_values, self.logit_thresh, self.rounding + for key, argmax, to_onehot, threshold, rounding in self.key_iterator( + d, self.argmax, self.to_onehot, self.threshold, self.rounding ): - d[key] = self.converter( - d[key], - argmax, - to_onehot, - num_classes, - threshold_values, - logit_thresh, - rounding, - ) + d[key] = self.converter(d[key], argmax, to_onehot, threshold, rounding) return d @@ -193,6 +210,8 @@ class KeepLargestConnectedComponentd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.KeepLargestConnectedComponent`. """ + backend = KeepLargestConnectedComponent.backend + def __init__( self, keys: KeysCollection, @@ -208,9 +227,11 @@ def __init__( applied_labels: Labels for applying the connected component on. If only one channel. The pixel whose value is not in this list will remain unchanged. If the data is in one-hot format, this is the channel indices to apply transform. - independent: consider several labels as a whole or independent, default is `True`. - Example use case would be segment label 1 is liver and label 2 is liver tumor, in that case - you want this "independent" to be specified as False. + independent: whether to treat ``applied_labels`` as a union of foreground labels. + If ``True``, the connected component analysis will be performed on each foreground label independently + and return the intersection of the largest components. + If ``False``, the analysis will be performed on the union of foreground labels. + default is `True`. connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. Accepted values are ranging from 1 to input.ndim. If ``None``, a full connectivity of ``input.ndim`` is used. @@ -220,7 +241,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = KeepLargestConnectedComponent(applied_labels, independent, connectivity) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -232,11 +253,10 @@ class LabelFilterd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.LabelFilter`. """ + backend = LabelFilter.backend + def __init__( - self, - keys: KeysCollection, - applied_labels: Union[Sequence[int], int], - allow_missing_keys: bool = False, + self, keys: KeysCollection, applied_labels: Union[Sequence[int], int], allow_missing_keys: bool = False ) -> None: """ Args: @@ -249,7 +269,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = LabelFilter(applied_labels) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -261,6 +281,8 @@ class FillHolesd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.FillHoles`. """ + backend = FillHoles.backend + def __init__( self, keys: KeysCollection, @@ -284,7 +306,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = FillHoles(applied_labels=applied_labels, connectivity=connectivity) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -296,6 +318,8 @@ class LabelToContourd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.LabelToContour`. """ + backend = LabelToContour.backend + def __init__(self, keys: KeysCollection, kernel_type: str = "Laplace", allow_missing_keys: bool = False) -> None: """ Args: @@ -308,7 +332,7 @@ def __init__(self, keys: KeysCollection, kernel_type: str = "Laplace", allow_mis super().__init__(keys, allow_missing_keys) self.converter = LabelToContour(kernel_type=kernel_type) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -321,10 +345,12 @@ class Ensembled(MapTransform): """ + backend = list(set(VoteEnsemble.backend) & set(MeanEnsemble.backend)) + def __init__( self, keys: KeysCollection, - ensemble: Callable[[Union[Sequence[torch.Tensor], torch.Tensor]], torch.Tensor], + ensemble: Callable[[Union[Sequence[NdarrayOrTensor], NdarrayOrTensor]], NdarrayOrTensor], output_key: Optional[str] = None, allow_missing_keys: bool = False, ) -> None: @@ -350,14 +376,16 @@ def __init__( raise ValueError("Incompatible values: len(self.keys) > 1 and output_key=None.") self.output_key = output_key if output_key is not None else self.keys[0] - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - items: Union[List[torch.Tensor], torch.Tensor] - if len(self.keys) == 1: + items: Union[List[NdarrayOrTensor], NdarrayOrTensor] + if len(self.keys) == 1 and self.keys[0] in d: items = d[self.keys[0]] else: items = [d[key] for key in self.key_iterator(d)] - d[self.output_key] = self.ensemble(items) + + if len(items) > 0: + d[self.output_key] = self.ensemble(items) return d @@ -367,11 +395,13 @@ class MeanEnsembled(Ensembled): Dictionary-based wrapper of :py:class:`monai.transforms.MeanEnsemble`. """ + backend = MeanEnsemble.backend + def __init__( self, keys: KeysCollection, output_key: Optional[str] = None, - weights: Optional[Union[Sequence[float], torch.Tensor, np.ndarray]] = None, + weights: Optional[Union[Sequence[float], NdarrayOrTensor]] = None, ) -> None: """ Args: @@ -400,6 +430,8 @@ class VoteEnsembled(Ensembled): Dictionary-based wrapper of :py:class:`monai.transforms.VoteEnsemble`. """ + backend = VoteEnsemble.backend + def __init__( self, keys: KeysCollection, output_key: Optional[str] = None, num_classes: Optional[int] = None ) -> None: @@ -448,6 +480,8 @@ class ProbNMSd(MapTransform): """ + backend = ProbNMS.backend + def __init__( self, keys: KeysCollection, @@ -459,13 +493,10 @@ def __init__( ) -> None: super().__init__(keys, allow_missing_keys) self.prob_nms = ProbNMS( - spatial_dims=spatial_dims, - sigma=sigma, - prob_threshold=prob_threshold, - box_size=box_size, + spatial_dims=spatial_dims, sigma=sigma, prob_threshold=prob_threshold, box_size=box_size ) - def __call__(self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]): + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]): d = dict(data) for key in self.key_iterator(d): d[key] = self.prob_nms(d[key]) @@ -589,7 +620,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: self.device, self.post_func, ): - transform_key = f"{orig_key}{InverseKeys.KEY_SUFFIX}" + transform_key = InvertibleTransform.trace_key(orig_key) if transform_key not in d: warnings.warn(f"transform info of `{orig_key}` is not available or no InvertibleTransform applied.") continue @@ -597,19 +628,14 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: transform_info = d[transform_key] if nearest_interp: transform_info = convert_inverse_interp_mode( - trans_info=deepcopy(transform_info), - mode="nearest", - align_corners=None, + trans_info=deepcopy(transform_info), mode="nearest", align_corners=None ) input = d[key] if isinstance(input, torch.Tensor): input = input.detach() # construct the input dict data for BatchInverseTransform - input_dict = { - orig_key: input, - transform_key: transform_info, - } + input_dict = {orig_key: input, transform_key: transform_info} orig_meta_key = orig_meta_key or f"{orig_key}_{meta_key_postfix}" meta_key = meta_key or f"{key}_{meta_key_postfix}" if orig_meta_key in d: @@ -639,7 +665,7 @@ def __init__( meta_keys: Optional[KeysCollection] = None, meta_key_postfix: str = "meta_dict", saver: Optional[CSVSaver] = None, - output_dir: str = "./", + output_dir: PathLike = "./", filename: str = "predictions.csv", overwrite: bool = True, flush: bool = True, diff --git a/monai/transforms/smooth_field/__init__.py b/monai/transforms/smooth_field/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/transforms/smooth_field/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/transforms/smooth_field/array.py b/monai/transforms/smooth_field/array.py new file mode 100644 index 0000000000..31ce76e5b5 --- /dev/null +++ b/monai/transforms/smooth_field/array.py @@ -0,0 +1,461 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transforms using a smooth spatial field generated by interpolating from smaller randomized fields.""" + +from typing import Any, Optional, Sequence, Union + +import numpy as np +import torch +from torch.nn.functional import grid_sample, interpolate + +import monai +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.transform import Randomizable, RandomizableTransform +from monai.transforms.utils_pytorch_numpy_unification import moveaxis +from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode +from monai.utils.enums import TransformBackends +from monai.utils.module import look_up_option, pytorch_after +from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor + +__all__ = ["SmoothField", "RandSmoothFieldAdjustContrast", "RandSmoothFieldAdjustIntensity", "RandSmoothDeform"] + + +class SmoothField(Randomizable): + """ + Generate a smooth field array by defining a smaller randomized field and then reinterpolating to the desired size. + + This exploits interpolation to create a smoothly varying field used for other applications. An initial randomized + field is defined with `rand_size` dimensions with `pad` number of values padding it along each dimension using + `pad_val` as the value. If `spatial_size` is given this is interpolated to that size, otherwise if None the random + array is produced uninterpolated. The output is always a Pytorch tensor allocated on the specified device. + + Args: + rand_size: size of the randomized field to start from + pad: number of pixels/voxels along the edges of the field to pad with `pad_val` + pad_val: value with which to pad field edges + low: low value for randomized field + high: high value for randomized field + channels: number of channels of final output + spatial_size: final output size of the array, None to produce original uninterpolated field + mode: interpolation mode for resizing the field + align_corners: if True align the corners when upsampling field + device: Pytorch device to define field on + """ + + def __init__( + self, + rand_size: Sequence[int], + pad: int = 0, + pad_val: float = 0, + low: float = -1.0, + high: float = 1.0, + channels: int = 1, + spatial_size: Optional[Sequence[int]] = None, + mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + align_corners: Optional[bool] = None, + device: Optional[torch.device] = None, + ): + self.rand_size = tuple(rand_size) + self.pad = pad + self.low = low + self.high = high + self.channels = channels + self.mode = mode + self.align_corners = align_corners + self.device = device + + self.spatial_size: Optional[Sequence[int]] = None + self.spatial_zoom: Optional[Sequence[float]] = None + + if low >= high: + raise ValueError("Value for `low` must be less than `high` otherwise field will be zeros") + + self.total_rand_size = tuple(rs + self.pad * 2 for rs in self.rand_size) + + self.field = torch.ones((1, self.channels) + self.total_rand_size, device=self.device) * pad_val + + self.crand_size = (self.channels,) + self.rand_size + + pad_slice = slice(None) if self.pad == 0 else slice(self.pad, -self.pad) + self.rand_slices = (0, slice(None)) + (pad_slice,) * len(self.rand_size) + + self.set_spatial_size(spatial_size) + + def randomize(self, data: Optional[Any] = None) -> None: + self.field[self.rand_slices] = torch.from_numpy(self.R.uniform(self.low, self.high, self.crand_size)) + + def set_spatial_size(self, spatial_size: Optional[Sequence[int]]) -> None: + """ + Set the `spatial_size` and `spatial_zoom` attributes used for interpolating the field to the given + dimension, or not interpolate at all if None. + + Args: + spatial_size: new size to interpolate to, or None to not interpolate + """ + if spatial_size is None: + self.spatial_size = None + self.spatial_zoom = None + else: + self.spatial_size = tuple(spatial_size) + self.spatial_zoom = tuple(s / f for s, f in zip(self.spatial_size, self.total_rand_size)) + + def set_mode(self, mode: Union[monai.utils.InterpolateMode, str]) -> None: + self.mode = mode + + def __call__(self, randomize=False) -> torch.Tensor: + if randomize: + self.randomize() + + field = self.field.clone() + + if self.spatial_zoom is not None: + resized_field = interpolate( # type: ignore + input=field, # type: ignore + scale_factor=self.spatial_zoom, + mode=look_up_option(self.mode, InterpolateMode).value, + align_corners=self.align_corners, + recompute_scale_factor=False, + ) + + mina = resized_field.min() + maxa = resized_field.max() + minv = self.field.min() + maxv = self.field.max() + + # faster than rescale_array, this uses in-place operations and doesn't perform unneeded range checks + norm_field = (resized_field.squeeze(0) - mina).div_(maxa - mina) + field = norm_field.mul_(maxv - minv).add_(minv) + + return field + + +class RandSmoothFieldAdjustContrast(RandomizableTransform): + """ + Randomly adjust the contrast of input images by calculating a randomized smooth field for each invocation. + + This uses SmoothField internally to define the adjustment over the image. If `pad` is greater than 0 the + edges of the input volume of that width will be mostly unchanged. Contrast is changed by raising input + values by the power of the smooth field so the range of values given by `gamma` should be chosen with this + in mind. For example, a minimum value of 0 in `gamma` will produce white areas so this should be avoided. + Afte the contrast is adjusted the values of the result are rescaled to the range of the original input. + + Args: + spatial_size: size of input array's spatial dimensions + rand_size: size of the randomized field to start from + pad: number of pixels/voxels along the edges of the field to pad with 1 + mode: interpolation mode to use when upsampling + align_corners: if True align the corners when upsampling field + prob: probability transform is applied + gamma: (min, max) range for exponential field + device: Pytorch device to define field on + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__( + self, + spatial_size: Sequence[int], + rand_size: Sequence[int], + pad: int = 0, + mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + align_corners: Optional[bool] = None, + prob: float = 0.1, + gamma: Union[Sequence[float], float] = (0.5, 4.5), + device: Optional[torch.device] = None, + ): + super().__init__(prob) + + if isinstance(gamma, (int, float)): + self.gamma = (0.5, gamma) + else: + if len(gamma) != 2: + raise ValueError("Argument `gamma` should be a number or pair of numbers.") + + self.gamma = (min(gamma), max(gamma)) + + self.sfield = SmoothField( + rand_size=rand_size, + pad=pad, + pad_val=1, + low=self.gamma[0], + high=self.gamma[1], + channels=1, + spatial_size=spatial_size, + mode=mode, + align_corners=align_corners, + device=device, + ) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandSmoothFieldAdjustContrast": + super().set_random_state(seed, state) + self.sfield.set_random_state(seed, state) + return self + + def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) + + if self._do_transform: + self.sfield.randomize() + + def set_mode(self, mode: Union[monai.utils.InterpolateMode, str]) -> None: + self.sfield.set_mode(mode) + + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: + """ + Apply the transform to `img`, if `randomize` randomizing the smooth field otherwise reusing the previous. + """ + if randomize: + self.randomize() + + if not self._do_transform: + return img + + img_min = img.min() + img_max = img.max() + img_rng = img_max - img_min + + field = self.sfield() + rfield, *_ = convert_to_dst_type(field, img) + + # everything below here is to be computed using the destination type (numpy, tensor, etc.) + + img = (img - img_min) / (img_rng + 1e-10) # rescale to unit values + img = img ** rfield # contrast is changed by raising image data to a power, in this case the field + + out = (img * img_rng) + img_min # rescale back to the original image value range + + return out + + +class RandSmoothFieldAdjustIntensity(RandomizableTransform): + """ + Randomly adjust the intensity of input images by calculating a randomized smooth field for each invocation. + + This uses SmoothField internally to define the adjustment over the image. If `pad` is greater than 0 the + edges of the input volume of that width will be mostly unchanged. Intensity is changed by multiplying the + inputs by the smooth field, so the values of `gamma` should be chosen with this in mind. The default values + of `(0.1, 1.0)` are sensible in that values will not be zeroed out by the field nor multiplied greater than + the original value range. + + Args: + spatial_size: size of input array + rand_size: size of the randomized field to start from + pad: number of pixels/voxels along the edges of the field to pad with 1 + mode: interpolation mode to use when upsampling + align_corners: if True align the corners when upsampling field + prob: probability transform is applied + gamma: (min, max) range of intensity multipliers + device: Pytorch device to define field on + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__( + self, + spatial_size: Sequence[int], + rand_size: Sequence[int], + pad: int = 0, + mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + align_corners: Optional[bool] = None, + prob: float = 0.1, + gamma: Union[Sequence[float], float] = (0.1, 1.0), + device: Optional[torch.device] = None, + ): + super().__init__(prob) + + if isinstance(gamma, (int, float)): + self.gamma = (0.5, gamma) + else: + if len(gamma) != 2: + raise ValueError("Argument `gamma` should be a number or pair of numbers.") + + self.gamma = (min(gamma), max(gamma)) + + self.sfield = SmoothField( + rand_size=rand_size, + pad=pad, + pad_val=1, + low=self.gamma[0], + high=self.gamma[1], + channels=1, + spatial_size=spatial_size, + mode=mode, + align_corners=align_corners, + device=device, + ) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandSmoothFieldAdjustIntensity": + super().set_random_state(seed, state) + self.sfield.set_random_state(seed, state) + return self + + def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) + + if self._do_transform: + self.sfield.randomize() + + def set_mode(self, mode: Union[InterpolateMode, str]) -> None: + self.sfield.set_mode(mode) + + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: + """ + Apply the transform to `img`, if `randomize` randomizing the smooth field otherwise reusing the previous. + """ + + if randomize: + self.randomize() + + if not self._do_transform: + return img + + field = self.sfield() + rfield, *_ = convert_to_dst_type(field, img) + + # everything below here is to be computed using the destination type (numpy, tensor, etc.) + + out = img * rfield + + return out + + +class RandSmoothDeform(RandomizableTransform): + """ + Deform an image using a random smooth field and Pytorch's grid_sample. + + The amount of deformation is given by `def_range` in fractions of the size of the image. The size of each dimension + of the input image is always defined as 2 regardless of actual image voxel dimensions, that is the coordinates in + every dimension range from -1 to 1. A value of 0.1 means pixels/voxels can be moved by up to 5% of the image's size. + + Args: + spatial_size: input array size to which deformation grid is interpolated + rand_size: size of the randomized field to start from + pad: number of pixels/voxels along the edges of the field to pad with 0 + field_mode: interpolation mode to use when upsampling the deformation field + align_corners: if True align the corners when upsampling field + prob: probability transform is applied + def_range: value of the deformation range in image size fractions, single min/max value or min/max pair + grid_dtype: type for the deformation grid calculated from the field + grid_mode: interpolation mode used for sampling input using deformation grid + grid_padding_mode: padding mode used for sampling input using deformation grid + grid_align_corners: if True align the corners when sampling the deformation grid + device: Pytorch device to define field on + """ + + backend = [TransformBackends.TORCH] + + def __init__( + self, + spatial_size: Sequence[int], + rand_size: Sequence[int], + pad: int = 0, + field_mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + align_corners: Optional[bool] = None, + prob: float = 0.1, + def_range: Union[Sequence[float], float] = 1.0, + grid_dtype=torch.float32, + grid_mode: Union[GridSampleMode, str] = GridSampleMode.NEAREST, + grid_padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + grid_align_corners: Optional[bool] = False, + device: Optional[torch.device] = None, + ): + super().__init__(prob) + + self.grid_dtype = grid_dtype + self.grid_mode = grid_mode + self.def_range = def_range + self.device = device + self.grid_align_corners = grid_align_corners + self.grid_padding_mode = grid_padding_mode + + if isinstance(def_range, (int, float)): + self.def_range = (-def_range, def_range) + else: + if len(def_range) != 2: + raise ValueError("Argument `def_range` should be a number or pair of numbers.") + + self.def_range = (min(def_range), max(def_range)) + + self.sfield = SmoothField( + spatial_size=spatial_size, + rand_size=rand_size, + pad=pad, + low=self.def_range[0], + high=self.def_range[1], + channels=len(rand_size), + mode=field_mode, + align_corners=align_corners, + device=device, + ) + + grid_space = spatial_size if spatial_size is not None else self.sfield.field.shape[2:] + grid_ranges = [torch.linspace(-1, 1, d) for d in grid_space] + + if pytorch_after(1, 10): + grid = torch.meshgrid(*grid_ranges, indexing="ij") + else: + grid = torch.meshgrid(*grid_ranges) + + self.grid = torch.stack(grid).unsqueeze(0).to(self.device, self.grid_dtype) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "Randomizable": + super().set_random_state(seed, state) + self.sfield.set_random_state(seed, state) + return self + + def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) + + if self._do_transform: + self.sfield.randomize() + + def set_field_mode(self, mode: Union[monai.utils.InterpolateMode, str]) -> None: + self.sfield.set_mode(mode) + + def set_grid_mode(self, mode: Union[monai.utils.GridSampleMode, str]) -> None: + self.grid_mode = mode + + def __call__( + self, img: NdarrayOrTensor, randomize: bool = True, device: Optional[torch.device] = None + ) -> NdarrayOrTensor: + if randomize: + self.randomize() + + if not self._do_transform: + return img + + device = device if device is not None else self.device + + field = self.sfield() + + dgrid = self.grid + field.to(self.grid_dtype) + dgrid = moveaxis(dgrid, 1, -1) # type: ignore + + img_t = convert_to_tensor(img[None], torch.float32, device) + + out = grid_sample( + input=img_t, + grid=dgrid, + mode=look_up_option(self.grid_mode, GridSampleMode).value, + align_corners=self.grid_align_corners, + padding_mode=look_up_option(self.grid_padding_mode, GridSamplePadMode).value, + ) + + out_t, *_ = convert_to_dst_type(out.squeeze(0), img) + + return out_t diff --git a/monai/transforms/smooth_field/dictionary.py b/monai/transforms/smooth_field/dictionary.py new file mode 100644 index 0000000000..c129d14f32 --- /dev/null +++ b/monai/transforms/smooth_field/dictionary.py @@ -0,0 +1,278 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Hashable, Mapping, Optional, Sequence, Union + +import numpy as np +import torch + +from monai.config import KeysCollection +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.smooth_field.array import ( + RandSmoothDeform, + RandSmoothFieldAdjustContrast, + RandSmoothFieldAdjustIntensity, +) +from monai.transforms.transform import MapTransform, RandomizableTransform +from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, ensure_tuple_rep +from monai.utils.enums import TransformBackends + +__all__ = ["RandSmoothFieldAdjustContrastd", "RandSmoothFieldAdjustIntensityd", "RandSmoothDeformd"] + + +InterpolateModeType = Union[InterpolateMode, str] +GridSampleModeType = Union[GridSampleMode, str] + + +class RandSmoothFieldAdjustContrastd(RandomizableTransform, MapTransform): + """ + Dictionary version of RandSmoothFieldAdjustContrast. + + The field is randomized once per invocation by default so the same field is applied to every selected key. The + `mode` parameter specifying interpolation mode for the field can be a single value or a sequence of values with + one for each key in `keys`. + + Args: + keys: key names to apply the augment to + spatial_size: size of input arrays, all arrays stated in `keys` must have same dimensions + rand_size: size of the randomized field to start from + pad: number of pixels/voxels along the edges of the field to pad with 0 + mode: interpolation mode to use when upsampling + align_corners: if True align the corners when upsampling field + prob: probability transform is applied + gamma: (min, max) range for exponential field + device: Pytorch device to define field on + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__( + self, + keys: KeysCollection, + spatial_size: Sequence[int], + rand_size: Sequence[int], + pad: int = 0, + mode: Union[InterpolateModeType, Sequence[InterpolateModeType]] = InterpolateMode.AREA, + align_corners: Optional[bool] = None, + prob: float = 0.1, + gamma: Union[Sequence[float], float] = (0.5, 4.5), + device: Optional[torch.device] = None, + ): + RandomizableTransform.__init__(self, prob) + MapTransform.__init__(self, keys) + + self.mode = ensure_tuple_rep(mode, len(self.keys)) + + self.trans = RandSmoothFieldAdjustContrast( + spatial_size=spatial_size, + rand_size=rand_size, + pad=pad, + mode=self.mode[0], + align_corners=align_corners, + prob=1.0, + gamma=gamma, + device=device, + ) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandSmoothFieldAdjustContrastd": + super().set_random_state(seed, state) + self.trans.set_random_state(seed, state) + return self + + def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) + + if self._do_transform: + self.trans.randomize() + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + self.randomize() + + if not self._do_transform: + return data + + d = dict(data) + + for idx, key in enumerate(self.key_iterator(d)): + self.trans.set_mode(self.mode[idx % len(self.mode)]) + d[key] = self.trans(d[key], False) + + return d + + +class RandSmoothFieldAdjustIntensityd(RandomizableTransform, MapTransform): + """ + Dictionary version of RandSmoothFieldAdjustIntensity. + + The field is randomized once per invocation by default so the same field is applied to every selected key. The + `mode` parameter specifying interpolation mode for the field can be a single value or a sequence of values with + one for each key in `keys`. + + Args: + keys: key names to apply the augment to + spatial_size: size of input arrays, all arrays stated in `keys` must have same dimensions + rand_size: size of the randomized field to start from + pad: number of pixels/voxels along the edges of the field to pad with 0 + mode: interpolation mode to use when upsampling + align_corners: if True align the corners when upsampling field + prob: probability transform is applied + gamma: (min, max) range of intensity multipliers + device: Pytorch device to define field on + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__( + self, + keys: KeysCollection, + spatial_size: Sequence[int], + rand_size: Sequence[int], + pad: int = 0, + mode: Union[InterpolateModeType, Sequence[InterpolateModeType]] = InterpolateMode.AREA, + align_corners: Optional[bool] = None, + prob: float = 0.1, + gamma: Union[Sequence[float], float] = (0.1, 1.0), + device: Optional[torch.device] = None, + ): + RandomizableTransform.__init__(self, prob) + MapTransform.__init__(self, keys) + + self.mode = ensure_tuple_rep(mode, len(self.keys)) + + self.trans = RandSmoothFieldAdjustIntensity( + spatial_size=spatial_size, + rand_size=rand_size, + pad=pad, + mode=self.mode[0], + align_corners=align_corners, + prob=1.0, + gamma=gamma, + device=device, + ) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandSmoothFieldAdjustIntensityd": + super().set_random_state(seed, state) + self.trans.set_random_state(seed, state) + return self + + def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) + self.trans.randomize() + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + self.randomize() + + if not self._do_transform: + return data + + d = dict(data) + + for idx, key in enumerate(self.key_iterator(d)): + self.trans.set_mode(self.mode[idx % len(self.mode)]) + d[key] = self.trans(d[key], False) + + return d + + +class RandSmoothDeformd(RandomizableTransform, MapTransform): + """ + Dictionary version of RandSmoothDeform. + + The field is randomized once per invocation by default so the same field is applied to every selected key. The + `field_mode` parameter specifying interpolation mode for the field can be a single value or a sequence of values + with one for each key in `keys`. Similarly the `grid_mode` parameter can be one value or one per key. + + Args: + keys: key names to apply the augment to + spatial_size: input array size to which deformation grid is interpolated + rand_size: size of the randomized field to start from + pad: number of pixels/voxels along the edges of the field to pad with 0 + field_mode: interpolation mode to use when upsampling the deformation field + align_corners: if True align the corners when upsampling field + prob: probability transform is applied + def_range: value of the deformation range in image size fractions + grid_dtype: type for the deformation grid calculated from the field + grid_mode: interpolation mode used for sampling input using deformation grid + grid_padding_mode: padding mode used for sampling input using deformation grid + grid_align_corners: if True align the corners when sampling the deformation grid + device: Pytorch device to define field on + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__( + self, + keys: KeysCollection, + spatial_size: Sequence[int], + rand_size: Sequence[int], + pad: int = 0, + field_mode: Union[InterpolateModeType, Sequence[InterpolateModeType]] = InterpolateMode.AREA, + align_corners: Optional[bool] = None, + prob: float = 0.1, + def_range: Union[Sequence[float], float] = 1.0, + grid_dtype=torch.float32, + grid_mode: Union[GridSampleModeType, Sequence[GridSampleModeType]] = GridSampleMode.NEAREST, + grid_padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + grid_align_corners: Optional[bool] = False, + device: Optional[torch.device] = None, + ): + RandomizableTransform.__init__(self, prob) + MapTransform.__init__(self, keys) + + self.field_mode = ensure_tuple_rep(field_mode, len(self.keys)) + self.grid_mode = ensure_tuple_rep(grid_mode, len(self.keys)) + + self.trans = RandSmoothDeform( + rand_size=rand_size, + spatial_size=spatial_size, + pad=pad, + field_mode=self.field_mode[0], + align_corners=align_corners, + prob=1.0, + def_range=def_range, + grid_dtype=grid_dtype, + grid_mode=self.grid_mode[0], + grid_padding_mode=grid_padding_mode, + grid_align_corners=grid_align_corners, + device=device, + ) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandSmoothDeformd": + super().set_random_state(seed, state) + self.trans.set_random_state(seed, state) + return self + + def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) + self.trans.randomize() + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + self.randomize() + + if not self._do_transform: + return data + + d = dict(data) + + for idx, key in enumerate(self.key_iterator(d)): + self.trans.set_field_mode(self.field_mode[idx % len(self.field_mode)]) + self.trans.set_grid_mode(self.grid_mode[idx % len(self.grid_mode)]) + + d[key] = self.trans(d[key], False, self.trans.device) + + return d diff --git a/monai/transforms/spatial/__init__.py b/monai/transforms/spatial/__init__.py index 14ae193634..1e97f89407 100644 --- a/monai/transforms/spatial/__init__.py +++ b/monai/transforms/spatial/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index c3bd4a3433..4e5cac2c85 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,7 +13,7 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ import warnings -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -22,7 +22,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull -from monai.transforms.croppad.array import CenterSpatialCrop +from monai.transforms.croppad.array import CenterSpatialCrop, Pad from monai.transforms.transform import Randomizable, RandomizableTransform, ThreadUnsafe, Transform from monai.transforms.utils import ( create_control_grid, @@ -38,6 +38,7 @@ GridSamplePadMode, InterpolateMode, NumpyPadMode, + PytorchPadMode, ensure_tuple, ensure_tuple_rep, ensure_tuple_size, @@ -45,8 +46,10 @@ issequenceiterable, optional_import, ) +from monai.utils.deprecate_utils import deprecated_arg from monai.utils.enums import TransformBackends from monai.utils.module import look_up_option +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type nib, _ = optional_import("nibabel") @@ -54,6 +57,7 @@ "Spacing", "Orientation", "Flip", + "GridDistortion", "Resize", "Rotate", "Zoom", @@ -61,6 +65,7 @@ "RandRotate90", "RandRotate", "RandFlip", + "RandGridDistortion", "RandAxisFlip", "RandZoom", "AffineGrid", @@ -71,7 +76,6 @@ "RandAffine", "Rand2DElastic", "Rand3DElastic", - "AddCoordinateChannels", ] RandRange = Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] @@ -82,6 +86,8 @@ class Spacing(Transform): Resample input image into the specified `pixdim`. """ + backend = [TransformBackends.TORCH] + def __init__( self, pixdim: Union[Sequence[float], float], @@ -90,6 +96,7 @@ def __init__( padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, align_corners: bool = False, dtype: DtypeLike = np.float64, + image_only: bool = False, ) -> None: """ Args: @@ -122,6 +129,7 @@ def __init__( dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. + image_only: return just the image or the image, the old affine and new affine. Default is `False`. """ self.pixdim = np.array(ensure_tuple(pixdim), dtype=np.float64) @@ -130,17 +138,18 @@ def __init__( self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) self.align_corners = align_corners self.dtype = dtype + self.image_only = image_only def __call__( self, - data_array: np.ndarray, - affine: Optional[np.ndarray] = None, + data_array: NdarrayOrTensor, + affine: Optional[NdarrayOrTensor] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, dtype: DtypeLike = None, output_spatial_shape: Optional[np.ndarray] = None, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + ) -> Union[NdarrayOrTensor, Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor]]: """ Args: data_array: in shape (num_channels, H[, W, ...]). @@ -169,15 +178,16 @@ def __call__( """ _dtype = dtype or self.dtype or data_array.dtype - sr = data_array.ndim - 1 + sr = int(data_array.ndim - 1) if sr <= 0: raise ValueError("data_array must have at least one spatial dimension.") if affine is None: # default to identity - affine = np.eye(sr + 1, dtype=np.float64) + affine_np = affine = np.eye(sr + 1, dtype=np.float64) affine_ = np.eye(sr + 1, dtype=np.float64) else: - affine_ = to_affine_nd(sr, affine) + affine_np, *_ = convert_data_type(affine, np.ndarray) # type: ignore + affine_ = to_affine_nd(sr, affine_np) out_d = self.pixdim[:sr] if out_d.size < sr: @@ -193,27 +203,31 @@ def __call__( # no resampling if it's identity transform if np.allclose(transform, np.diag(np.ones(len(transform))), atol=1e-3): - output_data = data_array.copy().astype(np.float32) - new_affine = to_affine_nd(affine, new_affine) - return output_data, affine, new_affine - - # resample - affine_xform = AffineTransform( - normalized=False, - mode=look_up_option(mode or self.mode, GridSampleMode), - padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode), - align_corners=self.align_corners if align_corners is None else align_corners, - reverse_indexing=True, - ) - output_data = affine_xform( - # AffineTransform requires a batch dim - torch.as_tensor(np.ascontiguousarray(data_array).astype(_dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(transform).astype(_dtype)), - spatial_size=output_shape if output_spatial_shape is None else output_spatial_shape, - ) - output_data = np.asarray(output_data.squeeze(0).detach().cpu().numpy(), dtype=np.float32) # type: ignore - new_affine = to_affine_nd(affine, new_affine) - + output_data = data_array + else: + # resample + affine_xform = AffineTransform( + normalized=False, + mode=look_up_option(mode or self.mode, GridSampleMode), + padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode), + align_corners=self.align_corners if align_corners is None else align_corners, + reverse_indexing=True, + ) + data_array_t: torch.Tensor + data_array_t, *_ = convert_data_type(data_array, torch.Tensor, dtype=_dtype) # type: ignore + output_data = affine_xform( + # AffineTransform requires a batch dim + data_array_t.unsqueeze(0), + convert_data_type(transform, torch.Tensor, data_array_t.device, dtype=_dtype)[0], + spatial_size=output_shape if output_spatial_shape is None else output_spatial_shape, + ).squeeze(0) + + output_data, *_ = convert_to_dst_type(output_data, data_array, dtype=torch.float32) + new_affine = to_affine_nd(affine_np, new_affine) # type: ignore + new_affine, *_ = convert_to_dst_type(src=new_affine, dst=affine, dtype=torch.float32) + + if self.image_only: + return output_data return output_data, affine, new_affine @@ -222,11 +236,14 @@ class Orientation(Transform): Change the input image's orientation into the specified based on `axcodes`. """ + backend = [TransformBackends.NUMPY] + def __init__( self, axcodes: Optional[str] = None, as_closest_canonical: bool = False, labels: Optional[Sequence[Tuple[str, str]]] = tuple(zip("LPI", "RAS")), + image_only: bool = False, ) -> None: """ Args: @@ -239,6 +256,7 @@ def __init__( labels: optional, None or sequence of (2,) sequences (2,) sequences are labels for (beginning, end) of output axis. Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``. + image_only: if True return only the image volume, otherwise return (image, affine, new_affine). Raises: ValueError: When ``axcodes=None`` and ``as_closest_canonical=True``. Incompatible values. @@ -253,10 +271,11 @@ def __init__( self.axcodes = axcodes self.as_closest_canonical = as_closest_canonical self.labels = labels + self.image_only = image_only def __call__( - self, data_array: np.ndarray, affine: Optional[np.ndarray] = None - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + self, data_array: NdarrayOrTensor, affine: Optional[NdarrayOrTensor] = None + ) -> Union[NdarrayOrTensor, Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor]]: """ original orientation of `data_array` is defined by `affine`. @@ -269,17 +288,22 @@ def __call__( ValueError: When ``axcodes`` spatiality differs from ``data_array``. Returns: - data_array (reoriented in `self.axcodes`), original axcodes, current axcodes. + data_array [reoriented in `self.axcodes`] if `self.image_only`, else + (data_array [reoriented in `self.axcodes`], original axcodes, current axcodes). """ - sr = data_array.ndim - 1 + data_array_np, *_ = convert_data_type(data_array, np.ndarray) # type: ignore + sr = data_array_np.ndim - 1 if sr <= 0: raise ValueError("data_array must have at least one spatial dimension.") if affine is None: - affine = np.eye(sr + 1, dtype=np.float64) + # default to identity + affine_np = affine = np.eye(sr + 1, dtype=np.float64) affine_ = np.eye(sr + 1, dtype=np.float64) else: - affine_ = to_affine_nd(sr, affine) + affine_np, *_ = convert_data_type(affine, np.ndarray) # type: ignore + affine_ = to_affine_nd(sr, affine_np) + src = nib.io_orientation(affine_) if self.as_closest_canonical: spatial_ornt = src @@ -295,12 +319,16 @@ def __call__( ornt = spatial_ornt.copy() ornt[:, 0] += 1 # skip channel dim ornt = np.concatenate([np.array([[0, 1]]), ornt]) - shape = data_array.shape[1:] - data_array = np.ascontiguousarray(nib.orientations.apply_orientation(data_array, ornt)) + shape = data_array_np.shape[1:] + data_array_np = np.ascontiguousarray(nib.orientations.apply_orientation(data_array_np, ornt)) new_affine = affine_ @ nib.orientations.inv_ornt_aff(spatial_ornt, shape) - new_affine = to_affine_nd(affine, new_affine) + new_affine = to_affine_nd(affine_np, new_affine) + out, *_ = convert_to_dst_type(src=data_array_np, dst=data_array) + new_affine, *_ = convert_to_dst_type(src=new_affine, dst=affine, dtype=torch.float32) - return data_array, affine, new_affine + if self.image_only: + return out + return out, affine, new_affine class Flip(Transform): @@ -330,8 +358,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ if isinstance(img, np.ndarray): return np.ascontiguousarray(np.flip(img, map_spatial_axes(img.ndim, self.spatial_axis))) - else: - return torch.flip(img, map_spatial_axes(img.ndim, self.spatial_axis)) + return torch.flip(img, map_spatial_axes(img.ndim, self.spatial_axis)) class Resize(Transform): @@ -357,6 +384,8 @@ class Resize(Transform): See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate """ + backend = [TransformBackends.TORCH] + def __init__( self, spatial_size: Union[Sequence[int], int], @@ -371,10 +400,10 @@ def __init__( def __call__( self, - img: np.ndarray, + img: NdarrayOrTensor, mode: Optional[Union[InterpolateMode, str]] = None, align_corners: Optional[bool] = None, - ) -> np.ndarray: + ) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]). @@ -389,32 +418,33 @@ def __call__( ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. """ + img_, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) # type: ignore if self.size_mode == "all": - input_ndim = img.ndim - 1 # spatial ndim + input_ndim = img_.ndim - 1 # spatial ndim output_ndim = len(ensure_tuple(self.spatial_size)) if output_ndim > input_ndim: - input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) - img = img.reshape(input_shape) + input_shape = ensure_tuple_size(img_.shape, output_ndim + 1, 1) + img_ = img_.reshape(input_shape) elif output_ndim < input_ndim: raise ValueError( "len(spatial_size) must be greater or equal to img spatial dimensions, " f"got spatial_size={output_ndim} img={input_ndim}." ) - spatial_size_ = fall_back_tuple(self.spatial_size, img.shape[1:]) + spatial_size_ = fall_back_tuple(self.spatial_size, img_.shape[1:]) else: # for the "longest" mode - img_size = img.shape[1:] + img_size = img_.shape[1:] if not isinstance(self.spatial_size, int): raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") scale = self.spatial_size / max(img_size) spatial_size_ = tuple(int(round(s * scale)) for s in img_size) resized = torch.nn.functional.interpolate( # type: ignore - input=torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0), + input=img_.unsqueeze(0), # type: ignore size=spatial_size_, mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value, align_corners=self.align_corners if align_corners is None else align_corners, ) - resized = resized.squeeze(0).detach().cpu().numpy() - return np.asarray(resized) + out, *_ = convert_to_dst_type(resized.squeeze(0), img) + return out class Rotate(Transform, ThreadUnsafe): @@ -439,6 +469,8 @@ class Rotate(Transform, ThreadUnsafe): the output data type is always ``np.float32``. """ + backend = [TransformBackends.TORCH] + def __init__( self, angle: Union[Sequence[float], float], @@ -446,7 +478,7 @@ def __init__( mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, align_corners: bool = False, - dtype: DtypeLike = np.float64, + dtype: Union[DtypeLike, torch.dtype] = np.float64, ) -> None: self.angle = angle self.keep_size = keep_size @@ -454,16 +486,16 @@ def __init__( self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) self.align_corners = align_corners self.dtype = dtype - self._rotation_matrix: Optional[np.ndarray] = None + self._rotation_matrix: Optional[NdarrayOrTensor] = None def __call__( self, - img: np.ndarray, + img: NdarrayOrTensor, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, - dtype: DtypeLike = None, - ) -> np.ndarray: + dtype: Union[DtypeLike, torch.dtype] = None, + ) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D]. @@ -486,7 +518,11 @@ def __call__( """ _dtype = dtype or self.dtype or img.dtype - im_shape = np.asarray(img.shape[1:]) # spatial dimensions + + img_t: torch.Tensor + img_t, *_ = convert_data_type(img, torch.Tensor, dtype=_dtype) # type: ignore + + im_shape = np.asarray(img_t.shape[1:]) # spatial dimensions input_ndim = len(im_shape) if input_ndim not in (2, 3): raise ValueError(f"Unsupported img dimension: {input_ndim}, available options are [2, 3].") @@ -499,11 +535,14 @@ def __call__( corners = np.asarray(np.meshgrid(*[(0, dim) for dim in im_shape], indexing="ij")).reshape( (len(im_shape), -1) ) - corners = transform[:-1, :-1] @ corners + corners = transform[:-1, :-1] @ corners # type: ignore output_shape = np.asarray(corners.ptp(axis=1) + 0.5, dtype=int) shift_1 = create_translate(input_ndim, (-(output_shape - 1) / 2).tolist()) transform = shift @ transform @ shift_1 + transform_t: torch.Tensor + transform_t, *_ = convert_to_dst_type(transform, img_t) # type: ignore + xform = AffineTransform( normalized=False, mode=look_up_option(mode or self.mode, GridSampleMode), @@ -511,15 +550,13 @@ def __call__( align_corners=self.align_corners if align_corners is None else align_corners, reverse_indexing=True, ) - output = xform( - torch.as_tensor(np.ascontiguousarray(img).astype(_dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(transform).astype(_dtype)), - spatial_size=output_shape, - ) + output: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=output_shape).float().squeeze(0) self._rotation_matrix = transform - return np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + out: NdarrayOrTensor + out, *_ = convert_to_dst_type(output, dst=img, dtype=output.dtype) + return out - def get_rotation_matrix(self) -> Optional[np.ndarray]: + def get_rotation_matrix(self) -> Optional[NdarrayOrTensor]: """ Get the most recently applied rotation matrix This is not thread-safe. @@ -542,82 +579,96 @@ class Zoom(Transform): mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"area"``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - padding_mode: {``"constant"``, ``"edge``", ``"linear_ramp``", ``"maximum``", ``"mean``", `"median``", - ``"minimum``", `"reflect``", ``"symmetric``", ``"wrap``", ``"empty``", ``"``"} + padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. The mode to pad data after zooming. - See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate keep_size: Should keep original size (padding/slicing if needed), default is True. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ + backend = [TransformBackends.TORCH] + def __init__( self, zoom: Union[Sequence[float], float], mode: Union[InterpolateMode, str] = InterpolateMode.AREA, - padding_mode: Union[NumpyPadMode, str] = NumpyPadMode.EDGE, + padding_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.EDGE, align_corners: Optional[bool] = None, keep_size: bool = True, - **np_kwargs, + **kwargs, ) -> None: self.zoom = zoom self.mode: InterpolateMode = InterpolateMode(mode) - self.padding_mode: NumpyPadMode = NumpyPadMode(padding_mode) + self.padding_mode = padding_mode self.align_corners = align_corners self.keep_size = keep_size - self.np_kwargs = np_kwargs + self.kwargs = kwargs def __call__( self, - img: np.ndarray, + img: NdarrayOrTensor, mode: Optional[Union[InterpolateMode, str]] = None, - padding_mode: Optional[Union[NumpyPadMode, str]] = None, + padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, align_corners: Optional[bool] = None, - ): + ) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]). mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``self.mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - padding_mode: {``"constant"``, ``"edge``", ``"linear_ramp``", ``"maximum``", ``"mean``", `"median``", - ``"minimum``", `"reflect``", ``"symmetric``", ``"wrap``", ``"empty``", ``"``"} - The mode to pad data after zooming, default to ``self.padding_mode``. - See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + The mode to pad data after zooming. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate """ + img_t: torch.Tensor + img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float32) # type: ignore + _zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim - zoomed = torch.nn.functional.interpolate( # type: ignore + zoomed: NdarrayOrTensor = torch.nn.functional.interpolate( # type: ignore recompute_scale_factor=True, - input=torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0), + input=img_t.unsqueeze(0), scale_factor=list(_zoom), mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value, align_corners=self.align_corners if align_corners is None else align_corners, ) - zoomed = zoomed.squeeze(0).detach().cpu().numpy() - if not self.keep_size or np.allclose(img.shape, zoomed.shape): - return zoomed + zoomed = zoomed.squeeze(0) + + if self.keep_size and not np.allclose(img_t.shape, zoomed.shape): - pad_vec = [[0, 0]] * len(img.shape) - slice_vec = [slice(None)] * len(img.shape) - for idx, (od, zd) in enumerate(zip(img.shape, zoomed.shape)): - diff = od - zd - half = abs(diff) // 2 - if diff > 0: # need padding - pad_vec[idx] = [half, diff - half] - elif diff < 0: # need slicing - slice_vec[idx] = slice(half, half + od) + pad_vec = [(0, 0)] * len(img_t.shape) + slice_vec = [slice(None)] * len(img_t.shape) + for idx, (od, zd) in enumerate(zip(img_t.shape, zoomed.shape)): + diff = od - zd + half = abs(diff) // 2 + if diff > 0: # need padding + pad_vec[idx] = (half, diff - half) + elif diff < 0: # need slicing + slice_vec[idx] = slice(half, half + od) - padding_mode = look_up_option(self.padding_mode if padding_mode is None else padding_mode, NumpyPadMode) - zoomed = np.pad(zoomed, pad_vec, mode=padding_mode.value, **self.np_kwargs) # type: ignore - return zoomed[tuple(slice_vec)] + padder = Pad(pad_vec, padding_mode or self.padding_mode) + zoomed = padder(zoomed) + zoomed = zoomed[tuple(slice_vec)] + + out, *_ = convert_to_dst_type(zoomed, dst=img) + return out class Rotate90(Transform): @@ -628,6 +679,8 @@ class Rotate90(Transform): """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: """ Args: @@ -642,14 +695,15 @@ def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: raise ValueError("spatial_axes must be 2 int numbers to indicate the axes to rotate 90 degrees.") self.spatial_axes = spatial_axes_ - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - - result: np.ndarray = np.rot90(img, self.k, map_spatial_axes(img.ndim, self.spatial_axes)) - return result.astype(img.dtype) + rot90: Callable = torch.rot90 if isinstance(img, torch.Tensor) else np.rot90 # type: ignore + out: NdarrayOrTensor = rot90(img, self.k, map_spatial_axes(img.ndim, self.spatial_axes)) + out, *_ = convert_data_type(out, dtype=img.dtype) + return out class RandRotate90(RandomizableTransform): @@ -658,6 +712,8 @@ class RandRotate90(RandomizableTransform): in the plane specified by `spatial_axes`. """ + backend = Rotate90.backend + def __init__(self, prob: float = 0.1, max_k: int = 3, spatial_axes: Tuple[int, int] = (0, 1)) -> None: """ Args: @@ -674,19 +730,24 @@ def __init__(self, prob: float = 0.1, max_k: int = 3, spatial_axes: Tuple[int, i self._rand_k = 0 def randomize(self, data: Optional[Any] = None) -> None: - self._rand_k = self.R.randint(self.max_k) + 1 super().randomize(None) + if not self._do_transform: + return None + self._rand_k = self.R.randint(self.max_k) + 1 - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), + randomize: whether to execute `randomize()` function first, default to True. """ - self.randomize() + if randomize: + self.randomize() + if not self._do_transform: return img - rotator = Rotate90(self._rand_k, self.spatial_axes) - return rotator(img) + + return Rotate90(self._rand_k, self.spatial_axes)(img) class RandRotate(RandomizableTransform): @@ -717,6 +778,8 @@ class RandRotate(RandomizableTransform): the output data type is always ``np.float32``. """ + backend = Rotate.backend + def __init__( self, range_x: Union[Tuple[float, float], float] = 0.0, @@ -727,7 +790,7 @@ def __init__( mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, align_corners: bool = False, - dtype: DtypeLike = np.float64, + dtype: Union[DtypeLike, torch.dtype] = np.float64, ) -> None: RandomizableTransform.__init__(self, prob) self.range_x = ensure_tuple(range_x) @@ -752,18 +815,22 @@ def __init__( def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) + if not self._do_transform: + return None self.x = self.R.uniform(low=self.range_x[0], high=self.range_x[1]) self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) def __call__( self, - img: np.ndarray, + img: NdarrayOrTensor, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, - dtype: DtypeLike = None, - ) -> np.ndarray: + dtype: Union[DtypeLike, torch.dtype] = None, + randomize: bool = True, + get_matrix: bool = False, + ): """ Args: img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D). @@ -778,10 +845,15 @@ def __call__( dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. + randomize: whether to execute `randomize()` function first, default to True. + get_matrix: whether to return the rotated image and rotate matrix together, default to False. """ - self.randomize() + if randomize: + self.randomize() + if not self._do_transform: return img + rotator = Rotate( angle=self.x if img.ndim == 3 else (self.x, self.y, self.z), keep_size=self.keep_size, @@ -790,7 +862,8 @@ def __call__( align_corners=self.align_corners if align_corners is None else align_corners, dtype=dtype or self.dtype or img.dtype, ) - return np.array(rotator(img)) + img = rotator(img) + return (img, rotator.get_rotation_matrix()) if get_matrix else img class RandFlip(RandomizableTransform): @@ -810,14 +883,18 @@ def __init__(self, prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int] RandomizableTransform.__init__(self, prob) self.flipper = Flip(spatial_axis=spatial_axis) - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), + randomize: whether to execute `randomize()` function first, default to True. """ - self.randomize(None) + if randomize: + self.randomize(None) + if not self._do_transform: return img + return self.flipper(img) @@ -840,18 +917,23 @@ def __init__(self, prob: float = 0.1) -> None: def randomize(self, data: NdarrayOrTensor) -> None: super().randomize(None) + if not self._do_transform: + return None self._axis = self.R.randint(data.ndim - 1) - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), + randomize: whether to execute `randomize()` function first, default to True. """ - self.randomize(data=img) + if randomize: + self.randomize(data=img) + if not self._do_transform: return img - flipper = Flip(spatial_axis=self._axis) - return flipper(img) + + return Flip(spatial_axis=self._axis)(img) class RandZoom(RandomizableTransform): @@ -873,29 +955,34 @@ class RandZoom(RandomizableTransform): mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"area"``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - padding_mode: {``"constant"``, ``"edge``", ``"linear_ramp``", ``"maximum``", ``"mean``", `"median``", - ``"minimum``", `"reflect``", ``"symmetric``", ``"wrap``", ``"empty``", ``"``"} + padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. The mode to pad data after zooming. - See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate keep_size: Should keep original size (pad if needed), default is True. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ + backend = Zoom.backend + def __init__( self, prob: float = 0.1, min_zoom: Union[Sequence[float], float] = 0.9, max_zoom: Union[Sequence[float], float] = 1.1, mode: Union[InterpolateMode, str] = InterpolateMode.AREA, - padding_mode: Union[NumpyPadMode, str] = NumpyPadMode.EDGE, + padding_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.EDGE, align_corners: Optional[bool] = None, keep_size: bool = True, - **np_kwargs, + **kwargs, ) -> None: RandomizableTransform.__init__(self, prob) self.min_zoom = ensure_tuple(min_zoom) @@ -903,59 +990,67 @@ def __init__( if len(self.min_zoom) != len(self.max_zoom): raise AssertionError("min_zoom and max_zoom must have same length.") self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) - self.padding_mode: NumpyPadMode = look_up_option(padding_mode, NumpyPadMode) + self.padding_mode = padding_mode self.align_corners = align_corners self.keep_size = keep_size - self.np_kwargs = np_kwargs + self.kwargs = kwargs self._zoom: Sequence[float] = [1.0] - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, img: NdarrayOrTensor) -> None: super().randomize(None) + if not self._do_transform: + return None self._zoom = [self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)] + if len(self._zoom) == 1: + # to keep the spatial shape ratio, use same random zoom factor for all dims + self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 1) + elif len(self._zoom) == 2 and img.ndim > 3: + # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim + self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 2) + ensure_tuple(self._zoom[-1]) def __call__( self, - img: np.ndarray, + img: NdarrayOrTensor, mode: Optional[Union[InterpolateMode, str]] = None, - padding_mode: Optional[Union[NumpyPadMode, str]] = None, + padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, align_corners: Optional[bool] = None, - ) -> np.ndarray: + randomize: bool = True, + ) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D). mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``self.mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate - padding_mode: {``"constant"``, ``"edge``", ``"linear_ramp``", ``"maximum``", ``"mean``", `"median``", - ``"minimum``", `"reflect``", ``"symmetric``", ``"wrap``", ``"empty``", ``"``"} - The mode to pad data after zooming, default to ``self.padding_mode``. - See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + The mode to pad data after zooming. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + randomize: whether to execute `randomize()` function first, default to True. + """ # match the spatial image dim - self.randomize() - _dtype = np.float32 + if randomize: + self.randomize(img=img) + if not self._do_transform: - return img.astype(_dtype) - if len(self._zoom) == 1: - # to keep the spatial shape ratio, use same random zoom factor for all dims - self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 1) - elif len(self._zoom) == 2 and img.ndim > 3: - # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim - self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 2) + ensure_tuple(self._zoom[-1]) - zoomer = Zoom(self._zoom, keep_size=self.keep_size, **self.np_kwargs) - return np.asarray( - zoomer( - img, - mode=look_up_option(mode or self.mode, InterpolateMode), - padding_mode=look_up_option(padding_mode or self.padding_mode, NumpyPadMode), - align_corners=self.align_corners if align_corners is None else align_corners, - ), - dtype=_dtype, - ) + return img + + return Zoom( + self._zoom, + keep_size=self.keep_size, + mode=look_up_option(mode or self.mode, InterpolateMode), + padding_mode=padding_mode or self.padding_mode, + align_corners=align_corners or self.align_corners, + **self.kwargs, + )(img) class AffineGrid(Transform): @@ -979,14 +1074,18 @@ class AffineGrid(Transform): pixel/voxel relative to the center of the input image. Defaults to no translation. scale_params: scale factor for every spatial dims. a tuple of 2 floats for 2D, a tuple of 3 floats for 3D. Defaults to `1.0`. - as_tensor_output: whether to output tensor instead of numpy array, defaults to True. - device: device to store the output grid data. affine: If applied, ignore the params (`rotate_params`, etc.) and use the supplied matrix. Should be square with each side = num of image spatial dimensions + 1. + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, rotate_params: Optional[Union[Sequence[float], float]] = None, @@ -995,24 +1094,23 @@ def __init__( scale_params: Optional[Union[Sequence[float], float]] = None, as_tensor_output: bool = True, device: Optional[torch.device] = None, - affine: Optional[Union[np.ndarray, torch.Tensor]] = None, + affine: Optional[NdarrayOrTensor] = None, ) -> None: self.rotate_params = rotate_params self.shear_params = shear_params self.translate_params = translate_params self.scale_params = scale_params - - self.as_tensor_output = as_tensor_output self.device = device - self.affine = affine def __call__( - self, - spatial_size: Optional[Sequence[int]] = None, - grid: Optional[Union[np.ndarray, torch.Tensor]] = None, - ) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]: + self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[NdarrayOrTensor] = None + ) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: """ + The grid can be initialized with a `spatial_size` parameter, or provided directly as `grid`. + Therefore, either `spatial_size` or `grid` must be provided. + When initialising from `spatial_size`, the backend "torch" will be used. + Args: spatial_size: output grid size. grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. @@ -1023,36 +1121,36 @@ def __call__( """ if grid is None: if spatial_size is not None: - grid = create_grid(spatial_size) + grid = create_grid(spatial_size, device=self.device, backend="torch") else: raise ValueError("Incompatible values: grid=None and spatial_size=None.") - affine: Union[torch.Tensor, np.ndarray] + _b = TransformBackends.TORCH if isinstance(grid, torch.Tensor) else TransformBackends.NUMPY + _device = grid.device if isinstance(grid, torch.Tensor) else self.device + affine: NdarrayOrTensor if self.affine is None: spatial_dims = len(grid.shape) - 1 - affine = np.eye(spatial_dims + 1) + affine = ( + torch.eye(spatial_dims + 1, device=_device) + if _b == TransformBackends.TORCH + else np.eye(spatial_dims + 1) + ) if self.rotate_params: - affine = affine @ create_rotate(spatial_dims, self.rotate_params) + affine = affine @ create_rotate(spatial_dims, self.rotate_params, device=_device, backend=_b) if self.shear_params: - affine = affine @ create_shear(spatial_dims, self.shear_params) + affine = affine @ create_shear(spatial_dims, self.shear_params, device=_device, backend=_b) if self.translate_params: - affine = affine @ create_translate(spatial_dims, self.translate_params) + affine = affine @ create_translate(spatial_dims, self.translate_params, device=_device, backend=_b) if self.scale_params: - affine = affine @ create_scale(spatial_dims, self.scale_params) + affine = affine @ create_scale(spatial_dims, self.scale_params, device=_device, backend=_b) else: affine = self.affine - if isinstance(affine, np.ndarray): - affine = torch.as_tensor(np.ascontiguousarray(affine)) + grid, *_ = convert_data_type(grid, torch.Tensor, device=_device, dtype=float) + affine, *_ = convert_to_dst_type(affine, grid) - grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone() - if self.device: - affine = affine.to(self.device) - grid = grid.to(self.device) - grid = (affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:])) - if grid is None or not isinstance(grid, torch.Tensor): - raise ValueError("Unknown grid.") - return grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()), affine + grid = (affine @ grid.reshape((grid.shape[0], -1))).reshape([-1] + list(grid.shape[1:])) + return grid, affine class RandAffineGrid(Randomizable, Transform): @@ -1061,6 +1159,9 @@ class RandAffineGrid(Randomizable, Transform): """ + backend = AffineGrid.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, rotate_range: RandRange = None, @@ -1094,8 +1195,6 @@ def __init__( scale_range: scaling range with format matching `rotate_range`. it defines the range to randomly select the scale factor to translate for every spatial dims. A value of 1.0 is added to the result. This allows 0 to correspond to no change (i.e., a scaling of 1.0). - as_tensor_output: whether to output tensor instead of numpy array. - defaults to True. device: device to store the output grid data. See also: @@ -1103,6 +1202,10 @@ def __init__( - :py:meth:`monai.transforms.utils.create_shear` - :py:meth:`monai.transforms.utils.create_translate` - :py:meth:`monai.transforms.utils.create_scale` + + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ self.rotate_range = ensure_tuple(rotate_range) self.shear_range = ensure_tuple(shear_range) @@ -1114,9 +1217,8 @@ def __init__( self.translate_params: Optional[List[float]] = None self.scale_params: Optional[List[float]] = None - self.as_tensor_output = as_tensor_output self.device = device - self.affine: Optional[Union[np.ndarray, torch.Tensor]] = None + self.affine: Optional[NdarrayOrTensor] = None def _get_rand_param(self, param_range, add_scalar: float = 0.0): out_param = [] @@ -1136,10 +1238,8 @@ def randomize(self, data: Optional[Any] = None) -> None: self.scale_params = self._get_rand_param(self.scale_range, 1.0) def __call__( - self, - spatial_size: Optional[Sequence[int]] = None, - grid: Optional[Union[np.ndarray, torch.Tensor]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[NdarrayOrTensor] = None + ) -> NdarrayOrTensor: """ Args: spatial_size: output grid size. @@ -1154,13 +1254,13 @@ def __call__( shear_params=self.shear_params, translate_params=self.translate_params, scale_params=self.scale_params, - as_tensor_output=self.as_tensor_output, device=self.device, ) - grid, self.affine = affine_grid(spatial_size, grid) - return grid + _grid: NdarrayOrTensor + _grid, self.affine = affine_grid(spatial_size, grid) + return _grid - def get_transformation_matrix(self) -> Optional[Union[np.ndarray, torch.Tensor]]: + def get_transformation_matrix(self) -> Optional[NdarrayOrTensor]: """Get the most recently applied transformation matrix""" return self.affine @@ -1170,6 +1270,8 @@ class RandDeformGrid(Randomizable, Transform): Generate random deformation grid. """ + backend = [TransformBackends.TORCH] + def __init__( self, spacing: Union[Sequence[float], float], @@ -1198,7 +1300,7 @@ def __init__( self.device = device def randomize(self, grid_size: Sequence[int]) -> None: - self.random_offset = self.R.normal(size=([len(grid_size)] + list(grid_size))).astype(np.float32) + self.random_offset = self.R.normal(size=([len(grid_size)] + list(grid_size))).astype(np.float32, copy=False) self.rand_mag = self.R.uniform(self.magnitude[0], self.magnitude[1]) def __call__(self, spatial_size: Sequence[int]): @@ -1207,20 +1309,25 @@ def __call__(self, spatial_size: Sequence[int]): spatial_size: spatial size of the grid. """ self.spacing = fall_back_tuple(self.spacing, (1.0,) * len(spatial_size)) - control_grid = create_control_grid(spatial_size, self.spacing) + control_grid = create_control_grid(spatial_size, self.spacing, device=self.device, backend="torch") self.randomize(control_grid.shape[1:]) - control_grid[: len(spatial_size)] += self.rand_mag * self.random_offset - if self.as_tensor_output: - control_grid = torch.as_tensor(np.ascontiguousarray(control_grid), device=self.device) + _offset, *_ = convert_to_dst_type(self.rand_mag * self.random_offset, control_grid) + control_grid[: len(spatial_size)] += _offset + if not self.as_tensor_output: + control_grid, *_ = convert_data_type(control_grid, output_type=np.ndarray, dtype=np.float32) return control_grid class Resample(Transform): + + backend = [TransformBackends.TORCH] + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, - as_tensor_output: bool = False, + as_tensor_output: bool = True, device: Optional[torch.device] = None, ) -> None: """ @@ -1234,21 +1341,23 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: whether to return a torch tensor. Defaults to False. device: device on which the tensor will be allocated. + + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) - self.as_tensor_output = as_tensor_output self.device = device def __call__( self, - img: Union[np.ndarray, torch.Tensor], - grid: Optional[Union[np.ndarray, torch.Tensor]] = None, + img: NdarrayOrTensor, + grid: Optional[NdarrayOrTensor] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> NdarrayOrTensor: """ Args: img: shape must be (num_channels, H, W[, D]). @@ -1260,21 +1369,19 @@ def __call__( Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample """ - - if not isinstance(img, torch.Tensor): - img = torch.as_tensor(np.ascontiguousarray(img)) if grid is None: - raise AssertionError("Error, grid argument must be supplied as an ndarray or tensor ") - grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone() - if self.device: - img = img.to(self.device) - grid = grid.to(self.device) + raise ValueError("Unknown grid.") + _device = img.device if isinstance(img, torch.Tensor) else self.device + img_t: torch.Tensor + grid_t: torch.Tensor + img_t, *_ = convert_data_type(img, torch.Tensor, device=_device, dtype=torch.float32) # type: ignore + grid_t, *_ = convert_to_dst_type(grid, img_t) # type: ignore if USE_COMPILED: - for i, dim in enumerate(img.shape[1:]): - grid[i] += (dim - 1.0) / 2.0 - grid = grid[:-1] / grid[-1:] - grid = grid.permute(list(range(grid.ndimension()))[1:] + [0]) + for i, dim in enumerate(img_t.shape[1:]): + grid_t[i] += (dim - 1.0) / 2.0 + grid_t = grid_t[:-1] / grid_t[-1:] + grid_t = grid_t.permute(list(range(grid_t.ndimension()))[1:] + [0]) _padding_mode = look_up_option( self.padding_mode if padding_mode is None else padding_mode, GridSamplePadMode ).value @@ -1286,29 +1393,29 @@ def __call__( bound = 1 _interp_mode = look_up_option(self.mode if mode is None else mode, GridSampleMode).value out = grid_pull( - img.unsqueeze(0).float(), - grid.unsqueeze(0).float(), + img_t.unsqueeze(0), + grid_t.unsqueeze(0), bound=bound, extrapolate=True, interpolation=1 if _interp_mode == "bilinear" else _interp_mode, )[0] else: - for i, dim in enumerate(img.shape[1:]): - grid[i] = 2.0 * grid[i] / (dim - 1.0) - grid = grid[:-1] / grid[-1:] - index_ordering: List[int] = list(range(img.ndimension() - 2, -1, -1)) - grid = grid[index_ordering] - grid = grid.permute(list(range(grid.ndimension()))[1:] + [0]) + for i, dim in enumerate(img_t.shape[1:]): + grid_t[i] = 2.0 * grid_t[i] / (dim - 1.0) + grid_t = grid_t[:-1] / grid_t[-1:] + index_ordering: List[int] = list(range(img_t.ndimension() - 2, -1, -1)) + grid_t = grid_t[index_ordering] + grid_t = grid_t.permute(list(range(grid_t.ndimension()))[1:] + [0]) out = torch.nn.functional.grid_sample( - img.unsqueeze(0).float(), - grid.unsqueeze(0).float(), + img_t.unsqueeze(0), + grid_t.unsqueeze(0), mode=self.mode.value if mode is None else GridSampleMode(mode).value, padding_mode=self.padding_mode.value if padding_mode is None else GridSamplePadMode(padding_mode).value, align_corners=True, )[0] - if self.as_tensor_output: - return torch.as_tensor(out) - return np.asarray(out.cpu().numpy()) + out_val: NdarrayOrTensor + out_val, *_ = convert_to_dst_type(out, dst=img, dtype=out.dtype) + return out_val class Affine(Transform): @@ -1318,16 +1425,20 @@ class Affine(Transform): """ + backend = list(set(AffineGrid.backend) & set(Resample.backend)) + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, rotate_params: Optional[Union[Sequence[float], float]] = None, shear_params: Optional[Union[Sequence[float], float]] = None, translate_params: Optional[Union[Sequence[float], float]] = None, scale_params: Optional[Union[Sequence[float], float]] = None, + affine: Optional[NdarrayOrTensor] = None, spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION, - as_tensor_output: bool = False, + as_tensor_output: bool = True, device: Optional[torch.device] = None, image_only: bool = False, ) -> None: @@ -1351,6 +1462,9 @@ def __init__( pixel/voxel relative to the center of the input image. Defaults to no translation. scale_params: scale factor for every spatial dims. a tuple of 2 floats for 2D, a tuple of 3 floats for 3D. Defaults to `1.0`. + affine: If applied, ignore the params (`rotate_params`, etc.) and use the + supplied matrix. Should be square with each side = num of image spatial + dimensions + 1. spatial_size: output image spatial size. if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1, the transform will use the spatial size of `img`. @@ -1363,32 +1477,34 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. image_only: if True return only the image volume, otherwise return (image, affine). + + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ self.affine_grid = AffineGrid( rotate_params=rotate_params, shear_params=shear_params, translate_params=translate_params, scale_params=scale_params, - as_tensor_output=True, + affine=affine, device=device, ) self.image_only = image_only - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.resampler = Resample(device=device) self.spatial_size = spatial_size self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: NdarrayOrTensor, spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ): + ) -> Union[NdarrayOrTensor, Tuple[NdarrayOrTensor, NdarrayOrTensor]]: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1418,6 +1534,9 @@ class RandAffine(RandomizableTransform): """ + backend = Affine.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, prob: float = 0.1, @@ -1473,13 +1592,15 @@ def __init__( cache_grid: whether to cache the identity sampling grid. If the spatial size is not dynamically defined by input image, enabling this option could accelerate the transform. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. See also: - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. + + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ RandomizableTransform.__init__(self, prob) @@ -1488,10 +1609,9 @@ def __init__( shear_range=shear_range, translate_range=translate_range, scale_range=scale_range, - as_tensor_output=True, device=device, ) - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.resampler = Resample(device=device) self.spatial_size = spatial_size self.cache_grid = cache_grid @@ -1519,7 +1639,7 @@ def _init_identity_cache(self): f"'spatial_size={self.spatial_size}', please specify 'spatial_size'." ) return None - return torch.tensor(create_grid(spatial_size=_sp_size)).to(self.rand_affine_grid.device) + return create_grid(spatial_size=_sp_size, device=self.rand_affine_grid.device, backend="torch") def get_identity_grid(self, spatial_size: Sequence[int]): """ @@ -1533,7 +1653,11 @@ def get_identity_grid(self, spatial_size: Sequence[int]): spatial_size, [2] * ndim ): raise RuntimeError(f"spatial_size should not be dynamic, got {spatial_size}.") - return create_grid(spatial_size=spatial_size) if self._cached_grid is None else self._cached_grid + return ( + create_grid(spatial_size=spatial_size, device=self.rand_affine_grid.device, backend="torch") + if self._cached_grid is None + else self._cached_grid + ) def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None @@ -1544,15 +1668,18 @@ def set_random_state( def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) + if not self._do_transform: + return None self.rand_affine_grid.randomize() def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: NdarrayOrTensor, spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + randomize: bool = True, + ) -> NdarrayOrTensor: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1567,21 +1694,25 @@ def __call__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + randomize: whether to execute `randomize()` function first, default to True. + """ - self.randomize() + if randomize: + self.randomize() + # if not doing transform and spatial size doesn't change, nothing to do - # except convert to float and convert numpy/torch + # except convert to float and device sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) do_resampling = self._do_transform or (sp_size != ensure_tuple(img.shape[1:])) if not do_resampling: - img = img.float() if isinstance(img, torch.Tensor) else img.astype("float32") - return torch.Tensor(img) if self.resampler.as_tensor_output else np.array(img) + img, *_ = convert_data_type(img, dtype=torch.float32, device=self.resampler.device) grid = self.get_identity_grid(sp_size) if self._do_transform: grid = self.rand_affine_grid(grid=grid) - return self.resampler( + out: NdarrayOrTensor = self.resampler( img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode ) + return out class Rand2DElastic(RandomizableTransform): @@ -1591,6 +1722,9 @@ class Rand2DElastic(RandomizableTransform): """ + backend = Resample.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, spacing: Union[Tuple[float, float], float], @@ -1645,13 +1779,15 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. See also: - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. + + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ RandomizableTransform.__init__(self, prob) self.deform_grid = RandDeformGrid( @@ -1662,11 +1798,11 @@ def __init__( shear_range=shear_range, translate_range=translate_range, scale_range=scale_range, - as_tensor_output=True, device=device, ) - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.resampler = Resample(device=device) + self.device = device self.spatial_size = spatial_size self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) @@ -1681,16 +1817,19 @@ def set_random_state( def randomize(self, spatial_size: Sequence[int]) -> None: super().randomize(None) + if not self._do_transform: + return None self.deform_grid.randomize(spatial_size) self.rand_affine_grid.randomize() def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: NdarrayOrTensor, spatial_size: Optional[Union[Tuple[int, int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + randomize: bool = True, + ) -> NdarrayOrTensor: """ Args: img: shape must be (num_channels, H, W), @@ -1703,23 +1842,30 @@ def __call__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + randomize: whether to execute `randomize()` function first, default to True. """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) - self.randomize(spatial_size=sp_size) + if randomize: + self.randomize(spatial_size=sp_size) + if self._do_transform: grid = self.deform_grid(spatial_size=sp_size) grid = self.rand_affine_grid(grid=grid) grid = torch.nn.functional.interpolate( # type: ignore recompute_scale_factor=True, - input=torch.as_tensor(grid).unsqueeze(0), + input=grid.unsqueeze(0), scale_factor=list(ensure_tuple(self.deform_grid.spacing)), mode=InterpolateMode.BICUBIC.value, align_corners=False, ) grid = CenterSpatialCrop(roi_size=sp_size)(grid[0]) else: - grid = create_grid(spatial_size=sp_size) - return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) + _device = img.device if isinstance(img, torch.Tensor) else self.device + grid = create_grid(spatial_size=sp_size, device=_device, backend="torch") + out: NdarrayOrTensor = self.resampler( + img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode + ) + return out class Rand3DElastic(RandomizableTransform): @@ -1729,6 +1875,9 @@ class Rand3DElastic(RandomizableTransform): """ + backend = Resample.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, sigma_range: Tuple[float, float], @@ -1786,17 +1935,25 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. See also: - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. + + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ RandomizableTransform.__init__(self, prob) - self.rand_affine_grid = RandAffineGrid(rotate_range, shear_range, translate_range, scale_range, True, device) - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.rand_affine_grid = RandAffineGrid( + rotate_range=rotate_range, + shear_range=shear_range, + translate_range=translate_range, + scale_range=scale_range, + device=device, + ) + self.resampler = Resample(device=device) self.sigma_range = sigma_range self.magnitude_range = magnitude_range @@ -1818,19 +1975,21 @@ def set_random_state( def randomize(self, grid_size: Sequence[int]) -> None: super().randomize(None) - if self._do_transform: - self.rand_offset = self.R.uniform(-1.0, 1.0, [3] + list(grid_size)).astype(np.float32) + if not self._do_transform: + return None + self.rand_offset = self.R.uniform(-1.0, 1.0, [3] + list(grid_size)).astype(np.float32, copy=False) self.magnitude = self.R.uniform(self.magnitude_range[0], self.magnitude_range[1]) self.sigma = self.R.uniform(self.sigma_range[0], self.sigma_range[1]) self.rand_affine_grid.randomize() def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: NdarrayOrTensor, spatial_size: Optional[Union[Tuple[int, int, int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + randomize: bool = True, + ) -> NdarrayOrTensor: """ Args: img: shape must be (num_channels, H, W, D), @@ -1843,59 +2002,185 @@ def __call__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + randomize: whether to execute `randomize()` function first, default to True. """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) - self.randomize(grid_size=sp_size) - grid = create_grid(spatial_size=sp_size) + if randomize: + self.randomize(grid_size=sp_size) + + _device = img.device if isinstance(img, torch.Tensor) else self.device + grid = create_grid(spatial_size=sp_size, device=_device, backend="torch") if self._do_transform: if self.rand_offset is None: - raise AssertionError - grid = torch.as_tensor(np.ascontiguousarray(grid), device=self.device) - gaussian = GaussianFilter(3, self.sigma, 3.0).to(device=self.device) - offset = torch.as_tensor(self.rand_offset, device=self.device).unsqueeze(0) + raise RuntimeError("rand_offset is not initialized.") + gaussian = GaussianFilter(3, self.sigma, 3.0).to(device=_device) + offset = torch.as_tensor(self.rand_offset, device=_device).unsqueeze(0) grid[:3] += gaussian(offset)[0] * self.magnitude grid = self.rand_affine_grid(grid=grid) - return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) + out: NdarrayOrTensor = self.resampler( + img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode + ) + return out -class AddCoordinateChannels(Transform): - """ - Appends additional channels encoding coordinates of the input. Useful when e.g. training using patch-based sampling, - to allow feeding of the patch's location into the network. +class GridDistortion(Transform): - This can be seen as a input-only version of CoordConv: + backend = [TransformBackends.TORCH] - Liu, R. et al. An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution, NeurIPS 2018. - """ + def __init__( + self, + num_cells: Union[Tuple[int], int], + distort_steps: Sequence[Sequence[float]], + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + device: Optional[torch.device] = None, + ) -> None: + """ + Grid distortion transform. Refer to: + https://github.com/albumentations-team/albumentations/blob/master/albumentations/augmentations/transforms.py + + Args: + num_cells: number of grid cells on each dimension. + distort_steps: This argument is a list of tuples, where each tuple contains the distort steps of the + corresponding dimensions (in the order of H, W[, D]). The length of each tuple equals to `num_cells + 1`. + Each value in the tuple represents the distort step of the related cell. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + device: device on which the tensor will be allocated. + + """ + self.resampler = Resample(mode=mode, padding_mode=padding_mode, device=device) + self.num_cells = num_cells + self.distort_steps = distort_steps + self.device = device + + def __call__( + self, + img: NdarrayOrTensor, + distort_steps: Optional[Sequence[Sequence]] = None, + mode: Optional[Union[GridSampleMode, str]] = None, + padding_mode: Optional[Union[GridSamplePadMode, str]] = None, + ) -> NdarrayOrTensor: + """ + Args: + img: shape must be (num_channels, H, W[, D]). + distort_steps: This argument is a list of tuples, where each tuple contains the distort steps of the + corresponding dimensions (in the order of H, W[, D]). The length of each tuple equals to `num_cells + 1`. + Each value in the tuple represents the distort step of the related cell. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + + """ + distort_steps = self.distort_steps if distort_steps is None else distort_steps + if len(img.shape) != len(distort_steps) + 1: + raise ValueError("the spatial size of `img` does not match with the length of `distort_steps`") + + all_ranges = [] + num_cells = ensure_tuple_rep(self.num_cells, len(img.shape) - 1) + for dim_idx, dim_size in enumerate(img.shape[1:]): + dim_distort_steps = distort_steps[dim_idx] + ranges = torch.zeros(dim_size, dtype=torch.float32) + cell_size = dim_size // num_cells[dim_idx] + prev = 0 + for idx in range(num_cells[dim_idx] + 1): + start = int(idx * cell_size) + end = start + cell_size + if end > dim_size: + end = dim_size + cur = dim_size + else: + cur = prev + cell_size * dim_distort_steps[idx] + ranges[start:end] = torch.linspace(prev, cur, end - start) + prev = cur + ranges = ranges - (dim_size - 1.0) / 2.0 + all_ranges.append(ranges) + + coords = torch.meshgrid(*all_ranges) + grid = torch.stack([*coords, torch.ones_like(coords[0])]) + + return self.resampler(img, grid=grid, mode=mode, padding_mode=padding_mode) # type: ignore + + +class RandGridDistortion(RandomizableTransform): + + backend = [TransformBackends.TORCH] def __init__( self, - spatial_channels: Sequence[int], + num_cells: Union[Tuple[int], int] = 5, + prob: float = 0.1, + distort_limit: Union[Tuple[float, float], float] = (-0.03, 0.03), + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + device: Optional[torch.device] = None, ) -> None: """ + Random grid distortion transform. Refer to: + https://github.com/albumentations-team/albumentations/blob/master/albumentations/augmentations/transforms.py + Args: - spatial_channels: the spatial dimensions that are to have their coordinates encoded in a channel and - appended to the input. E.g., `(1,2,3)` will append three channels to the input, encoding the - coordinates of the input's three spatial dimensions (0 is reserved for the channel dimension). + num_cells: number of grid cells on each dimension. + prob: probability of returning a randomized grid distortion transform. Defaults to 0.1. + distort_limit: range to randomly distort. + If single number, distort_limit is picked from (-distort_limit, distort_limit). + Defaults to (-0.03, 0.03). + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + device: device on which the tensor will be allocated. + """ - self.spatial_channels = spatial_channels + RandomizableTransform.__init__(self, prob) + self.num_cells = num_cells + if isinstance(distort_limit, (int, float)): + self.distort_limit = (min(-distort_limit, distort_limit), max(-distort_limit, distort_limit)) + else: + self.distort_limit = (min(distort_limit), max(distort_limit)) + self.distort_steps: Sequence[Sequence[float]] = ((1.0,),) + self.grid_distortion = GridDistortion( + num_cells=num_cells, distort_steps=self.distort_steps, mode=mode, padding_mode=padding_mode, device=device + ) + + def randomize(self, spatial_shape: Sequence[int]) -> None: + super().randomize(None) + if not self._do_transform: + return + self.distort_steps = tuple( + tuple(1.0 + self.R.uniform(low=self.distort_limit[0], high=self.distort_limit[1], size=n_cells + 1)) + for n_cells in ensure_tuple_rep(self.num_cells, len(spatial_shape)) + ) - def __call__(self, img: Union[np.ndarray, torch.Tensor]): + def __call__( + self, + img: NdarrayOrTensor, + mode: Optional[Union[GridSampleMode, str]] = None, + padding_mode: Optional[Union[GridSamplePadMode, str]] = None, + randomize: bool = True, + ) -> NdarrayOrTensor: """ Args: - img: data to be transformed, assuming `img` is channel first. + img: shape must be (num_channels, H, W[, D]). + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + randomize: whether to shuffle the random factors using `randomize()`, default to True. """ - if max(self.spatial_channels) > img.ndim - 1: - raise ValueError( - f"input has {img.ndim-1} spatial dimensions, cannot add AddCoordinateChannels channel for " - f"dim {max(self.spatial_channels)}." - ) - if 0 in self.spatial_channels: - raise ValueError("cannot add AddCoordinateChannels channel for dimension 0, as 0 is channel dim.") - - spatial_dims = img.shape[1:] - coord_channels = np.array(np.meshgrid(*tuple(np.linspace(-0.5, 0.5, s) for s in spatial_dims), indexing="ij")) - # only keep required dimensions. need to subtract 1 since im will be 0-based - # but user input is 1-based (because channel dim is 0) - coord_channels = coord_channels[[s - 1 for s in self.spatial_channels]] - return np.concatenate((img, coord_channels), axis=0) + if randomize: + self.randomize(img.shape[1:]) + if not self._do_transform: + return img + return self.grid_distortion(img, distort_steps=self.distort_steps, mode=mode, padding_mode=padding_mode) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index b0558a6556..d1cf70f92c 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,7 +17,7 @@ from copy import deepcopy from enum import Enum -from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -29,14 +29,19 @@ from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad from monai.transforms.inverse import InvertibleTransform from monai.transforms.spatial.array import ( - AddCoordinateChannels, Affine, AffineGrid, Flip, + GridDistortion, Orientation, Rand2DElastic, Rand3DElastic, RandAffine, + RandAxisFlip, + RandFlip, + RandGridDistortion, + RandRotate, + RandZoom, Resize, Rotate, Rotate90, @@ -50,12 +55,15 @@ GridSamplePadMode, InterpolateMode, NumpyPadMode, + PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple, ) -from monai.utils.enums import InverseKeys +from monai.utils.deprecate_utils import deprecated_arg +from monai.utils.enums import TraceKeys from monai.utils.module import optional_import +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type nib, _ = optional_import("nibabel") @@ -71,6 +79,8 @@ "Rand3DElasticd", "Flipd", "RandFlipd", + "GridDistortiond", + "RandGridDistortiond", "RandAxisFlipd", "Rotated", "RandRotated", @@ -98,6 +108,10 @@ "FlipDict", "RandFlipD", "RandFlipDict", + "GridDistortionD", + "GridDistortionDict", + "RandGridDistortionD", + "RandGridDistortionDict", "RandAxisFlipD", "RandAxisFlipDict", "RotateD", @@ -108,14 +122,12 @@ "ZoomDict", "RandZoomD", "RandZoomDict", - "AddCoordinateChannelsD", - "AddCoordinateChannelsDict", ] GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str] InterpolateModeSequence = Union[Sequence[Union[InterpolateMode, str]], InterpolateMode, str] -NumpyPadModeSequence = Union[Sequence[Union[NumpyPadMode, str]], NumpyPadMode, str] +PadModeSequence = Union[Sequence[Union[NumpyPadMode, PytorchPadMode, str]], NumpyPadMode, PytorchPadMode, str] class Spacingd(MapTransform, InvertibleTransform): @@ -132,6 +144,8 @@ class Spacingd(MapTransform, InvertibleTransform): :py:class:`monai.transforms.Spacing` """ + backend = Spacing.backend + def __init__( self, keys: KeysCollection, @@ -208,8 +222,8 @@ def __init__( self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) def __call__( - self, data: Mapping[Union[Hashable, str], Dict[str, np.ndarray]] - ) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]: + self, data: Mapping[Union[Hashable, str], Dict[str, NdarrayOrTensor]] + ) -> Dict[Hashable, NdarrayOrTensor]: d: Dict = dict(data) for key, mode, padding_mode, align_corners, dtype, meta_key, meta_key_postfix in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype, self.meta_keys, self.meta_key_postfix @@ -223,7 +237,7 @@ def __call__( # using affine fetched from d[affine_key] original_spatial_shape = d[key].shape[1:] d[key], old_affine, new_affine = self.spacing_transform( - data_array=np.asarray(d[key]), + data_array=d[key], affine=meta_data["affine"], mode=mode, padding_mode=padding_mode, @@ -238,7 +252,7 @@ def __call__( "old_affine": old_affine, "mode": mode.value if isinstance(mode, Enum) else mode, "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else "none", + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, }, orig_size=original_spatial_shape, ) @@ -246,7 +260,7 @@ def __call__( meta_data["affine"] = new_affine return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) @@ -256,25 +270,25 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar + "Please raise a github issue if you need this feature" ) # Create inverse transform - meta_data = d[transform[InverseKeys.EXTRA_INFO]["meta_key"]] - old_affine = np.array(transform[InverseKeys.EXTRA_INFO]["old_affine"]) - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] - orig_size = transform[InverseKeys.ORIG_SIZE] + meta_data = d[transform[TraceKeys.EXTRA_INFO]["meta_key"]] + old_affine = np.array(transform[TraceKeys.EXTRA_INFO]["old_affine"]) + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] + orig_size = transform[TraceKeys.ORIG_SIZE] orig_pixdim = np.sqrt(np.sum(np.square(old_affine), 0))[:-1] inverse_transform = Spacing(orig_pixdim, diagonal=self.spacing_transform.diagonal) # Apply inverse d[key], _, new_affine = inverse_transform( - data_array=np.asarray(d[key]), - affine=meta_data["affine"], + data_array=d[key], + affine=meta_data["affine"], # type: ignore mode=mode, padding_mode=padding_mode, - align_corners=False if align_corners == "none" else align_corners, + align_corners=False if align_corners == TraceKeys.NONE else align_corners, dtype=dtype, output_spatial_shape=orig_size, ) - meta_data["affine"] = new_affine + meta_data["affine"] = new_affine # type: ignore # Remove the applied transform self.pop_transform(d, key) @@ -292,6 +306,8 @@ class Orientationd(MapTransform, InvertibleTransform): to the `affine` field of metadata which is formed by ``key_{meta_key_postfix}``. """ + backend = Orientation.backend + def __init__( self, keys: KeysCollection, @@ -341,8 +357,8 @@ def __init__( self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) def __call__( - self, data: Mapping[Union[Hashable, str], Dict[str, np.ndarray]] - ) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]: + self, data: Mapping[Union[Hashable, str], Dict[str, NdarrayOrTensor]] + ) -> Dict[Hashable, NdarrayOrTensor]: d: Dict = dict(data) for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): meta_key = meta_key or f"{key}_{meta_key_postfix}" @@ -355,18 +371,16 @@ def __call__( d[meta_key]["affine"] = new_affine return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - meta_data = d[transform[InverseKeys.EXTRA_INFO]["meta_key"]] - orig_affine = transform[InverseKeys.EXTRA_INFO]["old_affine"] + meta_data: Dict = d[transform[TraceKeys.EXTRA_INFO]["meta_key"]] # type: ignore + orig_affine = transform[TraceKeys.EXTRA_INFO]["old_affine"] orig_axcodes = nib.orientations.aff2axcodes(orig_affine) inverse_transform = Orientation( - axcodes=orig_axcodes, - as_closest_canonical=False, - labels=self.ornt_transform.labels, + axcodes=orig_axcodes, as_closest_canonical=False, labels=self.ornt_transform.labels ) # Apply inverse d[key], _, new_affine = inverse_transform(d[key], affine=meta_data["affine"]) @@ -382,6 +396,8 @@ class Rotate90d(MapTransform, InvertibleTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Rotate90`. """ + backend = Rotate90.backend + def __init__( self, keys: KeysCollection, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1), allow_missing_keys: bool = False ) -> None: @@ -395,14 +411,14 @@ def __init__( super().__init__(keys, allow_missing_keys) self.rotator = Rotate90(k, spatial_axes) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): self.push_transform(d, key) d[key] = self.rotator(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): _ = self.get_most_recent_transform(d, key) @@ -411,9 +427,6 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar num_times_rotated = self.rotator.k num_times_to_rotate = 4 - num_times_rotated inverse_transform = Rotate90(num_times_to_rotate, spatial_axes) - # Might need to convert to numpy - if isinstance(d[key], torch.Tensor): - d[key] = torch.Tensor(d[key]).cpu().numpy() # Apply inverse d[key] = inverse_transform(d[key]) # Remove the applied transform @@ -429,6 +442,8 @@ class RandRotate90d(RandomizableTransform, MapTransform, InvertibleTransform): in the plane specified by `spatial_axes`. """ + backend = Rotate90.backend + def __init__( self, keys: KeysCollection, @@ -461,10 +476,12 @@ def randomize(self, data: Optional[Any] = None) -> None: self._rand_k = self.R.randint(self.max_k) + 1 super().randomize(None) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: self.randomize() d = dict(data) + # FIXME: here we didn't use array version `RandRotate90` transform as others, because we need + # to be compatible with the random status of some previous integration tests rotator = Rotate90(self._rand_k, self.spatial_axes) for key in self.key_iterator(d): if self._do_transform: @@ -472,19 +489,16 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. self.push_transform(d, key, extra_info={"rand_k": self._rand_k}) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InverseKeys.DO_TRANSFORM]: + if transform[TraceKeys.DO_TRANSFORM]: # Create inverse transform - num_times_rotated = transform[InverseKeys.EXTRA_INFO]["rand_k"] + num_times_rotated = transform[TraceKeys.EXTRA_INFO]["rand_k"] num_times_to_rotate = 4 - num_times_rotated inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes) - # Might need to convert to numpy - if isinstance(d[key], torch.Tensor): - d[key] = torch.Tensor(d[key]).cpu().numpy() # Apply inverse d[key] = inverse_transform(d[key]) # Remove the applied transform @@ -520,6 +534,8 @@ class Resized(MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = Resize.backend + def __init__( self, keys: KeysCollection, @@ -534,7 +550,7 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.resizer = Resize(spatial_size=spatial_size, size_mode=size_mode) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): self.push_transform( @@ -542,24 +558,24 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda key, extra_info={ "mode": mode.value if isinstance(mode, Enum) else mode, - "align_corners": align_corners if align_corners is not None else "none", + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, }, ) d[key] = self.resizer(d[key], mode=mode, align_corners=align_corners) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) - orig_size = transform[InverseKeys.ORIG_SIZE] - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] + orig_size = transform[TraceKeys.ORIG_SIZE] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] # Create inverse transform inverse_transform = Resize( spatial_size=orig_size, mode=mode, - align_corners=None if align_corners == "none" else align_corners, + align_corners=None if align_corners == TraceKeys.NONE else align_corners, ) # Apply inverse transform d[key] = inverse_transform(d[key]) @@ -574,6 +590,9 @@ class Affined(MapTransform, InvertibleTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Affine`. """ + backend = Affine.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, @@ -581,10 +600,11 @@ def __init__( shear_params: Optional[Union[Sequence[float], float]] = None, translate_params: Optional[Union[Sequence[float], float]] = None, scale_params: Optional[Union[Sequence[float], float]] = None, + affine: Optional[NdarrayOrTensor] = None, spatial_size: Optional[Union[Sequence[int], int]] = None, mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, - as_tensor_output: bool = False, + as_tensor_output: bool = True, device: Optional[torch.device] = None, allow_missing_keys: bool = False, ) -> None: @@ -607,6 +627,9 @@ def __init__( pixel/voxel relative to the center of the input image. Defaults to no translation. scale_params: scale factor for every spatial dims. a tuple of 2 floats for 2D, a tuple of 3 floats for 3D. Defaults to `1.0`. + affine: if applied, ignore the params (`rotate_params`, etc.) and use the + supplied matrix. Should be square with each side = num of image spatial + dimensions + 1. spatial_size: output image spatial size. if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1, the transform will use the spatial size of `img`. @@ -621,14 +644,16 @@ def __init__( Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample It also can be a sequence of string, each element corresponds to a key in ``keys``. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. See also: - :py:class:`monai.transforms.compose.MapTransform` - :py:class:`RandAffineGrid` for the random affine parameters configurations. + + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ MapTransform.__init__(self, keys, allow_missing_keys) self.affine = Affine( @@ -636,16 +661,14 @@ def __init__( shear_params=shear_params, translate_params=translate_params, scale_params=scale_params, + affine=affine, spatial_size=spatial_size, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): orig_size = d[key].shape[1:] @@ -662,26 +685,23 @@ def __call__( ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) - orig_size = transform[InverseKeys.ORIG_SIZE] + orig_size = transform[TraceKeys.ORIG_SIZE] # Create inverse transform - fwd_affine = transform[InverseKeys.EXTRA_INFO]["affine"] - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] + fwd_affine = transform[TraceKeys.EXTRA_INFO]["affine"] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) grid, _ = affine_grid(orig_size) # type: ignore # Apply inverse transform - out = self.affine.resampler(d[key], grid, mode, padding_mode) - - # Convert to numpy - d[key] = out if isinstance(out, np.ndarray) else out.cpu().numpy() + d[key] = self.affine.resampler(d[key], grid, mode, padding_mode) # Remove the applied transform self.pop_transform(d, key) @@ -694,6 +714,9 @@ class RandAffined(RandomizableTransform, MapTransform, InvertibleTransform): Dictionary-based wrapper of :py:class:`monai.transforms.RandAffine`. """ + backend = RandAffine.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, @@ -754,14 +777,16 @@ def __init__( cache_grid: whether to cache the identity sampling grid. If the spatial size is not dynamically defined by input image, enabling this option could accelerate the transform. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. See also: - :py:class:`monai.transforms.compose.MapTransform` - :py:class:`RandAffineGrid` for the random affine parameters configurations. + + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) @@ -773,7 +798,6 @@ def __init__( scale_range=scale_range, spatial_size=spatial_size, cache_grid=cache_grid, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -786,22 +810,23 @@ def set_random_state( super().set_random_state(seed, state) return self - def randomize(self, data: Optional[Any] = None) -> None: - super().randomize(None) - self.rand_affine.randomize() - - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize() + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d - sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:]) - # change image size or do random transform - do_resampling = self._do_transform or (sp_size != ensure_tuple(data[self.keys[0]].shape[1:])) + self.randomize(None) + # all the keys share the same random Affine factor + self.rand_affine.randomize() - # to be consistent with the self._do_transform case (dtype and device) - affine = torch.as_tensor(np.eye(len(sp_size) + 1), device=self.rand_affine.rand_affine_grid.device) + device = self.rand_affine.resampler.device + spatial_size = d[first_key].shape[1:] # type: ignore + sp_size = fall_back_tuple(self.rand_affine.spatial_size, spatial_size) + # change image size or do random transform + do_resampling = self._do_transform or (sp_size != ensure_tuple(spatial_size)) + affine: torch.Tensor = torch.eye(len(sp_size) + 1, dtype=torch.float64, device=device) + # converting affine to tensor because the resampler currently only support torch backend grid = None if do_resampling: # need to prepare grid grid = self.rand_affine.get_identity_grid(sp_size) @@ -822,39 +847,28 @@ def __call__( # do the transform if do_resampling: d[key] = self.rand_affine.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) - # if not doing transform and and spatial size is unchanged, only need to do numpy/torch conversion - else: - if self.rand_affine.resampler.as_tensor_output and not isinstance(d[key], torch.Tensor): - d[key] = torch.Tensor(d[key]) - elif not self.rand_affine.resampler.as_tensor_output and isinstance(d[key], torch.Tensor): - d[key] = d[key].detach().cpu().numpy() # type: ignore[union-attr] return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # if transform was not performed and spatial size is None, nothing to do. - if not transform[InverseKeys.DO_TRANSFORM] and self.rand_affine.spatial_size is None: - out: Union[np.ndarray, torch.Tensor] = d[key] - else: - orig_size = transform[InverseKeys.ORIG_SIZE] + if transform[TraceKeys.DO_TRANSFORM] or self.rand_affine.spatial_size is not None: + orig_size = transform[TraceKeys.ORIG_SIZE] # Create inverse transform - fwd_affine = transform[InverseKeys.EXTRA_INFO]["affine"] - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] + fwd_affine = transform[TraceKeys.EXTRA_INFO]["affine"] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) grid, _ = affine_grid(orig_size) # type: ignore # Apply inverse transform - out = self.rand_affine.resampler(d[key], grid, mode, padding_mode) - - # Convert to numpy - d[key] = out if isinstance(out, np.ndarray) else out.cpu().numpy() + d[key] = self.rand_affine.resampler(d[key], grid, mode, padding_mode) # Remove the applied transform self.pop_transform(d, key) @@ -867,6 +881,9 @@ class Rand2DElasticd(RandomizableTransform, MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Rand2DElastic`. """ + backend = Rand2DElastic.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, @@ -927,14 +944,16 @@ def __init__( Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample It also can be a sequence of string, each element corresponds to a key in ``keys``. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. See also: - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. + + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) @@ -947,7 +966,6 @@ def __init__( translate_range=translate_range, scale_range=scale_range, spatial_size=spatial_size, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -960,17 +978,17 @@ def set_random_state( super().set_random_state(seed, state) return self - def randomize(self, spatial_size: Sequence[int]) -> None: - super().randomize(None) - self.rand_2d_elastic.randomize(spatial_size) - - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + + self.randomize(None) - sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, data[self.keys[0]].shape[1:]) - self.randomize(spatial_size=sp_size) + sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, d[first_key].shape[1:]) # type: ignore + # all the keys share the same random elastic factor + self.rand_2d_elastic.randomize(sp_size) if self._do_transform: grid = self.rand_2d_elastic.deform_grid(spatial_size=sp_size) @@ -984,7 +1002,8 @@ def __call__( ) grid = CenterSpatialCrop(roi_size=sp_size)(grid[0]) else: - grid = create_grid(spatial_size=sp_size) + _device = self.rand_2d_elastic.deform_grid.device + grid = create_grid(spatial_size=sp_size, device=_device, backend="torch") for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): d[key] = self.rand_2d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) @@ -996,6 +1015,9 @@ class Rand3DElasticd(RandomizableTransform, MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Rand3DElastic`. """ + backend = Rand3DElastic.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, @@ -1058,14 +1080,16 @@ def __init__( Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample It also can be a sequence of string, each element corresponds to a key in ``keys``. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. See also: - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. + + .. deprecated:: 0.6.0 + ``as_tensor_output`` is deprecated. + """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) @@ -1078,7 +1102,6 @@ def __init__( translate_range=translate_range, scale_range=scale_range, spatial_size=spatial_size, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -1091,23 +1114,24 @@ def set_random_state( super().set_random_state(seed, state) return self - def randomize(self, grid_size: Sequence[int]) -> None: - super().randomize(None) - self.rand_3d_elastic.randomize(grid_size) - - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, data[self.keys[0]].shape[1:]) + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + + self.randomize(None) - self.randomize(grid_size=sp_size) - grid = create_grid(spatial_size=sp_size) + sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, d[first_key].shape[1:]) # type: ignore + # all the keys share the same random elastic factor + self.rand_3d_elastic.randomize(sp_size) + + _device = self.rand_3d_elastic.device + grid = create_grid(spatial_size=sp_size, device=_device, backend="torch") if self._do_transform: device = self.rand_3d_elastic.device - grid = torch.tensor(grid).to(device) gaussian = GaussianFilter(spatial_dims=3, sigma=self.rand_3d_elastic.sigma, truncated=3.0).to(device) - offset = torch.tensor(self.rand_3d_elastic.rand_offset, device=device).unsqueeze(0) + offset = torch.as_tensor(self.rand_3d_elastic.rand_offset, device=device).unsqueeze(0) grid[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude grid = self.rand_3d_elastic.rand_affine_grid(grid=grid) @@ -1173,7 +1197,7 @@ class RandFlipd(RandomizableTransform, MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ - backend = Flip.backend + backend = RandFlip.backend def __init__( self, @@ -1184,16 +1208,22 @@ def __init__( ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.spatial_axis = spatial_axis + self.flipper = RandFlip(prob=1.0, spatial_axis=spatial_axis) - self.flipper = Flip(spatial_axis=spatial_axis) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandFlipd": + super().set_random_state(seed, state) + self.flipper.set_random_state(seed, state) + return self def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - self.randomize(None) d = dict(data) + self.randomize(None) + for key in self.key_iterator(d): if self._do_transform: - d[key] = self.flipper(d[key]) + d[key] = self.flipper(d[key], randomize=False) self.push_transform(d, key) return d @@ -1202,9 +1232,9 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InverseKeys.DO_TRANSFORM]: + if transform[TraceKeys.DO_TRANSFORM]: # Inverse is same as forward - d[key] = self.flipper(d[key]) + d[key] = self.flipper(d[key], randomize=False) # Remove the applied transform self.pop_transform(d, key) return d @@ -1224,26 +1254,34 @@ class RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform): """ - backend = Flip.backend + backend = RandAxisFlip.backend def __init__(self, keys: KeysCollection, prob: float = 0.1, allow_missing_keys: bool = False) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self._axis: Optional[int] = None + self.flipper = RandAxisFlip(prob=1.0) - def randomize(self, data: NdarrayOrTensor) -> None: - super().randomize(None) - self._axis = self.R.randint(data.ndim - 1) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandAxisFlipd": + super().set_random_state(seed, state) + self.flipper.set_random_state(seed, state) + return self def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - self.randomize(data=data[self.keys[0]]) - flipper = Flip(spatial_axis=self._axis) - d = dict(data) + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + + self.randomize(None) + + # all the keys share the same random selected axis + self.flipper.randomize(d[first_key]) # type: ignore for key in self.key_iterator(d): if self._do_transform: - d[key] = flipper(d[key]) - self.push_transform(d, key, extra_info={"axis": self._axis}) + d[key] = self.flipper(d[key], randomize=False) + self.push_transform(d, key, extra_info={"axis": self.flipper._axis}) return d def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: @@ -1251,8 +1289,8 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InverseKeys.DO_TRANSFORM]: - flipper = Flip(spatial_axis=transform[InverseKeys.EXTRA_INFO]["axis"]) + if transform[TraceKeys.DO_TRANSFORM]: + flipper = Flip(spatial_axis=transform[TraceKeys.EXTRA_INFO]["axis"]) # Inverse is same as forward d[key] = flipper(d[key]) # Remove the applied transform @@ -1288,6 +1326,8 @@ class Rotated(MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = Rotate.backend + def __init__( self, keys: KeysCollection, @@ -1296,7 +1336,7 @@ def __init__( mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, align_corners: Union[Sequence[bool], bool] = False, - dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, + dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], Union[DtypeLike, torch.dtype]] = np.float64, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) @@ -1307,18 +1347,14 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype ): orig_size = d[key].shape[1:] d[key] = self.rotator( - d[key], - mode=mode, - padding_mode=padding_mode, - align_corners=align_corners, - dtype=dtype, + d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype ) rot_mat = self.rotator.get_rotation_matrix() self.push_transform( @@ -1329,35 +1365,37 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda "rot_mat": rot_mat, "mode": mode.value if isinstance(mode, Enum) else mode, "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else "none", + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, }, ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) # Create inverse transform - fwd_rot_mat = transform[InverseKeys.EXTRA_INFO]["rot_mat"] - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] + fwd_rot_mat = transform[TraceKeys.EXTRA_INFO]["rot_mat"] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] inv_rot_mat = np.linalg.inv(fwd_rot_mat) xform = AffineTransform( normalized=False, mode=mode, padding_mode=padding_mode, - align_corners=False if align_corners == "none" else align_corners, + align_corners=False if align_corners == TraceKeys.NONE else align_corners, reverse_indexing=True, ) - output = xform( - torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), - spatial_size=transform[InverseKeys.ORIG_SIZE], - ) - d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + img_t: torch.Tensor + img_t, *_ = convert_data_type(d[key], torch.Tensor, dtype=dtype) # type: ignore + transform_t: torch.Tensor + transform_t, *_ = convert_to_dst_type(inv_rot_mat, img_t) # type: ignore + + out = xform(img_t.unsqueeze(0), transform_t, spatial_size=transform[TraceKeys.ORIG_SIZE]).squeeze(0) + out, *_ = convert_to_dst_type(out, dst=d[key], dtype=out.dtype) + d[key] = out # Remove the applied transform self.pop_transform(d, key) @@ -1399,6 +1437,8 @@ class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = RandRotate.backend + def __init__( self, keys: KeysCollection, @@ -1410,99 +1450,86 @@ def __init__( mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, align_corners: Union[Sequence[bool], bool] = False, - dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, + dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], Union[DtypeLike, torch.dtype]] = np.float64, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.range_x = ensure_tuple(range_x) - if len(self.range_x) == 1: - self.range_x = tuple(sorted([-self.range_x[0], self.range_x[0]])) - self.range_y = ensure_tuple(range_y) - if len(self.range_y) == 1: - self.range_y = tuple(sorted([-self.range_y[0], self.range_y[0]])) - self.range_z = ensure_tuple(range_z) - if len(self.range_z) == 1: - self.range_z = tuple(sorted([-self.range_z[0], self.range_z[0]])) - - self.keep_size = keep_size + self.rand_rotate = RandRotate(range_x=range_x, range_y=range_y, range_z=range_z, prob=1.0, keep_size=keep_size) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - self.x = 0.0 - self.y = 0.0 - self.z = 0.0 - - def randomize(self, data: Optional[Any] = None) -> None: - super().randomize(None) - self.x = self.R.uniform(low=self.range_x[0], high=self.range_x[1]) - self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) - self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandRotated": + super().set_random_state(seed, state) + self.rand_rotate.set_random_state(seed, state) + return self - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - self.randomize() + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - angle: Union[Sequence[float], float] = self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z) - rotator = Rotate( - angle=angle, - keep_size=self.keep_size, - ) + self.randomize(None) + + # all the keys share the same random rotate angle + self.rand_rotate.randomize() for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype ): - orig_size = d[key].shape[1:] if self._do_transform: - d[key] = rotator( + d[key], rot_mat = self.rand_rotate( d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, + randomize=False, + get_matrix=True, ) - rot_mat = rotator.get_rotation_matrix() else: rot_mat = np.eye(d[key].ndim) self.push_transform( d, key, - orig_size=orig_size, + orig_size=d[key].shape[1:], extra_info={ "rot_mat": rot_mat, "mode": mode.value if isinstance(mode, Enum) else mode, "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else "none", + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, }, ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InverseKeys.DO_TRANSFORM]: + if transform[TraceKeys.DO_TRANSFORM]: # Create inverse transform - fwd_rot_mat = transform[InverseKeys.EXTRA_INFO]["rot_mat"] - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] + fwd_rot_mat = transform[TraceKeys.EXTRA_INFO]["rot_mat"] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] inv_rot_mat = np.linalg.inv(fwd_rot_mat) xform = AffineTransform( normalized=False, mode=mode, padding_mode=padding_mode, - align_corners=False if align_corners == "none" else align_corners, + align_corners=False if align_corners == TraceKeys.NONE else align_corners, reverse_indexing=True, ) - output = xform( - torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), - spatial_size=transform[InverseKeys.ORIG_SIZE], - ) - d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + img_t: torch.Tensor + img_t, *_ = convert_data_type(d[key], torch.Tensor, dtype=dtype) # type: ignore + transform_t: torch.Tensor + transform_t, *_ = convert_to_dst_type(inv_rot_mat, img_t) # type: ignore + output: torch.Tensor + out = xform(img_t.unsqueeze(0), transform_t, spatial_size=transform[TraceKeys.ORIG_SIZE]).squeeze(0) + out, *_ = convert_to_dst_type(out, dst=d[key], dtype=out.dtype) + d[key] = out # Remove the applied transform self.pop_transform(d, key) @@ -1522,39 +1549,44 @@ class Zoomd(MapTransform, InvertibleTransform): The interpolation mode. Defaults to ``"area"``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate It also can be a sequence of string, each element corresponds to a key in ``keys``. - padding_mode: {``"constant"``, ``"edge``", ``"linear_ramp``", ``"maximum``", ``"mean``", `"median``", - ``"minimum``", `"reflect``", ``"symmetric``", ``"wrap``", ``"empty``", ``"``"} + padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. The mode to pad data after zooming. - See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. keep_size: Should keep original size (pad if needed), default is True. allow_missing_keys: don't raise exception if key is missing. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. - more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. """ + backend = Zoom.backend + def __init__( self, keys: KeysCollection, zoom: Union[Sequence[float], float], mode: InterpolateModeSequence = InterpolateMode.AREA, - padding_mode: NumpyPadModeSequence = NumpyPadMode.EDGE, + padding_mode: PadModeSequence = NumpyPadMode.EDGE, align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, keep_size: bool = True, allow_missing_keys: bool = False, - **np_kwargs, + **kwargs, ) -> None: super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) - self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **np_kwargs) + self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **kwargs) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners @@ -1565,36 +1597,31 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda extra_info={ "mode": mode.value if isinstance(mode, Enum) else mode, "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else "none", + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, }, ) - d[key] = self.zoomer( - d[key], - mode=mode, - padding_mode=padding_mode, - align_corners=align_corners, - ) + d[key] = self.zoomer(d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform zoom = np.array(self.zoomer.zoom) inverse_transform = Zoom(zoom=(1 / zoom).tolist(), keep_size=self.zoomer.keep_size) - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] # Apply inverse d[key] = inverse_transform( d[key], mode=mode, padding_mode=padding_mode, - align_corners=None if align_corners == "none" else align_corners, + align_corners=None if align_corners == TraceKeys.NONE else align_corners, ) # Size might be out by 1 voxel so pad - d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE], mode="edge")(d[key]) + d[key] = SpatialPad(transform[TraceKeys.ORIG_SIZE], mode="edge")(d[key]) # type: ignore # Remove the applied transform self.pop_transform(d, key) @@ -1622,21 +1649,26 @@ class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform): The interpolation mode. Defaults to ``"area"``. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate It also can be a sequence of string, each element corresponds to a key in ``keys``. - padding_mode: {``"constant"``, ``"edge``", ``"linear_ramp``", ``"maximum``", ``"mean``", `"median``", - ``"minimum``", `"reflect``", ``"symmetric``", ``"wrap``", ``"empty``", ``"``"} + padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. The mode to pad data after zooming. - See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. keep_size: Should keep original size (pad if needed), default is True. allow_missing_keys: don't raise exception if key is missing. - np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. + kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ + backend = RandZoom.backend + def __init__( self, keys: KeysCollection, @@ -1644,118 +1676,196 @@ def __init__( min_zoom: Union[Sequence[float], float] = 0.9, max_zoom: Union[Sequence[float], float] = 1.1, mode: InterpolateModeSequence = InterpolateMode.AREA, - padding_mode: NumpyPadModeSequence = NumpyPadMode.EDGE, + padding_mode: PadModeSequence = NumpyPadMode.EDGE, align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, keep_size: bool = True, allow_missing_keys: bool = False, - **np_kwargs, + **kwargs, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.min_zoom = ensure_tuple(min_zoom) - self.max_zoom = ensure_tuple(max_zoom) - if len(self.min_zoom) != len(self.max_zoom): - raise AssertionError("min_zoom and max_zoom must have same length.") - + self.rand_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, keep_size=keep_size, **kwargs) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) - self.keep_size = keep_size - self.np_kwargs = np_kwargs - - self._zoom: Sequence[float] = [1.0] - def randomize(self, data: Optional[Any] = None) -> None: - super().randomize(None) - self._zoom = [self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)] + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandZoomd": + super().set_random_state(seed, state) + self.rand_zoom.set_random_state(seed, state) + return self - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - # match the spatial dim of first item - self.randomize() + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d - img_dims = data[self.keys[0]].ndim - if len(self._zoom) == 1: - # to keep the spatial shape ratio, use same random zoom factor for all dims - self._zoom = ensure_tuple_rep(self._zoom[0], img_dims - 1) - elif len(self._zoom) == 2 and img_dims > 3: - # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim - self._zoom = ensure_tuple_rep(self._zoom[0], img_dims - 2) + ensure_tuple(self._zoom[-1]) - zoomer = Zoom(self._zoom, keep_size=self.keep_size, **self.np_kwargs) + self.randomize(None) + + # all the keys share the same random zoom factor + self.rand_zoom.randomize(d[first_key]) # type: ignore for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners ): + if self._do_transform: + d[key] = self.rand_zoom( + d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, randomize=False + ) self.push_transform( d, key, extra_info={ - "zoom": self._zoom, + "zoom": self.rand_zoom._zoom, "mode": mode.value if isinstance(mode, Enum) else mode, "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else "none", + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, }, ) - if self._do_transform: - d[key] = zoomer( - d[key], - mode=mode, - padding_mode=padding_mode, - align_corners=align_corners, - ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InverseKeys.DO_TRANSFORM]: + if transform[TraceKeys.DO_TRANSFORM]: # Create inverse transform - zoom = np.array(transform[InverseKeys.EXTRA_INFO]["zoom"]) - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] - inverse_transform = Zoom(zoom=(1 / zoom).tolist(), keep_size=self.keep_size) + zoom = np.array(transform[TraceKeys.EXTRA_INFO]["zoom"]) + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] + inverse_transform = Zoom(zoom=(1 / zoom).tolist(), keep_size=self.rand_zoom.keep_size) # Apply inverse d[key] = inverse_transform( d[key], mode=mode, padding_mode=padding_mode, - align_corners=None if align_corners == "none" else align_corners, + align_corners=None if align_corners == TraceKeys.NONE else align_corners, ) # Size might be out by 1 voxel so pad - d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE], mode="edge")(d[key]) + d[key] = SpatialPad(transform[TraceKeys.ORIG_SIZE], mode="edge")(d[key]) # type: ignore # Remove the applied transform self.pop_transform(d, key) return d -class AddCoordinateChannelsd(MapTransform): +class GridDistortiond(MapTransform): """ - Dictionary-based wrapper of :py:class:`monai.transforms.AddCoordinateChannels`. + Dictionary-based wrapper of :py:class:`monai.transforms.GridDistortion`. """ - def __init__(self, keys: KeysCollection, spatial_channels: Sequence[int], allow_missing_keys: bool = False) -> None: + backend = GridDistortion.backend + + def __init__( + self, + keys: KeysCollection, + num_cells: Union[Tuple[int], int], + distort_steps: List[Tuple], + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + device: Optional[torch.device] = None, + allow_missing_keys: bool = False, + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. - See also: :py:class:`monai.transforms.compose.MapTransform` + num_cells: number of grid cells on each dimension. + distort_steps: This argument is a list of tuples, where each tuple contains the distort steps of the + corresponding dimensions (in the order of H, W[, D]). The length of each tuple equals to `num_cells + 1`. + Each value in the tuple represents the distort step of the related cell. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of string, each element corresponds to a key in ``keys``. + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"reflection"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of string, each element corresponds to a key in ``keys``. + device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. - spatial_channels: the spatial dimensions that are to have their coordinates encoded in a channel and - appended to the input. E.g., `(1,2,3)` will append three channels to the input, encoding the - coordinates of the input's three spatial dimensions. It is assumed dimension 0 is the channel. """ super().__init__(keys, allow_missing_keys) - self.add_coordinate_channels = AddCoordinateChannels(spatial_channels) + self.grid_distortion = GridDistortion(num_cells=num_cells, distort_steps=distort_steps, device=device) + self.mode = ensure_tuple_rep(mode, len(self.keys)) + self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - for key in self.key_iterator(d): - d[key] = self.add_coordinate_channels(d[key]) + for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + d[key] = self.grid_distortion(d[key], mode=mode, padding_mode=padding_mode) + return d + + +class RandGridDistortiond(RandomizableTransform, MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.RandGridDistortion`. + """ + + backend = RandGridDistortion.backend + + def __init__( + self, + keys: KeysCollection, + num_cells: Union[Tuple[int], int] = 5, + prob: float = 0.1, + distort_limit: Union[Tuple[float, float], float] = (-0.03, 0.03), + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + device: Optional[torch.device] = None, + allow_missing_keys: bool = False, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + num_cells: number of grid cells on each dimension. + prob: probability of returning a randomized grid distortion transform. Defaults to 0.1. + distort_limit: range to randomly distort. + If single number, distort_limit is picked from (-distort_limit, distort_limit). + Defaults to (-0.03, 0.03). + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of string, each element corresponds to a key in ``keys``. + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"reflection"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of string, each element corresponds to a key in ``keys``. + device: device on which the tensor will be allocated. + allow_missing_keys: don't raise exception if key is missing. + + """ + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob) + self.rand_grid_distortion = RandGridDistortion( + num_cells=num_cells, prob=1.0, distort_limit=distort_limit, device=device + ) + self.mode = ensure_tuple_rep(mode, len(self.keys)) + self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandGridDistortiond": + super().set_random_state(seed, state) + self.rand_grid_distortion.set_random_state(seed, state) + return self + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + self.randomize(None) + if not self._do_transform: + return d + + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return d + + self.rand_grid_distortion.randomize(d[first_key].shape[1:]) # type: ignore + for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + d[key] = self.rand_grid_distortion(d[key], mode=mode, padding_mode=padding_mode, randomize=False) return d @@ -1770,9 +1880,10 @@ def __call__( Rand3DElasticD = Rand3DElasticDict = Rand3DElasticd FlipD = FlipDict = Flipd RandFlipD = RandFlipDict = RandFlipd +GridDistortionD = GridDistortionDict = GridDistortiond +RandGridDistortionD = RandGridDistortionDict = RandGridDistortiond RandAxisFlipD = RandAxisFlipDict = RandAxisFlipd RotateD = RotateDict = Rotated RandRotateD = RandRotateDict = RandRotated ZoomD = ZoomDict = Zoomd RandZoomD = RandZoomDict = RandZoomd -AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index ef49bc706c..430e659c95 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,17 +21,10 @@ from monai import transforms from monai.config import KeysCollection -from monai.utils import MAX_SEED, ensure_tuple +from monai.utils import MAX_SEED, ensure_tuple, first from monai.utils.enums import TransformBackends -__all__ = [ - "ThreadUnsafe", - "apply_transform", - "Randomizable", - "RandomizableTransform", - "Transform", - "MapTransform", -] +__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"] ReturnType = TypeVar("ReturnType") @@ -47,9 +40,9 @@ def _apply_transform( Otherwise `parameters` is considered as single argument to `transform`. Args: - transform (Callable[..., ReturnType]): a callable to be used to transform `data`. - parameters (Any): parameters for the `transform`. - unpack_parameters (bool, optional): whether to unpack parameters for `transform`. Defaults to False. + transform: a callable to be used to transform `data`. + parameters: parameters for the `transform`. + unpack_parameters: whether to unpack parameters for `transform`. Defaults to False. Returns: ReturnType: The return type of `transform`. @@ -61,10 +54,7 @@ def _apply_transform( def apply_transform( - transform: Callable[..., ReturnType], - data: Any, - map_items: bool = True, - unpack_items: bool = False, + transform: Callable[..., ReturnType], data: Any, map_items: bool = True, unpack_items: bool = False ) -> Union[List[ReturnType], ReturnType]: """ Transform `data` with `transform`. @@ -74,11 +64,11 @@ def apply_transform( otherwise transform will be applied once with `data` as the argument. Args: - transform (Callable[..., ReturnType]): a callable to be used to transform `data`. - data (Any): an object to be transformed. - map_items (bool, optional): whether to apply transform to each item in `data`, + transform: a callable to be used to transform `data`. + data: an object to be transformed. + map_items: whether to apply transform to each item in `data`, if `data` is a list or tuple. Defaults to True. - unpack_items (bool, optional): [description]. Defaults to False. + unpack_items: whether to unpack parameters using `*`. Defaults to False. Raises: Exception: When ``transform`` raises an exception. @@ -213,10 +203,10 @@ class Transform(ABC): :py:class:`monai.transforms.Compose` """ + # Transforms should add data types to this list if they are capable of performing a transform without + # modifying the input type. For example, ["torch.Tensor", "np.ndarray"] means that no copies of the data + # are required if the input is either `torch.Tensor` or `np.ndarray`. backend: List[TransformBackends] = [] - """Transforms should add data types to this list if they are capable of performing a transform without - modifying the input type. For example, [\"torch.Tensor\", \"np.ndarray\"] means that no copies of the data - are required if the input is either \"torch.Tensor\" or \"np.ndarray\".""" @abstractmethod def __call__(self, data: Any): @@ -226,17 +216,15 @@ def __call__(self, data: Any): return an updated version of ``data``. To simplify the input validations, most of the transforms assume that - - ``data`` is a Numpy ndarray, PyTorch Tensor or string + - ``data`` is a Numpy ndarray, PyTorch Tensor or string, - the data shape can be: - #. string data without shape, `LoadImage` transform expects file paths - #. most of the pre-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``, - except that `AddChannel` expects (spatial_dim_1[, spatial_dim_2, ...]) and - `AsChannelFirst` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels) - #. most of the post-processing transforms expect - ``(batch_size, num_channels, spatial_dim_1[, spatial_dim_2, ...])`` + #. string data without shape, `LoadImage` transform expects file paths, + #. most of the pre-/post-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``, + except for example: `AddChannel` expects (spatial_dim_1[, spatial_dim_2, ...]) and + `AsChannelFirst` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels), - - the channel dimension is not omitted even if number of channels is one + - the channel dimension is often not omitted even if number of channels is one. This method can optionally take additional arguments to help execute transformation operation. @@ -333,18 +321,16 @@ def __call__(self, data): To simplify the input validations, this method assumes: - - ``data`` is a Python dictionary + - ``data`` is a Python dictionary, - ``data[key]`` is a Numpy ndarray, PyTorch Tensor or string, where ``key`` is an element of ``self.keys``, the data shape can be: - #. string data without shape, `LoadImaged` transform expects file paths - #. most of the pre-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``, - except that `AddChanneld` expects (spatial_dim_1[, spatial_dim_2, ...]) and + #. string data without shape, `LoadImaged` transform expects file paths, + #. most of the pre-/post-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``, + except for example: `AddChanneld` expects (spatial_dim_1[, spatial_dim_2, ...]) and `AsChannelFirstd` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels) - #. most of the post-processing transforms expect - ``(batch_size, num_channels, spatial_dim_1[, spatial_dim_2, ...])`` - - the channel dimension is not omitted even if number of channels is one + - the channel dimension is often not omitted even if number of channels is one. Raises: NotImplementedError: When the subclass does not override this method. @@ -355,11 +341,7 @@ def __call__(self, data): """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - def key_iterator( - self, - data: Dict[Hashable, Any], - *extra_iterables: Optional[Iterable], - ) -> Generator: + def key_iterator(self, data: Dict[Hashable, Any], *extra_iterables: Optional[Iterable]) -> Generator: """ Iterate across keys and optionally extra iterables. If key is missing, exception is raised if `allow_missing_keys==False` (default). If `allow_missing_keys==True`, key is skipped. @@ -379,3 +361,14 @@ def key_iterator( yield (key,) + tuple(_ex_iters) if extra_iterables else key elif not self.allow_missing_keys: raise KeyError(f"Key was missing ({key}) and allow_missing_keys==False") + + def first_key(self, data: Dict[Hashable, Any]): + """ + Get the first available key of `self.keys` in the input `data` dictionary. + If no available key, return an empty list `[]`. + + Args: + data: data that the transform will be applied to. + + """ + return first(self.key_iterator(data), []) diff --git a/monai/transforms/utility/__init__.py b/monai/transforms/utility/__init__.py index 14ae193634..1e97f89407 100644 --- a/monai/transforms/utility/__init__.py +++ b/monai/transforms/utility/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 2eb6c447c6..b869d2489c 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -31,22 +31,33 @@ map_binary_to_indices, map_classes_to_indices, ) -from monai.transforms.utils_pytorch_numpy_unification import in1d, moveaxis -from monai.utils import convert_to_numpy, convert_to_tensor, ensure_tuple, look_up_option, min_version, optional_import +from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, moveaxis, unravel_indices +from monai.utils import ( + convert_data_type, + convert_to_cupy, + convert_to_numpy, + convert_to_tensor, + deprecated_arg, + ensure_tuple, + look_up_option, + min_version, + optional_import, +) from monai.utils.enums import TransformBackends from monai.utils.misc import is_module_ver_at_least -from monai.utils.type_conversion import convert_data_type +from monai.utils.type_conversion import convert_to_dst_type, get_equivalent_dtype PILImageImage, has_pil = optional_import("PIL.Image", name="Image") pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray") cp, has_cp = optional_import("cupy") -cp_ndarray, _ = optional_import("cupy", name="ndarray") + __all__ = [ "Identity", "AsChannelFirst", "AsChannelLast", "AddChannel", + "AddCoordinateChannels", "EnsureChannelFirst", "EnsureType", "RepeatChannel", @@ -71,6 +82,9 @@ "MapLabelValue", "IntensityStats", "ToDevice", + "CuCIM", + "RandCuCIM", + "ToCupy", ] @@ -234,8 +248,8 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is a "channel-first" array. """ - repeeat_fn = torch.repeat_interleave if isinstance(img, torch.Tensor) else np.repeat - return repeeat_fn(img, self.repeats, 0) # type: ignore + repeat_fn = torch.repeat_interleave if isinstance(img, torch.Tensor) else np.repeat + return repeat_fn(img, self.repeats, 0) # type: ignore class RemoveRepeatedChannel(Transform): @@ -321,8 +335,6 @@ def __call__(self, img: NdarrayOrTensor, dtype: Optional[Union[DtypeLike, torch. TypeError: When ``img`` type is not in ``Union[numpy.ndarray, torch.Tensor]``. """ - if not isinstance(img, (torch.Tensor, np.ndarray)): - raise TypeError(f"img must be one of (numpy.ndarray, torch.Tensor) but is {type(img).__name__}.") img_out, *_ = convert_data_type(img, output_type=type(img), dtype=dtype or self.dtype) return img_out @@ -330,78 +342,131 @@ def __call__(self, img: NdarrayOrTensor, dtype: Optional[Union[DtypeLike, torch. class ToTensor(Transform): """ Converts the input image to a tensor without applying any other transformations. + Input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. + Will convert Tensor, Numpy array, float, int, bool to Tensor, strings and objects keep the original. + For dictionary, list or tuple, convert every item to a Tensor if applicable and `wrap_sequence=False`. + + Args: + dtype: target data type to when converting to Tensor. + device: target device to put the converted Tensor data. + wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. + E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`. + """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __call__(self, img: NdarrayOrTensor) -> torch.Tensor: + def __init__( + self, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, wrap_sequence: bool = True + ) -> None: + super().__init__() + self.dtype = dtype + self.device = device + self.wrap_sequence = wrap_sequence + + def __call__(self, img: NdarrayOrTensor): """ Apply the transform to `img` and make it contiguous. """ - return convert_to_tensor(img, wrap_sequence=True) # type: ignore + return convert_to_tensor(img, dtype=self.dtype, device=self.device, wrap_sequence=self.wrap_sequence) class EnsureType(Transform): """ Ensure the input data to be a PyTorch Tensor or numpy array, support: `numpy array`, `PyTorch Tensor`, `float`, `int`, `bool`, `string` and `object` keep the original. - If passing a dictionary, list or tuple, still return dictionary, list or tuple and recursively convert - every item to the expected data type. + If passing a dictionary, list or tuple, still return dictionary, list or tuple will recursively convert + every item to the expected data type if `wrap_sequence=False`. Args: data_type: target data type to convert, should be "tensor" or "numpy". + dtype: target data content type to convert, for example: np.float32, torch.float, etc. + device: for Tensor data type, specify the target device. + wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. + E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, data_type: str = "tensor") -> None: - data_type = data_type.lower() - if data_type not in ("tensor", "numpy"): - raise ValueError("`data type` must be 'tensor' or 'numpy'.") - - self.data_type = data_type + def __init__( + self, + data_type: str = "tensor", + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + device: Optional[torch.device] = None, + wrap_sequence: bool = True, + ) -> None: + self.data_type = look_up_option(data_type.lower(), {"tensor", "numpy"}) + self.dtype = dtype + self.device = device + self.wrap_sequence = wrap_sequence - def __call__(self, data: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, data: NdarrayOrTensor): """ Args: data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will ensure Tensor, Numpy array, float, int, bool as Tensors or numpy arrays, strings and objects keep the original. for dictionary, list or tuple, ensure every item as expected type - if applicable. + if applicable and `wrap_sequence=False`. """ - return convert_to_tensor(data) if self.data_type == "tensor" else convert_to_numpy(data) # type: ignore + output_type = torch.Tensor if self.data_type == "tensor" else np.ndarray + out, *_ = convert_data_type( + data=data, output_type=output_type, dtype=self.dtype, device=self.device, wrap_sequence=self.wrap_sequence + ) + return out class ToNumpy(Transform): """ Converts the input data to numpy array, can support list or tuple of numbers and PyTorch Tensor. + + Args: + dtype: target data type when converting to numpy array. + wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. + E.g., if `False`, `[1, 2]` -> `[array(1), array(2)]`, if `True`, then `[1, 2]` -> `array([1, 2])`. + """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __call__(self, img: NdarrayOrTensor) -> np.ndarray: + def __init__(self, dtype: DtypeLike = None, wrap_sequence: bool = True) -> None: + super().__init__() + self.dtype = dtype + self.wrap_sequence = wrap_sequence + + def __call__(self, img: NdarrayOrTensor): """ Apply the transform to `img` and make it contiguous. """ - return convert_to_numpy(img) # type: ignore + return convert_to_numpy(img, dtype=self.dtype, wrap_sequence=self.wrap_sequence) class ToCupy(Transform): """ Converts the input data to CuPy array, can support list or tuple of numbers, NumPy and PyTorch Tensor. + + Args: + dtype: data type specifier. It is inferred from the input by default. + if not None, must be an argument of `numpy.dtype`, for more details: + https://docs.cupy.dev/en/stable/reference/generated/cupy.array.html. + wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. + E.g., if `False`, `[1, 2]` -> `[array(1), array(2)]`, if `True`, then `[1, 2]` -> `array([1, 2])`. + """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __init__(self, dtype: Optional[np.dtype] = None, wrap_sequence: bool = True) -> None: + super().__init__() + self.dtype = dtype + self.wrap_sequence = wrap_sequence + + def __call__(self, data: NdarrayOrTensor): """ - Apply the transform to `img` and make it contiguous. + Create a CuPy array from `data` and make it contiguous """ - if isinstance(img, torch.Tensor): - img = img.detach().cpu().numpy() - return cp.ascontiguousarray(cp.asarray(img)) # type: ignore + return convert_to_cupy(data, dtype=self.dtype, wrap_sequence=self.wrap_sequence) class ToPIL(Transform): @@ -547,7 +612,7 @@ def __call__( lines = [f"{prefix or self.prefix} statistics:"] if self.data_type if data_type is None else data_type: - lines.append(f"Type: {type(img)}") + lines.append(f"Type: {type(img)} {img.dtype if hasattr(img, 'dtype') else None}") if self.data_shape if data_shape is None else data_shape: lines.append(f"Shape: {img.shape}") if self.value_range if value_range is None else value_range: @@ -697,9 +762,7 @@ class LabelToMask(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( # pytype: disable=annotation-type-mismatch - self, - select_labels: Union[Sequence[int], int], - merge_channels: bool = False, + self, select_labels: Union[Sequence[int], int], merge_channels: bool = False ) -> None: # pytype: disable=annotation-type-mismatch self.select_labels = ensure_tuple(select_labels) self.merge_channels = merge_channels @@ -726,7 +789,7 @@ def __call__( if img.shape[0] > 1: data = img[[*select_labels]] else: - where = np.where if isinstance(img, np.ndarray) else torch.where + where: Callable = np.where if isinstance(img, np.ndarray) else torch.where # type: ignore if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): data = where(in1d(img, select_labels), True, False).reshape(img.shape) # pre pytorch 1.8.0, need to use 1/0 instead of True/False @@ -759,16 +822,18 @@ class FgBgToIndices(Transform): """ + backend = [TransformBackends.NUMPY, TransformBackends.TORCH] + def __init__(self, image_threshold: float = 0.0, output_shape: Optional[Sequence[int]] = None) -> None: self.image_threshold = image_threshold self.output_shape = output_shape def __call__( self, - label: np.ndarray, - image: Optional[np.ndarray] = None, + label: NdarrayOrTensor, + image: Optional[NdarrayOrTensor] = None, output_shape: Optional[Sequence[int]] = None, - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: """ Args: label: input data to compute foreground and background indices. @@ -781,13 +846,15 @@ def __call__( output_shape = self.output_shape fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold) if output_shape is not None: - fg_indices = np.stack([np.unravel_index(i, output_shape) for i in fg_indices]) - bg_indices = np.stack([np.unravel_index(i, output_shape) for i in bg_indices]) - + fg_indices = unravel_indices(fg_indices, output_shape) + bg_indices = unravel_indices(bg_indices, output_shape) return fg_indices, bg_indices class ClassesToIndices(Transform): + + backend = [TransformBackends.NUMPY, TransformBackends.TORCH] + def __init__( self, num_classes: Optional[int] = None, @@ -814,10 +881,10 @@ def __init__( def __call__( self, - label: np.ndarray, - image: Optional[np.ndarray] = None, + label: NdarrayOrTensor, + image: Optional[NdarrayOrTensor] = None, output_shape: Optional[Sequence[int]] = None, - ) -> List[np.ndarray]: + ) -> List[NdarrayOrTensor]: """ Args: label: input data to compute the indices of every class. @@ -826,11 +893,13 @@ def __call__( output_shape: expected shape of output indices. if None, use `self.output_shape` instead. """ + if output_shape is None: output_shape = self.output_shape + indices: List[NdarrayOrTensor] indices = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) if output_shape is not None: - indices = [np.stack([np.unravel_index(i, output_shape) for i in array]) for array in indices] + indices = [unravel_indices(cls_indices, output_shape) for cls_indices in indices] return indices @@ -845,19 +914,17 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform): and ET (Enhancing tumor). """ - def __call__(self, img: np.ndarray) -> np.ndarray: + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: # if img has channel dim, squeeze it if img.ndim == 4 and img.shape[0] == 1: - img = np.squeeze(img, axis=0) + img = img.squeeze(0) - result = [] - # merge labels 1 (tumor non-enh) and 4 (tumor enh) to TC - result.append(np.logical_or(img == 1, img == 4)) + result = [(img == 1) | (img == 4), (img == 1) | (img == 4) | (img == 2), img == 4] # merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT - result.append(np.logical_or(np.logical_or(img == 1, img == 4), img == 2)) # label 4 is ET - result.append(img == 4) - return np.stack(result, axis=0) + return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0) class AddExtremePointsChannel(Randomizable, Transform): @@ -880,22 +947,24 @@ class AddExtremePointsChannel(Randomizable, Transform): ValueError: When label image is not single channel. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, background: int = 0, pert: float = 0.0) -> None: self._background = background self._pert = pert self._points: List[Tuple[int, ...]] = [] - def randomize(self, label: np.ndarray) -> None: + def randomize(self, label: NdarrayOrTensor) -> None: self._points = get_extreme_points(label, rand_state=self.R, background=self._background, pert=self._pert) def __call__( self, - img: np.ndarray, - label: Optional[np.ndarray] = None, + img: NdarrayOrTensor, + label: Optional[NdarrayOrTensor] = None, sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 3.0, rescale_min: float = -1.0, rescale_max: float = 1.0, - ): + ) -> NdarrayOrTensor: """ Args: img: the image that we want to add new channel to. @@ -918,8 +987,8 @@ def __call__( points_image = extreme_points_to_image( points=self._points, label=label, sigma=sigma, rescale_min=rescale_min, rescale_max=rescale_max ) - - return np.concatenate([img, points_image], axis=0) + points_image, *_ = convert_to_dst_type(points_image, img) # type: ignore + return concatenate((img, points_image), axis=0) class TorchVision: @@ -930,6 +999,8 @@ class TorchVision: """ + backend = [TransformBackends.TORCH] + def __init__(self, name: str, *args, **kwargs) -> None: """ Args: @@ -939,16 +1010,20 @@ def __init__(self, name: str, *args, **kwargs) -> None: """ super().__init__() + self.name = name transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name) self.trans = transform(*args, **kwargs) - def __call__(self, img: torch.Tensor): + def __call__(self, img: NdarrayOrTensor): """ Args: img: PyTorch Tensor data for the TorchVision transform. """ - return self.trans(img) + img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore + out = self.trans(img_t) + out, *_ = convert_to_dst_type(src=out, dst=img) + return out class MapLabelValue: @@ -960,6 +1035,8 @@ class MapLabelValue: """ + backend = [TransformBackends.NUMPY] + def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeLike = np.float32) -> None: """ Args: @@ -975,13 +1052,13 @@ def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeL self.orig_labels = orig_labels self.target_labels = target_labels - self.dtype = dtype + self.dtype = get_equivalent_dtype(dtype, data_type=np.ndarray) - def __call__(self, img: np.ndarray): - img = np.asarray(img) - img_flat = img.flatten() + def __call__(self, img: NdarrayOrTensor): + img_np, *_ = convert_data_type(img, np.ndarray) + img_flat = img_np.flatten() try: - out_flat = np.copy(img_flat).astype(self.dtype) + out_flat = np.array(img_flat, dtype=self.dtype) except ValueError: # can't copy unchanged labels as the expected dtype is not supported, must map all the label values out_flat = np.zeros(shape=img_flat.shape, dtype=self.dtype) @@ -991,7 +1068,9 @@ def __call__(self, img: np.ndarray): continue np.place(out_flat, img_flat == o, t) - return out_flat.reshape(img.shape) + reshaped = out_flat.reshape(img_np.shape) + out, *_ = convert_to_dst_type(src=reshaped, dst=img, dtype=self.dtype) + return out class IntensityStats(Transform): @@ -1013,17 +1092,16 @@ class IntensityStats(Transform): """ + backend = [TransformBackends.NUMPY] + def __init__(self, ops: Sequence[Union[str, Callable]], key_prefix: str, channel_wise: bool = False) -> None: self.ops = ensure_tuple(ops) self.key_prefix = key_prefix self.channel_wise = channel_wise def __call__( - self, - img: np.ndarray, - meta_data: Optional[Dict] = None, - mask: Optional[np.ndarray] = None, - ) -> Tuple[np.ndarray, Dict]: + self, img: NdarrayOrTensor, meta_data: Optional[Dict] = None, mask: Optional[np.ndarray] = None + ) -> Tuple[NdarrayOrTensor, Dict]: """ Compute statistics for the intensity of input image. @@ -1034,21 +1112,22 @@ def __call__( mask must have the same shape as input `img`. """ + img_np: np.ndarray + img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore if meta_data is None: meta_data = {} - img_: np.ndarray = img if mask is not None: - if mask.shape != img.shape or mask.dtype != bool: + if mask.shape != img_np.shape or mask.dtype != bool: raise TypeError("mask must be bool array with the same shape as input `img`.") - img_ = img[mask] + img_np = img_np[mask] supported_ops = { - "mean": lambda x: np.nanmean(x), - "median": lambda x: np.nanmedian(x), - "max": lambda x: np.nanmax(x), - "min": lambda x: np.nanmin(x), - "std": lambda x: np.nanstd(x), + "mean": np.nanmean, + "median": np.nanmedian, + "max": np.nanmax, + "min": np.nanmin, + "std": np.nanstd, } def _compute(op: Callable, data: np.ndarray): @@ -1060,9 +1139,9 @@ def _compute(op: Callable, data: np.ndarray): for o in self.ops: if isinstance(o, str): o = look_up_option(o, supported_ops.keys()) - meta_data[self.key_prefix + "_" + o] = _compute(supported_ops[o], img_) + meta_data[self.key_prefix + "_" + o] = _compute(supported_ops[o], img_np) # type: ignore elif callable(o): - meta_data[self.key_prefix + "_custom_" + str(custom_index)] = _compute(o, img_) + meta_data[self.key_prefix + "_custom_" + str(custom_index)] = _compute(o, img_np) custom_index += 1 else: raise ValueError("ops must be key string for predefined operations or callable function.") @@ -1083,6 +1162,8 @@ class ToDevice(Transform): """ + backend = [TransformBackends.TORCH] + def __init__(self, device: Union[torch.device, str], **kwargs) -> None: """ Args: @@ -1099,3 +1180,121 @@ def __call__(self, img: torch.Tensor): raise ValueError("img must be PyTorch Tensor, consider converting img by `EnsureType` transform first.") return img.to(self.device, **self.kwargs) + + +class CuCIM(Transform): + """ + Wrap a non-randomized cuCIM transform, defined based on the transform name and args. + For randomized transforms (or randomly applying a transform) use :py:class:`monai.transforms.RandCuCIM`. + + Args: + name: the transform name in CuCIM package + args: parameters for the CuCIM transform + kwargs: parameters for the CuCIM transform + + Note: + CuCIM transform only work with CuPy arrays, so this transform expects input data to be `cupy.ndarray`. + Users can call `ToCuPy` transform to convert a numpy array or torch tensor to cupy array. + """ + + def __init__(self, name: str, *args, **kwargs) -> None: + super().__init__() + self.name = name + self.transform, _ = optional_import("cucim.core.operations.expose.transform", name=name) + self.args = args + self.kwargs = kwargs + + def __call__(self, data): + """ + Args: + data: a CuPy array (`cupy.ndarray`) for the cuCIM transform + + Returns: + `cupy.ndarray` + + """ + return self.transform(data, *self.args, **self.kwargs) + + +class RandCuCIM(CuCIM, RandomizableTransform): + """ + Wrap a randomized cuCIM transform, defined based on the transform name and args, + or randomly apply a non-randomized transform. + For deterministic non-randomized transforms use :py:class:`monai.transforms.CuCIM`. + + Args: + name: the transform name in CuCIM package. + apply_prob: the probability to apply the transform (default=1.0) + args: parameters for the CuCIM transform. + kwargs: parameters for the CuCIM transform. + + Note: + - CuCIM transform only work with CuPy arrays, so this transform expects input data to be `cupy.ndarray`. + Users can call `ToCuPy` transform to convert a numpy array or torch tensor to cupy array. + - If the cuCIM transform is already randomized the `apply_prob` argument has nothing to do with + the randomness of the underlying cuCIM transform. `apply_prob` defines if the transform (either randomized + or non-randomized) being applied randomly, so it can apply non-randomized transforms randomly but be careful + with setting `apply_prob` to anything than 1.0 when using along with cuCIM's randomized transforms. + - If the random factor of the underlying cuCIM transform is not derived from `self.R`, + the results may not be deterministic. See Also: :py:class:`monai.transforms.Randomizable`. + """ + + def __init__(self, name: str, apply_prob: float = 1.0, *args, **kwargs) -> None: + CuCIM.__init__(self, name, *args, **kwargs) + RandomizableTransform.__init__(self, prob=apply_prob) + + def __call__(self, data): + """ + Args: + data: a CuPy array (`cupy.ndarray`) for the cuCIM transform + + Returns: + `cupy.ndarray` + + """ + self.randomize(data) + if not self._do_transform: + return data + return super().__call__(data) + + +class AddCoordinateChannels(Transform): + """ + Appends additional channels encoding coordinates of the input. Useful when e.g. training using patch-based sampling, + to allow feeding of the patch's location into the network. + + This can be seen as a input-only version of CoordConv: + + Liu, R. et al. An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution, NeurIPS 2018. + + Args: + spatial_dims: the spatial dimensions that are to have their coordinates encoded in a channel and + appended to the input image. E.g., `(0, 1, 2)` represents `H, W, D` dims and append three channels + to the input image, encoding the coordinates of the input's three spatial dimensions. + + .. deprecated:: 0.8.0 + ``spatial_channels`` is deprecated, use ``spatial_dims`` instead. + + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + @deprecated_arg( + name="spatial_channels", new_name="spatial_dims", since="0.8", msg_suffix="please use `spatial_dims` instead." + ) + def __init__(self, spatial_dims: Sequence[int]) -> None: + self.spatial_dims = spatial_dims + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + """ + Args: + img: data to be transformed, assuming `img` is channel first. + """ + if max(self.spatial_dims) > img.ndim - 2 or min(self.spatial_dims) < 0: + raise ValueError(f"`spatial_dims` values must be within [0, {img.ndim - 2}]") + + spatial_size = img.shape[1:] + coord_channels = np.array(np.meshgrid(*tuple(np.linspace(-0.5, 0.5, s) for s in spatial_size), indexing="ij")) + coord_channels, *_ = convert_to_dst_type(coord_channels, img) # type: ignore + coord_channels = coord_channels[list(self.spatial_dims)] + return concatenate((img, coord_channels), axis=0) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index e9bcce93b0..b74e63f683 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,8 +15,8 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ -import copy import logging +import re from copy import deepcopy from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union @@ -30,11 +30,14 @@ from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform from monai.transforms.utility.array import ( AddChannel, + AddCoordinateChannels, + AddExtremePointsChannel, AsChannelFirst, AsChannelLast, CastToType, ClassesToIndices, ConvertToMultiChannelBasedOnBratsClasses, + CuCIM, DataStats, EnsureChannelFirst, EnsureType, @@ -58,13 +61,18 @@ Transpose, ) from monai.transforms.utils import extreme_points_to_image, get_extreme_points -from monai.utils import convert_to_numpy, ensure_tuple, ensure_tuple_rep -from monai.utils.enums import InverseKeys, TransformBackends +from monai.transforms.utils_pytorch_numpy_unification import concatenate +from monai.utils import convert_to_numpy, deprecated_arg, ensure_tuple, ensure_tuple_rep +from monai.utils.enums import TraceKeys, TransformBackends +from monai.utils.type_conversion import convert_to_dst_type __all__ = [ "AddChannelD", "AddChannelDict", "AddChanneld", + "AddCoordinateChannelsD", + "AddCoordinateChannelsDict", + "AddCoordinateChannelsd", "AddExtremePointsChannelD", "AddExtremePointsChannelDict", "AddExtremePointsChanneld", @@ -86,6 +94,9 @@ "CopyItemsD", "CopyItemsDict", "CopyItemsd", + "CuCIMd", + "CuCIMD", + "CuCIMDict", "DataStatsD", "DataStatsDict", "DataStatsd", @@ -116,6 +127,9 @@ "MapLabelValueD", "MapLabelValueDict", "MapLabelValued", + "RandCuCIMd", + "RandCuCIMD", + "RandCuCIMDict", "RandLambdaD", "RandLambdaDict", "RandLambdad", @@ -161,6 +175,9 @@ "Transposed", "TransposeDict", "TransposeD", + "ClassesToIndicesd", + "ClassesToIndicesD", + "ClassesToIndicesDict", ] @@ -442,15 +459,26 @@ class ToTensord(MapTransform, InvertibleTransform): backend = ToTensor.backend - def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: + def __init__( + self, + keys: KeysCollection, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + wrap_sequence: bool = True, + allow_missing_keys: bool = False, + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` + dtype: target data content type to convert, for example: torch.float, etc. + device: specify the target device to put the Tensor data. + wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. + E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) - self.converter = ToTensor() + self.converter = ToTensor(dtype=dtype, device=device, wrap_sequence=wrap_sequence) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -478,7 +506,7 @@ class EnsureTyped(MapTransform, InvertibleTransform): Ensure the input data to be a PyTorch Tensor or numpy array, support: `numpy array`, `PyTorch Tensor`, `float`, `int`, `bool`, `string` and `object` keep the original. If passing a dictionary, list or tuple, still return dictionary, list or tuple and recursively convert - every item to the expected data type. + every item to the expected data type if `wrap_sequence=False`. Note: Currently, we only convert tensor data to numpy array or scalar number in the inverse operation. @@ -486,16 +514,28 @@ class EnsureTyped(MapTransform, InvertibleTransform): backend = EnsureType.backend - def __init__(self, keys: KeysCollection, data_type: str = "tensor", allow_missing_keys: bool = False) -> None: + def __init__( + self, + keys: KeysCollection, + data_type: str = "tensor", + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + device: Optional[torch.device] = None, + wrap_sequence: bool = True, + allow_missing_keys: bool = False, + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` data_type: target data type to convert, should be "tensor" or "numpy". + dtype: target data content type to convert, for example: np.float32, torch.float, etc. + device: for Tensor data type, specify the target device. + wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. + E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) - self.converter = EnsureType(data_type=data_type) + self.converter = EnsureType(data_type=data_type, dtype=dtype, device=device, wrap_sequence=wrap_sequence) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -522,15 +562,24 @@ class ToNumpyd(MapTransform): backend = ToNumpy.backend - def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: + def __init__( + self, + keys: KeysCollection, + dtype: DtypeLike = None, + wrap_sequence: bool = True, + allow_missing_keys: bool = False, + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` + dtype: target data type when converting to numpy array. + wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. + E.g., if `False`, `[1, 2]` -> `[array(1), array(2)]`, if `True`, then `[1, 2]` -> `array([1, 2])`. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) - self.converter = ToNumpy() + self.converter = ToNumpy(dtype=dtype, wrap_sequence=wrap_sequence) def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: d = dict(data) @@ -542,19 +591,29 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: class ToCupyd(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.ToCupy`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + dtype: data type specifier. It is inferred from the input by default. + if not None, must be an argument of `numpy.dtype`, for more details: + https://docs.cupy.dev/en/stable/reference/generated/cupy.array.html. + wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. + E.g., if `False`, `[1, 2]` -> `[array(1), array(2)]`, if `True`, then `[1, 2]` -> `array([1, 2])`. + allow_missing_keys: don't raise exception if key is missing. """ backend = ToCupy.backend - def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: - """ - Args: - keys: keys of the corresponding items to be transformed. - See also: :py:class:`monai.transforms.compose.MapTransform` - allow_missing_keys: don't raise exception if key is missing. - """ + def __init__( + self, + keys: KeysCollection, + dtype: Optional[np.dtype] = None, + wrap_sequence: bool = True, + allow_missing_keys: bool = False, + ) -> None: super().__init__(keys, allow_missing_keys) - self.converter = ToCupy() + self.converter = ToCupy(dtype=dtype, wrap_sequence=wrap_sequence) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -614,7 +673,7 @@ def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - fwd_indices = np.array(transform[InverseKeys.EXTRA_INFO]["indices"]) + fwd_indices = np.array(transform[TraceKeys.EXTRA_INFO]["indices"]) inv_indices = np.argsort(fwd_indices) inverse_transform = Transpose(inv_indices.tolist()) # Apply inverse @@ -630,8 +689,35 @@ class DeleteItemsd(MapTransform): It will remove the key-values and copy the others to construct a new dictionary. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self, keys: KeysCollection, sep: str = ".", use_re: Union[Sequence[bool], bool] = False) -> None: + """ + Args: + keys: keys of the corresponding items to delete, can be "A{sep}B{sep}C" + to delete key `C` in nested dictionary, `C` can be regular expression. + See also: :py:class:`monai.transforms.compose.MapTransform` + sep: the separator tag to define nested dictionary keys, default to ".". + use_re: whether the specified key is a regular expression, it also can be + a list of bool values, map the to keys. + """ + super().__init__(keys) + self.sep = sep + self.use_re = ensure_tuple_rep(use_re, len(self.keys)) + def __call__(self, data): - return {key: val for key, val in data.items() if key not in self.key_iterator(data)} + def _delete_item(keys, d, use_re: bool = False): + key = keys[0] + if len(keys) > 1: + d[key] = _delete_item(keys[1:], d[key], use_re) + return d + return {k: v for k, v in d.items() if (use_re and not re.search(key, k)) or (not use_re and k != key)} + + d = dict(data) + for key, use_re in zip(self.keys, self.use_re): + d = _delete_item(key.split(self.sep), d, use_re) + + return d class SelectItemsd(MapTransform): @@ -640,9 +726,10 @@ class SelectItemsd(MapTransform): It will copy the selected key-values and construct and new dictionary. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __call__(self, data): - result = {key: data[key] for key in self.key_iterator(data)} - return result + return {key: data[key] for key in self.key_iterator(data)} class SqueezeDimd(MapTransform): @@ -728,15 +815,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N for key, prefix, data_type, data_shape, value_range, data_value, additional_info in self.key_iterator( d, self.prefix, self.data_type, self.data_shape, self.value_range, self.data_value, self.additional_info ): - d[key] = self.printer( - d[key], - prefix, - data_type, - data_shape, - value_range, - data_value, - additional_info, - ) + d[key] = self.printer(d[key], prefix, data_type, data_shape, value_range, data_value, additional_info) return d @@ -825,7 +904,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N if isinstance(val, torch.Tensor): d[new_key] = val.detach().clone() else: - d[new_key] = copy.deepcopy(val) + d[new_key] = deepcopy(val) return d @@ -866,6 +945,10 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N elif not isinstance(d[key], data_type): raise TypeError("All items in data must have the same type.") output.append(d[key]) + + if len(output) == 0: + return d + if data_type is np.ndarray: d[self.name] = np.concatenate(output, axis=self.dim) elif data_type is torch.Tensor: @@ -1000,7 +1083,7 @@ def __call__(self, data): return super().__call__(data) def _inverse_transform(self, transform_info: Dict, data: Any, func: Callable): - return self._lambd(data, func=func) if transform_info[InverseKeys.DO_TRANSFORM] else data + return self._lambd(data, func=func) if transform_info[TraceKeys.DO_TRANSFORM] else data class LabelToMaskd(MapTransform): @@ -1059,6 +1142,8 @@ class FgBgToIndicesd(MapTransform): """ + backend = FgBgToIndices.backend + def __init__( self, keys: KeysCollection, @@ -1075,7 +1160,7 @@ def __init__( self.image_key = image_key self.converter = FgBgToIndices(image_threshold, output_shape) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) image = d[self.image_key] if self.image_key else None for key in self.key_iterator(d): @@ -1103,6 +1188,8 @@ class ClassesToIndicesd(MapTransform): """ + backend = ClassesToIndices.backend + def __init__( self, keys: KeysCollection, @@ -1118,7 +1205,7 @@ def __init__( self.image_key = image_key self.converter = ClassesToIndices(num_classes, image_threshold, output_shape) - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, Any]): d = dict(data) image = d[self.image_key] if self.image_key else None for key in self.key_iterator(d): @@ -1138,11 +1225,13 @@ class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform): and ET (Enhancing tumor). """ + backend = ConvertToMultiChannelBasedOnBratsClasses.backend + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False): super().__init__(keys, allow_missing_keys) self.converter = ConvertToMultiChannelBasedOnBratsClasses() - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -1168,6 +1257,8 @@ class AddExtremePointsChanneld(Randomizable, MapTransform): """ + backend = AddExtremePointsChannel.backend + def __init__( self, keys: KeysCollection, @@ -1188,10 +1279,10 @@ def __init__( self.rescale_min = rescale_min self.rescale_max = rescale_max - def randomize(self, label: np.ndarray) -> None: + def randomize(self, label: NdarrayOrTensor) -> None: self.points = get_extreme_points(label, rand_state=self.R, background=self.background, pert=self.pert) - def __call__(self, data): + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) label = d[self.label_key] if label.shape[0] != 1: @@ -1209,7 +1300,8 @@ def __call__(self, data): rescale_min=self.rescale_min, rescale_max=self.rescale_max, ) - d[key] = np.concatenate([img, points_image], axis=0) + points_image, *_ = convert_to_dst_type(points_image, img) # type: ignore + d[key] = concatenate([img, points_image], axis=0) return d @@ -1223,14 +1315,9 @@ class TorchVisiond(MapTransform): data to be dict of PyTorch Tensors, users can easily call `ToTensord` transform to convert Numpy to Tensor. """ - def __init__( - self, - keys: KeysCollection, - name: str, - allow_missing_keys: bool = False, - *args, - **kwargs, - ) -> None: + backend = TorchVision.backend + + def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None: """ Args: keys: keys of the corresponding items to be transformed. @@ -1242,9 +1329,10 @@ def __init__( """ super().__init__(keys, allow_missing_keys) + self.name = name self.trans = TorchVision(name, *args, **kwargs) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.trans(d[key]) @@ -1267,14 +1355,9 @@ class RandTorchVisiond(Randomizable, MapTransform): """ - def __init__( - self, - keys: KeysCollection, - name: str, - allow_missing_keys: bool = False, - *args, - **kwargs, - ) -> None: + backend = TorchVision.backend + + def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None: """ Args: keys: keys of the corresponding items to be transformed. @@ -1286,9 +1369,10 @@ def __init__( """ MapTransform.__init__(self, keys, allow_missing_keys) + self.name = name self.trans = TorchVision(name, *args, **kwargs) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.trans(d[key]) @@ -1300,6 +1384,8 @@ class MapLabelValued(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.MapLabelValue`. """ + backend = MapLabelValue.backend + def __init__( self, keys: KeysCollection, @@ -1321,7 +1407,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.mapper = MapLabelValue(orig_labels=orig_labels, target_labels=target_labels, dtype=dtype) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.mapper(d[key]) @@ -1363,6 +1449,8 @@ class IntensityStatsd(MapTransform): """ + backend = IntensityStats.backend + def __init__( self, keys: KeysCollection, @@ -1382,16 +1470,14 @@ def __init__( raise ValueError("meta_keys should have the same length as keys.") self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - def __call__(self, data) -> Dict[Hashable, np.ndarray]: + def __call__(self, data) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, mask_key, meta_key, meta_key_postfix in self.key_iterator( d, self.mask_keys, self.meta_keys, self.meta_key_postfix ): meta_key = meta_key or f"{key}_{meta_key_postfix}" d[key], d[meta_key] = self.stats( - img=d[key], - meta_data=d.get(meta_key), - mask=d.get(mask_key) if mask_key is not None else None, + img=d[key], meta_data=d.get(meta_key), mask=d.get(mask_key) if mask_key is not None else None ) return d @@ -1401,12 +1487,10 @@ class ToDeviced(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ToDevice`. """ + backend = ToDevice.backend + def __init__( - self, - keys: KeysCollection, - device: Union[torch.device, str], - allow_missing_keys: bool = False, - **kwargs, + self, keys: KeysCollection, device: Union[torch.device, str], allow_missing_keys: bool = False, **kwargs ) -> None: """ Args: @@ -1427,6 +1511,121 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc return d +class CuCIMd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.CuCIM` for non-randomized transforms. + For randomized transforms of CuCIM use :py:class:`monai.transforms.RandCuCIMd`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + name: The transform name in CuCIM package. + allow_missing_keys: don't raise exception if key is missing. + args: parameters for the CuCIM transform. + kwargs: parameters for the CuCIM transform. + + Note: + CuCIM transforms only work with CuPy arrays, this transform expects input data to be `cupy.ndarray`. + Users can call `ToCuPy` transform to convert a numpy array or torch tensor to cupy array. + """ + + def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None: + super().__init__(keys=keys, allow_missing_keys=allow_missing_keys) + self.name = name + self.trans = CuCIM(name, *args, **kwargs) + + def __call__(self, data): + """ + Args: + data: Dict[Hashable, `cupy.ndarray`] + + Returns: + Dict[Hashable, `cupy.ndarray`] + + """ + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.trans(d[key]) + return d + + +class RandCuCIMd(CuCIMd, RandomizableTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.CuCIM` for randomized transforms. + For deterministic non-randomized transforms of CuCIM use :py:class:`monai.transforms.CuCIMd`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + name: The transform name in CuCIM package. + apply_prob: the probability to apply the transform (default=1.0) + allow_missing_keys: don't raise exception if key is missing. + args: parameters for the CuCIM transform. + kwargs: parameters for the CuCIM transform. + + Note: + - CuCIM transform only work with CuPy arrays, so this transform expects input data to be `cupy.ndarray`. + Users can call `ToCuPy` transform to convert a numpy array or torch tensor to cupy array. + - If the cuCIM transform is already randomized the `apply_prob` argument has nothing to do with + the randomness of the underlying cuCIM transform. `apply_prob` defines if the transform (either randomized + or non-randomized) being applied randomly, so it can apply non-randomized transforms randomly but be careful + with setting `apply_prob` to anything than 1.0 when using along with cuCIM's randomized transforms. + - If the random factor of the underlying cuCIM transform is not derived from `self.R`, + the results may not be deterministic. See Also: :py:class:`monai.transforms.Randomizable`. + """ + + def __init__(self, apply_prob: float = 1.0, *args, **kwargs) -> None: + CuCIMd.__init__(self, *args, **kwargs) + RandomizableTransform.__init__(self, prob=apply_prob) + + def __call__(self, data): + """ + Args: + data: Dict[Hashable, `cupy.ndarray`] + + Returns: + Dict[Hashable, `cupy.ndarray`] + + """ + self.randomize(data) + if not self._do_transform: + return dict(data) + return super().__call__(data) + + +class AddCoordinateChannelsd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.AddCoordinateChannels`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + spatial_dims: the spatial dimensions that are to have their coordinates encoded in a channel and + appended to the input image. E.g., `(0, 1, 2)` represents `H, W, D` dims and append three channels + to the input image, encoding the coordinates of the input's three spatial dimensions. + allow_missing_keys: don't raise exception if key is missing. + + .. deprecated:: 0.8.0 + ``spatial_channels`` is deprecated, use ``spatial_dims`` instead. + + """ + + backend = AddCoordinateChannels.backend + + @deprecated_arg( + name="spatial_channels", new_name="spatial_dims", since="0.8", msg_suffix="please use `spatial_dims` instead." + ) + def __init__(self, keys: KeysCollection, spatial_dims: Sequence[int], allow_missing_keys: bool = False) -> None: + super().__init__(keys, allow_missing_keys) + self.add_coordinate_channels = AddCoordinateChannels(spatial_dims=spatial_dims) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.add_coordinate_channels(d[key]) + return d + + IdentityD = IdentityDict = Identityd AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd @@ -1463,3 +1662,6 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc MapLabelValueD = MapLabelValueDict = MapLabelValued IntensityStatsD = IntensityStatsDict = IntensityStatsd ToDeviceD = ToDeviceDict = ToDeviced +CuCIMD = CuCIMDict = CuCIMd +RandCuCIMD = RandCuCIMDict = RandCuCIMd +AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 30aa5e7b99..739d98e5c0 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,26 +20,39 @@ import torch import monai -import monai.transforms.transform from monai.config import DtypeLike, IndexSelection from monai.config.type_definitions import NdarrayOrTensor from monai.networks.layers import GaussianFilter from monai.transforms.compose import Compose, OneOf -from monai.transforms.transform import MapTransform, Transform +from monai.transforms.transform import MapTransform, Transform, apply_transform +from monai.transforms.utils_pytorch_numpy_unification import ( + any_np_pt, + cumsum, + isfinite, + nonzero, + ravel, + searchsorted, + unravel_index, + where, +) from monai.utils import ( GridSampleMode, InterpolateMode, - InverseKeys, + NumpyPadMode, + PytorchPadMode, + TraceKeys, + deprecated_arg, ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, issequenceiterable, + look_up_option, min_version, optional_import, ) from monai.utils.enums import TransformBackends -from monai.utils.type_conversion import convert_data_type +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type measure, _ = optional_import("skimage.measure", "0.14.2", min_version) ndimage, _ = optional_import("scipy.ndimage") @@ -84,6 +97,7 @@ "get_number_image_type_conversions", "get_transform_backends", "print_transform_backends", + "convert_pad_mode", ] @@ -135,31 +149,43 @@ def zero_margins(img: np.ndarray, margin: int) -> bool: def rescale_array( - arr: NdarrayOrTensor, minv: float = 0.0, maxv: float = 1.0, dtype: Union[DtypeLike, torch.dtype] = np.float32 + arr: NdarrayOrTensor, + minv: Optional[float] = 0.0, + maxv: Optional[float] = 1.0, + dtype: Optional[Union[DtypeLike, torch.dtype]] = np.float32, ) -> NdarrayOrTensor: """ Rescale the values of numpy array `arr` to be from `minv` to `maxv`. + If either `minv` or `maxv` is None, it returns `(a - min_a) / (max_a - min_a)`. + + Args: + arr: input array to rescale. + minv: minimum value of target rescaled array. + maxv: maxmum value of target rescaled array. + dtype: if not None, convert input array to dtype before computation. + """ if dtype is not None: arr, *_ = convert_data_type(arr, dtype=dtype) - mina = arr.min() maxa = arr.max() if mina == maxa: - return arr * minv + return arr * minv if minv is not None else arr norm = (arr - mina) / (maxa - mina) # normalize the array first + if (minv is None) or (maxv is None): + return norm return (norm * (maxv - minv)) + minv # rescale by minv and maxv, which is the normalized array by default def rescale_instance_array( - arr: np.ndarray, minv: float = 0.0, maxv: float = 1.0, dtype: DtypeLike = np.float32 + arr: np.ndarray, minv: Optional[float] = 0.0, maxv: Optional[float] = 1.0, dtype: DtypeLike = np.float32 ) -> np.ndarray: """ Rescale each array slice along the first dimension of `arr` independently. """ - out: np.ndarray = np.zeros(arr.shape, dtype) + out: np.ndarray = np.zeros(arr.shape, dtype or arr.dtype) for i in range(arr.shape[0]): out[i] = rescale_array(arr[i], minv, maxv, dtype) @@ -170,16 +196,12 @@ def rescale_array_int_max(arr: np.ndarray, dtype: DtypeLike = np.uint16) -> np.n """ Rescale the array `arr` to be between the minimum and maximum values of the type `dtype`. """ - info: np.iinfo = np.iinfo(dtype) - return np.asarray(rescale_array(arr, info.min, info.max), dtype=dtype) + info: np.iinfo = np.iinfo(dtype or arr.dtype) + return np.asarray(rescale_array(arr, info.min, info.max), dtype=dtype or arr.dtype) def copypaste_arrays( - src_shape, - dest_shape, - srccenter: Sequence[int], - destcenter: Sequence[int], - dims: Sequence[Optional[int]], + src_shape, dest_shape, srccenter: Sequence[int], destcenter: Sequence[int], dims: Sequence[Optional[int]] ) -> Tuple[Tuple[slice, ...], Tuple[slice, ...]]: """ Calculate the slices to copy a sliced area of array in `src_shape` into array in `dest_shape`. @@ -256,10 +278,8 @@ def resize_center(img: np.ndarray, *resize_dims: Optional[int], fill_value: floa def map_binary_to_indices( - label: np.ndarray, - image: Optional[np.ndarray] = None, - image_threshold: float = 0.0, -) -> Tuple[np.ndarray, np.ndarray]: + label: NdarrayOrTensor, image: Optional[NdarrayOrTensor] = None, image_threshold: float = 0.0 +) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: """ Compute the foreground and background of input label data, return the indices after fattening. For example: @@ -272,28 +292,32 @@ def map_binary_to_indices( to define background. so the output items will not map to all the voxels in the label. image_threshold: if enabled `image`, use ``image > image_threshold`` to determine the valid image content area and select background only in this area. - """ + # Prepare fg/bg indices if label.shape[0] > 1: label = label[1:] # for One-Hot format data, remove the background channel - label_flat = np.any(label, axis=0).ravel() # in case label has multiple dimensions - fg_indices = np.nonzero(label_flat)[0] + label_flat = ravel(any_np_pt(label, 0)) # in case label has multiple dimensions + fg_indices = nonzero(label_flat) if image is not None: - img_flat = np.any(image > image_threshold, axis=0).ravel() - bg_indices = np.nonzero(np.logical_and(img_flat, ~label_flat))[0] + img_flat = ravel(any_np_pt(image > image_threshold, 0)) + img_flat, *_ = convert_to_dst_type(img_flat, label, dtype=img_flat.dtype) + bg_indices = nonzero(img_flat & ~label_flat) else: - bg_indices = np.nonzero(~label_flat)[0] + bg_indices = nonzero(~label_flat) + # no need to save the indices in GPU, otherwise, still need to move to CPU at runtime when crop by indices + fg_indices, *_ = convert_data_type(fg_indices, device=torch.device("cpu")) + bg_indices, *_ = convert_data_type(bg_indices, device=torch.device("cpu")) return fg_indices, bg_indices def map_classes_to_indices( - label: np.ndarray, + label: NdarrayOrTensor, num_classes: Optional[int] = None, - image: Optional[np.ndarray] = None, + image: Optional[NdarrayOrTensor] = None, image_threshold: float = 0.0, -) -> List[np.ndarray]: +) -> List[NdarrayOrTensor]: """ Filter out indices of every class of the input label data, return the indices after fattening. It can handle both One-Hot format label and Argmax format label, must provide `num_classes` for @@ -313,11 +337,11 @@ def map_classes_to_indices( determine the valid image content area and select class indices only in this area. """ - img_flat: Optional[np.ndarray] = None + img_flat: Optional[NdarrayOrTensor] = None if image is not None: - img_flat = np.any(image > image_threshold, axis=0).ravel() + img_flat = ravel((image > image_threshold).any(0)) - indices: List[np.ndarray] = [] + indices: List[NdarrayOrTensor] = [] # assuming the first dimension is channel channels = len(label) @@ -328,16 +352,18 @@ def map_classes_to_indices( num_classes_ = num_classes for c in range(num_classes_): - label_flat = np.any(label[c : c + 1] if channels > 1 else label == c, axis=0).ravel() - label_flat = np.logical_and(img_flat, label_flat) if img_flat is not None else label_flat - indices.append(np.nonzero(label_flat)[0]) + label_flat = ravel(any_np_pt(label[c : c + 1] if channels > 1 else label == c, 0)) + label_flat = img_flat & label_flat if img_flat is not None else label_flat + # no need to save the indices in GPU, otherwise, still need to move to CPU at runtime when crop by indices + cls_indices, *_ = convert_data_type(nonzero(label_flat), device=torch.device("cpu")) + indices.append(cls_indices) return indices def weighted_patch_samples( spatial_size: Union[int, Sequence[int]], - w: np.ndarray, + w: NdarrayOrTensor, n_samples: int = 1, r_state: Optional[np.random.RandomState] = None, ) -> List: @@ -366,34 +392,45 @@ def weighted_patch_samples( s = tuple(slice(w // 2, m - w + w // 2) if m > w else slice(m // 2, m // 2 + 1) for w, m in zip(win_size, img_size)) v = w[s] # weight map in the 'valid' mode v_size = v.shape - v = v.ravel() - if np.any(v < 0): - v -= np.min(v) # shifting to non-negative - v = v.cumsum() - if not v[-1] or not np.isfinite(v[-1]) or v[-1] < 0: # uniform sampling + v = ravel(v) + if (v < 0).any(): + v -= v.min() # shifting to non-negative + v = cumsum(v) + if not v[-1] or not isfinite(v[-1]) or v[-1] < 0: # uniform sampling idx = r_state.randint(0, len(v), size=n_samples) else: - idx = v.searchsorted(r_state.random(n_samples) * v[-1], side="right") + r, *_ = convert_to_dst_type(r_state.random(n_samples), v) # type: ignore + idx = searchsorted(v, r * v[-1], right=True) + idx, *_ = convert_to_dst_type(idx, v, dtype=torch.int) # type: ignore # compensate 'valid' mode diff = np.minimum(win_size, img_size) // 2 - return [np.unravel_index(i, v_size) + diff for i in np.asarray(idx, dtype=int)] + diff, *_ = convert_to_dst_type(diff, v) # type: ignore + return [unravel_index(i, v_size) + diff for i in idx] def correct_crop_centers( - centers: List[np.ndarray], spatial_size: Union[Sequence[int], int], label_spatial_shape: Sequence[int] -) -> List[np.ndarray]: + centers: List[Union[int, torch.Tensor]], + spatial_size: Union[Sequence[int], int], + label_spatial_shape: Sequence[int], + allow_smaller: bool = False, +): """ - Utility to correct the crop center if the crop size is bigger than the image size. + Utility to correct the crop center if the crop size and centers are not compatible with the image size. Args: - ceters: pre-computed crop centers, will correct based on the valid region. + centers: pre-computed crop centers of every dim, will correct based on the valid region. spatial_size: spatial size of the ROIs to be sampled. label_spatial_shape: spatial shape of the original label data to compare with ROI. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will be set to + match the cropped size (i.e., no cropping in that dimension). """ spatial_size = fall_back_tuple(spatial_size, default=label_spatial_shape) - if not (np.subtract(label_spatial_shape, spatial_size) >= 0).all(): - raise ValueError("The size of the proposed random crop ROI is larger than the image size.") + if any(np.subtract(label_spatial_shape, spatial_size) < 0): + if not allow_smaller: + raise ValueError("The size of the proposed random crop ROI is larger than the image size.") + spatial_size = tuple(min(l, s) for l, s in zip(label_spatial_shape, spatial_size)) # Select subregion to assure valid roi valid_start = np.floor_divide(spatial_size, 2) @@ -405,16 +442,12 @@ def correct_crop_centers( # need this because np.random.randint does not work with same start and end if valid_s == valid_end[i]: valid_end[i] += 1 - - for i, c in enumerate(centers): - center_i = c - if c < valid_start[i]: - center_i = valid_start[i] - if c >= valid_end[i]: - center_i = valid_end[i] - 1 - centers[i] = center_i - - return centers + valid_centers = [] + for c, v_s, v_e in zip(centers, valid_start, valid_end): + _c = int(convert_data_type(c, np.ndarray)[0]) # type: ignore + center_i = min(max(_c, v_s), v_e - 1) + valid_centers.append(int(center_i)) + return valid_centers def generate_pos_neg_label_crop_centers( @@ -422,10 +455,11 @@ def generate_pos_neg_label_crop_centers( num_samples: int, pos_ratio: float, label_spatial_shape: Sequence[int], - fg_indices: np.ndarray, - bg_indices: np.ndarray, + fg_indices: NdarrayOrTensor, + bg_indices: NdarrayOrTensor, rand_state: Optional[np.random.RandomState] = None, -) -> List[List[np.ndarray]]: + allow_smaller: bool = False, +) -> List[List[int]]: """ Generate valid sample locations based on the label with option for specifying foreground ratio Valid: samples sitting entirely within image, expected input shape: [C, H, W, D] or [C, H, W] @@ -438,6 +472,9 @@ def generate_pos_neg_label_crop_centers( fg_indices: pre-computed foreground indices in 1 dimension. bg_indices: pre-computed background indices in 1 dimension. rand_state: numpy randomState object to align with other modules. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will be set to + match the cropped size (i.e., no cropping in that dimension). Raises: ValueError: When the proposed roi is larger than the image. @@ -448,11 +485,12 @@ def generate_pos_neg_label_crop_centers( rand_state = np.random.random.__self__ # type: ignore centers = [] - fg_indices, bg_indices = np.asarray(fg_indices), np.asarray(bg_indices) - if fg_indices.size == 0 and bg_indices.size == 0: + fg_indices = np.asarray(fg_indices) if isinstance(fg_indices, Sequence) else fg_indices + bg_indices = np.asarray(bg_indices) if isinstance(bg_indices, Sequence) else bg_indices + if len(fg_indices) == 0 and len(bg_indices) == 0: raise ValueError("No sampling location available.") - if fg_indices.size == 0 or bg_indices.size == 0: + if len(fg_indices) == 0 or len(bg_indices) == 0: warnings.warn( f"N foreground {len(fg_indices)}, N background {len(bg_indices)}," "unable to generate class balanced samples." @@ -462,10 +500,10 @@ def generate_pos_neg_label_crop_centers( for _ in range(num_samples): indices_to_use = fg_indices if rand_state.rand() < pos_ratio else bg_indices random_int = rand_state.randint(len(indices_to_use)) - center = np.unravel_index(indices_to_use[random_int], label_spatial_shape) + idx = indices_to_use[random_int] + center = unravel_index(idx, label_spatial_shape) # shift center to range of valid centers - center_ori = list(center) - centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape)) + centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape, allow_smaller)) return centers @@ -474,10 +512,11 @@ def generate_label_classes_crop_centers( spatial_size: Union[Sequence[int], int], num_samples: int, label_spatial_shape: Sequence[int], - indices: List[np.ndarray], + indices: Sequence[NdarrayOrTensor], ratios: Optional[List[Union[float, int]]] = None, rand_state: Optional[np.random.RandomState] = None, -) -> List[List[np.ndarray]]: + allow_smaller: bool = False, +) -> List[List[int]]: """ Generate valid sample locations based on the specified ratios of label classes. Valid: samples sitting entirely within image, expected input shape: [C, H, W, D] or [C, H, W] @@ -490,6 +529,9 @@ def generate_label_classes_crop_centers( ratios: ratios of every class in the label to generate crop centers, including background class. if None, every class will have the same ratio to generate crop centers. rand_state: numpy randomState object to align with other modules. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will be set to + match the cropped size (i.e., no cropping in that dimension). """ if rand_state is None: @@ -499,12 +541,10 @@ def generate_label_classes_crop_centers( raise ValueError("num_samples must be an int number and greater than 0.") ratios_: List[Union[float, int]] = ([1] * len(indices)) if ratios is None else ratios if len(ratios_) != len(indices): - raise ValueError("random crop radios must match the number of indices of classes.") + raise ValueError("random crop ratios must match the number of indices of classes.") if any(i < 0 for i in ratios_): raise ValueError("ratios should not contain negative number.") - # ensure indices are numpy array - indices = [np.asarray(i) for i in indices] for i, array in enumerate(indices): if len(array) == 0: warnings.warn(f"no available indices of class {i} to crop, set the crop ratio of this class to zero.") @@ -516,10 +556,10 @@ def generate_label_classes_crop_centers( # randomly select the indices of a class based on the ratios indices_to_use = indices[i] random_int = rand_state.randint(len(indices_to_use)) - center = np.unravel_index(indices_to_use[random_int], label_spatial_shape) + center = unravel_index(indices_to_use[random_int], label_spatial_shape) # shift center to range of valid centers center_ori = list(center) - centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape)) + centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape, allow_smaller)) return centers @@ -528,7 +568,9 @@ def create_grid( spatial_size: Sequence[int], spacing: Optional[Sequence[float]] = None, homogeneous: bool = True, - dtype: DtypeLike = float, + dtype=float, + device: Optional[torch.device] = None, + backend=TransformBackends.NUMPY, ): """ compute a `spatial_size` mesh. @@ -538,6 +580,26 @@ def create_grid( spacing: same len as ``spatial_size``, defaults to 1.0 (dense grid). homogeneous: whether to make homogeneous coordinates. dtype: output grid data type. + device: device to compute and store the output (when the backend is "torch"). + backend: APIs to use, ``numpy`` or ``torch``. + + """ + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + return _create_grid_numpy(spatial_size, spacing, homogeneous, dtype) + if _backend == TransformBackends.TORCH: + return _create_grid_torch(spatial_size, spacing, homogeneous, dtype, device) + raise ValueError(f"backend {backend} is not supported") + + +def _create_grid_numpy( + spatial_size: Sequence[int], + spacing: Optional[Sequence[float]] = None, + homogeneous: bool = True, + dtype: DtypeLike = float, +): + """ + compute a `spatial_size` mesh with the numpy API. """ spacing = spacing or tuple(1.0 for _ in spatial_size) ranges = [np.linspace(-(d - 1.0) / 2.0 * s, (d - 1.0) / 2.0 * s, int(d)) for d, s in zip(spatial_size, spacing)] @@ -547,23 +609,58 @@ def create_grid( return np.concatenate([coords, np.ones_like(coords[:1])]) +def _create_grid_torch( + spatial_size: Sequence[int], + spacing: Optional[Sequence[float]] = None, + homogeneous: bool = True, + dtype=torch.float32, + device: Optional[torch.device] = None, +): + """ + compute a `spatial_size` mesh with the torch API. + """ + spacing = spacing or tuple(1.0 for _ in spatial_size) + ranges = [ + torch.linspace(-(d - 1.0) / 2.0 * s, (d - 1.0) / 2.0 * s, int(d), device=device, dtype=dtype) + for d, s in zip(spatial_size, spacing) + ] + coords = torch.meshgrid(*ranges) + if not homogeneous: + return torch.stack(coords) + return torch.stack([*coords, torch.ones_like(coords[0])]) + + def create_control_grid( - spatial_shape: Sequence[int], spacing: Sequence[float], homogeneous: bool = True, dtype: DtypeLike = float + spatial_shape: Sequence[int], + spacing: Sequence[float], + homogeneous: bool = True, + dtype: DtypeLike = float, + device: Optional[torch.device] = None, + backend=TransformBackends.NUMPY, ): """ control grid with two additional point in each direction """ + torch_backend = look_up_option(backend, TransformBackends) == TransformBackends.TORCH + ceil_func: Callable = torch.ceil if torch_backend else np.ceil # type: ignore grid_shape = [] for d, s in zip(spatial_shape, spacing): - d = int(d) + d = torch.as_tensor(d, device=device) if torch_backend else int(d) # type: ignore if d % 2 == 0: - grid_shape.append(np.ceil((d - 1.0) / (2.0 * s) + 0.5) * 2.0 + 2.0) + grid_shape.append(ceil_func((d - 1.0) / (2.0 * s) + 0.5) * 2.0 + 2.0) else: - grid_shape.append(np.ceil((d - 1.0) / (2.0 * s)) * 2.0 + 3.0) - return create_grid(grid_shape, spacing, homogeneous, dtype) + grid_shape.append(ceil_func((d - 1.0) / (2.0 * s)) * 2.0 + 3.0) + return create_grid( + spatial_size=grid_shape, spacing=spacing, homogeneous=homogeneous, dtype=dtype, device=device, backend=backend + ) -def create_rotate(spatial_dims: int, radians: Union[Sequence[float], float]) -> np.ndarray: +def create_rotate( + spatial_dims: int, + radians: Union[Sequence[float], float], + device: Optional[torch.device] = None, + backend=TransformBackends.NUMPY, +) -> NdarrayOrTensor: """ create a 2D or 3D rotation matrix @@ -572,48 +669,83 @@ def create_rotate(spatial_dims: int, radians: Union[Sequence[float], float]) -> radians: rotation radians when spatial_dims == 3, the `radians` sequence corresponds to rotation in the 1st, 2nd, and 3rd dim respectively. + device: device to compute and store the output (when the backend is "torch"). + backend: APIs to use, ``numpy`` or ``torch``. Raises: ValueError: When ``radians`` is empty. ValueError: When ``spatial_dims`` is not one of [2, 3]. """ + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + return _create_rotate( + spatial_dims=spatial_dims, radians=radians, sin_func=np.sin, cos_func=np.cos, eye_func=np.eye + ) + if _backend == TransformBackends.TORCH: + return _create_rotate( + spatial_dims=spatial_dims, + radians=radians, + sin_func=lambda th: torch.sin(torch.as_tensor(th, dtype=torch.float32, device=device)), + cos_func=lambda th: torch.cos(torch.as_tensor(th, dtype=torch.float32, device=device)), + eye_func=lambda rank: torch.eye(rank, device=device), + ) + raise ValueError(f"backend {backend} is not supported") + + +def _create_rotate( + spatial_dims: int, + radians: Union[Sequence[float], float], + sin_func: Callable = np.sin, + cos_func: Callable = np.cos, + eye_func: Callable = np.eye, +) -> NdarrayOrTensor: radians = ensure_tuple(radians) if spatial_dims == 2: if len(radians) >= 1: - sin_, cos_ = np.sin(radians[0]), np.cos(radians[0]) - return np.array([[cos_, -sin_, 0.0], [sin_, cos_, 0.0], [0.0, 0.0, 1.0]]) + sin_, cos_ = sin_func(radians[0]), cos_func(radians[0]) + out = eye_func(3) + out[0, 0], out[0, 1] = cos_, -sin_ + out[1, 0], out[1, 1] = sin_, cos_ + return out # type: ignore raise ValueError("radians must be non empty.") if spatial_dims == 3: affine = None if len(radians) >= 1: - sin_, cos_ = np.sin(radians[0]), np.cos(radians[0]) - affine = np.array( - [[1.0, 0.0, 0.0, 0.0], [0.0, cos_, -sin_, 0.0], [0.0, sin_, cos_, 0.0], [0.0, 0.0, 0.0, 1.0]] - ) + sin_, cos_ = sin_func(radians[0]), cos_func(radians[0]) + affine = eye_func(4) + affine[1, 1], affine[1, 2] = cos_, -sin_ + affine[2, 1], affine[2, 2] = sin_, cos_ if len(radians) >= 2: - sin_, cos_ = np.sin(radians[1]), np.cos(radians[1]) + sin_, cos_ = sin_func(radians[1]), cos_func(radians[1]) if affine is None: raise ValueError("Affine should be a matrix.") - affine = affine @ np.array( - [[cos_, 0.0, sin_, 0.0], [0.0, 1.0, 0.0, 0.0], [-sin_, 0.0, cos_, 0.0], [0.0, 0.0, 0.0, 1.0]] - ) + _affine = eye_func(4) + _affine[0, 0], _affine[0, 2] = cos_, sin_ + _affine[2, 0], _affine[2, 2] = -sin_, cos_ + affine = affine @ _affine if len(radians) >= 3: - sin_, cos_ = np.sin(radians[2]), np.cos(radians[2]) + sin_, cos_ = sin_func(radians[2]), cos_func(radians[2]) if affine is None: raise ValueError("Affine should be a matrix.") - affine = affine @ np.array( - [[cos_, -sin_, 0.0, 0.0], [sin_, cos_, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] - ) + _affine = eye_func(4) + _affine[0, 0], _affine[0, 1] = cos_, -sin_ + _affine[1, 0], _affine[1, 1] = sin_, cos_ + affine = affine @ _affine if affine is None: raise ValueError("radians must be non empty.") - return affine + return affine # type: ignore raise ValueError(f"Unsupported spatial_dims: {spatial_dims}, available options are [2, 3].") -def create_shear(spatial_dims: int, coefs: Union[Sequence[float], float]) -> np.ndarray: +def create_shear( + spatial_dims: int, + coefs: Union[Sequence[float], float], + device: Optional[torch.device] = None, + backend=TransformBackends.NUMPY, +) -> NdarrayOrTensor: """ create a shearing matrix @@ -629,61 +761,119 @@ def create_shear(spatial_dims: int, coefs: Union[Sequence[float], float]) -> np. [0.0, 0.0, 0.0, 1.0], ] + device: device to compute and store the output (when the backend is "torch"). + backend: APIs to use, ``numpy`` or ``torch``. + Raises: NotImplementedError: When ``spatial_dims`` is not one of [2, 3]. """ + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + return _create_shear(spatial_dims=spatial_dims, coefs=coefs, eye_func=np.eye) + if _backend == TransformBackends.TORCH: + return _create_shear( + spatial_dims=spatial_dims, coefs=coefs, eye_func=lambda rank: torch.eye(rank, device=device) + ) + raise ValueError(f"backend {backend} is not supported") + + +def _create_shear(spatial_dims: int, coefs: Union[Sequence[float], float], eye_func=np.eye) -> NdarrayOrTensor: if spatial_dims == 2: coefs = ensure_tuple_size(coefs, dim=2, pad_val=0.0) - return np.array([[1, coefs[0], 0.0], [coefs[1], 1.0, 0.0], [0.0, 0.0, 1.0]]) + out = eye_func(3) + out[0, 1], out[1, 0] = coefs[0], coefs[1] + return out # type: ignore if spatial_dims == 3: coefs = ensure_tuple_size(coefs, dim=6, pad_val=0.0) - return np.array( - [ - [1.0, coefs[0], coefs[1], 0.0], - [coefs[2], 1.0, coefs[3], 0.0], - [coefs[4], coefs[5], 1.0, 0.0], - [0.0, 0.0, 0.0, 1.0], - ] - ) + out = eye_func(4) + out[0, 1], out[0, 2] = coefs[0], coefs[1] + out[1, 0], out[1, 2] = coefs[2], coefs[3] + out[2, 0], out[2, 1] = coefs[4], coefs[5] + return out # type: ignore raise NotImplementedError("Currently only spatial_dims in [2, 3] are supported.") -def create_scale(spatial_dims: int, scaling_factor: Union[Sequence[float], float]): +def create_scale( + spatial_dims: int, + scaling_factor: Union[Sequence[float], float], + device: Optional[torch.device] = None, + backend=TransformBackends.NUMPY, +) -> NdarrayOrTensor: """ create a scaling matrix Args: spatial_dims: spatial rank scaling_factor: scaling factors for every spatial dim, defaults to 1. - """ + device: device to compute and store the output (when the backend is "torch"). + backend: APIs to use, ``numpy`` or ``torch``. + """ + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + return _create_scale(spatial_dims=spatial_dims, scaling_factor=scaling_factor, array_func=np.diag) + if _backend == TransformBackends.TORCH: + return _create_scale( + spatial_dims=spatial_dims, + scaling_factor=scaling_factor, + array_func=lambda x: torch.diag(torch.as_tensor(x, device=device)), + ) + raise ValueError(f"backend {backend} is not supported") + + +def _create_scale( + spatial_dims: int, scaling_factor: Union[Sequence[float], float], array_func=np.diag +) -> NdarrayOrTensor: scaling_factor = ensure_tuple_size(scaling_factor, dim=spatial_dims, pad_val=1.0) - return np.diag(scaling_factor[:spatial_dims] + (1.0,)) + return array_func(scaling_factor[:spatial_dims] + (1.0,)) # type: ignore -def create_translate(spatial_dims: int, shift: Union[Sequence[float], float]) -> np.ndarray: +def create_translate( + spatial_dims: int, + shift: Union[Sequence[float], float], + device: Optional[torch.device] = None, + backend=TransformBackends.NUMPY, +) -> NdarrayOrTensor: """ create a translation matrix Args: spatial_dims: spatial rank shift: translate pixel/voxel for every spatial dim, defaults to 0. - """ + device: device to compute and store the output (when the backend is "torch"). + backend: APIs to use, ``numpy`` or ``torch``. + """ + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + return _create_translate(spatial_dims=spatial_dims, shift=shift, eye_func=np.eye, array_func=np.asarray) + if _backend == TransformBackends.TORCH: + return _create_translate( + spatial_dims=spatial_dims, + shift=shift, + eye_func=lambda x: torch.eye(torch.as_tensor(x), device=device), # type: ignore + array_func=lambda x: torch.as_tensor(x, device=device), # type: ignore + ) + raise ValueError(f"backend {backend} is not supported") + + +def _create_translate( + spatial_dims: int, shift: Union[Sequence[float], float], eye_func=np.eye, array_func=np.asarray +) -> NdarrayOrTensor: shift = ensure_tuple(shift) - affine = np.eye(spatial_dims + 1) + affine = eye_func(spatial_dims + 1) for i, a in enumerate(shift[:spatial_dims]): affine[i, spatial_dims] = a - return np.asarray(affine) + return array_func(affine) # type: ignore def generate_spatial_bounding_box( - img: np.ndarray, + img: NdarrayOrTensor, select_fn: Callable = is_positive, channel_indices: Optional[IndexSelection] = None, margin: Union[Sequence[int], int] = 0, ) -> Tuple[List[int], List[int]]: """ - generate the spatial bounding box of foreground in the image with start-end positions. + generate the spatial bounding box of foreground in the image with start-end positions (inclusive). Users can define arbitrary function to select expected foreground from the whole image or specified channels. And it can also add margin to every dim of the bounding box. The output format of the coordinates is: @@ -702,7 +892,7 @@ def generate_spatial_bounding_box( margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. """ data = img[list(ensure_tuple(channel_indices))] if channel_indices is not None else img - data = np.any(select_fn(data), axis=0) + data = select_fn(data).any(0) ndim = len(data.shape) margin = ensure_tuple_rep(margin, ndim) for m in margin: @@ -713,19 +903,25 @@ def generate_spatial_bounding_box( box_end = [0] * ndim for di, ax in enumerate(itertools.combinations(reversed(range(ndim)), ndim - 1)): - dt = data.any(axis=ax) - if not np.any(dt): + dt = data + if len(ax) != 0: + dt = any_np_pt(dt, ax) + + if not dt.any(): # if no foreground, return all zero bounding box coords return [0] * ndim, [0] * ndim - min_d = max(np.argmax(dt) - margin[di], 0) - max_d = max(data.shape[di] - max(np.argmax(dt[::-1]) - margin[di], 0), min_d + 1) - box_start[di], box_end[di] = min_d, max_d + arg_max = where(dt == dt.max())[0] + min_d = max(arg_max[0] - margin[di], 0) + max_d = arg_max[-1] + margin[di] + 1 + + box_start[di] = min_d.detach().cpu().item() if isinstance(min_d, torch.Tensor) else min_d # type: ignore + box_end[di] = max_d.detach().cpu().item() if isinstance(max_d, torch.Tensor) else max_d # type: ignore return box_start, box_end -def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Optional[int] = None) -> torch.Tensor: +def get_largest_connected_component_mask(img: NdarrayOrTensor, connectivity: Optional[int] = None) -> NdarrayOrTensor: """ Gets the largest connected component mask of an image. @@ -735,13 +931,13 @@ def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Option Accepted values are ranging from 1 to input.ndim. If ``None``, a full connectivity of ``input.ndim`` is used. """ - img_arr = img.detach().cpu().numpy() - largest_cc = np.zeros(shape=img_arr.shape, dtype=img_arr.dtype) + img_arr: np.ndarray = convert_data_type(img, np.ndarray)[0] # type: ignore + largest_cc: np.ndarray = np.zeros(shape=img_arr.shape, dtype=img_arr.dtype) img_arr = measure.label(img_arr, connectivity=connectivity) if img_arr.max() != 0: largest_cc[...] = img_arr == (np.argmax(np.bincount(img_arr.flat)[1:]) + 1) - - return torch.as_tensor(largest_cc, device=img.device) + largest_cc = convert_to_dst_type(largest_cc, dst=img, dtype=largest_cc.dtype)[0] # type: ignore + return largest_cc def fill_holes( @@ -804,7 +1000,7 @@ def fill_holes( def get_extreme_points( - img: np.ndarray, rand_state: Optional[np.random.RandomState] = None, background: int = 0, pert: float = 0.0 + img: NdarrayOrTensor, rand_state: Optional[np.random.RandomState] = None, background: int = 0, pert: float = 0.0 ) -> List[Tuple[int, ...]]: """ Generate extreme points from an image. These are used to generate initial segmentation @@ -828,7 +1024,7 @@ def get_extreme_points( """ if rand_state is None: rand_state = np.random.random.__self__ # type: ignore - indices = np.where(img != background) + indices = where(img != background) if np.size(indices[0]) == 0: raise ValueError("get_extreme_points: no foreground object in mask!") @@ -840,7 +1036,9 @@ def _get_point(val, dim): val : value for comparison dim : dimension in which to look for value """ - idx = rand_state.choice(np.where(indices[dim] == val)[0]) + idx = where(indices[dim] == val)[0] + idx = idx.cpu() if isinstance(idx, torch.Tensor) else idx + idx = rand_state.choice(idx) pt = [] for j in range(img.ndim): # add +- pert to each dimension @@ -852,19 +1050,19 @@ def _get_point(val, dim): points = [] for i in range(img.ndim): - points.append(tuple(_get_point(np.min(indices[i][...]), i))) - points.append(tuple(_get_point(np.max(indices[i][...]), i))) + points.append(tuple(_get_point(indices[i].min(), i))) + points.append(tuple(_get_point(indices[i].max(), i))) return points def extreme_points_to_image( points: List[Tuple[int, ...]], - label: np.ndarray, + label: NdarrayOrTensor, sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 0.0, rescale_min: float = -1.0, rescale_max: float = 1.0, -): +) -> torch.Tensor: """ Please refer to :py:class:`monai.transforms.AddExtremePointsChannel` for the usage. @@ -882,27 +1080,30 @@ def extreme_points_to_image( rescale_max: maximum value of output data. """ # points to image - points_image = torch.zeros(label.shape[1:], dtype=torch.float) + # points_image = torch.zeros(label.shape[1:], dtype=torch.float) + points_image = torch.zeros_like(torch.as_tensor(label[0]), dtype=torch.float) for p in points: points_image[p] = 1.0 + if isinstance(sigma, Sequence): + sigma = [torch.as_tensor(s, device=points_image.device) for s in sigma] + else: + sigma = torch.as_tensor(sigma, device=points_image.device) + # add channel and add batch points_image = points_image.unsqueeze(0).unsqueeze(0) gaussian_filter = GaussianFilter(label.ndim - 1, sigma=sigma) - points_image = gaussian_filter(points_image).squeeze(0).detach().numpy() + points_image = gaussian_filter(points_image).squeeze(0).detach() # rescale the points image to [rescale_min, rescale_max] - min_intensity = np.min(points_image) - max_intensity = np.max(points_image) + min_intensity = points_image.min() + max_intensity = points_image.max() points_image = (points_image - min_intensity) / (max_intensity - min_intensity) - points_image = points_image * (rescale_max - rescale_min) + rescale_min - return points_image + return points_image * (rescale_max - rescale_min) + rescale_min def map_spatial_axes( - img_ndim: int, - spatial_axes: Optional[Union[Sequence[int], int]] = None, - channel_first: bool = True, + img_ndim: int, spatial_axes: Optional[Union[Sequence[int], int]] = None, channel_first: bool = True ) -> List[int]: """ Utility to map the spatial axes to real axes in channel first/last shape. @@ -989,7 +1190,7 @@ def allow_missing_keys_mode(transform: Union[MapTransform, Compose, Tuple[MapTra def convert_inverse_interp_mode(trans_info: List, mode: str = "nearest", align_corners: Optional[bool] = None): """ Change the interpolation mode when inverting spatial transforms, default to "nearest". - This function modifies trans_info's `InverseKeys.EXTRA_INFO`. + This function modifies trans_info's `TraceKeys.EXTRA_INFO`. See also: :py:class:`monai.transform.inverse.InvertibleTransform` @@ -1002,21 +1203,21 @@ def convert_inverse_interp_mode(trans_info: List, mode: str = "nearest", align_c interp_modes = [i.value for i in InterpolateMode] + [i.value for i in GridSampleMode] # set to string for DataLoader collation - align_corners_ = "none" if align_corners is None else align_corners + align_corners_ = TraceKeys.NONE if align_corners is None else align_corners for item in ensure_tuple(trans_info): - if InverseKeys.EXTRA_INFO in item: - orig_mode = item[InverseKeys.EXTRA_INFO].get("mode", None) + if TraceKeys.EXTRA_INFO in item: + orig_mode = item[TraceKeys.EXTRA_INFO].get("mode", None) if orig_mode is not None: if orig_mode[0] in interp_modes: - item[InverseKeys.EXTRA_INFO]["mode"] = [mode for _ in range(len(mode))] + item[TraceKeys.EXTRA_INFO]["mode"] = [mode for _ in range(len(mode))] elif orig_mode in interp_modes: - item[InverseKeys.EXTRA_INFO]["mode"] = mode - if "align_corners" in item[InverseKeys.EXTRA_INFO]: - if issequenceiterable(item[InverseKeys.EXTRA_INFO]["align_corners"]): - item[InverseKeys.EXTRA_INFO]["align_corners"] = [align_corners_ for _ in range(len(mode))] + item[TraceKeys.EXTRA_INFO]["mode"] = mode + if "align_corners" in item[TraceKeys.EXTRA_INFO]: + if issequenceiterable(item[TraceKeys.EXTRA_INFO]["align_corners"]): + item[TraceKeys.EXTRA_INFO]["align_corners"] = [align_corners_ for _ in range(len(mode))] else: - item[InverseKeys.EXTRA_INFO]["align_corners"] = align_corners_ + item[TraceKeys.EXTRA_INFO]["align_corners"] = align_corners_ return trans_info @@ -1041,12 +1242,7 @@ def compute_divisible_spatial_size(spatial_shape: Sequence[int], k: Union[Sequen def equalize_hist( - img: np.ndarray, - mask: Optional[np.ndarray] = None, - num_bins: int = 256, - min: int = 0, - max: int = 255, - dtype: DtypeLike = np.float32, + img: np.ndarray, mask: Optional[np.ndarray] = None, num_bins: int = 256, min: int = 0, max: int = 255 ) -> np.ndarray: """ Utility to equalize input image based on the histogram. @@ -1061,9 +1257,9 @@ def equalize_hist( https://numpy.org/doc/stable/reference/generated/numpy.histogram.html. min: the min value to normalize input image, default to `0`. max: the max value to normalize input image, default to `255`. - dtype: data type of the output, default to `float32`. """ + orig_shape = img.shape hist_img = img[np.array(mask, dtype=bool)] if mask is not None else img if has_skimage: @@ -1078,8 +1274,7 @@ def equalize_hist( # apply linear interpolation img = np.interp(img.flatten(), bins, cum) - - return img.reshape(orig_shape).astype(dtype) + return img.reshape(orig_shape) class Fourier: @@ -1088,38 +1283,70 @@ class Fourier: """ @staticmethod - def shift_fourier(x: torch.Tensor, n_dims: int) -> torch.Tensor: + @deprecated_arg( + name="n_dims", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." + ) + def shift_fourier(x: NdarrayOrTensor, spatial_dims: int, n_dims: Optional[int] = None) -> NdarrayOrTensor: """ Applies fourier transform and shifts the zero-frequency component to the center of the spectrum. Only the spatial dimensions get transformed. Args: x: Image to transform. - n_dims: Number of spatial dimensions. + spatial_dims: Number of spatial dimensions. + + .. deprecated:: 0.6.0 + ``n_dims`` is deprecated, use ``spatial_dims`` instead. + Returns k: K-space data. """ - k: torch.Tensor = torch.fft.fftshift( - torch.fft.fftn(x, dim=tuple(range(-n_dims, 0))), dim=tuple(range(-n_dims, 0)) - ) + if n_dims is not None: + spatial_dims = n_dims + dims = tuple(range(-spatial_dims, 0)) + k: NdarrayOrTensor + if isinstance(x, torch.Tensor): + if hasattr(torch.fft, "fftshift"): # `fftshift` is new in torch 1.8.0 + k = torch.fft.fftshift(torch.fft.fftn(x, dim=dims), dim=dims) + else: + # if using old PyTorch, will convert to numpy array and return + k = np.fft.fftshift(np.fft.fftn(x.cpu().numpy(), axes=dims), axes=dims) + else: + k = np.fft.fftshift(np.fft.fftn(x, axes=dims), axes=dims) return k @staticmethod - def inv_shift_fourier(k: torch.Tensor, n_dims: int) -> torch.Tensor: + @deprecated_arg( + name="n_dims", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." + ) + def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: Optional[int] = None) -> NdarrayOrTensor: """ Applies inverse shift and fourier transform. Only the spatial dimensions are transformed. Args: k: K-space data. - n_dims: Number of spatial dimensions. + spatial_dims: Number of spatial dimensions. + + .. deprecated:: 0.6.0 + ``n_dims`` is deprecated, use ``spatial_dims`` instead. + Returns: x: Tensor in image space. """ - x: torch.Tensor = torch.fft.ifftn( - torch.fft.ifftshift(k, dim=tuple(range(-n_dims, 0))), dim=tuple(range(-n_dims, 0)) - ).real - return x + if n_dims is not None: + spatial_dims = n_dims + dims = tuple(range(-spatial_dims, 0)) + out: NdarrayOrTensor + if isinstance(k, torch.Tensor): + if hasattr(torch.fft, "ifftshift"): # `ifftshift` is new in torch 1.8.0 + out = torch.fft.ifftn(torch.fft.ifftshift(k, dim=dims), dim=dims, norm="backward").real + else: + # if using old PyTorch, will convert to numpy array and return + out = np.fft.ifftn(np.fft.ifftshift(k.cpu().numpy(), axes=dims), axes=dims).real + else: + out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims).real + return out def get_number_image_type_conversions(transform: Compose, test_data: Any, key: Optional[Hashable] = None) -> int: @@ -1149,9 +1376,7 @@ def _get_data(obj, key): prev_data = _get_data(test_data, key) prev_type = type(prev_data) prev_device = prev_data.device if isinstance(prev_data, torch.Tensor) else None - test_data = monai.transforms.transform.apply_transform( - _transform, test_data, transform.map_items, transform.unpack_items - ) + test_data = apply_transform(_transform, test_data, transform.map_items, transform.unpack_items) # every time the type or device changes, increment the counter curr_data = _get_data(test_data, key) curr_device = curr_data.device if isinstance(curr_data, torch.Tensor) else None @@ -1178,24 +1403,30 @@ def get_transform_backends(): continue unique_transforms.append(obj) - if isclass(obj) and issubclass(obj, Transform): - if n in [ - "Transform", + if ( + isclass(obj) + and issubclass(obj, Transform) + and n + not in [ + "BatchInverseTransform", + "Compose", + "Decollated", + "InvertD", "InvertibleTransform", "Lambda", "LambdaD", - "Compose", - "RandomizableTransform", + "MapTransform", "OneOf", - "BatchInverseTransform", - "InverteD", - ]: - continue - - backends[n] = [ - TransformBackends.TORCH in obj.backend, - TransformBackends.NUMPY in obj.backend, + "PadListDataCollate", + "RandLambda", + "RandLambdaD", + "RandTorchVisionD", + "RandomizableTransform", + "TorchVisionD", + "Transform", ] + ): + backends[n] = [TransformBackends.TORCH in obj.backend, TransformBackends.NUMPY in obj.backend] return backends @@ -1212,7 +1443,7 @@ def print_color(t, color): print(f"\033[{color}m{t}\033[00m") def print_table_column(name, torch, numpy, color=Colors.none): - print_color("{:<50} {:<8} {:<8}".format(name, torch, numpy), color) + print_color(f"{name:<50} {torch:<8} {numpy:<8}", color) backends = get_transform_backends() n_total = len(backends) @@ -1240,5 +1471,30 @@ def print_table_column(name, torch, numpy, color=Colors.none): print_color(f"Number of uncategorised: {n_uncategorized}", Colors.red) +def convert_pad_mode(dst: NdarrayOrTensor, mode: Union[NumpyPadMode, PytorchPadMode, str]): + """ + Utility to convert padding mode between numpy array and PyTorch Tensor. + + Args: + dst: target data to convert padding mode for, should be numpy array or PyTorch Tensor. + mode: current padding mode. + + """ + mode = mode.value if isinstance(mode, (NumpyPadMode, PytorchPadMode)) else mode + if isinstance(dst, torch.Tensor): + if mode == "wrap": + mode = "circular" + if mode == "edge": + mode = "replicate" + return look_up_option(mode, PytorchPadMode) + if isinstance(dst, np.ndarray): + if mode == "circular": + mode = "wrap" + if mode == "replicate": + mode = "edge" + return look_up_option(mode, NumpyPadMode) + raise ValueError(f"unsupported data type: {type(dst)}.") + + if __name__ == "__main__": print_transform_backends() diff --git a/monai/transforms/utils_create_transform_ims.py b/monai/transforms/utils_create_transform_ims.py new file mode 100644 index 0000000000..ab282d5332 --- /dev/null +++ b/monai/transforms/utils_create_transform_ims.py @@ -0,0 +1,734 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pathlib +import tempfile +import textwrap +from copy import deepcopy +from glob import glob +from typing import TYPE_CHECKING, Callable + +import numpy as np +import torch + +from monai.apps import download_and_extract +from monai.transforms import ( + AddChanneld, + Affine, + Affined, + AsDiscrete, + Compose, + Flip, + Flipd, + LoadImaged, + MapTransform, + Orientation, + Orientationd, + Rand3DElastic, + Rand3DElasticd, + RandFlip, + RandFlipd, + Randomizable, + RandRotate, + RandRotated, + RandZoom, + RandZoomd, + Rotate, + Rotate90, + Rotate90d, + Rotated, + ScaleIntensity, + ScaleIntensityd, + SpatialPadd, + Zoom, + Zoomd, +) +from monai.transforms.croppad.array import ( + BorderPad, + CenterScaleCrop, + CenterSpatialCrop, + CropForeground, + DivisiblePad, + RandCropByLabelClasses, + RandCropByPosNegLabel, + RandScaleCrop, + RandSpatialCrop, + RandSpatialCropSamples, + RandWeightedCrop, + ResizeWithPadOrCrop, + SpatialCrop, + SpatialPad, +) +from monai.transforms.croppad.dictionary import ( + BorderPadd, + CenterScaleCropd, + CenterSpatialCropd, + CropForegroundd, + DivisiblePadd, + RandCropByLabelClassesd, + RandCropByPosNegLabeld, + RandScaleCropd, + RandSpatialCropd, + RandSpatialCropSamplesd, + RandWeightedCropd, + ResizeWithPadOrCropd, + SpatialCropd, +) +from monai.transforms.intensity.array import ( + AdjustContrast, + GaussianSharpen, + GaussianSmooth, + GibbsNoise, + HistogramNormalize, + KSpaceSpikeNoise, + MaskIntensity, + NormalizeIntensity, + RandAdjustContrast, + RandBiasField, + RandCoarseDropout, + RandCoarseShuffle, + RandGaussianNoise, + RandGaussianSharpen, + RandGaussianSmooth, + RandGibbsNoise, + RandHistogramShift, + RandKSpaceSpikeNoise, + RandRicianNoise, + RandScaleIntensity, + RandShiftIntensity, + RandStdShiftIntensity, + SavitzkyGolaySmooth, + ScaleIntensityRange, + ScaleIntensityRangePercentiles, + ShiftIntensity, + StdShiftIntensity, + ThresholdIntensity, +) +from monai.transforms.intensity.dictionary import ( + AdjustContrastd, + GaussianSharpend, + GaussianSmoothd, + GibbsNoised, + HistogramNormalized, + KSpaceSpikeNoised, + MaskIntensityd, + NormalizeIntensityd, + RandAdjustContrastd, + RandBiasFieldd, + RandCoarseDropoutd, + RandCoarseShuffled, + RandGaussianNoised, + RandGaussianSharpend, + RandGaussianSmoothd, + RandGibbsNoised, + RandHistogramShiftd, + RandKSpaceSpikeNoised, + RandRicianNoised, + RandScaleIntensityd, + RandShiftIntensityd, + RandStdShiftIntensityd, + SavitzkyGolaySmoothd, + ScaleIntensityRanged, + ScaleIntensityRangePercentilesd, + ShiftIntensityd, + StdShiftIntensityd, + ThresholdIntensityd, +) +from monai.transforms.post.array import KeepLargestConnectedComponent, LabelFilter, LabelToContour +from monai.transforms.post.dictionary import AsDiscreted, KeepLargestConnectedComponentd, LabelFilterd, LabelToContourd +from monai.transforms.smooth_field.array import ( + RandSmoothDeform, + RandSmoothFieldAdjustContrast, + RandSmoothFieldAdjustIntensity, +) +from monai.transforms.smooth_field.dictionary import ( + RandSmoothDeformd, + RandSmoothFieldAdjustContrastd, + RandSmoothFieldAdjustIntensityd, +) +from monai.transforms.spatial.array import ( + GridDistortion, + Rand2DElastic, + RandAffine, + RandAxisFlip, + RandGridDistortion, + RandRotate90, + Resize, + Spacing, +) +from monai.transforms.spatial.dictionary import ( + GridDistortiond, + Rand2DElasticd, + RandAffined, + RandAxisFlipd, + RandGridDistortiond, + RandRotate90d, + Resized, + Spacingd, +) +from monai.utils.enums import CommonKeys +from monai.utils.module import optional_import + +if TYPE_CHECKING: + import matplotlib.pyplot as plt + + has_matplotlib = True + +else: + plt, has_matplotlib = optional_import("matplotlib.pyplot") + + +def get_data(keys): + """Get the example data to be used. + + Use MarsAtlas as it only contains 1 image for quick download and + that image is parcellated. + """ + cache_dir = os.environ.get("MONAI_DATA_DIRECTORY") or tempfile.mkdtemp() + fname = "MarsAtlas-MNI-Colin27.zip" + url = "https://www.dropbox.com/s/ndz8qtqblkciole/" + fname + "?dl=1" + out_path = os.path.join(cache_dir, "MarsAtlas-MNI-Colin27") + zip_path = os.path.join(cache_dir, fname) + + download_and_extract(url, zip_path, out_path) + + image, label = sorted(glob(os.path.join(out_path, "*.nii"))) + + data = {CommonKeys.IMAGE: image, CommonKeys.LABEL: label} + + transforms = Compose( + [LoadImaged(keys), AddChanneld(keys), ScaleIntensityd(CommonKeys.IMAGE), Rotate90d(keys, spatial_axes=[0, 2])] + ) + data = transforms(data) + max_size = max(data[keys[0]].shape) + padder = SpatialPadd(keys, (max_size, max_size, max_size)) + return padder(data) + + +def update_docstring(code_path, transform_name): + """ + Find the documentation for a given transform and if it's missing, + add a pointer to the transform's example image. + """ + with open(code_path) as f: + contents = f.readlines() + doc_start = None + for i, line in enumerate(contents): + # find the line containing start of the transform documentation + if "`" + transform_name + "`" in line: + doc_start = i + break + if doc_start is None: + raise RuntimeError("Couldn't find transform documentation") + + # if image is already in docs, nothing to do + image_line = doc_start + 2 + if ".. image" in contents[image_line]: + return + + # add the line for the image and the alt text + contents_orig = deepcopy(contents) + contents.insert( + image_line, + ".. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/" + transform_name + ".png\n", + ) + contents.insert(image_line + 1, " :alt: example of " + transform_name + "\n") + + # check that we've only added two lines + assert len(contents) == len(contents_orig) + 2 + + # write the updated doc to overwrite the original + with open(code_path, "w") as f: + f.writelines(contents) + + +def pre_process_data(data, ndim, is_map, is_post): + """If transform requires 2D data, then convert to 2D""" + if ndim == 2: + for k in keys: + data[k] = data[k][..., data[k].shape[-1] // 2] + + if is_map: + return data + return data[CommonKeys.LABEL] if is_post else data[CommonKeys.IMAGE] + + +def get_2d_slice(image, view, is_label): + """If image is 3d, get the central slice. If is already 2d, return as-is. + If image is label, set 0 to np.nan. + """ + if image.ndim == 2: + out = image + else: + shape = image.shape + slices = [slice(0, s) for s in shape] + _slice = shape[view] // 2 + slices[view] = slice(_slice, _slice + 1) + slices = tuple(slices) + out = np.squeeze(image[slices], view) + if is_label: + out[out == 0] = np.nan + return out + + +def get_stacked_2d_ims(im, is_label): + """Get the 3 orthogonal views and stack them into 1 image. + Requires that all images be same size, but this is taken care + of by the `SpatialPadd` earlier. + """ + return [get_2d_slice(im, i, is_label) for i in range(3)] + + +def get_stacked_before_after(before, after, is_label=False): + """Stack before and after images into 1 image if 3d. + Requires that before and after images be the same size. + """ + return [get_stacked_2d_ims(d, is_label) for d in (before, after)] + + +def save_image(images, labels, filename, transform_name, transform_args, shapes, colorbar=False): + """Save image to file, ensuring there's no whitespace around the edge.""" + plt.rcParams.update({"font.family": "monospace"}) + plt.style.use("dark_background") + nrow = len(images) # before and after (should always be 2) + ncol = len(images[0]) # num orthogonal views (either 1 or 3) + # roughly estimate the height_ratios of the first:second row + hs = [float(r[0].shape[0]) for r in images] + fig = plt.figure(tight_layout=True) + spec = fig.add_gridspec(nrow, ncol, hspace=0, wspace=0, height_ratios=hs) + for row in range(nrow): + vmin = min(i.min() for i in images[row]) + vmax = max(i.max() for i in images[row]) + for col in range(ncol): + ax = fig.add_subplot(spec[row, col]) + imshow = ax.imshow(images[row][col], cmap="gray", vmin=vmin, vmax=vmax) + ax.set_aspect("equal") + if colorbar and col == ncol - 1: + plt.colorbar(imshow, ax=ax) + if col == 0: + y_label = "After" if row else "Before" + y_label += ("\n" + shapes[row]) if shapes[0] != shapes[1] else "" + ax.set_ylabel(y_label) + # print yticks for the right most column + if col != ncol - 1 or colorbar: + ax.set_yticks([]) + else: + ax.yaxis.tick_right() + for n, label in enumerate(ax.yaxis.get_ticklabels()): + if n > 2: + label.set_visible(False) + ax.set_xticks([]) + ax.set_frame_on(False) + if labels is not None: + ax.imshow(labels[row][col], cmap="hsv", alpha=0.9, interpolation="nearest") + # title is e.g., Flipd(keys=keys, spatial_axis=0) + title = transform_name + "(" + for k, v in transform_args.items(): + title += k + "=" + if isinstance(v, str): + title += "'" + v + "'" + elif isinstance(v, (np.ndarray, torch.Tensor)): + title += "[array]" + elif isinstance(v, Callable): + title += "[callable]" + else: + title += str(v) + title += ", " + if len(transform_args) > 0: + title = title[:-2] + title += ")" + # shorten the lines + title = textwrap.fill(title, 50, break_long_words=False, subsequent_indent=" " * (len(transform_name) + 1)) + fig.suptitle(title, x=0.1, horizontalalignment="left") + fig.savefig(filename) + plt.close(fig) + + +def get_images(data, is_label=False): + """Get image. If is dictionary, extract key. If is list, stack. If both dictionary and list, do both. + Also return the image size as string to be used im the imshow. If it's a list, return `N x (H,W,D)`. + """ + # If not a list, convert + if not isinstance(data, list): + data = [data] + key = CommonKeys.LABEL if is_label else CommonKeys.IMAGE + is_map = isinstance(data[0], dict) + # length of the list will be equal to number of samples produced. This will be 1 except for transforms that + # produce `num_samples`. + data = [d[key] if is_map else d for d in data] + data = [d[0] for d in data] # remove channel component + + # for each sample, create a list of the orthogonal views. If image is 2d, length will be 1. If 3d, there + # will be three orthogonal views + num_samples = len(data) + num_orthog_views = 3 if data[0].ndim == 3 else 1 + shape_str = (f"{num_samples} x " if num_samples > 1 else "") + str(data[0].shape) + for i in range(num_samples): + data[i] = [get_2d_slice(data[i], view, is_label) for view in range(num_orthog_views)] + + out = [] + if num_samples == 1: + out = data[0] + else: + # we might need to panel the images. this happens if a transform produces e.g. 4 output images. + # In this case, we create a 2-by-2 grid from them. Output will be a list containing n_orthog_views, + # each element being either the image (if num_samples is 1) or the panelled image. + nrows = int(np.floor(num_samples ** 0.5)) + for view in range(num_orthog_views): + result = np.asarray([d[view] for d in data]) + nindex, height, width = result.shape + ncols = nindex // nrows + # only implemented for square number of images (e.g. 4 images goes to a 2-by-2 panel) + if nindex != nrows * ncols: + raise NotImplementedError + # want result.shape = (height*nrows, width*ncols), have to be careful about striding + result = result.reshape(nrows, ncols, height, width).swapaxes(1, 2).reshape(height * nrows, width * ncols) + out.append(result) + return out, shape_str + + +def create_transform_im( + transform, transform_args, data, ndim=3, colorbar=False, update_doc=True, seed=0, is_post=False +): + """Create an image with the before and after of the transform. + Also update the transform's documentation to point to this image.""" + + transform = transform(**transform_args) + + if not has_matplotlib: + raise RuntimeError + + if isinstance(transform, Randomizable): + # increment the seed for map transforms so they're different to the array versions. + seed = seed + 1 if isinstance(transform, MapTransform) else seed + transform.set_random_state(seed) + + out_dir = os.environ.get("MONAI_DOC_IMAGES") + if out_dir is None: + raise RuntimeError( + "Please git clone https://github.com/Project-MONAI/DocImages" + + " and then set the environment variable `MONAI_DOC_IMAGES`" + ) + out_dir = os.path.join(out_dir, "transforms") + + # Path is transform name + transform_name = transform.__class__.__name__ + out_fname = transform_name + ".png" + out_file = os.path.join(out_dir, out_fname) + + is_map = isinstance(transform, MapTransform) + data_in = pre_process_data(deepcopy(data), ndim, is_map, is_post) + + data_tr = transform(deepcopy(data_in)) + + images_before, before_shape = get_images(data_in) + images_after, after_shape = get_images(data_tr) + images = (images_before, images_after) + shapes = (before_shape, after_shape) + + labels = None + if is_map: + labels_before, *_ = get_images(data_in, is_label=True) + labels_after, *_ = get_images(data_tr, is_label=True) + labels = (labels_before, labels_after) + + save_image(images, labels, out_file, transform_name, transform_args, shapes, colorbar) + + if update_doc: + base_dir = pathlib.Path(__file__).parent.parent.parent + rst_path = os.path.join(base_dir, "docs", "source", "transforms.rst") + update_docstring(rst_path, transform_name) + + +if __name__ == "__main__": + + keys = [CommonKeys.IMAGE, CommonKeys.LABEL] + data = get_data(keys) + create_transform_im(RandFlip, dict(prob=1, spatial_axis=1), data) + create_transform_im(RandFlipd, dict(keys=keys, prob=1, spatial_axis=2), data) + create_transform_im(Flip, dict(spatial_axis=1), data) + create_transform_im(Flipd, dict(keys=keys, spatial_axis=2), data) + create_transform_im(Flipd, dict(keys=keys, spatial_axis=2), data) + create_transform_im(Orientation, dict(axcodes="RPI", image_only=True), data) + create_transform_im(Orientationd, dict(keys=keys, axcodes="RPI"), data) + create_transform_im( + Rand3DElastic, dict(prob=1.0, sigma_range=(1, 2), magnitude_range=(0.5, 0.5), shear_range=(1, 1, 1)), data + ) + create_transform_im(Affine, dict(shear_params=(0, 0.5, 0), image_only=True, padding_mode="zeros"), data) + create_transform_im( + Affined, dict(keys=keys, shear_params=(0, 0.5, 0), mode=["bilinear", "nearest"], padding_mode="zeros"), data + ) + create_transform_im(RandAffine, dict(prob=1, shear_range=(0.5, 0.5), padding_mode="zeros"), data) + create_transform_im( + RandAffined, + dict(keys=keys, prob=1, shear_range=(0.5, 0.5), mode=["bilinear", "nearest"], padding_mode="zeros"), + data, + ) + create_transform_im( + Rand3DElastic, dict(sigma_range=(5, 7), magnitude_range=(50, 150), prob=1, padding_mode="zeros"), data + ) + create_transform_im( + Rand2DElastic, dict(prob=1, spacing=(20, 20), magnitude_range=(1, 2), padding_mode="zeros"), data, 2 + ) + create_transform_im( + Rand2DElasticd, + dict( + keys=keys, + prob=1, + spacing=(20, 20), + magnitude_range=(1, 2), + padding_mode="zeros", + mode=["bilinear", "nearest"], + ), + data, + 2, + ) + create_transform_im( + Rand3DElasticd, + dict( + keys=keys, + sigma_range=(5, 7), + magnitude_range=(50, 150), + prob=1, + padding_mode="zeros", + mode=["bilinear", "nearest"], + ), + data, + ) + create_transform_im(Rotate90, dict(spatial_axes=(1, 2)), data) + create_transform_im(Rotate90d, dict(keys=keys, spatial_axes=(1, 2)), data) + create_transform_im(RandRotate90, dict(prob=1), data) + create_transform_im(RandRotate90d, dict(keys=keys, prob=1), data) + create_transform_im(Rotate, dict(angle=0.1), data) + create_transform_im(Rotated, dict(keys=keys, angle=0.1, mode=["bilinear", "nearest"]), data) + create_transform_im(RandRotate, dict(prob=1, range_x=[0.4, 0.4]), data) + create_transform_im(RandRotated, dict(keys=keys, prob=1, range_x=[0.4, 0.4], mode=["bilinear", "nearest"]), data) + create_transform_im(Zoom, dict(zoom=0.6), data) + create_transform_im(Zoomd, dict(keys=keys, zoom=1.3, mode=["area", "nearest"]), data) + create_transform_im(RandZoom, dict(prob=1, min_zoom=0.6, max_zoom=0.8), data) + create_transform_im(RandZoomd, dict(keys=keys, prob=1, min_zoom=1.3, max_zoom=1.5, mode=["area", "nearest"]), data) + create_transform_im(ScaleIntensity, dict(minv=0, maxv=10), data, colorbar=True) + create_transform_im(ScaleIntensityd, dict(keys=CommonKeys.IMAGE, minv=0, maxv=10), data, colorbar=True) + create_transform_im(RandScaleIntensity, dict(prob=1.0, factors=(5, 10)), data, colorbar=True) + create_transform_im( + RandScaleIntensityd, dict(keys=CommonKeys.IMAGE, prob=1.0, factors=(5, 10)), data, colorbar=True + ) + create_transform_im(DivisiblePad, dict(k=64), data) + create_transform_im(DivisiblePadd, dict(keys=keys, k=64), data) + create_transform_im(CropForeground, dict(), data) + create_transform_im(CropForegroundd, dict(keys=keys, source_key=CommonKeys.IMAGE), data) + create_transform_im(RandGaussianNoise, dict(prob=1, mean=0, std=0.1), data) + create_transform_im(RandGaussianNoised, dict(keys=CommonKeys.IMAGE, prob=1, mean=0, std=0.1), data) + create_transform_im(KSpaceSpikeNoise, dict(loc=(100, 100, 100), k_intensity=13), data) + create_transform_im(KSpaceSpikeNoised, dict(keys=CommonKeys.IMAGE, loc=(100, 100, 100), k_intensity=13), data) + create_transform_im(RandKSpaceSpikeNoise, dict(prob=1, intensity_range=(10, 13)), data) + create_transform_im( + RandKSpaceSpikeNoised, + dict(keys=CommonKeys.IMAGE, global_prob=1, prob=1, common_sampling=True, intensity_range=(13, 15)), + data, + ) + create_transform_im(RandRicianNoise, dict(prob=1.0, mean=1, std=0.5), data) + create_transform_im(RandRicianNoised, dict(keys=CommonKeys.IMAGE, prob=1.0, mean=1, std=0.5), data) + create_transform_im(SavitzkyGolaySmooth, dict(window_length=5, order=1), data) + create_transform_im(SavitzkyGolaySmoothd, dict(keys=CommonKeys.IMAGE, window_length=5, order=1), data) + create_transform_im(GibbsNoise, dict(alpha=0.8), data) + create_transform_im(GibbsNoised, dict(keys=CommonKeys.IMAGE, alpha=0.8), data) + create_transform_im(RandGibbsNoise, dict(prob=1.0, alpha=(0.6, 0.8)), data) + create_transform_im(RandGibbsNoised, dict(keys=CommonKeys.IMAGE, prob=1.0, alpha=(0.6, 0.8)), data) + create_transform_im(ShiftIntensity, dict(offset=1), data, colorbar=True) + create_transform_im(ShiftIntensityd, dict(keys=CommonKeys.IMAGE, offset=1), data, colorbar=True) + create_transform_im(RandShiftIntensity, dict(prob=1.0, offsets=(10, 20)), data, colorbar=True) + create_transform_im( + RandShiftIntensityd, dict(keys=CommonKeys.IMAGE, prob=1.0, offsets=(10, 20)), data, colorbar=True + ) + create_transform_im(StdShiftIntensity, dict(factor=10), data, colorbar=True) + create_transform_im(StdShiftIntensityd, dict(keys=CommonKeys.IMAGE, factor=10), data, colorbar=True) + create_transform_im(RandStdShiftIntensity, dict(prob=1.0, factors=(5, 10)), data, colorbar=True) + create_transform_im( + RandStdShiftIntensityd, dict(keys=CommonKeys.IMAGE, prob=1.0, factors=(5, 10)), data, colorbar=True + ) + create_transform_im(RandBiasField, dict(prob=1, coeff_range=(0.2, 0.3)), data) + create_transform_im(RandBiasFieldd, dict(keys=CommonKeys.IMAGE, prob=1, coeff_range=(0.2, 0.3)), data) + create_transform_im(NormalizeIntensity, dict(subtrahend=0, divisor=10), data, colorbar=True) + create_transform_im(NormalizeIntensityd, dict(keys=CommonKeys.IMAGE, subtrahend=0, divisor=10), data, colorbar=True) + create_transform_im(ThresholdIntensity, dict(threshold=0.4, above=False, cval=0.9), data, colorbar=True) + create_transform_im( + ThresholdIntensityd, dict(keys=CommonKeys.IMAGE, threshold=0.4, above=False, cval=0.9), data, colorbar=True + ) + create_transform_im(ScaleIntensityRange, dict(a_min=0, a_max=1, b_min=1, b_max=10), data, colorbar=True) + create_transform_im( + ScaleIntensityRanged, dict(keys=CommonKeys.IMAGE, a_min=0, a_max=1, b_min=1, b_max=10), data, colorbar=True + ) + create_transform_im(ScaleIntensityRangePercentiles, dict(lower=5, upper=95, b_min=1, b_max=10), data, colorbar=True) + create_transform_im( + ScaleIntensityRangePercentilesd, + dict(keys=CommonKeys.IMAGE, lower=5, upper=95, b_min=1, b_max=10), + data, + colorbar=True, + ) + create_transform_im(AdjustContrast, dict(gamma=2), data, colorbar=True) + create_transform_im(AdjustContrastd, dict(keys=CommonKeys.IMAGE, gamma=2), data, colorbar=True) + create_transform_im(RandAdjustContrast, dict(prob=1, gamma=(1.5, 2)), data, colorbar=True) + create_transform_im(RandAdjustContrastd, dict(keys=CommonKeys.IMAGE, prob=1, gamma=(1.5, 2)), data, colorbar=True) + create_transform_im(MaskIntensity, dict(mask_data=data[CommonKeys.IMAGE], select_fn=lambda x: x > 0.3), data) + create_transform_im( + MaskIntensityd, dict(keys=CommonKeys.IMAGE, mask_key=CommonKeys.IMAGE, select_fn=lambda x: x > 0.3), data + ) + create_transform_im(GaussianSmooth, dict(sigma=2), data) + create_transform_im(GaussianSmoothd, dict(keys=CommonKeys.IMAGE, sigma=2), data) + create_transform_im(RandGaussianSmooth, dict(prob=1.0, sigma_x=(1, 2)), data) + create_transform_im(RandGaussianSmoothd, dict(keys=CommonKeys.IMAGE, prob=1.0, sigma_x=(1, 2)), data) + create_transform_im(GaussianSharpen, dict(), GaussianSmoothd(CommonKeys.IMAGE, 2)(data)) + create_transform_im(GaussianSharpend, dict(keys=CommonKeys.IMAGE), GaussianSmoothd(CommonKeys.IMAGE, 2)(data)) + create_transform_im(RandGaussianSharpen, dict(prob=1), GaussianSmoothd(CommonKeys.IMAGE, 2)(data)) + create_transform_im( + RandGaussianSharpend, dict(keys=CommonKeys.IMAGE, prob=1), GaussianSmoothd(CommonKeys.IMAGE, 2)(data) + ) + create_transform_im(RandHistogramShift, dict(prob=1, num_control_points=3), data, colorbar=True) + create_transform_im( + RandHistogramShiftd, dict(keys=CommonKeys.IMAGE, prob=1, num_control_points=3), data, colorbar=True + ) + create_transform_im(RandCoarseDropout, dict(prob=1, holes=200, spatial_size=20, fill_value=0), data) + create_transform_im( + RandCoarseDropoutd, dict(keys=CommonKeys.IMAGE, prob=1, holes=200, spatial_size=20, fill_value=0), data + ) + create_transform_im(RandCoarseShuffle, dict(prob=1, holes=200, spatial_size=20), data) + create_transform_im(RandCoarseShuffled, dict(keys=CommonKeys.IMAGE, prob=1, holes=200, spatial_size=20), data) + create_transform_im(HistogramNormalize, dict(num_bins=10), data) + create_transform_im(HistogramNormalized, dict(keys=CommonKeys.IMAGE, num_bins=10), data) + create_transform_im(SpatialPad, dict(spatial_size=(300, 300, 300)), data) + create_transform_im(SpatialPadd, dict(keys=keys, spatial_size=(300, 300, 300)), data) + create_transform_im(BorderPad, dict(spatial_border=10), data) + create_transform_im(BorderPadd, dict(keys=keys, spatial_border=10), data) + create_transform_im(SpatialCrop, dict(roi_center=(75, 75, 75), roi_size=(100, 100, 100)), data) + create_transform_im(SpatialCropd, dict(keys=keys, roi_center=(75, 75, 75), roi_size=(100, 100, 100)), data) + create_transform_im(CenterSpatialCrop, dict(roi_size=(100, 100, 100)), data) + create_transform_im(CenterSpatialCropd, dict(keys=keys, roi_size=(100, 100, 100)), data) + create_transform_im(RandSpatialCrop, dict(roi_size=(100, 100, 100), random_size=False), data) + create_transform_im(RandSpatialCropd, dict(keys=keys, roi_size=(100, 100, 100), random_size=False), data) + create_transform_im(RandSpatialCropSamples, dict(num_samples=4, roi_size=(100, 100, 100), random_size=False), data) + create_transform_im( + RandSpatialCropSamplesd, dict(keys=keys, num_samples=4, roi_size=(100, 100, 100), random_size=False), data + ) + create_transform_im( + RandWeightedCrop, dict(spatial_size=(100, 100, 100), num_samples=4, weight_map=data[CommonKeys.IMAGE] > 0), data + ) + create_transform_im( + RandWeightedCropd, dict(keys=keys, spatial_size=(100, 100, 100), num_samples=4, w_key=CommonKeys.IMAGE), data + ) + create_transform_im( + RandCropByPosNegLabel, + dict(spatial_size=(100, 100, 100), label=data[CommonKeys.LABEL], neg=0, num_samples=4), + data, + ) + create_transform_im( + RandCropByPosNegLabeld, + dict(keys=keys, spatial_size=(100, 100, 100), label_key=CommonKeys.LABEL, neg=0, num_samples=4), + data, + ) + create_transform_im( + RandCropByLabelClasses, + dict( + spatial_size=(100, 100, 100), label=data[CommonKeys.LABEL] > 0, num_classes=2, ratios=[0, 1], num_samples=4 + ), + data, + ) + create_transform_im( + RandCropByLabelClassesd, + dict( + keys=keys, + spatial_size=(100, 100, 100), + label_key=CommonKeys.LABEL, + num_classes=2, + ratios=[0, 1], + num_samples=4, + ), + data, + ) + create_transform_im(ResizeWithPadOrCrop, dict(spatial_size=(100, 100, 100)), data) + create_transform_im(ResizeWithPadOrCropd, dict(keys=keys, spatial_size=(100, 100, 100)), data) + create_transform_im(RandScaleCrop, dict(roi_scale=0.4), data) + create_transform_im(RandScaleCropd, dict(keys=keys, roi_scale=0.4), data) + create_transform_im(CenterScaleCrop, dict(roi_scale=0.4), data) + create_transform_im(CenterScaleCropd, dict(keys=keys, roi_scale=0.4), data) + create_transform_im(AsDiscrete, dict(to_onehot=None, threshold=10), data, is_post=True, colorbar=True) + create_transform_im(AsDiscreted, dict(keys=CommonKeys.LABEL, to_onehot=None, threshold=10), data, is_post=True) + create_transform_im(LabelFilter, dict(applied_labels=(1, 2, 3, 4, 5, 6)), data, is_post=True) + create_transform_im( + LabelFilterd, dict(keys=CommonKeys.LABEL, applied_labels=(1, 2, 3, 4, 5, 6)), data, is_post=True + ) + create_transform_im(LabelToContour, dict(), data, is_post=True) + create_transform_im(LabelToContourd, dict(keys=CommonKeys.LABEL), data, is_post=True) + create_transform_im(Spacing, dict(pixdim=(5, 5, 5), image_only=True), data) + create_transform_im(Spacingd, dict(keys=keys, pixdim=(5, 5, 5), mode=["bilinear", "nearest"]), data) + create_transform_im(RandAxisFlip, dict(prob=1), data) + create_transform_im(RandAxisFlipd, dict(keys=keys, prob=1), data) + create_transform_im(Resize, dict(spatial_size=(100, 100, 100)), data) + create_transform_im(Resized, dict(keys=keys, spatial_size=(100, 100, 100), mode=["area", "nearest"]), data) + data_binary = deepcopy(data) + data_binary[CommonKeys.LABEL] = (data_binary[CommonKeys.LABEL] > 0).astype(np.float32) + create_transform_im(KeepLargestConnectedComponent, dict(applied_labels=1), data_binary, is_post=True, ndim=2) + create_transform_im( + KeepLargestConnectedComponentd, dict(keys=CommonKeys.LABEL, applied_labels=1), data_binary, is_post=True, ndim=2 + ) + create_transform_im( + GridDistortion, dict(num_cells=3, distort_steps=[(1.5,) * 4] * 3, mode="nearest", padding_mode="zeros"), data + ) + create_transform_im( + GridDistortiond, + dict( + keys=keys, num_cells=3, distort_steps=[(1.5,) * 4] * 3, mode=["bilinear", "nearest"], padding_mode="zeros" + ), + data, + ) + create_transform_im(RandGridDistortion, dict(num_cells=3, prob=1.0, distort_limit=(-0.1, 0.1)), data) + create_transform_im( + RandGridDistortiond, + dict(keys=keys, num_cells=4, prob=1.0, distort_limit=(-0.2, 0.2), mode=["bilinear", "nearest"]), + data, + ) + create_transform_im( + RandSmoothFieldAdjustContrast, dict(spatial_size=(217, 217, 217), rand_size=(10, 10, 10), prob=1.0), data + ) + create_transform_im( + RandSmoothFieldAdjustContrastd, + dict(keys=keys, spatial_size=(217, 217, 217), rand_size=(10, 10, 10), prob=1.0), + data, + ) + create_transform_im( + RandSmoothFieldAdjustIntensity, + dict(spatial_size=(217, 217, 217), rand_size=(10, 10, 10), prob=1.0, gamma=(0.5, 4.5)), + data, + ) + create_transform_im( + RandSmoothFieldAdjustIntensityd, + dict(keys=keys, spatial_size=(217, 217, 217), rand_size=(10, 10, 10), prob=1.0, gamma=(0.5, 4.5)), + data, + ) + + create_transform_im( + RandSmoothDeform, + dict(spatial_size=(217, 217, 217), rand_size=(10, 10, 10), prob=1.0, def_range=0.05, grid_mode="blinear"), + data, + ) + create_transform_im( + RandSmoothDeformd, + dict( + keys=keys, + spatial_size=(217, 217, 217), + rand_size=(10, 10, 10), + prob=1.0, + def_range=0.05, + grid_mode="blinear", + ), + data, + ) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 2eebe3eda3..25d59ac89a 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,21 +9,40 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Sequence, Union + import numpy as np import torch from monai.config.type_definitions import NdarrayOrTensor +from monai.utils.misc import is_module_ver_at_least __all__ = [ "moveaxis", "in1d", + "clip", + "percentile", + "where", + "nonzero", + "floor_divide", + "unravel_index", + "unravel_indices", + "ravel", + "any_np_pt", + "maximum", + "concatenate", + "cumsum", + "isfinite", + "searchsorted", + "repeat", + "isnan", ] def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor: """`moveaxis` for pytorch and numpy, using `permute` for pytorch ver < 1.8""" if isinstance(x, torch.Tensor): - if hasattr(torch, "moveaxis"): + if hasattr(torch, "moveaxis"): # `moveaxis` is new in torch 1.8.0 return torch.moveaxis(x, src, dst) return _moveaxis_with_permute(x, src, dst) # type: ignore if isinstance(x, np.ndarray): @@ -50,3 +69,255 @@ def in1d(x, y): if isinstance(x, np.ndarray): return np.in1d(x, y) return (x[..., None] == torch.tensor(y, device=x.device)).any(-1).view(-1) + + +def clip(a: NdarrayOrTensor, a_min, a_max) -> NdarrayOrTensor: + """`np.clip` with equivalent implementation for torch.""" + result: NdarrayOrTensor + if isinstance(a, np.ndarray): + result = np.clip(a, a_min, a_max) + else: + result = torch.clamp(a, a_min, a_max) + return result + + +def percentile(x: NdarrayOrTensor, q, dim: Optional[int] = None) -> Union[NdarrayOrTensor, float, int]: + """`np.percentile` with equivalent implementation for torch. + + Pytorch uses `quantile`, but this functionality is only available from v1.7. + For earlier methods, we calculate it ourselves. This doesn't do interpolation, + so is the equivalent of ``numpy.percentile(..., interpolation="nearest")``. + For more details, please refer to: + https://pytorch.org/docs/stable/generated/torch.quantile.html. + https://numpy.org/doc/stable/reference/generated/numpy.percentile.html. + + Args: + x: input data + q: percentile to compute (should in range 0 <= q <= 100) + dim: the dim along which the percentiles are computed. default is to compute the percentile + along a flattened version of the array. only work for numpy array or Tensor with PyTorch >= 1.7.0. + + Returns: + Resulting value (scalar) + """ + if np.isscalar(q): + if not 0 <= q <= 100: # type: ignore + raise ValueError + elif any(q < 0) or any(q > 100): + raise ValueError + result: Union[NdarrayOrTensor, float, int] + if isinstance(x, np.ndarray): + result = np.percentile(x, q, axis=dim) + else: + q = torch.tensor(q, device=x.device) + if hasattr(torch, "quantile"): # `quantile` is new in torch 1.7.0 + result = torch.quantile(x, q / 100.0, dim=dim) + else: + # Note that ``kthvalue()`` works one-based, i.e., the first sorted value + # corresponds to k=1, not k=0. Thus, we need the `1 +`. + k = 1 + (0.01 * q * (x.numel() - 1)).round().int() + if k.numel() > 1: + r = [x.view(-1).kthvalue(int(_k)).values.item() for _k in k] + result = torch.tensor(r, device=x.device) + else: + result = x.view(-1).kthvalue(int(k)).values.item() + + return result + + +def where(condition: NdarrayOrTensor, x=None, y=None) -> NdarrayOrTensor: + """ + Note that `torch.where` may convert y.dtype to x.dtype. + """ + result: NdarrayOrTensor + if isinstance(condition, np.ndarray): + if x is not None: + result = np.where(condition, x, y) + else: + result = np.where(condition) # type: ignore + else: + if x is not None: + x = torch.as_tensor(x, device=condition.device) + y = torch.as_tensor(y, device=condition.device, dtype=x.dtype) + result = torch.where(condition, x, y) + else: + result = torch.where(condition) # type: ignore + return result + + +def nonzero(x: NdarrayOrTensor): + """`np.nonzero` with equivalent implementation for torch. + + Args: + x: array/tensor + + Returns: + Index unravelled for given shape + """ + if isinstance(x, np.ndarray): + return np.nonzero(x)[0] + return torch.nonzero(x).flatten() + + +def floor_divide(a: NdarrayOrTensor, b) -> NdarrayOrTensor: + """`np.floor_divide` with equivalent implementation for torch. + + As of pt1.8, use `torch.div(..., rounding_mode="floor")`, and + before that, use `torch.floor_divide`. + + Args: + a: first array/tensor + b: scalar to divide by + + Returns: + Element-wise floor division between two arrays/tensors. + """ + if isinstance(a, torch.Tensor): + if is_module_ver_at_least(torch, (1, 8, 0)): + return torch.div(a, b, rounding_mode="floor") + return torch.floor_divide(a, b) + return np.floor_divide(a, b) + + +def unravel_index(idx, shape): + """`np.unravel_index` with equivalent implementation for torch. + + Args: + idx: index to unravel + shape: shape of array/tensor + + Returns: + Index unravelled for given shape + """ + if isinstance(idx, torch.Tensor): + coord = [] + for dim in reversed(shape): + coord.append(idx % dim) + idx = floor_divide(idx, dim) + return torch.stack(coord[::-1]) + return np.asarray(np.unravel_index(idx, shape)) + + +def unravel_indices(idx, shape): + """Computing unravel coordinates from indices. + + Args: + idx: a sequence of indices to unravel + shape: shape of array/tensor + + Returns: + Stacked indices unravelled for given shape + """ + lib_stack = torch.stack if isinstance(idx[0], torch.Tensor) else np.stack + return lib_stack([unravel_index(i, shape) for i in idx]) + + +def ravel(x: NdarrayOrTensor): + """`np.ravel` with equivalent implementation for torch. + + Args: + x: array/tensor to ravel + + Returns: + Return a contiguous flattened array/tensor. + """ + if isinstance(x, torch.Tensor): + if hasattr(torch, "ravel"): # `ravel` is new in torch 1.8.0 + return x.ravel() + return x.flatten().contiguous() + return np.ravel(x) + + +def any_np_pt(x: NdarrayOrTensor, axis: Union[int, Sequence[int]]): + """`np.any` with equivalent implementation for torch. + + For pytorch, convert to boolean for compatibility with older versions. + + Args: + x: input array/tensor + axis: axis to perform `any` over + + Returns: + Return a contiguous flattened array/tensor. + """ + if isinstance(x, np.ndarray): + return np.any(x, axis) + + # pytorch can't handle multiple dimensions to `any` so loop across them + axis = [axis] if not isinstance(axis, Sequence) else axis + for ax in axis: + try: + x = torch.any(x, ax) + except RuntimeError: + # older versions of pytorch require the input to be cast to boolean + x = torch.any(x.bool(), ax) + return x + + +def maximum(a: NdarrayOrTensor, b: NdarrayOrTensor) -> NdarrayOrTensor: + """`np.maximum` with equivalent implementation for torch. + + `torch.maximum` only available from pt>1.6, else use `torch.stack` and `torch.max`. + + Args: + a: first array/tensor + b: second array/tensor + + Returns: + Element-wise maximum between two arrays/tensors. + """ + if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): + # is torch and has torch.maximum (pt>1.6) + if hasattr(torch, "maximum"): # `maximum` is new in torch 1.7.0 + return torch.maximum(a, b) + return torch.stack((a, b)).max(dim=0)[0] + return np.maximum(a, b) + + +def concatenate(to_cat: Sequence[NdarrayOrTensor], axis: int = 0, out=None) -> NdarrayOrTensor: + """`np.concatenate` with equivalent implementation for torch (`torch.cat`).""" + if isinstance(to_cat[0], np.ndarray): + return np.concatenate(to_cat, axis, out) # type: ignore + return torch.cat(to_cat, dim=axis, out=out) # type: ignore + + +def cumsum(a: NdarrayOrTensor, axis=None): + """`np.cumsum` with equivalent implementation for torch.""" + if isinstance(a, np.ndarray): + return np.cumsum(a, axis) + if axis is None: + return torch.cumsum(a[:], 0) + return torch.cumsum(a, dim=axis) + + +def isfinite(x): + """`np.isfinite` with equivalent implementation for torch.""" + if not isinstance(x, torch.Tensor): + return np.isfinite(x) + return torch.isfinite(x) + + +def searchsorted(a: NdarrayOrTensor, v: NdarrayOrTensor, right=False, sorter=None): + side = "right" if right else "left" + if isinstance(a, np.ndarray): + return np.searchsorted(a, v, side, sorter) # type: ignore + return torch.searchsorted(a, v, right=right) # type: ignore + + +def repeat(a: NdarrayOrTensor, repeats: int, axis: Optional[int] = None): + """`np.repeat` with equivalent implementation for torch (`repeat_interleave`).""" + if isinstance(a, np.ndarray): + return np.repeat(a, repeats, axis) + return torch.repeat_interleave(a, repeats, dim=axis) + + +def isnan(x: NdarrayOrTensor): + """`np.isnan` with equivalent implementation for torch. + + Args: + x: array/tensor + + """ + if isinstance(x, np.ndarray): + return np.isnan(x) + return torch.isnan(x) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index aa8f02f815..0c04680234 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,7 +12,7 @@ # have to explicitly bring these in here to resolve circular import issues from .aliases import alias, resolve_name from .decorators import MethodReplacer, RestartGenerator -from .deprecated import DeprecatedError, deprecated, deprecated_arg +from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg from .dist import evenly_divisible_all_gather, get_dist_device, string_list_all_gather from .enums import ( Average, @@ -24,12 +24,14 @@ GridSamplePadMode, InterpolateMode, InverseKeys, + JITMetadataKeys, LossReduction, Method, MetricReduction, NumpyPadMode, PytorchPadMode, SkipMode, + TraceKeys, TransformBackends, UpsampleMode, Weight, @@ -57,7 +59,6 @@ zip_with, ) from .module import ( - PT_BEFORE_1_7, InvalidPyTorchVersionError, OptionalImportError, damerau_levenshtein_distance, @@ -70,6 +71,8 @@ look_up_option, min_version, optional_import, + pytorch_after, + require_pkg, version_leq, ) from .nvtx import Range @@ -77,6 +80,7 @@ from .state_cacher import StateCacher from .type_conversion import ( convert_data_type, + convert_to_cupy, convert_to_dst_type, convert_to_numpy, convert_to_tensor, diff --git a/monai/utils/aliases.py b/monai/utils/aliases.py index 2b7b29eeb5..0ae79e26ff 100644 --- a/monai/utils/aliases.py +++ b/monai/utils/aliases.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -70,8 +70,8 @@ def resolve_name(name): try: mod = importlib.import_module(modname) obj = getattr(mod, declname, None) - except ModuleNotFoundError: - raise ValueError(f"Module {modname!r} not found.") + except ModuleNotFoundError as not_found_err: + raise ValueError(f"Module {modname!r} not found.") from not_found_err if obj is None: raise ValueError(f"Module {modname!r} does not have member {declname!r}.") diff --git a/monai/utils/decorators.py b/monai/utils/decorators.py index 1931d703c9..0856c0fc1a 100644 --- a/monai/utils/decorators.py +++ b/monai/utils/decorators.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/utils/deprecated.py b/monai/utils/deprecate_utils.py similarity index 82% rename from monai/utils/deprecated.py rename to monai/utils/deprecate_utils.py index 3a4568b06c..a6092c1b63 100644 --- a/monai/utils/deprecated.py +++ b/monai/utils/deprecate_utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,6 +10,7 @@ # limitations under the License. import inspect +import sys import warnings from functools import wraps from types import FunctionType @@ -60,6 +61,9 @@ def deprecated( Decorated definition which warns or raises exception when used """ + # if version_val.startswith("0+"): + # # version unknown, set version_val to a large value (assuming the latest version) + # version_val = f"{sys.maxsize}" if since is not None and removed is not None and not version_leq(since, removed): raise ValueError(f"since must be less or equal to removed, got since={since}, removed={removed}.") is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since) @@ -116,6 +120,7 @@ def deprecated_arg( removed: Optional[str] = None, msg_suffix: str = "", version_val: str = __version__, + new_name: Optional[str] = None, ): """ Marks a particular named argument of a callable as deprecated. The same conditions for `since` and `removed` as @@ -130,6 +135,8 @@ def deprecated_arg( using the Sphinx directives such as `.. versionchanged:: version` and `.. deprecated:: version`. https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded + In the current implementation type annotations are not preserved. + Args: name: name of position or keyword argument to mark as deprecated. @@ -137,17 +144,23 @@ def deprecated_arg( removed: version at which the argument was removed and no longer usable. msg_suffix: message appended to warning/exception detailing reasons for deprecation and what to use instead. version_val: (used for testing) version to compare since and removed against, default is MONAI version. + new_name: name of position or keyword argument to replace the deprecated argument. + if it is specified and the signature of the decorated function has a `kwargs`, the value to the + deprecated argument `name` will be removed. Returns: - Decorated callable which warns or raises exception when deprecated argument used + Decorated callable which warns or raises exception when deprecated argument used. """ + + if version_val.startswith("0+") or not f"{version_val}".strip()[0].isdigit(): + # version unknown, set version_val to a large value (assuming the latest version) + version_val = f"{sys.maxsize}" if since is not None and removed is not None and not version_leq(since, removed): raise ValueError(f"since must be less or equal to removed, got since={since}, removed={removed}.") is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since) if is_not_yet_deprecated: # smaller than `since`, do nothing return lambda obj: obj - if since is None and removed is None: # raise a DeprecatedError directly is_removed = True @@ -157,9 +170,6 @@ def deprecated_arg( is_deprecated = since is not None and version_leq(since, version_val) is_removed = removed is not None and version_leq(removed, version_val) - if is_not_yet_deprecated: - return lambda obj: obj - def _decorator(func): argname = f"{func.__name__}_{name}" @@ -180,10 +190,23 @@ def _decorator(func): @wraps(func) def _wrapper(*args, **kwargs): + if new_name is not None and name in kwargs and new_name not in kwargs: + # replace the deprecated arg "name" with "new_name" + # if name is specified and new_name is not specified + kwargs[new_name] = kwargs[name] + try: + sig.bind(*args, **kwargs).arguments + except TypeError: + # multiple values for new_name using both args and kwargs + kwargs.pop(new_name, None) binding = sig.bind(*args, **kwargs).arguments - positional_found = name in binding - kw_found = "kwargs" in binding and name in binding["kwargs"] + kw_found = False + for k, param in sig.parameters.items(): + if param.kind == inspect.Parameter.VAR_KEYWORD and k in binding and name in binding[k]: + kw_found = True + # if the deprecated arg is found in the **kwargs, it should be removed + kwargs.pop(name, None) if positional_found or kw_found: if is_removed: diff --git a/monai/utils/dist.py b/monai/utils/dist.py index beb958a5c8..37536bfe83 100644 --- a/monai/utils/dist.py +++ b/monai/utils/dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 847df9e2d3..08a59c598c 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,6 +11,8 @@ from enum import Enum +from monai.utils.deprecate_utils import deprecated + __all__ = [ "NumpyPadMode", "GridSampleMode", @@ -26,6 +28,7 @@ "ChannelMatching", "SkipMode", "Method", + "TraceKeys", "InverseKeys", "CommonKeys", "ForwardMode", @@ -208,8 +211,27 @@ class ForwardMode(Enum): EVAL = "eval" +class TraceKeys: + """Extra meta data keys used for traceable transforms.""" + + CLASS_NAME = "class" + ID = "id" + ORIG_SIZE = "orig_size" + EXTRA_INFO = "extra_info" + DO_TRANSFORM = "do_transforms" + KEY_SUFFIX = "_transforms" + NONE = "none" + + +@deprecated(since="0.8.0", msg_suffix="use monai.utils.enums.TraceKeys instead.") class InverseKeys: - """Extra meta data keys used for inverse transforms.""" + """ + Extra meta data keys used for inverse transforms. + + .. deprecated:: 0.8.0 + Use :class:`monai.utils.enums.TraceKeys` instead. + + """ CLASS_NAME = "class" ID = "id" @@ -217,6 +239,7 @@ class InverseKeys: EXTRA_INFO = "extra_info" DO_TRANSFORM = "do_transforms" KEY_SUFFIX = "_transforms" + NONE = "none" class CommonKeys: @@ -243,3 +266,15 @@ class TransformBackends(Enum): TORCH = "torch" NUMPY = "numpy" + + +class JITMetadataKeys(Enum): + """ + Keys stored in the metadata file for saved Torchscript models. Some of these are generated by the routines + and others are optionally provided by users. + """ + + NAME = "name" + TIMESTAMP = "timestamp" + VERSION = "version" + DESCRIPTION = "description" diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py index 26487083b1..366d11ebd8 100644 --- a/monai/utils/jupyter_utils.py +++ b/monai/utils/jupyter_utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,11 +16,14 @@ from enum import Enum from threading import RLock, Thread -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch +from monai.config import IgniteInfo +from monai.utils.module import min_version, optional_import + try: import matplotlib.pyplot as plt @@ -28,14 +31,11 @@ except ImportError: has_matplotlib = False -try: +if TYPE_CHECKING: from ignite.engine import Engine, Events - - has_ignite = True -except ImportError: - Engine = object - Events = object - has_ignite = False +else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") LOSS_NAME = "loss" @@ -128,7 +128,7 @@ def plot_metric_images( else: im.imshow(np.squeeze(imagemap[n]), cmap="gray") - im.set_title("%s\n%.3g -> %.3g" % (n, imagemap[n].min(), imagemap[n].max())) + im.set_title(f"{n}\n{imagemap[n].min():.3g} -> {imagemap[n].max():.3g}") im.axis("off") axes.append(im) @@ -161,6 +161,7 @@ def plot_engine_status( window_fraction: int = 20, image_fn: Optional[Callable] = tensor_to_images, fig=None, + selected_inst: int = 0, ) -> Tuple: """ Plot the status of the given Engine with its logger. The plot will consist of a graph of loss values and metrics @@ -177,6 +178,7 @@ def plot_engine_status( window_fraction: for metric plot, what fraction of the graph value length to use as the running average window image_fn: callable converting tensors keyed to a name in the Engine to a tuple of images to plot fig: Figure object to plot into, reuse from previous plotting for flicker-free refreshing + selected_inst: index of the instance to show in the image plot Returns: Figure object (or `fig` if given), list of Axes objects for graph and images @@ -189,22 +191,36 @@ def plot_engine_status( graphmap = {LOSS_NAME: logger.loss} graphmap.update(logger.metrics) - imagemap = {} + imagemap: Dict = {} if image_fn is not None and engine.state is not None and engine.state.batch is not None: for src in (engine.state.batch, engine.state.output): + label = "Batch" if src is engine.state.batch else "Output" + batch_selected_inst = selected_inst # selected batch index, set to 0 when src is decollated + + # if the src object is a list of elements, ie. a decollated batch, select an element and keep it as + # a dictionary of tensors with a batch dimension added if isinstance(src, list): - for i, s in enumerate(src): - if isinstance(s, dict): - for k, v in s.items(): - if isinstance(v, torch.Tensor): - image = image_fn(k, v) - if image is not None: - imagemap[f"{k}_{i}"] = image - elif isinstance(s, torch.Tensor): - label = "Batch" if src is engine.state.batch else "Output" - image = image_fn(label, s) + selected_dict = src[selected_inst] # select this element + batch_selected_inst = 0 # set the selection to be the single index in the batch dimension + # store each tensor that is interpretable as an image with an added batch dimension + src = {k: v[None] for k, v in selected_dict.items() if isinstance(v, torch.Tensor) and v.ndim >= 3} + + # images will be generated from the batch item selected above only, or from the single item given as `src` + + if isinstance(src, dict): + for k, v in src.items(): + if isinstance(v, torch.Tensor) and v.ndim >= 4: + image = image_fn(k, v[batch_selected_inst]) + + # if we have images add each one separately to the map if image is not None: - imagemap[f"{label}_{i}"] = image + for i, im in enumerate(image): + imagemap[f"{k}_{i}"] = im + + elif isinstance(src, torch.Tensor): + image = image_fn(label, src) + if image is not None: + imagemap[f"{label}_{i}"] = image axes = plot_metric_images(fig, title, graphmap, imagemap, yscale, avg_keys, window_fraction) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index a31452f6ae..eae0580696 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,7 +22,7 @@ import numpy as np import torch -from monai.utils.module import get_torch_version_tuple, version_leq +from monai.utils.module import version_leq __all__ = [ "zip_with", @@ -43,6 +43,7 @@ "copy_to_device", "ImageMetaKey", "is_module_ver_at_least", + "has_option", ] _seed = None @@ -88,7 +89,7 @@ def ensure_tuple(vals: Any) -> Tuple[Any, ...]: Returns a tuple of `vals`. """ if not issequenceiterable(vals): - vals = (vals,) + return (vals,) return tuple(vals) @@ -97,8 +98,8 @@ def ensure_tuple_size(tup: Any, dim: int, pad_val: Any = 0) -> Tuple[Any, ...]: """ Returns a copy of `tup` with `dim` values by either shortened or padded with `pad_val` as necessary. """ - tup = ensure_tuple(tup) + (pad_val,) * dim - return tuple(tup[:dim]) + new_tup = ensure_tuple(tup) + (pad_val,) * dim + return new_tup[:dim] def ensure_tuple_rep(tup: Any, dim: int) -> Tuple[Any, ...]: @@ -250,19 +251,21 @@ def set_determinism( for func in additional_settings: func(seed) + if torch.backends.flags_frozen(): + warnings.warn("PyTorch global flag support of backends is disabled, enable it to set global `cudnn` flags.") + torch.backends.__allow_nonbracketed_mutation_flag = True + if seed is not None: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False else: # restore the original flags torch.backends.cudnn.deterministic = _flag_deterministic torch.backends.cudnn.benchmark = _flag_cudnn_benchmark - if use_deterministic_algorithms is not None: - torch_ver = get_torch_version_tuple() - if torch_ver >= (1, 9): + if hasattr(torch, "use_deterministic_algorithms"): # `use_deterministic_algorithms` is new in torch 1.8.0 torch.use_deterministic_algorithms(use_deterministic_algorithms) - elif torch_ver >= (1, 7): - torch.set_deterministic(use_deterministic_algorithms) # beta feature + elif hasattr(torch, "set_deterministic"): # `set_deterministic` is new in torch 1.7.0 + torch.set_deterministic(use_deterministic_algorithms) # type: ignore else: warnings.warn("use_deterministic_algorithms=True, but PyTorch version is too old to set the mode.") @@ -279,9 +282,7 @@ def list_to_dict(items): def _parse_var(s): items = s.split("=", maxsplit=1) key = items[0].strip(" \n\r\t'") - value = None - if len(items) > 1: - value = items[1].strip(" \n\r\t'") + value = items[1].strip(" \n\r\t'") if len(items) > 1 else None return key, value d = {} @@ -302,10 +303,7 @@ def _parse_var(s): def copy_to_device( - obj: Any, - device: Optional[Union[str, torch.device]], - non_blocking: bool = True, - verbose: bool = False, + obj: Any, device: Optional[Union[str, torch.device]], non_blocking: bool = True, verbose: bool = False ) -> Any: """ Copy object or tuple/list/dictionary of objects to ``device``. @@ -314,7 +312,7 @@ def copy_to_device( obj: object or tuple/list/dictionary of objects to move to ``device``. device: move ``obj`` to this device. Can be a string (e.g., ``cpu``, ``cuda``, ``cuda:0``, etc.) or of type ``torch.device``. - non_blocking_transfer: when `True`, moves data to device asynchronously if + non_blocking: when `True`, moves data to device asynchronously if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. verbose: when `True`, will print a warning for any elements of incompatible type not copied to ``device``. diff --git a/monai/utils/module.py b/monai/utils/module.py index 33314fb0e3..c0fc10a7c0 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -8,12 +8,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import enum +import os +import re import sys import warnings +from functools import wraps from importlib import import_module from pkgutil import walk_packages from re import match +from types import FunctionType from typing import Any, Callable, Collection, Hashable, Iterable, List, Mapping, Tuple, cast import torch @@ -29,12 +34,13 @@ "look_up_option", "min_version", "optional_import", + "require_pkg", "load_submodules", "get_full_type_name", "get_package_version", "get_torch_version_tuple", - "PT_BEFORE_1_7", "version_leq", + "pytorch_after", ] @@ -136,9 +142,7 @@ def damerau_levenshtein_distance(s1: str, s2: str): for j, s2j in enumerate(s2): cost = 0 if s1i == s2j else 1 d[(i, j)] = min( - d[(i - 1, j)] + 1, # deletion - d[(i, j - 1)] + 1, # insertion - d[(i - 1, j - 1)] + cost, # substitution + d[(i - 1, j)] + 1, d[(i, j - 1)] + 1, d[(i - 1, j - 1)] + cost # deletion # insertion # substitution ) if i and j and s1i == s2[j - 1] and s1[i - 1] == s2j: d[(i, j)] = min(d[(i, j)], d[i - 2, j - 2] + cost) # transposition @@ -349,6 +353,45 @@ def __call__(self, *_args, **_kwargs): return _LazyRaise(), False +def require_pkg( + pkg_name: str, version: str = "", version_checker: Callable[..., bool] = min_version, raise_error: bool = True +): + """ + Decorator function to check the required package installation. + + Args: + pkg_name: required package name, like: "itk", "nibabel", etc. + version: required version string used by the version_checker. + version_checker: a callable to check the module version, defaults to `monai.utils.min_version`. + raise_error: if True, raise `OptionalImportError` error if the required package is not installed + or the version doesn't match requirement, if False, print the error in a warning. + + """ + + def _decorator(obj): + is_func = isinstance(obj, FunctionType) + call_obj = obj if is_func else obj.__init__ + _, has = optional_import(module=pkg_name, version=version, version_checker=version_checker) + + @wraps(call_obj) + def _wrapper(*args, **kwargs): + if not has: + err_msg = f"required package `{pkg_name}` is not installed or the version doesn't match requirement." + if raise_error: + raise OptionalImportError(err_msg) + else: + warnings.warn(err_msg) + + return call_obj(*args, **kwargs) + + if is_func: + return _wrapper + obj.__init__ = _wrapper + return obj + + return _decorator + + def get_package_version(dep_name, default="NOT INSTALLED or UNKNOWN VERSION."): """ Try to load package and get version. If not found, return `default`. @@ -364,17 +407,25 @@ def get_torch_version_tuple(): Returns: tuple of ints represents the pytorch major/minor version. """ - return tuple((int(x) for x in torch.__version__.split(".")[:2])) + return tuple(int(x) for x in torch.__version__.split(".")[:2]) -def version_leq(lhs, rhs): - """Returns True if version `lhs` is earlier or equal to `rhs`.""" +def version_leq(lhs: str, rhs: str): + """ + Returns True if version `lhs` is earlier or equal to `rhs`. + + Args: + lhs: version name to compare with `rhs`, return True if earlier or equal to `rhs`. + rhs: version name to compare with `lhs`, return True if later or equal to `lhs`. + """ + + lhs, rhs = str(lhs), str(rhs) ver, has_ver = optional_import("pkg_resources", name="parse_version") if has_ver: return ver(lhs) <= ver(rhs) - def _try_cast(val): + def _try_cast(val: str): val = val.strip() try: m = match("(\\d+)(.*)", val) @@ -390,10 +441,10 @@ def _try_cast(val): rhs = rhs.split("+", 1)[0] # parse the version strings in this basic way without `packaging` package - lhs = map(_try_cast, lhs.split(".")) - rhs = map(_try_cast, rhs.split(".")) + lhs_ = map(_try_cast, lhs.split(".")) + rhs_ = map(_try_cast, rhs.split(".")) - for l, r in zip(lhs, rhs): + for l, r in zip(lhs_, rhs_): if l != r: if isinstance(l, int) and isinstance(r, int): return l < r @@ -402,7 +453,51 @@ def _try_cast(val): return True -try: - PT_BEFORE_1_7 = torch.__version__ != "1.7.0" and version_leq(torch.__version__, "1.7.0") -except (AttributeError, TypeError): - PT_BEFORE_1_7 = True +def pytorch_after(major, minor, patch=0, current_ver_string=None) -> bool: + """ + Compute whether the current pytorch version is after or equal to the specified version. + The current system pytorch version is determined by `torch.__version__` or + via system environment variable `PYTORCH_VER`. + + Args: + major: major version number to be compared with + minor: minor version number to be compared with + patch: patch version number to be compared with + current_ver_string: if None, `torch.__version__` will be used. + + Returns: + True if the current pytorch version is greater than or equal to the specified version. + """ + + try: + if current_ver_string is None: + _env_var = os.environ.get("PYTORCH_VER", "") + current_ver_string = _env_var if _env_var else torch.__version__ + ver, has_ver = optional_import("pkg_resources", name="parse_version") + if has_ver: + return ver(".".join((f"{major}", f"{minor}", f"{patch}"))) <= ver(f"{current_ver_string}") # type: ignore + parts = f"{current_ver_string}".split("+", 1)[0].split(".", 3) + while len(parts) < 3: + parts += ["0"] + c_major, c_minor, c_patch = parts[:3] + except (AttributeError, ValueError, TypeError): + c_major, c_minor = get_torch_version_tuple() + c_patch = "0" + c_mn = int(c_major), int(c_minor) + mn = int(major), int(minor) + if c_mn != mn: + return c_mn > mn + is_prerelease = ("a" in f"{c_patch}".lower()) or ("rc" in f"{c_patch}".lower()) + c_p = 0 + try: + p_reg = re.search(r"\d+", f"{c_patch}") + if p_reg: + c_p = int(p_reg.group()) + except (AttributeError, TypeError, ValueError): + is_prerelease = True + patch = int(patch) + if c_p != patch: + return c_p > patch # type: ignore + if is_prerelease: + return False + return True diff --git a/monai/utils/nvtx.py b/monai/utils/nvtx.py index 1980ceef71..691f900c7d 100644 --- a/monai/utils/nvtx.py +++ b/monai/utils/nvtx.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -62,8 +62,15 @@ def __call__(self, obj: Any): # Define the name to be associated to the range if not provided if self.name is None: name = type(obj).__name__ + # If CuCIM or TorchVision transform wrappers are being used, + # append the underlying transform to the name for more clarity + if "CuCIM" in name or "TorchVision" in name: + name = f"{name}_{obj.name}" self.name_counter[name] += 1 - self.name = f"{name}_{self.name_counter[name]}" + if self.name_counter[name] > 1: + self.name = f"{name}_{self.name_counter[name]}" + else: + self.name = name # Define the methods to be wrapped if not provided if self.methods is None: @@ -138,7 +145,7 @@ def _get_method(self, obj: Any) -> tuple: if len(method_list) < 1: raise ValueError( f"The method to be wrapped for this object [{type(obj)}] is not recognized." - "The name of the method should be provied or the object should have one of these methods:" + "The name of the method should be provided or the object should have one of these methods:" f"{default_methods}" ) return ensure_tuple(method_list) diff --git a/monai/utils/profiling.py b/monai/utils/profiling.py index 695653e897..8e0742268f 100644 --- a/monai/utils/profiling.py +++ b/monai/utils/profiling.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -56,7 +56,7 @@ def wrapper(*args, **kwargs): cpu_time = torch.autograd.profiler.format_time(cpu_time) gpu_time = torch.autograd.profiler.format_time(gpu_time) - print("cpu time: {}, gpu time: {}".format(cpu_time, gpu_time), flush=True) + print(f"cpu time: {cpu_time}, gpu time: {gpu_time}", flush=True) return result @@ -83,7 +83,7 @@ def wrapper(*args, **kwargs): total_time = (end - start) * 1e6 total_time_str = torch.autograd.profiler.format_time(total_time) - print("end to end time: {}".format(total_time_str), flush=True) + print(f"end to end time: {total_time_str}", flush=True) return result diff --git a/monai/utils/state_cacher.py b/monai/utils/state_cacher.py index 94943a8c37..3e392ab979 100644 --- a/monai/utils/state_cacher.py +++ b/monai/utils/state_cacher.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,10 +11,14 @@ import copy import os +import pickle import tempfile from typing import Dict, Optional import torch +from torch.serialization import DEFAULT_PROTOCOL + +from monai.config.type_definitions import PathLike __all__ = ["StateCacher"] @@ -37,8 +41,10 @@ class StateCacher: def __init__( self, in_memory: bool, - cache_dir: Optional[str] = None, + cache_dir: Optional[PathLike] = None, allow_overwrite: bool = True, + pickle_module=pickle, + pickle_protocol: int = DEFAULT_PROTOCOL, ) -> None: """Constructor. @@ -51,20 +57,39 @@ def __init__( allow_overwrite: allow the cache to be overwritten. If set to `False`, an error will be thrown if a matching already exists in the list of cached objects. + pickle_module: module used for pickling metadata and objects, default to `pickle`. + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. + pickle_protocol: can be specified to override the default protocol, default to `2`. + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. + """ self.in_memory = in_memory - self.cache_dir = cache_dir + self.cache_dir = tempfile.gettempdir() if cache_dir is None else cache_dir + if not os.path.isdir(self.cache_dir): + raise ValueError("Given `cache_dir` is not a valid directory.") + self.allow_overwrite = allow_overwrite + self.pickle_module = pickle_module + self.pickle_protocol = pickle_protocol + self.cached: Dict = {} - if self.cache_dir is None: - self.cache_dir = tempfile.gettempdir() - elif not os.path.isdir(self.cache_dir): - raise ValueError("Given `cache_dir` is not a valid directory.") + def store(self, key, data_obj, pickle_module=None, pickle_protocol: Optional[int] = None): + """ + Store a given object with the given key name. - self.cached: Dict[str, str] = {} + Args: + key: key of the data object to store. + data_obj: data object to store. + pickle_module: module used for pickling metadata and objects, default to `self.pickle_module`. + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. + pickle_protocol: can be specified to override the default protocol, default to `self.pickle_protocol`. + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. - def store(self, key, data_obj): - """Store a given object with the given key name.""" + """ if key in self.cached and not self.allow_overwrite: raise RuntimeError("Cached key already exists and overwriting is disabled.") if self.in_memory: @@ -72,7 +97,12 @@ def store(self, key, data_obj): else: fn = os.path.join(self.cache_dir, f"state_{key}_{id(self)}.pt") self.cached.update({key: {"obj": fn}}) - torch.save(data_obj, fn) + torch.save( + obj=data_obj, + f=fn, + pickle_module=self.pickle_module if pickle_module is None else pickle_module, + pickle_protocol=self.pickle_protocol if pickle_protocol is None else pickle_protocol, + ) # store object's device if relevant if hasattr(data_obj, "device"): self.cached[key]["device"] = data_obj.device diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index b0ce187e38..8686557176 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -1,3 +1,14 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import re from typing import Any, Optional, Sequence, Tuple, Union @@ -6,6 +17,7 @@ from monai.config.type_definitions import DtypeLike, NdarrayOrTensor from monai.utils import optional_import +from monai.utils.module import look_up_option cp, has_cp = optional_import("cupy") cp_ndarray, _ = optional_import("cupy", name="ndarray") @@ -16,6 +28,7 @@ "get_equivalent_dtype", "convert_data_type", "get_dtype", + "convert_to_cupy", "convert_to_numpy", "convert_to_tensor", "convert_to_dst_type", @@ -40,31 +53,34 @@ def dtype_torch_to_numpy(dtype): """Convert a torch dtype to its numpy equivalent.""" - if dtype not in _torch_to_np_dtype: - raise ValueError(f"Unsupported torch to numpy dtype '{dtype}'.") - return _torch_to_np_dtype[dtype] + return look_up_option(dtype, _torch_to_np_dtype) def dtype_numpy_to_torch(dtype): """Convert a numpy dtype to its torch equivalent.""" # np dtypes can be given as np.float32 and np.dtype(np.float32) so unify them - dtype = np.dtype(dtype) if type(dtype) is type else dtype - if dtype not in _np_to_torch_dtype: - raise ValueError(f"Unsupported numpy to torch dtype '{dtype}'.") - return _np_to_torch_dtype[dtype] + dtype = np.dtype(dtype) if isinstance(dtype, (type, str)) else dtype + return look_up_option(dtype, _np_to_torch_dtype) def get_equivalent_dtype(dtype, data_type): """Convert to the `dtype` that corresponds to `data_type`. - Example: + + Example:: + im = torch.tensor(1) dtype = get_equivalent_dtype(np.float32, type(im)) + """ + if dtype is None: + return None if data_type is torch.Tensor: - if type(dtype) is torch.dtype: + if isinstance(dtype, torch.dtype): + # already a torch dtype and target `data_type` is torch.Tensor return dtype return dtype_numpy_to_torch(dtype) - if type(dtype) is not torch.dtype: + if not isinstance(dtype, torch.dtype): + # assuming the dtype is ok if it is not a torch dtype and target `data_type` is not torch.Tensor return dtype return dtype_torch_to_numpy(dtype) @@ -83,43 +99,49 @@ def get_dtype(data: Any): return type(data) -def convert_to_tensor(data, wrap_sequence: bool = False): +def convert_to_tensor( + data, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, wrap_sequence: bool = False +): """ Utility to convert the input data to a PyTorch Tensor. If passing a dictionary, list or tuple, recursively check every item and convert it to PyTorch Tensor. Args: data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. - will convert Tensor, Numpy array, float, int, bool to Tensors, strings and objects keep the original. + will convert Tensor, Numpy array, float, int, bool to Tensor, strings and objects keep the original. for dictionary, list or tuple, convert every item to a Tensor if applicable. - wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. - If `True`, then `[1, 2]` -> `tensor([1, 2])`. + dtype: target data type to when converting to Tensor. + device: target device to put the converted Tensor data. + wrap_sequence: if `False`, then lists will recursively call this function. + E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. If `True`, then `[1, 2]` -> `tensor([1, 2])`. """ if isinstance(data, torch.Tensor): - return data.contiguous() + return data.to(dtype=dtype, device=device, memory_format=torch.contiguous_format) # type: ignore if isinstance(data, np.ndarray): # skip array of string classes and object, refer to: # https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/_utils/collate.py#L13 if re.search(r"[SaUO]", data.dtype.str) is None: # numpy array with 0 dims is also sequence iterable, # `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims - return torch.as_tensor(data if data.ndim == 0 else np.ascontiguousarray(data)) - elif isinstance(data, (float, int, bool)): - return torch.as_tensor(data) - elif isinstance(data, Sequence) and wrap_sequence: - return torch.as_tensor(data) + if data.ndim > 0: + data = np.ascontiguousarray(data) + return torch.as_tensor(data, dtype=dtype, device=device) # type: ignore + elif (has_cp and isinstance(data, cp_ndarray)) or isinstance(data, (float, int, bool)): + return torch.as_tensor(data, dtype=dtype, device=device) # type: ignore elif isinstance(data, list): - return [convert_to_tensor(i) for i in data] + list_ret = [convert_to_tensor(i, dtype=dtype, device=device) for i in data] + return torch.as_tensor(list_ret, dtype=dtype, device=device) if wrap_sequence else list_ret # type: ignore elif isinstance(data, tuple): - return tuple(convert_to_tensor(i) for i in data) + tuple_ret = tuple(convert_to_tensor(i, dtype=dtype, device=device) for i in data) + return torch.as_tensor(tuple_ret, dtype=dtype, device=device) if wrap_sequence else tuple_ret # type: ignore elif isinstance(data, dict): - return {k: convert_to_tensor(v) for k, v in data.items()} + return {k: convert_to_tensor(v, dtype=dtype, device=device) for k, v in data.items()} return data -def convert_to_numpy(data, wrap_sequence: bool = False): +def convert_to_numpy(data, dtype: DtypeLike = None, wrap_sequence: bool = False): """ Utility to convert the input data to a numpy array. If passing a dictionary, list or tuple, recursively check every item and convert it to numpy array. @@ -128,23 +150,24 @@ def convert_to_numpy(data, wrap_sequence: bool = False): data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will convert Tensor, Numpy array, float, int, bool to numpy arrays, strings and objects keep the original. for dictionary, list or tuple, convert every item to a numpy array if applicable. - wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`. - If `True`, then `[1, 2]` -> `array([1, 2])`. + dtype: target data type when converting to numpy array. + wrap_sequence: if `False`, then lists will recursively call this function. + E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`. """ if isinstance(data, torch.Tensor): - data = data.detach().cpu().numpy() + data = data.detach().to(dtype=get_equivalent_dtype(dtype, torch.Tensor), device="cpu").numpy() elif has_cp and isinstance(data, cp_ndarray): - data = cp.asnumpy(data) - elif isinstance(data, (float, int, bool)): - data = np.asarray(data) - elif isinstance(data, Sequence) and wrap_sequence: - return np.asarray(data) + data = cp.asnumpy(data).astype(dtype, copy=False) + elif isinstance(data, (np.ndarray, float, int, bool)): + data = np.asarray(data, dtype=dtype) elif isinstance(data, list): - return [convert_to_numpy(i) for i in data] + list_ret = [convert_to_numpy(i, dtype=dtype) for i in data] + return np.asarray(list_ret) if wrap_sequence else list_ret elif isinstance(data, tuple): - return tuple(convert_to_numpy(i) for i in data) + tuple_ret = tuple(convert_to_numpy(i, dtype=dtype) for i in data) + return np.asarray(tuple_ret) if wrap_sequence else tuple_ret elif isinstance(data, dict): - return {k: convert_to_numpy(v) for k, v in data.items()} + return {k: convert_to_numpy(v, dtype=dtype) for k, v in data.items()} if isinstance(data, np.ndarray) and data.ndim > 0: data = np.ascontiguousarray(data) @@ -152,11 +175,47 @@ def convert_to_numpy(data, wrap_sequence: bool = False): return data +def convert_to_cupy(data, dtype: Optional[np.dtype] = None, wrap_sequence: bool = False): + """ + Utility to convert the input data to a cupy array. If passing a dictionary, list or tuple, + recursively check every item and convert it to cupy array. + + Args: + data: input data can be PyTorch Tensor, numpy array, cupy array, list, dictionary, int, float, bool, str, etc. + Tensor, numpy array, cupy array, float, int, bool are converted to cupy arrays, + for dictionary, list or tuple, convert every item to a numpy array if applicable. + dtype: target data type when converting to Cupy array, tt must be an argument of `numpy.dtype`, + for more details: https://docs.cupy.dev/en/stable/reference/generated/cupy.array.html. + wrap_sequence: if `False`, then lists will recursively call this function. + E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`. + """ + + # direct calls + if isinstance(data, (cp_ndarray, np.ndarray, torch.Tensor, float, int, bool)): + data = cp.asarray(data, dtype) + elif isinstance(data, list): + list_ret = [convert_to_cupy(i, dtype) for i in data] + return cp.asarray(list_ret) if wrap_sequence else list_ret + elif isinstance(data, tuple): + tuple_ret = tuple(convert_to_cupy(i, dtype) for i in data) + return cp.asarray(tuple_ret) if wrap_sequence else tuple_ret + elif isinstance(data, dict): + return {k: convert_to_cupy(v, dtype) for k, v in data.items()} + # make it contiguous + if not isinstance(data, cp.ndarray): + raise ValueError(f"The input data type [{type(data)}] cannot be converted into cupy arrays!") + + if data.ndim > 0: + data = cp.ascontiguousarray(data) + return data + + def convert_data_type( data: Any, output_type: Optional[type] = None, device: Optional[torch.device] = None, dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + wrap_sequence: bool = False, ) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]: """ Convert to `torch.Tensor`/`np.ndarray` from `torch.Tensor`/`np.ndarray`/`float`/`int` etc. @@ -168,14 +227,27 @@ def convert_data_type( dtype: dtype of output data. Converted to correct library type (e.g., `np.float32` is converted to `torch.float32` if output type is `torch.Tensor`). If left blank, it remains unchanged. + wrap_sequence: if `False`, then lists will recursively call this function. + E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`. Returns: modified data, orig_type, orig_device + + Note: + When both `output_type` and `dtype` are specified with different backend + (e.g., `torch.Tensor` and `np.float32`), the `output_type` will be used as the primary type, + for example:: + + >>> convert_data_type(1, torch.Tensor, dtype=np.float32) + (1.0, , None) + """ orig_type: Any if isinstance(data, torch.Tensor): orig_type = torch.Tensor elif isinstance(data, np.ndarray): orig_type = np.ndarray + elif has_cp and isinstance(data, cp.ndarray): + orig_type = cp.ndarray else: orig_type = type(data) @@ -183,33 +255,47 @@ def convert_data_type( output_type = output_type or orig_type - dtype = get_equivalent_dtype(dtype or get_dtype(data), output_type) + dtype_ = get_equivalent_dtype(dtype, output_type) if output_type is torch.Tensor: - if orig_type is not torch.Tensor: - data = convert_to_tensor(data) - if dtype != data.dtype: - data = data.to(dtype) - if device is not None: - data = data.to(device) + data = convert_to_tensor(data, dtype=dtype_, device=device, wrap_sequence=wrap_sequence) elif output_type is np.ndarray: - if orig_type is not np.ndarray: - data = convert_to_numpy(data) - if data is not None and dtype != data.dtype: - data = data.astype(dtype) + data = convert_to_numpy(data, dtype=dtype_, wrap_sequence=wrap_sequence) + elif has_cp and output_type is cp.ndarray: + data = convert_to_cupy(data, dtype=dtype_, wrap_sequence=wrap_sequence) else: raise ValueError(f"Unsupported output type: {output_type}") return data, orig_type, orig_device -def convert_to_dst_type(src: Any, dst: NdarrayOrTensor) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]: +def convert_to_dst_type( + src: Any, dst: NdarrayOrTensor, dtype: Optional[Union[DtypeLike, torch.dtype]] = None, wrap_sequence: bool = False +) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]: """ - Convert `src` to the same `torch.Tensor`/`np.ndarray` and data type as `dst`. + Convert source data to the same data type and device as the destination data. + If `dst` is an instance of `torch.Tensor` or its subclass, convert `src` to `torch.Tensor` with the same data type as `dst`, + if `dst` is an instance of `numpy.ndarray` or its subclass, convert to `numpy.ndarray` with the same data type as `dst`, + otherwise, convert to the type of `dst` directly. + + Args: + src: source data to convert type. + dst: destination data that convert to the same data type as it. + dtype: an optional argument if the target `dtype` is different from the original `dst`'s data type. + wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`. + If `True`, then `[1, 2]` -> `array([1, 2])`. See Also: :func:`convert_data_type` """ - device = None + device = dst.device if isinstance(dst, torch.Tensor) else None + if dtype is None: + dtype = dst.dtype + + output_type: Any if isinstance(dst, torch.Tensor): - device = dst.device - return convert_data_type(data=src, output_type=type(dst), device=device, dtype=dst.dtype) + output_type = torch.Tensor + elif isinstance(dst, np.ndarray): + output_type = np.ndarray + else: + output_type = type(dst) + return convert_data_type(data=src, output_type=output_type, device=device, dtype=dtype, wrap_sequence=wrap_sequence) diff --git a/monai/visualize/__init__.py b/monai/visualize/__init__.py index 9ad61fa3f2..cd980846b3 100644 --- a/monai/visualize/__init__.py +++ b/monai/visualize/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,11 +10,7 @@ # limitations under the License. from .class_activation_maps import CAM, GradCAM, GradCAMpp, ModelWithHooks, default_normalizer -from .img2tensorboard import ( - add_animated_gif, - add_animated_gif_no_channels, - make_animated_gif_summary, - plot_2d_or_3d_image, -) +from .img2tensorboard import add_animated_gif, make_animated_gif_summary, plot_2d_or_3d_image from .occlusion_sensitivity import OcclusionSensitivity +from .utils import blend_images, matshow3d from .visualizer import default_upsampler diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index 992eaecdac..b7e433d6fa 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,7 +19,7 @@ from monai.config import NdarrayTensor from monai.transforms import ScaleIntensity -from monai.utils import ensure_tuple, get_torch_version_tuple +from monai.utils import ensure_tuple, pytorch_after from monai.visualize.visualizer import default_upsampler __all__ = ["CAM", "GradCAM", "GradCAMpp", "ModelWithHooks", "default_normalizer"] @@ -80,13 +80,13 @@ def __init__( continue _registered.append(name) if self.register_backward: - if get_torch_version_tuple() < (1, 8): - mod.register_backward_hook(self.backward_hook(name)) - else: + if pytorch_after(1, 8): if "inplace" in mod.__dict__ and mod.__dict__["inplace"]: # inplace=True causes errors for register_full_backward_hook mod.__dict__["inplace"] = False mod.register_full_backward_hook(self.backward_hook(name)) + else: + mod.register_backward_hook(self.backward_hook(name)) if self.register_forward: mod.register_forward_hook(self.forward_hook(name)) if len(_registered) != len(self.target_layers): @@ -137,6 +137,11 @@ def __call__(self, x, class_idx=None, retain_graph=False): self.score = self.class_score(logits, self.class_idx) self.model.zero_grad() self.score.sum().backward(retain_graph=retain_graph) + for layer in self.target_layers: + if layer not in self.gradients: + raise RuntimeError( + f"Backward hook for {layer} is not triggered; `requires_grad` of {layer} should be `True`." + ) grad = tuple(self.gradients[layer] for layer in self.target_layers) if train: self.model.train() @@ -221,6 +226,8 @@ class CAM(CAMBase): .. code-block:: python + import torch + # densenet 2d from monai.networks.nets import DenseNet121 from monai.visualize import CAM @@ -319,6 +326,8 @@ class GradCAM(CAMBase): .. code-block:: python + import torch + # densenet 2d from monai.networks.nets import DenseNet121 from monai.visualize import GradCAM diff --git a/monai/visualize/img2tensorboard.py b/monai/visualize/img2tensorboard.py index ccdbdc2396..4d4e91c6fd 100644 --- a/monai/visualize/img2tensorboard.py +++ b/monai/visualize/img2tensorboard.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,17 +9,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, List, Optional, Union import numpy as np import torch from monai.config import NdarrayTensor from monai.transforms import rescale_array -from monai.utils import optional_import +from monai.utils import convert_data_type, optional_import PIL, _ = optional_import("PIL") GifImage, _ = optional_import("PIL.GifImagePlugin", name="Image") +SummaryX, _ = optional_import("tensorboardX.proto.summary_pb2", name="Summary") +SummaryWriterX, has_tensorboardx = optional_import("tensorboardX", name="SummaryWriter") if TYPE_CHECKING: from tensorboard.compat.proto.summary_pb2 import Summary @@ -28,23 +30,28 @@ Summary, _ = optional_import("tensorboard.compat.proto.summary_pb2", name="Summary") SummaryWriter, _ = optional_import("torch.utils.tensorboard", name="SummaryWriter") +__all__ = ["make_animated_gif_summary", "add_animated_gif", "plot_2d_or_3d_image"] -__all__ = ["make_animated_gif_summary", "add_animated_gif", "add_animated_gif_no_channels", "plot_2d_or_3d_image"] - -def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], scale_factor: float = 1.0) -> Summary: +def _image3_animated_gif( + tag: str, image: Union[np.ndarray, torch.Tensor], writer, frame_dim: int = 0, scale_factor: float = 1.0 +): """Function to actually create the animated gif. Args: tag: Data identifier image: 3D image tensors expected to be in `HWD` format + writer: the tensorboard writer to plot image + frame_dim: the dimension used as frames for GIF image, expect data shape as `HWD`, default to `0`. scale_factor: amount to multiply values by. if the image data is between 0 and 1, using 255 for this value will scale it to displayable range """ if len(image.shape) != 3: raise AssertionError("3D image tensors expected to be in `HWD` format, len(image.shape) != 3") - ims = [(np.asarray((image[:, :, i])) * scale_factor).astype(np.uint8) for i in range(image.shape[2])] + image_np: np.ndarray + image_np, *_ = convert_data_type(image, output_type=np.ndarray) # type: ignore + ims = [(i * scale_factor).astype(np.uint8, copy=False) for i in np.moveaxis(image_np, frame_dim, 0)] ims = [GifImage.fromarray(im) for im in ims] img_str = b"" for b_data in PIL.GifImagePlugin.getheader(ims[0])[0]: @@ -54,52 +61,46 @@ def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], scale for b_data in PIL.GifImagePlugin.getdata(i): img_str += b_data img_str += b"\x3B" - summary_image_str = Summary.Image(height=10, width=10, colorspace=1, encoded_image_string=img_str) - image_summary = Summary.Value(tag=tag, image=summary_image_str) - return Summary(value=[image_summary]) + + summary = SummaryX if has_tensorboardx and isinstance(writer, SummaryWriterX) else Summary + summary_image_str = summary.Image(height=10, width=10, colorspace=1, encoded_image_string=img_str) + image_summary = summary.Value(tag=tag, image=summary_image_str) + return summary(value=[image_summary]) def make_animated_gif_summary( tag: str, image: Union[np.ndarray, torch.Tensor], + writer=None, max_out: int = 3, - animation_axes: Sequence[int] = (3,), - image_axes: Sequence[int] = (1, 2), - other_indices: Optional[Dict] = None, + frame_dim: int = -3, scale_factor: float = 1.0, ) -> Summary: """Creates an animated gif out of an image tensor in 'CHWD' format and returns Summary. Args: tag: Data identifier - image: The image, expected to be in CHWD format - max_out: maximum number of slices to animate through - animation_axes: axis to animate on (not currently used) - image_axes: axes of image (not currently used) - other_indices: (not currently used) + image: The image, expected to be in `CHWD` format + writer: the tensorboard writer to plot image + max_out: maximum number of image channels to animate through + frame_dim: the dimension used as frames for GIF image, expect input data shape as `CHWD`, + default to `-3` (the first spatial dim) scale_factor: amount to multiply values by. if the image data is between 0 and 1, using 255 for this value will scale it to displayable range """ suffix = "/image" if max_out == 1 else "/image/{}" - if other_indices is None: - other_indices = {} - axis_order = [0] + list(animation_axes) + list(image_axes) - - slicing = [] - for i in range(len(image.shape)): - if i in axis_order: - slicing.append(slice(None)) - else: - other_ind = other_indices.get(i, 0) - slicing.append(slice(other_ind, other_ind + 1)) - image = image[tuple(slicing)] + # GIF image has no channel dim, reduce the spatial dim index if positive + frame_dim = frame_dim - 1 if frame_dim > 0 else frame_dim + summary_op = [] for it_i in range(min(max_out, list(image.shape)[0])): one_channel_img: Union[torch.Tensor, np.ndarray] = ( image[it_i, :, :, :].squeeze(dim=0) if isinstance(image, torch.Tensor) else image[it_i, :, :, :] ) - summary_op = _image3_animated_gif(tag + suffix.format(it_i), one_channel_img, scale_factor) + summary_op.append( + _image3_animated_gif(tag + suffix.format(it_i), one_channel_img, writer, frame_dim, scale_factor) + ) return summary_op @@ -107,8 +108,9 @@ def add_animated_gif( writer: SummaryWriter, tag: str, image_tensor: Union[np.ndarray, torch.Tensor], - max_out: int, - scale_factor: float, + max_out: int = 3, + frame_dim: int = -3, + scale_factor: float = 1.0, global_step: Optional[int] = None, ) -> None: """Creates an animated gif out of an image tensor in 'CHWD' format and writes it with SummaryWriter. @@ -116,47 +118,20 @@ def add_animated_gif( Args: writer: Tensorboard SummaryWriter to write to tag: Data identifier - image_tensor: tensor for the image to add, expected to be in CHWD format - max_out: maximum number of slices to animate through + image_tensor: tensor for the image to add, expected to be in `CHWD` format + max_out: maximum number of image channels to animate through + frame_dim: the dimension used as frames for GIF image, expect input data shape as `CHWD`, + default to `-3` (the first spatial dim) scale_factor: amount to multiply values by. If the image data is between 0 and 1, using 255 for this value will scale it to displayable range global_step: Global step value to record """ - writer._get_file_writer().add_summary( - make_animated_gif_summary( - tag, image_tensor, max_out=max_out, animation_axes=[1], image_axes=[2, 3], scale_factor=scale_factor - ), - global_step, - ) - - -def add_animated_gif_no_channels( - writer: SummaryWriter, - tag: str, - image_tensor: Union[np.ndarray, torch.Tensor], - max_out: int, - scale_factor: float, - global_step: Optional[int] = None, -) -> None: - """Creates an animated gif out of an image tensor in 'HWD' format that does not have - a channel dimension and writes it with SummaryWriter. This is similar to the "add_animated_gif" - after inserting a channel dimension of 1. - - Args: - writer: Tensorboard SummaryWriter to write to - tag: Data identifier - image_tensor: tensor for the image to add, expected to be in HWD format - max_out: maximum number of slices to animate through - scale_factor: amount to multiply values by. If the image data is between 0 and 1, - using 255 for this value will scale it to displayable range - global_step: Global step value to record - """ - writer._get_file_writer().add_summary( - make_animated_gif_summary( - tag, image_tensor, max_out=max_out, animation_axes=[1], image_axes=[1, 2], scale_factor=scale_factor - ), - global_step, + summary = make_animated_gif_summary( + tag=tag, image=image_tensor, writer=writer, max_out=max_out, frame_dim=frame_dim, scale_factor=scale_factor ) + for s in summary: + # add GIF for every channel separately + writer._get_file_writer().add_summary(s, global_step) def plot_2d_or_3d_image( @@ -165,26 +140,33 @@ def plot_2d_or_3d_image( writer: SummaryWriter, index: int = 0, max_channels: int = 1, - max_frames: int = 64, + frame_dim: int = -3, + max_frames: int = 24, tag: str = "output", ) -> None: """Plot 2D or 3D image on the TensorBoard, 3D image will be converted to GIF image. Note: Plot 3D or 2D image(with more than 3 channels) as separate images. + And if writer is from TensorBoardX, data has 3 channels and `max_channels=3`, will plot as RGB video. Args: data: target data to be plotted as image on the TensorBoard. The data is expected to have 'NCHW[D]' dimensions or a list of data with `CHW[D]` dimensions, and only plot the first in the batch. step: current step to plot in a chart. - writer: specify TensorBoard SummaryWriter to plot the image. + writer: specify TensorBoard or TensorBoardX SummaryWriter to plot the image. index: plot which element in the input data batch, default is the first element. max_channels: number of channels to plot. - max_frames: number of frames for 2D-t plot. + frame_dim: if plotting 3D image as GIF, specify the dimension used as frames, + expect input data shape as `NCHWD`, default to `-3` (the first spatial dim) + max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`. tag: tag of the plotted image on TensorBoard. """ data_index = data[index] + # as the `d` data has no batch dim, reduce the spatial dim index if positive + frame_dim = frame_dim - 1 if frame_dim > 0 else frame_dim + d: np.ndarray = data_index.detach().cpu().numpy() if isinstance(data_index, torch.Tensor) else data_index if d.ndim == 2: @@ -206,7 +188,15 @@ def plot_2d_or_3d_image( if d.ndim >= 4: spatial = d.shape[-3:] - for j, d3 in enumerate(d.reshape([-1] + list(spatial))[:max_channels]): - d3 = rescale_array(d3, 0, 255) - add_animated_gif(writer, f"{tag}_HWD_{j}", d3[None], max_frames, 1.0, step) + d = d.reshape([-1] + list(spatial)) + if d.shape[0] == 3 and max_channels == 3 and has_tensorboardx and isinstance(writer, SummaryWriterX): # RGB + # move the expected frame dim to the end as `T` dim for video + d = np.moveaxis(d, frame_dim, -1) + writer.add_video(tag, d[None], step, fps=max_frames, dataformats="NCHWT") + return + # scale data to 0 - 255 for visualization + max_channels = min(max_channels, d.shape[0]) + d = np.stack([rescale_array(i, 0, 255) for i in d[:max_channels]], axis=0) + # will plot every channel as a separate GIF image + add_animated_gif(writer, f"{tag}_HWD", d, max_out=max_channels, frame_dim=frame_dim, global_step=step) return diff --git a/monai/visualize/occlusion_sensitivity.py b/monai/visualize/occlusion_sensitivity.py index 61b84bb406..184fa50e61 100644 --- a/monai/visualize/occlusion_sensitivity.py +++ b/monai/visualize/occlusion_sensitivity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -167,7 +167,7 @@ def __init__( upsampler: An upsampling method to upsample the output image. Default is N-dimensional linear (bilinear, trilinear, etc.) depending on num spatial dimensions of input. - verbose: Use ``tdqm.trange`` output (if available). + verbose: Use ``tqdm.trange`` output (if available). """ self.nn_module = nn_module @@ -265,9 +265,7 @@ def _compute_occlusion_sensitivity(self, x, b_box): return sensitivity_ims, output_im_shape def __call__( # type: ignore - self, - x: torch.Tensor, - b_box: Optional[Sequence] = None, + self, x: torch.Tensor, b_box: Optional[Sequence] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: diff --git a/monai/visualize/utils.py b/monai/visualize/utils.py new file mode 100644 index 0000000000..63cac7ea35 --- /dev/null +++ b/monai/visualize/utils.py @@ -0,0 +1,199 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import numpy as np + +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.croppad.array import SpatialPad +from monai.transforms.utils import rescale_array +from monai.transforms.utils_pytorch_numpy_unification import repeat, where +from monai.utils.module import optional_import +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type + +plt, _ = optional_import("matplotlib", name="pyplot") +cm, _ = optional_import("matplotlib", name="cm") + +__all__ = ["matshow3d", "blend_images"] + + +def matshow3d( + volume, + fig=None, + title: Optional[str] = None, + figsize=(10, 10), + frames_per_row: Optional[int] = None, + frame_dim: int = -3, + channel_dim: Optional[int] = None, + vmin=None, + vmax=None, + every_n: int = 1, + interpolation: str = "none", + show=False, + fill_value=np.nan, + margin: int = 1, + dtype=np.float32, + **kwargs, +): + """ + Create a 3D volume figure as a grid of images. + + Args: + volume: 3D volume to display. data shape can be `BCHWD`, `CHWD` or `HWD`. + Higher dimensional arrays will be reshaped into (-1, H, W, [C]), `C` depends on `channel_dim` arg. + A list of channel-first (C, H[, W, D]) arrays can also be passed in, + in which case they will be displayed as a padded and stacked volume. + fig: matplotlib figure to use. If None, a new figure will be created. + title: title of the figure. + figsize: size of the figure. + frames_per_row: number of frames to display in each row. If None, sqrt(firstdim) will be used. + frame_dim: for higher dimensional arrays, which dimension from (`-1`, `-2`, `-3`) is moved to + the `-3` dimension. dim and reshape to (-1, H, W) shape to construct frames, default to `-3`. + channel_dim: if not None, explicitly specify the channel dimension to be transposed to the + last dimensionas shape (-1, H, W, C). this can be used to plot RGB color image. + if None, the channel dimension will be flattened with `frame_dim` and `batch_dim` as shape (-1, H, W). + note that it can only support 3D input image. default is None. + vmin: `vmin` for the matplotlib `imshow`. + vmax: `vmax` for the matplotlib `imshow`. + every_n: factor to subsample the frames so that only every n-th frame is displayed. + interpolation: interpolation to use for the matplotlib `matshow`. + show: if True, show the figure. + fill_value: value to use for the empty part of the grid. + margin: margin to use for the grid. + dtype: data type of the output stacked frames. + kwargs: additional keyword arguments to matplotlib `matshow` and `imshow`. + + See Also: + - https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.imshow.html + - https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.matshow.html + + Example: + + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from monai.visualize import matshow3d + # create a figure of a 3D volume + >>> volume = np.random.rand(10, 10, 10) + >>> fig = plt.figure() + >>> matshow3d(volume, fig=fig, title="3D Volume") + >>> plt.show() + # create a figure of a list of channel-first 3D volumes + >>> volumes = [np.random.rand(1, 10, 10, 10), np.random.rand(1, 10, 10, 10)] + >>> fig = plt.figure() + >>> matshow3d(volumes, fig=fig, title="List of Volumes") + >>> plt.show() + + """ + vol: np.ndarray = convert_data_type(data=volume, output_type=np.ndarray)[0] # type: ignore + if channel_dim is not None: + if channel_dim not in [0, 1] or vol.shape[channel_dim] not in [1, 3, 4]: + raise ValueError("channel_dim must be: None, 0 or 1, and channels of image must be 1, 3 or 4.") + + if isinstance(vol, (list, tuple)): + # a sequence of channel-first volumes + if not isinstance(vol[0], np.ndarray): + raise ValueError("volume must be a list of arrays.") + pad_size = np.max(np.asarray([v.shape for v in vol]), axis=0) + pad = SpatialPad(pad_size[1:]) # assuming channel-first for item in vol + vol = np.concatenate([pad(v) for v in vol], axis=0) + else: # ndarray + while len(vol.shape) < 3: + vol = np.expand_dims(vol, 0) # so that we display 2d as well + + if channel_dim is not None: + vol = np.moveaxis(vol, frame_dim, -4) # move the expected dim to construct frames with `B` dim + vol = vol.reshape((-1, vol.shape[-3], vol.shape[-2], vol.shape[-1])) + else: + vol = np.moveaxis(vol, frame_dim, -3) + vol = vol.reshape((-1, vol.shape[-2], vol.shape[-1])) + vmin = np.nanmin(vol) if vmin is None else vmin + vmax = np.nanmax(vol) if vmax is None else vmax + + # subsample every_n-th frame of the 3D volume + vol = vol[:: max(every_n, 1)] + if not frames_per_row: + frames_per_row = int(np.ceil(np.sqrt(len(vol)))) + # create the grid of frames + cols = max(min(len(vol), frames_per_row), 1) + rows = int(np.ceil(len(vol) / cols)) + width = [[0, cols * rows - len(vol)]] + if channel_dim is not None: + width += [[0, 0]] # add pad width for the channel dim + width += [[margin, margin]] * 2 + vol = np.pad(vol.astype(dtype, copy=False), width, mode="constant", constant_values=fill_value) + im = np.block([[vol[i * cols + j] for j in range(cols)] for i in range(rows)]) + if channel_dim is not None: + # move channel dim to the end + im = np.moveaxis(im, 0, -1) + + # figure related configurations + if fig is None: + fig = plt.figure(tight_layout=True) + if not fig.axes: + fig.add_subplot(111) + ax = fig.axes[0] + ax.matshow(im, vmin=vmin, vmax=vmax, interpolation=interpolation, **kwargs) + ax.axis("off") + + if title is not None: + ax.set_title(title) + if figsize is not None: + fig.set_size_inches(figsize) + if show: + plt.show() + return fig, im + + +def blend_images( + image: NdarrayOrTensor, label: NdarrayOrTensor, alpha: float = 0.5, cmap: str = "hsv", rescale_arrays: bool = True +): + """ + Blend an image and a label. Both should have the shape CHW[D]. + The image may have C==1 or 3 channels (greyscale or RGB). + The label is expected to have C==1. + + Args: + image: the input image to blend with label data. + label: the input label to blend with image data. + alpha: when blending image and label, `alpha` is the weight for the image region mapping to `label != 0`, + and `1 - alpha` is the weight for the label region that `label != 0`, default to `0.5`. + cmap: specify colormap in the matplotlib, default to `hsv`, for more details, please refer to: + https://matplotlib.org/2.0.2/users/colormaps.html. + rescale_arrays: whether to rescale the array to [0, 1] first, default to `True`. + + """ + + if label.shape[0] != 1: + raise ValueError("Label should have 1 channel") + if image.shape[0] not in (1, 3): + raise ValueError("Image should have 1 or 3 channels") + # rescale arrays to [0, 1] if desired + if rescale_arrays: + image = rescale_array(image) + label = rescale_array(label) + # convert image to rgb (if necessary) and then rgba + if image.shape[0] == 1: + image = repeat(image, 3, axis=0) + + def get_label_rgb(cmap: str, label: NdarrayOrTensor): + _cmap = cm.get_cmap(cmap) + label_np: np.ndarray + label_np, *_ = convert_data_type(label, np.ndarray) # type: ignore + label_rgb_np = _cmap(label_np[0]) + label_rgb_np = np.moveaxis(label_rgb_np, -1, 0)[:3] + label_rgb, *_ = convert_to_dst_type(label_rgb_np, label) + return label_rgb + + label_rgb = get_label_rgb(cmap, label) + w_image = where(label == 0, 1.0, alpha) + w_label = where(label == 0, 0.0, 1 - alpha) + return w_image * image + w_label * label_rgb diff --git a/monai/visualize/visualizer.py b/monai/visualize/visualizer.py index bbb01f5c5e..5f19e4f63f 100644 --- a/monai/visualize/visualizer.py +++ b/monai/visualize/visualizer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/pyproject.toml b/pyproject.toml index 008af45f97..03e9f49ab5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,13 +2,13 @@ requires = [ "wheel", "setuptools", - "torch>=1.5", + "torch>=1.6", "ninja", ] [tool.black] line-length = 120 -target-version = ['py36', 'py37', 'py38'] +target-version = ['py37', 'py38', 'py39'] include = '\.pyi?$' exclude = ''' ( diff --git a/requirements-dev.txt b/requirements-dev.txt index 785454ad5d..fc752a7627 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,6 @@ # Full requirements for developments -r requirements-min.txt -pytorch-ignite==0.4.5 +pytorch-ignite==0.4.7 gdown>=3.6.4 scipy itk>=5.2 @@ -31,8 +31,14 @@ Sphinx==3.5.3 recommonmark==0.6.0 sphinx-autodoc-typehints==1.11.1 sphinx-rtd-theme==0.5.2 -cucim~=0.19.0; platform_system == "Linux" +cucim>=21.8.2; platform_system == "Linux" openslide-python==1.1.2 +imagecodecs; platform_system == "Linux" +tifffile; platform_system == "Linux" pandas requests einops +transformers +mlflow +matplotlib!=3.5.0 +tensorboardX diff --git a/requirements-min.txt b/requirements-min.txt index 5db219c840..195f6f49f4 100644 --- a/requirements-min.txt +++ b/requirements-min.txt @@ -1,5 +1,5 @@ # Requirements for minimal tests -r requirements.txt -setuptools>=50.3.0 +setuptools>=50.3.0,!=60.0.0 coverage>=5.5 parameterized diff --git a/requirements.txt b/requirements.txt index 5d96284307..e4ea34b5d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -torch>=1.5 +torch>=1.6 numpy>=1.17 diff --git a/runtests.sh b/runtests.sh index f10e888543..70ca8df5c2 100755 --- a/runtests.sh +++ b/runtests.sh @@ -1,6 +1,6 @@ #! /bin/bash -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -38,12 +38,14 @@ doNetTests=false doDryRun=false doZooTests=false doUnitTests=false +doBuild=false doBlackFormat=false doBlackFix=false doIsortFormat=false doIsortFix=false doFlake8Format=false doClangFormat=false +doCopyRight=false doPytypeFormat=false doMypyFormat=false doCleanup=false @@ -55,7 +57,8 @@ PY_EXE=${MONAI_PY_EXE:-$(which python)} function print_usage { echo "runtests.sh [--codeformat] [--autofix] [--black] [--isort] [--flake8] [--clangformat] [--pytype] [--mypy]" - echo " [--unittests] [--disttests] [--coverage] [--quick] [--min] [--net] [--dryrun] [-j number] [--clean] [--help] [--version]" + echo " [--unittests] [--disttests] [--coverage] [--quick] [--min] [--net] [--dryrun] [-j number] [--list_tests]" + echo " [--copyright] [--build] [--clean] [--help] [--version]" echo "" echo "MONAI unit testing utilities." echo "" @@ -86,10 +89,12 @@ function print_usage { echo " -q, --quick : skip long running unit tests and integration tests" echo " -m, --min : only run minimal unit tests which do not require optional packages" echo " --net : perform integration testing" + echo " -b, --build : compile and install the source code folder an editable release." echo " --list_tests : list unit tests and exit" echo "" echo "Misc. options:" echo " --dryrun : display the commands to the screen without running" + echo " --copyright : check whether every source code has a copyright header" echo " -f, --codeformat : shorthand to run all code style and static analysis tests" echo " -c, --clean : clean temporary files from tests and exit" echo " -h, --help : show this help message and exit" @@ -103,7 +108,7 @@ function print_usage { } function check_import { - echo "python: ${PY_EXE}" + echo "Python: ${PY_EXE}" ${cmdPrefix}${PY_EXE} -c "import monai" } @@ -238,6 +243,7 @@ do doFlake8Format=true doPytypeFormat=true doMypyFormat=true + doCopyRight=true ;; --disttests) doDistTests=true @@ -250,6 +256,7 @@ do doBlackFix=true doIsortFormat=true doBlackFormat=true + doCopyRight=true ;; --clangformat) doClangFormat=true @@ -270,6 +277,12 @@ do NUM_PARALLEL=$2 shift ;; + --copyright) + doCopyRight=true + ;; + -b|--build) + doBuild=true + ;; -c|--clean) doCleanup=true ;; @@ -314,6 +327,14 @@ else check_import fi +if [ $doBuild = true ] +then + echo "${separator}${blue}compile and install${noColor}" + # try to compile MONAI cpp + compile_cpp + + echo "${green}done! (to uninstall and clean up, please use \"./runtests.sh --clean\")${noColor}" +fi if [ $doCleanup = true ] then @@ -335,12 +356,33 @@ then exit fi -# try to compile MONAI cpp -compile_cpp - # unconditionally report on the state of monai print_version +if [ $doCopyRight = true ] +then + # check copyright headers + copyright_bad=0 + copyright_all=0 + while read -r fname; do + copyright_all=$((copyright_all + 1)) + if ! grep "http://www.apache.org/licenses/LICENSE-2.0" "$fname" > /dev/null; then + print_error_msg "Missing the license header in file: $fname" + copyright_bad=$((copyright_bad + 1)) + fi + done <<< "$(find "$(pwd)/monai" "$(pwd)/tests" -type f \ + ! -wholename "*_version.py" -and -name "*.py" -or -name "*.cpp" -or -name "*.cu" -or -name "*.h")" + if [[ ${copyright_bad} -eq 0 ]]; + then + echo "${green}Source code copyright headers checked ($copyright_all).${noColor}" + else + echo "Please add the licensing header to the file ($copyright_bad of $copyright_all files)." + echo " See also: https://github.com/Project-MONAI/MONAI/blob/dev/CONTRIBUTING.md#checking-the-coding-style" + echo "" + exit 1 + fi +fi + if [ $doIsortFormat = true ] then @@ -397,9 +439,9 @@ then if [ $doBlackFix = true ] then - ${cmdPrefix}${PY_EXE} -m black "$(pwd)" + ${cmdPrefix}${PY_EXE} -m black --skip-magic-trailing-comma "$(pwd)" else - ${cmdPrefix}${PY_EXE} -m black --check "$(pwd)" + ${cmdPrefix}${PY_EXE} -m black --skip-magic-trailing-comma --check "$(pwd)" fi black_status=$? @@ -535,7 +577,7 @@ if [ $doUnitTests = true ] then echo "${separator}${blue}unittests${noColor}" torch_validate - ${cmdPrefix}${cmd} ./tests/runner.py -p "test_((?!integration).)" + ${cmdPrefix}${cmd} ./tests/runner.py -p "^(?!test_integration).*(?= 3.6 +python_requires = >= 3.7 # for compiling and develop setup only # no need to specify the versions so that we could # compile for multiple targeted versions. @@ -24,7 +24,7 @@ setup_requires = torch ninja install_requires = - torch>=1.5 + torch>=1.6 numpy>=1.17 [options.extras_require] @@ -34,16 +34,22 @@ all = pillow tensorboard gdown>=3.6.4 - pytorch-ignite==0.4.5 + pytorch-ignite==0.4.7 torchvision itk>=5.2 tqdm>=4.47.0 lmdb psutil - cucim~=0.19.0 + cucim>=21.8.2 openslide-python==1.1.2 + tifffile + imagecodecs pandas einops + transformers + mlflow + matplotlib + tensorboardX nibabel = nibabel skimage = @@ -55,7 +61,7 @@ tensorboard = gdown = gdown>=3.6.4 ignite = - pytorch-ignite==0.4.5 + pytorch-ignite==0.4.7 torchvision = torchvision itk = @@ -67,22 +73,48 @@ lmdb = psutil = psutil cucim = - cucim~=0.19.0 + cucim>=21.8.2 openslide = openslide-python==1.1.2 +tifffile = + tifffile +imagecodecs = + imagecodecs pandas = pandas einops = einops +transformers = + transformers +mlflow = + mlflow +matplotlib = + matplotlib +tensorboardX = + tensorboardX + [flake8] select = B,C,E,F,N,P,T4,W,B9 max_line_length = 120 # C408 ignored because we like the dict keyword argument syntax # E501 is not flexible enough, we're using B950 instead ignore = - E203,E305,E402,E501,E721,E741,F821,F841,F999,W503,W504,C408,E302,W291,E303, - # N812 lowercase 'torch.nn.functional' imported as non lowercase 'F' - N812 + E203 + E305 + E402 + E501 + E721 + E741 + F821 + F841 + F999 + W503 + W504 + C408 + E302 + W291 + E303 + N812 # lowercase 'torch.nn.functional' imported as non lowercase 'F' per_file_ignores = __init__.py: F401 exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv,.venv,_version.py diff --git a/setup.py b/setup.py index eeaffb7823..83a60129e6 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -53,11 +53,7 @@ def torch_parallel_backend(): try: - match = re.search( - "^ATen parallel backend: (?P.*)$", - torch._C._parallel_info(), - re.MULTILINE, - ) + match = re.search("^ATen parallel backend: (?P.*)$", torch._C._parallel_info(), re.MULTILINE) if match is None: return None backend = match.group("backend") diff --git a/tests/__init__.py b/tests/__init__.py index 5093d1f72d..4639a58496 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/clang_format_utils.py b/tests/clang_format_utils.py index 41902eb272..1b13ce0ac3 100644 --- a/tests/clang_format_utils.py +++ b/tests/clang_format_utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,8 +34,8 @@ # This dictionary maps each platform to a relative path to a file containing its reference hash. # github/pytorch/pytorch/tree/63d62d3e44a0a4ec09d94f30381d49b78cc5b095/tools/clang_format_hash PLATFORM_TO_HASH = { - "Darwin": "b24cc8972344c4e01afbbae78d6a414f7638ff6f", - "Linux": "9073602de1c4e1748f2feea5a0782417b20e3043", + "Darwin": "1485a242a96c737ba7cdd9f259114f2201accdb46d87ac7a8650b1a814cd4d4d", + "Linux": "e1c8b97b919541a99e0a355df5c3f9e8abebc64259dbee6f8c68e1ef90582856", } # Directory and file paths for the clang-format binary. @@ -50,15 +50,15 @@ def get_and_check_clang_format(): """ # If the host platform is not in PLATFORM_TO_HASH, it is unsupported. if HOST_PLATFORM not in PLATFORM_TO_HASH: - print("Unsupported platform: {}".format(HOST_PLATFORM)) + print(f"Unsupported platform: {HOST_PLATFORM}") return False if HOST_PLATFORM not in PLATFORM_TO_CF_URL: - print("Unsupported platform: {}".format(HOST_PLATFORM)) + print(f"Unsupported platform: {HOST_PLATFORM}") return False try: download_url( - PLATFORM_TO_CF_URL[HOST_PLATFORM], CLANG_FORMAT_PATH, PLATFORM_TO_HASH[HOST_PLATFORM], hash_type="sha1" + PLATFORM_TO_CF_URL[HOST_PLATFORM], CLANG_FORMAT_PATH, PLATFORM_TO_HASH[HOST_PLATFORM], hash_type="sha256" ) except Exception as e: print(f"Download {CLANG_FORMAT_PATH} failed: {e}") @@ -69,7 +69,7 @@ def get_and_check_clang_format(): mode = os.stat(CLANG_FORMAT_PATH).st_mode mode |= stat.S_IXUSR os.chmod(CLANG_FORMAT_PATH, mode) - print("Using clang-format located at {}".format(CLANG_FORMAT_PATH)) + print(f"Using clang-format located at {CLANG_FORMAT_PATH}") return True diff --git a/tests/hvd_evenly_divisible_all_gather.py b/tests/hvd_evenly_divisible_all_gather.py index 42b2e9530d..f4c39e4e73 100644 --- a/tests/hvd_evenly_divisible_all_gather.py +++ b/tests/hvd_evenly_divisible_all_gather.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/min_tests.py b/tests/min_tests.py index 5b376d7b57..00f3e49850 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -33,8 +33,11 @@ def run_testsuit(): "test_cachedataset_parallel", "test_cachedataset_persistent_workers", "test_cachentransdataset", + "test_contrastive_loss", + "test_check_missing_files", "test_csv_dataset", "test_csv_iterable_dataset", + "test_cumulative_average_dist", "test_dataset", "test_dataset_summary", "test_deepedit_transforms", @@ -42,11 +45,14 @@ def run_testsuit(): "test_deepgrow_interaction", "test_deepgrow_transforms", "test_detect_envelope", + "test_dints_network", "test_efficientnet", "test_ensemble_evaluator", "test_ensure_channel_first", "test_ensure_channel_firstd", "test_fill_holes", + "test_fill_holesd", + "test_global_mutual_information_loss", "test_handler_checkpoint_loader", "test_handler_checkpoint_saver", "test_handler_classification_saver", @@ -61,6 +67,7 @@ def run_testsuit(): "test_handler_mean_dice", "test_handler_metrics_saver", "test_handler_metrics_saver_dist", + "test_handler_mlflow", "test_handler_nvtx", "test_handler_parameter_scheduler", "test_handler_post_processing", @@ -75,13 +82,13 @@ def run_testsuit(): "test_handler_surface_distance", "test_handler_tb_image", "test_handler_tb_stats", - "test_handler_transform_inverter", "test_handler_validation", "test_hausdorff_distance", "test_header_correct", "test_hilbert_transform", "test_image_dataset", "test_img2tensorboard", + "test_integration_fast_train", "test_integration_segmentation_3d", "test_integration_sliding_window", "test_integration_unet_2d", @@ -94,10 +101,12 @@ def run_testsuit(): "test_label_filter", "test_lltm", "test_lmdbdataset", + "test_lmdbdataset_dist", "test_load_image", "test_load_imaged", "test_load_spacing_orientation", "test_mednistdataset", + "test_milmodel", "test_mlp", "test_nifti_header_revise", "test_nifti_rw", @@ -112,6 +121,8 @@ def run_testsuit(): "test_plot_2d_or_3d_image", "test_png_rw", "test_png_saver", + "test_prepare_batch_default", + "test_prepare_batch_extra_input", "test_rand_rotate", "test_rand_rotated", "test_rand_zoom", @@ -132,11 +143,14 @@ def run_testsuit(): "test_testtimeaugmentation", "test_torchvision", "test_torchvisiond", + "test_transchex", "test_transformerblock", "test_unetr", "test_unetr_block", "test_vit", + "test_vitautoenc", "test_write_metrics_reports", + "test_wsireader", "test_zoom", "test_zoom_affine", "test_zoomd", diff --git a/tests/ngc_mmar_loading.py b/tests/ngc_mmar_loading.py new file mode 100644 index 0000000000..a371917e59 --- /dev/null +++ b/tests/ngc_mmar_loading.py @@ -0,0 +1,37 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import torch +from parameterized import parameterized + +from monai.apps.mmars import MODEL_DESC, load_from_mmar +from monai.config import print_debug_info + + +class TestAllDownloadingMMAR(unittest.TestCase): + def setUp(self): + print_debug_info() + self.test_dir = "./" + + @parameterized.expand((item,) for item in MODEL_DESC) + def test_loading_mmar(self, item): + pretrained_model = load_from_mmar(item=item, mmar_dir="./", map_location="cpu") + self.assertTrue(isinstance(pretrained_model, torch.nn.Module)) + + def tearDown(self): + print(os.listdir(self.test_dir)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/runner.py b/tests/runner.py index b340d60719..7356581365 100644 --- a/tests/runner.py +++ b/tests/runner.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_acn_block.py b/tests/test_acn_block.py new file mode 100644 index 0000000000..4c12155fd8 --- /dev/null +++ b/tests/test_acn_block.py @@ -0,0 +1,38 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.blocks.dints_block import ActiConvNormBlock + +TEST_CASES = [ + [{"in_channel": 32, "out_channel": 16, "kernel_size": 3, "padding": 1}, (7, 32, 16, 31, 7), (7, 16, 16, 31, 7)], + [ + {"in_channel": 32, "out_channel": 16, "kernel_size": 3, "padding": 1, "spatial_dims": 2}, + (7, 32, 13, 32), + (7, 16, 13, 32), + ], +] + + +class TestACNBlock(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_acn_block(self, input_param, input_shape, expected_shape): + net = ActiConvNormBlock(**input_param) + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_activations.py b/tests/test_activations.py index 7d8b3e4c38..a67e6f8cb6 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,27 +16,36 @@ from monai.networks.layers.factories import Act from monai.transforms import Activations - -TEST_CASE_1 = [ - {"sigmoid": True, "softmax": False, "other": None}, - torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), - torch.tensor([[[0.5000, 0.7311], [0.8808, 0.9526]]]), - (1, 2, 2), -] - -TEST_CASE_2 = [ - {"sigmoid": False, "softmax": True, "other": None}, - torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), - torch.tensor([[[0.1192, 0.1192]], [[0.8808, 0.8808]]]), - (2, 1, 2), -] - -TEST_CASE_3 = [ - {"sigmoid": False, "softmax": False, "other": torch.tanh}, - torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), - torch.tensor([[[0.0000, 0.7616], [0.9640, 0.9951]]]), - (1, 2, 2), -] +from tests.utils import TEST_NDARRAYS, assert_allclose + +TEST_CASES = [] +for p in TEST_NDARRAYS: + TEST_CASES.append( + [ + {"sigmoid": True, "softmax": False, "other": None}, + p([[[0.0, 1.0], [2.0, 3.0]]]), + p([[[0.5000, 0.7311], [0.8808, 0.9526]]]), + (1, 2, 2), + ] + ) + + TEST_CASES.append( + [ + {"sigmoid": False, "softmax": True, "other": None}, + p([[[0.0, 1.0]], [[2.0, 3.0]]]), + p([[[0.1192, 0.1192]], [[0.8808, 0.8808]]]), + (2, 1, 2), + ] + ) + + TEST_CASES.append( + [ + {"sigmoid": False, "softmax": False, "other": torch.tanh}, + p([[[0.0, 1.0], [2.0, 3.0]]]), + p([[[0.0000, 0.7616], [0.9640, 0.9951]]]), + (1, 2, 2), + ] + ) TEST_CASE_4 = [ "swish", @@ -67,12 +76,12 @@ class TestActivations(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TEST_CASES[:3]) def test_value_shape(self, input_param, img, out, expected_shape): result = Activations(**input_param)(img) def _compare(ret, out, shape): - torch.testing.assert_allclose(ret, out) + assert_allclose(ret, out, rtol=1e-3) self.assertTupleEqual(ret.shape, shape) if isinstance(result, (list, tuple)): diff --git a/tests/test_activationsd.py b/tests/test_activationsd.py index 355c50f389..557d68de90 100644 --- a/tests/test_activationsd.py +++ b/tests/test_activationsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,43 +15,46 @@ from parameterized import parameterized from monai.transforms import Activationsd - -TEST_CASE_1 = [ - {"keys": ["pred", "label"], "sigmoid": False, "softmax": [True, False], "other": None}, - {"pred": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), "label": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]])}, - { - "pred": torch.tensor([[[0.1192, 0.1192]], [[0.8808, 0.8808]]]), - "label": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), - }, - (2, 1, 2), -] - -TEST_CASE_2 = [ - {"keys": ["pred", "label"], "sigmoid": False, "softmax": False, "other": [torch.tanh, None]}, - {"pred": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), "label": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])}, - { - "pred": torch.tensor([[[0.0000, 0.7616], [0.9640, 0.9951]]]), - "label": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), - }, - (1, 2, 2), -] - -TEST_CASE_3 = [ - {"keys": "pred", "sigmoid": False, "softmax": False, "other": torch.tanh}, - {"pred": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])}, - {"pred": torch.tensor([[[0.0000, 0.7616], [0.9640, 0.9951]]])}, - (1, 2, 2), -] +from tests.utils import TEST_NDARRAYS, assert_allclose + +TEST_CASES = [] +for p in TEST_NDARRAYS: + TEST_CASES.append( + [ + {"keys": ["pred", "label"], "sigmoid": False, "softmax": [True, False], "other": None}, + {"pred": p([[[0.0, 1.0]], [[2.0, 3.0]]]), "label": p([[[0.0, 1.0]], [[2.0, 3.0]]])}, + {"pred": p([[[0.1192, 0.1192]], [[0.8808, 0.8808]]]), "label": p([[[0.0, 1.0]], [[2.0, 3.0]]])}, + (2, 1, 2), + ] + ) + + TEST_CASES.append( + [ + {"keys": ["pred", "label"], "sigmoid": False, "softmax": False, "other": [torch.tanh, None]}, + {"pred": p([[[0.0, 1.0], [2.0, 3.0]]]), "label": p([[[0.0, 1.0], [2.0, 3.0]]])}, + {"pred": p([[[0.0000, 0.7616], [0.9640, 0.9951]]]), "label": p([[[0.0, 1.0], [2.0, 3.0]]])}, + (1, 2, 2), + ] + ) + + TEST_CASES.append( + [ + {"keys": "pred", "sigmoid": False, "softmax": False, "other": torch.tanh}, + {"pred": p([[[0.0, 1.0], [2.0, 3.0]]])}, + {"pred": p([[[0.0000, 0.7616], [0.9640, 0.9951]]])}, + (1, 2, 2), + ] + ) class TestActivationsd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TEST_CASES) def test_value_shape(self, input_param, test_input, output, expected_shape): result = Activationsd(**input_param)(test_input) - torch.testing.assert_allclose(result["pred"], output["pred"]) + assert_allclose(result["pred"], output["pred"], rtol=1e-3) self.assertTupleEqual(result["pred"].shape, expected_shape) if "label" in result: - torch.testing.assert_allclose(result["label"], output["label"]) + assert_allclose(result["label"], output["label"], rtol=1e-3) self.assertTupleEqual(result["label"].shape, expected_shape) diff --git a/tests/test_adaptors.py b/tests/test_adaptors.py index 9bcd01feb7..3c351d9e20 100644 --- a/tests/test_adaptors.py +++ b/tests/test_adaptors.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_add_channeld.py b/tests/test_add_channeld.py index 8bdd89a4ae..9dc984aff3 100644 --- a/tests/test_add_channeld.py +++ b/tests/test_add_channeld.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_add_coordinate_channels.py b/tests/test_add_coordinate_channels.py index 3399008e02..5a483e25b9 100644 --- a/tests/test_add_coordinate_channels.py +++ b/tests/test_add_coordinate_channels.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,32 +12,36 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import AddCoordinateChannels +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [{"spatial_channels": (1, 2, 3)}, np.random.randint(0, 2, size=(1, 3, 3, 3)), (4, 3, 3, 3)] - -TEST_CASE_2 = [{"spatial_channels": (1,)}, np.random.randint(0, 2, size=(1, 3, 3, 3)), (2, 3, 3, 3)] - -TEST_CASE_ERROR_3 = [{"spatial_channels": (3,)}, np.random.randint(0, 2, size=(1, 3, 3))] - -TEST_CASE_ERROR_4 = [{"spatial_channels": (0, 1, 2)}, np.random.randint(0, 2, size=(1, 3, 3))] +TESTS, TEST_CASES_ERROR_1, TEST_CASES_ERROR_2 = [], [], [] +for p in TEST_NDARRAYS: + TESTS.append([{"spatial_dims": (0, 1, 2)}, p(np.random.randint(0, 2, size=(1, 3, 3, 3))), (4, 3, 3, 3)]) + TESTS.append([{"spatial_dims": (0,)}, p(np.random.randint(0, 2, size=(1, 3, 3, 3))), (2, 3, 3, 3)]) + TEST_CASES_ERROR_1.append([{"spatial_dims": (2,)}, p(np.random.randint(0, 2, size=(1, 3, 3)))]) + TEST_CASES_ERROR_2.append([{"spatial_dims": (-1, 0, 1)}, p(np.random.randint(0, 2, size=(1, 3, 3)))]) class TestAddCoordinateChannels(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input, expected_shape): result = AddCoordinateChannels(**input_param)(input) + self.assertEqual(type(result), type(input)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, input.device) self.assertEqual(list(result.shape), list(expected_shape)) - np.testing.assert_array_equal(input[0, ...], result[0, ...]) + assert_allclose(input[0, ...], result[0, ...]) - @parameterized.expand([TEST_CASE_ERROR_3]) + @parameterized.expand(TEST_CASES_ERROR_1) def test_max_channel(self, input_param, input): with self.assertRaises(ValueError): AddCoordinateChannels(**input_param)(input) - @parameterized.expand([TEST_CASE_ERROR_4]) + @parameterized.expand(TEST_CASES_ERROR_2) def test_channel_dim(self, input_param, input): with self.assertRaises(ValueError): AddCoordinateChannels(**input_param)(input) diff --git a/tests/test_add_coordinate_channelsd.py b/tests/test_add_coordinate_channelsd.py index 0fa6aae1c9..c14ff0ba64 100644 --- a/tests/test_add_coordinate_channelsd.py +++ b/tests/test_add_coordinate_channelsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,40 +12,50 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import AddCoordinateChannelsd +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"spatial_channels": (1, 2, 3), "keys": ["img"]}, - {"img": np.random.randint(0, 2, size=(1, 3, 3, 3))}, - (4, 3, 3, 3), -] +TESTS, TEST_CASES_ERROR_1, TEST_CASES_ERROR_2 = [], [], [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"spatial_dims": (0, 1, 2), "keys": ["img"]}, + {"img": p(np.random.randint(0, 2, size=(1, 3, 3, 3)))}, + (4, 3, 3, 3), + ] + ) + TESTS.append( + [{"spatial_dims": (0,), "keys": ["img"]}, {"img": p(np.random.randint(0, 2, size=(1, 3, 3, 3)))}, (2, 3, 3, 3)] + ) -TEST_CASE_2 = [ - {"spatial_channels": (1,), "keys": ["img"]}, - {"img": np.random.randint(0, 2, size=(1, 3, 3, 3))}, - (2, 3, 3, 3), -] - -TEST_CASE_ERROR_3 = [{"spatial_channels": (3,), "keys": ["img"]}, {"img": np.random.randint(0, 2, size=(1, 3, 3))}] - -TEST_CASE_ERROR_4 = [{"spatial_channels": (0, 1, 2), "keys": ["img"]}, {"img": np.random.randint(0, 2, size=(1, 3, 3))}] + TEST_CASES_ERROR_1.append( + [{"spatial_dims": (2,), "keys": ["img"]}, {"img": p(np.random.randint(0, 2, size=(1, 3, 3)))}] + ) + TEST_CASES_ERROR_2.append( + [{"spatial_dims": (-1, 0, 1), "keys": ["img"]}, {"img": p(np.random.randint(0, 2, size=(1, 3, 3)))}] + ) class TestAddCoordinateChannels(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input, expected_shape): - result = AddCoordinateChannelsd(**input_param)(input) - self.assertEqual(list(result["img"].shape), list(expected_shape)) - np.testing.assert_array_equal(input["img"][0, ...], result["img"][0, ...]) + result = AddCoordinateChannelsd(**input_param)(input)["img"] + input = input["img"] + self.assertEqual(type(result), type(input)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, input.device) + self.assertEqual(result.shape, expected_shape) + assert_allclose(input[0, ...], result[0, ...]) - @parameterized.expand([TEST_CASE_ERROR_3]) + @parameterized.expand(TEST_CASES_ERROR_1) def test_max_channel(self, input_param, input): with self.assertRaises(ValueError): AddCoordinateChannelsd(**input_param)(input) - @parameterized.expand([TEST_CASE_ERROR_4]) + @parameterized.expand(TEST_CASES_ERROR_2) def test_channel_dim(self, input_param, input): with self.assertRaises(ValueError): AddCoordinateChannelsd(**input_param)(input) diff --git a/tests/test_add_extreme_points_channel.py b/tests/test_add_extreme_points_channel.py index ecf2c83d3c..d2c8a627b6 100644 --- a/tests/test_add_extreme_points_channel.py +++ b/tests/test_add_extreme_points_channel.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,52 +15,63 @@ from parameterized import parameterized from monai.transforms import AddExtremePointsChannel +from tests.utils import TEST_NDARRAYS, assert_allclose IMG_CHANNEL = 3 +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + TESTS.append( + [ + { + "img": p(np.zeros((IMG_CHANNEL, 4, 3))), + "label": q(np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]])), + "sigma": 1.0, + "rescale_min": 0.0, + "rescale_max": 1.0, + }, + p( + np.array( + [ + [0.38318458, 0.98615628, 0.85551184], + [0.35422316, 0.94430935, 1.0], + [0.46000731, 0.57319659, 0.46000722], + [0.64577687, 0.38318464, 0.0], + ] + ) + ), + ] + ) -TEST_CASE_1 = [ - { - "img": np.zeros((IMG_CHANNEL, 4, 3)), - "label": np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]]), - "sigma": 1.0, - "rescale_min": 0.0, - "rescale_max": 1.0, - }, - np.array( - [ - [0.38318458, 0.98615628, 0.85551184], - [0.35422316, 0.94430935, 1.0], - [0.46000731, 0.57319659, 0.46000722], - [0.64577687, 0.38318464, 0.0], - ] - ), -] - -TEST_CASE_2 = [ - { - "img": np.zeros((IMG_CHANNEL, 4, 3)), - "label": np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]]), - "sigma": 1.0, - "rescale_min": 0.0, - "rescale_max": 1.0, - }, - np.array( - [ - [0.44628328, 0.80495411, 0.44628328], - [0.6779086, 1.0, 0.67790854], - [0.33002687, 0.62079221, 0.33002687], - [0.0, 0.31848389, 0.0], - ] - ), -] + TESTS.append( + [ + { + "img": p(np.zeros((IMG_CHANNEL, 4, 3))), + "label": q(np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]])), + "sigma": 1.0, + "rescale_min": 0.0, + "rescale_max": 1.0, + }, + p( + np.array( + [ + [0.44628328, 0.80495411, 0.44628328], + [0.6779086, 1.0, 0.67790854], + [0.33002687, 0.62079221, 0.33002687], + [0.0, 0.31848389, 0.0], + ] + ) + ), + ] + ) class TestAddExtremePointsChannel(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_correct_results(self, input_data, expected): add_extreme_points_channel = AddExtremePointsChannel() result = add_extreme_points_channel(**input_data) - np.testing.assert_allclose(result[IMG_CHANNEL], expected, rtol=1e-4) + assert_allclose(result[IMG_CHANNEL], expected, rtol=1e-4) if __name__ == "__main__": diff --git a/tests/test_add_extreme_points_channeld.py b/tests/test_add_extreme_points_channeld.py index e33bb0838c..39d221596f 100644 --- a/tests/test_add_extreme_points_channeld.py +++ b/tests/test_add_extreme_points_channeld.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,42 +15,60 @@ from parameterized import parameterized from monai.transforms import AddExtremePointsChanneld +from tests.utils import TEST_NDARRAYS, assert_allclose IMG_CHANNEL = 3 -TEST_CASE_1 = [ - {"img": np.zeros((IMG_CHANNEL, 4, 3)), "label": np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]])}, - np.array( - [ - [0.38318458, 0.98615628, 0.85551184], - [0.35422316, 0.94430935, 1.0], - [0.46000731, 0.57319659, 0.46000722], - [0.64577687, 0.38318464, 0.0], - ] - ), -] - -TEST_CASE_2 = [ - {"img": np.zeros((IMG_CHANNEL, 4, 3)), "label": np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]])}, - np.array( - [ - [0.44628328, 0.80495411, 0.44628328], - [0.6779086, 1.0, 0.67790854], - [0.33002687, 0.62079221, 0.33002687], - [0.0, 0.31848389, 0.0], - ] - ), -] +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + TESTS.append( + [ + { + "img": p(np.zeros((IMG_CHANNEL, 4, 3))), + "label": q(np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]])), + }, + p( + np.array( + [ + [0.38318458, 0.98615628, 0.85551184], + [0.35422316, 0.94430935, 1.0], + [0.46000731, 0.57319659, 0.46000722], + [0.64577687, 0.38318464, 0.0], + ] + ) + ), + ] + ) + + TESTS.append( + [ + { + "img": p(np.zeros((IMG_CHANNEL, 4, 3))), + "label": q(np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]])), + }, + p( + np.array( + [ + [0.44628328, 0.80495411, 0.44628328], + [0.6779086, 1.0, 0.67790854], + [0.33002687, 0.62079221, 0.33002687], + [0.0, 0.31848389, 0.0], + ] + ) + ), + ] + ) class TestAddExtremePointsChanneld(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_correct_results(self, input_data, expected): add_extreme_points_channel = AddExtremePointsChanneld( keys="img", label_key="label", sigma=1.0, rescale_min=0.0, rescale_max=1.0 ) result = add_extreme_points_channel(input_data) - np.testing.assert_allclose(result["img"][IMG_CHANNEL], expected, rtol=1e-4) + assert_allclose(result["img"][IMG_CHANNEL], expected, rtol=1e-4) if __name__ == "__main__": diff --git a/tests/test_adjust_contrast.py b/tests/test_adjust_contrast.py index 8e78698360..2f6c4e2259 100644 --- a/tests/test_adjust_contrast.py +++ b/tests/test_adjust_contrast.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import AdjustContrast -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose TEST_CASE_1 = [1.0] @@ -28,15 +28,16 @@ class TestAdjustContrast(NumpyImageTestCase2D): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_correct_results(self, gamma): adjuster = AdjustContrast(gamma=gamma) - result = adjuster(self.imt) - if gamma == 1.0: - expected = self.imt - else: - epsilon = 1e-7 - img_min = self.imt.min() - img_range = self.imt.max() - img_min - expected = np.power(((self.imt - img_min) / float(img_range + epsilon)), gamma) * img_range + img_min - np.testing.assert_allclose(expected, result, rtol=1e-05) + for p in TEST_NDARRAYS: + result = adjuster(p(self.imt)) + if gamma == 1.0: + expected = self.imt + else: + epsilon = 1e-7 + img_min = self.imt.min() + img_range = self.imt.max() - img_min + expected = np.power(((self.imt - img_min) / float(img_range + epsilon)), gamma) * img_range + img_min + assert_allclose(expected, result, rtol=1e-05, type_test=False) if __name__ == "__main__": diff --git a/tests/test_adjust_contrastd.py b/tests/test_adjust_contrastd.py index 65647607e4..a7224b643b 100644 --- a/tests/test_adjust_contrastd.py +++ b/tests/test_adjust_contrastd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import AdjustContrastd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose TEST_CASE_1 = [1.0] @@ -28,15 +28,16 @@ class TestAdjustContrastd(NumpyImageTestCase2D): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_correct_results(self, gamma): adjuster = AdjustContrastd("img", gamma=gamma) - result = adjuster({"img": self.imt}) - if gamma == 1.0: - expected = self.imt - else: - epsilon = 1e-7 - img_min = self.imt.min() - img_range = self.imt.max() - img_min - expected = np.power(((self.imt - img_min) / float(img_range + epsilon)), gamma) * img_range + img_min - np.testing.assert_allclose(expected, result["img"], rtol=1e-05) + for p in TEST_NDARRAYS: + result = adjuster({"img": p(self.imt)}) + if gamma == 1.0: + expected = self.imt + else: + epsilon = 1e-7 + img_min = self.imt.min() + img_range = self.imt.max() - img_min + expected = np.power(((self.imt - img_min) / float(img_range + epsilon)), gamma) * img_range + img_min + assert_allclose(expected, result["img"], rtol=1e-05, type_test=False) if __name__ == "__main__": diff --git a/tests/test_adn.py b/tests/test_adn.py index 2130ebc005..2352f5c1e2 100644 --- a/tests/test_adn.py +++ b/tests/test_adn.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_affine.py b/tests/test_affine.py index dd82d72e23..d681d2941b 100644 --- a/tests/test_affine.py +++ b/tests/test_affine.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,78 +16,150 @@ from parameterized import parameterized from monai.transforms import Affine +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(9).reshape((1, 3, 3)), "spatial_size": (-1, 0)}, - np.arange(9).reshape(1, 3, 3), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None, image_only=True), - {"img": np.arange(9).reshape((1, 3, 3)), "spatial_size": (-1, 0)}, - np.arange(9).reshape(1, 3, 3), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.arange(4).reshape(1, 2, 2), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2)), "spatial_size": (4, 4)}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(rotate_params=[np.pi / 2], padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2)), "spatial_size": (4, 4)}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(27).reshape((1, 3, 3, 3)), "spatial_size": (-1, 0, 0)}, - np.arange(27).reshape(1, 3, 3, 3), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(8).reshape((1, 2, 2, 2)), "spatial_size": (4, 4, 4)}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 4.0, 5.0, 0.0], [0.0, 6.0, 7.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(9).reshape((1, 3, 3))), "spatial_size": (-1, 0)}, + p(np.arange(9).reshape(1, 3, 3)), ] - ), - ], - [ - dict(rotate_params=[np.pi / 2], padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(8).reshape((1, 2, 2, 2)), "spatial_size": (4, 4, 4)}, - np.array( + ) + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 6.0, 4.0, 0.0], [0.0, 7.0, 5.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(padding_mode="zeros", device=device, image_only=True), + {"img": p(np.arange(9).reshape((1, 3, 3))), "spatial_size": (-1, 0)}, + p(np.arange(9).reshape(1, 3, 3)), ] - ), - ], -] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.arange(4).reshape(1, 2, 2)), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2))), "spatial_size": (4, 4)}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(rotate_params=[np.pi / 2], padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2))), "spatial_size": (4, 4)}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict( + affine=p(torch.tensor([[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]])), + padding_mode="zeros", + device=device, + ), + {"img": p(np.arange(4).reshape((1, 2, 2))), "spatial_size": (4, 4)}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(27).reshape((1, 3, 3, 3))), "spatial_size": (-1, 0, 0)}, + p(np.arange(27).reshape(1, 3, 3, 3)), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(8).reshape((1, 2, 2, 2))), "spatial_size": (4, 4, 4)}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 2.0, 3.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 4.0, 5.0, 0.0], + [0.0, 6.0, 7.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + dict(rotate_params=[np.pi / 2], padding_mode="zeros", device=device), + {"img": p(np.arange(8).reshape((1, 2, 2, 2))), "spatial_size": (4, 4, 4)}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 0.0, 0.0], + [0.0, 3.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 6.0, 4.0, 0.0], + [0.0, 7.0, 5.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) class TestAffine(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_affine(self, input_param, input_data, expected_val): g = Affine(**input_param) result = g(**input_data) if isinstance(result, tuple): result = result[0] - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_affine_grid.py b/tests/test_affine_grid.py index 24772b9a21..6f6364feda 100644 --- a/tests/test_affine_grid.py +++ b/tests/test_affine_grid.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,88 +16,130 @@ from parameterized import parameterized from monai.transforms import AffineGrid +from tests.utils import TEST_NDARRAYS, assert_allclose, is_tf32_env -TEST_CASES = [ - [ - {"as_tensor_output": False, "device": torch.device("cpu:0")}, - {"spatial_size": (2, 2)}, - np.array([[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]], [[1.0, 1.0], [1.0, 1.0]]]), - ], - [ - {"as_tensor_output": True, "device": None}, - {"spatial_size": (2, 2)}, - torch.tensor([[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]], [[1.0, 1.0], [1.0, 1.0]]]), - ], - [{"as_tensor_output": False, "device": None}, {"grid": np.ones((3, 3, 3))}, np.ones((3, 3, 3))], - [{"as_tensor_output": True, "device": torch.device("cpu:0")}, {"grid": np.ones((3, 3, 3))}, torch.ones((3, 3, 3))], - [{"as_tensor_output": False, "device": None}, {"grid": torch.ones((3, 3, 3))}, np.ones((3, 3, 3))], - [ - {"as_tensor_output": True, "device": torch.device("cpu:0")}, - {"grid": torch.ones((3, 3, 3))}, - torch.ones((3, 3, 3)), - ], - [ - { - "rotate_params": (1.0, 1.0), - "scale_params": (-20, 10), - "as_tensor_output": True, - "device": torch.device("cpu:0"), - }, - {"grid": torch.ones((3, 3, 3))}, - torch.tensor( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [[-19.2208, -19.2208, -19.2208], [-19.2208, -19.2208, -19.2208], [-19.2208, -19.2208, -19.2208]], - [[-11.4264, -11.4264, -11.4264], [-11.4264, -11.4264, -11.4264], [-11.4264, -11.4264, -11.4264]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + {"device": device}, + {"spatial_size": (2, 2)}, + np.array([[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]], [[1.0, 1.0], [1.0, 1.0]]]), ] - ), - ], - [ - { - "rotate_params": (1.0, 1.0, 1.0), - "scale_params": (-20, 10), - "as_tensor_output": True, - "device": torch.device("cpu:0"), - }, - {"grid": torch.ones((4, 3, 3, 3))}, - torch.tensor( + ) + + TESTS.append([{"device": device}, {"grid": p(np.ones((3, 3, 3)))}, p(np.ones((3, 3, 3)))]) + TESTS.append([{"device": device}, {"grid": p(torch.ones((3, 3, 3)))}, p(np.ones((3, 3, 3)))]) + TESTS.append( + [ + {"rotate_params": (1.0, 1.0), "scale_params": (-20, 10), "device": device}, + {"grid": p(torch.ones((3, 3, 3)))}, + p( + torch.tensor( + [ + [ + [-19.2208, -19.2208, -19.2208], + [-19.2208, -19.2208, -19.2208], + [-19.2208, -19.2208, -19.2208], + ], + [ + [-11.4264, -11.4264, -11.4264], + [-11.4264, -11.4264, -11.4264], + [-11.4264, -11.4264, -11.4264], + ], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ) + ), + ] + ) + TESTS.append( [ - [ - [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], - [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], - [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], - ], - [ - [[-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381]], - [[-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381]], - [[-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381]], - ], - [ - [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], - [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], - [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], - ], - [ - [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], - [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], - [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], - ], + { + "affine": p( + torch.tensor( + [[-10.8060, -8.4147, 0.0000], [-16.8294, 5.4030, 0.0000], [0.0000, 0.0000, 1.0000]] + ) + ) + }, + {"grid": p(torch.ones((3, 3, 3)))}, + p( + torch.tensor( + [ + [ + [-19.2208, -19.2208, -19.2208], + [-19.2208, -19.2208, -19.2208], + [-19.2208, -19.2208, -19.2208], + ], + [ + [-11.4264, -11.4264, -11.4264], + [-11.4264, -11.4264, -11.4264], + [-11.4264, -11.4264, -11.4264], + ], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ) + ), ] - ), - ], -] + ) + TESTS.append( + [ + {"rotate_params": (1.0, 1.0, 1.0), "scale_params": (-20, 10), "device": device}, + {"grid": p(torch.ones((4, 3, 3, 3)))}, + p( + torch.tensor( + [ + [ + [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], + [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], + [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], + ], + [ + [ + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + ], + [ + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + ], + [ + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + ], + ], + [ + [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], + [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], + [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], + ], + [ + [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], + [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], + [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], + ], + ] + ) + ), + ] + ) + + +_rtol = 5e-2 if is_tf32_env() else 1e-4 class TestAffineGrid(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_affine_grid(self, input_param, input_data, expected_val): g = AffineGrid(**input_param) result, _ = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + if "device" in input_data: + self.assertEqual(result.device, input_data[device]) + assert_allclose(result, expected_val, type_test=False, rtol=_rtol) if __name__ == "__main__": diff --git a/tests/test_affine_transform.py b/tests/test_affine_transform.py index 42af58be73..5170ab4260 100644 --- a/tests/test_affine_transform.py +++ b/tests/test_affine_transform.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,6 +17,9 @@ from monai.networks import normalize_transform, to_norm_affine from monai.networks.layers import AffineTransform +from tests.utils import is_tf32_env + +_rtol = 1e-4 if not is_tf32_env() else 5e-3 TEST_NORM_CASES = [ [(4, 5), True, [[[0.666667, 0, -1], [0, 0.5, -1], [0, 0, 1]]]], @@ -95,7 +98,7 @@ def test_to_norm_affine(self, affine, src_size, dst_size, align_corners, expecte affine = torch.as_tensor(affine, device=torch.device("cuda:0"), dtype=torch.float32) new_affine = to_norm_affine(affine, src_size, dst_size, align_corners) new_affine = new_affine.detach().cpu().numpy() - np.testing.assert_allclose(new_affine, expected, atol=1e-4) + np.testing.assert_allclose(new_affine, expected, atol=1e-5, rtol=_rtol) @parameterized.expand(TEST_ILL_TO_NORM_AFFINE_CASES) def test_to_norm_affine_ill(self, affine, src_size, dst_size, align_corners): @@ -113,7 +116,7 @@ def test_affine_shift(self): out = AffineTransform()(image, affine) out = out.detach().cpu().numpy() expected = [[[[0, 4, 1, 3], [0, 7, 6, 8], [0, 3, 5, 3]]]] - np.testing.assert_allclose(out, expected, atol=1e-5) + np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol) def test_affine_shift_1(self): affine = torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, -1.0]]) @@ -121,7 +124,7 @@ def test_affine_shift_1(self): out = AffineTransform()(image, affine) out = out.detach().cpu().numpy() expected = [[[[0, 0, 0, 0], [0, 4, 1, 3], [0, 7, 6, 8]]]] - np.testing.assert_allclose(out, expected, atol=1e-5) + np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol) def test_affine_shift_2(self): affine = torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, 0.0]]) @@ -129,28 +132,28 @@ def test_affine_shift_2(self): out = AffineTransform()(image, affine) out = out.detach().cpu().numpy() expected = [[[[0, 0, 0, 0], [4, 1, 3, 2], [7, 6, 8, 5]]]] - np.testing.assert_allclose(out, expected, atol=1e-5) + np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol) def test_zoom(self): affine = torch.as_tensor([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]]) image = torch.arange(1.0, 13.0).view(1, 1, 3, 4).to(device=torch.device("cpu:0")) out = AffineTransform((3, 2))(image, affine) expected = [[[[1, 3], [5, 7], [9, 11]]]] - np.testing.assert_allclose(out, expected, atol=1e-5) + np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol) def test_zoom_1(self): affine = torch.as_tensor([[2.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) image = torch.arange(1.0, 13.0).view(1, 1, 3, 4).to(device=torch.device("cpu:0")) out = AffineTransform()(image, affine, (1, 4)) expected = [[[[1, 2, 3, 4]]]] - np.testing.assert_allclose(out, expected, atol=1e-5) + np.testing.assert_allclose(out, expected, atol=_rtol) def test_zoom_2(self): affine = torch.as_tensor([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0]], dtype=torch.float32) image = torch.arange(1.0, 13.0).view(1, 1, 3, 4).to(device=torch.device("cpu:0")) out = AffineTransform((1, 2))(image, affine) expected = [[[[1, 3]]]] - np.testing.assert_allclose(out, expected, atol=1e-5) + np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol) def test_affine_transform_minimum(self): t = np.pi / 3 @@ -169,7 +172,7 @@ def test_affine_transform_minimum(self): ] ] ] - np.testing.assert_allclose(out, expected, atol=1e-5) + np.testing.assert_allclose(out, expected, atol=1e-3, rtol=_rtol) def test_affine_transform_2d(self): t = np.pi / 3 @@ -188,7 +191,7 @@ def test_affine_transform_2d(self): ] ] ] - np.testing.assert_allclose(out, expected, atol=1e-5) + np.testing.assert_allclose(out, expected, atol=1e-3, rtol=_rtol) if torch.cuda.is_available(): affine = torch.as_tensor(affine, device=torch.device("cuda:0"), dtype=torch.float32) @@ -205,7 +208,7 @@ def test_affine_transform_2d(self): ] ] ] - np.testing.assert_allclose(out, expected, atol=1e-4) + np.testing.assert_allclose(out, expected, atol=5e-3) def test_affine_transform_3d(self): t = np.pi / 3 @@ -231,7 +234,7 @@ def test_affine_transform_3d(self): ] ], ] - np.testing.assert_allclose(out, expected, atol=1e-4) + np.testing.assert_allclose(out, expected, atol=1e-4, rtol=_rtol) if torch.cuda.is_available(): affine = torch.as_tensor(affine, device=torch.device("cuda:0"), dtype=torch.float32) @@ -255,7 +258,7 @@ def test_affine_transform_3d(self): ] ], ] - np.testing.assert_allclose(out, expected, atol=1e-4) + np.testing.assert_allclose(out, expected, atol=5e-3) def test_ill_affine_transform(self): with self.assertRaises(ValueError): # image too small diff --git a/tests/test_affined.py b/tests/test_affined.py index 850f12905d..355833b858 100644 --- a/tests/test_affined.py +++ b/tests/test_affined.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,85 +16,145 @@ from parameterized import parameterized from monai.transforms import Affined +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - dict(keys="img", padding_mode="zeros", as_tensor_output=False, spatial_size=(-1, 0), device=None), - {"img": np.arange(9).reshape((1, 3, 3))}, - np.arange(9).reshape(1, 3, 3), - ], - [ - dict(keys="img", padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.arange(4).reshape(1, 2, 2), - ], - [ - dict(keys="img", padding_mode="zeros", spatial_size=(4, 4), as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict( - keys="img", - rotate_params=[np.pi / 2], - padding_mode="zeros", - spatial_size=(4, 4), - as_tensor_output=False, - device=None, - ), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0, 0), as_tensor_output=False, device=None), - {"img": np.arange(27).reshape((1, 3, 3, 3))}, - np.arange(27).reshape(1, 3, 3, 3), - ], - [ - dict(keys="img", padding_mode="zeros", spatial_size=(4, 4, 4), as_tensor_output=False, device=None), - {"img": np.arange(8).reshape((1, 2, 2, 2))}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 4.0, 5.0, 0.0], [0.0, 6.0, 7.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0), device=device), + {"img": p(np.arange(9).reshape((1, 3, 3)))}, + p(np.arange(9).reshape(1, 3, 3)), ] - ), - ], - [ - dict( - keys="img", - rotate_params=[np.pi / 2], - padding_mode="zeros", - spatial_size=(4, 4, 4), - as_tensor_output=False, - device=None, - ), - {"img": np.arange(8).reshape((1, 2, 2, 2))}, - np.array( + ) + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 6.0, 4.0, 0.0], [0.0, 7.0, 5.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(keys="img", padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.arange(4).reshape(1, 2, 2)), ] - ), - ], -] + ) + TESTS.append( + [ + dict(keys="img", padding_mode="zeros", spatial_size=(4, 4), device=device), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(keys="img", rotate_params=[np.pi / 2], padding_mode="zeros", spatial_size=(4, 4), device=device), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict( + keys="img", + affine=p(torch.tensor([[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]])), + padding_mode="zeros", + spatial_size=(4, 4), + device=device, + ), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0, 0), device=device), + {"img": p(np.arange(27).reshape((1, 3, 3, 3)))}, + p(np.arange(27).reshape(1, 3, 3, 3)), + ] + ) + TESTS.append( + [ + dict(keys="img", padding_mode="zeros", spatial_size=(4, 4, 4), device=device), + {"img": p(np.arange(8).reshape((1, 2, 2, 2)))}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 2.0, 3.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 4.0, 5.0, 0.0], + [0.0, 6.0, 7.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + dict( + keys="img", rotate_params=[np.pi / 2], padding_mode="zeros", spatial_size=(4, 4, 4), device=device + ), + {"img": p(np.arange(8).reshape((1, 2, 2, 2)))}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 0.0, 0.0], + [0.0, 3.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 6.0, 4.0, 0.0], + [0.0, 7.0, 5.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) class TestAffined(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_affine(self, input_param, input_data, expected_val): g = Affined(**input_param) result = g(input_data)["img"] - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_ahnet.py b/tests/test_ahnet.py index 777e2637a7..dba8eaf72b 100644 --- a/tests/test_ahnet.py +++ b/tests/test_ahnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -162,26 +162,14 @@ def test_mcfcn_shape(self, input_param, input_shape, expected_shape): class TestAHNET(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_AHNET_2D_1, - TEST_CASE_AHNET_2D_2, - TEST_CASE_AHNET_2D_3, - ] - ) + @parameterized.expand([TEST_CASE_AHNET_2D_1, TEST_CASE_AHNET_2D_2, TEST_CASE_AHNET_2D_3]) def test_ahnet_shape_2d(self, input_param, input_shape, expected_shape): net = AHNet(**input_param).to(device) with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) - @parameterized.expand( - [ - TEST_CASE_AHNET_3D_1, - TEST_CASE_AHNET_3D_2, - TEST_CASE_AHNET_3D_3, - ] - ) + @parameterized.expand([TEST_CASE_AHNET_3D_1, TEST_CASE_AHNET_3D_2, TEST_CASE_AHNET_3D_3]) @skip_if_quick def test_ahnet_shape_3d(self, input_param, input_shape, expected_shape): net = AHNet(**input_param).to(device) @@ -203,11 +191,7 @@ def test_script(self): class TestAHNETWithPretrain(unittest.TestCase): @parameterized.expand( - [ - TEST_CASE_AHNET_3D_WITH_PRETRAIN_1, - TEST_CASE_AHNET_3D_WITH_PRETRAIN_2, - TEST_CASE_AHNET_3D_WITH_PRETRAIN_3, - ] + [TEST_CASE_AHNET_3D_WITH_PRETRAIN_1, TEST_CASE_AHNET_3D_WITH_PRETRAIN_2, TEST_CASE_AHNET_3D_WITH_PRETRAIN_3] ) def test_ahnet_shape(self, input_param, input_shape, expected_shape, fcn_input_param): net = AHNet(**input_param).to(device) diff --git a/tests/test_alias.py b/tests/test_alias.py new file mode 100644 index 0000000000..49f9fa56fe --- /dev/null +++ b/tests/test_alias.py @@ -0,0 +1,39 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import inspect +import os +import unittest + +from monai.utils import optional_import + + +class TestModuleAlias(unittest.TestCase): + """check that 'import monai.xx.file_name' returns a module""" + + def test_files(self): + src_dir = os.path.dirname(os.path.dirname(__file__)) + monai_dir = os.path.join(src_dir, "monai") + py_files = glob.glob(os.path.join(monai_dir, "**", "*.py"), recursive=True) + for x in py_files: + if os.path.basename(x).startswith("_"): + continue + mod_name = x[len(src_dir) : -3] # create relative path + mod_name = mod_name[1:].replace(mod_name[0], ".") + mod, cls = mod_name.rsplit(".", 1) + obj, exist = optional_import(mod, name=cls) + if exist: + self.assertTrue(inspect.ismodule(obj), msg=mod_name) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_apply_filter.py b/tests/test_apply_filter.py new file mode 100644 index 0000000000..ff8480a02f --- /dev/null +++ b/tests/test_apply_filter.py @@ -0,0 +1,87 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from monai.networks.layers import apply_filter + + +class ApplyFilterTestCase(unittest.TestCase): + def test_1d(self): + a = torch.tensor([[list(range(10))]], dtype=torch.float) + out = apply_filter(a, torch.tensor([-1, 0, 1]), stride=1) + expected = np.array([[[1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, -8.0]]]) + np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4) + if torch.cuda.is_available(): + out = apply_filter(a.cuda(), torch.tensor([-1, 0, 1]).cuda()) + np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4) + + def test_2d(self): + a = torch.tensor([[[list(range(7)), list(range(7, 0, -1)), list(range(7))]]], dtype=torch.float) + expected = np.array( + [ + [14.0, 21.0, 21.0, 21.0, 21.0, 21.0, 14.0], + [15.0, 24.0, 27.0, 30.0, 33.0, 36.0, 25.0], + [14.0, 21.0, 21.0, 21.0, 21.0, 21.0, 14.0], + ] + ) + expected = expected[None][None] + out = apply_filter(a, torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]])) + np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4) + if torch.cuda.is_available(): + out = apply_filter(a.cuda(), torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]).cuda()) + np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4) + + def test_3d(self): + a = torch.tensor( + [[list(range(7)), list(range(7)), list(range(7))], [list(range(7)), list(range(7)), list(range(7))]], + dtype=torch.float, + ) + a = a[None][None] + a = a.expand(2, 3, -1, -1, -1) + expected = np.array( + [ + [ + [2.0, 6.0, 12.0, 18.0, 24.0, 30.0, 22.0], + [3.0, 9.0, 18.0, 27.0, 36.0, 45.0, 33.0], + [2.0, 6.0, 12.0, 18.0, 24.0, 30.0, 22.0], + ], + [ + [2.0, 6.0, 12.0, 18.0, 24.0, 30.0, 22.0], + [3.0, 9.0, 18.0, 27.0, 36.0, 45.0, 33.0], + [2.0, 6.0, 12.0, 18.0, 24.0, 30.0, 22.0], + ], + ] + ) + expected = expected + # testing shapes + k = torch.tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]]]) + for kernel in (k, k[None], k[None][None]): + out = apply_filter(a, kernel) + np.testing.assert_allclose(out.cpu().numpy()[1][2], expected, rtol=1e-4) + if torch.cuda.is_available(): + out = apply_filter(a.cuda(), kernel.cuda()) + np.testing.assert_allclose(out.cpu().numpy()[0][1], expected, rtol=1e-4) + + def test_wrong_args(self): + with self.assertRaisesRegex(ValueError, ""): + apply_filter(torch.ones((1, 2, 3, 2)), torch.ones((2,))) + with self.assertRaisesRegex(NotImplementedError, ""): + apply_filter(torch.ones((1, 1, 1, 2, 3, 2)), torch.ones((2,))) + with self.assertRaisesRegex(TypeError, ""): + apply_filter(((1, 1, 1, 2, 3, 2)), torch.ones((2,))) # type: ignore + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_arraydataset.py b/tests/test_arraydataset.py index f6459cc88c..ee1a92cf97 100644 --- a/tests/test_arraydataset.py +++ b/tests/test_arraydataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_as_channel_first.py b/tests/test_as_channel_first.py index 0d1b1c7d3a..a2d56295b8 100644 --- a/tests/test_as_channel_first.py +++ b/tests/test_as_channel_first.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,7 +34,7 @@ def test_value(self, in_type, input_param, expected_shape): if isinstance(test_data, torch.Tensor): test_data = test_data.cpu().numpy() expected = np.moveaxis(test_data, input_param["channel_dim"], 0) - assert_allclose(expected, result) + assert_allclose(result, expected, type_test=False) if __name__ == "__main__": diff --git a/tests/test_as_channel_firstd.py b/tests/test_as_channel_firstd.py index 68d33434c1..91086f9299 100644 --- a/tests/test_as_channel_firstd.py +++ b/tests/test_as_channel_firstd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_as_channel_last.py b/tests/test_as_channel_last.py index 55a7a08676..e6446ab7a6 100644 --- a/tests/test_as_channel_last.py +++ b/tests/test_as_channel_last.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_as_channel_lastd.py b/tests/test_as_channel_lastd.py index 350f639f3f..a6d94d216a 100644 --- a/tests/test_as_channel_lastd.py +++ b/tests/test_as_channel_lastd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index bb9457a357..a68e6431ec 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,52 +11,62 @@ import unittest -import torch from parameterized import parameterized from monai.transforms import AsDiscrete +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"argmax": True, "to_onehot": False, "num_classes": None, "threshold_values": False, "logit_thresh": 0.5}, - torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), - torch.tensor([[[1.0, 1.0]]]), - (1, 1, 2), -] +TEST_CASES = [] +for p in TEST_NDARRAYS: + TEST_CASES.append( + [ + {"argmax": True, "to_onehot": None, "threshold": 0.5}, + p([[[0.0, 1.0]], [[2.0, 3.0]]]), + p([[[1.0, 1.0]]]), + (1, 1, 2), + ] + ) -TEST_CASE_2 = [ - {"argmax": True, "to_onehot": True, "num_classes": 2, "threshold_values": False, "logit_thresh": 0.5}, - torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), - torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]]), - (2, 1, 2), -] + TEST_CASES.append( + [ + {"argmax": True, "to_onehot": 2, "threshold": 0.5}, + p([[[0.0, 1.0]], [[2.0, 3.0]]]), + p([[[0.0, 0.0]], [[1.0, 1.0]]]), + (2, 1, 2), + ] + ) -TEST_CASE_3 = [ - {"argmax": False, "to_onehot": False, "num_classes": None, "threshold_values": True, "logit_thresh": 0.6}, - torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), - torch.tensor([[[0.0, 1.0], [1.0, 1.0]]]), - (1, 2, 2), -] + TEST_CASES.append( + [ + {"argmax": False, "to_onehot": None, "threshold": 0.6}, + p([[[0.0, 1.0], [2.0, 3.0]]]), + p([[[0.0, 1.0], [1.0, 1.0]]]), + (1, 2, 2), + ] + ) -TEST_CASE_4 = [ - {"argmax": False, "to_onehot": True, "num_classes": 3}, - torch.tensor(1), - torch.tensor([0.0, 1.0, 0.0]), - (3,), -] + # test threshold = 0.0 + TEST_CASES.append( + [ + {"argmax": False, "to_onehot": None, "threshold": 0.0}, + p([[[0.0, -1.0], [-2.0, 3.0]]]), + p([[[1.0, 0.0], [0.0, 1.0]]]), + (1, 2, 2), + ] + ) -TEST_CASE_5 = [ - {"rounding": "torchrounding"}, - torch.tensor([[[0.123, 1.345], [2.567, 3.789]]]), - torch.tensor([[[0.0, 1.0], [3.0, 4.0]]]), - (1, 2, 2), -] + TEST_CASES.append([{"argmax": False, "to_onehot": 3}, p(1), p([0.0, 1.0, 0.0]), (3,)]) + + TEST_CASES.append( + [{"rounding": "torchrounding"}, p([[[0.123, 1.345], [2.567, 3.789]]]), p([[[0.0, 1.0], [3.0, 4.0]]]), (1, 2, 2)] + ) class TestAsDiscrete(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TEST_CASES) def test_value_shape(self, input_param, img, out, expected_shape): result = AsDiscrete(**input_param)(img) - torch.testing.assert_allclose(result, out) + assert_allclose(result, out, rtol=1e-3) self.assertTupleEqual(result.shape, expected_shape) diff --git a/tests/test_as_discreted.py b/tests/test_as_discreted.py index 90e98b297b..21825c2d6c 100644 --- a/tests/test_as_discreted.py +++ b/tests/test_as_discreted.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,69 +11,84 @@ import unittest -import torch from parameterized import parameterized from monai.transforms import AsDiscreted +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - { - "keys": ["pred", "label"], - "argmax": [True, False], - "to_onehot": True, - "num_classes": 2, - "threshold_values": False, - "logit_thresh": 0.5, - }, - {"pred": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), "label": torch.tensor([[[0, 1]]])}, - {"pred": torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]]), "label": torch.tensor([[[1.0, 0.0]], [[0.0, 1.0]]])}, - (2, 1, 2), -] +TEST_CASES = [] +for p in TEST_NDARRAYS: + TEST_CASES.append( + [ + {"keys": ["pred", "label"], "argmax": [True, False], "to_onehot": 2, "threshold": 0.5}, + {"pred": p([[[0.0, 1.0]], [[2.0, 3.0]]]), "label": p([[[0, 1]]])}, + {"pred": p([[[0.0, 0.0]], [[1.0, 1.0]]]), "label": p([[[1.0, 0.0]], [[0.0, 1.0]]])}, + (2, 1, 2), + ] + ) -TEST_CASE_2 = [ - { - "keys": ["pred", "label"], - "argmax": False, - "to_onehot": False, - "num_classes": None, - "threshold_values": [True, False], - "logit_thresh": 0.6, - }, - {"pred": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), "label": torch.tensor([[[0, 1], [1, 1]]])}, - {"pred": torch.tensor([[[0.0, 1.0], [1.0, 1.0]]]), "label": torch.tensor([[[0.0, 1.0], [1.0, 1.0]]])}, - (1, 2, 2), -] + TEST_CASES.append( + [ + {"keys": ["pred", "label"], "argmax": False, "to_onehot": None, "threshold": [0.6, None]}, + {"pred": p([[[0.0, 1.0], [2.0, 3.0]]]), "label": p([[[0, 1], [1, 1]]])}, + {"pred": p([[[0.0, 1.0], [1.0, 1.0]]]), "label": p([[[0.0, 1.0], [1.0, 1.0]]])}, + (1, 2, 2), + ] + ) -TEST_CASE_3 = [ - { - "keys": ["pred"], - "argmax": True, - "to_onehot": True, - "num_classes": 2, - "threshold_values": False, - "logit_thresh": 0.5, - }, - {"pred": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]])}, - {"pred": torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]])}, - (2, 1, 2), -] + TEST_CASES.append( + [ + {"keys": ["pred"], "argmax": True, "to_onehot": 2, "threshold": 0.5}, + {"pred": p([[[0.0, 1.0]], [[2.0, 3.0]]])}, + {"pred": p([[[0.0, 0.0]], [[1.0, 1.0]]])}, + (2, 1, 2), + ] + ) -TEST_CASE_4 = [ - {"keys": "pred", "rounding": "torchrounding"}, - {"pred": torch.tensor([[[0.123, 1.345], [2.567, 3.789]]])}, - {"pred": torch.tensor([[[0.0, 1.0], [3.0, 4.0]]])}, - (1, 2, 2), -] + TEST_CASES.append( + [ + {"keys": "pred", "rounding": "torchrounding"}, + {"pred": p([[[0.123, 1.345], [2.567, 3.789]]])}, + {"pred": p([[[0.0, 1.0], [3.0, 4.0]]])}, + (1, 2, 2), + ] + ) + + # test compatible with previous versions + TEST_CASES.append( + [ + { + "keys": ["pred", "label"], + "argmax": False, + "to_onehot": None, + "threshold": [True, None], + "logit_thresh": 0.6, + }, + {"pred": p([[[0.0, 1.0], [2.0, 3.0]]]), "label": p([[[0, 1], [1, 1]]])}, + {"pred": p([[[0.0, 1.0], [1.0, 1.0]]]), "label": p([[[0.0, 1.0], [1.0, 1.0]]])}, + (1, 2, 2), + ] + ) + + # test threshold = 0.0 + TEST_CASES.append( + [ + {"keys": ["pred", "label"], "argmax": False, "to_onehot": None, "threshold": [0.0, None]}, + {"pred": p([[[0.0, -1.0], [-2.0, 3.0]]]), "label": p([[[0, 1], [1, 1]]])}, + {"pred": p([[[1.0, 0.0], [0.0, 1.0]]]), "label": p([[[0.0, 1.0], [1.0, 1.0]]])}, + (1, 2, 2), + ] + ) class TestAsDiscreted(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand(TEST_CASES) def test_value_shape(self, input_param, test_input, output, expected_shape): result = AsDiscreted(**input_param)(test_input) - torch.testing.assert_allclose(result["pred"], output["pred"]) + assert_allclose(result["pred"], output["pred"], rtol=1e-3) self.assertTupleEqual(result["pred"].shape, expected_shape) if "label" in result: - torch.testing.assert_allclose(result["label"], output["label"]) + assert_allclose(result["label"], output["label"], rtol=1e-3) self.assertTupleEqual(result["label"].shape, expected_shape) diff --git a/tests/test_autoencoder.py b/tests/test_autoencoder.py index 54d6832c8d..bed5a198ff 100644 --- a/tests/test_autoencoder.py +++ b/tests/test_autoencoder.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,7 +23,7 @@ TEST_CASE_0 = [ # single channel 2D, batch 4, no residual { - "dimensions": 2, + "spatial_dims": 2, "in_channels": 1, "out_channels": 1, "channels": (4, 8, 16), @@ -35,20 +35,14 @@ ] TEST_CASE_1 = [ # single channel 2D, batch 4 - { - "dimensions": 2, - "in_channels": 1, - "out_channels": 1, - "channels": (4, 8, 16), - "strides": (2, 2, 2), - }, + {"spatial_dims": 2, "in_channels": 1, "out_channels": 1, "channels": (4, 8, 16), "strides": (2, 2, 2)}, (1, 1, 128, 128), (1, 1, 128, 128), ] TEST_CASE_2 = [ # 3-channel 2D, batch 4, LeakyReLU activation { - "dimensions": 2, + "spatial_dims": 2, "in_channels": 3, "out_channels": 3, "channels": (4, 8, 16), @@ -60,13 +54,7 @@ ] TEST_CASE_3 = [ # 4-channel 3D, batch 4 - { - "dimensions": 3, - "in_channels": 4, - "out_channels": 3, - "channels": (4, 8, 16), - "strides": (2, 2, 2), - }, + {"spatial_dims": 3, "in_channels": 4, "out_channels": 3, "channels": (4, 8, 16), "strides": (2, 2, 2)}, (1, 4, 128, 128, 128), (1, 3, 128, 128, 128), ] @@ -75,7 +63,7 @@ TEST_CASE_FAIL = { # 2-channel 2D, should fail because of stride/channel mismatch. - "dimensions": 2, + "spatial_dims": 2, "in_channels": 2, "out_channels": 2, "channels": (4, 8, 16), @@ -92,7 +80,7 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) def test_script(self): - net = AutoEncoder(dimensions=2, in_channels=1, out_channels=1, channels=(4, 8), strides=(2, 2)) + net = AutoEncoder(spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8), strides=(2, 2)) test_data = torch.randn(2, 1, 32, 32) test_script_save(net, test_data) diff --git a/tests/test_basic_unet.py b/tests/test_basic_unet.py index 09d7f72d0e..a4f88367dd 100644 --- a/tests/test_basic_unet.py +++ b/tests/test_basic_unet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,20 +20,10 @@ CASES_1D = [] for mode in ["pixelshuffle", "nontrainable", "deconv", None]: - kwargs = { - "dimensions": 1, - "in_channels": 5, - "out_channels": 8, - } + kwargs = {"spatial_dims": 1, "in_channels": 5, "out_channels": 8} if mode is not None: kwargs["upsample"] = mode # type: ignore - CASES_1D.append( - [ - kwargs, - (10, 5, 33), - (10, 8, 33), - ] - ) + CASES_1D.append([kwargs, (10, 5, 33), (10, 8, 33)]) CASES_2D = [] for mode in ["pixelshuffle", "nontrainable", "deconv"]: @@ -43,7 +33,7 @@ CASES_2D.append( [ { - "dimensions": 2, + "spatial_dims": 2, "in_channels": in_channels, "out_channels": out_channels, "features": (12, 12, 13, 14, 15, 16), @@ -56,7 +46,7 @@ CASES_3D = [ [ # single channel 3D, batch 2 { - "dimensions": 3, + "spatial_dims": 3, "in_channels": 1, "out_channels": 2, "features": (16, 20, 21, 22, 23, 11), @@ -67,7 +57,7 @@ ], [ # 2-channel 3D, batch 3 { - "dimensions": 3, + "spatial_dims": 3, "in_channels": 2, "out_channels": 7, "features": (14, 15, 16, 17, 18, 11), @@ -78,7 +68,7 @@ ], [ # 4-channel 3D, batch 5 { - "dimensions": 3, + "spatial_dims": 3, "in_channels": 4, "out_channels": 2, "features": (14, 15, 16, 17, 18, 10), @@ -101,7 +91,7 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) def test_script(self): - net = BasicUNet(dimensions=2, in_channels=1, out_channels=3) + net = BasicUNet(spatial_dims=2, in_channels=1, out_channels=3) test_data = torch.randn(16, 1, 32, 32) test_script_save(net, test_data) diff --git a/tests/test_bending_energy.py b/tests/test_bending_energy.py index 8f1fb43535..318b1905df 100644 --- a/tests/test_bending_energy.py +++ b/tests/test_bending_energy.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,31 +20,30 @@ device = "cuda" if torch.cuda.is_available() else "cpu" TEST_CASES = [ + [{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0], + [{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 0.0], [ - {}, - {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, - 0.0, - ], - [ - {}, - {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, - 0.0, - ], - [ - {}, + {"normalize": False}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, 4.0, ], [ - {}, - {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 3, 5, 5) ** 2}, + {"normalize": False}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, 4.0, ], + [{"normalize": False}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 4.0], [ - {}, - {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 3, 5) ** 2}, - 4.0, + {"normalize": True}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + 100.0, + ], + [ + {"normalize": True}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, + 100.0, ], + [{"normalize": True}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 100.0], ] @@ -57,18 +56,24 @@ def test_shape(self, input_param, input_data, expected_val): def test_ill_shape(self): loss = BendingEnergyLoss() # not in 3-d, 4-d, 5-d - with self.assertRaisesRegex(ValueError, ""): + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): loss.forward(torch.ones((1, 3), device=device)) - with self.assertRaisesRegex(ValueError, ""): - loss.forward(torch.ones((1, 3, 5, 5, 5, 5), device=device)) + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): + loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device)) # spatial_dim < 5 - with self.assertRaisesRegex(ValueError, ""): + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): loss.forward(torch.ones((1, 3, 4, 5, 5), device=device)) - with self.assertRaisesRegex(ValueError, ""): + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): loss.forward(torch.ones((1, 3, 5, 4, 5))) - with self.assertRaisesRegex(ValueError, ""): + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): loss.forward(torch.ones((1, 3, 5, 5, 4))) + # number of vector components unequal to number of spatial dims + with self.assertRaisesRegex(ValueError, "Number of vector components"): + loss.forward(torch.ones((1, 2, 5, 5, 5))) + with self.assertRaisesRegex(ValueError, "Number of vector components"): + loss.forward(torch.ones((1, 2, 5, 5, 5))) + def test_ill_opts(self): pred = torch.rand(1, 3, 5, 5, 5).to(device=device) with self.assertRaisesRegex(ValueError, ""): diff --git a/tests/test_bilateral_approx_cpu.py b/tests/test_bilateral_approx_cpu.py index 7960f76591..b3f4f5c3be 100644 --- a/tests/test_bilateral_approx_cpu.py +++ b/tests/test_bilateral_approx_cpu.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -306,7 +306,7 @@ # Frame 4 [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], ] - ], + ] ], # Expected [ diff --git a/tests/test_bilateral_approx_cuda.py b/tests/test_bilateral_approx_cuda.py index 345a920f3c..db34b0ff71 100644 --- a/tests/test_bilateral_approx_cuda.py +++ b/tests/test_bilateral_approx_cuda.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -306,7 +306,7 @@ # Frame 4 [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], ] - ], + ] ], # Expected [ diff --git a/tests/test_bilateral_precise.py b/tests/test_bilateral_precise.py index dfa3ca107d..b19369d758 100644 --- a/tests/test_bilateral_precise.py +++ b/tests/test_bilateral_precise.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -306,7 +306,7 @@ # Frame 4 [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], ] - ], + ] ], # Expected [ diff --git a/tests/test_blend_images.py b/tests/test_blend_images.py new file mode 100644 index 0000000000..6fea53ac30 --- /dev/null +++ b/tests/test_blend_images.py @@ -0,0 +1,53 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.case import skipUnless + +import torch +from parameterized import parameterized + +from monai.data.synthetic import create_test_image_2d, create_test_image_3d +from monai.transforms.utils_pytorch_numpy_unification import moveaxis +from monai.utils.module import optional_import +from monai.visualize.utils import blend_images +from tests.utils import TEST_NDARRAYS + +plt, has_matplotlib = optional_import("matplotlib.pyplot") + +TESTS = [] +for p in TEST_NDARRAYS: + image, label = create_test_image_2d(100, 101) + TESTS.append((p(image), p(label))) + + image, label = create_test_image_3d(100, 101, 102) + TESTS.append((p(image), p(label))) + + +@skipUnless(has_matplotlib, "Matplotlib required") +class TestBlendImages(unittest.TestCase): + @parameterized.expand(TESTS) + def test_blend(self, image, label): + blended = blend_images(image[None], label[None]) + self.assertEqual(type(image), type(blended)) + if isinstance(blended, torch.Tensor): + self.assertEqual(blended.device, image.device) + blended = blended.cpu().numpy() + self.assertEqual((3,) + image.shape, blended.shape) + + blended = moveaxis(blended, 0, -1) # move RGB component to end + if blended.ndim > 3: + blended = blended[blended.shape[0] // 2] + plt.imshow(blended) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_border_pad.py b/tests/test_border_pad.py index 9e6a8a6a08..b632ff831f 100644 --- a/tests/test_border_pad.py +++ b/tests/test_border_pad.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,17 +18,9 @@ from monai.utils import NumpyPadMode from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"spatial_border": 2, "mode": "constant"}, - np.zeros((3, 8, 8, 4)), - np.zeros((3, 12, 12, 8)), -] +TEST_CASE_1 = [{"spatial_border": 2, "mode": "constant"}, np.zeros((3, 8, 8, 4)), np.zeros((3, 12, 12, 8))] -TEST_CASE_2 = [ - {"spatial_border": [1, 2, 3], "mode": "constant"}, - np.zeros((3, 8, 8, 4)), - np.zeros((3, 10, 12, 10)), -] +TEST_CASE_2 = [{"spatial_border": [1, 2, 3], "mode": "constant"}, np.zeros((3, 8, 8, 4)), np.zeros((3, 10, 12, 10))] TEST_CASE_3 = [ {"spatial_border": [1, 2, 3, 4, 5, 6], "mode": "constant"}, diff --git a/tests/test_border_padd.py b/tests/test_border_padd.py index b48629fc98..e4b8dd20ea 100644 --- a/tests/test_border_padd.py +++ b/tests/test_border_padd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_bounding_rect.py b/tests/test_bounding_rect.py index 38585cba18..a7c2648f1e 100644 --- a/tests/test_bounding_rect.py +++ b/tests/test_bounding_rect.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,6 +16,7 @@ import monai from monai.transforms import BoundingRect +from tests.utils import TEST_NDARRAYS TEST_CASE_1 = [(2, 3), [[0, 0], [1, 2]]] @@ -35,14 +36,16 @@ def tearDown(self): def test_shape(self, input_shape, expected): test_data = np.random.randint(0, 8, size=input_shape) test_data = test_data == 7 - result = BoundingRect()(test_data) - np.testing.assert_allclose(result, expected) + for p in TEST_NDARRAYS: + result = BoundingRect()(p(test_data)) + np.testing.assert_allclose(result, expected) def test_select_fn(self): test_data = np.random.randint(0, 8, size=(2, 3)) test_data = test_data == 7 - bbox = BoundingRect(select_fn=lambda x: x < 1)(test_data) - np.testing.assert_allclose(bbox, [[0, 3], [0, 3]]) + for p in TEST_NDARRAYS: + bbox = BoundingRect(select_fn=lambda x: x < 1)(p(test_data)) + np.testing.assert_allclose(bbox, [[0, 3], [0, 3]]) if __name__ == "__main__": diff --git a/tests/test_bounding_rectd.py b/tests/test_bounding_rectd.py index 6e725ff583..47ed854263 100644 --- a/tests/test_bounding_rectd.py +++ b/tests/test_bounding_rectd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,6 +16,7 @@ import monai from monai.transforms import BoundingRectD +from tests.utils import TEST_NDARRAYS TEST_CASE_1 = [(2, 3), [[0, 0], [1, 2]]] @@ -35,14 +36,15 @@ def tearDown(self): def test_shape(self, input_shape, expected): test_data = np.random.randint(0, 8, size=input_shape) test_data = test_data == 7 - result = BoundingRectD("image")({"image": test_data}) - np.testing.assert_allclose(result["image_bbox"], expected) + for p in TEST_NDARRAYS: + result = BoundingRectD("image")({"image": p(test_data)}) + np.testing.assert_allclose(result["image_bbox"], expected) - result = BoundingRectD("image", "cc")({"image": test_data}) - np.testing.assert_allclose(result["image_cc"], expected) + result = BoundingRectD("image", "cc")({"image": p(test_data)}) + np.testing.assert_allclose(result["image_cc"], expected) - with self.assertRaises(KeyError): - BoundingRectD("image", "cc")({"image": test_data, "image_cc": None}) + with self.assertRaises(KeyError): + BoundingRectD("image", "cc")({"image": p(test_data), "image_cc": None}) if __name__ == "__main__": diff --git a/tests/test_cachedataset.py b/tests/test_cachedataset.py index bbb8143631..9d85c711fd 100644 --- a/tests/test_cachedataset.py +++ b/tests/test_cachedataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,8 +19,8 @@ from parameterized import parameterized from monai.data import CacheDataset, DataLoader, PersistentDataset, SmartCacheDataset -from monai.transforms import Compose, Lambda, LoadImaged, ThreadUnsafe, Transform -from monai.utils import get_torch_version_tuple +from monai.transforms import Compose, Lambda, LoadImaged, RandLambda, ThreadUnsafe, Transform +from monai.utils.module import pytorch_after TEST_CASE_1 = [Compose([LoadImaged(keys=["image", "label", "extra"])]), (128, 128, 128)] @@ -84,7 +84,7 @@ def test_shape(self, transform, expected_shape): def test_set_data(self): data_list1 = list(range(10)) - transform = Lambda(func=lambda x: np.array([x * 10])) + transform = Compose([Lambda(func=lambda x: np.array([x * 10])), RandLambda(func=lambda x: x + 1)]) dataset = CacheDataset( data=data_list1, @@ -92,19 +92,23 @@ def test_set_data(self): cache_rate=1.0, num_workers=4, progress=True, + copy_cache=False if sys.platform == "linux" else True, ) num_workers = 2 if sys.platform == "linux" else 0 dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=1) for i, d in enumerate(dataloader): - np.testing.assert_allclose([[data_list1[i] * 10]], d) + np.testing.assert_allclose([[data_list1[i] * 10 + 1]], d) + # simulate another epoch, the cache content should not be modified + for i, d in enumerate(dataloader): + np.testing.assert_allclose([[data_list1[i] * 10 + 1]], d) # update the datalist and fill the cache content data_list2 = list(range(-10, 0)) dataset.set_data(data=data_list2) # rerun with updated cache content for i, d in enumerate(dataloader): - np.testing.assert_allclose([[data_list2[i] * 10]], d) + np.testing.assert_allclose([[data_list2[i] * 10 + 1]], d) class _StatefulTransform(Transform, ThreadUnsafe): @@ -130,14 +134,10 @@ class TestCacheThread(unittest.TestCase): @parameterized.expand(TEST_DS) def test_thread_safe(self, persistent_workers, cache_workers, loader_workers): expected = [102, 202, 302, 402, 502, 602, 702, 802, 902, 1002] - _kwg = {"persistent_workers": persistent_workers} if get_torch_version_tuple() > (1, 7) else {} + _kwg = {"persistent_workers": persistent_workers} if pytorch_after(1, 8) else {} data_list = list(range(1, 11)) dataset = CacheDataset( - data=data_list, - transform=_StatefulTransform(), - cache_rate=1.0, - num_workers=cache_workers, - progress=False, + data=data_list, transform=_StatefulTransform(), cache_rate=1.0, num_workers=cache_workers, progress=False ) self.assertListEqual(expected, list(dataset)) loader = DataLoader( diff --git a/tests/test_cachedataset_parallel.py b/tests/test_cachedataset_parallel.py index 0be3ba085b..f409e17787 100644 --- a/tests/test_cachedataset_parallel.py +++ b/tests/test_cachedataset_parallel.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -42,12 +42,7 @@ def test_shape(self, num_workers, dataset_size, transform): "extra": os.path.join(tempdir, "test_extra1.nii.gz"), } ] * dataset_size - dataset = CacheDataset( - data=test_data, - transform=transform, - cache_rate=1, - num_workers=num_workers, - ) + dataset = CacheDataset(data=test_data, transform=transform, cache_rate=1, num_workers=num_workers) self.assertEqual(len(dataset._cache), dataset.cache_num) for i in range(dataset.cache_num): diff --git a/tests/test_cachedataset_persistent_workers.py b/tests/test_cachedataset_persistent_workers.py index 584a053614..4bea0486bc 100644 --- a/tests/test_cachedataset_persistent_workers.py +++ b/tests/test_cachedataset_persistent_workers.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,20 +19,16 @@ @SkipIfBeforePyTorchVersion((1, 7)) class TestTransformsWCacheDatasetAndPersistentWorkers(unittest.TestCase): def test_duplicate_transforms(self): - im, _ = create_test_image_2d(128, 128, num_seg_classes=1, channel_dim=0) - data = [{"img": im} for _ in range(2)] + data = [{"img": create_test_image_2d(128, 128, num_seg_classes=1, channel_dim=0)[0]} for _ in range(2)] # at least 1 deterministic followed by at least 1 random - transform = Compose( - [ - Spacingd("img", pixdim=(1, 1)), - RandAffined("img", prob=1.0), - ] - ) + transform = Compose([Spacingd("img", pixdim=(1, 1)), RandAffined("img", prob=1.0)]) # cachedataset and data loader w persistent_workers train_ds = CacheDataset(data, transform, cache_num=1) - train_loader = DataLoader(train_ds, num_workers=2, persistent_workers=True) + # num_workers > 1 may fail randomly with 21.09 on A100 test node + # https://github.com/Project-MONAI/MONAI/issues/3283 + train_loader = DataLoader(train_ds, num_workers=1, persistent_workers=True) b1 = next(iter(train_loader)) b2 = next(iter(train_loader)) diff --git a/tests/test_cachentransdataset.py b/tests/test_cachentransdataset.py index 492db8b16f..99ca0e0c3d 100644 --- a/tests/test_cachentransdataset.py +++ b/tests/test_cachentransdataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -42,19 +42,13 @@ def test_n_trans(self, transform, expected_shape): cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data") dataset_precached = CacheNTransDataset( - data=test_data, - transform=transform, - cache_dir=cache_dir, - cache_n_trans=2, + data=test_data, transform=transform, cache_dir=cache_dir, cache_n_trans=2 ) data_precached = dataset_precached[0] self.assertTupleEqual(data_precached["image"].shape, expected_shape) dataset_postcached = CacheNTransDataset( - data=test_data, - transform=transform, - cache_dir=cache_dir, - cache_n_trans=2, + data=test_data, transform=transform, cache_dir=cache_dir, cache_n_trans=2 ) data_postcached = dataset_postcached[0] self.assertTupleEqual(data_postcached["image"].shape, expected_shape) diff --git a/tests/test_distcall.py b/tests/test_call_dist.py similarity index 96% rename from tests/test_distcall.py rename to tests/test_call_dist.py index 1830a85654..bed8289506 100644 --- a/tests/test_distcall.py +++ b/tests/test_call_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_cast_to_type.py b/tests/test_cast_to_type.py index 0ef25cbafa..82daabc4e7 100644 --- a/tests/test_cast_to_type.py +++ b/tests/test_cast_to_type.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,14 +16,23 @@ from parameterized import parameterized from monai.transforms import CastToType +from monai.utils import optional_import from monai.utils.type_conversion import get_equivalent_dtype -from tests.utils import TEST_NDARRAYS +from tests.utils import HAS_CUPY, TEST_NDARRAYS + +cp, _ = optional_import("cupy") TESTS = [] for p in TEST_NDARRAYS: for out_dtype in (np.float64, torch.float64): TESTS.append([out_dtype, p(np.array([[0, 1], [1, 2]], dtype=np.float32)), out_dtype]) +TESTS_CUPY = [ + [np.float32, np.array([[0, 1], [1, 2]], dtype=np.float32), np.float32], + [np.float32, np.array([[0, 1], [1, 2]], dtype=np.uint8), np.float32], + [np.uint8, np.array([[0, 1], [1, 2]], dtype=np.float32), np.uint8], +] + class TestCastToType(unittest.TestCase): @parameterized.expand(TESTS) @@ -35,6 +44,19 @@ def test_type(self, out_dtype, input_data, expected_type): result = CastToType()(input_data, out_dtype) self.assertEqual(result.dtype, get_equivalent_dtype(expected_type, type(result))) + @parameterized.expand(TESTS_CUPY) + @unittest.skipUnless(HAS_CUPY, "Requires CuPy") + def test_type_cupy(self, out_dtype, input_data, expected_type): + input_data = cp.asarray(input_data) + + result = CastToType(dtype=out_dtype)(input_data) + self.assertTrue(isinstance(result, cp.ndarray)) + self.assertEqual(result.dtype, get_equivalent_dtype(expected_type, type(result))) + + result = CastToType()(input_data, out_dtype) + self.assertTrue(isinstance(result, cp.ndarray)) + self.assertEqual(result.dtype, get_equivalent_dtype(expected_type, type(result))) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_cast_to_typed.py b/tests/test_cast_to_typed.py index be495564fb..4c7623a9e0 100644 --- a/tests/test_cast_to_typed.py +++ b/tests/test_cast_to_typed.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,6 +16,10 @@ from parameterized import parameterized from monai.transforms import CastToTyped +from monai.utils import optional_import +from tests.utils import HAS_CUPY + +cp, _ = optional_import("cupy") TEST_CASE_1 = [ {"keys": ["img"], "dtype": np.float64}, @@ -33,6 +37,20 @@ ] +TESTS_CUPY = [ + [ + {"keys": "image", "dtype": np.uint8}, + {"image": np.array([[0, 1], [1, 2]], dtype=np.float32), "label": np.array([[0, 1], [1, 1]], dtype=np.float32)}, + {"image": np.uint8, "label": np.float32}, + ], + [ + {"keys": ["image", "label"], "dtype": np.float32}, + {"image": np.array([[0, 1], [1, 2]], dtype=np.uint8), "label": np.array([[0, 1], [1, 1]], dtype=np.uint8)}, + {"image": np.float32, "label": np.float32}, + ], +] + + class TestCastToTyped(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_type(self, input_param, input_data, expected_type): @@ -40,6 +58,16 @@ def test_type(self, input_param, input_data, expected_type): for k, v in result.items(): self.assertEqual(v.dtype, expected_type[k]) + @parameterized.expand(TESTS_CUPY) + @unittest.skipUnless(HAS_CUPY, "Requires CuPy") + def test_type_cupy(self, input_param, input_data, expected_type): + input_data = {k: cp.asarray(v) for k, v in input_data.items()} + + result = CastToTyped(**input_param)(input_data) + for k, v in result.items(): + self.assertTrue(isinstance(v, cp.ndarray)) + self.assertEqual(v.dtype, expected_type[k]) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_center_scale_crop.py b/tests/test_center_scale_crop.py index e28849ce90..f22651e3e0 100644 --- a/tests/test_center_scale_crop.py +++ b/tests/test_center_scale_crop.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -38,11 +38,13 @@ class TestCenterScaleCrop(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) def test_shape(self, input_param, input_data, expected_shape): result = CenterScaleCrop(**input_param)(input_data) + self.assertEqual(isinstance(result, torch.Tensor), isinstance(input_data, torch.Tensor)) np.testing.assert_allclose(result.shape, expected_shape) @parameterized.expand([TEST_CASE_2]) def test_value(self, input_param, input_data, expected_value): result = CenterScaleCrop(**input_param)(input_data) + self.assertEqual(isinstance(result, torch.Tensor), isinstance(input_data, torch.Tensor)) np.testing.assert_allclose(result, expected_value) diff --git a/tests/test_center_scale_cropd.py b/tests/test_center_scale_cropd.py index 313e8e7f7e..8aef2dbe5b 100644 --- a/tests/test_center_scale_cropd.py +++ b/tests/test_center_scale_cropd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -33,9 +33,15 @@ (3, 2, 2, 2), ] +TEST_CASE_4 = [ + {"keys": "test", "roi_scale": 0.6, "allow_missing_keys": True}, + np.random.randint(0, 2, size=[3, 3, 3, 3]), + (3, 3, 3, 3), +] + class TestCenterScaleCropd(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3, TEST_CASE_4]) def test_shape(self, input_param, input_data, expected_shape): result = CenterScaleCropd(**input_param)({"img": input_data}) np.testing.assert_allclose(result["img"].shape, expected_shape) diff --git a/tests/test_center_spatial_crop.py b/tests/test_center_spatial_crop.py index 3e828176a5..09f61be2f1 100644 --- a/tests/test_center_spatial_crop.py +++ b/tests/test_center_spatial_crop.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -38,11 +38,13 @@ class TestCenterSpatialCrop(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) def test_shape(self, input_param, input_data, expected_shape): result = CenterSpatialCrop(**input_param)(input_data) + self.assertEqual(isinstance(result, torch.Tensor), isinstance(input_data, torch.Tensor)) np.testing.assert_allclose(result.shape, expected_shape) @parameterized.expand([TEST_CASE_2]) def test_value(self, input_param, input_data, expected_value): result = CenterSpatialCrop(**input_param)(input_data) + self.assertEqual(isinstance(result, torch.Tensor), isinstance(input_data, torch.Tensor)) np.testing.assert_allclose(result, expected_value) diff --git a/tests/test_center_spatial_cropd.py b/tests/test_center_spatial_cropd.py index 349253ab56..bdbc1a5031 100644 --- a/tests/test_center_spatial_cropd.py +++ b/tests/test_center_spatial_cropd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,36 +15,43 @@ from parameterized import parameterized from monai.transforms import CenterSpatialCropd - -TEST_CASE_0 = [ - {"keys": "img", "roi_size": [2, -1, -1]}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 2, 3, 3), -] - -TEST_CASE_1 = [ - {"keys": "img", "roi_size": [2, 2, 2]}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 2, 2, 2), -] - -TEST_CASE_2 = [ - {"keys": "img", "roi_size": [2, 2]}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[1, 2], [2, 3]]]), -] +from tests.utils import TEST_NDARRAYS, assert_allclose + +TEST_SHAPES = [] +for p in TEST_NDARRAYS: + TEST_SHAPES.append( + [{"keys": "img", "roi_size": [2, -1, -1]}, {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, (3, 2, 3, 3)] + ) + + TEST_SHAPES.append( + [{"keys": "img", "roi_size": [2, 2, 2]}, {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, (3, 2, 2, 2)] + ) + +TEST_CASES = [] +for p in TEST_NDARRAYS: + TEST_CASES.append( + [ + {"keys": "img", "roi_size": [2, 2]}, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ) + }, + p(np.array([[[1, 2], [2, 3]]])), + ] + ) class TestCenterSpatialCropd(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) + @parameterized.expand(TEST_SHAPES) def test_shape(self, input_param, input_data, expected_shape): result = CenterSpatialCropd(**input_param)(input_data) self.assertTupleEqual(result["img"].shape, expected_shape) - @parameterized.expand([TEST_CASE_2]) + @parameterized.expand(TEST_CASES) def test_value(self, input_param, input_data, expected_value): result = CenterSpatialCropd(**input_param)(input_data) - np.testing.assert_allclose(result["img"], expected_value) + assert_allclose(result["img"], expected_value, type_test=False) if __name__ == "__main__": diff --git a/tests/test_channel_pad.py b/tests/test_channel_pad.py index ebc731c321..bde0f18d83 100644 --- a/tests/test_channel_pad.py +++ b/tests/test_channel_pad.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_check_hash.py b/tests/test_check_hash.py index 0126b3c1a3..5297021540 100644 --- a/tests/test_check_hash.py +++ b/tests/test_check_hash.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -41,7 +41,7 @@ def test_result(self, md5_value, t, expected_result): self.assertTrue(result == expected_result) def test_hash_type_error(self): - with self.assertRaises(NotImplementedError): + with self.assertRaises(ValueError): with tempfile.TemporaryDirectory() as tempdir: check_hash(tempdir, "test_hash", "test_type") diff --git a/tests/test_check_missing_files.py b/tests/test_check_missing_files.py new file mode 100644 index 0000000000..ecd0b52a63 --- /dev/null +++ b/tests/test_check_missing_files.py @@ -0,0 +1,56 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest +from pathlib import Path + +import nibabel as nib +import numpy as np + +from monai.data import check_missing_files + + +class TestCheckMissingFiles(unittest.TestCase): + def test_content(self): + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) + with tempfile.TemporaryDirectory() as tempdir: + nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) + + datalist = [ + { + "image": os.path.join(tempdir, "test_image1.nii.gz"), + "label": [os.path.join(tempdir, "test_label1.nii.gz"), os.path.join(tempdir, "test_extra1.nii.gz")], + }, + { + "image": Path(os.path.join(tempdir, "test_image2.nii.gz")), + "label": Path(os.path.join(tempdir, "test_label_missing.nii.gz")), + }, + ] + + missings = check_missing_files(datalist=datalist, keys=["image", "label"]) + self.assertEqual(len(missings), 1) + self.assertEqual(str(missings[0]), os.path.join(tempdir, "test_label_missing.nii.gz")) + + # test with missing key and relative path + datalist = [{"image": "test_image1.nii.gz", "label": "test_label_missing.nii.gz"}] + missings = check_missing_files( + datalist=datalist, keys=["image", "label", "test"], root_dir=tempdir, allow_missing_keys=True + ) + self.assertEqual(f"{missings[0]}", os.path.join(tempdir, "test_label_missing.nii.gz")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_classes_to_indices.py b/tests/test_classes_to_indices.py index 0ba3dd094a..1f39e0f480 100644 --- a/tests/test_classes_to_indices.py +++ b/tests/test_classes_to_indices.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,68 +11,80 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import ClassesToIndices +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - # test Argmax data - {"num_classes": 3, "image_threshold": 0.0}, - np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), - None, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] +TESTS_CASES = [] +for p in TEST_NDARRAYS: + TESTS_CASES.append( + [ + # test Argmax data + {"num_classes": 3, "image_threshold": 0.0}, + p([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), + None, + [p([0, 4, 8]), p([1, 5, 6]), p([2, 3, 7])], + ] + ) -TEST_CASE_2 = [ - {"num_classes": 3, "image_threshold": 60}, - np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), - np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + TESTS_CASES.append( + [ + {"num_classes": 3, "image_threshold": 60}, + p([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), + p([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), + [p([0, 8]), p([1, 5, 6]), p([3])], + ] + ) -TEST_CASE_3 = [ - # test One-Hot data - {"image_threshold": 0.0}, - np.array( + TESTS_CASES.append( [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + # test One-Hot data + {"image_threshold": 0.0}, + p( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ), + None, + [p([0, 4, 8]), p([1, 5, 6]), p([2, 3, 7])], ] - ), - None, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] + ) -TEST_CASE_4 = [ - {"num_classes": None, "image_threshold": 60}, - np.array( + TESTS_CASES.append( [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + {"num_classes": None, "image_threshold": 60}, + p( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ), + p([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), + [p([0, 8]), p([1, 5, 6]), p([3])], ] - ), - np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + ) -TEST_CASE_5 = [ - # test output_shape - {"num_classes": 3, "image_threshold": 0.0, "output_shape": [3, 3]}, - np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), - None, - [np.array([[0, 0], [1, 1], [2, 2]]), np.array([[0, 1], [1, 2], [2, 0]]), np.array([[0, 2], [1, 0], [2, 1]])], -] + TESTS_CASES.append( + [ + # test output_shape + {"num_classes": 3, "image_threshold": 0.0, "output_shape": [3, 3]}, + p([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), + None, + [p([[0, 0], [1, 1], [2, 2]]), p([[0, 1], [1, 2], [2, 0]]), p([[0, 2], [1, 0], [2, 1]])], + ] + ) class TestClassesToIndices(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS_CASES) def test_value(self, input_args, label, image, expected_indices): indices = ClassesToIndices(**input_args)(label, image) for i, e in zip(indices, expected_indices): - np.testing.assert_allclose(i, e) + assert_allclose(i, e) if __name__ == "__main__": diff --git a/tests/test_classes_to_indicesd.py b/tests/test_classes_to_indicesd.py index 67fac95c8c..398620d304 100644 --- a/tests/test_classes_to_indicesd.py +++ b/tests/test_classes_to_indicesd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,73 +11,91 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import ClassesToIndicesd +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - # test Argmax data - {"keys": "label", "num_classes": 3, "image_threshold": 0.0}, - {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])}, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] +TESTS_CASES = [] +for p in TEST_NDARRAYS: + TESTS_CASES.append( + [ + # test Argmax data + {"keys": "label", "num_classes": 3, "image_threshold": 0.0}, + {"label": p([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])}, + [p([0, 4, 8]), p([1, 5, 6]), p([2, 3, 7])], + ] + ) -TEST_CASE_2 = [ - {"keys": "label", "image_key": "image", "num_classes": 3, "image_threshold": 60}, - { - "label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), - "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - }, - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + TESTS_CASES.append( + [ + {"keys": "label", "image_key": "image", "num_classes": 3, "image_threshold": 60}, + { + "label": p([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), + "image": p([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), + }, + [p([0, 8]), p([1, 5, 6]), p([3])], + ] + ) -TEST_CASE_3 = [ - # test One-Hot data - {"keys": "label", "image_threshold": 0.0}, - { - "label": np.array( - [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], - ] - ) - }, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] + TESTS_CASES.append( + [ + # test One-Hot data + {"keys": "label", "image_threshold": 0.0}, + { + "label": p( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ) + }, + [p([0, 4, 8]), p([1, 5, 6]), p([2, 3, 7])], + ] + ) -TEST_CASE_4 = [ - {"keys": "label", "image_key": "image", "num_classes": None, "image_threshold": 60}, - { - "label": np.array( - [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], - ] - ), - "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - }, - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + TESTS_CASES.append( + [ + {"keys": "label", "image_key": "image", "num_classes": None, "image_threshold": 60}, + { + "label": p( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ), + "image": p([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), + }, + [p([0, 8]), p([1, 5, 6]), p([3])], + ] + ) -TEST_CASE_5 = [ - # test output_shape - {"keys": "label", "indices_postfix": "cls", "num_classes": 3, "image_threshold": 0.0, "output_shape": [3, 3]}, - {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])}, - [np.array([[0, 0], [1, 1], [2, 2]]), np.array([[0, 1], [1, 2], [2, 0]]), np.array([[0, 2], [1, 0], [2, 1]])], -] + TESTS_CASES.append( + [ + # test output_shape + { + "keys": "label", + "indices_postfix": "cls", + "num_classes": 3, + "image_threshold": 0.0, + "output_shape": [3, 3], + }, + {"label": p([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])}, + [p([[0, 0], [1, 1], [2, 2]]), p([[0, 1], [1, 2], [2, 0]]), p([[0, 2], [1, 0], [2, 1]])], + ] + ) class TestClassesToIndicesd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS_CASES) def test_value(self, input_args, input_data, expected_indices): result = ClassesToIndicesd(**input_args)(input_data) key_postfix = input_args.get("indices_postfix") key_postfix = "_cls_indices" if key_postfix is None else key_postfix for i, e in zip(result["label" + key_postfix], expected_indices): - np.testing.assert_allclose(i, e) + assert_allclose(i, e) if __name__ == "__main__": diff --git a/tests/test_compose.py b/tests/test_compose.py index 28783cad23..e0913f59e1 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_compose_get_number_conversions.py b/tests/test_compose_get_number_conversions.py index eb10c7d5ef..fca5bc727d 100644 --- a/tests/test_compose_get_number_conversions.py +++ b/tests/test_compose_get_number_conversions.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_compute_confusion_matrix.py b/tests/test_compute_confusion_matrix.py index 69a95e0c8b..1212715548 100644 --- a/tests/test_compute_confusion_matrix.py +++ b/tests/test_compute_confusion_matrix.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -174,22 +174,10 @@ # 3. test metric with compute_sample, denominator may have zeros TEST_CASES_COMPUTE_SAMPLE_NAN = [] metric_names = ["tpr", "tnr"] -result_sum = [ - torch.tensor([0.5000]), - torch.tensor([4.8333]), -] -not_nans_sum = [ - torch.tensor([6]), - torch.tensor([8]), -] -result_sum_batch = [ - torch.tensor([0.0000, 0.5000, 0.0000]), - torch.tensor([1.6667, 2.5000, 0.6667]), -] -not_nans_sum_batch = [ - torch.tensor([3.0, 2.0, 1.0]), - torch.tensor([2.0, 3.0, 3.0]), -] +result_sum = [torch.tensor([0.5000]), torch.tensor([4.8333])] +not_nans_sum = [torch.tensor([6]), torch.tensor([8])] +result_sum_batch = [torch.tensor([0.0000, 0.5000, 0.0000]), torch.tensor([1.6667, 2.5000, 0.6667])] +not_nans_sum_batch = [torch.tensor([3.0, 2.0, 1.0]), torch.tensor([2.0, 3.0, 3.0])] for idx in range(2): for reduction in ["sum", "sum_batch"]: TEST_CASE = [data_nan.copy()] diff --git a/tests/test_compute_froc.py b/tests/test_compute_froc.py index 70de836dd9..d68f3f7fb4 100644 --- a/tests/test_compute_froc.py +++ b/tests/test_compute_froc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_compute_meandice.py b/tests/test_compute_meandice.py index f9e494efc7..ad66ed672a 100644 --- a/tests/test_compute_meandice.py +++ b/tests/test_compute_meandice.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -168,10 +168,7 @@ ] TEST_CASE_10 = [ - { - "y": [torch.ones((2, 3, 3)), torch.ones((2, 3, 3))], - "y_pred": [torch.ones((2, 3, 3)), torch.ones((2, 3, 3))], - }, + {"y": [torch.ones((2, 3, 3)), torch.ones((2, 3, 3))], "y_pred": [torch.ones((2, 3, 3)), torch.ones((2, 3, 3))]}, [[1.0000, 1.0000], [1.0000, 1.0000]], ] diff --git a/tests/test_compute_regression_metrics.py b/tests/test_compute_regression_metrics.py index 126eab3f07..65ca73a4ec 100644 --- a/tests/test_compute_regression_metrics.py +++ b/tests/test_compute_regression_metrics.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_compute_roc_auc.py b/tests/test_compute_roc_auc.py index 1cec357b93..d06eed8740 100644 --- a/tests/test_compute_roc_auc.py +++ b/tests/test_compute_roc_auc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,7 +23,7 @@ torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), torch.tensor([[0], [1], [0], [1]]), True, - True, + 2, "macro", 0.75, ] @@ -32,34 +32,20 @@ torch.tensor([[0.5], [0.5], [0.2], [8.3]]), torch.tensor([[0], [1], [0], [1]]), False, - False, + None, "macro", 0.875, ] -TEST_CASE_3 = [ - torch.tensor([[0.5], [0.5], [0.2], [8.3]]), - torch.tensor([0, 1, 0, 1]), - False, - False, - "macro", - 0.875, -] +TEST_CASE_3 = [torch.tensor([[0.5], [0.5], [0.2], [8.3]]), torch.tensor([0, 1, 0, 1]), False, None, "macro", 0.875] -TEST_CASE_4 = [ - torch.tensor([0.5, 0.5, 0.2, 8.3]), - torch.tensor([0, 1, 0, 1]), - False, - False, - "macro", - 0.875, -] +TEST_CASE_4 = [torch.tensor([0.5, 0.5, 0.2, 8.3]), torch.tensor([0, 1, 0, 1]), False, None, "macro", 0.875] TEST_CASE_5 = [ torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), torch.tensor([[0], [1], [0], [1]]), True, - True, + 2, "none", [0.75, 0.75], ] @@ -68,7 +54,7 @@ torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), True, - False, + None, "weighted", 0.56667, ] @@ -77,7 +63,7 @@ torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), True, - False, + None, "micro", 0.62, ] @@ -87,7 +73,7 @@ class TestComputeROCAUC(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value): y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)]) - y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot, num_classes=2)]) + y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot)]) y_pred = torch.stack([y_pred_trans(i) for i in decollate_batch(y_pred)], dim=0) y = torch.stack([y_trans(i) for i in decollate_batch(y)], dim=0) result = compute_roc_auc(y_pred=y_pred, y=y, average=average) @@ -96,7 +82,7 @@ def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) def test_class_value(self, y_pred, y, softmax, to_onehot, average, expected_value): y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)]) - y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot, num_classes=2)]) + y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot)]) y_pred = [y_pred_trans(i) for i in decollate_batch(y_pred)] y = [y_trans(i) for i in decollate_batch(y)] metric = ROCAUCMetric(average=average) diff --git a/tests/test_concat_itemsd.py b/tests/test_concat_itemsd.py index 9c51e1efea..2f98738233 100644 --- a/tests/test_concat_itemsd.py +++ b/tests/test_concat_itemsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_contrastive_loss.py b/tests/test_contrastive_loss.py new file mode 100644 index 0000000000..5dce860486 --- /dev/null +++ b/tests/test_contrastive_loss.py @@ -0,0 +1,79 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses import ContrastiveLoss + +TEST_CASES = [ + [ # shape: (1, 4), (1, 4) + {"temperature": 0.5, "batch_size": 1}, + {"input": torch.tensor([[1.0, 1.0, 0.0, 0.0]]), "target": torch.tensor([[1.0, 1.0, 0.0, 0.0]])}, + 0.0, + ], + [ # shape: (2, 4), (2, 4) + {"temperature": 0.5, "batch_size": 2}, + { + "input": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), + "target": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), + }, + 1.0986, + ], + [ # shape: (1, 4), (1, 4) + {"temperature": 0.5, "batch_size": 2}, + { + "input": torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 1.0, 0.0, 0.0]]), + "target": torch.tensor([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), + }, + 0.8719, + ], + [ # shape: (1, 4), (1, 4) + {"temperature": 0.5, "batch_size": 1}, + {"input": torch.tensor([[0.0, 0.0, 1.0, 1.0]]), "target": torch.tensor([[1.0, 1.0, 0.0, 0.0]])}, + 0.0, + ], + [ # shape: (1, 4), (1, 4) + {"temperature": 0.05, "batch_size": 1}, + {"input": torch.tensor([[0.0, 0.0, 1.0, 1.0]]), "target": torch.tensor([[1.0, 1.0, 0.0, 0.0]])}, + 0.0, + ], +] + + +class TestContrastiveLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_result(self, input_param, input_data, expected_val): + contrastiveloss = ContrastiveLoss(**input_param) + result = contrastiveloss(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + + def test_ill_shape(self): + loss = ContrastiveLoss(temperature=0.5, batch_size=1) + with self.assertRaisesRegex(ValueError, ""): + loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + + def test_with_cuda(self): + loss = ContrastiveLoss(temperature=0.5, batch_size=1) + i = torch.ones((1, 10)) + j = torch.ones((1, 10)) + if torch.cuda.is_available(): + i = i.cuda() + j = j.cuda() + output = loss(i, j) + np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_convert_data_type.py b/tests/test_convert_data_type.py index a7fc64f950..1818f500f9 100644 --- a/tests/test_convert_data_type.py +++ b/tests/test_convert_data_type.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,6 +24,24 @@ for out_type in TEST_NDARRAYS: TESTS.append((in_type(np.array(1.0)), out_type(np.array(1.0)))) # type: ignore +TESTS_LIST: List[Tuple] = [] +for in_type in TEST_NDARRAYS + (int, float): + for out_type in TEST_NDARRAYS: + TESTS_LIST.append( + ([in_type(np.array(1.0)), in_type(np.array(1.0))], out_type(np.array([1.0, 1.0])), True) # type: ignore + ) + TESTS_LIST.append( + ( + [in_type(np.array(1.0)), in_type(np.array(1.0))], # type: ignore + [out_type(np.array(1.0)), out_type(np.array(1.0))], + False, + ) + ) + + +class TestTensor(torch.Tensor): + pass + class TestConvertDataType(unittest.TestCase): @parameterized.expand(TESTS) @@ -42,14 +60,23 @@ def test_convert_data_type(self, in_image, im_out): def test_neg_stride(self): _ = convert_data_type(np.array((1, 2))[::-1], torch.Tensor) - def test_ill_arg(self): - with self.assertRaises(ValueError): - convert_data_type(None, torch.Tensor) - convert_data_type(None, np.ndarray) + @parameterized.expand(TESTS_LIST) + def test_convert_list(self, in_image, im_out, wrap): + output_type = type(im_out) if wrap else type(im_out[0]) + converted_im, *_ = convert_data_type(in_image, output_type, wrap_sequence=wrap) + # check output is desired type + if not wrap: + converted_im = converted_im[0] + im_out = im_out[0] + self.assertEqual(type(converted_im), type(im_out)) + # check dtype is unchanged + if isinstance(in_type, (np.ndarray, torch.Tensor)): + self.assertEqual(converted_im.dtype, im_out.dtype) class TestConvertDataSame(unittest.TestCase): - @parameterized.expand(TESTS) + # add test for subclass of Tensor + @parameterized.expand(TESTS + [(np.array(1.0), TestTensor(np.array(1.0)))]) def test_convert_data_type(self, in_image, im_out): converted_im, orig_type, orig_device = convert_to_dst_type(in_image, im_out) # check input is unchanged @@ -57,7 +84,11 @@ def test_convert_data_type(self, in_image, im_out): if isinstance(in_image, torch.Tensor): self.assertEqual(in_image.device, orig_device) # check output is desired type - self.assertEqual(type(converted_im), type(im_out)) + if isinstance(im_out, torch.Tensor): + output_type = torch.Tensor + else: + output_type = np.ndarray + self.assertEqual(type(converted_im), output_type) # check dtype is unchanged if isinstance(in_type, (np.ndarray, torch.Tensor)): self.assertEqual(converted_im.dtype, im_out.dtype) diff --git a/tests/test_convert_to_multi_channel.py b/tests/test_convert_to_multi_channel.py index 2f7a38e6e4..b606fee04f 100644 --- a/tests/test_convert_to_multi_channel.py +++ b/tests/test_convert_to_multi_channel.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,34 +11,46 @@ import unittest -import numpy as np +import torch from parameterized import parameterized from monai.transforms import ConvertToMultiChannelBasedOnBratsClasses +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - np.array([[0, 1, 2], [1, 2, 4], [0, 1, 4]]), - np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]), -] - -TEST_CASE_2 = [ - np.array([[[[0, 1], [1, 2]], [[2, 4], [4, 4]]]]), - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.extend( [ - [[[0, 1], [1, 0]], [[0, 1], [1, 1]]], - [[[0, 1], [1, 1]], [[1, 1], [1, 1]]], - [[[0, 0], [0, 0]], [[0, 1], [1, 1]]], + [ + p([[0, 1, 2], [1, 2, 4], [0, 1, 4]]), + p( + [ + [[0, 1, 0], [1, 0, 1], [0, 1, 1]], + [[0, 1, 1], [1, 1, 1], [0, 1, 1]], + [[0, 0, 0], [0, 0, 1], [0, 0, 1]], + ] + ), + ], + [ + p([[[[0, 1], [1, 2]], [[2, 4], [4, 4]]]]), + p( + [ + [[[0, 1], [1, 0]], [[0, 1], [1, 1]]], + [[[0, 1], [1, 1]], [[1, 1], [1, 1]]], + [[[0, 0], [0, 0]], [[0, 1], [1, 1]]], + ] + ), + ], ] - ), -] + ) class TestConvertToMultiChannel(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_type_shape(self, data, expected_result): result = ConvertToMultiChannelBasedOnBratsClasses()(data) - np.testing.assert_equal(result, expected_result) - self.assertEqual(f"{result.dtype}", "bool") + assert_allclose(result, expected_result) + self.assertTrue(result.dtype in (bool, torch.bool)) if __name__ == "__main__": diff --git a/tests/test_convert_to_multi_channeld.py b/tests/test_convert_to_multi_channeld.py index 945e07e1cd..7525f8d7e2 100644 --- a/tests/test_convert_to_multi_channeld.py +++ b/tests/test_convert_to_multi_channeld.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_convert_to_torchscript.py b/tests/test_convert_to_torchscript.py new file mode 100644 index 0000000000..a1c1471463 --- /dev/null +++ b/tests/test_convert_to_torchscript.py @@ -0,0 +1,43 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import torch + +from monai.networks import convert_to_torchscript +from monai.networks.nets import UNet + + +class TestConvertToTorchScript(unittest.TestCase): + def test_value(self): + model = UNet( + spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0 + ) + with tempfile.TemporaryDirectory() as tempdir: + torchscript_model = convert_to_torchscript( + model=model, + filename_or_obj=os.path.join(tempdir, "model.ts"), + extra_files={"foo.txt": b"bar"}, + verify=True, + inputs=[torch.randn((16, 1, 32, 32), requires_grad=False)], + device="cuda" if torch.cuda.is_available() else "cpu", + rtol=1e-3, + atol=1e-4, + optimize=None, + ) + self.assertTrue(isinstance(torchscript_model, torch.nn.Module)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_convolutions.py b/tests/test_convolutions.py index 97c01dd659..5c1681ad71 100644 --- a/tests/test_convolutions.py +++ b/tests/test_convolutions.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_copy_itemsd.py b/tests/test_copy_itemsd.py index a0a1ad412b..375991cbcc 100644 --- a/tests/test_copy_itemsd.py +++ b/tests/test_copy_itemsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_copy_model_state.py b/tests/test_copy_model_state.py index 6330a1918a..bc7b116e1f 100644 --- a/tests/test_copy_model_state.py +++ b/tests/test_copy_model_state.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,7 +21,7 @@ class _TestModelOne(torch.nn.Module): def __init__(self, n_n, n_m, n_class): - super(_TestModelOne, self).__init__() + super().__init__() self.layer = torch.nn.Linear(n_n, n_m) self.class_layer = torch.nn.Linear(n_m, n_class) @@ -33,7 +33,7 @@ def forward(self, x): class _TestModelTwo(torch.nn.Module): def __init__(self, n_n, n_m, n_d, n_class): - super(_TestModelTwo, self).__init__() + super().__init__() self.layer = torch.nn.Linear(n_n, n_m) self.layer_1 = torch.nn.Linear(n_m, n_d) self.class_layer = torch.nn.Linear(n_d, n_class) diff --git a/tests/test_correct_crop_centers.py b/tests/test_correct_crop_centers.py new file mode 100644 index 0000000000..50478c7d5d --- /dev/null +++ b/tests/test_correct_crop_centers.py @@ -0,0 +1,34 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms.utils import correct_crop_centers +from tests.utils import assert_allclose + +TESTS = [[[1, 5, 0], [2, 2, 2], [10, 10, 10]], [[4, 4, 4], [2, 2, 1], [10, 10, 10]]] + + +class TestCorrectCropCenters(unittest.TestCase): + @parameterized.expand(TESTS) + def test_torch(self, spatial_size, centers, label_spatial_shape): + result1 = correct_crop_centers(centers, spatial_size, label_spatial_shape) + centers = [torch.tensor(i) for i in centers] + result2 = correct_crop_centers(centers, spatial_size, label_spatial_shape) + assert_allclose(result1, result2) + self.assertEqual(type(result1[0]), type(result2[0])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_create_cross_validation_datalist.py b/tests/test_create_cross_validation_datalist.py new file mode 100644 index 0000000000..3a3e8481ea --- /dev/null +++ b/tests/test_create_cross_validation_datalist.py @@ -0,0 +1,56 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest +from pathlib import Path + +from monai.data import create_cross_validation_datalist, load_decathlon_datalist + + +class TestCreateCrossValidationDatalist(unittest.TestCase): + def test_content(self): + with tempfile.TemporaryDirectory() as tempdir: + datalist = [] + for i in range(5): + image = os.path.join(tempdir, f"test_image{i}.nii.gz") + label = os.path.join(tempdir, f"test_label{i}.nii.gz") + Path(image).touch() + Path(label).touch() + datalist.append({"image": image, "label": label}) + + filename = os.path.join(tempdir, "test_datalist.json") + result = create_cross_validation_datalist( + datalist=datalist, + nfolds=5, + train_folds=[0, 1, 2, 3], + val_folds=4, + train_key="test_train", + val_key="test_val", + filename=Path(filename), + shuffle=True, + seed=123, + check_missing=True, + keys=["image", "label"], + root_dir=None, + allow_missing_keys=False, + raise_error=True, + ) + + loaded = load_decathlon_datalist(filename, data_list_key="test_train") + for r, l in zip(result["test_train"], loaded): + self.assertEqual(r["image"], l["image"]) + self.assertEqual(r["label"], l["label"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_create_grid_and_affine.py b/tests/test_create_grid_and_affine.py index 0c0e52e04a..d70db45468 100644 --- a/tests/test_create_grid_and_affine.py +++ b/tests/test_create_grid_and_affine.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from monai.transforms import ( create_control_grid, @@ -21,6 +22,7 @@ create_shear, create_translate, ) +from tests.utils import assert_allclose, is_tf32_env class TestCreateGrid(unittest.TestCase): @@ -32,50 +34,47 @@ def test_create_grid(self): with self.assertRaisesRegex(TypeError, ""): create_grid((1, 1), spacing=2.0) - g = create_grid((1, 1)) - expected = np.array([[[0.0]], [[0.0]], [[1.0]]]) - np.testing.assert_allclose(g, expected) + test_assert(create_grid, ((1, 1),), np.array([[[0.0]], [[0.0]], [[1.0]]])) - g = create_grid((1, 1), homogeneous=False) - expected = np.array([[[0.0]], [[0.0]]]) - np.testing.assert_allclose(g, expected) + test_assert(create_grid, ((1, 1), None, False), np.array([[[0.0]], [[0.0]]])) - g = create_grid((1, 1), spacing=(1.2, 1.3)) - expected = np.array([[[0.0]], [[0.0]], [[1.0]]]) - np.testing.assert_allclose(g, expected) + test_assert(create_grid, ((1, 1), (1.2, 1.3)), np.array([[[0.0]], [[0.0]], [[1.0]]])) - g = create_grid((1, 1, 1), spacing=(1.2, 1.3, 1.0)) - expected = np.array([[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[1.0]]]]) - np.testing.assert_allclose(g, expected) + test_assert(create_grid, ((1, 1, 1), (1.2, 1.3, 1.0)), np.array([[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[1.0]]]])) - g = create_grid((1, 1, 1), spacing=(1.2, 1.3, 1.0), homogeneous=False) - expected = np.array([[[[0.0]]], [[[0.0]]], [[[0.0]]]]) - np.testing.assert_allclose(g, expected) + test_assert(create_grid, ((1, 1, 1), (1.2, 1.3, 1.0), False), np.array([[[[0.0]]], [[[0.0]]], [[[0.0]]]])) g = create_grid((1, 1, 1), spacing=(1.2, 1.3, 1.0), dtype=np.int32) np.testing.assert_equal(g.dtype, np.int32) - g = create_grid((2, 2, 2)) - expected = np.array( - [ - [[[-0.5, -0.5], [-0.5, -0.5]], [[0.5, 0.5], [0.5, 0.5]]], - [[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, -0.5], [0.5, 0.5]]], - [[[-0.5, 0.5], [-0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]]], - [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], - ] + g = create_grid((1, 1, 1), spacing=(1.2, 1.3, 1.0), dtype=torch.float64, backend="torch") + np.testing.assert_equal(g.dtype, torch.float64) + + test_assert( + create_grid, + ((2, 2, 2),), + np.array( + [ + [[[-0.5, -0.5], [-0.5, -0.5]], [[0.5, 0.5], [0.5, 0.5]]], + [[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, -0.5], [0.5, 0.5]]], + [[[-0.5, 0.5], [-0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]]], + [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], + ] + ), ) - np.testing.assert_allclose(g, expected) - g = create_grid((2, 2, 2), spacing=(1.2, 1.3, 1.0)) - expected = np.array( - [ - [[[-0.6, -0.6], [-0.6, -0.6]], [[0.6, 0.6], [0.6, 0.6]]], - [[[-0.65, -0.65], [0.65, 0.65]], [[-0.65, -0.65], [0.65, 0.65]]], - [[[-0.5, 0.5], [-0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]]], - [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], - ] + test_assert( + create_grid, + ((2, 2, 2), (1.2, 1.3, 1.0)), + np.array( + [ + [[[-0.6, -0.6], [-0.6, -0.6]], [[0.6, 0.6], [0.6, 0.6]]], + [[[-0.65, -0.65], [0.65, 0.65]], [[-0.65, -0.65], [0.65, 0.65]]], + [[[-0.5, 0.5], [-0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]]], + [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], + ] + ), ) - np.testing.assert_allclose(g, expected) def test_create_control_grid(self): with self.assertRaisesRegex(TypeError, ""): @@ -83,72 +82,87 @@ def test_create_control_grid(self): with self.assertRaisesRegex(TypeError, ""): create_control_grid((1, 1), 2.0) - g = create_control_grid((1.0, 1.0), (1.0, 1.0)) - expected = np.array( - [ - [[-1.0, -1.0, -1.0], [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]], - [[-1.0, 0.0, 1.0], [-1.0, 0.0, 1.0], [-1.0, 0.0, 1.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - ] + test_assert( + create_control_grid, + ((1.0, 1.0), (1.0, 1.0)), + np.array( + [ + [[-1.0, -1.0, -1.0], [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]], + [[-1.0, 0.0, 1.0], [-1.0, 0.0, 1.0], [-1.0, 0.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ), ) - np.testing.assert_allclose(g, expected) - g = create_control_grid((1.0, 1.0), (2.0, 2.0)) - expected = np.array( - [ - [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], - [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - ] + test_assert( + create_control_grid, + ((1.0, 1.0), (2.0, 2.0)), + np.array( + [ + [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], + [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ), ) - np.testing.assert_allclose(g, expected) - g = create_control_grid((2.0, 2.0), (1.0, 1.0)) - expected = np.array( - [ - [[-1.5, -1.5, -1.5, -1.5], [-0.5, -0.5, -0.5, -0.5], [0.5, 0.5, 0.5, 0.5], [1.5, 1.5, 1.5, 1.5]], - [[-1.5, -0.5, 0.5, 1.5], [-1.5, -0.5, 0.5, 1.5], [-1.5, -0.5, 0.5, 1.5], [-1.5, -0.5, 0.5, 1.5]], - [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]], - ] + test_assert( + create_control_grid, + ((2.0, 2.0), (1.0, 1.0)), + np.array( + [ + [[-1.5, -1.5, -1.5, -1.5], [-0.5, -0.5, -0.5, -0.5], [0.5, 0.5, 0.5, 0.5], [1.5, 1.5, 1.5, 1.5]], + [[-1.5, -0.5, 0.5, 1.5], [-1.5, -0.5, 0.5, 1.5], [-1.5, -0.5, 0.5, 1.5], [-1.5, -0.5, 0.5, 1.5]], + [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]], + ] + ), ) - np.testing.assert_allclose(g, expected) - g = create_control_grid((2.0, 2.0), (2.0, 2.0)) - expected = np.array( - [ - [[-3.0, -3.0, -3.0, -3.0], [-1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0], [3.0, 3.0, 3.0, 3.0]], - [[-3.0, -1.0, 1.0, 3.0], [-3.0, -1.0, 1.0, 3.0], [-3.0, -1.0, 1.0, 3.0], [-3.0, -1.0, 1.0, 3.0]], - [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]], - ] + test_assert( + create_control_grid, + ((2.0, 2.0), (2.0, 2.0)), + np.array( + [ + [[-3.0, -3.0, -3.0, -3.0], [-1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0], [3.0, 3.0, 3.0, 3.0]], + [[-3.0, -1.0, 1.0, 3.0], [-3.0, -1.0, 1.0, 3.0], [-3.0, -1.0, 1.0, 3.0], [-3.0, -1.0, 1.0, 3.0]], + [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]], + ] + ), ) - np.testing.assert_allclose(g, expected) - g = create_control_grid((1.0, 1.0, 1.0), (2.0, 2.0, 2.0), homogeneous=False) - expected = np.array( - [ - [ - [[-2.0, -2.0, -2.0], [-2.0, -2.0, -2.0], [-2.0, -2.0, -2.0]], - [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], - [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], - ], - [ - [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], - [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], - [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], - ], + test_assert( + create_control_grid, + ((1.0, 1.0, 1.0), (2.0, 2.0, 2.0), False), + np.array( [ - [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], - [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], - [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], - ], - ] + [ + [[-2.0, -2.0, -2.0], [-2.0, -2.0, -2.0], [-2.0, -2.0, -2.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], + ], + [ + [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], + [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], + [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], + ], + [ + [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], + [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], + [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]], + ], + ] + ), ) - np.testing.assert_allclose(g, expected) def test_assert(func, params, expected): - m = func(*params) - np.testing.assert_allclose(m, expected, atol=1e-7) + gpu_test = ("torch_gpu",) if torch.cuda.is_available() else () + for b in ("torch", "numpy") + gpu_test: + if b == "torch_gpu": + m = func(*params, device="cuda:0", backend="torch") + else: + m = func(*params, backend=b) + assert_allclose(m, expected, type_test=False, rtol=1e-2 if is_tf32_env() else 1e-5, atol=1e-5) class TestCreateAffine(unittest.TestCase): diff --git a/tests/test_crf_cpu.py b/tests/test_crf_cpu.py index ed1860943f..46da3298bc 100644 --- a/tests/test_crf_cpu.py +++ b/tests/test_crf_cpu.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -55,12 +55,12 @@ # Batch 0 [ # Channel 0 - [1, 1, 1, 0.5, 0], + [1, 1, 1, 0.5, 0] ], # Batch 1 [ # Channel 0 - [1, 1, 0.5, 0, 0], + [1, 1, 0.5, 0, 0] ], ], # Expected @@ -117,12 +117,12 @@ # Batch 0 [ # Channel 0 - [1, 1, 1, 0.5, 0], + [1, 1, 1, 0.5, 0] ], # Batch 1 [ # Channel 0 - [1, 1, 0.5, 0, 0], + [1, 1, 0.5, 0, 0] ], ], # Expected @@ -185,7 +185,7 @@ [1.0, 1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0, 0.0], ], - ], + ] ], # Features [ @@ -207,7 +207,7 @@ [0.0, 0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 1.0, 1.0], ], - ], + ] ], # Expected [ @@ -237,7 +237,7 @@ [0.688815, 0.687855, 0.687076, 0.228579, 0.227552], [0.687434, 0.686453, 0.445019, 0.229047, 0.227588], ], - ], + ] ], ], [ @@ -344,7 +344,7 @@ [0.0, 0.0, 0.0, 1.0, 1.0], ], ], - ], + ] ], # Features [ @@ -392,8 +392,8 @@ [0.0, 0.0, 1.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0, 1.0], ], - ], - ], + ] + ] ], # Expected [ @@ -485,7 +485,7 @@ [0.500533, 0.500745, 0.553344, 0.771576, 0.772222], ], ], - ], + ] ], ], ] diff --git a/tests/test_crf_cuda.py b/tests/test_crf_cuda.py index adf8c440c0..ca25fe2de9 100644 --- a/tests/test_crf_cuda.py +++ b/tests/test_crf_cuda.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -55,12 +55,12 @@ # Batch 0 [ # Channel 0 - [1, 1, 1, 0.5, 0], + [1, 1, 1, 0.5, 0] ], # Batch 1 [ # Channel 0 - [1, 1, 0.5, 0, 0], + [1, 1, 0.5, 0, 0] ], ], # Expected @@ -117,12 +117,12 @@ # Batch 0 [ # Channel 0 - [1, 1, 1, 0.5, 0], + [1, 1, 1, 0.5, 0] ], # Batch 1 [ # Channel 0 - [1, 1, 0.5, 0, 0], + [1, 1, 0.5, 0, 0] ], ], # Expected @@ -185,7 +185,7 @@ [0.5, 1.0, 0.5, 0.0, 0.0], [1.0, 0.5, 0.0, 0.0, 0.0], ], - ], + ] ], # Features [ @@ -207,7 +207,7 @@ [0.0, 0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 1.0, 1.0], ], - ], + ] ], # Expected [ @@ -237,7 +237,7 @@ [0.492602, 0.609557, 0.480947, 0.161909, 0.161476], [0.610678, 0.480516, 0.352479, 0.159380, 0.158274], ], - ], + ] ], ], [ @@ -344,7 +344,7 @@ [0.0, 0.0, 0.0, 1.0, 1.0], ], ], - ], + ] ], # Features [ @@ -392,8 +392,8 @@ [0.0, 0.0, 1.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0, 1.0], ], - ], - ], + ] + ] ], # Expected [ @@ -485,7 +485,7 @@ [0.500663, 0.500887, 0.556332, 0.773597, 0.775210], ], ], - ], + ] ], ], ] diff --git a/tests/test_crop_foreground.py b/tests/test_crop_foreground.py index 71e488cac8..a9c473a4c6 100644 --- a/tests/test_crop_foreground.py +++ b/tests/test_crop_foreground.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,60 +12,79 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import CropForeground +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), -] - -TEST_CASE_2 = [ - {"select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[3]]]), -] - -TEST_CASE_3 = [ - {"select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), -] - -TEST_CASE_4 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), -] - -TEST_CASE_5 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1]}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), -] - -TEST_CASE_6 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 4}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2, 1, 0], [2, 3, 2, 0], [1, 2, 1, 0], [0, 0, 0, 0]]]), -] - -TEST_CASE_7 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 10, "constant_values": 2}, - np.array([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - np.zeros((1, 0, 0)), -] +TEST_COORDS, TESTS = [], [] + +for p in TEST_NDARRAYS: + TEST_COORDS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + p([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, + p([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]), + p([[[3]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + p([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1]}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 4}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + p([[[1, 2, 1, 0], [2, 3, 2, 0], [1, 2, 1, 0], [0, 0, 0, 0]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 10}, + p([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + p(np.zeros((1, 0, 0), dtype=np.int64)), + ] + ) class TestCropForeground(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + @parameterized.expand(TEST_COORDS + TESTS) def test_value(self, argments, image, expected_data): result = CropForeground(**argments)(image) - np.testing.assert_allclose(result, expected_data) + torch.testing.assert_allclose(result, expected_data, rtol=1e-7, atol=0) - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TEST_COORDS) def test_return_coords(self, argments, image, _): argments["return_coords"] = True _, start_coord, end_coord = CropForeground(**argments)(image) diff --git a/tests/test_crop_foregroundd.py b/tests/test_crop_foregroundd.py index efe6b65b4b..7f1842197c 100644 --- a/tests/test_crop_foregroundd.py +++ b/tests/test_crop_foregroundd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,85 +12,128 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import CropForegroundd -from monai.utils import NumpyPadMode +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - { - "keys": ["img", "label"], - "source_key": "label", - "select_fn": lambda x: x > 0, - "channel_indices": None, - "margin": 0, - "mode": "constant", - "constant_values": 2, - }, - { - "img": np.array([[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]), - "label": np.array([[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]), - }, - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), -] +TEST_POSITION, TESTS = [], [] +for p in TEST_NDARRAYS: -TEST_CASE_2 = [ - {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[3]]]), -] - -TEST_CASE_3 = [ - {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), -] - -TEST_CASE_4 = [ - {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), -] - -TEST_CASE_5 = [ - {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1]}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), -] - -TEST_CASE_6 = [ - { - "keys": ["img", "seg"], - "source_key": "img", - "select_fn": lambda x: x > 0, - "channel_indices": 0, - "margin": 0, - "k_divisible": [4, 6], - "mode": ["edge", NumpyPadMode.CONSTANT], - }, - { - "img": np.array([[[0, 2, 1, 2, 0], [1, 1, 2, 1, 1], [2, 2, 3, 2, 2], [1, 1, 2, 1, 1], [0, 0, 0, 0, 0]]]), - "seg": np.array([[[0, 2, 1, 2, 0], [1, 1, 2, 1, 1], [2, 2, 3, 2, 2], [1, 1, 2, 1, 1], [0, 0, 0, 0, 0]]]), - }, - np.array([[[0, 2, 1, 2, 0, 0], [1, 1, 2, 1, 1, 1], [2, 2, 3, 2, 2, 2], [1, 1, 2, 1, 1, 1]]]), -] + TEST_POSITION.append( + [ + { + "keys": ["img", "label"], + "source_key": "label", + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": 0, + }, + { + "img": p( + np.array([[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]) + ), + "label": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]) + ), + }, + p(np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]])), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]) + ) + }, + p(np.array([[[3]]])), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ) + }, + p(np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]])), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]) + ) + }, + p(np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]])), + ] + ) + TESTS.append( + [ + { + "keys": ["img"], + "source_key": "img", + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": [2, 1], + }, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]) + ) + }, + p(np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]])), + ] + ) + TESTS.append( + [ + { + "keys": ["img"], + "source_key": "img", + "select_fn": lambda x: x > 0, + "channel_indices": 0, + "margin": 0, + "k_divisible": [4, 6], + "mode": "edge", + }, + { + "img": p( + np.array( + [[[0, 2, 1, 2, 0], [1, 1, 2, 1, 1], [2, 2, 3, 2, 2], [1, 1, 2, 1, 1], [0, 0, 0, 0, 0]]], + dtype=np.float32, + ) + ) + }, + p(np.array([[[0, 2, 1, 2, 0, 0], [1, 1, 2, 1, 1, 1], [2, 2, 3, 2, 2, 2], [1, 1, 2, 1, 1, 1]]])), + ] + ) class TestCropForegroundd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) - def test_value(self, argments, image, expected_data): - result = CropForegroundd(**argments)(image) - np.testing.assert_allclose(result["img"], expected_data) + @parameterized.expand(TEST_POSITION + TESTS) + def test_value(self, argments, input_data, expected_data): + result = CropForegroundd(**argments)(input_data) + r, i = result["img"], input_data["img"] + self.assertEqual(type(r), type(i)) + if isinstance(r, torch.Tensor): + self.assertEqual(r.device, i.device) + assert_allclose(r, expected_data) - @parameterized.expand([TEST_CASE_1]) - def test_foreground_position(self, argments, image, _): - result = CropForegroundd(**argments)(image) + @parameterized.expand(TEST_POSITION) + def test_foreground_position(self, argments, input_data, _): + result = CropForegroundd(**argments)(input_data) np.testing.assert_allclose(result["foreground_start_coord"], np.array([1, 1])) np.testing.assert_allclose(result["foreground_end_coord"], np.array([4, 4])) argments["start_coord_key"] = "test_start_coord" argments["end_coord_key"] = "test_end_coord" - result = CropForegroundd(**argments)(image) + result = CropForegroundd(**argments)(input_data) np.testing.assert_allclose(result["test_start_coord"], np.array([1, 1])) np.testing.assert_allclose(result["test_end_coord"], np.array([4, 4])) diff --git a/tests/test_cross_validation.py b/tests/test_cross_validation.py index 33e10a6a40..017fd0243c 100644 --- a/tests/test_cross_validation.py +++ b/tests/test_cross_validation.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,7 +22,7 @@ class TestCrossValidation(unittest.TestCase): @skip_if_quick def test_values(self): testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") - transform = Compose( + train_transform = Compose( [ LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), @@ -30,6 +30,7 @@ def test_values(self): ToTensord(keys=["image", "label"]), ] ) + val_transform = LoadImaged(keys=["image", "label"]) def _test_dataset(dataset): self.assertEqual(len(dataset), 52) @@ -45,7 +46,7 @@ def _test_dataset(dataset): root_dir=testing_dir, task="Task04_Hippocampus", section="validation", - transform=transform, + transform=train_transform, download=True, ) @@ -60,13 +61,14 @@ def _test_dataset(dataset): _test_dataset(data) - # test training data for fold 0 of 5 splits + # test training data for fold [1, 2, 3, 4] of 5 splits data = cvdataset.get_dataset(folds=[1, 2, 3, 4]) self.assertTupleEqual(data[0]["image"].shape, (1, 35, 52, 33)) self.assertEqual(len(data), 208) # test train / validation for fold 4 of 5 splits - data = cvdataset.get_dataset(folds=[4]) - self.assertTupleEqual(data[0]["image"].shape, (1, 38, 53, 30)) + data = cvdataset.get_dataset(folds=[4], transform=val_transform, download=False) + # val_transform doesn't add the channel dim to shape + self.assertTupleEqual(data[0]["image"].shape, (38, 53, 30)) self.assertEqual(len(data), 52) data = cvdataset.get_dataset(folds=[0, 1, 2, 3]) self.assertTupleEqual(data[0]["image"].shape, (1, 34, 49, 41)) diff --git a/tests/test_csv_dataset.py b/tests/test_csv_dataset.py index d187f4e64d..f288ac4b95 100644 --- a/tests/test_csv_dataset.py +++ b/tests/test_csv_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import unittest import numpy as np +import pandas as pd from monai.data import CSVDataset from monai.transforms import ToNumpyd @@ -57,6 +58,7 @@ def prepare_csv_file(data, filepath): filepath1 = os.path.join(tempdir, "test_data1.csv") filepath2 = os.path.join(tempdir, "test_data2.csv") filepath3 = os.path.join(tempdir, "test_data3.csv") + filepaths = [filepath1, filepath2, filepath3] prepare_csv_file(test_data1, filepath1) prepare_csv_file(test_data2, filepath2) prepare_csv_file(test_data3, filepath3) @@ -76,7 +78,7 @@ def prepare_csv_file(data, filepath): ) # test multiple CSV files, join tables with kwargs - dataset = CSVDataset([filepath1, filepath2, filepath3], on="subject_id") + dataset = CSVDataset(filepaths, on="subject_id") self.assertDictEqual( {k: round(v, 4) if not isinstance(v, (str, np.bool_)) else v for k, v in dataset[3].items()}, { @@ -102,7 +104,7 @@ def prepare_csv_file(data, filepath): # test selected rows and columns dataset = CSVDataset( - filename=[filepath1, filepath2, filepath3], + src=filepaths, row_indices=[[0, 2], 3], # load row: 0, 1, 3 col_names=["subject_id", "image", "ehr_1", "ehr_7", "meta_1"], ) @@ -120,7 +122,7 @@ def prepare_csv_file(data, filepath): # test group columns dataset = CSVDataset( - filename=[filepath1, filepath2, filepath3], + src=filepaths, row_indices=[1, 3], # load row: 1, 3 col_names=["subject_id", "image", *[f"ehr_{i}" for i in range(11)], "meta_0", "meta_1", "meta_2"], col_groups={"ehr": [f"ehr_{i}" for i in range(11)], "meta12": ["meta_1", "meta_2"]}, @@ -133,9 +135,7 @@ def prepare_csv_file(data, filepath): # test transform dataset = CSVDataset( - filename=[filepath1, filepath2, filepath3], - col_groups={"ehr": [f"ehr_{i}" for i in range(5)]}, - transform=ToNumpyd(keys="ehr"), + src=filepaths, col_groups={"ehr": [f"ehr_{i}" for i in range(5)]}, transform=ToNumpyd(keys="ehr") ) self.assertEqual(len(dataset), 5) expected = [ @@ -151,7 +151,7 @@ def prepare_csv_file(data, filepath): # test default values and dtype dataset = CSVDataset( - filename=[filepath1, filepath2, filepath3], + src=filepaths, col_names=["subject_id", "image", "ehr_1", "ehr_9", "meta_1"], col_types={"image": {"type": str, "default": "No image"}, "ehr_1": {"type": int, "default": 0}}, how="outer", # generate NaN values in this merge mode @@ -161,6 +161,29 @@ def prepare_csv_file(data, filepath): self.assertEqual(type(dataset[-1]["ehr_1"]), int) np.testing.assert_allclose(dataset[-1]["ehr_9"], 3.3537, rtol=1e-2) + # test pre-loaded DataFrame + df = pd.read_csv(filepath1) + dataset = CSVDataset(src=df) + self.assertDictEqual( + {k: round(v, 4) if not isinstance(v, str) else v for k, v in dataset[2].items()}, + { + "subject_id": "s000002", + "label": 4, + "image": "./imgs/s000002.png", + "ehr_0": 3.7725, + "ehr_1": 4.2118, + "ehr_2": 4.6353, + }, + ) + + # test pre-loaded multiple DataFrames, join tables with kwargs + dfs = [pd.read_csv(i) for i in filepaths] + dataset = CSVDataset(src=dfs, on="subject_id") + self.assertEqual(dataset[3]["subject_id"], "s000003") + self.assertEqual(dataset[3]["label"], 1) + self.assertEqual(round(dataset[3]["ehr_0"], 4), 3.3333) + self.assertEqual(dataset[3]["meta_0"], False) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_csv_iterable_dataset.py b/tests/test_csv_iterable_dataset.py index c7a3f31dc6..d6b84074ba 100644 --- a/tests/test_csv_iterable_dataset.py +++ b/tests/test_csv_iterable_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,6 +15,7 @@ import unittest import numpy as np +import pandas as pd from monai.data import CSVIterableDataset, DataLoader from monai.transforms import ToNumpyd @@ -58,14 +59,17 @@ def prepare_csv_file(data, filepath): filepath1 = os.path.join(tempdir, "test_data1.csv") filepath2 = os.path.join(tempdir, "test_data2.csv") filepath3 = os.path.join(tempdir, "test_data3.csv") + filepaths = [filepath1, filepath2, filepath3] prepare_csv_file(test_data1, filepath1) prepare_csv_file(test_data2, filepath2) prepare_csv_file(test_data3, filepath3) # test single CSV file - dataset = CSVIterableDataset(filepath1) - for i, item in enumerate(dataset): - if i == 2: + dataset = CSVIterableDataset(filepath1, shuffle=False) + count = 0 + for item in dataset: + count += 1 + if count == 3: self.assertDictEqual( {k: round(v, 4) if not isinstance(v, str) else v for k, v in item.items()}, { @@ -78,16 +82,25 @@ def prepare_csv_file(data, filepath): }, ) break + self.assertEqual(count, 3) + dataset.close() + # test reset iterables - dataset.reset(filename=filepath3) + dataset.reset(src=filepath3) + count = 0 for i, item in enumerate(dataset): - if i == 3: + count += 1 + if i == 4: self.assertEqual(item["meta_0"], False) + self.assertEqual(count, 5) + dataset.close() # test multiple CSV files, join tables with kwargs - dataset = CSVIterableDataset([filepath1, filepath2, filepath3], on="subject_id") - for i, item in enumerate(dataset): - if i == 3: + dataset = CSVIterableDataset(filepaths, on="subject_id", shuffle=False) + count = 0 + for item in dataset: + count += 1 + if count == 4: self.assertDictEqual( {k: round(v, 4) if not isinstance(v, (str, np.bool_)) else v for k, v in item.items()}, { @@ -110,15 +123,17 @@ def prepare_csv_file(data, filepath): "meta_2": True, }, ) + self.assertEqual(count, 5) + dataset.close() # test selected columns and chunk size dataset = CSVIterableDataset( - filename=[filepath1, filepath2, filepath3], - chunksize=2, - col_names=["subject_id", "image", "ehr_1", "ehr_7", "meta_1"], + src=filepaths, chunksize=2, col_names=["subject_id", "image", "ehr_1", "ehr_7", "meta_1"], shuffle=False ) - for i, item in enumerate(dataset): - if i == 3: + count = 0 + for item in dataset: + count += 1 + if count == 4: self.assertDictEqual( {k: round(v, 4) if not isinstance(v, (str, np.bool_)) else v for k, v in item.items()}, { @@ -129,49 +144,106 @@ def prepare_csv_file(data, filepath): "meta_1": False, }, ) + self.assertEqual(count, 5) + dataset.close() # test group columns dataset = CSVIterableDataset( - filename=[filepath1, filepath2, filepath3], + src=filepaths, col_names=["subject_id", "image", *[f"ehr_{i}" for i in range(11)], "meta_0", "meta_1", "meta_2"], col_groups={"ehr": [f"ehr_{i}" for i in range(11)], "meta12": ["meta_1", "meta_2"]}, + shuffle=False, ) - for i, item in enumerate(dataset): - if i == 3: + count = 0 + for item in dataset: + count += 1 + if count == 4: np.testing.assert_allclose( [round(i, 4) for i in item["ehr"]], [3.3333, 3.2353, 3.4000, 3.1647, 3.0863, 3.7255, 3.6980, 3.6980, 3.7020, 3.3098, 3.7294], ) np.testing.assert_allclose(item["meta12"], [False, True]) + self.assertEqual(count, 5) + dataset.close() # test transform dataset = CSVIterableDataset( - filename=[filepath1, filepath2, filepath3], + chunksize=2, + buffer_size=4, + src=filepaths, col_groups={"ehr": [f"ehr_{i}" for i in range(5)]}, transform=ToNumpyd(keys="ehr"), + shuffle=True, + seed=123, ) expected = [ - [2.0078, 2.2902, 2.0549, 3.0196, 3.8078], [6.8392, 6.4745, 5.8627, 5.1922, 5.2745], - [3.7725, 4.2118, 4.6353, 5.2980, 9.5451], [3.3333, 3.2353, 3.4000, 3.1647, 3.0863], + [3.7725, 4.2118, 4.6353, 5.298, 9.5451], [6.4275, 6.2549, 5.9765, 6.2627, 7.7176], + [2.0078, 2.2902, 2.0549, 3.0196, 3.8078], ] + count = 0 for item, exp in zip(dataset, expected): + count += 1 self.assertTrue(isinstance(item["ehr"], np.ndarray)) np.testing.assert_allclose(np.around(item["ehr"], 4), exp) + self.assertEqual(count, 5) + dataset.close() # test multiple processes loading - dataset = CSVIterableDataset(filepath1, transform=ToNumpyd(keys="label")) + dataset = CSVIterableDataset(filepath1, transform=ToNumpyd(keys="label"), shuffle=False) # set num workers = 0 for mac / win num_workers = 2 if sys.platform == "linux" else 0 dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=2) + count = 0 for item in dataloader: + count += 1 # test the last item which only has 1 data if len(item) == 1: self.assertListEqual(item["subject_id"], ["s000002"]) np.testing.assert_allclose(item["label"], [4]) self.assertListEqual(item["image"], ["./imgs/s000002.png"]) + self.assertEqual(count, 3) + dataset.close() + + # test iterable stream + iters = pd.read_csv(filepath1, chunksize=1000) + dataset = CSVIterableDataset(src=iters, shuffle=False) + count = 0 + for item in dataset: + count += 1 + if count == 3: + self.assertDictEqual( + {k: round(v, 4) if not isinstance(v, str) else v for k, v in item.items()}, + { + "subject_id": "s000002", + "label": 4, + "image": "./imgs/s000002.png", + "ehr_0": 3.7725, + "ehr_1": 4.2118, + "ehr_2": 4.6353, + }, + ) + break + self.assertEqual(count, 3) + dataset.close() + + # test multiple iterable streams, join tables with kwargs + iters = [pd.read_csv(i, chunksize=1000) for i in filepaths] + dataset = CSVIterableDataset(src=iters, on="subject_id", shuffle=False) + count = 0 + for item in dataset: + count += 1 + if count == 4: + self.assertEqual(item["subject_id"], "s000003") + self.assertEqual(item["label"], 1) + self.assertEqual(round(item["ehr_0"], 4), 3.3333) + self.assertEqual(item["meta_0"], False) + self.assertEqual(count, 5) + # manually close the pre-loaded iterables instead of `dataset.close()` + for i in iters: + i.close() if __name__ == "__main__": diff --git a/tests/test_csv_saver.py b/tests/test_csv_saver.py index 6dd0159322..6b60de4b0c 100644 --- a/tests/test_csv_saver.py +++ b/tests/test_csv_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -29,7 +29,7 @@ def test_saved_content(self): saver.finalize() filepath = os.path.join(tempdir, "predictions.csv") self.assertTrue(os.path.exists(filepath)) - with open(filepath, "r") as f: + with open(filepath) as f: reader = csv.reader(f) i = 0 for row in reader: diff --git a/tests/test_cucim_dict_transform.py b/tests/test_cucim_dict_transform.py new file mode 100644 index 0000000000..f8b54c3147 --- /dev/null +++ b/tests/test_cucim_dict_transform.py @@ -0,0 +1,141 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import CuCIMd +from monai.utils import optional_import, set_determinism +from tests.utils import HAS_CUPY, skip_if_no_cuda + +_, has_cut = optional_import("cucim.core.operations.expose.transform") +cp, _ = optional_import("cupy") + +set_determinism(seed=0) + +TEST_CASE_COLOR_JITTER_1 = [ + {"name": "color_jitter", "brightness": 0.0, "contrast": 0.0, "saturation": 0.0, "hue": 0.0}, + np.array([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]], dtype=np.float32), + np.array([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]], dtype=np.float32), +] + +TEST_CASE_COLOR_JITTER_2 = [ + {"name": "color_jitter", "brightness": 0.0, "contrast": 0.0, "saturation": 0.0, "hue": 0.0}, + np.array([[[0, 1], [2, 3]], [[0, 1], [2, 3]], [[0, 1], [2, 3]]], dtype=np.uint8), + np.array([[[0, 1], [2, 3]], [[0, 1], [2, 3]], [[0, 1], [2, 3]]], dtype=np.uint8), +] + +TEST_CASE_FLIP_1 = [ + {"name": "image_flip", "spatial_axis": -1}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]]], dtype=np.float32), +] + + +TEST_CASE_ROTATE_1 = [ + {"name": "image_rotate_90", "k": 1, "spatial_axis": (-2, -1)}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]]], dtype=np.float32), +] + +TEST_CASE_SCALE_INTENSITY_1 = [ + {"name": "scale_intensity_range", "a_min": 0.0, "a_max": 4.0, "b_min": 0.0, "b_max": 1.0, "clip": False}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]]], dtype=np.float32), +] + +TEST_CASE_ZOOM_1 = [ + {"name": "zoom", "zoom_factor": (0.5, 0.5)}, + np.mgrid[:3, 1:4].astype(dtype=np.float32), + np.concatenate([np.ones((1, 3, 3), dtype=np.float32) * 1.0, np.ones((1, 3, 3), dtype=np.float32) * 2.0]), +] + + +@skip_if_no_cuda +@unittest.skipUnless(HAS_CUPY, "CuPy is required.") +@unittest.skipUnless(has_cut, "cuCIM transforms are required.") +class TestCuCIMDict(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_COLOR_JITTER_2, + TEST_CASE_FLIP_1, + TEST_CASE_ROTATE_1, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + ] + ) + def test_tramsforms_numpy_single(self, params, input, expected): + input = {"image": input} + output = CuCIMd(keys="image", **params)(input)["image"] + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, expected) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_COLOR_JITTER_2, + TEST_CASE_FLIP_1, + TEST_CASE_ROTATE_1, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + ] + ) + def test_tramsforms_numpy_batch(self, params, input, expected): + input = {"image": input[cp.newaxis, ...]} + expected = expected[cp.newaxis, ...] + output = CuCIMd(keys="image", **params)(input)["image"] + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, expected) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_COLOR_JITTER_2, + TEST_CASE_FLIP_1, + TEST_CASE_ROTATE_1, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + ] + ) + def test_tramsforms_cupy_single(self, params, input, expected): + input = {"image": cp.asarray(input)} + expected = cp.asarray(expected) + output = CuCIMd(keys="image", **params)(input)["image"] + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, expected) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_COLOR_JITTER_2, + TEST_CASE_FLIP_1, + TEST_CASE_ROTATE_1, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + ] + ) + def test_tramsforms_cupy_batch(self, params, input, expected): + input = {"image": cp.asarray(input)[cp.newaxis, ...]} + expected = cp.asarray(expected)[cp.newaxis, ...] + output = CuCIMd(keys="image", **params)(input)["image"] + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_cucim_transform.py b/tests/test_cucim_transform.py new file mode 100644 index 0000000000..2bf9791bce --- /dev/null +++ b/tests/test_cucim_transform.py @@ -0,0 +1,140 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import CuCIM +from monai.utils import optional_import, set_determinism +from tests.utils import HAS_CUPY, skip_if_no_cuda + +_, has_cut = optional_import("cucim.core.operations.expose.transform") +cp, _ = optional_import("cupy") + +set_determinism(seed=0) + +TEST_CASE_COLOR_JITTER_1 = [ + {"name": "color_jitter", "brightness": 0.0, "contrast": 0.0, "saturation": 0.0, "hue": 0.0}, + np.array([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]], dtype=np.float32), + np.array([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]], dtype=np.float32), +] + +TEST_CASE_COLOR_JITTER_2 = [ + {"name": "color_jitter", "brightness": 0.0, "contrast": 0.0, "saturation": 0.0, "hue": 0.0}, + np.array([[[0, 1], [2, 3]], [[0, 1], [2, 3]], [[0, 1], [2, 3]]], dtype=np.uint8), + np.array([[[0, 1], [2, 3]], [[0, 1], [2, 3]], [[0, 1], [2, 3]]], dtype=np.uint8), +] + +TEST_CASE_FLIP_1 = [ + {"name": "image_flip", "spatial_axis": -1}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]]], dtype=np.float32), +] + + +TEST_CASE_ROTATE_1 = [ + {"name": "image_rotate_90", "k": 1, "spatial_axis": (-2, -1)}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]]], dtype=np.float32), +] + +TEST_CASE_SCALE_INTENSITY_1 = [ + {"name": "scale_intensity_range", "a_min": 0.0, "a_max": 4.0, "b_min": 0.0, "b_max": 1.0, "clip": False}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]]], dtype=np.float32), +] + +TEST_CASE_ZOOM_1 = [ + {"name": "zoom", "zoom_factor": (0.5, 0.5)}, + np.mgrid[:3, 1:4].astype(dtype=np.float32), + np.concatenate([np.ones((1, 3, 3), dtype=np.float32) * 1.0, np.ones((1, 3, 3), dtype=np.float32) * 2.0]), +] + + +@skip_if_no_cuda +@unittest.skipUnless(HAS_CUPY, "CuPy is required.") +@unittest.skipUnless(has_cut, "cuCIM transforms are required.") +class TestCuCIM(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_COLOR_JITTER_2, + TEST_CASE_FLIP_1, + TEST_CASE_ROTATE_1, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + ] + ) + def test_tramsforms_numpy_single(self, params, input, expected): + output = CuCIM(**params)(input) + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, expected) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_COLOR_JITTER_2, + TEST_CASE_FLIP_1, + TEST_CASE_ROTATE_1, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + ] + ) + def test_tramsforms_numpy_batch(self, params, input, expected): + input = input[cp.newaxis, ...] + expected = expected[cp.newaxis, ...] + output = CuCIM(**params)(input) + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, expected) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_COLOR_JITTER_2, + TEST_CASE_FLIP_1, + TEST_CASE_ROTATE_1, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + ] + ) + def test_tramsforms_cupy_single(self, params, input, expected): + input = cp.asarray(input) + expected = cp.asarray(expected) + output = CuCIM(**params)(input) + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, expected) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_COLOR_JITTER_2, + TEST_CASE_FLIP_1, + TEST_CASE_ROTATE_1, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + ] + ) + def test_tramsforms_cupy_batch(self, params, input, expected): + input = cp.asarray(input)[cp.newaxis, ...] + expected = cp.asarray(expected)[cp.newaxis, ...] + output = CuCIM(**params)(input) + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_cuimage_reader.py b/tests/test_cuimage_reader.py deleted file mode 100644 index 2cbfaec113..0000000000 --- a/tests/test_cuimage_reader.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import unittest -from unittest import skipUnless - -import numpy as np -from numpy.testing import assert_array_equal -from parameterized import parameterized - -from monai.apps.utils import download_url -from monai.data.image_reader import WSIReader -from monai.utils import optional_import - -_, has_cim = optional_import("cucim") -PILImage, has_pil = optional_import("PIL.Image") - -FILE_URL = "http://openslide.cs.cmu.edu/download/openslide-testdata/Generic-TIFF/CMU-1.tiff" -FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + os.path.basename(FILE_URL)) - -HEIGHT = 32914 -WIDTH = 46000 - -TEST_CASE_0 = [FILE_PATH, (3, HEIGHT, WIDTH)] - -TEST_CASE_1 = [ - FILE_PATH, - {"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0}, - np.array([[[246], [246]], [[246], [246]], [[246], [246]]]), -] - -TEST_CASE_2 = [ - FILE_PATH, - {"location": (0, 0), "size": (2, 1), "level": 2}, - np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), -] - -TEST_CASE_3 = [ - FILE_PATH, - { - "location": (0, 0), - "size": (8, 8), - "level": 2, - "grid_shape": (2, 1), - "patch_size": 2, - }, - np.array( - [ - [[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[239, 239], [239, 239]]], - [[[242, 242], [242, 243]], [[242, 242], [242, 243]], [[242, 242], [242, 243]]], - ] - ), -] - -TEST_CASE_4 = [ - FILE_PATH, - { - "location": (0, 0), - "size": (8, 8), - "level": 2, - "grid_shape": (2, 1), - "patch_size": 1, - }, - np.array([[[[239]], [[239]], [[239]]], [[[243]], [[243]], [[243]]]]), -] - -TEST_CASE_RGB_0 = [ - np.ones((3, 2, 2), dtype=np.uint8), # CHW -] - -TEST_CASE_RGB_1 = [ - np.ones((3, 100, 100), dtype=np.uint8), # CHW -] - - -class TestCuCIMReader(unittest.TestCase): - @skipUnless(has_cim, "Requires CuCIM") - def setUp(self): - download_url(FILE_URL, FILE_PATH, "5a3cfd4fd725c50578ddb80b517b759f") - - @parameterized.expand([TEST_CASE_0]) - def test_read_whole_image(self, file_path, expected_shape): - reader = WSIReader("cuCIM") - img_obj = reader.read(file_path) - img = reader.get_data(img_obj)[0] - self.assertTupleEqual(img.shape, expected_shape) - - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_read_region(self, file_path, patch_info, expected_img): - reader = WSIReader("cuCIM") - img_obj = reader.read(file_path) - img = reader.get_data(img_obj, **patch_info)[0] - self.assertTupleEqual(img.shape, expected_img.shape) - self.assertIsNone(assert_array_equal(img, expected_img)) - - @parameterized.expand([TEST_CASE_3, TEST_CASE_4]) - def test_read_patches(self, file_path, patch_info, expected_img): - reader = WSIReader("cuCIM") - img_obj = reader.read(file_path) - img = reader.get_data(img_obj, **patch_info)[0] - self.assertTupleEqual(img.shape, expected_img.shape) - self.assertIsNone(assert_array_equal(img, expected_img)) - - @parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1]) - @skipUnless(has_pil, "Requires PIL") - def test_read_rgba(self, img_expected): - image = {} - reader = WSIReader("cuCIM") - for mode in ["RGB", "RGBA"]: - file_path = self.create_rgba_image(img_expected, "temp_cu_tiff_image", mode=mode) - img_obj = reader.read(file_path) - image[mode], _ = reader.get_data(img_obj) - - self.assertIsNone(assert_array_equal(image["RGB"], img_expected)) - self.assertIsNone(assert_array_equal(image["RGBA"], img_expected)) - - def create_rgba_image(self, array: np.ndarray, filename_prefix: str, mode: str): - file_path = os.path.join(os.path.dirname(__file__), "testing_data", f"{filename_prefix}_{mode}.tiff") - - if mode == "RGBA": - array = np.concatenate([array, 255 * np.ones_like(array[0])[np.newaxis]]).astype(np.uint8) - - img_rgb = array.transpose(1, 2, 0) - - image = PILImage.fromarray(img_rgb, mode=mode) - image.save(file_path) - - return file_path - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_cumulative.py b/tests/test_cumulative.py new file mode 100644 index 0000000000..12a6a5e5e7 --- /dev/null +++ b/tests/test_cumulative.py @@ -0,0 +1,59 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from monai.metrics import Cumulative +from tests.utils import assert_allclose + + +class TestCumulative(unittest.TestCase): + def test_single(self): + c = Cumulative() + c.extend([2, 3]) + c.append(1) + assert_allclose(c.get_buffer(), torch.tensor([2, 3, 1])) + + def test_multi(self): + c = Cumulative() + c.extend([2, 3], [4, 6]) + c.append(1) + assert_allclose(c.get_buffer()[0], torch.tensor([2, 3, 1])) + assert_allclose(c.get_buffer()[1], torch.tensor([4, 6])) + + c.reset() + c.append() + c.extend() + self.assertEqual(c.get_buffer(), []) + + c.reset() + + def test_ill(self): + c = Cumulative() + with self.assertRaises(TypeError): + c.extend(None) + with self.assertRaises(TypeError): + c.extend([]) + with self.assertRaises(TypeError): + c.extend(1) + with self.assertRaises(TypeError): + c.append([]) + c.append([1, 2]) + c.get_buffer() + with self.assertRaises(TypeError): + c.append(None) + c.get_buffer() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_cumulative_average.py b/tests/test_cumulative_average.py new file mode 100644 index 0000000000..4e7e4ff5d9 --- /dev/null +++ b/tests/test_cumulative_average.py @@ -0,0 +1,63 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.metrics import CumulativeAverage + +# single class value +TEST_CASE_1 = [[torch.as_tensor([[0.1]]), torch.as_tensor([[0.2]]), torch.as_tensor([[0.3]])], torch.as_tensor([0.2])] + +# multi-class value +TEST_CASE_2 = [ + [torch.as_tensor([[0.1, 0.2]]), torch.as_tensor([[0.2, 0.3]]), torch.as_tensor([[0.3, 0.4]])], + torch.as_tensor([0.2, 0.3]), +] + +# Nan value +TEST_CASE_3 = [ + [torch.as_tensor([[0.1]]), torch.as_tensor([[0.2]]), torch.as_tensor([[float("nan")]])], + torch.as_tensor([0.15]), +] + +# different input shape +TEST_CASE_4 = [[torch.as_tensor(0.1), torch.as_tensor(0.2), torch.as_tensor(0.3)], torch.as_tensor(0.2)] + + +class TestCumulativeAverage(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + def test_value(self, input_data, expected_value): + average = CumulativeAverage() + func = average.append if input_data[0].ndim < 2 else average.extend + func(input_data[0]) + func(input_data[1]) + result = average.aggregate() + # continue to update new data + func(input_data[2]) + result = average.aggregate() + torch.testing.assert_allclose(result, expected_value) + + def test_numpy_array(self): + class TestCumulativeAverage(CumulativeAverage): + def get_buffer(self): + return np.array([[1, 2], [3, np.nan]]) + + average = TestCumulativeAverage() + result = average.aggregate() + np.testing.assert_allclose(result, np.array([2.0, 2.0])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_cumulative_average_dist.py b/tests/test_cumulative_average_dist.py new file mode 100644 index 0000000000..5de139e9ac --- /dev/null +++ b/tests/test_cumulative_average_dist.py @@ -0,0 +1,47 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +import torch.distributed as dist + +from monai.metrics import CumulativeAverage +from tests.utils import DistCall, DistTestCase + + +class DistributedCumulativeAverage(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) + def test_value(self): + rank = dist.get_rank() + input_data = [ + [torch.as_tensor([[0.1]]), torch.as_tensor([[0.2]]), torch.as_tensor([[0.3]])], + [torch.as_tensor([[0.1]]), torch.as_tensor([[0.2]]), torch.as_tensor([[float("nan")]])], + [torch.as_tensor([[0.1, 0.2]]), torch.as_tensor([[0.2, 0.3]]), torch.as_tensor([[0.3, 0.4]])], + [torch.as_tensor(0.1), torch.as_tensor(0.2), torch.as_tensor(0.3)], + ] + expected = [torch.as_tensor([0.2]), torch.as_tensor([0.15]), torch.as_tensor([0.2, 0.3]), torch.as_tensor(0.2)] + average = CumulativeAverage() + + for i, e in zip(input_data, expected): + func = average.append if i[0].ndim < 2 else average.extend + if rank == 0: + func(i[0]) + func(i[1]) + else: + func(i[2]) + result = average.aggregate() + torch.testing.assert_allclose(result, e) + average.reset() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_data_stats.py b/tests/test_data_stats.py index 50536f2a5c..65a61ce7ec 100644 --- a/tests/test_data_stats.py +++ b/tests/test_data_stats.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,6 +11,7 @@ import logging import os +import sys import tempfile import unittest @@ -126,7 +127,7 @@ TEST_CASE_8 = [ np.array([[0, 1], [1, 2]]), - "test data statistics:\nType: \nShape: (2, 2)\nValue range: (0, 2)\n" + "test data statistics:\nType: int64\nShape: (2, 2)\nValue range: (0, 2)\n" "Value: [[0 1]\n [1 2]]\nAdditional info: 1.0\n", ] @@ -159,9 +160,10 @@ def test_file(self, input_data, expected_print): for h in _logger.handlers[:]: h.close() _logger.removeHandler(h) - with open(filename, "r") as f: + with open(filename) as f: content = f.read() - self.assertEqual(content, expected_print) + if sys.platform != "win32": + self.assertEqual(content, expected_print) if __name__ == "__main__": diff --git a/tests/test_data_statsd.py b/tests/test_data_statsd.py index aea0f1e721..1f38db2b05 100644 --- a/tests/test_data_statsd.py +++ b/tests/test_data_statsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,6 +11,7 @@ import logging import os +import sys import tempfile import unittest @@ -147,23 +148,14 @@ TEST_CASE_9 = [ {"img": np.array([[0, 1], [1, 2]])}, - "test data statistics:\nType: \nShape: (2, 2)\nValue range: (0, 2)\n" + "test data statistics:\nType: int64\nShape: (2, 2)\nValue range: (0, 2)\n" "Value: [[0 1]\n [1 2]]\nAdditional info: 1.0\n", ] class TestDataStatsd(unittest.TestCase): @parameterized.expand( - [ - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - TEST_CASE_7, - TEST_CASE_8, - ] + [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8] ) def test_value(self, input_param, input_data, expected_print): transform = DataStatsd(**input_param) @@ -192,9 +184,10 @@ def test_file(self, input_data, expected_print): h.close() _logger.removeHandler(h) del handler - with open(filename, "r") as f: + with open(filename) as f: content = f.read() - self.assertEqual(content, expected_print) + if sys.platform != "win32": + self.assertEqual(content, expected_print) if __name__ == "__main__": diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 3b159fb5b8..79126b2dbb 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,19 +20,9 @@ from monai.transforms import Compose, DataStatsd, Randomizable, SimulateDelayd from monai.utils import set_determinism -TEST_CASE_1 = [ - [ - {"image": np.asarray([1, 2, 3])}, - {"image": np.asarray([4, 5])}, - ] -] - -TEST_CASE_2 = [ - [ - {"label": torch.as_tensor([[3], [2]])}, - {"label": np.asarray([[1], [2]])}, - ] -] +TEST_CASE_1 = [[{"image": np.asarray([1, 2, 3])}, {"image": np.asarray([4, 5])}]] + +TEST_CASE_2 = [[{"label": torch.as_tensor([[3], [2]])}, {"label": np.asarray([[1], [2]])}]] class TestDataLoader(unittest.TestCase): diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 491b777550..f8d4ed2104 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_dataset_func.py b/tests/test_dataset_func.py new file mode 100644 index 0000000000..b5871d7de1 --- /dev/null +++ b/tests/test_dataset_func.py @@ -0,0 +1,52 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import tempfile +import unittest + +from monai.data import Dataset, DatasetFunc, load_decathlon_datalist, partition_dataset + + +class TestDatasetFunc(unittest.TestCase): + def test_seg_values(self): + with tempfile.TemporaryDirectory() as tempdir: + # prepare test datalist file + test_data = { + "name": "Spleen", + "description": "Spleen Segmentation", + "labels": {"0": "background", "1": "spleen"}, + "training": [ + {"image": "spleen_19.nii.gz", "label": "spleen_19.nii.gz"}, + {"image": "spleen_31.nii.gz", "label": "spleen_31.nii.gz"}, + ], + "test": ["spleen_15.nii.gz", "spleen_23.nii.gz"], + } + json_str = json.dumps(test_data) + file_path = os.path.join(tempdir, "test_data.json") + with open(file_path, "w") as json_file: + json_file.write(json_str) + + data_list = DatasetFunc( + data=file_path, func=load_decathlon_datalist, data_list_key="training", base_dir=tempdir + ) + # partition dataset for train / validation + data_partition = DatasetFunc( + data=data_list, func=lambda x, **kwargs: partition_dataset(x, **kwargs)[0], num_partitions=2 + ) + dataset = Dataset(data=data_partition, transform=None) + self.assertEqual(dataset[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz")) + self.assertEqual(dataset[0]["label"], os.path.join(tempdir, "spleen_19.nii.gz")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dataset_summary.py b/tests/test_dataset_summary.py index 5307bc7e66..5569c51a0c 100644 --- a/tests/test_dataset_summary.py +++ b/tests/test_dataset_summary.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,6 +22,15 @@ from monai.utils import set_determinism +def test_collate(batch): + elem = batch[0] + elem_type = type(elem) + if isinstance(elem, np.ndarray): + return np.stack(batch, 0) + elif isinstance(elem, dict): + return elem_type({key: test_collate([d[key] for d in batch]) for key in elem}) + + class TestDatasetSummary(unittest.TestCase): def test_spacing_intensity(self): set_determinism(seed=0) @@ -40,9 +49,12 @@ def test_spacing_intensity(self): {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) ] - dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"])) + dataset = Dataset( + data=data_dicts, transform=LoadImaged(keys=["image", "label"], meta_keys=["test1", "test2"]) + ) - calculator = DatasetSummary(dataset, num_workers=4) + # test **kwargs of `DatasetSummary` for `DataLoader` + calculator = DatasetSummary(dataset, num_workers=4, meta_key="test1", collate_fn=test_collate) target_spacing = calculator.get_target_spacing() self.assertEqual(target_spacing, (1.0, 1.0, 1.0)) @@ -56,13 +68,7 @@ def test_spacing_intensity(self): def test_anisotropic_spacing(self): with tempfile.TemporaryDirectory() as tempdir: - pixdims = [ - [1.0, 1.0, 5.0], - [1.0, 1.0, 4.0], - [1.0, 1.0, 4.5], - [1.0, 1.0, 2.0], - [1.0, 1.0, 1.0], - ] + pixdims = [[1.0, 1.0, 5.0], [1.0, 1.0, 4.0], [1.0, 1.0, 4.5], [1.0, 1.0, 2.0], [1.0, 1.0, 1.0]] for i in range(5): im, seg = create_test_image_3d(32, 32, 32, num_seg_classes=1, num_objs=3, rad_max=6, channel_dim=0) n = nib.Nifti1Image(im, np.eye(4)) @@ -80,7 +86,7 @@ def test_anisotropic_spacing(self): dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"])) - calculator = DatasetSummary(dataset, num_workers=4) + calculator = DatasetSummary(dataset, num_workers=4, meta_key_postfix="meta_dict") target_spacing = calculator.get_target_spacing(anisotropic_threshold=4.0, percentile=20.0) np.testing.assert_allclose(target_spacing, (1.0, 1.0, 1.8)) diff --git a/tests/test_decathlondataset.py b/tests/test_decathlondataset.py index 15dbceb8ad..9a785668ac 100644 --- a/tests/test_decathlondataset.py +++ b/tests/test_decathlondataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,6 +12,7 @@ import os import shutil import unittest +from pathlib import Path from urllib.error import ContentTooShortError, HTTPError from monai.apps import DecathlonDataset @@ -46,6 +47,7 @@ def _test_dataset(dataset): transform=transform, section="validation", download=True, + copy_cache=False, ) except (ContentTooShortError, HTTPError, RuntimeError) as e: print(str(e)) @@ -59,6 +61,8 @@ def _test_dataset(dataset): root_dir=testing_dir, task="Task04_Hippocampus", transform=transform, section="validation", download=False ) _test_dataset(data) + self.assertTrue(data[0]["image_meta_dict"]["filename_or_obj"].endswith("hippocampus_163.nii.gz")) + self.assertTrue(data[0]["label_meta_dict"]["filename_or_obj"].endswith("hippocampus_163.nii.gz")) # test validation without transforms data = DecathlonDataset(root_dir=testing_dir, task="Task04_Hippocampus", section="validation", download=False) self.assertTupleEqual(data[0]["image"].shape, (36, 47, 44)) @@ -69,17 +73,14 @@ def _test_dataset(dataset): # test dataset properties data = DecathlonDataset( - root_dir=testing_dir, - task="Task04_Hippocampus", - section="validation", - download=False, + root_dir=Path(testing_dir), task="Task04_Hippocampus", section="validation", download=False ) properties = data.get_properties(keys="labels") self.assertDictEqual(properties["labels"], {"0": "background", "1": "Anterior", "2": "Posterior"}) shutil.rmtree(os.path.join(testing_dir, "Task04_Hippocampus")) try: - data = DecathlonDataset( + DecathlonDataset( root_dir=testing_dir, task="Task04_Hippocampus", transform=transform, diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 521d263663..dc43cfc422 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -38,7 +38,7 @@ from monai.transforms.inverse_batch_transform import Decollated from monai.transforms.spatial.dictionary import RandAffined, RandRotate90d from monai.utils import optional_import, set_determinism -from monai.utils.enums import InverseKeys +from monai.utils.enums import TraceKeys from tests.utils import make_nifti_image _, has_nib = optional_import("nibabel") @@ -75,6 +75,7 @@ [[None, None], [None, None]], [["test"], ["test"]], [[], []], + [[("ch1", "ch2"), ("ch3",)], [["ch1", "ch3"], ["ch2", None]]], # default pad None ] @@ -105,7 +106,7 @@ def check_match(self, in1, in2): k1, k2 = k1.value, k2.value self.check_match(k1, k2) # Transform ids won't match for windows with multiprocessing, so don't check values - if k1 == InverseKeys.ID and sys.platform in ["darwin", "win32"]: + if k1 == TraceKeys.ID and sys.platform in ["darwin", "win32"]: continue self.check_match(v1, v2) elif isinstance(in1, (list, tuple)): @@ -120,7 +121,7 @@ def check_match(self, in1, in2): def check_decollate(self, dataset): batch_size = 2 - num_workers = 2 + num_workers = 2 if sys.platform == "linux" else 0 loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) @@ -170,10 +171,7 @@ def test_decollation_examples(self, input_val, expected_out): self.assertListEqual(expected_out, out) def test_dict_examples(self): - test_case = { - "meta": {"out": ["test", "test"]}, - "image_meta_dict": {"scl_slope": torch.Tensor((0.0, 0.0))}, - } + test_case = {"meta": {"out": ["test", "test"]}, "image_meta_dict": {"scl_slope": torch.Tensor((0.0, 0.0))}} out = decollate_batch(test_case) self.assertEqual(out[0]["meta"]["out"], "test") self.assertEqual(out[0]["image_meta_dict"]["scl_slope"], 0.0) @@ -210,6 +208,16 @@ def test_dict_examples(self): out = decollate_batch(test_case, detach=False) self.assertEqual(out[0]["out"], "test") + test_case = { + "image": torch.tensor([[[1, 2, 3]], [[3, 4, 5]]]), + "label": torch.tensor([[[5]], [[7]]]), + "out": ["test"], + } + out = decollate_batch(test_case, detach=False, pad=False) + self.assertEqual(len(out), 1) # no padding + out = decollate_batch(test_case, detach=False, pad=True, fill_value=0) + self.assertEqual(out[1]["out"], 0) # verify padding fill_value + def test_decollated(self): test_case = { "image": torch.tensor([[[1, 2]], [[3, 4]]]), @@ -236,12 +244,16 @@ def test_decollated(self): torch.tensor([[[1, 2]], [[3, 4]]]), {"out": ["test", "test"]}, {"scl_slope": torch.Tensor((0.0, 0.0))}, + {"out2": ["test1"]}, 0.85, + [], ] - transform = Decollated(keys=None, detach=False) + transform = Decollated(keys=None, detach=False, fill_value=-1) out = transform(test_case) - # the 4th item in the list is scalar loss value - self.assertEqual(out[1][3], 0.85) + + self.assertEqual(out[0][-2], 0.85) # scalar replicates + self.assertEqual(out[1][-2], 0.85) # scalar replicates + self.assertEqual(out[1][-3], -1) # fill value for the dictionary item self.assertEqual(out[0][1]["out"], "test") self.assertEqual(out[0][2]["scl_slope"], 0.0) self.assertTrue(isinstance(out[0][2]["scl_slope"], torch.Tensor)) diff --git a/tests/test_deepedit_transforms.py b/tests/test_deepedit_transforms.py index c2b11e8ee7..aa0a73183d 100644 --- a/tests/test_deepedit_transforms.py +++ b/tests/test_deepedit_transforms.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -28,11 +28,7 @@ "background": [0, 0, 0], } -DISCARD_ADD_GUIDANCE_TEST_CASE = [ - {"image": IMAGE, "label": LABEL}, - DATA_1, - (3, 1, 5, 5), -] +DISCARD_ADD_GUIDANCE_TEST_CASE = [{"image": IMAGE, "label": LABEL}, DATA_1, (3, 1, 5, 5)] DATA_2 = { "image": IMAGE, diff --git a/tests/test_deepgrow_dataset.py b/tests/test_deepgrow_dataset.py index 147d8e7099..ff8de87b81 100644 --- a/tests/test_deepgrow_dataset.py +++ b/tests/test_deepgrow_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_deepgrow_interaction.py b/tests/test_deepgrow_interaction.py index 016ba17251..b040348b62 100644 --- a/tests/test_deepgrow_interaction.py +++ b/tests/test_deepgrow_interaction.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_deepgrow_transforms.py b/tests/test_deepgrow_transforms.py index f50e92d146..5a3f55be60 100644 --- a/tests/test_deepgrow_transforms.py +++ b/tests/test_deepgrow_transforms.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -31,12 +31,7 @@ IMAGE = np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]) LABEL = np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]) -DATA_1 = { - "image": IMAGE, - "label": LABEL, - "image_meta_dict": {}, - "label_meta_dict": {}, -} +DATA_1 = {"image": IMAGE, "label": LABEL, "image_meta_dict": {}, "label_meta_dict": {}} DATA_2 = { "image": np.array( @@ -141,23 +136,11 @@ "pred": np.array([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]), } -DATA_12 = { - "image": np.arange(27).reshape(3, 3, 3), - "image_meta_dict": {}, - "guidance": [[0, 0, 0], [0, 1, 1], 1], -} +DATA_12 = {"image": np.arange(27).reshape(3, 3, 3), "image_meta_dict": {}, "guidance": [[0, 0, 0], [0, 1, 1], 1]} -FIND_SLICE_TEST_CASE_1 = [ - {"label": "label", "sids": "sids"}, - DATA_1, - [0], -] +FIND_SLICE_TEST_CASE_1 = [{"label": "label", "sids": "sids"}, DATA_1, [0]] -FIND_SLICE_TEST_CASE_2 = [ - {"label": "label", "sids": "sids"}, - DATA_2, - [0, 1], -] +FIND_SLICE_TEST_CASE_2 = [{"label": "label", "sids": "sids"}, DATA_2, [0, 1]] CROP_TEST_CASE_1 = [ { @@ -338,14 +321,10 @@ [[1.0, 1.0, 2.0, 2.0], [1.0, 1.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0], [3.0, 3.0, 4.0, 4.0]], [[5.0, 5.0, 6.0, 6.0], [5.0, 5.0, 6.0, 6.0], [7.0, 7.0, 8.0, 8.0], [7.0, 7.0, 8.0, 8.0]], [[5.0, 5.0, 6.0, 6.0], [5.0, 5.0, 6.0, 6.0], [7.0, 7.0, 8.0, 8.0], [7.0, 7.0, 8.0, 8.0]], - ], + ] ) -RESTORE_LABEL_TEST_CASE_2 = [ - {"keys": ["pred"], "ref_image": "image", "mode": "nearest"}, - DATA_11, - RESULT, -] +RESTORE_LABEL_TEST_CASE_2 = [{"keys": ["pred"], "ref_image": "image", "mode": "nearest"}, DATA_11, RESULT] FETCH_2D_SLICE_TEST_CASE_1 = [ {"keys": ["image"], "guidance": "guidance"}, diff --git a/tests/test_delete_itemsd.py b/tests/test_delete_itemsd.py index 7426e39ff0..18138119b5 100644 --- a/tests/test_delete_itemsd.py +++ b/tests/test_delete_itemsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,19 +19,36 @@ TEST_CASE_1 = [{"keys": [str(i) for i in range(30)]}, 20] +TEST_CASE_2 = [{"keys": ["image/" + str(i) for i in range(30)], "sep": "/"}, 20] + +TEST_CASE_3 = [{"keys": "meta_dict%0008\\|[0-9]", "sep": "%", "use_re": True}] + class TestDeleteItemsd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_memory(self, input_param, expected_key_size): - input_data = {} + input_data = {"image": {}} if "sep" in input_param else {} for i in range(50): - input_data[str(i)] = [time.time()] * 100000 + if "sep" in input_param: + input_data["image"][str(i)] = [time.time()] * 100000 + else: + input_data[str(i)] = [time.time()] * 100000 result = DeleteItemsd(**input_param)(input_data) - self.assertEqual(len(result.keys()), expected_key_size) + if "sep" in input_param: + self.assertEqual(len(result["image"].keys()), expected_key_size) + else: + self.assertEqual(len(result.keys()), expected_key_size) self.assertGreaterEqual( sys.getsizeof(input_data) * float(expected_key_size) / len(input_data), sys.getsizeof(result) ) + @parameterized.expand([TEST_CASE_3]) + def test_re(self, input_param): + input_data = {"image": [1, 2, 3], "meta_dict": {"0008|0005": 1, "0008|1050": 2, "0008test": 3}} + result = DeleteItemsd(**input_param)(input_data) + self.assertEqual(result["meta_dict"]["0008test"], 3) + self.assertTrue(len(result["meta_dict"]), 1) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_densenet.py b/tests/test_densenet.py index ba4b7afcb4..47f584297e 100644 --- a/tests/test_densenet.py +++ b/tests/test_densenet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 429d5ee767..545031321b 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -29,7 +29,7 @@ def test_warning(self): def foo2(): pass - print(foo2()) + foo2() # should not raise any warnings def test_warning_milestone(self): """Test deprecated decorator with `since` and `removed` set for a milestone version""" @@ -172,7 +172,7 @@ def test_arg_warn2(self): """Test deprecated_arg decorator with just `since` set.""" @deprecated_arg("b", since=self.prev_version, version_val=self.test_version) - def afoo2(a, **kwargs): + def afoo2(a, **kw): pass afoo2(1) # ok when no b provided @@ -222,3 +222,72 @@ def future1(): warnings.warn("fake warning", DeprecationWarning) self.assertEqual(aw.warning.args[0], "fake warning") + + def test_arg_except2_unknown(self): + """ + Test deprecated_arg decorator raises exception with `removed` set in the past. + with unknown version + """ + + @deprecated_arg("b", removed=self.prev_version, version_val="0+untagged.1.g3131155") + def afoo4(a, b=None): + pass + + self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2)) + + def test_arg_except3_unknown(self): + """ + Test deprecated_arg decorator raises exception with `removed` set in the past. + with unknown version and kwargs + """ + + @deprecated_arg("b", removed=self.prev_version, version_val="0+untagged.1.g3131155") + def afoo4(a, b=None, **kwargs): + pass + + self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2)) + self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2, c=3)) + + def test_replacement_arg(self): + """ + Test deprecated arg being replaced. + """ + + @deprecated_arg("b", new_name="a", since=self.prev_version, version_val=self.test_version) + def afoo4(a, b=None): + return a + + self.assertEqual(afoo4(b=2), 2) + self.assertEqual(afoo4(1, b=2), 1) # new name is in use + self.assertEqual(afoo4(a=1, b=2), 1) # prefers the new arg + + def test_replacement_arg1(self): + """ + Test deprecated arg being replaced with kwargs. + """ + + @deprecated_arg("b", new_name="a", since=self.prev_version, version_val=self.test_version) + def afoo4(a, *args, **kwargs): + return a + + self.assertEqual(afoo4(b=2), 2) + self.assertEqual(afoo4(1, b=2, c=3), 1) # new name is in use + self.assertEqual(afoo4(a=1, b=2, c=3), 1) # prefers the new arg + + def test_replacement_arg2(self): + """ + Test deprecated arg (with a default value) being replaced. + """ + + @deprecated_arg("b", new_name="a", since=self.prev_version, version_val=self.test_version) + def afoo4(a, b=None, **kwargs): + return a, kwargs + + self.assertEqual(afoo4(b=2, c=3), (2, {"c": 3})) + self.assertEqual(afoo4(1, b=2, c=3), (1, {"c": 3})) # new name is in use + self.assertEqual(afoo4(a=1, b=2, c=3), (1, {"c": 3})) # prefers the new arg + self.assertEqual(afoo4(1, 2, c=3), (1, {"c": 3})) # prefers the new positional arg + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_detect_envelope.py b/tests/test_detect_envelope.py index ded0290de2..f1b9c7ad1a 100644 --- a/tests/test_detect_envelope.py +++ b/tests/test_detect_envelope.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import DetectEnvelope @@ -70,9 +71,9 @@ TEST_CASE_2_CHAN_3D_SINE = [ {}, # args (empty, so use default (i.e. process along first spatial dimension, axis=1) # Create 100 identical windowed sine waves as a (n_samples x 10 x 10) 3D numpy array, twice (2 channels) - np.stack([np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2)] * 2, axis=0), + torch.as_tensor(np.stack([np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2)] * 2, axis=0)), # Expected output: Set of 100 identical Hann windows in (n_samples x 10 x 10) 3D numpy array, twice (2 channels) - np.stack([np.stack([np.stack([np.hanning(n_samples)] * 10, axis=1)] * 10, axis=2)] * 2, axis=0), + torch.as_tensor(np.stack([np.stack([np.stack([np.hanning(n_samples)] * 10, axis=1)] * 10, axis=2)] * 2, axis=0)), 1e-4, # absolute tolerance ] diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py index 66cfb36e99..b11165ca9c 100644 --- a/tests/test_dice_ce_loss.py +++ b/tests/test_dice_ce_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_dice_focal_loss.py b/tests/test_dice_focal_loss.py index 920994f8de..c611fe4160 100644 --- a/tests/test_dice_focal_loss.py +++ b/tests/test_dice_focal_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,11 +24,7 @@ def test_result_onehot_target_include_bg(self): label = torch.randint(low=0, high=2, size=size) pred = torch.randn(size) for reduction in ["sum", "mean", "none"]: - common_params = { - "include_background": True, - "to_onehot_y": False, - "reduction": reduction, - } + common_params = {"include_background": True, "to_onehot_y": False, "reduction": reduction} for focal_weight in [None, torch.tensor([1.0, 1.0, 2.0]), (3, 2.0, 1)]: for lambda_focal in [0.5, 1.0, 1.5]: dice_focal = DiceFocalLoss( @@ -46,11 +42,7 @@ def test_result_no_onehot_no_bg(self): label = torch.argmax(label, dim=1, keepdim=True) pred = torch.randn(size) for reduction in ["sum", "mean", "none"]: - common_params = { - "include_background": False, - "to_onehot_y": True, - "reduction": reduction, - } + common_params = {"include_background": False, "to_onehot_y": True, "reduction": reduction} for focal_weight in [2.0, torch.tensor([1.0, 2.0]), (2.0, 1)]: for lambda_focal in [0.5, 1.0, 1.5]: dice_focal = DiceFocalLoss(focal_weight=focal_weight, lambda_focal=lambda_focal, **common_params) diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index ef0a51eb15..4e45393de6 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,10 +21,7 @@ TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.307576, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) @@ -64,7 +61,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - [[0.296529, 0.415136], [0.599976, 0.428559]], + [[[0.296529], [0.415136]], [[0.599976], [0.428559]]], ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": True, "to_onehot_y": True, "softmax": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, @@ -91,26 +88,17 @@ ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.307576, ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "squared_pred": True}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.178337, ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "jaccard": True}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.470451, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) diff --git a/tests/test_dints_cell.py b/tests/test_dints_cell.py new file mode 100644 index 0000000000..d480235b70 --- /dev/null +++ b/tests/test_dints_cell.py @@ -0,0 +1,77 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.nets.dints import Cell + +TEST_CASES_3D = [ + [ + {"c_prev": 8, "c": 8, "rate": 1, "arch_code_c": None}, + torch.tensor([1, 1, 1, 1, 1]), + torch.tensor([1, 1, 1, 1, 1]), + (2, 8, 32, 16, 8), + (2, 8, 64, 32, 16), + ], + [ + {"c_prev": 8, "c": 4, "rate": 1, "arch_code_c": [1, 1, 0, 0, 1]}, + torch.tensor([1, 1, 0, 0, 1]), + torch.tensor([1, 0.2, 1.3, 0, 1]), + (2, 8, 32, 16, 8), + (2, 4, 64, 32, 16), + ], + [ + {"c_prev": 8, "c": 8, "rate": 0, "arch_code_c": None}, + torch.tensor([1, 1, 1, 1, 1]), + torch.tensor([0, 0, 0, 1, 0]), + (2, 8, 32, 16, 8), + (2, 8, 32, 16, 8), + ], + [ + {"c_prev": 8, "c": 8, "rate": -1, "arch_code_c": None}, + torch.tensor([1, 1, 1, 1, 1]), + torch.tensor([1, 1, 1, 1, 1]), + (2, 8, 32, 16, 8), + (2, 8, 16, 8, 4), + ], + [ + {"c_prev": 8, "c": 8, "rate": -1, "arch_code_c": [1, 0, 0, 0, 1]}, + torch.tensor([1, 0, 0, 0, 1]), + torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2]), + (2, 8, 32, 16, 8), + (2, 8, 16, 8, 4), + ], +] + +TEST_CASES_2D = [ + [ + {"c_prev": 8, "c": 7, "rate": -1, "arch_code_c": [1, 0, 0, 0, 1], "spatial_dims": 2}, + torch.tensor([1, 0]), + torch.tensor([0.2, 0.2]), + (2, 8, 16, 8), + (2, 7, 8, 4), + ] +] + + +class TestCell(unittest.TestCase): + @parameterized.expand(TEST_CASES_2D + TEST_CASES_3D) + def test_cell_3d(self, input_param, ops, weight, input_shape, expected_shape): + net = Cell(**input_param) + result = net(torch.randn(input_shape), weight=weight) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dints_mixop.py b/tests/test_dints_mixop.py new file mode 100644 index 0000000000..b686069173 --- /dev/null +++ b/tests/test_dints_mixop.py @@ -0,0 +1,82 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.nets.dints import Cell, MixedOp +from tests.utils import test_script_save + +TEST_CASES_3D = [ + [ + {"c": 8, "arch_code_c": None}, + torch.tensor([1, 1, 1, 1, 1]), + torch.tensor([1, 1, 1, 1, 1]), + (2, 8, 32, 16, 8), + (2, 8, 32, 16, 8), + ], + [ + {"c": 8, "arch_code_c": [1, 1, 0, 0, 1]}, + torch.tensor([1, 1, 0, 0, 1]), + torch.tensor([1, 0.2, 1.3, 0, 1]), + (2, 8, 64, 32, 16), + (2, 8, 64, 32, 16), + ], + [ + {"c": 8, "arch_code_c": None}, + torch.tensor([1, 1, 1, 1, 1]), + torch.tensor([0, 0, 0, 1, 0]), + (2, 8, 32, 16, 8), + (2, 8, 32, 16, 8), + ], + [ + {"c": 8, "arch_code_c": [1, 1, 1, 0, 1]}, + torch.tensor([1, 1, 1, 1, 1]), + torch.tensor([0, 0, 0, 1, 0]), + (2, 8, 32, 16, 8), + (2, 8, 32, 16, 8), + ], +] +TEST_CASES_2D = [ + [ + {"c": 32, "arch_code_c": [1, 1, 1, 0, 1]}, + torch.tensor([1, 1]), + torch.tensor([0, 0]), + (2, 32, 16, 8), + (2, 32, 16, 8), + ] +] + + +class TestMixOP(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D) + def test_mixop_3d(self, input_param, ops, weight, input_shape, expected_shape): + net = MixedOp(ops=Cell.OPS3D, **input_param) + result = net(torch.randn(input_shape), weight=weight) + self.assertEqual(result.shape, expected_shape) + self.assertEqual(result.shape, expected_shape) + + @parameterized.expand(TEST_CASES_2D) + def test_mixop_2d(self, input_param, ops, weight, input_shape, expected_shape): + net = MixedOp(ops=Cell.OPS2D, **input_param) + result = net(torch.randn(input_shape), weight=weight) + self.assertEqual(result.shape, expected_shape) + + @parameterized.expand(TEST_CASES_3D) + def test_script(self, input_param, ops, weight, input_shape, expected_shape): + net = MixedOp(ops=Cell.OPS3D, **input_param) + test_script_save(net, torch.randn(input_shape), weight) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dints_network.py b/tests/test_dints_network.py new file mode 100644 index 0000000000..8be5eb7ccd --- /dev/null +++ b/tests/test_dints_network.py @@ -0,0 +1,165 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.nets import DiNTS, TopologyInstance, TopologySearch +from monai.networks.nets.dints import Cell +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save + +TEST_CASES_3D = [ + [ + { + "channel_mul": 0.2, + "num_blocks": 6, + "num_depths": 3, + "device": "cpu", + "use_downsample": False, + "spatial_dims": 3, + }, + { + "in_channels": 1, + "num_classes": 3, + "act_name": "RELU", + "norm_name": "INSTANCE", + "use_downsample": False, + "spatial_dims": 3, + }, + (3, 1, 32, 32, 16), + (3, 3, 32, 32, 16), + ] +] +if torch.cuda.is_available(): + TEST_CASES_3D += [ + [ + { + "channel_mul": 0.5, + "num_blocks": 7, + "num_depths": 4, + "device": "cuda", + "use_downsample": True, + "spatial_dims": 3, + }, + { + "in_channels": 2, + "num_classes": 2, + "act_name": "PRELU", + "norm_name": "BATCH", + "use_downsample": True, + "spatial_dims": 3, + }, + (3, 2, 32, 32, 16), + (3, 2, 32, 32, 16), + ] + ] +TEST_CASES_2D = [ + [ + { + "channel_mul": 1, + "num_blocks": 7, + "num_depths": 4, + "device": "cpu", + "use_downsample": True, + "spatial_dims": 2, + }, + { + "in_channels": 2, + "num_classes": 2, + "act_name": "PRELU", + "norm_name": "BATCH", + "use_downsample": True, + "spatial_dims": 2, + }, + (2, 2, 32, 16), + (2, 2, 32, 16), + ] +] +if torch.cuda.is_available(): + TEST_CASES_2D += [ + [ + { + "channel_mul": 0.5, + "num_blocks": 8, + "num_depths": 4, + "device": "cuda", + "use_downsample": False, + "spatial_dims": 2, + }, + { + "in_channels": 1, + "num_classes": 4, + "act_name": "RELU", + "norm_name": "INSTANCE", + "use_downsample": False, + "spatial_dims": 2, + }, + (2, 1, 32, 16), + (2, 4, 32, 16), + ] + ] + + +class TestDints(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D + TEST_CASES_2D) + def test_dints_inference(self, dints_grid_params, dints_params, input_shape, expected_shape): + grid = TopologySearch(**dints_grid_params) + dints_params["dints_space"] = grid + net = DiNTS(**dints_params).to(dints_grid_params["device"]) + result = net(torch.randn(input_shape).to(dints_grid_params["device"])) + self.assertEqual(result.shape, expected_shape) + # test functions + grid.get_ram_cost_usage(in_size=input_shape, full=True) + grid.get_ram_cost_usage(in_size=input_shape, full=False) + probs_a, _ = grid.get_prob_a(child=True) + grid.get_topology_entropy(probs_a) + grid.decode() + grid.gen_mtx(depth=4) + + @parameterized.expand(TEST_CASES_3D + TEST_CASES_2D) + def test_dints_search(self, dints_grid_params, dints_params, input_shape, expected_shape): + num_blocks = dints_grid_params["num_blocks"] + num_depths = dints_grid_params["num_depths"] + # init a Cell to obtain cell operation number + _cell = Cell(1, 1, 0, spatial_dims=dints_grid_params["spatial_dims"]) + num_cell_ops = len(_cell.OPS) + # define archtecture codes + node_a = torch.ones((num_blocks + 1, num_depths)) + arch_code_a = np.ones((num_blocks, 3 * num_depths - 2)) + arch_code_c = np.random.randint(num_cell_ops, size=(num_blocks, 3 * num_depths - 2)) + # initialize with codes + dints_grid_params["arch_code"] = [arch_code_a, arch_code_c] + grid = TopologyInstance(**dints_grid_params) + # set as deploy stage + dints_params["dints_space"] = grid + dints_params["node_a"] = node_a + net = DiNTS(**dints_params).to(dints_grid_params["device"]) + result = net(torch.randn(input_shape).to(dints_grid_params["device"])) + self.assertEqual(result.shape, expected_shape) + self.assertTrue(isinstance(net.weight_parameters(), list)) + + +@SkipIfBeforePyTorchVersion((1, 9)) +class TestDintsTS(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D + TEST_CASES_2D) + def test_script(self, dints_grid_params, dints_params, input_shape, _): + grid = TopologyInstance(**dints_grid_params) + dints_grid_params["device"] = "cpu" + dints_params["dints_space"] = grid + net = DiNTS(**dints_params).to(dints_grid_params["device"]) + test_script_save(net, torch.randn(input_shape).to(dints_grid_params["device"])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_discriminator.py b/tests/test_discriminator.py index 52b9a10dd5..aa9b9720c4 100644 --- a/tests/test_discriminator.py +++ b/tests/test_discriminator.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_divisible_pad.py b/tests/test_divisible_pad.py index ca15b4b347..f940636fa8 100644 --- a/tests/test_divisible_pad.py +++ b/tests/test_divisible_pad.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,21 +22,11 @@ for p in TEST_NDARRAYS: # pad first dim to be divisible by 7, the second unchanged. - TESTS.append( - [ - {"k": (7, -1), "mode": "constant"}, - p(np.zeros((3, 8, 7))), - p(np.zeros((3, 14, 7))), - ] - ) + TESTS.append([{"k": (7, -1), "mode": "constant"}, p(np.zeros((3, 8, 7))), p(np.zeros((3, 14, 7)))]) # pad all dimensions to be divisible by 5 TESTS.append( - [ - {"k": 5, "mode": "constant", "method": "end"}, - p(np.zeros((3, 10, 5, 17))), - p(np.zeros((3, 10, 5, 20))), - ] + [{"k": 5, "mode": "constant", "method": "end"}, p(np.zeros((3, 10, 5, 17))), p(np.zeros((3, 10, 5, 20)))] ) @@ -50,11 +40,13 @@ def test_pad_shape(self, input_param, input_data, expected_val): self.assertAlmostEqual(result.shape, expected_val.shape) def test_pad_kwargs(self): - padder = DivisiblePad(k=5, mode="constant", constant_values=((0, 0), (1, 1), (2, 2))) for p in TEST_NDARRAYS: - result = padder(p(np.zeros((3, 8, 4)))) - result = result.cpu() if isinstance(result, torch.Tensor) else result - torch.testing.assert_allclose(result[:, :1, :4], np.ones((3, 1, 4)), rtol=1e-7, atol=0) + input_data = p(np.zeros((3, 8, 4))) + if isinstance(input_data, np.ndarray): + result = DivisiblePad(k=5, mode="constant", constant_values=((0, 0), (1, 1), (2, 2)))(input_data) + np.testing.assert_allclose(result[:, :1, :4], np.ones((3, 1, 4)), rtol=1e-7, atol=0) + else: + result = DivisiblePad(k=5, mode="constant", value=2)(input_data).cpu() torch.testing.assert_allclose(result[:, :, 4:5], np.ones((3, 10, 1)) + 1, rtol=1e-7, atol=0) diff --git a/tests/test_divisible_padd.py b/tests/test_divisible_padd.py index c834adac6d..61fe917421 100644 --- a/tests/test_divisible_padd.py +++ b/tests/test_divisible_padd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -28,11 +28,7 @@ np.zeros((3, 14, 7)), ] -TEST_CASE_3 = [ - {"keys": ["img"], "k": 0, "mode": {"constant"}}, - {"img": np.zeros((3, 8))}, - np.zeros((3, 8)), -] +TEST_CASE_3 = [{"keys": ["img"], "k": 0, "mode": {"constant"}}, {"img": np.zeros((3, 8))}, np.zeros((3, 8))] class TestDivisiblePadd(unittest.TestCase): diff --git a/tests/test_download_and_extract.py b/tests/test_download_and_extract.py index b02e4ff86f..f896d4ae93 100644 --- a/tests/test_download_and_extract.py +++ b/tests/test_download_and_extract.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,6 +12,7 @@ import os import tempfile import unittest +from pathlib import Path from urllib.error import ContentTooShortError, HTTPError from monai.apps import download_and_extract, download_url, extractall @@ -23,8 +24,8 @@ class TestDownloadAndExtract(unittest.TestCase): def test_actions(self): testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") url = "https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE" - filepath = os.path.join(testing_dir, "MedNIST.tar.gz") - output_dir = testing_dir + filepath = Path(testing_dir) / "MedNIST.tar.gz" + output_dir = Path(testing_dir) md5_value = "0bc7306e7427e00ad1c5526a6677552d" try: download_and_extract(url, filepath, output_dir, md5_value) @@ -37,14 +38,15 @@ def test_actions(self): return # skipping this test due the network connection errors wrong_md5 = "0" - try: - download_url(url, filepath, wrong_md5) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors + with self.assertLogs(logger="monai.apps", level="ERROR"): + try: + download_url(url, filepath, wrong_md5) + except (ContentTooShortError, HTTPError, RuntimeError) as e: + print(str(e)) + if isinstance(e, RuntimeError): + # FIXME: skip MD5 check as current downloading method may fail + self.assertTrue(str(e).startswith("md5 check")) + return # skipping this test due the network connection errors try: extractall(filepath, output_dir, wrong_md5) diff --git a/tests/test_downsample_block.py b/tests/test_downsample_block.py index f4ae30198f..ac2acb0845 100644 --- a/tests/test_downsample_block.py +++ b/tests/test_downsample_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,11 +20,7 @@ TEST_CASES = [ [{"spatial_dims": 2, "kernel_size": 2}, (7, 4, 64, 48), (7, 8, 32, 24)], # 4-channel 2D, batch 7 [{"spatial_dims": 1, "kernel_size": 4}, (16, 4, 63), (16, 8, 15)], # 4-channel 1D, batch 16 - [ # 4-channel 1D, batch 16 - {"spatial_dims": 1, "kernel_size": 4, "padding": 1}, - (16, 4, 63), - (16, 8, 16), - ], + [{"spatial_dims": 1, "kernel_size": 4, "padding": 1}, (16, 4, 63), (16, 8, 16)], # 4-channel 1D, batch 16 [ # 4-channel 3D, batch 16 {"spatial_dims": 3, "kernel_size": 3, "ceil_mode": True}, (16, 4, 32, 24, 48), diff --git a/tests/test_dvf2ddf.py b/tests/test_dvf2ddf.py index cc3323cf13..d061cca7ff 100644 --- a/tests/test_dvf2ddf.py +++ b/tests/test_dvf2ddf.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index 81ed239461..36ac9d0309 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -26,14 +26,14 @@ expected_shape: Sequence[Any] TEST_CASE_DYNUNET_2D = [] +out_channels = 2 +in_size = 64 +spatial_dims = 2 for kernel_size in [(3, 3, 3, 1), ((3, 1), 1, (3, 3), (1, 1))]: for strides in [(1, 1, 1, 1), (2, 2, 2, 1)]: + expected_shape = (1, out_channels, *[in_size // strides[0]] * spatial_dims) for in_channels in [2, 3]: for res_block in [True, False]: - out_channels = 2 - in_size = 64 - spatial_dims = 2 - expected_shape = (1, out_channels, *[in_size // strides[0]] * spatial_dims) test_case = [ { "spatial_dims": spatial_dims, @@ -43,8 +43,10 @@ "strides": strides, "upsample_kernel_size": strides[1:], "norm_name": "batch", + "act_name": ("leakyrelu", {"inplace": True, "negative_slope": 0.2}), "deep_supervision": False, "res_block": res_block, + "dropout": None, }, (1, in_channels, in_size, in_size), expected_shape, @@ -52,11 +54,11 @@ TEST_CASE_DYNUNET_2D.append(test_case) TEST_CASE_DYNUNET_3D = [] # in 3d cases, also test anisotropic kernel/strides +in_channels = 1 +in_size = 64 for out_channels in [2, 3]: + expected_shape = (1, out_channels, 64, 32, 64) for res_block in [True, False]: - in_channels = 1 - in_size = 64 - expected_shape = (1, out_channels, 64, 32, 64) test_case = [ { "spatial_dims": 3, @@ -65,9 +67,11 @@ "kernel_size": (3, (1, 1, 3), 3, 3), "strides": ((1, 2, 1), 2, 2, 1), "upsample_kernel_size": (2, 2, 1), + "filters": (64, 96, 128, 192), "norm_name": ("INSTANCE", {"affine": True}), - "deep_supervision": False, + "deep_supervision": True, "res_block": res_block, + "dropout": ("alphadropout", {"p": 0.25}), }, (1, in_channels, in_size, in_size, in_size), expected_shape, diff --git a/tests/test_dynunet_block.py b/tests/test_dynunet_block.py index 7e832f6d81..1c83552766 100644 --- a/tests/test_dynunet_block.py +++ b/tests/test_dynunet_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -35,6 +35,7 @@ "out_channels": 16, "kernel_size": kernel_size, "norm_name": norm_name, + "act_name": ("leakyrelu", {"inplace": True, "negative_slope": 0.1}), "stride": stride, }, (1, 16, *([in_size] * spatial_dims)), @@ -49,22 +50,24 @@ for stride in [1, 2]: for norm_name in ["batch", "instance"]: for in_size in [15, 16]: - out_size = in_size * stride - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": in_channels, - "out_channels": out_channels, - "kernel_size": kernel_size, - "norm_name": norm_name, - "stride": stride, - "upsample_kernel_size": stride, - }, - (1, in_channels, *([in_size] * spatial_dims)), - (1, out_channels, *([out_size] * spatial_dims)), - (1, out_channels, *([in_size * stride] * spatial_dims)), - ] - TEST_UP_BLOCK.append(test_case) + for trans_bias in [True, False]: + out_size = in_size * stride + test_case = [ + { + "spatial_dims": spatial_dims, + "in_channels": in_channels, + "out_channels": out_channels, + "kernel_size": kernel_size, + "norm_name": norm_name, + "stride": stride, + "upsample_kernel_size": stride, + "trans_bias": trans_bias, + }, + (1, in_channels, *([in_size] * spatial_dims)), + (1, out_channels, *([out_size] * spatial_dims)), + (1, out_channels, *([in_size * stride] * spatial_dims)), + ] + TEST_UP_BLOCK.append(test_case) class TestResBasicBlock(unittest.TestCase): diff --git a/tests/test_dynunet_v1.py b/tests/test_dynunet_v1.py deleted file mode 100644 index fc216c145b..0000000000 --- a/tests/test_dynunet_v1.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from typing import Any, Sequence, Union - -import torch -from parameterized import parameterized - -from monai.networks import eval_mode -from monai.networks.nets.dynunet_v1 import DynUNetV1 -from tests.utils import skip_if_quick, test_script_save - -device = "cuda" if torch.cuda.is_available() else "cpu" - -strides: Sequence[Union[Sequence[int], int]] -kernel_size: Sequence[Any] -expected_shape: Sequence[Any] - -TEST_CASE_DYNUNET_2D = [] -for kernel_size in [(3, 3, 3, 1), ((3, 1), 1, (3, 3), (1, 1))]: - for strides in [(1, 1, 1, 1), (2, 2, 2, 1)]: - for in_channels in [2, 3]: - for res_block in [True, False]: - out_channels = 2 - in_size = 64 - spatial_dims = 2 - expected_shape = (1, out_channels, *[in_size // strides[0]] * spatial_dims) - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": in_channels, - "out_channels": out_channels, - "kernel_size": kernel_size, - "strides": strides, - "upsample_kernel_size": strides[1:], - "norm_name": "batch", - "deep_supervision": False, - "res_block": res_block, - }, - (1, in_channels, in_size, in_size), - expected_shape, - ] - TEST_CASE_DYNUNET_2D.append(test_case) - -TEST_CASE_DYNUNET_3D = [] # in 3d cases, also test anisotropic kernel/strides -for out_channels in [2, 3]: - for res_block in [True, False]: - in_channels = 1 - in_size = 64 - expected_shape = (1, out_channels, 64, 32, 64) - test_case = [ - { - "spatial_dims": 3, - "in_channels": in_channels, - "out_channels": out_channels, - "kernel_size": (3, (1, 1, 3), 3, 3), - "strides": ((1, 2, 1), 2, 2, 1), - "upsample_kernel_size": (2, 2, 1), - "norm_name": "instance", - "deep_supervision": False, - "res_block": res_block, - }, - (1, in_channels, in_size, in_size, in_size), - expected_shape, - ] - TEST_CASE_DYNUNET_3D.append(test_case) - -TEST_CASE_DEEP_SUPERVISION = [] -for spatial_dims in [2, 3]: - for res_block in [True, False]: - for deep_supr_num in [1, 2]: - for strides in [(1, 2, 1, 2, 1), (2, 2, 2, 1), (2, 1, 1, 2, 2)]: - scale = strides[0] - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": 1, - "out_channels": 2, - "kernel_size": [3] * len(strides), - "strides": strides, - "upsample_kernel_size": strides[1:], - "norm_name": "group", - "deep_supervision": True, - "deep_supr_num": deep_supr_num, - "res_block": res_block, - }, - (1, 1, *[in_size] * spatial_dims), - (1, 1 + deep_supr_num, 2, *[in_size // scale] * spatial_dims), - ] - TEST_CASE_DEEP_SUPERVISION.append(test_case) - - -@skip_if_quick -class TestDynUNet(unittest.TestCase): - @parameterized.expand(TEST_CASE_DYNUNET_2D + TEST_CASE_DYNUNET_3D) - def test_shape(self, input_param, input_shape, expected_shape): - net = DynUNetV1(**input_param).to(device) - with eval_mode(net): - result = net(torch.randn(input_shape).to(device)) - self.assertEqual(result.shape, expected_shape) - - def test_script(self): - input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] - net = DynUNetV1(**input_param) - test_data = torch.randn(input_shape) - test_script_save(net, test_data) - - -class TestDynUNetDeepSupervision(unittest.TestCase): - @parameterized.expand(TEST_CASE_DEEP_SUPERVISION) - def test_shape(self, input_param, input_shape, expected_shape): - net = DynUNetV1(**input_param).to(device) - with torch.no_grad(): - results = net(torch.randn(input_shape).to(device)) - self.assertEqual(results.shape, expected_shape) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_efficientnet.py b/tests/test_efficientnet.py index 6befba108a..0ab383fd56 100644 --- a/tests/test_efficientnet.py +++ b/tests/test_efficientnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,6 +13,7 @@ import unittest from typing import TYPE_CHECKING from unittest import skipUnless +from urllib.error import ContentTooShortError, HTTPError import torch from parameterized import parameterized @@ -44,7 +45,7 @@ def get_model_names(): - return ["efficientnet-b{}".format(d) for d in range(8)] + return [f"efficientnet-b{d}" for d in range(8)] def get_expected_model_shape(model_name): @@ -107,11 +108,7 @@ def make_shape_cases( ret_tests.append( [ kwargs, - ( - batch, - in_channels, - ) - + (get_expected_model_shape(model),) * spatial_dim, + (batch, in_channels) + (get_expected_model_shape(model),) * spatial_dim, (batch, num_classes), ] ) @@ -245,7 +242,7 @@ def make_shape_cases( }, [1, 2, 224, 224], ([1, 32, 112, 112], [1, 56, 56, 56], [1, 88, 28, 28], [1, 248, 14, 14], [1, 704, 7, 7]), - ), + ) ] @@ -254,8 +251,12 @@ class TestEFFICIENTNET(unittest.TestCase): def test_shape(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" - # initialize model - net = EfficientNetBN(**input_param).to(device) + try: + # initialize model + net = EfficientNetBN(**input_param).to(device) + except (ContentTooShortError, HTTPError, RuntimeError) as e: + print(str(e)) + return # skipping the tests because of http errors # run inference with random tensor with eval_mode(net): @@ -268,8 +269,12 @@ def test_shape(self, input_param, input_shape, expected_shape): def test_non_default_shapes(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" - # initialize model - net = EfficientNetBN(**input_param).to(device) + try: + # initialize model + net = EfficientNetBN(**input_param).to(device) + except (ContentTooShortError, HTTPError, RuntimeError) as e: + print(str(e)) + return # skipping the tests because of http errors # override input shape with different variations num_dims = len(input_shape) - 2 @@ -382,8 +387,12 @@ class TestExtractFeatures(unittest.TestCase): def test_shape(self, input_param, input_shape, expected_shapes): device = "cuda" if torch.cuda.is_available() else "cpu" - # initialize model - net = EfficientNetBNFeatures(**input_param).to(device) + try: + # initialize model + net = EfficientNetBNFeatures(**input_param).to(device) + except (ContentTooShortError, HTTPError, RuntimeError) as e: + print(str(e)) + return # skipping the tests because of http errors # run inference with random tensor with eval_mode(net): diff --git a/tests/test_ensemble_evaluator.py b/tests/test_ensemble_evaluator.py index 7f63cb6401..c7554e9421 100644 --- a/tests/test_ensemble_evaluator.py +++ b/tests/test_ensemble_evaluator.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_ensure_channel_first.py b/tests/test_ensure_channel_first.py index 23126d326f..dd6168ec75 100644 --- a/tests/test_ensure_channel_first.py +++ b/tests/test_ensure_channel_first.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -27,11 +27,7 @@ TEST_CASE_2 = [{"image_only": False}, ["test_image.nii.gz"], -1] -TEST_CASE_3 = [ - {"image_only": False}, - ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], - None, -] +TEST_CASE_3 = [{"image_only": False}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], None] TEST_CASE_4 = [{"reader": ITKReader(), "image_only": False}, ["test_image.nii.gz"], None] @@ -43,11 +39,7 @@ None, ] -TEST_CASE_7 = [ - {"image_only": False, "reader": ITKReader(pixel_type=itk.UC)}, - "tests/testing_data/CT_DICOM", - None, -] +TEST_CASE_7 = [{"image_only": False, "reader": ITKReader(pixel_type=itk.UC)}, "tests/testing_data/CT_DICOM", None] class TestEnsureChannelFirst(unittest.TestCase): diff --git a/tests/test_ensure_channel_firstd.py b/tests/test_ensure_channel_firstd.py index b4cde02a8f..9b9043e4ad 100644 --- a/tests/test_ensure_channel_firstd.py +++ b/tests/test_ensure_channel_firstd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,11 +25,7 @@ TEST_CASE_2 = [{"keys": "img"}, ["test_image.nii.gz"], -1] -TEST_CASE_3 = [ - {"keys": "img"}, - ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], - None, -] +TEST_CASE_3 = [{"keys": "img"}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], None] class TestEnsureChannelFirstd(unittest.TestCase): diff --git a/tests/test_ensure_type.py b/tests/test_ensure_type.py index 8feb96ed37..f8a6ee30ff 100644 --- a/tests/test_ensure_type.py +++ b/tests/test_ensure_type.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,9 +25,11 @@ def test_array_input(self): test_datas.append(test_datas[-1].cuda()) for test_data in test_datas: for dtype in ("tensor", "NUMPY"): - result = EnsureType(data_type=dtype)(test_data) + result = EnsureType(dtype, dtype=np.float32 if dtype == "NUMPY" else None, device="cpu")(test_data) + if dtype == "NUMPY": + self.assertTrue(result.dtype == np.float32) self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) self.assertTupleEqual(result.shape, (2, 2)) def test_single_input(self): @@ -36,12 +38,12 @@ def test_single_input(self): test_datas.append(test_datas[-1].cuda()) for test_data in test_datas: for dtype in ("tensor", "numpy"): - result = EnsureType(data_type=dtype)(test_data) + result = EnsureType(data_type=dtype, device="cpu")(test_data) self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) if isinstance(test_data, bool): self.assertFalse(result) else: - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) self.assertEqual(result.ndim, 0) def test_string(self): @@ -57,12 +59,12 @@ def test_string(self): def test_list_tuple(self): for dtype in ("tensor", "numpy"): - result = EnsureType(data_type=dtype)([[1, 2], [3, 4]]) + result = EnsureType(data_type=dtype, wrap_sequence=False)([[1, 2], [3, 4]]) self.assertTrue(isinstance(result, list)) self.assertTrue(isinstance(result[0][1], torch.Tensor if dtype == "tensor" else np.ndarray)) torch.testing.assert_allclose(result[1][0], torch.as_tensor(3)) # tuple of numpy arrays - result = EnsureType(data_type=dtype)((np.array([1, 2]), np.array([3, 4]))) + result = EnsureType(data_type=dtype, wrap_sequence=False)((np.array([1, 2]), np.array([3, 4]))) self.assertTrue(isinstance(result, tuple)) self.assertTrue(isinstance(result[0], torch.Tensor if dtype == "tensor" else np.ndarray)) torch.testing.assert_allclose(result[1], torch.as_tensor([3, 4])) diff --git a/tests/test_ensure_typed.py b/tests/test_ensure_typed.py index 96f482afc2..cadab9bd56 100644 --- a/tests/test_ensure_typed.py +++ b/tests/test_ensure_typed.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,9 +25,13 @@ def test_array_input(self): test_datas.append(test_datas[-1].cuda()) for test_data in test_datas: for dtype in ("tensor", "NUMPY"): - result = EnsureTyped(keys="data", data_type=dtype)({"data": test_data})["data"] + result = EnsureTyped( + keys="data", data_type=dtype, dtype=np.float32 if dtype == "NUMPY" else None, device="cpu" + )({"data": test_data})["data"] + if dtype == "NUMPY": + self.assertTrue(result.dtype == np.float32) self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) self.assertTupleEqual(result.shape, (2, 2)) def test_single_input(self): @@ -41,7 +45,7 @@ def test_single_input(self): if isinstance(test_data, bool): self.assertFalse(result) else: - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) self.assertEqual(result.ndim, 0) def test_string(self): @@ -57,12 +61,14 @@ def test_string(self): def test_list_tuple(self): for dtype in ("tensor", "numpy"): - result = EnsureTyped(keys="data", data_type=dtype)({"data": [[1, 2], [3, 4]]})["data"] + result = EnsureTyped(keys="data", data_type=dtype, wrap_sequence=False)({"data": [[1, 2], [3, 4]]})["data"] self.assertTrue(isinstance(result, list)) self.assertTrue(isinstance(result[0][1], torch.Tensor if dtype == "tensor" else np.ndarray)) torch.testing.assert_allclose(result[1][0], torch.as_tensor(3)) # tuple of numpy arrays - result = EnsureTyped(keys="data", data_type=dtype)({"data": (np.array([1, 2]), np.array([3, 4]))})["data"] + result = EnsureTyped(keys="data", data_type=dtype, wrap_sequence=False)( + {"data": (np.array([1, 2]), np.array([3, 4]))} + )["data"] self.assertTrue(isinstance(result, tuple)) self.assertTrue(isinstance(result[0], torch.Tensor if dtype == "tensor" else np.ndarray)) torch.testing.assert_allclose(result[1], torch.as_tensor([3, 4])) @@ -75,7 +81,7 @@ def test_dict(self): "extra": None, } for dtype in ("tensor", "numpy"): - result = EnsureTyped(keys="data", data_type=dtype)({"data": test_data})["data"] + result = EnsureTyped(keys="data", data_type=dtype, device="cpu")({"data": test_data})["data"] self.assertTrue(isinstance(result, dict)) self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray)) torch.testing.assert_allclose(result["img"], torch.as_tensor([1.0, 2.0])) diff --git a/tests/test_enum_bound_interp.py b/tests/test_enum_bound_interp.py index f788f8ba17..7607619e7a 100644 --- a/tests/test_enum_bound_interp.py +++ b/tests/test_enum_bound_interp.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_eval_mode.py b/tests/test_eval_mode.py index 45c551c209..bc9c97d238 100644 --- a/tests/test_eval_mode.py +++ b/tests/test_eval_mode.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_evenly_divisible_all_gather_dist.py b/tests/test_evenly_divisible_all_gather_dist.py index bf3bd1bacc..1bb3d887a0 100644 --- a/tests/test_evenly_divisible_all_gather_dist.py +++ b/tests/test_evenly_divisible_all_gather_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_factorized_increase.py b/tests/test_factorized_increase.py new file mode 100644 index 0000000000..a86f5a2db9 --- /dev/null +++ b/tests/test_factorized_increase.py @@ -0,0 +1,34 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.blocks.dints_block import FactorizedIncreaseBlock + +TEST_CASES_3D = [ + [{"in_channel": 32, "out_channel": 16}, (7, 32, 24, 16, 8), (7, 16, 48, 32, 16)], + [{"in_channel": 1, "out_channel": 2}, (1, 1, 1, 1, 1), (1, 2, 2, 2, 2)], +] + + +class TestFactInc(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D) + def test_factorized_increase_3d(self, input_param, input_shape, expected_shape): + net = FactorizedIncreaseBlock(**input_param) + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_factorized_reduce.py b/tests/test_factorized_reduce.py new file mode 100644 index 0000000000..d14418233e --- /dev/null +++ b/tests/test_factorized_reduce.py @@ -0,0 +1,34 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.blocks.dints_block import FactorizedReduceBlock + +TEST_CASES_3D = [ + [{"in_channel": 32, "out_channel": 16}, (7, 32, 24, 16, 8), (7, 16, 12, 8, 4)], + [{"in_channel": 16, "out_channel": 32}, (7, 16, 22, 14, 6), (7, 32, 11, 7, 3)], +] + + +class TestFactRed(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D) + def test_factorized_reduce_3d(self, input_param, input_shape, expected_shape): + net = FactorizedReduceBlock(**input_param) + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_fg_bg_to_indices.py b/tests/test_fg_bg_to_indices.py index 98626c7028..03eb770d6d 100644 --- a/tests/test_fg_bg_to_indices.py +++ b/tests/test_fg_bg_to_indices.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,58 +11,70 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import FgBgToIndices +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"image_threshold": 0.0, "output_shape": None}, - np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), - None, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 4, 8]), -] +TESTS_CASES = [] +for p in TEST_NDARRAYS: + TESTS_CASES.append( + [ + {"image_threshold": 0.0, "output_shape": None}, + p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), + None, + p([1, 2, 3, 5, 6, 7]), + p([0, 4, 8]), + ] + ) -TEST_CASE_2 = [ - {"image_threshold": 0.0, "output_shape": None}, - np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), - np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]]), - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TESTS_CASES.append( + [ + {"image_threshold": 0.0, "output_shape": None}, + p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), + p([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]]), + p([1, 2, 3, 5, 6, 7]), + p([0, 8]), + ] + ) -TEST_CASE_3 = [ - {"image_threshold": 1.0, "output_shape": None}, - np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), - np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TESTS_CASES.append( + [ + {"image_threshold": 1.0, "output_shape": None}, + p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), + p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), + p([1, 2, 3, 5, 6, 7]), + p([0, 8]), + ] + ) -TEST_CASE_4 = [ - {"image_threshold": 1.0, "output_shape": None}, - np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), - np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TESTS_CASES.append( + [ + {"image_threshold": 1.0, "output_shape": None}, + p([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), + p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), + p([1, 2, 3, 5, 6, 7]), + p([0, 8]), + ] + ) -TEST_CASE_5 = [ - {"image_threshold": 1.0, "output_shape": [3, 3]}, - np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), - np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), - np.array([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]), - np.array([[0, 0], [2, 2]]), -] + TESTS_CASES.append( + [ + {"image_threshold": 1.0, "output_shape": [3, 3]}, + p([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), + p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), + p([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]), + p([[0, 0], [2, 2]]), + ] + ) class TestFgBgToIndices(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS_CASES) def test_type_shape(self, input_data, label, image, expected_fg, expected_bg): fg_indices, bg_indices = FgBgToIndices(**input_data)(label, image) - np.testing.assert_allclose(fg_indices, expected_fg) - np.testing.assert_allclose(bg_indices, expected_bg) + assert_allclose(fg_indices, expected_fg) + assert_allclose(bg_indices, expected_bg) if __name__ == "__main__": diff --git a/tests/test_fg_bg_to_indicesd.py b/tests/test_fg_bg_to_indicesd.py index ce6ca30f1b..3be795919f 100644 --- a/tests/test_fg_bg_to_indicesd.py +++ b/tests/test_fg_bg_to_indicesd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,53 +11,66 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import FgBgToIndicesd +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"keys": "label", "image_key": None, "image_threshold": 0.0, "output_shape": None}, - {"label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 4, 8]), -] +TEST_CASES = [] +for p in TEST_NDARRAYS: -TEST_CASE_2 = [ - {"keys": "label", "image_key": "image", "image_threshold": 0.0, "output_shape": None}, - {"label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]])}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TEST_CASES.append( + [ + {"keys": "label", "image_key": None, "image_threshold": 0.0, "output_shape": None}, + {"label": p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])}, + p([1, 2, 3, 5, 6, 7]), + p([0, 4, 8]), + ] + ) -TEST_CASE_3 = [ - {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None}, - {"label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TEST_CASES.append( + [ + {"keys": "label", "image_key": "image", "image_threshold": 0.0, "output_shape": None}, + {"label": p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": p([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]])}, + p([1, 2, 3, 5, 6, 7]), + p([0, 8]), + ] + ) -TEST_CASE_4 = [ - {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None}, - {"label": np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TEST_CASES.append( + [ + {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None}, + {"label": p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, + p([1, 2, 3, 5, 6, 7]), + p([0, 8]), + ] + ) -TEST_CASE_5 = [ - {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": [3, 3]}, - {"label": np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, - np.array([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]), - np.array([[0, 0], [2, 2]]), -] + TEST_CASES.append( + [ + {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None}, + {"label": p([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), "image": p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, + p([1, 2, 3, 5, 6, 7]), + p([0, 8]), + ] + ) + + TEST_CASES.append( + [ + {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": [3, 3]}, + {"label": p([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), "image": p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, + p([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]), + p([[0, 0], [2, 2]]), + ] + ) class TestFgBgToIndicesd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TEST_CASES) def test_type_shape(self, input_data, data, expected_fg, expected_bg): result = FgBgToIndicesd(**input_data)(data) - np.testing.assert_allclose(result["label_fg_indices"], expected_fg) - np.testing.assert_allclose(result["label_bg_indices"], expected_bg) + assert_allclose(result["label_fg_indices"], expected_fg) + assert_allclose(result["label_bg_indices"], expected_bg) if __name__ == "__main__": diff --git a/tests/test_file_basename.py b/tests/test_file_basename.py index 77e77fabc5..dc8b1316a2 100644 --- a/tests/test_file_basename.py +++ b/tests/test_file_basename.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,6 +12,7 @@ import os import tempfile import unittest +from pathlib import Path from monai.data.utils import create_file_basename @@ -69,6 +70,15 @@ def test_value(self): expected = os.path.join(output_tmp, "test", "test_post_8") self.assertEqual(result, expected) + result = create_file_basename("post", Path("test.tar.gz"), Path(output_tmp), Path("foo"), True, 8) + expected = os.path.join(output_tmp, "test", "test_post_8") + self.assertEqual(result, expected) + + def test_relative_path(self): + output = create_file_basename("", "test.txt", "output", "", makedirs=False) + expected = os.path.join("output", "test", "test") + self.assertEqual(output, expected) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_fill_holes.py b/tests/test_fill_holes.py index 6ea83c239b..9f9dc1fc2e 100644 --- a/tests/test_fill_holes.py +++ b/tests/test_fill_holes.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,31 +16,15 @@ from parameterized import parameterized from monai.transforms import FillHoles -from tests.utils import assert_allclose, clone +from tests.utils import TEST_NDARRAYS, assert_allclose, clone -grid_1_raw = [ - [1, 1, 1], - [1, 0, 1], - [1, 1, 1], -] +grid_1_raw = [[1, 1, 1], [1, 0, 1], [1, 1, 1]] -grid_2_raw = [ - [0, 1, 0], - [1, 0, 1], - [0, 1, 0], -] +grid_2_raw = [[0, 1, 0], [1, 0, 1], [0, 1, 0]] -grid_3_raw = [ - [1, 1, 1], - [1, 1, 1], - [1, 1, 1], -] +grid_3_raw = [[1, 1, 1], [1, 1, 1], [1, 1, 1]] -grid_4_raw = [ - [0, 1, 0], - [1, 1, 1], - [0, 1, 0], -] +grid_4_raw = [[0, 1, 0], [1, 1, 1], [0, 1, 0]] grid_1 = torch.tensor([grid_1_raw]) @@ -50,49 +34,15 @@ grid_4 = torch.tensor([grid_4_raw]) -grid_5 = torch.tensor( - [ - [ - [1, 1, 1], - [1, 0, 0], - [1, 1, 1], - ] - ] -) - -grid_6 = torch.tensor( - [ - [ - [1, 1, 2, 2, 2], - [1, 0, 2, 0, 2], - [1, 1, 2, 2, 2], - ] - ] -) - -grid_7 = torch.tensor( - [ - [ - [1, 1, 2, 2, 2], - [1, 0, 2, 2, 2], - [1, 1, 2, 2, 2], - ] - ] -) - -TEST_CASE_0 = [ - "enclosed_default_full_connectivity_default_applied_labels", - {}, - grid_1, - grid_3, -] +grid_5 = torch.tensor([[[1, 1, 1], [1, 0, 0], [1, 1, 1]]]) -TEST_CASE_1 = [ - "enclosed_full_connectivity_default_applied_labels", - {"connectivity": 2}, - grid_1, - grid_3, -] +grid_6 = torch.tensor([[[1, 1, 2, 2, 2], [1, 0, 2, 0, 2], [1, 1, 2, 2, 2]]]) + +grid_7 = torch.tensor([[[1, 1, 2, 2, 2], [1, 0, 2, 2, 2], [1, 1, 2, 2, 2]]]) + +TEST_CASE_0 = ["enclosed_default_full_connectivity_default_applied_labels", {}, grid_1, grid_3] + +TEST_CASE_1 = ["enclosed_full_connectivity_default_applied_labels", {"connectivity": 2}, grid_1, grid_3] TEST_CASE_2 = [ "enclosed_full_connectivity_applied_labels_same_single", @@ -129,40 +79,15 @@ grid_3, ] -TEST_CASE_7 = [ - "enclosed_connectivity_1_default_applied_labels", - {"connectivity": 1}, - grid_1, - grid_3, -] +TEST_CASE_7 = ["enclosed_connectivity_1_default_applied_labels", {"connectivity": 1}, grid_1, grid_3] -TEST_CASE_8 = [ - "enclosed_connectivity_1_default_applied_labels", - {"connectivity": 1}, - grid_2, - grid_4, -] +TEST_CASE_8 = ["enclosed_connectivity_1_default_applied_labels", {"connectivity": 1}, grid_2, grid_4] -TEST_CASE_9 = [ - "open_full_connectivity_default_applied_labels", - {"connectivity": 2}, - grid_2, - grid_2, -] +TEST_CASE_9 = ["open_full_connectivity_default_applied_labels", {"connectivity": 2}, grid_2, grid_2] -TEST_CASE_10 = [ - "open_to_edge_connectivity_1_default_applied_labels", - {"connectivity": 1}, - grid_5, - grid_5, -] +TEST_CASE_10 = ["open_to_edge_connectivity_1_default_applied_labels", {"connectivity": 1}, grid_5, grid_5] -TEST_CASE_11 = [ - "open_to_other_label_connectivity_1_default_applied_labels", - {"connectivity": 1}, - grid_6, - grid_7, -] +TEST_CASE_11 = ["open_to_other_label_connectivity_1_default_applied_labels", {"connectivity": 1}, grid_6, grid_7] TEST_CASE_12 = [ "open_to_other_label_connectivity_1_applied_labels_other", @@ -276,11 +201,9 @@ class TestFillHoles(unittest.TestCase): @parameterized.expand(VALID_CASES) def test_correct_results(self, _, args, input_image, expected): converter = FillHoles(**args) - if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): - result = converter(clone(input_image).cuda()) - else: - result = converter(clone(input_image)) - assert_allclose(result, expected) + for p in TEST_NDARRAYS: + result = converter(p(clone(input_image))) + assert_allclose(result, p(expected)) @parameterized.expand(INVALID_CASES) def test_raise_exception(self, _, args, input_image, expected_error): diff --git a/tests/test_fill_holesd.py b/tests/test_fill_holesd.py new file mode 100644 index 0000000000..f7aa9f6108 --- /dev/null +++ b/tests/test_fill_holesd.py @@ -0,0 +1,222 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import FillHolesd +from monai.utils.enums import CommonKeys +from tests.utils import TEST_NDARRAYS, assert_allclose, clone + +grid_1_raw = [[1, 1, 1], [1, 0, 1], [1, 1, 1]] + +grid_2_raw = [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + +grid_3_raw = [[1, 1, 1], [1, 1, 1], [1, 1, 1]] + +grid_4_raw = [[0, 1, 0], [1, 1, 1], [0, 1, 0]] + +grid_1 = torch.tensor([grid_1_raw]) + +grid_2 = torch.tensor([grid_2_raw]) + +grid_3 = torch.tensor([grid_3_raw]) + +grid_4 = torch.tensor([grid_4_raw]) + +grid_5 = torch.tensor([[[1, 1, 1], [1, 0, 0], [1, 1, 1]]]) + +grid_6 = torch.tensor([[[1, 1, 2, 2, 2], [1, 0, 2, 0, 2], [1, 1, 2, 2, 2]]]) + +grid_7 = torch.tensor([[[1, 1, 2, 2, 2], [1, 0, 2, 2, 2], [1, 1, 2, 2, 2]]]) + +TEST_CASE_0 = ["enclosed_default_full_connectivity_default_applied_labels", {}, grid_1, grid_3] + +TEST_CASE_1 = ["enclosed_full_connectivity_default_applied_labels", {"connectivity": 2}, grid_1, grid_3] + +TEST_CASE_2 = [ + "enclosed_full_connectivity_applied_labels_same_single", + {"connectivity": 2, "applied_labels": 1}, + grid_1, + grid_3, +] + +TEST_CASE_3 = [ + "enclosed_full_connectivity_applied_labels_same_list", + {"connectivity": 2, "applied_labels": [1]}, + grid_1, + grid_3, +] + +TEST_CASE_4 = [ + "enclosed_full_connectivity_applied_labels_other_single", + {"connectivity": 2, "applied_labels": 2}, + grid_1, + grid_1, +] + +TEST_CASE_5 = [ + "enclosed_full_connectivity_applied_labels_other_list", + {"connectivity": 2, "applied_labels": [2]}, + grid_1, + grid_1, +] + +TEST_CASE_6 = [ + "enclosed_full_connectivity_applied_labels_same_and_other", + {"connectivity": 2, "applied_labels": [1, 2]}, + grid_1, + grid_3, +] + +TEST_CASE_7 = ["enclosed_connectivity_1_default_applied_labels", {"connectivity": 1}, grid_1, grid_3] + +TEST_CASE_8 = ["enclosed_connectivity_1_default_applied_labels", {"connectivity": 1}, grid_2, grid_4] + +TEST_CASE_9 = ["open_full_connectivity_default_applied_labels", {"connectivity": 2}, grid_2, grid_2] + +TEST_CASE_10 = ["open_to_edge_connectivity_1_default_applied_labels", {"connectivity": 1}, grid_5, grid_5] + +TEST_CASE_11 = ["open_to_other_label_connectivity_1_default_applied_labels", {"connectivity": 1}, grid_6, grid_7] + +TEST_CASE_12 = [ + "open_to_other_label_connectivity_1_applied_labels_other", + {"connectivity": 1, "applied_labels": 1}, + grid_6, + grid_6, +] + +TEST_CASE_13 = [ + "numpy_enclosed_default_full_connectivity_default_applied_labels", + {}, + grid_1.cpu().numpy(), + grid_3.cpu().numpy(), +] + +TEST_CASE_14 = [ + "3D_enclosed_full_connectivity_default_applied_labels", + {"connectivity": 3}, + torch.tensor([[grid_3_raw, grid_1_raw, grid_3_raw]]), + torch.tensor([[grid_3_raw, grid_3_raw, grid_3_raw]]), +] + +TEST_CASE_15 = [ + "3D_enclosed_connectivity_1_default_applied_labels", + {"connectivity": 1}, + torch.tensor([[grid_4_raw, grid_2_raw, grid_4_raw]]), + torch.tensor([[grid_4_raw, grid_4_raw, grid_4_raw]]), +] + +TEST_CASE_16 = [ + "3D_open_full_connectivity_default_applied_labels", + {"connectivity": 3}, + torch.tensor([[grid_4_raw, grid_2_raw, grid_4_raw]]), + torch.tensor([[grid_4_raw, grid_2_raw, grid_4_raw]]), +] + +TEST_CASE_17 = [ + "3D_open_to_edge_connectivity_1_default_applied_labels", + {"connectivity": 1}, + torch.tensor([[grid_1_raw, grid_1_raw, grid_3_raw]]), + torch.tensor([[grid_1_raw, grid_1_raw, grid_3_raw]]), +] + +TEST_CASE_18 = [ + "enclosed_full_connectivity_applied_labels_with_background", + {"connectivity": 2, "applied_labels": [0, 1]}, + grid_1, + grid_3, +] + +TEST_CASE_19 = [ + "enclosed_full_connectivity_applied_labels_only_background", + {"connectivity": 2, "applied_labels": [0]}, + grid_1, + grid_1, +] + +TEST_CASE_20 = [ + "one-hot_enclosed_connectivity_1_default_applied_labels", + {"connectivity": 1}, + torch.tensor([grid_1_raw, grid_1_raw, grid_2_raw]), + torch.tensor([grid_1_raw, grid_3_raw, grid_4_raw]), +] + +TEST_CASE_21 = [ + "one-hot_enclosed_connectivity_1_applied_labels_2", + {"connectivity": 1, "applied_labels": [2]}, + torch.tensor([grid_1_raw, grid_1_raw, grid_2_raw]), + torch.tensor([grid_1_raw, grid_1_raw, grid_4_raw]), +] + +TEST_CASE_22 = [ + "one-hot_full_connectivity_applied_labels_2", + {"connectivity": 2}, + torch.tensor([grid_1_raw, grid_1_raw, grid_2_raw]), + torch.tensor([grid_1_raw, grid_3_raw, grid_2_raw]), +] + +VALID_CASES = [ + TEST_CASE_0, + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + TEST_CASE_10, + TEST_CASE_11, + TEST_CASE_12, + TEST_CASE_13, + TEST_CASE_14, + TEST_CASE_15, + TEST_CASE_16, + TEST_CASE_17, + TEST_CASE_18, + TEST_CASE_19, + TEST_CASE_20, + TEST_CASE_21, + TEST_CASE_22, +] + +ITEST_CASE_1 = ["invalid_image_data_type", {}, [[[[1, 1, 1]]]], NotImplementedError] + +INVALID_CASES = [ITEST_CASE_1] + + +class TestFillHoles(unittest.TestCase): + @parameterized.expand(VALID_CASES) + def test_correct_results(self, _, args, input_image, expected): + key = CommonKeys.IMAGE + converter = FillHolesd(keys=key, **args) + for p in TEST_NDARRAYS: + result = converter({key: p(clone(input_image))})[key] + assert_allclose(result, p(expected)) + + @parameterized.expand(INVALID_CASES) + def test_raise_exception(self, _, args, input_image, expected_error): + key = CommonKeys.IMAGE + with self.assertRaises(expected_error): + converter = FillHolesd(keys=key, **args) + if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): + _ = converter({key: clone(input_image).cuda()})[key] + else: + _ = converter({key: clone(input_image)})[key] + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_flip.py b/tests/test_flip.py index 404a3def7d..17cf0d2c39 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,12 +34,10 @@ def test_correct_results(self, _, spatial_axis): for p in TEST_NDARRAYS: im = p(self.imt[0]) flip = Flip(spatial_axis=spatial_axis) - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, spatial_axis)) + expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) result = flip(im) - assert_allclose(expected, result) + assert_allclose(result, p(expected)) if __name__ == "__main__": diff --git a/tests/test_flipd.py b/tests/test_flipd.py index 1676723800..900779f4e0 100644 --- a/tests/test_flipd.py +++ b/tests/test_flipd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -33,12 +33,10 @@ def test_invalid_cases(self, _, spatial_axis, raises): def test_correct_results(self, _, spatial_axis): for p in TEST_NDARRAYS: flip = Flipd(keys="img", spatial_axis=spatial_axis) - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, spatial_axis)) + expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) result = flip({"img": p(self.imt[0])})["img"] - assert_allclose(expected, result) + assert_allclose(result, p(expected)) if __name__ == "__main__": diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index 1314fe3841..d8a9c8ab5b 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -44,6 +44,35 @@ def test_consistency_with_cross_entropy_2d(self): max_error = abs(a - b) self.assertAlmostEqual(max_error, 0.0, places=3) + def test_consistency_with_cross_entropy_2d_no_reduction(self): + """For gamma=0 the focal loss reduces to the cross entropy loss""" + import numpy as np + + focal_loss = FocalLoss(to_onehot_y=False, gamma=0.0, reduction="none", weight=1.0) + ce = nn.BCEWithLogitsLoss(reduction="none") + max_error = 0 + class_num = 10 + batch_size = 128 + for _ in range(100): + # Create a random tensor of shape (batch_size, class_num, 8, 4) + x = torch.rand(batch_size, class_num, 8, 4, requires_grad=True) + # Create a random batch of classes + l = torch.randint(low=0, high=2, size=(batch_size, class_num, 8, 4)).float() + if torch.cuda.is_available(): + x = x.cuda() + l = l.cuda() + output0 = focal_loss(x, l) + output1 = ce(x, l) + a = output0.cpu().detach().numpy() + b = output1.cpu().detach().numpy() + error = np.abs(a - b) + max_error = np.maximum(error, max_error) + # if np.all(np.abs(a - b) > max_error): + # max_error = np.abs(a - b) + + assert np.allclose(max_error, 0) + # self.assertAlmostEqual(max_error, 0.0, places=3) + def test_consistency_with_cross_entropy_2d_onehot_label(self): """For gamma=0 the focal loss reduces to the cross entropy loss""" focal_loss = FocalLoss(to_onehot_y=True, gamma=0.0, reduction="mean") diff --git a/tests/test_fourier.py b/tests/test_fourier.py index 488bf0cbf9..e1b5f3089d 100644 --- a/tests/test_fourier.py +++ b/tests/test_fourier.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_fullyconnectednet.py b/tests/test_fullyconnectednet.py index ec91a99c3e..6378ec9718 100644 --- a/tests/test_fullyconnectednet.py +++ b/tests/test_fullyconnectednet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_gaussian.py b/tests/test_gaussian.py index e2659abb0c..461b11d076 100644 --- a/tests/test_gaussian.py +++ b/tests/test_gaussian.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -241,16 +241,8 @@ def test_gaussian(self): rtol=1e-4, ) - np.testing.assert_allclose( - gaussian_1d(1, 1), - torch.tensor([0.24173, 0.382925, 0.24173]), - rtol=1e-4, - ) - np.testing.assert_allclose( - gaussian_1d(1, 1, normalize=True), - torch.tensor([0.2790, 0.4420, 0.2790]), - rtol=1e-4, - ) + np.testing.assert_allclose(gaussian_1d(1, 1), torch.tensor([0.24173, 0.382925, 0.24173]), rtol=1e-4) + np.testing.assert_allclose(gaussian_1d(1, 1, normalize=True), torch.tensor([0.2790, 0.4420, 0.2790]), rtol=1e-4) def test_scalespace_gaussian(self): np.testing.assert_allclose( @@ -272,15 +264,11 @@ def test_scalespace_gaussian(self): ) np.testing.assert_allclose( - gaussian_1d(1, 1, "scalespace"), - torch.tensor([0.20791, 0.46576, 0.20791]), - rtol=1e-3, + gaussian_1d(1, 1, "scalespace"), torch.tensor([0.20791, 0.46576, 0.20791]), rtol=1e-3 ) np.testing.assert_allclose( - gaussian_1d(1, 1, "scalespace", normalize=True), - torch.tensor([0.2358, 0.5283, 0.2358]), - rtol=1e-3, + gaussian_1d(1, 1, "scalespace", normalize=True), torch.tensor([0.2358, 0.5283, 0.2358]), rtol=1e-3 ) np.testing.assert_allclose( diff --git a/tests/test_gaussian_filter.py b/tests/test_gaussian_filter.py index 7636aa5459..9d76e44cec 100644 --- a/tests/test_gaussian_filter.py +++ b/tests/test_gaussian_filter.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,10 +19,7 @@ from tests.utils import SkipIfBeforePyTorchVersion, skip_if_quick TEST_CASES = [[{"type": "erf", "gt": 2.0}], [{"type": "scalespace", "gt": 3.0}], [{"type": "sampled", "gt": 5.0}]] -TEST_CASES_GPU = [ - [{"type": "erf", "gt": 0.8, "device": "cuda"}], - [{"type": "sampled", "gt": 5.0, "device": "cuda"}], -] +TEST_CASES_GPU = [[{"type": "erf", "gt": 0.8, "device": "cuda"}], [{"type": "sampled", "gt": 5.0, "device": "cuda"}]] TEST_CASES_3d = [ [{"type": "scalespace", "gt": 0.5, "dims": (2, 3, 8, 9, 10), "lr": 0.01, "device": "cuda"}], [{"type": "erf", "gt": 3.8, "dims": (2, 3, 8, 9, 10), "lr": 0.1, "device": "cuda"}], diff --git a/tests/test_gaussian_sharpen.py b/tests/test_gaussian_sharpen.py index 9d078e65e5..547febdfaf 100644 --- a/tests/test_gaussian_sharpen.py +++ b/tests/test_gaussian_sharpen.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,50 +11,79 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import GaussianSharpen +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( +TESTS = [] + +for p in TEST_NDARRAYS: + TESTS.append( [ - [[4.1081963, 3.4950666, 4.1081963], [3.7239995, 2.8491793, 3.7239995], [4.569839, 3.9529324, 4.569839]], - [[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]], + {}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [4.1081963, 3.4950666, 4.1081963], + [3.7239995, 2.8491793, 3.7239995], + [4.569839, 3.9529324, 4.569839], + ], + [[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - {"sigma1": 1.0, "sigma2": 0.75, "alpha": 20}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[4.513644, 4.869134, 4.513644], [8.467242, 9.4004135, 8.467242], [10.416813, 12.0653515, 10.416813]], - [[15.711488, 17.569994, 15.711488], [21.16811, 23.501041, 21.16811], [21.614658, 24.766209, 21.614658]], + {"sigma1": 1.0, "sigma2": 0.75, "alpha": 20}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [4.513644, 4.869134, 4.513644], + [8.467242, 9.4004135, 8.467242], + [10.416813, 12.0653515, 10.416813], + ], + [ + [15.711488, 17.569994, 15.711488], + [21.16811, 23.501041, 21.16811], + [21.614658, 24.766209, 21.614658], + ], + ] + ), ] - ), -] + ) -TEST_CASE_3 = [ - {"sigma1": (0.5, 1.0), "sigma2": (0.5, 0.75), "alpha": 20}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[3.3324685, 3.335536, 3.3324673], [7.7666636, 8.16056, 7.7666636], [12.662973, 14.317837, 12.6629715]], - [[15.329051, 16.57557, 15.329051], [19.41665, 20.40139, 19.416655], [24.659554, 27.557873, 24.659554]], + {"sigma1": (0.5, 1.0), "sigma2": (0.5, 0.75), "alpha": 20}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [3.3324685, 3.335536, 3.3324673], + [7.7666636, 8.16056, 7.7666636], + [12.662973, 14.317837, 12.6629715], + ], + [ + [15.329051, 16.57557, 15.329051], + [19.41665, 20.40139, 19.416655], + [24.659554, 27.557873, 24.659554], + ], + ] + ), ] - ), -] + ) class TestGaussianSharpen(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = GaussianSharpen(**argments)(image) - np.testing.assert_allclose(result, expected_data, rtol=1e-4) + assert_allclose(result, expected_data, atol=0, rtol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_gaussian_sharpend.py b/tests/test_gaussian_sharpend.py index c795b11762..d9ef503532 100644 --- a/tests/test_gaussian_sharpend.py +++ b/tests/test_gaussian_sharpend.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,46 +15,75 @@ from parameterized import parameterized from monai.transforms import GaussianSharpend +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"keys": "img"}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( [ - [[4.1081963, 3.4950666, 4.1081963], [3.7239995, 2.8491793, 3.7239995], [4.569839, 3.9529324, 4.569839]], - [[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]], + {"keys": "img"}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [4.1081963, 3.4950666, 4.1081963], + [3.7239995, 2.8491793, 3.7239995], + [4.569839, 3.9529324, 4.569839], + ], + [[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - {"keys": "img", "sigma1": 1.0, "sigma2": 0.75, "alpha": 20}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( + TESTS.append( [ - [[4.513644, 4.869134, 4.513644], [8.467242, 9.4004135, 8.467242], [10.416813, 12.0653515, 10.416813]], - [[15.711488, 17.569994, 15.711488], [21.16811, 23.501041, 21.16811], [21.614658, 24.766209, 21.614658]], + {"keys": "img", "sigma1": 1.0, "sigma2": 0.75, "alpha": 20}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [4.513644, 4.869134, 4.513644], + [8.467242, 9.4004135, 8.467242], + [10.416813, 12.0653515, 10.416813], + ], + [ + [15.711488, 17.569994, 15.711488], + [21.16811, 23.501041, 21.16811], + [21.614658, 24.766209, 21.614658], + ], + ] + ), ] - ), -] + ) -TEST_CASE_3 = [ - {"keys": "img", "sigma1": (0.5, 1.0), "sigma2": (0.5, 0.75), "alpha": 20}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( + TESTS.append( [ - [[3.3324685, 3.335536, 3.3324673], [7.7666636, 8.16056, 7.7666636], [12.662973, 14.317837, 12.6629715]], - [[15.329051, 16.57557, 15.329051], [19.41665, 20.40139, 19.416655], [24.659554, 27.557873, 24.659554]], + {"keys": "img", "sigma1": (0.5, 1.0), "sigma2": (0.5, 0.75), "alpha": 20}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [3.3324685, 3.335536, 3.3324673], + [7.7666636, 8.16056, 7.7666636], + [12.662973, 14.317837, 12.6629715], + ], + [ + [15.329051, 16.57557, 15.329051], + [19.41665, 20.40139, 19.416655], + [24.659554, 27.557873, 24.659554], + ], + ] + ), ] - ), -] + ) class TestGaussianSharpend(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = GaussianSharpend(**argments)(image) - np.testing.assert_allclose(result["img"], expected_data, rtol=1e-4) + assert_allclose(result["img"], expected_data, rtol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_gaussian_smooth.py b/tests/test_gaussian_smooth.py index e51977fbee..53f2fc396b 100644 --- a/tests/test_gaussian_smooth.py +++ b/tests/test_gaussian_smooth.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,54 +11,83 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import GaussianSmooth +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"sigma": 1.5}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( +TESTS = [] + +for p in TEST_NDARRAYS: + TESTS.append( [ - [ - [0.59167546, 0.69312394, 0.59167546], - [0.7956997, 0.93213004, 0.7956997], - [0.7668002, 0.8982755, 0.7668002], - ], - [[1.6105323, 1.8866735, 1.6105323], [1.9892492, 2.3303251, 1.9892492], [1.7856569, 2.091825, 1.7856569]], + {"sigma": 1.5}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [0.59167546, 0.69312394, 0.59167546], + [0.7956997, 0.93213004, 0.7956997], + [0.7668002, 0.8982755, 0.7668002], + ], + [ + [1.6105323, 1.8866735, 1.6105323], + [1.9892492, 2.3303251, 1.9892492], + [1.7856569, 2.091825, 1.7856569], + ], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - {"sigma": 0.5}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[0.8424794, 0.99864554, 0.8424794], [1.678146, 1.9892154, 1.678146], [1.9889624, 2.3576462, 1.9889624]], - [[2.966061, 3.5158648, 2.966061], [4.1953645, 4.973038, 4.1953645], [4.112544, 4.8748655, 4.1125436]], + {"sigma": 0.5}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [0.8424794, 0.99864554, 0.8424794], + [1.678146, 1.9892154, 1.678146], + [1.9889624, 2.3576462, 1.9889624], + ], + [ + [2.966061, 3.5158648, 2.966061], + [4.1953645, 4.973038, 4.1953645], + [4.112544, 4.8748655, 4.1125436], + ], + ] + ), ] - ), -] + ) -TEST_CASE_3 = [ - {"sigma": [1.5, 0.5]}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[0.8542037, 1.0125432, 0.8542037], [1.1487541, 1.3616928, 1.1487541], [1.1070318, 1.3122368, 1.1070318]], - [[2.3251305, 2.756128, 2.3251305], [2.8718853, 3.4042323, 2.8718853], [2.5779586, 3.0558217, 2.5779586]], + {"sigma": [1.5, 0.5]}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [0.8542037, 1.0125432, 0.8542037], + [1.1487541, 1.3616928, 1.1487541], + [1.1070318, 1.3122368, 1.1070318], + ], + [ + [2.3251305, 2.756128, 2.3251305], + [2.8718853, 3.4042323, 2.8718853], + [2.5779586, 3.0558217, 2.5779586], + ], + ] + ), ] - ), -] + ) class TestGaussianSmooth(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = GaussianSmooth(**argments)(image) - np.testing.assert_allclose(result, expected_data, rtol=1e-4) + assert_allclose(result, expected_data, atol=0, rtol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_gaussian_smoothd.py b/tests/test_gaussian_smoothd.py index 3d7eb6195e..839bac81fe 100644 --- a/tests/test_gaussian_smoothd.py +++ b/tests/test_gaussian_smoothd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,50 +15,79 @@ from parameterized import parameterized from monai.transforms import GaussianSmoothd +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"keys": "img", "sigma": 1.5}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( [ - [ - [0.59167546, 0.69312394, 0.59167546], - [0.7956997, 0.93213004, 0.7956997], - [0.7668002, 0.8982755, 0.7668002], - ], - [[1.6105323, 1.8866735, 1.6105323], [1.9892492, 2.3303251, 1.9892492], [1.7856569, 2.091825, 1.7856569]], + {"keys": "img", "sigma": 1.5}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [0.59167546, 0.69312394, 0.59167546], + [0.7956997, 0.93213004, 0.7956997], + [0.7668002, 0.8982755, 0.7668002], + ], + [ + [1.6105323, 1.8866735, 1.6105323], + [1.9892492, 2.3303251, 1.9892492], + [1.7856569, 2.091825, 1.7856569], + ], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - {"keys": "img", "sigma": 0.5}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( + TESTS.append( [ - [[0.8424794, 0.99864554, 0.8424794], [1.678146, 1.9892154, 1.678146], [1.9889624, 2.3576462, 1.9889624]], - [[2.966061, 3.5158648, 2.966061], [4.1953645, 4.973038, 4.1953645], [4.112544, 4.8748655, 4.1125436]], + {"keys": "img", "sigma": 0.5}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [0.8424794, 0.99864554, 0.8424794], + [1.678146, 1.9892154, 1.678146], + [1.9889624, 2.3576462, 1.9889624], + ], + [ + [2.966061, 3.5158648, 2.966061], + [4.1953645, 4.973038, 4.1953645], + [4.112544, 4.8748655, 4.1125436], + ], + ] + ), ] - ), -] + ) -TEST_CASE_3 = [ - {"keys": "img", "sigma": [1.5, 0.5]}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( + TESTS.append( [ - [[0.8542037, 1.0125432, 0.8542037], [1.1487541, 1.3616928, 1.1487541], [1.1070318, 1.3122368, 1.1070318]], - [[2.3251305, 2.756128, 2.3251305], [2.8718853, 3.4042323, 2.8718853], [2.5779586, 3.0558217, 2.5779586]], + {"keys": "img", "sigma": [1.5, 0.5]}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [0.8542037, 1.0125432, 0.8542037], + [1.1487541, 1.3616928, 1.1487541], + [1.1070318, 1.3122368, 1.1070318], + ], + [ + [2.3251305, 2.756128, 2.3251305], + [2.8718853, 3.4042323, 2.8718853], + [2.5779586, 3.0558217, 2.5779586], + ], + ] + ), ] - ), -] + ) class TestGaussianSmoothd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = GaussianSmoothd(**argments)(image) - np.testing.assert_allclose(result["img"], expected_data, rtol=1e-4) + assert_allclose(result["img"], expected_data, rtol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index 06446204fb..f93878a683 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,10 +21,7 @@ TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.307576, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) @@ -87,7 +84,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - [0.273476, 0.555539], + [[[0.273476]], [[0.555539]]], ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 1e-8, "smooth_dr": 1e-8}, @@ -99,10 +96,7 @@ ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.307576, ], [ # shape: (1, 2, 4), (1, 1, 4) diff --git a/tests/test_generalized_wasserstein_dice_loss.py b/tests/test_generalized_wasserstein_dice_loss.py index 295a4a6d70..2c33d365f4 100644 --- a/tests/test_generalized_wasserstein_dice_loss.py +++ b/tests/test_generalized_wasserstein_dice_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -159,7 +159,7 @@ def test_convergence(self): # define a model with one layer class OnelayerNet(nn.Module): def __init__(self): - super(OnelayerNet, self).__init__() + super().__init__() self.layer = nn.Linear(num_voxels, num_voxels * num_classes) def forward(self, x): diff --git a/tests/test_generate_label_classes_crop_centers.py b/tests/test_generate_label_classes_crop_centers.py index 38f2a3e0d1..4f64aadc26 100644 --- a/tests/test_generate_label_classes_crop_centers.py +++ b/tests/test_generate_label_classes_crop_centers.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,11 +10,13 @@ # limitations under the License. import unittest +from copy import deepcopy -import numpy as np from parameterized import parameterized from monai.transforms import generate_label_classes_crop_centers +from monai.utils.misc import set_determinism +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASE_1 = [ { @@ -23,7 +25,6 @@ "ratios": [1, 2], "label_spatial_shape": [3, 3, 3], "indices": [[3, 12, 21], [1, 9, 18]], - "rand_state": np.random.RandomState(), }, list, 2, @@ -37,7 +38,6 @@ "ratios": None, "label_spatial_shape": [3, 3, 3], "indices": [[3, 12, 21], [1, 9, 18]], - "rand_state": np.random.RandomState(), }, list, 1, @@ -48,10 +48,21 @@ class TestGenerateLabelClassesCropCenters(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_type_shape(self, input_data, expected_type, expected_count, expected_shape): - result = generate_label_classes_crop_centers(**input_data) - self.assertIsInstance(result, expected_type) - self.assertEqual(len(result), expected_count) - self.assertEqual(len(result[0]), expected_shape) + results = [] + for p in TEST_NDARRAYS + (None,): + input_data = deepcopy(input_data) + if p is not None: + input_data["indices"] = p(input_data["indices"]) + set_determinism(0) + result = generate_label_classes_crop_centers(**input_data) + self.assertIsInstance(result, expected_type) + self.assertEqual(len(result), expected_count) + self.assertEqual(len(result[0]), expected_shape) + # check for consistency between numpy, torch and torch.cuda + results.append(result) + if len(results) > 1: + for x, y in zip(result[0], result[-1]): + assert_allclose(x, y, type_test=False) if __name__ == "__main__": diff --git a/tests/test_generate_param_groups.py b/tests/test_generate_param_groups.py index ea1fad44f9..0b259442ea 100644 --- a/tests/test_generate_param_groups.py +++ b/tests/test_generate_param_groups.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,15 +18,7 @@ from monai.optimizers import generate_param_groups from monai.utils import ensure_tuple -TEST_CASE_1 = [ - { - "layer_matches": [lambda x: x.model[-1]], - "match_types": "select", - "lr_values": [1], - }, - (1, 100), - [5, 21], -] +TEST_CASE_1 = [{"layer_matches": [lambda x: x.model[-1]], "match_types": "select", "lr_values": [1]}, (1, 100), [5, 21]] TEST_CASE_2 = [ { @@ -39,11 +31,7 @@ ] TEST_CASE_3 = [ - { - "layer_matches": [lambda x: x.model[2][1].conv[0].conv], - "match_types": ["select"], - "lr_values": [1], - }, + {"layer_matches": [lambda x: x.model[2][1].conv[0].conv], "match_types": ["select"], "lr_values": [1]}, (1, 100), [2, 24], ] @@ -59,12 +47,7 @@ ] TEST_CASE_5 = [ - { - "layer_matches": [lambda x: x.model[-1]], - "match_types": ["select"], - "lr_values": [1], - "include_others": False, - }, + {"layer_matches": [lambda x: x.model[-1]], "match_types": ["select"], "lr_values": [1], "include_others": False}, (1), [5], ] @@ -86,12 +69,7 @@ class TestGenerateParamGroups(unittest.TestCase): def test_lr_values(self, input_param, expected_values, expected_groups): device = "cuda" if torch.cuda.is_available() else "cpu" net = Unet( - dimensions=3, - in_channels=1, - out_channels=3, - channels=(16, 32, 64), - strides=(2, 2), - num_res_units=1, + spatial_dims=3, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=1 ).to(device) params = generate_param_groups(network=net, **input_param) @@ -107,12 +85,7 @@ def test_wrong(self): """overlapped""" device = "cuda" if torch.cuda.is_available() else "cpu" net = Unet( - dimensions=3, - in_channels=1, - out_channels=3, - channels=(16, 32, 64), - strides=(2, 2), - num_res_units=1, + spatial_dims=3, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=1 ).to(device) params = generate_param_groups( diff --git a/tests/test_generate_pos_neg_label_crop_centers.py b/tests/test_generate_pos_neg_label_crop_centers.py index 40181aa9ea..e1d9398fe3 100644 --- a/tests/test_generate_pos_neg_label_crop_centers.py +++ b/tests/test_generate_pos_neg_label_crop_centers.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,35 +10,52 @@ # limitations under the License. import unittest +from copy import deepcopy -import numpy as np from parameterized import parameterized from monai.transforms import generate_pos_neg_label_crop_centers - -TEST_CASE_1 = [ - { - "spatial_size": [2, 2, 2], - "num_samples": 2, - "pos_ratio": 1.0, - "label_spatial_shape": [3, 3, 3], - "fg_indices": [1, 9, 18], - "bg_indices": [3, 12, 21], - "rand_state": np.random.RandomState(), - }, - list, - 2, - 3, -] +from monai.utils.misc import set_determinism +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +TESTS.append( + [ + { + "spatial_size": [2, 2, 2], + "num_samples": 2, + "pos_ratio": 1.0, + "label_spatial_shape": [3, 3, 3], + "fg_indices": [1, 9, 18], + "bg_indices": [3, 12, 21], + }, + list, + 2, + 3, + ] +) class TestGeneratePosNegLabelCropCenters(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TESTS) def test_type_shape(self, input_data, expected_type, expected_count, expected_shape): - result = generate_pos_neg_label_crop_centers(**input_data) - self.assertIsInstance(result, expected_type) - self.assertEqual(len(result), expected_count) - self.assertEqual(len(result[0]), expected_shape) + results = [] + for p in TEST_NDARRAYS + (None,): + input_data = deepcopy(input_data) + if p is not None: + for k in ["fg_indices", "bg_indices"]: + input_data[k] = p(input_data[k]) + set_determinism(0) + result = generate_pos_neg_label_crop_centers(**input_data) + self.assertIsInstance(result, expected_type) + self.assertEqual(len(result), expected_count) + self.assertEqual(len(result[0]), expected_shape) + # check for consistency between numpy, torch and torch.cuda + results.append(result) + if len(results) > 1: + # compare every crop center + for x, y in zip(results[0], results[-1]): + assert_allclose(x, y, type_test=False) if __name__ == "__main__": diff --git a/tests/test_generate_spatial_bounding_box.py b/tests/test_generate_spatial_bounding_box.py index 32a45d8d1c..00c9d49724 100644 --- a/tests/test_generate_spatial_bounding_box.py +++ b/tests/test_generate_spatial_bounding_box.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,60 +15,79 @@ from parameterized import parameterized from monai.transforms import generate_spatial_bounding_box +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 0, - "channel_indices": None, - "margin": 0, - }, - ([1, 1], [4, 4]), -] - -TEST_CASE_2 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 1, - "channel_indices": None, - "margin": 0, - }, - ([2, 2], [3, 3]), -] - -TEST_CASE_3 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 0, - "channel_indices": 0, - "margin": 0, - }, - ([1, 1], [4, 4]), -] - -TEST_CASE_4 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 0, - "channel_indices": None, - "margin": 1, - }, - ([0, 0], [4, 5]), -] - -TEST_CASE_5 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 0, - "channel_indices": None, - "margin": [2, 1], - }, - ([0, 0], [5, 5]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": 0, + }, + ([1, 1], [4, 4]), + ] + ) + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 1, + "channel_indices": None, + "margin": 0, + }, + ([2, 2], [3, 3]), + ] + ) + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 0, + "channel_indices": 0, + "margin": 0, + }, + ([1, 1], [4, 4]), + ] + ) + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": 1, + }, + ([0, 0], [4, 5]), + ] + ) + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": [2, 1], + }, + ([0, 0], [5, 5]), + ] + ) class TestGenerateSpatialBoundingBox(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_value(self, input_data, expected_box): result = generate_spatial_bounding_box(**input_data) self.assertTupleEqual(result, expected_box) diff --git a/tests/test_generator.py b/tests/test_generator.py index b5d846febc..617655f86e 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_get_equivalent_dtype.py b/tests/test_get_equivalent_dtype.py index 04ba5ae5fb..fc0867523d 100644 --- a/tests/test_get_equivalent_dtype.py +++ b/tests/test_get_equivalent_dtype.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -32,6 +32,14 @@ def test_get_equivalent_dtype(self, im, input_dtype): out_dtype = get_equivalent_dtype(input_dtype, type(im)) self.assertEqual(out_dtype, im.dtype) + def test_native_type(self): + """the get_equivalent_dtype currently doesn't change the build-in type""" + n_type = [float, int, bool] + for n in n_type: + for im_dtype in DTYPES: + out_dtype = get_equivalent_dtype(n, type(im_dtype)) + self.assertEqual(out_dtype, n) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_get_extreme_points.py b/tests/test_get_extreme_points.py index a334c12415..457351b98c 100644 --- a/tests/test_get_extreme_points.py +++ b/tests/test_get_extreme_points.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,30 +15,37 @@ from parameterized import parameterized from monai.transforms import get_extreme_points - -TEST_CASE_1 = [ - { - "img": np.array([[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]), - "rand_state": np.random, - "background": 0, - "pert": 0.0, - }, - [(0, 1), (3, 0), (3, 0), (1, 2)], -] - -TEST_CASE_2 = [ - { - "img": np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]), - "rand_state": np.random, - "background": 0, - "pert": 0.0, - }, - [(0, 1), (3, 1), (1, 0), (1, 2)], -] +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + { + "img": p(np.array([[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]])), + "rand_state": np.random, + "background": 0, + "pert": 0.0, + }, + [(0, 1), (3, 0), (3, 0), (1, 2)], + ] + ) + + TESTS.append( + [ + { + "img": p(np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]])), + "rand_state": np.random, + "background": 0, + "pert": 0.0, + }, + [(0, 1), (3, 1), (1, 0), (1, 2)], + ] + ) class TestGetExtremePoints(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_type_shape(self, input_data, expected): result = get_extreme_points(**input_data) self.assertEqual(result, expected) diff --git a/tests/test_get_layers.py b/tests/test_get_layers.py index e6ea810a6b..6109052d1f 100644 --- a/tests/test_get_layers.py +++ b/tests/test_get_layers.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_get_package_version.py b/tests/test_get_package_version.py index beddb340ab..c4e15c9d09 100644 --- a/tests/test_get_package_version.py +++ b/tests/test_get_package_version.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_gibbs_noise.py b/tests/test_gibbs_noise.py index 264e2e630a..3fbe047944 100644 --- a/tests/test_gibbs_noise.py +++ b/tests/test_gibbs_noise.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,17 +19,17 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import GibbsNoise from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS + +_, has_torch_fft = optional_import("torch.fft", name="fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, input_type)) -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestGibbsNoise(unittest.TestCase): def setUp(self): set_determinism(0) @@ -39,36 +39,39 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d im = create_test_image(*im_shape, num_objs=4, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] - return torch.Tensor(im) if as_tensor_input else im + return input_type(im) @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_same_result(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = 0.8 - t = GibbsNoise(alpha, as_tensor_output) + t = GibbsNoise(alpha) out1 = t(deepcopy(im)) out2 = t(deepcopy(im)) - np.testing.assert_allclose(out1, out2) - self.assertIsInstance(out1, torch.Tensor if as_tensor_output else np.ndarray) + self.assertEqual(type(out1), type(im)) + if isinstance(out1, torch.Tensor): + self.assertEqual(out1.device, im.device) + torch.testing.assert_allclose(out1, out2, rtol=1e-7, atol=0) + self.assertIsInstance(out1, type(im)) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_identity(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = 0.0 t = GibbsNoise(alpha) out = t(deepcopy(im)) - np.testing.assert_allclose(im, out, atol=1e-2) + torch.testing.assert_allclose(im, out, atol=1e-2, rtol=1e-7) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_alpha_1(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = 1.0 t = GibbsNoise(alpha) out = t(deepcopy(im)) - np.testing.assert_allclose(0 * im, out) + torch.testing.assert_allclose(0 * im, out, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_gibbs_noised.py b/tests/test_gibbs_noised.py index 558556489a..4905300703 100644 --- a/tests/test_gibbs_noised.py +++ b/tests/test_gibbs_noised.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,19 +19,18 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import GibbsNoised from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS + +_, has_torch_fft = optional_import("torch.fft", name="fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) - + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, input_type)) KEYS = ["im", "label"] -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestGibbsNoised(unittest.TestCase): def setUp(self): set_determinism(0) @@ -41,49 +40,56 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims - return dict(zip(KEYS, ims)) + return {k: input_type(deepcopy(v)) for k, v in zip(KEYS, ims)} @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_same_result(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = 0.8 - t = GibbsNoised(KEYS, alpha, as_tensor_output) + t = GibbsNoised(KEYS, alpha) out1 = t(deepcopy(data)) out2 = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(out1[k], out2[k]) - self.assertIsInstance(out1[k], torch.Tensor if as_tensor_output else np.ndarray) + torch.testing.assert_allclose(out1[k], out2[k], rtol=1e-7, atol=0) + self.assertIsInstance(out1[k], type(data[k])) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_identity(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = 0.0 t = GibbsNoised(KEYS, alpha) out = t(deepcopy(data)) for k in KEYS: + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k], data[k] = out[k].cpu(), data[k].cpu() np.testing.assert_allclose(data[k], out[k], atol=1e-2) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_alpha_1(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = 1.0 t = GibbsNoised(KEYS, alpha) out = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(0 * data[k], out[k]) + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k], data[k] = out[k].cpu(), data[k].cpu() + np.testing.assert_allclose(0.0 * data[k], out[k], atol=1e-2) @parameterized.expand(TEST_CASES) - def test_dict_matches(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_dict_matches(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} alpha = 1.0 t = GibbsNoised(KEYS, alpha) out = t(deepcopy(data)) - np.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]]) + torch.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]], rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_global_mutual_information_loss.py b/tests/test_global_mutual_information_loss.py index 3373b59621..1e97122d08 100644 --- a/tests/test_global_mutual_information_loss.py +++ b/tests/test_global_mutual_information_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -8,86 +8,98 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os import unittest import numpy as np import torch -from parameterized import parameterized +from monai import transforms +from monai.apps import download_url from monai.losses.image_dissimilarity import GlobalMutualInformationLoss +from tests.utils import SkipIfBeforePyTorchVersion device = "cuda" if torch.cuda.is_available() else "cpu" -TEST_CASES = [ - [ - {}, - { - "pred": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None, None] - .expand(1, 3, 3, 3, 3) - .div(3), - "target": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None, None] - .expand(1, 3, 3, 3, 3) - .div(3), - }, - -1.0986018, - ], - [ - {}, - { - "pred": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None, None] - .expand(1, 3, 3, 3, 3) - .div(3), - "target": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None, None] - .expand(1, 3, 3, 3, 3) - .div(3) - ** 2, - }, - -1.083999, - ], - [ - {}, - { - "pred": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None].expand(1, 3, 3, 3).div(3), - "target": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None] - .expand(1, 3, 3, 3) - .div(3) - ** 2, - }, - -1.083999, - ], - [ - {}, - { - "pred": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None].expand(1, 3, 3).div(3), - "target": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None].expand(1, 3, 3).div(3) ** 2, - }, - -1.083999, - ], - [ - {}, - { - "pred": torch.arange(0, 3, dtype=torch.float, device=device)[None, :].div(3), - "target": torch.arange(0, 3, dtype=torch.float, device=device)[None, :].div(3) ** 2, - }, - -1.083999, +FILE_URL = "https://drive.google.com/uc?id=17tsDLvG_GZm7a4fCVMCv-KyDx0hqq1ji" +FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + "mri.nii") + +EXPECTED_VALUE = { + "xyz_translation": [ + -1.5860259532928467, + -0.5957175493240356, + -0.3855515122413635, + -0.28728482127189636, + -0.23416118323802948, + -0.19534644484519958, + -0.17001715302467346, + -0.15043553709983826, + -0.1366637945175171, + -0.12534910440444946, ], - [ - {}, - { - "pred": torch.arange(0, 3, dtype=torch.float, device=device).div(3), - "target": torch.arange(0, 3, dtype=torch.float, device=device).div(3) ** 2, - }, - -1.1920927e-07, + "xyz_rotation": [ + -1.5860259532928467, + -0.29977330565452576, + -0.18411292135715485, + -0.1582011878490448, + -0.16107326745986938, + -0.165723517537117, + -0.1970357596874237, + -0.1755618453025818, + -0.17100191116333008, + -0.17264796793460846, ], -] +} class TestGlobalMutualInformationLoss(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_shape(self, input_param, input_data, expected_val): - result = GlobalMutualInformationLoss(**input_param).forward(**input_data) - np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4) + def setUp(self): + download_url(FILE_URL, FILE_PATH) + + @SkipIfBeforePyTorchVersion((1, 9)) + def test_bspline(self): + loss_fn = GlobalMutualInformationLoss(kernel_type="b-spline", num_bins=32, sigma_ratio=0.015) + + transform_params_dict = { + "xyz_translation": [(i, i, i) for i in range(10)], + "xyz_rotation": [(np.pi / 100 * i, np.pi / 100 * i, np.pi / 100 * i) for i in range(10)], + } + + def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0.0)): + """ + Read and transform Prostate_T2W_AX_1.nii + Args: + translate_params: a tuple of 3 floats, translation is in pixel/voxel relative to the center of the input + image. Defaults to no translation. + rotate_params: a rotation angle in radians, a tuple of 3 floats for 3D. + Defaults to no rotation. + Returns: + numpy array of shape HWD + """ + transform_list = [ + transforms.LoadImaged(keys="img"), + transforms.Affined( + keys="img", translate_params=translate_params, rotate_params=rotate_params, device=None + ), + transforms.NormalizeIntensityd(keys=["img"]), + ] + transformation = transforms.Compose(transform_list) + return transformation({"img": FILE_PATH})["img"] + + a1 = transformation() + a1 = torch.tensor(a1).unsqueeze(0).unsqueeze(0).to(device) + + for mode in transform_params_dict.keys(): + transform_params_list = transform_params_dict[mode] + expected_value_list = EXPECTED_VALUE[mode] + for transform_params, expected_value in zip(transform_params_list, expected_value_list): + a2 = transformation( + translate_params=transform_params if "translation" in mode else (0.0, 0.0, 0.0), + rotate_params=transform_params if "rotation" in mode else (0.0, 0.0, 0.0), + ) + a2 = torch.tensor(a2).unsqueeze(0).unsqueeze(0).to(device) + result = loss_fn(a2, a1).detach().cpu().numpy() + np.testing.assert_allclose(result, expected_value, rtol=1e-3, atol=5e-3) def test_ill_shape(self): loss = GlobalMutualInformationLoss() diff --git a/tests/test_globalnet.py b/tests/test_globalnet.py index 32bc58f610..ef0209e397 100644 --- a/tests/test_globalnet.py +++ b/tests/test_globalnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_gmm.py b/tests/test_gmm.py index 0e2401b452..f085dd916c 100644 --- a/tests/test_gmm.py +++ b/tests/test_gmm.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -47,12 +47,12 @@ # Batch 0 [ # Channel 0 - [1, -1, 0, -1, 1], + [1, -1, 0, -1, 1] ], # Batch 1 [ # Channel 0 - [1, 1, 0, 0, -1], + [1, 1, 0, 0, -1] ], ], # Expected @@ -94,15 +94,15 @@ [0.7, 0.9, 0.0, 0.0, 0.0], # Channel 4 [0.2, 0.1, 0.2, 0.2, 0.1], - ], + ] ], # Labels [ # Batch 0 [ # Channel 0 - [0, 0, -1, 1, 1], - ], + [0, 0, -1, 1, 1] + ] ], # Expected [ @@ -112,7 +112,7 @@ [1, 1, 0, 0, 0], # Channel 1 [0, 0, 1, 1, 1], - ], + ] ], ], [ @@ -142,21 +142,15 @@ [0.4, 0.5, 0.0, 0.0, 0.0], [0.7, 0.6, 0.0, 0.0, 0.0], ], - ], + ] ], # Labels [ # Batch 0 [ # Channel 0 - [ - [-1, 1, -1, 0, -1], - [1, -1, -1, -1, -1], - [-1, -1, 0, -1, -1], - [2, 2, -1, 3, -1], - [-1, -1, -1, -1, 3], - ], - ], + [[-1, 1, -1, 0, -1], [1, -1, -1, -1, -1], [-1, -1, 0, -1, -1], [2, 2, -1, 3, -1], [-1, -1, -1, -1, 3]] + ] ], # Expected [ @@ -194,7 +188,7 @@ [0.0, 0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 1.0, 1.0], ], - ], + ] ], ], [ @@ -211,25 +205,13 @@ # Channel 0 [ # Slice 0 - [ - [0.7, 0.6, 0.0], - [0.5, 0.4, 0.0], - [0.0, 0.0, 0.0], - ], + [[0.7, 0.6, 0.0], [0.5, 0.4, 0.0], [0.0, 0.0, 0.0]], # Slice 1 - [ - [0.5, 0.6, 0.0], - [0.4, 0.3, 0.0], - [0.0, 0.0, 0.0], - ], + [[0.5, 0.6, 0.0], [0.4, 0.3, 0.0], [0.0, 0.0, 0.0]], # Slice 2 - [ - [0.3, 0.3, 0.0], - [0.2, 0.1, 0.0], - [0.0, 0.0, 0.0], - ], - ], - ], + [[0.3, 0.3, 0.0], [0.2, 0.1, 0.0], [0.0, 0.0, 0.0]], + ] + ] ], # Labels [ @@ -238,25 +220,13 @@ # Channel 0 [ # Slice 0 - [ - [0, -1, -1], - [0, -1, -1], - [-1, -1, 1], - ], + [[0, -1, -1], [0, -1, -1], [-1, -1, 1]], # Slice 1 - [ - [0, 0, -1], - [-1, -1, 1], - [-1, 1, 1], - ], + [[0, 0, -1], [-1, -1, 1], [-1, 1, 1]], # Slice 2 - [ - [0, -1, -1], - [-1, -1, -1], - [-1, -1, -1], - ], - ], - ], + [[0, -1, -1], [-1, -1, -1], [-1, -1, -1]], + ] + ] ], # Expected [ @@ -265,46 +235,22 @@ # Channel 0 [ # Slice 0 - [ - [1.0, 1.0, 0.0], - [1.0, 1.0, 0.0], - [0.0, 0.0, 0.0], - ], + [[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 0.0]], # Slice 1 - [ - [1.0, 1.0, 0.0], - [1.0, 1.0, 0.0], - [0.0, 0.0, 0.0], - ], + [[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 0.0]], # Slice 2 - [ - [1.0, 1.0, 0.0], - [1.0, 1.0, 0.0], - [0.0, 0.0, 0.0], - ], + [[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 0.0]], ], # Channel 1 [ # Slice 0 - [ - [0.0, 0.0, 1.0], - [0.0, 0.0, 1.0], - [1.0, 1.0, 1.0], - ], + [[0.0, 0.0, 1.0], [0.0, 0.0, 1.0], [1.0, 1.0, 1.0]], # Slice 1 - [ - [0.0, 0.0, 1.0], - [0.0, 0.0, 1.0], - [1.0, 1.0, 1.0], - ], + [[0.0, 0.0, 1.0], [0.0, 0.0, 1.0], [1.0, 1.0, 1.0]], # Slice 2 - [ - [0.0, 0.0, 1.0], - [0.0, 0.0, 1.0], - [1.0, 1.0, 1.0], - ], + [[0.0, 0.0, 1.0], [0.0, 0.0, 1.0], [1.0, 1.0, 1.0]], ], - ], + ] ], ], ] diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py index 6e0aa4023e..9c4bcc52ae 100644 --- a/tests/test_grid_dataset.py +++ b/tests/test_grid_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -56,26 +56,20 @@ def test_loading_array(self): np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) np.testing.assert_allclose( item[0], - np.array([[[[1.7413, 2.7413], [5.7413, 6.7413]]], [[[9.1419, 10.1419], [13.1419, 14.1419]]]]), - rtol=1e-5, - ) - np.testing.assert_allclose( - item[1], - np.array([[[0, 1], [0, 2], [2, 4]], [[0, 1], [2, 4], [2, 4]]]), + np.array([[[[2.0577, 3.0577], [6.0577, 7.0577]]], [[[10.5540, 11.5540], [14.5540, 15.5540]]]]), rtol=1e-5, ) + np.testing.assert_allclose(item[1], np.array([[[0, 1], [0, 2], [2, 4]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5) if sys.platform != "win32": for item in DataLoader(ds, batch_size=2, shuffle=False, num_workers=2): np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) np.testing.assert_allclose( item[0], - np.array([[[[2.3944, 3.3944], [6.3944, 7.3944]]], [[[10.6551, 11.6551], [14.6551, 15.6551]]]]), + np.array([[[[1.6533, 2.6533], [5.6533, 6.6533]]], [[[9.8524, 10.8524], [13.8524, 14.8524]]]]), rtol=1e-3, ) np.testing.assert_allclose( - item[1], - np.array([[[0, 1], [0, 2], [2, 4]], [[0, 1], [2, 4], [2, 4]]]), - rtol=1e-5, + item[1], np.array([[[0, 1], [0, 2], [2, 4]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5 ) diff --git a/tests/test_grid_distortion.py b/tests/test_grid_distortion.py new file mode 100644 index 0000000000..5e7ccd7c32 --- /dev/null +++ b/tests/test_grid_distortion.py @@ -0,0 +1,108 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import GridDistortion +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + dict(num_cells=3, distort_steps=[(1.5,) * 4] * 2, mode="nearest", padding_mode="zeros"), + p(np.indices([6, 6]).astype(np.float32)), + p( + np.array( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [3.0, 3.0, 3.0, 0.0, 0.0, 0.0], + [3.0, 3.0, 3.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 3.0, 3.0, 0.0, 0.0, 0.0], + [0.0, 3.0, 3.0, 0.0, 0.0, 0.0], + [0.0, 3.0, 3.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + ] + ).astype(np.float32) + ), + ] + ) + num_cells = (2, 2) + distort_steps = [(1.5,) * (1 + num_cells[0]), (1.0,) * (1 + num_cells[1])] + TESTS.append( + [ + dict(num_cells=num_cells, distort_steps=distort_steps, mode="bilinear", padding_mode="reflection"), + p(np.indices([6, 6]).astype(np.float32)), + p( + np.array( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [2.25, 2.25, 2.25, 2.25, 2.25, 2.25], + [4.5, 4.5, 4.5, 4.5, 4.5, 4.5], + [4.5, 4.5, 4.5, 4.5, 4.5, 4.5], + [3.25, 3.25, 3.25, 3.25, 3.25, 3.25], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ], + [ + [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], + [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], + [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], + [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], + [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], + [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], + ], + ] + ).astype(np.float32) + ), + ] + ) + TESTS.append( + [ + dict(num_cells=2, distort_steps=[(1.25,) * 3] * 3, mode="nearest", padding_mode="zeros"), + p(np.indices([3, 3, 3])[:1].astype(np.float32)), + p( + np.array( + [ + [ + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + ] + ] + ).astype(np.float32) + ), + ] + ) + + +class TestGridDistortion(unittest.TestCase): + @parameterized.expand(TESTS) + def test_grid_distortion(self, input_param, input_data, expected_val): + g = GridDistortion(**input_param) + result = g(input_data) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_grid_distortiond.py b/tests/test_grid_distortiond.py new file mode 100644 index 0000000000..662596f935 --- /dev/null +++ b/tests/test_grid_distortiond.py @@ -0,0 +1,85 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import GridDistortiond +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +num_cells = (2, 2) +distort_steps = [(1.5,) * (1 + n_c) for n_c in num_cells] +for p in TEST_NDARRAYS: + img = np.indices([6, 6]).astype(np.float32) + TESTS.append( + [ + dict( + keys=["img", "mask"], + num_cells=num_cells, + distort_steps=distort_steps, + mode=["bilinear", "nearest"], + padding_mode=["reflection", "zeros"], + ), + {"img": p(img), "mask": p(np.ones_like(img[:1]))}, + p( + np.array( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [2.25, 2.25, 2.25, 2.25, 2.25, 2.25], + [4.5, 4.5, 4.5, 4.5, 4.5, 4.5], + [4.5, 4.5, 4.5, 4.5, 4.5, 4.5], + [3.25, 3.25, 3.25, 3.25, 3.25, 3.25], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ], + [ + [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], + [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], + [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], + [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], + [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], + [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], + ], + ] + ).astype(np.float32) + ), + p( + np.array( + [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] + ] + ).astype(np.float32) + ), + ] + ) + + +class TestGridDistortiond(unittest.TestCase): + @parameterized.expand(TESTS) + def test_grid_distortiond(self, input_param, input_data, expected_val_img, expected_val_mask): + g = GridDistortiond(**input_param) + result = g(input_data) + assert_allclose(result["img"], expected_val_img, rtol=1e-4, atol=1e-4) + assert_allclose(result["mask"], expected_val_mask, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_grid_pull.py b/tests/test_grid_pull.py index 9e4d2e8237..b0f19d1a00 100644 --- a/tests/test_grid_pull.py +++ b/tests/test_grid_pull.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -53,11 +53,7 @@ def make_grid(shape, dtype=None, device=None, requires_grad=True): "interpolation": interp, "bound": bound, }, - { - "val": torch.tensor([[expected_val]]), - "device": device, - "grad": torch.tensor(expected_grad), - }, + {"val": torch.tensor([[expected_val]]), "device": device, "grad": torch.tensor(expected_grad)}, ] TEST_1D_GP.append(test_case) @@ -85,7 +81,7 @@ def test_grid_pull(self, input_param, expected): grads = grads[0] else: grads = torch.cat(grads, dim=0) - self.assertTrue("{}".format(result.device).startswith(expected["device"])) + self.assertTrue(f"{result.device}".startswith(expected["device"])) np.testing.assert_allclose(result.detach().cpu().numpy(), expected["val"].cpu().numpy(), rtol=1e-4, atol=1e-4) np.testing.assert_allclose(grads.detach().cpu().numpy(), expected["grad"].cpu().numpy(), rtol=1e-4, atol=1e-4) diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py index 81a3cdc96d..ec43ed357b 100644 --- a/tests/test_handler_checkpoint_loader.py +++ b/tests/test_handler_checkpoint_loader.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_checkpoint_saver.py b/tests/test_handler_checkpoint_saver.py index bcab49f12b..1b746184a4 100644 --- a/tests/test_handler_checkpoint_saver.py +++ b/tests/test_handler_checkpoint_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -112,16 +112,7 @@ class TestHandlerCheckpointSaver(unittest.TestCase): @parameterized.expand( - [ - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - TEST_CASE_7, - TEST_CASE_8, - ] + [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8] ) def test_file( self, diff --git a/tests/test_handler_classification_saver.py b/tests/test_handler_classification_saver.py index 87ce5ca3f8..a498fa2b5c 100644 --- a/tests/test_handler_classification_saver.py +++ b/tests/test_handler_classification_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -45,7 +45,7 @@ def _train_func(engine, batch): def _test_file(filename): filepath = os.path.join(tempdir, filename) self.assertTrue(os.path.exists(filepath)) - with open(filepath, "r") as f: + with open(filepath) as f: reader = csv.reader(f) i = 0 for row in reader: diff --git a/tests/test_handler_classification_saver_dist.py b/tests/test_handler_classification_saver_dist.py index 70cc0ca42f..e92009d37f 100644 --- a/tests/test_handler_classification_saver_dist.py +++ b/tests/test_handler_classification_saver_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -61,7 +61,7 @@ def _train_func(engine, batch): filepath = os.path.join(tempdir, "predictions.csv") if rank == 1: self.assertTrue(os.path.exists(filepath)) - with open(filepath, "r") as f: + with open(filepath) as f: reader = csv.reader(f) i = 0 for row in reader: diff --git a/tests/test_handler_confusion_matrix.py b/tests/test_handler_confusion_matrix.py index 0c6e36066b..5bddef26af 100644 --- a/tests/test_handler_confusion_matrix.py +++ b/tests/test_handler_confusion_matrix.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,7 +20,7 @@ TEST_CASE_1 = [{"include_background": True, "save_details": False, "metric_name": "f1"}, 0.75] TEST_CASE_2 = [{"include_background": False, "save_details": False, "metric_name": "ppv"}, 1.0] - +TEST_CASE_3 = [{"save_details": False, "metric_name": "f1", "reduction": "mean_batch"}, torch.tensor([0.6667, 0.8000])] TEST_CASE_SEG_1 = [{"include_background": True, "metric_name": "tpr"}, 0.7] data_1: Dict[Any, Any] = { @@ -39,23 +39,15 @@ } data_2: Dict[Any, Any] = { - "y_pred": torch.tensor( - [ - [[[0.0, 1.0], [1.0, 0.0]], [[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [0.0, 0.0]]], - ] - ), - "y": torch.tensor( - [ - [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], - ] - ), + "y_pred": torch.tensor([[[[0.0, 1.0], [1.0, 0.0]], [[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [0.0, 0.0]]]]), + "y": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]), } class TestHandlerConfusionMatrix(unittest.TestCase): # TODO test multi node averaged confusion matrix - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_compute(self, input_params, expected_avg): metric = ConfusionMatrix(**input_params) # test input a list of channel-first tensor @@ -68,7 +60,7 @@ def test_compute(self, input_params, expected_avg): metric.update([y_pred, y]) avg_metric = metric.compute() - self.assertAlmostEqual(avg_metric, expected_avg, places=4) + torch.testing.assert_allclose(avg_metric, expected_avg) @parameterized.expand([TEST_CASE_SEG_1]) def test_compute_seg(self, input_params, expected_avg): diff --git a/tests/test_handler_confusion_matrix_dist.py b/tests/test_handler_confusion_matrix_dist.py index 40245bce2e..325a799098 100644 --- a/tests/test_handler_confusion_matrix_dist.py +++ b/tests/test_handler_confusion_matrix_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -54,12 +54,10 @@ def _val_func(engine, batch): if dist.get_rank() == 1: y_pred = torch.tensor( - [[[[0.0, 1.0], [1.0, 0.0]], [[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [0.0, 0.0]]]], - device=device, + [[[[0.0, 1.0], [1.0, 0.0]], [[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [0.0, 0.0]]]], device=device ) y = torch.tensor( - [[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]], - device=device, + [[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]], device=device ) metric.update([y_pred, y]) diff --git a/tests/test_handler_decollate_batch.py b/tests/test_handler_decollate_batch.py index 8f0ffb2b5c..1a43ae295b 100644 --- a/tests/test_handler_decollate_batch.py +++ b/tests/test_handler_decollate_batch.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -32,7 +32,7 @@ def test_compute(self): [ Activationsd(keys="pred", sigmoid=True), CopyItemsd(keys="filename", times=1, names="filename_bak"), - AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, num_classes=2), + AsDiscreted(keys="pred", threshold=0.5, to_onehot=2), ] ) ), diff --git a/tests/test_handler_early_stop.py b/tests/test_handler_early_stop.py index efe8e89825..36604e5735 100644 --- a/tests/test_handler_early_stop.py +++ b/tests/test_handler_early_stop.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,10 +23,7 @@ def _train_func(engine, batch): trainer = Engine(_train_func) EarlyStopHandler( - patience=5, - score_function=lambda x: x.state.output["loss"], - trainer=trainer, - epoch_level=False, + patience=5, score_function=lambda x: x.state.output["loss"], trainer=trainer, epoch_level=False ).attach(trainer) trainer.run(range(4), max_epochs=2) diff --git a/tests/test_handler_garbage_collector.py b/tests/test_handler_garbage_collector.py index 75ab9ceb99..0350ba62fb 100644 --- a/tests/test_handler_garbage_collector.py +++ b/tests/test_handler_garbage_collector.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,13 +34,7 @@ class TestHandlerGarbageCollector(unittest.TestCase): @skipUnless(has_ignite, "Requires ignite") - @parameterized.expand( - [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - ] - ) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) def test_content(self, data, trigger_event): # set up engine gb_count_dict = {} diff --git a/tests/test_handler_hausdorff_distance.py b/tests/test_handler_hausdorff_distance.py index bbc36cc2b5..7e38f0ad56 100644 --- a/tests/test_handler_hausdorff_distance.py +++ b/tests/test_handler_hausdorff_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,9 +20,7 @@ def create_spherical_seg_3d( - radius: float = 20.0, - centre: Tuple[int, int, int] = (49, 49, 49), - im_shape: Tuple[int, int, int] = (99, 99, 99), + radius: float = 20.0, centre: Tuple[int, int, int] = (49, 49, 49), im_shape: Tuple[int, int, int] = (99, 99, 99) ) -> np.ndarray: """ Return a 3D image with a sphere inside. Voxel values will be @@ -91,6 +89,21 @@ def test_shape_mismatch(self): y = torch.ones((1, 1, 10, 10, 10)) hd_metric.update([y_pred, y]) + def test_reduction(self): + hd_metric = HausdorffDistance(include_background=True, reduction="mean_channel") + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + hd_metric.attach(engine, "hausdorff_distance") + + y_pred, y = TEST_SAMPLE_1 + hd_metric.update([y_pred, y]) + y_pred, y = TEST_SAMPLE_2 + hd_metric.update([y_pred, y]) + torch.testing.assert_allclose(hd_metric.compute().float(), torch.tensor([10.0, 0.0])) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_handler_lr_scheduler.py b/tests/test_handler_lr_scheduler.py index 82a62dce21..cbb752925b 100644 --- a/tests/test_handler_lr_scheduler.py +++ b/tests/test_handler_lr_scheduler.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_mean_dice.py b/tests/test_handler_mean_dice.py index ba4fb9d413..f309c7e693 100644 --- a/tests/test_handler_mean_dice.py +++ b/tests/test_handler_mean_dice.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,12 +19,17 @@ TEST_CASE_1 = [{"include_background": True, "output_transform": from_engine(["pred", "label"])}, 0.75, (4, 2)] TEST_CASE_2 = [{"include_background": False, "output_transform": from_engine(["pred", "label"])}, 0.66666, (4, 1)] +TEST_CASE_3 = [ + {"reduction": "mean_channel", "output_transform": from_engine(["pred", "label"])}, + torch.Tensor([1.0, 0.0, 1.0, 1.0]), + (4, 2), +] class TestHandlerMeanDice(unittest.TestCase): # TODO test multi node averaged dice - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_compute(self, input_params, expected_avg, details_shape): dice_metric = MeanDice(**input_params) # set up engine @@ -46,8 +51,7 @@ def _val_func(engine, batch): engine.fire_event(Events.ITERATION_COMPLETED) engine.fire_event(Events.EPOCH_COMPLETED) - - self.assertAlmostEqual(engine.state.metrics["mean_dice"], expected_avg, places=4) + torch.testing.assert_allclose(engine.state.metrics["mean_dice"], expected_avg) self.assertTupleEqual(tuple(engine.state.metric_details["mean_dice"].shape), details_shape) @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) diff --git a/tests/test_handler_metric_logger.py b/tests/test_handler_metric_logger.py index 5812605cd7..c3de866c5c 100644 --- a/tests/test_handler_metric_logger.py +++ b/tests/test_handler_metric_logger.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_metrics_saver.py b/tests/test_handler_metrics_saver.py index 17c23be274..1f65eb46dc 100644 --- a/tests/test_handler_metrics_saver.py +++ b/tests/test_handler_metrics_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_metrics_saver_dist.py b/tests/test_handler_metrics_saver_dist.py index 0a36a19c66..06dcbafa28 100644 --- a/tests/test_handler_metrics_saver_dist.py +++ b/tests/test_handler_metrics_saver_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -52,10 +52,7 @@ def _val_func(engine, batch): @engine.on(Events.EPOCH_COMPLETED) def _save_metrics0(engine): engine.state.metrics = {"metric1": 1, "metric2": 2} - engine.state.metric_details = { - "metric3": torch.tensor([[1, 2]]), - "metric4": torch.tensor([[5, 6]]), - } + engine.state.metric_details = {"metric3": torch.tensor([[1, 2]]), "metric4": torch.tensor([[5, 6]])} if dist.get_rank() == 1: # different ranks have different data length diff --git a/tests/test_handler_mlflow.py b/tests/test_handler_mlflow.py new file mode 100644 index 0000000000..05b99b0053 --- /dev/null +++ b/tests/test_handler_mlflow.py @@ -0,0 +1,51 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import os +import tempfile +import unittest +from pathlib import Path + +from ignite.engine import Engine, Events + +from monai.handlers import MLFlowHandler + + +class TestHandlerMLFlow(unittest.TestCase): + def test_metrics_track(self): + with tempfile.TemporaryDirectory() as tempdir: + + # set up engine + def _train_func(engine, batch): + return [batch + 1.0] + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get("acc", 0.1) + engine.state.metrics["acc"] = current_metric + 0.1 + engine.state.test = current_metric + + # set up testing handler + test_path = os.path.join(tempdir, "mlflow_test") + handler = MLFlowHandler(tracking_uri=Path(test_path).as_uri(), state_attributes=["test"]) + handler.attach(engine) + engine.run(range(3), max_epochs=2) + handler.close() + # check logging output + self.assertTrue(len(glob.glob(test_path)) > 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_nvtx.py b/tests/test_handler_nvtx.py index fee29af344..eeca15ea6f 100644 --- a/tests/test_handler_nvtx.py +++ b/tests/test_handler_nvtx.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,42 +22,18 @@ _, has_nvtx = optional_import("torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?") -TENSOR_0 = torch.tensor( - [ - [ - [[1.0], [2.0]], - [[3.0], [4.0]], - ] - ] -) +TENSOR_0 = torch.tensor([[[[1.0], [2.0]], [[3.0], [4.0]]]]) -TENSOR_1 = torch.tensor( - [ - [ - [[0.0], [-2.0]], - [[-3.0], [4.0]], - ] - ] -) +TENSOR_1 = torch.tensor([[[[0.0], [-2.0]], [[-3.0], [4.0]]]]) -TENSOR_1_EXPECTED = torch.tensor( - [ - [[1.0], [0.5]], - [[0.25], [5.0]], - ] -) +TENSOR_1_EXPECTED = torch.tensor([[[1.0], [0.5]], [[0.25], [5.0]]]) TEST_CASE_0 = [[{"image": TENSOR_0}], TENSOR_0[0] + 1.0] TEST_CASE_1 = [[{"image": TENSOR_1}], TENSOR_1_EXPECTED] class TestHandlerDecollateBatch(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_0, - TEST_CASE_1, - ] - ) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) @unittest.skipUnless(has_nvtx, "CUDA is required for NVTX!") def test_compute(self, data, expected): # Set up handlers diff --git a/tests/test_handler_parameter_scheduler.py b/tests/test_handler_parameter_scheduler.py index 5b3e845ace..72742f1956 100644 --- a/tests/test_handler_parameter_scheduler.py +++ b/tests/test_handler_parameter_scheduler.py @@ -1,3 +1,14 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest import torch @@ -9,7 +20,7 @@ class ToyNet(Module): def __init__(self, value): - super(ToyNet, self).__init__() + super().__init__() self.value = value def forward(self, input): diff --git a/tests/test_handler_post_processing.py b/tests/test_handler_post_processing.py index e9d57128cb..89087e1765 100644 --- a/tests/test_handler_post_processing.py +++ b/tests/test_handler_post_processing.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -26,7 +26,7 @@ "transform": Compose( [ CopyItemsd(keys="filename", times=1, names="filename_bak"), - AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, num_classes=2), + AsDiscreted(keys="pred", threshold=0.5, to_onehot=2), ] ), "event": "iteration_completed", diff --git a/tests/test_handler_prob_map_producer.py b/tests/test_handler_prob_map_producer.py index b21cf03171..b3f79cf587 100644 --- a/tests/test_handler_prob_map_producer.py +++ b/tests/test_handler_prob_map_producer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -32,12 +32,7 @@ class TestDataset(Dataset): def __init__(self, name, size): super().__init__( data=[ - { - "name": name, - "mask_shape": (size, size), - "mask_locations": [[i, i] for i in range(size)], - "level": 0, - } + {"name": name, "mask_shape": (size, size), "mask_locations": [[i, i] for i in range(size)], "level": 0} ] ) self.len = size @@ -46,11 +41,7 @@ def __len__(self): return self.len def __getitem__(self, index): - return { - "name": self.data[0]["name"], - "mask_location": self.data[0]["mask_locations"][index], - "pred": index + 1, - } + return {"name": self.data[0]["name"], "mask_location": self.data[0]["mask_locations"][index], "pred": index + 1} class TestEvaluator(Evaluator): @@ -59,13 +50,7 @@ def _iteration(self, engine, batchdata): class TestHandlerProbMapGenerator(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - ] - ) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) def test_prob_map_generator(self, name, size): # set up dataset dataset = TestDataset(name, size) diff --git a/tests/test_handler_regression_metrics.py b/tests/test_handler_regression_metrics.py index 7bb72dd5d5..c6af76a4db 100644 --- a/tests/test_handler_regression_metrics.py +++ b/tests/test_handler_regression_metrics.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_regression_metrics_dist.py b/tests/test_handler_regression_metrics_dist.py index c336ccf28c..a8b644d550 100644 --- a/tests/test_handler_regression_metrics_dist.py +++ b/tests/test_handler_regression_metrics_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_rocauc.py b/tests/test_handler_rocauc.py index 5b80bc43eb..6e2d6be27e 100644 --- a/tests/test_handler_rocauc.py +++ b/tests/test_handler_rocauc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,7 +22,7 @@ class TestHandlerROCAUC(unittest.TestCase): def test_compute(self): auc_metric = ROCAUC() act = Activations(softmax=True) - to_onehot = AsDiscrete(to_onehot=True, num_classes=2) + to_onehot = AsDiscrete(to_onehot=2) y_pred = [torch.Tensor([0.1, 0.9]), torch.Tensor([0.3, 1.4])] y = [torch.Tensor([0]), torch.Tensor([1])] diff --git a/tests/test_handler_rocauc_dist.py b/tests/test_handler_rocauc_dist.py index 8316d4c4b6..5113911d7c 100644 --- a/tests/test_handler_rocauc_dist.py +++ b/tests/test_handler_rocauc_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -26,7 +26,7 @@ class DistributedROCAUC(DistTestCase): def test_compute(self): auc_metric = ROCAUC() act = Activations(softmax=True) - to_onehot = AsDiscrete(to_onehot=True, num_classes=2) + to_onehot = AsDiscrete(to_onehot=2) device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu" if dist.get_rank() == 0: diff --git a/tests/test_handler_segmentation_saver.py b/tests/test_handler_segmentation_saver.py index 78dea0a68b..3632a98cfc 100644 --- a/tests/test_handler_segmentation_saver.py +++ b/tests/test_handler_segmentation_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_smartcache.py b/tests/test_handler_smartcache.py index b67f1226cd..ec96d47e3d 100644 --- a/tests/test_handler_smartcache.py +++ b/tests/test_handler_smartcache.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,13 +24,7 @@ class TestHandlerSmartCache(unittest.TestCase): def test_content(self): data = [0, 1, 2, 3, 4, 5, 6, 7, 8] - expected = [ - [0, 1, 2, 3, 4], - [1, 2, 3, 4, 5], - [2, 3, 4, 5, 6], - [3, 4, 5, 6, 7], - [4, 5, 6, 7, 8], - ] + expected = [[0, 1, 2, 3, 4], [1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7], [4, 5, 6, 7, 8]] # set up engine def _train_func(engine, batch): diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index 84cdef59a8..aa6bc427e1 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -140,7 +140,7 @@ def _train_func(engine, batch): engine.run(range(3), max_epochs=2) handler.close() stats_handler.logger.removeHandler(handler) - with open(filename, "r") as f: + with open(filename) as f: output_str = f.read() grep = re.compile(f".*{key_to_handler}.*") has_key_word = re.compile(f".*{key_to_print}.*") @@ -163,6 +163,45 @@ def _train_func(engine, batch): with self.assertRaises(RuntimeError): engine.run(range(3), max_epochs=2) + def test_attributes_print(self): + log_stream = StringIO() + log_handler = logging.StreamHandler(log_stream) + log_handler.setLevel(logging.INFO) + key_to_handler = "test_logging" + + # set up engine + def _train_func(engine, batch): + return [torch.tensor(0.0)] + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + if not hasattr(engine.state, "test1"): + engine.state.test1 = 0.1 + engine.state.test2 = 0.2 + else: + engine.state.test1 += 0.1 + engine.state.test2 += 0.2 + + # set up testing handler + stats_handler = StatsHandler( + name=key_to_handler, state_attributes=["test1", "test2", "test3"], logger_handler=log_handler + ) + stats_handler.attach(engine) + + engine.run(range(3), max_epochs=2) + + # check logging output + output_str = log_stream.getvalue() + log_handler.close() + grep = re.compile(f".*{key_to_handler}.*") + has_key_word = re.compile(".*State values.*") + for idx, line in enumerate(output_str.split("\n")): + if grep.match(line) and idx in [5, 10]: + self.assertTrue(has_key_word.match(line)) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_handler_surface_distance.py b/tests/test_handler_surface_distance.py index 82cdb50d90..c990181998 100644 --- a/tests/test_handler_surface_distance.py +++ b/tests/test_handler_surface_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,9 +20,7 @@ def create_spherical_seg_3d( - radius: float = 20.0, - centre: Tuple[int, int, int] = (49, 49, 49), - im_shape: Tuple[int, int, int] = (99, 99, 99), + radius: float = 20.0, centre: Tuple[int, int, int] = (49, 49, 49), im_shape: Tuple[int, int, int] = (99, 99, 99) ) -> np.ndarray: """ Return a 3D image with a sphere inside. Voxel values will be @@ -91,6 +89,21 @@ def test_shape_mismatch(self): y = torch.ones((1, 1, 10, 10, 10)) sur_metric.update([y_pred, y]) + def test_reduction(self): + sur_metric = SurfaceDistance(include_background=True, reduction="mean_channel") + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + sur_metric.attach(engine, "surface_distance") + + y_pred, y = TEST_SAMPLE_1 + sur_metric.update([y_pred, y]) + y_pred, y = TEST_SAMPLE_2 + sur_metric.update([y_pred, y]) + torch.testing.assert_allclose(sur_metric.compute().float(), torch.tensor([4.1713, 0.0000])) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_handler_tb_image.py b/tests/test_handler_tb_image.py index b5d963eedf..d11bbfec59 100644 --- a/tests/test_handler_tb_image.py +++ b/tests/test_handler_tb_image.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_handler_tb_stats.py b/tests/test_handler_tb_stats.py index 1d722e7f66..eae60f9b09 100644 --- a/tests/test_handler_tb_stats.py +++ b/tests/test_handler_tb_stats.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -57,11 +57,15 @@ def _train_func(engine, batch): def _update_metric(engine): current_metric = engine.state.metrics.get("acc", 0.1) engine.state.metrics["acc"] = current_metric + 0.1 + engine.state.test = current_metric # set up testing handler writer = SummaryWriter(log_dir=tempdir) stats_handler = TensorBoardStatsHandler( - writer, output_transform=lambda x: {"loss": x[0] * 2.0}, global_epoch_transform=lambda x: x * 3.0 + summary_writer=writer, + output_transform=lambda x: {"loss": x[0] * 2.0}, + global_epoch_transform=lambda x: x * 3.0, + state_attributes=["test"], ) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) diff --git a/tests/test_handler_transform_inverter.py b/tests/test_handler_transform_inverter.py deleted file mode 100644 index 385311eba7..0000000000 --- a/tests/test_handler_transform_inverter.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -import unittest - -import numpy as np -import torch -from ignite.engine import Engine - -from monai.data import CacheDataset, DataLoader, create_test_image_3d, decollate_batch -from monai.engines.utils import IterationEvents -from monai.handlers import TransformInverter -from monai.transforms import ( - AddChanneld, - CastToTyped, - Compose, - CopyItemsd, - LoadImaged, - Orientationd, - RandAffined, - RandAxisFlipd, - RandFlipd, - RandRotate90d, - RandRotated, - RandZoomd, - ResizeWithPadOrCropd, - ScaleIntensityd, - Spacingd, - ToTensord, -) -from monai.utils.misc import set_determinism -from tests.utils import make_nifti_image - -KEYS = ["image", "label"] - - -class TestTransformInverter(unittest.TestCase): - def test_invert(self): - set_determinism(seed=0) - im_fname, seg_fname = [make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)] - transform = Compose( - [ - LoadImaged(KEYS), - AddChanneld(KEYS), - Orientationd(KEYS, "RPS"), - Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), - ScaleIntensityd("image", minv=1, maxv=10), - RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), - RandAxisFlipd(KEYS, prob=0.5), - RandRotate90d(KEYS, spatial_axes=(1, 2)), - RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), - RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), - RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), - ResizeWithPadOrCropd(KEYS, 100), - ToTensord("image"), # test to support both Tensor and Numpy array when inverting - CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), - CopyItemsd("label", times=2, names=["label_inverted1", "label_inverted2"]), - CopyItemsd("image", times=2, names=["image_inverted1", "image_inverted2"]), - ] - ) - data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] - - # num workers = 0 for mac or gpu transforms - num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available() else 2 - - dataset = CacheDataset(data, transform=transform, progress=False) - loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) - - # set up engine - def _train_func(engine, batch): - self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100)) - engine.state.output = engine.state.batch = decollate_batch(batch) - engine.fire_event(IterationEvents.MODEL_COMPLETED) - return engine.state.output - - engine = Engine(_train_func) - engine.register_events(*IterationEvents) - - # set up testing handler - TransformInverter( - transform=transform, - output_keys=["image_inverted1", "label_inverted1"], - batch_keys="label", - meta_keys=["image_inverted1_meta_dict", "label_inverted1_meta_dict"], - batch_meta_keys="label_meta_dict", - nearest_interp=True, - to_tensor=[True, False], - device="cpu", - ).attach(engine) - - # test different nearest interpolation values - TransformInverter( - transform=transform, - output_keys=["image_inverted2", "label_inverted2"], - batch_keys="image", - meta_keys=None, - batch_meta_keys="image_meta_dict", - meta_key_postfix="meta_dict", - nearest_interp=[True, False], - post_func=[lambda x: x + 10, lambda x: x], - ).attach(engine) - - engine.run(loader, max_epochs=1) - set_determinism(seed=None) - - for output in engine.state.output: - self.assertTupleEqual(output["image"].shape, (1, 100, 100, 100)) - self.assertTupleEqual(output["label"].shape, (1, 100, 100, 100)) - # check the nearest inerpolation mode - i = output["image_inverted1"] - torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float)) - self.assertTupleEqual(i.shape, (1, 100, 101, 107)) - i = output["label_inverted1"] - np.testing.assert_allclose(i.astype(np.uint8).astype(np.float32), i.astype(np.float32)) - self.assertTupleEqual(i.shape, (1, 100, 101, 107)) - - # check the case that different items use different interpolation mode to invert transforms - d = output["image_inverted2"] - # if the interpolation mode is nearest, accumulated diff should be smaller than 1 - self.assertLess(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) - self.assertTupleEqual(d.shape, (1, 100, 101, 107)) - - d = output["label_inverted2"] - # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 - self.assertGreater(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) - self.assertTupleEqual(d.shape, (1, 100, 101, 107)) - - # check labels match - reverted = engine.state.output[-1]["label_inverted1"].astype(np.int32) - original = LoadImaged(KEYS)(data[-1])["label"] - n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) - reverted_name = engine.state.batch[-1]["label_inverted1_meta_dict"]["filename_or_obj"] - original_name = data[-1]["label"] - self.assertEqual(reverted_name, original_name) - print("invert diff", reverted.size - n_good) - # 25300: 2 workers (cpu, non-macos) - # 1812: 0 workers (gpu or macos) - # 1824: torch 1.5.1 - self.assertTrue((reverted.size - n_good) in (25300, 1812, 1824), "diff. in 3 possible values") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_handler_validation.py b/tests/test_handler_validation.py index 06f400109d..42ffc8b9eb 100644 --- a/tests/test_handler_validation.py +++ b/tests/test_handler_validation.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_hashing.py b/tests/test_hashing.py index ca317a72e8..5a1265bd48 100644 --- a/tests/test_hashing.py +++ b/tests/test_hashing.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_hausdorff_distance.py b/tests/test_hausdorff_distance.py index 0b313f722f..79a2c84b37 100644 --- a/tests/test_hausdorff_distance.py +++ b/tests/test_hausdorff_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,9 +20,7 @@ def create_spherical_seg_3d( - radius: float = 20.0, - centre: Tuple[int, int, int] = (49, 49, 49), - im_shape: Tuple[int, int, int] = (99, 99, 99), + radius: float = 20.0, centre: Tuple[int, int, int] = (49, 49, 49), im_shape: Tuple[int, int, int] = (99, 99, 99) ) -> np.ndarray: """ Return a 3D image with a sphere inside. Voxel values will be @@ -49,10 +47,7 @@ def create_spherical_seg_3d( TEST_CASES = [ - [ - [create_spherical_seg_3d(), create_spherical_seg_3d(), 1], - [0, 0, 0, 0, 0, 0], - ], + [[create_spherical_seg_3d(), create_spherical_seg_3d(), 1], [0, 0, 0, 0, 0, 0]], [ [ create_spherical_seg_3d(radius=20, centre=(20, 20, 20)), @@ -106,8 +101,8 @@ def create_spherical_seg_3d( # both pred and gt do not have foreground, metric and not_nans should be 0 np.zeros([99, 99, 99]), np.zeros([99, 99, 99]), - ], - ], + ] + ] ] diff --git a/tests/test_header_correct.py b/tests/test_header_correct.py index 4a8927fa80..aa0a4dde08 100644 --- a/tests/test_header_correct.py +++ b/tests/test_header_correct.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_highresnet.py b/tests/test_highresnet.py index 61af529b63..76c2203431 100644 --- a/tests/test_highresnet.py +++ b/tests/test_highresnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_hilbert_transform.py b/tests/test_hilbert_transform.py index 82454c34d0..10aa83293f 100644 --- a/tests/test_hilbert_transform.py +++ b/tests/test_hilbert_transform.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -28,10 +28,7 @@ def create_expected_numpy_output(input_datum, **kwargs): - x = np.fft.fft( - input_datum.cpu().numpy() if input_datum.device.type == "cuda" else input_datum.numpy(), - **kwargs, - ) + x = np.fft.fft(input_datum.cpu().numpy() if input_datum.device.type == "cuda" else input_datum.numpy(), **kwargs) f = np.fft.fftfreq(x.shape[kwargs["axis"]]) u = np.heaviside(f, 0.5) new_dims_before = kwargs["axis"] diff --git a/tests/test_histogram_normalize.py b/tests/test_histogram_normalize.py index b69fb1d927..95aa37f26e 100644 --- a/tests/test_histogram_normalize.py +++ b/tests/test_histogram_normalize.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,32 +15,42 @@ from parameterized import parameterized from monai.transforms import HistogramNormalize - -TEST_CASE_1 = [ - {"num_bins": 4, "min": 1, "max": 5, "mask": np.array([1, 1, 1, 1, 1, 0])}, - np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), - np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0]), -] - -TEST_CASE_2 = [ - {"num_bins": 4, "max": 4, "dtype": np.uint8}, - np.array([0.0, 1.0, 2.0, 3.0, 4.0]), - np.array([0, 0, 1, 3, 4]), -] - -TEST_CASE_3 = [ - {"num_bins": 256, "max": 255, "dtype": np.uint8}, - np.array([[[100.0, 200.0], [150.0, 250.0]]]), - np.array([[[0, 170], [70, 255]]]), -] +from monai.utils import get_equivalent_dtype +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"num_bins": 4, "min": 1, "max": 5, "mask": np.array([1, 1, 1, 1, 1, 0])}, + p(np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])), + p(np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0])), + ] + ) + + TESTS.append( + [ + {"num_bins": 4, "max": 4, "dtype": np.uint8}, + p(np.array([0.0, 1.0, 2.0, 3.0, 4.0])), + p(np.array([0, 0, 1, 3, 4])), + ] + ) + + TESTS.append( + [ + {"num_bins": 256, "max": 255, "dtype": np.uint8}, + p(np.array([[[100.0, 200.0], [150.0, 250.0]]])), + p(np.array([[[0, 170], [70, 255]]])), + ] + ) class TestHistogramNormalize(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = HistogramNormalize(**argments)(image) - np.testing.assert_allclose(result, expected_data) - self.assertEqual(result.dtype, argments.get("dtype", np.float32)) + assert_allclose(result, expected_data) + self.assertEqual(get_equivalent_dtype(result.dtype, data_type=np.ndarray), argments.get("dtype", np.float32)) if __name__ == "__main__": diff --git a/tests/test_histogram_normalized.py b/tests/test_histogram_normalized.py index 68647e82fb..7b86a9685f 100644 --- a/tests/test_histogram_normalized.py +++ b/tests/test_histogram_normalized.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,32 +15,42 @@ from parameterized import parameterized from monai.transforms import HistogramNormalized - -TEST_CASE_1 = [ - {"keys": "img", "num_bins": 4, "min": 1, "max": 5, "mask_key": "mask"}, - {"img": np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), "mask": np.array([1, 1, 1, 1, 1, 0])}, - np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0]), -] - -TEST_CASE_2 = [ - {"keys": "img", "num_bins": 4, "max": 4, "dtype": np.uint8}, - {"img": np.array([0.0, 1.0, 2.0, 3.0, 4.0])}, - np.array([0, 0, 1, 3, 4]), -] - -TEST_CASE_3 = [ - {"keys": "img", "num_bins": 256, "max": 255, "dtype": np.uint8}, - {"img": np.array([[[100.0, 200.0], [150.0, 250.0]]])}, - np.array([[[0, 170], [70, 255]]]), -] +from monai.utils import get_equivalent_dtype +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": "img", "num_bins": 4, "min": 1, "max": 5, "mask_key": "mask"}, + {"img": p(np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])), "mask": p(np.array([1, 1, 1, 1, 1, 0]))}, + p(np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0])), + ] + ) + + TESTS.append( + [ + {"keys": "img", "num_bins": 4, "max": 4, "dtype": np.uint8}, + {"img": p(np.array([0.0, 1.0, 2.0, 3.0, 4.0]))}, + p(np.array([0, 0, 1, 3, 4])), + ] + ) + + TESTS.append( + [ + {"keys": "img", "num_bins": 256, "max": 255, "dtype": np.uint8}, + {"img": p(np.array([[[100.0, 200.0], [150.0, 250.0]]]))}, + p(np.array([[[0, 170], [70, 255]]])), + ] + ) class TestHistogramNormalized(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = HistogramNormalized(**argments)(image)["img"] - np.testing.assert_allclose(result, expected_data) - self.assertEqual(result.dtype, argments.get("dtype", np.float32)) + assert_allclose(result, expected_data) + self.assertEqual(get_equivalent_dtype(result.dtype, data_type=np.ndarray), argments.get("dtype", np.float32)) if __name__ == "__main__": diff --git a/tests/test_identity.py b/tests/test_identity.py index 172860668c..60134c24a4 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_identityd.py b/tests/test_identityd.py index 665b7d5d1c..2df74ba2c6 100644 --- a/tests/test_identityd.py +++ b/tests/test_identityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_image_dataset.py b/tests/test_image_dataset.py index 3b3c06c87c..c478f28d13 100644 --- a/tests/test_image_dataset.py +++ b/tests/test_image_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_img2tensorboard.py b/tests/test_img2tensorboard.py index bd0369868e..58c4d3cfab 100644 --- a/tests/test_img2tensorboard.py +++ b/tests/test_img2tensorboard.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,29 +22,21 @@ class TestImg2Tensorboard(unittest.TestCase): def test_write_gray(self): nparr = np.ones(shape=(1, 32, 32, 32), dtype=np.float32) summary_object_np = make_animated_gif_summary( - tag="test_summary_nparr.png", - image=nparr, - max_out=1, - animation_axes=(3,), - image_axes=(1, 2), - scale_factor=253.0, + tag="test_summary_nparr.png", image=nparr, max_out=1, scale_factor=253.0 ) - assert isinstance( - summary_object_np, tensorboard.compat.proto.summary_pb2.Summary - ), "make_animated_gif_summary must return a tensorboard.summary object from numpy array" + for s in summary_object_np: + assert isinstance( + s, tensorboard.compat.proto.summary_pb2.Summary + ), "make_animated_gif_summary must return a tensorboard.summary object from numpy array" tensorarr = torch.tensor(nparr) summary_object_tensor = make_animated_gif_summary( - tag="test_summary_tensorarr.png", - image=tensorarr, - max_out=1, - animation_axes=(3,), - image_axes=(1, 2), - scale_factor=253.0, + tag="test_summary_tensorarr.png", image=tensorarr, max_out=1, frame_dim=-1, scale_factor=253.0 ) - assert isinstance( - summary_object_tensor, tensorboard.compat.proto.summary_pb2.Summary - ), "make_animated_gif_summary must return a tensorboard.summary object from tensor input" + for s in summary_object_tensor: + assert isinstance( + s, tensorboard.compat.proto.summary_pb2.Summary + ), "make_animated_gif_summary must return a tensorboard.summary object from tensor input" if __name__ == "__main__": diff --git a/tests/test_init_reader.py b/tests/test_init_reader.py index d6737c26ca..03a63cc375 100644 --- a/tests/test_init_reader.py +++ b/tests/test_init_reader.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,6 +13,7 @@ from monai.data import ITKReader, NibabelReader, NumpyReader, PILReader from monai.transforms import LoadImage, LoadImaged +from tests.utils import SkipIfNoModule class TestInitLoadImage(unittest.TestCase): @@ -26,6 +27,9 @@ def test_load_image(self): inst = LoadImaged("image", reader=r) self.assertIsInstance(inst, LoadImaged) + @SkipIfNoModule("itk") + @SkipIfNoModule("nibabel") + @SkipIfNoModule("PIL") def test_readers(self): inst = ITKReader() self.assertIsInstance(inst, ITKReader) diff --git a/tests/test_integration_classification_2d.py b/tests/test_integration_classification_2d.py index 03b5571973..2c0c9e1f2e 100644 --- a/tests/test_integration_classification_2d.py +++ b/tests/test_integration_classification_2d.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -80,7 +80,7 @@ def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0", [LoadImage(image_only=True), AddChannel(), Transpose(indices=[0, 2, 1]), ScaleIntensity(), ToTensor()] ) y_pred_trans = Compose([ToTensor(), Activations(softmax=True)]) - y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=True, num_classes=len(np.unique(train_y)))]) + y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=len(np.unique(train_y)))]) auc_metric = ROCAUCMetric() # create train, val data loaders @@ -197,7 +197,7 @@ def setUp(self): assert os.path.exists(data_dir) - class_names = sorted((x for x in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, x)))) + class_names = sorted(x for x in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, x))) image_files = [ [os.path.join(data_dir, class_name, x) for x in sorted(os.listdir(os.path.join(data_dir, class_name)))] for class_name in class_names diff --git a/tests/test_integration_determinism.py b/tests/test_integration_determinism.py index e077420420..97e510d03c 100644 --- a/tests/test_integration_determinism.py +++ b/tests/test_integration_determinism.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -41,7 +41,7 @@ def __len__(self): return train_steps net = UNet( - dimensions=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2 + spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2 ).to(device) loss = DiceLoss(sigmoid=True) @@ -81,7 +81,7 @@ def test_training(self): loss, step = run_test(device=self.device) print(f"Deterministic loss {loss} at training step {step}") np.testing.assert_allclose(step, 4) - np.testing.assert_allclose(loss, 0.535927, rtol=1e-4) + np.testing.assert_allclose(loss, 0.536134, rtol=1e-4) if __name__ == "__main__": diff --git a/tests/test_integration_fast_train.py b/tests/test_integration_fast_train.py new file mode 100644 index 0000000000..3522a57342 --- /dev/null +++ b/tests/test_integration_fast_train.py @@ -0,0 +1,234 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +import shutil +import tempfile +import time +import unittest +from glob import glob + +import nibabel as nib +import numpy as np +import torch + +import monai +from monai.data import CacheDataset, ThreadDataLoader, create_test_image_3d, decollate_batch +from monai.inferers import sliding_window_inference +from monai.losses import DiceCELoss +from monai.metrics import DiceMetric +from monai.networks.layers import Norm +from monai.networks.nets import UNet +from monai.optimizers import Novograd +from monai.transforms import ( + AsDiscrete, + Compose, + CropForegroundd, + EnsureChannelFirstd, + EnsureType, + EnsureTyped, + FgBgToIndicesd, + LoadImaged, + RandAffined, + RandAxisFlipd, + RandCropByPosNegLabeld, + RandFlipd, + RandGaussianNoised, + RandRotate90d, + RandRotated, + RandStdShiftIntensityd, + RandZoomd, + ScaleIntensityd, + Spacingd, + ToDeviced, +) +from monai.utils import set_determinism +from tests.utils import DistTestCase, SkipIfBeforePyTorchVersion, TimedCall, skip_if_no_cuda, skip_if_quick + + +@skip_if_no_cuda +@skip_if_quick +@SkipIfBeforePyTorchVersion((1, 7)) +class IntegrationFastTrain(DistTestCase): + def setUp(self): + set_determinism(seed=0) + monai.config.print_config() + + self.data_dir = tempfile.mkdtemp() + for i in range(41): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(self.data_dir, f"img{i:d}.nii.gz")) + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(self.data_dir, f"seg{i:d}.nii.gz")) + + def tearDown(self): + set_determinism(seed=None) + shutil.rmtree(self.data_dir) + + # test the fast training speed is as expected + @TimedCall(seconds=100, daemon=False, force_quit=False) + def test_train_timing(self): + images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz"))) + segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz"))) + train_files = [{"image": img, "label": seg} for img, seg in zip(images[:32], segs[:32])] + val_files = [{"image": img, "label": seg} for img, seg in zip(images[-9:], segs[-9:])] + + device = torch.device("cuda:0") + # define transforms for train and validation + train_transforms = Compose( + [ + LoadImaged(keys=["image", "label"]), + EnsureChannelFirstd(keys=["image", "label"]), + Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), + ScaleIntensityd(keys="image"), + CropForegroundd(keys=["image", "label"], source_key="image"), + # pre-compute foreground and background indexes + # and cache them to accelerate training + FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"), + # change to execute transforms with Tensor data + EnsureTyped(keys=["image", "label"]), + # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch + ToDeviced(keys=["image", "label"], device=device), + # randomly crop out patch samples from big + # image based on pos / neg ratio + # the image centers of negative samples + # must be in valid image area + RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=(64, 64, 64), + pos=1, + neg=1, + num_samples=4, + fg_indices_key="label_fg", + bg_indices_key="label_bg", + ), + RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]), + RandAxisFlipd(keys=["image", "label"], prob=0.5), + RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(1, 2)), + RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.2, keep_size=True), + RandRotated( + keys=["image", "label"], + prob=0.5, + range_x=np.pi / 4, + mode=("bilinear", "nearest"), + align_corners=True, + ), + RandAffined(keys=["image", "label"], prob=0.5, rotate_range=np.pi / 2, mode=("bilinear", "nearest")), + RandGaussianNoised(keys="image", prob=0.5), + RandStdShiftIntensityd(keys="image", prob=0.5, factors=0.05, nonzero=True), + ] + ) + + val_transforms = Compose( + [ + LoadImaged(keys=["image", "label"]), + EnsureChannelFirstd(keys=["image", "label"]), + Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), + ScaleIntensityd(keys="image"), + CropForegroundd(keys=["image", "label"], source_key="image"), + EnsureTyped(keys=["image", "label"]), + # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch + ToDeviced(keys=["image", "label"], device=device), + ] + ) + + max_epochs = 5 + learning_rate = 2e-4 + val_interval = 1 # do validation for every epoch + + # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training + train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=8) + val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=5) + # disable multi-workers because `ThreadDataLoader` works with multi-threads + train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=4, shuffle=True) + val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1) + + loss_function = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True) + model = UNet( + spatial_dims=3, + in_channels=1, + out_channels=2, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, + norm=Norm.BATCH, + ).to(device) + + # Novograd paper suggests to use a bigger LR than Adam, + # because Adam does normalization by element-wise second moments + optimizer = Novograd(model.parameters(), learning_rate * 10) + scaler = torch.cuda.amp.GradScaler() + + post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]) + post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) + + dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) + + best_metric = -1 + total_start = time.time() + for epoch in range(max_epochs): + epoch_start = time.time() + print("-" * 10) + print(f"epoch {epoch + 1}/{max_epochs}") + model.train() + epoch_loss = 0 + step = 0 + for batch_data in train_loader: + step_start = time.time() + step += 1 + optimizer.zero_grad() + # set AMP for training + with torch.cuda.amp.autocast(): + outputs = model(batch_data["image"]) + loss = loss_function(outputs, batch_data["label"]) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + epoch_loss += loss.item() + epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) + print( + f"{step}/{epoch_len}, train_loss: {loss.item():.4f}" f" step time: {(time.time() - step_start):.4f}" + ) + epoch_loss /= step + print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") + + if (epoch + 1) % val_interval == 0: + model.eval() + with torch.no_grad(): + for val_data in val_loader: + roi_size = (96, 96, 96) + sw_batch_size = 4 + # set AMP for validation + with torch.cuda.amp.autocast(): + val_outputs = sliding_window_inference(val_data["image"], roi_size, sw_batch_size, model) + + val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)] + val_labels = [post_label(i) for i in decollate_batch(val_data["label"])] + dice_metric(y_pred=val_outputs, y=val_labels) + + metric = dice_metric.aggregate().item() + dice_metric.reset() + if metric > best_metric: + best_metric = metric + print(f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}") + print(f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}") + + total_time = time.time() - total_start + print(f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}") + # test expected metrics + self.assertGreater(best_metric, 0.95) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_integration_segmentation_3d.py b/tests/test_integration_segmentation_3d.py index d5eb69f7af..97e339f5bb 100644 --- a/tests/test_integration_segmentation_3d.py +++ b/tests/test_integration_segmentation_3d.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -95,12 +95,12 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) - val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) + val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) # create UNet, DiceLoss and Adam optimizer model = monai.networks.nets.UNet( - dimensions=3, + spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), @@ -195,11 +195,11 @@ def run_inference_test(root_dir, device="cuda:0"): val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) - val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) + val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) model = UNet( - dimensions=3, + spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), diff --git a/tests/test_integration_sliding_window.py b/tests/test_integration_sliding_window.py index b63f331ba6..af49e3db77 100644 --- a/tests/test_integration_sliding_window.py +++ b/tests/test_integration_sliding_window.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,7 +34,7 @@ def run_test(batch_size, img_name, seg_name, output_dir, device="cuda:0"): loader = DataLoader(ds, batch_size=1, pin_memory=torch.cuda.is_available()) net = UNet( - dimensions=3, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2 + spatial_dims=3, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2 ).to(device) roi_size = (16, 32, 48) sw_batch_size = batch_size diff --git a/tests/test_integration_stn.py b/tests/test_integration_stn.py index c1fcfe7a89..ca067c4d78 100644 --- a/tests/test_integration_stn.py +++ b/tests/test_integration_stn.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function import unittest @@ -104,7 +103,7 @@ def setUp(self): def tearDown(self): set_determinism(seed=None) - @TimedCall(seconds=60) + @TimedCall(seconds=100) def test_training(self): """ check that the quality AffineTransform backpropagation diff --git a/tests/test_integration_unet_2d.py b/tests/test_integration_unet_2d.py index a46a174dc9..e60c91968a 100644 --- a/tests/test_integration_unet_2d.py +++ b/tests/test_integration_unet_2d.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -32,10 +32,10 @@ def __len__(self): return train_steps if net_name == "basicunet": - net = BasicUNet(dimensions=2, in_channels=1, out_channels=1, features=(4, 8, 8, 16, 16, 32)) + net = BasicUNet(spatial_dims=2, in_channels=1, out_channels=1, features=(4, 8, 8, 16, 16, 32)) elif net_name == "unet": net = UNet( - dimensions=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2 + spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2 ) net.to(device) diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 7fcc0b4064..adb6acfcdd 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -98,7 +98,7 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): # create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( - dimensions=3, + spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), @@ -114,7 +114,7 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): [ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), - AsDiscreted(keys="pred", threshold_values=True), + AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) @@ -155,7 +155,7 @@ def _forward_completed(self, engine): [ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), - AsDiscreted(keys="pred", threshold_values=True), + AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) @@ -230,7 +230,7 @@ def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_wor # create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( - dimensions=3, + spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), @@ -242,15 +242,10 @@ def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_wor [ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), - AsDiscreted(keys="pred", threshold_values=True), + AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), # test the case that `pred` in `engine.state.output`, while `image_meta_dict` in `engine.state.batch` - SaveImaged( - keys="pred", - meta_keys="image_meta_dict", - output_dir=root_dir, - output_postfix="seg_transform", - ), + SaveImaged(keys="pred", meta_keys="image_meta_dict", output_dir=root_dir, output_postfix="seg_transform"), ] ) val_handlers = [ @@ -351,20 +346,15 @@ def _test_saved_files(postfix): def test_training(self): repeated = [] - test_rounds = 3 if monai.utils.module.get_torch_version_tuple() >= (1, 6) else 2 + test_rounds = 3 for i in range(test_rounds): results = self.train_and_infer(idx=i) repeated.append(results) np.testing.assert_allclose(repeated[0], repeated[1]) - @TimedCall( - seconds=300, - skip_timing=not torch.cuda.is_available(), - daemon=False, - ) + @TimedCall(seconds=300, skip_timing=not torch.cuda.is_available(), daemon=False) def test_timing(self): - if monai.utils.module.get_torch_version_tuple() >= (1, 6): - self.train_and_infer(idx=2) + self.train_and_infer(idx=2) if __name__ == "__main__": diff --git a/tests/test_integration_workflows_gan.py b/tests/test_integration_workflows_gan.py index c54e8b01f2..790b222ea0 100644 --- a/tests/test_integration_workflows_gan.py +++ b/tests/test_integration_workflows_gan.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -90,8 +90,7 @@ def generator_loss(gen_images): train_handlers = [ StatsHandler( - name="training_loss", - output_transform=lambda x: {Keys.GLOSS: x[Keys.GLOSS], Keys.DLOSS: x[Keys.DLOSS]}, + name="training_loss", output_transform=lambda x: {Keys.GLOSS: x[Keys.GLOSS], Keys.DLOSS: x[Keys.DLOSS]} ), TensorBoardStatsHandler( log_dir=root_dir, diff --git a/tests/test_intensity_stats.py b/tests/test_intensity_stats.py index 059271e442..3479306180 100644 --- a/tests/test_intensity_stats.py +++ b/tests/test_intensity_stats.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,45 +15,43 @@ from parameterized import parameterized from monai.transforms import IntensityStats +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"ops": ["max", "mean"], "key_prefix": "orig"}, - np.array([[[0.0, 1.0], [2.0, 3.0]]]), - {"affine": None}, - {"orig_max": 3.0, "orig_mean": 1.5}, -] - -TEST_CASE_2 = [ - {"ops": "std", "key_prefix": "orig"}, - np.array([[[0.0, 1.0], [2.0, 3.0]]]), - None, - {"orig_std": 1.118034}, -] - -TEST_CASE_3 = [ - {"ops": [lambda x: np.mean(x), "max", lambda x: np.min(x)], "key_prefix": "orig"}, - np.array([[[0.0, 1.0], [2.0, 3.0]]]), - None, - {"orig_custom_0": 1.5, "orig_max": 3.0, "orig_custom_1": 0.0}, -] - -TEST_CASE_4 = [ - {"ops": ["max", "mean"], "key_prefix": "orig", "channel_wise": True}, - np.array([[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]]]), - {"affine": None}, - {"orig_max": [3.0, 7.0], "orig_mean": [1.5, 5.5]}, -] - -TEST_CASE_5 = [ - {"ops": ["max", "mean"], "key_prefix": "orig"}, - np.array([[[0.0, 1.0], [2.0, 3.0]]]), - {"affine": None}, - {"orig_max": 3.0, "orig_mean": 1.5}, -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.extend( + [ + [ + {"ops": ["max", "mean"], "key_prefix": "orig"}, + p([[[0.0, 1.0], [2.0, 3.0]]]), + {"affine": None}, + {"orig_max": 3.0, "orig_mean": 1.5}, + ], + [{"ops": "std", "key_prefix": "orig"}, p([[[0.0, 1.0], [2.0, 3.0]]]), None, {"orig_std": 1.118034}], + [ + {"ops": [np.mean, "max", np.min], "key_prefix": "orig"}, + p([[[0.0, 1.0], [2.0, 3.0]]]), + None, + {"orig_custom_0": 1.5, "orig_max": 3.0, "orig_custom_1": 0.0}, + ], + [ + {"ops": ["max", "mean"], "key_prefix": "orig", "channel_wise": True}, + p([[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]]]), + {"affine": None}, + {"orig_max": [3.0, 7.0], "orig_mean": [1.5, 5.5]}, + ], + [ + {"ops": ["max", "mean"], "key_prefix": "orig"}, + p([[[0.0, 1.0], [2.0, 3.0]]]), + {"affine": None}, + {"orig_max": 3.0, "orig_mean": 1.5}, + ], + ] + ) class TestIntensityStats(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_value(self, input_param, img, meta_dict, expected): _, meta_dict = IntensityStats(**input_param)(img, meta_dict) for k, v in expected.items(): @@ -61,11 +59,12 @@ def test_value(self, input_param, img, meta_dict, expected): np.testing.assert_allclose(v, meta_dict[k], atol=1e-3) def test_mask(self): - img = np.array([[[0.0, 1.0], [2.0, 3.0]]]) - mask = np.array([[[1, 0], [1, 0]]], dtype=bool) - img, meta_dict = IntensityStats(ops=["max", "mean"], key_prefix="orig")(img, mask=mask) - np.testing.assert_allclose(meta_dict["orig_max"], 2.0, atol=1e-3) - np.testing.assert_allclose(meta_dict["orig_mean"], 1.0, atol=1e-3) + for p in TEST_NDARRAYS: + img = p([[[0.0, 1.0], [2.0, 3.0]]]) + mask = np.array([[[1, 0], [1, 0]]], dtype=bool) + img, meta_dict = IntensityStats(ops=["max", "mean"], key_prefix="orig")(img, mask=mask) + np.testing.assert_allclose(meta_dict["orig_max"], 2.0, atol=1e-3) + np.testing.assert_allclose(meta_dict["orig_mean"], 1.0, atol=1e-3) if __name__ == "__main__": diff --git a/tests/test_intensity_statsd.py b/tests/test_intensity_statsd.py index 8c8bc8795a..97d3de80c0 100644 --- a/tests/test_intensity_statsd.py +++ b/tests/test_intensity_statsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,7 +34,7 @@ ] TEST_CASE_3 = [ - {"keys": "img", "ops": [lambda x: np.mean(x), "max", lambda x: np.min(x)], "key_prefix": "orig"}, + {"keys": "img", "ops": [np.mean, "max", np.min], "key_prefix": "orig"}, {"img": np.array([[[0.0, 1.0], [2.0, 3.0]]])}, "img_meta_dict", {"orig_custom_0": 1.5, "orig_max": 3.0, "orig_custom_1": 0.0}, diff --git a/tests/test_inverse.py b/tests/test_inverse.py index f2470d47fd..4455009658 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -59,13 +59,13 @@ Spacingd, SpatialCropd, SpatialPadd, + TraceableTransform, Transposed, Zoomd, allow_missing_keys_mode, convert_inverse_interp_mode, ) from monai.utils import first, get_seed, optional_import, set_determinism -from monai.utils.enums import InverseKeys from tests.utils import make_nifti_image, make_rand_affine if TYPE_CHECKING: @@ -122,23 +122,9 @@ ) ) -TESTS.append( - ( - "SpatialPadd 3d", - "3D", - 0, - SpatialPadd(KEYS, spatial_size=[112, 113, 116]), - ) -) +TESTS.append(("SpatialPadd 3d", "3D", 0, SpatialPadd(KEYS, spatial_size=[112, 113, 116]))) -TESTS.append( - ( - "SpatialCropd 2d", - "2D", - 0, - SpatialCropd(KEYS, [49, 51], [90, 89]), - ) -) +TESTS.append(("SpatialCropd 2d", "2D", 0, SpatialCropd(KEYS, [49, 51], [90, 89]))) TESTS.append( ( @@ -149,91 +135,28 @@ ) ) -TESTS.append( - ( - "SpatialCropd 2d", - "2D", - 0, - SpatialCropd(KEYS, [49, 51], [390, 89]), - ) -) +TESTS.append(("SpatialCropd 2d", "2D", 0, SpatialCropd(KEYS, [49, 51], [390, 89]))) -TESTS.append( - ( - "SpatialCropd 3d", - "3D", - 0, - SpatialCropd(KEYS, [49, 51, 44], [90, 89, 93]), - ) -) +TESTS.append(("SpatialCropd 3d", "3D", 0, SpatialCropd(KEYS, [49, 51, 44], [90, 89, 93]))) TESTS.append(("RandSpatialCropd 2d", "2D", 0, RandSpatialCropd(KEYS, [96, 93], None, True, False))) TESTS.append(("RandSpatialCropd 3d", "3D", 0, RandSpatialCropd(KEYS, [96, 93, 92], None, False, False))) -TESTS.append( - ( - "BorderPadd 2d", - "2D", - 0, - BorderPadd(KEYS, [3, 7, 2, 5]), - ) -) +TESTS.append(("BorderPadd 2d", "2D", 0, BorderPadd(KEYS, [3, 7, 2, 5]))) -TESTS.append( - ( - "BorderPadd 2d", - "2D", - 0, - BorderPadd(KEYS, [3, 7]), - ) -) +TESTS.append(("BorderPadd 2d", "2D", 0, BorderPadd(KEYS, [3, 7]))) -TESTS.append( - ( - "BorderPadd 3d", - "3D", - 0, - BorderPadd(KEYS, [4]), - ) -) +TESTS.append(("BorderPadd 3d", "3D", 0, BorderPadd(KEYS, [4]))) -TESTS.append( - ( - "DivisiblePadd 2d", - "2D", - 0, - DivisiblePadd(KEYS, k=4), - ) -) +TESTS.append(("DivisiblePadd 2d", "2D", 0, DivisiblePadd(KEYS, k=4))) -TESTS.append( - ( - "DivisiblePadd 3d", - "3D", - 0, - DivisiblePadd(KEYS, k=[4, 8, 11]), - ) -) +TESTS.append(("DivisiblePadd 3d", "3D", 0, DivisiblePadd(KEYS, k=[4, 8, 11]))) -TESTS.append( - ( - "CenterSpatialCropd 2d", - "2D", - 0, - CenterSpatialCropd(KEYS, roi_size=95), - ) -) +TESTS.append(("CenterSpatialCropd 2d", "2D", 0, CenterSpatialCropd(KEYS, roi_size=95))) -TESTS.append( - ( - "CenterSpatialCropd 3d", - "3D", - 0, - CenterSpatialCropd(KEYS, roi_size=[95, 97, 98]), - ) -) +TESTS.append(("CenterSpatialCropd 3d", "3D", 0, CenterSpatialCropd(KEYS, roi_size=[95, 97, 98]))) TESTS.append(("CropForegroundd 2d", "2D", 0, CropForegroundd(KEYS, source_key="label", margin=2))) @@ -242,69 +165,20 @@ TESTS.append(("ResizeWithPadOrCropd 3d", "3D", 0, ResizeWithPadOrCropd(KEYS, [201, 150, 105]))) -TESTS.append( - ( - "Flipd 3d", - "3D", - 0, - Flipd(KEYS, [1, 2]), - ) -) +TESTS.append(("Flipd 3d", "3D", 0, Flipd(KEYS, [1, 2]))) -TESTS.append( - ( - "RandFlipd 3d", - "3D", - 0, - RandFlipd(KEYS, 1, [1, 2]), - ) -) +TESTS.append(("RandFlipd 3d", "3D", 0, RandFlipd(KEYS, 1, [1, 2]))) -TESTS.append( - ( - "RandAxisFlipd 3d", - "3D", - 0, - RandAxisFlipd(KEYS, 1), - ) -) +TESTS.append(("RandAxisFlipd 3d", "3D", 0, RandAxisFlipd(KEYS, 1))) for acc in [True, False]: - TESTS.append( - ( - "Orientationd 3d", - "3D", - 0, - Orientationd(KEYS, "RAS", as_closest_canonical=acc), - ) - ) + TESTS.append(("Orientationd 3d", "3D", 0, Orientationd(KEYS, "RAS", as_closest_canonical=acc))) -TESTS.append( - ( - "Rotate90d 2d", - "2D", - 0, - Rotate90d(KEYS), - ) -) +TESTS.append(("Rotate90d 2d", "2D", 0, Rotate90d(KEYS))) -TESTS.append( - ( - "Rotate90d 3d", - "3D", - 0, - Rotate90d(KEYS, k=2, spatial_axes=(1, 2)), - ) -) +TESTS.append(("Rotate90d 3d", "3D", 0, Rotate90d(KEYS, k=2, spatial_axes=(1, 2)))) -TESTS.append( - ( - "RandRotate90d 3d", - "3D", - 0, - RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2)), - ) -) +TESTS.append(("RandRotate90d 3d", "3D", 0, RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2)))) TESTS.append(("Spacingd 3d", "3D", 3e-2, Spacingd(KEYS, [0.5, 0.7, 0.9], diagonal=False))) @@ -327,51 +201,18 @@ ) ) -TESTS.append( - ( - "Zoomd 1d", - "1D odd", - 0, - Zoomd(KEYS, zoom=2, keep_size=False), - ) -) +TESTS.append(("Zoomd 1d", "1D odd", 0, Zoomd(KEYS, zoom=2, keep_size=False))) -TESTS.append( - ( - "Zoomd 2d", - "2D", - 2e-1, - Zoomd(KEYS, zoom=0.9), - ) -) +TESTS.append(("Zoomd 2d", "2D", 2e-1, Zoomd(KEYS, zoom=0.9))) -TESTS.append( - ( - "Zoomd 3d", - "3D", - 3e-2, - Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False), - ) -) +TESTS.append(("Zoomd 3d", "3D", 3e-2, Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False))) TESTS.append(("RandZoom 3d", "3D", 9e-2, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) -TESTS.append( - ( - "RandRotated, prob 0", - "2D", - 0, - RandRotated(KEYS, prob=0), - ) -) +TESTS.append(("RandRotated, prob 0", "2D", 0, RandRotated(KEYS, prob=0))) TESTS.append( - ( - "Rotated 2d", - "2D", - 8e-2, - Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False), - ) + ("Rotated 2d", "2D", 8e-2, Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False)) ) TESTS.append( @@ -392,23 +233,9 @@ ) ) -TESTS.append( - ( - "Transposed 2d", - "2D", - 0, - Transposed(KEYS, [0, 2, 1]), # channel=0 - ) -) +TESTS.append(("Transposed 2d", "2D", 0, Transposed(KEYS, [0, 2, 1]))) # channel=0 -TESTS.append( - ( - "Transposed 3d", - "3D", - 0, - Transposed(KEYS, [0, 3, 1, 2]), # channel=0 - ) -) +TESTS.append(("Transposed 3d", "3D", 0, Transposed(KEYS, [0, 3, 1, 2]))) # channel=0 TESTS.append( ( @@ -444,14 +271,7 @@ ) ) -TESTS.append( - ( - "RandAffine 3d", - "3D", - 0, - RandAffined(KEYS, spatial_size=None, prob=0), - ) -) +TESTS.append(("RandAffine 3d", "3D", 0, RandAffined(KEYS, spatial_size=None, prob=0))) TESTS.append( ( @@ -462,32 +282,11 @@ ) ) -TESTS.append( - ( - "RandCropByPosNegLabeld 2d", - "2D", - 1e-7, - RandCropByPosNegLabeld(KEYS, "label", (99, 96), num_samples=10), - ) -) +TESTS.append(("RandCropByPosNegLabeld 2d", "2D", 1e-7, RandCropByPosNegLabeld(KEYS, "label", (99, 96), num_samples=10))) -TESTS.append( - ( - "RandSpatialCropSamplesd 2d", - "2D", - 1e-7, - RandSpatialCropSamplesd(KEYS, (90, 91), num_samples=10), - ) -) +TESTS.append(("RandSpatialCropSamplesd 2d", "2D", 1e-7, RandSpatialCropSamplesd(KEYS, (90, 91), num_samples=10))) -TESTS.append( - ( - "RandWeightedCropd 2d", - "2D", - 1e-7, - RandWeightedCropd(KEYS, "label", (90, 91), num_samples=10), - ) -) +TESTS.append(("RandWeightedCropd 2d", "2D", 1e-7, RandWeightedCropd(KEYS, "label", (90, 91), num_samples=10))) TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] @@ -566,8 +365,8 @@ def setUp(self): "other": np.array(im_1d, copy=True), } - im_2d_fname, seg_2d_fname = [make_nifti_image(i) for i in create_test_image_2d(101, 100)] - im_3d_fname, seg_3d_fname = [make_nifti_image(i, affine) for i in create_test_image_3d(100, 101, 107)] + im_2d_fname, seg_2d_fname = (make_nifti_image(i) for i in create_test_image_2d(101, 100)) + im_3d_fname, seg_3d_fname = (make_nifti_image(i, affine) for i in create_test_image_3d(100, 101, 107)) load_ims = Compose([LoadImaged(KEYS), AddChanneld(KEYS)]) self.all_data["2D"] = load_ims({"image": im_2d_fname, "label": seg_2d_fname}) @@ -651,27 +450,15 @@ def test_inverse_inferred_seg(self, extra_transform): batch_size = 10 # num workers = 0 for mac - num_workers = 2 if sys.platform != "darwin" else 0 - transforms = Compose( - [ - AddChanneld(KEYS), - SpatialPadd(KEYS, (150, 153)), - extra_transform, - ] - ) + num_workers = 2 if sys.platform == "linux" else 0 + transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), extra_transform]) num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform)) dataset = CacheDataset(test_data, transform=transforms, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) device = "cuda" if torch.cuda.is_available() else "cpu" - model = UNet( - dimensions=2, - in_channels=1, - out_channels=1, - channels=(2, 4), - strides=(2,), - ).to(device) + model = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(2,)).to(device) data = first(loader) self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) @@ -679,7 +466,7 @@ def test_inverse_inferred_seg(self, extra_transform): labels = data["label"].to(device) segs = model(labels).detach().cpu() - label_transform_key = "label" + InverseKeys.KEY_SUFFIX + label_transform_key = TraceableTransform.trace_key("label") segs_dict = {"label": segs, label_transform_key: data[label_transform_key]} segs_dict_decollated = decollate_batch(segs_dict) diff --git a/tests/test_inverse_collation.py b/tests/test_inverse_collation.py index c302e04017..3c293070bb 100644 --- a/tests/test_inverse_collation.py +++ b/tests/test_inverse_collation.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -48,15 +48,11 @@ for t in [ RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=KEYS, prob=0.5), - RandRotate90d(keys=KEYS, spatial_axes=(1, 2)), + Compose([RandRotate90d(keys=KEYS, spatial_axes=(1, 2)), ToTensord(keys=KEYS)]), RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(keys=KEYS, prob=0.5, range_x=np.pi), RandAffined( - keys=KEYS, - prob=0.5, - rotate_range=np.pi, - device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), - as_tensor_output=False, + keys=KEYS, prob=0.5, rotate_range=np.pi, device=torch.device("cuda" if torch.cuda.is_available() else "cpu") ), ] ] @@ -67,15 +63,11 @@ for t in [ RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1]), RandAxisFlipd(keys=KEYS, prob=0.5), - RandRotate90d(keys=KEYS, prob=0.5, spatial_axes=(0, 1)), + Compose([RandRotate90d(keys=KEYS, prob=0.5, spatial_axes=(0, 1)), ToTensord(keys=KEYS)]), RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(keys=KEYS, prob=0.5, range_x=np.pi), RandAffined( - keys=KEYS, - prob=0.5, - rotate_range=np.pi, - device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), - as_tensor_output=False, + keys=KEYS, prob=0.5, rotate_range=np.pi, device=torch.device("cuda" if torch.cuda.is_available() else "cpu") ), ] ] @@ -91,12 +83,12 @@ def setUp(self): set_determinism(seed=0) b_size = 11 - im_fname, seg_fname = [make_nifti_image(i) for i in create_test_image_3d(101, 100, 107)] + im_fname, seg_fname = (make_nifti_image(i) for i in create_test_image_3d(101, 100, 107)) load_ims = Compose([LoadImaged(KEYS), AddChanneld(KEYS)]) self.data_3d = [load_ims({"image": im_fname, "label": seg_fname}) for _ in range(b_size)] b_size = 8 - im_fname, seg_fname = [make_nifti_image(i) for i in create_test_image_2d(62, 37, rad_max=10)] + im_fname, seg_fname = (make_nifti_image(i) for i in create_test_image_2d(62, 37, rad_max=10)) load_ims = Compose([LoadImaged(KEYS), AddChanneld(KEYS)]) self.data_2d = [load_ims({"image": im_fname, "label": seg_fname}) for _ in range(b_size)] @@ -107,17 +99,14 @@ def tearDown(self): @parameterized.expand(TESTS_2D + TESTS_3D) def test_collation(self, _, transform, collate_fn, ndim): - if ndim == 3: - data = self.data_3d - else: - data = self.data_2d + data = self.data_3d if ndim == 3 else self.data_2d if collate_fn: modified_transform = transform else: modified_transform = Compose([transform, ResizeWithPadOrCropd(KEYS, 100), ToTensord(KEYS)]) # num workers = 0 for mac or gpu transforms - num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available() else 2 + num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2 dataset = CacheDataset(data, transform=modified_transform, progress=False) loader = DataLoader(dataset, num_workers, batch_size=self.batch_size, collate_fn=collate_fn) diff --git a/tests/test_invertd.py b/tests/test_invertd.py index 5b98653f0a..1dd4e2eecf 100644 --- a/tests/test_invertd.py +++ b/tests/test_invertd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,8 +34,9 @@ ResizeWithPadOrCropd, ScaleIntensityd, Spacingd, + ToTensord, ) -from monai.utils.misc import set_determinism +from monai.utils import set_determinism from tests.utils import make_nifti_image KEYS = ["image", "label"] @@ -44,7 +45,7 @@ class TestInvertd(unittest.TestCase): def test_invert(self): set_determinism(seed=0) - im_fname, seg_fname = [make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)] + im_fname, seg_fname = (make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)) transform = Compose( [ LoadImaged(KEYS), @@ -63,47 +64,106 @@ def test_invert(self): CopyItemsd("image_meta_dict", times=1, names="test_dict"), # test to support Tensor, Numpy array and dictionary when inverting EnsureTyped(keys=["image", "test_dict"]), + ToTensord("image"), CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), - CopyItemsd("label", times=1, names="label_inverted"), + CopyItemsd("label", times=2, names=["label_inverted", "label_inverted1"]), + CopyItemsd("image", times=2, names=["image_inverted", "image_inverted1"]), ] ) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms - num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available() else 2 + num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) inverter = Invertd( # `image` was not copied, invert the original value directly - keys=["image", "label_inverted", "test_dict"], + keys=["image_inverted", "label_inverted", "test_dict"], transform=transform, orig_keys=["label", "label", "test_dict"], - meta_keys=["image_meta_dict", "label_inverted_meta_dict", None], + meta_keys=["image_inverted_meta_dict", "label_inverted_meta_dict", None], orig_meta_keys=["label_meta_dict", "label_meta_dict", None], nearest_interp=True, to_tensor=[True, False, False], device="cpu", ) + inverter_1 = Invertd( + # `image` was not copied, invert the original value directly + keys=["image_inverted1", "label_inverted1"], + transform=transform, + orig_keys=["image", "image"], + meta_keys=["image_inverted1_meta_dict", "label_inverted1_meta_dict"], + orig_meta_keys=["image_meta_dict", "image_meta_dict"], + nearest_interp=[True, False], + to_tensor=[True, True], + device="cpu", + ) + + expected_keys = [ + "image", + "image_inverted", + "image_inverted1", + "image_inverted1_meta_dict", + "image_inverted_meta_dict", + "image_meta_dict", + "image_transforms", + "label", + "label_inverted", + "label_inverted1", + "label_inverted1_meta_dict", + "label_inverted_meta_dict", + "label_meta_dict", + "label_transforms", + "test_dict", + "test_dict_transforms", + ] # execute 1 epoch for d in loader: d = decollate_batch(d) for item in d: item = inverter(item) - # this unit test only covers basic function, test_handler_transform_inverter covers more + item = inverter_1(item) + + self.assertListEqual(sorted(item), expected_keys) + self.assertTupleEqual(item["image"].shape[1:], (100, 100, 100)) self.assertTupleEqual(item["label"].shape[1:], (100, 100, 100)) - # check the nearest inerpolation mode - i = item["image"] + # check the nearest interpolation mode + i = item["image_inverted"] torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) i = item["label_inverted"] - np.testing.assert_allclose(i.astype(np.uint8).astype(np.float32), i.astype(np.float32)) + torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) # test inverted test_dict self.assertTrue(isinstance(item["test_dict"]["affine"], np.ndarray)) self.assertTrue(isinstance(item["test_dict"]["filename_or_obj"], str)) + # check the case that different items use different interpolation mode to invert transforms + d = item["image_inverted1"] + # if the interpolation mode is nearest, accumulated diff should be smaller than 1 + self.assertLess(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) + self.assertTupleEqual(d.shape, (1, 100, 101, 107)) + + d = item["label_inverted1"] + # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 + self.assertGreater(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) + self.assertTupleEqual(d.shape, (1, 100, 101, 107)) + + # check labels match + reverted = item["label_inverted"].detach().cpu().numpy().astype(np.int32) + original = LoadImaged(KEYS)(data[-1])["label"] + n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) + reverted_name = item["label_inverted_meta_dict"]["filename_or_obj"] + original_name = data[-1]["label"] + self.assertEqual(reverted_name, original_name) + print("invert diff", reverted.size - n_good) + # 25300: 2 workers (cpu, non-macos) + # 1812: 0 workers (gpu or macos) + # 1821: windows torch 1.10.0 + self.assertTrue((reverted.size - n_good) in (34007, 1812, 1821), f"diff. {reverted.size - n_good}") + set_determinism(seed=None) diff --git a/tests/test_is_supported_format.py b/tests/test_is_supported_format.py index c0af8f4395..71a44bd190 100644 --- a/tests/test_is_supported_format.py +++ b/tests/test_is_supported_format.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,35 +15,17 @@ from monai.data import is_supported_format -TEST_CASE_1 = [ - {"filename": "testfile.nii.gz", "suffixes": ["nii", "nii.gz"]}, - True, -] - -TEST_CASE_2 = [ - {"filename": "./testfile.nii.gz", "suffixes": ["nii", "nii.gz"]}, - True, -] - -TEST_CASE_3 = [ - {"filename": "./test.data/file.nii.gz", "suffixes": ["nii", "nii.gz"]}, - True, -] - -TEST_CASE_4 = [ - {"filename": "./test.data/file.nii", "suffixes": ["nii", "nii.gz"]}, - True, -] - -TEST_CASE_5 = [ - {"filename": "C:\\documents\\testfile.nii.gz", "suffixes": ["nii", "nii.gz"]}, - True, -] - -TEST_CASE_6 = [ - {"filename": "1.3.12.2.1107.5.4.4.145.nii.gz", "suffixes": ["nii.gz"]}, - True, -] +TEST_CASE_1 = [{"filename": "testfile.nii.gz", "suffixes": ["nii", "nii.gz"]}, True] + +TEST_CASE_2 = [{"filename": "./testfile.nii.gz", "suffixes": ["nii", "nii.gz"]}, True] + +TEST_CASE_3 = [{"filename": "./test.data/file.nii.gz", "suffixes": ["nii", "nii.gz"]}, True] + +TEST_CASE_4 = [{"filename": "./test.data/file.nii", "suffixes": ["nii", "nii.gz"]}, True] + +TEST_CASE_5 = [{"filename": "C:\\documents\\testfile.nii.gz", "suffixes": ["nii", "nii.gz"]}, True] + +TEST_CASE_6 = [{"filename": "1.3.12.2.1107.5.4.4.145.nii.gz", "suffixes": ["nii.gz"]}, True] class TestIsSupportedFormat(unittest.TestCase): diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 7b16eaf594..2c47a2181e 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,6 +10,7 @@ # limitations under the License. import os +import sys import tempfile import unittest @@ -36,14 +37,9 @@ def test_shape(self): with tempfile.TemporaryDirectory() as tempdir: for i in range(6): nib.save(test_image, os.path.join(tempdir, f"test_image{str(i)}.nii.gz")) - test_data.append({"image": os.path.join(tempdir, f"test_image{str(i)}.nii.gz")}) + test_data.append({"image": os.path.join(tempdir, f"test_image{i}.nii.gz")}) - test_transform = Compose( - [ - LoadImaged(keys="image"), - SimulateDelayd(keys="image", delay_time=1e-7), - ] - ) + test_transform = Compose([LoadImaged(keys="image"), SimulateDelayd(keys="image", delay_time=1e-7)]) data_iterator = _Stream(test_data) with self.assertRaises(TypeError): # Dataset doesn't work @@ -54,7 +50,8 @@ def test_shape(self): for d in dataset: self.assertTupleEqual(d["image"].shape, expected_shape) - dataloader = DataLoader(dataset=dataset, batch_size=3, num_workers=2) + num_workers = 2 if sys.platform == "linux" else 0 + dataloader = DataLoader(dataset=dataset, batch_size=3, num_workers=num_workers) for d in dataloader: self.assertTupleEqual(d["image"].shape[1:], expected_shape) diff --git a/tests/test_k_space_spike_noise.py b/tests/test_k_space_spike_noise.py index bb6d05e676..43717aa214 100644 --- a/tests/test_k_space_spike_noise.py +++ b/tests/test_k_space_spike_noise.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,17 +20,14 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import KSpaceSpikeNoise from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from tests.utils import TEST_NDARRAYS -TEST_CASES = [] +TESTS = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for p in TEST_NDARRAYS: + TESTS.append((shape, p)) -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestKSpaceSpikeNoise(unittest.TestCase): def setUp(self): set_determinism(0) @@ -40,34 +37,44 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, im_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d - im = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] - return torch.Tensor(im) if as_tensor_input else im + im, _ = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) + return im_type(im[None]) - @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_same_result(self, im_shape, im_type): - im = self.get_data(im_shape, as_tensor_input) + im = self.get_data(im_shape, im_type) loc = [0, int(im.shape[1] / 2), 0] if len(im_shape) == 2 else [0, int(im.shape[1] / 2), 0, 0] k_intensity = 10 - t = KSpaceSpikeNoise(loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoise(loc, k_intensity) out1 = t(deepcopy(im)) out2 = t(deepcopy(im)) + self.assertEqual(type(im), type(out1)) + if isinstance(out1, torch.Tensor): + self.assertEqual(im.device, out1.device) + out1 = out1.cpu() + out2 = out2.cpu() + np.testing.assert_allclose(out1, out2) - self.assertIsInstance(out1, torch.Tensor if as_tensor_output else np.ndarray) - @parameterized.expand(TEST_CASES) - def test_highlighted_kspace_pixel(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_highlighted_kspace_pixel(self, im_shape, as_tensor_input): im = self.get_data(im_shape, as_tensor_input) loc = [0, int(im.shape[1] / 2), 0] if len(im_shape) == 2 else [0, int(im.shape[1] / 2), 0, 0] k_intensity = 10 - t = KSpaceSpikeNoise(loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoise(loc, k_intensity) out = t(im) + self.assertEqual(type(im), type(out)) + if isinstance(out, torch.Tensor): + self.assertEqual(im.device, out.device) + out = out.cpu() + n_dims = len(im_shape) out_k = fftshift(fftn(out, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) log_mag = np.log(np.absolute(out_k)) diff --git a/tests/test_k_space_spike_noised.py b/tests/test_k_space_spike_noised.py index 616662b3cd..0230f40b15 100644 --- a/tests/test_k_space_spike_noised.py +++ b/tests/test_k_space_spike_noised.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,19 +20,16 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import KSpaceSpikeNoised from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from tests.utils import TEST_NDARRAYS -TEST_CASES = [] +TESTS = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for p in TEST_NDARRAYS: + TESTS.append((shape, p)) KEYS = ["image", "label"] -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestKSpaceSpikeNoised(unittest.TestCase): def setUp(self): set_determinism(0) @@ -42,55 +39,69 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, im_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) - ims = [im[None] for im in ims] - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims - return dict(zip(KEYS, ims)) + ims = [im_type(im[None]) for im in ims] + return {k: v for k, v in zip(KEYS, ims)} - @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_same_result(self, im_shape, im_type): - data = self.get_data(im_shape, as_tensor_input) + data = self.get_data(im_shape, im_type) loc = [0] + [int(im_shape[i] / 2) for i in range(len(im_shape))] k_intensity = 10 - t = KSpaceSpikeNoised(KEYS, loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoised(KEYS, loc, k_intensity) out1 = t(deepcopy(data)) out2 = t(deepcopy(data)) for k in KEYS: + self.assertEqual(type(out1[k]), type(data[k])) + if isinstance(out1[k], torch.Tensor): + self.assertEqual(out1[k].device, data[k].device) + out1[k] = out1[k].cpu() + out2[k] = out2[k].cpu() np.testing.assert_allclose(out1[k], out2[k]) - self.assertIsInstance(out1[k], torch.Tensor if as_tensor_output else np.ndarray) - @parameterized.expand(TEST_CASES) - def test_highlighted_kspace_pixel(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_highlighted_kspace_pixel(self, im_shape, im_type): - data = self.get_data(im_shape, as_tensor_input) + data = self.get_data(im_shape, im_type) loc = [0] + [int(im_shape[i] / 2) for i in range(len(im_shape))] k_intensity = 10 - t = KSpaceSpikeNoised(KEYS, loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoised(KEYS, loc, k_intensity) out = t(data) for k in KEYS: + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k] = out[k].cpu() + n_dims = len(im_shape) out_k = fftshift(fftn(out[k], axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) log_mag = np.log(np.absolute(out_k)) np.testing.assert_allclose(k_intensity, log_mag[tuple(loc)], 1e-1) - @parameterized.expand(TEST_CASES) - def test_dict_matches(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_dict_matches(self, im_shape, im_type): + data = self.get_data(im_shape, im_type) # use same image for both dictionary entries to check same trans is applied to them data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} loc = [0] + [int(im_shape[i] / 2) for i in range(len(im_shape))] k_intensity = 10 - t = KSpaceSpikeNoised(KEYS, loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoised(KEYS, loc, k_intensity) out = t(deepcopy(data)) + for k in KEYS: + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k] = out[k].cpu() + np.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]]) diff --git a/tests/test_keep_largest_connected_component.py b/tests/test_keep_largest_connected_component.py index 527d986614..5307830cd3 100644 --- a/tests/test_keep_largest_connected_component.py +++ b/tests/test_keep_largest_connected_component.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,331 +15,333 @@ from parameterized import parameterized from monai.transforms import KeepLargestConnectedComponent -from tests.utils import assert_allclose, clone +from tests.utils import TEST_NDARRAYS, assert_allclose -grid_1 = torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]) -grid_2 = torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]) -grid_3 = torch.tensor( +grid_1 = [[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]] +grid_2 = [[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]] +grid_3 = [ [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [1.0, 0.0, 1.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], ], -) -grid_4 = torch.tensor( [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [1.0, 0.0, 1.0, 1.0, 0.0], - [1.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [1.0, 0.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], ], -) - - -TEST_CASE_1 = [ - "value_1", - {"independent": False, "applied_labels": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), ] - -TEST_CASE_2 = [ - "value_2", - {"independent": False, "applied_labels": [2]}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), +grid_4 = [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [1.0, 0.0, 1.0, 1.0, 0.0], + [1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], ] -TEST_CASE_3 = [ - "independent_value_1_2", - {"independent": True, "applied_labels": [1, 2]}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), -] -TEST_CASE_4 = [ - "dependent_value_1_2", - {"independent": False, "applied_labels": [1, 2]}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + "value_1", + {"independent": False, "applied_labels": 1}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), + ] + ) -TEST_CASE_5 = [ - "value_1", - {"independent": True, "applied_labels": [1]}, - grid_2, - torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), -] + TESTS.append( + [ + "value_2", + {"independent": False, "applied_labels": [2]}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), + ] + ) -TEST_CASE_6 = [ - "independent_value_1_2", - {"independent": True, "applied_labels": [1, 2]}, - grid_2, - torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), -] + TESTS.append( + [ + "independent_value_1_2", + {"independent": True, "applied_labels": [1, 2]}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), + ] + ) -TEST_CASE_7 = [ - "dependent_value_1_2", - {"independent": False, "applied_labels": [1, 2]}, - grid_2, - torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]), -] + TESTS.append( + [ + "dependent_value_1_2", + {"independent": False, "applied_labels": [1, 2]}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), + ] + ) -TEST_CASE_8 = [ - "value_1_connect_1", - {"independent": False, "applied_labels": [1], "connectivity": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 2]]]), -] + TESTS.append( + [ + "value_1", + {"independent": True, "applied_labels": [1]}, + p(grid_2), + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), + ] + ) -TEST_CASE_9 = [ - "independent_value_1_2_connect_1", - {"independent": True, "applied_labels": [1, 2], "connectivity": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), -] + TESTS.append( + [ + "independent_value_1_2", + {"independent": True, "applied_labels": [1, 2]}, + p(grid_2), + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), + ] + ) -TEST_CASE_10 = [ - "dependent_value_1_2_connect_1", - {"independent": False, "applied_labels": [1, 2], "connectivity": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), -] + TESTS.append( + [ + "dependent_value_1_2", + {"independent": False, "applied_labels": [1, 2]}, + p(grid_2), + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]), + ] + ) -TEST_CASE_11 = [ - "onehot_independent_batch_2_apply_label_1_connect_1", - {"independent": True, "applied_labels": [1], "connectivity": 1}, - grid_3, - torch.tensor( + TESTS.append( [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], - ], - ), -] + "value_1_connect_1", + {"independent": False, "applied_labels": [1], "connectivity": 1}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 2]]]), + ] + ) -TEST_CASE_12 = [ - "onehot_independent_batch_2_apply_label_1_connect_2", - {"independent": True, "applied_labels": [1], "connectivity": 2}, - grid_3, - torch.tensor( + TESTS.append( [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], - ], - ), -] + "independent_value_1_2_connect_1", + {"independent": True, "applied_labels": [1, 2], "connectivity": 1}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), + ] + ) -TEST_CASE_13 = [ - "onehot_independent_batch_2_apply_label_1_2_connect_2", - {"independent": True, "applied_labels": [1, 2], "connectivity": 2}, - grid_3, - torch.tensor( + TESTS.append( [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - ], - ], - ), -] + "dependent_value_1_2_connect_1", + {"independent": False, "applied_labels": [1, 2], "connectivity": 1}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), + ] + ) -TEST_CASE_14 = [ - "onehot_dependent_batch_2_apply_label_1_2_connect_2", - {"independent": False, "applied_labels": [1, 2], "connectivity": 2}, - grid_4, - torch.tensor( + TESTS.append( [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - ], - ), -] + "onehot_independent_batch_2_apply_label_1_connect_1", + {"independent": True, "applied_labels": [1], "connectivity": 1}, + p(grid_3), + torch.tensor( + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], + ], + ] + ), + ] + ) -TEST_CASE_15 = [ - "onehot_dependent_batch_2_apply_label_1_2_connect_1", - {"independent": False, "applied_labels": [1, 2], "connectivity": 1}, - grid_4, - torch.tensor( + TESTS.append( [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - ], - ), -] + "onehot_independent_batch_2_apply_label_1_connect_2", + {"independent": True, "applied_labels": [1], "connectivity": 2}, + p(grid_3), + torch.tensor( + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], + ], + ] + ), + ] + ) -VALID_CASES = [ - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - TEST_CASE_7, - TEST_CASE_8, - TEST_CASE_9, - TEST_CASE_10, - TEST_CASE_11, - TEST_CASE_12, - TEST_CASE_13, - TEST_CASE_14, - TEST_CASE_15, -] + TESTS.append( + [ + "onehot_independent_batch_2_apply_label_1_2_connect_2", + {"independent": True, "applied_labels": [1, 2], "connectivity": 2}, + p(grid_3), + torch.tensor( + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + ], + ] + ), + ] + ) -ITEST_CASE_1 = ["no_applied_labels_for_single_channel", {"independent": False}, grid_1, TypeError] + TESTS.append( + [ + "onehot_dependent_batch_2_apply_label_1_2_connect_2", + {"independent": False, "applied_labels": [1, 2], "connectivity": 2}, + p(grid_4), + torch.tensor( + [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + ] + ), + ] + ) -ITEST_CASE_2 = ["no_applied_labels_for_multi_channel", {"independent": False}, grid_3, TypeError] + TESTS.append( + [ + "onehot_dependent_batch_2_apply_label_1_2_connect_1", + {"independent": False, "applied_labels": [1, 2], "connectivity": 1}, + p(grid_4), + torch.tensor( + [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + ] + ), + ] + ) -INVALID_CASES = [ITEST_CASE_1, ITEST_CASE_2] +INVALID_CASES = [] +for p in TEST_NDARRAYS: + INVALID_CASES.append(["no_applied_labels_for_single_channel", {"independent": False}, p(grid_1), TypeError]) + INVALID_CASES.append(["no_applied_labels_for_multi_channel", {"independent": False}, p(grid_3), TypeError]) class TestKeepLargestConnectedComponent(unittest.TestCase): - @parameterized.expand(VALID_CASES) + @parameterized.expand(TESTS) def test_correct_results(self, _, args, input_image, expected): converter = KeepLargestConnectedComponent(**args) - if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): - result = converter(clone(input_image).cuda()) - - else: - result = converter(clone(input_image)) - assert_allclose(result, expected) + result = converter(input_image) + assert_allclose(result, expected, type_test=False) @parameterized.expand(INVALID_CASES) def test_raise_exception(self, _, args, input_image, expected_error): with self.assertRaises(expected_error): converter = KeepLargestConnectedComponent(**args) - if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): - _ = converter(clone(input_image).cuda()) - else: - _ = converter(clone(input_image).clone()) + _ = converter(input_image) if __name__ == "__main__": diff --git a/tests/test_keep_largest_connected_componentd.py b/tests/test_keep_largest_connected_componentd.py index 9478cfb965..94a36feed0 100644 --- a/tests/test_keep_largest_connected_componentd.py +++ b/tests/test_keep_largest_connected_componentd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,332 +15,335 @@ from parameterized import parameterized from monai.transforms import KeepLargestConnectedComponentd +from tests.utils import TEST_NDARRAYS, assert_allclose -grid_1 = {"img": torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]])} -grid_2 = {"img": torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]])} -grid_3 = { - "img": torch.tensor( - [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [1.0, 0.0, 1.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], - ], - ) -} -grid_4 = { - "img": torch.tensor( - [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [1.0, 0.0, 1.0, 1.0, 0.0], - [1.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - ], - ) -} - -TEST_CASE_1 = [ - "value_1", - {"keys": ["img"], "independent": False, "applied_labels": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), -] - -TEST_CASE_2 = [ - "value_2", - {"keys": ["img"], "independent": False, "applied_labels": [2]}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), +grid_1 = [[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]] +grid_2 = [[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]] +grid_3 = [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [1.0, 0.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], + ], ] - -TEST_CASE_3 = [ - "independent_value_1_2", - {"keys": ["img"], "independent": True, "applied_labels": [1, 2]}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), +grid_4 = [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [1.0, 0.0, 1.0, 1.0, 0.0], + [1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], ] -TEST_CASE_4 = [ - "dependent_value_1_2", - {"keys": ["img"], "independent": False, "applied_labels": [1, 2]}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), -] +VALID_CASES = [] +for p in TEST_NDARRAYS: + VALID_CASES.append( + [ + "value_1", + {"keys": ["img"], "independent": False, "applied_labels": 1}, + {"img": p(grid_1)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), + ] + ) -TEST_CASE_5 = [ - "value_1", - {"keys": ["img"], "independent": True, "applied_labels": [1]}, - grid_2, - torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), -] + VALID_CASES.append( + [ + "value_2", + {"keys": ["img"], "independent": False, "applied_labels": [2]}, + {"img": p(grid_1)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), + ] + ) -TEST_CASE_6 = [ - "independent_value_1_2", - {"keys": ["img"], "independent": True, "applied_labels": [1, 2]}, - grid_2, - torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), -] + VALID_CASES.append( + [ + "independent_value_1_2", + {"keys": ["img"], "independent": True, "applied_labels": [1, 2]}, + {"img": p(grid_1)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), + ] + ) -TEST_CASE_7 = [ - "dependent_value_1_2", - {"keys": ["img"], "independent": False, "applied_labels": [1, 2]}, - grid_2, - torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]), -] + VALID_CASES.append( + [ + "dependent_value_1_2", + {"keys": ["img"], "independent": False, "applied_labels": [1, 2]}, + {"img": p(grid_1)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), + ] + ) -TEST_CASE_8 = [ - "value_1_connect_1", - {"keys": ["img"], "independent": False, "applied_labels": [1], "connectivity": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 2]]]), -] + VALID_CASES.append( + [ + "value_1", + {"keys": ["img"], "independent": True, "applied_labels": [1]}, + {"img": p(grid_2)}, + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), + ] + ) -TEST_CASE_9 = [ - "independent_value_1_2_connect_1", - {"keys": ["img"], "independent": True, "applied_labels": [1, 2], "connectivity": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), -] + VALID_CASES.append( + [ + "independent_value_1_2", + {"keys": ["img"], "independent": True, "applied_labels": [1, 2]}, + {"img": p(grid_2)}, + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), + ] + ) -TEST_CASE_10 = [ - "dependent_value_1_2_connect_1", - {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "connectivity": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), -] + VALID_CASES.append( + [ + "dependent_value_1_2", + {"keys": ["img"], "independent": False, "applied_labels": [1, 2]}, + {"img": p(grid_2)}, + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]), + ] + ) -TEST_CASE_11 = [ - "onehot_independent_batch_2_apply_label_1_connect_1", - {"keys": ["img"], "independent": True, "applied_labels": [1], "connectivity": 1}, - grid_3, - torch.tensor( + VALID_CASES.append( [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], - ], - ), -] + "value_1_connect_1", + {"keys": ["img"], "independent": False, "applied_labels": [1], "connectivity": 1}, + {"img": p(grid_1)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 2]]]), + ] + ) -TEST_CASE_12 = [ - "onehot_independent_batch_2_apply_label_1_connect_2", - {"keys": ["img"], "independent": True, "applied_labels": [1], "connectivity": 2}, - grid_3, - torch.tensor( + VALID_CASES.append( [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], - ], - ), -] + "independent_value_1_2_connect_1", + {"keys": ["img"], "independent": True, "applied_labels": [1, 2], "connectivity": 1}, + {"img": p(grid_1)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), + ] + ) -TEST_CASE_13 = [ - "onehot_independent_batch_2_apply_label_1_2_connect_2", - {"keys": ["img"], "independent": True, "applied_labels": [1, 2], "connectivity": 2}, - grid_3, - torch.tensor( + VALID_CASES.append( [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - ], - ], - ), -] + "dependent_value_1_2_connect_1", + {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "connectivity": 1}, + {"img": p(grid_1)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), + ] + ) -TEST_CASE_14 = [ - "onehot_dependent_batch_2_apply_label_1_2_connect_2", - {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "connectivity": 2}, - grid_4, - torch.tensor( + VALID_CASES.append( [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - ], - ), -] + "onehot_independent_batch_2_apply_label_1_connect_1", + {"keys": ["img"], "independent": True, "applied_labels": [1], "connectivity": 1}, + {"img": p(grid_3)}, + torch.tensor( + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], + ], + ] + ), + ] + ) -TEST_CASE_15 = [ - "onehot_dependent_batch_2_apply_label_1_2_connect_1", - {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "connectivity": 1}, - grid_4, - torch.tensor( + VALID_CASES.append( [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - ], - ), -] + "onehot_independent_batch_2_apply_label_1_connect_2", + {"keys": ["img"], "independent": True, "applied_labels": [1], "connectivity": 2}, + {"img": p(grid_3)}, + torch.tensor( + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], + ], + ] + ), + ] + ) -VALID_CASES = [ - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - TEST_CASE_7, - TEST_CASE_8, - TEST_CASE_9, - TEST_CASE_10, - TEST_CASE_11, - TEST_CASE_12, - TEST_CASE_13, - TEST_CASE_14, - TEST_CASE_15, -] + VALID_CASES.append( + [ + "onehot_independent_batch_2_apply_label_1_2_connect_2", + {"keys": ["img"], "independent": True, "applied_labels": [1, 2], "connectivity": 2}, + {"img": p(grid_3)}, + torch.tensor( + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + ], + ] + ), + ] + ) -ITEST_CASE_1 = ["no_applied_labels_for_single_channel", {"keys": ["img"], "independent": False}, grid_1, TypeError] + VALID_CASES.append( + [ + "onehot_dependent_batch_2_apply_label_1_2_connect_2", + {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "connectivity": 2}, + {"img": p(grid_4)}, + torch.tensor( + [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + ] + ), + ] + ) -ITEST_CASE_2 = ["no_applied_labels_for_multi_channel", {"keys": ["img"], "independent": False}, grid_3, TypeError] + VALID_CASES.append( + [ + "onehot_dependent_batch_2_apply_label_1_2_connect_1", + {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "connectivity": 1}, + {"img": p(grid_4)}, + torch.tensor( + [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + ] + ), + ] + ) -INVALID_CASES = [ITEST_CASE_1, ITEST_CASE_2] +INVALID_CASES = [] +for p in TEST_NDARRAYS: + INVALID_CASES.append( + ["no_applied_labels_for_single_channel", {"keys": ["img"], "independent": False}, {"img": p(grid_1)}, TypeError] + ) + INVALID_CASES.append( + ["no_applied_labels_for_multi_channel", {"keys": ["img"], "independent": False}, {"img": p(grid_3)}, TypeError] + ) class TestKeepLargestConnectedComponentd(unittest.TestCase): @parameterized.expand(VALID_CASES) def test_correct_results(self, _, args, input_dict, expected): converter = KeepLargestConnectedComponentd(**args) - if torch.cuda.is_available(): - input_dict["img"] = input_dict["img"].cuda() - result = converter(input_dict) - torch.allclose(result["img"], expected.cuda()) - else: - result = converter(input_dict) - torch.allclose(result["img"], expected) + result = converter(input_dict) + assert_allclose(result["img"], expected, type_test=False) @parameterized.expand(INVALID_CASES) def test_raise_exception(self, _, args, input_dict, expected_error): with self.assertRaises(expected_error): converter = KeepLargestConnectedComponentd(**args) - if torch.cuda.is_available(): - input_dict["img"] = input_dict["img"].cuda() _ = converter(input_dict) diff --git a/tests/test_label_filter.py b/tests/test_label_filter.py index c699fb31fd..b782f90441 100644 --- a/tests/test_label_filter.py +++ b/tests/test_label_filter.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,86 +16,41 @@ from parameterized import parameterized from monai.transforms import LabelFilter -from tests.utils import assert_allclose, clone +from tests.utils import TEST_NDARRAYS, assert_allclose -grid_1 = torch.tensor( - [ - [ - [ - [1, 2, 3], - [4, 5, 6], - [7, 8, 9], - ] - ] - ] -) +grid_1 = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]]) - -TEST_CASE_0 = [ - "filter_single_label", - {"applied_labels": 3}, - grid_1, - torch.tensor( +VALID_TESTS = [] +for p in TEST_NDARRAYS: + VALID_TESTS.append( [ - [ - [ - [0, 0, 3], - [0, 0, 0], - [0, 0, 0], - ] - ] + "filter_single_label", + {"applied_labels": 3}, + p(grid_1), + p(torch.tensor([[[[0, 0, 3], [0, 0, 0], [0, 0, 0]]]])), ] - ), -] - + ) -TEST_CASE_1 = [ - "filter_single_label_list", - {"applied_labels": [3]}, - grid_1, - torch.tensor( + VALID_TESTS.append( [ - [ - [ - [0, 0, 3], - [0, 0, 0], - [0, 0, 0], - ] - ] + "filter_single_label_list", + {"applied_labels": [3]}, + p(grid_1), + p(torch.tensor([[[[0, 0, 3], [0, 0, 0], [0, 0, 0]]]])), ] - ), -] - -TEST_CASE_2 = [ - "filter_multi_label", - {"applied_labels": [3, 5, 8]}, - grid_1, - torch.tensor( + ) + + VALID_TESTS.append( [ - [ - [ - [0, 0, 3], - [0, 5, 0], - [0, 8, 0], - ] - ] + "filter_multi_label", + {"applied_labels": [3, 5, 8]}, + p(grid_1), + p(torch.tensor([[[[0, 0, 3], [0, 5, 0], [0, 8, 0]]]])), ] - ), -] - -TEST_CASE_3 = [ - "filter_all", - {"applied_labels": [1, 2, 3, 4, 5, 6, 7, 8, 9]}, - grid_1, - grid_1, -] - -VALID_CASES = [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, -] + ) + + VALID_TESTS.append(["filter_all", {"applied_labels": [1, 2, 3, 4, 5, 6, 7, 8, 9]}, p(grid_1), p(grid_1)]) + ITEST_CASE_1 = ["invalid_image_data_type", {"applied_labels": 1}, [[[[1, 1, 1]]]], NotImplementedError] @@ -103,13 +58,10 @@ class TestLabelFilter(unittest.TestCase): - @parameterized.expand(VALID_CASES) + @parameterized.expand(VALID_TESTS) def test_correct_results(self, _, args, input_image, expected): converter = LabelFilter(**args) - if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): - result = converter(clone(input_image).cuda()) - else: - result = converter(clone(input_image)) + result = converter(input_image) assert_allclose(result, expected) @parameterized.expand(INVALID_CASES) @@ -117,9 +69,9 @@ def test_raise_exception(self, _, args, input_image, expected_error): with self.assertRaises(expected_error): converter = LabelFilter(**args) if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): - _ = converter(clone(input_image).cuda()) + _ = converter(input_image.cuda()) else: - _ = converter(clone(input_image)) + _ = converter(input_image) if __name__ == "__main__": diff --git a/tests/test_label_filterd.py b/tests/test_label_filterd.py new file mode 100644 index 0000000000..d53dc21faf --- /dev/null +++ b/tests/test_label_filterd.py @@ -0,0 +1,78 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms.post.dictionary import LabelFilterd +from tests.utils import TEST_NDARRAYS, assert_allclose + +grid_1 = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]]) + +VALID_TESTS = [] +for p in TEST_NDARRAYS: + VALID_TESTS.append( + [ + "filter_single_label", + {"applied_labels": 3}, + p(grid_1), + p(torch.tensor([[[[0, 0, 3], [0, 0, 0], [0, 0, 0]]]])), + ] + ) + + VALID_TESTS.append( + [ + "filter_single_label_list", + {"applied_labels": [3]}, + p(grid_1), + p(torch.tensor([[[[0, 0, 3], [0, 0, 0], [0, 0, 0]]]])), + ] + ) + + VALID_TESTS.append( + [ + "filter_multi_label", + {"applied_labels": [3, 5, 8]}, + p(grid_1), + p(torch.tensor([[[[0, 0, 3], [0, 5, 0], [0, 8, 0]]]])), + ] + ) + + VALID_TESTS.append(["filter_all", {"applied_labels": [1, 2, 3, 4, 5, 6, 7, 8, 9]}, p(grid_1), p(grid_1)]) + + +ITEST_CASE_1 = ["invalid_image_data_type", {"applied_labels": 1}, [[[[1, 1, 1]]]], NotImplementedError] + +INVALID_CASES = [ITEST_CASE_1] + + +class TestLabelFilter(unittest.TestCase): + @parameterized.expand(VALID_TESTS) + def test_correct_results(self, _, args, input_image, expected): + converter = LabelFilterd(keys="image", **args) + result = converter({"image": input_image})["image"] + assert_allclose(result, expected) + + @parameterized.expand(INVALID_CASES) + def test_raise_exception(self, _, args, input_image, expected_error): + with self.assertRaises(expected_error): + converter = LabelFilterd(keys="image", **args) + if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): + _ = converter({"image": input_image.cuda()}) + else: + _ = converter({"image": input_image}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_label_to_contour.py b/tests/test_label_to_contour.py index 8f8f3cc054..fef40af08d 100644 --- a/tests/test_label_to_contour.py +++ b/tests/test_label_to_contour.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,108 +15,107 @@ import torch from monai.transforms import LabelToContour +from tests.utils import TEST_NDARRAYS, assert_allclose -expected_output_for_cube = np.array( +expected_output_for_cube = [ [ - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - ] -) - - -def gen_fixed_cube(): + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], +] + + +def gen_fixed_cube(array_type): scale, core_start, core_end = 8, 1, 7 - cube = torch.zeros(scale, scale, scale) + cube = np.zeros((scale, scale, scale)) cube[core_start:core_end, core_start:core_end, core_start:core_end] = torch.ones( core_end - core_start, core_end - core_start, core_end - core_start ) - cube = torch.unsqueeze(cube, 0) + cube = cube[None] batch_size, channels = 10, 6 - cube = cube.repeat(batch_size, channels, 1, 1, 1) - return cube, expected_output_for_cube + cube = np.tile(cube, (batch_size, channels, 1, 1, 1)) + return array_type(cube), array_type(expected_output_for_cube) -def gen_fixed_img(): - img = torch.tensor( +def gen_fixed_img(array_type): + img = np.array( [ [0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1], @@ -124,19 +123,18 @@ def gen_fixed_img(): [0, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], ], - dtype=torch.float32, + dtype=np.float32, ) batch_size, channels = 10, 6 - img = img.repeat(batch_size, channels, 1, 1) - expected_output_for_img = torch.tensor( + img = array_type(np.tile(img, (batch_size, channels, 1, 1))) + expected_output_for_img = array_type( [ [0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 0, 0, 1], [0, 0, 1, 1, 0, 0, 1], [0, 1, 1, 0, 0, 0, 1], [1, 1, 1, 1, 1, 1, 1], - ], - dtype=torch.float32, + ] ) return img, expected_output_for_img @@ -145,33 +143,34 @@ class TestContour(unittest.TestCase): def test_contour(self): input_param = {"kernel_type": "Laplace"} - # check 5-dim input data - test_cube, expected_output = gen_fixed_cube() - for cube in test_cube: - test_result_cube = LabelToContour(**input_param)(cube) - self.assertEqual(test_result_cube.shape, cube.shape) + for p in TEST_NDARRAYS: + # check 5-dim input data + test_cube, expected_output = gen_fixed_cube(p) + for cube in test_cube: + test_result_cube = LabelToContour(**input_param)(cube) + self.assertEqual(test_result_cube.shape, cube.shape) - test_result_np = test_result_cube.cpu().numpy() - channels = cube.shape[0] - for channel in range(channels): - np.testing.assert_allclose(test_result_np[channel, ...], expected_output) + channels = cube.shape[0] + for channel in range(channels): + assert_allclose(test_result_cube[channel, ...], expected_output) - # check 4-dim input data - test_img, expected_output = gen_fixed_img() - for img in test_img: - channels = img.shape[0] - test_result_img = LabelToContour(**input_param)(img) - self.assertEqual(test_result_img.shape, img.shape) + # check 4-dim input data + test_img, expected_output = gen_fixed_img(p) + for img in test_img: + channels = img.shape[0] + test_result_img = LabelToContour(**input_param)(img) + self.assertEqual(test_result_img.shape, img.shape) - test_result_np = test_result_img.cpu().numpy() - for channel in range(channels): - np.testing.assert_allclose(test_result_np[channel, ...], expected_output) + for channel in range(channels): + assert_allclose(test_result_img[channel, ...], expected_output) # check invalid input data error_input = torch.rand(1, 2) self.assertRaises(ValueError, LabelToContour(**input_param), error_input) error_input = torch.rand(1, 2, 3, 4, 5) self.assertRaises(ValueError, LabelToContour(**input_param), error_input) + error_input = np.random.rand(1, 2, 3, 4, 5) + self.assertRaises(ValueError, LabelToContour(**input_param), error_input) if __name__ == "__main__": diff --git a/tests/test_label_to_contourd.py b/tests/test_label_to_contourd.py index d3795755c7..6481e803ba 100644 --- a/tests/test_label_to_contourd.py +++ b/tests/test_label_to_contourd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,108 +15,107 @@ import torch from monai.transforms import LabelToContourd +from tests.utils import TEST_NDARRAYS, assert_allclose -expected_output_for_cube = np.array( +expected_output_for_cube = [ [ - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ], - ] -) - - -def gen_fixed_cube(): + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], +] + + +def gen_fixed_cube(array_type): scale, core_start, core_end = 8, 1, 7 - cube = torch.zeros(scale, scale, scale) + cube = np.zeros((scale, scale, scale)) cube[core_start:core_end, core_start:core_end, core_start:core_end] = torch.ones( core_end - core_start, core_end - core_start, core_end - core_start ) - cube = torch.unsqueeze(cube, 0) + cube = cube[None] batch_size, channels = 10, 6 - cube = cube.repeat(batch_size, channels, 1, 1, 1) - return cube, expected_output_for_cube + cube = np.tile(cube, (batch_size, channels, 1, 1, 1)) + return array_type(cube), array_type(expected_output_for_cube) -def gen_fixed_img(): - img = torch.tensor( +def gen_fixed_img(array_type): + img = np.array( [ [0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1], @@ -124,19 +123,19 @@ def gen_fixed_img(): [0, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], ], - dtype=torch.float32, + dtype=np.float32, ) batch_size, channels = 10, 6 - img = img.repeat(batch_size, channels, 1, 1) - expected_output_for_img = torch.tensor( + img = np.tile(img, (batch_size, channels, 1, 1)) + img = array_type(img) + expected_output_for_img = array_type( [ [0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 0, 0, 1], [0, 0, 1, 1, 0, 0, 1], [0, 1, 1, 0, 0, 0, 1], [1, 1, 1, 1, 1, 1, 1], - ], - dtype=torch.float32, + ] ) return img, expected_output_for_img @@ -145,31 +144,34 @@ class TestContourd(unittest.TestCase): def test_contour(self): input_param = {"keys": "img", "kernel_type": "Laplace"} - # check 5-dim input data - test_cube, expected_output = gen_fixed_cube() - for cube in test_cube: - test_result_cube = LabelToContourd(**input_param)({"img": cube}) - self.assertEqual(test_result_cube["img"].shape, cube.shape) - - test_result_np = test_result_cube["img"].cpu().numpy() - channels = cube.shape[0] - for channel in range(channels): - np.testing.assert_allclose(test_result_np[channel, ...], expected_output) - - # check 4-dim input data - test_img, expected_output = gen_fixed_img() - for img in test_img: - channels = img.shape[0] - test_result_img = LabelToContourd(**input_param)({"img": img}) - self.assertEqual(test_result_img["img"].shape, img.shape) - - test_result_np = test_result_img["img"].cpu().numpy() - for channel in range(channels): - np.testing.assert_allclose(test_result_np[channel, ...], expected_output) + for p in TEST_NDARRAYS: + # check 5-dim input data + test_cube, expected_output = gen_fixed_cube(p) + for cube in test_cube: + test_result_cube = LabelToContourd(**input_param)({"img": cube}) + self.assertEqual(test_result_cube["img"].shape, cube.shape) + + test_result_np = test_result_cube["img"] + channels = cube.shape[0] + for channel in range(channels): + assert_allclose(test_result_np[channel, ...], expected_output) + + # check 4-dim input data + test_img, expected_output = gen_fixed_img(p) + for img in test_img: + channels = img.shape[0] + test_result_img = LabelToContourd(**input_param)({"img": img}) + self.assertEqual(test_result_img["img"].shape, img.shape) + + test_result_np = test_result_img["img"] + for channel in range(channels): + assert_allclose(test_result_np[channel, ...], expected_output) # check invalid input data error_input = {"img": torch.rand(1, 2)} self.assertRaises(ValueError, LabelToContourd(**input_param), error_input) + error_input = {"img": np.random.rand(1, 2)} + self.assertRaises(ValueError, LabelToContourd(**input_param), error_input) error_input = {"img": torch.rand(1, 2, 3, 4, 5)} self.assertRaises(ValueError, LabelToContourd(**input_param), error_input) diff --git a/tests/test_label_to_mask.py b/tests/test_label_to_mask.py index 9caa7252f3..8f81a8da1a 100644 --- a/tests/test_label_to_mask.py +++ b/tests/test_label_to_mask.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -64,7 +64,7 @@ def test_value(self, argments, image, expected_data): self.assertEqual(type(result), type(image)) if isinstance(result, torch.Tensor): self.assertEqual(result.device, image.device) - assert_allclose(result, expected_data) + assert_allclose(result, expected_data, type_test=False) if __name__ == "__main__": diff --git a/tests/test_label_to_maskd.py b/tests/test_label_to_maskd.py index b8f0d3c171..e67b857502 100644 --- a/tests/test_label_to_maskd.py +++ b/tests/test_label_to_maskd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -65,7 +65,7 @@ def test_value(self, argments, input_data, expected_data): self.assertEqual(type(r), type(i)) if isinstance(r, torch.Tensor): self.assertEqual(r.device, i.device) - assert_allclose(r, expected_data) + assert_allclose(r, expected_data, type_test=False) if __name__ == "__main__": diff --git a/tests/test_lambda.py b/tests/test_lambda.py index 738c81130d..c187cc979b 100644 --- a/tests/test_lambda.py +++ b/tests/test_lambda.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_lambdad.py b/tests/test_lambdad.py index 05ba0ff6bc..30d70f40fb 100644 --- a/tests/test_lambdad.py +++ b/tests/test_lambdad.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_lesion_froc.py b/tests/test_lesion_froc.py index 2454de88fa..6b4989a9d5 100644 --- a/tests/test_lesion_froc.py +++ b/tests/test_lesion_froc.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,18 +19,19 @@ from monai.apps.pathology.metrics import LesionFROC from monai.utils import optional_import -_, has_cucim = optional_import("cucim") +_cucim, has_cucim = optional_import("cucim") +has_cucim = has_cucim and hasattr(_cucim, "CuImage") _, has_skimage = optional_import("skimage.measure") _, has_sp = optional_import("scipy.ndimage") -PILImage, has_pil = optional_import("PIL.Image") +imwrite, has_tif = optional_import("tifffile", name="imwrite") def save_as_tif(filename, array): array = array[::-1, ...] # Upside-down - img = PILImage.fromarray(array) if not filename.endswith(".tif"): filename += ".tif" - img.save(os.path.join("tests", "testing_data", filename)) + file_path = os.path.join("tests", "testing_data", filename) + imwrite(file_path, array, compress="jpeg", tile=(16, 16)) def around(val, interval=3): @@ -301,7 +302,7 @@ class TestEvaluateTumorFROC(unittest.TestCase): @skipUnless(has_cucim, "Requires cucim") @skipUnless(has_skimage, "Requires skimage") @skipUnless(has_sp, "Requires scipy") - @skipUnless(has_pil, "Requires PIL") + @skipUnless(has_tif, "Requires tifffile") def setUp(self): prepare_test_data() diff --git a/tests/test_list_data_collate.py b/tests/test_list_data_collate.py index eebac69fcf..93b06cc187 100644 --- a/tests/test_list_data_collate.py +++ b/tests/test_list_data_collate.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_list_to_dict.py b/tests/test_list_to_dict.py index 2f026f3e29..ec81310c9f 100644 --- a/tests/test_list_to_dict.py +++ b/tests/test_list_to_dict.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,25 +15,13 @@ from monai.utils import list_to_dict -TEST_CASE_1 = [ - ["a=1", "b=2", "c=3", "d=4"], - {"a": 1, "b": 2, "c": 3, "d": 4}, -] +TEST_CASE_1 = [["a=1", "b=2", "c=3", "d=4"], {"a": 1, "b": 2, "c": 3, "d": 4}] -TEST_CASE_2 = [ - ["a=a", "b=b", "c=c", "d=d"], - {"a": "a", "b": "b", "c": "c", "d": "d"}, -] +TEST_CASE_2 = [["a=a", "b=b", "c=c", "d=d"], {"a": "a", "b": "b", "c": "c", "d": "d"}] -TEST_CASE_3 = [ - ["a=0.1", "b=0.2", "c=0.3", "d=0.4"], - {"a": 0.1, "b": 0.2, "c": 0.3, "d": 0.4}, -] +TEST_CASE_3 = [["a=0.1", "b=0.2", "c=0.3", "d=0.4"], {"a": 0.1, "b": 0.2, "c": 0.3, "d": 0.4}] -TEST_CASE_4 = [ - ["a=True", "b=TRUE", "c=false", "d=FALSE"], - {"a": True, "b": True, "c": False, "d": False}, -] +TEST_CASE_4 = [["a=True", "b=TRUE", "c=false", "d=FALSE"], {"a": True, "b": True, "c": False, "d": False}] TEST_CASE_5 = [ ["a='1'", "b=2 ", " c = 3", "d='test'", "'e'=0", "f", "g=None"], diff --git a/tests/test_lltm.py b/tests/test_lltm.py index f1311379bc..7633c2fe34 100644 --- a/tests/test_lltm.py +++ b/tests/test_lltm.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,7 +15,9 @@ from parameterized import parameterized from monai.networks.layers import LLTM -from tests.utils import SkipIfNoModule +from tests.utils import SkipIfNoModule, is_tf32_env + +_rtol = 0.001 if is_tf32_env() else 0.0001 TEST_CASE_1 = [ {"input_features": 32, "state_size": 2}, @@ -50,8 +52,8 @@ def test_value_cuda(self, input_param, expected_h, expected_c): new_h, new_c = lltm(x, (h, c)) (new_h.sum() + new_c.sum()).backward() - torch.testing.assert_allclose(new_h, expected_h.to(device), rtol=0.0001, atol=1e-04) - torch.testing.assert_allclose(new_c, expected_c.to(device), rtol=0.0001, atol=1e-04) + torch.testing.assert_allclose(new_h, expected_h.to(device), rtol=_rtol, atol=0.001) + torch.testing.assert_allclose(new_c, expected_c.to(device), rtol=_rtol, atol=0.001) if __name__ == "__main__": diff --git a/tests/test_lmdbdataset.py b/tests/test_lmdbdataset.py index fbdb651297..b624e5c4e3 100644 --- a/tests/test_lmdbdataset.py +++ b/tests/test_lmdbdataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_lmdbdataset_dist.py b/tests/test_lmdbdataset_dist.py new file mode 100644 index 0000000000..cad2949dde --- /dev/null +++ b/tests/test_lmdbdataset_dist.py @@ -0,0 +1,72 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import tempfile +import unittest + +import numpy as np + +from monai.data import LMDBDataset, json_hashing +from monai.transforms import Transform +from tests.utils import DistCall, DistTestCase, skip_if_windows + + +class _InplaceXform(Transform): + def __call__(self, data): + if data: + data[0] = data[0] + np.pi + else: + data.append(1) + return data + + +@skip_if_windows +class TestMPLMDBDataset(DistTestCase): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tempdir) + + @DistCall(nnodes=1, nproc_per_node=1) + def test_mp_cache(self): + items = [[list(range(i))] for i in range(5)] + + ds = LMDBDataset(items, transform=_InplaceXform(), cache_dir=self.tempdir, lmdb_kwargs={"map_size": 10 * 1024}) + self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) + ds1 = LMDBDataset(items, transform=_InplaceXform(), cache_dir=self.tempdir, lmdb_kwargs={"map_size": 10 * 1024}) + self.assertEqual(list(ds1), list(ds)) + self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) + + ds = LMDBDataset( + items, + transform=_InplaceXform(), + cache_dir=self.tempdir, + lmdb_kwargs={"map_size": 10 * 1024}, + hash_func=json_hashing, + ) + self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) + ds1 = LMDBDataset( + items, + transform=_InplaceXform(), + cache_dir=self.tempdir, + lmdb_kwargs={"map_size": 10 * 1024}, + hash_func=json_hashing, + ) + self.assertEqual(list(ds1), list(ds)) + self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) + + self.assertTrue(isinstance(ds1.info(), dict)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_load_decathlon_datalist.py b/tests/test_load_decathlon_datalist.py index fe7ff6f8a2..91d144d84f 100644 --- a/tests/test_load_decathlon_datalist.py +++ b/tests/test_load_decathlon_datalist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,6 +13,7 @@ import os import tempfile import unittest +from pathlib import Path from monai.data import load_decathlon_datalist @@ -115,7 +116,7 @@ def test_additional_items(self): file_path = os.path.join(tempdir, "test_data.json") with open(file_path, "w") as json_file: json_file.write(json_str) - result = load_decathlon_datalist(file_path, True, "training", tempdir) + result = load_decathlon_datalist(file_path, True, "training", Path(tempdir)) self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz")) self.assertEqual(result[0]["label"], os.path.join(tempdir, "spleen_19.nii.gz")) self.assertEqual(result[1]["mask"], os.path.join(tempdir, "mask31.txt")) diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 2aa6eced65..f5121c0fd7 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -68,11 +68,7 @@ def get_data(self, _obj): (3, 128, 128, 128), ] -TEST_CASE_5 = [ - {"reader": NibabelReader(mmap=False), "image_only": False}, - ["test_image.nii.gz"], - (128, 128, 128), -] +TEST_CASE_5 = [{"reader": NibabelReader(mmap=False), "image_only": False}, ["test_image.nii.gz"], (128, 128, 128)] TEST_CASE_6 = [{"reader": ITKReader(), "image_only": True}, ["test_image.nii.gz"], (128, 128, 128)] @@ -94,12 +90,21 @@ def get_data(self, _obj): {"image_only": False, "reader": ITKReader(pixel_type=itk.UC)}, "tests/testing_data/CT_DICOM", (16, 16, 4), + (16, 16, 4), ] TEST_CASE_11 = [ {"image_only": False, "reader": "ITKReader", "pixel_type": itk.UC}, "tests/testing_data/CT_DICOM", (16, 16, 4), + (16, 16, 4), +] + +TEST_CASE_12 = [ + {"image_only": False, "reader": "ITKReader", "pixel_type": itk.UC, "reverse_indexing": True}, + "tests/testing_data/CT_DICOM", + (16, 16, 4), + (4, 16, 16), ] @@ -142,11 +147,11 @@ def test_itk_reader(self, input_param, filenames, expected_shape): np.testing.assert_allclose(header["original_affine"], np_diag) self.assertTupleEqual(result.shape, expected_shape) - @parameterized.expand([TEST_CASE_10, TEST_CASE_11]) - def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape): + @parameterized.expand([TEST_CASE_10, TEST_CASE_11, TEST_CASE_12]) + def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape, expected_np_shape): result, header = LoadImage(**input_param)(filenames) self.assertTrue("affine" in header) - self.assertEqual(header["filename_or_obj"], filenames) + self.assertEqual(header["filename_or_obj"], f"{Path(filenames)}") np.testing.assert_allclose( header["affine"], np.array( @@ -158,8 +163,8 @@ def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape): ] ), ) - self.assertTupleEqual(result.shape, expected_shape) self.assertTupleEqual(tuple(header["spatial_shape"]), expected_shape) + self.assertTupleEqual(result.shape, expected_np_shape) def test_itk_reader_multichannel(self): test_image = np.random.randint(0, 256, size=(256, 224, 3)).astype("uint8") @@ -167,12 +172,31 @@ def test_itk_reader_multichannel(self): filename = os.path.join(tempdir, "test_image.png") itk_np_view = itk.image_view_from_array(test_image, is_vector=True) itk.imwrite(itk_np_view, filename) - result, header = LoadImage(reader=ITKReader())(Path(filename)) + for flag in (False, True): + result, header = LoadImage(reader=ITKReader(reverse_indexing=flag))(Path(filename)) + + self.assertTupleEqual(tuple(header["spatial_shape"]), (224, 256)) + test_image = test_image.transpose(1, 0, 2) + np.testing.assert_allclose(result[:, :, 0], test_image[:, :, 0]) + np.testing.assert_allclose(result[:, :, 1], test_image[:, :, 1]) + np.testing.assert_allclose(result[:, :, 2], test_image[:, :, 2]) + + def test_load_nifti_multichannel(self): + test_image = np.random.randint(0, 256, size=(31, 64, 16, 2)).astype(np.float32) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, "test_image.nii.gz") + itk_np_view = itk.image_view_from_array(test_image, is_vector=True) + itk.imwrite(itk_np_view, filename) + + itk_img, itk_header = LoadImage(reader=ITKReader())(Path(filename)) + self.assertTupleEqual(tuple(itk_header["spatial_shape"]), (16, 64, 31)) + self.assertTupleEqual(tuple(itk_img.shape), (16, 64, 31, 2)) + + nib_image, nib_header = LoadImage(reader=NibabelReader(squeeze_non_spatial_dims=True))(Path(filename)) + self.assertTupleEqual(tuple(nib_header["spatial_shape"]), (16, 64, 31)) + self.assertTupleEqual(tuple(nib_image.shape), (16, 64, 31, 2)) - self.assertTupleEqual(tuple(header["spatial_shape"]), (224, 256)) - np.testing.assert_allclose(result[:, :, 0], test_image[:, :, 0].T) - np.testing.assert_allclose(result[:, :, 1], test_image[:, :, 1].T) - np.testing.assert_allclose(result[:, :, 2], test_image[:, :, 2].T) + np.testing.assert_allclose(itk_img, nib_image, atol=1e-3, rtol=1e-3) def test_load_png(self): spatial_size = (256, 224) diff --git a/tests/test_load_imaged.py b/tests/test_load_imaged.py index ca5b56a7d9..39885a2cae 100644 --- a/tests/test_load_imaged.py +++ b/tests/test_load_imaged.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -81,12 +81,7 @@ class TestConsistency(unittest.TestCase): def _cmp(self, filename, shape, ch_shape, reader_1, reader_2, outname, ext): data_dict = {"img": filename} keys = data_dict.keys() - xforms = Compose( - [ - LoadImaged(keys, reader=reader_1), - EnsureChannelFirstD(keys), - ] - ) + xforms = Compose([LoadImaged(keys, reader=reader_1), EnsureChannelFirstD(keys)]) img_dict = xforms(data_dict) # load dicom with itk self.assertTupleEqual(img_dict["img"].shape, ch_shape) self.assertTupleEqual(tuple(img_dict["img_meta_dict"]["spatial_shape"]), shape) @@ -97,12 +92,7 @@ def _cmp(self, filename, shape, ch_shape, reader_1, reader_2, outname, ext): ) save_xform(img_dict) # save to nifti - new_xforms = Compose( - [ - LoadImaged(keys, reader=reader_2), - EnsureChannelFirstD(keys), - ] - ) + new_xforms = Compose([LoadImaged(keys, reader=reader_2), EnsureChannelFirstD(keys)]) out = new_xforms({"img": os.path.join(tempdir, outname)}) # load nifti with itk self.assertTupleEqual(out["img"].shape, ch_shape) self.assertTupleEqual(tuple(out["img_meta_dict"]["spatial_shape"]), shape) diff --git a/tests/test_load_spacing_orientation.py b/tests/test_load_spacing_orientation.py index 48aac7ec56..7690adf284 100644 --- a/tests/test_load_spacing_orientation.py +++ b/tests/test_load_spacing_orientation.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_loader_semaphore.py b/tests/test_loader_semaphore.py index 85c6d54f35..bbb2d4eef6 100644 --- a/tests/test_loader_semaphore.py +++ b/tests/test_loader_semaphore.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py index 31954e727b..8070c27f90 100644 --- a/tests/test_local_normalized_cross_correlation_loss.py +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_localnet.py b/tests/test_localnet.py index dc680f15f9..1a288fb447 100644 --- a/tests/test_localnet.py +++ b/tests/test_localnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_localnet_block.py b/tests/test_localnet_block.py index f4e857a0fa..d85509344e 100644 --- a/tests/test_localnet_block.py +++ b/tests/test_localnet_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_look_up_option.py b/tests/test_look_up_option.py index 60786f2fc5..89fec1b575 100644 --- a/tests/test_look_up_option.py +++ b/tests/test_look_up_option.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py index 5b730c2a77..78c94d4e41 100644 --- a/tests/test_lr_finder.py +++ b/tests/test_lr_finder.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,6 +10,7 @@ # limitations under the License. import os +import pickle import random import sys import unittest @@ -78,7 +79,14 @@ def test_lr_finder(self): learning_rate = 1e-5 optimizer = torch.optim.Adam(model.parameters(), learning_rate) - lr_finder = LearningRateFinder(model, optimizer, loss_function, device=device) + lr_finder = LearningRateFinder( + model=model, + optimizer=optimizer, + criterion=loss_function, + device=device, + pickle_module=pickle, + pickle_protocol=4, + ) lr_finder.range_test(train_loader, val_loader=train_loader, end_lr=10, num_iter=5) print(lr_finder.get_steepest_gradient(0, 0)[0]) diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py index aa126f7848..a3e1ea9dd6 100644 --- a/tests/test_lr_scheduler.py +++ b/tests/test_lr_scheduler.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,7 +19,7 @@ class SchedulerTestNet(torch.nn.Module): def __init__(self): - super(SchedulerTestNet, self).__init__() + super().__init__() self.conv1 = torch.nn.Conv2d(1, 1, 1) self.conv2 = torch.nn.Conv2d(1, 1, 1) @@ -28,13 +28,7 @@ def forward(self, x): TEST_CASE_LRSCHEDULER = [ - [ - { - "warmup_steps": 2, - "t_total": 10, - }, - [0.000, 0.500, 1.00, 0.962, 0.854, 0.691, 0.500, 0.309, 0.146, 0.038], - ] + [{"warmup_steps": 2, "t_total": 10}, [0.000, 0.500, 1.00, 0.962, 0.854, 0.691, 0.500, 0.309, 0.146, 0.038]] ] @@ -47,11 +41,11 @@ def test_shape(self, input_param, expected_lr): self.assertEqual(len([scheduler.get_last_lr()[0]]), 1) lrs_1 = [] for _ in range(input_param["t_total"]): - lrs_1.append(float("{:.3f}".format(scheduler.get_last_lr()[0]))) + lrs_1.append(float(f"{scheduler.get_last_lr()[0]:.3f}")) optimizer.step() scheduler.step() for a, b in zip(lrs_1, expected_lr): - self.assertEqual(a, b, msg="LR is wrong ! expected {}, got {}".format(b, a)) + self.assertEqual(a, b, msg=f"LR is wrong ! expected {b}, got {a}") if __name__ == "__main__": diff --git a/tests/test_map_binary_to_indices.py b/tests/test_map_binary_to_indices.py index 1fafa6f446..bc96231160 100644 --- a/tests/test_map_binary_to_indices.py +++ b/tests/test_map_binary_to_indices.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,50 +15,58 @@ from parameterized import parameterized from monai.transforms import map_binary_to_indices +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": None, "image_threshold": 0.0}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 4, 8]), -] - -TEST_CASE_2 = [ - { - "label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), - "image": np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]]), - "image_threshold": 0.0, - }, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] - -TEST_CASE_3 = [ - { - "label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), - "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), - "image_threshold": 1.0, - }, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] - -TEST_CASE_4 = [ - { - "label": np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), - "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), - "image_threshold": 1.0, - }, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"label": p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])), "image": None, "image_threshold": 0.0}, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 4, 8]), + ] + ) + TESTS.append( + [ + { + "label": p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])), + "image": p(np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]])), + "image_threshold": 0.0, + }, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 8]), + ] + ) + TESTS.append( + [ + { + "label": p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])), + "image": p(np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])), + "image_threshold": 1.0, + }, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 8]), + ] + ) + TESTS.append( + [ + { + "label": p(np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]])), + "image": p(np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])), + "image_threshold": 1.0, + }, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 8]), + ] + ) class TestMapBinaryToIndices(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand(TESTS) def test_type_shape(self, input_data, expected_fg, expected_bg): fg_indices, bg_indices = map_binary_to_indices(**input_data) - np.testing.assert_allclose(fg_indices, expected_fg) - np.testing.assert_allclose(bg_indices, expected_bg) + assert_allclose(fg_indices, expected_fg, type_test=False) + assert_allclose(bg_indices, expected_bg, type_test=False) if __name__ == "__main__": diff --git a/tests/test_map_classes_to_indices.py b/tests/test_map_classes_to_indices.py index 2320954520..2f32382f6b 100644 --- a/tests/test_map_classes_to_indices.py +++ b/tests/test_map_classes_to_indices.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,86 +15,117 @@ from parameterized import parameterized from monai.transforms import map_classes_to_indices +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - # test Argmax data - {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), "num_classes": 3, "image": None, "image_threshold": 0.0}, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + # test Argmax data + { + "label": p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + "num_classes": 3, + "image": None, + "image_threshold": 0.0, + }, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], + ] + ) -TEST_CASE_2 = [ - { - "label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), - "num_classes": 3, - "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - "image_threshold": 60, - }, - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + TESTS.append( + [ + { + "label": p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + "num_classes": 3, + "image": p(np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]])), + "image_threshold": 60, + }, + [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], + ] + ) -TEST_CASE_3 = [ - # test One-Hot data - { - "label": np.array( - [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], - ] - ), - "image": None, - "image_threshold": 0.0, - }, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] + TESTS.append( + [ + # test One-Hot data + { + "label": p( + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ) + ), + "image": None, + "image_threshold": 0.0, + }, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], + ] + ) -TEST_CASE_4 = [ - { - "label": np.array( - [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], - ] - ), - "num_classes": None, - "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - "image_threshold": 60, - }, - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + TESTS.append( + [ + { + "label": p( + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ) + ), + "num_classes": None, + "image": p(np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]])), + "image_threshold": 60, + }, + [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], + ] + ) -TEST_CASE_5 = [ - # test empty class - {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), "num_classes": 5, "image": None, "image_threshold": 0.0}, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7]), np.array([]), np.array([])], -] + TESTS.append( + [ + # test empty class + { + "label": p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + "num_classes": 5, + "image": None, + "image_threshold": 0.0, + }, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7]), np.array([]), np.array([])], + ] + ) -TEST_CASE_6 = [ - # test empty class - { - "label": np.array( - [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], - [[0, 0, 0], [0, 0, 0], [0, 0, 0]], - [[0, 0, 0], [0, 0, 0], [0, 0, 0]], - ] - ), - "image": None, - "image_threshold": 0.0, - }, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7]), np.array([]), np.array([])], -] + TESTS.append( + [ + # test empty class + { + "label": p( + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + ] + ) + ), + "image": None, + "image_threshold": 0.0, + }, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7]), np.array([]), np.array([])], + ] + ) class TestMapClassesToIndices(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand(TESTS) def test_value(self, input_data, expected_indices): indices = map_classes_to_indices(**input_data) for i, e in zip(indices, expected_indices): - np.testing.assert_allclose(i, e) + assert_allclose(i, e, type_test=False) if __name__ == "__main__": diff --git a/tests/test_map_label_value.py b/tests/test_map_label_value.py index ff1d7d1eef..0416858a74 100644 --- a/tests/test_map_label_value.py +++ b/tests/test_map_label_value.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,75 +12,67 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import MapLabelValue +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"orig_labels": [3, 2, 1], "target_labels": [0, 1, 2]}, - np.array([[3, 1], [1, 2]]), - np.array([[0, 2], [2, 1]]), -] - -TEST_CASE_2 = [ - {"orig_labels": [3, 5, 8], "target_labels": [0, 1, 2]}, - np.array([[[3], [5], [5], [8]]]), - np.array([[[0], [1], [1], [2]]]), -] - -TEST_CASE_3 = [ - {"orig_labels": [1, 2, 3], "target_labels": [0, 1, 2]}, - np.array([3, 1, 1, 2]), - np.array([2, 0, 0, 1]), -] - -TEST_CASE_4 = [ - {"orig_labels": [1, 2, 3], "target_labels": [0.5, 1.5, 2.5]}, - np.array([3, 1, 1, 2]), - np.array([2.5, 0.5, 0.5, 1.5]), -] - -TEST_CASE_5 = [ - {"orig_labels": [1.5, 2.5, 3.5], "target_labels": [0, 1, 2], "dtype": np.int8}, - np.array([3.5, 1.5, 1.5, 2.5]), - np.array([2, 0, 0, 1]), -] - -TEST_CASE_6 = [ - {"orig_labels": ["label3", "label2", "label1"], "target_labels": [0, 1, 2]}, - np.array([["label3", "label1"], ["label1", "label2"]]), - np.array([[0, 2], [2, 1]]), -] - -TEST_CASE_7 = [ - {"orig_labels": [3.5, 2.5, 1.5], "target_labels": ["label0", "label1", "label2"], "dtype": "str"}, - np.array([[3.5, 1.5], [1.5, 2.5]]), - np.array([["label0", "label2"], ["label2", "label1"]]), -] - -TEST_CASE_8 = [ - {"orig_labels": ["label3", "label2", "label1"], "target_labels": ["label1", "label2", "label3"], "dtype": "str"}, - np.array([["label3", "label1"], ["label1", "label2"]]), - np.array([["label1", "label3"], ["label3", "label2"]]), -] - - -class TestMapLabelValue(unittest.TestCase): - @parameterized.expand( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.extend( [ - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - TEST_CASE_7, - TEST_CASE_8, + [{"orig_labels": [3, 2, 1], "target_labels": [0, 1, 2]}, p([[3, 1], [1, 2]]), p([[0.0, 2.0], [2.0, 1.0]])], + [ + {"orig_labels": [3, 5, 8], "target_labels": [0, 1, 2]}, + p([[[3], [5], [5], [8]]]), + p([[[0.0], [1.0], [1.0], [2.0]]]), + ], + [{"orig_labels": [1, 2, 3], "target_labels": [0, 1, 2]}, p([3, 1, 1, 2]), p([2.0, 0.0, 0.0, 1.0])], + [{"orig_labels": [1, 2, 3], "target_labels": [0.5, 1.5, 2.5]}, p([3, 1, 1, 2]), p([2.5, 0.5, 0.5, 1.5])], ] ) + # note: PyTorch 1.5.1 doesn't support rich dtypes + TESTS.append( + [ + {"orig_labels": [1.5, 2.5, 3.5], "target_labels": [0, 1, 2], "dtype": np.int8}, + p([3.5, 1.5, 1.5, 2.5]), + p([2, 0, 0, 1]), + ] + ) +TESTS.extend( + [ + [ + {"orig_labels": ["label3", "label2", "label1"], "target_labels": [0, 1, 2]}, + np.array([["label3", "label1"], ["label1", "label2"]]), + np.array([[0, 2], [2, 1]]), + ], + [ + {"orig_labels": [3.5, 2.5, 1.5], "target_labels": ["label0", "label1", "label2"], "dtype": "str"}, + np.array([[3.5, 1.5], [1.5, 2.5]]), + np.array([["label0", "label2"], ["label2", "label1"]]), + ], + [ + { + "orig_labels": ["label3", "label2", "label1"], + "target_labels": ["label1", "label2", "label3"], + "dtype": "str", + }, + np.array([["label3", "label1"], ["label1", "label2"]]), + np.array([["label1", "label3"], ["label3", "label2"]]), + ], + ] +) + + +class TestMapLabelValue(unittest.TestCase): + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_value): result = MapLabelValue(**input_param)(input_data) - np.testing.assert_equal(result, expected_value) + if isinstance(expected_value, torch.Tensor): + torch.testing.assert_allclose(result, expected_value) + else: + np.testing.assert_equal(result, expected_value) self.assertTupleEqual(result.shape, expected_value.shape) diff --git a/tests/test_map_label_valued.py b/tests/test_map_label_valued.py index 426ac28836..cf8ca6c8e2 100644 --- a/tests/test_map_label_valued.py +++ b/tests/test_map_label_valued.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_map_transform.py b/tests/test_map_transform.py index 803e699a7d..dd77ccb099 100644 --- a/tests/test_map_transform.py +++ b/tests/test_map_transform.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_mask_intensity.py b/tests/test_mask_intensity.py index a3662eec49..b6cfe0e10c 100644 --- a/tests/test_mask_intensity.py +++ b/tests/test_mask_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import MaskIntensity @@ -43,9 +44,15 @@ np.array([[[0, 0, 0], [2, 2, 2], [0, 0, 0]], [[0, 0, 0], [5, 5, 5], [0, 0, 0]]]), ] +TEST_CASE_5 = [ + {"mask_data": np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [0, 1, 0], [0, 1, 0]]])}, + torch.as_tensor([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + torch.as_tensor([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]), +] + class TestMaskIntensity(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_value(self, argments, image, expected_data): result = MaskIntensity(**argments)(image) np.testing.assert_allclose(result, expected_data) diff --git a/tests/test_mask_intensityd.py b/tests/test_mask_intensityd.py index c21e26eba6..fe61e7be04 100644 --- a/tests/test_mask_intensityd.py +++ b/tests/test_mask_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_masked_dice_loss.py b/tests/test_masked_dice_loss.py index b8d69bc8f9..317da3a316 100644 --- a/tests/test_masked_dice_loss.py +++ b/tests/test_masked_dice_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -67,7 +67,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - [[0.296529, 0.415136], [0.599976, 0.428559]], + [[[0.296529], [0.415136]], [[0.599976], [0.428559]]], ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": True, "to_onehot_y": True, "softmax": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, @@ -94,26 +94,17 @@ ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.307576, ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "squared_pred": True, "smooth_nr": 1e-5, "smooth_dr": 1e-5}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.178337, ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "jaccard": True, "smooth_nr": 1e-5, "smooth_dr": 1e-5}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.470451, ], ] diff --git a/tests/test_masked_inference_wsi_dataset.py b/tests/test_masked_inference_wsi_dataset.py index 361c17e106..6cec5b4304 100644 --- a/tests/test_masked_inference_wsi_dataset.py +++ b/tests/test_masked_inference_wsi_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,11 +22,11 @@ from monai.utils import optional_import from tests.utils import skip_if_quick -_, has_cim = optional_import("cucim") +_, has_cim = optional_import("cucim", name="CuImage") _, has_osl = optional_import("openslide") -FILE_URL = "http://openslide.cs.cmu.edu/download/openslide-testdata/Generic-TIFF/CMU-1.tiff" -base_name, extension = os.path.splitext(os.path.basename(FILE_URL)) +FILE_URL = "https://drive.google.com/uc?id=1sGTKZlJBIz53pfqTxoTqiIQzIoEzHLAe" +base_name, extension = FILE_URL.split("id=")[1], ".tiff" FILE_NAME = "temp_" + base_name FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", FILE_NAME + extension) @@ -50,28 +50,12 @@ def prepare_data(): TEST_CASE_0 = [ - { - "data": [ - {"image": FILE_PATH, "mask": MASK1}, - ], - "patch_size": 1, - "image_reader_name": "cuCIM", - }, - [ - { - "image": np.array([[[243]], [[243]], [[243]]], dtype=np.uint8), - "name": FILE_NAME, - "mask_location": [100, 100], - }, - ], + {"data": [{"image": FILE_PATH, "mask": MASK1}], "patch_size": 1, "image_reader_name": "cuCIM"}, + [{"image": np.array([[[243]], [[243]], [[243]]], dtype=np.uint8), "name": FILE_NAME, "mask_location": [100, 100]}], ] TEST_CASE_1 = [ - { - "data": [{"image": FILE_PATH, "mask": MASK2}], - "patch_size": 1, - "image_reader_name": "cuCIM", - }, + {"data": [{"image": FILE_PATH, "mask": MASK2}], "patch_size": 1, "image_reader_name": "cuCIM"}, [ { "image": np.array([[[243]], [[243]], [[243]]], dtype=np.uint8), @@ -87,11 +71,7 @@ def prepare_data(): ] TEST_CASE_2 = [ - { - "data": [{"image": FILE_PATH, "mask": MASK4}], - "patch_size": 1, - "image_reader_name": "cuCIM", - }, + {"data": [{"image": FILE_PATH, "mask": MASK4}], "patch_size": 1, "image_reader_name": "cuCIM"}, [ { "image": np.array([[[243]], [[243]], [[243]]], dtype=np.uint8), @@ -117,35 +97,21 @@ def prepare_data(): ] TEST_CASE_3 = [ - { - "data": [ - {"image": FILE_PATH, "mask": MASK1}, - ], - "patch_size": 2, - "image_reader_name": "cuCIM", - }, + {"data": [{"image": FILE_PATH, "mask": MASK1}], "patch_size": 2, "image_reader_name": "cuCIM"}, [ { "image": np.array( - [ - [[243, 243], [243, 243]], - [[243, 243], [243, 243]], - [[243, 243], [243, 243]], - ], - dtype=np.uint8, + [[[243, 243], [243, 243]], [[243, 243], [243, 243]], [[243, 243], [243, 243]]], dtype=np.uint8 ), "name": FILE_NAME, "mask_location": [100, 100], - }, + } ], ] TEST_CASE_4 = [ { - "data": [ - {"image": FILE_PATH, "mask": MASK1}, - {"image": FILE_PATH, "mask": MASK2}, - ], + "data": [{"image": FILE_PATH, "mask": MASK1}, {"image": FILE_PATH, "mask": MASK2}], "patch_size": 1, "image_reader_name": "cuCIM", }, @@ -170,28 +136,12 @@ def prepare_data(): TEST_CASE_OPENSLIDE_0 = [ - { - "data": [ - {"image": FILE_PATH, "mask": MASK1}, - ], - "patch_size": 1, - "image_reader_name": "OpenSlide", - }, - [ - { - "image": np.array([[[243]], [[243]], [[243]]], dtype=np.uint8), - "name": FILE_NAME, - "mask_location": [100, 100], - }, - ], + {"data": [{"image": FILE_PATH, "mask": MASK1}], "patch_size": 1, "image_reader_name": "OpenSlide"}, + [{"image": np.array([[[243]], [[243]], [[243]]], dtype=np.uint8), "name": FILE_NAME, "mask_location": [100, 100]}], ] TEST_CASE_OPENSLIDE_1 = [ - { - "data": [{"image": FILE_PATH, "mask": MASK2}], - "patch_size": 1, - "image_reader_name": "OpenSlide", - }, + {"data": [{"image": FILE_PATH, "mask": MASK2}], "patch_size": 1, "image_reader_name": "OpenSlide"}, [ { "image": np.array([[[243]], [[243]], [[243]]], dtype=np.uint8), @@ -212,27 +162,14 @@ def setUp(self): prepare_data() download_url(FILE_URL, FILE_PATH, "5a3cfd4fd725c50578ddb80b517b759f") - @parameterized.expand( - [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - ] - ) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) @skipUnless(has_cim, "Requires CuCIM") @skip_if_quick def test_read_patches_cucim(self, input_parameters, expected): dataset = MaskedInferenceWSIDataset(**input_parameters) self.compare_samples_expected(dataset, expected) - @parameterized.expand( - [ - TEST_CASE_OPENSLIDE_0, - TEST_CASE_OPENSLIDE_1, - ] - ) + @parameterized.expand([TEST_CASE_OPENSLIDE_0, TEST_CASE_OPENSLIDE_1]) @skipUnless(has_osl, "Requires OpenSlide") @skip_if_quick def test_read_patches_openslide(self, input_parameters, expected): diff --git a/tests/test_masked_loss.py b/tests/test_masked_loss.py index 225e3d9668..9f28d51aa4 100644 --- a/tests/test_masked_loss.py +++ b/tests/test_masked_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -33,7 +33,7 @@ "reduction": "sum", }, [(14.538666, 20.191753), (13.17672, 8.251623)], - ], + ] ] diff --git a/tests/test_matshow3d.py b/tests/test_matshow3d.py new file mode 100644 index 0000000000..83984d1556 --- /dev/null +++ b/tests/test_matshow3d.py @@ -0,0 +1,109 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import numpy as np + +from monai.transforms import AddChanneld, Compose, LoadImaged, RandSpatialCropSamplesd, RepeatChanneld, ScaleIntensityd +from monai.utils import optional_import +from monai.visualize.utils import matshow3d +from tests.utils import SkipIfNoModule + +compare_images, _ = optional_import("matplotlib.testing.compare", name="compare_images") +pyplot, has_pyplot = optional_import("matplotlib", name="pyplot") + + +@SkipIfNoModule("matplotlib") +class TestMatshow3d(unittest.TestCase): + def test_3d(self): + testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") + keys = "image" + xforms = Compose([LoadImaged(keys=keys), AddChanneld(keys=keys), ScaleIntensityd(keys=keys)]) + image_path = os.path.join(testing_dir, "anatomical.nii") + ims = xforms({keys: image_path}) + + fig = pyplot.figure() # external figure + fig, _ = matshow3d(ims[keys], fig=fig, figsize=(2, 2), frames_per_row=5, every_n=2, frame_dim=-1, show=False) + + with tempfile.TemporaryDirectory() as tempdir: + tempimg = f"{tempdir}/matshow3d_test.png" + fig.savefig(tempimg) + comp = compare_images(f"{testing_dir}/matshow3d_test.png", tempimg, 5e-2) + self.assertIsNone(comp, f"value of comp={comp}") # None indicates test passed + + def test_samples(self): + testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") + keys = "image" + xforms = Compose( + [ + LoadImaged(keys=keys), + AddChanneld(keys=keys), + ScaleIntensityd(keys=keys), + RandSpatialCropSamplesd(keys=keys, roi_size=(8, 8, 5), random_size=True, num_samples=10), + ] + ) + image_path = os.path.join(testing_dir, "anatomical.nii") + xforms.set_random_state(0) + ims = xforms({keys: image_path}) + fig, mat = matshow3d( + [im[keys] for im in ims], title=f"testing {keys}", figsize=(2, 2), frames_per_row=5, every_n=2, show=False + ) + self.assertTrue(mat.dtype == np.float32) + + with tempfile.TemporaryDirectory() as tempdir: + tempimg = f"{tempdir}/matshow3d_patch_test.png" + fig.savefig(tempimg) + comp = compare_images(f"{testing_dir}/matshow3d_patch_test.png", tempimg, 5e-2, in_decorator=True) + if comp: + print("not none comp: ", comp) # matplotlib 3.2.2 + np.testing.assert_allclose(comp["rms"], 30.786983, atol=1e-3, rtol=1e-3) + else: + self.assertIsNone(comp, f"value of comp={comp}") # None indicates test passed + + def test_3d_rgb(self): + testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") + keys = "image" + xforms = Compose( + [ + LoadImaged(keys=keys), + AddChanneld(keys=keys), + ScaleIntensityd(keys=keys), + # change to RGB color image + RepeatChanneld(keys=keys, repeats=3), + ] + ) + image_path = os.path.join(testing_dir, "anatomical.nii") + ims = xforms({keys: image_path}) + + fig = pyplot.figure() # external figure + fig, _ = matshow3d( + volume=ims[keys], + fig=fig, + figsize=(2, 2), + frames_per_row=5, + every_n=2, + frame_dim=-1, + channel_dim=0, + show=False, + ) + + with tempfile.TemporaryDirectory() as tempdir: + tempimg = f"{tempdir}/matshow3d_rgb_test.png" + fig.savefig(tempimg) + comp = compare_images(f"{testing_dir}/matshow3d_rgb_test.png", tempimg, 5e-2) + self.assertIsNone(comp, f"value of comp={comp}") # None indicates test passed + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_mean_ensemble.py b/tests/test_mean_ensemble.py index 7e08846beb..b14f6f01d3 100644 --- a/tests/test_mean_ensemble.py +++ b/tests/test_mean_ensemble.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,49 +16,50 @@ from parameterized import parameterized from monai.transforms import MeanEnsemble +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"weights": None}, - [torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2], - torch.ones(2, 2, 2) + 1, -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([{"weights": None}, [p(torch.ones(2, 2, 2)), p(torch.ones(2, 2, 2)) + 2], p(torch.ones(2, 2, 2)) + 1]) -TEST_CASE_2 = [ - {"weights": None}, - torch.stack([torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2]), - torch.ones(2, 2, 2) + 1, -] + TESTS.append( + [{"weights": None}, p(torch.stack([torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2])), p(torch.ones(2, 2, 2)) + 1] + ) -TEST_CASE_3 = [ - {"weights": [1, 3]}, - [torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2], - torch.ones(2, 2, 2) * 2.5, -] + TESTS.append( + [{"weights": [1, 3]}, [p(torch.ones(2, 2, 2)), p(torch.ones(2, 2, 2)) + 2], p(torch.ones(2, 2, 2)) * 2.5] + ) -TEST_CASE_4 = [ - {"weights": [[1, 3], [3, 1]]}, - [torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2], - torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1), -] + TESTS.append( + [ + {"weights": [[1, 3], [3, 1]]}, + [p(torch.ones(2, 2, 2)), p(torch.ones(2, 2, 2)) + 2], + p(torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1)), + ] + ) -TEST_CASE_5 = [ - {"weights": np.array([[1, 3], [3, 1]])}, - [torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2], - torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1), -] + TESTS.append( + [ + {"weights": np.array([[1, 3], [3, 1]])}, + [p(torch.ones(2, 2, 2)), p(torch.ones(2, 2, 2)) + 2], + p(torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1)), + ] + ) -TEST_CASE_6 = [ - {"weights": torch.tensor([[[1, 3]], [[3, 1]]])}, - [torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2) + 2], - torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1), -] + TESTS.append( + [ + {"weights": torch.tensor([[[1, 3]], [[3, 1]]])}, + [p(torch.ones(2, 2, 2, 2)), p(torch.ones(2, 2, 2, 2)) + 2], + p(torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1)), + ] + ) class TestMeanEnsemble(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand(TESTS) def test_value(self, input_param, img, expected_value): result = MeanEnsemble(**input_param)(img) - torch.testing.assert_allclose(result, expected_value) + assert_allclose(result, expected_value) def test_cuda_value(self): img = torch.stack([torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2) + 2]) diff --git a/tests/test_mean_ensembled.py b/tests/test_mean_ensembled.py index ea77ef18a0..b5e1569d65 100644 --- a/tests/test_mean_ensembled.py +++ b/tests/test_mean_ensembled.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,46 +16,61 @@ from parameterized import parameterized from monai.transforms import MeanEnsembled +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"keys": ["pred0", "pred1"], "output_key": "output", "weights": None}, - {"pred0": torch.ones(2, 2, 2), "pred1": torch.ones(2, 2, 2) + 2}, - torch.ones(2, 2, 2) + 1, -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["pred0", "pred1"], "output_key": "output", "weights": None}, + {"pred0": p(torch.ones(2, 2, 2)), "pred1": p(torch.ones(2, 2, 2)) + 2}, + p(torch.ones(2, 2, 2)) + 1, + ] + ) -TEST_CASE_2 = [ - {"keys": "output", "weights": None}, - {"output": torch.stack([torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2])}, - torch.ones(2, 2, 2) + 1, -] + TESTS.append( + [ + {"keys": "output", "weights": None}, + {"output": p(torch.stack([torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2]))}, + p(torch.ones(2, 2, 2)) + 1, + ] + ) -TEST_CASE_3 = [ - {"keys": ["pred0", "pred1"], "output_key": "output", "weights": [1, 3]}, - {"pred0": torch.ones(2, 2, 2, 2), "pred1": torch.ones(2, 2, 2, 2) + 2}, - torch.ones(2, 2, 2, 2) * 2.5, -] + TESTS.append( + [ + {"keys": ["pred0", "pred1"], "output_key": "output", "weights": [1, 3]}, + {"pred0": p(torch.ones(2, 2, 2, 2)), "pred1": p(torch.ones(2, 2, 2, 2)) + 2}, + p(torch.ones(2, 2, 2, 2)) * 2.5, + ] + ) -TEST_CASE_4 = [ - {"keys": ["pred0", "pred1"], "output_key": "output", "weights": [[1, 3], [3, 1]]}, - {"pred0": torch.ones(2, 2, 2), "pred1": torch.ones(2, 2, 2) + 2}, - torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1), -] + TESTS.append( + [ + {"keys": ["pred0", "pred1"], "output_key": "output", "weights": [[1, 3], [3, 1]]}, + {"pred0": p(torch.ones(2, 2, 2)), "pred1": p(torch.ones(2, 2, 2)) + 2}, + p(torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1)), + ] + ) -TEST_CASE_5 = [ - {"keys": ["pred0", "pred1"], "output_key": "output", "weights": np.array([[[1, 3]], [[3, 1]]])}, - {"pred0": torch.ones(2, 2, 2, 2), "pred1": torch.ones(2, 2, 2, 2) + 2}, - torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1), -] + TESTS.append( + [ + {"keys": ["pred0", "pred1"], "output_key": "output", "weights": np.array([[[1, 3]], [[3, 1]]])}, + {"pred0": p(torch.ones(2, 2, 2, 2)), "pred1": p(torch.ones(2, 2, 2, 2)) + 2}, + p(torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1)), + ] + ) -TEST_CASE_6 = [ - {"keys": ["pred0", "pred1"], "output_key": "output", "weights": torch.tensor([[[1, 3]], [[3, 1]]])}, - {"pred0": torch.ones(2, 2, 2, 2), "pred1": torch.ones(2, 2, 2, 2) + 2}, - torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1), -] + TESTS.append( + [ + {"keys": ["pred0", "pred1"], "output_key": "output", "weights": torch.tensor([[[1, 3]], [[3, 1]]])}, + {"pred0": p(torch.ones(2, 2, 2, 2)), "pred1": p(torch.ones(2, 2, 2, 2)) + 2}, + p(torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1)), + ] + ) class TestMeanEnsembled(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand(TESTS) def test_value(self, input_param, data, expected_value): result = MeanEnsembled(**input_param)(data) torch.testing.assert_allclose(result["output"], expected_value) @@ -67,7 +82,7 @@ def test_cuda_value(self): img = img.to(torch.device("cuda:0")) expected_value = expected_value.to(torch.device("cuda:0")) result = MeanEnsembled(keys="output", weights=torch.tensor([[[1, 3]], [[3, 1]]]))({"output": img}) - torch.testing.assert_allclose(result["output"], expected_value) + assert_allclose(result["output"], expected_value) if __name__ == "__main__": diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py index 2e27f4ba95..54fb11135a 100644 --- a/tests/test_mednistdataset.py +++ b/tests/test_mednistdataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,6 +12,7 @@ import os import shutil import unittest +from pathlib import Path from urllib.error import ContentTooShortError, HTTPError from monai.apps import MedNISTDataset @@ -42,7 +43,9 @@ def _test_dataset(dataset): self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64)) try: # will start downloading if testing_dir doesn't have the MedNIST files - data = MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=True) + data = MedNISTDataset( + root_dir=testing_dir, transform=transform, section="test", download=True, copy_cache=False + ) except (ContentTooShortError, HTTPError, RuntimeError) as e: print(str(e)) if isinstance(e, RuntimeError): @@ -53,17 +56,19 @@ def _test_dataset(dataset): _test_dataset(data) # testing from - data = MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=False) - data.get_num_classes() + data = MedNISTDataset(root_dir=Path(testing_dir), transform=transform, section="test", download=False) + self.assertEqual(data.get_num_classes(), 6) _test_dataset(data) data = MedNISTDataset(root_dir=testing_dir, section="test", download=False) self.assertTupleEqual(data[0]["image"].shape, (64, 64)) # test same dataset length with different random seed data = MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=False, seed=42) _test_dataset(data) + self.assertEqual(data[0]["class_name"], "AbdomenCT") + self.assertEqual(data[0]["label"].cpu().item(), 0) shutil.rmtree(os.path.join(testing_dir, "MedNIST")) try: - data = MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=False) + MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=False) except RuntimeError as e: print(str(e)) self.assertTrue(str(e).startswith("Cannot find dataset directory")) diff --git a/tests/test_milmodel.py b/tests/test_milmodel.py new file mode 100644 index 0000000000..ad04e96c60 --- /dev/null +++ b/tests/test_milmodel.py @@ -0,0 +1,91 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import MILModel +from monai.utils.module import optional_import +from tests.utils import test_script_save + +models, _ = optional_import("torchvision.models") + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +TEST_CASE_MILMODEL = [] +for num_classes in [1, 5]: + for mil_mode in ["mean", "max", "att", "att_trans", "att_trans_pyramid"]: + test_case = [ + {"num_classes": num_classes, "mil_mode": mil_mode, "pretrained": False}, + (1, 2, 3, 512, 512), + (1, num_classes), + ] + TEST_CASE_MILMODEL.append(test_case) + + +for trans_blocks in [1, 3]: + test_case = [ + {"num_classes": 5, "pretrained": False, "trans_blocks": trans_blocks, "trans_dropout": 0.5}, + (1, 2, 3, 512, 512), + (1, 5), + ] + TEST_CASE_MILMODEL.append(test_case) + +# torchvision backbone +TEST_CASE_MILMODEL.append( + [{"num_classes": 5, "backbone": "resnet18", "pretrained": False}, (2, 2, 3, 512, 512), (2, 5)] +) +TEST_CASE_MILMODEL.append([{"num_classes": 5, "backbone": "resnet18", "pretrained": True}, (2, 2, 3, 512, 512), (2, 5)]) + +# custom backbone +backbone = models.densenet121(pretrained=False) +backbone_nfeatures = backbone.classifier.in_features +backbone.classifier = torch.nn.Identity() +TEST_CASE_MILMODEL.append( + [ + {"num_classes": 5, "backbone": backbone, "backbone_num_features": backbone_nfeatures, "pretrained": False}, + (2, 2, 3, 512, 512), + (2, 5), + ] +) + + +class TestMilModel(unittest.TestCase): + @parameterized.expand(TEST_CASE_MILMODEL) + def test_shape(self, input_param, input_shape, expected_shape): + net = MILModel(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape, dtype=torch.float).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_args(self): + with self.assertRaises(ValueError): + MILModel( + num_classes=5, + pretrained=False, + backbone="resnet50", + backbone_num_features=2048, + mil_mode="att_trans_pyramid", + ) + + def test_script(self): + input_param, input_shape, expected_shape = TEST_CASE_MILMODEL[0] + net = MILModel(**input_param) + test_data = torch.randn(input_shape, dtype=torch.float) + test_script_save(net, test_data) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_mlp.py b/tests/test_mlp.py index 7a93f81ec3..6fec5b6854 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,11 +24,7 @@ for mlp_dim in [512, 1028, 2048, 3072]: test_case = [ - { - "hidden_size": hidden_size, - "mlp_dim": mlp_dim, - "dropout_rate": dropout_rate, - }, + {"hidden_size": hidden_size, "mlp_dim": mlp_dim, "dropout_rate": dropout_rate}, (2, 512, hidden_size), (2, 512, hidden_size), ] diff --git a/tests/test_mmar_download.py b/tests/test_mmar_download.py index 6952e62c3c..2cae5969db 100644 --- a/tests/test_mmar_download.py +++ b/tests/test_mmar_download.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,6 +12,7 @@ import os import tempfile import unittest +from pathlib import Path from urllib.error import ContentTooShortError, HTTPError import numpy as np @@ -21,7 +22,7 @@ from monai.apps import RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar from monai.apps.mmars import MODEL_DESC from monai.apps.mmars.mmars import _get_val -from tests.utils import SkipIfAtLeastPyTorchVersion, SkipIfBeforePyTorchVersion, skip_if_quick +from tests.utils import skip_if_quick TEST_CASES = [["clara_pt_prostate_mri_segmentation_1"], ["clara_pt_covid19_ct_lesion_segmentation_1"]] TEST_EXTRACT_CASES = [ @@ -103,17 +104,17 @@ class TestMMMARDownload(unittest.TestCase): @parameterized.expand(TEST_CASES) @skip_if_quick - @SkipIfBeforePyTorchVersion((1, 6)) def test_download(self, idx): try: # test model specification cand = get_model_spec(idx) self.assertEqual(cand[RemoteMMARKeys.ID], idx) - download_mmar(idx) + with self.assertLogs(level="INFO", logger="monai.apps"): + download_mmar(idx) download_mmar(idx, progress=False) # repeated to check caching with tempfile.TemporaryDirectory() as tmp_dir: download_mmar(idx, mmar_dir=tmp_dir, progress=False) - download_mmar(idx, mmar_dir=tmp_dir, progress=False, version=1) # repeated to check caching + download_mmar(idx, mmar_dir=Path(tmp_dir), progress=False, version=1) # repeated to check caching self.assertTrue(os.path.exists(os.path.join(tmp_dir, idx))) except (ContentTooShortError, HTTPError, RuntimeError) as e: print(str(e)) @@ -123,7 +124,6 @@ def test_download(self, idx): @parameterized.expand(TEST_EXTRACT_CASES) @skip_if_quick - @SkipIfBeforePyTorchVersion((1, 6)) def test_load_ckpt(self, input_args, expected_name, expected_val): try: output = load_from_mmar(**input_args) @@ -138,14 +138,9 @@ def test_load_ckpt(self, input_args, expected_name, expected_val): def test_unique(self): # model ids are unique - keys = sorted([m["id"] for m in MODEL_DESC]) + keys = sorted(m["id"] for m in MODEL_DESC) self.assertTrue(keys == sorted(set(keys))) - @SkipIfAtLeastPyTorchVersion((1, 6)) - def test_no_default(self): - with self.assertRaises(ValueError): - download_mmar(0) - def test_search(self): self.assertEqual(_get_val({"a": 1, "b": 2}, key="b"), 2) self.assertEqual(_get_val({"a": {"c": {"c": 4}}, "b": {"c": 2}}, key="b"), {"c": 2}) diff --git a/tests/test_module_list.py b/tests/test_module_list.py index 3aefaf5e0c..5ec4aa9ff1 100644 --- a/tests/test_module_list.py +++ b/tests/test_module_list.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_multi_scale.py b/tests/test_multi_scale.py index 01a760db72..963824f25e 100644 --- a/tests/test_multi_scale.py +++ b/tests/test_multi_scale.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_net_adapter.py b/tests/test_net_adapter.py index b2d55129a7..0d73499a6d 100644 --- a/tests/test_net_adapter.py +++ b/tests/test_net_adapter.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,23 +19,11 @@ device = "cuda" if torch.cuda.is_available() else "cpu" -TEST_CASE_0 = [ - {"num_classes": 1, "use_conv": True, "dim": 2}, - (2, 3, 224, 224), - (2, 1, 8, 1), -] +TEST_CASE_0 = [{"num_classes": 1, "use_conv": True, "dim": 2}, (2, 3, 224, 224), (2, 1, 8, 1)] -TEST_CASE_1 = [ - {"num_classes": 1, "use_conv": True, "dim": 3, "pool": None}, - (2, 3, 32, 32, 32), - (2, 1, 1, 1, 1), -] +TEST_CASE_1 = [{"num_classes": 1, "use_conv": True, "dim": 3, "pool": None}, (2, 3, 32, 32, 32), (2, 1, 1, 1, 1)] -TEST_CASE_2 = [ - {"num_classes": 5, "use_conv": True, "dim": 3, "pool": None}, - (2, 3, 32, 32, 32), - (2, 5, 1, 1, 1), -] +TEST_CASE_2 = [{"num_classes": 5, "use_conv": True, "dim": 3, "pool": None}, (2, 3, 32, 32, 32), (2, 5, 1, 1, 1)] TEST_CASE_3 = [ {"num_classes": 5, "use_conv": True, "pool": ("avg", {"kernel_size": 4, "stride": 1}), "dim": 3}, diff --git a/tests/test_network_consistency.py b/tests/test_network_consistency.py index 9698a40116..419e1202d0 100644 --- a/tests/test_network_consistency.py +++ b/tests/test_network_consistency.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,8 +20,9 @@ from parameterized.parameterized import parameterized import monai.networks.nets as nets +from monai.utils import set_determinism -extra_test_data_dir = os.environ.get("MONAI_EXTRA_TEST_DATA", None) +extra_test_data_dir = os.environ.get("MONAI_EXTRA_TEST_DATA") TESTS = [] if extra_test_data_dir is not None: @@ -33,6 +34,12 @@ class TestNetworkConsistency(unittest.TestCase): + def setUp(self): + set_determinism(0) + + def tearDown(self): + set_determinism(None) + @skipIf( len(TESTS) == 0, "To run these tests, clone https://github.com/Project-MONAI/MONAI-extra-test-data and set MONAI_EXTRA_TEST_DATA", @@ -53,8 +60,8 @@ def test_network_consistency(self, net_name, data_path, json_path): json_file.close() # Create model - model = nets.__dict__[net_name](**model_params) - model.load_state_dict(loaded_data["model"]) + model = getattr(nets, net_name)(**model_params) + model.load_state_dict(loaded_data["model"], strict=False) model.eval() in_data = loaded_data["in_data"] diff --git a/tests/test_nifti_endianness.py b/tests/test_nifti_endianness.py index bf0f27b9ca..06b2d803d8 100644 --- a/tests/test_nifti_endianness.py +++ b/tests/test_nifti_endianness.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_nifti_header_revise.py b/tests/test_nifti_header_revise.py index 8d9a1d4f3a..7f917cb0e9 100644 --- a/tests/test_nifti_header_revise.py +++ b/tests/test_nifti_header_revise.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_nifti_rw.py b/tests/test_nifti_rw.py index f16d80659c..1322fa6a45 100644 --- a/tests/test_nifti_rw.py +++ b/tests/test_nifti_rw.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,54 +19,66 @@ from monai.data import write_nifti from monai.transforms import LoadImage, Orientation, Spacing -from tests.utils import make_nifti_image - -TEST_IMAGE = np.arange(24).reshape((2, 4, 3)) -TEST_AFFINE = np.array( - [[-5.3, 0.0, 0.0, 102.01], [0.0, 0.52, 2.17, -7.50], [-0.0, 1.98, -0.26, -23.12], [0.0, 0.0, 0.0, 1.0]] -) - -TEST_CASES = [ - [ - TEST_IMAGE, - TEST_AFFINE, - dict(reader="NibabelReader", image_only=False, as_closest_canonical=True), - np.arange(24).reshape((2, 4, 3)), - ], - [ - TEST_IMAGE, - TEST_AFFINE, - dict(reader="NibabelReader", image_only=True, as_closest_canonical=True), - np.array( +from tests.utils import TEST_NDARRAYS, assert_allclose, make_nifti_image + +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + TEST_IMAGE = p(np.arange(24).reshape((2, 4, 3))) + TEST_AFFINE = q( + np.array( + [[-5.3, 0.0, 0.0, 102.01], [0.0, 0.52, 2.17, -7.50], [-0.0, 1.98, -0.26, -23.12], [0.0, 0.0, 0.0, 1.0]] + ) + ) + TESTS.append( + [ + TEST_IMAGE, + TEST_AFFINE, + dict(reader="NibabelReader", image_only=False, as_closest_canonical=True), + np.arange(24).reshape((2, 4, 3)), + ] + ) + TESTS.append( + [ + TEST_IMAGE, + TEST_AFFINE, + dict(reader="NibabelReader", image_only=True, as_closest_canonical=True), + np.array( + [ + [[12.0, 15.0, 18.0, 21.0], [13.0, 16.0, 19.0, 22.0], [14.0, 17.0, 20.0, 23.0]], + [[0.0, 3.0, 6.0, 9.0], [1.0, 4.0, 7.0, 10.0], [2.0, 5.0, 8.0, 11.0]], + ] + ), + ] + ) + TESTS.append( [ - [[12.0, 15.0, 18.0, 21.0], [13.0, 16.0, 19.0, 22.0], [14.0, 17.0, 20.0, 23.0]], - [[0.0, 3.0, 6.0, 9.0], [1.0, 4.0, 7.0, 10.0], [2.0, 5.0, 8.0, 11.0]], + TEST_IMAGE, + TEST_AFFINE, + dict(reader="NibabelReader", image_only=True, as_closest_canonical=False), + np.arange(24).reshape((2, 4, 3)), ] - ), - ], - [ - TEST_IMAGE, - TEST_AFFINE, - dict(reader="NibabelReader", image_only=True, as_closest_canonical=False), - np.arange(24).reshape((2, 4, 3)), - ], - [ - TEST_IMAGE, - TEST_AFFINE, - dict(reader="NibabelReader", image_only=False, as_closest_canonical=False), - np.arange(24).reshape((2, 4, 3)), - ], - [ - TEST_IMAGE, - None, - dict(reader="NibabelReader", image_only=False, as_closest_canonical=False), - np.arange(24).reshape((2, 4, 3)), - ], -] + ) + TESTS.append( + [ + TEST_IMAGE, + TEST_AFFINE, + dict(reader="NibabelReader", image_only=False, as_closest_canonical=False), + np.arange(24).reshape((2, 4, 3)), + ] + ) + TESTS.append( + [ + TEST_IMAGE, + None, + dict(reader="NibabelReader", image_only=False, as_closest_canonical=False), + np.arange(24).reshape((2, 4, 3)), + ] + ) class TestNiftiLoadRead(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_orientation(self, array, affine, reader_param, expected): test_image = make_nifti_image(array, affine) @@ -93,8 +105,8 @@ def test_orientation(self, array, affine, reader_param, expected): os.remove(test_image) if affine is not None: - np.testing.assert_allclose(saved_affine, affine) - np.testing.assert_allclose(saved_data, expected) + assert_allclose(saved_affine, affine, type_test=False) + assert_allclose(saved_data, expected, type_test=False) def test_consistency(self): np.set_printoptions(suppress=True, precision=3) @@ -140,69 +152,81 @@ def test_consistency(self): def test_write_2d(self): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.nii.gz") - img = np.arange(6).reshape((2, 3)) - write_nifti(img, image_name, affine=np.diag([1]), target_affine=np.diag([1.4])) - out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[0, 1, 2], [3.0, 4, 5]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) - - image_name = os.path.join(out_dir, "test1.nii.gz") - img = np.arange(5).reshape((1, 5)) - write_nifti(img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 1, 3, 5])) - out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[0, 2, 4]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 1, 1])) + for p in TEST_NDARRAYS: + img = p(np.arange(6).reshape((2, 3))) + write_nifti(img, image_name, affine=np.diag([1]), target_affine=np.diag([1.4])) + out = nib.load(image_name) + np.testing.assert_allclose(out.get_fdata(), [[0, 1, 2], [3.0, 4, 5]]) + np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) + + image_name = os.path.join(out_dir, "test1.nii.gz") + img = np.arange(5).reshape((1, 5)) + write_nifti( + img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 1, 3, 5]) + ) + out = nib.load(image_name) + np.testing.assert_allclose(out.get_fdata(), [[0, 2, 4]]) + np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 1, 1])) def test_write_3d(self): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.nii.gz") - img = np.arange(6).reshape((1, 2, 3)) - write_nifti(img, image_name, affine=np.diag([1]), target_affine=np.diag([1.4])) - out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[[0, 1, 2], [3, 4, 5]]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) - - image_name = os.path.join(out_dir, "test1.nii.gz") - img = np.arange(5).reshape((1, 1, 5)) - write_nifti(img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 2, 3, 5])) - out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[[0, 2, 4]]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) + for p in TEST_NDARRAYS: + img = p(np.arange(6).reshape((1, 2, 3))) + write_nifti(img, image_name, affine=np.diag([1]), target_affine=np.diag([1.4])) + out = nib.load(image_name) + np.testing.assert_allclose(out.get_fdata(), [[[0, 1, 2], [3, 4, 5]]]) + np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) + + image_name = os.path.join(out_dir, "test1.nii.gz") + img = p(np.arange(5).reshape((1, 1, 5))) + write_nifti( + img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 2, 3, 5]) + ) + out = nib.load(image_name) + np.testing.assert_allclose(out.get_fdata(), [[[0, 2, 4]]]) + np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) def test_write_4d(self): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.nii.gz") - img = np.arange(6).reshape((1, 1, 3, 2)) - write_nifti(img, image_name, affine=np.diag([1.4, 1]), target_affine=np.diag([1, 1.4, 1])) - out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[[[0, 1], [2, 3], [4, 5]]]]) - np.testing.assert_allclose(out.affine, np.diag([1, 1.4, 1, 1])) - - image_name = os.path.join(out_dir, "test1.nii.gz") - img = np.arange(5).reshape((1, 1, 5, 1)) - write_nifti(img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 2, 3, 5])) - out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[[[0], [2], [4]]]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) + for p in TEST_NDARRAYS: + img = p(np.arange(6).reshape((1, 1, 3, 2))) + write_nifti(img, image_name, affine=np.diag([1.4, 1]), target_affine=np.diag([1, 1.4, 1])) + out = nib.load(image_name) + np.testing.assert_allclose(out.get_fdata(), [[[[0, 1], [2, 3], [4, 5]]]]) + np.testing.assert_allclose(out.affine, np.diag([1, 1.4, 1, 1])) + + image_name = os.path.join(out_dir, "test1.nii.gz") + img = p(np.arange(5).reshape((1, 1, 5, 1))) + write_nifti( + img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 2, 3, 5]) + ) + out = nib.load(image_name) + np.testing.assert_allclose(out.get_fdata(), [[[[0], [2], [4]]]]) + np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) def test_write_5d(self): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.nii.gz") - img = np.arange(12).reshape((1, 1, 3, 2, 2)) - write_nifti(img, image_name, affine=np.diag([1]), target_affine=np.diag([1.4])) - out = nib.load(image_name) - np.testing.assert_allclose( - out.get_fdata(), - np.array([[[[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]], [[8.0, 9.0], [10.0, 11.0]]]]]), - ) - np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) - - image_name = os.path.join(out_dir, "test1.nii.gz") - img = np.arange(10).reshape((1, 1, 5, 1, 2)) - write_nifti(img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 2, 3, 5])) - out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), np.array([[[[[0.0, 1.0]], [[4.0, 5.0]], [[8.0, 9.0]]]]])) - np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) + for p in TEST_NDARRAYS: + img = p(np.arange(12).reshape((1, 1, 3, 2, 2))) + write_nifti(img, image_name, affine=np.diag([1]), target_affine=np.diag([1.4])) + out = nib.load(image_name) + np.testing.assert_allclose( + out.get_fdata(), + np.array([[[[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]], [[8.0, 9.0], [10.0, 11.0]]]]]), + ) + np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) + + image_name = os.path.join(out_dir, "test1.nii.gz") + img = p(np.arange(10).reshape((1, 1, 5, 1, 2))) + write_nifti( + img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 2, 3, 5]) + ) + out = nib.load(image_name) + np.testing.assert_allclose(out.get_fdata(), np.array([[[[[0.0, 1.0]], [[4.0, 5.0]], [[8.0, 9.0]]]]])) + np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) if __name__ == "__main__": diff --git a/tests/test_nifti_saver.py b/tests/test_nifti_saver.py index c07084172f..6855a59041 100644 --- a/tests/test_nifti_saver.py +++ b/tests/test_nifti_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -36,12 +36,7 @@ def test_saved_content(self): def test_saved_resize_content(self): with tempfile.TemporaryDirectory() as tempdir: - saver = NiftiSaver( - output_dir=tempdir, - output_postfix="seg", - output_ext=".nii.gz", - dtype=np.float32, - ) + saver = NiftiSaver(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32) meta_data = { "filename_or_obj": ["testfile" + str(i) + ".nii" for i in range(8)], @@ -56,12 +51,7 @@ def test_saved_resize_content(self): def test_saved_3d_resize_content(self): with tempfile.TemporaryDirectory() as tempdir: - saver = NiftiSaver( - output_dir=tempdir, - output_postfix="seg", - output_ext=".nii.gz", - dtype=np.float32, - ) + saver = NiftiSaver(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32) meta_data = { "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], @@ -74,6 +64,25 @@ def test_saved_3d_resize_content(self): filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + def test_saved_3d_no_resize_content(self): + with tempfile.TemporaryDirectory() as tempdir: + + saver = NiftiSaver( + output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32, resample=False + ) + + meta_data = { + "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], + "spatial_shape": [(10, 10, 2)] * 8, + "affine": [np.diag(np.ones(4)) * 5] * 8, + "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, + } + saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join(tempdir, "testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") + img, _ = LoadImage("nibabelreader")(filepath) + self.assertEqual(img.shape, (1, 2, 2, 8)) + def test_squeeze_end_dims(self): with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/test_normalize_intensity.py b/tests/test_normalize_intensity.py index 2755eb4c25..5bcee1263b 100644 --- a/tests/test_normalize_intensity.py +++ b/tests/test_normalize_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -31,51 +31,51 @@ "divisor": u(np.array([0.5, 0.5, 0.5, 0.5])), "nonzero": True, }, - np.array([0.0, 3.0, 0.0, 4.0]), - np.array([0.0, -1.0, 0.0, 1.0]), + p(np.array([0.0, 3.0, 0.0, 4.0])), + p(np.array([0.0, -1.0, 0.0, 1.0])), ] ) - TESTS.append([p, {"nonzero": True}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])]) - TESTS.append([p, {"nonzero": False}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])]) - TESTS.append([p, {"nonzero": False}, np.array([1, 1, 1, 1]), np.array([0.0, 0.0, 0.0, 0.0])]) + TESTS.append([p, {"nonzero": True}, p(np.array([0.0, 0.0, 0.0, 0.0])), p(np.array([0.0, 0.0, 0.0, 0.0]))]) + TESTS.append([p, {"nonzero": False}, p(np.array([0.0, 0.0, 0.0, 0.0])), p(np.array([0.0, 0.0, 0.0, 0.0]))]) + TESTS.append([p, {"nonzero": False}, p(np.array([1, 1, 1, 1])), p(np.array([0.0, 0.0, 0.0, 0.0]))]) TESTS.append( [ p, - {"nonzero": False, "channel_wise": True, "subtrahend": [1, 2, 3]}, - np.ones((3, 2, 2)), - np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-2.0, -2.0], [-2.0, -2.0]]]), + {"nonzero": False, "channel_wise": True, "subtrahend": [1, 2, 3], "dtype": np.float32}, + p(np.ones((3, 2, 2))), + p(np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-2.0, -2.0], [-2.0, -2.0]]])), ] ) TESTS.append( [ p, - {"nonzero": True, "channel_wise": True, "subtrahend": [1, 2, 3], "divisor": [0, 0, 2]}, - np.ones((3, 2, 2)), - np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-1.0, -1.0], [-1.0, -1.0]]]), + {"nonzero": True, "channel_wise": True, "subtrahend": [1, 2, 3], "divisor": [0, 0, 2], "dtype": "float32"}, + p(np.ones((3, 2, 2))), + p(np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-1.0, -1.0], [-1.0, -1.0]]])), ] ) TESTS.append( [ p, - {"nonzero": True, "channel_wise": False, "subtrahend": 2, "divisor": 0}, - np.ones((3, 2, 2)), - np.ones((3, 2, 2)) * -1.0, + {"nonzero": True, "channel_wise": False, "subtrahend": 2, "divisor": 0, "dtype": torch.float32}, + p(np.ones((3, 2, 2))), + p(np.ones((3, 2, 2)) * -1.0), ] ) TESTS.append( [ p, {"nonzero": True, "channel_wise": False, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": 0}, - np.ones((3, 2, 2)), - np.ones((3, 2, 2)) * 0.5, + p(np.ones((3, 2, 2))), + p(np.ones((3, 2, 2)) * 0.5), ] ) TESTS.append( [ p, {"nonzero": True, "channel_wise": True, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": [0, 1, 0]}, - np.ones((3, 2, 2)), - np.ones((3, 2, 2)) * 0.5, + p(np.ones((3, 2, 2))), + p(np.ones((3, 2, 2)) * 0.5), ] ) @@ -91,17 +91,14 @@ def test_default(self, im_type): self.assertEqual(im.device, normalized.device) self.assertTrue(normalized.dtype in (np.float32, torch.float32)) expected = (self.imt - np.mean(self.imt)) / np.std(self.imt) - assert_allclose(expected, normalized, rtol=1e-3) + assert_allclose(normalized, expected, type_test=False, rtol=1e-3) @parameterized.expand(TESTS) def test_nonzero(self, in_type, input_param, input_data, expected_data): normalizer = NormalizeIntensity(**input_param) im = in_type(input_data) normalized = normalizer(im) - self.assertEqual(type(im), type(normalized)) - if isinstance(normalized, torch.Tensor): - self.assertEqual(im.device, normalized.device) - assert_allclose(expected_data, normalized) + assert_allclose(normalized, in_type(expected_data)) @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_channel_wise(self, im_type): @@ -109,10 +106,7 @@ def test_channel_wise(self, im_type): input_data = im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]])) expected = np.array([[0.0, -1.0, 0.0, 1.0], [0.0, -1.0, 0.0, 1.0]]) normalized = normalizer(input_data) - self.assertEqual(type(input_data), type(normalized)) - if isinstance(normalized, torch.Tensor): - self.assertEqual(input_data.device, normalized.device) - assert_allclose(expected, normalized) + assert_allclose(normalized, im_type(expected)) @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_value_errors(self, im_type): diff --git a/tests/test_normalize_intensityd.py b/tests/test_normalize_intensityd.py index e2cec5407a..12a39b1b5b 100644 --- a/tests/test_normalize_intensityd.py +++ b/tests/test_normalize_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,7 +25,7 @@ [ {"keys": ["img"], "nonzero": True}, {"img": p(np.array([0.0, 3.0, 0.0, 4.0]))}, - np.array([0.0, -1.0, 0.0, 1.0]), + p(np.array([0.0, -1.0, 0.0, 1.0])), ] ) TESTS.append( @@ -37,14 +37,14 @@ "nonzero": True, }, {"img": p(np.array([0.0, 3.0, 0.0, 4.0]))}, - np.array([0.0, -1.0, 0.0, 1.0]), + p(np.array([0.0, -1.0, 0.0, 1.0])), ] ) TESTS.append( [ {"keys": ["img"], "nonzero": True}, {"img": p(np.array([0.0, 0.0, 0.0, 0.0]))}, - np.array([0.0, 0.0, 0.0, 0.0]), + p(np.array([0.0, 0.0, 0.0, 0.0])), ] ) @@ -60,7 +60,7 @@ def test_image_normalize_intensityd(self, im_type): self.assertEqual(type(im), type(normalized)) if isinstance(normalized, torch.Tensor): self.assertEqual(im.device, normalized.device) - assert_allclose(normalized, expected, rtol=1e-3) + assert_allclose(normalized, im_type(expected), rtol=1e-3) @parameterized.expand(TESTS) def test_nonzero(self, input_param, input_data, expected_data): @@ -82,7 +82,7 @@ def test_channel_wise(self, im_type): if isinstance(normalized, torch.Tensor): self.assertEqual(input_data[key].device, normalized.device) expected = np.array([[0.0, -1.0, 0.0, 1.0], [0.0, -1.0, 0.0, 1.0]]) - assert_allclose(normalized, expected) + assert_allclose(normalized, im_type(expected)) if __name__ == "__main__": diff --git a/tests/test_npzdictitemdataset.py b/tests/test_npzdictitemdataset.py index 2e86ef29d0..e24a2cfc1f 100644 --- a/tests/test_npzdictitemdataset.py +++ b/tests/test_npzdictitemdataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_numpy_reader.py b/tests/test_numpy_reader.py index a57a036905..662d22afde 100644 --- a/tests/test_numpy_reader.py +++ b/tests/test_numpy_reader.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,12 +10,15 @@ # limitations under the License. import os +import sys import tempfile import unittest import numpy as np +import torch -from monai.data import NumpyReader +from monai.data import DataLoader, Dataset, NumpyReader +from monai.transforms import LoadImaged class TestNumpyReader(unittest.TestCase): @@ -27,8 +30,8 @@ def test_npy(self): reader = NumpyReader() result = reader.get_data(reader.read(filepath)) - self.assertTupleEqual(result[1]["spatial_shape"], test_data.shape) - self.assertTupleEqual(result[0].shape, test_data.shape) + np.testing.assert_allclose(result[1]["spatial_shape"], test_data.shape) + np.testing.assert_allclose(result[0].shape, test_data.shape) np.testing.assert_allclose(result[0], test_data) def test_npz1(self): @@ -39,8 +42,8 @@ def test_npz1(self): reader = NumpyReader() result = reader.get_data(reader.read(filepath)) - self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape) - self.assertTupleEqual(result[0].shape, test_data1.shape) + np.testing.assert_allclose(result[1]["spatial_shape"], test_data1.shape) + np.testing.assert_allclose(result[0].shape, test_data1.shape) np.testing.assert_allclose(result[0], test_data1) def test_npz2(self): @@ -52,8 +55,8 @@ def test_npz2(self): reader = NumpyReader() result = reader.get_data(reader.read(filepath)) - self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape) - self.assertTupleEqual(result[0].shape, (2, 3, 4, 4)) + np.testing.assert_allclose(result[1]["spatial_shape"], test_data1.shape) + np.testing.assert_allclose(result[0].shape, (2, 3, 4, 4)) np.testing.assert_allclose(result[0], np.stack([test_data1, test_data2])) def test_npz3(self): @@ -65,8 +68,8 @@ def test_npz3(self): reader = NumpyReader(npz_keys=["test1", "test2"]) result = reader.get_data(reader.read(filepath)) - self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape) - self.assertTupleEqual(result[0].shape, (2, 3, 4, 4)) + np.testing.assert_allclose(result[1]["spatial_shape"], test_data1.shape) + np.testing.assert_allclose(result[0].shape, (2, 3, 4, 4)) np.testing.assert_allclose(result[0], np.stack([test_data1, test_data2])) def test_npy_pickle(self): @@ -77,7 +80,7 @@ def test_npy_pickle(self): reader = NumpyReader() result = reader.get_data(reader.read(filepath))[0].item() - self.assertTupleEqual(result["test"].shape, test_data["test"].shape) + np.testing.assert_allclose(result["test"].shape, test_data["test"].shape) np.testing.assert_allclose(result["test"], test_data["test"]) def test_kwargs(self): @@ -88,7 +91,39 @@ def test_kwargs(self): reader = NumpyReader(mmap_mode="r") result = reader.get_data(reader.read(filepath, mmap_mode=None))[0].item() - self.assertTupleEqual(result["test"].shape, test_data["test"].shape) + np.testing.assert_allclose(result["test"].shape, test_data["test"].shape) + + def test_dataloader(self): + test_data = np.random.randint(0, 256, size=[3, 4, 5]) + datalist = [] + with tempfile.TemporaryDirectory() as tempdir: + for i in range(4): + filepath = os.path.join(tempdir, f"test_data{i}.npz") + np.savez(filepath, test_data) + datalist.append({"image": filepath}) + + num_workers = 2 if sys.platform == "linux" else 0 + loader = DataLoader( + Dataset(data=datalist, transform=LoadImaged(keys="image", reader=NumpyReader())), + batch_size=2, + num_workers=num_workers, + ) + for d in loader: + for s in d["image_meta_dict"]["spatial_shape"]: + torch.testing.assert_allclose(s, torch.as_tensor([3, 4, 5])) + for c in d["image"]: + torch.testing.assert_allclose(c, test_data) + + def test_channel_dim(self): + test_data = np.random.randint(0, 256, size=[3, 4, 5, 2]) + with tempfile.TemporaryDirectory() as tempdir: + filepath = os.path.join(tempdir, "test_data.npy") + np.save(filepath, test_data) + + reader = NumpyReader(channel_dim=-1) + result = reader.get_data(reader.read(filepath)) + np.testing.assert_allclose(result[1]["spatial_shape"], test_data.shape[:-1]) + self.assertEqual(result[1]["original_channel_dim"], -1) if __name__ == "__main__": diff --git a/tests/test_nvtx_decorator.py b/tests/test_nvtx_decorator.py index e2a9ad67b8..e81c72efcf 100644 --- a/tests/test_nvtx_decorator.py +++ b/tests/test_nvtx_decorator.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,52 +17,44 @@ from monai.transforms import ( Compose, + CuCIM, Flip, FlipD, RandAdjustContrast, + RandCuCIM, RandFlip, Randomizable, Rotate90, + ToCupy, + TorchVision, ToTensor, ToTensorD, ) from monai.utils import Range, optional_import +from tests.utils import HAS_CUPY _, has_nvtx = optional_import("torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?") +_, has_tvt = optional_import("torchvision.transforms") +_, has_cut = optional_import("cucim.core.operations.expose.transform") -TEST_CASE_ARRAY_0 = [ - np.random.randn(3, 3), -] -TEST_CASE_ARRAY_1 = [ - np.random.randn(3, 10, 10), -] +TEST_CASE_ARRAY_0 = [np.random.randn(3, 3)] +TEST_CASE_ARRAY_1 = [np.random.randn(3, 10, 10)] -TEST_CASE_DICT_0 = [ - {"image": np.random.randn(3, 3)}, -] -TEST_CASE_DICT_1 = [ - {"image": np.random.randn(3, 10, 10)}, -] +TEST_CASE_DICT_0 = [{"image": np.random.randn(3, 3)}] +TEST_CASE_DICT_1 = [{"image": np.random.randn(3, 10, 10)}] -TEST_CASE_TORCH_0 = [ - torch.randn(3, 3), -] -TEST_CASE_TORCH_1 = [ - torch.randn(3, 10, 10), -] +TEST_CASE_TORCH_0 = [torch.randn(3, 3)] +TEST_CASE_TORCH_1 = [torch.randn(3, 10, 10)] +TEST_CASE_WRAPPER = [np.random.randn(3, 10, 10)] + +@unittest.skipUnless(has_nvtx, "CUDA is required for NVTX Range!") class TestNVTXRangeDecorator(unittest.TestCase): @parameterized.expand([TEST_CASE_ARRAY_0, TEST_CASE_ARRAY_1]) - @unittest.skipUnless(has_nvtx, "CUDA is required for NVTX Range!") def test_tranform_array(self, input): - transforms = Compose( - [ - Range("random flip")(Flip()), - Range()(ToTensor()), - ] - ) + transforms = Compose([Range("random flip")(Flip()), Range()(ToTensor())]) # Apply transforms output = transforms(input) @@ -82,18 +74,12 @@ def test_tranform_array(self, input): self.assertIsInstance(output2, torch.Tensor) self.assertIsInstance(output3, torch.Tensor) np.testing.assert_equal(output.numpy(), output1.numpy()) - np.testing.assert_equal(output.numpy(), output1.numpy()) + np.testing.assert_equal(output.numpy(), output2.numpy()) np.testing.assert_equal(output.numpy(), output3.numpy()) @parameterized.expand([TEST_CASE_DICT_0, TEST_CASE_DICT_1]) - @unittest.skipUnless(has_nvtx, "CUDA is required for NVTX Range!") def test_tranform_dict(self, input): - transforms = Compose( - [ - Range("random flip dict")(FlipD(keys="image")), - Range()(ToTensorD("image")), - ] - ) + transforms = Compose([Range("random flip dict")(FlipD(keys="image")), Range()(ToTensorD("image"))]) # Apply transforms output = transforms(input)["image"] @@ -116,8 +102,32 @@ def test_tranform_dict(self, input): np.testing.assert_equal(output.numpy(), output2.numpy()) np.testing.assert_equal(output.numpy(), output3.numpy()) + @parameterized.expand([TEST_CASE_WRAPPER]) + @unittest.skipUnless(HAS_CUPY, "Requires CuPy.") + @unittest.skipUnless(has_cut, "Requires cuCIM transforms.") + @unittest.skipUnless(has_tvt, "Requires torchvision transforms.") + def test_wrapper_tranforms(self, input): + transform_list = [ + ToTensor(), + TorchVision(name="RandomHorizontalFlip", p=1.0), + ToCupy(), + CuCIM(name="image_flip", spatial_axis=-1), + RandCuCIM(name="rand_image_rotate_90", prob=1.0, max_k=1, spatial_axis=(-2, -1)), + ] + + transforms = Compose(transform_list) + transforms_range = Compose([Range()(t) for t in transform_list]) + + # Apply transforms + output = transforms(input) + + # Apply transforms with Range + output_r = transforms_range(input) + + # Check the outputs + np.testing.assert_equal(output.get(), output_r.get()) + @parameterized.expand([TEST_CASE_ARRAY_1]) - @unittest.skipUnless(has_nvtx, "CUDA is required for NVTX Range!") def test_tranform_randomized(self, input): # Compose deterministic and randomized transforms transforms = Compose( @@ -158,13 +168,9 @@ def test_tranform_randomized(self, input): break @parameterized.expand([TEST_CASE_TORCH_0, TEST_CASE_TORCH_1]) - @unittest.skipUnless(has_nvtx, "CUDA is required for NVTX Range!") def test_network(self, input): # Create a network - model = torch.nn.Sequential( - torch.nn.ReLU(), - torch.nn.Sigmoid(), - ) + model = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Sigmoid()) # Forward output = model(input) @@ -189,7 +195,6 @@ def test_network(self, input): np.testing.assert_equal(output.numpy(), output3.numpy()) @parameterized.expand([TEST_CASE_TORCH_0, TEST_CASE_TORCH_1]) - @unittest.skipUnless(has_nvtx, "CUDA is required for NVTX Range!") def test_loss(self, input): # Create a network and loss model = torch.nn.Sigmoid() @@ -219,7 +224,6 @@ def test_loss(self, input): np.testing.assert_equal(output.numpy(), output2.numpy()) np.testing.assert_equal(output.numpy(), output3.numpy()) - @unittest.skipUnless(has_nvtx, "CUDA is required for NVTX Range!") def test_context_manager(self): model = torch.nn.Sigmoid() loss = torch.nn.BCELoss() diff --git a/tests/test_nvtx_transform.py b/tests/test_nvtx_transform.py index 6bcfe00078..01a069ed8a 100644 --- a/tests/test_nvtx_transform.py +++ b/tests/test_nvtx_transform.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -35,29 +35,14 @@ _, has_nvtx = optional_import("torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?") -TEST_CASE_ARRAY_0 = [ - np.random.randn(3, 3), -] -TEST_CASE_ARRAY_1 = [ - np.random.randn(3, 10, 10), -] -TEST_CASE_DICT_0 = [ - {"image": np.random.randn(3, 3)}, -] -TEST_CASE_DICT_1 = [ - {"image": np.random.randn(3, 10, 10)}, -] +TEST_CASE_ARRAY_0 = [np.random.randn(3, 3)] +TEST_CASE_ARRAY_1 = [np.random.randn(3, 10, 10)] +TEST_CASE_DICT_0 = [{"image": np.random.randn(3, 3)}] +TEST_CASE_DICT_1 = [{"image": np.random.randn(3, 10, 10)}] class TestNVTXTransforms(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_ARRAY_0, - TEST_CASE_ARRAY_1, - TEST_CASE_DICT_0, - TEST_CASE_DICT_1, - ] - ) + @parameterized.expand([TEST_CASE_ARRAY_0, TEST_CASE_ARRAY_1, TEST_CASE_DICT_0, TEST_CASE_DICT_1]) @unittest.skipUnless(has_nvtx, "CUDA is required for NVTX!") def test_nvtx_transfroms_alone(self, input): transforms = Compose( diff --git a/tests/test_occlusion_sensitivity.py b/tests/test_occlusion_sensitivity.py index d58359a598..ff32f747d4 100644 --- a/tests/test_occlusion_sensitivity.py +++ b/tests/test_occlusion_sensitivity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -29,47 +29,27 @@ # 2D w/ bounding box TEST_CASE_0 = [ - { - "nn_module": model_2d, - }, - { - "x": torch.rand(1, 1, 48, 64).to(device), - "b_box": [-1, -1, 2, 40, 1, 62], - }, + {"nn_module": model_2d}, + {"x": torch.rand(1, 1, 48, 64).to(device), "b_box": [-1, -1, 2, 40, 1, 62]}, (1, 1, 39, 62, out_channels_2d), (1, 1, 39, 62), ] # 3D w/ bounding box and stride TEST_CASE_1 = [ {"nn_module": model_3d, "n_batch": 10, "stride": (2, 1, 2), "mask_size": (16, 15, 14)}, - { - "x": torch.rand(1, 1, 6, 6, 6).to(device), - "b_box": [-1, -1, 2, 3, -1, -1, -1, -1], - }, + {"x": torch.rand(1, 1, 6, 6, 6).to(device), "b_box": [-1, -1, 2, 3, -1, -1, -1, -1]}, (1, 1, 2, 6, 6, out_channels_3d), (1, 1, 2, 6, 6), ] TEST_CASE_FAIL_0 = [ # 2D should fail, since 3 stride values given - { - "nn_module": model_2d, - "n_batch": 10, - "stride": (2, 2, 2), - }, - { - "x": torch.rand(1, 1, 48, 64).to(device), - "b_box": [-1, -1, 2, 3, -1, -1], - }, + {"nn_module": model_2d, "n_batch": 10, "stride": (2, 2, 2)}, + {"x": torch.rand(1, 1, 48, 64).to(device), "b_box": [-1, -1, 2, 3, -1, -1]}, ] TEST_CASE_FAIL_1 = [ # 2D should fail, since stride is not a factor of image size - { - "nn_module": model_2d, - "stride": 3, - }, - { - "x": torch.rand(1, 1, 48, 64).to(device), - }, + {"nn_module": model_2d, "stride": 3}, + {"x": torch.rand(1, 1, 48, 64).to(device)}, ] diff --git a/tests/test_one_of.py b/tests/test_one_of.py index d45d0f3f61..29d13d7d0c 100644 --- a/tests/test_one_of.py +++ b/tests/test_one_of.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,12 +12,21 @@ import unittest from copy import deepcopy +import numpy as np from parameterized import parameterized -from monai.transforms import InvertibleTransform, OneOf, Transform +from monai.transforms import ( + InvertibleTransform, + OneOf, + RandScaleIntensityd, + RandShiftIntensityd, + Resized, + TraceableTransform, + Transform, +) from monai.transforms.compose import Compose from monai.transforms.transform import MapTransform -from monai.utils.enums import InverseKeys +from monai.utils.enums import TraceKeys class X(Transform): @@ -93,9 +102,7 @@ def __init__(self, keys): self.inv_fn = lambda x: x - 100 -TESTS = [ - ((X(), Y(), X()), (1, 2, 1), (0.25, 0.5, 0.25)), -] +TESTS = [((X(), Y(), X()), (1, 2, 1), (0.25, 0.5, 0.25))] KEYS = ["x", "y"] TEST_INVERSES = [ @@ -141,30 +148,52 @@ def _match(a, b): _match(p, f) @parameterized.expand(TEST_INVERSES) - def test_inverse(self, transform, should_be_ok): + def test_inverse(self, transform, invertible): data = {k: (i + 1) * 10.0 for i, k in enumerate(KEYS)} fwd_data = transform(data) - if not should_be_ok: - with self.assertRaises(RuntimeError): - transform.inverse(fwd_data) - return - - for k in KEYS: - t = fwd_data[k + InverseKeys.KEY_SUFFIX][-1] - # make sure the OneOf index was stored - self.assertEqual(t[InverseKeys.CLASS_NAME], OneOf.__name__) - # make sure index exists and is in bounds - self.assertTrue(0 <= t[InverseKeys.EXTRA_INFO]["index"] < len(transform)) + + if invertible: + for k in KEYS: + t = fwd_data[TraceableTransform.trace_key(k)][-1] + # make sure the OneOf index was stored + self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__) + # make sure index exists and is in bounds + self.assertTrue(0 <= t[TraceKeys.EXTRA_INFO]["index"] < len(transform)) # call the inverse fwd_inv_data = transform.inverse(fwd_data) - for k in KEYS: - # check transform was removed - self.assertTrue(len(fwd_inv_data[k + InverseKeys.KEY_SUFFIX]) < len(fwd_data[k + InverseKeys.KEY_SUFFIX])) - # check data is same as original (and different from forward) - self.assertEqual(fwd_inv_data[k], data[k]) - self.assertNotEqual(fwd_inv_data[k], fwd_data[k]) + if invertible: + for k in KEYS: + # check transform was removed + self.assertTrue( + len(fwd_inv_data[TraceableTransform.trace_key(k)]) < len(fwd_data[TraceableTransform.trace_key(k)]) + ) + # check data is same as original (and different from forward) + self.assertEqual(fwd_inv_data[k], data[k]) + self.assertNotEqual(fwd_inv_data[k], fwd_data[k]) + else: + # if not invertible, should not change the data + self.assertDictEqual(fwd_data, fwd_inv_data) + + def test_inverse_compose(self): + transform = Compose( + [ + Resized(keys="img", spatial_size=[100, 100, 100]), + OneOf( + [ + RandScaleIntensityd(keys="img", factors=0.5, prob=1.0), + RandShiftIntensityd(keys="img", offsets=0.5, prob=1.0), + ] + ), + ] + ) + transform.set_random_state(seed=0) + result = transform({"img": np.ones((1, 101, 102, 103))}) + + result = transform.inverse(result) + # invert to the original spatial shape + self.assertTupleEqual(result["img"].shape, (1, 101, 102, 103)) def test_one_of(self): p = OneOf((A(), B(), C()), (1, 2, 1)) diff --git a/tests/test_openslide_reader.py b/tests/test_openslide_reader.py deleted file mode 100644 index c0b395fd02..0000000000 --- a/tests/test_openslide_reader.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import unittest -from unittest import skipUnless - -import numpy as np -from numpy.testing import assert_array_equal -from parameterized import parameterized - -from monai.apps.utils import download_url -from monai.data.image_reader import WSIReader -from monai.utils import optional_import - -_, has_osl = optional_import("openslide") - - -FILE_URL = "http://openslide.cs.cmu.edu/download/openslide-testdata/Generic-TIFF/CMU-1.tiff" -FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + os.path.basename(FILE_URL)) - -HEIGHT = 32914 -WIDTH = 46000 - -TEST_CASE_0 = [FILE_PATH, (3, HEIGHT, WIDTH)] - -TEST_CASE_1 = [ - FILE_PATH, - {"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0}, - np.array([[[246], [246]], [[246], [246]], [[246], [246]]]), -] - -TEST_CASE_2 = [ - FILE_PATH, - {"location": (0, 0), "size": (2, 1), "level": 2}, - np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), -] - -TEST_CASE_3 = [ - FILE_PATH, - { - "location": (0, 0), - "size": (8, 8), - "level": 2, - "grid_shape": (2, 1), - "patch_size": 2, - }, - np.array( - [ - [[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[239, 239], [239, 239]]], - [[[242, 242], [242, 243]], [[242, 242], [242, 243]], [[242, 242], [242, 243]]], - ] - ), -] - -TEST_CASE_4 = [ - FILE_PATH, - { - "location": (0, 0), - "size": (8, 8), - "level": 2, - "grid_shape": (2, 1), - "patch_size": 1, - }, - np.array([[[[239]], [[239]], [[239]]], [[[243]], [[243]], [[243]]]]), -] - - -class TestOpenSlideReader(unittest.TestCase): - @skipUnless(has_osl, "Requires OpenSlide") - def setUp(self): - download_url(FILE_URL, FILE_PATH, "5a3cfd4fd725c50578ddb80b517b759f") - - @parameterized.expand([TEST_CASE_0]) - def test_read_whole_image(self, file_path, expected_shape): - reader = WSIReader("OpenSlide") - img_obj = reader.read(file_path) - img = reader.get_data(img_obj)[0] - self.assertTupleEqual(img.shape, expected_shape) - - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_read_region(self, file_path, patch_info, expected_img): - reader = WSIReader("OpenSlide") - img_obj = reader.read(file_path) - img = reader.get_data(img_obj, **patch_info)[0] - self.assertTupleEqual(img.shape, expected_img.shape) - self.assertIsNone(assert_array_equal(img, expected_img)) - - @parameterized.expand([TEST_CASE_3, TEST_CASE_4]) - def test_read_patches(self, file_path, patch_info, expected_img): - reader = WSIReader("OpenSlide") - img_obj = reader.read(file_path) - img = reader.get_data(img_obj, **patch_info)[0] - self.assertTupleEqual(img.shape, expected_img.shape) - self.assertIsNone(assert_array_equal(img, expected_img)) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_optim_novograd.py b/tests/test_optim_novograd.py index c76501cd6f..0cf4c35cb6 100644 --- a/tests/test_optim_novograd.py +++ b/tests/test_optim_novograd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -38,37 +38,19 @@ def build_test_cases(data): return test_cases -TEST_CASES_ALL = build_test_cases( # normal parameters - [ - torch.randn(10, 5), - torch.randn(10), - torch.randn(5), - ] -) +TEST_CASES_ALL = build_test_cases([torch.randn(10, 5), torch.randn(10), torch.randn(5)]) # normal parameters TEST_CASES_ALL += build_test_cases( # non-contiguous parameters - [ - torch.randn(10, 5, 2)[..., 0], - torch.randn(10, 2)[..., 0], - torch.randn(5), - ] + [torch.randn(10, 5, 2)[..., 0], torch.randn(10, 2)[..., 0], torch.randn(5)] ) if torch.cuda.is_available(): TEST_CASES_ALL += build_test_cases( # gpu parameters - [ - torch.randn(10, 5).cuda(), - torch.randn(10).cuda(), - torch.randn(5).cuda(), - ] + [torch.randn(10, 5).cuda(), torch.randn(10).cuda(), torch.randn(5).cuda()] ) if torch.cuda.device_count() > 1: TEST_CASES_ALL += build_test_cases( # multi-gpu parameters - [ - torch.randn(10, 5).cuda(0), - torch.randn(10).cuda(1), - torch.randn(5).cuda(0), - ] + [torch.randn(10, 5).cuda(0), torch.randn(10).cuda(1), torch.randn(5).cuda(0)] ) diff --git a/tests/test_optional_import.py b/tests/test_optional_import.py index 05584f9f9c..b87ebf8909 100644 --- a/tests/test_optional_import.py +++ b/tests/test_optional_import.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_orientation.py b/tests/test_orientation.py index aa7f33a469..685c977f36 100644 --- a/tests/test_orientation.py +++ b/tests/test_orientation.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_orientationd.py b/tests/test_orientationd.py index 452172ce9b..89ecd07b0a 100644 --- a/tests/test_orientationd.py +++ b/tests/test_orientationd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_p3d_block.py b/tests/test_p3d_block.py new file mode 100644 index 0000000000..62b7098dcd --- /dev/null +++ b/tests/test_p3d_block.py @@ -0,0 +1,75 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.blocks.dints_block import P3DActiConvNormBlock + +TEST_CASES_3D = [ + [ + {"in_channel": 32, "out_channel": 16, "kernel_size": 3, "padding": 0, "mode": 0}, + (7, 32, 16, 32, 8), + (7, 16, 14, 30, 6), + ], + [ + {"in_channel": 32, "out_channel": 16, "kernel_size": 3, "padding": 1, "mode": 0}, # check padding + (7, 32, 16, 32, 8), + (7, 16, 16, 32, 8), + ], + [ + {"in_channel": 32, "out_channel": 16, "kernel_size": 3, "padding": 0, "mode": 1}, + (7, 32, 16, 32, 8), + (7, 16, 14, 30, 6), + ], + [ + { + "in_channel": 32, + "out_channel": 16, + "kernel_size": 3, + "padding": 0, + "mode": 2, + "act_name": ("leakyrelu", {"inplace": True, "negative_slope": 0.2}), + }, + (7, 32, 16, 32, 8), + (7, 16, 14, 30, 6), + ], + [ + { + "in_channel": 32, + "out_channel": 16, + "kernel_size": 4, + "padding": 0, + "mode": 0, + "norm_name": ("INSTANCE", {"affine": True}), + }, + (7, 32, 16, 32, 8), + (7, 16, 13, 29, 5), + ], +] + + +class TestP3D(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D) + def test_3d(self, input_param, input_shape, expected_shape): + net = P3DActiConvNormBlock(**input_param) + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill(self): + with self.assertRaises(ValueError): + P3DActiConvNormBlock(in_channel=32, out_channel=16, kernel_size=3, padding=0, mode=3) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py index a8c544558f..a070fc760f 100644 --- a/tests/test_pad_collation.py +++ b/tests/test_pad_collation.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,6 +20,7 @@ from monai.data import CacheDataset, DataLoader from monai.data.utils import decollate_batch, pad_list_data_collate from monai.transforms import ( + Compose, PadListDataCollate, RandRotate, RandRotate90, @@ -29,24 +30,26 @@ RandSpatialCropd, RandZoom, RandZoomd, + ToTensor, + ToTensord, ) from monai.utils import set_determinism TESTS: List[Tuple] = [] for pad_collate in [ - lambda x: pad_list_data_collate(batch=x, method="end", mode="constant", constant_values=1), - PadListDataCollate(method="end", mode="constant", constant_values=1), + lambda x: pad_list_data_collate(batch=x, method="end", mode="constant"), + PadListDataCollate(method="end", mode="constant"), ]: TESTS.append((dict, pad_collate, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) TESTS.append((dict, pad_collate, RandRotated("image", prob=1, range_x=np.pi, keep_size=False))) TESTS.append((dict, pad_collate, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) - TESTS.append((dict, pad_collate, RandRotate90d("image", prob=1, max_k=2))) + TESTS.append((dict, pad_collate, Compose([RandRotate90d("image", prob=1, max_k=2), ToTensord("image")]))) TESTS.append((list, pad_collate, RandSpatialCrop(roi_size=[8, 7], random_size=True))) TESTS.append((list, pad_collate, RandRotate(prob=1, range_x=np.pi, keep_size=False))) TESTS.append((list, pad_collate, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) - TESTS.append((list, pad_collate, RandRotate90(prob=1, max_k=2))) + TESTS.append((list, pad_collate, Compose([RandRotate90(prob=1, max_k=2), ToTensor()]))) class _Dataset(torch.utils.data.Dataset): diff --git a/tests/test_parallel_execution.py b/tests/test_parallel_execution.py index c4115d21ef..6186e73f68 100644 --- a/tests/test_parallel_execution.py +++ b/tests/test_parallel_execution.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_partition_dataset.py b/tests/test_partition_dataset.py index a954bfae91..687cf8df34 100644 --- a/tests/test_partition_dataset.py +++ b/tests/test_partition_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -117,16 +117,7 @@ class TestPartitionDataset(unittest.TestCase): @parameterized.expand( - [ - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - TEST_CASE_7, - TEST_CASE_8, - ] + [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8] ) def test_value(self, input_param, result): self.assertListEqual(partition_dataset(**input_param), result) diff --git a/tests/test_partition_dataset_classes.py b/tests/test_partition_dataset_classes.py index 3aef47107a..4ed283bdd7 100644 --- a/tests/test_partition_dataset_classes.py +++ b/tests/test_partition_dataset_classes.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_patch_dataset.py b/tests/test_patch_dataset.py index 4f6e9a25fd..1796ad4f23 100644 --- a/tests/test_patch_dataset.py +++ b/tests/test_patch_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -59,7 +59,7 @@ def test_loading_array(self): np.testing.assert_allclose( item[0], np.array( - [[[1.779992, 2.779992, 3.779992], [5.779992, 6.779992, 7.779992], [9.779992, 10.779992, 11.779992]]] + [[[1.338681, 2.338681, 3.338681], [5.338681, 6.338681, 7.338681], [9.338681, 10.338681, 11.338681]]] ), rtol=1e-5, ) @@ -71,9 +71,9 @@ def test_loading_array(self): np.array( [ [ - [5.025618, 6.025618, 7.025618], - [9.025618, 10.025618, 11.025618], - [13.025618, 14.025618, 15.025618], + [4.957847, 5.957847, 6.957847], + [8.957847, 9.957847, 10.957847], + [12.957847, 13.957847, 14.957847], ] ] ), diff --git a/tests/test_patch_wsi_dataset.py b/tests/test_patch_wsi_dataset.py index f775f28376..79165df36d 100644 --- a/tests/test_patch_wsi_dataset.py +++ b/tests/test_patch_wsi_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,25 +21,23 @@ from monai.apps.utils import download_url from monai.utils import optional_import -_, has_cim = optional_import("cucim") +_cucim, has_cim = optional_import("cucim") +has_cim = has_cim and hasattr(_cucim, "CuImage") _, has_osl = optional_import("openslide") -FILE_URL = "http://openslide.cs.cmu.edu/download/openslide-testdata/Generic-TIFF/CMU-1.tiff" -FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + os.path.basename(FILE_URL)) +FILE_URL = "https://drive.google.com/uc?id=1sGTKZlJBIz53pfqTxoTqiIQzIoEzHLAe" +base_name, extension = FILE_URL.split("id=")[1], ".tiff" +FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + base_name + extension) TEST_CASE_0 = [ { - "data": [ - {"image": FILE_PATH, "location": [0, 0], "label": [1]}, - ], + "data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}], "region_size": (1, 1), "grid_shape": (1, 1), "patch_size": 1, "image_reader_name": "cuCIM", }, - [ - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}, - ], + [{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}], ] TEST_CASE_1 = [ @@ -60,47 +58,35 @@ TEST_CASE_2 = [ { - "data": [ - {"image": FILE_PATH, "location": [0, 0], "label": [1]}, - ], + "data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}], "region_size": 1, "grid_shape": 1, "patch_size": 1, "image_reader_name": "cuCIM", }, - [ - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}, - ], + [{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}], ] TEST_CASE_3 = [ { - "data": [ - {"image": FILE_PATH, "location": [0, 0], "label": [[[0, 1], [1, 0]]]}, - ], + "data": [{"image": FILE_PATH, "location": [0, 0], "label": [[[0, 1], [1, 0]]]}], "region_size": 1, "grid_shape": 1, "patch_size": 1, "image_reader_name": "cuCIM", }, - [ - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0, 1], [1, 0]]])}, - ], + [{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0, 1], [1, 0]]])}], ] TEST_CASE_OPENSLIDE_0 = [ { - "data": [ - {"image": FILE_PATH, "location": [0, 0], "label": [1]}, - ], + "data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}], "region_size": (1, 1), "grid_shape": (1, 1), "patch_size": 1, "image_reader_name": "OpenSlide", }, - [ - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}, - ], + [{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}], ] TEST_CASE_OPENSLIDE_1 = [ @@ -124,14 +110,7 @@ class TestPatchWSIDataset(unittest.TestCase): def setUp(self): download_url(FILE_URL, FILE_PATH, "5a3cfd4fd725c50578ddb80b517b759f") - @parameterized.expand( - [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - ] - ) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) @skipUnless(has_cim, "Requires CuCIM") def test_read_patches_cucim(self, input_parameters, expected): dataset = PatchWSIDataset(**input_parameters) @@ -142,12 +121,7 @@ def test_read_patches_cucim(self, input_parameters, expected): self.assertIsNone(assert_array_equal(samples[i]["label"], expected[i]["label"])) self.assertIsNone(assert_array_equal(samples[i]["image"], expected[i]["image"])) - @parameterized.expand( - [ - TEST_CASE_OPENSLIDE_0, - TEST_CASE_OPENSLIDE_1, - ] - ) + @parameterized.expand([TEST_CASE_OPENSLIDE_0, TEST_CASE_OPENSLIDE_1]) @skipUnless(has_osl, "Requires OpenSlide") def test_read_patches_openslide(self, input_parameters, expected): dataset = PatchWSIDataset(**input_parameters) diff --git a/tests/test_patchembedding.py b/tests/test_patchembedding.py index 6c9ac78a99..4af2b47ba5 100644 --- a/tests/test_patchembedding.py +++ b/tests/test_patchembedding.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_pathology_he_stain.py b/tests/test_pathology_he_stain.py index 1d74f485e9..7b884315fc 100644 --- a/tests/test_pathology_he_stain.py +++ b/tests/test_pathology_he_stain.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -73,12 +73,7 @@ class TestExtractHEStains(unittest.TestCase): @parameterized.expand( - [ - NEGATIVE_VALUE_TEST_CASE, - INVALID_VALUE_TEST_CASE, - EXTRACT_STAINS_TEST_CASE_0, - EXTRACT_STAINS_TEST_CASE_1, - ] + [NEGATIVE_VALUE_TEST_CASE, INVALID_VALUE_TEST_CASE, EXTRACT_STAINS_TEST_CASE_0, EXTRACT_STAINS_TEST_CASE_1] ) def test_transparent_image(self, image): """ @@ -112,13 +107,7 @@ def test_identical_result_vectors(self, image): result = ExtractHEStains()(image) np.testing.assert_array_equal(result[:, 0], result[:, 1]) - @parameterized.expand( - [ - EXTRACT_STAINS_TEST_CASE_00, - EXTRACT_STAINS_TEST_CASE_4, - EXTRACT_STAINS_TEST_CASE_5, - ] - ) + @parameterized.expand([EXTRACT_STAINS_TEST_CASE_00, EXTRACT_STAINS_TEST_CASE_4, EXTRACT_STAINS_TEST_CASE_5]) def test_result_value(self, image, expected_data): """ Test that an input image returns an expected stain matrix. @@ -156,12 +145,7 @@ def test_result_value(self, image, expected_data): class TestNormalizeHEStains(unittest.TestCase): @parameterized.expand( - [ - NEGATIVE_VALUE_TEST_CASE, - INVALID_VALUE_TEST_CASE, - NORMALIZE_STAINS_TEST_CASE_0, - NORMALIZE_STAINS_TEST_CASE_1, - ] + [NEGATIVE_VALUE_TEST_CASE, INVALID_VALUE_TEST_CASE, NORMALIZE_STAINS_TEST_CASE_0, NORMALIZE_STAINS_TEST_CASE_1] ) def test_transparent_image(self, image): """ diff --git a/tests/test_pathology_he_stain_dict.py b/tests/test_pathology_he_stain_dict.py index 8d51579cb2..7d6c3ffb75 100644 --- a/tests/test_pathology_he_stain_dict.py +++ b/tests/test_pathology_he_stain_dict.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -100,13 +100,7 @@ def test_identical_result_vectors(self, image): result = ExtractHEStainsD([key])({key: image}) np.testing.assert_array_equal(result[key][:, 0], result[key][:, 1]) - @parameterized.expand( - [ - EXTRACT_STAINS_TEST_CASE_00, - EXTRACT_STAINS_TEST_CASE_4, - EXTRACT_STAINS_TEST_CASE_5, - ] - ) + @parameterized.expand([EXTRACT_STAINS_TEST_CASE_00, EXTRACT_STAINS_TEST_CASE_4, EXTRACT_STAINS_TEST_CASE_5]) def test_result_value(self, image, expected_data): """ Test that an input image returns an expected stain matrix. diff --git a/tests/test_pathology_prob_nms.py b/tests/test_pathology_prob_nms.py index 223b136ea7..3399e33afa 100644 --- a/tests/test_pathology_prob_nms.py +++ b/tests/test_pathology_prob_nms.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -41,12 +41,7 @@ class TestPathologyProbNMS(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASES_2D, - TEST_CASES_3D, - ] - ) + @parameterized.expand([TEST_CASES_2D, TEST_CASES_3D]) def test_output(self, class_args, call_args, probs_map, expected): nms = PathologyProbNMS(**class_args) output = nms(probs_map, **call_args) diff --git a/tests/test_persistentdataset.py b/tests/test_persistentdataset.py index 8446f566ef..17575c79f7 100644 --- a/tests/test_persistentdataset.py +++ b/tests/test_persistentdataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,6 +10,7 @@ # limitations under the License. import os +import pickle import tempfile import unittest @@ -56,7 +57,13 @@ def test_cache(self): items = [[list(range(i))] for i in range(5)] with tempfile.TemporaryDirectory() as tempdir: - ds = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir) + ds = PersistentDataset( + data=items, + transform=_InplaceXform(), + cache_dir=tempdir, + pickle_module="pickle", + pickle_protocol=pickle.HIGHEST_PROTOCOL, + ) self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) ds1 = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir) self.assertEqual(list(ds1), list(ds)) diff --git a/tests/test_persistentdataset_dist.py b/tests/test_persistentdataset_dist.py index d45bba03e5..20dcb2c264 100644 --- a/tests/test_persistentdataset_dist.py +++ b/tests/test_persistentdataset_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_phl_cpu.py b/tests/test_phl_cpu.py index 31e28bd39d..d479f554b4 100644 --- a/tests/test_phl_cpu.py +++ b/tests/test_phl_cpu.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -42,7 +42,7 @@ # Batch 0 [ # Channel 0 - [1, 0.2, 0.5, 0, 1], + [1, 0.2, 0.5, 0, 1] ], # Batch 1 [ @@ -79,15 +79,15 @@ [0, 0, 0, 0, 1], # Channel 2 [0, 0, 1, 0, 0], - ], + ] ], # Features [ # Batch 0 [ # Channel 0 - [1, 0.2, 0.5, 0.2, 1], - ], + [1, 0.2, 0.5, 0.2, 1] + ] ], # Expected [ @@ -99,7 +99,7 @@ [0.229572, 0.182884, 0.202637, 0.182884, 0.229572], # Channel 2 [0.201235, 0.208194, 0.205409, 0.208194, 0.201235], - ], + ] ], ], [ @@ -113,7 +113,7 @@ [ # Channel 0 [[9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 6, 6, 6], [9, 9, 6, 6, 6]] - ], + ] ], # Features [ @@ -125,7 +125,7 @@ [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]], # Channel 2 [[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3], [4, 4, 4, 4, 4]], - ], + ] ], # Expected [ @@ -139,7 +139,7 @@ [7.613517, 7.359183, 5.846500, 5.638952, 5.350098], [7.598255, 7.458446, 5.912375, 5.583625, 5.233126], ] - ], + ] ], ], [ @@ -164,7 +164,7 @@ # Frame 4 [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], ] - ], + ] ], # Features [ @@ -183,7 +183,7 @@ # Frame 4 [[0, 0, 5, 5, 5], [0, 0, 5, 5, 5], [0, 0, 5, 5, 5], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], ] - ], + ] ], # Expected [ @@ -232,7 +232,7 @@ [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], ], ] - ], + ] ], ], ] diff --git a/tests/test_phl_cuda.py b/tests/test_phl_cuda.py index 8f7fc6fc3d..d49a60ecd9 100644 --- a/tests/test_phl_cuda.py +++ b/tests/test_phl_cuda.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -42,7 +42,7 @@ # Batch 0 [ # Channel 0 - [1, 0.2, 0.5, 0, 1], + [1, 0.2, 0.5, 0, 1] ], # Batch 1 [ @@ -79,15 +79,15 @@ [0, 0, 0, 0, 1], # Channel 2 [0, 0, 1, 0, 0], - ], + ] ], # Features [ # Batch 0 [ # Channel 0 - [1, 0.2, 0.5, 0.2, 1], - ], + [1, 0.2, 0.5, 0.2, 1] + ] ], # Expected [ @@ -99,7 +99,7 @@ [0.229572, 0.182884, 0.202637, 0.182884, 0.229572], # Channel 2 [0.201235, 0.208194, 0.205409, 0.208194, 0.201235], - ], + ] ], ], [ @@ -113,7 +113,7 @@ [ # Channel 0 [[9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 6, 6, 6], [9, 9, 6, 6, 6]] - ], + ] ], # Features [ @@ -125,7 +125,7 @@ [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]], # Channel 2 [[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3], [4, 4, 4, 4, 4]], - ], + ] ], # Expected [ @@ -139,7 +139,7 @@ [7.712976, 7.429060, 5.789552, 5.594258, 5.371737], [7.701185, 7.492719, 5.860026, 5.538241, 5.281656], ] - ], + ] ], ], ] diff --git a/tests/test_pil_reader.py b/tests/test_pil_reader.py index 0a076b581e..0f7792a56c 100644 --- a/tests/test_pil_reader.py +++ b/tests/test_pil_reader.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_plot_2d_or_3d_image.py b/tests/test_plot_2d_or_3d_image.py index 645658e311..2e4adb93e3 100644 --- a/tests/test_plot_2d_or_3d_image.py +++ b/tests/test_plot_2d_or_3d_image.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,7 +17,11 @@ from parameterized import parameterized from torch.utils.tensorboard import SummaryWriter +from monai.utils import optional_import from monai.visualize import plot_2d_or_3d_image +from tests.utils import SkipIfNoModule + +SummaryWriterX, _ = optional_import("tensorboardX", name="SummaryWriter") TEST_CASE_1 = [(1, 1, 10, 10)] @@ -32,10 +36,30 @@ class TestPlot2dOr3dImage(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) - def test_tb_image_shape(self, shape): + def test_tb_image(self, shape): with tempfile.TemporaryDirectory() as tempdir: writer = SummaryWriter(log_dir=tempdir) - plot_2d_or_3d_image(torch.zeros(shape), 0, writer) + plot_2d_or_3d_image(torch.zeros(shape), 0, writer, max_channels=3, frame_dim=-1) + writer.flush() + writer.close() + self.assertTrue(len(glob.glob(tempdir)) > 0) + + @SkipIfNoModule("tensorboardX") + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + def test_tbx_image(self, shape): + with tempfile.TemporaryDirectory() as tempdir: + writer = SummaryWriterX(log_dir=tempdir) + plot_2d_or_3d_image(torch.zeros(shape), 0, writer, max_channels=2) + writer.flush() + writer.close() + self.assertTrue(len(glob.glob(tempdir)) > 0) + + @SkipIfNoModule("tensorboardX") + @parameterized.expand([TEST_CASE_5]) + def test_tbx_video(self, shape): + with tempfile.TemporaryDirectory() as tempdir: + writer = SummaryWriterX(log_dir=tempdir) + plot_2d_or_3d_image(torch.rand(shape), 0, writer, max_channels=3) writer.flush() writer.close() self.assertTrue(len(glob.glob(tempdir)) > 0) diff --git a/tests/test_png_rw.py b/tests/test_png_rw.py index 265b31b83b..84251b391f 100644 --- a/tests/test_png_rw.py +++ b/tests/test_png_rw.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_png_saver.py b/tests/test_png_saver.py index f8ea1df54b..d832718643 100644 --- a/tests/test_png_saver.py +++ b/tests/test_png_saver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -60,11 +60,7 @@ def test_saved_specified_root(self): with tempfile.TemporaryDirectory() as tempdir: saver = PNGSaver( - output_dir=tempdir, - output_postfix="seg", - output_ext=".png", - scale=255, - data_root_dir="test", + output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255, data_root_dir="test" ) meta_data = { diff --git a/tests/test_polyval.py b/tests/test_polyval.py index 4ff05bc817..db3bcaca53 100644 --- a/tests/test_polyval.py +++ b/tests/test_polyval.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_prepare_batch_default.py b/tests/test_prepare_batch_default.py new file mode 100644 index 0000000000..8aabb4e9ce --- /dev/null +++ b/tests/test_prepare_batch_default.py @@ -0,0 +1,54 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from monai.engines import PrepareBatchDefault, SupervisedEvaluator +from tests.utils import assert_allclose + + +class TestNet(torch.nn.Module): + def forward(self, x: torch.Tensor): + return x + + +class TestPrepareBatchDefault(unittest.TestCase): + def test_content(self): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataloader = [ + { + "image": torch.tensor([1, 2]), + "label": torch.tensor([3, 4]), + "extra1": torch.tensor([5, 6]), + "extra2": 16, + "extra3": "test", + } + ] + # set up engine + evaluator = SupervisedEvaluator( + device=device, + val_data_loader=dataloader, + epoch_length=1, + network=TestNet(), + non_blocking=False, + prepare_batch=PrepareBatchDefault(), + decollate=False, + ) + evaluator.run() + output = evaluator.state.output + assert_allclose(output["image"], torch.tensor([1, 2], device=device)) + assert_allclose(output["label"], torch.tensor([3, 4], device=device)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_prepare_batch_extra_input.py b/tests/test_prepare_batch_extra_input.py new file mode 100644 index 0000000000..79c9a13679 --- /dev/null +++ b/tests/test_prepare_batch_extra_input.py @@ -0,0 +1,76 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.engines import PrepareBatchExtraInput, SupervisedEvaluator +from tests.utils import assert_allclose + +TEST_CASE_0 = [ + {"extra_keys": "extra1"}, + {"x": torch.tensor([1, 2]), "t1": torch.tensor([5, 6]), "t2": None, "t3": None}, +] + +TEST_CASE_1 = [ + {"extra_keys": ["extra1", "extra3"]}, + {"x": torch.tensor([1, 2]), "t1": torch.tensor([5, 6]), "t2": "test", "t3": None}, +] + +TEST_CASE_2 = [ + {"extra_keys": {"t1": "extra2", "t2": "extra3", "t3": "extra1"}}, + {"x": torch.tensor([1, 2]), "t1": 16, "t2": "test", "t3": torch.tensor([5, 6])}, +] + + +class TestNet(torch.nn.Module): + def forward(self, x: torch.Tensor, t1=None, t2=None, t3=None): + return {"x": x, "t1": t1, "t2": t2, "t3": t3} + + +class TestPrepareBatchExtraInput(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) + def test_content(self, input_args, expected_value): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataloader = [ + { + "image": torch.tensor([1, 2]), + "label": torch.tensor([3, 4]), + "extra1": torch.tensor([5, 6]), + "extra2": 16, + "extra3": "test", + } + ] + # set up engine + evaluator = SupervisedEvaluator( + device=device, + val_data_loader=dataloader, + epoch_length=1, + network=TestNet(), + non_blocking=True, + prepare_batch=PrepareBatchExtraInput(**input_args), + decollate=False, + ) + evaluator.run() + output = evaluator.state.output + assert_allclose(output["image"], torch.tensor([1, 2], device=device)) + assert_allclose(output["label"], torch.tensor([3, 4], device=device)) + for k, v in output["pred"].items(): + if isinstance(v, torch.Tensor): + assert_allclose(v, expected_value[k].to(device)) + else: + self.assertEqual(v, expected_value[k]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_print_info.py b/tests/test_print_info.py index 64f0b66949..591316884c 100644 --- a/tests/test_print_info.py +++ b/tests/test_print_info.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_print_transform_backends.py b/tests/test_print_transform_backends.py index 4164687f01..2db00fea39 100644 --- a/tests/test_print_transform_backends.py +++ b/tests/test_print_transform_backends.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_probnms.py b/tests/test_probnms.py index e51d1017d8..aab312c1db 100644 --- a/tests/test_probnms.py +++ b/tests/test_probnms.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,87 +16,54 @@ from parameterized import parameterized from monai.transforms.post.array import ProbNMS +from tests.utils import TEST_NDARRAYS, assert_allclose -probs_map_1 = np.random.rand(100, 100).clip(0, 0.5) -TEST_CASES_2D_1 = [{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": 10}, probs_map_1, []] +TESTS = [] +for p in TEST_NDARRAYS: + probs_map_1 = p(np.random.rand(100, 100).clip(0, 0.5)) + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": 10}, probs_map_1, []]) -probs_map_2 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_2[33, 33] = 0.7 -probs_map_2[66, 66] = 0.9 -expected_2 = [[0.9, 66, 66], [0.7, 33, 33]] -TEST_CASES_2D_2 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": [10, 10]}, - probs_map_2, - expected_2, -] + probs_map_2 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_2[33, 33] = 0.7 + probs_map_2[66, 66] = 0.9 + expected_2 = [[0.9, 66, 66], [0.7, 33, 33]] + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": [10, 10]}, probs_map_2, expected_2]) -probs_map_3 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_3[56, 58] = 0.7 -probs_map_3[60, 66] = 0.8 -probs_map_3[66, 66] = 0.9 -expected_3 = [[0.9, 66, 66], [0.8, 60, 66]] -TEST_CASES_2D_3 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": (10, 20)}, - probs_map_3, - expected_3, -] + probs_map_3 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_3[56, 58] = 0.7 + probs_map_3[60, 66] = 0.8 + probs_map_3[66, 66] = 0.9 + expected_3 = [[0.9, 66, 66], [0.8, 60, 66]] + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": (10, 20)}, probs_map_3, expected_3]) -probs_map_4 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_4[33, 33] = 0.7 -probs_map_4[66, 66] = 0.9 -expected_4 = [[0.9, 66, 66]] -TEST_CASES_2D_4 = [ - {"spatial_dims": 2, "prob_threshold": 0.8, "box_size": 10}, - probs_map_4, - expected_4, -] + probs_map_4 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_4[33, 33] = 0.7 + probs_map_4[66, 66] = 0.9 + expected_4 = [[0.9, 66, 66]] + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.8, "box_size": 10}, probs_map_4, expected_4]) -probs_map_5 = np.random.rand(100, 100).clip(0, 0.5) -TEST_CASES_2D_5 = [{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, probs_map_5, []] + probs_map_5 = p(np.random.rand(100, 100).clip(0, 0.5)) + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, probs_map_5, []]) -probs_map_6 = torch.as_tensor(np.random.rand(100, 100).clip(0, 0.5)) -TEST_CASES_2D_6 = [{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, probs_map_6, []] + probs_map_6 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_6[33, 33] = 0.7 + probs_map_6[66, 66] = 0.9 + expected_6 = [[0.9, 66, 66], [0.7, 33, 33]] + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, probs_map_6, expected_6]) -probs_map_7 = torch.as_tensor(np.random.rand(100, 100).clip(0, 0.5)) -probs_map_7[33, 33] = 0.7 -probs_map_7[66, 66] = 0.9 -if torch.cuda.is_available(): - probs_map_7 = probs_map_7.cuda() -expected_7 = [[0.9, 66, 66], [0.7, 33, 33]] -TEST_CASES_2D_7 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, - probs_map_7, - expected_7, -] - -probs_map_3d = torch.rand([50, 50, 50]).uniform_(0, 0.5) -probs_map_3d[25, 25, 25] = 0.7 -probs_map_3d[45, 45, 45] = 0.9 -expected_3d = [[0.9, 45, 45, 45], [0.7, 25, 25, 25]] -TEST_CASES_3D = [ - {"spatial_dims": 3, "prob_threshold": 0.5, "box_size": (10, 10, 10)}, - probs_map_3d, - expected_3d, -] + probs_map_3d = p(torch.rand([50, 50, 50]).uniform_(0, 0.5)) + probs_map_3d[25, 25, 25] = 0.7 + probs_map_3d[45, 45, 45] = 0.9 + expected_3d = [[0.9, 45, 45, 45], [0.7, 25, 25, 25]] + TESTS.append([{"spatial_dims": 3, "prob_threshold": 0.5, "box_size": (10, 10, 10)}, probs_map_3d, expected_3d]) class TestProbNMS(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASES_2D_1, - TEST_CASES_2D_2, - TEST_CASES_2D_3, - TEST_CASES_2D_4, - TEST_CASES_2D_5, - TEST_CASES_2D_6, - TEST_CASES_2D_7, - TEST_CASES_3D, - ] - ) + @parameterized.expand(TESTS) def test_output(self, class_args, probs_map, expected): nms = ProbNMS(**class_args) output = nms(probs_map) - np.testing.assert_allclose(output, expected) + assert_allclose(output, expected) if __name__ == "__main__": diff --git a/tests/test_probnmsd.py b/tests/test_probnmsd.py index 5b75d4310f..bb2315487b 100644 --- a/tests/test_probnmsd.py +++ b/tests/test_probnmsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,89 +10,63 @@ # limitations under the License. import unittest +from typing import Any, List import numpy as np import torch from parameterized import parameterized from monai.transforms.post.dictionary import ProbNMSD +from tests.utils import TEST_NDARRAYS -probs_map_1 = np.random.rand(100, 100).clip(0, 0.5) -TEST_CASES_2D_1 = [{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": 10}, {"prob_map": probs_map_1}, []] +TESTS: List[Any] = [] +for p in TEST_NDARRAYS: + probs_map_1 = p(np.random.rand(100, 100).clip(0, 0.5)) + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": 10}, {"prob_map": probs_map_1}, []]) -probs_map_2 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_2[33, 33] = 0.7 -probs_map_2[66, 66] = 0.9 -expected_2 = [[0.9, 66, 66], [0.7, 33, 33]] -TEST_CASES_2D_2 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": [10, 10]}, - {"prob_map": probs_map_2}, - expected_2, -] - -probs_map_3 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_3[56, 58] = 0.7 -probs_map_3[60, 66] = 0.8 -probs_map_3[66, 66] = 0.9 -expected_3 = [[0.9, 66, 66], [0.8, 60, 66]] -TEST_CASES_2D_3 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": (10, 20)}, - {"prob_map": probs_map_3}, - expected_3, -] + probs_map_2 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_2[33, 33] = 0.7 + probs_map_2[66, 66] = 0.9 + expected_2 = [[0.9, 66, 66], [0.7, 33, 33]] + TESTS.append( + [{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": [10, 10]}, {"prob_map": probs_map_2}, expected_2] + ) -probs_map_4 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_4[33, 33] = 0.7 -probs_map_4[66, 66] = 0.9 -expected_4 = [[0.9, 66, 66]] -TEST_CASES_2D_4 = [ - {"spatial_dims": 2, "prob_threshold": 0.8, "box_size": 10}, - {"prob_map": probs_map_4}, - expected_4, -] + probs_map_3 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_3[56, 58] = 0.7 + probs_map_3[60, 66] = 0.8 + probs_map_3[66, 66] = 0.9 + expected_3 = [[0.9, 66, 66], [0.8, 60, 66]] + TESTS.append( + [{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": (10, 20)}, {"prob_map": probs_map_3}, expected_3] + ) -probs_map_5 = np.random.rand(100, 100).clip(0, 0.5) -TEST_CASES_2D_5 = [{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, {"prob_map": probs_map_5}, []] + probs_map_4 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_4[33, 33] = 0.7 + probs_map_4[66, 66] = 0.9 + expected_4 = [[0.9, 66, 66]] + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.8, "box_size": 10}, {"prob_map": probs_map_4}, expected_4]) -probs_map_6 = torch.as_tensor(np.random.rand(100, 100).clip(0, 0.5)) -TEST_CASES_2D_6 = [{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, {"prob_map": probs_map_6}, []] + probs_map_5 = p(np.random.rand(100, 100).clip(0, 0.5)) + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, {"prob_map": probs_map_5}, []]) -probs_map_7 = torch.as_tensor(np.random.rand(100, 100).clip(0, 0.5)) -probs_map_7[33, 33] = 0.7 -probs_map_7[66, 66] = 0.9 -if torch.cuda.is_available(): - probs_map_7 = probs_map_7.cuda() -expected_7 = [[0.9, 66, 66], [0.7, 33, 33]] -TEST_CASES_2D_7 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, - {"prob_map": probs_map_7}, - expected_7, -] + probs_map_6 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_6[33, 33] = 0.7 + probs_map_6[66, 66] = 0.9 + expected_6 = [[0.9, 66, 66], [0.7, 33, 33]] + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, {"prob_map": probs_map_6}, expected_6]) -probs_map_3d = torch.rand([50, 50, 50]).uniform_(0, 0.5) -probs_map_3d[25, 25, 25] = 0.7 -probs_map_3d[45, 45, 45] = 0.9 -expected_3d = [[0.9, 45, 45, 45], [0.7, 25, 25, 25]] -TEST_CASES_3D = [ - {"spatial_dims": 3, "prob_threshold": 0.5, "box_size": (10, 10, 10)}, - {"prob_map": probs_map_3d}, - expected_3d, -] + probs_map_3d = p(torch.rand([50, 50, 50]).uniform_(0, 0.5)) + probs_map_3d[25, 25, 25] = 0.7 + probs_map_3d[45, 45, 45] = 0.9 + expected_3d = [[0.9, 45, 45, 45], [0.7, 25, 25, 25]] + TESTS.append( + [{"spatial_dims": 3, "prob_threshold": 0.5, "box_size": (10, 10, 10)}, {"prob_map": probs_map_3d}, expected_3d] + ) class TestProbNMS(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASES_2D_1, - TEST_CASES_2D_2, - TEST_CASES_2D_3, - TEST_CASES_2D_4, - TEST_CASES_2D_5, - TEST_CASES_2D_6, - TEST_CASES_2D_7, - TEST_CASES_3D, - ] - ) + @parameterized.expand(TESTS) def test_output(self, class_args, probs_map, expected): nms = ProbNMSD(keys="prob_map", **class_args) output = nms(probs_map) diff --git a/tests/test_pytorch_version_after.py b/tests/test_pytorch_version_after.py new file mode 100644 index 0000000000..68abb9571f --- /dev/null +++ b/tests/test_pytorch_version_after.py @@ -0,0 +1,47 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from parameterized import parameterized + +from monai.utils import pytorch_after + +TEST_CASES = ( + (1, 5, 9, "1.6.0"), + (1, 6, 0, "1.6.0"), + (1, 6, 1, "1.6.0", False), + (1, 7, 0, "1.6.0", False), + (2, 6, 0, "1.6.0", False), + (0, 6, 0, "1.6.0a0+3fd9dcf"), + (1, 5, 9, "1.6.0a0+3fd9dcf"), + (1, 6, 0, "1.6.0a0+3fd9dcf", False), + (1, 6, 1, "1.6.0a0+3fd9dcf", False), + (2, 6, 0, "1.6.0a0+3fd9dcf", False), + (1, 6, 0, "1.6.0-rc0+3fd9dcf", False), # defaults to prerelease + (1, 6, 0, "1.6.0rc0", False), + (1, 6, 0, "1.6", True), + (1, 6, 0, "1", False), + (1, 6, 0, "1.6.0+cpu", True), + (1, 6, 1, "1.6.0+cpu", False), +) + + +class TestPytorchVersionCompare(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_compare(self, a, b, p, current, expected=True): + """Test pytorch_after with a and b""" + self.assertEqual(pytorch_after(a, b, p, current), expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_query_memory.py b/tests/test_query_memory.py index 22c29598fc..cdb44d3eb1 100644 --- a/tests/test_query_memory.py +++ b/tests/test_query_memory.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_rand_adjust_contrast.py b/tests/test_rand_adjust_contrast.py index d7d750957d..eaeff70d51 100644 --- a/tests/test_rand_adjust_contrast.py +++ b/tests/test_rand_adjust_contrast.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import RandAdjustContrast -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose TEST_CASE_1 = [(0.5, 4.5)] @@ -26,14 +26,16 @@ class TestRandAdjustContrast(NumpyImageTestCase2D): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_correct_results(self, gamma): adjuster = RandAdjustContrast(prob=1.0, gamma=gamma) - result = adjuster(self.imt) - epsilon = 1e-7 - img_min = self.imt.min() - img_range = self.imt.max() - img_min - expected = ( - np.power(((self.imt - img_min) / float(img_range + epsilon)), adjuster.gamma_value) * img_range + img_min - ) - np.testing.assert_allclose(expected, result, rtol=1e-05) + for p in TEST_NDARRAYS: + result = adjuster(p(self.imt)) + epsilon = 1e-7 + img_min = self.imt.min() + img_range = self.imt.max() - img_min + expected = ( + np.power(((self.imt - img_min) / float(img_range + epsilon)), adjuster.gamma_value) * img_range + + img_min + ) + assert_allclose(expected, result, rtol=1e-05, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_adjust_contrastd.py b/tests/test_rand_adjust_contrastd.py index e4b61293bb..e5f1f6099a 100644 --- a/tests/test_rand_adjust_contrastd.py +++ b/tests/test_rand_adjust_contrastd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import RandAdjustContrastd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose TEST_CASE_1 = [(0.5, 4.5)] @@ -26,14 +26,16 @@ class TestRandAdjustContrastd(NumpyImageTestCase2D): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_correct_results(self, gamma): adjuster = RandAdjustContrastd("img", prob=1.0, gamma=gamma) - result = adjuster({"img": self.imt}) - epsilon = 1e-7 - img_min = self.imt.min() - img_range = self.imt.max() - img_min - expected = ( - np.power(((self.imt - img_min) / float(img_range + epsilon)), adjuster.gamma_value) * img_range + img_min - ) - np.testing.assert_allclose(expected, result["img"], rtol=1e-05) + for p in TEST_NDARRAYS: + result = adjuster({"img": p(self.imt)}) + epsilon = 1e-7 + img_min = self.imt.min() + img_range = self.imt.max() - img_min + expected = ( + np.power(((self.imt - img_min) / float(img_range + epsilon)), adjuster.adjuster.gamma_value) * img_range + + img_min + ) + assert_allclose(expected, result["img"], rtol=1e-05, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_affine.py b/tests/test_rand_affine.py index 1e1a23bc09..8de408ab84 100644 --- a/tests/test_rand_affine.py +++ b/tests/test_rand_affine.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,114 +16,130 @@ from parameterized import parameterized from monai.transforms import RandAffine +from monai.utils.type_conversion import convert_data_type +from tests.utils import TEST_NDARRAYS, assert_allclose, is_tf32_env -TEST_CASES = [ - [ - dict(as_tensor_output=False, device=None), - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - dict(as_tensor_output=False, device=None, spatial_size=-1), - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - dict(as_tensor_output=False, device=None), - {"img": torch.arange(27).reshape((3, 3, 3)), "spatial_size": (2, 2)}, - np.array([[[2.0, 3.0], [5.0, 6.0]], [[11.0, 12.0], [14.0, 15.0]], [[20.0, 21.0], [23.0, 24.0]]]), - ], - [ - dict(as_tensor_output=True, device=None), - {"img": torch.ones((1, 3, 3, 3)), "spatial_size": (2, 2, 2)}, - torch.ones((1, 2, 2, 2)), - ], - [ - dict(as_tensor_output=True, device=None, spatial_size=(2, 2, 2), cache_grid=True), - {"img": torch.ones((1, 3, 3, 3))}, - torch.ones((1, 2, 2, 2)), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - padding_mode="zeros", - spatial_size=(2, 2, 2), - device=None, - ), - {"img": torch.ones((1, 3, 3, 3)), "mode": "bilinear"}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - padding_mode="zeros", - spatial_size=(2, 2, 2), - cache_grid=True, - device=None, - ), - {"img": torch.ones((1, 3, 3, 3)), "mode": "bilinear"}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=True, - device=None, - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "spatial_size": (3, 3)}, - torch.tensor([[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - spatial_size=(3, 3), - cache_grid=True, - as_tensor_output=True, - device=None, - ), - {"img": torch.arange(64).reshape((1, 8, 8))}, - torch.tensor([[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]), - ], -] +_rtol = 1e-3 if is_tf32_env() else 1e-4 + +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [dict(device=device), {"img": p(torch.arange(27).reshape((3, 3, 3)))}, p(np.arange(27).reshape((3, 3, 3)))] + ) + TESTS.append( + [ + dict(device=device, spatial_size=-1), + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), + ] + ) + TESTS.append( + [ + dict(device=device), + {"img": p(torch.arange(27).reshape((3, 3, 3))), "spatial_size": (2, 2)}, + p(np.array([[[2.0, 3.0], [5.0, 6.0]], [[11.0, 12.0], [14.0, 15.0]], [[20.0, 21.0], [23.0, 24.0]]])), + ] + ) + TESTS.append( + [ + dict(device=device), + {"img": p(torch.ones((1, 3, 3, 3))), "spatial_size": (2, 2, 2)}, + p(torch.ones((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2, 2), cache_grid=True), + {"img": p(torch.ones((1, 3, 3, 3)))}, + p(torch.ones((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + padding_mode="zeros", + spatial_size=(2, 2, 2), + device=device, + ), + {"img": p(torch.ones((1, 3, 3, 3))), "mode": "bilinear"}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + padding_mode="zeros", + spatial_size=(2, 2, 2), + cache_grid=True, + device=device, + ), + {"img": p(torch.ones((1, 3, 3, 3))), "mode": "bilinear"}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "spatial_size": (3, 3)}, + p( + torch.tensor( + [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]] + ) + ), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + cache_grid=True, + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8)))}, + p( + torch.tensor( + [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]] + ) + ), + ] + ) -ARR_NUMPY = np.arange(9 * 10).reshape(1, 9, 10) -ARR_TORCH = torch.Tensor(ARR_NUMPY) TEST_CASES_SKIPPED_CONSISTENCY = [] -for im in (ARR_NUMPY, ARR_TORCH): - for as_tensor_output in (True, False): - for in_dtype_is_int in (True, False): - TEST_CASES_SKIPPED_CONSISTENCY.append((im, as_tensor_output, in_dtype_is_int)) +for p in TEST_NDARRAYS: + for in_dtype in (np.int32, np.float32): + TEST_CASES_SKIPPED_CONSISTENCY.append((p(np.arange(9 * 10).reshape(1, 9, 10)), in_dtype)) class TestRandAffine(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_affine(self, input_param, input_data, expected_val): g = RandAffine(**input_param) g.set_random_state(123) result = g(**input_data) if input_param.get("cache_grid", False): self.assertTrue(g._cached_grid is not None) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=_rtol, atol=1e-4) def test_ill_cache(self): with self.assertWarns(UserWarning): @@ -132,15 +148,11 @@ def test_ill_cache(self): RandAffine(cache_grid=True, spatial_size=(1, 1, -1)) @parameterized.expand(TEST_CASES_SKIPPED_CONSISTENCY) - def test_skipped_transform_consistency(self, im, as_tensor_output, in_dtype_is_int): - t1 = RandAffine(prob=0, as_tensor_output=as_tensor_output) - t2 = RandAffine(prob=1, spatial_size=(10, 11), as_tensor_output=as_tensor_output) + def test_skipped_transform_consistency(self, im, in_dtype): + t1 = RandAffine(prob=0) + t2 = RandAffine(prob=1, spatial_size=(10, 11)) - # change dtype to int32 or float32 - if in_dtype_is_int: - im = im.astype("int32") if isinstance(im, np.ndarray) else im.int() - else: - im = im.astype("float32") if isinstance(im, np.ndarray) else im.float() + im, *_ = convert_data_type(im, dtype=in_dtype) out1 = t1(im) out2 = t2(im) diff --git a/tests/test_rand_affine_grid.py b/tests/test_rand_affine_grid.py index 605d0a30ba..60ac40f468 100644 --- a/tests/test_rand_affine_grid.py +++ b/tests/test_rand_affine_grid.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,182 +16,194 @@ from parameterized import parameterized from monai.transforms import RandAffineGrid +from tests.utils import TEST_NDARRAYS, assert_allclose, is_tf32_env -TEST_CASES = [ - [{"as_tensor_output": False, "device": None}, {"grid": torch.ones((3, 3, 3))}, np.ones((3, 3, 3))], - [ - {"rotate_range": (1, 2), "translate_range": (3, 3, 3)}, - {"grid": torch.arange(0, 27).reshape((3, 3, 3))}, - torch.tensor( - np.array( - [ - [ - [-32.81998, -33.910976, -35.001972], - [-36.092968, -37.183964, -38.27496], - [-39.36596, -40.456955, -41.54795], - ], - [[2.1380205, 3.1015975, 4.0651755], [5.028752, 5.9923296, 6.955907], [7.919484, 8.883063, 9.84664]], - [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0], [24.0, 25.0, 26.0]], - ] - ) - ), - ], - [ - {"translate_range": (3, 3, 3), "as_tensor_output": False, "device": torch.device("cpu:0")}, - {"spatial_size": (3, 3, 3)}, - np.array( +_rtol = 1e-1 if is_tf32_env else 1e-4 + +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append([{"device": device}, {"grid": p(torch.ones((3, 3, 3)))}, p(np.ones((3, 3, 3)))]) + TESTS.append( [ - [ - [ - [0.17881513, 0.17881513, 0.17881513], - [0.17881513, 0.17881513, 0.17881513], - [0.17881513, 0.17881513, 0.17881513], - ], - [ - [1.1788151, 1.1788151, 1.1788151], - [1.1788151, 1.1788151, 1.1788151], - [1.1788151, 1.1788151, 1.1788151], - ], - [ - [2.1788151, 2.1788151, 2.1788151], - [2.1788151, 2.1788151, 2.1788151], - [2.1788151, 2.1788151, 2.1788151], - ], - ], - [ - [ - [-2.283164, -2.283164, -2.283164], - [-1.283164, -1.283164, -1.283164], - [-0.28316402, -0.28316402, -0.28316402], - ], - [ - [-2.283164, -2.283164, -2.283164], - [-1.283164, -1.283164, -1.283164], - [-0.28316402, -0.28316402, -0.28316402], - ], - [ - [-2.283164, -2.283164, -2.283164], - [-1.283164, -1.283164, -1.283164], - [-0.28316402, -0.28316402, -0.28316402], - ], - ], - [ - [ - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - ], - [ - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - ], - [ - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - ], - ], - [ - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - ], - ] - ), - ], - [ - {"rotate_range": (1.0, 1.0, 1.0), "shear_range": (0.1,), "scale_range": (1.2,)}, - {"grid": torch.arange(0, 108).reshape((4, 3, 3, 3))}, - torch.tensor( - np.array( - [ - [ - [ - [-9.4201e00, -8.1672e00, -6.9143e00], - [-5.6614e00, -4.4085e00, -3.1556e00], - [-1.9027e00, -6.4980e-01, 6.0310e-01], - ], - [ - [1.8560e00, 3.1089e00, 4.3618e00], - [5.6147e00, 6.8676e00, 8.1205e00], - [9.3734e00, 1.0626e01, 1.1879e01], - ], + {"rotate_range": (1, 2), "translate_range": (3, 3, 3)}, + {"grid": p(torch.arange(0, 27).reshape((3, 3, 3)))}, + p( + np.array( [ - [1.3132e01, 1.4385e01, 1.5638e01], - [1.6891e01, 1.8144e01, 1.9397e01], - [2.0650e01, 2.1902e01, 2.3155e01], - ], - ], - [ - [ - [9.9383e-02, -4.8845e-01, -1.0763e00], - [-1.6641e00, -2.2519e00, -2.8398e00], - [-3.4276e00, -4.0154e00, -4.6032e00], - ], - [ - [-5.1911e00, -5.7789e00, -6.3667e00], - [-6.9546e00, -7.5424e00, -8.1302e00], - [-8.7180e00, -9.3059e00, -9.8937e00], - ], - [ - [-1.0482e01, -1.1069e01, -1.1657e01], - [-1.2245e01, -1.2833e01, -1.3421e01], - [-1.4009e01, -1.4596e01, -1.5184e01], - ], - ], + [ + [-32.81998, -33.910976, -35.001972], + [-36.092968, -37.183964, -38.27496], + [-39.36596, -40.456955, -41.54795], + ], + [ + [2.1380205, 3.1015975, 4.0651755], + [5.028752, 5.9923296, 6.955907], + [7.919484, 8.883063, 9.84664], + ], + [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0], [24.0, 25.0, 26.0]], + ] + ) + ), + ] + ) + TESTS.append( + [ + {"translate_range": (3, 3, 3), "device": device}, + {"spatial_size": (3, 3, 3)}, + np.array( [ [ - [5.9635e01, 6.1199e01, 6.2764e01], - [6.4328e01, 6.5892e01, 6.7456e01], - [6.9021e01, 7.0585e01, 7.2149e01], - ], - [ - [7.3714e01, 7.5278e01, 7.6842e01], - [7.8407e01, 7.9971e01, 8.1535e01], - [8.3099e01, 8.4664e01, 8.6228e01], + [ + [0.17881513, 0.17881513, 0.17881513], + [0.17881513, 0.17881513, 0.17881513], + [0.17881513, 0.17881513, 0.17881513], + ], + [ + [1.1788151, 1.1788151, 1.1788151], + [1.1788151, 1.1788151, 1.1788151], + [1.1788151, 1.1788151, 1.1788151], + ], + [ + [2.1788151, 2.1788151, 2.1788151], + [2.1788151, 2.1788151, 2.1788151], + [2.1788151, 2.1788151, 2.1788151], + ], ], [ - [8.7792e01, 8.9357e01, 9.0921e01], - [9.2485e01, 9.4049e01, 9.5614e01], - [9.7178e01, 9.8742e01, 1.0031e02], + [ + [-2.283164, -2.283164, -2.283164], + [-1.283164, -1.283164, -1.283164], + [-0.28316402, -0.28316402, -0.28316402], + ], + [ + [-2.283164, -2.283164, -2.283164], + [-1.283164, -1.283164, -1.283164], + [-0.28316402, -0.28316402, -0.28316402], + ], + [ + [-2.283164, -2.283164, -2.283164], + [-1.283164, -1.283164, -1.283164], + [-0.28316402, -0.28316402, -0.28316402], + ], ], - ], - [ [ - [8.1000e01, 8.2000e01, 8.3000e01], - [8.4000e01, 8.5000e01, 8.6000e01], - [8.7000e01, 8.8000e01, 8.9000e01], + [ + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + ], + [ + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + ], + [ + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + ], ], [ - [9.0000e01, 9.1000e01, 9.2000e01], - [9.3000e01, 9.4000e01, 9.5000e01], - [9.6000e01, 9.7000e01, 9.8000e01], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], ], + ] + ), + ] + ) + TESTS.append( + [ + {"device": device, "rotate_range": (1.0, 1.0, 1.0), "shear_range": (0.1,), "scale_range": (1.2,)}, + {"grid": p(torch.arange(0, 108).reshape((4, 3, 3, 3)))}, + p( + np.array( [ - [9.9000e01, 1.0000e02, 1.0100e02], - [1.0200e02, 1.0300e02, 1.0400e02], - [1.0500e02, 1.0600e02, 1.0700e02], - ], - ], - ] - ) - ), - ], -] + [ + [ + [-9.4201e00, -8.1672e00, -6.9143e00], + [-5.6614e00, -4.4085e00, -3.1556e00], + [-1.9027e00, -6.4980e-01, 6.0310e-01], + ], + [ + [1.8560e00, 3.1089e00, 4.3618e00], + [5.6147e00, 6.8676e00, 8.1205e00], + [9.3734e00, 1.0626e01, 1.1879e01], + ], + [ + [1.3132e01, 1.4385e01, 1.5638e01], + [1.6891e01, 1.8144e01, 1.9397e01], + [2.0650e01, 2.1902e01, 2.3155e01], + ], + ], + [ + [ + [9.9383e-02, -4.8845e-01, -1.0763e00], + [-1.6641e00, -2.2519e00, -2.8398e00], + [-3.4276e00, -4.0154e00, -4.6032e00], + ], + [ + [-5.1911e00, -5.7789e00, -6.3667e00], + [-6.9546e00, -7.5424e00, -8.1302e00], + [-8.7180e00, -9.3059e00, -9.8937e00], + ], + [ + [-1.0482e01, -1.1069e01, -1.1657e01], + [-1.2245e01, -1.2833e01, -1.3421e01], + [-1.4009e01, -1.4596e01, -1.5184e01], + ], + ], + [ + [ + [5.9635e01, 6.1199e01, 6.2764e01], + [6.4328e01, 6.5892e01, 6.7456e01], + [6.9021e01, 7.0585e01, 7.2149e01], + ], + [ + [7.3714e01, 7.5278e01, 7.6842e01], + [7.8407e01, 7.9971e01, 8.1535e01], + [8.3099e01, 8.4664e01, 8.6228e01], + ], + [ + [8.7792e01, 8.9357e01, 9.0921e01], + [9.2485e01, 9.4049e01, 9.5614e01], + [9.7178e01, 9.8742e01, 1.0031e02], + ], + ], + [ + [ + [8.1000e01, 8.2000e01, 8.3000e01], + [8.4000e01, 8.5000e01, 8.6000e01], + [8.7000e01, 8.8000e01, 8.9000e01], + ], + [ + [9.0000e01, 9.1000e01, 9.2000e01], + [9.3000e01, 9.4000e01, 9.5000e01], + [9.6000e01, 9.7000e01, 9.8000e01], + ], + [ + [9.9000e01, 1.0000e02, 1.0100e02], + [1.0200e02, 1.0300e02, 1.0400e02], + [1.0500e02, 1.0600e02, 1.0700e02], + ], + ], + ] + ) + ), + ] + ) class TestRandAffineGrid(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_affine_grid(self, input_param, input_data, expected_val): g = RandAffineGrid(**input_param) g.set_random_state(123) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + if "device" in input_data: + self.assertEqual(result.device, input_data[device]) + assert_allclose(result, expected_val, type_test=False, rtol=_rtol, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_affined.py b/tests/test_rand_affined.py index d2f8a60665..882b5554e6 100644 --- a/tests/test_rand_affined.py +++ b/tests/test_rand_affined.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,179 +17,190 @@ from monai.transforms import RandAffined from monai.utils import GridSampleMode +from tests.utils import TEST_NDARRAYS, assert_allclose, is_tf32_env -TEST_CASES = [ - [ - dict(as_tensor_output=False, device=None, spatial_size=None, keys=("img", "seg")), - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - dict(as_tensor_output=False, device=None, spatial_size=(2, 2), keys=("img", "seg")), - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.ones((3, 2, 2)), - ], - [ - dict(as_tensor_output=False, device=None, spatial_size=(2, 2), cache_grid=True, keys=("img", "seg")), - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.ones((3, 2, 2)), - ], - [ - dict(as_tensor_output=True, device=None, spatial_size=(2, 2, 2), keys=("img", "seg")), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - torch.ones((1, 2, 2, 2)), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - spatial_size=(2, 2, 2), - padding_mode="zeros", - device=None, - keys=("img", "seg"), - mode="bilinear", - ), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=False, - spatial_size=(2, 2, 2), - padding_mode="zeros", - device=None, - cache_grid=True, - keys=("img", "seg"), - mode="bilinear", - ), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - np.array([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=True, - spatial_size=(3, 3), - keys=("img", "seg"), - device=None, - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - torch.tensor([[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]), - ], - [ - dict( - prob=0.9, - mode=("bilinear", "nearest"), - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=False, - spatial_size=(3, 3), - keys=("img", "seg"), - device=torch.device("cpu:0"), - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - { - "img": np.array( - [ - [ - [18.736153, 15.581954, 12.4277525], - [27.398798, 24.244598, 21.090399], - [36.061443, 32.90724, 29.753046], - ] - ] - ), - "seg": np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]]), - }, - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - spatial_size=(2, 2, 2), - padding_mode="zeros", - device=None, - keys=("img", "seg"), - mode=GridSampleMode.BILINEAR, - ), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=False, - spatial_size=(3, 3), - keys=("img", "seg"), - device=torch.device("cpu:0"), - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - { - "img": np.array( - [ - [ - [18.736153, 15.581954, 12.4277525], - [27.398798, 24.244598, 21.090399], - [36.061443, 32.90724, 29.753046], - ] - ] - ), - "seg": np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]]), - }, - ], - [ - dict( - prob=0.9, - mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=False, - spatial_size=(3, 3), - cache_grid=True, - keys=("img", "seg"), - device=torch.device("cpu:0"), - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - { - "img": np.array( - [ - [ - [18.736153, 15.581954, 12.4277525], - [27.398798, 24.244598, 21.090399], - [36.061443, 32.90724, 29.753046], - ] - ] - ), - "seg": np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]]), - }, - ], -] +_rtol = 1e-3 if is_tf32_env() else 1e-4 + +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + dict(device=device, spatial_size=None, keys=("img", "seg")), + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2), keys=("img", "seg")), + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p(np.ones((3, 2, 2))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2), cache_grid=True, keys=("img", "seg")), + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p(np.ones((3, 2, 2))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2, 2), keys=("img", "seg")), + {"img": p(torch.ones((1, 3, 3, 3))), "seg": p(torch.ones((1, 3, 3, 3)))}, + p(torch.ones((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + spatial_size=(2, 2, 2), + padding_mode="zeros", + device=device, + keys=("img", "seg"), + mode="bilinear", + ), + {"img": p(torch.ones((1, 3, 3, 3))), "seg": p(torch.ones((1, 3, 3, 3)))}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + p( + torch.tensor( + [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]] + ) + ), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + mode=("bilinear", "nearest"), + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + { + "img": p( + np.array( + [ + [ + [18.736153, 15.581954, 12.4277525], + [27.398798, 24.244598, 21.090399], + [36.061443, 32.90724, 29.753046], + ] + ] + ) + ), + "seg": p(np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), + }, + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + spatial_size=(2, 2, 2), + padding_mode="zeros", + device=device, + keys=("img", "seg"), + mode=GridSampleMode.BILINEAR, + ), + {"img": p(torch.ones((1, 3, 3, 3))), "seg": p(torch.ones((1, 3, 3, 3)))}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + { + "img": p( + np.array( + [ + [ + [18.736153, 15.581954, 12.4277525], + [27.398798, 24.244598, 21.090399], + [36.061443, 32.90724, 29.753046], + ] + ] + ) + ), + "seg": p(np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), + }, + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + cache_grid=True, + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + { + "img": p( + np.array( + [ + [ + [18.736153, 15.581954, 12.4277525], + [27.398798, 24.244598, 21.090399], + [36.061443, 32.90724, 29.753046], + ] + ] + ) + ), + "seg": p(np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), + }, + ] + ) class TestRandAffined(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_affined(self, input_param, input_data, expected_val): g = RandAffined(**input_param).set_random_state(123) res = g(input_data) @@ -200,28 +211,20 @@ def test_rand_affined(self, input_param, input_data, expected_val): if "_transforms" in key: continue expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected, rtol=_rtol, atol=1e-3) + + g.set_random_state(4) + res = g(input_data) + # affine should be tensor because the resampler only supports pytorch backend + self.assertTrue(isinstance(res["img_transforms"][0]["extra_info"]["affine"], torch.Tensor)) def test_ill_cache(self): with self.assertWarns(UserWarning): # spatial size is None - RandAffined( - as_tensor_output=False, device=None, spatial_size=None, prob=1.0, cache_grid=True, keys=("img", "seg") - ) + RandAffined(device=device, spatial_size=None, prob=1.0, cache_grid=True, keys=("img", "seg")) with self.assertWarns(UserWarning): # spatial size is dynamic - RandAffined( - as_tensor_output=False, - device=None, - spatial_size=(2, -1), - prob=1.0, - cache_grid=True, - keys=("img", "seg"), - ) + RandAffined(device=device, spatial_size=(2, -1), prob=1.0, cache_grid=True, keys=("img", "seg")) if __name__ == "__main__": diff --git a/tests/test_rand_axis_flip.py b/tests/test_rand_axis_flip.py index c05c3a1e0d..b7c504557f 100644 --- a/tests/test_rand_axis_flip.py +++ b/tests/test_rand_axis_flip.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,10 +22,8 @@ def test_correct_results(self): for p in TEST_NDARRAYS: flip = RandAxisFlip(prob=1.0) result = flip(p(self.imt[0])) - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, flip._axis)) - assert_allclose(np.stack(expected), result) + expected = [np.flip(channel, flip._axis) for channel in self.imt[0]] + assert_allclose(result, p(np.stack(expected))) if __name__ == "__main__": diff --git a/tests/test_rand_axis_flipd.py b/tests/test_rand_axis_flipd.py index 7bef0baa63..ff97d5dc1e 100644 --- a/tests/test_rand_axis_flipd.py +++ b/tests/test_rand_axis_flipd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,10 +23,8 @@ def test_correct_results(self): flip = RandAxisFlipd(keys="img", prob=1.0) result = flip({"img": p(self.imt[0])})["img"] - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, flip._axis)) - assert_allclose(np.stack(expected), result) + expected = [np.flip(channel, flip.flipper._axis) for channel in self.imt[0]] + assert_allclose(result, p(np.stack(expected))) if __name__ == "__main__": diff --git a/tests/test_random_bias_field.py b/tests/test_rand_bias_field.py similarity index 65% rename from tests/test_random_bias_field.py rename to tests/test_rand_bias_field.py index 5aeeb79874..b3aa8e9174 100644 --- a/tests/test_random_bias_field.py +++ b/tests/test_rand_bias_field.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,29 +12,34 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandBiasField -TEST_CASES_2D = [{}, (3, 32, 32)] -TEST_CASES_3D = [{}, (3, 32, 32, 32)] -TEST_CASES_2D_ZERO_RANGE = [{"coeff_range": (0.0, 0.0)}, (2, 3, 3)] -TEST_CASES_2D_ONES = [{"coeff_range": (1.0, 1.0)}, np.asarray([[[7.389056, 0.1353353], [7.389056, 22026.46]]])] +TEST_CASES_2D = [{"prob": 1.0}, (3, 32, 32)] +TEST_CASES_3D = [{"prob": 1.0}, (3, 32, 32, 32)] +TEST_CASES_2D_ZERO_RANGE = [{"prob": 1.0, "coeff_range": (0.0, 0.0)}, (2, 3, 3)] +TEST_CASES_2D_ONES = [ + {"prob": 1.0, "coeff_range": (1.0, 1.0)}, + np.asarray([[[7.389056, 0.1353353], [7.389056, 22026.46]]]), +] class TestRandBiasField(unittest.TestCase): @parameterized.expand([TEST_CASES_2D, TEST_CASES_3D]) def test_output_shape(self, class_args, img_shape): - for degree in [1, 2, 3]: - bias_field = RandBiasField(degree=degree, **class_args) - img = np.random.rand(*img_shape) - output = bias_field(img) - np.testing.assert_equal(output.shape, img_shape) - np.testing.assert_equal(output.dtype, bias_field.dtype) - - img_zero = np.zeros([*img_shape]) - output_zero = bias_field(img_zero) - np.testing.assert_equal(output_zero, img_zero) + for fn in (np.random, torch): + for degree in [1, 2, 3]: + bias_field = RandBiasField(degree=degree, **class_args) + img = fn.rand(*img_shape) + output = bias_field(img) + np.testing.assert_equal(output.shape, img_shape) + self.assertTrue(output.dtype in (np.float32, torch.float32)) + + img_zero = np.zeros([*img_shape]) + output_zero = bias_field(img_zero) + np.testing.assert_equal(output_zero, img_zero) @parameterized.expand([TEST_CASES_2D_ZERO_RANGE]) def test_zero_range(self, class_args, img_shape): diff --git a/tests/test_random_bias_fieldd.py b/tests/test_rand_bias_fieldd.py similarity index 89% rename from tests/test_random_bias_fieldd.py rename to tests/test_rand_bias_fieldd.py index aa2e206de9..da08cfe053 100644 --- a/tests/test_random_bias_fieldd.py +++ b/tests/test_rand_bias_fieldd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,11 +16,11 @@ from monai.transforms import RandBiasFieldd -TEST_CASES_2D = [{}, (3, 32, 32)] -TEST_CASES_3D = [{}, (3, 32, 32, 32)] -TEST_CASES_2D_ZERO_RANGE = [{"coeff_range": (0.0, 0.0)}, (3, 32, 32)] +TEST_CASES_2D = [{"prob": 1.0}, (3, 32, 32)] +TEST_CASES_3D = [{"prob": 1.0}, (3, 32, 32, 32)] +TEST_CASES_2D_ZERO_RANGE = [{"prob": 1.0, "coeff_range": (0.0, 0.0)}, (3, 32, 32)] TEST_CASES_2D_ONES = [ - {"coeff_range": (1.0, 1.0)}, + {"prob": 1.0, "coeff_range": (1.0, 1.0)}, np.asarray([[[7.3890562e00, 1.3533528e-01], [7.3890562e00, 2.2026465e04]]]), ] diff --git a/tests/test_rand_coarse_dropout.py b/tests/test_rand_coarse_dropout.py index 830832c2a5..a05d323277 100644 --- a/tests/test_rand_coarse_dropout.py +++ b/tests/test_rand_coarse_dropout.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandCoarseDropout @@ -52,12 +53,20 @@ np.random.randint(0, 2, size=[3, 3, 3, 4]), ] +TEST_CASE_7 = [ + {"holes": 2, "spatial_size": [2, 2, 2], "dropout_holes": False, "fill_value": (3, 6), "prob": 1.0}, + torch.randint(0, 2, size=[3, 3, 3, 4]), +] + class TestRandCoarseDropout(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand( + [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7] + ) def test_value(self, input_param, input_data): dropout = RandCoarseDropout(**input_param) result = dropout(input_data) + self.assertEqual(type(result), type(input_data)) holes = input_param.get("holes") max_holes = input_param.get("max_holes") spatial_size = fall_back_tuple(input_param.get("spatial_size"), input_data.shape[1:]) diff --git a/tests/test_rand_coarse_dropoutd.py b/tests/test_rand_coarse_dropoutd.py index fc898a9fca..e54db130a5 100644 --- a/tests/test_rand_coarse_dropoutd.py +++ b/tests/test_rand_coarse_dropoutd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -28,14 +28,7 @@ ] TEST_CASE_2 = [ - { - "keys": "img", - "holes": 2, - "spatial_size": [2, 2, 2], - "fill_value": 5, - "max_spatial_size": [4, 4, 3], - "prob": 1.0, - }, + {"keys": "img", "holes": 2, "spatial_size": [2, 2, 2], "fill_value": 5, "max_spatial_size": [4, 4, 3], "prob": 1.0}, {"img": np.random.randint(0, 2, size=[3, 3, 3, 4])}, ] diff --git a/tests/test_rand_coarse_shuffle.py b/tests/test_rand_coarse_shuffle.py new file mode 100644 index 0000000000..fb7311e5a3 --- /dev/null +++ b/tests/test_rand_coarse_shuffle.py @@ -0,0 +1,62 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import RandCoarseShuffle + +TEST_CASES = [ + [ + {"holes": 5, "spatial_size": 1, "max_spatial_size": -1, "prob": 0.0}, + {"img": np.arange(8).reshape((1, 2, 2, 2))}, + np.arange(8).reshape((1, 2, 2, 2)), + ], + [ + {"holes": 10, "spatial_size": 1, "max_spatial_size": -1, "prob": 1.0}, + {"img": np.arange(27).reshape((1, 3, 3, 3))}, + np.asarray( + [ + [ + [[8, 19, 26], [24, 6, 15], [0, 13, 25]], + [[17, 3, 5], [10, 1, 12], [22, 4, 11]], + [[21, 20, 23], [14, 2, 16], [18, 9, 7]], + ] + ] + ), + ], + [ + {"holes": 2, "spatial_size": 1, "max_spatial_size": -1, "prob": 1.0}, + {"img": np.arange(16).reshape((2, 2, 2, 2))}, + np.asarray([[[[6, 1], [4, 3]], [[0, 2], [7, 5]]], [[[14, 10], [9, 8]], [[12, 15], [13, 11]]]]), + ], + [ + {"holes": 2, "spatial_size": 1, "max_spatial_size": -1, "prob": 1.0}, + {"img": torch.arange(16).reshape((2, 2, 2, 2))}, + torch.as_tensor([[[[6, 1], [4, 3]], [[0, 2], [7, 5]]], [[[14, 10], [9, 8]], [[12, 15], [13, 11]]]]), + ], +] + + +class TestRandCoarseShuffle(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shuffle(self, input_param, input_data, expected_val): + g = RandCoarseShuffle(**input_param) + g.set_random_state(seed=12) + result = g(**input_data) + np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_coarse_shuffled.py b/tests/test_rand_coarse_shuffled.py new file mode 100644 index 0000000000..fa9c17286d --- /dev/null +++ b/tests/test_rand_coarse_shuffled.py @@ -0,0 +1,56 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandCoarseShuffled + +TEST_CASES = [ + [ + {"keys": "img", "holes": 5, "spatial_size": 1, "max_spatial_size": -1, "prob": 0.0}, + {"img": np.arange(8).reshape((1, 2, 2, 2))}, + np.arange(8).reshape((1, 2, 2, 2)), + ], + [ + {"keys": "img", "holes": 10, "spatial_size": 1, "max_spatial_size": -1, "prob": 1.0}, + {"img": np.arange(27).reshape((1, 3, 3, 3))}, + np.asarray( + [ + [ + [[8, 19, 26], [24, 6, 15], [0, 13, 25]], + [[17, 3, 5], [10, 1, 12], [22, 4, 11]], + [[21, 20, 23], [14, 2, 16], [18, 9, 7]], + ] + ] + ), + ], + [ + {"keys": "img", "holes": 2, "spatial_size": 1, "max_spatial_size": -1, "prob": 1.0}, + {"img": np.arange(16).reshape((2, 2, 2, 2))}, + np.asarray([[[[6, 1], [4, 3]], [[0, 2], [7, 5]]], [[[14, 10], [9, 8]], [[12, 15], [13, 11]]]]), + ], +] + + +class TestRandCoarseShuffled(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shuffle(self, input_param, input_data, expected_val): + g = RandCoarseShuffled(**input_param) + g.set_random_state(seed=12) + result = g(input_data) + np.testing.assert_allclose(result["img"], expected_val, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_crop_by_label_classes.py b/tests/test_rand_crop_by_label_classes.py index b21f971042..11d73df74e 100644 --- a/tests/test_rand_crop_by_label_classes.py +++ b/tests/test_rand_crop_by_label_classes.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,68 +15,121 @@ from parameterized import parameterized from monai.transforms import ClassesToIndices, RandCropByLabelClasses +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ +TESTS_INDICES, TESTS_SHAPE = [], [] +for p in TEST_NDARRAYS: # One-Hot label - { - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "num_classes": None, - "spatial_size": [2, 2, -1], - "ratios": [1, 1, 1], - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - list, - (3, 2, 2, 3), -] + TESTS_INDICES.append( + [ + { + "label": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "num_classes": None, + "spatial_size": [2, 2, -1], + "ratios": [1, 1, 1], + "num_samples": 2, + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image_threshold": 0, + }, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + list, + (3, 2, 2, 3), + ] + ) -TEST_CASE_1 = [ - # Argmax label - { - "label": np.random.randint(0, 2, size=[1, 3, 3, 3]), - "num_classes": 2, - "spatial_size": [2, 2, 2], - "ratios": [1, 1], - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - list, - (3, 2, 2, 2), -] + TESTS_INDICES.append( + [ + # Argmax label + { + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + "num_classes": 2, + "spatial_size": [2, 2, 2], + "ratios": [1, 1], + "num_samples": 2, + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image_threshold": 0, + }, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + list, + (3, 2, 2, 2), + ] + ) -TEST_CASE_2 = [ - # provide label at runtime - { - "label": None, - "num_classes": 2, - "spatial_size": [2, 2, 2], - "ratios": [1, 1], - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - { - "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[1, 3, 3, 3]), - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - }, - list, - (3, 2, 2, 2), -] + TESTS_SHAPE.append( + [ + # provide label at runtime + { + "label": None, + "num_classes": 2, + "spatial_size": [2, 2, 2], + "ratios": [1, 1], + "num_samples": 2, + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image_threshold": 0, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + }, + list, + (3, 2, 2, 2), + ] + ) + TESTS_SHAPE.append( + [ + # provide label at runtime + { + "label": None, + "num_classes": 2, + "spatial_size": [4, 4, 2], + "ratios": [1, 1], + "num_samples": 2, + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image_threshold": 0, + "allow_smaller": True, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + }, + list, + (3, 3, 3, 2), + ] + ) + TESTS_SHAPE.append( + [ + # provide label at runtime + { + "label": None, + "num_classes": 2, + "spatial_size": [4, 4, 4], + "ratios": [1, 1], + "num_samples": 2, + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image_threshold": 0, + "allow_smaller": True, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + }, + list, + (3, 3, 3, 3), + ] + ) class TestRandCropByLabelClasses(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS_INDICES + TESTS_SHAPE) def test_type_shape(self, input_param, input_data, expected_type, expected_shape): result = RandCropByLabelClasses(**input_param)(**input_data) self.assertIsInstance(result, expected_type) self.assertTupleEqual(result[0].shape, expected_shape) - @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) + @parameterized.expand(TESTS_INDICES) def test_indices(self, input_param, input_data, expected_type, expected_shape): input_param["indices"] = ClassesToIndices(num_classes=input_param["num_classes"])(input_param["label"]) result = RandCropByLabelClasses(**input_param)(**input_data) diff --git a/tests/test_rand_crop_by_label_classesd.py b/tests/test_rand_crop_by_label_classesd.py index 829096953b..92780458e0 100644 --- a/tests/test_rand_crop_by_label_classesd.py +++ b/tests/test_rand_crop_by_label_classesd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,52 +15,107 @@ from parameterized import parameterized from monai.transforms import ClassesToIndicesd, RandCropByLabelClassesd +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ - # One-Hot label - { - "keys": "img", - "label_key": "label", - "num_classes": None, - "spatial_size": [2, 2, -1], - "ratios": [1, 1, 1], - "num_samples": 2, - "image_key": "image", - "image_threshold": 0, - }, - { - "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - }, - list, - (3, 2, 2, 3), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + # One-Hot label + { + "keys": "img", + "label_key": "label", + "num_classes": None, + "spatial_size": [2, 2, -1], + "ratios": [1, 1, 1], + "num_samples": 2, + "image_key": "image", + "image_threshold": 0, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + }, + list, + (3, 2, 2, 3), + ] + ) -TEST_CASE_1 = [ - # Argmax label - { - "keys": "img", - "label_key": "label", - "num_classes": 2, - "spatial_size": [2, 2, 2], - "ratios": [1, 1], - "num_samples": 2, - "image_key": "image", - "image_threshold": 0, - }, - { - "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[1, 3, 3, 3]), - }, - list, - (3, 2, 2, 2), -] + TESTS.append( + [ + # Argmax label + { + "keys": "img", + "label_key": "label", + "num_classes": 2, + "spatial_size": [2, 2, 2], + "ratios": [1, 1], + "num_samples": 2, + "image_key": "image", + "image_threshold": 0, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + }, + list, + (3, 2, 2, 2), + ] + ) + + TESTS.append( + [ + # Argmax label + { + "keys": "img", + "label_key": "label", + "num_classes": 2, + "spatial_size": [4, 4, 2], + "ratios": [1, 1], + "num_samples": 2, + "image_key": "image", + "image_threshold": 0, + "allow_smaller": True, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + }, + list, + (3, 3, 3, 2), + ] + ) + + TESTS.append( + [ + # Argmax label + { + "keys": "img", + "label_key": "label", + "num_classes": 2, + "spatial_size": [4, 4, 4], + "ratios": [1, 1], + "num_samples": 2, + "image_key": "image", + "image_threshold": 0, + "allow_smaller": True, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + }, + list, + (3, 3, 3, 3), + ] + ) class TestRandCropByLabelClassesd(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) + @parameterized.expand(TESTS) def test_type_shape(self, input_param, input_data, expected_type, expected_shape): result = RandCropByLabelClassesd(**input_param)(input_data) self.assertIsInstance(result, expected_type) diff --git a/tests/test_rand_crop_by_pos_neg_label.py b/tests/test_rand_crop_by_pos_neg_label.py index e0f669ab3f..f8b8a77a45 100644 --- a/tests/test_rand_crop_by_pos_neg_label.py +++ b/tests/test_rand_crop_by_pos_neg_label.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,68 +10,123 @@ # limitations under the License. import unittest +from copy import deepcopy import numpy as np from parameterized import parameterized from monai.transforms import RandCropByPosNegLabel +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ - { - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "spatial_size": [2, 2, -1], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - list, - (3, 2, 2, 3), -] +TESTS = [] +TESTS.append( + [ + { + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "spatial_size": [2, 2, -1], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_threshold": 0, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, + (3, 2, 2, 3), + ] +) +TESTS.append( + [ + { + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "spatial_size": [2, 2, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_threshold": 0, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, + (3, 2, 2, 2), + ] +) +TESTS.append( + [ + { + "label": None, + "spatial_size": [2, 2, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_threshold": 0, + }, + { + "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + }, + (3, 2, 2, 2), + ] +) +TESTS.append( + [ + { + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "spatial_size": [4, 4, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "allow_smaller": True, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, + (3, 3, 3, 2), + ] +) +TESTS.append( + [ + { + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "spatial_size": [4, 4, 4], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "allow_smaller": True, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, + (3, 3, 3, 3), + ] +) -TEST_CASE_1 = [ - { - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "spatial_size": [2, 2, 2], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - list, - (3, 2, 2, 2), -] -TEST_CASE_2 = [ - { - "label": None, - "spatial_size": [2, 2, 2], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - { - "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - }, - list, - (3, 2, 2, 2), -] +class TestRandCropByPosNegLabel(unittest.TestCase): + @staticmethod + def convert_data_type(im_type, d, keys=("img", "image", "label")): + out = deepcopy(d) + for k, v in out.items(): + if k in keys and isinstance(v, np.ndarray): + out[k] = im_type(v) + return out + @parameterized.expand(TESTS) + def test_type_shape(self, input_param, input_data, expected_shape): + results = [] + for p in TEST_NDARRAYS: + input_param_mod = self.convert_data_type(p, input_param) + input_data_mod = self.convert_data_type(p, input_data) + cropper = RandCropByPosNegLabel(**input_param_mod) + cropper.set_random_state(0) + result = cropper(**input_data_mod) -class TestRandCropByPosNegLabel(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) - def test_type_shape(self, input_param, input_data, expected_type, expected_shape): - result = RandCropByPosNegLabel(**input_param)(**input_data) - self.assertIsInstance(result, expected_type) - self.assertTupleEqual(result[0].shape, expected_shape) + self.assertIsInstance(result, list) + self.assertTupleEqual(result[0].shape, expected_shape) + + # check for same results across numpy, torch.Tensor and torch.cuda.Tensor + result = np.asarray([i if isinstance(i, np.ndarray) else i.cpu().numpy() for i in result]) + results.append(np.asarray(result)) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) if __name__ == "__main__": diff --git a/tests/test_rand_crop_by_pos_neg_labeld.py b/tests/test_rand_crop_by_pos_neg_labeld.py index 17a3e117bb..bff5c0e7fd 100644 --- a/tests/test_rand_crop_by_pos_neg_labeld.py +++ b/tests/test_rand_crop_by_pos_neg_labeld.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,90 +10,141 @@ # limitations under the License. import unittest +from copy import deepcopy import numpy as np from parameterized import parameterized from monai.transforms import RandCropByPosNegLabeld +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ - { - "keys": ["image", "extra", "label"], - "label_key": "label", - "spatial_size": [-1, 2, 2], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image_key": None, - "image_threshold": 0, - }, - { - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "extra": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, - }, - list, - (3, 3, 2, 2), +TESTS = [ + [ + { + "keys": ["image", "extra", "label"], + "label_key": "label", + "spatial_size": [-1, 2, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image_key": None, + "image_threshold": 0, + }, + { + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "extra": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, + }, + (3, 3, 2, 2), + ], + [ + { + "keys": ["image", "extra", "label"], + "label_key": "label", + "spatial_size": [2, 2, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image_key": None, + "image_threshold": 0, + }, + { + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "extra": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "label_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, + }, + (3, 2, 2, 2), + ], + [ + { + "keys": ["image", "extra", "label"], + "label_key": "label", + "spatial_size": [2, 2, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image_key": None, + "image_threshold": 0, + }, + { + "image": np.zeros([3, 3, 3, 3]) - 1, + "extra": np.zeros([3, 3, 3, 3]), + "label": np.ones([3, 3, 3, 3]), + "extra_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, + }, + (3, 2, 2, 2), + ], + [ + { + "keys": ["image", "extra", "label"], + "label_key": "label", + "spatial_size": [4, 4, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image_key": None, + "image_threshold": 0, + "allow_smaller": True, + }, + { + "image": np.zeros([3, 3, 3, 3]) - 1, + "extra": np.zeros([3, 3, 3, 3]), + "label": np.ones([3, 3, 3, 3]), + "extra_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, + }, + (3, 3, 3, 2), + ], + [ + { + "keys": ["image", "extra", "label"], + "label_key": "label", + "spatial_size": [4, 4, 4], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image_key": None, + "image_threshold": 0, + "allow_smaller": True, + }, + { + "image": np.zeros([3, 3, 3, 3]) - 1, + "extra": np.zeros([3, 3, 3, 3]), + "label": np.ones([3, 3, 3, 3]), + "extra_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, + }, + (3, 3, 3, 3), + ], ] -TEST_CASE_1 = [ - { - "keys": ["image", "extra", "label"], - "label_key": "label", - "spatial_size": [2, 2, 2], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image_key": None, - "image_threshold": 0, - }, - { - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "extra": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, - }, - list, - (3, 2, 2, 2), -] -TEST_CASE_2 = [ - { - "keys": ["image", "extra", "label"], - "label_key": "label", - "spatial_size": [2, 2, 2], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image_key": None, - "image_threshold": 0, - }, - { - "image": np.zeros([3, 3, 3, 3]) - 1, - "extra": np.zeros([3, 3, 3, 3]), - "label": np.ones([3, 3, 3, 3]), - "extra_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, - }, - list, - (3, 2, 2, 2), -] +class TestRandCropByPosNegLabeld(unittest.TestCase): + @staticmethod + def convert_data_type(im_type, d, keys=("img", "image", "label")): + out = deepcopy(d) + for k, v in out.items(): + if k in keys and isinstance(v, np.ndarray): + out[k] = im_type(v) + return out + @parameterized.expand(TESTS) + def test_type_shape(self, input_param, input_data, expected_shape): + for p in TEST_NDARRAYS: + input_param_mod = self.convert_data_type(p, input_param) + input_data_mod = self.convert_data_type(p, input_data) + cropper = RandCropByPosNegLabeld(**input_param_mod) + cropper.set_random_state(0) + result = cropper(input_data_mod) -class TestRandCropByPosNegLabeld(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) - def test_type_shape(self, input_param, input_data, expected_type, expected_shape): - result = RandCropByPosNegLabeld(**input_param)(input_data) - self.assertIsInstance(result, expected_type) - self.assertTupleEqual(result[0]["image"].shape, expected_shape) - self.assertTupleEqual(result[0]["extra"].shape, expected_shape) - self.assertTupleEqual(result[0]["label"].shape, expected_shape) - _len = len(tuple(input_data.keys())) - self.assertTupleEqual(tuple(result[0].keys())[:_len], tuple(input_data.keys())) - for i, item in enumerate(result): - self.assertEqual(item["image_meta_dict"]["patch_index"], i) - self.assertEqual(item["label_meta_dict"]["patch_index"], i) - self.assertEqual(item["extra_meta_dict"]["patch_index"], i) + self.assertIsInstance(result, list) + + _len = len(tuple(input_data.keys())) + self.assertTupleEqual(tuple(result[0].keys())[:_len], tuple(input_data.keys())) + for k in ("image", "extra", "label"): + self.assertTupleEqual(result[0][k].shape, expected_shape) + for i, item in enumerate(result): + self.assertEqual(item[k + "_meta_dict"]["patch_index"], i) if __name__ == "__main__": diff --git a/tests/test_rand_cucim_dict_transform.py b/tests/test_rand_cucim_dict_transform.py new file mode 100644 index 0000000000..cd41b7f49a --- /dev/null +++ b/tests/test_rand_cucim_dict_transform.py @@ -0,0 +1,185 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandCuCIMd +from monai.utils import optional_import, set_determinism +from tests.utils import HAS_CUPY, skip_if_no_cuda + +_, has_cut = optional_import("cucim.core.operations.expose.transform") +cp, _ = optional_import("cupy") + +set_determinism(seed=0) + +TEST_CASE_COLOR_JITTER_1 = [ + {"name": "color_jitter", "brightness": 0.0, "contrast": 0.0, "saturation": 0.0, "hue": 0.0}, + np.array([[[0, 1], [2, 3]], [[0, 10], [20, 30]], [[0, 50], [100, 150]]], dtype=np.uint8), + np.array([[[0, 1], [2, 3]], [[0, 10], [20, 30]], [[0, 50], [100, 150]]], dtype=np.uint8), +] + +TEST_CASE_FLIP_1 = [ + {"name": "image_flip", "spatial_axis": -1}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]]], dtype=np.float32), +] + +TEST_CASE_RAND_ROTATE_1 = [ + {"name": "rand_image_rotate_90", "prob": 1.0, "max_k": 1, "spatial_axis": (-2, -1)}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]]], dtype=np.float32), +] + + +TEST_CASE_RAND_ROTATE_2 = [ + {"name": "rand_image_rotate_90", "prob": 0.0, "max_k": 1, "spatial_axis": (-2, -1)}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), +] + +TEST_CASE_SCALE_INTENSITY_1 = [ + {"name": "scale_intensity_range", "a_min": 0.0, "a_max": 4.0, "b_min": 0.0, "b_max": 1.0, "clip": False}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]]], dtype=np.float32), +] + +TEST_CASE_ZOOM_1 = [ + {"name": "zoom", "zoom_factor": (0.5, 0.5)}, + np.mgrid[:3, 1:4].astype(dtype=np.float32), + np.concatenate([np.ones((1, 3, 3), dtype=np.float32) * 1.0, np.ones((1, 3, 3), dtype=np.float32) * 2.0]), +] + +TEST_CASE_RAND_ZOOM_1 = [ + {"name": "rand_zoom", "prob": 1.0, "min_zoom": 0.5, "max_zoom": 0.5}, + np.mgrid[:3, 1:4].astype(dtype=np.float32), + np.concatenate([np.ones((1, 3, 3), dtype=np.float32) * 1.0, np.ones((1, 3, 3), dtype=np.float32) * 2.0]), +] + +TEST_CASE_RAND_ZOOM_2 = [ + {"name": "rand_zoom", "prob": 0.0, "min_zoom": 0.5, "max_zoom": 0.5}, + np.mgrid[:3, 1:4].astype(dtype=np.float32), + np.mgrid[:3, 1:4].astype(dtype=np.float32), +] + + +@skip_if_no_cuda +@unittest.skipUnless(HAS_CUPY, "CuPy is required.") +@unittest.skipUnless(has_cut, "cuCIM transforms are required.") +class TestRandCuCIMDict(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_FLIP_1, + TEST_CASE_RAND_ROTATE_1, + TEST_CASE_RAND_ROTATE_2, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + TEST_CASE_RAND_ZOOM_1, + TEST_CASE_RAND_ZOOM_2, + ] + ) + def test_tramsforms_numpy_single(self, params, input, expected): + input = {"image": input} + # apply_prob=1.0 + output = RandCuCIMd(keys="image", apply_prob=1.0, **params)(input)["image"] + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, expected) + # apply_prob=0.0 + output = RandCuCIMd(keys="image", apply_prob=0.0, **params)(input)["image"] + self.assertTrue(output.dtype == input["image"].dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, input["image"]) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_FLIP_1, + TEST_CASE_RAND_ROTATE_1, + TEST_CASE_RAND_ROTATE_2, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + TEST_CASE_RAND_ZOOM_1, + TEST_CASE_RAND_ZOOM_2, + ] + ) + def test_tramsforms_numpy_batch(self, params, input, expected): + input = {"image": input[cp.newaxis, ...]} + expected = expected[cp.newaxis, ...] + # apply_prob=1.0 + output = RandCuCIMd(keys="image", apply_prob=1.0, **params)(input)["image"] + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, expected) + # apply_prob=0.0 + output = RandCuCIMd(keys="image", apply_prob=0.0, **params)(input)["image"] + self.assertTrue(output.dtype == input["image"].dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, input["image"]) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_FLIP_1, + TEST_CASE_RAND_ROTATE_1, + TEST_CASE_RAND_ROTATE_2, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + TEST_CASE_RAND_ZOOM_1, + TEST_CASE_RAND_ZOOM_2, + ] + ) + def test_tramsforms_cupy_single(self, params, input, expected): + input = {"image": cp.asarray(input)} + expected = cp.asarray(expected) + # apply_prob=1.0 + output = RandCuCIMd(keys="image", apply_prob=1.0, **params)(input)["image"] + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, expected) + # apply_prob=0.0 + output = RandCuCIMd(keys="image", apply_prob=0.0, **params)(input)["image"] + self.assertTrue(output.dtype == input["image"].dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, input["image"]) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_FLIP_1, + TEST_CASE_RAND_ROTATE_1, + TEST_CASE_RAND_ROTATE_2, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + TEST_CASE_RAND_ZOOM_1, + TEST_CASE_RAND_ZOOM_2, + ] + ) + def test_tramsforms_cupy_batch(self, params, input, expected): + input = {"image": cp.asarray(input)[cp.newaxis, ...]} + expected = cp.asarray(expected)[cp.newaxis, ...] + # apply_prob=1.0 + output = RandCuCIMd(keys="image", **params)(input)["image"] + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, expected) + # apply_prob=0.0 + output = RandCuCIMd(keys="image", apply_prob=0.0, **params)(input)["image"] + self.assertTrue(output.dtype == input["image"].dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, input["image"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_cucim_transform.py b/tests/test_rand_cucim_transform.py new file mode 100644 index 0000000000..0950329833 --- /dev/null +++ b/tests/test_rand_cucim_transform.py @@ -0,0 +1,184 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandCuCIM +from monai.utils import optional_import, set_determinism +from tests.utils import HAS_CUPY, skip_if_no_cuda + +_, has_cut = optional_import("cucim.core.operations.expose.transform") +cp, _ = optional_import("cupy") + +set_determinism(seed=0) + +TEST_CASE_COLOR_JITTER_1 = [ + {"name": "color_jitter", "brightness": 0.0, "contrast": 0.0, "saturation": 0.0, "hue": 0.0}, + np.array([[[0, 1], [2, 3]], [[0, 10], [20, 30]], [[0, 50], [100, 150]]], dtype=np.uint8), + np.array([[[0, 1], [2, 3]], [[0, 10], [20, 30]], [[0, 50], [100, 150]]], dtype=np.uint8), +] + +TEST_CASE_FLIP_1 = [ + {"name": "image_flip", "spatial_axis": -1}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]]], dtype=np.float32), +] + +TEST_CASE_RAND_ROTATE_1 = [ + {"name": "rand_image_rotate_90", "prob": 1.0, "max_k": 1, "spatial_axis": (-2, -1)}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]]], dtype=np.float32), +] + + +TEST_CASE_RAND_ROTATE_2 = [ + {"name": "rand_image_rotate_90", "prob": 0.0, "max_k": 1, "spatial_axis": (-2, -1)}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), +] + +TEST_CASE_SCALE_INTENSITY_1 = [ + {"name": "scale_intensity_range", "a_min": 0.0, "a_max": 4.0, "b_min": 0.0, "b_max": 1.0, "clip": False}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32), + np.array([[[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]]], dtype=np.float32), +] + +TEST_CASE_ZOOM_1 = [ + {"name": "zoom", "zoom_factor": (0.5, 0.5)}, + np.mgrid[:3, 1:4].astype(dtype=np.float32), + np.concatenate([np.ones((1, 3, 3), dtype=np.float32) * 1.0, np.ones((1, 3, 3), dtype=np.float32) * 2.0]), +] + +TEST_CASE_RAND_ZOOM_1 = [ + {"name": "rand_zoom", "prob": 1.0, "min_zoom": 0.5, "max_zoom": 0.5}, + np.mgrid[:3, 1:4].astype(dtype=np.float32), + np.concatenate([np.ones((1, 3, 3), dtype=np.float32) * 1.0, np.ones((1, 3, 3), dtype=np.float32) * 2.0]), +] + +TEST_CASE_RAND_ZOOM_2 = [ + {"name": "rand_zoom", "prob": 0.0, "min_zoom": 0.5, "max_zoom": 0.5}, + np.mgrid[:3, 1:4].astype(dtype=np.float32), + np.mgrid[:3, 1:4].astype(dtype=np.float32), +] + + +@skip_if_no_cuda +@unittest.skipUnless(HAS_CUPY, "CuPy is required.") +@unittest.skipUnless(has_cut, "cuCIM transforms are required.") +class TestRandCuCIM(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_FLIP_1, + TEST_CASE_RAND_ROTATE_1, + TEST_CASE_RAND_ROTATE_2, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + TEST_CASE_RAND_ZOOM_1, + TEST_CASE_RAND_ZOOM_2, + ] + ) + def test_tramsforms_numpy_single(self, params, input, expected): + # apply_prob=1.0 + output = RandCuCIM(apply_prob=1.0, **params)(input) + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, expected) + # apply_prob=0.0 + output = RandCuCIM(apply_prob=0.0, **params)(input) + self.assertTrue(output.dtype == input.dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, input) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_FLIP_1, + TEST_CASE_RAND_ROTATE_1, + TEST_CASE_RAND_ROTATE_2, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + TEST_CASE_RAND_ZOOM_1, + TEST_CASE_RAND_ZOOM_2, + ] + ) + def test_tramsforms_numpy_batch(self, params, input, expected): + input = input[cp.newaxis, ...] + expected = expected[cp.newaxis, ...] + # apply_prob=1.0 + output = RandCuCIM(apply_prob=1.0, **params)(input) + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, expected) + # apply_prob=0.0 + output = RandCuCIM(apply_prob=0.0, **params)(input) + self.assertTrue(output.dtype == input.dtype) + self.assertTrue(isinstance(output, np.ndarray)) + cp.testing.assert_allclose(output, input) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_FLIP_1, + TEST_CASE_RAND_ROTATE_1, + TEST_CASE_RAND_ROTATE_2, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + TEST_CASE_RAND_ZOOM_1, + TEST_CASE_RAND_ZOOM_2, + ] + ) + def test_tramsforms_cupy_single(self, params, input, expected): + input = cp.asarray(input) + expected = cp.asarray(expected) + # apply_prob=1.0 + output = RandCuCIM(apply_prob=1.0, **params)(input) + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, expected) + # apply_prob=0.0 + output = RandCuCIM(apply_prob=0.0, **params)(input) + self.assertTrue(output.dtype == input.dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, input) + + @parameterized.expand( + [ + TEST_CASE_COLOR_JITTER_1, + TEST_CASE_FLIP_1, + TEST_CASE_RAND_ROTATE_1, + TEST_CASE_RAND_ROTATE_2, + TEST_CASE_SCALE_INTENSITY_1, + TEST_CASE_ZOOM_1, + TEST_CASE_RAND_ZOOM_1, + TEST_CASE_RAND_ZOOM_2, + ] + ) + def test_tramsforms_cupy_batch(self, params, input, expected): + input = cp.asarray(input)[cp.newaxis, ...] + expected = cp.asarray(expected)[cp.newaxis, ...] + # apply_prob=1.0 + output = RandCuCIM(**params)(input) + self.assertTrue(output.dtype == expected.dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, expected) + # apply_prob=0.0 + output = RandCuCIM(apply_prob=0.0, **params)(input) + self.assertTrue(output.dtype == input.dtype) + self.assertTrue(isinstance(output, cp.ndarray)) + cp.testing.assert_allclose(output, input) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_deform_grid.py b/tests/test_rand_deform_grid.py index 7c12c263d2..8a2c8bf6eb 100644 --- a/tests/test_rand_deform_grid.py +++ b/tests/test_rand_deform_grid.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,10 +12,10 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import RandDeformGrid +from tests.utils import assert_allclose TEST_CASES = [ [ @@ -129,11 +129,7 @@ def test_rand_deform_grid(self, input_param, input_data, expected_val): g = RandDeformGrid(**input_param) g.set_random_state(123) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, type_test=False, rtol=1e-3, atol=1e-3) if __name__ == "__main__": diff --git a/tests/test_rand_elastic_2d.py b/tests/test_rand_elastic_2d.py index fbfb7d5761..bc23a6c5cb 100644 --- a/tests/test_rand_elastic_2d.py +++ b/tests/test_rand_elastic_2d.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,90 +16,103 @@ from parameterized import parameterized from monai.transforms import Rand2DElastic +from tests.utils import TEST_NDARRAYS, assert_allclose, is_tf32_env -TEST_CASES = [ - [ - {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "as_tensor_output": False, "device": None}, - {"img": torch.ones((3, 3, 3)), "spatial_size": (2, 2)}, - np.ones((3, 2, 2)), - ], - [ - {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "as_tensor_output": False, "device": None}, - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - { - "spacing": (0.3, 0.3), - "magnitude_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "device": None, - "padding_mode": "zeros", - }, - {"img": torch.ones((3, 3, 3)), "spatial_size": (2, 2), "mode": "bilinear"}, - np.array( +_rtol = 5e-3 if is_tf32_env() else 1e-4 + +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "device": device}, + {"img": p(torch.ones((3, 3, 3))), "spatial_size": (2, 2)}, + p(np.ones((3, 2, 2))), + ] + ) + TESTS.append( + [ + {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "device": device}, + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), + ] + ) + TESTS.append( [ - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], + { + "spacing": (0.3, 0.3), + "magnitude_range": (1.0, 2.0), + "prob": 0.9, + "device": device, + "padding_mode": "zeros", + }, + {"img": p(torch.ones((3, 3, 3))), "spatial_size": (2, 2), "mode": "bilinear"}, + p( + np.array( + [ + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + ] + ) + ), ] - ), - ], - [ - { - "spacing": (1.0, 1.0), - "magnitude_range": (1.0, 1.0), - "scale_range": [1.2, 2.2], - "prob": 0.9, - "padding_mode": "border", - "as_tensor_output": True, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3))}, - torch.tensor( + ) + TESTS.append( [ - [[3.0793, 2.6141], [4.0568, 5.9978]], - [[12.0793, 11.6141], [13.0568, 14.9978]], - [[21.0793, 20.6141], [22.0568, 23.9978]], + { + "spacing": (1.0, 1.0), + "magnitude_range": (1.0, 1.0), + "scale_range": [1.2, 2.2], + "prob": 0.9, + "padding_mode": "border", + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + torch.tensor( + [ + [[3.0793, 2.6141], [4.0568, 5.9978]], + [[12.0793, 11.6141], [13.0568, 14.9978]], + [[21.0793, 20.6141], [22.0568, 23.9978]], + ] + ) + ), ] - ), - ], - [ - { - "spacing": (0.3, 0.3), - "magnitude_range": (0.1, 0.2), - "translate_range": [-0.01, 0.01], - "scale_range": [0.01, 0.02], - "prob": 0.9, - "as_tensor_output": False, - "device": "cuda" if torch.cuda.is_available() else "cpu", - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.array( + ) + TESTS.append( [ - [[1.3584113, 1.9251312], [5.626623, 6.642721]], - [[10.358411, 10.925131], [14.626623, 15.642721]], - [[19.358412, 19.92513], [23.626623, 24.642721]], + { + "spacing": (0.3, 0.3), + "magnitude_range": (0.1, 0.2), + "translate_range": [-0.01, 0.01], + "scale_range": [0.01, 0.02], + "prob": 0.9, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + np.array( + [ + [[1.3584113, 1.9251312], [5.626623, 6.642721]], + [[10.358411, 10.925131], [14.626623, 15.642721]], + [[19.358412, 19.92513], [23.626623, 24.642721]], + ] + ) + ), ] - ), - ], -] + ) class TestRand2DElastic(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_2d_elastic(self, input_param, input_data, expected_val): g = Rand2DElastic(**input_param) g.set_random_state(123) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=_rtol, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_elastic_3d.py b/tests/test_rand_elastic_3d.py index c63282d571..d4eed3753f 100644 --- a/tests/test_rand_elastic_3d.py +++ b/tests/test_rand_elastic_3d.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,69 +16,79 @@ from parameterized import parameterized from monai.transforms import Rand3DElastic +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - { - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": -1, - }, - {"img": torch.arange(72).reshape((2, 3, 3, 4))}, - np.arange(72).reshape((2, 3, 3, 4)), - ], - [ - { - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - }, - {"img": torch.ones((2, 3, 3, 3)), "spatial_size": (2, 2, 2)}, - np.ones((2, 2, 2, 2)), - ], - [ - { - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "device": None, - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "spatial_size": (2, 2, 2)}, - np.array([[[[6.4939356, 7.50289], [9.518351, 10.522849]], [[15.512375, 16.523542], [18.531467, 19.53646]]]]), - ], - [ - { - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "rotate_range": [1, 1, 1], - "as_tensor_output": False, - "device": "cuda" if torch.cuda.is_available() else "cpu", - "spatial_size": (2, 2, 2), - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "mode": "bilinear"}, - np.array([[[[5.0069294, 9.463932], [9.287769, 13.739735]], [[12.319424, 16.777205], [16.594296, 21.045748]]]]), - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + { + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": -1, + }, + {"img": p(torch.arange(72).reshape((2, 3, 3, 4)))}, + p(np.arange(72).reshape((2, 3, 3, 4))), + ] + ) + TESTS.append( + [ + {"magnitude_range": (0.3, 2.3), "sigma_range": (1.0, 20.0), "prob": 0.0, "device": device}, + {"img": p(torch.ones((2, 3, 3, 3))), "spatial_size": (2, 2, 2)}, + p(np.ones((2, 2, 2, 2))), + ] + ) + TESTS.append( + [ + {"magnitude_range": (0.3, 0.3), "sigma_range": (1.0, 2.0), "prob": 0.9, "device": device}, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "spatial_size": (2, 2, 2)}, + p( + np.array( + [ + [ + [[6.4939356, 7.50289], [9.518351, 10.522849]], + [[15.512375, 16.523542], [18.531467, 19.53646]], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "rotate_range": [1, 1, 1], + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "mode": "bilinear"}, + p( + np.array( + [ + [ + [[5.0069294, 9.463932], [9.287769, 13.739735]], + [[12.319424, 16.777205], [16.594296, 21.045748]], + ] + ] + ) + ), + ] + ) class TestRand3DElastic(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_3d_elastic(self, input_param, input_data, expected_val): g = Rand3DElastic(**input_param) g.set_random_state(123) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_elasticd_2d.py b/tests/test_rand_elasticd_2d.py index f8eb026088..ead39e5731 100644 --- a/tests/test_rand_elasticd_2d.py +++ b/tests/test_rand_elasticd_2d.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,127 +16,149 @@ from parameterized import parameterized from monai.transforms import Rand2DElasticd +from tests.utils import TEST_NDARRAYS, assert_allclose, is_tf32_env -TEST_CASES = [ - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (1.0, 2.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.ones((3, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (0.3, 0.3), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": -1, - }, - {"img": torch.arange(4).reshape((1, 2, 2)), "seg": torch.arange(4).reshape((1, 2, 2))}, - np.arange(4).reshape((1, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "padding_mode": "zeros", - "device": None, - "spatial_size": (2, 2), - "mode": "bilinear", - }, - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.array( +_rtol = 5e-3 if is_tf32_env() else 1e-4 + +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (1.0, 2.0), + "prob": 0.0, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p(np.ones((3, 2, 2))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (0.3, 0.3), + "prob": 0.0, + "device": device, + "spatial_size": -1, + }, + {"img": p(torch.arange(4).reshape((1, 2, 2))), "seg": p(torch.arange(4).reshape((1, 2, 2)))}, + p(np.arange(4).reshape((1, 2, 2))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (1.0, 2.0), + "prob": 0.9, + "padding_mode": "zeros", + "device": device, + "spatial_size": (2, 2), + "mode": "bilinear", + }, + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p( + np.array( + [ + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + ] + ) + ), + ] + ) + TESTS.append( [ - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], + { + "keys": ("img", "seg"), + "spacing": (1.0, 1.0), + "magnitude_range": (1.0, 1.0), + "scale_range": [1.2, 2.2], + "prob": 0.9, + "padding_mode": "border", + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + torch.tensor( + [ + [[3.0793, 2.6141], [4.0568, 5.9978]], + [[12.0793, 11.6141], [13.0568, 14.9978]], + [[21.0793, 20.6141], [22.0568, 23.9978]], + ] + ) + ), ] - ), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (1.0, 1.0), - "magnitude_range": (1.0, 1.0), - "scale_range": [1.2, 2.2], - "prob": 0.9, - "padding_mode": "border", - "as_tensor_output": True, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - torch.tensor( + ) + TESTS.append( [ - [[3.0793, 2.6141], [4.0568, 5.9978]], - [[12.0793, 11.6141], [13.0568, 14.9978]], - [[21.0793, 20.6141], [22.0568, 23.9978]], + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (0.1, 0.2), + "translate_range": [-0.01, 0.01], + "scale_range": [0.01, 0.02], + "prob": 0.9, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + np.array( + [ + [[1.3584113, 1.9251312], [5.626623, 6.642721]], + [[10.358411, 10.925131], [14.626623, 15.642721]], + [[19.358412, 19.92513], [23.626623, 24.642721]], + ] + ) + ), ] - ), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (0.1, 0.2), - "translate_range": [-0.01, 0.01], - "scale_range": [0.01, 0.02], - "prob": 0.9, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - np.array( + ) + TESTS.append( [ - [[1.3584113, 1.9251312], [5.626623, 6.642721]], - [[10.358411, 10.925131], [14.626623, 15.642721]], - [[19.358412, 19.92513], [23.626623, 24.642721]], + { + "keys": ("img", "seg"), + "mode": ("bilinear", "nearest"), + "spacing": (0.3, 0.3), + "magnitude_range": (0.1, 0.2), + "translate_range": [-0.01, 0.01], + "scale_range": [0.01, 0.02], + "prob": 0.9, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + { + "img": p( + torch.tensor( + [ + [[1.3584, 1.9251], [5.6266, 6.6427]], + [[10.3584, 10.9251], [14.6266, 15.6427]], + [[19.3584, 19.9251], [23.6266, 24.6427]], + ] + ) + ), + "seg": p( + torch.tensor( + [[[0.0, 2.0], [6.0, 8.0]], [[9.0, 11.0], [15.0, 17.0]], [[18.0, 20.0], [24.0, 26.0]]] + ) + ), + }, ] - ), - ], - [ - { - "keys": ("img", "seg"), - "mode": ("bilinear", "nearest"), - "spacing": (0.3, 0.3), - "magnitude_range": (0.1, 0.2), - "translate_range": [-0.01, 0.01], - "scale_range": [0.01, 0.02], - "prob": 0.9, - "as_tensor_output": True, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - { - "img": torch.tensor( - [ - [[1.3584, 1.9251], [5.6266, 6.6427]], - [[10.3584, 10.9251], [14.6266, 15.6427]], - [[19.3584, 19.9251], [23.6266, 24.6427]], - ] - ), - "seg": torch.tensor([[[0.0, 2.0], [6.0, 8.0]], [[9.0, 11.0], [15.0, 17.0]], [[18.0, 20.0], [24.0, 26.0]]]), - }, - ], -] + ) class TestRand2DElasticd(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_2d_elasticd(self, input_param, input_data, expected_val): g = Rand2DElasticd(**input_param) g.set_random_state(123) @@ -144,11 +166,7 @@ def test_rand_2d_elasticd(self, input_param, input_data, expected_val): for key in res: result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected, rtol=_rtol, atol=5e-3) if __name__ == "__main__": diff --git a/tests/test_rand_elasticd_3d.py b/tests/test_rand_elasticd_3d.py index 47ab814882..84bc765ab0 100644 --- a/tests/test_rand_elasticd_3d.py +++ b/tests/test_rand_elasticd_3d.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,98 +16,128 @@ from parameterized import parameterized from monai.transforms import Rand3DElasticd +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2, 2), - }, - {"img": torch.ones((2, 3, 3, 3)), "seg": torch.ones((2, 3, 3, 3))}, - np.ones((2, 2, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, -1, -1), - }, - {"img": torch.ones((2, 3, 3, 3)), "seg": torch.ones((2, 3, 3, 3))}, - np.ones((2, 2, 3, 3)), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": -1, - }, - {"img": torch.arange(8).reshape((1, 2, 2, 2)), "seg": torch.arange(8).reshape((1, 2, 2, 2))}, - np.arange(8).reshape((1, 2, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2, 2), - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "seg": torch.arange(27).reshape((1, 3, 3, 3))}, - np.array([[[[6.4939356, 7.50289], [9.518351, 10.522849]], [[15.512375, 16.523542], [18.531467, 19.53646]]]]), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "rotate_range": [1, 1, 1], - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2, 2), - "mode": "bilinear", - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "seg": torch.arange(27).reshape((1, 3, 3, 3))}, - np.array([[[[5.0069294, 9.463932], [9.287769, 13.739735]], [[12.319424, 16.777205], [16.594296, 21.045748]]]]), - ], - [ - { - "keys": ("img", "seg"), - "mode": ("bilinear", "nearest"), - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "rotate_range": [1, 1, 1], - "as_tensor_output": True, - "device": torch.device("cpu:0"), - "spatial_size": (2, 2, 2), - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "seg": torch.arange(27).reshape((1, 3, 3, 3))}, - { - "img": torch.tensor([[[[5.0069, 9.4639], [9.2878, 13.7397]], [[12.3194, 16.7772], [16.5943, 21.0457]]]]), - "seg": torch.tensor([[[[4.0, 14.0], [7.0, 14.0]], [[9.0, 19.0], [12.0, 22.0]]]]), - }, - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.ones((2, 3, 3, 3))), "seg": p(torch.ones((2, 3, 3, 3)))}, + p(np.ones((2, 2, 2, 2))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": (2, -1, -1), + }, + {"img": p(torch.ones((2, 3, 3, 3))), "seg": p(torch.ones((2, 3, 3, 3)))}, + p(np.ones((2, 2, 3, 3))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": -1, + }, + {"img": p(torch.arange(8).reshape((1, 2, 2, 2))), "seg": p(torch.arange(8).reshape((1, 2, 2, 2)))}, + p(np.arange(8).reshape((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "seg": p(torch.arange(27).reshape((1, 3, 3, 3)))}, + p( + np.array( + [ + [ + [[6.4939356, 7.50289], [9.518351, 10.522849]], + [[15.512375, 16.523542], [18.531467, 19.53646]], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "rotate_range": [1, 1, 1], + "device": device, + "spatial_size": (2, 2, 2), + "mode": "bilinear", + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "seg": p(torch.arange(27).reshape((1, 3, 3, 3)))}, + p( + np.array( + [ + [ + [[5.0069294, 9.463932], [9.287769, 13.739735]], + [[12.319424, 16.777205], [16.594296, 21.045748]], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "mode": ("bilinear", "nearest"), + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "rotate_range": [1, 1, 1], + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "seg": p(torch.arange(27).reshape((1, 3, 3, 3)))}, + { + "img": p( + torch.tensor( + [[[[5.0069, 9.4639], [9.2878, 13.7397]], [[12.3194, 16.7772], [16.5943, 21.0457]]]] + ) + ), + "seg": p(torch.tensor([[[[4.0, 14.0], [7.0, 14.0]], [[9.0, 19.0], [12.0, 22.0]]]])), + }, + ] + ) class TestRand3DElasticd(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_3d_elasticd(self, input_param, input_data, expected_val): g = Rand3DElasticd(**input_param) g.set_random_state(123) @@ -115,11 +145,7 @@ def test_rand_3d_elasticd(self, input_param, input_data, expected_val): for key in res: result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_flip.py b/tests/test_rand_flip.py index b3c514cb1f..b9e9a8c4d6 100644 --- a/tests/test_rand_flip.py +++ b/tests/test_rand_flip.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,12 +34,10 @@ def test_correct_results(self, _, spatial_axis): for p in TEST_NDARRAYS: im = p(self.imt[0]) flip = RandFlip(prob=1.0, spatial_axis=spatial_axis) - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, spatial_axis)) + expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) result = flip(im) - assert_allclose(expected, result) + assert_allclose(result, p(expected)) if __name__ == "__main__": diff --git a/tests/test_rand_flipd.py b/tests/test_rand_flipd.py index 8972024fd8..9a92661c59 100644 --- a/tests/test_rand_flipd.py +++ b/tests/test_rand_flipd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -26,11 +26,9 @@ def test_correct_results(self, _, spatial_axis): for p in TEST_NDARRAYS: flip = RandFlipd(keys="img", prob=1.0, spatial_axis=spatial_axis) result = flip({"img": p(self.imt[0])})["img"] - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, spatial_axis)) + expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(expected, result) + assert_allclose(result, p(expected)) if __name__ == "__main__": diff --git a/tests/test_rand_gaussian_noise.py b/tests/test_rand_gaussian_noise.py index d376add460..1f2adfb9e7 100644 --- a/tests/test_rand_gaussian_noise.py +++ b/tests/test_rand_gaussian_noise.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_rand_gaussian_noised.py b/tests/test_rand_gaussian_noised.py index 4b0d2a311a..be1df0f2e6 100644 --- a/tests/test_rand_gaussian_noised.py +++ b/tests/test_rand_gaussian_noised.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -29,14 +29,16 @@ class TestRandGaussianNoised(NumpyImageTestCase2D): @parameterized.expand(TESTS) def test_correct_results(self, _, im_type, keys, mean, std): - gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std) + gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std, dtype=np.float64) gaussian_fn.set_random_state(seed) im = im_type(self.imt) noised = gaussian_fn({k: im for k in keys}) np.random.seed(seed) + # simulate the randomize() of transform np.random.random() + noise = np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape) for k in keys: - expected = self.imt + np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape) + expected = self.imt + noise self.assertEqual(type(im), type(noised[k])) if isinstance(noised[k], torch.Tensor): noised[k] = noised[k].cpu() diff --git a/tests/test_rand_gaussian_sharpen.py b/tests/test_rand_gaussian_sharpen.py index 909f96f56b..06563a35b6 100644 --- a/tests/test_rand_gaussian_sharpen.py +++ b/tests/test_rand_gaussian_sharpen.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,88 +11,127 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import RandGaussianSharpen +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"prob": 1.0}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( +TESTS = [] + +for p in TEST_NDARRAYS: + TESTS.append( [ - [[5.2919216, 5.5854445, 5.29192], [11.3982, 12.62332, 11.398202], [14.870525, 17.323769, 14.870527]], - [[20.413757, 22.767355, 20.413757], [28.495504, 31.558315, 28.495499], [29.99236, 34.505676, 29.992361]], + {"prob": 1.0}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [5.2919216, 5.5854445, 5.29192], + [11.3982, 12.62332, 11.398202], + [14.870525, 17.323769, 14.870527], + ], + [ + [20.413757, 22.767355, 20.413757], + [28.495504, 31.558315, 28.495499], + [29.99236, 34.505676, 29.992361], + ], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - { - "sigma1_x": (0.5, 0.75), - "sigma1_y": (0.5, 0.75), - "sigma1_z": (0.5, 0.75), - "sigma2_x": 0.4, - "sigma2_y": 0.4, - "sigma2_z": 0.4, - "prob": 1.0, - }, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[4.1071496, 3.597953, 4.1071477], [10.062014, 9.825114, 10.0620165], [14.698058, 15.818766, 14.698058]], - [[18.211048, 18.16049, 18.211048], [25.155039, 24.56279, 25.155039], [28.801964, 30.381308, 28.801964]], + { + "sigma1_x": (0.5, 0.75), + "sigma1_y": (0.5, 0.75), + "sigma1_z": (0.5, 0.75), + "sigma2_x": 0.4, + "sigma2_y": 0.4, + "sigma2_z": 0.4, + "prob": 1.0, + }, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [4.1071496, 3.597953, 4.1071477], + [10.062014, 9.825114, 10.0620165], + [14.698058, 15.818766, 14.698058], + ], + [ + [18.211048, 18.16049, 18.211048], + [25.155039, 24.56279, 25.155039], + [28.801964, 30.381308, 28.801964], + ], + ] + ), ] - ), -] + ) -TEST_CASE_3 = [ - { - "sigma1_x": (0.5, 0.75), - "sigma1_y": (0.5, 0.75), - "sigma1_z": (0.5, 0.75), - "sigma2_x": (0.5, 0.75), - "sigma2_y": (0.5, 0.75), - "sigma2_z": (0.5, 0.75), - "prob": 1.0, - }, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[4.81077, 4.4237204, 4.81077], [12.061236, 12.298177, 12.061236], [17.362553, 19.201174, 17.362553]], - [[21.440754, 22.142393, 21.440754], [30.15308, 30.745445, 30.153086], [33.99255, 36.919838, 33.99255]], + { + "sigma1_x": (0.5, 0.75), + "sigma1_y": (0.5, 0.75), + "sigma1_z": (0.5, 0.75), + "sigma2_x": (0.5, 0.75), + "sigma2_y": (0.5, 0.75), + "sigma2_z": (0.5, 0.75), + "prob": 1.0, + }, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [4.81077, 4.4237204, 4.81077], + [12.061236, 12.298177, 12.061236], + [17.362553, 19.201174, 17.362553], + ], + [ + [21.440754, 22.142393, 21.440754], + [30.15308, 30.745445, 30.153086], + [33.99255, 36.919838, 33.99255], + ], + ] + ), ] - ), -] + ) -TEST_CASE_4 = [ - { - "sigma1_x": (0.5, 0.75), - "sigma1_y": (0.5, 0.75), - "sigma1_z": (0.5, 0.75), - "sigma2_x": (0.5, 0.75), - "sigma2_y": (0.5, 0.75), - "sigma2_z": (0.5, 0.75), - "approx": "scalespace", - "prob": 1.0, - }, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[4.430213, 3.2278745, 4.4302144], [10.325399, 8.507457, 10.325399], [17.494898, 16.5609, 17.494894]], - [[20.87405, 18.06946, 20.87405], [25.813503, 21.268656, 25.8135], [33.93874, 31.402481, 33.938725]], + { + "sigma1_x": (0.5, 0.75), + "sigma1_y": (0.5, 0.75), + "sigma1_z": (0.5, 0.75), + "sigma2_x": (0.5, 0.75), + "sigma2_y": (0.5, 0.75), + "sigma2_z": (0.5, 0.75), + "approx": "scalespace", + "prob": 1.0, + }, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [4.430213, 3.2278745, 4.4302144], + [10.325399, 8.507457, 10.325399], + [17.494898, 16.5609, 17.494894], + ], + [[20.87405, 18.06946, 20.87405], [25.813503, 21.268656, 25.8135], [33.93874, 31.402481, 33.938725]], + ] + ), ] - ), -] + ) class TestRandGaussianSharpen(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): converter = RandGaussianSharpen(**argments) converter.set_random_state(seed=0) result = converter(image) - np.testing.assert_allclose(result, expected_data, rtol=1e-4) + assert_allclose(result, expected_data, atol=0, rtol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_gaussian_sharpend.py b/tests/test_rand_gaussian_sharpend.py index 9ba29ee71b..ecffa547e0 100644 --- a/tests/test_rand_gaussian_sharpend.py +++ b/tests/test_rand_gaussian_sharpend.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,87 +15,126 @@ from parameterized import parameterized from monai.transforms import RandGaussianSharpend +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"keys": "img", "prob": 1.0}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( [ - [[5.2919216, 5.5854445, 5.29192], [11.3982, 12.62332, 11.398202], [14.870525, 17.323769, 14.870527]], - [[20.413757, 22.767355, 20.413757], [28.495504, 31.558315, 28.495499], [29.99236, 34.505676, 29.992361]], + {"keys": "img", "prob": 1.0}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [5.2919216, 5.5854445, 5.29192], + [11.3982, 12.62332, 11.398202], + [14.870525, 17.323769, 14.870527], + ], + [ + [20.413757, 22.767355, 20.413757], + [28.495504, 31.558315, 28.495499], + [29.99236, 34.505676, 29.992361], + ], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - { - "keys": "img", - "sigma1_x": (0.5, 0.75), - "sigma1_y": (0.5, 0.75), - "sigma1_z": (0.5, 0.75), - "sigma2_x": 0.4, - "sigma2_y": 0.4, - "sigma2_z": 0.4, - "prob": 1.0, - }, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( + TESTS.append( [ - [[4.1071496, 3.597953, 4.1071477], [10.062014, 9.825114, 10.0620165], [14.698058, 15.818766, 14.698058]], - [[18.211048, 18.16049, 18.211048], [25.155039, 24.56279, 25.155039], [28.801964, 30.381308, 28.801964]], + { + "keys": "img", + "sigma1_x": (0.5, 0.75), + "sigma1_y": (0.5, 0.75), + "sigma1_z": (0.5, 0.75), + "sigma2_x": 0.4, + "sigma2_y": 0.4, + "sigma2_z": 0.4, + "prob": 1.0, + }, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [4.1071496, 3.597953, 4.1071477], + [10.062014, 9.825114, 10.0620165], + [14.698058, 15.818766, 14.698058], + ], + [ + [18.211048, 18.16049, 18.211048], + [25.155039, 24.56279, 25.155039], + [28.801964, 30.381308, 28.801964], + ], + ] + ), ] - ), -] + ) -TEST_CASE_3 = [ - { - "keys": "img", - "sigma1_x": (0.5, 0.75), - "sigma1_y": (0.5, 0.75), - "sigma1_z": (0.5, 0.75), - "sigma2_x": (0.5, 0.75), - "sigma2_y": (0.5, 0.75), - "sigma2_z": (0.5, 0.75), - "prob": 1.0, - }, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( + TESTS.append( [ - [[4.81077, 4.4237204, 4.81077], [12.061236, 12.298177, 12.061236], [17.362553, 19.201174, 17.362553]], - [[21.440754, 22.142393, 21.440754], [30.15308, 30.745445, 30.153086], [33.99255, 36.919838, 33.99255]], + { + "keys": "img", + "sigma1_x": (0.5, 0.75), + "sigma1_y": (0.5, 0.75), + "sigma1_z": (0.5, 0.75), + "sigma2_x": (0.5, 0.75), + "sigma2_y": (0.5, 0.75), + "sigma2_z": (0.5, 0.75), + "prob": 1.0, + }, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [4.81077, 4.4237204, 4.81077], + [12.061236, 12.298177, 12.061236], + [17.362553, 19.201174, 17.362553], + ], + [ + [21.440754, 22.142393, 21.440754], + [30.15308, 30.745445, 30.153086], + [33.99255, 36.919838, 33.99255], + ], + ] + ), ] - ), -] + ) -TEST_CASE_4 = [ - { - "keys": "img", - "sigma1_x": (0.5, 0.75), - "sigma1_y": (0.5, 0.75), - "sigma1_z": (0.5, 0.75), - "sigma2_x": (0.5, 0.75), - "sigma2_y": (0.5, 0.75), - "sigma2_z": (0.5, 0.75), - "approx": "scalespace", - "prob": 1.0, - }, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( + TESTS.append( [ - [[4.430213, 3.2278745, 4.4302144], [10.325399, 8.507457, 10.325399], [17.494898, 16.5609, 17.494894]], - [[20.87405, 18.06946, 20.87405], [25.813503, 21.268656, 25.8135], [33.93874, 31.402481, 33.938725]], + { + "keys": "img", + "sigma1_x": (0.5, 0.75), + "sigma1_y": (0.5, 0.75), + "sigma1_z": (0.5, 0.75), + "sigma2_x": (0.5, 0.75), + "sigma2_y": (0.5, 0.75), + "sigma2_z": (0.5, 0.75), + "approx": "scalespace", + "prob": 1.0, + }, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [4.430213, 3.2278745, 4.4302144], + [10.325399, 8.507457, 10.325399], + [17.494898, 16.5609, 17.494894], + ], + [[20.87405, 18.06946, 20.87405], [25.813503, 21.268656, 25.8135], [33.93874, 31.402481, 33.938725]], + ] + ), ] - ), -] + ) class TestRandGaussianSharpend(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): converter = RandGaussianSharpend(**argments) converter.set_random_state(seed=0) result = converter(image) - np.testing.assert_allclose(result["img"], expected_data, rtol=1e-4) + assert_allclose(result["img"], expected_data, rtol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_gaussian_smooth.py b/tests/test_rand_gaussian_smooth.py index 889ed7d6d5..d51618be95 100644 --- a/tests/test_rand_gaussian_smooth.py +++ b/tests/test_rand_gaussian_smooth.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,48 +15,81 @@ from parameterized import parameterized from monai.transforms import RandGaussianSmooth +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"sigma_x": (0.5, 1.5), "prob": 1.0}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( [ - [[0.71806467, 0.9074683, 0.71806467], [1.0718315, 1.3545481, 1.0718315], [1.0337002, 1.306359, 1.0337002]], - [[2.0318885, 2.5678391, 2.0318885], [2.6795788, 3.3863702, 2.6795788], [2.3475242, 2.9667296, 2.3475242]], + {"sigma_x": (0.5, 1.5), "prob": 1.0}, + p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])), + np.array( + [ + [ + [0.71806467, 0.9074683, 0.71806467], + [1.0718315, 1.3545481, 1.0718315], + [1.0337002, 1.306359, 1.0337002], + ], + [ + [2.0318885, 2.5678391, 2.0318885], + [2.6795788, 3.3863702, 2.6795788], + [2.3475242, 2.9667296, 2.3475242], + ], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - {"sigma_x": (0.5, 1.5), "sigma_y": (0.5, 1.0), "prob": 1.0}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[0.7686928, 0.9848021, 0.7686928], [1.1474025, 1.4699818, 1.1474024], [1.1065826, 1.4176859, 1.1065826]], - [[2.1751494, 2.7866683, 2.1751497], [2.8685062, 3.6749542, 2.8685062], [2.5130394, 3.219552, 2.5130394]], + {"sigma_x": (0.5, 1.5), "sigma_y": (0.5, 1.0), "prob": 1.0}, + p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])), + np.array( + [ + [ + [0.7686928, 0.9848021, 0.7686928], + [1.1474025, 1.4699818, 1.1474024], + [1.1065826, 1.4176859, 1.1065826], + ], + [ + [2.1751494, 2.7866683, 2.1751497], + [2.8685062, 3.6749542, 2.8685062], + [2.5130394, 3.219552, 2.5130394], + ], + ] + ), ] - ), -] + ) -TEST_CASE_3 = [ - {"sigma_x": (0.5, 1.5), "sigma_y": (0.5, 1.0), "approx": "scalespace", "prob": 1.0}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[0.8128456, 0.96736777, 0.8128456], [1.2742369, 1.5164697, 1.2742369], [1.2800367, 1.5233722, 1.2800368]], - [[2.3825073, 2.8354228, 2.3825073], [3.1855922, 3.7911744, 3.1855922], [2.8496985, 3.391427, 2.8496985]], + {"sigma_x": (0.5, 1.5), "sigma_y": (0.5, 1.0), "approx": "scalespace", "prob": 1.0}, + p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])), + np.array( + [ + [ + [0.8128456, 0.96736777, 0.8128456], + [1.2742369, 1.5164697, 1.2742369], + [1.2800367, 1.5233722, 1.2800368], + ], + [ + [2.3825073, 2.8354228, 2.3825073], + [3.1855922, 3.7911744, 3.1855922], + [2.8496985, 3.391427, 2.8496985], + ], + ] + ), ] - ), -] + ) class TestRandGaussianSmooth(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): converter = RandGaussianSmooth(**argments) converter.set_random_state(seed=0) result = converter(image) - np.testing.assert_allclose(result, expected_data, rtol=1e-4) + assert_allclose(result, expected_data, rtol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_gaussian_smoothd.py b/tests/test_rand_gaussian_smoothd.py index 2eedc9071c..e0ef0a8bb5 100644 --- a/tests/test_rand_gaussian_smoothd.py +++ b/tests/test_rand_gaussian_smoothd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,48 +15,81 @@ from parameterized import parameterized from monai.transforms import RandGaussianSmoothd +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"keys": "img", "sigma_x": (0.5, 1.5), "prob": 1.0}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( [ - [[0.71806467, 0.9074683, 0.71806467], [1.0718315, 1.3545481, 1.0718315], [1.0337002, 1.306359, 1.0337002]], - [[2.0318885, 2.5678391, 2.0318885], [2.6795788, 3.3863702, 2.6795788], [2.3475242, 2.9667296, 2.3475242]], + {"keys": "img", "sigma_x": (0.5, 1.5), "prob": 1.0}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [0.71806467, 0.9074683, 0.71806467], + [1.0718315, 1.3545481, 1.0718315], + [1.0337002, 1.306359, 1.0337002], + ], + [ + [2.0318885, 2.5678391, 2.0318885], + [2.6795788, 3.3863702, 2.6795788], + [2.3475242, 2.9667296, 2.3475242], + ], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - {"keys": "img", "sigma_x": (0.5, 1.5), "sigma_y": (0.5, 1.0), "prob": 1.0}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( + TESTS.append( [ - [[0.7686928, 0.9848021, 0.7686928], [1.1474025, 1.4699818, 1.1474024], [1.1065826, 1.4176859, 1.1065826]], - [[2.1751494, 2.7866683, 2.1751497], [2.8685062, 3.6749542, 2.8685062], [2.5130394, 3.219552, 2.5130394]], + {"keys": "img", "sigma_x": (0.5, 1.5), "sigma_y": (0.5, 1.0), "prob": 1.0}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [0.7686928, 0.9848021, 0.7686928], + [1.1474025, 1.4699818, 1.1474024], + [1.1065826, 1.4176859, 1.1065826], + ], + [ + [2.1751494, 2.7866683, 2.1751497], + [2.8685062, 3.6749542, 2.8685062], + [2.5130394, 3.219552, 2.5130394], + ], + ] + ), ] - ), -] + ) -TEST_CASE_3 = [ - {"keys": "img", "sigma_x": (0.5, 1.5), "sigma_y": (0.5, 1.0), "approx": "scalespace", "prob": 1.0}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array( + TESTS.append( [ - [[0.8128456, 0.96736777, 0.8128456], [1.2742369, 1.5164697, 1.2742369], [1.2800367, 1.5233722, 1.2800368]], - [[2.3825073, 2.8354228, 2.3825073], [3.1855922, 3.7911744, 3.1855922], [2.8496985, 3.391427, 2.8496985]], + {"keys": "img", "sigma_x": (0.5, 1.5), "sigma_y": (0.5, 1.0), "approx": "scalespace", "prob": 1.0}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array( + [ + [ + [0.8128456, 0.96736777, 0.8128456], + [1.2742369, 1.5164697, 1.2742369], + [1.2800367, 1.5233722, 1.2800368], + ], + [ + [2.3825073, 2.8354228, 2.3825073], + [3.1855922, 3.7911744, 3.1855922], + [2.8496985, 3.391427, 2.8496985], + ], + ] + ), ] - ), -] + ) class TestRandGaussianSmoothd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): converter = RandGaussianSmoothd(**argments) converter.set_random_state(seed=0) result = converter(image) - np.testing.assert_allclose(result["img"], expected_data, rtol=1e-4) + assert_allclose(result["img"], expected_data, rtol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_gibbs_noise.py b/tests/test_rand_gibbs_noise.py index a0701d09c3..fe928038da 100644 --- a/tests/test_rand_gibbs_noise.py +++ b/tests/test_rand_gibbs_noise.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,17 +19,17 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import RandGibbsNoise from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS + +_, has_torch_fft = optional_import("torch.fft", name="fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, input_type)) -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestRandGibbsNoise(unittest.TestCase): def setUp(self): set_determinism(0) @@ -39,50 +39,50 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d im = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] - return torch.Tensor(im) if as_tensor_input else im + return input_type(im) @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_0_prob(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [0.5, 1.0] - t = RandGibbsNoise(0.0, alpha, as_tensor_output) + t = RandGibbsNoise(0.0, alpha) out = t(im) - np.testing.assert_allclose(im, out) + torch.testing.assert_allclose(im, out, rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_same_result(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [0.5, 0.8] - t = RandGibbsNoise(1.0, alpha, as_tensor_output) + t = RandGibbsNoise(1.0, alpha) t.set_random_state(42) out1 = t(deepcopy(im)) t.set_random_state(42) out2 = t(deepcopy(im)) - np.testing.assert_allclose(out1, out2) - self.assertIsInstance(out1, torch.Tensor if as_tensor_output else np.ndarray) + torch.testing.assert_allclose(out1, out2, rtol=1e-7, atol=0) + self.assertIsInstance(out1, type(im)) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_identity(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [0.0, 0.0] t = RandGibbsNoise(1.0, alpha) out = t(deepcopy(im)) - np.testing.assert_allclose(im, out, atol=1e-2) + torch.testing.assert_allclose(im, out, atol=1e-2, rtol=1e-7) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_alpha_1(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [1.0, 1.0] t = RandGibbsNoise(1.0, alpha) out = t(deepcopy(im)) - np.testing.assert_allclose(0 * im, out) + torch.testing.assert_allclose(0 * im, out, rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_alpha(self, im_shape, _, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_alpha(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [0.5, 0.51] t = RandGibbsNoise(1.0, alpha) _ = t(deepcopy(im)) diff --git a/tests/test_rand_gibbs_noised.py b/tests/test_rand_gibbs_noised.py index b778bffdda..8c5e045b90 100644 --- a/tests/test_rand_gibbs_noised.py +++ b/tests/test_rand_gibbs_noised.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,19 +19,19 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import RandGibbsNoised from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS + +_, has_torch_fft = optional_import("torch.fft", name="fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, input_type)) KEYS = ["im", "label"] -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestRandGibbsNoised(unittest.TestCase): def setUp(self): set_determinism(0) @@ -41,70 +41,76 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims - return dict(zip(KEYS, ims)) + return {k: input_type(v) for k, v in zip(KEYS, ims)} @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_0_prob(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = [0.5, 1.0] - t = RandGibbsNoised(KEYS, 0.0, alpha, as_tensor_output) + t = RandGibbsNoised(KEYS, 0.0, alpha) out = t(data) for k in KEYS: - np.testing.assert_allclose(data[k], out[k]) + torch.testing.assert_allclose(data[k], out[k], rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_same_result(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = [0.5, 0.8] - t = RandGibbsNoised(KEYS, 1.0, alpha, as_tensor_output) + t = RandGibbsNoised(KEYS, 1.0, alpha) t.set_random_state(42) out1 = t(deepcopy(data)) t.set_random_state(42) out2 = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(out1[k], out2[k]) - self.assertIsInstance(out1[k], torch.Tensor if as_tensor_output else np.ndarray) + torch.testing.assert_allclose(out1[k], out2[k], rtol=1e-7, atol=0) + self.assertIsInstance(out1[k], type(data[k])) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_identity(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = [0.0, 0.0] t = RandGibbsNoised(KEYS, 1.0, alpha) out = t(deepcopy(data)) for k in KEYS: + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k], data[k] = out[k].cpu(), data[k].cpu() np.testing.assert_allclose(data[k], out[k], atol=1e-2) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_alpha_1(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = [1.0, 1.0] t = RandGibbsNoised(KEYS, 1.0, alpha) out = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(0 * data[k], out[k]) + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k], data[k] = out[k].cpu(), data[k].cpu() + np.testing.assert_allclose(0.0 * data[k], out[k], atol=1e-2) @parameterized.expand(TEST_CASES) - def test_dict_matches(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_dict_matches(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) # use same image for both dictionary entries to check same trans is applied to them data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} alpha = [0.5, 1.0] t = RandGibbsNoised(KEYS, 1.0, alpha) out = t(deepcopy(data)) - np.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]]) + torch.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]], rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_alpha(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_alpha(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = [0.5, 0.51] t = RandGibbsNoised(KEYS, 1.0, alpha) _ = t(deepcopy(data)) - self.assertGreaterEqual(t.sampled_alpha, 0.5) - self.assertLessEqual(t.sampled_alpha, 0.51) + self.assertTrue(0.5 <= t.rand_gibbs_noise.sampled_alpha <= 0.51) if __name__ == "__main__": diff --git a/tests/test_rand_grid_distortion.py b/tests/test_rand_grid_distortion.py new file mode 100644 index 0000000000..80f19df0db --- /dev/null +++ b/tests/test_rand_grid_distortion.py @@ -0,0 +1,94 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandGridDistortion +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + seed = 0 + TESTS.append( + [ + dict(num_cells=2, prob=1.0, distort_limit=0.5, mode="nearest", padding_mode="zeros"), + seed, + p(np.indices([6, 6]).astype(np.float32)), + p( + np.array( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [2.0, 2.0, 2.0, 2.0, 2.0, 0.0], + [4.0, 4.0, 4.0, 4.0, 4.0, 0.0], + [4.0, 4.0, 4.0, 4.0, 4.0, 0.0], + [5.0, 5.0, 5.0, 5.0, 5.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 1.0, 3.0, 3.0, 4.0, 0.0], + [0.0, 1.0, 3.0, 3.0, 4.0, 0.0], + [0.0, 1.0, 3.0, 3.0, 4.0, 0.0], + [0.0, 1.0, 3.0, 3.0, 4.0, 0.0], + [0.0, 1.0, 3.0, 3.0, 4.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + ] + ).astype(np.float32) + ), + ] + ) + seed = 1 + TESTS.append( + [ + dict(num_cells=(2, 2), prob=1.0, distort_limit=0.1, mode="bilinear", padding_mode="reflection"), + seed, + p(np.indices([6, 6]).astype(np.float32)), + p( + np.array( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.5660975, 1.5660975, 1.5660975, 1.5660975, 1.5660974, 1.5660975], + [3.132195, 3.132195, 3.132195, 3.132195, 3.132195, 3.132195], + [3.132195, 3.132195, 3.132195, 3.132195, 3.132195, 3.132195], + [4.482229, 4.482229, 4.482229, 4.482229, 4.482229, 4.482229], + [4.167737, 4.167737, 4.167737, 4.167737, 4.167737, 4.167737], + ], + [ + [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 4.4565434], + [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 4.4565434], + [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 4.4565434], + [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 4.4565434], + [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 4.4565434], + [0.0, 1.3940266, 2.7880538, 2.7880538, 4.1657557, 4.456543], + ], + ] + ).astype(np.float32) + ), + ] + ) + + +class TestRandGridDistortion(unittest.TestCase): + @parameterized.expand(TESTS) + def test_rand_grid_distortion(self, input_param, seed, input_data, expected_val): + g = RandGridDistortion(**input_param) + g.set_random_state(seed=seed) + result = g(input_data) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_grid_distortiond.py b/tests/test_rand_grid_distortiond.py new file mode 100644 index 0000000000..323848dc0b --- /dev/null +++ b/tests/test_rand_grid_distortiond.py @@ -0,0 +1,88 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandGridDistortiond +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +num_cells = 2 +seed = 0 +for p in TEST_NDARRAYS: + img = np.indices([6, 6]).astype(np.float32) + TESTS.append( + [ + dict( + keys=["img", "mask"], + num_cells=num_cells, + prob=1.0, + distort_limit=(-0.1, 0.1), + mode=["bilinear", "nearest"], + padding_mode="zeros", + ), + seed, + {"img": p(img), "mask": p(np.ones_like(img[:1]))}, + p( + np.array( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.5645568, 1.5645568, 1.5645568, 1.5645568, 1.5645568, 0.0], + [3.1291137, 3.1291137, 3.1291137, 3.1291137, 3.1291137, 0.0], + [3.1291137, 3.1291137, 3.1291137, 3.1291137, 3.1291137, 0.0], + [4.6599426, 4.6599426, 4.6599426, 4.6599426, 4.6599426, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0], + [0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0], + [0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0], + [0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0], + [0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + ] + ).astype(np.float32) + ), + p( + np.array( + [ + [ + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] + ] + ) + ), + ] + ) + + +class TestRandGridDistortiond(unittest.TestCase): + @parameterized.expand(TESTS) + def test_rand_grid_distortiond(self, input_param, seed, input_data, expected_val_img, expected_val_mask): + g = RandGridDistortiond(**input_param) + g.set_random_state(seed=seed) + result = g(input_data) + assert_allclose(result["img"], expected_val_img, rtol=1e-4, atol=1e-4) + assert_allclose(result["mask"], expected_val_mask, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_histogram_shift.py b/tests/test_rand_histogram_shift.py index b258cc5a7e..c66f7859c6 100644 --- a/tests/test_rand_histogram_shift.py +++ b/tests/test_rand_histogram_shift.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,33 +15,40 @@ from parameterized import parameterized from monai.transforms import RandHistogramShift - -TEST_CASES = [ - [ - {"num_control_points": 5, "prob": 0.0}, - {"img": np.arange(8).reshape((1, 2, 2, 2))}, - np.arange(8).reshape((1, 2, 2, 2)), - ], - [ - {"num_control_points": 5, "prob": 0.9}, - {"img": np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)}, - np.array([[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]]), - ], - [ - {"num_control_points": (5, 20), "prob": 0.9}, - {"img": np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)}, - np.array([[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]]), - ], -] +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"num_control_points": 5, "prob": 0.0}, + {"img": p(np.arange(8).reshape((1, 2, 2, 2)))}, + np.arange(8).reshape((1, 2, 2, 2)), + ] + ) + TESTS.append( + [ + {"num_control_points": 5, "prob": 0.9}, + {"img": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32))}, + np.array([[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]]), + ] + ) + TESTS.append( + [ + {"num_control_points": (5, 20), "prob": 0.9}, + {"img": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32))}, + np.array([[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]]), + ] + ) class TestRandHistogramShift(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_histogram_shift(self, input_param, input_data, expected_val): g = RandHistogramShift(**input_param) g.set_random_state(123) result = g(**input_data) - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_histogram_shiftd.py b/tests/test_rand_histogram_shiftd.py index 806e4f5cf2..fe8ddf9ffd 100644 --- a/tests/test_rand_histogram_shiftd.py +++ b/tests/test_rand_histogram_shiftd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,47 +14,60 @@ import numpy as np from parameterized import parameterized -from monai.transforms import RandHistogramShiftD - -TEST_CASES = [ - [ - {"keys": ("img",), "num_control_points": 5, "prob": 0.0}, - {"img": np.arange(8).reshape((1, 2, 2, 2)), "seg": np.ones(8).reshape((1, 2, 2, 2))}, - {"img": np.arange(8).reshape((1, 2, 2, 2)), "seg": np.ones(8).reshape((1, 2, 2, 2))}, - ], - [ - {"keys": ("img",), "num_control_points": 5, "prob": 0.9}, - {"img": np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32), "seg": np.ones(8).reshape((1, 2, 2, 2))}, - { - "img": np.array( - [[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]] - ), - "seg": np.ones(8).reshape((1, 2, 2, 2)), - }, - ], - [ - {"keys": ("img",), "num_control_points": (5, 20), "prob": 0.9}, - {"img": np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32), "seg": np.ones(8).reshape((1, 2, 2, 2))}, - { - "img": np.array( - [[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]] - ), - "seg": np.ones(8).reshape((1, 2, 2, 2)), - }, - ], -] +from monai.transforms.intensity.dictionary import RandHistogramShiftd +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ("img",), "num_control_points": 5, "prob": 0.0}, + {"img": p(np.arange(8).reshape((1, 2, 2, 2))), "seg": p(np.ones(8).reshape((1, 2, 2, 2)))}, + {"img": np.arange(8).reshape((1, 2, 2, 2)), "seg": np.ones(8).reshape((1, 2, 2, 2))}, + ] + ) + TESTS.append( + [ + {"keys": ("img",), "num_control_points": 5, "prob": 0.9}, + { + "img": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)), + "seg": p(np.ones(8).reshape((1, 2, 2, 2))), + }, + { + "img": np.array( + [[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]] + ), + "seg": np.ones(8).reshape((1, 2, 2, 2)), + }, + ] + ) + TESTS.append( + [ + {"keys": ("img",), "num_control_points": (5, 20), "prob": 0.9}, + { + "img": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)), + "seg": p(np.ones(8).reshape((1, 2, 2, 2))), + }, + { + "img": np.array( + [[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]] + ), + "seg": np.ones(8).reshape((1, 2, 2, 2)), + }, + ] + ) class TestRandHistogramShiftD(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_histogram_shiftd(self, input_param, input_data, expected_val): - g = RandHistogramShiftD(**input_param) + g = RandHistogramShiftd(**input_param) g.set_random_state(123) res = g(input_data) for key in res: result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected, rtol=1e-4, atol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_k_space_spike_noise.py b/tests/test_rand_k_space_spike_noise.py index 71f7e36d9b..645e6ae1ce 100644 --- a/tests/test_rand_k_space_spike_noise.py +++ b/tests/test_rand_k_space_spike_noise.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,18 +19,15 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import KSpaceSpikeNoise, RandKSpaceSpikeNoise from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from tests.utils import TEST_NDARRAYS -TEST_CASES = [] +TESTS = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - for channel_wise in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input, channel_wise)) + for p in TEST_NDARRAYS: + for channel_wise in (True, False): + TESTS.append((shape, p, channel_wise)) -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestRandKSpaceSpikeNoise(unittest.TestCase): def setUp(self): set_determinism(0) @@ -40,44 +37,55 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, im_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d im = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] - return torch.Tensor(im) if as_tensor_input else im + return im_type(im) - @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input, channel_wise): - im = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_0_prob(self, im_shape, im_type, channel_wise): + im = self.get_data(im_shape, im_type) intensity_range = [14, 15] - t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise, as_tensor_output) + t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise) out = t(im) + self.assertEqual(type(im), type(out)) + if isinstance(out, torch.Tensor): + self.assertEqual(out.device, im.device) + im, out = im.cpu(), out.cpu() np.testing.assert_allclose(im, out) - @parameterized.expand(TEST_CASES) - def test_1_prob(self, im_shape, as_tensor_output, as_tensor_input, channel_wise): - im = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_1_prob(self, im_shape, im_type, channel_wise): + im = self.get_data(im_shape, im_type) intensity_range = [14, 14] - t = RandKSpaceSpikeNoise(1.0, intensity_range, channel_wise, as_tensor_output) + t = RandKSpaceSpikeNoise(1.0, intensity_range, channel_wise) out = t(im) - base_t = KSpaceSpikeNoise(t.sampled_locs, [14], as_tensor_output) + base_t = KSpaceSpikeNoise(t.sampled_locs, [14]) out = out - base_t(im) + self.assertEqual(type(im), type(out)) + if isinstance(out, torch.Tensor): + self.assertEqual(out.device, im.device) + im, out = im.cpu(), out.cpu() np.testing.assert_allclose(out, im * 0) - @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input, channel_wise): - im = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_same_result(self, im_shape, im_type, channel_wise): + im = self.get_data(im_shape, im_type) intensity_range = [14, 15] - t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise, as_tensor_output) + t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise) t.set_random_state(42) out1 = t(deepcopy(im)) t.set_random_state(42) out2 = t(deepcopy(im)) + self.assertEqual(type(im), type(out1)) + if isinstance(out1, torch.Tensor): + self.assertEqual(out1.device, im.device) + out1, out2 = out1.cpu(), out2.cpu() np.testing.assert_allclose(out1, out2) - self.assertIsInstance(out1, torch.Tensor if as_tensor_output else np.ndarray) - @parameterized.expand(TEST_CASES) - def test_intensity(self, im_shape, _, as_tensor_input, channel_wise): - im = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_intensity(self, im_shape, im_type, channel_wise): + im = self.get_data(im_shape, im_type) intensity_range = [14, 14.1] t = RandKSpaceSpikeNoise(1.0, intensity_range, channel_wise) _ = t(deepcopy(im)) diff --git a/tests/test_rand_k_space_spike_noised.py b/tests/test_rand_k_space_spike_noised.py index 1056ebf163..7a6a73b215 100644 --- a/tests/test_rand_k_space_spike_noised.py +++ b/tests/test_rand_k_space_spike_noised.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,19 +19,16 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import RandKSpaceSpikeNoised from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from tests.utils import TEST_NDARRAYS -TEST_CASES = [] +TESTS = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for p in TEST_NDARRAYS: + TESTS.append((shape, p)) KEYS = ["image", "label"] -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestKSpaceSpikeNoised(unittest.TestCase): def setUp(self): set_determinism(0) @@ -41,107 +38,53 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, im_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) - ims = [im[None] for im in ims] - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims - return dict(zip(KEYS, ims)) - - @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): - - data = self.get_data(im_shape, as_tensor_input) - - intensity_ranges = {"image": (13, 15), "label": (13, 15)} - t = RandKSpaceSpikeNoised( - KEYS, - global_prob=1.0, - prob=1.0, - intensity_ranges=intensity_ranges, - channel_wise=True, - as_tensor_output=as_tensor_output, - ) - t.set_rand_state(42) + ims = [im_type(im[None]) for im in ims] + return {k: v for k, v in zip(KEYS, ims)} + + @parameterized.expand(TESTS) + def test_same_result(self, im_shape, im_type): + + data = self.get_data(im_shape, im_type) + + t = RandKSpaceSpikeNoised(KEYS, prob=1.0, intensity_range=(13, 15), channel_wise=True) + t.set_random_state(42) out1 = t(deepcopy(data)) - t.set_rand_state(42) + t.set_random_state(42) out2 = t(deepcopy(data)) for k in KEYS: + self.assertEqual(type(out1[k]), type(data[k])) + if isinstance(out1[k], torch.Tensor): + self.assertEqual(out1[k].device, data[k].device) + out1[k] = out1[k].cpu() + out2[k] = out2[k].cpu() np.testing.assert_allclose(out1[k], out2[k], atol=1e-10) - self.assertIsInstance(out1[k], torch.Tensor if as_tensor_output else np.ndarray) - - @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) - intensity_ranges = {"image": (13, 15), "label": (13, 15)} - t1 = RandKSpaceSpikeNoised( - KEYS, - global_prob=0.0, - prob=1.0, - intensity_ranges=intensity_ranges, - channel_wise=True, - as_tensor_output=as_tensor_output, - ) - - t2 = RandKSpaceSpikeNoised( - KEYS, - global_prob=0.0, - prob=1.0, - intensity_ranges=intensity_ranges, - channel_wise=True, - as_tensor_output=as_tensor_output, - ) + + @parameterized.expand(TESTS) + def test_0_prob(self, im_shape, im_type): + data = self.get_data(im_shape, im_type) + + t1 = RandKSpaceSpikeNoised(KEYS, prob=0.0, intensity_range=(13, 15), channel_wise=True) + + t2 = RandKSpaceSpikeNoised(KEYS, prob=0.0, intensity_range=(13, 15), channel_wise=True) out1 = t1(data) out2 = t2(data) for k in KEYS: + self.assertEqual(type(out1[k]), type(data[k])) + if isinstance(out1[k], torch.Tensor): + self.assertEqual(out1[k].device, data[k].device) + out1[k] = out1[k].cpu() + out2[k] = out2[k].cpu() + data[k] = data[k].cpu() + np.testing.assert_allclose(data[k], out1[k]) np.testing.assert_allclose(data[k], out2[k]) - @parameterized.expand(TEST_CASES) - def test_intensity(self, im_shape, as_tensor_output, as_tensor_input): - - data = self.get_data(im_shape, as_tensor_input) - intensity_ranges = {"image": (13, 13.1), "label": (13, 13.1)} - t = RandKSpaceSpikeNoised( - KEYS, - global_prob=1.0, - prob=1.0, - intensity_ranges=intensity_ranges, - channel_wise=True, - as_tensor_output=True, - ) - - _ = t(data) - self.assertGreaterEqual(t.transforms["image"].sampled_k_intensity[0], 13) - self.assertLessEqual(t.transforms["image"].sampled_k_intensity[0], 13.1) - self.assertGreaterEqual(t.transforms["label"].sampled_k_intensity[0], 13) - self.assertLessEqual(t.transforms["label"].sampled_k_intensity[0], 13.1) - - @parameterized.expand(TEST_CASES) - def test_same_transformation(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) - # use same image for both dictionary entries to check same trans is applied to them - data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} - - intensity_ranges = {"image": (13, 15), "label": (13, 15)} - # use common_sampling = True to ask for the same transformation - t = RandKSpaceSpikeNoised( - KEYS, - global_prob=1.0, - prob=1.0, - intensity_ranges=intensity_ranges, - channel_wise=True, - common_sampling=True, - as_tensor_output=True, - ) - - out = t(deepcopy(data)) - - np.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]]) - if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_lambda.py b/tests/test_rand_lambda.py index bf537883cf..043f44aec4 100644 --- a/tests/test_rand_lambda.py +++ b/tests/test_rand_lambda.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_rand_lambdad.py b/tests/test_rand_lambdad.py index 0a127839b8..854fef8879 100644 --- a/tests/test_rand_lambdad.py +++ b/tests/test_rand_lambdad.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_rand_local_patch_shuffle.py b/tests/test_rand_local_patch_shuffle.py deleted file mode 100644 index 8e2eefb5d1..0000000000 --- a/tests/test_rand_local_patch_shuffle.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from parameterized import parameterized - -from monai.transforms import LocalPatchShuffling - -TEST_CASES = [ - [ - {"number_blocks": 10, "blocksize_ratio": 1, "prob": 0.0}, - {"img": np.arange(8).reshape((1, 2, 2, 2))}, - np.arange(8).reshape((1, 2, 2, 2)), - ], - [ - {"number_blocks": 10, "blocksize_ratio": 1, "prob": 1.0}, - {"img": np.arange(27).reshape((1, 3, 3, 3))}, - [ - [ - [[9, 1, 2], [3, 4, 5], [6, 7, 8]], - [[0, 10, 11], [12, 4, 14], [15, 16, 17]], - [[18, 19, 20], [21, 22, 23], [24, 25, 26]], - ] - ], - ], -] - - -class TestLocalPatchShuffle(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_local_patch_shuffle(self, input_param, input_data, expected_val): - g = LocalPatchShuffling(**input_param) - g.set_random_state(seed=12) - result = g(**input_data) - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_rand_rician_noise.py b/tests/test_rand_rician_noise.py index 7ec5fc4dc4..8e2ea1ee3a 100644 --- a/tests/test_rand_rician_noise.py +++ b/tests/test_rand_rician_noise.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_rand_rician_noised.py b/tests/test_rand_rician_noised.py index 010bbcb310..05707059bc 100644 --- a/tests/test_rand_rician_noised.py +++ b/tests/test_rand_rician_noised.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -29,11 +29,12 @@ class TestRandRicianNoisedNumpy(NumpyImageTestCase2D): @parameterized.expand(TESTS) def test_correct_results(self, _, in_type, keys, mean, std): - rician_fn = RandRicianNoised(keys=keys, global_prob=1.0, prob=1.0, mean=mean, std=std) + rician_fn = RandRicianNoised(keys=keys, prob=1.0, mean=mean, std=std, dtype=np.float64) rician_fn.set_random_state(seed) noised = rician_fn({k: in_type(self.imt) for k in keys}) np.random.seed(seed) for k in keys: + # simulate the `randomize` function of transform np.random.random() _std = np.random.uniform(0, std) expected = np.sqrt( diff --git a/tests/test_rand_rotate.py b/tests/test_rand_rotate.py index 0ff8508a0f..b453b01884 100644 --- a/tests/test_rand_rotate.py +++ b/tests/test_rand_rotate.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,25 +10,60 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import RandRotate -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, np.pi / 2, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4, True, "nearest", "border", False)) + TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", True)) + TEST_CASES_2D.append((p, (-np.pi / 4, 0), False, "nearest", "zeros", True)) -class TestRandRotate2D(NumpyImageTestCase2D): - @parameterized.expand( - [ - (np.pi / 2, True, "bilinear", "border", False), - (np.pi / 4, True, "nearest", "border", False), - (np.pi, False, "nearest", "zeros", True), - ((-np.pi / 4, 0), False, "nearest", "zeros", True), - ] +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append( + (p, np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)) + ) + TEST_CASES_3D.append( + ( + p, + np.pi / 4, + (-np.pi / 9, np.pi / 4.5), + (np.pi / 9, np.pi / 6), + False, + "nearest", + "border", + True, + (1, 89, 105, 104), + ) + ) + TEST_CASES_3D.append( + ( + p, + 0.0, + (2 * np.pi, 2.06 * np.pi), + (-np.pi / 180, np.pi / 180), + True, + "nearest", + "zeros", + True, + (1, 48, 64, 80), + ) ) - def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_corners): + TEST_CASES_3D.append((p, (-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90))) + + +class TestRandRotate2D(NumpyImageTestCase2D): + @parameterized.expand(TEST_CASES_2D) + def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners): rotate_fn = RandRotate( range_x=degrees, prob=1.0, @@ -38,7 +73,7 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor align_corners=align_corners, ) rotate_fn.set_random_state(243) - rotated = rotate_fn(self.imt[0]) + rotated = rotate_fn(im_type(self.imt[0])) _order = 0 if mode == "nearest" else 1 if mode == "border": @@ -52,38 +87,14 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) expected = np.stack(expected).astype(np.float32) + rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated good = np.sum(np.isclose(expected, rotated[0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") class TestRandRotate3D(NumpyImageTestCase3D): - @parameterized.expand( - [ - (np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)), - ( - np.pi / 4, - (-np.pi / 9, np.pi / 4.5), - (np.pi / 9, np.pi / 6), - False, - "nearest", - "border", - True, - (1, 89, 105, 104), - ), - ( - 0.0, - (2 * np.pi, 2.06 * np.pi), - (-np.pi / 180, np.pi / 180), - True, - "nearest", - "zeros", - True, - (1, 48, 64, 80), - ), - ((-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90)), - ] - ) - def test_correct_results(self, x, y, z, keep_size, mode, padding_mode, align_corners, expected): + @parameterized.expand(TEST_CASES_3D) + def test_correct_results(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected): rotate_fn = RandRotate( range_x=x, range_y=y, @@ -95,8 +106,8 @@ def test_correct_results(self, x, y, z, keep_size, mode, padding_mode, align_cor align_corners=align_corners, ) rotate_fn.set_random_state(243) - rotated = rotate_fn(self.imt[0]) - np.testing.assert_allclose(rotated.shape, expected) + rotated = rotate_fn(im_type(self.imt[0])) + torch.testing.assert_allclose(rotated.shape, expected, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_rand_rotate90.py b/tests/test_rand_rotate90.py index 50a1b28e53..b845944062 100644 --- a/tests/test_rand_rotate90.py +++ b/tests/test_rand_rotate90.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,49 +14,45 @@ import numpy as np from monai.transforms import RandRotate90 -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestRandRotate90(NumpyImageTestCase2D): def test_default(self): rotate = RandRotate90() - rotate.set_random_state(123) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(123) + rotated = rotate(p(self.imt[0])) + expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) def test_k(self): rotate = RandRotate90(max_k=2) - rotate.set_random_state(234) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(123) + rotated = rotate(p(self.imt[0])) + expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) def test_spatial_axes(self): rotate = RandRotate90(spatial_axes=(0, 1)) - rotate.set_random_state(234) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(123) + rotated = rotate(p(self.imt[0])) + expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) def test_prob_k_spatial_axes(self): rotate = RandRotate90(prob=1.0, max_k=2, spatial_axes=(0, 1)) - rotate.set_random_state(234) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(234) + rotated = rotate(p(self.imt[0])) + expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) if __name__ == "__main__": diff --git a/tests/test_rand_rotate90d.py b/tests/test_rand_rotate90d.py index a487b695f5..ded18e430a 100644 --- a/tests/test_rand_rotate90d.py +++ b/tests/test_rand_rotate90d.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,53 +14,49 @@ import numpy as np from monai.transforms import RandRotate90d -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestRandRotate90d(NumpyImageTestCase2D): def test_default(self): key = None rotate = RandRotate90d(keys=key) - rotate.set_random_state(123) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(123) + rotated = rotate({key: p(self.imt[0])}) + expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated[key], p(expected)) def test_k(self): key = "test" rotate = RandRotate90d(keys=key, max_k=2) - rotate.set_random_state(234) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(234) + rotated = rotate({key: p(self.imt[0])}) + expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated[key], p(expected)) def test_spatial_axes(self): key = "test" rotate = RandRotate90d(keys=key, spatial_axes=(0, 1)) - rotate.set_random_state(234) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(234) + rotated = rotate({key: p(self.imt[0])}) + expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated[key], p(expected)) def test_prob_k_spatial_axes(self): key = "test" rotate = RandRotate90d(keys=key, prob=1.0, max_k=2, spatial_axes=(0, 1)) - rotate.set_random_state(234) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(234) + rotated = rotate({key: p(self.imt[0])}) + expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated[key], p(expected)) def test_no_key(self): key = "unknown" diff --git a/tests/test_rand_rotated.py b/tests/test_rand_rotated.py index 47b4b7107e..23314720c1 100644 --- a/tests/test_rand_rotated.py +++ b/tests/test_rand_rotated.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,26 +10,104 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import RandRotated from monai.utils import GridSampleMode, GridSamplePadMode -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, np.pi / 2, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4, True, "nearest", "border", False)) + TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", True)) + TEST_CASES_2D.append((p, (-np.pi / 4, 0), False, "nearest", "zeros", True)) -class TestRandRotated2D(NumpyImageTestCase2D): - @parameterized.expand( - [ - (np.pi / 2, True, "bilinear", "border", False), - (np.pi / 4, True, "nearest", "border", False), - (np.pi, False, "nearest", "zeros", True), - ((-np.pi / 4, 0), False, "nearest", "zeros", True), - ] + +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append( + (p, np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)) ) - def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_corners): + TEST_CASES_3D.append( + ( + p, + np.pi / 2, + -np.pi / 6, + (0.0, np.pi), + False, + GridSampleMode.NEAREST, + GridSamplePadMode.BORDER, + False, + (1, 87, 104, 109), + ) + ) + TEST_CASES_3D.append( + ( + p, + np.pi / 4, + (-np.pi / 9, np.pi / 4.5), + (np.pi / 9, np.pi / 6), + False, + "nearest", + "border", + True, + (1, 89, 105, 104), + ) + ) + TEST_CASES_3D.append( + ( + p, + np.pi / 4, + (-np.pi / 9, np.pi / 4.5), + (np.pi / 9, np.pi / 6), + False, + GridSampleMode.NEAREST, + GridSamplePadMode.BORDER, + True, + (1, 89, 105, 104), + ) + ) + TEST_CASES_3D.append( + ( + p, + 0.0, + (2 * np.pi, 2.06 * np.pi), + (-np.pi / 180, np.pi / 180), + True, + "nearest", + "zeros", + True, + (1, 48, 64, 80), + ) + ) + TEST_CASES_3D.append( + ( + p, + 0.0, + (2 * np.pi, 2.06 * np.pi), + (-np.pi / 180, np.pi / 180), + True, + GridSampleMode.NEAREST, + GridSamplePadMode.ZEROS, + True, + (1, 48, 64, 80), + ) + ) + TEST_CASES_3D.append((p, (-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90))) + TEST_CASES_3D.append( + (p, (-np.pi / 4, 0), 0, 0, False, GridSampleMode.NEAREST, GridSamplePadMode.ZEROS, False, (1, 48, 77, 90)) + ) + + +class TestRandRotated2D(NumpyImageTestCase2D): + @parameterized.expand(TEST_CASES_2D) + def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners): rotate_fn = RandRotated( "img", range_x=degrees, @@ -40,7 +118,7 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor align_corners=align_corners, ) rotate_fn.set_random_state(243) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) _order = 0 if mode == "nearest" else 1 if padding_mode == "border": @@ -49,74 +127,20 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor _mode = "reflect" else: _mode = "constant" - angle = rotate_fn.x + angle = rotate_fn.rand_rotate.x expected = scipy.ndimage.rotate( self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v expected = np.stack(expected).astype(np.float32) good = np.sum(np.isclose(expected, rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") class TestRandRotated3D(NumpyImageTestCase3D): - @parameterized.expand( - [ - (np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)), - ( - np.pi / 2, - -np.pi / 6, - (0.0, np.pi), - False, - GridSampleMode.NEAREST, - GridSamplePadMode.BORDER, - False, - (1, 87, 104, 109), - ), - ( - np.pi / 4, - (-np.pi / 9, np.pi / 4.5), - (np.pi / 9, np.pi / 6), - False, - "nearest", - "border", - True, - (1, 89, 105, 104), - ), - ( - np.pi / 4, - (-np.pi / 9, np.pi / 4.5), - (np.pi / 9, np.pi / 6), - False, - GridSampleMode.NEAREST, - GridSamplePadMode.BORDER, - True, - (1, 89, 105, 104), - ), - ( - 0.0, - (2 * np.pi, 2.06 * np.pi), - (-np.pi / 180, np.pi / 180), - True, - "nearest", - "zeros", - True, - (1, 48, 64, 80), - ), - ( - 0.0, - (2 * np.pi, 2.06 * np.pi), - (-np.pi / 180, np.pi / 180), - True, - GridSampleMode.NEAREST, - GridSamplePadMode.ZEROS, - True, - (1, 48, 64, 80), - ), - ((-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90)), - ((-np.pi / 4, 0), 0, 0, False, GridSampleMode.NEAREST, GridSamplePadMode.ZEROS, False, (1, 48, 77, 90)), - ] - ) - def test_correct_shapes(self, x, y, z, keep_size, mode, padding_mode, align_corners, expected): + @parameterized.expand(TEST_CASES_3D) + def test_correct_shapes(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected): rotate_fn = RandRotated( "img", range_x=x, @@ -129,7 +153,7 @@ def test_correct_shapes(self, x, y, z, keep_size, mode, padding_mode, align_corn align_corners=align_corners, ) rotate_fn.set_random_state(243) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) np.testing.assert_allclose(rotated["img"].shape, expected) diff --git a/tests/test_rand_scale_crop.py b/tests/test_rand_scale_crop.py index db5487ebff..5d6312002f 100644 --- a/tests/test_rand_scale_crop.py +++ b/tests/test_rand_scale_crop.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.transforms import RandScaleCrop +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASE_1 = [ {"roi_scale": [1.0, 1.0, -1.0], "random_center": True}, @@ -55,22 +56,25 @@ class TestRandScaleCrop(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, input_param, input_data, expected_shape): - result = RandScaleCrop(**input_param)(input_data) - self.assertTupleEqual(result.shape, expected_shape) + for p in TEST_NDARRAYS: + result = RandScaleCrop(**input_param)(p(input_data)) + self.assertTupleEqual(result.shape, expected_shape) @parameterized.expand([TEST_CASE_3]) def test_value(self, input_param, input_data): - cropper = RandScaleCrop(**input_param) - result = cropper(input_data) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] - np.testing.assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) + for p in TEST_NDARRAYS: + cropper = RandScaleCrop(**input_param) + result = cropper(p(input_data)) + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] + assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False) @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_random_shape(self, input_param, input_data, expected_shape): - cropper = RandScaleCrop(**input_param) - cropper.set_random_state(seed=123) - result = cropper(input_data) - self.assertTupleEqual(result.shape, expected_shape) + for p in TEST_NDARRAYS: + cropper = RandScaleCrop(**input_param) + cropper.set_random_state(seed=123) + result = cropper(p(input_data)) + self.assertTupleEqual(result.shape, expected_shape) if __name__ == "__main__": diff --git a/tests/test_rand_scale_cropd.py b/tests/test_rand_scale_cropd.py index 265c6c467d..5e833fef98 100644 --- a/tests/test_rand_scale_cropd.py +++ b/tests/test_rand_scale_cropd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.transforms import RandScaleCropd +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASE_1 = [ {"keys": "img", "roi_scale": [1.0, 1.0, -1.0], "random_center": True}, @@ -23,7 +24,8 @@ ] TEST_CASE_2 = [ - {"keys": "img", "roi_scale": [1.0, 1.0, 1.0], "random_center": False}, + # test `allow_missing_keys` with key "label" + {"keys": ["label", "img"], "roi_scale": [1.0, 1.0, 1.0], "random_center": False, "allow_missing_keys": True}, {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, (3, 3, 3, 3), ] @@ -66,10 +68,14 @@ def test_shape(self, input_param, input_data, expected_shape): @parameterized.expand([TEST_CASE_3]) def test_value(self, input_param, input_data): - cropper = RandScaleCropd(**input_param) - result = cropper(input_data) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] - np.testing.assert_allclose(result["img"], input_data["img"][:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) + for p in TEST_NDARRAYS: + cropper = RandScaleCropd(**input_param) + input_data["img"] = p(input_data["img"]) + result = cropper(input_data) + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] + assert_allclose( + result["img"], input_data["img"][:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False + ) @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_random_shape(self, input_param, input_data, expected_shape): diff --git a/tests/test_rand_scale_intensity.py b/tests/test_rand_scale_intensity.py index 750d88bfad..5aa5c7b964 100644 --- a/tests/test_rand_scale_intensity.py +++ b/tests/test_rand_scale_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,8 +24,10 @@ def test_value(self): scaler.set_random_state(seed=0) result = scaler(p(self.imt)) np.random.seed(0) + # simulate the randomize() of transform + np.random.random() expected = p((self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32)) - assert_allclose(result, expected, rtol=1e-7, atol=0) + assert_allclose(result, p(expected), rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_rand_scale_intensityd.py b/tests/test_rand_scale_intensityd.py index a8d2e63f65..655bd88ee0 100644 --- a/tests/test_rand_scale_intensityd.py +++ b/tests/test_rand_scale_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,14 +19,16 @@ class TestRandScaleIntensityd(NumpyImageTestCase2D): def test_value(self): + key = "img" for p in TEST_NDARRAYS: - key = "img" scaler = RandScaleIntensityd(keys=[key], factors=0.5, prob=1.0) scaler.set_random_state(seed=0) result = scaler({key: p(self.imt)}) np.random.seed(0) + # simulate the randomize function of transform + np.random.random() expected = (self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32) - assert_allclose(result[key], expected) + assert_allclose(result[key], p(expected)) if __name__ == "__main__": diff --git a/tests/test_rand_shift_intensity.py b/tests/test_rand_shift_intensity.py index 4c4dd87dfe..b4f32a385a 100644 --- a/tests/test_rand_shift_intensity.py +++ b/tests/test_rand_shift_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,6 +23,8 @@ def test_value(self): shifter.set_random_state(seed=0) result = shifter(self.imt, factor=1.0) np.random.seed(0) + # simulate the randomize() of transform + np.random.random() expected = self.imt + np.random.uniform(low=-1.0, high=1.0) np.testing.assert_allclose(result, expected) diff --git a/tests/test_rand_shift_intensityd.py b/tests/test_rand_shift_intensityd.py index 6766236146..10bb5cd71e 100644 --- a/tests/test_rand_shift_intensityd.py +++ b/tests/test_rand_shift_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,14 +19,16 @@ class TestRandShiftIntensityd(NumpyImageTestCase2D): def test_value(self): + key = "img" for p in TEST_NDARRAYS: - key = "img" shifter = RandShiftIntensityd(keys=[key], offsets=1.0, prob=1.0) shifter.set_random_state(seed=0) result = shifter({key: p(self.imt)}) np.random.seed(0) + # simulate the randomize() of transform + np.random.random() expected = self.imt + np.random.uniform(low=-1.0, high=1.0) - assert_allclose(result[key], expected) + assert_allclose(result[key], p(expected)) def test_factor(self): key = "img" @@ -36,6 +38,8 @@ def test_factor(self): shifter.set_random_state(seed=0) result = shifter(stats(data)) np.random.seed(0) + # simulate the randomize() of transform + np.random.random() expected = self.imt + np.random.uniform(low=-1.0, high=1.0) * np.nanmax(self.imt) np.testing.assert_allclose(result[key], expected) diff --git a/tests/test_rand_spatial_crop.py b/tests/test_rand_spatial_crop.py index 01e057e589..8f4bb0fffa 100644 --- a/tests/test_rand_spatial_crop.py +++ b/tests/test_rand_spatial_crop.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.transforms import RandSpatialCrop +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASE_0 = [ {"roi_size": [3, 3, -1], "random_center": True}, @@ -56,10 +57,11 @@ def test_shape(self, input_param, input_data, expected_shape): @parameterized.expand([TEST_CASE_3]) def test_value(self, input_param, input_data): - cropper = RandSpatialCrop(**input_param) - result = cropper(input_data) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] - np.testing.assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) + for p in TEST_NDARRAYS: + cropper = RandSpatialCrop(**input_param) + result = cropper(p(input_data)) + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] + assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=False) @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) def test_random_shape(self, input_param, input_data, expected_shape): diff --git a/tests/test_rand_spatial_crop_samples.py b/tests/test_rand_spatial_crop_samples.py index 0ade9bbbba..18fdf38773 100644 --- a/tests/test_rand_spatial_crop_samples.py +++ b/tests/test_rand_spatial_crop_samples.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.transforms import RandSpatialCropSamples +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASE_1 = [ {"roi_size": [3, 3, 3], "num_samples": 4, "random_center": True, "random_size": False}, @@ -70,14 +71,15 @@ class TestRandSpatialCropSamples(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, input_param, input_data, expected_shape, expected_last_item): - xform = RandSpatialCropSamples(**input_param) - xform.set_random_state(1234) - result = xform(input_data) + for p in TEST_NDARRAYS: + xform = RandSpatialCropSamples(**input_param) + xform.set_random_state(1234) + result = xform(p(input_data)) - np.testing.assert_equal(len(result), input_param["num_samples"]) - for item, expected in zip(result, expected_shape): - self.assertTupleEqual(item.shape, expected) - np.testing.assert_allclose(result[-1], expected_last_item) + np.testing.assert_equal(len(result), input_param["num_samples"]) + for item, expected in zip(result, expected_shape): + self.assertTupleEqual(item.shape, expected) + assert_allclose(result[-1], expected_last_item, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_spatial_crop_samplesd.py b/tests/test_rand_spatial_crop_samplesd.py index 3f5eee7b27..b7ceefe542 100644 --- a/tests/test_rand_spatial_crop_samplesd.py +++ b/tests/test_rand_spatial_crop_samplesd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.transforms import Compose, RandSpatialCropSamplesd, ToTensord +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASE_1 = [ {"keys": ["img", "seg"], "num_samples": 4, "roi_size": [2, 2, 2], "random_center": True}, @@ -38,31 +39,48 @@ }, ] -TEST_CASE_2 = [ - {"keys": ["img", "seg"], "num_samples": 8, "roi_size": [2, 2, 3], "random_center": False}, - {"img": np.arange(81).reshape(3, 3, 3, 3), "seg": np.arange(81, 0, -1).reshape(3, 3, 3, 3)}, - [(3, 3, 3, 3), (3, 2, 3, 3), (3, 2, 2, 3), (3, 2, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3), (3, 2, 2, 3), (3, 3, 2, 3)], - { - "img": np.array( +TEST_CASE_2 = [] +for p in TEST_NDARRAYS: + TEST_CASE_2.append( + [ + {"keys": ["img", "seg"], "num_samples": 8, "roi_size": [2, 2, 3], "random_center": False}, + {"img": p(np.arange(81).reshape(3, 3, 3, 3)), "seg": p(np.arange(81, 0, -1).reshape(3, 3, 3, 3))}, [ - [[[0, 1, 2], [3, 4, 5]], [[9, 10, 11], [12, 13, 14]], [[18, 19, 20], [21, 22, 23]]], - [[[27, 28, 29], [30, 31, 32]], [[36, 37, 38], [39, 40, 41]], [[45, 46, 47], [48, 49, 50]]], - [[[54, 55, 56], [57, 58, 59]], [[63, 64, 65], [66, 67, 68]], [[72, 73, 74], [75, 76, 77]]], - ] - ), - "seg": np.array( - [ - [[[81, 80, 79], [78, 77, 76]], [[72, 71, 70], [69, 68, 67]], [[63, 62, 61], [60, 59, 58]]], - [[[54, 53, 52], [51, 50, 49]], [[45, 44, 43], [42, 41, 40]], [[36, 35, 34], [33, 32, 31]]], - [[[27, 26, 25], [24, 23, 22]], [[18, 17, 16], [15, 14, 13]], [[9, 8, 7], [6, 5, 4]]], - ] - ), - }, -] + (3, 3, 3, 3), + (3, 2, 3, 3), + (3, 2, 2, 3), + (3, 2, 3, 3), + (3, 3, 3, 3), + (3, 3, 3, 3), + (3, 2, 2, 3), + (3, 3, 2, 3), + ], + { + "img": p( + np.array( + [ + [[[0, 1, 2], [3, 4, 5]], [[9, 10, 11], [12, 13, 14]], [[18, 19, 20], [21, 22, 23]]], + [[[27, 28, 29], [30, 31, 32]], [[36, 37, 38], [39, 40, 41]], [[45, 46, 47], [48, 49, 50]]], + [[[54, 55, 56], [57, 58, 59]], [[63, 64, 65], [66, 67, 68]], [[72, 73, 74], [75, 76, 77]]], + ] + ) + ), + "seg": p( + np.array( + [ + [[[81, 80, 79], [78, 77, 76]], [[72, 71, 70], [69, 68, 67]], [[63, 62, 61], [60, 59, 58]]], + [[[54, 53, 52], [51, 50, 49]], [[45, 44, 43], [42, 41, 40]], [[36, 35, 34], [33, 32, 31]]], + [[[27, 26, 25], [24, 23, 22]], [[18, 17, 16], [15, 14, 13]], [[9, 8, 7], [6, 5, 4]]], + ] + ) + ), + }, + ] + ) class TestRandSpatialCropSamplesd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_1, *TEST_CASE_2]) def test_shape(self, input_param, input_data, expected_shape, expected_last): xform = RandSpatialCropSamplesd(**input_param) xform.set_random_state(1234) @@ -73,18 +91,14 @@ def test_shape(self, input_param, input_data, expected_shape, expected_last): for i, item in enumerate(result): self.assertEqual(item["img_meta_dict"]["patch_index"], i) self.assertEqual(item["seg_meta_dict"]["patch_index"], i) - np.testing.assert_allclose(item["img"], expected_last["img"]) - np.testing.assert_allclose(item["seg"], expected_last["seg"]) + assert_allclose(item["img"], expected_last["img"], type_test=True) + assert_allclose(item["seg"], expected_last["seg"], type_test=True) def test_deep_copy(self): data = {"img": np.ones((1, 10, 11, 12))} num_samples = 3 sampler = RandSpatialCropSamplesd( - keys=["img"], - roi_size=(3, 3, 3), - num_samples=num_samples, - random_center=True, - random_size=False, + keys=["img"], roi_size=(3, 3, 3), num_samples=num_samples, random_center=True, random_size=False ) transform = Compose([ToTensord(keys="img"), sampler]) samples = transform(data) diff --git a/tests/test_rand_spatial_cropd.py b/tests/test_rand_spatial_cropd.py index 610c1974aa..9e6e86eea2 100644 --- a/tests/test_rand_spatial_cropd.py +++ b/tests/test_rand_spatial_cropd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.transforms import RandSpatialCropd +from tests.utils import TEST_NDARRAYS TEST_CASE_0 = [ {"keys": "img", "roi_size": [3, 3, -1], "random_center": True}, @@ -67,10 +68,12 @@ def test_value(self, input_param, input_data): @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) def test_random_shape(self, input_param, input_data, expected_shape): - cropper = RandSpatialCropd(**input_param) - cropper.set_random_state(seed=123) - result = cropper(input_data) - self.assertTupleEqual(result["img"].shape, expected_shape) + for p in TEST_NDARRAYS: + cropper = RandSpatialCropd(**input_param) + cropper.set_random_state(seed=123) + input_data["img"] = p(input_data["img"]) + result = cropper(input_data) + self.assertTupleEqual(result["img"].shape, expected_shape) if __name__ == "__main__": diff --git a/tests/test_rand_std_shift_intensity.py b/tests/test_rand_std_shift_intensity.py index 0c6382555e..fdf386fee4 100644 --- a/tests/test_rand_std_shift_intensity.py +++ b/tests/test_rand_std_shift_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,6 +22,8 @@ class TestRandStdShiftIntensity(NumpyImageTestCase2D): def test_value(self): for p in TEST_NDARRAYS: np.random.seed(0) + # simulate the randomize() of transform + np.random.random() factor = np.random.uniform(low=-1.0, high=1.0) offset = factor * np.std(self.imt) expected = p(self.imt + offset) diff --git a/tests/test_rand_std_shift_intensityd.py b/tests/test_rand_std_shift_intensityd.py index 0ab017a42d..e98d1e3ad3 100644 --- a/tests/test_rand_std_shift_intensityd.py +++ b/tests/test_rand_std_shift_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,6 +23,8 @@ def test_value(self): for p in TEST_NDARRAYS: key = "img" np.random.seed(0) + # simulate the randomize() of transform + np.random.random() factor = np.random.uniform(low=-1.0, high=1.0) expected = self.imt + factor * np.std(self.imt) shifter = RandStdShiftIntensityd(keys=[key], factors=1.0, prob=1.0) diff --git a/tests/test_rand_weighted_crop.py b/tests/test_rand_weighted_crop.py index 39a9439122..dae7f05016 100644 --- a/tests/test_rand_weighted_crop.py +++ b/tests/test_rand_weighted_crop.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,127 +12,159 @@ import unittest import numpy as np +import torch +from parameterized.parameterized import parameterized from monai.transforms.croppad.array import RandWeightedCrop -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose -class TestRandWeightedCrop2D(NumpyImageTestCase2D): - def test_rand_weighted_crop_small_roi(self): - img = self.seg1[0] - n_samples = 3 - crop = RandWeightedCrop((10, 12), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 40, 31] = 1 - weight[0, 80, 21] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 10, 12)) - np.testing.assert_allclose(np.asarray(crop.centers), [[80, 21], [30, 17], [40, 31]]) +def get_data(ndim): + im_gen = NumpyImageTestCase2D() if ndim == 2 else NumpyImageTestCase3D() + im_gen.setUp() + return im_gen.imt[0], im_gen.seg1[0], im_gen.segn[0] + + +IMT_2D, SEG1_2D, SEGN_2D = get_data(ndim=2) +IMT_3D, SEG1_3D, SEGN_3D = get_data(ndim=3) + - def test_rand_weighted_crop_default_roi(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCrop((10, -1), n_samples) - weight = np.zeros_like(img) +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + im = SEG1_2D + weight = np.zeros_like(im) weight[0, 30, 17] = 1.1 weight[0, 40, 31] = 1 weight[0, 80, 21] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 10, 64)) - np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32], [105, 32], [20, 32]]) - - def test_rand_weighted_crop_large_roi(self): - img = self.segn[0] - n_samples = 3 - crop = RandWeightedCrop((10000, 400), n_samples) - weight = np.zeros_like(img) + TESTS.append( + [ + "small roi 2d", + dict(spatial_size=(10, 12), num_samples=3), + p(im), + q(weight), + (1, 10, 12), + [[80, 21], [30, 17], [40, 31]], + ] + ) + im = IMT_2D + TESTS.append( + [ + "default roi 2d", + dict(spatial_size=(10, -1), num_samples=3), + p(im), + q(weight), + (1, 10, 64), + [[14, 32], [105, 32], [20, 32]], + ] + ) + im = SEGN_2D + weight = np.zeros_like(im) weight[0, 30, 17] = 1.1 weight[0, 10, 1] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 128, 64)) - np.testing.assert_allclose(np.asarray(crop.centers), [[64, 32], [64, 32], [64, 32]]) - for res in result: - np.testing.assert_allclose(res, self.segn[0]) - - def test_rand_weighted_crop_bad_w(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCrop((20, 40), n_samples) - weight = np.zeros_like(img) + TESTS.append( + [ + "large roi 2d", + dict(spatial_size=(10000, 400), num_samples=3), + p(im), + q(weight), + (1, 128, 64), + [[64, 32], [64, 32], [64, 32]], + ] + ) + im = IMT_2D + weight = np.zeros_like(im) weight[0, 30, 17] = np.inf weight[0, 10, 1] = -np.inf weight[0, 10, 20] = -np.nan - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 20, 40)) - np.testing.assert_allclose(np.asarray(crop.centers), [[63, 37], [31, 43], [66, 20]]) - - -class TestRandWeightedCrop(NumpyImageTestCase3D): - def test_rand_weighted_crop_small_roi(self): - img = self.seg1[0] - n_samples = 3 - crop = RandWeightedCrop((8, 10, 12), n_samples) - weight = np.zeros_like(img) + TESTS.append( + [ + "bad w 2d", + dict(spatial_size=(20, 40), num_samples=3), + p(im), + q(weight), + (1, 20, 40), + [[63, 37], [31, 43], [66, 20]], + ] + ) + im = SEG1_3D + weight = np.zeros_like(im) weight[0, 5, 30, 17] = 1.1 weight[0, 8, 40, 31] = 1 weight[0, 11, 23, 21] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 8, 10, 12)) - np.testing.assert_allclose(np.asarray(crop.centers), [[11, 23, 21], [5, 30, 17], [8, 40, 31]]) - - def test_rand_weighted_crop_default_roi(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCrop((10, -1, -1), n_samples) - weight = np.zeros_like(img) + TESTS.append( + [ + "small roi 3d", + dict(spatial_size=(8, 10, 12), num_samples=3), + p(im), + q(weight), + (1, 8, 10, 12), + [[11, 23, 21], [5, 30, 17], [8, 40, 31]], + ] + ) + im = IMT_3D + weight = np.zeros_like(im) weight[0, 7, 17] = 1.1 weight[0, 13, 31] = 1.1 weight[0, 24, 21] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32, 40], [41, 32, 40], [20, 32, 40]]) - - def test_rand_weighted_crop_large_roi(self): - img = self.segn[0] - n_samples = 3 - crop = RandWeightedCrop((10000, 400, 80), n_samples) - weight = np.zeros_like(img) + TESTS.append( + [ + "default roi 3d", + dict(spatial_size=(10, -1, -1), num_samples=3), + p(im), + q(weight), + (1, 10, 64, 80), + [[14, 32, 40], [41, 32, 40], [20, 32, 40]], + ] + ) + im = SEGN_3D + weight = np.zeros_like(im) weight[0, 30, 17, 20] = 1.1 weight[0, 10, 1, 17] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]]) - for res in result: - np.testing.assert_allclose(res, self.segn[0]) - - def test_rand_weighted_crop_bad_w(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCrop((48, 64, 80), n_samples) - weight = np.zeros_like(img) + TESTS.append( + [ + "large roi 3d", + dict(spatial_size=(10000, 400, 80), num_samples=3), + p(im), + q(weight), + (1, 48, 64, 80), + [[24, 32, 40], [24, 32, 40], [24, 32, 40]], + ] + ) + im = IMT_3D + weight = np.zeros_like(im) weight[0, 30, 17] = np.inf weight[0, 10, 1] = -np.inf weight[0, 10, 20] = -np.nan + TESTS.append( + [ + "bad w 3d", + dict(spatial_size=(48, 64, 80), num_samples=3), + p(im), + q(weight), + (1, 48, 64, 80), + [[24, 32, 40], [24, 32, 40], [24, 32, 40]], + ] + ) + + +class TestRandWeightedCrop(unittest.TestCase): + @parameterized.expand(TESTS) + def test_rand_weighted_crop(self, _, input_params, img, weight, expected_shape, expected_vals): + crop = RandWeightedCrop(**input_params) crop.set_random_state(10) result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]]) + self.assertTrue(len(result) == input_params["num_samples"]) + assert_allclose(result[0].shape, expected_shape) + for c, e in zip(crop.centers, expected_vals): + assert_allclose(c, e, type_test=False) + # if desired ROI is larger than image, check image is unchanged + if all(s >= i for i, s in zip(img.shape[1:], input_params["spatial_size"])): + for res in result: + self.assertEqual(type(img), type(res)) + if isinstance(img, torch.Tensor): + self.assertEqual(res.device, img.device) + assert_allclose(res, img) if __name__ == "__main__": diff --git a/tests/test_rand_weighted_cropd.py b/tests/test_rand_weighted_cropd.py index 367ce3beb9..93726f55bb 100644 --- a/tests/test_rand_weighted_cropd.py +++ b/tests/test_rand_weighted_cropd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,148 +14,177 @@ import numpy as np from monai.transforms.croppad.dictionary import RandWeightedCropd -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose class TestRandWeightedCrop(NumpyImageTestCase2D): def test_rand_weighted_crop_small_roi(self): - img = self.seg1[0] - n_samples = 3 - crop = RandWeightedCropd("img", "w", (10, 12), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 40, 31] = 1 - weight[0, 80, 21] = 1 - crop.set_random_state(10) - d = {"img": img, "w": weight} - result = crop(d) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 12)) - np.testing.assert_allclose(np.asarray(crop.centers), [[80, 21], [30, 17], [40, 31]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.seg1[0] + n_samples = 3 + crop = RandWeightedCropd("img", "w", (10, 12), n_samples) + weight = np.zeros_like(img) + weight[0, 30, 17] = 1.1 + weight[0, 40, 31] = 1 + weight[0, 80, 21] = 1 + crop.set_random_state(10) + d = {"img": p(img), "w": q(weight)} + result = crop(d) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 12)) + for c, e in zip(crop.centers, [[80, 21], [30, 17], [40, 31]]): + assert_allclose(c, e, type_test=False) def test_rand_weighted_crop_default_roi(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd("im", "weight", (10, -1), n_samples, "coords") - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 40, 31] = 1 - weight[0, 80, 21] = 1 - crop.set_random_state(10) - data = {"im": img, "weight": weight, "others": np.nan} - result = crop(data) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["im"].shape, (1, 10, 64)) - np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32], [105, 32], [20, 32]]) - np.testing.assert_allclose(result[1]["coords"], [105, 32]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.imt[0] + n_samples = 3 + crop = RandWeightedCropd("im", "weight", (10, -1), n_samples, "coords") + weight = np.zeros_like(img) + weight[0, 30, 17] = 1.1 + weight[0, 40, 31] = 1 + weight[0, 80, 21] = 1 + crop.set_random_state(10) + data = {"im": p(img), "weight": q(weight), "others": np.nan} + result = crop(data) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["im"].shape, (1, 10, 64)) + for c, e in zip(crop.centers, [[14, 32], [105, 32], [20, 32]]): + assert_allclose(c, e, type_test=False) + assert_allclose(result[1]["coords"], [105, 32], type_test=False) def test_rand_weighted_crop_large_roi(self): - img = self.segn[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "weight", (10000, 400), n_samples, "location") - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 10, 1] = 1 - crop.set_random_state(10) - data = {"img": img, "seg": self.imt[0], "weight": weight} - result = crop(data) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 128, 64)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 128, 64)) - np.testing.assert_allclose(np.asarray(crop.centers), [[64, 32], [64, 32], [64, 32]]) - np.testing.assert_allclose(result[1]["location"], [64, 32]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.segn[0] + n_samples = 3 + crop = RandWeightedCropd(("img", "seg"), "weight", (10000, 400), n_samples, "location") + weight = np.zeros_like(img) + weight[0, 30, 17] = 1.1 + weight[0, 10, 1] = 1 + crop.set_random_state(10) + data = {"img": p(img), "seg": p(self.imt[0]), "weight": q(weight)} + result = crop(data) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 128, 64)) + np.testing.assert_allclose(result[0]["seg"].shape, (1, 128, 64)) + for c, e in zip(crop.centers, [[64, 32], [64, 32], [64, 32]]): + assert_allclose(c, e, type_test=False) + assert_allclose(result[1]["location"], [64, 32], type_test=False) def test_rand_weighted_crop_bad_w(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (20, 40), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = np.inf - weight[0, 10, 1] = -np.inf - weight[0, 10, 20] = -np.nan - crop.set_random_state(10) - result = crop({"img": img, "seg": self.segn[0], "w": weight}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 20, 40)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 20, 40)) - np.testing.assert_allclose(np.asarray(crop.centers), [[63, 37], [31, 43], [66, 20]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.imt[0] + n_samples = 3 + crop = RandWeightedCropd(("img", "seg"), "w", (20, 40), n_samples) + weight = np.zeros_like(img) + weight[0, 30, 17] = np.inf + weight[0, 10, 1] = -np.inf + weight[0, 10, 20] = -np.nan + crop.set_random_state(10) + result = crop({"img": p(img), "seg": p(self.segn[0]), "w": q(weight)}) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 20, 40)) + np.testing.assert_allclose(result[0]["seg"].shape, (1, 20, 40)) + for c, e in zip(crop.centers, [[63, 37], [31, 43], [66, 20]]): + assert_allclose(c, e, type_test=False) class TestRandWeightedCrop3D(NumpyImageTestCase3D): def test_rand_weighted_crop_small_roi(self): - img = self.seg1[0] - n_samples = 3 - crop = RandWeightedCropd("img", "w", (8, 10, 12), n_samples) - weight = np.zeros_like(img) - weight[0, 5, 30, 17] = 1.1 - weight[0, 8, 40, 31] = 1 - weight[0, 11, 23, 21] = 1 - crop.set_random_state(10) - result = crop({"img": img, "w": weight}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 8, 10, 12)) - np.testing.assert_allclose(np.asarray(crop.centers), [[11, 23, 21], [5, 30, 17], [8, 40, 31]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.seg1[0] + n_samples = 3 + crop = RandWeightedCropd("img", "w", (8, 10, 12), n_samples) + weight = np.zeros_like(img) + weight[0, 5, 30, 17] = 1.1 + weight[0, 8, 40, 31] = 1 + weight[0, 11, 23, 21] = 1 + crop.set_random_state(10) + result = crop({"img": p(img), "w": q(weight)}) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 8, 10, 12)) + for c, e in zip(crop.centers, [[11, 23, 21], [5, 30, 17], [8, 40, 31]]): + assert_allclose(c, e, type_test=False) def test_rand_weighted_crop_default_roi(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) - weight = np.zeros_like(img) - weight[0, 7, 17] = 1.1 - weight[0, 13, 31] = 1.1 - weight[0, 24, 21] = 1 - crop.set_random_state(10) - result = crop({"img": img, "seg": self.segn[0], "w": weight}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32, 40], [41, 32, 40], [20, 32, 40]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.imt[0] + n_samples = 3 + crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) + weight = np.zeros_like(img) + weight[0, 7, 17] = 1.1 + weight[0, 13, 31] = 1.1 + weight[0, 24, 21] = 1 + crop.set_random_state(10) + result = crop({"img": p(img), "seg": p(self.segn[0]), "w": q(weight)}) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 64, 80)) + np.testing.assert_allclose(result[0]["seg"].shape, (1, 10, 64, 80)) + for c, e in zip(crop.centers, [[14, 32, 40], [41, 32, 40], [20, 32, 40]]): + assert_allclose(c, e, type_test=False) def test_rand_weighted_crop_large_roi(self): - img = self.segn[0] - n_samples = 3 - crop = RandWeightedCropd("img", "w", (10000, 400, 80), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17, 20] = 1.1 - weight[0, 10, 1, 17] = 1 - crop.set_random_state(10) - result = crop({"img": img, "w": weight}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.segn[0] + n_samples = 3 + crop = RandWeightedCropd("img", "w", (10000, 400, 80), n_samples) + weight = np.zeros_like(img) + weight[0, 30, 17, 20] = 1.1 + weight[0, 10, 1, 17] = 1 + crop.set_random_state(10) + result = crop({"img": p(img), "w": q(weight)}) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) + for c, e in zip(crop.centers, [[24, 32, 40], [24, 32, 40], [24, 32, 40]]): + assert_allclose(c, e, type_test=False) def test_rand_weighted_crop_bad_w(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (48, 64, 80), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = np.inf - weight[0, 10, 1] = -np.inf - weight[0, 10, 20] = -np.nan - crop.set_random_state(10) - result = crop({"img": img, "seg": self.segn[0], "w": weight}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.imt[0] + n_samples = 3 + crop = RandWeightedCropd(("img", "seg"), "w", (48, 64, 80), n_samples) + weight = np.zeros_like(img) + weight[0, 30, 17] = np.inf + weight[0, 10, 1] = -np.inf + weight[0, 10, 20] = -np.nan + crop.set_random_state(10) + result = crop({"img": p(img), "seg": p(self.segn[0]), "w": q(weight)}) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) + np.testing.assert_allclose(result[0]["seg"].shape, (1, 48, 64, 80)) + for c, e in zip(crop.centers, [[24, 32, 40], [24, 32, 40], [24, 32, 40]]): + assert_allclose(c, e, type_test=False) def test_rand_weighted_crop_patch_index(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) - weight = np.zeros_like(img) - weight[0, 7, 17] = 1.1 - weight[0, 13, 31] = 1.1 - weight[0, 24, 21] = 1 - crop.set_random_state(10) - result = crop({"img": img, "seg": self.segn[0], "w": weight, "img_meta_dict": {"affine": None}}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32, 40], [41, 32, 40], [20, 32, 40]]) - for i in range(n_samples): - np.testing.assert_allclose(result[i]["img"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(result[i]["seg"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(result[i]["img_meta_dict"]["patch_index"], i) - np.testing.assert_allclose(result[i]["seg_meta_dict"]["patch_index"], i) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.imt[0] + n_samples = 3 + crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) + weight = np.zeros_like(img) + weight[0, 7, 17] = 1.1 + weight[0, 13, 31] = 1.1 + weight[0, 24, 21] = 1 + crop.set_random_state(10) + result = crop( + {"img": p(img), "seg": p(self.segn[0]), "w": q(weight), "img_meta_dict": {"affine": None}} + ) + self.assertTrue(len(result) == n_samples) + for c, e in zip(crop.centers, [[14, 32, 40], [41, 32, 40], [20, 32, 40]]): + assert_allclose(c, e, type_test=False) + for i in range(n_samples): + np.testing.assert_allclose(result[i]["img"].shape, (1, 10, 64, 80)) + np.testing.assert_allclose(result[i]["seg"].shape, (1, 10, 64, 80)) + np.testing.assert_allclose(result[i]["img_meta_dict"]["patch_index"], i) + np.testing.assert_allclose(result[i]["seg_meta_dict"]["patch_index"], i) if __name__ == "__main__": diff --git a/tests/test_rand_zoom.py b/tests/test_rand_zoom.py index c21bc8b9e9..35472024ef 100644 --- a/tests/test_rand_zoom.py +++ b/tests/test_rand_zoom.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,7 +17,7 @@ from monai.transforms import RandZoom from monai.utils import GridSampleMode, InterpolateMode -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose VALID_CASES = [(0.8, 1.2, "nearest", False), (0.8, 1.2, InterpolateMode.NEAREST, False)] @@ -25,36 +25,28 @@ class TestRandZoom(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) def test_correct_results(self, min_zoom, max_zoom, mode, keep_size): - random_zoom = RandZoom( - prob=1.0, - min_zoom=min_zoom, - max_zoom=max_zoom, - mode=mode, - keep_size=keep_size, - ) - random_zoom.set_random_state(1234) - zoomed = random_zoom(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False)) - expected = np.stack(expected).astype(np.float32) - np.testing.assert_allclose(zoomed, expected, atol=1.0) + for p in TEST_NDARRAYS: + random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode, keep_size=keep_size) + random_zoom.set_random_state(1234) + zoomed = random_zoom(p(self.imt[0])) + expected = [ + zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False) + for channel in self.imt[0] + ] + + expected = np.stack(expected).astype(np.float32) + assert_allclose(zoomed, p(expected), atol=1.0) def test_keep_size(self): - random_zoom = RandZoom( - prob=1.0, - min_zoom=0.6, - max_zoom=0.7, - keep_size=True, - padding_mode="constant", - constant_values=2, - ) - zoomed = random_zoom(self.imt[0]) - self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) - zoomed = random_zoom(self.imt[0]) - self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) - zoomed = random_zoom(self.imt[0]) - self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) + for p in TEST_NDARRAYS: + im = p(self.imt[0]) + random_zoom = RandZoom(prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True) + zoomed = random_zoom(im) + self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) + zoomed = random_zoom(im) + self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) + zoomed = random_zoom(im) + self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) @parameterized.expand( [ @@ -64,23 +56,19 @@ def test_keep_size(self): ] ) def test_invalid_inputs(self, _, min_zoom, max_zoom, mode, raises): - with self.assertRaises(raises): - random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode) - random_zoom(self.imt[0]) + for p in TEST_NDARRAYS: + with self.assertRaises(raises): + random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode) + random_zoom(p(self.imt[0])) def test_auto_expand_3d(self): - random_zoom = RandZoom( - prob=1.0, - min_zoom=[0.8, 0.7], - max_zoom=[1.2, 1.3], - mode="nearest", - keep_size=False, - ) - random_zoom.set_random_state(1234) - test_data = np.random.randint(0, 2, size=[2, 2, 3, 4]) - zoomed = random_zoom(test_data) - np.testing.assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2) - np.testing.assert_allclose(zoomed.shape, (2, 2, 3, 3)) + for p in TEST_NDARRAYS: + random_zoom = RandZoom(prob=1.0, min_zoom=[0.8, 0.7], max_zoom=[1.2, 1.3], mode="nearest", keep_size=False) + random_zoom.set_random_state(1234) + test_data = p(np.random.randint(0, 2, size=[2, 2, 3, 4])) + zoomed = random_zoom(test_data) + assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2) + assert_allclose(zoomed.shape, (2, 2, 3, 3)) if __name__ == "__main__": diff --git a/tests/test_rand_zoomd.py b/tests/test_rand_zoomd.py index 4ccb1aad64..a22f2f36f1 100644 --- a/tests/test_rand_zoomd.py +++ b/tests/test_rand_zoomd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,7 +16,7 @@ from scipy.ndimage import zoom as zoom_scipy from monai.transforms import RandZoomd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose VALID_CASES = [(0.8, 1.2, "nearest", None, False)] @@ -34,52 +34,47 @@ def test_correct_results(self, min_zoom, max_zoom, mode, align_corners, keep_siz align_corners=align_corners, keep_size=keep_size, ) - random_zoom.set_random_state(1234) + for p in TEST_NDARRAYS: + random_zoom.set_random_state(1234) - zoomed = random_zoom({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False)) - expected = np.stack(expected).astype(np.float32) - np.testing.assert_allclose(expected, zoomed[key], atol=1.0) + zoomed = random_zoom({key: p(self.imt[0])}) + expected = [ + zoom_scipy(channel, zoom=random_zoom.rand_zoom._zoom, mode="nearest", order=0, prefilter=False) + for channel in self.imt[0] + ] + + expected = np.stack(expected).astype(np.float32) + assert_allclose(zoomed[key], p(expected), atol=1.0) def test_keep_size(self): key = "img" random_zoom = RandZoomd( - keys=key, - prob=1.0, - min_zoom=0.6, - max_zoom=0.7, - keep_size=True, - padding_mode="constant", - constant_values=2, + keys=key, prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True, padding_mode="constant", constant_values=2 ) - zoomed = random_zoom({key: self.imt[0]}) - self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:])) + for p in TEST_NDARRAYS: + zoomed = random_zoom({key: p(self.imt[0])}) + np.testing.assert_array_equal(zoomed[key].shape, self.imt.shape[1:]) @parameterized.expand( [("no_min_zoom", None, 1.1, "bilinear", TypeError), ("invalid_order", 0.9, 1.1, "s", ValueError)] ) def test_invalid_inputs(self, _, min_zoom, max_zoom, mode, raises): key = "img" - with self.assertRaises(raises): - random_zoom = RandZoomd(key, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode) - random_zoom({key: self.imt[0]}) + for p in TEST_NDARRAYS: + with self.assertRaises(raises): + random_zoom = RandZoomd(key, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode) + random_zoom({key: p(self.imt[0])}) def test_auto_expand_3d(self): random_zoom = RandZoomd( - keys="img", - prob=1.0, - min_zoom=[0.8, 0.7], - max_zoom=[1.2, 1.3], - mode="nearest", - keep_size=False, + keys="img", prob=1.0, min_zoom=[0.8, 0.7], max_zoom=[1.2, 1.3], mode="nearest", keep_size=False ) - random_zoom.set_random_state(1234) - test_data = {"img": np.random.randint(0, 2, size=[2, 2, 3, 4])} - zoomed = random_zoom(test_data) - np.testing.assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2) - np.testing.assert_allclose(zoomed["img"].shape, (2, 2, 3, 3)) + for p in TEST_NDARRAYS: + random_zoom.set_random_state(1234) + test_data = {"img": p(np.random.randint(0, 2, size=[2, 2, 3, 4]))} + zoomed = random_zoom(test_data) + assert_allclose(random_zoom.rand_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2) + assert_allclose(zoomed["img"].shape, (2, 2, 3, 3)) if __name__ == "__main__": diff --git a/tests/test_randomizable.py b/tests/test_randomizable.py index 9972bded0f..7445287a12 100644 --- a/tests/test_randomizable.py +++ b/tests/test_randomizable.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_randtorchvisiond.py b/tests/test_randtorchvisiond.py index d0485ce405..2e96d723ee 100644 --- a/tests/test_randtorchvisiond.py +++ b/tests/test_randtorchvisiond.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -29,19 +29,10 @@ {"img": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])}, torch.tensor( [ - [ - [0.1090, 0.6193], - [0.6193, 0.9164], - ], - [ - [0.1090, 0.6193], - [0.6193, 0.9164], - ], - [ - [0.1090, 0.6193], - [0.6193, 0.9164], - ], - ], + [[0.1090, 0.6193], [0.6193, 0.9164]], + [[0.1090, 0.6193], [0.6193, 0.9164]], + [[0.1090, 0.6193], [0.6193, 0.9164]], + ] ), ] @@ -50,24 +41,9 @@ {"img": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])}, torch.tensor( [ - [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - ], + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], ] ), ] diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py index b864a64647..822a056879 100644 --- a/tests/test_reg_loss_integration.py +++ b/tests/test_reg_loss_integration.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,25 +17,15 @@ from parameterized import parameterized from monai.losses import BendingEnergyLoss, GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss +from tests.utils import SkipIfBeforePyTorchVersion TEST_CASES = [ - [BendingEnergyLoss, {}, ["pred"]], - [ - LocalNormalizedCrossCorrelationLoss, - {"kernel_size": 7, "kernel_type": "rectangular"}, - ["pred", "target"], - ], - [ - LocalNormalizedCrossCorrelationLoss, - {"kernel_size": 5, "kernel_type": "triangular"}, - ["pred", "target"], - ], - [ - LocalNormalizedCrossCorrelationLoss, - {"kernel_size": 3, "kernel_type": "gaussian"}, - ["pred", "target"], - ], + [BendingEnergyLoss, {}, ["pred"], 3], + [LocalNormalizedCrossCorrelationLoss, {"kernel_size": 7, "kernel_type": "rectangular"}, ["pred", "target"]], + [LocalNormalizedCrossCorrelationLoss, {"kernel_size": 5, "kernel_type": "triangular"}, ["pred", "target"]], + [LocalNormalizedCrossCorrelationLoss, {"kernel_size": 3, "kernel_type": "gaussian"}, ["pred", "target"]], [GlobalMutualInformationLoss, {"num_bins": 10}, ["pred", "target"]], + [GlobalMutualInformationLoss, {"kernel_type": "b-spline", "num_bins": 10}, ["pred", "target"]], ] @@ -51,7 +41,8 @@ def tearDown(self): torch.backends.cudnn.benchmark = True @parameterized.expand(TEST_CASES) - def test_convergence(self, loss_type, loss_args, forward_args): + @SkipIfBeforePyTorchVersion((1, 9)) + def test_convergence(self, loss_type, loss_args, forward_args, pred_channels=1): """ The goal of this test is to assess if the gradient of the loss function is correct by testing if we can train a one layer neural network @@ -69,11 +60,11 @@ def test_convergence(self, loss_type, loss_args, forward_args): # define a one layer model class OnelayerNet(nn.Module): def __init__(self): - super(OnelayerNet, self).__init__() + super().__init__() self.layer = nn.Sequential( nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1), nn.ReLU(), - nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1), + nn.Conv3d(in_channels=1, out_channels=pred_channels, kernel_size=3, padding=1), ) def forward(self, x): diff --git a/tests/test_regunet.py b/tests/test_regunet.py index 4dd968a1cf..e37ca49538 100644 --- a/tests/test_regunet.py +++ b/tests/test_regunet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_regunet_block.py b/tests/test_regunet_block.py index 9b96875432..3be02ea377 100644 --- a/tests/test_regunet_block.py +++ b/tests/test_regunet_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_remove_repeated_channel.py b/tests/test_remove_repeated_channel.py index ebbe6c730c..39b42cc4b0 100644 --- a/tests/test_remove_repeated_channel.py +++ b/tests/test_remove_repeated_channel.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_remove_repeated_channeld.py b/tests/test_remove_repeated_channeld.py index 9d4812791e..9db66a6aa0 100644 --- a/tests/test_remove_repeated_channeld.py +++ b/tests/test_remove_repeated_channeld.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_repeat_channel.py b/tests/test_repeat_channel.py index e246dd1212..3d74b6479c 100644 --- a/tests/test_repeat_channel.py +++ b/tests/test_repeat_channel.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_repeat_channeld.py b/tests/test_repeat_channeld.py index 3b73962bb9..a348f3eea9 100644 --- a/tests/test_repeat_channeld.py +++ b/tests/test_repeat_channeld.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_require_pkg.py b/tests/test_require_pkg.py new file mode 100644 index 0000000000..dbe63fff3f --- /dev/null +++ b/tests/test_require_pkg.py @@ -0,0 +1,77 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from monai.utils import OptionalImportError, min_version, require_pkg + + +class TestRequirePkg(unittest.TestCase): + def test_class(self): + @require_pkg(pkg_name="torch", version="1.4", version_checker=min_version) + class TestClass: + pass + + TestClass() + + def test_function(self): + @require_pkg(pkg_name="torch", version="1.4", version_checker=min_version) + def test_func(x): + return x + + test_func(x=None) + + def test_warning(self): + @require_pkg(pkg_name="test123", raise_error=False) + def test_func(x): + return x + + test_func(x=None) + + def test_class_exception(self): + with self.assertRaises(OptionalImportError): + + @require_pkg(pkg_name="test123") + class TestClass: + pass + + TestClass() + + def test_class_version_exception(self): + with self.assertRaises(OptionalImportError): + + @require_pkg(pkg_name="torch", version="10000", version_checker=min_version) + class TestClass: + pass + + TestClass() + + def test_func_exception(self): + with self.assertRaises(OptionalImportError): + + @require_pkg(pkg_name="test123") + def test_func(x): + return x + + test_func(x=None) + + def test_func_versions_exception(self): + with self.assertRaises(OptionalImportError): + + @require_pkg(pkg_name="torch", version="10000", version_checker=min_version) + def test_func(x): + return x + + test_func(x=None) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_resample_datalist.py b/tests/test_resample_datalist.py new file mode 100644 index 0000000000..fa120b261e --- /dev/null +++ b/tests/test_resample_datalist.py @@ -0,0 +1,40 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.data import resample_datalist + +TEST_CASE_1 = [ + {"data": [1, 2, 3, 4, 5], "factor": 2.5, "random_pick": True, "seed": 123}, + [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 2, 4, 5], +] + +TEST_CASE_2 = [ + {"data": [1, 2, 3, 4, 5], "factor": 2.5, "random_pick": False, "seed": 0}, + [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3], +] + +TEST_CASE_3 = [{"data": [1, 2, 3, 4, 5], "factor": 0.6, "random_pick": True, "seed": 123}, [2, 4, 5]] + + +class TestResampleDatalist(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_value_shape(self, input_param, expected): + result = resample_datalist(**input_param) + np.testing.assert_allclose(result, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_resampler.py b/tests/test_resampler.py index 2be94acebd..7dfb86a7a9 100644 --- a/tests/test_resampler.py +++ b/tests/test_resampler.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,69 +17,146 @@ from monai.transforms import Resample from monai.transforms.utils import create_grid +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"grid": create_grid((2, 2)), "img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 1.0], [2.0, 3.0]]]), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4)), "img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(padding_mode="border", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4)), "img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3, 3.0], [2.0, 2.0, 3.0, 3.0]]]), - ], - [ - dict(padding_mode="reflection", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4)), "img": np.arange(4).reshape((1, 2, 2)), "mode": "nearest"}, - np.array([[[3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0], [3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0]]]), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4, 4)), "img": np.arange(8).reshape((1, 2, 2, 2)), "mode": "bilinear"}, - np.array( - [ +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 4.0, 5.0, 0.0], [0.0, 6.0, 7.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + dict(padding_mode="zeros", device=device), + {"grid": p(create_grid((2, 2))), "img": q(np.arange(4).reshape((1, 2, 2)))}, + q(np.array([[[0.0, 1.0], [2.0, 3.0]]])), ] - ] - ), - ], - [ - dict(padding_mode="border", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4, 4)), "img": np.arange(8).reshape((1, 2, 2, 2)), "mode": "bilinear"}, - np.array( - [ + ) + TESTS.append( [ - [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3.0, 3.0], [2.0, 2.0, 3.0, 3.0]], - [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3.0, 3.0], [2.0, 2.0, 3.0, 3.0]], - [[4.0, 4.0, 5.0, 5.0], [4.0, 4.0, 5.0, 5.0], [6.0, 6.0, 7.0, 7.0], [6.0, 6.0, 7.0, 7.0]], - [[4.0, 4.0, 5.0, 5.0], [4.0, 4.0, 5.0, 5.0], [6.0, 6.0, 7.0, 7.0], [6.0, 6.0, 7.0, 7.0]], + dict(padding_mode="zeros", device=device), + {"grid": p(create_grid((4, 4))), "img": q(np.arange(4).reshape((1, 2, 2)))}, + q( + np.array( + [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]] + ) + ), ] - ] - ), - ], -] + ) + TESTS.append( + [ + dict(padding_mode="border", device=device), + {"grid": p(create_grid((4, 4))), "img": q(np.arange(4).reshape((1, 2, 2)))}, + q( + np.array( + [[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3, 3.0], [2.0, 2.0, 3.0, 3.0]]] + ) + ), + ] + ) + TESTS.append( + [ + dict(padding_mode="reflection", device=device), + {"grid": p(create_grid((4, 4))), "img": q(np.arange(4).reshape((1, 2, 2))), "mode": "nearest"}, + q( + np.array( + [[[3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0], [3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0]]] + ) + ), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + { + "grid": p(create_grid((4, 4, 4))), + "img": q(np.arange(8).reshape((1, 2, 2, 2))), + "mode": "bilinear", + }, + q( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 2.0, 3.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 4.0, 5.0, 0.0], + [0.0, 6.0, 7.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + dict(padding_mode="border", device=device), + { + "grid": p(create_grid((4, 4, 4))), + "img": q(np.arange(8).reshape((1, 2, 2, 2))), + "mode": "bilinear", + }, + q( + np.array( + [ + [ + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + [2.0, 2.0, 3.0, 3.0], + [2.0, 2.0, 3.0, 3.0], + ], + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + [2.0, 2.0, 3.0, 3.0], + [2.0, 2.0, 3.0, 3.0], + ], + [ + [4.0, 4.0, 5.0, 5.0], + [4.0, 4.0, 5.0, 5.0], + [6.0, 6.0, 7.0, 7.0], + [6.0, 6.0, 7.0, 7.0], + ], + [ + [4.0, 4.0, 5.0, 5.0], + [4.0, 4.0, 5.0, 5.0], + [6.0, 6.0, 7.0, 7.0], + [6.0, 6.0, 7.0, 7.0], + ], + ] + ] + ) + ), + ] + ) class TestResample(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_resample(self, input_param, input_data, expected_val): g = Resample(**input_param) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + if "device" in input_data: + self.assertEqual(result.device, input_data["device"]) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_resize.py b/tests/test_resize.py index e5ec5dd1a9..06246b2358 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.transforms import Resize -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose TEST_CASE_0 = [{"spatial_size": 15}, (6, 10, 15)] @@ -45,16 +45,17 @@ def test_correct_results(self, spatial_size, mode): _order = 1 if spatial_size == (32, -1): spatial_size = (32, 64) - expected = [] - for channel in self.imt[0]: - expected.append( - skimage.transform.resize( - channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=False - ) + expected = [ + skimage.transform.resize( + channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=False ) + for channel in self.imt[0] + ] + expected = np.stack(expected).astype(np.float32) - out = resize(self.imt[0]) - np.testing.assert_allclose(out, expected, atol=0.9) + for p in TEST_NDARRAYS: + out = resize(p(self.imt[0])) + assert_allclose(out, expected, type_test=False, atol=0.9) @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) def test_longest_shape(self, input_param, expected_shape): diff --git a/tests/test_resize_with_pad_or_crop.py b/tests/test_resize_with_pad_or_crop.py index 46f1fc86cc..f81e1d4b08 100644 --- a/tests/test_resize_with_pad_or_crop.py +++ b/tests/test_resize_with_pad_or_crop.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,47 +12,38 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import ResizeWithPadOrCrop +from tests.utils import TEST_NDARRAYS TEST_CASES = [ - [ - {"spatial_size": [15, 8, 8], "mode": "constant"}, - (3, 8, 8, 4), - (3, 15, 8, 8), - ], + [{"spatial_size": [15, 8, 8], "mode": "constant"}, (3, 8, 8, 4), (3, 15, 8, 8)], [ {"spatial_size": [15, 4, 8], "mode": "constant", "method": "end", "constant_values": 1}, (3, 8, 8, 4), (3, 15, 4, 8), ], - [ - {"spatial_size": [15, 4, -1], "mode": "constant"}, - (3, 8, 8, 4), - (3, 15, 4, 4), - ], - [ - {"spatial_size": [15, 4, -1], "mode": "reflect"}, - (3, 8, 8, 4), - (3, 15, 4, 4), - ], - [ - {"spatial_size": [-1, -1, -1], "mode": "reflect"}, - (3, 8, 8, 4), - (3, 8, 8, 4), - ], + [{"spatial_size": [15, 4, -1], "mode": "constant"}, (3, 8, 8, 4), (3, 15, 4, 4)], + [{"spatial_size": [15, 4, -1], "mode": "reflect"}, (3, 8, 8, 4), (3, 15, 4, 4)], + [{"spatial_size": [-1, -1, -1], "mode": "reflect"}, (3, 8, 8, 4), (3, 8, 8, 4)], ] class TestResizeWithPadOrCrop(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_pad_shape(self, input_param, input_shape, expected_shape): - paddcroper = ResizeWithPadOrCrop(**input_param) - result = paddcroper(np.zeros(input_shape)) - np.testing.assert_allclose(result.shape, expected_shape) - result = paddcroper(np.zeros(input_shape), mode="constant") - np.testing.assert_allclose(result.shape, expected_shape) + for p in TEST_NDARRAYS: + if isinstance(p(0), torch.Tensor) and ( + "constant_values" in input_param or input_param["mode"] == "reflect" + ): + continue + paddcroper = ResizeWithPadOrCrop(**input_param) + result = paddcroper(p(np.zeros(input_shape))) + np.testing.assert_allclose(result.shape, expected_shape) + result = paddcroper(p(np.zeros(input_shape)), mode="constant") + np.testing.assert_allclose(result.shape, expected_shape) if __name__ == "__main__": diff --git a/tests/test_resize_with_pad_or_cropd.py b/tests/test_resize_with_pad_or_cropd.py index 32a62a9e16..28993a2bf4 100644 --- a/tests/test_resize_with_pad_or_cropd.py +++ b/tests/test_resize_with_pad_or_cropd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,45 +12,37 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import ResizeWithPadOrCropd +from tests.utils import TEST_NDARRAYS TEST_CASES = [ - [ - {"keys": "img", "spatial_size": [15, 8, 8], "mode": "constant"}, - {"img": np.zeros((3, 8, 8, 4))}, - (3, 15, 8, 8), - ], + [{"keys": "img", "spatial_size": [15, 8, 8], "mode": "constant"}, {"img": np.zeros((3, 8, 8, 4))}, (3, 15, 8, 8)], [ {"keys": "img", "spatial_size": [15, 4, 8], "mode": "constant", "method": "end", "constant_values": 1}, {"img": np.zeros((3, 8, 8, 4))}, (3, 15, 4, 8), ], - [ - {"keys": "img", "spatial_size": [15, 4, -1], "mode": "constant"}, - {"img": np.zeros((3, 8, 8, 4))}, - (3, 15, 4, 4), - ], - [ - {"keys": "img", "spatial_size": [15, 4, -1], "mode": "reflect"}, - {"img": np.zeros((3, 8, 8, 4))}, - (3, 15, 4, 4), - ], - [ - {"keys": "img", "spatial_size": [-1, -1, -1], "mode": "reflect"}, - {"img": np.zeros((3, 8, 8, 4))}, - (3, 8, 8, 4), - ], + [{"keys": "img", "spatial_size": [15, 4, -1], "mode": "constant"}, {"img": np.zeros((3, 8, 8, 4))}, (3, 15, 4, 4)], + [{"keys": "img", "spatial_size": [15, 4, -1], "mode": "reflect"}, {"img": np.zeros((3, 8, 8, 4))}, (3, 15, 4, 4)], + [{"keys": "img", "spatial_size": [-1, -1, -1], "mode": "reflect"}, {"img": np.zeros((3, 8, 8, 4))}, (3, 8, 8, 4)], ] class TestResizeWithPadOrCropd(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_pad_shape(self, input_param, input_data, expected_val): - paddcroper = ResizeWithPadOrCropd(**input_param) - result = paddcroper(input_data) - np.testing.assert_allclose(result["img"].shape, expected_val) + for p in TEST_NDARRAYS: + if isinstance(p(0), torch.Tensor) and ( + "constant_values" in input_param or input_param["mode"] == "reflect" + ): + continue + paddcroper = ResizeWithPadOrCropd(**input_param) + input_data["img"] = p(input_data["img"]) + result = paddcroper(input_data) + np.testing.assert_allclose(result["img"].shape, expected_val) if __name__ == "__main__": diff --git a/tests/test_resized.py b/tests/test_resized.py index 930faf00eb..d7374ea930 100644 --- a/tests/test_resized.py +++ b/tests/test_resized.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.transforms import Resized -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose TEST_CASE_0 = [{"keys": "img", "spatial_size": 15}, (6, 10, 15)] @@ -48,16 +48,17 @@ def test_correct_results(self, spatial_size, mode): _order = 1 if spatial_size == (32, -1): spatial_size = (32, 64) - expected = [] - for channel in self.imt[0]: - expected.append( - skimage.transform.resize( - channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=False - ) + expected = [ + skimage.transform.resize( + channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=False ) + for channel in self.imt[0] + ] + expected = np.stack(expected).astype(np.float32) - out = resize({"img": self.imt[0]})["img"] - np.testing.assert_allclose(out, expected, atol=0.9) + for p in TEST_NDARRAYS: + out = resize({"img": p(self.imt[0])})["img"] + assert_allclose(out, expected, type_test=False, atol=0.9) @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_longest_shape(self, input_param, expected_shape): diff --git a/tests/test_resnet.py b/tests/test_resnet.py index c4ba5c2e16..ffb48125a4 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -42,14 +42,32 @@ (2, 3), ] +TEST_CASE_2_A = [ # 2D, batch 2, 1 input channel, shortcut type A + {"pretrained": False, "spatial_dims": 2, "n_input_channels": 1, "num_classes": 3, "shortcut_type": "A"}, + (2, 1, 32, 64), + (2, 3), +] + TEST_CASE_3 = [ # 1D, batch 1, 2 input channels {"pretrained": False, "spatial_dims": 1, "n_input_channels": 2, "num_classes": 3}, (1, 2, 32), (1, 3), ] +TEST_CASE_3_A = [ # 1D, batch 1, 2 input channels + {"pretrained": False, "spatial_dims": 1, "n_input_channels": 2, "num_classes": 3, "shortcut_type": "A"}, + (1, 2, 32), + (1, 3), +] + +TEST_CASE_4 = [ # 2D, batch 2, 1 input channel + {"pretrained": False, "spatial_dims": 2, "n_input_channels": 1, "num_classes": 3, "feed_forward": False}, + (2, 1, 32, 64), + ((2, 512), (2, 2048)), +] + TEST_CASES = [] -for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]: +for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]: for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]: TEST_CASES.append([model, *case]) @@ -64,7 +82,10 @@ def test_resnet_shape(self, model, input_param, input_shape, expected_shape): net = model(**input_param).to(device) with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) - self.assertEqual(result.shape, expected_shape) + if input_param.get("feed_forward", True): + self.assertEqual(result.shape, expected_shape) + else: + self.assertTrue(result.shape in expected_shape) @parameterized.expand(TEST_SCRIPT_CASES) def test_script(self, model, input_param, input_shape, expected_shape): diff --git a/tests/test_rotate.py b/tests/test_rotate.py index 436c952d4b..42947e7f72 100644 --- a/tests/test_rotate.py +++ b/tests/test_rotate.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,42 +10,44 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import Rotate -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D - -TEST_CASES_2D = [ - (np.pi / 6, False, "bilinear", "border", False), - (np.pi / 4, True, "bilinear", "border", False), - (-np.pi / 4.5, True, "nearest", "reflection", False), - (np.pi, False, "nearest", "zeros", False), - (-np.pi / 2, False, "bilinear", "zeros", True), -] - -TEST_CASES_3D = [ - (-np.pi / 2, True, "nearest", "border", False), - (np.pi / 4, True, "bilinear", "border", False), - (-np.pi / 4.5, True, "nearest", "reflection", False), - (np.pi, False, "nearest", "zeros", False), - (-np.pi / 2, False, "bilinear", "zeros", False), -] - -TEST_CASES_SHAPE_3D = [ - ([-np.pi / 2, 1.0, 2.0], "nearest", "border", False), - ([np.pi / 4, 0, 0], "bilinear", "border", False), - ([-np.pi / 4.5, -20, 20], "nearest", "reflection", False), -] +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D + +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, np.pi / 6, False, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, -np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", False)) + TEST_CASES_2D.append((p, -np.pi / 2, False, "bilinear", "zeros", True)) + +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append((p, -np.pi / 2, True, "nearest", "border", False)) + TEST_CASES_3D.append((p, np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_3D.append((p, -np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_3D.append((p, np.pi, False, "nearest", "zeros", False)) + TEST_CASES_3D.append((p, -np.pi / 2, False, "bilinear", "zeros", False)) + +TEST_CASES_SHAPE_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_SHAPE_3D.append((p, [-np.pi / 2, 1.0, 2.0], "nearest", "border", False)) + TEST_CASES_SHAPE_3D.append((p, [np.pi / 4, 0, 0], "bilinear", "border", False)) + TEST_CASES_SHAPE_3D.append((p, [-np.pi / 4.5, -20, 20], "nearest", "reflection", False)) class TestRotate2D(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES_2D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotate(angle, keep_size, mode, padding_mode, align_corners) - rotated = rotate_fn(self.imt[0]) + rotated = rotate_fn(im_type(self.imt[0])) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated.shape) _order = 0 if mode == "nearest" else 1 @@ -60,25 +62,20 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne for channel in self.imt[0]: expected.append( scipy.ndimage.rotate( - channel, - -np.rad2deg(angle), - (0, 1), - not keep_size, - order=_order, - mode=_mode, - prefilter=False, + channel, -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) ) expected = np.stack(expected).astype(np.float32) + rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated good = np.sum(np.isclose(expected, rotated, atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") class TestRotate3D(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotate([angle, 0, 0], keep_size, mode, padding_mode, align_corners) - rotated = rotate_fn(self.imt[0]) + rotated = rotate_fn(im_type(self.imt[0])) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated.shape) _order = 0 if mode == "nearest" else 1 @@ -93,33 +90,29 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne for channel in self.imt[0]: expected.append( scipy.ndimage.rotate( - channel, - -np.rad2deg(angle), - (1, 2), - not keep_size, - order=_order, - mode=_mode, - prefilter=False, + channel, -np.rad2deg(angle), (1, 2), not keep_size, order=_order, mode=_mode, prefilter=False ) ) expected = np.stack(expected).astype(np.float32) + rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated n_good = np.sum(np.isclose(expected, rotated, atol=1e-3)) self.assertLessEqual(expected.size - n_good, 5, "diff at most 5 pixels") @parameterized.expand(TEST_CASES_SHAPE_3D) - def test_correct_shape(self, angle, mode, padding_mode, align_corners): + def test_correct_shape(self, im_type, angle, mode, padding_mode, align_corners): rotate_fn = Rotate(angle, True, align_corners=align_corners) - rotated = rotate_fn(self.imt[0], mode=mode, padding_mode=padding_mode) + rotated = rotate_fn(im_type(self.imt[0]), mode=mode, padding_mode=padding_mode) np.testing.assert_allclose(self.imt[0].shape, rotated.shape) def test_ill_case(self): - rotate_fn = Rotate(10, True) - with self.assertRaises(ValueError): # wrong shape - rotate_fn(self.imt) - - rotate_fn = Rotate(10, keep_size=False) - with self.assertRaises(ValueError): # wrong mode - rotate_fn(self.imt[0], mode="trilinear") + for p in TEST_NDARRAYS: + rotate_fn = Rotate(10, True) + with self.assertRaises(ValueError): # wrong shape + rotate_fn(p(self.imt)) + + rotate_fn = Rotate(10, keep_size=False) + with self.assertRaises(ValueError): # wrong mode + rotate_fn(p(self.imt[0]), mode="trilinear") if __name__ == "__main__": diff --git a/tests/test_rotate90.py b/tests/test_rotate90.py index 4ab39d5cf6..9865120688 100644 --- a/tests/test_rotate90.py +++ b/tests/test_rotate90.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,45 +14,41 @@ import numpy as np from monai.transforms import Rotate90 -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestRotate90(NumpyImageTestCase2D): def test_rotate90_default(self): rotate = Rotate90() - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotated = rotate(p(self.imt[0])) + expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) def test_k(self): rotate = Rotate90(k=2) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 2, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotated = rotate(p(self.imt[0])) + expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) def test_spatial_axes(self): rotate = Rotate90(spatial_axes=(0, -1)) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, -1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotated = rotate(p(self.imt[0])) + expected = [np.rot90(channel, 1, (0, -1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) def test_prob_k_spatial_axes(self): rotate = Rotate90(k=2, spatial_axes=(0, 1)) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 2, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotated = rotate(p(self.imt[0])) + expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) if __name__ == "__main__": diff --git a/tests/test_rotate90d.py b/tests/test_rotate90d.py index 3d71ead82a..ef4bad9419 100644 --- a/tests/test_rotate90d.py +++ b/tests/test_rotate90d.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,49 +14,45 @@ import numpy as np from monai.transforms import Rotate90d -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestRotate90d(NumpyImageTestCase2D): def test_rotate90_default(self): key = "test" rotate = Rotate90d(keys=key) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotated = rotate({key: p(self.imt[0])}) + expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated[key], p(expected)) def test_k(self): key = None rotate = Rotate90d(keys=key, k=2) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 2, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotated = rotate({key: p(self.imt[0])}) + expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated[key], p(expected)) def test_spatial_axes(self): key = "test" rotate = Rotate90d(keys=key, spatial_axes=(0, 1)) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotated = rotate({key: p(self.imt[0])}) + expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated[key], p(expected)) def test_prob_k_spatial_axes(self): key = "test" rotate = Rotate90d(keys=key, k=2, spatial_axes=(0, 1)) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 2, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotated = rotate({key: p(self.imt[0])}) + expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated[key], p(expected)) def test_no_key(self): key = "unknown" diff --git a/tests/test_rotated.py b/tests/test_rotated.py index 2ea421101b..1b759cfef5 100644 --- a/tests/test_rotated.py +++ b/tests/test_rotated.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,36 +10,38 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import Rotated -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D -TEST_CASES_2D = [ - (-np.pi / 6, False, "bilinear", "border", False), - (-np.pi / 4, True, "bilinear", "border", False), - (np.pi / 4.5, True, "nearest", "reflection", False), - (-np.pi, False, "nearest", "zeros", False), - (np.pi / 2, False, "bilinear", "zeros", True), -] +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, -np.pi / 6, False, "bilinear", "border", False)) + TEST_CASES_2D.append((p, -np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_2D.append((p, -np.pi, False, "nearest", "zeros", False)) + TEST_CASES_2D.append((p, np.pi / 2, False, "bilinear", "zeros", True)) -TEST_CASES_3D = [ - (-np.pi / 6, False, "bilinear", "border", False), - (-np.pi / 4, True, "bilinear", "border", False), - (np.pi / 4.5, True, "nearest", "reflection", False), - (-np.pi, False, "nearest", "zeros", False), - (np.pi / 2, False, "bilinear", "zeros", True), -] +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append((p, -np.pi / 6, False, "bilinear", "border", False)) + TEST_CASES_3D.append((p, -np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_3D.append((p, np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_3D.append((p, -np.pi, False, "nearest", "zeros", False)) + TEST_CASES_3D.append((p, np.pi / 2, False, "bilinear", "zeros", True)) class TestRotated2D(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES_2D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotated(("img", "seg"), angle, keep_size, (mode, "nearest"), padding_mode, align_corners) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) _order = 0 if mode == "nearest" else 1 @@ -52,6 +54,8 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne expected = scipy.ndimage.rotate( self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v good = np.sum(np.isclose(expected, rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") @@ -64,9 +68,9 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne class TestRotated3D(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotated(("img", "seg"), [0, angle, 0], keep_size, (mode, "nearest"), padding_mode, align_corners) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) _order = 0 if mode == "nearest" else 1 @@ -79,6 +83,8 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne expected = scipy.ndimage.rotate( self.imt[0, 0], np.rad2deg(angle), (0, 2), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v good = np.sum(np.isclose(expected.astype(np.float32), rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 voxels.") @@ -86,14 +92,14 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne self.segn[0, 0], np.rad2deg(angle), (0, 2), not keep_size, order=0, mode=_mode, prefilter=False ) expected = np.stack(expected).astype(int) - self.assertLessEqual(np.count_nonzero(expected != rotated["seg"][0]), 130) + self.assertLessEqual(np.count_nonzero(expected != rotated["seg"][0]), 160) class TestRotated3DXY(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotated(("img", "seg"), [0, 0, angle], keep_size, (mode, "nearest"), padding_mode, align_corners) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) _order = 0 if mode == "nearest" else 1 @@ -106,6 +112,8 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne expected = scipy.ndimage.rotate( self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v good = np.sum(np.isclose(expected, rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 voxels") @@ -113,7 +121,7 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne self.segn[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=0, mode=_mode, prefilter=False ) expected = np.stack(expected).astype(int) - self.assertLessEqual(np.count_nonzero(expected != rotated["seg"][0]), 130) + self.assertLessEqual(np.count_nonzero(expected != rotated["seg"][0]), 160) if __name__ == "__main__": diff --git a/tests/test_saliency_inferer.py b/tests/test_saliency_inferer.py index 416b7170ae..c97bcb7811 100644 --- a/tests/test_saliency_inferer.py +++ b/tests/test_saliency_inferer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_distributed_sampler.py b/tests/test_sampler_dist.py similarity index 97% rename from tests/test_distributed_sampler.py rename to tests/test_sampler_dist.py index 0a439874bd..8b140f3ff8 100644 --- a/tests/test_distributed_sampler.py +++ b/tests/test_sampler_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_save_classificationd.py b/tests/test_save_classificationd.py index 67dc0320a6..c5a8b4705d 100644 --- a/tests/test_save_classificationd.py +++ b/tests/test_save_classificationd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,6 +13,7 @@ import os import tempfile import unittest +from pathlib import Path import numpy as np import torch @@ -39,7 +40,7 @@ def test_saved_content(self): }, ] - saver = CSVSaver(output_dir=tempdir, filename="predictions2.csv", overwrite=False, flush=False) + saver = CSVSaver(output_dir=Path(tempdir), filename="predictions2.csv", overwrite=False, flush=False) # set up test transforms post_trans = Compose( [ @@ -49,7 +50,7 @@ def test_saved_content(self): keys="pred", saver=None, meta_keys=None, - output_dir=tempdir, + output_dir=Path(tempdir), filename="predictions1.csv", overwrite=True, ), @@ -83,7 +84,7 @@ def test_saved_content(self): def _test_file(filename, count): filepath = os.path.join(tempdir, filename) self.assertTrue(os.path.exists(filepath)) - with open(filepath, "r") as f: + with open(filepath) as f: reader = csv.reader(f) i = 0 for row in reader: diff --git a/tests/test_save_image.py b/tests/test_save_image.py index f7c8e07f06..d3671cf830 100644 --- a/tests/test_save_image.py +++ b/tests/test_save_image.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,19 +18,9 @@ from monai.transforms import SaveImage -TEST_CASE_1 = [ - torch.randint(0, 255, (1, 2, 3, 4)), - {"filename_or_obj": "testfile0.nii.gz"}, - ".nii.gz", - False, -] - -TEST_CASE_2 = [ - torch.randint(0, 255, (1, 2, 3, 4)), - None, - ".nii.gz", - False, -] +TEST_CASE_1 = [torch.randint(0, 255, (1, 2, 3, 4)), {"filename_or_obj": "testfile0.nii.gz"}, ".nii.gz", False] + +TEST_CASE_2 = [torch.randint(0, 255, (1, 2, 3, 4)), None, ".nii.gz", False] class TestSaveImage(unittest.TestCase): diff --git a/tests/test_save_imaged.py b/tests/test_save_imaged.py index 35bbea9628..d84e582621 100644 --- a/tests/test_save_imaged.py +++ b/tests/test_save_imaged.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,10 +19,7 @@ from monai.transforms import SaveImaged TEST_CASE_1 = [ - { - "img": torch.randint(0, 255, (1, 2, 3, 4)), - "img_meta_dict": {"filename_or_obj": "testfile0.nii.gz"}, - }, + {"img": torch.randint(0, 255, (1, 2, 3, 4)), "img_meta_dict": {"filename_or_obj": "testfile0.nii.gz"}}, ".nii.gz", False, ] diff --git a/tests/test_savitzky_golay_filter.py b/tests/test_savitzky_golay_filter.py index c9bcd9687e..b410a641ea 100644 --- a/tests/test_savitzky_golay_filter.py +++ b/tests/test_savitzky_golay_filter.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -99,13 +99,7 @@ class TestSavitzkyGolayCPU(unittest.TestCase): @parameterized.expand( - [ - TEST_CASE_SINGLE_VALUE, - TEST_CASE_1D, - TEST_CASE_2D_AXIS_2, - TEST_CASE_2D_AXIS_3, - TEST_CASE_SINE_SMOOTH, - ] + [TEST_CASE_SINGLE_VALUE, TEST_CASE_1D, TEST_CASE_2D_AXIS_2, TEST_CASE_2D_AXIS_3, TEST_CASE_SINE_SMOOTH] ) def test_value(self, arguments, image, expected_data, atol): result = SavitzkyGolayFilter(**arguments)(image) @@ -124,13 +118,7 @@ def test_value(self, arguments, image, expected_data, atol): @skip_if_no_cuda class TestSavitzkyGolayGPU(unittest.TestCase): @parameterized.expand( - [ - TEST_CASE_SINGLE_VALUE, - TEST_CASE_1D, - TEST_CASE_2D_AXIS_2, - TEST_CASE_2D_AXIS_3, - TEST_CASE_SINE_SMOOTH, - ] + [TEST_CASE_SINGLE_VALUE, TEST_CASE_1D, TEST_CASE_2D_AXIS_2, TEST_CASE_2D_AXIS_3, TEST_CASE_SINE_SMOOTH] ) def test_value(self, arguments, image, expected_data, atol): result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda")) @@ -140,12 +128,7 @@ def test_value(self, arguments, image, expected_data, atol): @skip_if_no_cuda class TestSavitzkyGolayGPUREP(unittest.TestCase): @parameterized.expand( - [ - TEST_CASE_SINGLE_VALUE_REP, - TEST_CASE_1D_REP, - TEST_CASE_2D_AXIS_2_REP, - TEST_CASE_2D_AXIS_3_REP, - ] + [TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP] ) def test_value(self, arguments, image, expected_data, atol): result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda")) diff --git a/tests/test_savitzky_golay_smooth.py b/tests/test_savitzky_golay_smooth.py index 45d0ea3e4d..ac42cf806e 100644 --- a/tests/test_savitzky_golay_smooth.py +++ b/tests/test_savitzky_golay_smooth.py @@ -1,4 +1,4 @@ -# Copyright 2020 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,14 +25,14 @@ np.expand_dims(np.array([1.0]), 0), # Input data: Single value np.expand_dims(np.array([1 / 3]), 0), # Expected output: With a window length of 3 and polyorder 1 # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) - 1e-15, # absolute tolerance + 1e-5, # absolute tolerance ] TEST_CASE_2D_AXIS_2 = [ {"window_length": 3, "order": 1, "axis": 2}, # along axis 2 (second spatial dim) np.expand_dims(np.ones((2, 3)), 0), np.expand_dims(np.array([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]), 0), - 1e-15, # absolute tolerance + 1e-5, # absolute tolerance ] # Replicated-padding trivial tests @@ -42,7 +42,7 @@ np.expand_dims(np.array([1.0]), 0), # Input data: Single value np.expand_dims(np.array([1.0]), 0), # Expected output: With a window length of 3 and polyorder 1 # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) - 1e-15, # absolute tolerance + 1e-5, # absolute tolerance ] # Sine smoothing @@ -59,19 +59,13 @@ class TestSavitzkyGolaySmooth(unittest.TestCase): - @parameterized.expand([TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH]) + @parameterized.expand( + [TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH, TEST_CASE_SINGLE_VALUE_REP] + ) def test_value(self, arguments, image, expected_data, atol): for p in TEST_NDARRAYS: - result = SavitzkyGolaySmooth(**arguments)(p(image)) - torch.testing.assert_allclose(result, p(expected_data.astype(np.float32)), rtol=1e-7, atol=atol) - - -class TestSavitzkyGolaySmoothREP(unittest.TestCase): - @parameterized.expand([TEST_CASE_SINGLE_VALUE_REP]) - def test_value(self, arguments, image, expected_data, atol): - for p in TEST_NDARRAYS: - result = SavitzkyGolaySmooth(**arguments)(p(image)) - torch.testing.assert_allclose(result, p(expected_data.astype(np.float32)), rtol=1e-7, atol=atol) + result = SavitzkyGolaySmooth(**arguments)(p(image.astype(np.float32))) + torch.testing.assert_allclose(result, p(expected_data.astype(np.float32)), rtol=1e-4, atol=atol) if __name__ == "__main__": diff --git a/tests/test_savitzky_golay_smoothd.py b/tests/test_savitzky_golay_smoothd.py new file mode 100644 index 0000000000..6f0b33f533 --- /dev/null +++ b/tests/test_savitzky_golay_smoothd.py @@ -0,0 +1,72 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import SavitzkyGolaySmoothd +from tests.utils import TEST_NDARRAYS + +# Zero-padding trivial tests + +TEST_CASE_SINGLE_VALUE = [ + {"keys": "img", "window_length": 3, "order": 1}, + np.expand_dims(np.array([1.0]), 0), # Input data: Single value + np.expand_dims(np.array([1 / 3]), 0), # Expected output: With a window length of 3 and polyorder 1 + # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) + 1e-5, # absolute tolerance +] + +TEST_CASE_2D_AXIS_2 = [ + {"keys": "img", "window_length": 3, "order": 1, "axis": 2}, # along axis 2 (second spatial dim) + np.expand_dims(np.ones((2, 3)), 0), + np.expand_dims(np.array([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]), 0), + 1e-5, # absolute tolerance +] + +# Replicated-padding trivial tests + +TEST_CASE_SINGLE_VALUE_REP = [ + {"keys": "img", "window_length": 3, "order": 1, "mode": "replicate"}, + np.expand_dims(np.array([1.0]), 0), # Input data: Single value + np.expand_dims(np.array([1.0]), 0), # Expected output: With a window length of 3 and polyorder 1 + # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) + 1e-5, # absolute tolerance +] + +# Sine smoothing + +TEST_CASE_SINE_SMOOTH = [ + {"keys": "img", "window_length": 3, "order": 1}, + # Sine wave with period equal to savgol window length (windowed to reduce edge effects). + np.expand_dims(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100), 0), + # Should be smoothed out to zeros + np.expand_dims(np.zeros(100), 0), + # tolerance chosen by examining output of SciPy.signal.savgol_filter() when provided the above input + 2e-2, # absolute tolerance +] + + +class TestSavitzkyGolaySmoothd(unittest.TestCase): + @parameterized.expand( + [TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH, TEST_CASE_SINGLE_VALUE_REP] + ) + def test_value(self, arguments, image, expected_data, atol): + for p in TEST_NDARRAYS: + result = SavitzkyGolaySmoothd(**arguments)({"img": p(image.astype(np.float32))})["img"] + torch.testing.assert_allclose(result, p(expected_data.astype(np.float32)), rtol=1e-4, atol=atol) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_scale_intensity.py b/tests/test_scale_intensity.py index c2485af616..b81351ed2e 100644 --- a/tests/test_scale_intensity.py +++ b/tests/test_scale_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import numpy as np from monai.transforms import ScaleIntensity +from monai.transforms.utils import rescale_array from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose @@ -26,14 +27,45 @@ def test_range_scale(self): maxa = self.imt.max() norm = (self.imt - mina) / (maxa - mina) expected = p((norm * (2.0 - 1.0)) + 1.0) - assert_allclose(result, expected, rtol=1e-7, atol=0) + assert_allclose(result, expected, type_test=False, rtol=1e-7, atol=0) def test_factor_scale(self): for p in TEST_NDARRAYS: scaler = ScaleIntensity(minv=None, maxv=None, factor=0.1) result = scaler(p(self.imt)) expected = p((self.imt * (1 + 0.1)).astype(np.float32)) - assert_allclose(result, expected, rtol=1e-7, atol=0) + assert_allclose(result, p(expected), rtol=1e-7, atol=0) + + def test_max_none(self): + for p in TEST_NDARRAYS: + scaler = ScaleIntensity(minv=0.0, maxv=None, factor=0.1) + result = scaler(p(self.imt)) + expected = rescale_array(p(self.imt), minv=0.0, maxv=None) + assert_allclose(result, expected, rtol=1e-3, atol=1e-3) + + def test_int(self): + """integers should be handled by converting them to floats first.""" + for p in TEST_NDARRAYS: + scaler = ScaleIntensity(minv=1.0, maxv=2.0) + result = scaler(p(self.imt.astype(int))) + _imt = self.imt.astype(int).astype(np.float32) + mina = _imt.min() + maxa = _imt.max() + norm = (_imt - mina) / (maxa - mina) + expected = p((norm * (2.0 - 1.0)) + 1.0) + assert_allclose(result, expected, type_test=False, rtol=1e-7, atol=0) + + def test_channel_wise(self): + for p in TEST_NDARRAYS: + scaler = ScaleIntensity(minv=1.0, maxv=2.0, channel_wise=True) + data = p(self.imt) + result = scaler(data) + mina = self.imt.min() + maxa = self.imt.max() + for i, c in enumerate(data): + norm = (c - mina) / (maxa - mina) + expected = p((norm * (2.0 - 1.0)) + 1.0) + assert_allclose(result[i], expected, type_test=False, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_scale_intensity_range.py b/tests/test_scale_intensity_range.py index cba07d9157..faddf9001b 100644 --- a/tests/test_scale_intensity_range.py +++ b/tests/test_scale_intensity_range.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,16 +14,25 @@ import numpy as np from monai.transforms import ScaleIntensityRange -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class IntensityScaleIntensityRange(NumpyImageTestCase2D): def test_image_scale_intensity_range(self): - scaler = ScaleIntensityRange(a_min=20, a_max=108, b_min=50, b_max=80) - scaled = scaler(self.imt) - expected = (self.imt - 20) / 88 - expected = expected * 30 + 50 - self.assertTrue(np.allclose(scaled, expected)) + scaler = ScaleIntensityRange(a_min=20, a_max=108, b_min=50, b_max=80, dtype=np.uint8) + for p in TEST_NDARRAYS: + scaled = scaler(p(self.imt)) + self.assertTrue(scaled.dtype, np.uint8) + expected = (((self.imt - 20) / 88) * 30 + 50).astype(np.uint8) + assert_allclose(scaled, p(expected)) + + def test_image_scale_intensity_range_none_clip(self): + scaler = ScaleIntensityRange(a_min=20, a_max=108, b_min=None, b_max=80, clip=True, dtype=np.uint8) + for p in TEST_NDARRAYS: + scaled = scaler(p(self.imt)) + self.assertTrue(scaled.dtype, np.uint8) + expected = (np.clip((self.imt - 20) / 88, None, 80)).astype(np.uint8) + assert_allclose(scaled, p(expected)) if __name__ == "__main__": diff --git a/tests/test_scale_intensity_range_percentiles.py b/tests/test_scale_intensity_range_percentiles.py index 015162c8de..6ccfaf7ba6 100644 --- a/tests/test_scale_intensity_range_percentiles.py +++ b/tests/test_scale_intensity_range_percentiles.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,12 +14,12 @@ import numpy as np from monai.transforms.intensity.array import ScaleIntensityRangePercentiles -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestScaleIntensityRangePercentiles(NumpyImageTestCase2D): def test_scaling(self): - img = self.imt + img = self.imt[0] lower = 10 upper = 99 b_min = 0 @@ -27,13 +27,14 @@ def test_scaling(self): a_min = np.percentile(img, lower) a_max = np.percentile(img, upper) - expected = (img - a_min) / (a_max - a_min) - expected = (expected * (b_max - b_min)) + b_min - scaler = ScaleIntensityRangePercentiles(lower=lower, upper=upper, b_min=b_min, b_max=b_max) - self.assertTrue(np.allclose(expected, scaler(img))) + expected = (((img - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min).astype(np.uint8) + scaler = ScaleIntensityRangePercentiles(lower=lower, upper=upper, b_min=b_min, b_max=b_max, dtype=np.uint8) + for p in TEST_NDARRAYS: + result = scaler(p(img)) + assert_allclose(result, p(expected), rtol=1e-4) def test_relative_scaling(self): - img = self.imt + img = self.imt[0] lower = 10 upper = 99 b_min = 100 @@ -47,7 +48,16 @@ def test_relative_scaling(self): expected_img = (img - expected_a_min) / (expected_a_max - expected_a_min) expected_img = (expected_img * (expected_b_max - expected_b_min)) + expected_b_min - self.assertTrue(np.allclose(expected_img, scaler(img))) + for p in TEST_NDARRAYS: + result = scaler(p(img)) + assert_allclose(result, p(expected_img), rtol=1e-3) + + scaler = ScaleIntensityRangePercentiles( + lower=lower, upper=upper, b_min=b_min, b_max=b_max, relative=True, clip=True + ) + for p in TEST_NDARRAYS: + result = scaler(p(img)) + assert_allclose(result, p(np.clip(expected_img, expected_b_min, expected_b_max)), rtol=1e-4) def test_invalid_instantiation(self): self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=-10, upper=99, b_min=0, b_max=255) @@ -55,6 +65,26 @@ def test_invalid_instantiation(self): self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=30, upper=-20, b_min=0, b_max=255) self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=30, upper=900, b_min=0, b_max=255) + def test_channel_wise(self): + img = self.imt[0] + lower = 10 + upper = 99 + b_min = 0 + b_max = 255 + scaler = ScaleIntensityRangePercentiles( + lower=lower, upper=upper, b_min=b_min, b_max=b_max, channel_wise=True, dtype=np.uint8 + ) + expected = [] + for c in img: + a_min = np.percentile(c, lower) + a_max = np.percentile(c, upper) + expected.append(((c - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min) + expected = np.stack(expected).astype(np.uint8) + + for p in TEST_NDARRAYS: + result = scaler(p(img)) + assert_allclose(result, p(expected), rtol=1e-4) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_scale_intensity_range_percentilesd.py b/tests/test_scale_intensity_range_percentilesd.py index 9d0fe8284a..cd0a5f8c35 100644 --- a/tests/test_scale_intensity_range_percentilesd.py +++ b/tests/test_scale_intensity_range_percentilesd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,14 +14,12 @@ import numpy as np from monai.transforms.intensity.dictionary import ScaleIntensityRangePercentilesd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestScaleIntensityRangePercentilesd(NumpyImageTestCase2D): def test_scaling(self): img = self.imt - data = {} - data["img"] = img lower = 10 upper = 99 b_min = 0 @@ -29,12 +27,14 @@ def test_scaling(self): a_min = np.percentile(img, lower) a_max = np.percentile(img, upper) - expected = (img - a_min) / (a_max - a_min) - expected = (expected * (b_max - b_min)) + b_min - - scaler = ScaleIntensityRangePercentilesd(keys=data.keys(), lower=lower, upper=upper, b_min=b_min, b_max=b_max) + expected = (((img - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min).astype(np.uint8) - self.assertTrue(np.allclose(expected, scaler(data)["img"])) + for p in TEST_NDARRAYS: + data = {"img": p(img)} + scaler = ScaleIntensityRangePercentilesd( + keys=data.keys(), lower=lower, upper=upper, b_min=b_min, b_max=b_max, dtype=np.uint8 + ) + assert_allclose(p(expected), scaler(data)["img"], rtol=1e-4) def test_relative_scaling(self): img = self.imt @@ -55,7 +55,7 @@ def test_relative_scaling(self): expected_img = (img - expected_a_min) / (expected_a_max - expected_a_min) expected_img = (expected_img * (expected_b_max - expected_b_min)) + expected_b_min - self.assertTrue(np.allclose(expected_img, scaler(data)["img"])) + np.testing.assert_allclose(expected_img, scaler(data)["img"]) def test_invalid_instantiation(self): self.assertRaises( @@ -70,6 +70,29 @@ def test_invalid_instantiation(self): self.assertRaises( ValueError, ScaleIntensityRangePercentilesd, keys=["img"], lower=30, upper=1000, b_min=0, b_max=255 ) + with self.assertRaises(ValueError): + s = ScaleIntensityRangePercentilesd(keys=["img"], lower=30, upper=90, b_min=None, b_max=20, relative=True) + s(self.imt) + + def test_channel_wise(self): + img = self.imt + lower = 10 + upper = 99 + b_min = 0 + b_max = 255 + scaler = ScaleIntensityRangePercentilesd( + keys="img", lower=lower, upper=upper, b_min=b_min, b_max=b_max, channel_wise=True, dtype=np.uint8 + ) + expected = [] + for c in img: + a_min = np.percentile(c, lower) + a_max = np.percentile(c, upper) + expected.append((((c - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min).astype(np.uint8)) + expected = np.stack(expected) + + for p in TEST_NDARRAYS: + data = {"img": p(img)} + assert_allclose(scaler(data)["img"], p(expected), rtol=1e-4) if __name__ == "__main__": diff --git a/tests/test_scale_intensity_ranged.py b/tests/test_scale_intensity_ranged.py index a8cac414e8..ffbd3e44c4 100644 --- a/tests/test_scale_intensity_ranged.py +++ b/tests/test_scale_intensity_ranged.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,20 +11,27 @@ import unittest -import numpy as np - from monai.transforms import ScaleIntensityRanged -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class IntensityScaleIntensityRanged(NumpyImageTestCase2D): def test_image_scale_intensity_ranged(self): key = "img" scaler = ScaleIntensityRanged(keys=key, a_min=20, a_max=108, b_min=50, b_max=80) - scaled = scaler({key: self.imt}) - expected = (self.imt - 20) / 88 - expected = expected * 30 + 50 - self.assertTrue(np.allclose(scaled[key], expected)) + for p in TEST_NDARRAYS: + scaled = scaler({key: p(self.imt)}) + expected = (self.imt - 20) / 88 + expected = expected * 30 + 50 + assert_allclose(scaled[key], p(expected)) + + def test_image_scale_intensity_ranged_none(self): + key = "img" + scaler = ScaleIntensityRanged(keys=key, a_min=20, a_max=108, b_min=None, b_max=None) + for p in TEST_NDARRAYS: + scaled = scaler({key: p(self.imt)}) + expected = (self.imt - 20) / 88 + assert_allclose(scaled[key], p(expected)) if __name__ == "__main__": diff --git a/tests/test_scale_intensityd.py b/tests/test_scale_intensityd.py index 6e13dbc272..42f1527490 100644 --- a/tests/test_scale_intensityd.py +++ b/tests/test_scale_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,23 +19,36 @@ class TestScaleIntensityd(NumpyImageTestCase2D): def test_range_scale(self): + key = "img" for p in TEST_NDARRAYS: - key = "img" scaler = ScaleIntensityd(keys=[key], minv=1.0, maxv=2.0) result = scaler({key: p(self.imt)}) mina = np.min(self.imt) maxa = np.max(self.imt) norm = (self.imt - mina) / (maxa - mina) expected = (norm * (2.0 - 1.0)) + 1.0 - assert_allclose(result[key], expected) + assert_allclose(result[key], p(expected)) def test_factor_scale(self): + key = "img" for p in TEST_NDARRAYS: - key = "img" scaler = ScaleIntensityd(keys=[key], minv=None, maxv=None, factor=0.1) result = scaler({key: p(self.imt)}) expected = (self.imt * (1 + 0.1)).astype(np.float32) - assert_allclose(result[key], expected) + assert_allclose(result[key], p(expected)) + + def test_channel_wise(self): + key = "img" + for p in TEST_NDARRAYS: + scaler = ScaleIntensityd(keys=[key], minv=1.0, maxv=2.0, channel_wise=True) + data = p(self.imt) + result = scaler({key: data}) + mina = self.imt.min() + maxa = self.imt.max() + for i, c in enumerate(data): + norm = (c - mina) / (maxa - mina) + expected = p((norm * (2.0 - 1.0)) + 1.0) + assert_allclose(result[key][i], expected, type_test=False, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_se_block.py b/tests/test_se_block.py index 1f515a7fb4..88983a7746 100644 --- a/tests/test_se_block.py +++ b/tests/test_se_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_se_blocks.py b/tests/test_se_blocks.py index e9aed7d9d9..400ee85e7f 100644 --- a/tests/test_se_blocks.py +++ b/tests/test_se_blocks.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_seg_loss_integration.py b/tests/test_seg_loss_integration.py index d2f991f160..f4a0f25267 100644 --- a/tests/test_seg_loss_integration.py +++ b/tests/test_seg_loss_integration.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -91,7 +91,7 @@ def test_convergence(self, loss_type, loss_args, forward_args): # define a one layer model class OnelayerNet(nn.Module): def __init__(self): - super(OnelayerNet, self).__init__() + super().__init__() self.layer_1 = nn.Linear(num_voxels, 200) self.acti = nn.ReLU() self.layer_2 = nn.Linear(200, num_voxels * num_classes) diff --git a/tests/test_segresnet.py b/tests/test_segresnet.py index ea6ca5b5dd..b7c37f87b9 100644 --- a/tests/test_segresnet.py +++ b/tests/test_segresnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -70,6 +70,7 @@ "init_filters": init_filters, "out_channels": out_channels, "upsample_mode": upsample_mode, + "act": ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), "input_image_size": ([16] * spatial_dims), "vae_estimate_std": vae_estimate_std, }, @@ -88,7 +89,7 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) def test_ill_arg(self): - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): SegResNet(spatial_dims=4) def test_script(self): diff --git a/tests/test_segresnet_block.py b/tests/test_segresnet_block.py index eb8cc9676b..9bb435ac1d 100644 --- a/tests/test_segresnet_block.py +++ b/tests/test_segresnet_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_select_cross_validation_folds.py b/tests/test_select_cross_validation_folds.py index 6dbd004e71..7693baca80 100644 --- a/tests/test_select_cross_validation_folds.py +++ b/tests/test_select_cross_validation_folds.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_select_itemsd.py b/tests/test_select_itemsd.py index bf63864eb0..ba75a27cff 100644 --- a/tests/test_select_itemsd.py +++ b/tests/test_select_itemsd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 3d561aac2f..407bee341c 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -28,11 +28,7 @@ for num_heads in [4, 6, 8, 12]: test_case = [ - { - "hidden_size": hidden_size, - "num_heads": num_heads, - "dropout_rate": dropout_rate, - }, + {"hidden_size": hidden_size, "num_heads": num_heads, "dropout_rate": dropout_rate}, (2, 512, hidden_size), (2, 512, hidden_size), ] diff --git a/tests/test_senet.py b/tests/test_senet.py index 1c6222d6a0..be7f571e0b 100644 --- a/tests/test_senet.py +++ b/tests/test_senet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest from typing import TYPE_CHECKING from unittest import skipUnless @@ -16,10 +17,11 @@ import torch from parameterized import parameterized +import monai.networks.nets.senet as se_mod from monai.networks import eval_mode from monai.networks.nets import SENet154, SEResNet50, SEResNet101, SEResNet152, SEResNext50, SEResNext101 from monai.utils import optional_import -from tests.utils import test_pretrained_networks, test_script_save +from tests.utils import test_is_quick, test_pretrained_networks, test_script_save if TYPE_CHECKING: import pretrainedmodels @@ -31,6 +33,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu" + NET_ARGS = {"spatial_dims": 3, "in_channels": 2, "num_classes": 2} TEST_CASE_1 = [SENet154, NET_ARGS] TEST_CASE_2 = [SEResNet50, NET_ARGS] @@ -60,6 +63,43 @@ def test_script(self, net, net_args): class TestPretrainedSENET(unittest.TestCase): + def setUp(self): + self.original_urls = se_mod.SE_NET_MODELS.copy() + if test_is_quick(): + testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") + testing_data_urls = { + "senet154": { + "url": "https://drive.google.com/uc?id=1e10LFGVIV9L8_Q5Fhwi3X5nU6mDRrCDh", + "filename": "senet154-c7b49a05.pth", + }, + "se_resnet50": { + "url": "https://drive.google.com/uc?id=1WCeveS0tvjta4Wcp1wAGRi_uyXRfXAGA", + "filename": "se_resnet50-ce0d4300.pth", + }, + "se_resnet101": { + "url": "https://drive.google.com/uc?id=1Bh0PmLISUltsY8FevtlTbt6vT35clzWg", + "filename": "se_resnet101-7e38fcc6.pth", + }, + "se_resnet152": { + "url": "https://drive.google.com/uc?id=1fcqpP0ITOcALy_TZAcBdkyf7HcH687J-", + "filename": "se_resnet152-d17c99b7.pth", + }, + "se_resnext50_32x4d": { + "url": "https://drive.google.com/uc?id=1kRKW8YjGaEwYdQUyhoCIDg1H9ZAoJ-jI", + "filename": "se_resnext50_32x4d-a260b3a4.pth", + }, + "se_resnext101_32x4d": { + "url": "https://drive.google.com/uc?id=1Tg6Zim1lXgmYgH7FyTXAgihbkq5Jegni", + "filename": "se_resnext101_32x4d-3b2fe3d8.pth", + }, + } + for item in testing_data_urls: + testing_data_urls[item]["filename"] = os.path.join(testing_dir, testing_data_urls[item]["filename"]) + se_mod.SE_NET_MODELS = testing_data_urls + + def tearDown(self): + se_mod.SE_NET_MODELS = self.original_urls.copy() + @parameterized.expand([TEST_CASE_PRETRAINED_1]) def test_senet_shape(self, model, input_param): net = test_pretrained_networks(model, input_param, device) diff --git a/tests/test_separable_filter.py b/tests/test_separable_filter.py new file mode 100644 index 0000000000..167add9556 --- /dev/null +++ b/tests/test_separable_filter.py @@ -0,0 +1,85 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from monai.networks.layers import separable_filtering + + +class SeparableFilterTestCase(unittest.TestCase): + def test_1d(self): + a = torch.tensor([[list(range(10))]], dtype=torch.float) + out = separable_filtering(a, torch.tensor([-1, 0, 1])) + expected = np.array([[[1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, -8.0]]]) + np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4) + if torch.cuda.is_available(): + out = separable_filtering(a.cuda(), torch.tensor([-1, 0, 1]).cuda()) + np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4) + + def test_2d(self): + a = torch.tensor([[[list(range(7)), list(range(7, 0, -1)), list(range(7))]]], dtype=torch.float) + expected = np.array( + [ + [28.0, 28.0, 28.0, 28.0, 28.0, 28.0], + [30.0, 34.0, 38.0, 42.0, 46.0, 50.0], + [28.0, 28.0, 28.0, 28.0, 28.0, 28.0], + ] + ) + expected = expected[None][None] + out = separable_filtering(a, [torch.tensor([1, 1, 1]), torch.tensor([2, 2])]) + np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4) + if torch.cuda.is_available(): + out = separable_filtering(a.cuda(), [torch.tensor([1, 1, 1]).cuda(), torch.tensor([2, 2]).cuda()]) + np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4) + + def test_3d(self): + a = torch.tensor( + [[list(range(7)), list(range(7)), list(range(7))], [list(range(7)), list(range(7)), list(range(7))]], + dtype=torch.float, + ) + a = a[None][None] + a = a.expand(2, 3, -1, -1, -1) + expected = np.array( + [ + [ + [4.0, 12.0, 24.0, 36.0, 48.0, 60.0, 44.0], + [6.0, 18.0, 36.0, 54.0, 72.0, 90.0, 66.0], + [4.0, 12.0, 24.0, 36.0, 48.0, 60.0, 44.0], + ], + [ + [4.0, 12.0, 24.0, 36.0, 48.0, 60.0, 44.0], + [6.0, 18.0, 36.0, 54.0, 72.0, 90.0, 66.0], + [4.0, 12.0, 24.0, 36.0, 48.0, 60.0, 44.0], + ], + ] + ) + expected = expected + # testing shapes + k = torch.tensor([1, 1, 1]) + for kernel in (k, [k] * 3): + out = separable_filtering(a, kernel) + np.testing.assert_allclose(out.cpu().numpy()[1][2], expected, rtol=1e-4) + if torch.cuda.is_available(): + out = separable_filtering( + a.cuda(), kernel.cuda() if isinstance(kernel, torch.Tensor) else [k.cuda() for k in kernel] + ) + np.testing.assert_allclose(out.cpu().numpy()[0][1], expected, rtol=1e-4) + + def test_wrong_args(self): + with self.assertRaisesRegex(TypeError, ""): + separable_filtering(((1, 1, 1, 2, 3, 2)), torch.ones((2,))) # type: ignore + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_set_determinism.py b/tests/test_set_determinism.py index 537aa36676..7d6c54909d 100644 --- a/tests/test_set_determinism.py +++ b/tests/test_set_determinism.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -40,6 +40,8 @@ def test_values(self): self.assertEqual(seed, get_seed()) a = np.random.randint(seed) b = torch.randint(seed, (1,)) + # tset when global flag support is disabled + torch.backends.disable_global_flags() set_determinism(seed=seed) c = np.random.randint(seed) d = torch.randint(seed, (1,)) diff --git a/tests/test_set_visible_devices.py b/tests/test_set_visible_devices.py index b6da879f4b..75cbd6fb0d 100644 --- a/tests/test_set_visible_devices.py +++ b/tests/test_set_visible_devices.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_shift_intensity.py b/tests/test_shift_intensity.py index b73c18b6a5..ecded268ab 100644 --- a/tests/test_shift_intensity.py +++ b/tests/test_shift_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_shift_intensityd.py b/tests/test_shift_intensityd.py index 0396857781..a37aaf4a6b 100644 --- a/tests/test_shift_intensityd.py +++ b/tests/test_shift_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,7 +24,7 @@ def test_value(self): shifter = ShiftIntensityd(keys=[key], offset=1.0) result = shifter({key: p(self.imt)}) expected = self.imt + 1.0 - assert_allclose(result[key], expected) + assert_allclose(result[key], p(expected)) def test_factor(self): key = "img" diff --git a/tests/test_simple_aspp.py b/tests/test_simple_aspp.py index fbc8cb37d1..9c952bd791 100644 --- a/tests/test_simple_aspp.py +++ b/tests/test_simple_aspp.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_simulatedelay.py b/tests/test_simulatedelay.py index 3a4686218e..abf629f5c5 100644 --- a/tests/test_simulatedelay.py +++ b/tests/test_simulatedelay.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_simulatedelayd.py b/tests/test_simulatedelayd.py index 58bd3eb6b8..cbabb68e0f 100644 --- a/tests/test_simulatedelayd.py +++ b/tests/test_simulatedelayd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_skip_connection.py b/tests/test_skip_connection.py index 2118842ed0..f523891084 100644 --- a/tests/test_skip_connection.py +++ b/tests/test_skip_connection.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,11 +24,7 @@ result_shape = (input_shape[0] * 2, *input_shape[1:]) else: result_shape = input_shape - test_case = [ - {"dim": 0, "mode": type_1}, - input_shape, - result_shape, - ] + test_case = [{"dim": 0, "mode": type_1}, input_shape, result_shape] TEST_CASES_3D.append(test_case) diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py index a22e5990bf..3d4ad3151b 100644 --- a/tests/test_sliding_window_inference.py +++ b/tests/test_sliding_window_inference.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -33,14 +33,7 @@ [(1, 3, 16, 7), (80, 50), 7, 0.5, "gaussian", torch.device("cpu:0")], # 2D large overlap, gaussian [(1, 3, 16, 15, 7), (4, 10, 7), 3, 0.25, "gaussian", torch.device("cpu:0")], # 3D small roi, gaussian [(3, 3, 16, 15, 7), (4, 10, 7), 3, 0.25, "gaussian", torch.device("cpu:0")], # 3D small roi, gaussian - [ - (1, 3, 16, 15, 7), - (4, 10, 7), - 3, - 0.25, - "gaussian", - torch.device("cuda:0"), - ], # test inference on gpu if availabe + [(1, 3, 16, 15, 7), (4, 10, 7), 3, 0.25, "gaussian", torch.device("cuda:0")], # test inference on gpu if availabe [(1, 3, 16, 15, 7), (4, 1, 7), 3, 0.25, "constant", torch.device("cpu:0")], # 3D small roi [(5, 3, 16, 15, 7), (4, 1, 7), 3, 0.25, "constant", torch.device("cpu:0")], # 3D small roi ] diff --git a/tests/test_smartcache_patch_wsi_dataset.py b/tests/test_smartcache_patch_wsi_dataset.py index c484e5fc69..73583b5eb1 100644 --- a/tests/test_smartcache_patch_wsi_dataset.py +++ b/tests/test_smartcache_patch_wsi_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,10 +21,12 @@ from monai.apps.utils import download_url from monai.utils import optional_import -_, has_cim = optional_import("cucim") +_cucim, has_cim = optional_import("cucim") +has_cim = has_cim and hasattr(_cucim, "CuImage") -FILE_URL = "http://openslide.cs.cmu.edu/download/openslide-testdata/Generic-TIFF/CMU-1.tiff" -FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + os.path.basename(FILE_URL)) +FILE_URL = "https://drive.google.com/uc?id=1sGTKZlJBIz53pfqTxoTqiIQzIoEzHLAe" +base_name, extension = FILE_URL.split("id=")[1], ".tiff" +FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + base_name + extension) TEST_CASE_0 = [ { @@ -43,6 +45,7 @@ "cache_num": 2, "num_init_workers": 1, "num_replace_workers": 1, + "copy_cache": False, }, [ {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0]]])}, @@ -133,13 +136,7 @@ class TestSmartCachePatchWSIDataset(unittest.TestCase): def setUp(self): download_url(FILE_URL, FILE_PATH, "5a3cfd4fd725c50578ddb80b517b759f") - @parameterized.expand( - [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - ] - ) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) @skipUnless(has_cim, "Requires CuCIM") def test_read_patches(self, input_parameters, expected): dataset = SmartCachePatchWSIDataset(**input_parameters) diff --git a/tests/test_smartcachedataset.py b/tests/test_smartcachedataset.py index e2675f4d8c..e7d51be63a 100644 --- a/tests/test_smartcachedataset.py +++ b/tests/test_smartcachedataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -174,13 +174,7 @@ def test_datalist(self): data_list = [np.array([i]) for i in range(5)] data_list_backup = copy.copy(data_list) - SmartCacheDataset( - data=data_list, - transform=None, - cache_rate=0.5, - replace_rate=0.4, - shuffle=True, - ) + SmartCacheDataset(data=data_list, transform=None, cache_rate=0.5, replace_rate=0.4, shuffle=True) np.testing.assert_allclose(data_list, data_list_backup) diff --git a/tests/test_smooth_field.py b/tests/test_smooth_field.py new file mode 100644 index 0000000000..5849b96167 --- /dev/null +++ b/tests/test_smooth_field.py @@ -0,0 +1,143 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from itertools import product + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import RandSmoothDeformd, RandSmoothFieldAdjustContrastd, RandSmoothFieldAdjustIntensityd +from tests.utils import TEST_NDARRAYS, assert_allclose, is_tf32_env + +_rtol = 5e-3 if is_tf32_env() else 1e-4 + +INPUT_SHAPES = ((1, 8, 8), (2, 8, 8), (1, 8, 8, 8)) + +TESTS_CONTRAST = [] +TESTS_INTENSITY = [] +TESTS_DEFORM = [] + +KEY = "test" + +for arr_type, shape in product(TEST_NDARRAYS, INPUT_SHAPES): + in_arr = arr_type(np.ones(shape, np.float32)) + exp_arr = arr_type(np.ones(shape, np.float32)) + rand_size = (4,) * (len(shape) - 1) + + device = torch.device("cpu") + + if isinstance(in_arr, torch.Tensor) and in_arr.get_device() >= 0: + device = torch.device(in_arr.get_device()) + + TESTS_CONTRAST.append( + ( + {"keys": (KEY,), "spatial_size": shape[1:], "rand_size": rand_size, "prob": 1.0, "device": device}, + {KEY: in_arr}, + {KEY: exp_arr}, + ) + ) + + TESTS_INTENSITY.append( + ( + { + "keys": (KEY,), + "spatial_size": shape[1:], + "rand_size": rand_size, + "prob": 1.0, + "device": device, + "gamma": (0.9, 1), + }, + {KEY: in_arr}, + {KEY: exp_arr}, + ) + ) + + TESTS_DEFORM.append( + ( + { + "keys": (KEY,), + "spatial_size": shape[1:], + "rand_size": rand_size, + "prob": 1.0, + "device": device, + "def_range": 0.1, + }, + {KEY: in_arr}, + {KEY: exp_arr}, + ) + ) + + +class TestSmoothField(unittest.TestCase): + @parameterized.expand(TESTS_CONTRAST) + def test_rand_smooth_field_adjust_contrastd(self, input_param, input_data, expected_val): + g = RandSmoothFieldAdjustContrastd(**input_param) + g.set_random_state(123) + + res = g(input_data) + for key, result in res.items(): + expected = expected_val[key] + assert_allclose(result, expected, rtol=_rtol, atol=1e-1) + + def test_rand_smooth_field_adjust_contrastd_pad(self): + input_param, input_data, expected_val = TESTS_CONTRAST[0] + + g = RandSmoothFieldAdjustContrastd(pad=1, **input_param) + g.set_random_state(123) + + res = g(input_data) + for key, result in res.items(): + expected = expected_val[key] + assert_allclose(result, expected, rtol=_rtol, atol=1e-1) + + @parameterized.expand(TESTS_INTENSITY) + def test_rand_smooth_field_adjust_intensityd(self, input_param, input_data, expected_val): + g = RandSmoothFieldAdjustIntensityd(**input_param) + g.set_random_state(123) + + res = g(input_data) + for key, result in res.items(): + expected = expected_val[key] + assert_allclose(result, expected, rtol=_rtol, atol=1e-1) + + def test_rand_smooth_field_adjust_intensityd_pad(self): + input_param, input_data, expected_val = TESTS_INTENSITY[0] + + g = RandSmoothFieldAdjustIntensityd(pad=1, **input_param) + g.set_random_state(123) + + res = g(input_data) + for key, result in res.items(): + expected = expected_val[key] + assert_allclose(result, expected, rtol=_rtol, atol=1e-1) + + @parameterized.expand(TESTS_DEFORM) + def test_rand_smooth_deformd(self, input_param, input_data, expected_val): + g = RandSmoothDeformd(**input_param) + g.set_random_state(123) + + res = g(input_data) + for key, result in res.items(): + expected = expected_val[key] + assert_allclose(result, expected, rtol=_rtol, atol=1e-1) + + def test_rand_smooth_deformd_pad(self): + input_param, input_data, expected_val = TESTS_DEFORM[0] + + g = RandSmoothDeformd(pad=1, **input_param) + g.set_random_state(123) + + res = g(input_data) + for key, result in res.items(): + expected = expected_val[key] + assert_allclose(result, expected, rtol=_rtol, atol=1e-1) diff --git a/tests/test_spacing.py b/tests/test_spacing.py index 6be6730c5a..ebff25712d 100644 --- a/tests/test_spacing.py +++ b/tests/test_spacing.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,155 +12,204 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import Spacing from monai.utils import ensure_tuple, fall_back_tuple +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - {"pixdim": (1.0, 1.5), "padding_mode": "zeros", "dtype": float}, - np.arange(4).reshape((1, 2, 2)) + 1.0, # data - {"affine": np.eye(4)}, - np.array([[[1.0, 1.0], [3.0, 2.0]]]), - ], - [ - {"pixdim": 1.0, "padding_mode": "zeros", "dtype": float}, - np.ones((1, 2, 1, 2)), # data - {"affine": np.eye(4)}, - np.array([[[[1.0, 1.0]], [[1.0, 1.0]]]]), - ], - [ - {"pixdim": (1.0, 1.0, 1.0), "padding_mode": "zeros", "dtype": float}, - np.ones((1, 2, 1, 2)), # data - {"affine": np.eye(4)}, - np.array([[[[1.0, 1.0]], [[1.0, 1.0]]]]), - ], - [ - {"pixdim": (1.0, 0.2, 1.5), "diagonal": False, "padding_mode": "zeros", "align_corners": True}, - np.ones((1, 2, 1, 2)), # data - {"affine": np.array([[2, 1, 0, 4], [-1, -3, 0, 5], [0, 0, 2.0, 5], [0, 0, 0, 1]])}, - np.array([[[[0.95527864, 0.95527864]], [[1.0, 1.0]], [[1.0, 1.0]]]]), - ], - [ - {"pixdim": (3.0, 1.0), "padding_mode": "zeros"}, - np.arange(24).reshape((2, 3, 4)), # data - {"affine": np.diag([-3.0, 0.2, 1.5, 1])}, - np.array([[[0, 0], [4, 0], [8, 0]], [[12, 0], [16, 0], [20, 0]]]), - ], - [ - {"pixdim": (3.0, 1.0), "padding_mode": "zeros"}, - np.arange(24).reshape((2, 3, 4)), # data - {}, - np.array([[[0, 1, 2, 3], [0, 0, 0, 0]], [[12, 13, 14, 15], [0, 0, 0, 0]]]), - ], - [ - {"pixdim": (1.0, 1.0)}, - np.arange(24).reshape((2, 3, 4)), # data - {}, - np.array( - [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] - ), - ], - [ - {"pixdim": (4.0, 5.0, 6.0)}, - np.arange(24).reshape((1, 2, 3, 4)), # data - {"affine": np.array([[-4, 0, 0, 4], [0, 5, 0, -5], [0, 0, 6, -6], [0, 0, 0, 1]])}, - np.arange(24).reshape((1, 2, 3, 4)), # data - ], - [ - {"pixdim": (4.0, 5.0, 6.0), "diagonal": True}, - np.arange(24).reshape((1, 2, 3, 4)), # data - {"affine": np.array([[-4, 0, 0, 4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]])}, - np.array( - [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] - ), - ], - [ - {"pixdim": (4.0, 5.0, 6.0), "padding_mode": "border", "diagonal": True}, - np.arange(24).reshape((1, 2, 3, 4)), # data - {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]])}, - np.array( - [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] - ), - ], - [ - {"pixdim": (4.0, 5.0, 6.0), "padding_mode": "border", "diagonal": True}, - np.arange(24).reshape((1, 2, 3, 4)), # data - {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "nearest"}, - np.array( - [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] - ), - ], - [ - {"pixdim": (1.9, 4.0), "padding_mode": "zeros", "diagonal": True}, - np.arange(24).reshape((1, 4, 6)), # data - {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "nearest"}, - np.array( - [ +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + p, + {"pixdim": (1.0, 1.5), "padding_mode": "zeros", "dtype": float}, + np.arange(4).reshape((1, 2, 2)) + 1.0, # data + {"affine": np.eye(4)}, + np.array([[[1.0, 1.0], [3.0, 2.0]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": 1.0, "padding_mode": "zeros", "dtype": float}, + np.ones((1, 2, 1, 2)), # data + {"affine": np.eye(4)}, + np.array([[[[1.0, 1.0]], [[1.0, 1.0]]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (1.0, 1.0, 1.0), "padding_mode": "zeros", "dtype": float}, + np.ones((1, 2, 1, 2)), # data + {"affine": np.eye(4)}, + np.array([[[[1.0, 1.0]], [[1.0, 1.0]]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (1.0, 0.2, 1.5), "diagonal": False, "padding_mode": "zeros", "align_corners": True}, + np.ones((1, 2, 1, 2)), # data + {"affine": np.array([[2, 1, 0, 4], [-1, -3, 0, 5], [0, 0, 2.0, 5], [0, 0, 0, 1]])}, + np.array([[[[0.95527864, 0.95527864]], [[1.0, 1.0]], [[1.0, 1.0]]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (3.0, 1.0), "padding_mode": "zeros"}, + np.arange(24).reshape((2, 3, 4)), # data + {"affine": np.diag([-3.0, 0.2, 1.5, 1])}, + np.array([[[0, 0], [4, 0], [8, 0]], [[12, 0], [16, 0], [20, 0]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (3.0, 1.0), "padding_mode": "zeros"}, + np.arange(24).reshape((2, 3, 4)), # data + {}, + np.array([[[0, 1, 2, 3], [0, 0, 0, 0]], [[12, 13, 14, 15], [0, 0, 0, 0]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (1.0, 1.0)}, + np.arange(24).reshape((2, 3, 4)), # data + {}, + np.array( + [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (4.0, 5.0, 6.0)}, + np.arange(24).reshape((1, 2, 3, 4)), # data + {"affine": np.array([[-4, 0, 0, 4], [0, 5, 0, -5], [0, 0, 6, -6], [0, 0, 0, 1]])}, + np.arange(24).reshape((1, 2, 3, 4)), # data + ] + ) + TESTS.append( + [ + p, + {"pixdim": (4.0, 5.0, 6.0), "diagonal": True}, + np.arange(24).reshape((1, 2, 3, 4)), # data + {"affine": np.array([[-4, 0, 0, 4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]])}, + np.array( + [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (4.0, 5.0, 6.0), "padding_mode": "border", "diagonal": True}, + np.arange(24).reshape((1, 2, 3, 4)), # data + {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]])}, + np.array( + [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (4.0, 5.0, 6.0), "padding_mode": "border", "diagonal": True}, + np.arange(24).reshape((1, 2, 3, 4)), # data + {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "nearest"}, + np.array( + [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (1.9, 4.0), "padding_mode": "zeros", "diagonal": True}, + np.arange(24).reshape((1, 4, 6)), # data + {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "nearest"}, + np.array( [ - [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0], - [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0], - [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0], - [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0], - [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0], - [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0], - [0.0, 1.0, 2.0, 2.0, 3.0, 4.0, 5.0], + [ + [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0], + [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0], + [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0], + [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0], + [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0], + [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0], + [0.0, 1.0, 2.0, 2.0, 3.0, 4.0, 5.0], + ] ] - ] - ), - ], - [ - {"pixdim": (5.0, 3.0), "padding_mode": "border", "diagonal": True, "dtype": np.float32}, - np.arange(24).reshape((1, 4, 6)), # data - {"affine": np.array([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "bilinear"}, - np.array( - [ + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (5.0, 3.0), "padding_mode": "border", "diagonal": True, "dtype": np.float32}, + np.arange(24).reshape((1, 4, 6)), # data + {"affine": np.array([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "bilinear"}, + np.array( [ - [18.0, 18.6, 19.2, 19.8, 20.400002, 21.0, 21.6, 22.2, 22.8], - [10.5, 11.1, 11.700001, 12.299999, 12.900001, 13.5, 14.1, 14.700001, 15.3], - [3.0, 3.6000001, 4.2000003, 4.8, 5.4000006, 6.0, 6.6000004, 7.200001, 7.8], + [ + [18.0, 18.6, 19.2, 19.8, 20.400002, 21.0, 21.6, 22.2, 22.8], + [10.5, 11.1, 11.700001, 12.299999, 12.900001, 13.5, 14.1, 14.700001, 15.3], + [3.0, 3.6000001, 4.2000003, 4.8, 5.4000006, 6.0, 6.6000004, 7.200001, 7.8], + ] ] - ] - ), - ], - [ - {"pixdim": (5.0, 3.0), "padding_mode": "zeros", "diagonal": True, "dtype": np.float32}, - np.arange(24).reshape((1, 4, 6)), # data - {"affine": np.array([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "bilinear"}, - np.array( - [ + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (5.0, 3.0), "padding_mode": "zeros", "diagonal": True, "dtype": np.float32}, + np.arange(24).reshape((1, 4, 6)), # data + {"affine": np.array([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "bilinear"}, + np.array( [ - [18.0000, 18.6000, 19.2000, 19.8000, 20.4000, 21.0000, 21.6000, 22.2000, 22.8000], - [10.5000, 11.1000, 11.7000, 12.3000, 12.9000, 13.5000, 14.1000, 14.7000, 15.3000], - [3.0000, 3.6000, 4.2000, 4.8000, 5.4000, 6.0000, 6.6000, 7.2000, 7.8000], + [ + [18.0000, 18.6000, 19.2000, 19.8000, 20.4000, 21.0000, 21.6000, 22.2000, 22.8000], + [10.5000, 11.1000, 11.7000, 12.3000, 12.9000, 13.5000, 14.1000, 14.7000, 15.3000], + [3.0000, 3.6000, 4.2000, 4.8000, 5.4000, 6.0000, 6.6000, 7.2000, 7.8000], + ] ] - ] - ), - ], - [ - {"pixdim": [-1, -1, 0.5], "padding_mode": "zeros", "dtype": float}, - np.ones((1, 2, 1, 2)), # data - {"affine": np.eye(4)}, - np.array([[[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]]]), - ], -] + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": [-1, -1, 0.5], "padding_mode": "zeros", "dtype": float}, + np.ones((1, 2, 1, 2)), # data + {"affine": np.eye(4)}, + np.array([[[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]]]), + ] + ) class TestSpacingCase(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_spacing(self, init_param, img, data_param, expected_output): - res = Spacing(**init_param)(img, **data_param) - if not isinstance(res, tuple): - np.testing.assert_allclose(res, expected_output, atol=1e-6) - return - np.testing.assert_allclose(res[0], expected_output, atol=1e-6) - sr = len(res[0].shape) - 1 + @parameterized.expand(TESTS) + def test_spacing(self, in_type, init_param, img, data_param, expected_output): + _img = in_type(img) + output_data, _, new_affine = Spacing(**init_param)(_img, **data_param) + if isinstance(_img, torch.Tensor): + self.assertEqual(_img.device, output_data.device) + output_data = output_data.cpu() + + np.testing.assert_allclose(output_data, expected_output, atol=1e-3, rtol=1e-3) + sr = len(output_data.shape) - 1 if isinstance(init_param["pixdim"], float): init_param["pixdim"] = [init_param["pixdim"]] * sr init_pixdim = ensure_tuple(init_param["pixdim"]) init_pixdim = init_param["pixdim"][:sr] - norm = np.sqrt(np.sum(np.square(res[2]), axis=0))[:sr] + norm = np.sqrt(np.sum(np.square(new_affine), axis=0))[:sr] np.testing.assert_allclose(fall_back_tuple(init_pixdim, norm), norm) diff --git a/tests/test_spacingd.py b/tests/test_spacingd.py index 61a4a4c38b..79e759082a 100644 --- a/tests/test_spacingd.py +++ b/tests/test_spacingd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,82 +10,88 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np +import torch +from parameterized import parameterized from monai.transforms import Spacingd +from tests.utils import TEST_NDARRAYS, assert_allclose - -class TestSpacingDCase(unittest.TestCase): - def test_spacingd_3d(self): - data = {"image": np.ones((2, 10, 15, 20)), "image_meta_dict": {"affine": np.eye(4)}} - spacing = Spacingd(keys="image", pixdim=(1, 2, 1.4)) - res = spacing(data) - self.assertEqual(("image", "image_meta_dict", "image_transforms"), tuple(sorted(res))) - np.testing.assert_allclose(res["image"].shape, (2, 10, 8, 15)) - np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag([1, 2, 1.4, 1.0])) - - def test_spacingd_2d(self): - data = {"image": np.ones((2, 10, 20)), "image_meta_dict": {"affine": np.eye(3)}} - spacing = Spacingd(keys="image", pixdim=(1, 2)) - res = spacing(data) - self.assertEqual(("image", "image_meta_dict", "image_transforms"), tuple(sorted(res))) - np.testing.assert_allclose(res["image"].shape, (2, 10, 10)) - np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 2, 1))) - - def test_spacingd_2d_no_metadata(self): - data = {"image": np.ones((2, 10, 20))} - spacing = Spacingd(keys="image", pixdim=(1, 2)) - res = spacing(data) - self.assertEqual(("image", "image_meta_dict", "image_transforms"), tuple(sorted(res))) - np.testing.assert_allclose(res["image"].shape, (2, 10, 10)) - np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 2, 1))) - - def test_interp_all(self): - data = { - "image": np.arange(20).reshape((2, 1, 10)), - "seg": np.ones((2, 1, 10)), - "image_meta_dict": {"affine": np.eye(4)}, - "seg_meta_dict": {"affine": np.eye(4)}, - } - spacing = Spacingd( - keys=("image", "seg"), - mode="nearest", - pixdim=( - 1, - 0.2, - ), +TESTS: List[Tuple] = [] +for p in TEST_NDARRAYS: + TESTS.append( + ( + "spacing 3d", + {"image": p(np.ones((2, 10, 15, 20))), "image_meta_dict": {"affine": p(np.eye(4))}}, + dict(keys="image", pixdim=(1, 2, 1.4)), + ("image", "image_meta_dict", "image_transforms"), + (2, 10, 8, 15), + p(np.diag([1, 2, 1.4, 1.0])), ) - res = spacing(data) - self.assertEqual( - ("image", "image_meta_dict", "image_transforms", "seg", "seg_meta_dict", "seg_transforms"), - tuple(sorted(res)), + ) + TESTS.append( + ( + "spacing 2d", + {"image": np.ones((2, 10, 20)), "image_meta_dict": {"affine": np.eye(3)}}, + dict(keys="image", pixdim=(1, 2)), + ("image", "image_meta_dict", "image_transforms"), + (2, 10, 10), + np.diag((1, 2, 1)), ) - np.testing.assert_allclose(res["image"].shape, (2, 1, 46)) - np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 0.2, 1, 1))) - - def test_interp_sep(self): - data = { - "image": np.ones((2, 1, 10)), - "seg": np.ones((2, 1, 10)), - "image_meta_dict": {"affine": np.eye(4)}, - "seg_meta_dict": {"affine": np.eye(4)}, - } - spacing = Spacingd( - keys=("image", "seg"), - mode=("bilinear", "nearest"), - pixdim=( - 1, - 0.2, - ), + ) + TESTS.append( + ( + "spacing 2d no metadata", + {"image": np.ones((2, 10, 20))}, + dict(keys="image", pixdim=(1, 2)), + ("image", "image_meta_dict", "image_transforms"), + (2, 10, 10), + np.diag((1, 2, 1)), + ) + ) + TESTS.append( + ( + "interp all", + { + "image": np.arange(20).reshape((2, 1, 10)), + "seg": np.ones((2, 1, 10)), + "image_meta_dict": {"affine": np.eye(4)}, + "seg_meta_dict": {"affine": np.eye(4)}, + }, + dict(keys=("image", "seg"), mode="nearest", pixdim=(1, 0.2)), + ("image", "image_meta_dict", "image_transforms", "seg", "seg_meta_dict", "seg_transforms"), + (2, 1, 46), + np.diag((1, 0.2, 1, 1)), ) - res = spacing(data) - self.assertEqual( + ) + TESTS.append( + ( + "interp sep", + { + "image": np.ones((2, 1, 10)), + "seg": np.ones((2, 1, 10)), + "image_meta_dict": {"affine": np.eye(4)}, + "seg_meta_dict": {"affine": np.eye(4)}, + }, + dict(keys=("image", "seg"), mode=("bilinear", "nearest"), pixdim=(1, 0.2)), ("image", "image_meta_dict", "image_transforms", "seg", "seg_meta_dict", "seg_transforms"), - tuple(sorted(res)), + (2, 1, 46), + np.diag((1, 0.2, 1, 1)), ) - np.testing.assert_allclose(res["image"].shape, (2, 1, 46)) - np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 0.2, 1, 1))) + ) + + +class TestSpacingDCase(unittest.TestCase): + @parameterized.expand(TESTS) + def test_spacingd(self, _, data, kw_args, expected_keys, expected_shape, expected_affine): + res = Spacingd(**kw_args)(data) + if isinstance(data["image"], torch.Tensor): + self.assertEqual(data["image"].device, res["image"].device) + self.assertEqual(expected_keys, tuple(sorted(res))) + np.testing.assert_allclose(res["image"].shape, expected_shape) + assert_allclose(res["image_meta_dict"]["affine"], expected_affine) if __name__ == "__main__": diff --git a/tests/test_spatial_crop.py b/tests/test_spatial_crop.py index c76915f0a3..bf1eb11491 100644 --- a/tests/test_spatial_crop.py +++ b/tests/test_spatial_crop.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,54 +16,41 @@ from parameterized import parameterized from monai.transforms import SpatialCrop +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - {"roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, - (3, 3, 3, 3), - (3, 2, 2, 2), - ], +TESTS = [ + [{"roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2)], + [{"roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, (3, 1, 1, 1), (3, 1, 1, 1)], [{"roi_start": [0, 0, 0], "roi_end": [2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2)], [{"roi_start": [0, 0], "roi_end": [2, 2]}, (3, 3, 3, 3), (3, 2, 2, 3)], - [ - {"roi_start": [0, 0, 0, 0, 0], "roi_end": [2, 2, 2, 2, 2]}, - (3, 3, 3, 3), - (3, 2, 2, 2), - ], - [ - {"roi_start": [0, 0, 0, 0, 0], "roi_end": [8, 8, 8, 2, 2]}, - (3, 3, 3, 3), - (3, 3, 3, 3), - ], - [ - {"roi_start": [1, 0, 0], "roi_end": [1, 8, 8]}, - (3, 3, 3, 3), - (3, 0, 3, 3), - ], - [ - {"roi_slices": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]}, - (3, 3, 3, 3), - (3, 1, 2, 2), - ], + [{"roi_start": [0, 0, 0, 0, 0], "roi_end": [2, 2, 2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2)], + [{"roi_start": [0, 0, 0, 0, 0], "roi_end": [8, 8, 8, 2, 2]}, (3, 3, 3, 3), (3, 3, 3, 3)], + [{"roi_start": [1, 0, 0], "roi_end": [1, 8, 8]}, (3, 3, 3, 3), (3, 0, 3, 3)], + [{"roi_slices": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]}, (3, 3, 3, 3), (3, 1, 2, 2)], ] -TEST_ERRORS = [ - [{"roi_slices": [slice(s, e, 2) for s, e in zip([-1, -2, 0], [None, None, 2])]}], -] +TEST_ERRORS = [[{"roi_slices": [slice(s, e, 2) for s, e in zip([-1, -2, 0], [None, None, 2])]}]] class TestSpatialCrop(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_shape, expected_shape): input_data = np.random.randint(0, 2, size=input_shape) - result = SpatialCrop(**input_param)(input_data) - self.assertTupleEqual(result.shape, expected_shape) - - @parameterized.expand(TEST_CASES) - def test_tensor_shape(self, input_param, input_shape, expected_shape): - input_data = torch.randint(0, 2, size=input_shape, device="cuda" if torch.cuda.is_available() else "cpu") - result = SpatialCrop(**input_param)(input_data) - self.assertTupleEqual(result.shape, expected_shape) + results = [] + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS + (None,): + input_param_mod = { + k: q(v) if k != "roi_slices" and q is not None else v for k, v in input_param.items() + } + im = p(input_data) + result = SpatialCrop(**input_param_mod)(im) + self.assertEqual(type(im), type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, im.device) + self.assertTupleEqual(result.shape, expected_shape) + results.append(result) + if len(results) > 1: + assert_allclose(results[0], results[-1], type_test=False) @parameterized.expand(TEST_ERRORS) def test_error(self, input_param): diff --git a/tests/test_spatial_cropd.py b/tests/test_spatial_cropd.py index 797c25d34b..5b16f460fd 100644 --- a/tests/test_spatial_cropd.py +++ b/tests/test_spatial_cropd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,38 +15,49 @@ from parameterized import parameterized from monai.transforms import SpatialCropd +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - {"keys": ["img"], "roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 2, 2, 2), - ], - [ - {"keys": ["img"], "roi_start": [0, 0, 0], "roi_end": [2, 2, 2]}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 2, 2, 2), - ], - [ - {"keys": ["img"], "roi_start": [0, 0], "roi_end": [2, 2]}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 2, 2, 3), - ], - [ - {"keys": ["img"], "roi_start": [0, 0, 0, 0, 0], "roi_end": [2, 2, 2, 2, 2]}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 2, 2, 2), - ], - [ - {"keys": ["img"], "roi_slices": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 1, 2, 2), - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["img"], "roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + (3, 2, 2, 2), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "roi_start": [0, 0, 0], "roi_end": [2, 2, 2]}, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + (3, 2, 2, 2), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "roi_start": [0, 0], "roi_end": [2, 2]}, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + (3, 2, 2, 3), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "roi_start": [0, 0, 0, 0, 0], "roi_end": [2, 2, 2, 2, 2]}, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + (3, 2, 2, 2), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "roi_slices": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]}, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + (3, 1, 2, 2), + ] + ) class TestSpatialCropd(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape): result = SpatialCropd(**input_param)(input_data) self.assertTupleEqual(result["img"].shape, expected_shape) diff --git a/tests/test_spatial_pad.py b/tests/test_spatial_pad.py index 86d010bbad..4cdeb6d64e 100644 --- a/tests/test_spatial_pad.py +++ b/tests/test_spatial_pad.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,44 +17,39 @@ from parameterized import parameterized from monai.transforms import SpatialPad -from monai.utils.enums import NumpyPadMode +from monai.utils.enums import NumpyPadMode, PytorchPadMode from monai.utils.misc import set_determinism from tests.utils import TEST_NDARRAYS TESTS = [] -# Numpy modes -MODES: List = [ +MODES = [] + +# Test modes +NP_MODES: List = [ "constant", "edge", - "linear_ramp", - "maximum", - "mean", - "median", - "minimum", - "reflect", - "symmetric", + # `reflect` mode is not supported in some PyTorch versions, skip the test + # "reflect", "wrap", - "empty", ] -MODES += [NumpyPadMode(i) for i in MODES] +MODES += NP_MODES +MODES += [NumpyPadMode(i) for i in NP_MODES] + +PT_MODES: list = [ + "constant", + "replicate", + "circular", + # `reflect` mode is not supported in some PyTorch versions, skip the test + # "reflect", +] +MODES += PT_MODES +MODES += [PytorchPadMode(i) for i in PT_MODES] for mode in MODES: - TESTS.append( - [ - {"spatial_size": [50, 50], "method": "end", "mode": mode}, - (1, 2, 2), - (1, 50, 50), - ] - ) - - TESTS.append( - [ - {"spatial_size": [15, 4, -1], "method": "symmetric", "mode": mode}, - (3, 8, 8, 4), - (3, 15, 8, 4), - ] - ) + TESTS.append([{"spatial_size": [3, 4], "method": "end", "mode": mode}, (1, 2, 3), (1, 3, 4)]) + + TESTS.append([{"spatial_size": [15, 4, -1], "method": "symmetric", "mode": mode}, (3, 8, 8, 4), (3, 15, 8, 4)]) class TestSpatialPad(unittest.TestCase): @@ -86,14 +81,19 @@ def test_pad_shape(self, input_param, input_shape, expected_shape): torch.testing.assert_allclose(results[0], results[-1], atol=0, rtol=1e-5) def test_pad_kwargs(self): - padder = SpatialPad( - spatial_size=[15, 8], method="end", mode="constant", constant_values=((0, 0), (1, 1), (2, 2)) - ) for p in TEST_NDARRAYS: - result = padder(p(np.zeros((3, 8, 4)))) - if isinstance(result, torch.Tensor): - result = result.cpu().numpy() - torch.testing.assert_allclose(result[:, 8:, :4], np.ones((3, 7, 4)), rtol=1e-7, atol=0) + input_data = p(np.zeros((3, 8, 4))) + if isinstance(input_data, torch.Tensor): + result = ( + SpatialPad(spatial_size=[15, 8], method="end", mode="constant", value=2)(img=input_data) + .cpu() + .numpy() + ) + else: + result = SpatialPad( + spatial_size=[15, 8], method="end", mode="constant", constant_values=((0, 0), (1, 1), (2, 2)) + )(img=input_data) + torch.testing.assert_allclose(result[:, 8:, :4], np.ones((3, 7, 4)), rtol=1e-7, atol=0) torch.testing.assert_allclose(result[:, :, 4:], np.ones((3, 15, 4)) + 1, rtol=1e-7, atol=0) diff --git a/tests/test_spatial_padd.py b/tests/test_spatial_padd.py index 8400bb82cc..762a1145f5 100644 --- a/tests/test_spatial_padd.py +++ b/tests/test_spatial_padd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_split_channel.py b/tests/test_split_channel.py index 38315a102c..75216227e4 100644 --- a/tests/test_split_channel.py +++ b/tests/test_split_channel.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_split_channeld.py b/tests/test_split_channeld.py index f1df24364d..7a34855676 100644 --- a/tests/test_split_channeld.py +++ b/tests/test_split_channeld.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -51,13 +51,7 @@ ] ) - TESTS.append( - [ - {"keys": "pred", "channel_dim": 1}, - {"pred": p(np.random.randint(2, size=(3, 2, 4)))}, - (3, 1, 4), - ] - ) + TESTS.append([{"keys": "pred", "channel_dim": 1}, {"pred": p(np.random.randint(2, size=(3, 2, 4)))}, (3, 1, 4)]) class TestSplitChanneld(unittest.TestCase): diff --git a/tests/test_split_on_grid.py b/tests/test_split_on_grid.py index a187835e7b..b1d4cd93c5 100644 --- a/tests/test_split_on_grid.py +++ b/tests/test_split_on_grid.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,11 +11,11 @@ import unittest -import numpy as np import torch from parameterized import parameterized from monai.apps.pathology.transforms import SplitOnGrid +from tests.utils import TEST_NDARRAYS, assert_allclose A11 = torch.randn(3, 2, 2) A12 = torch.randn(3, 2, 2) @@ -26,105 +26,58 @@ A2 = torch.cat([A21, A22], 2) A = torch.cat([A1, A2], 1) -TEST_CASE_0 = [ - {"grid_size": (2, 2)}, - A, - torch.stack([A11, A12, A21, A22]), +TEST_CASE_0 = [{"grid_size": (2, 2)}, A, torch.stack([A11, A12, A21, A22])] +TEST_CASE_1 = [{"grid_size": (2, 1)}, A, torch.stack([A1, A2])] +TEST_CASE_2 = [{"grid_size": (1, 2)}, A1, torch.stack([A11, A12])] +TEST_CASE_3 = [{"grid_size": (1, 2)}, A2, torch.stack([A21, A22])] +TEST_CASE_4 = [{"grid_size": (1, 1), "patch_size": (2, 2)}, A, torch.stack([A11])] +TEST_CASE_5 = [{"grid_size": 1, "patch_size": 4}, A, torch.stack([A])] +TEST_CASE_6 = [{"grid_size": 2, "patch_size": 2}, A, torch.stack([A11, A12, A21, A22])] +TEST_CASE_7 = [{"grid_size": 1}, A, torch.stack([A])] +TEST_CASE_8 = [ + {"grid_size": (2, 2), "patch_size": 2}, + torch.arange(12).reshape(1, 3, 4).to(torch.float32), + torch.Tensor([[[[0, 1], [4, 5]]], [[[2, 3], [6, 7]]], [[[4, 5], [8, 9]]], [[[6, 7], [10, 11]]]]).to(torch.float32), ] -TEST_CASE_1 = [ - {"grid_size": (2, 1)}, - A, - torch.stack([A1, A2]), -] - -TEST_CASE_2 = [ - {"grid_size": (1, 2)}, - A1, - torch.stack([A11, A12]), -] - -TEST_CASE_3 = [ - {"grid_size": (1, 2)}, - A2, - torch.stack([A21, A22]), -] - -TEST_CASE_4 = [ - {"grid_size": (1, 1), "patch_size": (2, 2)}, - A, - torch.stack([A11]), -] - -TEST_CASE_5 = [ - {"grid_size": 1, "patch_size": 4}, - A, - torch.stack([A]), -] - -TEST_CASE_6 = [ - {"grid_size": 2, "patch_size": 2}, - A, - torch.stack([A11, A12, A21, A22]), -] - -TEST_CASE_7 = [ - {"grid_size": 1}, - A, - torch.stack([A]), -] - -TEST_CASE_MC_0 = [ - {"grid_size": (2, 2)}, - [A, A], - [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])], -] - - -TEST_CASE_MC_1 = [ - {"grid_size": (2, 1)}, - [A] * 5, - [torch.stack([A1, A2])] * 5, -] - - -TEST_CASE_MC_2 = [ - {"grid_size": (1, 2)}, - [A1, A2], - [torch.stack([A11, A12]), torch.stack([A21, A22])], -] +TEST_SINGLE = [] +for p in TEST_NDARRAYS: + TEST_SINGLE.append([p, *TEST_CASE_0]) + TEST_SINGLE.append([p, *TEST_CASE_1]) + TEST_SINGLE.append([p, *TEST_CASE_2]) + TEST_SINGLE.append([p, *TEST_CASE_3]) + TEST_SINGLE.append([p, *TEST_CASE_4]) + TEST_SINGLE.append([p, *TEST_CASE_5]) + TEST_SINGLE.append([p, *TEST_CASE_6]) + TEST_SINGLE.append([p, *TEST_CASE_7]) + TEST_SINGLE.append([p, *TEST_CASE_8]) + +TEST_CASE_MC_0 = [{"grid_size": (2, 2)}, [A, A], [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])]] +TEST_CASE_MC_1 = [{"grid_size": (2, 1)}, [A] * 5, [torch.stack([A1, A2])] * 5] +TEST_CASE_MC_2 = [{"grid_size": (1, 2)}, [A1, A2], [torch.stack([A11, A12]), torch.stack([A21, A22])]] + +TEST_MULTIPLE = [] +for p in TEST_NDARRAYS: + TEST_MULTIPLE.append([p, *TEST_CASE_MC_0]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_1]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_2]) class TestSplitOnGrid(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - TEST_CASE_7, - ] - ) - def test_split_pathce_single_call(self, input_parameters, img, expected): + @parameterized.expand(TEST_SINGLE) + def test_split_patch_single_call(self, in_type, input_parameters, image, expected): + input_image = in_type(image) splitter = SplitOnGrid(**input_parameters) - output = splitter(img) - np.testing.assert_equal(output.numpy(), expected.numpy()) + output = splitter(input_image) + assert_allclose(output, expected, type_test=False) - @parameterized.expand( - [ - TEST_CASE_MC_0, - TEST_CASE_MC_1, - TEST_CASE_MC_2, - ] - ) - def test_split_pathce_multiple_call(self, input_parameters, img_list, expected_list): + @parameterized.expand(TEST_MULTIPLE) + def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list): splitter = SplitOnGrid(**input_parameters) - for img, expected in zip(img_list, expected_list): - output = splitter(img) - np.testing.assert_equal(output.numpy(), expected.numpy()) + for image, expected in zip(img_list, expected_list): + input_image = in_type(image) + output = splitter(input_image) + assert_allclose(output, expected, type_test=False) if __name__ == "__main__": diff --git a/tests/test_split_on_grid_dict.py b/tests/test_split_on_grid_dict.py index 96ec095423..778a38da34 100644 --- a/tests/test_split_on_grid_dict.py +++ b/tests/test_split_on_grid_dict.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,11 +11,11 @@ import unittest -import numpy as np import torch from parameterized import parameterized from monai.apps.pathology.transforms import SplitOnGridDict +from tests.utils import TEST_NDARRAYS, assert_allclose A11 = torch.randn(3, 2, 2) A12 = torch.randn(3, 2, 2) @@ -26,105 +26,74 @@ A2 = torch.cat([A21, A22], 2) A = torch.cat([A1, A2], 1) -TEST_CASE_0 = [ - {"keys": "image", "grid_size": (2, 2)}, - {"image": A}, - torch.stack([A11, A12, A21, A22]), -] - -TEST_CASE_1 = [ - {"keys": "image", "grid_size": (2, 1)}, - {"image": A}, - torch.stack([A1, A2]), -] - -TEST_CASE_2 = [ - {"keys": "image", "grid_size": (1, 2)}, - {"image": A1}, - torch.stack([A11, A12]), -] - -TEST_CASE_3 = [ - {"keys": "image", "grid_size": (1, 2)}, - {"image": A2}, - torch.stack([A21, A22]), -] - -TEST_CASE_4 = [ - {"keys": "image", "grid_size": (1, 1), "patch_size": (2, 2)}, - {"image": A}, - torch.stack([A11]), -] - -TEST_CASE_5 = [ - {"keys": "image", "grid_size": 1, "patch_size": 4}, - {"image": A}, - torch.stack([A]), +TEST_CASE_0 = [{"keys": "image", "grid_size": (2, 2)}, {"image": A}, torch.stack([A11, A12, A21, A22])] +TEST_CASE_1 = [{"keys": "image", "grid_size": (2, 1)}, {"image": A}, torch.stack([A1, A2])] +TEST_CASE_2 = [{"keys": "image", "grid_size": (1, 2)}, {"image": A1}, torch.stack([A11, A12])] +TEST_CASE_3 = [{"keys": "image", "grid_size": (1, 2)}, {"image": A2}, torch.stack([A21, A22])] +TEST_CASE_4 = [{"keys": "image", "grid_size": (1, 1), "patch_size": (2, 2)}, {"image": A}, torch.stack([A11])] +TEST_CASE_5 = [{"keys": "image", "grid_size": 1, "patch_size": 4}, {"image": A}, torch.stack([A])] +TEST_CASE_6 = [{"keys": "image", "grid_size": 2, "patch_size": 2}, {"image": A}, torch.stack([A11, A12, A21, A22])] +TEST_CASE_7 = [{"keys": "image", "grid_size": 1}, {"image": A}, torch.stack([A])] +TEST_CASE_8 = [ + {"keys": "image", "grid_size": (2, 2), "patch_size": 2}, + {"image": torch.arange(12).reshape(1, 3, 4).to(torch.float32)}, + torch.Tensor([[[[0, 1], [4, 5]]], [[[2, 3], [6, 7]]], [[[4, 5], [8, 9]]], [[[6, 7], [10, 11]]]]).to(torch.float32), ] -TEST_CASE_6 = [ - {"keys": "image", "grid_size": 2, "patch_size": 2}, - {"image": A}, - torch.stack([A11, A12, A21, A22]), -] - -TEST_CASE_7 = [ - {"keys": "image", "grid_size": 1}, - {"image": A}, - torch.stack([A]), -] +TEST_SINGLE = [] +for p in TEST_NDARRAYS: + TEST_SINGLE.append([p, *TEST_CASE_0]) + TEST_SINGLE.append([p, *TEST_CASE_1]) + TEST_SINGLE.append([p, *TEST_CASE_2]) + TEST_SINGLE.append([p, *TEST_CASE_3]) + TEST_SINGLE.append([p, *TEST_CASE_4]) + TEST_SINGLE.append([p, *TEST_CASE_5]) + TEST_SINGLE.append([p, *TEST_CASE_6]) + TEST_SINGLE.append([p, *TEST_CASE_7]) + TEST_SINGLE.append([p, *TEST_CASE_8]) TEST_CASE_MC_0 = [ {"keys": "image", "grid_size": (2, 2)}, [{"image": A}, {"image": A}], [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])], ] - - TEST_CASE_MC_1 = [ {"keys": "image", "grid_size": (2, 1)}, - [{"image": A}] * 5, - [torch.stack([A1, A2])] * 5, + [{"image": A}, {"image": A}, {"image": A}], + [torch.stack([A1, A2])] * 3, ] - - TEST_CASE_MC_2 = [ {"keys": "image", "grid_size": (1, 2)}, [{"image": A1}, {"image": A2}], [torch.stack([A11, A12]), torch.stack([A21, A22])], ] +TEST_MULTIPLE = [] +for p in TEST_NDARRAYS: + TEST_MULTIPLE.append([p, *TEST_CASE_MC_0]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_1]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_2]) + class TestSplitOnGridDict(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - TEST_CASE_7, - ] - ) - def test_split_pathce_single_call(self, input_parameters, img_dict, expected): + @parameterized.expand(TEST_SINGLE) + def test_split_patch_single_call(self, in_type, input_parameters, img_dict, expected): + input_dict = {} + for k, v in img_dict.items(): + input_dict[k] = in_type(v) splitter = SplitOnGridDict(**input_parameters) - output = splitter(img_dict)[input_parameters["keys"]] - np.testing.assert_equal(output.numpy(), expected.numpy()) + output = splitter(input_dict)[input_parameters["keys"]] + assert_allclose(output, expected, type_test=False) - @parameterized.expand( - [ - TEST_CASE_MC_0, - TEST_CASE_MC_1, - TEST_CASE_MC_2, - ] - ) - def test_split_pathce_multiple_call(self, input_parameters, img_list, expected_list): + @parameterized.expand(TEST_MULTIPLE) + def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list): splitter = SplitOnGridDict(**input_parameters) for img_dict, expected in zip(img_list, expected_list): - output = splitter(img_dict)[input_parameters["keys"]] - np.testing.assert_equal(output.numpy(), expected.numpy()) + input_dict = {} + for k, v in img_dict.items(): + input_dict[k] = in_type(v) + output = splitter(input_dict)[input_parameters["keys"]] + assert_allclose(output, expected, type_test=False) if __name__ == "__main__": diff --git a/tests/test_squeezedim.py b/tests/test_squeezedim.py index 15ff7e94d6..8403efe836 100644 --- a/tests/test_squeezedim.py +++ b/tests/test_squeezedim.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_squeezedimd.py b/tests/test_squeezedimd.py index 35e7cd5d74..6baf4696a5 100644 --- a/tests/test_squeezedimd.py +++ b/tests/test_squeezedimd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_state_cacher.py b/tests/test_state_cacher.py index 139e7b8374..e4164be272 100644 --- a/tests/test_state_cacher.py +++ b/tests/test_state_cacher.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pickle import unittest from os.path import exists, join +from pathlib import Path from tempfile import gettempdir import torch @@ -20,20 +22,15 @@ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" -TEST_CASE_0 = [ - torch.Tensor([1]).to(DEVICE), - {"in_memory": True}, -] +TEST_CASE_0 = [torch.Tensor([1]).to(DEVICE), {"in_memory": True}] TEST_CASE_1 = [ torch.Tensor([1]).to(DEVICE), - {"in_memory": False, "cache_dir": gettempdir()}, -] -TEST_CASE_2 = [ - torch.Tensor([1]).to(DEVICE), - {"in_memory": False, "allow_overwrite": False}, + {"in_memory": False, "cache_dir": gettempdir(), "pickle_module": None, "pickle_protocol": pickle.HIGHEST_PROTOCOL}, ] +TEST_CASE_2 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "allow_overwrite": False}] +TEST_CASE_3 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "cache_dir": Path(gettempdir())}] -TEST_CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2] +TEST_CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3] class TestStateCacher(unittest.TestCase): @@ -44,7 +41,7 @@ def test_state_cacher(self, data_obj, params): state_cacher = StateCacher(**params) # store it - state_cacher.store(key, data_obj) + state_cacher.store(key, data_obj, pickle_module=pickle) # create clone then modify original data_obj_orig = data_obj.clone() data_obj += 1 diff --git a/tests/test_std_shift_intensity.py b/tests/test_std_shift_intensity.py index 5c16e14c45..55750161ec 100644 --- a/tests/test_std_shift_intensity.py +++ b/tests/test_std_shift_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_std_shift_intensityd.py b/tests/test_std_shift_intensityd.py index 4eb256f1e5..595da5cbc2 100644 --- a/tests/test_std_shift_intensityd.py +++ b/tests/test_std_shift_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_subpixel_upsample.py b/tests/test_subpixel_upsample.py index 07e110d7a7..0216f164c3 100644 --- a/tests/test_subpixel_upsample.py +++ b/tests/test_subpixel_upsample.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,37 +24,30 @@ for dim in range(1, 4): for factor in range(1, 3): test_case = [ - {"dimensions": dim, "in_channels": inch, "scale_factor": factor}, + {"spatial_dims": dim, "in_channels": inch, "scale_factor": factor}, (2, inch, *([8] * dim)), (2, inch, *([8 * factor] * dim)), ] TEST_CASE_SUBPIXEL.append(test_case) TEST_CASE_SUBPIXEL_2D_EXTRA = [ - {"dimensions": 2, "in_channels": 2, "scale_factor": 3}, + {"spatial_dims": 2, "in_channels": 2, "scale_factor": 3}, (2, 2, 8, 4), # different size for H and W (2, 2, 24, 12), ] TEST_CASE_SUBPIXEL_3D_EXTRA = [ - {"dimensions": 3, "in_channels": 1, "scale_factor": 2}, + {"spatial_dims": 3, "in_channels": 1, "scale_factor": 2}, (2, 1, 16, 8, 4), # different size for H, W and D (2, 1, 32, 16, 8), ] conv_block = nn.Sequential( - Conv[Conv.CONV, 3](1, 4, kernel_size=1), - Conv[Conv.CONV, 3]( - 4, - 8, - kernel_size=3, - stride=1, - padding=1, - ), + Conv[Conv.CONV, 3](1, 4, kernel_size=1), Conv[Conv.CONV, 3](4, 8, kernel_size=3, stride=1, padding=1) ) TEST_CASE_SUBPIXEL_CONV_BLOCK_EXTRA = [ - {"dimensions": 3, "in_channels": 1, "scale_factor": 2, "conv_block": conv_block}, + {"spatial_dims": 3, "in_channels": 1, "scale_factor": 2, "conv_block": conv_block}, (2, 1, 16, 8, 4), # different size for H, W and D (2, 1, 32, 16, 8), ] diff --git a/tests/test_surface_distance.py b/tests/test_surface_distance.py index e5d2145a1f..781c71c23a 100644 --- a/tests/test_surface_distance.py +++ b/tests/test_surface_distance.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,9 +20,7 @@ def create_spherical_seg_3d( - radius: float = 20.0, - centre: Tuple[int, int, int] = (49, 49, 49), - im_shape: Tuple[int, int, int] = (99, 99, 99), + radius: float = 20.0, centre: Tuple[int, int, int] = (49, 49, 49), im_shape: Tuple[int, int, int] = (99, 99, 99) ) -> np.ndarray: """ Return a 3D image with a sphere inside. Voxel values will be @@ -49,10 +47,7 @@ def create_spherical_seg_3d( TEST_CASES = [ - [ - [create_spherical_seg_3d(), create_spherical_seg_3d()], - [0, 0], - ], + [[create_spherical_seg_3d(), create_spherical_seg_3d()], [0, 0]], [ [ create_spherical_seg_3d(radius=20, centre=(20, 20, 20)), @@ -91,21 +86,8 @@ def create_spherical_seg_3d( ], [17.32691760951026, 12.432687531048186], ], - [ - [ - np.zeros([99, 99, 99]), - create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), - ], - [np.inf, np.inf], - ], - [ - [ - create_spherical_seg_3d(), - np.zeros([99, 99, 99]), - "taxicab", - ], - [np.inf, np.inf], - ], + [[np.zeros([99, 99, 99]), create_spherical_seg_3d(radius=40, centre=(20, 33, 22))], [np.inf, np.inf]], + [[create_spherical_seg_3d(), np.zeros([99, 99, 99]), "taxicab"], [np.inf, np.inf]], ] TEST_CASES_NANS = [ @@ -114,8 +96,8 @@ def create_spherical_seg_3d( # both pred and gt do not have foreground, metric and not_nans should be 0 np.zeros([99, 99, 99]), np.zeros([99, 99, 99]), - ], - ], + ] + ] ] diff --git a/tests/test_synthetic.py b/tests/test_synthetic.py index 97ab12a588..fadf6255ff 100644 --- a/tests/test_synthetic.py +++ b/tests/test_synthetic.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,29 +18,10 @@ from monai.utils import set_determinism TEST_CASES = [ + [2, {"width": 64, "height": 64, "rad_max": 10, "rad_min": 4}, 0.1479004, 0.739502, (64, 64), 5], [ 2, - { - "width": 64, - "height": 64, - "rad_max": 10, - "rad_min": 4, - }, - 0.1479004, - 0.739502, - (64, 64), - 5, - ], - [ - 2, - { - "width": 32, - "height": 28, - "num_objs": 3, - "rad_max": 5, - "rad_min": 1, - "noise_max": 0.2, - }, + {"width": 32, "height": 28, "num_objs": 3, "rad_max": 5, "rad_min": 1, "noise_max": 0.2}, 0.1709315, 0.4040179, (32, 28), @@ -48,15 +29,7 @@ ], [ 3, - { - "width": 64, - "height": 64, - "depth": 45, - "num_seg_classes": 3, - "channel_dim": -1, - "rad_max": 10, - "rad_min": 4, - }, + {"width": 64, "height": 64, "depth": 45, "num_seg_classes": 3, "channel_dim": -1, "rad_max": 10, "rad_min": 4}, 0.025132, 0.0753961, (64, 64, 45, 1), diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index a07d59703d..cf4772e8bc 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,9 +21,18 @@ from monai.data.utils import pad_list_data_collate from monai.losses import DiceLoss from monai.networks.nets import UNet -from monai.transforms import Activations, AddChanneld, AsDiscrete, Compose, CropForegroundd, DivisiblePadd, RandAffined +from monai.transforms import ( + Activations, + AddChanneld, + AsDiscrete, + Compose, + CropForegroundd, + DivisiblePadd, + RandAffined, + RandScaleIntensityd, +) from monai.transforms.croppad.dictionary import SpatialPadd -from monai.transforms.spatial.dictionary import Rand2DElasticd, RandFlipd, Spacingd +from monai.transforms.spatial.dictionary import RandFlipd, Spacingd from monai.utils import optional_import, set_determinism from tests.utils import TEST_NDARRAYS @@ -46,13 +55,13 @@ def get_data(num_examples, input_size, data_type=np.asarray, include_label=True) create_test_image_2d, *input_size, rad_max=7, num_seg_classes=1, num_objs=1 ) data = [] - for _ in range(num_examples): + for i in range(num_examples): im, label = custom_create_test_image_2d() d = {} - d["image"] = data_type(im) + d["image"] = data_type(im[:, i:]) d["image_meta_dict"] = {"affine": np.eye(4)} if include_label: - d["label"] = data_type(label) + d["label"] = data_type(label[:, i:]) d["label_meta_dict"] = {"affine": np.eye(4)} data.append(d) return data[0] if num_examples == 1 else data @@ -113,12 +122,7 @@ def test_test_time_augmentation(self): epoch_loss /= len(train_loader) - post_trans = Compose( - [ - Activations(sigmoid=True), - AsDiscrete(threshold_values=True), - ] - ) + post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) def inferrer_fn(x): return post_trans(model(x)) @@ -137,10 +141,20 @@ def test_fail_non_random(self): with self.assertRaises(RuntimeError): TestTimeAugmentation(transforms, None, None, None) - def test_fail_random_but_not_invertible(self): - transforms = Compose([AddChanneld("im"), Rand2DElasticd("im", None, None)]) - with self.assertRaises(RuntimeError): - TestTimeAugmentation(transforms, None, None, None) + def test_warn_random_but_has_no_invertible(self): + transforms = Compose( + [AddChanneld("image"), RandFlipd("image", prob=1.0), RandScaleIntensityd("image", 0.1, prob=1.0)] + ) + with self.assertWarns(UserWarning): + tta = TestTimeAugmentation(transforms, 5, 0, orig_key="image") + tta(self.get_data(1, (20, 20), data_type=np.float32)) + + def test_warn_random_but_all_not_invertible(self): + """test with no invertible stack""" + transforms = Compose([AddChanneld("image"), RandScaleIntensityd("image", 0.1, prob=1.0)]) + with self.assertWarns(UserWarning): + tta = TestTimeAugmentation(transforms, 1, 0, orig_key="image") + tta(self.get_data(1, (20, 20), data_type=np.float32)) def test_single_transform(self): for p in TEST_NDARRAYS: @@ -155,7 +169,7 @@ def test_image_no_label(self): @unittest.skipUnless(has_nib, "Requires nibabel") def test_requires_meta_dict(self): - transforms = Compose([RandFlipd("image"), Spacingd("image", pixdim=1.0)]) + transforms = Compose([AddChanneld("image"), RandFlipd("image"), Spacingd("image", pixdim=1.1)]) tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x, orig_key="image") tta(self.get_data(1, (20, 20), include_label=False)) diff --git a/tests/test_thread_buffer.py b/tests/test_thread_buffer.py index 507b6909be..04511220f8 100644 --- a/tests/test_thread_buffer.py +++ b/tests/test_thread_buffer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -78,6 +78,22 @@ def test_time(self): f"Buffered time {buffered_time} should be less than unbuffered time {unbuffered_time}", ) + def test_dataloader_repeats(self): + dataset = Dataset(data=self.datalist, transform=self.transform) + dataloader = ThreadDataLoader(dataset=dataset, batch_size=2, num_workers=0, repeats=2) + + previous_batch = None + + for d in dataloader: + self.assertEqual(d["image"][0], "spleen_19.nii.gz") + self.assertEqual(d["image"][1], "spleen_31.nii.gz") + + if previous_batch is None: + previous_batch = d + else: + self.assertTrue(previous_batch is d, "Batch object was not repeated") + previous_batch = None + if __name__ == "__main__": unittest.main() diff --git a/tests/test_threadcontainer.py b/tests/test_threadcontainer.py index 543dab4d0c..2419b390fd 100644 --- a/tests/test_threadcontainer.py +++ b/tests/test_threadcontainer.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -79,7 +79,7 @@ def test_plot(self): # a third non-image key is added to test that this is correctly ignored when plotting data = {CommonKeys.IMAGE: img, CommonKeys.LABEL: img, "Not Image Data": ["This isn't an image"]} - loader = DataLoader([data] * 10) + loader = DataLoader([data] * 20, batch_size=2) trainer = SupervisedTrainer( device=torch.device("cpu"), @@ -102,7 +102,7 @@ def test_plot(self): with tempfile.TemporaryDirectory() as tempdir: tempimg = f"{tempdir}/threadcontainer_plot_test.png" fig.savefig(tempimg) - comp = compare_images(f"{testing_dir}/threadcontainer_plot_test.png", tempimg, 1e-2) + comp = compare_images(f"{testing_dir}/threadcontainer_plot_test.png", tempimg, 5e-2) self.assertIsNone(comp, comp) # None indicates test passed diff --git a/tests/test_threshold_intensity.py b/tests/test_threshold_intensity.py index a6d3895709..01321f1b0b 100644 --- a/tests/test_threshold_intensity.py +++ b/tests/test_threshold_intensity.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,20 +15,21 @@ from parameterized import parameterized from monai.transforms import ThresholdIntensity +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [{"threshold": 5, "above": True, "cval": 0}, (0, 0, 0, 0, 0, 0, 6, 7, 8, 9)] - -TEST_CASE_2 = [{"threshold": 5, "above": False, "cval": 0}, (0, 1, 2, 3, 4, 0, 0, 0, 0, 0)] - -TEST_CASE_3 = [{"threshold": 5, "above": True, "cval": 5}, (5, 5, 5, 5, 5, 5, 6, 7, 8, 9)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"threshold": 5, "above": True, "cval": 0}, (0, 0, 0, 0, 0, 0, 6, 7, 8, 9)]) + TESTS.append([p, {"threshold": 5, "above": False, "cval": 0}, (0, 1, 2, 3, 4, 0, 0, 0, 0, 0)]) + TESTS.append([p, {"threshold": 5, "above": True, "cval": 5}, (5, 5, 5, 5, 5, 5, 6, 7, 8, 9)]) class TestThresholdIntensity(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_value(self, input_param, expected_value): - test_data = np.arange(10) + @parameterized.expand(TESTS) + def test_value(self, in_type, input_param, expected_value): + test_data = in_type(np.arange(10)) result = ThresholdIntensity(**input_param)(test_data) - np.testing.assert_allclose(result, expected_value) + assert_allclose(result, in_type(expected_value)) if __name__ == "__main__": diff --git a/tests/test_threshold_intensityd.py b/tests/test_threshold_intensityd.py index efcfcfe604..e0610ebb5b 100644 --- a/tests/test_threshold_intensityd.py +++ b/tests/test_threshold_intensityd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,31 +15,41 @@ from parameterized import parameterized from monai.transforms import ThresholdIntensityd - -TEST_CASE_1 = [ - {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 0}, - (0, 0, 0, 0, 0, 0, 6, 7, 8, 9), -] - -TEST_CASE_2 = [ - {"keys": ["image", "label", "extra"], "threshold": 5, "above": False, "cval": 0}, - (0, 1, 2, 3, 4, 0, 0, 0, 0, 0), -] - -TEST_CASE_3 = [ - {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 5}, - (5, 5, 5, 5, 5, 5, 6, 7, 8, 9), -] +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + p, + {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 0}, + (0, 0, 0, 0, 0, 0, 6, 7, 8, 9), + ] + ) + TESTS.append( + [ + p, + {"keys": ["image", "label", "extra"], "threshold": 5, "above": False, "cval": 0}, + (0, 1, 2, 3, 4, 0, 0, 0, 0, 0), + ] + ) + TESTS.append( + [ + p, + {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 5}, + (5, 5, 5, 5, 5, 5, 6, 7, 8, 9), + ] + ) class TestThresholdIntensityd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_value(self, input_param, expected_value): - test_data = {"image": np.arange(10), "label": np.arange(10), "extra": np.arange(10)} + @parameterized.expand(TESTS) + def test_value(self, in_type, input_param, expected_value): + test_data = {"image": in_type(np.arange(10)), "label": in_type(np.arange(10)), "extra": in_type(np.arange(10))} result = ThresholdIntensityd(**input_param)(test_data) - np.testing.assert_allclose(result["image"], expected_value) - np.testing.assert_allclose(result["label"], expected_value) - np.testing.assert_allclose(result["extra"], expected_value) + assert_allclose(result["image"], in_type(expected_value)) + assert_allclose(result["label"], in_type(expected_value)) + assert_allclose(result["extra"], in_type(expected_value)) if __name__ == "__main__": diff --git a/tests/test_tile_on_grid.py b/tests/test_tile_on_grid.py new file mode 100644 index 0000000000..08fb5d96fe --- /dev/null +++ b/tests/test_tile_on_grid.py @@ -0,0 +1,141 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Optional + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms import TileOnGrid +from tests.utils import TEST_NDARRAYS, assert_allclose + +TEST_CASES = [] +for tile_count in [16, 64]: + for tile_size in [8, 32]: + for filter_mode in ["min", "max", "random"]: + for background_val in [255, 0]: + TEST_CASES.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": False, + "background_val": background_val, + } + ] + ) + +for tile_size in [8, 16]: + for step in [4, 8]: + TEST_CASES.append([{"tile_count": 16, "step": step, "tile_size": tile_size}]) + +TESTS = [] +for p in TEST_NDARRAYS: + for tc in TEST_CASES: + TESTS.append([p, *tc]) + +TEST_CASES2 = [] +for tile_count in [16, 64]: + for tile_size in [8, 32]: + for filter_mode in ["min", "max", "random"]: + for background_val in [255, 0]: + TEST_CASES2.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": True, + "background_val": background_val, + } + ] + ) + +TESTS2 = [] +for p in TEST_NDARRAYS: + for tc in TEST_CASES2: + TESTS2.append([p, *tc]) + + +def make_image( + tile_count: int, + tile_size: int, + step: int = 0, + random_offset: bool = False, + filter_mode: Optional[str] = None, + seed=123, + **kwargs, +): + + tile_count = int(np.sqrt(tile_count)) + pad = 0 + if random_offset: + pad = 3 + + if step == 0: + step = tile_size + + image = np.random.randint( + 200, + size=[3, (tile_count - 1) * step + tile_size + pad, (tile_count - 1) * step + tile_size + pad], + dtype=np.uint8, + ) + imlarge = image + + random_state = np.random.RandomState(seed) + + if random_offset: + pad_h = image.shape[1] % tile_size + pad_w = image.shape[2] % tile_size + offset = (random_state.randint(pad_h) if pad_h > 0 else 0, random_state.randint(pad_w) if pad_w > 0 else 0) + image = image[:, offset[0] :, offset[1] :] + + tiles_list = [] + for x in range(tile_count): + for y in range(tile_count): + tiles_list.append(image[:, x * step : x * step + tile_size, y * step : y * step + tile_size]) + + tiles = np.stack(tiles_list, axis=0) # type: ignore + + if (filter_mode == "min" or filter_mode == "max") and len(tiles) > tile_count ** 2: + tiles = tiles[np.argsort(tiles.sum(axis=(1, 2, 3)))] + + return imlarge, tiles + + +class TestTileOnGrid(unittest.TestCase): + @parameterized.expand(TESTS) + def test_tile_patch_single_call(self, in_type, input_parameters): + + img, tiles = make_image(**input_parameters) + input_img = in_type(img) + + tiler = TileOnGrid(**input_parameters) + output = tiler(input_img) + assert_allclose(output, tiles, type_test=False) + + @parameterized.expand(TESTS2) + def test_tile_patch_random_call(self, in_type, input_parameters): + + img, tiles = make_image(**input_parameters, seed=123) + input_img = in_type(img) + + tiler = TileOnGrid(**input_parameters) + tiler.set_random_state(seed=123) + + output = tiler(input_img) + assert_allclose(output, tiles, type_test=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_tile_on_grid_dict.py b/tests/test_tile_on_grid_dict.py new file mode 100644 index 0000000000..9f1d67ac29 --- /dev/null +++ b/tests/test_tile_on_grid_dict.py @@ -0,0 +1,176 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Optional + +import numpy as np +import torch +from parameterized import parameterized + +from monai.apps.pathology.transforms import TileOnGridDict +from tests.utils import TEST_NDARRAYS, assert_allclose + +TEST_CASES = [] +for tile_count in [16, 64]: + for tile_size in [8, 32]: + for filter_mode in ["min", "max", "random"]: + for background_val in [255, 0]: + for return_list_of_dicts in [False, True]: + TEST_CASES.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": False, + "background_val": background_val, + "return_list_of_dicts": return_list_of_dicts, + } + ] + ) + +for tile_size in [8, 16]: + for step in [4, 8]: + TEST_CASES.append([{"tile_count": 16, "step": step, "tile_size": tile_size}]) + +TESTS = [] +for p in TEST_NDARRAYS: + for tc in TEST_CASES: + TESTS.append([p, *tc]) + +TEST_CASES2 = [] +for tile_count in [16, 64]: + for tile_size in [8, 32]: + for filter_mode in ["min", "max", "random"]: + for background_val in [255, 0]: + for return_list_of_dicts in [False, True]: + TEST_CASES2.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": True, + "background_val": background_val, + "return_list_of_dicts": return_list_of_dicts, + } + ] + ) + +TESTS2 = [] +for p in TEST_NDARRAYS: + for tc in TEST_CASES2: + TESTS2.append([p, *tc]) + +for tile_size in [8, 16]: + for step in [4, 8]: + TEST_CASES.append([{"tile_count": 16, "step": step, "tile_size": tile_size}]) + + +def make_image( + tile_count: int, + tile_size: int, + step: int = 0, + random_offset: bool = False, + filter_mode: Optional[str] = None, + seed=123, + **kwargs, +): + + tile_count = int(np.sqrt(tile_count)) + pad = 0 + if random_offset: + pad = 3 + + if step == 0: + step = tile_size + + image = np.random.randint( + 200, + size=[3, (tile_count - 1) * step + tile_size + pad, (tile_count - 1) * step + tile_size + pad], + dtype=np.uint8, + ) + imlarge = image + + random_state = np.random.RandomState(seed) + + if random_offset: + pad_h = image.shape[1] % tile_size + pad_w = image.shape[2] % tile_size + offset = (random_state.randint(pad_h) if pad_h > 0 else 0, random_state.randint(pad_w) if pad_w > 0 else 0) + image = image[:, offset[0] :, offset[1] :] + + tiles_list = [] + for x in range(tile_count): + for y in range(tile_count): + tiles_list.append(image[:, x * step : x * step + tile_size, y * step : y * step + tile_size]) + + tiles = np.stack(tiles_list, axis=0) # type: ignore + + if (filter_mode == "min" or filter_mode == "max") and len(tiles) > tile_count ** 2: + tiles = tiles[np.argsort(tiles.sum(axis=(1, 2, 3)))] + + return imlarge, tiles + + +class TestTileOnGridDict(unittest.TestCase): + @parameterized.expand(TESTS) + def test_tile_patch_single_call(self, in_type, input_parameters): + + key = "image" + input_parameters["keys"] = key + + img, tiles = make_image(**input_parameters) + input_img = in_type(img) + + splitter = TileOnGridDict(**input_parameters) + + output = splitter({key: input_img}) + + if input_parameters.get("return_list_of_dicts", False): + if isinstance(input_img, torch.Tensor): + output = torch.stack([ix[key] for ix in output], axis=0) + else: + output = np.stack([ix[key] for ix in output], axis=0) + else: + output = output[key] + + assert_allclose(output, tiles, type_test=False) + + @parameterized.expand(TESTS2) + def test_tile_patch_random_call(self, in_type, input_parameters): + + key = "image" + input_parameters["keys"] = key + + random_state = np.random.RandomState(123) + seed = random_state.randint(10000) + img, tiles = make_image(**input_parameters, seed=seed) + input_img = in_type(img) + + splitter = TileOnGridDict(**input_parameters) + splitter.set_random_state(seed=123) + + output = splitter({key: input_img}) + + if input_parameters.get("return_list_of_dicts", False): + if isinstance(input_img, torch.Tensor): + output = torch.stack([ix[key] for ix in output], axis=0) + else: + output = np.stack([ix[key] for ix in output], axis=0) + else: + output = output[key] + assert_allclose(output, tiles, type_test=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_timedcall.py b/tests/test_timedcall_dist.py similarity index 98% rename from tests/test_timedcall.py rename to tests/test_timedcall_dist.py index de10abb8f7..70b5e0f56a 100644 --- a/tests/test_timedcall.py +++ b/tests/test_timedcall_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_to_cupy.py b/tests/test_to_cupy.py index 8b00e12539..36edf24f3f 100644 --- a/tests/test_to_cupy.py +++ b/tests/test_to_cupy.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,60 +17,92 @@ from monai.transforms import ToCupy from monai.utils import optional_import -from tests.utils import skip_if_no_cuda +from tests.utils import HAS_CUPY, skip_if_no_cuda -cp, has_cp = optional_import("cupy") +cp, _ = optional_import("cupy") +@skipUnless(HAS_CUPY, "CuPy is required.") class TestToCupy(unittest.TestCase): - @skipUnless(has_cp, "CuPy is required.") def test_cupy_input(self): - test_data = cp.array([[1, 2], [3, 4]]) + test_data = cp.array([[1, 2], [3, 4]], dtype=cp.float32) test_data = cp.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) result = ToCupy()(test_data) + self.assertTrue(result.dtype == cp.float32) + self.assertTrue(isinstance(result, cp.ndarray)) + self.assertTrue(result.flags["C_CONTIGUOUS"]) + cp.testing.assert_allclose(result, test_data) + + def test_cupy_input_dtype(self): + test_data = cp.array([[1, 2], [3, 4]], dtype=cp.float32) + test_data = cp.rot90(test_data) + self.assertFalse(test_data.flags["C_CONTIGUOUS"]) + result = ToCupy(cp.uint8)(test_data) + self.assertTrue(result.dtype == cp.uint8) self.assertTrue(isinstance(result, cp.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data) - @skipUnless(has_cp, "CuPy is required.") def test_numpy_input(self): - test_data = np.array([[1, 2], [3, 4]]) + test_data = np.array([[1, 2], [3, 4]], dtype=np.float32) test_data = np.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) result = ToCupy()(test_data) + self.assertTrue(result.dtype == cp.float32) + self.assertTrue(isinstance(result, cp.ndarray)) + self.assertTrue(result.flags["C_CONTIGUOUS"]) + cp.testing.assert_allclose(result, test_data) + + def test_numpy_input_dtype(self): + test_data = np.array([[1, 2], [3, 4]], dtype=np.float32) + test_data = np.rot90(test_data) + self.assertFalse(test_data.flags["C_CONTIGUOUS"]) + result = ToCupy(np.uint8)(test_data) + self.assertTrue(result.dtype == cp.uint8) self.assertTrue(isinstance(result, cp.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data) - @skipUnless(has_cp, "CuPy is required.") def test_tensor_input(self): - test_data = torch.tensor([[1, 2], [3, 4]]) + test_data = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) test_data = test_data.rot90() self.assertFalse(test_data.is_contiguous()) result = ToCupy()(test_data) + self.assertTrue(result.dtype == cp.float32) self.assertTrue(isinstance(result, cp.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - cp.testing.assert_allclose(result, test_data.numpy()) + cp.testing.assert_allclose(result, test_data) - @skipUnless(has_cp, "CuPy is required.") @skip_if_no_cuda def test_tensor_cuda_input(self): - test_data = torch.tensor([[1, 2], [3, 4]]).cuda() + test_data = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32).cuda() test_data = test_data.rot90() self.assertFalse(test_data.is_contiguous()) result = ToCupy()(test_data) + self.assertTrue(result.dtype == cp.float32) + self.assertTrue(isinstance(result, cp.ndarray)) + self.assertTrue(result.flags["C_CONTIGUOUS"]) + cp.testing.assert_allclose(result, test_data) + + @skip_if_no_cuda + def test_tensor_cuda_input_dtype(self): + test_data = torch.tensor([[1, 2], [3, 4]], dtype=torch.uint8).cuda() + test_data = test_data.rot90() + self.assertFalse(test_data.is_contiguous()) + + result = ToCupy(dtype="float32")(test_data) + self.assertTrue(result.dtype == cp.float32) self.assertTrue(isinstance(result, cp.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - cp.testing.assert_allclose(result, test_data.cpu().numpy()) + cp.testing.assert_allclose(result, test_data) - @skipUnless(has_cp, "CuPy is required.") def test_list_tuple(self): test_data = [[1, 2], [3, 4]] - result = ToCupy()(test_data) + result = ToCupy(wrap_sequence=True)(test_data) cp.testing.assert_allclose(result, cp.asarray(test_data)) test_data = ((1, 2), (3, 4)) - result = ToCupy()(test_data) + result = ToCupy(wrap_sequence=True)(test_data) cp.testing.assert_allclose(result, cp.asarray(test_data)) diff --git a/tests/test_to_cupyd.py b/tests/test_to_cupyd.py index 6f40bafe1c..3e778ae269 100644 --- a/tests/test_to_cupyd.py +++ b/tests/test_to_cupyd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,13 +17,13 @@ from monai.transforms import ToCupyd from monai.utils import optional_import -from tests.utils import skip_if_no_cuda +from tests.utils import HAS_CUPY, skip_if_no_cuda -cp, has_cp = optional_import("cupy") +cp, _ = optional_import("cupy") +@skipUnless(HAS_CUPY, "CuPy is required.") class TestToCupyd(unittest.TestCase): - @skipUnless(has_cp, "CuPy is required.") def test_cupy_input(self): test_data = cp.array([[1, 2], [3, 4]]) test_data = cp.rot90(test_data) @@ -33,7 +33,6 @@ def test_cupy_input(self): self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data) - @skipUnless(has_cp, "CuPy is required.") def test_numpy_input(self): test_data = np.array([[1, 2], [3, 4]]) test_data = np.rot90(test_data) @@ -43,7 +42,6 @@ def test_numpy_input(self): self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data) - @skipUnless(has_cp, "CuPy is required.") def test_tensor_input(self): test_data = torch.tensor([[1, 2], [3, 4]]) test_data = test_data.rot90() @@ -53,7 +51,6 @@ def test_tensor_input(self): self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data.numpy()) - @skipUnless(has_cp, "CuPy is required.") @skip_if_no_cuda def test_tensor_cuda_input(self): test_data = torch.tensor([[1, 2], [3, 4]]).cuda() @@ -64,13 +61,12 @@ def test_tensor_cuda_input(self): self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data.cpu().numpy()) - @skipUnless(has_cp, "CuPy is required.") def test_list_tuple(self): test_data = [[1, 2], [3, 4]] - result = ToCupyd(keys="img")({"img": test_data})["img"] + result = ToCupyd(keys="img", wrap_sequence=True)({"img": test_data})["img"] cp.testing.assert_allclose(result, cp.asarray(test_data)) test_data = ((1, 2), (3, 4)) - result = ToCupyd(keys="img")({"img": test_data})["img"] + result = ToCupyd(keys="img", wrap_sequence=True)({"img": test_data})["img"] cp.testing.assert_allclose(result, cp.asarray(test_data)) diff --git a/tests/test_to_device.py b/tests/test_to_device.py index 9855a353f0..70f1ea8828 100644 --- a/tests/test_to_device.py +++ b/tests/test_to_device.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_to_deviced.py b/tests/test_to_deviced.py index 0d5d1d1cdc..7d075ad365 100644 --- a/tests/test_to_deviced.py +++ b/tests/test_to_deviced.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -24,9 +24,7 @@ def test_value(self): device = "cuda:0" data = [{"img": torch.tensor(i)} for i in range(4)] dataset = CacheDataset( - data=data, - transform=ToDeviced(keys="img", device=device, non_blocking=True), - cache_rate=1.0, + data=data, transform=ToDeviced(keys="img", device=device, non_blocking=True), cache_rate=1.0 ) dataloader = ThreadDataLoader(dataset=dataset, num_workers=0, batch_size=1) for i, d in enumerate(dataloader): diff --git a/tests/test_to_numpy.py b/tests/test_to_numpy.py index b48727c01d..e1f135a289 100644 --- a/tests/test_to_numpy.py +++ b/tests/test_to_numpy.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,13 +17,13 @@ from monai.transforms import ToNumpy from monai.utils import optional_import -from tests.utils import assert_allclose, skip_if_no_cuda +from tests.utils import HAS_CUPY, assert_allclose, skip_if_no_cuda -cp, has_cp = optional_import("cupy") +cp, _ = optional_import("cupy") class TestToNumpy(unittest.TestCase): - @skipUnless(has_cp, "CuPy is required.") + @skipUnless(HAS_CUPY, "CuPy is required.") def test_cupy_input(self): test_data = cp.array([[1, 2], [3, 4]]) test_data = cp.rot90(test_data) @@ -31,25 +31,26 @@ def test_cupy_input(self): result = ToNumpy()(test_data) self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data.get()) + assert_allclose(result, test_data.get(), type_test=False) def test_numpy_input(self): test_data = np.array([[1, 2], [3, 4]]) test_data = np.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) - result = ToNumpy()(test_data) + result = ToNumpy(dtype="float32")(test_data) self.assertTrue(isinstance(result, np.ndarray)) + self.assertTrue(result.dtype == np.float32) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) def test_tensor_input(self): test_data = torch.tensor([[1, 2], [3, 4]]) test_data = test_data.rot90() self.assertFalse(test_data.is_contiguous()) - result = ToNumpy()(test_data) + result = ToNumpy(dtype=torch.uint8)(test_data) self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) @skip_if_no_cuda def test_tensor_cuda_input(self): @@ -59,21 +60,22 @@ def test_tensor_cuda_input(self): result = ToNumpy()(test_data) self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) def test_list_tuple(self): test_data = [[1, 2], [3, 4]] result = ToNumpy()(test_data) - assert_allclose(result, np.asarray(test_data)) + assert_allclose(result, np.asarray(test_data), type_test=False) test_data = ((1, 2), (3, 4)) - result = ToNumpy()(test_data) - assert_allclose(result, np.asarray(test_data)) + result = ToNumpy(wrap_sequence=False)(test_data) + self.assertTrue(type(result), tuple) + assert_allclose(result, ((np.asarray(1), np.asarray(2)), (np.asarray(3), np.asarray(4)))) def test_single_value(self): for test_data in [5, np.array(5), torch.tensor(5)]: - result = ToNumpy()(test_data) + result = ToNumpy(dtype=np.uint8)(test_data) self.assertTrue(isinstance(result, np.ndarray)) - assert_allclose(result, np.asarray(test_data)) + assert_allclose(result, np.asarray(test_data), type_test=False) self.assertEqual(result.ndim, 0) diff --git a/tests/test_to_numpyd.py b/tests/test_to_numpyd.py index 5acaef39c7..ba7cf798ef 100644 --- a/tests/test_to_numpyd.py +++ b/tests/test_to_numpyd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,13 +17,13 @@ from monai.transforms import ToNumpyd from monai.utils import optional_import -from tests.utils import assert_allclose, skip_if_no_cuda +from tests.utils import HAS_CUPY, assert_allclose, skip_if_no_cuda -cp, has_cp = optional_import("cupy") +cp, _ = optional_import("cupy") class TestToNumpyd(unittest.TestCase): - @skipUnless(has_cp, "CuPy is required.") + @skipUnless(HAS_CUPY, "CuPy is required.") def test_cupy_input(self): test_data = cp.array([[1, 2], [3, 4]]) test_data = cp.rot90(test_data) @@ -31,7 +31,7 @@ def test_cupy_input(self): result = ToNumpyd(keys="img")({"img": test_data})["img"] self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data.get()) + assert_allclose(result, test_data.get(), type_test=False) def test_numpy_input(self): test_data = np.array([[1, 2], [3, 4]]) @@ -40,7 +40,7 @@ def test_numpy_input(self): result = ToNumpyd(keys="img")({"img": test_data})["img"] self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) def test_tensor_input(self): test_data = torch.tensor([[1, 2], [3, 4]]) @@ -49,7 +49,7 @@ def test_tensor_input(self): result = ToNumpyd(keys="img")({"img": test_data})["img"] self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) @skip_if_no_cuda def test_tensor_cuda_input(self): @@ -59,7 +59,7 @@ def test_tensor_cuda_input(self): result = ToNumpyd(keys="img")({"img": test_data})["img"] self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - assert_allclose(result, test_data) + assert_allclose(result, test_data, type_test=False) if __name__ == "__main__": diff --git a/tests/test_to_onehot.py b/tests/test_to_onehot.py index c3e373955d..c08672bfb2 100644 --- a/tests/test_to_onehot.py +++ b/tests/test_to_onehot.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_to_pil.py b/tests/test_to_pil.py index 5690645dd8..0a1351028c 100644 --- a/tests/test_to_pil.py +++ b/tests/test_to_pil.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -43,7 +43,7 @@ class TestToPIL(unittest.TestCase): def test_value(self, test_data): result = ToPIL()(test_data) self.assertTrue(isinstance(result, PILImageImage)) - assert_allclose(np.array(result), test_data) + assert_allclose(np.array(result), test_data, type_test=False) if __name__ == "__main__": diff --git a/tests/test_to_pild.py b/tests/test_to_pild.py index 3a15b1e507..d00ecf13d4 100644 --- a/tests/test_to_pild.py +++ b/tests/test_to_pild.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -30,9 +30,7 @@ PILImageImage, _ = optional_import("PIL.Image", name="Image") im = [[1.0, 2.0], [3.0, 4.0]] -TESTS = [] -for p in TEST_NDARRAYS: - TESTS.append([{"keys": "image"}, {"image": p(im)}]) +TESTS = [[{"keys": "image"}, {"image": p(im)}] for p in TEST_NDARRAYS] if has_pil: TESTS.append([{"keys": "image"}, {"image": pil_image_fromarray(np.array(im))}]) @@ -43,7 +41,7 @@ class TestToPIL(unittest.TestCase): def test_values(self, input_param, test_data): result = ToPILd(**input_param)(test_data)[input_param["keys"]] self.assertTrue(isinstance(result, PILImageImage)) - assert_allclose(np.array(result), test_data[input_param["keys"]]) + assert_allclose(np.array(result), test_data[input_param["keys"]], type_test=False) if __name__ == "__main__": diff --git a/tests/test_to_tensor.py b/tests/test_to_tensor.py index 6ac06983f6..bfc61cdb19 100644 --- a/tests/test_to_tensor.py +++ b/tests/test_to_tensor.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,10 +11,13 @@ import unittest +import torch from parameterized import parameterized from monai.transforms import ToTensor -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import HAS_CUPY, TEST_NDARRAYS, assert_allclose, optional_import + +cp, _ = optional_import("cupy") im = [[1, 2], [3, 4]] @@ -32,16 +35,26 @@ class TestToTensor(unittest.TestCase): @parameterized.expand(TESTS) def test_array_input(self, test_data, expected_shape): - result = ToTensor()(test_data) - assert_allclose(result, test_data) + result = ToTensor(dtype=torch.float32, device="cpu", wrap_sequence=True)(test_data) + self.assertTrue(isinstance(result, torch.Tensor)) + assert_allclose(result, test_data, type_test=False) self.assertTupleEqual(result.shape, expected_shape) @parameterized.expand(TESTS_SINGLE) def test_single_input(self, test_data): result = ToTensor()(test_data) - assert_allclose(result, test_data) + self.assertTrue(isinstance(result, torch.Tensor)) + assert_allclose(result, test_data, type_test=False) self.assertEqual(result.ndim, 0) + @unittest.skipUnless(HAS_CUPY, "CuPy is required.") + def test_cupy(self): + test_data = [[1, 2], [3, 4]] + cupy_array = cp.ascontiguousarray(cp.asarray(test_data)) + result = ToTensor()(cupy_array) + self.assertTrue(isinstance(result, torch.Tensor)) + assert_allclose(result, test_data, type_test=False) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_torchscript_utils.py b/tests/test_torchscript_utils.py new file mode 100644 index 0000000000..b26d41345a --- /dev/null +++ b/tests/test_torchscript_utils.py @@ -0,0 +1,112 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import torch + +from monai.config import get_config_values +from monai.data import load_net_with_metadata, save_net_with_metadata +from monai.utils import JITMetadataKeys +from monai.utils.module import pytorch_after + + +class TestModule(torch.nn.Module): + def forward(self, x): + return x + 10 + + +class TestTorchscript(unittest.TestCase): + def test_save_net_with_metadata(self): + """Save a network without metadata to a file.""" + m = torch.jit.script(TestModule()) + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test") + + self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) + + def test_save_net_with_metadata_ext(self): + """Save a network without metadata to a file.""" + m = torch.jit.script(TestModule()) + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test.zip") + + self.assertTrue(os.path.isfile(f"{tempdir}/test.zip")) + + def test_save_net_with_metadata_with_extra(self): + """Save a network with simple metadata to a file.""" + m = torch.jit.script(TestModule()) + + test_metadata = {"foo": [1, 2], "bar": "string"} + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata) + + self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) + + def test_load_net_with_metadata(self): + """Save then load a network with no metadata or other extra files.""" + m = torch.jit.script(TestModule()) + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test") + _, meta, extra_files = load_net_with_metadata(f"{tempdir}/test.pt") + + del meta[JITMetadataKeys.TIMESTAMP.value] # no way of knowing precisely what this value would be + + self.assertEqual(meta, get_config_values()) + self.assertEqual(extra_files, {}) + + def test_load_net_with_metadata_with_extra(self): + """Save then load a network with basic metadata.""" + m = torch.jit.script(TestModule()) + + test_metadata = {"foo": [1, 2], "bar": "string"} + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata) + _, meta, extra_files = load_net_with_metadata(f"{tempdir}/test.pt") + + del meta[JITMetadataKeys.TIMESTAMP.value] # no way of knowing precisely what this value would be + + test_compare = get_config_values() + test_compare.update(test_metadata) + + self.assertEqual(meta, test_compare) + self.assertEqual(extra_files, {}) + + def test_save_load_more_extra_files(self): + """Save then load extra file data from a torchscript file.""" + m = torch.jit.script(TestModule()) + + test_metadata = {"foo": [1, 2], "bar": "string"} + + more_extra_files = {"test.txt": b"This is test data"} + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata, more_extra_files=more_extra_files) + + self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) + + _, _, loaded_extra_files = load_net_with_metadata(f"{tempdir}/test.pt", more_extra_files=("test.txt",)) + + if pytorch_after(1, 7): + self.assertEqual(more_extra_files["test.txt"], loaded_extra_files["test.txt"]) + else: + self.assertEqual(more_extra_files["test.txt"].decode(), loaded_extra_files["test.txt"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_torchvision.py b/tests/test_torchvision.py index 0846b7f6b6..e0844eb4b9 100644 --- a/tests/test_torchvision.py +++ b/tests/test_torchvision.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,75 +11,54 @@ import unittest -import torch from parameterized import parameterized from monai.transforms import TorchVision from monai.utils import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion +from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose -TEST_CASE_1 = [ - {"name": "ColorJitter"}, - torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), - torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), -] - -TEST_CASE_2 = [ - {"name": "ColorJitter", "brightness": 0.5, "contrast": 0.5, "saturation": [0.1, 0.8], "hue": 0.5}, - torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), - torch.tensor( - [ - [ - [0.1090, 0.6193], - [0.6193, 0.9164], - ], - [ - [0.1090, 0.6193], - [0.6193, 0.9164], - ], - [ - [0.1090, 0.6193], - [0.6193, 0.9164], - ], - ], - ), -] - -TEST_CASE_3 = [ - {"name": "Pad", "padding": [1, 1, 1, 1]}, - torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), - torch.tensor( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.extend( [ [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 0.0, 0.0, 0.0], + {"name": "ColorJitter"}, + p([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), + p([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), ], [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 0.0, 0.0, 0.0], + {"name": "ColorJitter", "brightness": 0.5, "contrast": 0.5, "saturation": [0.1, 0.8], "hue": 0.5}, + p([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), + p( + [ + [[0.1090, 0.6193], [0.6193, 0.9164]], + [[0.1090, 0.6193], [0.6193, 0.9164]], + [[0.1090, 0.6193], [0.6193, 0.9164]], + ] + ), ], [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 0.0, 0.0, 0.0], + {"name": "Pad", "padding": [1, 1, 1, 1]}, + p([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), + p( + [ + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + ] + ), ], ] - ), -] + ) @SkipIfBeforePyTorchVersion((1, 7)) class TestTorchVision(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, input_param, input_data, expected_value): set_determinism(seed=0) result = TorchVision(**input_param)(input_data) - torch.testing.assert_allclose(result, expected_value) + assert_allclose(result, expected_value, rtol=1e-3) if __name__ == "__main__": diff --git a/tests/test_torchvision_fc_model.py b/tests/test_torchvision_fc_model.py index d6d3ea69c9..98b300eeac 100644 --- a/tests/test_torchvision_fc_model.py +++ b/tests/test_torchvision_fc_model.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -115,17 +115,7 @@ class TestTorchVisionFCModel(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - ] - ) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) @skipUnless(has_tv, "Requires TorchVision.") def test_without_pretrained(self, input_param, input_shape, expected_shape): net = TorchVisionFCModel(**input_param).to(device) diff --git a/tests/test_torchvision_fully_conv_model.py b/tests/test_torchvision_fully_conv_model.py index af2c1458d3..34a61ce9fa 100644 --- a/tests/test_torchvision_fully_conv_model.py +++ b/tests/test_torchvision_fully_conv_model.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,23 +23,11 @@ device = "cuda" if torch.cuda.is_available() else "cpu" -TEST_CASE_0 = [ - {"model_name": "resnet18", "num_classes": 1, "pretrained": False}, - (2, 3, 224, 224), - (2, 1, 1, 1), -] +TEST_CASE_0 = [{"model_name": "resnet18", "num_classes": 1, "pretrained": False}, (2, 3, 224, 224), (2, 1, 1, 1)] -TEST_CASE_1 = [ - {"model_name": "resnet18", "num_classes": 1, "pretrained": False}, - (2, 3, 256, 256), - (2, 1, 2, 2), -] +TEST_CASE_1 = [{"model_name": "resnet18", "num_classes": 1, "pretrained": False}, (2, 3, 256, 256), (2, 1, 2, 2)] -TEST_CASE_2 = [ - {"model_name": "resnet101", "num_classes": 5, "pretrained": False}, - (2, 3, 256, 256), - (2, 5, 2, 2), -] +TEST_CASE_2 = [{"model_name": "resnet101", "num_classes": 5, "pretrained": False}, (2, 3, 256, 256), (2, 5, 2, 2)] TEST_CASE_3 = [ {"model_name": "resnet101", "num_classes": 5, "pool_size": 6, "pretrained": False}, @@ -70,14 +58,7 @@ class TestTorchVisionFullyConvModel(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - ] - ) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) @skipUnless(has_tv, "Requires TorchVision.") def test_without_pretrained(self, input_param, input_shape, expected_shape): net = TorchVisionFullyConvModel(**input_param).to(device) @@ -85,13 +66,7 @@ def test_without_pretrained(self, input_param, input_shape, expected_shape): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) - @parameterized.expand( - [ - TEST_CASE_PRETRAINED_0, - TEST_CASE_PRETRAINED_1, - TEST_CASE_PRETRAINED_2, - ] - ) + @parameterized.expand([TEST_CASE_PRETRAINED_0, TEST_CASE_PRETRAINED_1, TEST_CASE_PRETRAINED_2]) @skipUnless(has_tv, "Requires TorchVision.") def test_with_pretrained(self, input_param, input_shape, expected_shape, expected_value): net = TorchVisionFullyConvModel(**input_param).to(device) diff --git a/tests/test_torchvisiond.py b/tests/test_torchvisiond.py index 4f42bc95f7..4c62c6e41a 100644 --- a/tests/test_torchvisiond.py +++ b/tests/test_torchvisiond.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -29,19 +29,10 @@ {"img": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])}, torch.tensor( [ - [ - [0.1090, 0.6193], - [0.6193, 0.9164], - ], - [ - [0.1090, 0.6193], - [0.6193, 0.9164], - ], - [ - [0.1090, 0.6193], - [0.6193, 0.9164], - ], - ], + [[0.1090, 0.6193], [0.6193, 0.9164]], + [[0.1090, 0.6193], [0.6193, 0.9164]], + [[0.1090, 0.6193], [0.6193, 0.9164]], + ] ), ] @@ -50,24 +41,9 @@ {"img": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])}, torch.tensor( [ - [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - ], + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], ] ), ] diff --git a/tests/test_traceable_transform.py b/tests/test_traceable_transform.py new file mode 100644 index 0000000000..bc6aad3a62 --- /dev/null +++ b/tests/test_traceable_transform.py @@ -0,0 +1,53 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from monai.transforms.inverse import TraceableTransform + + +class _TraceTest(TraceableTransform): + def __call__(self, data): + self.push_transform(data) + return data + + def pop(self, data): + self.pop_transform(data) + return data + + +class TestTraceable(unittest.TestCase): + def test_default(self): + expected_key = "_transforms" + a = _TraceTest() + self.assertEqual(a.trace_key(), expected_key) + + data = {"image": "test"} + data = a(data) # adds to the stack + self.assertTrue(isinstance(data[expected_key], list)) + self.assertEqual(data[expected_key][0]["class"], "_TraceTest") + + data = a(data) # adds to the stack + self.assertEqual(len(data[expected_key]), 2) + self.assertEqual(data[expected_key][-1]["class"], "_TraceTest") + + with self.assertRaises(IndexError): + a.pop({"test": "test"}) # no stack in the data + data = a.pop(data) + data = a.pop(data) + self.assertEqual(data[expected_key], []) + + with self.assertRaises(IndexError): # no more items + a.pop(data) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_train_mode.py b/tests/test_train_mode.py index 1acb443041..231e3854f0 100644 --- a/tests/test_train_mode.py +++ b/tests/test_train_mode.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_transchex.py b/tests/test_transchex.py new file mode 100644 index 0000000000..462ce64fd6 --- /dev/null +++ b/tests/test_transchex.py @@ -0,0 +1,82 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.transchex import Transchex +from tests.utils import skip_if_quick + +TEST_CASE_TRANSCHEX = [] +for drop_out in [0.4]: + for in_channels in [3]: + for img_size in [224]: + for patch_size in [16, 32]: + for num_language_layers in [2]: + for num_vision_layers in [4]: + for num_mixed_layers in [3]: + for num_classes in [8]: + test_case = [ + { + "in_channels": in_channels, + "img_size": (img_size,) * 2, + "patch_size": (patch_size,) * 2, + "num_vision_layers": num_vision_layers, + "num_mixed_layers": num_mixed_layers, + "num_language_layers": num_language_layers, + "num_classes": num_classes, + "drop_out": drop_out, + }, + (2, num_classes), # type: ignore + ] + TEST_CASE_TRANSCHEX.append(test_case) + + +@skip_if_quick +class TestTranschex(unittest.TestCase): + @parameterized.expand(TEST_CASE_TRANSCHEX) + def test_shape(self, input_param, expected_shape): + net = Transchex(**input_param) + with eval_mode(net): + result = net(torch.randint(2, (2, 512)), torch.randint(2, (2, 512)), torch.randn((2, 3, 224, 224))) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + Transchex( + in_channels=3, + img_size=(128, 128), + patch_size=(16, 16), + num_language_layers=2, + num_mixed_layers=4, + num_vision_layers=2, + num_classes=2, + drop_out=5.0, + ) + + with self.assertRaises(ValueError): + Transchex( + in_channels=1, + img_size=(97, 97), + patch_size=(16, 16), + num_language_layers=6, + num_mixed_layers=6, + num_vision_layers=8, + num_classes=8, + drop_out=0.4, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_transformerblock.py b/tests/test_transformerblock.py index 616e3e7ec9..d6131d010c 100644 --- a/tests/test_transformerblock.py +++ b/tests/test_transformerblock.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_transpose.py b/tests/test_transpose.py index 10882c9dd8..94a5b49c3a 100644 --- a/tests/test_transpose.py +++ b/tests/test_transpose.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,18 +20,8 @@ TESTS = [] for p in TEST_NDARRAYS: - TESTS.append( - [ - p(np.arange(5 * 4).reshape(5, 4)), - None, - ] - ) - TESTS.append( - [ - p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), - [2, 0, 1], - ] - ) + TESTS.append([p(np.arange(5 * 4).reshape(5, 4)), None]) + TESTS.append([p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), [2, 0, 1]]) class TestTranspose(unittest.TestCase): @@ -42,7 +32,7 @@ def test_transpose(self, im, indices): if isinstance(im, torch.Tensor): im = im.cpu().numpy() out2 = np.transpose(im, indices) - assert_allclose(out1, out2) + assert_allclose(out1, out2, type_test=False) if __name__ == "__main__": diff --git a/tests/test_transposed.py b/tests/test_transposed.py index 88ecd0c872..14e62eb9da 100644 --- a/tests/test_transposed.py +++ b/tests/test_transposed.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,30 +21,10 @@ TESTS = [] for p in TEST_NDARRAYS: - TESTS.append( - [ - p(np.arange(5 * 4).reshape(5, 4)), - [1, 0], - ] - ) - TESTS.append( - [ - p(np.arange(5 * 4).reshape(5, 4)), - None, - ] - ) - TESTS.append( - [ - p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), - [2, 0, 1], - ] - ) - TESTS.append( - [ - p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), - None, - ] - ) + TESTS.append([p(np.arange(5 * 4).reshape(5, 4)), [1, 0]]) + TESTS.append([p(np.arange(5 * 4).reshape(5, 4)), None]) + TESTS.append([p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), [2, 0, 1]]) + TESTS.append([p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), None]) class TestTranspose(unittest.TestCase): @@ -57,13 +37,13 @@ def test_transpose(self, im, indices): if isinstance(im, torch.Tensor): im = im.cpu().numpy() out_gt = np.transpose(im, indices) - assert_allclose(out_im1, out_gt) - assert_allclose(out_im2, out_gt) + assert_allclose(out_im1, out_gt, type_test=False) + assert_allclose(out_im2, out_gt, type_test=False) # test inverse fwd_inv_data = tr.inverse(out_data) for i, j in zip(data.values(), fwd_inv_data.values()): - assert_allclose(i, j) + assert_allclose(i, j, type_test=False) if __name__ == "__main__": diff --git a/tests/test_tversky_loss.py b/tests/test_tversky_loss.py index 0bc2ca2e70..2bb2409360 100644 --- a/tests/test_tversky_loss.py +++ b/tests/test_tversky_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,10 +21,7 @@ TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.307576, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) @@ -107,18 +104,12 @@ ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "alpha": 0.3, "beta": 0.7, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.3589, ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "alpha": 0.7, "beta": 0.3, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - }, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.247366, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) diff --git a/tests/test_unet.py b/tests/test_unet.py index 4091c4e9d7..1ae0094d55 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,7 +23,7 @@ TEST_CASE_0 = [ # single channel 2D, batch 16, no residual { - "dimensions": 2, + "spatial_dims": 2, "in_channels": 1, "out_channels": 3, "channels": (16, 32, 64), @@ -36,7 +36,7 @@ TEST_CASE_1 = [ # single channel 2D, batch 16 { - "dimensions": 2, + "spatial_dims": 2, "in_channels": 1, "out_channels": 3, "channels": (16, 32, 64), @@ -49,7 +49,7 @@ TEST_CASE_2 = [ # single channel 3D, batch 16 { - "dimensions": 3, + "spatial_dims": 3, "in_channels": 1, "out_channels": 3, "channels": (16, 32, 64), @@ -62,7 +62,7 @@ TEST_CASE_3 = [ # 4-channel 3D, batch 16 { - "dimensions": 3, + "spatial_dims": 3, "in_channels": 4, "out_channels": 3, "channels": (16, 32, 64), @@ -75,7 +75,7 @@ TEST_CASE_4 = [ # 4-channel 3D, batch 16, batch normalization { - "dimensions": 3, + "spatial_dims": 3, "in_channels": 4, "out_channels": 3, "channels": (16, 32, 64), @@ -89,7 +89,7 @@ TEST_CASE_5 = [ # 4-channel 3D, batch 16, LeakyReLU activation { - "dimensions": 3, + "spatial_dims": 3, "in_channels": 4, "out_channels": 3, "channels": (16, 32, 64), @@ -103,7 +103,7 @@ TEST_CASE_6 = [ # 4-channel 3D, batch 16, LeakyReLU activation explicit { - "dimensions": 3, + "spatial_dims": 3, "in_channels": 4, "out_channels": 3, "channels": (16, 32, 64), @@ -120,7 +120,7 @@ ILL_CASES = [ [ { # len(channels) < 2 - "dimensions": 2, + "spatial_dims": 2, "in_channels": 1, "out_channels": 3, "channels": (16,), @@ -130,7 +130,7 @@ ], [ { # len(strides) < len(channels) - 1 - "dimensions": 2, + "spatial_dims": 2, "in_channels": 1, "out_channels": 3, "channels": (8, 8, 8), @@ -139,8 +139,8 @@ } ], [ - { # len(kernel_size) = 3, dimensions = 2 - "dimensions": 2, + { # len(kernel_size) = 3, spatial_dims = 2 + "spatial_dims": 2, "in_channels": 1, "out_channels": 3, "channels": (8, 8, 8), @@ -149,8 +149,8 @@ } ], [ - { # len(up_kernel_size) = 2, dimensions = 3 - "dimensions": 3, + { # len(up_kernel_size) = 2, spatial_dims = 3 + "spatial_dims": 3, "in_channels": 1, "out_channels": 3, "channels": (8, 8, 8), @@ -170,13 +170,15 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) def test_script(self): - net = UNet(dimensions=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0) + net = UNet( + spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0 + ) test_data = torch.randn(16, 1, 32, 32) test_script_save(net, test_data) def test_script_without_running_stats(self): net = UNet( - dimensions=2, + spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), @@ -188,13 +190,7 @@ def test_script_without_running_stats(self): test_script_save(net, test_data) def test_ill_input_shape(self): - net = UNet( - dimensions=2, - in_channels=1, - out_channels=3, - channels=(16, 32, 64), - strides=(2, 2), - ) + net = UNet(spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2)) with eval_mode(net): with self.assertRaisesRegex(RuntimeError, "Sizes of tensors must match"): net.forward(torch.randn(2, 1, 16, 5)) diff --git a/tests/test_unetr.py b/tests/test_unetr.py index d19ed2ca59..a8c7b7bf88 100644 --- a/tests/test_unetr.py +++ b/tests/test_unetr.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_unetr_block.py b/tests/test_unetr_block.py index 7546918a2c..c0f14c829d 100644 --- a/tests/test_unetr_block.py +++ b/tests/test_unetr_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_upsample_block.py b/tests/test_upsample_block.py index 7b8ada399c..aa4141aabc 100644 --- a/tests/test_upsample_block.py +++ b/tests/test_upsample_block.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,11 +20,7 @@ TEST_CASES = [ [{"dimensions": 2, "in_channels": 4}, (7, 4, 32, 48), (7, 4, 64, 96)], # 4-channel 2D, batch 7 - [ - {"dimensions": 1, "in_channels": 4, "out_channels": 3}, - (16, 4, 63), - (16, 3, 126), - ], # 4-channel 1D, batch 16 + [{"dimensions": 1, "in_channels": 4, "out_channels": 3}, (16, 4, 63), (16, 3, 126)], # 4-channel 1D, batch 16 [ {"dimensions": 1, "in_channels": 4, "out_channels": 8, "mode": "deconv", "align_corners": False}, (16, 4, 20), @@ -78,14 +74,7 @@ expected_shape = (16, 5, 4 * s, 5 * s, 6 * s) for t in UpsampleMode: test_case = [ - { - "dimensions": 3, - "in_channels": 3, - "out_channels": 5, - "mode": t, - "scale_factor": s, - "align_corners": True, - }, + {"dimensions": 3, "in_channels": 3, "out_channels": 5, "mode": t, "scale_factor": s, "align_corners": True}, (16, 3, 4, 5, 6), ] test_case.append(expected_shape) diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py new file mode 100644 index 0000000000..4db3056b7b --- /dev/null +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -0,0 +1,59 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from monai.transforms.utils_pytorch_numpy_unification import percentile +from monai.utils import set_determinism +from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose + + +class TestPytorchNumpyUnification(unittest.TestCase): + def setUp(self) -> None: + set_determinism(0) + + def test_percentile(self): + for size in (1, 100): + q = np.random.randint(0, 100, size=size) + results = [] + for p in TEST_NDARRAYS: + arr = p(np.arange(100 * 101).reshape(1, 100, 101).astype(np.float32)) + results.append(percentile(arr, q)) + # pre torch 1.7, no `quantile`. Our own method doesn't interpolate, + # so we can only be accurate to 0.5 + atol = 0.5 if not hasattr(torch, "quantile") else 1e-4 + assert_allclose(results[0], results[-1], type_test=False, atol=atol) + + def test_fails(self): + for p in TEST_NDARRAYS: + for q in (-1, 101): + arr = p(np.arange(100 * 101).reshape(1, 100, 101).astype(np.float32)) + with self.assertRaises(ValueError): + percentile(arr, q) + + @SkipIfBeforePyTorchVersion((1, 7)) + def test_dim(self): + q = np.random.randint(0, 100, size=50) + results = [] + for p in TEST_NDARRAYS: + arr = p(np.arange(6).reshape(1, 2, 3).astype(np.float32)) + results.append(percentile(arr, q, dim=1)) + # pre torch 1.7, no `quantile`. Our own method doesn't interpolate, + # so we can only be accurate to 0.5 + atol = 0.5 if not hasattr(torch, "quantile") else 1e-4 + assert_allclose(results[0], results[-1], type_test=False, atol=atol) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_varautoencoder.py b/tests/test_varautoencoder.py index 7a4a546d87..95fea8afcb 100644 --- a/tests/test_varautoencoder.py +++ b/tests/test_varautoencoder.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,7 +23,7 @@ TEST_CASE_0 = [ # single channel 2D, batch 4, no residual { - "dimensions": 2, + "spatial_dims": 2, "in_shape": (1, 128, 128), "out_channels": 1, "latent_size": 2, @@ -37,7 +37,7 @@ TEST_CASE_1 = [ # single channel 2D, batch 4 { - "dimensions": 2, + "spatial_dims": 2, "in_shape": (1, 128, 128), "out_channels": 1, "latent_size": 2, @@ -50,7 +50,7 @@ TEST_CASE_2 = [ # 3-channel 2D, batch 4, LeakyReLU activation { - "dimensions": 2, + "spatial_dims": 2, "in_shape": (3, 128, 128), "out_channels": 3, "latent_size": 2, @@ -64,7 +64,7 @@ TEST_CASE_3 = [ # 4-channel 3D, batch 4 { - "dimensions": 3, + "spatial_dims": 3, "in_shape": (4, 128, 128, 128), "out_channels": 3, "latent_size": 2, @@ -88,7 +88,7 @@ def test_shape(self, input_param, input_shape, expected_shape): def test_script(self): net = VarAutoEncoder( - dimensions=2, in_shape=(1, 32, 32), out_channels=1, latent_size=2, channels=(4, 8), strides=(2, 2) + spatial_dims=2, in_shape=(1, 32, 32), out_channels=1, latent_size=2, channels=(4, 8), strides=(2, 2) ) test_data = torch.randn(2, 1, 32, 32) test_script_save(net, test_data) diff --git a/tests/test_version_leq.py b/tests/test_version_leq.py index a1913069d3..86fccca9fb 100644 --- a/tests/test_version_leq.py +++ b/tests/test_version_leq.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -67,6 +67,9 @@ def _pairwise(iterable): ("0post1", "0.4post1"), ("2.1.0-rc1", "2.1.0"), ("2.1dev", "2.1a0"), + (1.6, "1.6.0"), + ("1.6.0", 1.6), + (1.6, 1.7), ) + tuple(_pairwise(reversed(torture.split()))) diff --git a/tests/test_vis_cam.py b/tests/test_vis_cam.py index 47c116cd5d..2137926424 100644 --- a/tests/test_vis_cam.py +++ b/tests/test_vis_cam.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_vis_gradcam.py b/tests/test_vis_gradcam.py index eebf32d70b..acca06d405 100644 --- a/tests/test_vis_gradcam.py +++ b/tests/test_vis_gradcam.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -40,23 +40,13 @@ ] # 2D TEST_CASE_2 = [ - { - "model": "senet2d", - "shape": (2, 3, 64, 64), - "feature_shape": (2, 1, 2, 2), - "target_layers": "layer4", - }, + {"model": "senet2d", "shape": (2, 3, 64, 64), "feature_shape": (2, 1, 2, 2), "target_layers": "layer4"}, (2, 1, 64, 64), ] # 3D TEST_CASE_3 = [ - { - "model": "senet3d", - "shape": (2, 3, 8, 8, 48), - "feature_shape": (2, 1, 1, 1, 2), - "target_layers": "layer4", - }, + {"model": "senet3d", "shape": (2, 3, 8, 8, 48), "feature_shape": (2, 1, 1, 1, 2), "target_layers": "layer4"}, (2, 1, 8, 8, 48), ] @@ -88,6 +78,16 @@ def test_shape(self, input_data, expected_shape): result2 = cam(x=image, layer_idx=-1, class_idx=model(image).max(1)[-1].cpu()) torch.testing.assert_allclose(result, result2) + def test_ill(self): + model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) + for name, x in model.named_parameters(): + if "features" in name: + x.requires_grad = False + cam = GradCAM(nn_module=model, target_layers="class_layers.relu") + image = torch.rand((2, 1, 48, 64)) + with self.assertRaises(RuntimeError): + cam(x=image) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_vis_gradcampp.py b/tests/test_vis_gradcampp.py index 92a4b2ac7b..a261b6055b 100644 --- a/tests/test_vis_gradcampp.py +++ b/tests/test_vis_gradcampp.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -39,23 +39,13 @@ ] # 2D TEST_CASE_2 = [ - { - "model": "senet2d", - "shape": (2, 3, 64, 64), - "feature_shape": (2, 1, 2, 2), - "target_layers": "layer4", - }, + {"model": "senet2d", "shape": (2, 3, 64, 64), "feature_shape": (2, 1, 2, 2), "target_layers": "layer4"}, (2, 1, 64, 64), ] # 3D TEST_CASE_3 = [ - { - "model": "senet3d", - "shape": (2, 3, 8, 8, 48), - "feature_shape": (2, 1, 1, 1, 2), - "target_layers": "layer4", - }, + {"model": "senet3d", "shape": (2, 3, 8, 8, 48), "feature_shape": (2, 1, 1, 1, 2), "target_layers": "layer4"}, (2, 1, 8, 8, 48), ] diff --git a/tests/test_vit.py b/tests/test_vit.py index cdf0888222..870e4010ec 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_vitautoenc.py b/tests/test_vitautoenc.py new file mode 100644 index 0000000000..c45cde68c2 --- /dev/null +++ b/tests/test_vitautoenc.py @@ -0,0 +1,121 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.vitautoenc import ViTAutoEnc + +TEST_CASE_Vitautoenc = [] +for in_channels in [1, 4]: + for img_size in [64, 96, 128]: + for patch_size in [16]: + for pos_embed in ["conv", "perceptron"]: + for nd in [2, 3]: + test_case = [ + { + "in_channels": in_channels, + "img_size": (img_size,) * nd, + "patch_size": (patch_size,) * nd, + "hidden_size": 768, + "mlp_dim": 3072, + "num_layers": 4, + "num_heads": 12, + "pos_embed": pos_embed, + "dropout_rate": 0.6, + "spatial_dims": nd, + }, + (2, in_channels, *([img_size] * nd)), + (2, 1, *([img_size] * nd)), + ] + + TEST_CASE_Vitautoenc.append(test_case) + + +class TestPatchEmbeddingBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_Vitautoenc) + def test_shape(self, input_param, input_shape, expected_shape): + net = ViTAutoEnc(**input_param) + with eval_mode(net): + result, _ = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + ViTAutoEnc( + in_channels=1, + img_size=(128, 128, 128), + patch_size=(16, 16, 16), + hidden_size=128, + mlp_dim=3072, + num_layers=12, + num_heads=12, + pos_embed="conv", + dropout_rate=5.0, + ) + + with self.assertRaises(ValueError): + ViTAutoEnc( + in_channels=1, + img_size=(32, 32, 32), + patch_size=(64, 64, 64), + hidden_size=512, + mlp_dim=3072, + num_layers=12, + num_heads=8, + pos_embed="perceptron", + dropout_rate=0.3, + ) + + with self.assertRaises(ValueError): + ViTAutoEnc( + in_channels=1, + img_size=(96, 96, 96), + patch_size=(8, 8, 8), + hidden_size=512, + mlp_dim=3072, + num_layers=12, + num_heads=14, + pos_embed="conv", + dropout_rate=0.3, + ) + + with self.assertRaises(ValueError): + ViTAutoEnc( + in_channels=1, + img_size=(97, 97, 97), + patch_size=(4, 4, 4), + hidden_size=768, + mlp_dim=3072, + num_layers=12, + num_heads=8, + pos_embed="perceptron", + dropout_rate=0.3, + ) + + with self.assertRaises(ValueError): + ViTAutoEnc( + in_channels=4, + img_size=(96, 96, 96), + patch_size=(16, 16, 16), + hidden_size=768, + mlp_dim=3072, + num_layers=12, + num_heads=12, + pos_embed="perc", + dropout_rate=0.3, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vnet.py b/tests/test_vnet.py index 4eba5396b2..add0396bd8 100644 --- a/tests/test_vnet.py +++ b/tests/test_vnet.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_vote_ensemble.py b/tests/test_vote_ensemble.py index 74c19d5f48..79868d4706 100644 --- a/tests/test_vote_ensemble.py +++ b/tests/test_vote_ensemble.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,66 +15,67 @@ from parameterized import parameterized from monai.transforms import VoteEnsemble +from tests.utils import TEST_NDARRAYS, assert_allclose -# shape: [2, 1, 1] -TEST_CASE_1 = [ - {"num_classes": None}, - [torch.tensor([[[1]], [[0]]]), torch.tensor([[[1]], [[0]]]), torch.tensor([[[0]], [[1]]])], - torch.tensor([[[1.0]], [[0.0]]]), -] +TESTS = [] +for p in TEST_NDARRAYS: + # shape: [2, 1, 1] + TESTS.append( + [ + {"num_classes": None}, + [p(torch.tensor([[[1]], [[0]]])), p(torch.tensor([[[1]], [[0]]])), p(torch.tensor([[[0]], [[1]]]))], + p(torch.tensor([[[1.0]], [[0.0]]])), + ] + ) -# shape: [1, 2, 1, 1] -TEST_CASE_2 = [ - {"num_classes": None}, - torch.stack([torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])]), - torch.tensor([[[[1.0]], [[0.0]]]]), -] + # shape: [1, 2, 1, 1] + TESTS.append( + [ + {"num_classes": None}, + p( + torch.stack( + [torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])] + ) + ), + p(torch.tensor([[[[1.0]], [[0.0]]]])), + ] + ) -# shape: [1, 2, 1] -TEST_CASE_3 = [ - {"num_classes": 3}, - [torch.tensor([[[0], [2]]]), torch.tensor([[[0], [2]]]), torch.tensor([[[1], [1]]])], - torch.tensor([[[0], [2]]]), -] + # shape: [1, 2, 1] + TESTS.append( + [ + {"num_classes": 3}, + [p(torch.tensor([[[0], [2]]])), p(torch.tensor([[[0], [2]]])), p(torch.tensor([[[1], [1]]]))], + p(torch.tensor([[[0], [2]]])), + ] + ) -# shape: [1, 2, 1] -TEST_CASE_4 = [ - {"num_classes": 5}, - [torch.tensor([[[0], [2]]]), torch.tensor([[[0], [2]]]), torch.tensor([[[1], [1]]])], - torch.tensor([[[0], [2]]]), -] + # shape: [1, 2, 1] + TESTS.append( + [ + {"num_classes": 5}, + [p(torch.tensor([[[0], [2]]])), p(torch.tensor([[[0], [2]]])), p(torch.tensor([[[1], [1]]]))], + p(torch.tensor([[[0], [2]]])), + ] + ) -# shape: [1] -TEST_CASE_5 = [ - {"num_classes": 3}, - [torch.tensor([2]), torch.tensor([2]), torch.tensor([1])], - torch.tensor([2]), -] + # shape: [1] + TESTS.append( + [{"num_classes": 3}, [p(torch.tensor([2])), p(torch.tensor([2])), p(torch.tensor([1]))], p(torch.tensor([2]))] + ) -# shape: 1 -TEST_CASE_6 = [ - {"num_classes": 3}, - [torch.tensor(2), torch.tensor(2), torch.tensor(1)], - torch.tensor(2), -] + # shape: 1 + TESTS.append([{"num_classes": 3}, [p(torch.tensor(2)), p(torch.tensor(2)), p(torch.tensor(1))], p(torch.tensor(2))]) class TestVoteEnsemble(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand(TESTS) def test_value(self, input_param, img, expected_value): result = VoteEnsemble(**input_param)(img) - torch.testing.assert_allclose(result, expected_value) - - def test_cuda_value(self): - img = torch.stack( - [torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])] - ) - expected_value = torch.tensor([[[[1.0]], [[0.0]]]]) - if torch.cuda.is_available(): - img = img.to(torch.device("cuda:0")) - expected_value = expected_value.to(torch.device("cuda:0")) - result = VoteEnsemble(num_classes=None)(img) - torch.testing.assert_allclose(result, expected_value) + if isinstance(img, torch.Tensor): + self.assertIsInstance(result, torch.Tensor) + self.assertEqual(result.device, img.device) + assert_allclose(result, expected_value) if __name__ == "__main__": diff --git a/tests/test_vote_ensembled.py b/tests/test_vote_ensembled.py index e94213733f..e42a57f3b7 100644 --- a/tests/test_vote_ensembled.py +++ b/tests/test_vote_ensembled.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,64 +15,79 @@ from parameterized import parameterized from monai.transforms import VoteEnsembled +from tests.utils import TEST_NDARRAYS, assert_allclose -# shape: [1, 2, 1, 1] -TEST_CASE_1 = [ - {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": None}, - { - "pred0": torch.tensor([[[[1]], [[0]]]]), - "pred1": torch.tensor([[[[1]], [[0]]]]), - "pred2": torch.tensor([[[[0]], [[1]]]]), - }, - torch.tensor([[[[1.0]], [[0.0]]]]), -] +TESTS = [] +for p in TEST_NDARRAYS: + # shape: [1, 2, 1, 1] + TESTS.append( + [ + {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": None}, + { + "pred0": p(torch.tensor([[[[1]], [[0]]]])), + "pred1": p(torch.tensor([[[[1]], [[0]]]])), + "pred2": p(torch.tensor([[[[0]], [[1]]]])), + }, + p(torch.tensor([[[[1.0]], [[0.0]]]])), + ] + ) -# shape: [1, 2, 1, 1] -TEST_CASE_2 = [ - {"keys": "output", "output_key": "output", "num_classes": None}, - { - "output": torch.stack( - [torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])] - ) - }, - torch.tensor([[[[1.0]], [[0.0]]]]), -] + # shape: [1, 2, 1, 1] + TESTS.append( + [ + {"keys": "output", "output_key": "output", "num_classes": None}, + { + "output": p( + torch.stack( + [torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])] + ) + ) + }, + p(torch.tensor([[[[1.0]], [[0.0]]]])), + ] + ) -# shape: [1, 2, 1] -TEST_CASE_3 = [ - {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 3}, - { - "pred0": torch.tensor([[[0], [2]]]), - "pred1": torch.tensor([[[0], [2]]]), - "pred2": torch.tensor([[[1], [1]]]), - }, - torch.tensor([[[0], [2]]]), -] + # shape: [1, 2, 1] + TESTS.append( + [ + {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 3}, + { + "pred0": p(torch.tensor([[[0], [2]]])), + "pred1": p(torch.tensor([[[0], [2]]])), + "pred2": p(torch.tensor([[[1], [1]]])), + }, + p(torch.tensor([[[0], [2]]])), + ] + ) -# shape: [1, 2, 1] -TEST_CASE_4 = [ - {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 5}, - { - "pred0": torch.tensor([[[0], [2]]]), - "pred1": torch.tensor([[[0], [2]]]), - "pred2": torch.tensor([[[1], [1]]]), - }, - torch.tensor([[[0], [2]]]), -] + # shape: [1, 2, 1] + TESTS.append( + [ + {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 5}, + { + "pred0": p(torch.tensor([[[0], [2]]])), + "pred1": p(torch.tensor([[[0], [2]]])), + "pred2": p(torch.tensor([[[1], [1]]])), + }, + p(torch.tensor([[[0], [2]]])), + ] + ) -# shape: [1] -TEST_CASE_5 = [ - {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 3}, - {"pred0": torch.tensor([2]), "pred1": torch.tensor([2]), "pred2": torch.tensor([1])}, - torch.tensor([2]), -] + # shape: [1] + TESTS.append( + [ + {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 3}, + {"pred0": p(torch.tensor([2])), "pred1": p(torch.tensor([2])), "pred2": p(torch.tensor([1]))}, + p(torch.tensor([2])), + ] + ) class TestVoteEnsembled(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_value(self, input_param, img, expected_value): result = VoteEnsembled(**input_param)(img) - torch.testing.assert_allclose(result["output"], expected_value) + assert_allclose(result["output"], expected_value) def test_cuda_value(self): img = torch.stack( @@ -83,7 +98,7 @@ def test_cuda_value(self): img = img.to(torch.device("cuda:0")) expected_value = expected_value.to(torch.device("cuda:0")) result = VoteEnsembled(keys="output", num_classes=None)({"output": img}) - torch.testing.assert_allclose(result["output"], expected_value) + assert_allclose(result["output"], expected_value) if __name__ == "__main__": diff --git a/tests/test_warp.py b/tests/test_warp.py index c6c79a369a..fde1048bf7 100644 --- a/tests/test_warp.py +++ b/tests/test_warp.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_distributed_weighted_random_sampler.py b/tests/test_weighted_random_sampler_dist.py similarity index 91% rename from tests/test_distributed_weighted_random_sampler.py rename to tests/test_weighted_random_sampler_dist.py index b8e088fdcf..13404a8acb 100644 --- a/tests/test_distributed_weighted_random_sampler.py +++ b/tests/test_weighted_random_sampler_dist.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,10 +25,7 @@ def test_sampling(self): data = [1, 2, 3, 4, 5] weights = [1, 2, 3, 4, 5] sampler = DistributedWeightedRandomSampler( - weights=weights, - dataset=data, - shuffle=False, - generator=torch.Generator().manual_seed(0), + weights=weights, dataset=data, shuffle=False, generator=torch.Generator().manual_seed(0) ) samples = np.array([data[i] for i in list(sampler)]) diff --git a/tests/test_with_allow_missing_keys.py b/tests/test_with_allow_missing_keys.py index 68c5ad30c4..36d5c0c843 100644 --- a/tests/test_with_allow_missing_keys.py +++ b/tests/test_with_allow_missing_keys.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_write_metrics_reports.py b/tests/test_write_metrics_reports.py index f736db9961..101e1137b6 100644 --- a/tests/test_write_metrics_reports.py +++ b/tests/test_write_metrics_reports.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,6 +13,7 @@ import os import tempfile import unittest +from pathlib import Path import torch @@ -23,7 +24,7 @@ class TestWriteMetricsReports(unittest.TestCase): def test_content(self): with tempfile.TemporaryDirectory() as tempdir: write_metrics_reports( - save_dir=tempdir, + save_dir=Path(tempdir), images=["filepath1", "filepath2"], metrics={"metric1": 1, "metric2": 2}, metric_details={"metric3": torch.tensor([[1, 2], [2, 3]]), "metric4": torch.tensor([[5, 6], [7, 8]])}, diff --git a/tests/test_wsireader.py b/tests/test_wsireader.py new file mode 100644 index 0000000000..416a0c11a1 --- /dev/null +++ b/tests/test_wsireader.py @@ -0,0 +1,206 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +from unittest import skipUnless + +import numpy as np +import torch +from numpy.testing import assert_array_equal +from parameterized import parameterized + +from monai.apps.utils import download_url +from monai.data import DataLoader, Dataset +from monai.data.image_reader import WSIReader +from monai.transforms import Compose, LoadImaged, ToTensord +from monai.utils import first, optional_import + +cucim, has_cucim = optional_import("cucim") +has_cucim = has_cucim and hasattr(cucim, "CuImage") +_, has_osl = optional_import("openslide") +imsave, has_tiff = optional_import("tifffile", name="imsave") +_, has_codec = optional_import("imagecodecs") +has_tiff = has_tiff and has_codec + +FILE_URL = "https://drive.google.com/uc?id=1sGTKZlJBIz53pfqTxoTqiIQzIoEzHLAe" +base_name, extension = FILE_URL.split("id=")[1], ".tiff" +FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + base_name + extension) + +HEIGHT = 32914 +WIDTH = 46000 + +TEST_CASE_0 = [FILE_PATH, 2, (3, HEIGHT // 4, WIDTH // 4)] + +TEST_CASE_TRANSFORM_0 = [FILE_PATH, 4, (HEIGHT // 16, WIDTH // 16), (1, 3, HEIGHT // 16, WIDTH // 16)] + +TEST_CASE_1 = [ + FILE_PATH, + {"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0}, + np.array([[[246], [246]], [[246], [246]], [[246], [246]]]), +] + +TEST_CASE_2 = [ + FILE_PATH, + {"location": (0, 0), "size": (2, 1), "level": 2}, + np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), +] + +TEST_CASE_3 = [ + FILE_PATH, + {"location": (0, 0), "size": (8, 8), "level": 2, "grid_shape": (2, 1), "patch_size": 2}, + np.array( + [ + [[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[239, 239], [239, 239]]], + [[[242, 242], [242, 243]], [[242, 242], [242, 243]], [[242, 242], [242, 243]]], + ] + ), +] + +TEST_CASE_4 = [ + FILE_PATH, + {"location": (0, 0), "size": (8, 8), "level": 2, "grid_shape": (2, 1), "patch_size": 1}, + np.array([[[[239]], [[239]], [[239]]], [[[243]], [[243]], [[243]]]]), +] + +TEST_CASE_5 = [ + FILE_PATH, + {"location": (HEIGHT - 2, WIDTH - 2), "level": 0, "grid_shape": (1, 1)}, + np.array([[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[237, 237], [237, 237]]]), +] + + +TEST_CASE_RGB_0 = [np.ones((3, 2, 2), dtype=np.uint8)] # CHW + +TEST_CASE_RGB_1 = [np.ones((3, 100, 100), dtype=np.uint8)] # CHW + + +def save_rgba_tiff(array: np.ndarray, filename: str, mode: str): + """ + Save numpy array into a TIFF RGB/RGBA file + + Args: + array: numpy ndarray with the shape of CxHxW and C==3 representing a RGB image + file_prefix: the filename to be used for the tiff file. '_RGB.tiff' or '_RGBA.tiff' will be appended to this filename. + mode: RGB or RGBA + """ + if mode == "RGBA": + array = np.concatenate([array, 255 * np.ones_like(array[0])[np.newaxis]]).astype(np.uint8) + + img_rgb = array.transpose(1, 2, 0) + imsave(filename, img_rgb, shape=img_rgb.shape, tile=(16, 16)) + + return filename + + +@skipUnless(has_cucim or has_osl or has_tiff, "Requires cucim, openslide, or tifffile!") +def setUpModule(): # noqa: N802 + download_url(FILE_URL, FILE_PATH, "5a3cfd4fd725c50578ddb80b517b759f") + + +class WSIReaderTests: + class Tests(unittest.TestCase): + backend = None + + @parameterized.expand([TEST_CASE_0]) + def test_read_whole_image(self, file_path, level, expected_shape): + reader = WSIReader(self.backend, level=level) + with reader.read(file_path) as img_obj: + img = reader.get_data(img_obj)[0] + self.assertTupleEqual(img.shape, expected_shape) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_5]) + def test_read_region(self, file_path, patch_info, expected_img): + reader = WSIReader(self.backend) + with reader.read(file_path) as img_obj: + if self.backend == "tifffile": + with self.assertRaises(ValueError): + reader.get_data(img_obj, **patch_info)[0] + else: + # Read twice to check multiple calls + img = reader.get_data(img_obj, **patch_info)[0] + img2 = reader.get_data(img_obj, **patch_info)[0] + self.assertTupleEqual(img.shape, img2.shape) + self.assertIsNone(assert_array_equal(img, img2)) + self.assertTupleEqual(img.shape, expected_img.shape) + self.assertIsNone(assert_array_equal(img, expected_img)) + + @parameterized.expand([TEST_CASE_3, TEST_CASE_4]) + def test_read_patches(self, file_path, patch_info, expected_img): + reader = WSIReader(self.backend) + with reader.read(file_path) as img_obj: + if self.backend == "tifffile": + with self.assertRaises(ValueError): + reader.get_data(img_obj, **patch_info)[0] + else: + img = reader.get_data(img_obj, **patch_info)[0] + self.assertTupleEqual(img.shape, expected_img.shape) + self.assertIsNone(assert_array_equal(img, expected_img)) + + @parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1]) + @skipUnless(has_tiff, "Requires tifffile.") + def test_read_rgba(self, img_expected): + # skip for OpenSlide since not working with images without tiles + if self.backend == "openslide": + return + image = {} + reader = WSIReader(self.backend) + for mode in ["RGB", "RGBA"]: + file_path = save_rgba_tiff( + img_expected, + os.path.join(os.path.dirname(__file__), "testing_data", f"temp_tiff_image_{mode}.tiff"), + mode=mode, + ) + with reader.read(file_path) as img_obj: + image[mode], _ = reader.get_data(img_obj) + + self.assertIsNone(assert_array_equal(image["RGB"], img_expected)) + self.assertIsNone(assert_array_equal(image["RGBA"], img_expected)) + + @parameterized.expand([TEST_CASE_TRANSFORM_0]) + def test_with_dataloader(self, file_path, level, expected_spatial_shape, expected_shape): + train_transform = Compose( + [ + LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level), + ToTensord(keys=["image"]), + ] + ) + dataset = Dataset([{"image": file_path}], transform=train_transform) + data_loader = DataLoader(dataset) + data: dict = first(data_loader) + for s in data["image_meta_dict"]["spatial_shape"]: + torch.testing.assert_allclose(s, expected_spatial_shape) + self.assertTupleEqual(data["image"].shape, expected_shape) + + +@skipUnless(has_cucim, "Requires cucim") +class TestCuCIM(WSIReaderTests.Tests): + @classmethod + def setUpClass(cls): + cls.backend = "cucim" + + +@skipUnless(has_osl, "Requires OpenSlide") +class TestOpenSlide(WSIReaderTests.Tests): + @classmethod + def setUpClass(cls): + cls.backend = "openslide" + + +@skipUnless(has_tiff, "Requires TiffFile") +class TestTiffFile(WSIReaderTests.Tests): + @classmethod + def setUpClass(cls): + cls.backend = "tifffile" + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_zipdataset.py b/tests/test_zipdataset.py index 710ca71fc2..f381e0a453 100644 --- a/tests/test_zipdataset.py +++ b/tests/test_zipdataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_zoom.py b/tests/test_zoom.py index e6710ede29..1a7694072e 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,11 +12,12 @@ import unittest import numpy as np +import torch from parameterized import parameterized from scipy.ndimage import zoom as zoom_scipy from monai.transforms import Zoom -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose VALID_CASES = [(1.5, "nearest"), (1.5, "nearest"), (0.8, "bilinear"), (0.8, "area")] @@ -26,38 +27,42 @@ class TestZoom(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) def test_correct_results(self, zoom, mode): - zoom_fn = Zoom(zoom=zoom, mode=mode, keep_size=False) - zoomed = zoom_fn(self.imt[0]) - _order = 0 - if mode.endswith("linear"): - _order = 1 - expected = [] - for channel in self.imt[0]: - expected.append(zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False)) - expected = np.stack(expected).astype(np.float32) - np.testing.assert_allclose(zoomed, expected, atol=1.0) + for p in TEST_NDARRAYS: + zoom_fn = Zoom(zoom=zoom, mode=mode, keep_size=False) + zoomed = zoom_fn(p(self.imt[0])) + _order = 0 + if mode.endswith("linear"): + _order = 1 + expected = [] + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False)) + expected = np.stack(expected).astype(np.float32) + assert_allclose(zoomed, p(expected), atol=1.0) def test_keep_size(self): - zoom_fn = Zoom(zoom=[0.6, 0.6], keep_size=True, align_corners=True, padding_mode="constant", constant_values=2) - zoomed = zoom_fn(self.imt[0], mode="bilinear") - np.testing.assert_allclose(zoomed.shape, self.imt.shape[1:]) + for p in TEST_NDARRAYS: + zoom_fn = Zoom(zoom=[0.6, 0.6], keep_size=True, align_corners=True) + zoomed = zoom_fn(p(self.imt[0]), mode="bilinear") + assert_allclose(zoomed.shape, self.imt.shape[1:]) - zoom_fn = Zoom(zoom=[1.3, 1.3], keep_size=True) - zoomed = zoom_fn(self.imt[0]) - np.testing.assert_allclose(zoomed.shape, self.imt.shape[1:]) + zoom_fn = Zoom(zoom=[1.3, 1.3], keep_size=True) + zoomed = zoom_fn(p(self.imt[0])) + assert_allclose(zoomed.shape, self.imt.shape[1:]) @parameterized.expand(INVALID_CASES) def test_invalid_inputs(self, zoom, mode, raises): - with self.assertRaises(raises): - zoom_fn = Zoom(zoom=zoom, mode=mode) - zoom_fn(self.imt[0]) + for p in TEST_NDARRAYS: + with self.assertRaises(raises): + zoom_fn = Zoom(zoom=zoom, mode=mode) + zoom_fn(p(self.imt[0])) def test_padding_mode(self): - zoom_fn = Zoom(zoom=0.5, mode="nearest", padding_mode="constant", keep_size=True) - test_data = np.array([[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]]) - zoomed = zoom_fn(test_data) - expected = np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]) - np.testing.assert_allclose(zoomed, expected) + for p in TEST_NDARRAYS: + zoom_fn = Zoom(zoom=0.5, mode="nearest", padding_mode="constant", keep_size=True) + test_data = p([[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]]) + zoomed = zoom_fn(test_data) + expected = p([[[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]) + torch.testing.assert_allclose(zoomed, expected) if __name__ == "__main__": diff --git a/tests/test_zoom_affine.py b/tests/test_zoom_affine.py index 49c3c0dcac..3c4bcd302c 100644 --- a/tests/test_zoom_affine.py +++ b/tests/test_zoom_affine.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/test_zoomd.py b/tests/test_zoomd.py index 1a1a905d80..87a5cec22b 100644 --- a/tests/test_zoomd.py +++ b/tests/test_zoomd.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,7 +16,7 @@ from scipy.ndimage import zoom as zoom_scipy from monai.transforms import Zoomd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose VALID_CASES = [(1.5, "nearest", False), (0.3, "bilinear", False), (0.8, "bilinear", False)] @@ -27,38 +27,37 @@ class TestZoomd(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) def test_correct_results(self, zoom, mode, keep_size): key = "img" - zoom_fn = Zoomd( - key, - zoom=zoom, - mode=mode, - keep_size=keep_size, - ) - zoomed = zoom_fn({key: self.imt[0]}) - _order = 0 - if mode.endswith("linear"): - _order = 1 - expected = [] - for channel in self.imt[0]: - expected.append(zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False)) - expected = np.stack(expected).astype(np.float32) - np.testing.assert_allclose(expected, zoomed[key], atol=1.0) + zoom_fn = Zoomd(key, zoom=zoom, mode=mode, keep_size=keep_size) + for p in TEST_NDARRAYS: + zoomed = zoom_fn({key: p(self.imt[0])}) + _order = 0 + if mode.endswith("linear"): + _order = 1 + expected = [ + zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False) for channel in self.imt[0] + ] + + expected = np.stack(expected).astype(np.float32) + assert_allclose(zoomed[key], p(expected), atol=1.0) def test_keep_size(self): key = "img" zoom_fn = Zoomd(key, zoom=0.6, keep_size=True, padding_mode="constant", constant_values=2) - zoomed = zoom_fn({key: self.imt[0]}) - self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:])) + for p in TEST_NDARRAYS: + zoomed = zoom_fn({key: p(self.imt[0])}) + np.testing.assert_array_equal(zoomed[key].shape, self.imt.shape[1:]) - zoom_fn = Zoomd(key, zoom=1.3, keep_size=True) - zoomed = zoom_fn({key: self.imt[0]}) - self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:])) + zoom_fn = Zoomd(key, zoom=1.3, keep_size=True) + zoomed = zoom_fn({key: self.imt[0]}) + self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:])) @parameterized.expand(INVALID_CASES) def test_invalid_inputs(self, _, zoom, mode, raises): key = "img" - with self.assertRaises(raises): - zoom_fn = Zoomd(key, zoom=zoom, mode=mode) - zoom_fn({key: self.imt[0]}) + for p in TEST_NDARRAYS: + with self.assertRaises(raises): + zoom_fn = Zoomd(key, zoom=zoom, mode=mode) + zoom_fn({key: p(self.imt[0])}) if __name__ == "__main__": diff --git a/tests/testing_data/cpp_resample_answers.py b/tests/testing_data/cpp_resample_answers.py index 51ac6ccda9..93f596619e 100644 --- a/tests/testing_data/cpp_resample_answers.py +++ b/tests/testing_data/cpp_resample_answers.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,7 +23,7 @@ def _read_testing_data_answers(fname: Optional[str] = None, delimiter=",") -> Li pwd = os.path.dirname(os.path.abspath(__file__)) filename = os.path.join(pwd, fname) if not os.path.isfile(filename): - warnings.warn("test data {} not found.".format(filename)) + warnings.warn(f"test data {filename} not found.") return answers with open(filename) as f: res_reader = csv.reader(f, delimiter=delimiter) diff --git a/tests/testing_data/integration_answers.py b/tests/testing_data/integration_answers.py index ccb4293a40..99765a2b33 100644 --- a/tests/testing_data/integration_answers.py +++ b/tests/testing_data/integration_answers.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -418,7 +418,7 @@ 0.17794132232666016, 0.18584394454956055, 0.03577899932861328, - ], + ] }, "integration_segmentation_3d": { # for the mixed readers "losses": [ @@ -433,6 +433,153 @@ "infer_metric": 0.9326590299606323, }, }, + { # test answers for PyTorch 21.10 + "integration_classification_2d": { + "losses": [0.7806222991199251, 0.16259610306495315, 0.07529311385124353, 0.04640352608529246], + "best_metric": 0.9999369155431564, + "infer_prop": [1030, 898, 981, 1033, 960, 1046], + }, + "integration_segmentation_3d": { + "losses": [ + 0.5462362408638001, + 0.4913381844758987, + 0.4526856362819672, + 0.43404580652713776, + 0.42532919645309447, + 0.4160102754831314, + ], + "best_metric": 0.9357608556747437, + "infer_metric": 0.9359462857246399, + "output_sums": [ + 0.14133183650702907, + 0.15129517085134564, + 0.15039408698301698, + 0.1388800895551786, + 0.18765019147239637, + 0.16847158867677473, + 0.14567945622102715, + 0.16728557092807228, + 0.15601444057659314, + 0.17816339678760573, + 0.1616256801482474, + 0.16733042976922818, + 0.14342795433701588, + 0.1122946416901734, + 0.16105778942392063, + 0.20017543167070598, + 0.17512204704647916, + 0.09592956823274325, + 0.19316383411238341, + 0.2022308530579937, + 0.19527218778022315, + 0.2075871950564991, + 0.16083565516485876, + 0.13111518931029637, + 0.1473909261474288, + 0.14161210629657228, + 0.23102446985179093, + 0.15980667305916593, + 0.14760356792082058, + 0.1018092235719272, + 0.11792260857122504, + 0.1285278390386459, + 0.11275165891441473, + 0.15101653432548032, + 0.16236351926994622, + 0.1932631773335222, + 0.2221395787381994, + 0.18003549292918666, + 0.18940543270178078, + 0.07430261166443994, + ], + }, + "integration_workflows": { + "output_sums": [ + 0.14211511611938477, + 0.1516571044921875, + 0.1381092071533203, + 0.13403034210205078, + 0.18480682373046875, + 0.16382598876953125, + 0.14140796661376953, + 0.1665945053100586, + 0.15700864791870117, + 0.17697620391845703, + 0.16163396835327148, + 0.16488313674926758, + 0.1442713737487793, + 0.11060476303100586, + 0.16111087799072266, + 0.19617986679077148, + 0.1744403839111328, + 0.052786827087402344, + 0.19046974182128906, + 0.19913578033447266, + 0.19527721405029297, + 0.2032318115234375, + 0.16050148010253906, + 0.13228464126586914, + 0.1512293815612793, + 0.1372208595275879, + 0.22692251205444336, + 0.16164922714233398, + 0.14729642868041992, + 0.10398292541503906, + 0.1195836067199707, + 0.13096046447753906, + 0.11221647262573242, + 0.1521167755126953, + 0.1599421501159668, + 0.1898345947265625, + 0.21675777435302734, + 0.1777491569519043, + 0.18526840209960938, + 0.035144805908203125, + ], + "output_sums_2": [ + 0.14200592041015625, + 0.15146303176879883, + 0.13796186447143555, + 0.1339101791381836, + 0.18489742279052734, + 0.1637406349182129, + 0.14113903045654297, + 0.16657161712646484, + 0.15676355361938477, + 0.17683839797973633, + 0.1614980697631836, + 0.16493558883666992, + 0.14408016204833984, + 0.11035394668579102, + 0.1610560417175293, + 0.1962742805480957, + 0.17439842224121094, + 0.05285835266113281, + 0.19057941436767578, + 0.19914865493774414, + 0.19533538818359375, + 0.20333576202392578, + 0.16032838821411133, + 0.13197898864746094, + 0.1510462760925293, + 0.13703680038452148, + 0.2270984649658203, + 0.16144943237304688, + 0.1472611427307129, + 0.10393238067626953, + 0.11940813064575195, + 0.1307811737060547, + 0.11203241348266602, + 0.15186500549316406, + 0.15992307662963867, + 0.18991422653198242, + 0.21689796447753906, + 0.1777033805847168, + 0.18547868728637695, + 0.035192012786865234, + ], + }, + }, ] @@ -440,6 +587,8 @@ def test_integration_value(test_name, key, data, rtol=1e-2): for (idx, expected) in enumerate(EXPECTED_ANSWERS): if test_name not in expected: continue + if key not in expected[test_name]: + continue value = expected[test_name][key] if np.allclose(data, value, rtol=rtol): print(f"matched {idx} result of {test_name}, {key}, {rtol}.") diff --git a/tests/testing_data/matshow3d_patch_test.png b/tests/testing_data/matshow3d_patch_test.png new file mode 100644 index 0000000000..a4d89e3446 Binary files /dev/null and b/tests/testing_data/matshow3d_patch_test.png differ diff --git a/tests/testing_data/matshow3d_rgb_test.png b/tests/testing_data/matshow3d_rgb_test.png new file mode 100644 index 0000000000..7c8e224c0e Binary files /dev/null and b/tests/testing_data/matshow3d_rgb_test.png differ diff --git a/tests/testing_data/matshow3d_test.png b/tests/testing_data/matshow3d_test.png new file mode 100644 index 0000000000..d720a0c407 Binary files /dev/null and b/tests/testing_data/matshow3d_test.png differ diff --git a/tests/utils.py b/tests/utils.py index 1375cd2d72..7e36b289e6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,10 +22,9 @@ import unittest import warnings from functools import partial -from io import BytesIO from subprocess import PIPE, Popen from typing import Callable, Optional, Tuple -from urllib.error import ContentTooShortError, HTTPError, URLError +from urllib.error import HTTPError, URLError import numpy as np import torch @@ -35,13 +34,15 @@ from monai.config.deviceconfig import USE_COMPILED from monai.config.type_definitions import NdarrayOrTensor from monai.data import create_test_image_2d, create_test_image_3d -from monai.utils import ensure_tuple, optional_import, set_determinism -from monai.utils.misc import is_module_ver_at_least -from monai.utils.module import version_leq +from monai.networks import convert_to_torchscript +from monai.utils import optional_import +from monai.utils.module import pytorch_after, version_leq +from monai.utils.type_conversion import convert_data_type nib, _ = optional_import("nibabel") quick_test_var = "QUICKTEST" +_tf32_enabled = None def clone(data: NdarrayTensor) -> NdarrayTensor: @@ -57,31 +58,83 @@ def clone(data: NdarrayTensor) -> NdarrayTensor: return copy.deepcopy(data) -def assert_allclose(a: NdarrayOrTensor, b: NdarrayOrTensor, *args, **kwargs): +def assert_allclose( + actual: NdarrayOrTensor, + desired: NdarrayOrTensor, + type_test: bool = True, + device_test: bool = False, + *args, + **kwargs, +): """ - Assert that all values of two data objects are close. + Assert that types and all values of two data objects are close. Args: - a (NdarrayOrTensor): Pytorch Tensor or numpy array for comparison - b (NdarrayOrTensor): Pytorch Tensor or numpy array to compare against + actual: Pytorch Tensor or numpy array for comparison. + desired: Pytorch Tensor or numpy array to compare against. + type_test: whether to test that `actual` and `desired` are both numpy arrays or torch tensors. + device_test: whether to test the device property. + args: extra arguments to pass on to `np.testing.assert_allclose`. + kwargs: extra arguments to pass on to `np.testing.assert_allclose`. + + """ - a = a.cpu() if isinstance(a, torch.Tensor) else a - b = b.cpu() if isinstance(b, torch.Tensor) else b - np.testing.assert_allclose(a, b, *args, **kwargs) + if type_test: + # check both actual and desired are of the same type + np.testing.assert_equal(isinstance(actual, np.ndarray), isinstance(desired, np.ndarray), "numpy type") + np.testing.assert_equal(isinstance(actual, torch.Tensor), isinstance(desired, torch.Tensor), "torch type") + + if isinstance(desired, torch.Tensor) or isinstance(actual, torch.Tensor): + if device_test: + np.testing.assert_equal(str(actual.device), str(desired.device), "torch device check") # type: ignore + actual = actual.cpu().numpy() if isinstance(actual, torch.Tensor) else actual + desired = desired.cpu().numpy() if isinstance(desired, torch.Tensor) else desired + np.testing.assert_allclose(actual, desired, *args, **kwargs) def test_pretrained_networks(network, input_param, device): try: - net = network(**input_param).to(device) - except (URLError, HTTPError, ContentTooShortError) as e: - raise unittest.SkipTest(e) - return net + return network(**input_param).to(device) + except (URLError, HTTPError) as e: + raise unittest.SkipTest(e) from e + except RuntimeError as r_error: + if "unexpected EOF" in f"{r_error}": # The file might be corrupted. + raise unittest.SkipTest(f"{r_error}") from r_error + raise def test_is_quick(): return os.environ.get(quick_test_var, "").lower() == "true" +def is_tf32_env(): + """ + The environment variable NVIDIA_TF32_OVERRIDE=0 will override any defaults + or programmatic configuration of NVIDIA libraries, and consequently, + cuBLAS will not accelerate FP32 computations with TF32 tensor cores. + """ + global _tf32_enabled + if _tf32_enabled is None: + _tf32_enabled = False + if ( + torch.cuda.is_available() + and not version_leq(f"{torch.version.cuda}", "10.100") + and os.environ.get("NVIDIA_TF32_OVERRIDE", "1") != "0" + and torch.cuda.device_count() > 0 # at least 11.0 + ): + try: + # with TF32 enabled, the speed is ~8x faster, but the precision has ~2 digits less in the result + g_gpu = torch.Generator(device="cuda") + g_gpu.manual_seed(2147483647) + a_full = torch.randn(1024, 1024, dtype=torch.double, device="cuda", generator=g_gpu) + b_full = torch.randn(1024, 1024, dtype=torch.double, device="cuda", generator=g_gpu) + _tf32_enabled = (a_full.float() @ b_full.float() - a_full @ b_full).abs().max().item() > 0.001 # 0.1713 + except BaseException: + pass + print(f"tf32 enabled: {_tf32_enabled}") + return _tf32_enabled + + def skip_if_quick(obj): """ Skip the unit tests if environment variable `quick_test_var=true`. @@ -143,7 +196,7 @@ class SkipIfBeforePyTorchVersion: def __init__(self, pytorch_version_tuple): self.min_version = pytorch_version_tuple - self.version_too_old = not is_module_ver_at_least(torch, pytorch_version_tuple) + self.version_too_old = not pytorch_after(*pytorch_version_tuple) def __call__(self, obj): return unittest.skipIf( @@ -157,8 +210,7 @@ class SkipIfAtLeastPyTorchVersion: def __init__(self, pytorch_version_tuple): self.max_version = pytorch_version_tuple - test_ver = ".".join(map(str, self.max_version)) - self.version_too_new = version_leq(test_ver, torch.__version__) + self.version_too_new = pytorch_after(*pytorch_version_tuple) def __call__(self, obj): return unittest.skipIf( @@ -166,11 +218,36 @@ def __call__(self, obj): )(obj) -def make_nifti_image(array, affine=None): +def has_cupy(): + """ + Returns True if the user has installed a version of cupy. + """ + cp, has_cp = optional_import("cupy") + if not has_cp: + return False + try: # test cupy installation with a basic example + x = cp.arange(6, dtype="f").reshape(2, 3) + y = cp.arange(3, dtype="f") + kernel = cp.ElementwiseKernel( + "float32 x, float32 y", "float32 z", """ if (x - 2 > y) { z = x * y; } else { z = x + y; } """, "my_kernel" + ) + return kernel(x, y)[0, 0] == 0 + except Exception: + return False + + +HAS_CUPY = has_cupy() + + +def make_nifti_image(array: NdarrayOrTensor, affine=None): """ Create a temporary nifti image on the disk and return the image name. User is responsible for deleting the temporary file when done with it. """ + if isinstance(array, torch.Tensor): + array, *_ = convert_data_type(array, np.ndarray) + if isinstance(affine, torch.Tensor): + affine, *_ = convert_data_type(affine, np.ndarray) if affine is None: affine = np.eye(4) test_image = nib.Nifti1Image(array, affine) @@ -298,8 +375,7 @@ def run_process(self, func, local_rank, args, kwargs, results): os.environ["RANK"] = str(self.nproc_per_node * self.node_rank + local_rank) if torch.cuda.is_available(): - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - torch.cuda.set_device(int(local_rank)) + torch.cuda.set_device(int(local_rank)) # using device ids from CUDA_VISIBILE_DEVICES dist.init_process_group( backend=self.backend, @@ -354,6 +430,7 @@ def _wrapper(*args, **kwargs): for p in processes: p.join() assert results.get(), "Distributed call failed." + _del_original_func(obj) return _wrapper @@ -435,6 +512,7 @@ def _wrapper(*args, **kwargs): finally: p.join() + _del_original_func(obj) res = None try: res = results.get(block=False) @@ -460,6 +538,15 @@ def _cache_original_func(obj) -> None: _original_funcs[obj.__name__] = obj +def _del_original_func(obj): + """pop the original function from cache.""" + global _original_funcs + _original_funcs.pop(obj.__name__, None) + if torch.cuda.is_available(): # clean up the cached function + torch.cuda.synchronize() + torch.cuda.empty_cache() + + def _call_original_func(name, module, *args, **kwargs): if name not in _original_funcs: _original_module = importlib.import_module(module) # reimport, refresh _original_funcs @@ -524,59 +611,31 @@ def setUp(self): self.segn = torch.tensor(self.segn) -def test_script_save(net, *inputs, eval_nets=True, device=None, rtol=1e-4): +def test_script_save(net, *inputs, device=None, rtol=1e-4, atol=0.0): """ Test the ability to save `net` as a Torchscript object, reload it, and apply inference. The value `inputs` is - forward-passed through the original and loaded copy of the network and their results returned. Both `net` and its - reloaded copy are set to evaluation mode if `eval_nets` is True. The forward pass for both is done without - gradient accumulation. + forward-passed through the original and loaded copy of the network and their results returned. + The forward pass for both is done without gradient accumulation. The test will be performed with CUDA if available, else CPU. """ - if True: - device = "cpu" - else: - # TODO: It would be nice to be able to use GPU if - # available, but this currently causes CI failures. - if not device: - device = "cuda" if torch.cuda.is_available() else "cpu" - - # Convert to device - inputs = [i.to(device) for i in inputs] - - scripted = torch.jit.script(net.cpu()) - buffer = scripted.save_to_buffer() - reloaded_net = torch.jit.load(BytesIO(buffer)).to(device) - net.to(device) - - if eval_nets: - net.eval() - reloaded_net.eval() - - with torch.no_grad(): - set_determinism(seed=0) - result1 = net(*inputs) - result2 = reloaded_net(*inputs) - set_determinism(seed=None) - - # convert results to tuples if needed to allow iterating over pairs of outputs - result1 = ensure_tuple(result1) - result2 = ensure_tuple(result2) - - for i, (r1, r2) in enumerate(zip(result1, result2)): - if None not in (r1, r2): # might be None - np.testing.assert_allclose( - r1.detach().cpu().numpy(), - r2.detach().cpu().numpy(), - rtol=rtol, - atol=0, - err_msg=f"failed on comparison number: {i}", - ) + # TODO: would be nice to use GPU if available, but it currently causes CI failures. + device = "cpu" + with tempfile.TemporaryDirectory() as tempdir: + convert_to_torchscript( + model=net, + filename_or_obj=os.path.join(tempdir, "model.ts"), + verify=True, + inputs=inputs, + device=device, + rtol=rtol, + atol=atol, + ) def query_memory(n=2): """ - Find best n idle devices and return a string of device ids. + Find best n idle devices and return a string of device ids using the `nvidia-smi` command. """ bash_string = "nvidia-smi --query-gpu=power.draw,temperature.gpu,memory.used --format=csv,noheader,nounits" @@ -587,7 +646,7 @@ def query_memory(n=2): free_memory = np.asarray(free_memory, dtype=float).T free_memory[1] += free_memory[0] # combine 0/1 column measures ids = np.lexsort(free_memory)[:n] - except (FileNotFoundError, TypeError, IndexError): + except (TypeError, IndexError, OSError): ids = range(n) if isinstance(n, int) else [] return ",".join(f"{int(x)}" for x in ids)