diff --git a/.dockerignore b/.dockerignore
index 549e63bad5..4e1161bfb2 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -4,8 +4,10 @@ __pycache__/
docs/
.coverage
+.coverage.*
+.coverage/
+coverage.xml
.readthedocs.yml
-*.md
*.toml
!README.md
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
index 61b8814857..f7024f1a08 100644
--- a/.github/pull_request_template.md
+++ b/.github/pull_request_template.md
@@ -11,7 +11,7 @@ A few sentences describing the changes proposed in this pull request.
- [x] Non-breaking change (fix or new feature that would not break existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing functionality to change).
- [ ] New tests added to cover the changes.
-- [ ] Integration tests passed locally by running `./runtests.sh --codeformat --coverage`.
-- [ ] Quick tests passed locally by running `./runtests.sh --quick`.
+- [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`.
+- [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/` folder.
diff --git a/.github/workflows/blossom-ci.yml b/.github/workflows/blossom-ci.yml
new file mode 100644
index 0000000000..52bd4fddae
--- /dev/null
+++ b/.github/workflows/blossom-ci.yml
@@ -0,0 +1,103 @@
+# A workflow to trigger ci on hybrid infra (github + self hosted runner)
+name: Blossom-CI
+on:
+ issue_comment:
+ types: [created]
+ workflow_dispatch:
+ inputs:
+ platform:
+ description: 'runs-on argument'
+ required: false
+ args:
+ description: 'argument'
+ required: false
+
+permissions:
+ actions: write
+ checks: write
+ contents: write
+ issues: write
+ pull-requests: write
+ repository-projects: write
+ statuses: write
+
+jobs:
+ Authorization:
+ name: Authorization
+ runs-on: blossom
+ outputs:
+ args: ${{ env.args }}
+
+ # This job only runs for pull request comments
+ if: |
+ contains( 'madil90,Nic-Ma,wyli,', format('{0},', github.actor)) &&
+ github.event.comment.body == '/build'
+ steps:
+ - name: Check if comment is issued by authorized person
+ run: blossom-ci
+ env:
+ OPERATION: 'AUTH'
+ REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ REPO_KEY_DATA: ${{ secrets.BLOSSOM_KEY }}
+
+ Vulnerability-scan:
+ name: Vulnerability scan
+ needs: [Authorization]
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v2
+ with:
+ repository: ${{ fromJson(needs.Authorization.outputs.args).repo }}
+ ref: ${{ fromJson(needs.Authorization.outputs.args).ref }}
+ lfs: 'true'
+
+ # repo specific steps
+ #- name: Setup java
+ # uses: actions/setup-java@v1
+ # with:
+ # java-version: 1.8
+
+ # add blackduck properties https://synopsys.atlassian.net/wiki/spaces/INTDOCS/pages/631308372/Methods+for+Configuring+Analysis#Using-a-configuration-file
+ #- name: Setup blackduck properties
+ # run: |
+ # PROJECTS=$(mvn -am dependency:tree | grep maven-dependency-plugin | awk '{ out="com.nvidia:"$(NF-1);print out }' | grep rapids | xargs | sed -e 's/ /,/g')
+ # echo detect.maven.build.command="-pl=$PROJECTS -am" >> application.properties
+ # echo detect.maven.included.scopes=compile >> application.properties
+ - name: Setup blackduck properties
+ run: |
+ echo detect.excluded.detector.types=PIP >> application.properties
+
+ - name: Run blossom action
+ uses: NVIDIA/blossom-action@main
+ env:
+ REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ REPO_KEY_DATA: ${{ secrets.BLOSSOM_KEY }}
+ with:
+ args1: ${{ fromJson(needs.Authorization.outputs.args).args1 }}
+ args2: ${{ fromJson(needs.Authorization.outputs.args).args2 }}
+ args3: ${{ fromJson(needs.Authorization.outputs.args).args3 }}
+
+ Job-trigger:
+ name: Start ci job
+ needs: [Vulnerability-scan]
+ runs-on: blossom
+ steps:
+ - name: Start ci job
+ run: blossom-ci
+ env:
+ OPERATION: 'START-CI-JOB'
+ CI_SERVER: ${{ secrets.CI_SERVER }}
+ REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+
+ Post-processing:
+ name: Post processing
+ runs-on: blossom
+ if : github.event_name == 'workflow_dispatch'
+ steps:
+ - name: Start post processing
+ run: blossom-ci
+ env:
+ OPERATION: 'POST-PROCESSING'
+ CI_SERVER: ${{ secrets.CI_SERVER }}
+ REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}
diff --git a/.github/workflows/cleanup.yml b/.github/workflows/cleanup.yml
deleted file mode 100644
index f3d297286e..0000000000
--- a/.github/workflows/cleanup.yml
+++ /dev/null
@@ -1,20 +0,0 @@
-name: cleanup-workflow
-
-on:
- workflow_run:
- workflows:
- - "build"
- types: ["requested"]
-
-jobs:
- cancel-duplicated-workflow:
- name: "Cancel duplicated workflow"
- runs-on: ubuntu-latest
- steps:
- - uses: potiuk/cancel-workflow-runs@953e057dc81d3458935a18d1184c386b0f6b5738 # tested
- name: "Cancel duplicate workflows"
- with:
- cancelMode: allDuplicates
- token: ${{ secrets.GITHUB_TOKEN }}
- sourceRunId: ${{ github.event.workflow_run.id }}
- skipEventTypes: '["schedule"]'
diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml
index e568ba9e15..c3d8a4c3b1 100644
--- a/.github/workflows/cron.yml
+++ b/.github/workflows/cron.yml
@@ -3,6 +3,8 @@ name: crons
on:
schedule:
- cron: "0 2 * * *" # at 02:00 UTC
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
jobs:
cron-gpu:
@@ -13,7 +15,7 @@ jobs:
runs-on: [self-hosted, linux, x64, common]
strategy:
matrix:
- pytorch-version: [1.5.0, 1.5.1, 1.6.0, latest]
+ pytorch-version: [1.5.1, 1.6.0, 1.7.1, 1.8.1, latest]
steps:
- uses: actions/checkout@v2
- name: Install the dependencies
@@ -23,15 +25,14 @@ jobs:
python -m pip uninstall -y torch torchvision
if [ ${{ matrix.pytorch-version }} == "latest" ]; then
python -m pip install torch torchvision
- elif [ ${{ matrix.pytorch-version }} == "1.5.0" ]; then
- python -m pip install torch==1.5.0
- python -m pip install torchvision==0.6.0
elif [ ${{ matrix.pytorch-version }} == "1.5.1" ]; then
- python -m pip install torch==1.5.1
- python -m pip install torchvision==0.6.1
+ python -m pip install torch==1.5.1 torchvision==0.6.1
elif [ ${{ matrix.pytorch-version }} == "1.6.0" ]; then
- python -m pip install torch==1.6.0
- python -m pip install torchvision==0.7.0
+ python -m pip install torch==1.6.0 torchvision==0.7.0
+ elif [ ${{ matrix.pytorch-version }} == "1.7.1" ]; then
+ python -m pip install torch==1.7.1 torchvision==0.8.2
+ elif [ ${{ matrix.pytorch-version }} == "1.8.1" ]; then
+ python -m pip install torch==1.8.1 torchvision==0.9.1
fi
python -m pip install -r requirements-dev.txt
python -m pip list
@@ -43,10 +44,14 @@ jobs:
nvidia-smi
export CUDA_VISIBLE_DEVICES=$(python -m tests.utils)
echo $CUDA_VISIBLE_DEVICES
+ trap 'if pgrep python; then pkill python; fi;' ERR
+ python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null &
python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))"
- python -c 'import torch; print(torch.rand(5,3, device=torch.device("cuda:0")))'
- BUILD_MONAI=1 ./runtests.sh --coverage
+ python -c 'import torch; print(torch.rand(5, 3, device=torch.device("cuda:0")))'
+ BUILD_MONAI=1 ./runtests.sh --coverage --unittests # unit tests with coverage report
+ BUILD_MONAI=1 ./runtests.sh --coverage --net # integration tests with coverage report
coverage xml
+ if pgrep python; then pkill python; fi
- name: Upload coverage
uses: codecov/codecov-action@v1
with:
@@ -55,13 +60,20 @@ jobs:
cron-pt-image:
if: github.repository == 'Project-MONAI/MONAI'
+ strategy:
+ matrix:
+ container: ["pytorch:21.02", "pytorch:21.06"] # 21.02 for backward comp.
container:
- image: nvcr.io/nvidia/pytorch:20.12-py3 # testing with the latest pytorch base image
+ image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image
options: "--gpus all"
runs-on: [self-hosted, linux, x64, common]
steps:
- uses: actions/checkout@v2
- - name: Install the dependencies
+ - name: Install APT dependencies
+ run: |
+ apt-get update
+ DEBIAN_FRONTEND="noninteractive" apt-get install -y libopenslide0
+ - name: Install Python dependencies
run: |
which python
python -m pip install --upgrade pip wheel
@@ -75,16 +87,89 @@ jobs:
nvidia-smi
export CUDA_VISIBLE_DEVICES=$(python -m tests.utils)
echo $CUDA_VISIBLE_DEVICES
+ trap 'if pgrep python; then pkill python; fi;' ERR
+ python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null &
python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))"
- python -c 'import torch; print(torch.rand(5,3, device=torch.device("cuda:0")))'
- BUILD_MONAI=1 ./runtests.sh --coverage
+ python -c 'import torch; print(torch.rand(5, 3, device=torch.device("cuda:0")))'
+ BUILD_MONAI=1 ./runtests.sh --coverage --unittests # unit tests with coverage report
+ BUILD_MONAI=1 ./runtests.sh --coverage --net # integration tests with coverage report
coverage xml
+ if pgrep python; then pkill python; fi
- name: Upload coverage
uses: codecov/codecov-action@v1
with:
fail_ci_if_error: false
file: ./coverage.xml
+ cron-pip:
+ # pip install monai[all] and use it to run unit tests
+ if: github.repository == 'Project-MONAI/MONAI'
+ strategy:
+ matrix:
+ container: ["pytorch:21.02", "pytorch:21.06"] # 21.02 for backward comp.
+ container:
+ image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image
+ options: "--gpus all"
+ runs-on: [self-hosted, linux, x64, common]
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+ - name: Install the dependencies
+ run: |
+ which python
+ python -m pip install --upgrade pip wheel twine
+ python -m pip list
+ - name: Run tests report coverage
+ run: |
+ pip uninstall monai
+ pip list | grep -iv monai
+ git fetch --depth=1 origin +refs/tags/*:refs/tags/*
+ root_dir=$PWD
+ echo "$root_dir"
+ set -e
+
+ # build tar.gz and wheel
+ bash runtests.sh --clean # clear any existing dev temp files
+ python -m pip uninstall -y torch torchvision
+ python setup.py check -m -s
+ python setup.py sdist bdist_wheel
+ python -m twine check dist/*
+
+ # move packages to a temp dir
+ tmp_dir=$(mktemp -d)
+ cp dist/monai* "$tmp_dir"
+ rm -r build dist monai.egg-info
+ cd "$tmp_dir"
+ ls -al
+
+ # install from tar.gz
+ name=$(ls *.tar.gz | head -n1)
+ echo $name
+ python -m pip install $name[all]
+ python -c 'import monai; monai.config.print_config()' 2>&1 | grep -iv "unknown"
+ python -c 'import monai; print(monai.__file__)'
+
+ # run tests
+ cp $root_dir/requirements*.txt "$tmp_dir"
+ cp -r $root_dir/tests "$tmp_dir"
+ pwd
+ ls -al
+
+ export LAUNCH_DELAY=$[ $RANDOM % 16 * 60 ]
+ echo "Sleep $LAUNCH_DELAY"
+ sleep $LAUNCH_DELAY
+ nvidia-smi
+ export CUDA_VISIBLE_DEVICES=$(python -m tests.utils)
+ echo $CUDA_VISIBLE_DEVICES
+ trap 'if pgrep python; then pkill python; fi;' ERR
+ python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null &
+ python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))"
+
+ python -m pip install -r requirements-dev.txt
+ PYTHONPATH="$tmp_dir":$PYTHONPATH BUILD_MONAI=1 python ./tests/runner.py -p 'test_((?!integration).)' # unit tests
+ if pgrep python; then pkill python; fi
+
cron-docker:
if: github.repository == 'Project-MONAI/MONAI'
container:
@@ -100,13 +185,54 @@ jobs:
nvidia-smi
export CUDA_VISIBLE_DEVICES=$(python -m tests.utils)
echo $CUDA_VISIBLE_DEVICES
+ trap 'if pgrep python; then pkill python; fi;' ERR
+ python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null &
python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))"
python -c 'import torch; print(torch.rand(5,3, device=torch.device("cuda:0")))'
ngc --version
- BUILD_MONAI=1 ./runtests.sh --coverage --pytype
+ BUILD_MONAI=1 ./runtests.sh --coverage --pytype --unittests # unit tests with pytype checks, coverage report
+ BUILD_MONAI=1 ./runtests.sh --coverage --net # integration tests with coverage report
coverage xml
+ if pgrep python; then pkill python; fi
- name: Upload coverage
uses: codecov/codecov-action@v1
with:
fail_ci_if_error: false
file: ./coverage.xml
+
+ cron-tutorial-notebooks:
+ if: github.repository == 'Project-MONAI/MONAI'
+ needs: cron-gpu # so that monai itself is verified first
+ container:
+ image: nvcr.io/nvidia/pytorch:21.06-py3 # testing with the latest pytorch base image
+ options: "--gpus all --ipc=host"
+ runs-on: [self-hosted, linux, x64, common]
+ steps:
+ - uses: actions/checkout@v2
+ - name: Install MONAI
+ id: monai-install
+ run: |
+ which python
+ python -m pip install --upgrade pip wheel
+ python -m pip install -r requirements-dev.txt
+ BUILD_MONAI=0 python setup.py develop # install monai
+ nvidia-smi
+ export CUDA_VISIBLE_DEVICES=$(python -m tests.utils)
+ echo $CUDA_VISIBLE_DEVICES
+ echo "::set-output name=devices::$CUDA_VISIBLE_DEVICES"
+ - name: Checkout tutorials and install their requirements
+ run: |
+ cd /opt
+ git clone --depth 1 --branch master --single-branch https://github.com/Project-MONAI/tutorials.git # latest commit of master branch
+ cd tutorials
+ python -m pip install -r requirements.txt
+ - name: Run tutorial notebooks
+ timeout-minutes: 150
+ run: |
+ export CUDA_VISIBLE_DEVICES=${{ steps.monai-install.outputs.devices }}
+ echo $CUDA_VISIBLE_DEVICES
+ trap 'if pgrep python; then pkill python; fi;' ERR
+ python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null &
+ cd /opt/tutorials
+ $(pwd)/runner.sh
+ if pgrep python; then pkill python; fi
diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml
new file mode 100644
index 0000000000..3104224e2b
--- /dev/null
+++ b/.github/workflows/docker.yml
@@ -0,0 +1,124 @@
+name: docker
+# versioning: compute a static version file
+# local_docker: use the version file to build docker images
+# docker_test_latest: test the latest internal docker image (has flake)
+# docker_test_dockerhub: test the latest dockerhub release (no flake)
+on:
+ # dev only docker deployment and quick tests
+ push:
+ branches:
+ - dev
+ # Allows you to run this workflow manually from the Actions tab
+ # This is to trigger building/testing docker image from dev only.
+ workflow_dispatch:
+
+jobs:
+ versioning:
+ # compute versioning file from python setup.py
+ # upload as artifact
+ # (also used in release.yml)
+ if: github.repository == 'Project-MONAI/MONAI'
+ container:
+ image: localhost:5000/local_monai:latest
+ runs-on: [self-hosted, linux, x64, build_only]
+ steps:
+ - uses: actions/checkout@v2
+ # full history so that we can git describe
+ with:
+ ref: dev
+ fetch-depth: 0
+ - shell: bash
+ run: |
+ git describe
+ python setup.py build
+ cat build/lib/monai/_version.py
+ - name: Upload version
+ uses: actions/upload-artifact@v2
+ with:
+ name: _version.py
+ path: build/lib/monai/_version.py
+ - name: Clean up directory
+ shell: bash
+ run: |
+ ls -al
+ rm -rf {*,.[^.]*}
+
+ local_docker:
+ # builds two versions: local_monai:latest and local_monai:dockerhub
+ # latest: used for local tests
+ # dockerhub: release, no flake package
+ if: github.repository == 'Project-MONAI/MONAI'
+ needs: versioning
+ runs-on: [self-hosted, linux, x64, build_only]
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ ref: dev
+ - name: Download version
+ uses: actions/download-artifact@v2
+ with:
+ name: _version.py
+ - name: docker_build
+ shell: bash
+ run: |
+ # get tag info for versioning
+ cat _version.py
+ mv _version.py monai/
+ # build and run original docker image for local registry
+ docker build -t localhost:5000/local_monai:latest -f Dockerfile .
+ docker push localhost:5000/local_monai:latest
+ # build once more w/ tag "latest": remove flake package as it is not needed on hub.docker.com
+ sed -i '/flake/d' requirements-dev.txt
+ docker build -t projectmonai/monai:latest -f Dockerfile .
+ # also push as tag "dockerhub" to local registry
+ docker image tag projectmonai/monai:latest localhost:5000/local_monai:dockerhub
+ docker push localhost:5000/local_monai:dockerhub
+ # distribute as always w/ tag "latest" to hub.docker.com
+ echo "${{ secrets.DOCKER_PW }}" | docker login -u projectmonai --password-stdin
+ docker push projectmonai/monai:latest
+ docker logout
+ docker image prune -f
+
+ docker_test_latest:
+ if: github.repository == 'Project-MONAI/MONAI'
+ needs: local_docker
+ container:
+ image: localhost:5000/local_monai:latest
+ runs-on: [self-hosted, linux, x64, common]
+ steps:
+ - name: Import
+ run: |
+ export CUDA_VISIBLE_DEVICES=$(python -m tests.utils)
+ echo $CUDA_VISIBLE_DEVICES
+ trap 'if pgrep python; then pkill python; fi;' ERR
+ python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null &
+ python -c 'import monai; monai.config.print_config()'
+ cd /opt/monai
+ ls -al
+ ngc --version
+ python -m tests.min_tests
+ if pgrep python; then pkill python; fi
+ env:
+ QUICKTEST: True
+
+ docker_test_dockerhub:
+ if: github.repository == 'Project-MONAI/MONAI'
+ needs: local_docker
+ container:
+ image: localhost:5000/local_monai:dockerhub
+ runs-on: [self-hosted, linux, x64, common]
+ steps:
+ - name: Import
+ run: |
+ export CUDA_VISIBLE_DEVICES=$(python -m tests.utils)
+ echo $CUDA_VISIBLE_DEVICES
+ trap 'if pgrep python; then pkill python; fi;' ERR
+ python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null &
+ python -c 'import monai; monai.config.print_config()'
+ cd /opt/monai
+ ls -al
+ ngc --version
+ python -m tests.min_tests
+ if pgrep python; then pkill python; fi
+ env:
+ QUICKTEST: True
diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml
index 003a746de4..ed025e98fe 100644
--- a/.github/workflows/integration.yml
+++ b/.github/workflows/integration.yml
@@ -7,7 +7,7 @@ on:
jobs:
integration-py3:
container:
- image: nvcr.io/nvidia/pytorch:20.03-py3 # CUDA 10.2
+ image: nvcr.io/nvidia/pytorch:20.12-py3 # CUDA 11.1
options: --gpus all
runs-on: [self-hosted, linux, x64, common]
steps:
@@ -28,13 +28,13 @@ jobs:
path: |
~/.cache/pip
~/.cache/torch
- key: docker-20-03-py3-pip-${{ steps.pip-cache.outputs.datew }}
+ key: docker-py3-pip-${{ steps.pip-cache.outputs.datew }}
- name: Install the dependencies
run: |
which python
python -m pip install --upgrade pip wheel
python -m pip uninstall -y torch torchvision
- python -m pip install torch==1.7.1 torchvision==0.8.2
+ python -m pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install -r requirements-dev.txt
- name: Run integration tests
run: |
@@ -42,9 +42,14 @@ jobs:
nvidia-smi
export CUDA_VISIBLE_DEVICES=$(python -m tests.utils)
echo $CUDA_VISIBLE_DEVICES
+ trap 'if pgrep python; then pkill python; fi;' ERR
+ python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null &
python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))"
python -c 'import torch; print(torch.rand(5,3, device=torch.device("cuda:0")))'
BUILD_MONAI=1 ./runtests.sh --net
+ BUILD_MONAI=1 ./runtests.sh --unittests
+ if pgrep python; then pkill python; fi
+ shell: bash
- name: Add reaction
uses: peter-evans/create-or-update-comment@v1
with:
diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml
index 8e92ea0ed7..b2ddb74d34 100644
--- a/.github/workflows/pythonapp.yml
+++ b/.github/workflows/pythonapp.yml
@@ -1,15 +1,22 @@
name: build
on:
- # quick tests for every pull request
+ # quick tests for pull requests and the releasing branches
push:
branches:
- - master
+ - dev
+ - main
+ - releasing/*
pull_request:
+concurrency:
+ # automatically cancel the previously triggered workflows when there's a newer version
+ group: build-${{ github.event.pull_request.number || github.ref }}
+ cancel-in-progress: true
+
jobs:
# caching of these jobs:
- # - docker-20-03-py3-pip- (shared)
+ # - docker-py3-pip- (shared)
# - ubuntu py37 pip-
# - os-latest-pip- (shared)
flake8-py3:
@@ -39,9 +46,9 @@ jobs:
# clean up temporary files
$(pwd)/runtests.sh --clean
# Git hub actions have 2 cores, so parallize pytype
- $(pwd)/runtests.sh --nounittests --codeformat -j 2
+ $(pwd)/runtests.sh --codeformat -j 2
- quick-py3: # full dependencies installed
+ quick-py3: # full dependencies installed tests for different OS
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
@@ -80,13 +87,10 @@ jobs:
- if: runner.os == 'windows'
name: Install torch cpu from pytorch.org (Windows only)
run: |
- python -m pip install torch==1.7.1+cpu torchvision==0.8.2+cpu -f https://download.pytorch.org/whl/torch_stable.html
- # min. requirements for windows instances
- python -c "f=open('requirements-dev.txt', 'r'); txt=f.readlines(); f.close(); print(txt); f=open('requirements-dev.txt', 'w'); f.writelines(txt[1:12]); f.close()"
+ python -m pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install the dependencies
run: |
- python -m pip install torch==1.7.1
- python -m pip install torchvision==0.8.2
+ python -m pip install torch==1.9.0 torchvision==0.10.0
cat "requirements-dev.txt"
python -m pip install -r requirements-dev.txt
python -m pip list
@@ -102,7 +106,7 @@ jobs:
env:
QUICKTEST: True
- min-dep-py3: # min dependencies installed
+ min-dep-os: # min dependencies installed tests for different OS
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
@@ -134,11 +138,11 @@ jobs:
- if: runner.os == 'windows'
name: Install torch cpu from pytorch.org (Windows only)
run: |
- python -m pip install torch==1.7.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
+ python -m pip install torch==1.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install the dependencies
run: |
# min. requirements
- python -m pip install torch==1.7.1
+ python -m pip install torch==1.9.0
python -m pip install -r requirements-min.txt
python -m pip list
BUILD_MONAI=0 python setup.py develop # no compile of extensions
@@ -147,7 +151,52 @@ jobs:
run: |
python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'
python -c "import monai; monai.config.print_config()"
- python -m tests.min_tests
+ ./runtests.sh --min
+ env:
+ QUICKTEST: True
+
+ min-dep-py3: # min dependencies installed tests for different python
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ python-version: [3.6, 3.7, 3.8, 3.9]
+ timeout-minutes: 40
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Prepare pip wheel
+ run: |
+ which python
+ python -m pip install --user --upgrade pip setuptools wheel
+ - name: cache weekly timestamp
+ id: pip-cache
+ run: |
+ echo "::set-output name=datew::$(date '+%Y-%V')"
+ echo "::set-output name=dir::$(pip cache dir)"
+ shell: bash
+ - name: cache for pip
+ uses: actions/cache@v2
+ id: cache
+ with:
+ path: ${{ steps.pip-cache.outputs.dir }}
+ key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }}
+ - name: Install the dependencies
+ run: |
+ # min. requirements
+ python -m pip install torch==1.9.0
+ python -m pip install -r requirements-min.txt
+ python -m pip list
+ BUILD_MONAI=0 python setup.py develop # no compile of extensions
+ shell: bash
+ - name: Run quick tests (CPU ${{ runner.os }})
+ run: |
+ python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'
+ python -c "import monai; monai.config.print_config()"
+ ./runtests.sh --min
env:
QUICKTEST: True
@@ -156,18 +205,13 @@ jobs:
strategy:
matrix:
environment:
- - "PT15+CUDA101"
- - "PT16+CUDA102"
- "PT16+CUDA110"
- "PT17+CUDA102"
- "PT17+CUDA110"
+ - "PT18+CUDA102"
+ - "PT19+CUDA113"
+ - "PT19+CUDA102"
include:
- - environment: PT15+CUDA101
- pytorch: "torch==1.5.0+cu101 torchvision==0.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html"
- base: "nvcr.io/nvidia/cuda:10.1-devel-ubuntu18.04"
- - environment: PT16+CUDA102
- pytorch: "torch==1.6.0 torchvision==0.7.0"
- base: "nvcr.io/nvidia/cuda:10.2-devel-ubuntu18.04"
- environment: PT16+CUDA110
# we explicitly set pytorch to -h to avoid pip install error
pytorch: "-h"
@@ -179,6 +223,16 @@ jobs:
# we explicitly set pytorch to -h to avoid pip install error
pytorch: "-h"
base: "nvcr.io/nvidia/pytorch:20.09-py3"
+ - environment: PT18+CUDA102
+ pytorch: "torch==1.8.1 torchvision==0.9.1"
+ base: "nvcr.io/nvidia/cuda:10.2-devel-ubuntu18.04"
+ - environment: PT19+CUDA113
+ # we explicitly set pytorch to -h to avoid pip install error
+ pytorch: "-h"
+ base: "nvcr.io/nvidia/pytorch:21.06-py3"
+ - environment: PT19+CUDA102
+ pytorch: "torch==1.9.0 torchvision==0.10.0"
+ base: "nvcr.io/nvidia/cuda:10.2-devel-ubuntu18.04"
container:
image: ${{ matrix.base }}
options: --gpus all
@@ -187,7 +241,10 @@ jobs:
- uses: actions/checkout@v2
- name: apt install
run: |
- if [ ${{ matrix.environment }} != "PT16+CUDA110" ]; then \
+ if [ ${{ matrix.environment }} = "PT17+CUDA102" ] || \
+ [ ${{ matrix.environment }} = "PT18+CUDA102" ] || \
+ [ ${{ matrix.environment }} = "PT19+CUDA102" ]
+ then
PYVER=3.6 PYSFX=3 DISTUTILS=python3-distutils && \
apt-get update && apt-get install -y --no-install-recommends \
curl \
@@ -208,7 +265,8 @@ jobs:
libboost-test-dev \
libgoogle-glog-dev \
libjsoncpp-dev \
- cmake && \
+ cmake \
+ git && \
rm -rf /var/lib/apt/lists/* && \
export PYTHONIOENCODING=utf-8 LC_ALL=C.UTF-8 && \
rm -f /usr/bin/python && \
@@ -217,28 +275,40 @@ jobs:
ln -s /usr/bin/python$PYVER /usr/bin/python`echo $PYVER | cut -c1-1` &&
curl -O https://bootstrap.pypa.io/get-pip.py && \
python get-pip.py && \
- rm get-pip.py ; fi
+ rm get-pip.py;
+ fi
- name: Install dependencies
run: |
which python
python -m pip install --upgrade pip wheel
python -m pip install ${{ matrix.pytorch }}
python -m pip install -r requirements-dev.txt
+ python -m pip list
- name: Run quick tests (GPU)
run: |
- python -m pip list
+ git clone --depth 1 \
+ https://github.com/Project-MONAI/MONAI-extra-test-data.git /MONAI-extra-test-data
+ export MONAI_EXTRA_TEST_DATA="/MONAI-extra-test-data"
nvidia-smi
+ export LAUNCH_DELAY=$(python -c "import numpy; print(numpy.random.randint(30) * 10)")
+ echo "Sleep $LAUNCH_DELAY"
+ sleep $LAUNCH_DELAY
export CUDA_VISIBLE_DEVICES=$(coverage run -m tests.utils)
echo $CUDA_VISIBLE_DEVICES
+ trap 'if pgrep python; then pkill python; fi;' ERR
+ python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null &
python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))"
- python -c 'import torch; print(torch.rand(5,3, device=torch.device("cuda:0")))'
+ python -c 'import torch; print(torch.rand(5, 3, device=torch.device("cuda:0")))'
python -c "import monai; monai.config.print_config()"
- BUILD_MONAI=1 ./runtests.sh --quick
- if [ ${{ matrix.environment }} == "PT16+CUDA110" ]; then
+ # build for the current self-hosted CI Tesla V100
+ BUILD_MONAI=1 TORCH_CUDA_ARCH_LIST="7.0" ./runtests.sh --quick --unittests
+ if [ ${{ matrix.environment }} = "PT19+CUDA102" ]; then
# test the clang-format tool downloading once
coverage run -m tests.clang_format_utils
fi
coverage xml
+ if pgrep python; then pkill python; fi
+ shell: bash
- name: Upload coverage
uses: codecov/codecov-action@v1
with:
@@ -275,6 +345,8 @@ jobs:
python -m pip install torch>=1.5 torchvision
- name: Test source archive and wheel file
run: |
+ pip uninstall monai
+ pip list | grep -iv monai
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
root_dir=$PWD
echo "$root_dir"
@@ -300,15 +372,22 @@ jobs:
rm monai*.whl
# install from tar.gz
- python -m pip install monai*.tar.gz
+ name=$(ls *.tar.gz | head -n1)
+ echo $name
+ python -m pip install $name[all]
python -c 'import monai; monai.config.print_config()' 2>&1 | grep -iv "unknown"
python -c 'import monai; print(monai.__file__)'
- python -m pip uninstall -y monai
- rm monai*.tar.gz
- # clean up
- cd "$root_dir"
- rm -r "$tmp_dir"
+ # run min tests
+ cp $root_dir/requirements*.txt "$tmp_dir"
+ cp -r $root_dir/tests "$tmp_dir"
+ pwd
+ ls -al
+ python -m pip install -r requirements-dev.txt
+ python -m unittest -v
+ env:
+ QUICKTEST: True
+ shell: bash
build-docs:
runs-on: ubuntu-latest
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 840194b1da..bfdc639788 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -1,9 +1,10 @@
name: release
+# generating and testing package artefacts from the main branch
on:
push:
branches:
- - 'releases/*'
+ - main
tags:
- '*'
@@ -12,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: [3.6, 3.7, 3.8]
+ python-version: [3.6, 3.7, 3.8, 3.9]
steps:
- uses: actions/checkout@v2
with:
@@ -83,3 +84,70 @@ jobs:
password: ${{ secrets.TEST_PYPI }}
repository_url: https://test.pypi.org/legacy/
+ versioning:
+ # compute versioning file from python setup.py
+ # upload as artifact
+ # (also used in docker.yml)
+ if: github.repository == 'Project-MONAI/MONAI'
+ needs: packaging
+ container:
+ image: localhost:5000/local_monai:latest
+ runs-on: [self-hosted, linux, x64, build_only]
+ steps:
+ - uses: actions/checkout@v2
+ # full history so that we can git describe
+ with:
+ ref: main
+ fetch-depth: 0
+ - shell: bash
+ run: |
+ git describe
+ python setup.py build
+ cat build/lib/monai/_version.py
+ - name: Upload version
+ uses: actions/upload-artifact@v2
+ with:
+ name: _version.py
+ path: build/lib/monai/_version.py
+ - name: Clean up directory
+ shell: bash
+ run: |
+ ls -al
+ rm -rf {*,.[^.]*}
+
+ release_tag_docker:
+ if: github.repository == 'Project-MONAI/MONAI'
+ needs: versioning
+ runs-on: [self-hosted, linux, x64, build_only]
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ ref: main
+ - name: Download version
+ uses: actions/download-artifact@v2
+ with:
+ name: _version.py
+ - name: Set tag
+ id: versioning
+ run: echo ::set-output name=tag::${GITHUB_REF#refs/*/}
+ - name: Check tag
+ env:
+ RELEASE_VERSION: ${{ steps.versioning.outputs.tag }}
+ run: |
+ echo "$RELEASE_VERSION"
+ cat _version.py
+ - if: startsWith(github.ref, 'refs/tags/')
+ name: build with the tag
+ env:
+ RELEASE_VERSION: ${{ steps.versioning.outputs.tag }}
+ shell: bash
+ run: |
+ # get tag info for versioning
+ mv _version.py monai/
+ # remove flake package as it is not needed on hub.docker.com
+ sed -i '/flake/d' requirements-dev.txt
+ docker build -t projectmonai/monai:"$RELEASE_VERSION" -f Dockerfile .
+ # distribute with a tag to hub.docker.com
+ echo "${{ secrets.DOCKER_PW }}" | docker login -u projectmonai --password-stdin
+ docker push projectmonai/monai:"$RELEASE_VERSION"
+ docker logout
diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml
index 7656eb4828..295d4814c8 100644
--- a/.github/workflows/setupapp.yml
+++ b/.github/workflows/setupapp.yml
@@ -1,14 +1,22 @@
name: deploy
on:
- # master only tests
+ # full tests for all the important branches
push:
branches:
- - master
+ - dev
+ - main
+ - releasing/*
+ - feature/*
+
+concurrency:
+ # automatically cancel the previously triggered workflows when there's a newer version
+ group: deploy-${{ github.event.pull_request.number || github.ref }}
+ cancel-in-progress: true
jobs:
# caching of these jobs:
- # - docker-20-03-py3-pip- (shared)
+ # - docker-py3-pip- (shared)
# - ubuntu py36 37 38-pip-
# - os-latest-pip (shared)
coverage-py3:
@@ -30,35 +38,43 @@ jobs:
path: |
~/.cache/pip
~/.cache/torch
- key: docker-20-03-py3-pip-${{ steps.pip-cache.outputs.datew }}
+ key: docker-py3-pip-${{ steps.pip-cache.outputs.datew }}
- name: Install the dependencies
run: |
which python
python -m pip install --upgrade pip wheel
python -m pip uninstall -y torch torchvision
- python -m pip install torch==1.7.1 torchvision==0.8.2
+ python -m pip install torch==1.9.0 torchvision==0.10.0
python -m pip install -r requirements-dev.txt
- name: Run unit tests report coverage
run: |
python -m pip list
+ export LAUNCH_DELAY=$[ $RANDOM % 16 * 60 ]
+ echo "Sleep $LAUNCH_DELAY"
+ sleep $LAUNCH_DELAY
nvidia-smi
export CUDA_VISIBLE_DEVICES=$(python -m tests.utils)
echo $CUDA_VISIBLE_DEVICES
+ trap 'if pgrep python; then pkill python; fi;' ERR
+ python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null &
python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))"
- python -c 'import torch; print(torch.rand(5,3, device=torch.device("cuda:0")))'
- BUILD_MONAI=1 ./runtests.sh --coverage
+ python -c 'import torch; print(torch.rand(5, 3, device=torch.device("cuda:0")))'
+ BUILD_MONAI=1 ./runtests.sh --coverage --unittests # unit tests with coverage report
+ BUILD_MONAI=1 ./runtests.sh --coverage --net # integration tests with coverage report
coverage xml
+ if pgrep python; then pkill python; fi
+ shell: bash
- name: Upload coverage
uses: codecov/codecov-action@v1
with:
- fail_ci_if_error: true
+ fail_ci_if_error: false
file: ./coverage.xml
test-py3x:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: [3.6, 3.7, 3.8]
+ python-version: [3.6, 3.7, 3.8, 3.9]
steps:
- uses: actions/checkout@v2
with:
@@ -82,13 +98,13 @@ jobs:
- name: Install the dependencies
run: |
python -m pip install --upgrade pip wheel
- python -m pip install torch==1.7.1 torchvision==0.8.2
+ python -m pip install torch==1.9.0 torchvision==0.10.0
python -m pip install -r requirements-dev.txt
- name: Run quick tests CPU ubuntu
run: |
python -m pip list
python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'
- BUILD_MONAI=1 ./runtests.sh --quick
+ BUILD_MONAI=1 ./runtests.sh --quick --unittests
coverage xml
- name: Upload coverage
uses: codecov/codecov-action@v1
@@ -96,7 +112,7 @@ jobs:
fail_ci_if_error: false
file: ./coverage.xml
- install: # pip install from github url
+ install: # pip install from github url, the default branch is dev
runs-on: ubuntu-latest
steps:
- name: Set up Python 3.8
@@ -115,21 +131,26 @@ jobs:
~/.cache/pip
~/.cache/torch
key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }}
- - name: Install the default branch no build
+ - name: Install the default branch no build (dev branch only)
+ if: github.ref == 'refs/heads/dev'
run: |
BUILD_MONAI=0 pip install git+https://github.com/Project-MONAI/MONAI#egg=MONAI
python -c 'import monai; monai.config.print_config()'
cd $(python -c 'import monai; import os; print(os.path.dirname(monai.__file__))')
ls .
pip uninstall -y monai
- - name: Install the default branch with build
+ - name: Install the default branch with build (dev branch only)
+ if: github.ref == 'refs/heads/dev'
run: |
BUILD_MONAI=1 pip install git+https://github.com/Project-MONAI/MONAI#egg=MONAI
python -c 'import monai; monai.config.print_config()'
- - uses: actions/checkout@v2
+ - name: Get the test cases (dev branch only)
+ if: github.ref == 'refs/heads/dev'
+ uses: actions/checkout@v2
with:
- ref: master
- - name: Quick test installed
+ ref: dev
+ - name: Quick test installed (dev branch only)
+ if: github.ref == 'refs/heads/dev'
run: |
cd $GITHUB_WORKSPACE
rm -rf monai/
@@ -138,44 +159,3 @@ jobs:
python -m tests.min_tests
env:
QUICKTEST: True
-
- local_docker:
- if: github.repository == 'Project-MONAI/MONAI'
- runs-on: [self-hosted, linux, x64, build_only]
- # we only push built container if it is built from master branch
- steps:
- - uses: actions/checkout@v2
- with:
- ref: master
- - name: docker_build
- run: |
- # build and run original docker image for local registry
- docker build -t localhost:5000/local_monai:latest -f Dockerfile .
- docker push localhost:5000/local_monai:latest
- # build once more w/ tag "latest": remove flake package as it is not needed on hub.docker.com
- sed -i '/flake/d' requirements-dev.txt
- docker build -t projectmonai/monai:latest -f Dockerfile .
- # also push as tag "dockerhub" to local registry
- docker image tag projectmonai/monai:latest localhost:5000/local_monai:dockerhub
- docker push localhost:5000/local_monai:dockerhub
- # distribute as always w/ tag "latest" to hub.docker.com
- echo "${{ secrets.DOCKER_PW }}" | docker login -u projectmonai --password-stdin
- docker push projectmonai/monai:latest
- docker logout
-
- docker:
- if: github.repository == 'Project-MONAI/MONAI'
- needs: local_docker
- container:
- image: localhost:5000/local_monai:latest
- runs-on: [self-hosted, linux, x64, common]
- steps:
- - name: Import
- run: |
- python -c 'import monai; monai.config.print_config()'
- cd /opt/monai
- ls -al
- ngc --version
- python -m tests.min_tests
- env:
- QUICKTEST: True
diff --git a/.github/workflows/weekly-preview.yml b/.github/workflows/weekly-preview.yml
index bb68a0801d..981ca5cdaf 100644
--- a/.github/workflows/weekly-preview.yml
+++ b/.github/workflows/weekly-preview.yml
@@ -11,6 +11,7 @@ jobs:
steps:
- uses: actions/checkout@v2
with:
+ ref: dev
fetch-depth: 0
- name: Set up Python 3.8
uses: actions/setup-python@v2
@@ -32,7 +33,7 @@ jobs:
export YEAR_WEEK=$(date +'%y%U')
echo "Year week for tag is ${YEAR_WEEK}"
if ! [[ $YEAR_WEEK =~ ^[0-9]{4}$ ]] ; then echo "Wrong 'year week' format. Should be 4 digits."; exit 1 ; fi
- git tag "0.5.dev${YEAR_WEEK}"
+ git tag "0.7.dev${YEAR_WEEK}"
git log -1
git tag --list
python setup.py sdist bdist_wheel
diff --git a/.gitignore b/.gitignore
index 0d1455d70d..7444d7f2f9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -40,6 +40,7 @@ htmlcov/
.tox/
.coverage
.coverage.*
+.coverage/
.cache
nosetests.xml
coverage.xml
@@ -47,6 +48,9 @@ coverage.xml
.hypothesis/
.pytest_cache/
+# temporary unittest artifacts
+tests/testing_data/temp_*
+
# Translations
*.mo
*.pot
@@ -124,6 +128,7 @@ temp/
# temporary testing data MedNIST
tests/testing_data/MedNIST*
tests/testing_data/*Hippocampus*
+tests/testing_data/*.tiff
# clang format tool
.clang-format-bin/
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 56e65a7d92..55a0ca11e9 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,167 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [Unreleased]
+## [0.6.0] - 2021-07-08
+### Added
+* 10 new transforms, a masked loss wrapper, and a `NetAdapter` for transfer learning
+* APIs to load networks and pre-trained weights from Clara Train [Medical Model ARchives (MMARs)](https://docs.nvidia.com/clara/clara-train-sdk/pt/mmar.html)
+* Base metric and cumulative metric APIs, 4 new regression metrics
+* Initial CSV dataset support
+* Decollating mini-batch as the default first postprocessing step, [Migrating your v0.5 code to v0.6](https://github.com/Project-MONAI/MONAI/wiki/v0.5-to-v0.6-migration-guide) wiki shows how to adapt to the breaking changes
+* Initial backward compatibility support via `monai.utils.deprecated`
+* Attention-based vision modules and `UNETR` for segmentation
+* Generic module loaders and Gaussian mixture models using the PyTorch JIT compilation
+* Inverse of image patch sampling transforms
+* Network block utilities `get_[norm, act, dropout, pool]_layer`
+* `unpack_items` mode for `apply_transform` and `Compose`
+* New event `INNER_ITERATION_STARTED` in the deepgrow interactive workflow
+* `set_data` API for cache-based datasets to dynamically update the dataset content
+* Fully compatible with PyTorch 1.9
+* `--disttests` and `--min` options for `runtests.sh`
+* Initial support of pre-merge tests with Nvidia Blossom system
+### Changed
+* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:21.06-py3` from
+ `nvcr.io/nvidia/pytorch:21.04-py3`
+* Optionally depend on PyTorch-Ignite v0.4.5 instead of v0.4.4
+* Unified the demo, tutorial, testing data to the project shared drive, and
+ [`Project-MONAI/MONAI-extra-test-data`](https://github.com/Project-MONAI/MONAI-extra-test-data)
+* Unified the terms: `post_transform` is renamed to `postprocessing`, `pre_transform` is renamed to `preprocessing`
+* Unified the postprocessing transforms and event handlers to accept the "channel-first" data format
+* `evenly_divisible_all_gather` and `string_list_all_gather` moved to `monai.utils.dist`
+### Removed
+* Support of 'batched' input for postprocessing transforms and event handlers
+* `TorchVisionFullyConvModel`
+* `set_visible_devices` utility function
+* `SegmentationSaver` and `TransformsInverter` handlers
+### Fixed
+* Issue of handling big-endian image headers
+* Multi-thread issue for non-random transforms in the cache-based datasets
+* Persistent dataset issue when multiple processes sharing a non-exist cache location
+* Typing issue with Numpy 1.21.0
+* Loading checkpoint with both `model` and `optmizier` using `CheckpointLoader` when `strict_shape=False`
+* `SplitChannel` has different behaviour depending on numpy/torch inputs
+* Transform pickling issue caused by the Lambda functions
+* Issue of filtering by name in `generate_param_groups`
+* Inconsistencies in the return value types of `class_activation_maps`
+* Various docstring typos
+* Various usability enhancements in `monai.transforms`
+
+## [0.5.3] - 2021-05-28
+### Changed
+* Project default branch renamed to `dev` from `master`
+* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:21.04-py3` from `nvcr.io/nvidia/pytorch:21.02-py3`
+* Enhanced type checks for the `iteration_metric` handler
+* Enhanced `PersistentDataset` to use `tempfile` during caching computation
+* Enhanced various info/error messages
+* Enhanced performance of `RandAffine`
+* Enhanced performance of `SmartCacheDataset`
+* Optionally requires `cucim` when the platform is `Linux`
+* Default `device` of `TestTimeAugmentation` changed to `cpu`
+
+### Fixed
+* Download utilities now provide better default parameters
+* Duplicated `key_transforms` in the patch-based transforms
+* A multi-GPU issue in `ClassificationSaver`
+* A default `meta_data` issue in `SpacingD`
+* Dataset caching issue with the persistent data loader workers
+* A memory issue in `permutohedral_cuda`
+* Dictionary key issue in `CopyItemsd`
+* `box_start` and `box_end` parameters for deepgrow `SpatialCropForegroundd`
+* Tissue mask array transpose issue in `MaskedInferenceWSIDataset`
+* Various type hint errors
+* Various docstring typos
+
+### Added
+* Support of `to_tensor` and `device` arguments for `TransformInverter`
+* Slicing options with SpatialCrop
+* Class name alias for the networks for backward compatibility
+* `k_divisible` option for CropForeground
+* `map_items` option for `Compose`
+* Warnings of `inf` and `nan` for surface distance computation
+* A `print_log` flag to the image savers
+* Basic testing pipelines for Python 3.9
+
+## [0.5.0] - 2021-04-09
+### Added
+* Overview document for [feature highlights in v0.5.0](https://github.com/Project-MONAI/MONAI/blob/master/docs/source/highlights.md)
+* Invertible spatial transforms
+ * `InvertibleTransform` base APIs
+ * Batch inverse and decollating APIs
+ * Inverse of `Compose`
+ * Batch inverse event handling
+ * Test-time augmentation as an application
+* Initial support of learning-based image registration:
+ * Bending energy, LNCC, and global mutual information loss
+ * Fully convolutional architectures
+ * Dense displacement field, dense velocity field computation
+ * Warping with high-order interpolation with C++/CUDA implementations
+* Deepgrow modules for interactive segmentation:
+ * Workflows with simulations of clicks
+ * Distance-based transforms for guidance signals
+* Digital pathology support:
+ * Efficient whole slide imaging IO and sampling with Nvidia cuCIM and SmartCache
+ * FROC measurements for lesion
+ * Probabilistic post-processing for lesion detection
+ * TorchVision classification model adaptor for fully convolutional analysis
+* 12 new transforms, grid patch dataset, `ThreadDataLoader`, EfficientNets B0-B7
+* 4 iteration events for the engine for finer control of workflows
+* New C++/CUDA extensions:
+ * Conditional random field
+ * Fast bilateral filtering using the permutohedral lattice
+* Metrics summary reporting and saving APIs
+* DiceCELoss, DiceFocalLoss, a multi-scale wrapper for segmentation loss computation
+* Data loading utilities:
+ * `decollate_batch`
+ * `PadListDataCollate` with inverse support
+* Support of slicing syntax for `Dataset`
+* Initial Torchscript support for the loss modules
+* Learning rate finder
+* Allow for missing keys in the dictionary-based transforms
+* Support of checkpoint loading for transfer learning
+* Various summary and plotting utilities for Jupyter notebooks
+* Contributor Covenant Code of Conduct
+* Major CI/CD enhancements covering the tutorial repository
+* Fully compatible with PyTorch 1.8
+* Initial nightly CI/CD pipelines using Nvidia Blossom Infrastructure
+
+### Changed
+* Enhanced `list_data_collate` error handling
+* Unified iteration metric APIs
+* `densenet*` extensions are renamed to `DenseNet*`
+* `se_res*` network extensions are renamed to `SERes*`
+* Transform base APIs are rearranged into `compose`, `inverse`, and `transform`
+* `_do_transform` flag for the random augmentations is unified via `RandomizableTransform`
+* Decoupled post-processing steps, e.g. `softmax`, `to_onehot_y`, from the metrics computations
+* Moved the distributed samplers to `monai.data.samplers` from `monai.data.utils`
+* Engine's data loaders now accept generic iterables as input
+* Workflows now accept additional custom events and state properties
+* Various type hints according to Numpy 1.20
+* Refactored testing utility `runtests.sh` to have `--unittest` and `--net` (integration tests) options
+* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:21.02-py3` from `nvcr.io/nvidia/pytorch:20.10-py3`
+* Docker images are now built with self-hosted environments
+* Primary contact email updated to `monai.contact@gmail.com`
+* Now using GitHub Discussions as the primary communication forum
+
+### Removed
+* Compatibility tests for PyTorch 1.5.x
+* Format specific loaders, e.g. `LoadNifti`, `NiftiDataset`
+* Assert statements from non-test files
+* `from module import *` statements, addressed flake8 F403
+
+### Fixed
+* Uses American English spelling for code, as per PyTorch
+* Code coverage now takes multiprocessing runs into account
+* SmartCache with initial shuffling
+* `ConvertToMultiChannelBasedOnBratsClasses` now supports channel-first inputs
+* Checkpoint handler to save with non-root permissions
+* Fixed an issue for exiting the distributed unit tests
+* Unified `DynUNet` to have single tensor output w/o deep supervision
+* `SegmentationSaver` now supports user-specified data types and a `squeeze_end_dims` flag
+* Fixed `*Saver` event handlers output filenames with a `data_root_dir` option
+* Load image functions now ensure little-endian
+* Fixed the test runner to support regex-based test case matching
+* Usability issues in the event handlers
+
## [0.4.0] - 2020-12-15
### Added
* Overview document for [feature highlights in v0.4.0](https://github.com/Project-MONAI/MONAI/blob/master/docs/source/highlights.md)
@@ -173,7 +334,10 @@ the postprocessing steps should be used before calling the metrics methods
[highlights]: https://github.com/Project-MONAI/MONAI/blob/master/docs/source/highlights.md
-[Unreleased]: https://github.com/Project-MONAI/MONAI/compare/0.4.0...HEAD
+[Unreleased]: https://github.com/Project-MONAI/MONAI/compare/0.6.0...HEAD
+[0.6.0]: https://github.com/Project-MONAI/MONAI/compare/0.5.3...0.6.0
+[0.5.3]: https://github.com/Project-MONAI/MONAI/compare/0.5.0...0.5.3
+[0.5.0]: https://github.com/Project-MONAI/MONAI/compare/0.4.0...0.5.0
[0.4.0]: https://github.com/Project-MONAI/MONAI/compare/0.3.0...0.4.0
[0.3.0]: https://github.com/Project-MONAI/MONAI/compare/0.2.0...0.3.0
[0.2.0]: https://github.com/Project-MONAI/MONAI/compare/0.1.0...0.2.0
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 01a4773b5a..ca40261dea 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -16,21 +16,21 @@
## Introduction
-This documentation is intended for individuals and institutions interested in contributing to MONAI. MONAI is an open-source project and, as such, its success relies on its community of contributors willing to keep improving it. Your contribution will be a valued addition to the code base; we simply ask that you read this page and understand our contribution process, whether you are a seasoned open-source contributor or whether you are a first-time contributor.
+Welcome to Project MONAI! We're excited you're here and want to contribute. This documentation is intended for individuals and institutions interested in contributing to MONAI. MONAI is an open-source project and, as such, its success relies on its community of contributors willing to keep improving it. Your contribution will be a valued addition to the code base; we simply ask that you read this page and understand our contribution process, whether you are a seasoned open-source contributor or whether you are a first-time contributor.
### Communicate with us
-We are happy to talk with you about your needs for MONAI and your ideas for contributing to the project. One way to do this is to create an issue discussing your thoughts. It might be that a very similar feature is under development or already exists, so an issue is a great starting point.
+We are happy to talk with you about your needs for MONAI and your ideas for contributing to the project. One way to do this is to create an issue discussing your thoughts. It might be that a very similar feature is under development or already exists, so an issue is a great starting point. If you are looking for an issue to resolve that will help Project MONAI, see the [*good first issue*](https://github.com/Project-MONAI/MONAI/labels/good%20first%20issue) and [*Contribution wanted*](https://github.com/Project-MONAI/MONAI/labels/Contribution%20wanted) labels.
### Does it belong in PyTorch instead of MONAI?
-MONAI is based on the PyTorch and Numpy libraries. These libraries implement what we consider to be best practice for general scientific computing and deep learning functionality. MONAI builds on these with a strong focus on medical applications. As such, it is a good idea to consider whether your functionality is medical-application specific or not. General deep learning functionality may be better off in PyTorch; you can find their contribution guidelines [here](https://pytorch.org/docs/stable/community/contribution_guide.html).
+MONAI is part of [PyTorch Ecosystem](https://pytorch.org/ecosystem/), and mainly based on the PyTorch and Numpy libraries. These libraries implement what we consider to be best practice for general scientific computing and deep learning functionality. MONAI builds on these with a strong focus on medical applications. As such, it is a good idea to consider whether your functionality is medical-application specific or not. General deep learning functionality may be better off in PyTorch; you can find their contribution guidelines [here](https://pytorch.org/docs/stable/community/contribution_guide.html).
## The contribution process
_Pull request early_
-We encourage you to create pull requests early. It helps us track the contributions under development, whether they are ready to be merged or not. Change your pull request's title to begin with `[WIP]` until it is ready for formal review.
+We encourage you to create pull requests early. It helps us track the contributions under development, whether they are ready to be merged or not. Change your pull request's title, to begin with `[WIP]` and/or [create a draft pull request](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/about-pull-requests#draft-pull-requests) until it is ready for formal review.
Please note that, as per PyTorch, MONAI uses American English spelling. This means classes and variables should be: normali**z**e, visuali**z**e, colo~~u~~r, etc.
@@ -51,11 +51,15 @@ Coding style is checked and enforced by flake8, black, and isort, using [a flake
Before submitting a pull request, we recommend that all linting should pass, by running the following command locally:
```bash
-pip install -U -r requirements-dev.txt # install the latest tools
-./runtests.sh --codeformat --nounittests # runs the linting tools only
+# optionally update the dependencies and dev tools
+python -m pip install -U pip
+python -m pip install -U -r requirements-dev.txt
+
+# run the linting and type checking tools
+./runtests.sh --codeformat
# try to fix the coding style errors automatically
-./runtests.sh --autofix --nounittests
+./runtests.sh --autofix
```
License information: all source code files should start with this paragraph:
@@ -83,10 +87,11 @@ If you intend for any variables/functions/classes to be available outside of the
#### Unit testing
MONAI tests are located under `tests/`.
-- The unit test's file name follows `test_[module_name].py`.
+- The unit test's file name currently follows `test_[module_name].py` or `test_[module_name]_dist.py`.
+- The `test_[module_name]_dist.py` subset of unit tests requires a distributed environment to verify the module with distributed GPU-based computation.
- The integration test's file name follows `test_integration_[workflow_name].py`.
-A bash script (`runtests.sh`) is provided to run all tests locally
+A bash script (`runtests.sh`) is provided to run all tests locally.
Please run ``./runtests.sh -h`` to see all options.
To run a particular test, for example `tests/test_dice_loss.py`:
@@ -98,16 +103,16 @@ Before submitting a pull request, we recommend that all linting and unit tests
should pass, by running the following command locally:
```bash
-./runtests.sh --codeformat --coverage
+./runtests.sh -f -u --net --coverage
```
or (for new features that would not break existing functionality):
```bash
-./runtests.sh --quick
+./runtests.sh --quick --unittests
```
It is recommended that the new test `test_[module_name].py` is constructed by using only
-python 3.6+ build-in functions, `torch`, `numpy`, and `parameterized` packages.
+python 3.6+ build-in functions, `torch`, `numpy`, `coverage` (for reporting code coverages) and `parameterized` (for organising test cases) packages.
If it requires any other external packages, please make sure:
- the packages are listed in [`requirements-dev.txt`](requirements-dev.txt)
- the new test `test_[module_name].py` is added to the `exclude_cases` in [`./tests/min_tests.py`](./tests/min_tests.py) so that
@@ -141,7 +146,7 @@ Before submitting a pull request, it is recommended to:
- check the auto-generated documentation (by browsing `./docs/build/html/index.html` with a web browser)
- type `make clean` in `docs/` folder to remove the current build files.
-Please type `make help` for all supported format options.
+Please type `make help` in `docs/` folder for all supported format options.
#### Automatic code formatting
MONAI provides support of automatic Python code formatting via [a customised GitHub action](https://github.com/Project-MONAI/monai-code-formatter).
@@ -226,19 +231,19 @@ For string definition, [f-string](https://www.python.org/dev/peps/pep-0498/) is
### Submitting pull requests
-All code changes to the master branch must be done via [pull requests](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/proposing-changes-to-your-work-with-pull-requests).
+All code changes to the dev branch must be done via [pull requests](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/proposing-changes-to-your-work-with-pull-requests).
1. Create a new ticket or take a known ticket from [the issue list][monai issue list].
1. Check if there's already a branch dedicated to the task.
1. If the task has not been taken, [create a new branch in your fork](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/creating-a-pull-request-from-a-fork)
of the codebase named `[ticket_id]-[task_name]`.
For example, branch name `19-ci-pipeline-setup` corresponds to [issue #19](https://github.com/Project-MONAI/MONAI/issues/19).
-Ideally, the new branch should be based on the latest `master` branch.
+Ideally, the new branch should be based on the latest `dev` branch.
1. Make changes to the branch ([use detailed commit messages if possible](https://chris.beams.io/posts/git-commit/)).
1. Make sure that new tests cover the changes and the changed codebase [passes all tests locally](#unit-testing).
-1. [Create a new pull request](https://help.github.com/en/desktop/contributing-to-projects/creating-a-pull-request) from the task branch to the master branch, with detailed descriptions of the purpose of this pull request.
+1. [Create a new pull request](https://help.github.com/en/desktop/contributing-to-projects/creating-a-pull-request) from the task branch to the dev branch, with detailed descriptions of the purpose of this pull request.
1. Check [the CI/CD status of the pull request][github ci], make sure all CI/CD tests passed.
1. Wait for reviews; if there are reviews, make point-to-point responses, make further code changes if needed.
-1. If there are conflicts between the pull request branch and the master branch, pull the changes from the master and resolve the conflicts locally.
+1. If there are conflicts between the pull request branch and the dev branch, pull the changes from the dev and resolve the conflicts locally.
1. Reviewer and contributor may have discussions back and forth until all comments addressed.
1. Wait for the pull request to be merged.
@@ -251,7 +256,8 @@ All code review comments should be specific, constructive, and actionable.
1. Read carefully the descriptions of the pull request and the files changed, write comments if needed.
1. Make in-line comments to specific code segments, [request for changes](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/about-pull-request-reviews) if needed.
1. Review any further code changes until all comments addressed by the contributors.
-1. Merge the pull request to the master branch.
+1. Comment to trigger `/black` and/or `/integration-test` for optional auto code formatting and [integration tests](.github/workflows/integration.yml).
+1. Merge the pull request to the dev branch.
1. Close the corresponding task ticket on [the issue list][monai issue list].
[github ci]: https://github.com/Project-MONAI/MONAI/actions
@@ -261,31 +267,34 @@ All code review comments should be specific, constructive, and actionable.
## Admin tasks
### Release a new version
-- Prepare [a release note](https://github.com/Project-MONAI/MONAI/releases).
-- Checkout a new branch `releases/[version number]` locally from the master branch and push to the codebase.
-- Create a tag, for example `git tag -a 0.1a -m "version 0.1a"`.
-- Push the tag to the codebase, for example `git push origin 0.1a`.
+The `dev` branch's `HEAD` always corresponds to MONAI docker image's latest tag: `projectmonai/monai:latest`.
+The `main` branch's `HEAD` always corresponds to the latest MONAI milestone release.
+
+When major features are ready for a milestone, to prepare for a new release:
+- Prepare [a release note](https://github.com/Project-MONAI/MONAI/releases) and release checklist.
+- Check out or cherry-pick a new branch `releasing/[version number]` locally from the `dev` branch and push to the codebase.
+- Create a release candidate tag, for example, `git tag -a 0.1.0rc1 -m "release candidate 1 of version 0.1.0"`.
+- Push the tag to the codebase, for example, `git push origin 0.1.0rc1`.
This step will trigger package building and testing.
The resultant packages are automatically uploaded to
[TestPyPI](https://test.pypi.org/project/monai/). The packages are also available for downloading as
repository's artifacts (e.g. the file at https://github.com/Project-MONAI/MONAI/actions/runs/66570977).
- Check the release test at [TestPyPI](https://test.pypi.org/project/monai/), download the artifacts when the CI finishes.
+- Optionally run [the cron testing jobs](https://github.com/Project-MONAI/MONAI/blob/dev/.github/workflows/cron.yml) on `releasing/[version number]`.
+- Once the release candidate is verified, tag and push a milestone, for example, `git push origin 0.1.0`.
+ The tag must be with the latest commit of `releasing/[version number]`.
+- Rebase `releasing/[version number]` to `main`, make sure all the test pipelines succeed.
- Upload the packages to [PyPI](https://pypi.org/project/monai/).
This could be done manually by ``twine upload dist/*``, given the artifacts are unzipped to the folder ``dist/``.
+- Merge `releasing/[version number]` to `dev`, this step must make sure that the tagging commit unchanged on `dev`.
- Publish the release note.
Note that the release should be tagged with a [PEP440](https://www.python.org/dev/peps/pep-0440/) compliant
[semantic versioning](https://semver.org/spec/v2.0.0.html) number.
-If any error occurs during the release process, first checkout a new branch from the master, make PRs to the master
-to fix the bugs via the regular contribution procedure.
-Then rollback the release branch and tag:
- - remove any artifacts (website UI) and tag (`git tag -d` and `git push origin -d`).
- - reset the `releases/[version number]` branch to the latest master:
- ```bash
-git checkout master
-git pull origin master
-git checkout releases/[version number]
-git reset --hard master
-```
-Finally, repeat the tagging and TestPyPI uploading process.
+If any error occurs during the release process, first check out a new hotfix branch from the `releasing/[version number]`,
+then make PRs to the `releasing/[version number]` to fix the bugs via the regular contribution procedure.
+
+If any error occurs after the release process, first check out a new hotfix branch from the `main` branch,
+make a minor version release following the semantic versioning, for example, `releasing/0.1.1`.
+Make sure the `releasing/0.1.1` is merged back into both `dev` and `main` and all the test pipelines succeed.
diff --git a/Dockerfile b/Dockerfile
index 47976b97b1..ac06183768 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -9,8 +9,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:20.12-py3
-
+# To build with a different base image
+# please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag.
+ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:21.06-py3
FROM ${PYTORCH_IMAGE}
LABEL maintainer="monai.contact@gmail.com"
@@ -29,10 +30,9 @@ RUN cp /tmp/requirements.txt /tmp/req.bak \
# please specify exact files and folders to be copied -- else, basically always, the Docker build process cannot cache
# this or anything below it and always will build from at most here; one file change leads to no caching from here on...
-COPY LICENSE setup.py setup.cfg versioneer.py runtests.sh .gitignore .gitattributes README.md MANIFEST.in ./
+COPY LICENSE CHANGELOG.md CODE_OF_CONDUCT.md CONTRIBUTING.md README.md versioneer.py setup.py setup.cfg runtests.sh MANIFEST.in ./
COPY tests ./tests
COPY monai ./monai
-COPY .git ./.git
RUN BUILD_MONAI=1 FORCE_CUDA=1 python setup.py develop \
&& rm -rf build __pycache__
@@ -43,6 +43,9 @@ RUN wget -q ${NGC_CLI_URI} && \
unzip ngccli_cat_linux.zip && chmod u+x ngc && \
md5sum -c ngc.md5 && \
rm -rf ngccli_cat_linux.zip ngc.md5
+RUN apt-get update \
+ && DEBIAN_FRONTEND="noninteractive" apt-get install -y libopenslide0 \
+ && rm -rf /var/lib/apt/lists/*
# append /opt/tools to runtime path for NGC CLI to be accessible from all file system locations
ENV PATH=${PATH}:/opt/tools
WORKDIR /opt/monai
diff --git a/README.md b/README.md
index f06a2d146f..e9facef64d 100644
--- a/README.md
+++ b/README.md
@@ -1,16 +1,16 @@
-
+
**M**edical **O**pen **N**etwork for **AI**
[![License](https://img.shields.io/badge/license-Apache%202.0-green.svg)](https://opensource.org/licenses/Apache-2.0)
-[![CI Build](https://github.com/Project-MONAI/MONAI/workflows/build/badge.svg?branch=master)](https://github.com/Project-MONAI/MONAI/commits/master)
+[![CI Build](https://github.com/Project-MONAI/MONAI/workflows/build/badge.svg?branch=dev)](https://github.com/Project-MONAI/MONAI/commits/dev)
[![Documentation Status](https://readthedocs.org/projects/monai/badge/?version=latest)](https://docs.monai.io/en/latest/?badge=latest)
-[![codecov](https://codecov.io/gh/Project-MONAI/MONAI/branch/master/graph/badge.svg)](https://codecov.io/gh/Project-MONAI/MONAI)
+[![codecov](https://codecov.io/gh/Project-MONAI/MONAI/branch/dev/graph/badge.svg)](https://codecov.io/gh/Project-MONAI/MONAI)
[![PyPI version](https://badge.fury.io/py/monai.svg)](https://badge.fury.io/py/monai)
-MONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](https://github.com/Project-MONAI/MONAI/blob/master/LICENSE) framework for deep learning in healthcare imaging, part of [PyTorch Ecosystem](https://pytorch.org/ecosystem/).
+MONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](LICENSE) framework for deep learning in healthcare imaging, part of [PyTorch Ecosystem](https://pytorch.org/ecosystem/).
Its ambitions are:
- developing a community of academic, industrial and clinical researchers collaborating on a common foundation;
- creating state-of-the-art, end-to-end training workflows for healthcare imaging;
@@ -19,7 +19,7 @@ Its ambitions are:
## Features
> _The codebase is currently under active development._
-> _Please see [the technical highlights](https://docs.monai.io/en/latest/highlights.html) of the current milestone release._
+> _Please see [the technical highlights](https://docs.monai.io/en/latest/highlights.html) and [What's New in 0.6](https://docs.monai.io/en/latest/whatsnew_0_6.html) of the current milestone release._
- flexible pre-processing for multi-dimensional medical imaging data;
- compositional & portable APIs for ease of integration in existing workflows;
@@ -36,7 +36,7 @@ To install [the current release](https://pypi.org/project/monai/), you can simpl
pip install monai
```
-For other installation methods (using the master branch, using Docker, etc.), please refer to [the installation guide](https://docs.monai.io/en/latest/installation.html).
+For other installation methods (using the default GitHub branch, using Docker, etc.), please refer to [the installation guide](https://docs.monai.io/en/latest/installation.html).
## Getting Started
@@ -47,7 +47,7 @@ Examples and notebook tutorials are located at [Project-MONAI/tutorials](https:/
Technical documentation is available at [docs.monai.io](https://docs.monai.io).
## Contributing
-For guidance on making a contribution to MONAI, see the [contributing guidelines](https://github.com/Project-MONAI/MONAI/blob/master/CONTRIBUTING.md).
+For guidance on making a contribution to MONAI, see the [contributing guidelines](CONTRIBUTING.md).
## Community
Join the conversation on Twitter [@ProjectMONAI](https://twitter.com/ProjectMONAI) or join our [Slack channel](https://forms.gle/QTxJq3hFictp31UM9).
@@ -62,3 +62,6 @@ Ask and answer questions over on [MONAI's GitHub Discussions tab](https://github
- Issue tracker: https://github.com/Project-MONAI/MONAI/issues
- Wiki: https://github.com/Project-MONAI/MONAI/wiki
- Test status: https://github.com/Project-MONAI/MONAI/actions
+- PyPI package: https://pypi.org/project/monai/
+- Weekly previews: https://pypi.org/project/monai-weekly/
+- Docker Hub: https://hub.docker.com/r/projectmonai/monai
diff --git a/docs/images/3d_paired.png b/docs/images/3d_paired.png
new file mode 100644
index 0000000000..dd751c8e16
Binary files /dev/null and b/docs/images/3d_paired.png differ
diff --git a/docs/images/BTCV_organs.png b/docs/images/BTCV_organs.png
new file mode 100644
index 0000000000..e109437c98
Binary files /dev/null and b/docs/images/BTCV_organs.png differ
diff --git a/docs/images/UNETR.png b/docs/images/UNETR.png
new file mode 100644
index 0000000000..d028f26fd6
Binary files /dev/null and b/docs/images/UNETR.png differ
diff --git a/docs/images/decollate_batch.png b/docs/images/decollate_batch.png
new file mode 100644
index 0000000000..2a1c0c832c
Binary files /dev/null and b/docs/images/decollate_batch.png differ
diff --git a/docs/images/deepgrow.png b/docs/images/deepgrow.png
new file mode 100644
index 0000000000..d006bd0d09
Binary files /dev/null and b/docs/images/deepgrow.png differ
diff --git a/docs/images/deepgrow_scheme.png b/docs/images/deepgrow_scheme.png
new file mode 100644
index 0000000000..9b4e400839
Binary files /dev/null and b/docs/images/deepgrow_scheme.png differ
diff --git a/docs/images/gmm_feature_set_comparison_s.png b/docs/images/gmm_feature_set_comparison_s.png
new file mode 100644
index 0000000000..a0161b8194
Binary files /dev/null and b/docs/images/gmm_feature_set_comparison_s.png differ
diff --git a/docs/images/invert_transforms.png b/docs/images/invert_transforms.png
new file mode 100644
index 0000000000..fa3863f373
Binary files /dev/null and b/docs/images/invert_transforms.png differ
diff --git a/docs/images/lr_finder.png b/docs/images/lr_finder.png
new file mode 100644
index 0000000000..ed9ba69770
Binary files /dev/null and b/docs/images/lr_finder.png differ
diff --git a/docs/images/metrics_report.png b/docs/images/metrics_report.png
new file mode 100644
index 0000000000..a317fcdc21
Binary files /dev/null and b/docs/images/metrics_report.png differ
diff --git a/docs/images/pathology.png b/docs/images/pathology.png
new file mode 100644
index 0000000000..da12ad23e7
Binary files /dev/null and b/docs/images/pathology.png differ
diff --git a/docs/images/post_transforms.png b/docs/images/postprocessing_transforms.png
similarity index 100%
rename from docs/images/post_transforms.png
rename to docs/images/postprocessing_transforms.png
diff --git a/docs/images/transfer_mmar.png b/docs/images/transfer_mmar.png
new file mode 100644
index 0000000000..7ae5b876ea
Binary files /dev/null and b/docs/images/transfer_mmar.png differ
diff --git a/docs/images/tta.png b/docs/images/tta.png
new file mode 100644
index 0000000000..6c4e18ffa0
Binary files /dev/null and b/docs/images/tta.png differ
diff --git a/docs/requirements.txt b/docs/requirements.txt
index d046bc53cf..9d9cdebb3e 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,16 +1,16 @@
-f https://download.pytorch.org/whl/cpu/torch-1.6.0%2Bcpu-cp37-cp37m-linux_x86_64.whl
torch>=1.5
-pytorch-ignite==0.4.2
+pytorch-ignite==0.4.5
numpy>=1.17
-itk>=5.0
+itk>=5.2
nibabel
parameterized
scikit-image>=0.14.2
tensorboard
commonmark==0.9.1
recommonmark==0.6.0
-Sphinx==3.3.0
-sphinx-rtd-theme==0.5.0
+Sphinx==3.5.3
+sphinx-rtd-theme==0.5.2
sphinxcontrib-applehelp
sphinxcontrib-devhelp
sphinxcontrib-htmlhelp
@@ -18,3 +18,5 @@ sphinxcontrib-jsmath
sphinxcontrib-qthelp
sphinxcontrib-serializinghtml
sphinx-autodoc-typehints==1.11.1
+pandas
+einops
diff --git a/docs/source/apps.rst b/docs/source/apps.rst
index b8c8b4d341..f9f7a4159c 100644
--- a/docs/source/apps.rst
+++ b/docs/source/apps.rst
@@ -18,6 +18,17 @@ Applications
.. autoclass:: CrossValidation
:members:
+
+Clara MMARs
+-----------
+.. autofunction:: download_mmar
+
+.. autofunction:: load_from_mmar
+
+.. autodata:: monai.apps.MODEL_DESC
+ :annotation:
+
+
`Utilities`
-----------
@@ -46,9 +57,44 @@ Applications
:members:
.. autoclass:: AddRandomGuidanced
:members:
+.. autoclass:: AddGuidanceFromPointsd
+ :members:
.. autoclass:: SpatialCropForegroundd
:members:
+.. autoclass:: SpatialCropGuidanced
+ :members:
+.. autoclass:: RestoreLabeld
+ :members:
+.. autoclass:: ResizeGuidanced
+ :members:
.. autoclass:: FindDiscrepancyRegionsd
:members:
.. autoclass:: FindAllValidSlicesd
:members:
+.. autoclass:: Fetch2DSliced
+ :members:
+
+`Pathology`
+-----------
+
+.. automodule:: monai.apps.pathology.datasets
+.. autoclass:: PatchWSIDataset
+ :members:
+.. autoclass:: SmartCachePatchWSIDataset
+ :members:
+.. autoclass:: MaskedInferenceWSIDataset
+ :members:
+
+.. automodule:: monai.apps.pathology.handlers
+.. autoclass:: ProbMapProducer
+ :members:
+
+.. automodule:: monai.apps.pathology.metrics
+.. autoclass:: LesionFROC
+ :members:
+
+.. automodule:: monai.apps.pathology.utils
+.. autofunction:: compute_multi_instance_mask
+.. autofunction:: compute_isolated_tumor_cells
+.. autoclass:: PathologyProbNMS
+ :members:
diff --git a/docs/source/conf.py b/docs/source/conf.py
index a2f1b3af5c..780a2d9a6d 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -119,7 +119,7 @@ def setup(app):
"display_github": True,
"github_user": "Project-MONAI",
"github_repo": "MONAI",
- "github_version": "master",
+ "github_version": "dev",
"conf_py_path": "/docs/",
}
html_scaled_image_link = False
diff --git a/docs/source/data.rst b/docs/source/data.rst
index 11609964c3..a5c3509fc9 100644
--- a/docs/source/data.rst
+++ b/docs/source/data.rst
@@ -21,6 +21,12 @@ Generic Interfaces
:members:
:special-members: __next__
+`CSVIterableDataset`
+~~~~~~~~~~~~~~~~~~~~
+.. autoclass:: CSVIterableDataset
+ :members:
+ :special-members: __next__
+
`PersistentDataset`
~~~~~~~~~~~~~~~~~~~
.. autoclass:: PersistentDataset
@@ -68,6 +74,18 @@ Generic Interfaces
.. autoclass:: ImageDataset
:members:
:special-members: __getitem__
+
+`NPZDictItemDataset`
+~~~~~~~~~~~~~~~~~~~~
+.. autoclass:: NPZDictItemDataset
+ :members:
+ :special-members: __getitem__
+
+`CSVDataset`
+~~~~~~~~~~~~
+.. autoclass:: CSVDataset
+ :members:
+ :special-members: __getitem__
Patch-based dataset
-------------------
@@ -77,6 +95,11 @@ Patch-based dataset
.. autoclass:: GridPatchDataset
:members:
+`PatchIter`
+~~~~~~~~~~~
+.. autoclass:: PatchIter
+ :members:
+
`PatchDataset`
~~~~~~~~~~~~~~
.. autoclass:: PatchDataset
@@ -105,6 +128,10 @@ PILReader
.. autoclass:: PILReader
:members:
+WSIReader
+~~~~~~~~~
+.. autoclass:: WSIReader
+ :members:
Nifti format handling
---------------------
@@ -151,6 +178,9 @@ DistributedSampler
~~~~~~~~~~~~~~~~~~
.. autoclass:: monai.data.DistributedSampler
+DistributedWeightedRandomSampler
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+.. autoclass:: monai.data.DistributedWeightedRandomSampler
Decathlon Datalist
~~~~~~~~~~~~~~~~~~
@@ -165,3 +195,8 @@ DataLoader
ThreadBuffer
~~~~~~~~~~~~
.. autoclass:: monai.data.ThreadBuffer
+
+
+TestTimeAugmentation
+~~~~~~~~~~~~~~~~~~~~
+.. autoclass:: monai.data.TestTimeAugmentation
diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst
index a629b28b27..096777cdef 100644
--- a/docs/source/handlers.rst
+++ b/docs/source/handlers.rst
@@ -29,9 +29,9 @@ CSV saver
:members:
-Iteration Metric
-----------------
-.. autoclass:: IterationMetric
+Ignite Metric
+-------------
+.. autoclass:: IgniteMetric
:members:
@@ -65,6 +65,30 @@ Surface distance metrics handler
:members:
+Mean squared error metrics handler
+----------------------------------
+.. autoclass:: MeanSquaredError
+ :members:
+
+
+Mean absolute error metrics handler
+-----------------------------------
+.. autoclass:: MeanAbsoluteError
+ :members:
+
+
+Root mean squared error metrics handler
+---------------------------------------
+.. autoclass:: RootMeanSquaredError
+ :members:
+
+
+Peak signal to noise ratio metrics handler
+------------------------------------------
+.. autoclass:: PeakSignalToNoiseRatio
+ :members:
+
+
Metric logger
-------------
.. autoclass:: MetricLogger
@@ -110,3 +134,38 @@ SmartCache handler
------------------
.. autoclass:: SmartCacheHandler
:members:
+
+Parameter Scheduler handler
+---------------------------
+.. autoclass:: ParamSchedulerHandler
+ :members:
+
+EarlyStop handler
+-----------------
+.. autoclass:: EarlyStopHandler
+ :members:
+
+GarbageCollector handler
+------------------------
+.. autoclass:: GarbageCollector
+ :members:
+
+Transform inverter
+------------------
+.. autoclass:: TransformInverter
+ :members:
+
+Post processing
+---------------
+.. autoclass:: PostProcessing
+ :members:
+
+Decollate batch
+---------------
+.. autoclass:: DecollateBatch
+ :members:
+
+Utilities
+---------
+.. automodule:: monai.handlers.utils
+ :members:
diff --git a/docs/source/highlights.md b/docs/source/highlights.md
index 29302bda77..61935fd3dc 100644
--- a/docs/source/highlights.md
+++ b/docs/source/highlights.md
@@ -1,18 +1,18 @@
-# Modules in v0.4.0
+# Modules overview
MONAI aims at supporting deep learning in medical image analysis at multiple granularities.
-This figure shows a typical example of the end-to-end workflow in medical deep learning area:
-![image](../images/end_to_end.png)
+This figure shows a typical example of the end-to-end workflow:
+![an end to end workflow](../images/end_to_end.png)
## MONAI architecture
The design principle of MONAI is to provide flexible and light APIs for users with varying expertise.
-1. All the core components are independent modules, which can be easily integrated into any existing PyTorch programs.
+1. All the core components are independent modules, which can be easily integrated into any existing PyTorch program.
2. Users can leverage the workflows in MONAI to quickly set up a robust training or evaluation program for research experiments.
3. Rich examples and demos are provided to demonstrate the key features.
4. Researchers contribute implementations based on the state-of-the-art for the latest research challenges, including COVID-19 image analysis, Model Parallel, etc.
The overall architecture and modules are shown in the following figure:
-![image](../images/arch_modules_v0.4.png)
+![architecture overview](../images/arch_modules_v0.4.png)
The rest of this page provides more details for each module.
* [Data I/O, processing and augmentation](#medical-image-data-i-o-processing-and-augmentation)
@@ -26,6 +26,7 @@ The rest of this page provides more details for each module.
* [Workflows](#workflows)
* [Research](#research)
* [GPU acceleration](#gpu-acceleration)
+* [Applications](#applications)
## Medical image data I/O, processing and augmentation
Medical images require highly specialized methods for I/O, preprocessing, and augmentation. Medical images are often in specialized formats with rich meta-information, and the data volumes are often high-dimensional. These require carefully designed manipulation procedures. The medical imaging focus of MONAI is enabled by powerful and flexible image transformations that facilitate user-friendly, reproducible, optimized medical data pre-processing pipelines.
@@ -36,6 +37,10 @@ Medical images require highly specialized methods for I/O, preprocessing, and au
There is a rich set of transforms in six categories: Crop & Pad, Intensity, IO, Post-processing, Spatial, and Utilities. For more details, please visit [all the transforms in MONAI](https://docs.monai.io/en/latest/transforms.html).
+Almost all the transforms expect the input data to have a channel-first shape format: `[Channel dim, spatial dim 1, spatial dim 2, ...]`.
+Flexible [base APIs](https://github.com/Project-MONAI/MONAI/tree/dev/monai/transforms) are also provided. The `monai.transforms` module is
+easily extensible.
+
### 2. Medical specific transforms
MONAI aims at providing a comprehensive medical image specific
transformations. These currently include, for example:
@@ -49,7 +54,7 @@ transformations. These currently include, for example:
- `Rand3DElastic`: Random elastic deformation and affine in 3D
[2D transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/transforms_demo_2d.ipynb) shows the detailed usage of several MONAI medical image specific transforms.
-![image](../images/medical_transforms.png)
+![2d transform examples](../images/medical_transforms.png)
### 3. Fused spatial transforms and GPU acceleration
As medical image volumes are usually large (in multi-dimensional arrays), pre-processing performance affects the overall pipeline speed. MONAI provides affine transforms to execute fused spatial operations, supports GPU acceleration via native PyTorch for high performance.
@@ -69,8 +74,8 @@ new_img = affine(image, spatial_size=(300, 400), mode='bilinear')
```
Experiments and test results are available at [Fused transforms test](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/transform_speed.ipynb).
-Currently all the geometric image transforms (Spacing, Zoom, Rotate, Resize, etc.) are designed based on the PyTorch native interfaces. [Geometric transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/3d_image_transforms.ipynb) indicates the usage of affine transforms with 3D medical images.
-![image](../images/affine.png)
+Currently, all the geometric image transforms (Spacing, Zoom, Rotate, Resize, etc.) are designed based on the PyTorch native interfaces. [Geometric transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/3d_image_transforms.ipynb) indicates the usage of affine transforms with 3D medical images.
+![3d transform examples](../images/affine.png)
### 4. Randomly crop out batch images based on positive/negative ratio
Medical image data volume may be too large to fit into GPU memory. A widely-used approach is to randomly draw small size data samples during training and run a “sliding window” routine for inference. MONAI currently provides general random sampling strategies including class-balanced fixed ratio sampling which may help stabilize the patch-based training process. A typical example is in [Spleen 3D segmentation tutorial](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d.ipynb), which achieves the class-balanced sampling with `RandCropByPosNegLabel` transform.
@@ -98,21 +103,21 @@ monai.utils.set_determinism(seed=0, additional_settings=None)
To apply different transforms on the same data and concatenate the results, MONAI provides `CopyItems` transform to make copies of specified items in the data dictionary and `ConcatItems` transform to combine specified items on the expected dimension, and also provides `DeleteItems` transform to delete unnecessary items to save memory.
Typical usage is to scale the intensity of the same image into different ranges and concatenate the results together.
-![image](../images/multi_transform_chains.png)
+![multiple transform chains](../images/multi_transform_chains.png)
### 7. Debug transforms with DataStats
-When transforms are combined with the "compose" function, it's not easy to track the output of specific transform. To help debug errors in the composed transforms, MONAI provides utility transforms such as `DataStats` to print out intermediate data properties such as `data shape`, `value range`, `data value`, `Additional information`, etc. It's a self-contained transform and can be integrated into any transform chain.
+When transforms are combined with the "compose" function, it's not easy to track the output of a specific transform. To help debug errors in the composed transforms, MONAI provides utility transforms such as `DataStats` to print out intermediate data properties such as `data shape`, `value range`, `data value`, `Additional information`, etc. It's a self-contained transform and can be integrated into any transform chain.
### 8. Post-processing transforms for model output
MONAI also provides post-processing transforms for handling the model outputs. Currently, the transforms include:
-- Adding activation layer (Sigmoid, Softmax, etc.).
+- Adding an activation layer (Sigmoid, Softmax, etc.).
- Converting to discrete values (Argmax, One-Hot, Threshold value, etc), as below figure (b).
- Splitting multi-channel data into multiple single channels.
- Removing segmentation noise based on Connected Component Analysis, as below figure (c).
- Extracting contour of segmentation result, which can be used to map to original image and evaluate the model, as below figure (d) and (e).
-After applying the post-processing transforms, it's easier to compute metrics, save model output into files or visualize data in the TensorBoard. [Post transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/post_transforms.ipynb) shows an example with several main post transforms.
-![image](../images/post_transforms.png)
+After decollating the batch data of model output and applying the post-processing transforms, it's easier to compute metrics, save model output into files or visualize data in the TensorBoard. [Postprocessing transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/postprocessing_transforms.ipynb) shows an example with several main transforms for post-processing.
+![post-processing transforms](../images/postprocessing_transforms.png)
### 9. Integrate third-party transforms
The design of MONAI transforms emphasis code readability and usability. It works for array data or dictionary-based data. MONAI also provides `Adaptor` tools to accommodate different data format for 3rd party transforms. To convert the data shapes or types, utility transforms such as `ToTensor`, `ToNumpy`, `SqueezeDim` are also provided. So it's easy to enhance the transform chain by seamlessly integrating transforms from external packages, including: `ITK`, `BatchGenerator`, `TorchIO` and `Rising`.
@@ -120,28 +125,47 @@ The design of MONAI transforms emphasis code readability and usability. It works
For more details, please check out the tutorial: [integrate 3rd party transforms into MONAI program](https://github.com/Project-MONAI/tutorials/blob/master/modules/integrate_3rd_party_transforms.ipynb).
### 10. IO factory for medical image formats
-Many popular image formats exist in the medical domain, and they are quite different with rich metadata information. To easily handle different medical image formats in the same pipeline, [MONAI provides `LoadImage` transform](https://github.com/Project-MONAI/tutorials/blob/master/modules/load_medical_images.ipynb), which can automatically choose image readers based on the supported suffixes and in below priority order:
-- User-specified reader at runtime when call this loader.
-- Registered readers from the latest to the first in list.
+Many popular image formats exist in the medical domain, and they are quite different with rich metadata information. To easily handle different medical image formats in the same pipeline, [MONAI provides `LoadImage` transform](https://github.com/Project-MONAI/tutorials/blob/master/modules/load_medical_images.ipynb), which can automatically choose image readers based on the supported suffixes and in the following priority order:
+- User-specified reader at runtime when calling this loader.
+- Registered readers from the latest to the first in the list.
- Default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader), (npz, npy -> NumpyReader), (others -> ITKReader).
-The `ImageReader` API is quite straight-forward, users can easily extend for their own customized image readers.
+The `ImageReader` API is quite straightforward, users can easily extend it for their customized image readers.
With these pre-defined image readers, MONAI can load images in formats: `NIfTI`, `DICOM`, `PNG`, `JPG`, `BMP`, `NPY/NPZ`, etc.
+### 11. Save transform data into NIfTI or PNG files
+To convert images into files or debug the transform chain, MONAI provides `SaveImage` transform. Users can inject this transform into the transform chain to save the results.
+
+### 12. Automatically ensure `channel-first` data shape
+Medical images have different shape formats. They can be `channel-last`, `channel-first` or even `no-channel`. We may, for example, want to load several `no-channel` images and stack them as `channel-first` data. To improve the user experience, MONAI provided an `EnsureChannelFirst` transform to automatically detect data shape according to the meta information and convert it to the `channel-first` format consistently.
+
+### 13. Invert spatial transforms and test-time augmentations
+It is often desirable to invert the previously applied spatial transforms (resize, flip, rotate, zoom, crop, pad, etc.) within the deep learning workflows, for example, to resume to the original imaging space after processing the image data in a normalized data space. Many spatial transforms are enhanced with an `inverse` operation since in v0.5. The [model inference tutorial](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_inference_dict.py) shows a basic example.
+
+If the pipeline includes random transformations, users may want to observe the effect that these transformations have on the output. The typical approach is that we pass the same input through the transforms multiple times with different random realizations. Then use the inverse transforms to move all the results to a common space, and calculate the metrics. MONAI provided `TestTimeAugmentation` for this feature, which by default will calculate the `mode`, `mean`, `standard deviation` and `volume variation coefficient`.
+
+[Invert transforms and TTA tutorials](https://github.com/Project-MONAI/tutorials/blob/master/modules/inverse_transforms_and_test_time_augmentations.ipynb) introduce details about the API with usage examples.
+
+(1) The last column is the inverted data of model output:
+![invert transform](../images/invert_transforms.png)
+
+(2) The TTA results of `mode`, `mean` and `standard deviation`:
+![test time augmentation](../images/tta.png)
+
## Datasets
### 1. Cache IO and transforms data to accelerate training
Users often need to train the model with many (potentially thousands of) epochs over the data to achieve the desired model quality. A native PyTorch implementation may repeatedly load data and run the same preprocessing steps for every epoch during training, which can be time-consuming and unnecessary, especially when the medical image volumes are large.
-MONAI provides a multi-threads `CacheDataset` and `LMDBDataset` to accelerate these transformation steps during training by storing the intermediate outcomes before the first randomized transform in the transform chain. Enabling this feature could potentially give 10x training speedups in the [Datasets experiment](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/dataset_type_performance.ipynb).
-![image](../images/cache_dataset.png)
+MONAI provides a multi-thread `CacheDataset` and `LMDBDataset` to accelerate these transformation steps during training by storing the intermediate outcomes before the first randomized transform in the transform chain. Enabling this feature could potentially give 10x training speedups in the [Datasets experiment](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/dataset_type_performance.ipynb).
+![digital pathology](../images/cache_dataset.png)
### 2. Cache intermediate outcomes into persistent storage
The `PersistentDataset` is similar to the CacheDataset, where the intermediate cache values are persisted to disk storage or LMDB for rapid retrieval between experimental runs (as is the case when tuning hyperparameters), or when the entire data set size exceeds available memory. The `PersistentDataset` could achieve similar performance when comparing to `CacheDataset` in [Datasets experiment](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/dataset_type_performance.ipynb).
-![image](../images/datasets_speed.png)
+![cachedataset speed](../images/datasets_speed.png)
### 3. SmartCache mechanism for big datasets
-During training with very big volume dataset, an efficient approach is to only train with a subset of the dataset in an epoch and dynamically replace part of the subset in every epoch. It's the `SmartCache` mechanism in [NVIDIA Clara-train SDK](https://docs.nvidia.com/clara/tlt-mi/clara-train-sdk-v3.0/nvmidl/additional_features/smart_cache.html#smart-cache).
+During training with large volume dataset, an efficient approach is to only train with a subset of the dataset in an epoch and dynamically replace part of the subset in every epoch. It's the `SmartCache` mechanism in [NVIDIA Clara-train SDK](https://docs.nvidia.com/clara/tlt-mi/clara-train-sdk-v3.0/nvmidl/additional_features/smart_cache.html#smart-cache).
MONAI provides a PyTorch version `SmartCache` as `SmartCacheDataset`. In each epoch, only the items in the cache are used for training, at the same time, another thread is preparing replacement items by applying the transform sequence to items not in the cache. Once one epoch is completed, `SmartCache` replaces the same number of items with replacement items.
@@ -183,26 +207,34 @@ It supports user-specified `image_transforms` and `patch_transforms` with custom
which decouples the two-level computations in a multiprocess context.
### 6. Predefined Datasets for public medical data
-To quickly get started with popular training data in the medical domain, MONAI provides several data-specific Datasets(like: `MedNISTDataset`, `DecathlonDataset`, etc.), which include downloading from our AWS storage, extracting data files and support generation of training/evaluation items with transforms. And they are flexible that users can easily modify the JSON config file to change the default behaviors.
+To quickly get started with popular training data in the medical domain, MONAI provides several data-specific Datasets(like: `MedNISTDataset`, `DecathlonDataset`, etc.), which include downloading from our AWS storage, extracting data files and support generation of training/evaluation items with transforms. And they are flexible in that users can easily modify the JSON config file to change the default behaviors.
MONAI always welcome new contributions of public datasets, please refer to existing Datasets and leverage the download and extracting APIs, etc. [Public datasets tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/public_datasets.ipynb) indicates how to quickly set up training workflows with `MedNISTDataset` and `DecathlonDataset` and how to create a new `Dataset` for public data.
The common workflow of predefined datasets:
-![image](../images/dataset_progress.png)
+![pre-defined dataset](../images/dataset_progress.png)
### 7. Partition dataset for cross validation
-The `partition_dataset` utility in MONAI can perform several kinds of mechanism to partition dataset for training and validation or cross-validation. It supports shuffling based on a specified random seed, and will return a set of datasets, each dataset contains one partition. And it can split the dataset based on specified ratios or evenly split into `num_partitions`. For given class labels, it can also make sure the same ratio of classes in every partition.
+The `partition_dataset` utility in MONAI can perform different types of partitioning for training and validation or cross-validation. It supports shuffling based on a specified random seed, and will return a set of datasets, each dataset contains one partition. And it can split the dataset based on specified ratios or evenly split into `num_partitions`. For given class labels, it can also make sure the same ratio of classes in every partition.
+
+### 8. CSV `Dataset` and `IterableDataset`
+CSV tables are often used in additional to image data to incorporate adjunct information, such as patient demographics, lab results, image acquisition parameters and other non-image data, MONAI provides `CSVDataset` to load CSV files and `CSVIterableDataset` to load large CSV files with scalable data access.
+In addition to the regular preprocessing transform while loading, it also supports multiple CSV files loading, joining tables, rows and columns selection and grouping. [CSVDatasets tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/csv_datasets.ipynb) shows detailed usage examples.
## Losses
-There are domain-specific loss functions in the medical imaging research which are not typically used in the generic computer vision tasks. As an important module of MONAI, these loss functions are implemented in PyTorch, such as `DiceLoss`, `GeneralizedDiceLoss`, `MaskedDiceLoss`, `TverskyLoss` and `FocalLoss`, etc.
+There are domain-specific loss functions in the medical imaging research which are not typically used in generic computer vision tasks. As an important module of MONAI, these loss functions are implemented in PyTorch, such as `DiceLoss`, `GeneralizedDiceLoss`, `MaskedDiceLoss`, `TverskyLoss`, `FocalLoss`, `DiceCELoss`, and `DiceFocalLoss`, etc.
## Optimizers
-MONAI provides several advanced features in optimizers to help accelerate the training or fine-tuning progress. For example, `Novograd` optimizer can be used to converge obviously faster than traditional optimizers. And users can easily define different learning rates for the model layers based [on the `generate_param_groups` utility API](https://github.com/Project-MONAI/tutorials/blob/master/modules/layer_wise_learning_rate.ipynb).
+MONAI provides several advanced features in optimizers to help accelerate the training or fine-tuning progress. For example, `Novograd` optimizer can be used to converge faster than the traditional optimizers. And users can easily define different learning rates for the model layers based [on the `generate_param_groups` utility API](https://github.com/Project-MONAI/tutorials/blob/master/modules/layer_wise_learning_rate.ipynb).
+
+Another important feature is `LearningRateFinder`. The learning rate range test increases the learning rate in a pre-training run between two boundaries in a linear or exponential manner. It provides valuable information on how well the network can be trained over a range of learning rates and what the optimal learning rates are. [LearningRateFinder tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/learning_rate.ipynb) indicates the API usage examples.
+![learning rate finder plot](../images/lr_finder.png)
## Network architectures
Some deep neural network architectures have shown to be particularly effective for medical imaging analysis tasks. MONAI implements reference networks with the aims of both flexibility and code readability.
-To leverage the common network layers and blocks, MONAI provides several predefined layers and blocks which are compatible with 1D, 2D and 3D networks. Users can easily integrate the layer factories in their own networks.
+### 1. Predefined layers and blocks
+To leverage the common network layers and blocks, MONAI provides several predefined layers and blocks which are compatible with 1D, 2D and 3D networks. Users can easily integrate the layer factories in their customised networks.
For example:
```py
@@ -215,7 +247,12 @@ name, dimension = Conv.CONVTRANS, 3
conv_type = Conv[name, dimension]
add_module('conv1', conv_type(in_channels, out_channels, kernel_size=1, bias=False))
```
-And there are several 1D/2D/3D-compatible implementations of intermediate blocks and generic networks, such as UNet, DynUNet, DenseNet, GAN, AHNet, VNet, SENet(and SEResNet, SEResNeXt), SegResNet, etc. All the networks can support PyTorch serialization pipeline based on `torch.jit.script`.
+
+### 2. Implementation of generic 2D/3D networks
+And there are several 1D/2D/3D-compatible implementations of intermediate blocks and generic networks, such as UNet, DynUNet, DenseNet, GAN, AHNet, VNet, SENet(and SEResNet, SEResNeXt), SegResNet, EfficientNet, Attention-based networks. All the networks can support PyTorch serialization pipeline based on `torch.jit.script`.
+
+### 3. Network adapter to finetune final layers
+Instead of training from scratch, we often leverage the existing models, and finetune the final layers of a network for new learning tasks. MONAI provides a `NetAdapter` to easily replace the last layer of a model by a convolutional layer or a fully-connected layer. A typical usage example is to adapt [Torchvision models trained with ImageNet](https://pytorch.org/vision/stable/models.html) for other learning tasks.
## Evaluation
To run model inferences and evaluate the model quality, MONAI provides reference implementations for the relevant widely-used approaches. Currently, several popular evaluation metrics and inference patterns are included:
@@ -228,7 +265,7 @@ A typical process is:
2. Iteratively run batched window inferences until all windows are analyzed.
3. Aggregate the inference outputs to a single segmentation map.
4. Save the results to file or compute some evaluation metrics.
-![image](../images/sliding_window.png)
+![sliding window scheme](../images/sliding_window.png)
The [Spleen 3D segmentation tutorial](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d.ipynb) leverages `SlidingWindow` inference for validation.
@@ -237,17 +274,27 @@ Various useful evaluation metrics have been used to measure the quality of medic
For example, `Mean Dice` score can be used for segmentation tasks, and the area under the ROC curve(`ROCAUC`) for classification tasks. We continue to integrate more options.
+1. MONAI provides flexible base APIs for metrics
+The base classes of MONAI metrics implement the basic computation logic for both iteration and epoch-based metrics. They are a good starting point for customized metrics.
+2. All the metrics support data parallel computation
+With a `Cumulative` base class, intermediate metric outcomes can be automatically buffered, cumulated, synced across distributed processes, and aggregated for the final results. [Multi-processing computation example](https://github.com/Project-MONAI/tutorials/blob/master/modules/compute_metric.py) shows how to compute metrics based on saved predictions and labels in multi-processing environment.
+3. All the metrics modules can handle `batch-first` Tensors and list of `channel-first` Tensors
+
+### 3. Metrics report generation
+During evaluation, users usually save the metrics of every input image, then analyze the bad cases to improve the deep learning pipeline. To save detailed information of metrics, MONAI provided a handler `MetricsSaver`, which can save the final metric values, raw metric of every model output channel of every input image, metrics summary report of operations: `mean`, `median`, `max`, `min`, `percentile`, `std`, etc. The `MeanDice` reports of validation with prostate dataset are as below:
+![metrics report example](../images/metrics_report.png)
+
## Visualization
Beyond the simple point and curve plotting, MONAI provides intuitive interfaces to visualize multidimensional data as GIF animations in TensorBoard. This could provide a quick qualitative assessment of the model by visualizing, for example, the volumetric inputs, segmentation maps, and intermediate feature maps. A runnable example with visualization is available at [UNet training example](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_training_dict.py).
And to visualize the class activation mapping for a trained classification model, MONAI provides CAM, GradCAM, GradCAM++ APIs for both 2D and 3D models:
-![image](../images/cam.png)
+![CAM visualization example](../images/cam.png)
The above example is generated by computing [GradCAM/GradCAM++ from a lung CT lesion classification model](https://github.com/Project-MONAI/tutorials/tree/master/modules/interpretability).
## Result writing
-Currently MONAI supports writing the model outputs as NIfTI files or PNG files for segmentation tasks, and as CSV files for classification tasks. And the writers can restore the data spacing, orientation or shape according to the `original_shape` or `original_affine` information from the input image.
+Currently, MONAI supports writing the model outputs as NIfTI files or PNG files for segmentation tasks, and as CSV files for classification tasks. And the writers can restore the data spacing, orientation or shape according to the `original_shape` or `original_affine` information from the input image.
A rich set of formats will be supported soon, along with relevant statistics and evaluation metrics automatically computed from the outputs.
@@ -255,11 +302,11 @@ A rich set of formats will be supported soon, along with relevant statistics and
To quickly set up training and evaluation experiments, MONAI provides a set of workflows to significantly simplify the modules and allow for fast prototyping.
These features decouple the domain-specific components and the generic machine learning processes. They also provide a set of unify APIs for higher level applications (such as AutoML, Federated Learning).
-The trainers and evaluators of the workflows are compatible with pytorch-ignite `Engine` and `Event-Handler` mechanism. There are rich event handlers in MONAI to independently attach to the trainer or evaluator.
+The trainers and evaluators of the workflows are compatible with pytorch-ignite `Engine` and `Event-Handler` mechanism. There are rich event handlers in MONAI to independently attach to the trainer or evaluator, and users can register additional `custom events` to workflows.
### 1. General workflows pipeline
-The workflow and event handlers are shown as below:
-![image](../images/workflows.png)
+The workflow and some of MONAI event handlers are shown as below:
+![workflow pipeline](../images/workflows.png)
The end-to-end training and evaluation examples are available at [Workflow examples](https://github.com/Project-MONAI/tutorials/tree/master/modules/engines).
@@ -270,9 +317,37 @@ Models ensemble is a popular strategy in machine learning and deep learning area
3. Execute inference on the test data with all the K models.
4. Compute the average values with weights or vote the most common value as the final result.
-![image](../images/models_ensemble.png)
+![model ensemble](../images/models_ensemble.png)
More details of practice is at [Model ensemble tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/models_ensemble.ipynb).
+### 3. Transfer learning for different input / output classes
+`Transfer-learning` is a common and efficient training approach, especially in the medical-specific domain where obtaining large datasets for training can be difficult. So transfer learning from a pre-trained checkpoint can significantly improve the model metrics and shorten training time.
+
+MONAI provided `CheckpointLoader` to load a checkpoint for the workflow before training, and it allows some `layer names` of the current network don't match the checkpoint, or some `layer shapes` don't match the checkpoint, which can be useful if the current task has different input image classes or output classes.
+
+### 4. Transfer learning based on NVIDIA Clara MMAR
+[The MMAR (Medical Model ARchive)](https://docs.nvidia.com/clara/clara-train-sdk/pt/mmar.html) defines a data structure for organizing all artifacts produced during the model development life cycle. NVIDIA Clara provides rich existing MMARs of medical domain-specific models. And these MMARs include all the information about the model including configurations and scripts to provide a work space to perform all model development tasks. To better leverage the pretrained MMARs released on Nvidia GPU cloud, MONAI provides pythonic APIs to access the MMARs.
+
+The following figure compares the loss curves and validation scores for (1) training from scratch (the green line), (2) applying a pretrained model without training (the magenta line), (3) training from the pretrained model (the blue line), according to the number of training epochs
+(the tutorial is available at [transfer_mmar](https://github.com/Project-MONAI/tutorials/blob/master/modules/transfer_mmar.ipynb)):
+
+![transfer_mmar](../images/transfer_mmar.png)
+
+### 5. Decollate batch data for flexible postprocessings
+`decollate batch` is introduced in MONAI v0.6, which simplifies the post processing transforms and provides flexible following operations on a batch of data with various data shapes. It can decollate batched data (e.g. model predictions) into a list of tensors, for the benefits such as:
+1. enabling postprocessing transforms for each item independently -- randomised transforms could be applied differently for each predicted item in a batch.
+2. simplifying the transform APIs and reducing the input validation burdens because both the preprocessing and postprocessing transforms now only need to support the "channel-first" input format.
+3. enabling the `Invertd` transform for the predictions and the inverted data with different shapes, as the data items are in a list, not stacked in a single tensor.
+4. allowing for both batch-first tensor and list of channel-first tensors in a flexible metric computation.
+
+A typical process of `decollate batch` is illustrated as follows (with a `batch_size=N` model predictions and labels as an example):
+![decollate_batch](../images/decollate_batch.png)
+
+[decollate batch tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/decollate_batch.ipynb) shows a detailed usage example based on a PyTorch native workflow.
+
+### 6. Easy to integrate into popular workflows
+Except for the pytorch-ignite based `monai.engines`, most of the MONAI modules could be used independently or combined with other software packages. For example, MONAI can be easily integrated into popular frameworks such as PyTorch-Lightning and Catalyst: [Lightning segmentation](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d_lightning.ipynb) and [Lightning + TorchIO](https://github.com/Project-MONAI/tutorials/blob/master/modules/TorchIO_MONAI_PyTorch_Lightning.ipynb) tutorials show the PyTorch Lightning programs with MONAI modules, and [Catalyst segmentation](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_catalyst.ipynb) shows the Catalyst program with MONAI modules.
+
## Research
There are several research prototypes in MONAI corresponding to the recently published papers that address advanced research problems.
We always welcome contributions in forms of comments, suggestions, and code implementations.
@@ -283,13 +358,13 @@ The generic patterns/modules identified from the research prototypes will be int
[A reimplementation](https://monai.io/research/coplenet-pneumonia-lesion-segmentation) of the COPLE-Net originally proposed by:
G. Wang, X. Liu, C. Li, Z. Xu, J. Ruan, H. Zhu, T. Meng, K. Li, N. Huang, S. Zhang. (2020) "A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions from CT Images." IEEE Transactions on Medical Imaging. 2020. [DOI: 10.1109/TMI.2020.3000314](https://doi.org/10.1109/TMI.2020.3000314)
-![image](../images/coplenet.png)
+![coplenet](../images/coplenet.png)
### 2. LAMP: Large Deep Nets with Automated Model Parallelism for Image Segmentation
[A reimplementation](https://monai.io/research/lamp-automated-model-parallelism) of the LAMP system originally proposed by:
Wentao Zhu, Can Zhao, Wenqi Li, Holger Roth, Ziyue Xu, and Daguang Xu (2020) "LAMP: Large Deep Nets with Automated Model Parallelism for Image Segmentation." MICCAI 2020 (Early Accept, paper link: https://arxiv.org/abs/2006.12575)
-![image](../images/unet-pipe.png)
+![LAMP UNet](../images/unet-pipe.png)
## GPU acceleration
NVIDIA GPUs have been widely applied in many areas of deep learning training and evaluation, and the CUDA parallel computation shows obvious acceleration when comparing to traditional computation methods. To fully leverage GPU features, many popular mechanisms raised, like automatic mixed precision (AMP), distributed data parallel, etc. MONAI can support these features and provides rich examples.
@@ -299,19 +374,46 @@ In 2017, NVIDIA researchers developed a methodology for mixed-precision training
For the PyTorch 1.6 release, developers at NVIDIA and Facebook moved mixed precision functionality into PyTorch core as the AMP package, `torch.cuda.amp`.
-MONAI workflows can easily set `amp=True/False` in `SupervisedTrainer` or `SupervisedEvaluator` during training or evaluation to enable/disable AMP. And we tried to compare the training speed if AMP ON/OFF on Tesla V100 GPU with CUDA 11 and PyTorch 1.6, got some benchmark for reference:
-![image](../images/amp_training_v100.png)
-We also executed the same test program on Testa A100 GPU with the same software environment, got much faster benchmark for reference:
-![image](../images/amp_training_a100.png)
+MONAI workflows can easily set `amp=True/False` in `SupervisedTrainer` or `SupervisedEvaluator` during training or evaluation to enable/disable AMP. And we tried to compare the training speed if AMP ON/OFF on NVIDIA V100 GPU with CUDA 11 and PyTorch 1.6, obtained some benchmark results:
+![amp v100 results](../images/amp_training_v100.png)
+We also executed the same test program on NVIDIA A100 GPU with the same software environment, obtained faster results:
+![amp a100 results](../images/amp_training_a100.png)
More details is available at [AMP training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/automatic_mixed_precision.ipynb).
We also tried to combine AMP with `CacheDataset` and `Novograd` optimizer to achieve the fast training in MONAI, able to obtain approximately 12x speedup compared with a Pytorch native implementation when the training converges at a validation mean dice of 0.93. Benchmark for reference:
-![image](../images/fast_training.png)
+![fast training results](../images/fast_training.png)
More details is available at [Fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb).
### 2. Distributed data parallel
-Distributed data parallel is an important feature of PyTorch to connect multiple GPU devices on single or multiple nodes to train or evaluate models. MONAI provides demos for reference: train/evaluate with PyTorch DDP, train/evaluate with Horovod, train/evaluate with Ignite DDP, partition dataset and train with SmartCacheDataset, as well as a real world training example based on Decathlon challenge Task01 - Brain Tumor segmentation.
-The demo contains distributed caching, training, and validation. We tried to train this example on NVIDIA NGC server, got some performance benchmarks for reference(PyTorch 1.6, CUDA 11, Tesla V100 GPUs):
-![image](../images/distributed_training.png)
+Distributed data parallel is an important feature of PyTorch to connect multiple GPU devices on single or multiple nodes to train or evaluate models. The distributed data parallel APIs of MONAI are compatible with native PyTorch distributed module, pytorch-ignite distributed module, Horovod, XLA, and the SLURM platform. MONAI provides demos for reference: train/evaluate with PyTorch DDP, train/evaluate with Horovod, train/evaluate with Ignite DDP, partition dataset and train with SmartCacheDataset, as well as a real world training example based on Decathlon challenge Task01 - Brain Tumor segmentation. The demo contains distributed caching, training, and validation. We obtained performance benchmarks for reference (based on PyTorch 1.6, CUDA 11, NVIDIA V100 GPUs):
+
+![distributed training results](../images/distributed_training.png)
### 3. C++/CUDA optimized modules
-To accelerate some heavy computation progress, C++/CUDA implementation can be an impressive method, which usually brings even hundreds of times faster performance. MONAI contains some C++/CUDA optimized modules, like `Resampler`,and fully support C++/CUDA programs in CI/CD and building package.
+To further accelerate the domain-specific routines in the workflows, MONAI C++/CUDA implementation are introduced as extensions of the PyTorch native implementations.
+MONAI provides the modules using [the two ways of building C++ extensions from PyTorch](https://pytorch.org/tutorials/advanced/cpp_extension.html#custom-c-and-cuda-extensions):
+- via `setuptools`, for modules including `Resampler`, `Conditional random field (CRF)`, `Fast bilateral filtering using the permutohedral lattice`.
+- via just-in-time (JIT) compilation, for the `Gaussian mixtures` module. This approach allows for dynamic optimisation according to the user-specified parameters and local system environments.
+The following figure shows results of MONAI's Gaussian mixture models applied to tissue and surgical tools segmentation:
+![Gaussian mixture models as a postprocessing step](../images/gmm_feature_set_comparison_s.png)
+
+## Applications
+The research area of medical image deep learning is expanding fast. To apply the latest achievements into applications, MONAI contains many application components to build end-to-end solutions or prototypes for other similar use cases.
+
+### 1. DeepGrow modules for interactive segmentation
+[A reimplementation](https://github.com/Project-MONAI/MONAI/tree/master/monai/apps/deepgrow) of the DeepGrow components, which is deep learning based semi-automated segmentation approach that aims to be a "smart" interactive tool for region of interest delineation in medical images, originally proposed by:
+
+Sakinis, Tomas, et al. "Interactive segmentation of medical images through fully convolutional neural networks." arXiv preprint arXiv:1903.08205 (2019).
+
+![deepgrow scheme](../images/deepgrow.png)
+
+### 2. Lesion detection in digital pathology
+[Implementation](https://github.com/Project-MONAI/MONAI/tree/master/monai/apps/pathology) of the pathology detection components, which includes efficient whole slide imaging IO and sampling with NVIDIA cuCIM library and SmartCache mechanism, FROC measurements for lesion and probabilistic post-processing for lesion detection.
+
+![digital pathology](../images/pathology.png)
+
+### 3. Learning-based image registration
+Starting from v0.5.0, MONAI provides experimental features for building learning-based 2D/3D registration workflows. These include image similarity measures as loss functions, bending energy as model regularization, network architectures, warping modules. The components can be used to build the major unsupervised and weakly-supervised algorithms.
+
+The following figure shows the registration of CT images acquired at different time points for a single patient using MONAI:
+
+![3d registration](../images/3d_paired.png)
diff --git a/docs/source/index.rst b/docs/source/index.rst
index ea21428e6e..30671427a4 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -1,6 +1,6 @@
:github_url: https://github.com/Project-MONAI/MONAI
-.. MONAI documentation master file, created by
+.. MONAI documentation main file, created by
sphinx-quickstart on Wed Feb 5 09:40:29 2020.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
@@ -11,7 +11,7 @@ Project MONAI
*Medical Open Network for AI*
-MONAI is a `PyTorch `_-based, `open-source `_ framework
+MONAI is a `PyTorch `_-based, `open-source `_ framework
for deep learning in healthcare imaging, part of `PyTorch Ecosystem `_.
Its ambitions are:
@@ -45,6 +45,8 @@ Technical documentation is available at `docs.monai.io `_
:maxdepth: 1
:caption: Feature highlights
+ whatsnew_0_6.md
+ whatsnew_0_5.md
highlights.md
.. toctree::
@@ -75,7 +77,7 @@ Contributing
------------
For guidance on making a contribution to MONAI, see the `contributing guidelines
-`_.
+`_.
Links
@@ -86,14 +88,13 @@ Links
- Code: https://github.com/Project-MONAI/MONAI
- Project tracker: https://github.com/Project-MONAI/MONAI/projects
- Issue tracker: https://github.com/Project-MONAI/MONAI/issues
-- Changelog: https://github.com/Project-MONAI/MONAI/blob/master/CHANGELOG.md
+- Changelog: https://github.com/Project-MONAI/MONAI/blob/dev/CHANGELOG.md
- Wiki: https://github.com/Project-MONAI/MONAI/wiki
- FAQ: https://github.com/Project-MONAI/MONAI/wiki/Frequently-asked-questions-and-answers
- Test status: https://github.com/Project-MONAI/MONAI/actions
- PyPI package: https://pypi.org/project/monai/
+- Weekly previews: https://pypi.org/project/monai-weekly/
- Docker Hub: https://hub.docker.com/r/projectmonai/monai
-- Google Group: https://groups.google.com/forum/#!forum/project-monai
-- Reddit: https://www.reddit.com/r/projectmonai/
Indices and tables
diff --git a/docs/source/inferers.rst b/docs/source/inferers.rst
index 544e695d2e..e358e603bd 100644
--- a/docs/source/inferers.rst
+++ b/docs/source/inferers.rst
@@ -30,3 +30,9 @@ Inferers
.. autoclass:: SlidingWindowInferer
:members:
:special-members: __call__
+
+`SaliencyInferer`
+~~~~~~~~~~~~~~~~~
+.. autoclass:: SaliencyInferer
+ :members:
+ :special-members: __call__
diff --git a/docs/source/installation.md b/docs/source/installation.md
index cb540b1559..d8dddff205 100644
--- a/docs/source/installation.md
+++ b/docs/source/installation.md
@@ -59,13 +59,19 @@ for the latest features:
### Option 1 (as a part of your system-wide module):
```bash
-pip install git+https://github.com/Project-MONAI/MONAI#egg=MONAI
+pip install git+https://github.com/Project-MONAI/MONAI#egg=monai
```
or, to build with MONAI Cpp/CUDA extensions:
```bash
-BUILD_MONAI=1 pip install git+https://github.com/Project-MONAI/MONAI#egg=MONAI
+BUILD_MONAI=1 pip install git+https://github.com/Project-MONAI/MONAI#egg=monai
```
-this command will download and install the current master branch of [MONAI from
+
+To build the extensions, if the system environment already has a version of Pytorch installed,
+`--no-build-isolation` might be preferred:
+```bash
+BUILD_MONAI=1 pip install --no-build-isolation git+https://github.com/Project-MONAI/MONAI#egg=monai
+```
+this command will download and install the current `dev` branch of [MONAI from
GitHub](https://github.com/Project-MONAI/MONAI).
This documentation website by default shows the information for the latest version.
@@ -128,7 +134,7 @@ Note that you do not need to install the CUDA toolkit on the host, but the drive
Please find out more information on [nvidia-docker](https://github.com/NVIDIA/nvidia-docker).
Assuming that you have the Nvidia driver and Docker 19.03+ installed, running the following command will
-download and start a container with the latest version of MONAI. The latest master branch of MONAI from GitHub
+download and start a container with the latest version of MONAI. The latest `dev` branch of MONAI from GitHub
is included in the image.
```bash
docker run --gpus all --rm -ti --ipc=host projectmonai/monai:latest
@@ -168,9 +174,9 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
- The options are
```
-[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil]
+[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops]
```
which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`,
-`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb` and `psutil`, respectively.
+`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas` and `einops`, respectively.
- `pip install 'monai[all]'` installs all the optional dependencies.
diff --git a/docs/source/losses.rst b/docs/source/losses.rst
index 5e19219fee..fc7c302ea3 100644
--- a/docs/source/losses.rst
+++ b/docs/source/losses.rst
@@ -48,6 +48,11 @@ Segmentation Losses
.. autoclass:: DiceCELoss
:members:
+`DiceFocalLoss`
+~~~~~~~~~~~~~~~
+.. autoclass:: DiceFocalLoss
+ :members:
+
`FocalLoss`
~~~~~~~~~~~
.. autoclass:: FocalLoss
@@ -77,9 +82,14 @@ Registration Losses
:members:
Loss Wrappers
---------------
+-------------
`MultiScaleLoss`
-~~~~~~~~~~~~~~~~~
+~~~~~~~~~~~~~~~~
.. autoclass:: MultiScaleLoss
- :members:
\ No newline at end of file
+ :members:
+
+`MaskedLoss`
+~~~~~~~~~~~~
+.. autoclass:: MaskedLoss
+ :members:
diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst
index 32a3faf380..e6605065c4 100644
--- a/docs/source/metrics.rst
+++ b/docs/source/metrics.rst
@@ -6,6 +6,30 @@ Metrics
=======
.. currentmodule:: monai.metrics
+`FROC`
+------
+.. autofunction:: compute_froc_score
+
+`Metric`
+--------
+.. autoclass:: Metric
+ :members:
+
+`IterationMetric`
+-----------------
+.. autoclass:: IterationMetric
+ :members:
+
+`Cumulative`
+------------
+.. autoclass:: Cumulative
+ :members:
+
+`CumulativeIterationMetric`
+---------------------------
+.. autoclass:: CumulativeIterationMetric
+ :members:
+
`Mean Dice`
-----------
.. autofunction:: compute_meandice
@@ -17,6 +41,9 @@ Metrics
--------------------------
.. autofunction:: compute_roc_auc
+.. autoclass:: ROCAUCMetric
+ :members:
+
`Confusion matrix`
------------------
.. autofunction:: get_confusion_matrix
@@ -37,3 +64,23 @@ Metrics
.. autoclass:: SurfaceDistanceMetric
:members:
+
+`Mean squared error`
+--------------------
+.. autoclass:: MSEMetric
+ :members:
+
+`Mean absolute error`
+---------------------
+.. autoclass:: MAEMetric
+ :members:
+
+`Root mean squared error`
+-------------------------
+.. autoclass:: RMSEMetric
+ :members:
+
+`Peak signal to noise ratio`
+----------------------------
+.. autoclass:: PSNRMetric
+ :members:
\ No newline at end of file
diff --git a/docs/source/networks.rst b/docs/source/networks.rst
index e0ac0f2d75..a5ce86287a 100644
--- a/docs/source/networks.rst
+++ b/docs/source/networks.rst
@@ -20,6 +20,11 @@ Blocks
.. autoclass:: Convolution
:members:
+`CRF`
+~~~~~~~~~~~~~
+.. autoclass:: CRF
+ :members:
+
`ResidualUnit`
~~~~~~~~~~~~~~
.. autoclass:: ResidualUnit
@@ -30,6 +35,11 @@ Blocks
.. autoclass:: Swish
:members:
+`MemoryEfficientSwish`
+~~~~~~~~~~~~~~~~~~~~~~
+.. autoclass:: MemoryEfficientSwish
+ :members:
+
`Mish`
~~~~~~
.. autoclass:: Mish
@@ -69,11 +79,30 @@ Blocks
.. autoclass:: ResBlock
:members:
+`SABlock Block`
+~~~~~~~~~~~~~~~
+.. autoclass:: SABlock
+ :members:
+
`Squeeze-and-Excitation`
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ChannelSELayer
:members:
+`Transformer Block`
+~~~~~~~~~~~~~~~~~~~
+.. autoclass:: TransformerBlock
+ :members:
+
+`UNETR Block`
+~~~~~~~~~~~~~
+.. autoclass:: UnetrBasicBlock
+ :members:
+.. autoclass:: UnetrUpBlock
+ :members:
+.. autoclass:: UnetrPrUpBlock
+ :members:
+
`Residual Squeeze-and-Excitation`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ResidualSELayer
@@ -94,7 +123,7 @@ Blocks
.. autoclass:: SEResNetBottleneck
:members:
-`Squeeze-and-Excitation ResneXt Bottleneck`
+`Squeeze-and-Excitation ResNeXt Bottleneck`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: SEResNeXtBottleneck
:members:
@@ -119,6 +148,21 @@ Blocks
.. autoclass:: Subpixelupsample
.. autoclass:: SubpixelUpSample
+`Registration Residual Conv Block`
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+.. autoclass:: RegistrationResidualConvBlock
+ :members:
+
+`Registration Down Sample Block`
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+.. autoclass:: RegistrationDownSampleBlock
+ :members:
+
+`Registration Extraction Block`
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+.. autoclass:: RegistrationExtractionBlock
+ :members:
+
`LocalNet DownSample Block`
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: LocalNetDownSampleBlock
@@ -134,6 +178,16 @@ Blocks
.. autoclass:: LocalNetFeatureExtractorBlock
:members:
+`MLP Block`
+~~~~~~~~~~~
+.. autoclass:: MLPBlock
+ :members:
+
+`Patch Embedding Block`
+~~~~~~~~~~~~~~~~~~~~~~~
+.. autoclass:: PatchEmbeddingBlock
+ :members:
+
`Warp`
~~~~~~
.. autoclass:: Warp
@@ -216,6 +270,10 @@ Layers
~~~~~~~~~~~~~~~~~
.. autoclass:: PHLFilter
+`GaussianMixtureModel`
+~~~~~~~~~~~~~~~~~~~~~~
+.. autoclass:: GaussianMixtureModel
+
`SavitzkyGolayFilter`
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: SavitzkyGolayFilter
@@ -256,6 +314,8 @@ Layers
~~~~~~~~~~~
.. automodule:: monai.networks.layers.convutils
:members:
+.. automodule:: monai.networks.layers.utils
+ :members:
Nets
@@ -271,10 +331,32 @@ Nets
~~~~~~~~~~
.. autoclass:: DenseNet
:members:
-.. autofunction:: densenet121
-.. autofunction:: densenet169
-.. autofunction:: densenet201
-.. autofunction:: densenet264
+
+`DenseNet121`
+~~~~~~~~~~~~~
+.. autoclass:: DenseNet121
+
+`DenseNet169`
+~~~~~~~~~~~~~
+.. autoclass:: DenseNet169
+
+`DenseNet201`
+~~~~~~~~~~~~~
+.. autoclass:: DenseNet201
+
+`DenseNet264`
+~~~~~~~~~~~~~
+.. autoclass:: DenseNet264
+
+`EfficientNet`
+~~~~~~~~~~~~~~
+.. autoclass:: EfficientNet
+ :members:
+
+`EfficientNetBN`
+~~~~~~~~~~~~~~~~
+.. autoclass:: EfficientNetBN
+ :members:
`SegResNet`
~~~~~~~~~~~
@@ -290,12 +372,30 @@ Nets
~~~~~~~
.. autoclass:: SENet
:members:
-.. autofunction:: senet154
-.. autofunction:: se_resnet50
-.. autofunction:: se_resnet101
-.. autofunction:: se_resnet152
-.. autofunction:: se_resnext50_32x4d
-.. autofunction:: se_resnext101_32x4d
+
+`SENet154`
+~~~~~~~~~~
+.. autoclass:: SENet154
+
+`SEResNet50`
+~~~~~~~~~~~~
+.. autoclass:: SEResNet50
+
+`SEResNet101`
+~~~~~~~~~~~~~
+.. autoclass:: SEResNet101
+
+`SEResNet152`
+~~~~~~~~~~~~~
+.. autoclass:: SEResNet152
+
+`SEResNext50`
+~~~~~~~~~~~~~
+.. autoclass:: SEResNext50
+
+`SEResNext101`
+~~~~~~~~~~~~~~
+.. autoclass:: SEResNext101
`HighResNet`
~~~~~~~~~~~~
@@ -318,6 +418,11 @@ Nets
.. autoclass:: Unet
.. autoclass:: unet
+`UNETR`
+~~~~~~~
+.. autoclass:: UNETR
+ :members:
+
`BasicUNet`
~~~~~~~~~~~
.. autoclass:: BasicUNet
@@ -330,6 +435,16 @@ Nets
.. autoclass:: VNet
:members:
+`RegUNet`
+~~~~~~~~~
+.. autoclass:: RegUNet
+ :members:
+
+`GlobalNet`
+~~~~~~~~~~~~
+.. autoclass:: GlobalNet
+ :members:
+
`LocalNet`
~~~~~~~~~~~
.. autoclass:: LocalNet
@@ -345,6 +460,11 @@ Nets
.. autoclass:: VarAutoEncoder
:members:
+`ViT`
+~~~~~
+.. autoclass:: ViT
+ :members:
+
`FullyConnectedNet`
~~~~~~~~~~~~~~~~~~~
.. autoclass:: FullyConnectedNet
@@ -375,6 +495,21 @@ Nets
.. autoclass:: Critic
:members:
+`NetAdapter`
+~~~~~~~~~~~~
+.. autoclass:: NetAdapter
+ :members:
+
+`TorchVisionFCModel`
+~~~~~~~~~~~~~~~~~~~~
+.. autoclass:: TorchVisionFCModel
+ :members:
+
+`TorchVisionFullyConvModel`
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+.. autoclass:: TorchVisionFullyConvModel
+ :members:
+
Utilities
---------
.. automodule:: monai.networks.utils
diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst
index 5b550f7885..fcd9adba94 100644
--- a/docs/source/transforms.rst
+++ b/docs/source/transforms.rst
@@ -27,12 +27,33 @@ Generic Interfaces
.. autoclass:: Randomizable
:members:
+`RandomizableTransform`
+^^^^^^^^^^^^^^^^^^^^^^^
+.. autoclass:: RandomizableTransform
+ :members:
+
`Compose`
^^^^^^^^^
.. autoclass:: Compose
:members:
:special-members: __call__
+`InvertibleTransform`
+^^^^^^^^^^^^^^^^^^^^^
+.. autoclass:: InvertibleTransform
+ :members:
+
+`BatchInverseTransform`
+^^^^^^^^^^^^^^^^^^^^^^^
+.. autoclass:: BatchInverseTransform
+ :members:
+
+`Decollated`
+^^^^^^^^^^^^
+.. autoclass:: Decollated
+ :members:
+
+
Vanilla Transforms
------------------
@@ -99,6 +120,12 @@ Crop and Pad
:members:
:special-members: __call__
+`RandCropByLabelClasses`
+""""""""""""""""""""""""
+.. autoclass:: RandCropByLabelClasses
+ :members:
+ :special-members: __call__
+
`ResizeWithPadOrCrop`
"""""""""""""""""""""
.. autoclass:: ResizeWithPadOrCrop
@@ -111,6 +138,18 @@ Crop and Pad
:members:
:special-members: __call__
+`RandScaleCrop`
+"""""""""""""""
+.. autoclass:: RandScaleCrop
+ :members:
+ :special-members: __call__
+
+`CenterScaleCrop`
+"""""""""""""""""
+.. autoclass:: CenterScaleCrop
+ :members:
+ :special-members: __call__
+
Intensity
^^^^^^^^^
@@ -132,6 +171,24 @@ Intensity
:members:
:special-members: __call__
+`StdShiftIntensity`
+"""""""""""""""""""
+.. autoclass:: StdShiftIntensity
+ :members:
+ :special-members: __call__
+
+`RandStdShiftIntensity`
+"""""""""""""""""""""""
+.. autoclass:: RandStdShiftIntensity
+ :members:
+ :special-members: __call__
+
+`RandBiasField`
+"""""""""""""""
+.. autoclass:: RandBiasField
+ :members:
+ :special-members: __call__
+
`ScaleIntensity`
""""""""""""""""
.. autoclass:: ScaleIntensity
@@ -228,6 +285,31 @@ Intensity
:members:
:special-members: __call__
+`GibbsNoise`
+""""""""""""""
+.. autoclass:: GibbsNoise
+ :members:
+ :special-members: __call__
+
+`RandGibbsNoise`
+"""""""""""""""""
+.. autoclass:: RandGibbsNoise
+ :members:
+ :special-members: __call__
+
+`KSpaceSpikeNoise`
+""""""""""""""""""""
+.. autoclass:: KSpaceSpikeNoise
+ :members:
+ :special-members: __call__
+
+`RandKSpaceSpikeNoise`
+""""""""""""""""""""""""
+ .. autoclass:: RandKSpaceSpikeNoise
+ :members:
+ :special-members: __call__
+
+
IO
^^
@@ -276,6 +358,11 @@ Post-processing
:members:
:special-members: __call__
+`Prob NMS`
+""""""""""
+.. autoclass:: ProbNMS
+ :members:
+
`VoteEnsemble`
""""""""""""""
.. autoclass:: VoteEnsemble
@@ -309,6 +396,12 @@ Spatial
:members:
:special-members: __call__
+`RandAxisFlip`
+""""""""""""""
+.. autoclass:: RandAxisFlip
+ :members:
+ :special-members: __call__
+
`RandZoom`
""""""""""
.. autoclass:: RandZoom
@@ -399,6 +492,12 @@ Spatial
:members:
:special-members: __call__
+`AddCoordinateChannels`
+"""""""""""""""""""""""
+.. autoclass:: AddCoordinateChannels
+ :members:
+ :special-members: __call__
+
Utility
^^^^^^^
@@ -426,6 +525,12 @@ Utility
:members:
:special-members: __call__
+`EnsureChannelFirst`
+""""""""""""""""""""
+.. autoclass:: EnsureChannelFirst
+ :members:
+ :special-members: __call__
+
`RepeatChannel`
"""""""""""""""
.. autoclass:: RepeatChannel
@@ -456,6 +561,13 @@ Utility
:members:
:special-members: __call__
+`ToCupy`
+""""""""
+.. autoclass:: ToCupy
+ :members:
+ :special-members: __call__
+
+
`Transpose`
"""""""""""
.. autoclass:: Transpose
@@ -498,6 +610,12 @@ Utility
:members:
:special-members: __call__
+`ClassesToIndices`
+""""""""""""""""""
+.. autoclass:: ClassesToIndices
+ :members:
+ :special-members: __call__
+
`ConvertToMultiChannelBasedOnBratsClasses`
""""""""""""""""""""""""""""""""""""""""""
.. autoclass:: ConvertToMultiChannelBasedOnBratsClasses
@@ -516,6 +634,18 @@ Utility
:members:
:special-members: __call__
+`MapLabelValue`
+"""""""""""""""
+.. autoclass:: MapLabelValue
+ :members:
+ :special-members: __call__
+
+`EnsureType`
+""""""""""""
+.. autoclass:: EnsureType
+ :members:
+ :special-members: __call__
+
Dictionary Transforms
---------------------
@@ -582,6 +712,12 @@ Crop and Pad (Dict)
:members:
:special-members: __call__
+`RandCropByLabelClassesd`
+"""""""""""""""""""""""""
+.. autoclass:: RandCropByLabelClassesd
+ :members:
+ :special-members: __call__
+
`ResizeWithPadOrCropd`
""""""""""""""""""""""
.. autoclass:: ResizeWithPadOrCropd
@@ -594,8 +730,20 @@ Crop and Pad (Dict)
:members:
:special-members: __call__
-Instensity (Dict)
-^^^^^^^^^^^^^^^^^
+`RandScaleCropd`
+""""""""""""""""
+.. autoclass:: RandScaleCropd
+ :members:
+ :special-members: __call__
+
+`CenterScaleCropd`
+""""""""""""""""""
+.. autoclass:: CenterScaleCropd
+ :members:
+ :special-members: __call__
+
+Intensity (Dict)
+^^^^^^^^^^^^^^^^
`RandGaussianNoised`
""""""""""""""""""""
@@ -615,6 +763,24 @@ Instensity (Dict)
:members:
:special-members: __call__
+`StdShiftIntensityd`
+""""""""""""""""""""
+.. autoclass:: StdShiftIntensityd
+ :members:
+ :special-members: __call__
+
+`RandStdShiftIntensityd`
+""""""""""""""""""""""""
+.. autoclass:: RandStdShiftIntensityd
+ :members:
+ :special-members: __call__
+
+`RandBiasFieldd`
+""""""""""""""""
+.. autoclass:: RandBiasFieldd
+ :members:
+ :special-members: __call__
+
`ScaleIntensityd`
"""""""""""""""""
.. autoclass:: ScaleIntensityd
@@ -645,6 +811,30 @@ Instensity (Dict)
:members:
:special-members: __call__
+`GibbsNoised`
+""""""""""""""
+.. autoclass:: GibbsNoised
+ :members:
+ :special-members: __call__
+
+`RandGibbsNoised`
+""""""""""""""""""
+.. autoclass:: RandGibbsNoised
+ :members:
+ :special-members: __call__
+
+`KSpaceSpikeNoised`
+""""""""""""""""""""""
+.. autoclass:: KSpaceSpikeNoised
+ :members:
+ :special-members: __call__
+
+`RandKSpaceSpikeNoised`
+"""""""""""""""""""""""""
+.. autoclass:: RandKSpaceSpikeNoised
+ :members:
+ :special-members: __call__
+
`ScaleIntensityRangePercentilesd`
"""""""""""""""""""""""""""""""""
.. autoclass:: ScaleIntensityRangePercentilesd
@@ -759,6 +949,18 @@ Post-processing (Dict)
:members:
:special-members: __call__
+`Invertd`
+"""""""""
+.. autoclass:: Invertd
+ :members:
+ :special-members: __call__
+
+`SaveClassificationd`
+"""""""""""""""""""""
+.. autoclass:: SaveClassificationd
+ :members:
+ :special-members: __call__
+
Spatial (Dict)
^^^^^^^^^^^^^^
@@ -786,6 +988,12 @@ Spatial (Dict)
:members:
:special-members: __call__
+`RandAxisFlipd`
+"""""""""""""""
+.. autoclass:: RandAxisFlipd
+ :members:
+ :special-members: __call__
+
`Rotated`
"""""""""
.. autoclass:: Rotated
@@ -828,6 +1036,12 @@ Spatial (Dict)
:members:
:special-members: __call__
+`Affined`
+"""""""""
+.. autoclass:: Affined
+ :members:
+ :special-members: __call__
+
`RandAffined`
"""""""""""""
.. autoclass:: RandAffined
@@ -846,6 +1060,12 @@ Spatial (Dict)
:members:
:special-members: __call__
+`AddCoordinateChannelsd`
+""""""""""""""""""""""""
+.. autoclass:: AddCoordinateChannelsd
+ :members:
+ :special-members: __call__
+
Utility (Dict)
^^^^^^^^^^^^^^
@@ -873,6 +1093,12 @@ Utility (Dict)
:members:
:special-members: __call__
+`EnsureChannelFirstd`
+"""""""""""""""""""""
+.. autoclass:: EnsureChannelFirstd
+ :members:
+ :special-members: __call__
+
`RepeatChanneld`
""""""""""""""""
.. autoclass:: RepeatChanneld
@@ -903,6 +1129,12 @@ Utility (Dict)
:members:
:special-members: __call__
+`ToCupyd`
+"""""""""
+.. autoclass:: ToCupyd
+ :members:
+ :special-members: __call__
+
`DeleteItemsd`
""""""""""""""
.. autoclass:: DeleteItemsd
@@ -969,6 +1201,12 @@ Utility (Dict)
:members:
:special-members: __call__
+`ClassesToIndicesd`
+"""""""""""""""""""
+.. autoclass:: ClassesToIndicesd
+ :members:
+ :special-members: __call__
+
`ConvertToMultiChannelBasedOnBratsClassesd`
"""""""""""""""""""""""""""""""""""""""""""
.. autoclass:: ConvertToMultiChannelBasedOnBratsClassesd
@@ -987,6 +1225,24 @@ Utility (Dict)
:members:
:special-members: __call__
+`RandTorchVisiond`
+""""""""""""""""""
+.. autoclass:: RandTorchVisiond
+ :members:
+ :special-members: __call__
+
+`MapLabelValued`
+""""""""""""""""
+.. autoclass:: MapLabelValued
+ :members:
+ :special-members: __call__
+
+`EnsureTyped`
+"""""""""""""
+.. autoclass:: EnsureTyped
+ :members:
+ :special-members: __call__
+
Transform Adaptors
------------------
.. automodule:: monai.transforms.adaptors
diff --git a/docs/source/utils.rst b/docs/source/utils.rst
index e0b993da60..321e6acdfc 100644
--- a/docs/source/utils.rst
+++ b/docs/source/utils.rst
@@ -27,7 +27,13 @@ Misc
.. automodule:: monai.utils.misc
:members:
+
Profiling
---------
.. automodule:: monai.utils.profiling
:members:
+
+Deprecated
+----------
+.. automodule:: monai.utils.deprecated
+ :members:
diff --git a/docs/source/whatsnew_0_5.md b/docs/source/whatsnew_0_5.md
new file mode 100644
index 0000000000..f353084303
--- /dev/null
+++ b/docs/source/whatsnew_0_5.md
@@ -0,0 +1,78 @@
+# What's new in 0.5
+
+- Invert spatial transforms and test-time augmentations
+- Lesion detection in digital pathology
+- DeepGrow modules for interactive segmentation
+- Various usability improvements
+
+## Invert spatial transforms and test-time augmentations
+It is often desirable to invert the previously applied spatial transforms (resize, flip, rotate, zoom, crop, pad, etc.) with the deep learning workflows, for example, to resume to the original imaging space after processing the image data in a normalized data space. We enhance almost all the spatial transforms with an `inverse` operation and release this experimental feature in v0.5. Users can easily invert all the spatial transforms for one transformed data item or a batch of data items. It also can be achieved within the workflows by using the `TransformInverter` handler.
+
+If the pipeline includes random transformations, users may want to observe the effect that these transformations have on the output. The typical approach is that we pass the same input through the transforms multiple times with different random realizations. Then use the inverse transforms to move all the results to a common space, and calculate the metrics. MONAI provided `TestTimeAugmentation` for this feature, which by default will calculate the `mode`, `mean`, `standard deviation` and `volume variation coefficient`.
+
+[Invert transforms and TTA tutorials](https://github.com/Project-MONAI/tutorials/blob/master/modules/inverse_transforms_and_test_time_augmentations.ipynb) introduce details about the API with examples.
+
+(1) The last column is the inverted data of model output:
+![invert transform](../images/invert_transforms.png)
+
+(2) The TTA results of `mode`, `mean` and `standard deviation`:
+![test time augmentation](../images/tta.png)
+
+## Lesion detection in digital pathology
+MONAI starts to support digital pathology deep learning tasks. The initial [implementation](https://github.com/Project-MONAI/MONAI/tree/master/monai/apps/pathology) of the pathology detection components includes:
+- Efficient whole slide imaging IO with NVIDIA cuCIM library
+- Patch-based sampling and training strategies with the SmartCache mechanism
+- FROC measurements for lesion detection
+- Probabilistic post-processing for lesion ROIs.
+
+![digital pathology](../images/pathology.png)
+
+## DeepGrow modules for interactive segmentation
+Towards an interactive workflow with manual input during training and inference,
+[a reimplementation](https://github.com/Project-MONAI/MONAI/tree/master/monai/apps/deepgrow) of the DeepGrow components is included in this release.
+DeepGrow is a deep learning based semi-automated segmentation approach that aims to be a "smart" interactive tool for regions of interest delineation in medical images.
+
+![deepgrow scheme](../images/deepgrow_scheme.png)
+
+An end-to-end example is presented at [`project-monai/tutorials`](https://github.com/Project-MONAI/tutorials/tree/master/deepgrow/ignite).
+![deepgrow end-to-end](../images/deepgrow.png)
+
+## Learning-based image registration
+Starting from v0.5, MONAI provides experimental features for building learning-based 2D/3D registration workflows. These include image similarity measures as loss functions, bending energy as model regularization, network architectures, warping modules. The components can be used to build the major unsupervised and weakly-supervised algorithms.
+
+The following figure shows the registration of CT images acquired at different time points for a single patient using MONAI:
+
+![3d registration](../images/3d_paired.png)
+
+## Various usability improvements
+### IO factory for medical image formats
+Many popular image formats exist in the medical domain, and they are quite different with rich metadata information. To easily handle different medical image formats in the same pipeline, [MONAI provides `LoadImage` transform](https://github.com/Project-MONAI/tutorials/blob/master/modules/load_medical_images.ipynb), which can automatically choose image readers based on the supported suffixes and in the below priority order:
+- User-specified reader at runtime when call this loader.
+- Registered readers from the latest to the first in list.
+- Default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader), (npz, npy -> NumpyReader), (others -> ITKReader).
+
+The `ImageReader` API is quite straight-forward, users can easily extend for their own customized image readers.
+
+With these pre-defined image readers, MONAI can load images in formats: `NIfTI`, `DICOM`, `PNG`, `JPG`, `BMP`, `NPY/NPZ`, etc.
+
+### Save transform data into NIfTI or PNG files
+To convert images into files or debug the transform chain, MONAI provides `SaveImage` transform. Users can inject this transform into the transform chain to save the results.
+
+### Automatically ensure `channel-first` data shape
+Medical images have different shape formats. They can be `channel-last`, `channel-first` or even `no-channel`. We may, for example, want to load several `no-channel` images and stack them as `channel-first` data. To improve the user experience, MONAI provided an `EnsureChannelFirst` transform to automatically detect data shape according to the meta information and convert it to the `channel-first` format consistently.
+
+### Network architectures
+Various ready-to-use architectures with pretrained model weights from `torch.hub`.
+
+### Result writing
+Currently MONAI supports writing the model outputs as NIfTI files or PNG files for segmentation tasks, and as CSV files for classification tasks. And the writers can restore the data spacing, orientation or shape according to the `original_shape` or `original_affine` information from the input image.
+
+A rich set of formats will be supported soon, along with relevant statistics and evaluation metrics automatically computed from the outputs.
+
+### Transfer learning for different input / output classes
+`Transfer-learning` is a very common and efficient training approach, especially in the medical-specific domain where obtaining large datasets for training can be difficult. So transfer-learning from a pre-trained checkpoint can significantly improve the model metrics and shorten training time.
+
+MONAI provided `CheckpointLoader` to load a checkpoint for the workflow before training, and it allows some `layer names` of current network don't match the checkpoint, or some `layer shapes` don't match the checkpoint, which can be useful if the current task has different input image classes or output classes.
+
+### C++/CUDA optimized modules
+To accelerate some heavy computation progress, C++/CUDA implementation can be an impressive method, which usually brings even hundreds of times faster performance. MONAI contains some C++/CUDA optimized modules, like `Resampler`, `Conditional random field (CRF)`, `Fast bilateral filtering using the permutohedral lattice`, and fully support C++/CUDA programs in CI/CD and building package.
diff --git a/docs/source/whatsnew_0_6.md b/docs/source/whatsnew_0_6.md
new file mode 100644
index 0000000000..bdc419df37
--- /dev/null
+++ b/docs/source/whatsnew_0_6.md
@@ -0,0 +1,96 @@
+# What's new in 0.6 🎉🎉
+
+- Decollating mini-batches as an essential post-processing step
+- Pythonic APIs to load the pretrained models from Clara Train MMARs
+- UNETR: Transformers for Medical Image Segmentation
+- Enhancements of the base metric interfaces
+- C++/CUDA extension modules via PyTorch JIT compilation
+- Backward compatibility and enhanced continuous integration/continuous delivery
+- Collaboration with Project-MONAI/MONAILabel for smooth integration
+
+
+## Decollating mini-batches as an essential post-processing step
+`decollate batch` is introduced in MONAI v0.6, to simplify the post-processing transforms and enable flexible operations on a batch of model outputs.
+It can decollate batched data (e.g. model inference results) into a list of tensors -- as an 'inverse' operation of `collate_fn` of the PyTorch data loader. It has the benefits such as:
+- enabling postprocessing transforms for each item independently, for example, randomised transforms could be applied differently for each predicted item in a batch.
+- simplifying the transform APIs and reducing the input validation burdens, because both the preprocessing and postprocessing transforms now only support the "channel-first" input format.
+- enabling the transform inverse operation for data items in different original shapes, as the inverted items are in a list, instead of being stacked in a single tensor.
+- allowing for both a "batch-first" tensor and a list of "channel-first" tensors for flexible metric computation.
+
+A typical process of `decollate batch` is illustrated as follows (with a `batch_size=N` model predictions and labels as an example):
+![decollate_batch](../images/decollate_batch.png)
+
+[decollate batch tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/decollate_batch.ipynb) shows a detailed usage example based on a PyTorch native workflow.
+
+[Migrating your v0.5 code to v0.6](https://github.com/Project-MONAI/MONAI/wiki/v0.5-to-v0.6-migration-guide) wiki shows how to migrate an existing program from v0.5 to v0.6 to adapt to the `decollate batch` logic.
+
+## UNETR: Transformers for Medical Image Segmentation
+[UNETR](https://arxiv.org/abs/2103.10504) is a transformer-based model for volumetric (3D) medical image segmentation and is currently the state-of-the-art on [BTCV dataset](https://www.synapse.org/#!Synapse:syn3193805/wiki/217752) test server for the task of multi-organ semantic segmentation. UNETR is introduced in MONAI v0.6 and its flexible implementation supports various segmentation tasks.
+![UNETR](../images/UNETR.png)
+
+A tutorial for the task of 3D multi-organ semantic segmentation using UNETR is provided within
+[`project-monai/tutorials`](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unetr_btcv_segmentation_3d.ipynb).
+And it contains the following features:
+- Transforms for dictionary format data,
+- Defining a new transform according to MONAI transform API,
+- Loading Nifti image with metadata, loading a list of images and stacking them,
+- Randomly adjusting the intensity for data augmentation,
+- Optimized cache IO and transforms to accelerate training and validation,
+- 3D UNETR model, DiceCE loss function and Mean Dice metric for multi-organ segmentation task,
+
+The following illustrates target body organs that are segmentation in this tutorial:
+![BTCV_organs](../images/BTCV_organs.png)
+
+Please visit UNETR repository for more details:
+https://monai.io/research/unetr-btcv-multi-organ-segmentation
+
+## Pythonic APIs to load the pretrained models from Clara Train MMARs
+[The MMAR (Medical Model ARchive)](https://docs.nvidia.com/clara/clara-train-sdk/pt/mmar.html)
+defines a data structure for organizing all artifacts produced during the model development life cycle.
+NVIDIA Clara provides [various MMARs of medical domain-specific models](https://ngc.nvidia.com/catalog/models?orderBy=scoreDESC&pageNumber=0&query=clara_pt&quickFilter=&filters=).
+These MMARs include all the information about the model including configurations and scripts to provide a workspace to perform model development tasks. To better leverage the trained MMARs released on Nvidia GPU cloud, MONAI provides pythonic APIs to access them.
+
+To demonstrate this new feature, a medical image segmentation tutorial is created within
+[`project-monai/tutorials`](https://github.com/Project-MONAI/tutorials/blob/master/modules/transfer_mmar.ipynb).
+It mainly produces the following figure to compare the loss curves and validation scores for
+- training from scratch (the green line),
+- applying pretrained MMAR weights without training (the magenta line),
+- training from the MMAR model weights (the blue line),
+
+according to the number of training epochs:
+
+![transfer_mmar](../images/transfer_mmar.png)
+
+The tutorial shows the capability of encapsulating the details of MMAR parsing, as well as the potential of using pretrained MMARs for transfer learning.
+These APIs are also being integrated into AI-assisted interactive workflows to accelerate the manual annotating processes (e.g. via [project-MONAI/MONAILabel](https://github.com/Project-MONAI/MONAILabel)).
+
+## Enhancements of the base metric interfaces
+The base API for metrics is now enhanced to support the essential computation logic for both iteration and epoch-based metrics.
+With this update, the MONAI metrics module becomes more extensible, and thus a good starting point for customised metrics.
+The APIs also by default support data parallel computation and consider the computation efficiency: with a `Cumulative` base class, intermediate metric outcomes can be automatically buffered, cumulated, synced across distributed processes, and aggregated for the final results. The [multi-processing computation example](https://github.com/Project-MONAI/tutorials/blob/master/modules/compute_metric.py) shows how to compute metrics based on saved predictions and labels in multi-processing environment.
+
+## C++/CUDA extension modules via PyTorch JIT compilation
+To further accelerate the domain-specific routines in the workflows, MONAI C++/CUDA modules are introduced as extensions of the PyTorch native implementation.
+It now provides modules using [the two ways of building C++ extensions from PyTorch](https://pytorch.org/tutorials/advanced/cpp_extension.html#custom-c-and-cuda-extensions):
+- via `setuptools` (since MONAI v0.5), for modules including `Resampler`, `Conditional random field (CRF)`, `Fast bilateral filtering using the permutohedral lattice`.
+- via just-in-time (JIT) compilation (since MONAI v0.6), for the `Gaussian mixtures` module. This approach allows for dynamic optimisation according to the user-specified parameters and local system environments.
+The following figure shows results of MONAI's Gaussian mixture models applied to a tissue and surgical tools segmentation task:
+![Gaussian mixture models as a postprocessing step](../images/gmm_feature_set_comparison_s.png)
+
+## Backward compatibility and enhanced continuous integration/continuous delivery
+Starting from this version, we experiment with basic policies of backward compatibility.
+New utilities are introduced on top of the existing semantic versioning modules, and the git branching model.
+
+At the same time, we actively analyze efficient, scalable, and secure CI/CD solutions to accommodate fast and collaborative codebase development.
+
+Although a complete mechanism is still under development, these provide another essential step towards API-stable versions of MONAI, sustainable release cycles, and efficient open-source collaborations.
+
+## Collaboration with [`Project-MONAI/MONAILabel`](https://github.com/Project-MONAI/MONAILabel) for smooth integration
+Since MONAI v0.6, we welcome [`MONAILabel`](https://github.com/Project-MONAI/MONAILabel) under [`Project-MONAI`](https://github.com/Project-MONAI).
+
+MONAI Label is an intelligent open source image labeling and learning tool that enables users to create annotated datasets and build AI annotation models for clinical evaluation.
+MONAI Label enables application developers to build labeling apps in a serverless way,
+where custom labeling apps are exposed as a service through the MONAI Label Server.
+
+Please visit MONAILabel documentation website for details:
+https://docs.monai.io/projects/label/en/latest/
diff --git a/monai/README.md b/monai/README.md
index 89c1fa3653..a224996f38 100644
--- a/monai/README.md
+++ b/monai/README.md
@@ -12,19 +12,21 @@
* **handlers**: defines handlers for implementing functionality at various stages in the training process.
-* **inferers**: defines model inference methods.
+* **inferers**: defines model inference methods.
-* **losses**: classes defining loss functions.
+* **losses**: classes defining loss functions, which follow the pattern of `torch.nn.modules.loss`.
* **metrics**: defines metric tracking types.
* **networks**: contains network definitions, component definitions, and Pytorch specific utilities.
-* **optimizers**: classes defining optimizers.
+* **optimizers**: classes defining optimizers, which follow the pattern of `torch.optim`.
* **transforms**: defines data transforms for preprocessing and postprocessing.
* **utils**: generic utilities intended to be implemented in pure Python or using Numpy,
and not with Pytorch, such as namespace aliasing, auto module loading.
-* **visualize**: utilities for data visualization.
\ No newline at end of file
+* **visualize**: utilities for data visualization.
+
+* **_extensions**: C++/CUDA extensions to be loaded in a just-in-time manner using `torch.utils.cpp_extension.load`.
diff --git a/monai/__init__.py b/monai/__init__.py
index 910698ee14..2c7c920162 100644
--- a/monai/__init__.py
+++ b/monai/__init__.py
@@ -18,8 +18,8 @@
PY_REQUIRED_MINOR = 6
version_dict = get_versions()
-__version__ = version_dict.get("version", "0+unknown")
-__revision_id__ = version_dict.get("full-revisionid", None)
+__version__: str = version_dict.get("version", "0+unknown")
+__revision_id__: str = version_dict.get("full-revisionid")
del get_versions, version_dict
__copyright__ = "(c) 2020 - 2021 MONAI Consortium"
@@ -44,3 +44,19 @@
# load all modules, this will trigger all export decorations
load_submodules(sys.modules[__name__], True, exclude_pattern=excludes)
+
+__all__ = [
+ "apps",
+ "config",
+ "data",
+ "engines",
+ "handlers",
+ "inferers",
+ "losses",
+ "metrics",
+ "networks",
+ "optimizers",
+ "transforms",
+ "utils",
+ "visualize",
+]
diff --git a/monai/_extensions/__init__.py b/monai/_extensions/__init__.py
new file mode 100644
index 0000000000..3718894b7c
--- /dev/null
+++ b/monai/_extensions/__init__.py
@@ -0,0 +1,12 @@
+# Copyright 2020 - 2021 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .loader import load_module
diff --git a/monai/_extensions/gmm/gmm.cpp b/monai/_extensions/gmm/gmm.cpp
new file mode 100644
index 0000000000..ecb85e252a
--- /dev/null
+++ b/monai/_extensions/gmm/gmm.cpp
@@ -0,0 +1,89 @@
+/*
+Copyright 2020 - 2021 MONAI Consortium
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+#include
+
+#include "gmm.h"
+
+py::tuple init()
+{
+ torch::Tensor gmm_tensor = torch::zeros({GMM_COUNT, GMM_COMPONENT_COUNT}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
+ torch::Tensor scratch_tensor = torch::empty({1}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
+ return py::make_tuple(gmm_tensor, scratch_tensor);
+}
+
+void learn(torch::Tensor gmm_tensor, torch::Tensor scratch_tensor, torch::Tensor input_tensor, torch::Tensor label_tensor)
+{
+ c10::DeviceType device_type = input_tensor.device().type();
+
+ unsigned int batch_count = input_tensor.size(0);
+ unsigned int element_count = input_tensor.stride(1);
+
+ unsigned int scratch_size = batch_count * (element_count + GMM_COMPONENT_COUNT * GMM_COUNT * (element_count / (32 * 32)));
+
+ if (scratch_tensor.size(0) < scratch_size)
+ {
+ scratch_tensor.resize_({scratch_size});
+ }
+
+ float* gmm = gmm_tensor.data_ptr();
+ float* scratch = scratch_tensor.data_ptr();
+ float* input = input_tensor.data_ptr();
+ int* labels = label_tensor.data_ptr();
+
+ if(device_type == torch::kCUDA)
+ {
+ learn_cuda(input, labels, gmm, scratch, batch_count, element_count);
+ }
+ else
+ {
+ learn_cpu(input, labels, gmm, scratch, batch_count, element_count);
+ }
+}
+
+torch::Tensor apply(torch::Tensor gmm_tensor, torch::Tensor input_tensor)
+{
+ c10::DeviceType device_type = input_tensor.device().type();
+
+ unsigned int dim = input_tensor.dim();
+ unsigned int batch_count = input_tensor.size(0);
+ unsigned int element_count = input_tensor.stride(1);
+
+ long int* output_size = new long int[dim];
+ memcpy(output_size, input_tensor.sizes().data(), dim * sizeof(long int));
+ output_size[1] = MIXTURE_COUNT;
+ torch::Tensor output_tensor = torch::empty(c10::IntArrayRef(output_size, dim), torch::dtype(torch::kFloat32).device(device_type));
+ delete output_size;
+
+ const float* gmm = gmm_tensor.data_ptr();
+ const float* input = input_tensor.data_ptr();
+ float* output = output_tensor.data_ptr();
+
+ if(device_type == torch::kCUDA)
+ {
+ apply_cuda(gmm, input, output, batch_count, element_count);
+ }
+ else
+ {
+ apply_cpu(gmm, input, output, batch_count, element_count);
+ }
+
+ return output_tensor;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("init", torch::wrap_pybind_function(init));
+ m.def("learn", torch::wrap_pybind_function(learn));
+ m.def("apply", torch::wrap_pybind_function(apply));
+}
diff --git a/monai/_extensions/gmm/gmm.h b/monai/_extensions/gmm/gmm.h
new file mode 100644
index 0000000000..9a43351eb9
--- /dev/null
+++ b/monai/_extensions/gmm/gmm.h
@@ -0,0 +1,32 @@
+/*
+Copyright 2020 - 2021 MONAI Consortium
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+#if !defined(CHANNEL_COUNT) || !defined(MIXTURE_COUNT) || !defined(MIXTURE_SIZE)
+#error Definition of CHANNEL_COUNT, MIXTURE_COUNT, and MIXTURE_SIZE required
+#endif
+
+#if CHANNEL_COUNT < 1 || MIXTURE_COUNT < 1 || MIXTURE_SIZE < 1
+#error CHANNEL_COUNT, MIXTURE_COUNT, and MIXTURE_SIZE must be positive
+#endif
+
+#define MATRIX_COMPONENT_COUNT ((CHANNEL_COUNT + 1) * (CHANNEL_COUNT + 2) / 2)
+#define SUB_MATRIX_COMPONENT_COUNT (CHANNEL_COUNT * (CHANNEL_COUNT + 1) / 2)
+#define GMM_COMPONENT_COUNT (MATRIX_COMPONENT_COUNT + 1)
+#define GMM_COUNT (MIXTURE_COUNT * MIXTURE_SIZE)
+
+
+void learn_cpu(const float* input, const int* labels, float* gmm, float* scratch_memory, unsigned int batch_count, unsigned int element_count);
+void apply_cpu(const float* gmm, const float* input, float* output, unsigned int batch_count, unsigned int element_count);
+
+void learn_cuda(const float* input, const int* labels, float* gmm, float* scratch_memory, unsigned int batch_count, unsigned int element_count);
+void apply_cuda(const float* gmm, const float* input, float* output, unsigned int batch_count, unsigned int element_count);
diff --git a/monai/_extensions/gmm/gmm_cpu.cpp b/monai/_extensions/gmm/gmm_cpu.cpp
new file mode 100644
index 0000000000..144e66806c
--- /dev/null
+++ b/monai/_extensions/gmm/gmm_cpu.cpp
@@ -0,0 +1,26 @@
+/*
+Copyright 2020 - 2021 MONAI Consortium
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+#include
+
+#include "gmm.h"
+
+void learn_cpu(const float* input, const int* labels, float* gmm, float* scratch_memory, unsigned int batch_count, unsigned int element_count)
+{
+ throw std::invalid_argument("GMM received a cpu tensor but is not yet implemented for the cpu");
+}
+
+void apply_cpu(const float* gmm, const float* input, float* output, unsigned int batch_count, unsigned int element_count)
+{
+ throw std::invalid_argument("GMM recieved a cpu tensor but is not yet implemented for the cpu");
+}
diff --git a/monai/_extensions/gmm/gmm_cuda.cu b/monai/_extensions/gmm/gmm_cuda.cu
new file mode 100644
index 0000000000..36af48b06c
--- /dev/null
+++ b/monai/_extensions/gmm/gmm_cuda.cu
@@ -0,0 +1,521 @@
+/*
+Copyright 2020 - 2021 MONAI Consortium
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+#include
+#include
+
+#include "gmm.h"
+
+#include "gmm_cuda_linalg.cuh"
+
+#define EPSILON 1e-5
+#define BLOCK_SIZE 32
+#define TILE(SIZE, STRIDE) ((((SIZE) - 1)/(STRIDE)) + 1)
+
+template
+__global__ void CovarianceReductionKernel(int gaussian_index, const float* g_image, const int* g_alpha, float* g_matrices, int element_count)
+{
+ constexpr int block_size = warp_count * 32;
+
+ __shared__ float s_matrix_component[warp_count];
+
+ int batch_index = blockIdx.z;
+
+ const float* g_batch_image = g_image + batch_index * element_count * CHANNEL_COUNT;
+ const int* g_batch_alpha = g_alpha + batch_index * element_count;
+ float* g_batch_matrices = g_matrices + batch_index * GMM_COUNT * GMM_COMPONENT_COUNT * gridDim.x;
+
+ int local_index = threadIdx.x;
+ int block_index = blockIdx.x;
+ int warp_index = local_index >> 5;
+ int lane_index = local_index & 31;
+ int global_index = local_index + block_index * block_size * load_count;
+ int matrix_offset = (gaussian_index * gridDim.x + block_index) * GMM_COMPONENT_COUNT;
+
+ float matrix[MATRIX_COMPONENT_COUNT];
+
+ for (int i = 0; i < MATRIX_COMPONENT_COUNT; i++)
+ {
+ matrix[i] = 0;
+ }
+
+ for (int load = 0; load < load_count; load++)
+ {
+ global_index += load * block_size;
+
+ if (global_index < element_count)
+ {
+ int my_alpha = g_batch_alpha[global_index];
+
+ if (my_alpha != -1)
+ {
+ if (gaussian_index == (my_alpha & 15) + (my_alpha >> 4) * MIXTURE_COUNT)
+ {
+ float feature[CHANNEL_COUNT + 1];
+
+ feature[0] = 1;
+
+ for (int i = 0; i < CHANNEL_COUNT; i++)
+ {
+ feature[i + 1] = g_batch_image[global_index + i * element_count];
+ }
+
+ for (int index = 0, i = 0; i < CHANNEL_COUNT + 1; i++)
+ {
+ for (int j = i; j < CHANNEL_COUNT + 1; j++, index++)
+ {
+ matrix[index] += feature[i] * feature[j];
+ }
+ }
+ }
+ }
+ }
+ }
+
+ __syncthreads();
+
+ for (int i = 0; i < MATRIX_COMPONENT_COUNT; i++)
+ {
+ float matrix_component = matrix[i];
+
+ matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16);
+ matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8);
+ matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4);
+ matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2);
+ matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1);
+
+ if (lane_index == 0)
+ {
+ s_matrix_component[warp_index] = matrix_component;
+ }
+
+ __syncthreads();
+
+ if (warp_index == 0)
+ {
+ matrix_component = s_matrix_component[lane_index];
+
+ if (warp_count >= 32) { matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16); }
+ if (warp_count >= 16) { matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8); }
+ if (warp_count >= 8) { matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4); }
+ if (warp_count >= 4) { matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2); }
+ if (warp_count >= 2) { matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1); }
+
+ if (lane_index == 0)
+ {
+ g_batch_matrices[matrix_offset + i] = matrix_component;
+ }
+ }
+
+ __syncthreads();
+ }
+}
+
+template
+__global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_gmm, int matrix_count)
+{
+ constexpr int block_size = warp_count * 32;
+
+ __shared__ float s_matrix_component[warp_count];
+ __shared__ float s_gmm[GMM_COMPONENT_COUNT];
+
+ int batch_index = blockIdx.z;
+
+ const float* g_batch_matrices = g_matrices + batch_index * GMM_COUNT * GMM_COMPONENT_COUNT * matrix_count;
+ float* g_batch_gmm = g_gmm + batch_index * GMM_COUNT * GMM_COMPONENT_COUNT;
+
+ int local_index = threadIdx.x;
+ int warp_index = local_index >> 5;
+ int lane_index = local_index & 31;
+ int gmm_index = blockIdx.x;
+ int matrix_offset = gmm_index * matrix_count;
+
+ int load_count = TILE(matrix_count, block_size);
+
+ float norm_factor = 1.0f;
+
+ for (int index = 0, i = 0; i < CHANNEL_COUNT + 1; i++)
+ {
+ for (int j = i; j < CHANNEL_COUNT + 1; j++, index++)
+ {
+ float matrix_component = 0.0f;
+
+ for(int load = 0; load < load_count; load++)
+ {
+ int matrix_index = local_index + load * block_size;
+
+ if(matrix_index < matrix_count)
+ {
+ matrix_component += g_batch_matrices[(matrix_offset + matrix_index) * GMM_COMPONENT_COUNT + index];
+ }
+ }
+
+ matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16);
+ matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8);
+ matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4);
+ matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2);
+ matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1);
+
+ if (lane_index == 0)
+ {
+ s_matrix_component[warp_index] = matrix_component;
+ }
+
+ __syncthreads();
+
+ if (warp_index == 0)
+ {
+ matrix_component = s_matrix_component[lane_index];
+
+ if (warp_count >= 32) { matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16); }
+ if (warp_count >= 16) { matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8); }
+ if (warp_count >= 8) { matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4); }
+ if (warp_count >= 4) { matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2); }
+ if (warp_count >= 2) { matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1); }
+
+ if (lane_index == 0)
+ {
+ float constant = i == 0 ? 0.0f : s_gmm[i] * s_gmm[j];
+
+ if (i != 0 && i == j)
+ {
+ constant -= EPSILON;
+ }
+
+ s_gmm[index] = norm_factor * matrix_component - constant;
+
+ if (index == 0 && matrix_component > 0)
+ {
+ norm_factor = 1.0f / matrix_component;
+ }
+ }
+ }
+
+ __syncthreads();
+ }
+ }
+
+ float* matrix = s_gmm + (CHANNEL_COUNT + 1);
+ float* det_ptr = s_gmm + MATRIX_COMPONENT_COUNT;
+
+ if (local_index == 0)
+ {
+ float square_mat[CHANNEL_COUNT][CHANNEL_COUNT];
+ float cholesky_mat[CHANNEL_COUNT][CHANNEL_COUNT];
+
+ for(int i = 0; i < CHANNEL_COUNT; i++)
+ {
+ for(int j = 0; j < CHANNEL_COUNT; j++)
+ {
+ square_mat[i][j] = 0.0f;
+ cholesky_mat[i][j] = 0.0f;
+ }
+ }
+
+ to_square(matrix, square_mat);
+ cholesky(square_mat, cholesky_mat);
+
+ *det_ptr = chol_det(cholesky_mat);
+
+ if (invert_matrix)
+ {
+ chol_inv(cholesky_mat, square_mat);
+ to_triangle(square_mat, matrix);
+ }
+ }
+
+ if (local_index < GMM_COMPONENT_COUNT)
+ {
+ g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT + local_index] = s_gmm[local_index];
+ }
+}
+
+struct GMMSplit_t
+{
+ int idx;
+ float threshold;
+ float eigenvector[CHANNEL_COUNT];
+};
+
+// 1 Block, 32xMIXTURE_COUNT
+__global__ void GMMFindSplit(GMMSplit_t *gmmSplit, int gmmK, float *gmm)
+{
+ int batch_index = blockIdx.z;
+
+ float* g_batch_gmm = gmm + batch_index * GMM_COUNT * GMM_COMPONENT_COUNT;
+ GMMSplit_t* g_batch_gmmSplit = gmmSplit + batch_index * MIXTURE_COUNT;
+
+ int gmm_idx = threadIdx.x * MIXTURE_COUNT + threadIdx.y;
+
+ float eigenvalue = 0;
+ float eigenvector[CHANNEL_COUNT];
+
+ if (threadIdx.x < gmmK)
+ {
+ float* matrix = g_batch_gmm + gmm_idx * GMM_COMPONENT_COUNT + (CHANNEL_COUNT + 1);
+ largest_eigenpair(matrix, eigenvector, &eigenvalue);
+ }
+
+ float max_value = eigenvalue;
+
+ max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 16));
+ max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 8));
+ max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 4));
+ max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 2));
+ max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
+
+ if (max_value == eigenvalue)
+ {
+ GMMSplit_t split;
+
+ float* average_feature = gmm + gmm_idx * GMM_COMPONENT_COUNT + 1;
+
+ split.idx = threadIdx.x;
+ split.threshold = scalar_prod(average_feature, eigenvector);
+
+ for (int i = 0; i < CHANNEL_COUNT; i++)
+ {
+ split.eigenvector[i] = eigenvector[i];
+ }
+
+ g_batch_gmmSplit[threadIdx.y] = split;
+ }
+}
+
+#define DO_SPLIT_DEGENERACY 4
+
+__global__ void GMMDoSplit(const GMMSplit_t *gmmSplit, int k, const float *image, int *alpha, int element_count)
+{
+ __shared__ GMMSplit_t s_gmmSplit[MIXTURE_COUNT];
+
+ int batch_index = blockIdx.z;
+
+ const GMMSplit_t* g_batch_gmmSplit = gmmSplit + batch_index * MIXTURE_COUNT;
+ const float* g_batch_image = image + batch_index * element_count * CHANNEL_COUNT;
+ int* g_batch_alpha = alpha + batch_index * element_count;
+
+ int *s_linear = (int *) s_gmmSplit;
+ int *g_linear = (int *) g_batch_gmmSplit;
+
+ if (threadIdx.x < MIXTURE_COUNT * sizeof(GMMSplit_t))
+ {
+ s_linear[threadIdx.x] = g_linear[threadIdx.x];
+ }
+
+ __syncthreads();
+
+ int index = threadIdx.x + blockIdx.x * BLOCK_SIZE * DO_SPLIT_DEGENERACY;
+
+ for (int i = 0; i < DO_SPLIT_DEGENERACY; i++)
+ {
+ index += BLOCK_SIZE;
+
+ if (index < element_count)
+ {
+ int my_alpha = g_batch_alpha[index];
+
+ if(my_alpha != -1)
+ {
+ int select = my_alpha & 15;
+ int gmm_idx = my_alpha >> 4;
+
+ if (gmm_idx == s_gmmSplit[select].idx)
+ {
+ // in the split cluster now
+ float feature[CHANNEL_COUNT];
+
+ for (int i = 0; i < CHANNEL_COUNT; i++)
+ {
+ feature[i] = g_batch_image[index + i * element_count];
+ }
+
+ float value = scalar_prod(s_gmmSplit[select].eigenvector, feature);
+
+ if (value > s_gmmSplit[select].threshold)
+ {
+ // assign pixel to new cluster
+ g_batch_alpha[index] = k + select;
+ }
+ }
+ }
+ }
+ }
+}
+
+// Single block, 32xMIXTURE_COUNT
+__global__ void GMMcommonTerm(float *g_gmm)
+{
+ int batch_index = blockIdx.z;
+
+ float* g_batch_gmm = g_gmm + batch_index * GMM_COUNT * GMM_COMPONENT_COUNT;
+
+ int gmm_index = (threadIdx.x * MIXTURE_COUNT) + threadIdx.y;
+
+ float gmm_n = threadIdx.x < MIXTURE_SIZE ? g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT] : 0.0f;
+
+ float sum = gmm_n;
+
+ sum += __shfl_xor_sync(0xffffffff, sum, 1);
+ sum += __shfl_xor_sync(0xffffffff, sum, 2);
+ sum += __shfl_xor_sync(0xffffffff, sum, 4);
+ sum += __shfl_xor_sync(0xffffffff, sum, 8);
+ sum += __shfl_xor_sync(0xffffffff, sum, 16);
+
+ if (threadIdx.x < MIXTURE_SIZE)
+ {
+ float det = g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT + MATRIX_COMPONENT_COUNT] + EPSILON;
+ float commonTerm = det > 0.0f ? gmm_n / (sqrtf(det) * sum) : gmm_n / sum;
+
+ g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT + MATRIX_COMPONENT_COUNT] = commonTerm;
+ }
+}
+
+__device__ float GMMTerm(float* feature, const float *gmm)
+{
+ const float* average_feature = gmm + 1;
+ const float* matrix = gmm + CHANNEL_COUNT + 1;
+
+ float diff[CHANNEL_COUNT];
+
+ for (int i = 0; i < CHANNEL_COUNT; i++)
+ {
+ diff[i] = feature[i] - average_feature[i];
+ }
+
+ float value = 0.0f;
+
+ for (int index = 0, i = 0; i < CHANNEL_COUNT; i++)
+ {
+ for (int j = i; j < CHANNEL_COUNT; j++, index++)
+ {
+ float term = diff[i] * diff[j] * matrix[index];
+
+ value += i == j ? term : 2 * term;
+ }
+ }
+
+ return gmm[MATRIX_COMPONENT_COUNT] * expf(-0.5 * value);
+}
+
+__global__ void GMMDataTermKernel(const float *image, const float *gmm, float* output, int element_count)
+{
+ int batch_index = blockIdx.z;
+
+ const float* g_batch_image = image + batch_index * element_count * CHANNEL_COUNT;
+ const float* g_batch_gmm = gmm + batch_index * GMM_COUNT * GMM_COMPONENT_COUNT;
+ float* g_batch_output = output + batch_index * element_count * MIXTURE_COUNT;
+
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (index >= element_count) return;
+
+ float feature[CHANNEL_COUNT];
+
+ for (int i = 0; i < CHANNEL_COUNT; i++)
+ {
+ feature[i] = g_batch_image[index + i * element_count];
+ }
+
+ float weights[MIXTURE_COUNT];
+ float weight_total = 0.0f;
+
+ for(int i = 0; i < MIXTURE_COUNT; i++)
+ {
+ float mixture_weight = 0.0f;
+
+ for(int j = 0; j < MIXTURE_SIZE; j++)
+ {
+ mixture_weight += GMMTerm(feature, &g_batch_gmm[(MIXTURE_COUNT * j + i) * GMM_COMPONENT_COUNT]);
+ }
+
+ weights[i] = mixture_weight;
+ weight_total += mixture_weight;
+ }
+
+ for(int i = 0; i < MIXTURE_COUNT; i++)
+ {
+ // protecting against pixels with 0 in all mixtures
+ float final_weight = weight_total > 0.0f ? weights[i] / weight_total : 0.0f;
+ g_batch_output[index + i * element_count] = final_weight;
+ }
+}
+
+#define THREADS 512
+#define WARPS 16
+#define BLOCK (WARPS << 5)
+#define LOAD 4
+
+void GMMInitialize(const float *image, int *alpha, float *gmm, float *scratch_mem, unsigned int batch_count, unsigned int element_count)
+{
+ unsigned int block_count = TILE(element_count, BLOCK * LOAD);
+
+ float* block_gmm_scratch = scratch_mem;
+ GMMSplit_t* gmm_split_scratch = (GMMSplit_t*) scratch_mem;
+
+ int gmm_N = MIXTURE_COUNT * MIXTURE_SIZE;
+
+ for (unsigned int k = MIXTURE_COUNT; k < gmm_N; k+=MIXTURE_COUNT)
+ {
+ for (unsigned int i = 0; i < k; ++i)
+ {
+ CovarianceReductionKernel<<<{block_count, 1, batch_count}, BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count);
+ }
+
+ CovarianceFinalizationKernel<<<{k, 1, batch_count}, BLOCK>>>(block_gmm_scratch, gmm, block_count);
+
+ GMMFindSplit<<<{1, 1, batch_count}, dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(gmm_split_scratch, k / MIXTURE_COUNT, gmm);
+ GMMDoSplit<<<{TILE(element_count, BLOCK_SIZE * DO_SPLIT_DEGENERACY), 1, batch_count}, BLOCK_SIZE>>>(gmm_split_scratch, (k / MIXTURE_COUNT) << 4, image, alpha, element_count);
+ }
+}
+
+void GMMUpdate(const float *image, int *alpha, float *gmm, float *scratch_mem, unsigned int batch_count, unsigned int element_count)
+{
+ unsigned int block_count = TILE(element_count, BLOCK * LOAD);
+
+ float* block_gmm_scratch = scratch_mem;
+
+ unsigned int gmm_N = MIXTURE_COUNT * MIXTURE_SIZE;
+
+ for (unsigned int i = 0; i < gmm_N; ++i)
+ {
+ CovarianceReductionKernel<<<{block_count, 1, batch_count}, BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count);
+ }
+
+ CovarianceFinalizationKernel<<<{gmm_N, 1, batch_count}, BLOCK>>>(block_gmm_scratch, gmm, block_count);
+
+ GMMcommonTerm<<<{1, 1, batch_count}, dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(gmm);
+}
+
+void GMMDataTerm(const float *image, const float *gmm, float* output, unsigned int batch_count, unsigned int element_count)
+{
+ dim3 block(BLOCK_SIZE, 1);
+ dim3 grid(TILE(element_count, BLOCK_SIZE), 1, batch_count);
+
+ GMMDataTermKernel<<>>(image, gmm, output, element_count);
+}
+
+void learn_cuda(const float* input, const int* labels, float* gmm, float* scratch_memory, unsigned int batch_count, unsigned int element_count)
+{
+ int* alpha = (int*)scratch_memory;
+ float* scratch_mem = scratch_memory + batch_count * element_count;
+
+ cudaMemcpyAsync(alpha, labels, batch_count * element_count * sizeof(int), cudaMemcpyDeviceToDevice);
+
+ GMMInitialize(input, alpha, gmm, scratch_mem, batch_count, element_count);
+ GMMUpdate(input, alpha, gmm, scratch_mem, batch_count, element_count);
+}
+
+void apply_cuda(const float* gmm, const float* input, float* output, unsigned int batch_count, unsigned int element_count)
+{
+ GMMDataTerm(input, gmm, output, batch_count, element_count);
+}
diff --git a/monai/_extensions/gmm/gmm_cuda_linalg.cuh b/monai/_extensions/gmm/gmm_cuda_linalg.cuh
new file mode 100644
index 0000000000..49e68c8442
--- /dev/null
+++ b/monai/_extensions/gmm/gmm_cuda_linalg.cuh
@@ -0,0 +1,180 @@
+/*
+Copyright 2020 - 2021 MONAI Consortium
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+__device__ void to_square(float in[SUB_MATRIX_COMPONENT_COUNT], float out[CHANNEL_COUNT][CHANNEL_COUNT])
+{
+ for (int index = 0, i = 0; i < CHANNEL_COUNT; i++)
+ {
+ for (int j = i; j < CHANNEL_COUNT; j++, index++)
+ {
+ out[i][j] = in[index];
+ out[j][i] = in[index];
+ }
+ }
+}
+
+__device__ void to_triangle(float in[CHANNEL_COUNT][CHANNEL_COUNT], float out[SUB_MATRIX_COMPONENT_COUNT])
+{
+ for (int index = 0, i = 0; i < CHANNEL_COUNT; i++)
+ {
+ for (int j = i; j < CHANNEL_COUNT; j++, index++)
+ {
+ out[index] = in[j][i];
+ }
+ }
+}
+
+__device__ void cholesky(float in[CHANNEL_COUNT][CHANNEL_COUNT], float out[CHANNEL_COUNT][CHANNEL_COUNT])
+{
+ for (int i = 0; i < CHANNEL_COUNT; i++)
+ {
+ for (int j = 0; j < i+1; j++)
+ {
+ float sum = 0.0f;
+
+ for (int k = 0; k < j; k++)
+ {
+ sum += out[i][k] * out[j][k];
+ }
+
+ if (i == j)
+ {
+ out[i][j] = sqrtf(in[i][i] - sum);
+ }
+ else
+ {
+ out[i][j] = (in[i][j] - sum) / out[j][j];
+ }
+ }
+ }
+}
+
+__device__ float chol_det(float in[CHANNEL_COUNT][CHANNEL_COUNT])
+{
+ float det = 1.0f;
+
+ for (int i = 0; i < CHANNEL_COUNT; i++)
+ {
+ det *= in[i][i];
+ }
+
+ return det * det;
+}
+
+__device__ void chol_inv(float in[CHANNEL_COUNT][CHANNEL_COUNT], float out[CHANNEL_COUNT][CHANNEL_COUNT])
+{
+ // Invert cholesky matrix
+ for (int i = 0; i < CHANNEL_COUNT; i++)
+ {
+ in[i][i] = 1.0f / (in[i][i] + 0.0001f);
+
+ for (int j = 0; j < i; j++)
+ {
+ float sum = 0.0f;
+
+ for (int k = j; k < i; k++)
+ {
+ sum += in[i][k] * in[k][j];
+ }
+
+ in[i][j] = -in[i][i] * sum;
+ }
+ }
+
+ // Dot with transpose of self
+ for (int i = 0; i < CHANNEL_COUNT; i++)
+ {
+ for (int j = 0; j < CHANNEL_COUNT; j++)
+ {
+ out[i][j] = 0.0f;
+
+ for (int k = max(i, j); k < CHANNEL_COUNT; k++)
+ {
+ out[i][j] += in[k][i] * in[k][j];
+ }
+ }
+ }
+}
+
+__device__ void normalize(float* v)
+{
+ float norm = 0.0f;
+
+ for (int i = 0; i < CHANNEL_COUNT; i++)
+ {
+ norm += v[i] * v[i];
+ }
+
+ norm = 1.0f / sqrtf(norm);
+
+ for (int i = 0; i < CHANNEL_COUNT; i++)
+ {
+ v[i] *= norm;
+ }
+}
+
+__device__ float scalar_prod(float* a, float* b)
+{
+ float product = 0.0f;
+
+ for (int i = 0; i < CHANNEL_COUNT; i++)
+ {
+ product += a[i] * b[i];
+ }
+
+ return product;
+}
+
+__device__ void largest_eigenpair(const float *M, float* evec, float* eval)
+{
+ float scratch[CHANNEL_COUNT];
+
+ for(int i = 0; i < CHANNEL_COUNT; i++)
+ {
+ scratch[i] = i + 1;
+ }
+
+ for (int itr = 0; itr < 10; itr++)
+ {
+ *eval = 0.0f;
+
+ for (int i = 0; i < CHANNEL_COUNT; i++)
+ {
+ int index = i;
+
+ evec[i] = 0.0f;
+
+ for (int j = 0; j < CHANNEL_COUNT; j++)
+ {
+ evec[i] += M[index] * scratch[j];
+
+ if (j < i)
+ {
+ index += CHANNEL_COUNT - (j + 1);
+ }
+ else
+ {
+ index += 1;
+ }
+ }
+
+ *eval = max(*eval, evec[i]);
+ }
+
+ for (int i = 0; i < CHANNEL_COUNT; i++)
+ {
+ evec[i] /= *eval;
+ scratch[i] = evec[i];
+ }
+ }
+}
diff --git a/monai/_extensions/loader.py b/monai/_extensions/loader.py
new file mode 100644
index 0000000000..5f77480ecc
--- /dev/null
+++ b/monai/_extensions/loader.py
@@ -0,0 +1,94 @@
+# Copyright 2020 - 2021 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import platform
+from _thread import interrupt_main
+from contextlib import contextmanager
+from glob import glob
+from os import path
+from threading import Timer
+from typing import Optional
+
+import torch
+
+from monai.utils.module import get_torch_version_tuple, optional_import
+
+dir_path = path.dirname(path.realpath(__file__))
+
+
+@contextmanager
+def timeout(time, message):
+ timer = None
+ try:
+ timer = Timer(time, interrupt_main)
+ timer.daemon = True
+ yield timer.start()
+ except KeyboardInterrupt as e:
+ if timer is not None and timer.is_alive():
+ raise e # interrupt from user?
+ raise TimeoutError(message)
+ finally:
+ if timer is not None:
+ try:
+ timer.cancel()
+ finally:
+ pass
+
+
+def load_module(
+ module_name: str, defines: Optional[dict] = None, verbose_build: bool = False, build_timeout: int = 300
+):
+ """
+ Handles the loading of c++ extension modules.
+
+ Args:
+ module_name: Name of the module to load.
+ Must match the name of the relevant source directory in the `_extensions` directory.
+ defines: Dictionary containing names and values of compilation defines.
+ verbose_build: Set to true to enable build logging.
+ build_timeout: Time in seconds before the build will throw an exception to prevent hanging.
+ """
+
+ # Ensuring named module exists in _extensions directory.
+ module_dir = path.join(dir_path, module_name)
+ if not path.exists(module_dir):
+ raise ValueError(f"No extension module named {module_name}")
+
+ platform_str = f"_{platform.system()}_{platform.python_version()}_"
+ platform_str += "".join(f"{v}" for v in get_torch_version_tuple()[:2])
+ # Adding configuration to module name.
+ if defines is not None:
+ module_name = "_".join([module_name] + [f"{v}" for v in defines.values()])
+
+ # Gathering source files.
+ source = glob(path.join(module_dir, "**", "*.cpp"), recursive=True)
+ if torch.cuda.is_available():
+ source += glob(path.join(module_dir, "**", "*.cu"), recursive=True)
+ platform_str += f"_{torch.version.cuda}"
+
+ # Constructing compilation argument list.
+ define_args = [] if not defines else [f"-D {key}={defines[key]}" for key in defines]
+
+ # Ninja may be blocked by something out of our control.
+ # This will error if the build takes longer than expected.
+ with timeout(build_timeout, "Build appears to be blocked. Is there a stopped process building the same extension?"):
+ load, _ = optional_import("torch.utils.cpp_extension", name="load") # main trigger some JIT config in pytorch
+ # This will either run the build or return the existing .so object.
+ name = module_name + platform_str.replace(".", "_")
+ module = load(
+ name=name,
+ sources=source,
+ extra_cflags=define_args,
+ extra_cuda_cflags=define_args,
+ verbose=verbose_build,
+ )
+
+ return module
diff --git a/monai/_version.py b/monai/_version.py
index 1b31d5fd1a..79f569dd79 100644
--- a/monai/_version.py
+++ b/monai/_version.py
@@ -1,3 +1,4 @@
+
# This file helps to compute a version number in source trees obtained from
# git-archive tarball (such as those provided by githubs download-from-tag
# feature). Distribution tarballs (built by setup.py sdist) and build
@@ -5,7 +6,7 @@
# that just contains the computed version number.
# This file is released into the public domain. Generated by
-# versioneer-0.18 (https://github.com/warner/python-versioneer)
+# versioneer-0.19 (https://github.com/python-versioneer/python-versioneer)
"""Git implementation of _version.py."""
@@ -56,7 +57,7 @@ class NotThisMethod(Exception):
def register_vcs_handler(vcs, method): # decorator
- """Decorator to mark a method as the handler for a particular VCS."""
+ """Create decorator to mark a method as the handler of a VCS."""
def decorate(f):
"""Store f in HANDLERS[vcs][method]."""
if vcs not in HANDLERS:
@@ -92,9 +93,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
if verbose:
print("unable to find command, tried %s" % (commands,))
return None, None
- stdout = p.communicate()[0].strip()
- if sys.version_info[0] >= 3:
- stdout = stdout.decode()
+ stdout = p.communicate()[0].strip().decode()
if p.returncode != 0:
if verbose:
print("unable to run %s (error)" % dispcmd)
@@ -164,6 +163,10 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
raise NotThisMethod("no keywords at all, weird")
date = keywords.get("date")
if date is not None:
+ # Use only the last line. Previous lines may contain GPG signature
+ # information.
+ date = date.splitlines()[-1]
+
# git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant
# datestamp. However we prefer "%ci" (which expands to an "ISO-8601
# -like" string, which we must then edit to make compliant), because
@@ -299,6 +302,9 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
# commit date: see ISO-8601 comment in git_versions_from_keywords()
date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"],
cwd=root)[0].strip()
+ # Use only the last line. Previous lines may contain GPG signature
+ # information.
+ date = date.splitlines()[-1]
pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
return pieces
@@ -337,18 +343,18 @@ def render_pep440(pieces):
def render_pep440_pre(pieces):
- """TAG[.post.devDISTANCE] -- No -dirty.
+ """TAG[.post0.devDISTANCE] -- No -dirty.
Exceptions:
- 1: no tags. 0.post.devDISTANCE
+ 1: no tags. 0.post0.devDISTANCE
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"]:
- rendered += ".post.dev%d" % pieces["distance"]
+ rendered += ".post0.dev%d" % pieces["distance"]
else:
# exception #1
- rendered = "0.post.dev%d" % pieces["distance"]
+ rendered = "0.post0.dev%d" % pieces["distance"]
return rendered
@@ -494,7 +500,7 @@ def get_versions():
# versionfile_source is the relative path from the top of the source
# tree (where the .git directory might live) to this file. Invert
# this to find the root from __file__.
- for i in cfg.versionfile_source.split('/'): # lgtm[py/unused-loop-variable]
+ for i in cfg.versionfile_source.split('/'):
root = os.path.dirname(root)
except NameError:
return {"version": "0+unknown", "full-revisionid": None,
diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py
index 59f38cbb6f..ef4352cabd 100644
--- a/monai/apps/__init__.py
+++ b/monai/apps/__init__.py
@@ -10,4 +10,5 @@
# limitations under the License.
from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset
+from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar
from .utils import check_hash, download_and_extract, download_url, extractall
diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py
index f0416b8c4f..c766914026 100644
--- a/monai/apps/datasets.py
+++ b/monai/apps/datasets.py
@@ -11,7 +11,7 @@
import os
import sys
-from typing import Any, Callable, Dict, List, Optional, Sequence, Union
+from typing import Callable, Dict, List, Optional, Sequence, Union
import numpy as np
@@ -57,7 +57,7 @@ class MedNISTDataset(Randomizable, CacheDataset):
"""
- resource = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz?dl=1"
+ resource = "https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE"
md5 = "0bc7306e7427e00ad1c5526a6677552d"
compressed_file_name = "MedNIST.tar.gz"
dataset_folder_name = "MedNIST"
@@ -98,8 +98,8 @@ def __init__(
self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers
)
- def randomize(self, data: Optional[Any] = None) -> None:
- self.rann = self.R.random()
+ def randomize(self, data: List[int]) -> None:
+ self.R.shuffle(data)
def get_num_classes(self) -> int:
"""Get number of classes."""
@@ -128,27 +128,32 @@ def _generate_data_list(self, dataset_dir: str) -> List[Dict]:
image_files_list.extend(image_files[i])
image_class.extend([i] * num_each[i])
class_name.extend([class_names[i]] * num_each[i])
- num_total = len(image_class)
-
- data = []
-
- for i in range(num_total):
- self.randomize()
- if self.section == "training":
- if self.rann < self.val_frac + self.test_frac:
- continue
- elif self.section == "validation":
- if self.rann >= self.val_frac:
- continue
- elif self.section == "test":
- if self.rann < self.val_frac or self.rann >= self.val_frac + self.test_frac:
- continue
- else:
- raise ValueError(
- f'Unsupported section: {self.section}, available options are ["training", "validation", "test"].'
- )
- data.append({"image": image_files_list[i], "label": image_class[i], "class_name": class_name[i]})
- return data
+
+ length = len(image_files_list)
+ indices = np.arange(length)
+ self.randomize(indices)
+
+ test_length = int(length * self.test_frac)
+ val_length = int(length * self.val_frac)
+ if self.section == "test":
+ section_indices = indices[:test_length]
+ elif self.section == "validation":
+ section_indices = indices[test_length : test_length + val_length]
+ elif self.section == "training":
+ section_indices = indices[test_length + val_length :]
+ else:
+ raise ValueError(
+ f'Unsupported section: {self.section}, available options are ["training", "validation", "test"].'
+ )
+
+ return [
+ {
+ "image": image_files_list[i],
+ "label": image_class[i],
+ "class_name": class_name[i],
+ }
+ for i in section_indices
+ ]
class DecathlonDataset(Randomizable, CacheDataset):
diff --git a/monai/apps/deepgrow/dataset.py b/monai/apps/deepgrow/dataset.py
index 45cfbde6ea..acaeba0bc3 100644
--- a/monai/apps/deepgrow/dataset.py
+++ b/monai/apps/deepgrow/dataset.py
@@ -22,7 +22,7 @@
def create_dataset(
datalist,
output_dir: str,
- dimension,
+ dimension: int,
pixdim,
image_key: str = "image",
label_key: str = "label",
@@ -138,7 +138,7 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
if len(vol_image.shape) == 4:
logging.info(
"4D-Image, pick only first series; Image: {}; Label: {}".format(
- vol_image.shape, vol_label.shape if vol_label else None
+ vol_image.shape, vol_label.shape if vol_label is not None else None
)
)
vol_image = vol_image[0]
@@ -216,7 +216,7 @@ def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
if len(vol_image.shape) == 4:
logging.info(
"4D-Image, pick only first series; Image: {}; Label: {}".format(
- vol_image.shape, vol_label.shape if vol_label else None
+ vol_image.shape, vol_label.shape if vol_label is not None else None
)
)
vol_image = vol_image[0]
diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py
index 77e271a9eb..81e82c958d 100644
--- a/monai/apps/deepgrow/interaction.py
+++ b/monai/apps/deepgrow/interaction.py
@@ -12,15 +12,16 @@
import torch
+from monai.data import decollate_batch, list_data_collate
from monai.engines import SupervisedEvaluator, SupervisedTrainer
-from monai.engines.utils import CommonKeys
-from monai.engines.workflow import Events
+from monai.engines.utils import IterationEvents
from monai.transforms import Compose
+from monai.utils.enums import CommonKeys
class Interaction:
"""
- Ignite handler used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation.
+ Ignite process_function used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation.
This implementation is based on:
Sakinis et al., Interactive segmentation of medical images through
@@ -50,10 +51,6 @@ def __init__(
self.train = train
self.key_probability = key_probability
- def attach(self, engine: Union[SupervisedTrainer, SupervisedEvaluator]) -> None:
- if not engine.has_event_handler(self, Events.ITERATION_STARTED):
- engine.add_event_handler(Events.ITERATION_STARTED, self)
-
def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchdata: Dict[str, torch.Tensor]):
if batchdata is None:
raise ValueError("Must provide batch data for current iteration.")
@@ -62,6 +59,8 @@ def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchd
inputs, _ = engine.prepare_batch(batchdata)
inputs = inputs.to(engine.state.device)
+ engine.fire_event(IterationEvents.INNER_ITERATION_STARTED)
+
engine.network.eval()
with torch.no_grad():
if engine.amp:
@@ -70,10 +69,19 @@ def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchd
else:
predictions = engine.inferer(inputs, engine.network)
+ engine.fire_event(IterationEvents.INNER_ITERATION_COMPLETED)
+
batchdata.update({CommonKeys.PRED: predictions})
- batchdata[self.key_probability] = torch.as_tensor(
- ([1.0 - ((1.0 / self.max_interactions) * j)] if self.train else [1.0]) * len(inputs)
- )
- batchdata = self.transforms(batchdata)
+
+ # decollate batch data to execute click transforms
+ batchdata_list = decollate_batch(batchdata, detach=True)
+ for i in range(len(batchdata_list)):
+ batchdata_list[i][self.key_probability] = (
+ (1.0 - ((1.0 / self.max_interactions) * j)) if self.train else 1.0
+ )
+ batchdata_list[i] = self.transforms(batchdata_list[i])
+
+ # collate list into a batch for next round interaction
+ batchdata = list_data_collate(batchdata_list)
return engine._iteration(engine, batchdata)
diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py
index f178360031..db450792b0 100644
--- a/monai/apps/deepgrow/transforms.py
+++ b/monai/apps/deepgrow/transforms.py
@@ -8,18 +8,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from typing import Callable, Optional, Sequence, Union
+import json
+from typing import Callable, Dict, Optional, Sequence, Union
import numpy as np
import torch
from monai.config import IndexSelection, KeysCollection
from monai.networks.layers import GaussianFilter
-from monai.transforms import SpatialCrop
-from monai.transforms.compose import MapTransform, Randomizable, Transform
+from monai.transforms import Resize, SpatialCrop
+from monai.transforms.transform import MapTransform, Randomizable, Transform
from monai.transforms.utils import generate_spatial_bounding_box
-from monai.utils import min_version, optional_import
+from monai.utils import InterpolateMode, ensure_tuple, ensure_tuple_rep, min_version, optional_import
measure, _ = optional_import("skimage.measure", "0.14.2", min_version)
distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt")
@@ -48,7 +48,7 @@ def _apply(self, label):
return np.asarray(sids)
def __call__(self, data):
- d = dict(data)
+ d: Dict = dict(data)
label = d[self.label]
if label.shape[0] != 1:
raise ValueError("Only supports single channel labels!")
@@ -67,7 +67,8 @@ class AddInitialSeedPointd(Randomizable, Transform):
Add random guidance as initial seed point for a given label.
Note that the label is of size (C, D, H, W) or (C, H, W)
- The guidance is of size (2, N, # of dims) where N is number of guidance added
+
+ The guidance is of size (2, N, # of dims) where N is number of guidance added.
# of dims = 4 when C, D, H, W; # of dims = 3 when (C, H, W)
Args:
@@ -87,13 +88,21 @@ def __init__(
connected_regions: int = 5,
):
self.label = label
- self.sids = sids
- self.sid = sid
+ self.sids_key = sids
+ self.sid_key = sid
+ self.sid = None
self.guidance = guidance
self.connected_regions = connected_regions
- def randomize(self, data=None):
- pass
+ def randomize(self, data):
+ sid = data.get(self.sid_key, None)
+ sids = data.get(self.sids_key, None)
+ if sids is not None:
+ if sid is None or sid not in sids:
+ sid = self.R.choice(sids, replace=False)
+ else:
+ sid = None
+ self.sid = sid
def _apply(self, label, sid):
dimensions = 3 if len(label.shape) > 3 else 2
@@ -106,7 +115,8 @@ def _apply(self, label, sid):
label = (label > 0.5).astype(np.float32)
blobs_labels = measure.label(label.astype(int), background=0) if dims == 2 else label
- assert np.max(blobs_labels) > 0, "Not a valid Label"
+ if np.max(blobs_labels) <= 0:
+ raise AssertionError("Not a valid Label")
pos_guidance = []
for ridx in range(1, 2 if dims == 3 else self.connected_regions + 1):
@@ -134,14 +144,8 @@ def _apply(self, label, sid):
def __call__(self, data):
d = dict(data)
- sid = d.get(self.sid, None)
- sids = d.get(self.sids, None)
- if sids is not None:
- if sid is None or sid not in sids:
- sid = self.R.choice(sids, replace=False)
- else:
- sid = None
- d[self.guidance] = self._apply(d[self.label], sid)
+ self.randomize(data)
+ d[self.guidance] = json.dumps(self._apply(d[self.label], self.sid).astype(int).tolist())
return d
@@ -156,7 +160,7 @@ class AddGuidanceSignald(Transform):
guidance: key to store guidance.
sigma: standard deviation for Gaussian kernel.
number_intensity_ch: channel index.
- batched: whether input is batched or not.
+
"""
def __init__(
@@ -165,25 +169,24 @@ def __init__(
guidance: str = "guidance",
sigma: int = 2,
number_intensity_ch: int = 1,
- batched: bool = False,
):
self.image = image
self.guidance = guidance
self.sigma = sigma
self.number_intensity_ch = number_intensity_ch
- self.batched = batched
def _get_signal(self, image, guidance):
dimensions = 3 if len(image.shape) > 3 else 2
guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance
+ guidance = json.loads(guidance) if isinstance(guidance, str) else guidance
if dimensions == 3:
signal = np.zeros((len(guidance), image.shape[-3], image.shape[-2], image.shape[-1]), dtype=np.float32)
else:
signal = np.zeros((len(guidance), image.shape[-2], image.shape[-1]), dtype=np.float32)
sshape = signal.shape
- for i in range(len(guidance)):
- for point in guidance[i]:
+ for i, g_i in enumerate(guidance):
+ for point in g_i:
if np.any(np.asarray(point) < 0):
continue
@@ -207,16 +210,9 @@ def _get_signal(self, image, guidance):
return signal
def _apply(self, image, guidance):
- if not self.batched:
- signal = self._get_signal(image, guidance)
- return np.concatenate([image, signal], axis=0)
-
- images = []
- for i, g in zip(image, guidance):
- i = i[0 : 0 + self.number_intensity_ch, ...]
- signal = self._get_signal(i, g)
- images.append(np.concatenate([i, signal], axis=0))
- return images
+ signal = self._get_signal(image, guidance)
+ image = image[0 : 0 + self.number_intensity_ch, ...]
+ return np.concatenate([image, signal], axis=0)
def __call__(self, data):
d = dict(data)
@@ -231,25 +227,17 @@ class FindDiscrepancyRegionsd(Transform):
"""
Find discrepancy between prediction and actual during click interactions during training.
- If batched is true:
- label is in shape (B, C, D, H, W) or (B, C, H, W)
- pred has same shape as label
- discrepancy will have shape (B, 2, C, D, H, W) or (B, 2, C, H, W)
-
Args:
label: key to label source.
pred: key to prediction source.
discrepancy: key to store discrepancies found between label and prediction.
- batched: whether input is batched or not.
+
"""
- def __init__(
- self, label: str = "label", pred: str = "pred", discrepancy: str = "discrepancy", batched: bool = True
- ):
+ def __init__(self, label: str = "label", pred: str = "pred", discrepancy: str = "discrepancy"):
self.label = label
self.pred = pred
self.discrepancy = discrepancy
- self.batched = batched
@staticmethod
def disparity(label, pred):
@@ -262,13 +250,7 @@ def disparity(label, pred):
return [pos_disparity, neg_disparity]
def _apply(self, label, pred):
- if not self.batched:
- return self.disparity(label, pred)
-
- disparity = []
- for la, pr in zip(label, pred):
- disparity.append(self.disparity(la, pr))
- return disparity
+ return self.disparity(label, pred)
def __call__(self, data):
d = dict(data)
@@ -282,22 +264,16 @@ def __call__(self, data):
class AddRandomGuidanced(Randomizable, Transform):
"""
Add random guidance based on discrepancies that were found between label and prediction.
-
- If batched is True:
-
- Guidance is of shape (B, 2, N, # of dim) where B is batch size, 2 means positive and negative,
- N means how many guidance points, # of dim is the total number of dimensions of the image
- (for example if the image is CDHW, then # of dim would be 4).
-
- Discrepancy is of shape (B, 2, C, D, H, W) or (B, 2, C, H, W)
-
- Probability is of shape (B,)
+ input shape is as below:
+ Guidance is of shape (2, N, # of dim)
+ Discrepancy is of shape (2, C, D, H, W) or (2, C, H, W)
+ Probability is of shape (1)
Args:
guidance: key to guidance source.
discrepancy: key that represents discrepancies found between label and prediction.
probability: key that represents click/interaction probability.
- batched: whether input is batched or not.
+
"""
def __init__(
@@ -305,22 +281,15 @@ def __init__(
guidance: str = "guidance",
discrepancy: str = "discrepancy",
probability: str = "probability",
- batched: bool = True,
):
self.guidance = guidance
self.discrepancy = discrepancy
self.probability = probability
- self.batched = batched
self._will_interact = None
def randomize(self, data=None):
probability = data[self.probability]
- if not self.batched:
- self._will_interact = self.R.choice([True, False], p=[probability, 1.0 - probability])
- else:
- self._will_interact = []
- for p in probability:
- self._will_interact.append(self.R.choice([True, False], p=[p, 1.0 - p]))
+ self._will_interact = self.R.choice([True, False], p=[probability, 1.0 - probability])
def find_guidance(self, discrepancy):
distance = distance_transform_cdt(discrepancy).flatten()
@@ -356,24 +325,16 @@ def add_guidance(self, discrepancy, will_interact):
def _apply(self, guidance, discrepancy):
guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance
- if not self.batched:
- pos, neg = self.add_guidance(discrepancy, self._will_interact)
- if pos:
- guidance[0].append(pos)
- guidance[1].append([-1] * len(pos))
- if neg:
- guidance[0].append([-1] * len(neg))
- guidance[1].append(neg)
- else:
- for g, d, w in zip(guidance, discrepancy, self._will_interact):
- pos, neg = self.add_guidance(d, w)
- if pos:
- g[0].append(pos)
- g[1].append([-1] * len(pos))
- if neg:
- g[0].append([-1] * len(neg))
- g[1].append(neg)
- return np.asarray(guidance)
+ guidance = json.loads(guidance) if isinstance(guidance, str) else guidance
+ pos, neg = self.add_guidance(discrepancy, self._will_interact)
+ if pos:
+ guidance[0].append(pos)
+ guidance[1].append([-1] * len(pos))
+ if neg:
+ guidance[0].append([-1] * len(neg))
+ guidance[1].append(neg)
+
+ return json.dumps(np.asarray(guidance).astype(int).tolist())
def __call__(self, data):
d = dict(data)
@@ -389,7 +350,7 @@ class SpatialCropForegroundd(MapTransform):
"""
Crop only the foreground object of the expected images.
- Difference VS CropForegroundd:
+ Difference VS :py:class:`monai.transforms.CropForegroundd`:
1. If the bounding box is smaller than spatial size in all dimensions then this transform will crop the
object using box's center and spatial_size.
@@ -399,9 +360,11 @@ class SpatialCropForegroundd(MapTransform):
The typical usage is to help training and evaluation if the valid part is small in the whole medical image.
The valid part can be determined by any field in the data with `source_key`, for example:
+
- Select values > 0 in image field as the foreground and crop on all fields specified by `keys`.
- Select label = 3 in label field as the foreground to crop on all fields specified by `keys`.
- Select label > 0 in the third channel of a One-Hot label field as the foreground to crop all `keys` fields.
+
Users can define arbitrary function to select expected foreground from the whole source image or specified
channels. And it can also add margin to every dim of the bounding box of foreground object.
@@ -414,14 +377,20 @@ class SpatialCropForegroundd(MapTransform):
channel_indices: if defined, select foreground only on the specified channels
of image. if None, select foreground on the whole image.
margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims.
- meta_key_postfix: use `{key}_{meta_key_postfix}` to to fetch/store the meta data according to the key data,
- default is `meta_dict`, the meta data is a dictionary object.
+ meta_keys: explicitly indicate the key of the corresponding meta data dictionary.
+ for example, for data with key `image`, the metadata by default is in `image_meta_dict`.
+ the meta data is a dictionary object which contains: filename, original_shape, etc.
+ it can be a sequence of string, map to the `keys`.
+ if None, will try to construct meta_keys by `key_{meta_key_postfix}`.
+ meta_key_postfix: if meta_keys is None, use `{key}_{meta_key_postfix}` to to fetch/store the meta data according
+ to the key data, default is `meta_dict`, the meta data is a dictionary object.
For example, to handle key `image`, read/write affine matrices from the
metadata `image_meta_dict` dictionary's `affine` field.
start_coord_key: key to record the start coordinate of spatial bounding box for foreground.
end_coord_key: key to record the end coordinate of spatial bounding box for foreground.
original_shape_key: key to record original shape for foreground.
cropped_shape_key: key to record cropped shape for foreground.
+ allow_missing_keys: don't raise exception if key is missing.
"""
def __init__(
@@ -432,20 +401,25 @@ def __init__(
select_fn: Callable = lambda x: x > 0,
channel_indices: Optional[IndexSelection] = None,
margin: int = 0,
+ meta_keys: Optional[KeysCollection] = None,
meta_key_postfix="meta_dict",
start_coord_key: str = "foreground_start_coord",
end_coord_key: str = "foreground_end_coord",
original_shape_key: str = "foreground_original_shape",
cropped_shape_key: str = "foreground_cropped_shape",
+ allow_missing_keys: bool = False,
) -> None:
- super().__init__(keys)
+ super().__init__(keys, allow_missing_keys)
self.source_key = source_key
self.spatial_size = list(spatial_size)
self.select_fn = select_fn
self.channel_indices = channel_indices
self.margin = margin
- self.meta_key_postfix = meta_key_postfix
+ self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)
+ if len(self.keys) != len(self.meta_keys):
+ raise ValueError("meta_keys should have the same length as keys.")
+ self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
self.start_coord_key = start_coord_key
self.end_coord_key = end_coord_key
self.original_shape_key = original_shape_key
@@ -457,18 +431,254 @@ def __call__(self, data):
d[self.source_key], self.select_fn, self.channel_indices, self.margin
)
- center = np.mean([box_start, box_end], axis=0).astype(int).tolist()
- current_size = np.subtract(box_end, box_start).astype(int).tolist()
+ center = list(np.mean([box_start, box_end], axis=0).astype(int))
+ current_size = list(np.subtract(box_end, box_start).astype(int))
if np.all(np.less(current_size, self.spatial_size)):
cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size)
- box_start = cropper.roi_start
- box_end = cropper.roi_end
+ box_start = np.array([s.start for s in cropper.slices])
+ box_end = np.array([s.stop for s in cropper.slices])
+ else:
+ cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)
+
+ for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
+ meta_key = meta_key or f"{key}_{meta_key_postfix}"
+ d[meta_key][self.start_coord_key] = box_start
+ d[meta_key][self.end_coord_key] = box_end
+ d[meta_key][self.original_shape_key] = d[key].shape
+
+ image = cropper(d[key])
+ d[meta_key][self.cropped_shape_key] = image.shape
+ d[key] = image
+ return d
+
+
+# Transforms to support Inference for Deepgrow models
+class AddGuidanceFromPointsd(Transform):
+ """
+ Add guidance based on user clicks.
+
+ We assume the input is loaded by LoadImaged and has the shape of (H, W, D) originally.
+ Clicks always specify the coordinates in (H, W, D)
+
+ If depth_first is True:
+
+ Input is now of shape (D, H, W), will return guidance that specifies the coordinates in (D, H, W)
+
+ else:
+
+ Input is now of shape (H, W, D), will return guidance that specifies the coordinates in (H, W, D)
+
+ Args:
+ ref_image: key to reference image to fetch current and original image details.
+ guidance: output key to store guidance.
+ foreground: key that represents user foreground (+ve) clicks.
+ background: key that represents user background (-ve) clicks.
+ axis: axis that represents slices in 3D volume. (axis to Depth)
+ depth_first: if depth (slices) is positioned at first dimension.
+ dimensions: dimensions based on model used for deepgrow (2D vs 3D).
+ slice_key: key that represents applicable slice to add guidance.
+ meta_keys: explicitly indicate the key of the meta data dictionary of `ref_image`.
+ for example, for data with key `image`, the metadata by default is in `image_meta_dict`.
+ the meta data is a dictionary object which contains: filename, original_shape, etc.
+ if None, will try to construct meta_keys by `{ref_image}_{meta_key_postfix}`.
+ meta_key_postfix: if meta_key is None, use `{ref_image}_{meta_key_postfix}` to to fetch the meta data according
+ to the key data, default is `meta_dict`, the meta data is a dictionary object.
+ For example, to handle key `image`, read/write affine matrices from the
+ metadata `image_meta_dict` dictionary's `affine` field.
+ """
+
+ def __init__(
+ self,
+ ref_image,
+ guidance: str = "guidance",
+ foreground: str = "foreground",
+ background: str = "background",
+ axis: int = 0,
+ depth_first: bool = True,
+ dimensions: int = 2,
+ slice_key: str = "slice",
+ meta_keys: Optional[str] = None,
+ meta_key_postfix: str = "meta_dict",
+ ):
+ self.ref_image = ref_image
+ self.guidance = guidance
+ self.foreground = foreground
+ self.background = background
+ self.axis = axis
+ self.depth_first = depth_first
+ self.dimensions = dimensions
+ self.slice = slice_key
+ self.meta_keys = meta_keys
+ self.meta_key_postfix = meta_key_postfix
+
+ def _apply(self, pos_clicks, neg_clicks, factor, slice_num):
+ pos = neg = []
+
+ if self.dimensions == 2:
+ points = list(pos_clicks)
+ points.extend(neg_clicks)
+ points = np.array(points)
+
+ slices = list(np.unique(points[:, self.axis]))
+ slice_idx = slices[0] if slice_num is None else next(x for x in slices if x == slice_num)
+
+ if len(pos_clicks):
+ pos_clicks = np.array(pos_clicks)
+ pos = (pos_clicks[np.where(pos_clicks[:, self.axis] == slice_idx)] * factor)[:, 1:].astype(int).tolist()
+ if len(neg_clicks):
+ neg_clicks = np.array(neg_clicks)
+ neg = (neg_clicks[np.where(neg_clicks[:, self.axis] == slice_idx)] * factor)[:, 1:].astype(int).tolist()
+
+ guidance = [pos, neg, slice_idx]
+ else:
+ if len(pos_clicks):
+ pos = np.multiply(pos_clicks, factor).astype(int).tolist()
+ if len(neg_clicks):
+ neg = np.multiply(neg_clicks, factor).astype(int).tolist()
+ guidance = [pos, neg]
+ return guidance
+
+ def __call__(self, data):
+ d = dict(data)
+ meta_dict_key = self.meta_keys or f"{self.ref_image}_{self.meta_key_postfix}"
+ if meta_dict_key not in d:
+ raise RuntimeError(f"Missing meta_dict {meta_dict_key} in data!")
+ if "spatial_shape" not in d[meta_dict_key]:
+ raise RuntimeError('Missing "spatial_shape" in meta_dict!')
+ original_shape = d[meta_dict_key]["spatial_shape"]
+ current_shape = list(d[self.ref_image].shape)
+
+ if self.depth_first:
+ if self.axis != 0:
+ raise RuntimeError("Depth first means the depth axis should be 0.")
+ # in here we assume the depth dimension was in the last dimension of "original_shape"
+ original_shape = np.roll(original_shape, 1)
+
+ factor = np.array(current_shape) / original_shape
+
+ fg_bg_clicks = []
+ for key in [self.foreground, self.background]:
+ clicks = d[key]
+ clicks = list(np.array(clicks).astype(int))
+ if self.depth_first:
+ for i in range(len(clicks)):
+ clicks[i] = list(np.roll(clicks[i], 1))
+ fg_bg_clicks.append(clicks)
+ d[self.guidance] = self._apply(fg_bg_clicks[0], fg_bg_clicks[1], factor, d.get(self.slice))
+ return d
+
+
+class SpatialCropGuidanced(MapTransform):
+ """
+ Crop image based on guidance with minimal spatial size.
+
+ - If the bounding box is smaller than spatial size in all dimensions then this transform will crop the
+ object using box's center and spatial_size.
+
+ - This transform will set "start_coord_key", "end_coord_key", "original_shape_key" and "cropped_shape_key"
+ in data[{key}_{meta_key_postfix}]
+
+ Input data is of shape (C, spatial_1, [spatial_2, ...])
+
+ Args:
+ keys: keys of the corresponding items to be transformed.
+ guidance: key to the guidance. It is used to generate the bounding box of foreground
+ spatial_size: minimal spatial size of the image patch e.g. [128, 128, 128] to fit in.
+ margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims.
+ meta_keys: explicitly indicate the key of the corresponding meta data dictionary.
+ for example, for data with key `image`, the metadata by default is in `image_meta_dict`.
+ the meta data is a dictionary object which contains: filename, original_shape, etc.
+ it can be a sequence of string, map to the `keys`.
+ if None, will try to construct meta_keys by `key_{meta_key_postfix}`.
+ meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according
+ to the key data, default is `meta_dict`, the meta data is a dictionary object.
+ For example, to handle key `image`, read/write affine matrices from the
+ metadata `image_meta_dict` dictionary's `affine` field.
+ start_coord_key: key to record the start coordinate of spatial bounding box for foreground.
+ end_coord_key: key to record the end coordinate of spatial bounding box for foreground.
+ original_shape_key: key to record original shape for foreground.
+ cropped_shape_key: key to record cropped shape for foreground.
+ allow_missing_keys: don't raise exception if key is missing.
+ """
+
+ def __init__(
+ self,
+ keys: KeysCollection,
+ guidance: str,
+ spatial_size,
+ margin=20,
+ meta_keys: Optional[KeysCollection] = None,
+ meta_key_postfix="meta_dict",
+ start_coord_key: str = "foreground_start_coord",
+ end_coord_key: str = "foreground_end_coord",
+ original_shape_key: str = "foreground_original_shape",
+ cropped_shape_key: str = "foreground_cropped_shape",
+ allow_missing_keys: bool = False,
+ ) -> None:
+ super().__init__(keys, allow_missing_keys)
+
+ self.guidance = guidance
+ self.spatial_size = list(spatial_size)
+ self.margin = margin
+ self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)
+ if len(self.keys) != len(self.meta_keys):
+ raise ValueError("meta_keys should have the same length as keys.")
+ self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
+ self.start_coord_key = start_coord_key
+ self.end_coord_key = end_coord_key
+ self.original_shape_key = original_shape_key
+ self.cropped_shape_key = cropped_shape_key
+
+ def bounding_box(self, points, img_shape):
+ ndim = len(img_shape)
+ margin = ensure_tuple_rep(self.margin, ndim)
+ for m in margin:
+ if m < 0:
+ raise ValueError("margin value should not be negative number.")
+
+ box_start = [0] * ndim
+ box_end = [0] * ndim
+
+ for di in range(ndim):
+ dt = points[..., di]
+ min_d = max(min(dt - margin[di]), 0)
+ max_d = min(img_shape[di], max(dt + margin[di] + 1))
+ box_start[di], box_end[di] = min_d, max_d
+ return box_start, box_end
+
+ def __call__(self, data):
+ d: Dict = dict(data)
+ guidance = d[self.guidance]
+ original_spatial_shape = d[self.keys[0]].shape[1:]
+ box_start, box_end = self.bounding_box(np.array(guidance[0] + guidance[1]), original_spatial_shape)
+ center = list(np.mean([box_start, box_end], axis=0).astype(int))
+ spatial_size = self.spatial_size
+
+ box_size = list(np.subtract(box_end, box_start).astype(int))
+ spatial_size = spatial_size[-len(box_size) :]
+
+ if len(spatial_size) < len(box_size):
+ # If the data is in 3D and spatial_size is specified as 2D [256,256]
+ # Then we will get all slices in such case
+ diff = len(box_size) - len(spatial_size)
+ spatial_size = list(original_spatial_shape[1 : (1 + diff)]) + spatial_size
+
+ if np.all(np.less(box_size, spatial_size)):
+ if len(center) == 3:
+ # 3D Deepgrow: set center to be middle of the depth dimension (D)
+ center[0] = spatial_size[0] // 2
+ cropper = SpatialCrop(roi_center=center, roi_size=spatial_size)
else:
cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)
- for key in self.keys:
- meta_key = f"{key}_{self.meta_key_postfix}"
+ # update bounding box in case it was corrected by the SpatialCrop constructor
+ box_start = np.array([s.start for s in cropper.slices])
+ box_end = np.array([s.stop for s in cropper.slices])
+ for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
+ if not np.array_equal(d[key].shape[1:], original_spatial_shape):
+ raise RuntimeError("All the image specified in keys should have same spatial shape")
+ meta_key = meta_key or f"{key}_{meta_key_postfix}"
d[meta_key][self.start_coord_key] = box_start
d[meta_key][self.end_coord_key] = box_end
d[meta_key][self.original_shape_key] = d[key].shape
@@ -476,4 +686,255 @@ def __call__(self, data):
image = cropper(d[key])
d[meta_key][self.cropped_shape_key] = image.shape
d[key] = image
+
+ pos_clicks, neg_clicks = guidance[0], guidance[1]
+ pos = np.subtract(pos_clicks, box_start).tolist() if len(pos_clicks) else []
+ neg = np.subtract(neg_clicks, box_start).tolist() if len(neg_clicks) else []
+
+ d[self.guidance] = [pos, neg]
+ return d
+
+
+class ResizeGuidanced(Transform):
+ """
+ Resize the guidance based on cropped vs resized image.
+
+ This transform assumes that the images have been cropped and resized. And the shape after cropped is store inside
+ the meta dict of ref image.
+
+ Args:
+ guidance: key to guidance
+ ref_image: key to reference image to fetch current and original image details
+ meta_keys: explicitly indicate the key of the meta data dictionary of `ref_image`.
+ for example, for data with key `image`, the metadata by default is in `image_meta_dict`.
+ the meta data is a dictionary object which contains: filename, original_shape, etc.
+ if None, will try to construct meta_keys by `{ref_image}_{meta_key_postfix}`.
+ meta_key_postfix: if meta_key is None, use `{ref_image}_{meta_key_postfix}` to to fetch the meta data according
+ to the key data, default is `meta_dict`, the meta data is a dictionary object.
+ For example, to handle key `image`, read/write affine matrices from the
+ metadata `image_meta_dict` dictionary's `affine` field.
+ cropped_shape_key: key that records cropped shape for foreground.
+ """
+
+ def __init__(
+ self,
+ guidance: str,
+ ref_image: str,
+ meta_keys: Optional[str] = None,
+ meta_key_postfix: str = "meta_dict",
+ cropped_shape_key: str = "foreground_cropped_shape",
+ ) -> None:
+ self.guidance = guidance
+ self.ref_image = ref_image
+ self.meta_keys = meta_keys
+ self.meta_key_postfix = meta_key_postfix
+ self.cropped_shape_key = cropped_shape_key
+
+ def __call__(self, data):
+ d = dict(data)
+ guidance = d[self.guidance]
+ meta_dict: Dict = d[self.meta_keys or f"{self.ref_image}_{self.meta_key_postfix}"]
+ current_shape = d[self.ref_image].shape[1:]
+ cropped_shape = meta_dict[self.cropped_shape_key][1:]
+ factor = np.divide(current_shape, cropped_shape)
+
+ pos_clicks, neg_clicks = guidance[0], guidance[1]
+ pos = np.multiply(pos_clicks, factor).astype(int).tolist() if len(pos_clicks) else []
+ neg = np.multiply(neg_clicks, factor).astype(int).tolist() if len(neg_clicks) else []
+
+ d[self.guidance] = [pos, neg]
+ return d
+
+
+class RestoreLabeld(MapTransform):
+ """
+ Restores label based on the ref image.
+
+ The ref_image is assumed that it went through the following transforms:
+
+ 1. Fetch2DSliced (If 2D)
+ 2. Spacingd
+ 3. SpatialCropGuidanced
+ 4. Resized
+
+ And its shape is assumed to be (C, D, H, W)
+
+ This transform tries to undo these operation so that the result label can be overlapped with original volume.
+ It does the following operation:
+
+ 1. Undo Resized
+ 2. Undo SpatialCropGuidanced
+ 3. Undo Spacingd
+ 4. Undo Fetch2DSliced
+
+ The resulting label is of shape (D, H, W)
+
+ Args:
+ keys: keys of the corresponding items to be transformed.
+ ref_image: reference image to fetch current and original image details
+ slice_only: apply only to an applicable slice, in case of 2D model/prediction
+ mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
+ ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
+ One of the listed string values or a user supplied function for padding. Defaults to ``"constant"``.
+ See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
+ align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
+ See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
+ It also can be a sequence of bool, each element corresponds to a key in ``keys``.
+ meta_keys: explicitly indicate the key of the corresponding meta data dictionary.
+ for example, for data with key `image`, the metadata by default is in `image_meta_dict`.
+ the meta data is a dictionary object which contains: filename, original_shape, etc.
+ it can be a sequence of string, map to the `keys`.
+ if None, will try to construct meta_keys by `key_{meta_key_postfix}`.
+ meta_key_postfix: if meta_key is None, use `key_{meta_key_postfix} to to fetch the meta data according
+ to the key data, default is `meta_dict`, the meta data is a dictionary object.
+ For example, to handle key `image`, read/write affine matrices from the
+ metadata `image_meta_dict` dictionary's `affine` field.
+ start_coord_key: key that records the start coordinate of spatial bounding box for foreground.
+ end_coord_key: key that records the end coordinate of spatial bounding box for foreground.
+ original_shape_key: key that records original shape for foreground.
+ cropped_shape_key: key that records cropped shape for foreground.
+ allow_missing_keys: don't raise exception if key is missing.
+ """
+
+ def __init__(
+ self,
+ keys: KeysCollection,
+ ref_image: str,
+ slice_only: bool = False,
+ mode: Union[Sequence[Union[InterpolateMode, str]], InterpolateMode, str] = InterpolateMode.NEAREST,
+ align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None,
+ meta_keys: Optional[str] = None,
+ meta_key_postfix: str = "meta_dict",
+ start_coord_key: str = "foreground_start_coord",
+ end_coord_key: str = "foreground_end_coord",
+ original_shape_key: str = "foreground_original_shape",
+ cropped_shape_key: str = "foreground_cropped_shape",
+ allow_missing_keys: bool = False,
+ ) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.ref_image = ref_image
+ self.slice_only = slice_only
+ self.mode = ensure_tuple_rep(mode, len(self.keys))
+ self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))
+ self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)
+ if len(self.keys) != len(self.meta_keys):
+ raise ValueError("meta_keys should have the same length as keys.")
+ self.meta_key_postfix = meta_key_postfix
+ self.start_coord_key = start_coord_key
+ self.end_coord_key = end_coord_key
+ self.original_shape_key = original_shape_key
+ self.cropped_shape_key = cropped_shape_key
+
+ def __call__(self, data):
+ d = dict(data)
+ meta_dict: Dict = d[f"{self.ref_image}_{self.meta_key_postfix}"]
+
+ for key, mode, align_corners, meta_key in self.key_iterator(d, self.mode, self.align_corners, self.meta_keys):
+ image = d[key]
+
+ # Undo Resize
+ current_shape = image.shape
+ cropped_shape = meta_dict[self.cropped_shape_key]
+ if np.any(np.not_equal(current_shape, cropped_shape)):
+ resizer = Resize(spatial_size=cropped_shape[1:], mode=mode)
+ image = resizer(image, mode=mode, align_corners=align_corners)
+
+ # Undo Crop
+ original_shape = meta_dict[self.original_shape_key]
+ result = np.zeros(original_shape, dtype=np.float32)
+ box_start = meta_dict[self.start_coord_key]
+ box_end = meta_dict[self.end_coord_key]
+
+ spatial_dims = min(len(box_start), len(image.shape[1:]))
+ slices = [slice(None)] + [slice(s, e) for s, e in zip(box_start[:spatial_dims], box_end[:spatial_dims])]
+ slices = tuple(slices)
+ result[slices] = image
+
+ # Undo Spacing
+ current_size = result.shape[1:]
+ # change spatial_shape from HWD to DHW
+ spatial_shape = list(np.roll(meta_dict["spatial_shape"], 1))
+ spatial_size = spatial_shape[-len(current_size) :]
+
+ if np.any(np.not_equal(current_size, spatial_size)):
+ resizer = Resize(spatial_size=spatial_size, mode=mode)
+ result = resizer(result, mode=mode, align_corners=align_corners)
+
+ # Undo Slicing
+ slice_idx = meta_dict.get("slice_idx")
+ if slice_idx is None or self.slice_only:
+ final_result = result if len(result.shape) <= 3 else result[0]
+ else:
+ slice_idx = meta_dict["slice_idx"][0]
+ final_result = np.zeros(tuple(spatial_shape))
+ final_result[slice_idx] = result
+ d[key] = final_result
+
+ meta_key = meta_key or f"{key}_{self.meta_key_postfix}"
+ meta = d.get(meta_key)
+ if meta is None:
+ meta = dict()
+ d[meta_key] = meta
+ meta["slice_idx"] = slice_idx
+ meta["affine"] = meta_dict["original_affine"]
+ return d
+
+
+class Fetch2DSliced(MapTransform):
+ """
+ Fetch one slice in case of a 3D volume.
+
+ The volume only contains spatial coordinates.
+
+ Args:
+ keys: keys of the corresponding items to be transformed.
+ guidance: key that represents guidance.
+ axis: axis that represents slice in 3D volume.
+ meta_keys: explicitly indicate the key of the corresponding meta data dictionary.
+ for example, for data with key `image`, the metadata by default is in `image_meta_dict`.
+ the meta data is a dictionary object which contains: filename, original_shape, etc.
+ it can be a sequence of string, map to the `keys`.
+ if None, will try to construct meta_keys by `key_{meta_key_postfix}`.
+ meta_key_postfix: use `key_{meta_key_postfix}` to to fetch the meta data according to the key data,
+ default is `meta_dict`, the meta data is a dictionary object.
+ For example, to handle key `image`, read/write affine matrices from the
+ metadata `image_meta_dict` dictionary's `affine` field.
+ allow_missing_keys: don't raise exception if key is missing.
+ """
+
+ def __init__(
+ self,
+ keys,
+ guidance="guidance",
+ axis: int = 0,
+ meta_keys: Optional[KeysCollection] = None,
+ meta_key_postfix: str = "meta_dict",
+ allow_missing_keys: bool = False,
+ ):
+ super().__init__(keys, allow_missing_keys)
+ self.guidance = guidance
+ self.axis = axis
+ self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)
+ if len(self.keys) != len(self.meta_keys):
+ raise ValueError("meta_keys should have the same length as keys.")
+ self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
+
+ def _apply(self, image, guidance):
+ slice_idx = guidance[2] # (pos, neg, slice_idx)
+ idx = []
+ for i, size_i in enumerate(image.shape):
+ idx.append(slice_idx) if i == self.axis else idx.append(slice(0, size_i))
+
+ idx = tuple(idx)
+ return image[idx], idx
+
+ def __call__(self, data):
+ d = dict(data)
+ guidance = d[self.guidance]
+ if len(guidance) < 3:
+ raise RuntimeError("Guidance does not container slice_idx!")
+ for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
+ img_slice, idx = self._apply(d[key], guidance)
+ d[key] = img_slice
+ d[meta_key or f"{key}_{meta_key_postfix}"]["slice_idx"] = idx
return d
diff --git a/monai/apps/mmars/__init__.py b/monai/apps/mmars/__init__.py
new file mode 100644
index 0000000000..396be2e87d
--- /dev/null
+++ b/monai/apps/mmars/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2020 - 2021 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .mmars import download_mmar, get_model_spec, load_from_mmar
+from .model_desc import MODEL_DESC, RemoteMMARKeys
diff --git a/monai/apps/mmars/mmars.py b/monai/apps/mmars/mmars.py
new file mode 100644
index 0000000000..e7ff28ce44
--- /dev/null
+++ b/monai/apps/mmars/mmars.py
@@ -0,0 +1,295 @@
+# Copyright 2020 - 2021 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Utilities for accessing Nvidia MMARs
+
+See Also:
+ - https://docs.nvidia.com/clara/clara-train-sdk/pt/mmar.html
+"""
+
+import json
+import os
+import warnings
+from typing import Mapping, Union
+
+import torch
+
+import monai.networks.nets as monai_nets
+from monai.apps.utils import download_and_extract
+from monai.utils.module import optional_import
+
+from .model_desc import MODEL_DESC
+from .model_desc import RemoteMMARKeys as Keys
+
+__all__ = ["get_model_spec", "download_mmar", "load_from_mmar"]
+
+
+def get_model_spec(idx: Union[int, str]):
+ """get model specification by `idx`. `idx` could be index of the constant tuple of dict or the actual model ID."""
+ if isinstance(idx, int):
+ return MODEL_DESC[idx]
+ if isinstance(idx, str):
+ key = idx.strip().lower()
+ for cand in MODEL_DESC:
+ if str(cand[Keys.ID]).strip().lower() == key:
+ return cand
+ print(f"Available specs are: {MODEL_DESC}.")
+ raise ValueError(f"Unknown MODEL_DESC request: {idx}")
+
+
+def _get_all_ngc_models(pattern, page_index=0, page_size=50):
+ url = "https://api.ngc.nvidia.com/v2/search/catalog/resources/MODEL"
+ query_dict = {
+ "query": "",
+ "orderBy": [{"field": "score", "value": "DESC"}],
+ "queryFields": ["all", "description", "displayName", "name", "resourceId"],
+ "fields": [
+ "isPublic",
+ "attributes",
+ "guestAccess",
+ "name",
+ "orgName",
+ "teamName",
+ "displayName",
+ "dateModified",
+ "labels",
+ "description",
+ ],
+ "page": 0,
+ }
+
+ filter = [dict(field="name", value=f"*{pattern}*")]
+ query_dict["page"] = page_index
+ query_dict["pageSize"] = page_size
+ query_dict["filters"] = filter
+ query_str = json.dumps(query_dict)
+ full_url = f"{url}?q={query_str}"
+ requests_get, has_requests = optional_import("requests", name="get")
+ if has_requests:
+ resp = requests_get(full_url)
+ else:
+ raise ValueError("NGC API requires requests package. Please install it.")
+ model_list = json.loads(resp.text)
+ model_dict = {}
+ for result in model_list["results"]:
+ for model in result["resources"]:
+ current_res_id = model["resourceId"]
+ model_dict[current_res_id] = {"name": model["name"]}
+ for attribute in model["attributes"]:
+ if attribute["key"] == "latestVersionIdStr":
+ model_dict[current_res_id]["latest"] = attribute["value"]
+ return model_dict
+
+
+def _get_ngc_url(model_name: str, version: str, model_prefix=""):
+ return f"https://api.ngc.nvidia.com/v2/models/{model_prefix}{model_name}/versions/{version}/zip"
+
+
+def _get_ngc_doc_url(model_name: str, model_prefix=""):
+ return f"https://ngc.nvidia.com/catalog/models/{model_prefix}{model_name}"
+
+
+def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False, version: int = -1):
+ """
+ Download and extract Medical Model Archive (MMAR) from Nvidia Clara Train.
+
+ See Also:
+ - https://docs.nvidia.com/clara/
+ - Nvidia NGC Registry CLI
+ - https://docs.nvidia.com/clara/clara-train-sdk/pt/mmar.html
+
+ Args:
+ item: the corresponding model item from `MODEL_DESC`.
+ Or when api is True, the substring to query NGC's model name field.
+ mmar_dir: target directory to store the MMAR, default is `mmars` subfolder under `torch.hub get_dir()`.
+ progress: whether to display a progress bar.
+ api: whether to query NGC and download via api
+ version: which version of MMAR to download. -1 means the latest from ngc.
+
+ Examples::
+ >>> from monai.apps import download_mmar
+ >>> download_mmar("clara_pt_prostate_mri_segmentation_1", mmar_dir=".")
+ >>> download_mmar("prostate_mri_segmentation", mmar_dir=".", api=True)
+
+
+ Returns:
+ The local directory of the downloaded model.
+ If api is True, a list of local directories of downloaded models.
+ """
+ if not mmar_dir:
+ get_dir, has_home = optional_import("torch.hub", name="get_dir")
+ if has_home:
+ mmar_dir = os.path.join(get_dir(), "mmars")
+ else:
+ raise ValueError("mmar_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?")
+
+ if api:
+ model_dict = _get_all_ngc_models(item)
+ if len(model_dict) == 0:
+ raise ValueError(f"api query returns no item for pattern {item}. Please change or shorten it.")
+ model_dir_list = []
+ for k, v in model_dict.items():
+ ver = v["latest"] if version == -1 else str(version)
+ download_url = _get_ngc_url(k, ver)
+ model_dir = os.path.join(mmar_dir, v["name"])
+ download_and_extract(
+ url=download_url,
+ filepath=os.path.join(mmar_dir, f'{v["name"]}_{ver}.zip'),
+ output_dir=model_dir,
+ hash_val=None,
+ hash_type="md5",
+ file_type="zip",
+ has_base=False,
+ progress=progress,
+ )
+ model_dir_list.append(model_dir)
+ return model_dir_list
+
+ if not isinstance(item, Mapping):
+ item = get_model_spec(item)
+
+ ver = item.get(Keys.VERSION, 1)
+ if version > 0:
+ ver = str(version)
+ model_fullname = f"{item[Keys.NAME]}_{ver}"
+ model_dir = os.path.join(mmar_dir, model_fullname)
+ model_url = item.get(Keys.URL) or _get_ngc_url(item[Keys.NAME], version=ver, model_prefix="nvidia/med/")
+ download_and_extract(
+ url=model_url,
+ filepath=os.path.join(mmar_dir, f"{model_fullname}.{item[Keys.FILE_TYPE]}"),
+ output_dir=model_dir,
+ hash_val=item[Keys.HASH_VAL],
+ hash_type=item[Keys.HASH_TYPE],
+ file_type=item[Keys.FILE_TYPE],
+ has_base=False,
+ progress=progress,
+ )
+ return model_dir
+
+
+def load_from_mmar(
+ item,
+ mmar_dir=None,
+ progress: bool = True,
+ version: int = -1,
+ map_location=None,
+ pretrained=True,
+ weights_only=False,
+ model_key: str = "model",
+):
+ """
+ Download and extract Medical Model Archive (MMAR) model weights from Nvidia Clara Train.
+
+ Args:
+ item: the corresponding model item from `MODEL_DESC`.
+ mmar_dir: : target directory to store the MMAR, default is mmars subfolder under `torch.hub get_dir()`.
+ progress: whether to display a progress bar when downloading the content.
+ version: version number of the MMAR. Set it to `-1` to use `item[Keys.VERSION]`.
+ map_location: pytorch API parameter for `torch.load` or `torch.jit.load`.
+ pretrained: whether to load the pretrained weights after initializing a network module.
+ weights_only: whether to load only the weights instead of initializing the network module and assign weights.
+ model_key: a key to search in the model file or config file for the model dictionary.
+ Currently this function assumes that the model dictionary has
+ `{"[name|path]": "test.module", "args": {'kw': 'test'}}`.
+
+ Examples::
+ >>> from monai.apps import load_from_mmar
+ >>> unet_model = load_from_mmar("clara_pt_prostate_mri_segmentation_1", mmar_dir=".", map_location="cpu")
+ >>> print(unet_model)
+
+ See Also:
+ https://docs.nvidia.com/clara/
+ """
+ if not isinstance(item, Mapping):
+ item = get_model_spec(item)
+ model_dir = download_mmar(item=item, mmar_dir=mmar_dir, progress=progress, version=version)
+ model_file = os.path.join(model_dir, item[Keys.MODEL_FILE])
+ print(f'\n*** "{item[Keys.ID]}" available at {model_dir}.')
+
+ # loading with `torch.jit.load`
+ if f"{model_file}".endswith(".ts"):
+ if not pretrained:
+ warnings.warn("Loading a ScriptModule, 'pretrained' option ignored.")
+ if weights_only:
+ warnings.warn("Loading a ScriptModule, 'weights_only' option ignored.")
+ return torch.jit.load(model_file, map_location=map_location)
+
+ # loading with `torch.load`
+ model_dict = torch.load(model_file, map_location=map_location)
+ if weights_only:
+ return model_dict.get(model_key, model_dict) # model_dict[model_key] or model_dict directly
+
+ # 1. search `model_dict['train_config]` for model config spec.
+ model_config = _get_val(dict(model_dict).get("train_conf", {}), key=model_key, default={})
+ if not model_config:
+ # 2. search json CONFIG_FILE for model config spec.
+ json_path = os.path.join(model_dir, item.get(Keys.CONFIG_FILE, "config_train.json"))
+ with open(json_path) as f:
+ conf_dict = json.load(f)
+ conf_dict = dict(conf_dict)
+ model_config = _get_val(conf_dict, key=model_key, default={})
+ if not model_config:
+ # 3. search `model_dict` for model config spec.
+ model_config = _get_val(dict(model_dict), key=model_key, default={})
+
+ if not (model_config and isinstance(model_config, Mapping)):
+ raise ValueError(
+ f"Could not load model config dictionary from config: {item.get(Keys.CONFIG_FILE)}, "
+ f"or from model file: {item.get(Keys.MODEL_FILE)}."
+ )
+
+ # parse `model_config` for model class and model parameters
+ if model_config.get("name"): # model config section is a "name"
+ model_name = model_config["name"]
+ model_cls = monai_nets.__dict__[model_name]
+ elif model_config.get("path"): # model config section is a "path"
+ # https://docs.nvidia.com/clara/clara-train-sdk/pt/byom.html
+ model_module, model_name = model_config.get("path", ".").rsplit(".", 1)
+ model_cls, has_cls = optional_import(module=model_module, name=model_name)
+ if not has_cls:
+ raise ValueError(
+ f"Could not load MMAR model config {model_config.get('path', '')}, "
+ f"Please make sure MMAR's sub-folders in '{model_dir}' is on the PYTHONPATH."
+ "See also: https://docs.nvidia.com/clara/clara-train-sdk/pt/byom.html"
+ )
+ else:
+ raise ValueError(f"Could not load model config {model_config}.")
+
+ print(f"*** Model: {model_cls}")
+ model_kwargs = model_config.get("args", None)
+ if model_kwargs:
+ model_inst = model_cls(**model_kwargs)
+ print(f"*** Model params: {model_kwargs}")
+ else:
+ model_inst = model_cls()
+ if pretrained:
+ model_inst.load_state_dict(model_dict.get(model_key, model_dict))
+ print("\n---")
+ doc_url = item.get(Keys.DOC) or _get_ngc_doc_url(item[Keys.NAME], model_prefix="nvidia:med:")
+ print(f"For more information, please visit {doc_url}\n")
+ return model_inst
+
+
+def _get_val(input_dict: Mapping, key="model", default=None):
+ """
+ Search for the item with `key` in `config_dict`.
+ Returns: the first occurrence of `key` in a breadth first search.
+ """
+ if key in input_dict:
+ return input_dict[key]
+ for sub_dict in input_dict:
+ val = input_dict[sub_dict]
+ if isinstance(val, Mapping):
+ found_val = _get_val(val, key=key, default=None)
+ if found_val is not None:
+ return found_val
+ return default
diff --git a/monai/apps/mmars/model_desc.py b/monai/apps/mmars/model_desc.py
new file mode 100644
index 0000000000..fca6f60da5
--- /dev/null
+++ b/monai/apps/mmars/model_desc.py
@@ -0,0 +1,197 @@
+# Copyright 2020 - 2021 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Collection of the remote MMAR descriptors
+
+See Also:
+ - https://docs.nvidia.com/clara/clara-train-sdk/pt/mmar.html
+"""
+
+import os
+
+__all__ = ["MODEL_DESC", "RemoteMMARKeys"]
+
+
+class RemoteMMARKeys:
+ """
+ Data keys used for loading MMAR.
+ ID must uniquely define an MMAR.
+ """
+
+ ID = "id" # unique MMAR
+ NAME = "name" # MMAR name for readability
+ URL = "url" # remote location of the MMAR, see also: `monai.apps.mmars.mmars._get_ngc_url`
+ DOC = "doc" # documentation page of the remote model, see also: `monai.apps.mmars.mmars._get_ngc_doc_url`
+ FILE_TYPE = "file_type" # type of the compressed MMAR
+ HASH_TYPE = "hash_type" # hashing method for the compressed MMAR
+ HASH_VAL = "hash_val" # hashing value for the compressed MMAR
+ MODEL_FILE = "model_file" # within an MMAR folder, the relative path to the model file
+ CONFIG_FILE = "config_file" # within an MMAR folder, the relative path to the config file (for model config)
+ VERSION = "version" # version of the MMAR
+
+
+MODEL_DESC = (
+ {
+ RemoteMMARKeys.ID: "clara_pt_spleen_ct_segmentation_1",
+ RemoteMMARKeys.NAME: "clara_pt_spleen_ct_segmentation",
+ RemoteMMARKeys.FILE_TYPE: "zip",
+ RemoteMMARKeys.HASH_TYPE: "md5",
+ RemoteMMARKeys.HASH_VAL: None,
+ RemoteMMARKeys.MODEL_FILE: os.path.join("models", "model.pt"),
+ RemoteMMARKeys.CONFIG_FILE: os.path.join("config", "config_train.json"),
+ RemoteMMARKeys.VERSION: 1,
+ },
+ {
+ RemoteMMARKeys.ID: "clara_pt_prostate_mri_segmentation_1",
+ RemoteMMARKeys.NAME: "clara_pt_prostate_mri_segmentation",
+ RemoteMMARKeys.FILE_TYPE: "zip",
+ RemoteMMARKeys.HASH_TYPE: "md5",
+ RemoteMMARKeys.HASH_VAL: None,
+ RemoteMMARKeys.MODEL_FILE: os.path.join("models", "model.pt"),
+ RemoteMMARKeys.CONFIG_FILE: os.path.join("config", "config_train.json"),
+ RemoteMMARKeys.VERSION: 1,
+ },
+ {
+ RemoteMMARKeys.ID: "clara_pt_covid19_ct_lesion_segmentation_1",
+ RemoteMMARKeys.NAME: "clara_pt_covid19_ct_lesion_segmentation",
+ RemoteMMARKeys.FILE_TYPE: "zip",
+ RemoteMMARKeys.HASH_TYPE: "md5",
+ RemoteMMARKeys.HASH_VAL: None,
+ RemoteMMARKeys.MODEL_FILE: os.path.join("models", "model.pt"),
+ RemoteMMARKeys.VERSION: 1,
+ },
+ {
+ RemoteMMARKeys.ID: "clara_pt_covid19_3d_ct_classification_1",
+ RemoteMMARKeys.NAME: "clara_pt_covid19_3d_ct_classification",
+ RemoteMMARKeys.FILE_TYPE: "zip",
+ RemoteMMARKeys.HASH_TYPE: "md5",
+ RemoteMMARKeys.HASH_VAL: None,
+ RemoteMMARKeys.MODEL_FILE: os.path.join("models", "model.pt"),
+ RemoteMMARKeys.CONFIG_FILE: os.path.join("config", "config_train.json"),
+ RemoteMMARKeys.VERSION: 1,
+ },
+ {
+ RemoteMMARKeys.ID: "clara_pt_covid19_ct_lung_annotation_1",
+ RemoteMMARKeys.NAME: "clara_pt_covid19_ct_lung_annotation",
+ RemoteMMARKeys.FILE_TYPE: "zip",
+ RemoteMMARKeys.HASH_TYPE: "md5",
+ RemoteMMARKeys.HASH_VAL: None,
+ RemoteMMARKeys.MODEL_FILE: os.path.join("models", "model.pt"),
+ RemoteMMARKeys.VERSION: 1,
+ },
+ {
+ RemoteMMARKeys.ID: "clara_pt_fed_learning_brain_tumor_mri_segmentation_1",
+ RemoteMMARKeys.NAME: "clara_pt_fed_learning_brain_tumor_mri_segmentation",
+ RemoteMMARKeys.FILE_TYPE: "zip",
+ RemoteMMARKeys.HASH_TYPE: "md5",
+ RemoteMMARKeys.HASH_VAL: None,
+ RemoteMMARKeys.MODEL_FILE: os.path.join("models", "server", "best_FL_global_model.pt"),
+ RemoteMMARKeys.VERSION: 1,
+ },
+ {
+ RemoteMMARKeys.ID: "clara_pt_pathology_metastasis_detection_1",
+ RemoteMMARKeys.NAME: "clara_pt_pathology_metastasis_detection",
+ RemoteMMARKeys.FILE_TYPE: "zip",
+ RemoteMMARKeys.HASH_TYPE: "md5",
+ RemoteMMARKeys.HASH_VAL: None,
+ RemoteMMARKeys.MODEL_FILE: os.path.join("models", "model.pt"),
+ RemoteMMARKeys.CONFIG_FILE: os.path.join("config", "config_train.json"),
+ RemoteMMARKeys.VERSION: 1,
+ },
+ {
+ RemoteMMARKeys.ID: "clara_pt_brain_mri_segmentation_1",
+ RemoteMMARKeys.NAME: "clara_pt_brain_mri_segmentation",
+ RemoteMMARKeys.FILE_TYPE: "zip",
+ RemoteMMARKeys.HASH_TYPE: "md5",
+ RemoteMMARKeys.HASH_VAL: None,
+ RemoteMMARKeys.MODEL_FILE: os.path.join("models", "model.pt"),
+ RemoteMMARKeys.VERSION: 1,
+ },
+ {
+ RemoteMMARKeys.ID: "clara_pt_brain_mri_segmentation_t1c_1",
+ RemoteMMARKeys.NAME: "clara_pt_brain_mri_segmentation_t1c",
+ RemoteMMARKeys.FILE_TYPE: "zip",
+ RemoteMMARKeys.HASH_TYPE: "md5",
+ RemoteMMARKeys.HASH_VAL: None,
+ RemoteMMARKeys.MODEL_FILE: os.path.join("models", "model.pt"),
+ RemoteMMARKeys.VERSION: 1,
+ },
+ {
+ RemoteMMARKeys.ID: "clara_pt_liver_and_tumor_ct_segmentation_1",
+ RemoteMMARKeys.NAME: "clara_pt_liver_and_tumor_ct_segmentation",
+ RemoteMMARKeys.FILE_TYPE: "zip",
+ RemoteMMARKeys.HASH_TYPE: "md5",
+ RemoteMMARKeys.HASH_VAL: None,
+ RemoteMMARKeys.MODEL_FILE: os.path.join("models", "model.pt"),
+ RemoteMMARKeys.CONFIG_FILE: os.path.join("config", "config_train.json"),
+ RemoteMMARKeys.VERSION: 1,
+ },
+ {
+ RemoteMMARKeys.ID: "clara_pt_pancreas_and_tumor_ct_segmentation_1",
+ RemoteMMARKeys.NAME: "clara_pt_pancreas_and_tumor_ct_segmentation",
+ RemoteMMARKeys.FILE_TYPE: "zip",
+ RemoteMMARKeys.HASH_TYPE: "md5",
+ RemoteMMARKeys.HASH_VAL: None,
+ RemoteMMARKeys.MODEL_FILE: os.path.join("models", "model.pt"),
+ RemoteMMARKeys.CONFIG_FILE: os.path.join("config", "config_train.json"),
+ RemoteMMARKeys.VERSION: 1,
+ },
+ {
+ RemoteMMARKeys.ID: "clara_pt_brain_mri_annotation_t1c_1",
+ RemoteMMARKeys.NAME: "clara_pt_brain_mri_annotation_t1c",
+ RemoteMMARKeys.FILE_TYPE: "zip",
+ RemoteMMARKeys.HASH_TYPE: "md5",
+ RemoteMMARKeys.HASH_VAL: None,
+ RemoteMMARKeys.MODEL_FILE: os.path.join("models", "model.pt"),
+ RemoteMMARKeys.VERSION: 1,
+ },
+ {
+ RemoteMMARKeys.ID: "clara_pt_spleen_ct_annotation_1",
+ RemoteMMARKeys.NAME: "clara_pt_spleen_ct_annotation",
+ RemoteMMARKeys.FILE_TYPE: "zip",
+ RemoteMMARKeys.HASH_TYPE: "md5",
+ RemoteMMARKeys.HASH_VAL: None,
+ RemoteMMARKeys.MODEL_FILE: os.path.join("models", "model.pt"),
+ RemoteMMARKeys.CONFIG_FILE: os.path.join("config", "config_train.json"),
+ RemoteMMARKeys.VERSION: 1,
+ },
+ {
+ RemoteMMARKeys.ID: "clara_pt_deepgrow_3d_annotation_1",
+ RemoteMMARKeys.NAME: "clara_pt_deepgrow_3d_annotation",
+ RemoteMMARKeys.FILE_TYPE: "zip",
+ RemoteMMARKeys.HASH_TYPE: "md5",
+ RemoteMMARKeys.HASH_VAL: None,
+ RemoteMMARKeys.MODEL_FILE: os.path.join("models", "model.pt"),
+ RemoteMMARKeys.CONFIG_FILE: os.path.join("config", "config_train.json"),
+ RemoteMMARKeys.VERSION: 1,
+ },
+ {
+ RemoteMMARKeys.ID: "clara_pt_deepgrow_2d_annotation_1",
+ RemoteMMARKeys.NAME: "clara_pt_deepgrow_2d_annotation",
+ RemoteMMARKeys.FILE_TYPE: "zip",
+ RemoteMMARKeys.HASH_TYPE: "md5",
+ RemoteMMARKeys.HASH_VAL: None,
+ RemoteMMARKeys.MODEL_FILE: os.path.join("models", "model.pt"),
+ RemoteMMARKeys.CONFIG_FILE: os.path.join("config", "config_train.json"),
+ RemoteMMARKeys.VERSION: 1,
+ },
+ {
+ RemoteMMARKeys.ID: "clara_pt_covid19_ct_lung_segmentation_1",
+ RemoteMMARKeys.NAME: "clara_pt_covid19_ct_lung_segmentation",
+ RemoteMMARKeys.FILE_TYPE: "zip",
+ RemoteMMARKeys.HASH_TYPE: "md5",
+ RemoteMMARKeys.HASH_VAL: None,
+ RemoteMMARKeys.MODEL_FILE: os.path.join("models", "model.pt"),
+ RemoteMMARKeys.CONFIG_FILE: os.path.join("config", "config_train.json"),
+ RemoteMMARKeys.VERSION: 1,
+ },
+)
diff --git a/monai/apps/pathology/__init__.py b/monai/apps/pathology/__init__.py
new file mode 100644
index 0000000000..203e1a80d7
--- /dev/null
+++ b/monai/apps/pathology/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2020 - 2021 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .datasets import MaskedInferenceWSIDataset, PatchWSIDataset, SmartCacheDataset
+from .handlers import ProbMapProducer
+from .metrics import LesionFROC
+from .utils import PathologyProbNMS, compute_isolated_tumor_cells, compute_multi_instance_mask
diff --git a/monai/apps/pathology/datasets.py b/monai/apps/pathology/datasets.py
new file mode 100644
index 0000000000..3694ca4144
--- /dev/null
+++ b/monai/apps/pathology/datasets.py
@@ -0,0 +1,311 @@
+# Copyright 2020 - 2021 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
+
+import numpy as np
+
+from monai.data import Dataset, SmartCacheDataset
+from monai.data.image_reader import WSIReader
+from monai.utils import ensure_tuple_rep
+
+__all__ = ["PatchWSIDataset", "SmartCachePatchWSIDataset", "MaskedInferenceWSIDataset"]
+
+
+class PatchWSIDataset(Dataset):
+ """
+ This dataset reads whole slide images, extracts regions, and creates patches.
+ It also reads labels for each patch and provides each patch with its associated class labels.
+
+ Args:
+ data: the list of input samples including image, location, and label (see the note below for more details).
+ region_size: the size of regions to be extracted from the whole slide image.
+ grid_shape: the grid shape on which the patches should be extracted.
+ patch_size: the size of patches extracted from the region on the grid.
+ transform: transforms to be executed on input data.
+ image_reader_name: the name of library to be used for loading whole slide imaging, either CuCIM or OpenSlide.
+ Defaults to CuCIM.
+
+ Note:
+ The input data has the following form as an example:
+ `[{"image": "path/to/image1.tiff", "location": [200, 500], "label": [0,0,0,1]}]`.
+
+ This means from "image1.tiff" extract a region centered at the given location `location`
+ with the size of `region_size`, and then extract patches with the size of `patch_size`
+ from a grid with the shape of `grid_shape`.
+ Be aware the the `grid_shape` should construct a grid with the same number of element as `labels`,
+ so for this example the `grid_shape` should be (2, 2).
+
+ """
+
+ def __init__(
+ self,
+ data: List,
+ region_size: Union[int, Tuple[int, int]],
+ grid_shape: Union[int, Tuple[int, int]],
+ patch_size: Union[int, Tuple[int, int]],
+ transform: Optional[Callable] = None,
+ image_reader_name: str = "cuCIM",
+ ):
+ super().__init__(data, transform)
+
+ self.region_size = ensure_tuple_rep(region_size, 2)
+ self.grid_shape = ensure_tuple_rep(grid_shape, 2)
+ self.patch_size = ensure_tuple_rep(patch_size, 2)
+
+ self.image_path_list = list({x["image"] for x in self.data})
+ self.image_reader_name = image_reader_name
+ self.image_reader = WSIReader(image_reader_name)
+ self.wsi_object_dict = None
+ if self.image_reader_name != "openslide":
+ # OpenSlide causes memory issue if we prefetch image objects
+ self._fetch_wsi_objects()
+
+ def _fetch_wsi_objects(self):
+ """Load all the image objects and reuse them when asked for an item."""
+ self.wsi_object_dict = {}
+ for image_path in self.image_path_list:
+ self.wsi_object_dict[image_path] = self.image_reader.read(image_path)
+
+ def __getitem__(self, index):
+ sample = self.data[index]
+ if self.image_reader_name == "openslide":
+ img_obj = self.image_reader.read(sample["image"])
+ else:
+ img_obj = self.wsi_object_dict[sample["image"]]
+ location = [sample["location"][i] - self.region_size[i] // 2 for i in range(len(self.region_size))]
+ images, _ = self.image_reader.get_data(
+ img=img_obj,
+ location=location,
+ size=self.region_size,
+ grid_shape=self.grid_shape,
+ patch_size=self.patch_size,
+ )
+ labels = np.array(sample["label"], dtype=np.float32)
+ # expand dimensions to have 4 dimension as batch, class, height, and width.
+ for _ in range(4 - labels.ndim):
+ labels = np.expand_dims(labels, 1)
+ patches = [{"image": images[i], "label": labels[i]} for i in range(len(sample["label"]))]
+ if self.transform:
+ patches = self.transform(patches)
+ return patches
+
+
+class SmartCachePatchWSIDataset(SmartCacheDataset):
+ """Add SmartCache functionality to `PatchWSIDataset`.
+
+ Args:
+ data: the list of input samples including image, location, and label (see `PatchWSIDataset` for more details)
+ region_size: the region to be extracted from the whole slide image.
+ grid_shape: the grid shape on which the patches should be extracted.
+ patch_size: the size of patches extracted from the region on the grid.
+ image_reader_name: the name of library to be used for loading whole slide imaging, either CuCIM or OpenSlide.
+ Defaults to CuCIM.
+ transform: transforms to be executed on input data.
+ replace_rate: percentage of the cached items to be replaced in every epoch.
+ cache_num: number of items to be cached. Default is `sys.maxsize`.
+ will take the minimum of (cache_num, data_length x cache_rate, data_length).
+ cache_rate: percentage of cached data in total, default is 1.0 (cache all).
+ will take the minimum of (cache_num, data_length x cache_rate, data_length).
+ num_init_workers: the number of worker threads to initialize the cache for first epoch.
+ If num_init_workers is None then the number returned by os.cpu_count() is used.
+ num_replace_workers: the number of worker threads to prepare the replacement cache for every epoch.
+ If num_replace_workers is None then the number returned by os.cpu_count() is used.
+ progress: whether to display a progress bar when caching for the first epoch.
+
+ """
+
+ def __init__(
+ self,
+ data: List,
+ region_size: Union[int, Tuple[int, int]],
+ grid_shape: Union[int, Tuple[int, int]],
+ patch_size: Union[int, Tuple[int, int]],
+ transform: Union[Sequence[Callable], Callable],
+ image_reader_name: str = "cuCIM",
+ replace_rate: float = 0.5,
+ cache_num: int = sys.maxsize,
+ cache_rate: float = 1.0,
+ num_init_workers: Optional[int] = None,
+ num_replace_workers: Optional[int] = None,
+ progress: bool = True,
+ ):
+ patch_wsi_dataset = PatchWSIDataset(
+ data=data,
+ region_size=region_size,
+ grid_shape=grid_shape,
+ patch_size=patch_size,
+ image_reader_name=image_reader_name,
+ )
+ super().__init__(
+ data=patch_wsi_dataset, # type: ignore
+ transform=transform,
+ replace_rate=replace_rate,
+ cache_num=cache_num,
+ cache_rate=cache_rate,
+ num_init_workers=num_init_workers,
+ num_replace_workers=num_replace_workers,
+ progress=progress,
+ shuffle=False,
+ )
+
+
+class MaskedInferenceWSIDataset(Dataset):
+ """
+ This dataset load the provided foreground masks at an arbitrary resolution level,
+ and extract patches based on that mask from the associated whole slide image.
+
+ Args:
+ data: a list of sample including the path to the whole slide image and the path to the mask.
+ Like this: `[{"image": "path/to/image1.tiff", "mask": "path/to/mask1.npy}, ...]"`.
+ patch_size: the size of patches to be extracted from the whole slide image for inference.
+ transform: transforms to be executed on extracted patches.
+ image_reader_name: the name of library to be used for loading whole slide imaging, either CuCIM or OpenSlide.
+ Defaults to CuCIM.
+
+ Note:
+ The resulting output (probability maps) after performing inference using this dataset is
+ supposed to be the same size as the foreground mask and not the original wsi image size.
+ """
+
+ def __init__(
+ self,
+ data: List[Dict["str", "str"]],
+ patch_size: Union[int, Tuple[int, int]],
+ transform: Optional[Callable] = None,
+ image_reader_name: str = "cuCIM",
+ ) -> None:
+ super().__init__(data, transform)
+
+ self.patch_size = ensure_tuple_rep(patch_size, 2)
+
+ # set up whole slide image reader
+ self.image_reader_name = image_reader_name
+ self.image_reader = WSIReader(image_reader_name)
+
+ # process data and create a list of dictionaries containing all required data and metadata
+ self.data = self._prepare_data(data)
+
+ # calculate cumulative number of patches for all the samples
+ self.num_patches_per_sample = [len(d["image_locations"]) for d in self.data]
+ self.num_patches = sum(self.num_patches_per_sample)
+ self.cum_num_patches = np.cumsum([0] + self.num_patches_per_sample[:-1])
+
+ def _prepare_data(self, input_data: List[Dict["str", "str"]]) -> List[Dict]:
+ prepared_data = []
+ for sample in input_data:
+ prepared_sample = self._prepare_a_sample(sample)
+ prepared_data.append(prepared_sample)
+ return prepared_data
+
+ def _prepare_a_sample(self, sample: Dict["str", "str"]) -> Dict:
+ """
+ Preprocess input data to load WSIReader object and the foreground mask,
+ and define the locations where patches need to be extracted.
+
+ Args:
+ sample: one sample, a dictionary containing path to the whole slide image and the foreground mask.
+ For example: `{"image": "path/to/image1.tiff", "mask": "path/to/mask1.npy}`
+
+ Return:
+ A dictionary containing:
+ "name": the base name of the whole slide image,
+ "image": the WSIReader image object,
+ "mask_shape": the size of the foreground mask,
+ "mask_locations": the list of non-zero pixel locations (x, y) on the foreground mask,
+ "image_locations": the list of pixel locations (x, y) on the whole slide image where patches are extracted, and
+ "level": the resolution level of the mask with respect to the whole slide image.
+ }
+ """
+ image = self.image_reader.read(sample["image"])
+ mask = np.load(sample["mask"])
+ try:
+ level, ratio = self._calculate_mask_level(image, mask)
+ except ValueError as err:
+ err.args = (sample["mask"],) + err.args
+ raise
+
+ # get all indices for non-zero pixels of the foreground mask
+ mask_locations = np.vstack(mask.nonzero()).T
+
+ # convert mask locations to image locations to extract patches
+ image_locations = (mask_locations + 0.5) * ratio - np.array(self.patch_size) // 2
+
+ return {
+ "name": os.path.splitext(os.path.basename(sample["image"]))[0],
+ "image": image,
+ "mask_shape": mask.shape,
+ "mask_locations": mask_locations.astype(int).tolist(),
+ "image_locations": image_locations.astype(int).tolist(),
+ "level": level,
+ }
+
+ def _calculate_mask_level(self, image: np.ndarray, mask: np.ndarray) -> Tuple[int, float]:
+ """
+ Calculate level of the mask and its ratio with respect to the whole slide image
+
+ Args:
+ image: the original whole slide image
+ mask: a mask, that can be down-sampled at an arbitrary level.
+ Note that down-sampling ratio should be 2^N and equal in all dimension.
+
+ Return:
+ tuple: (level, ratio) where ratio is 2^level
+
+ """
+ image_shape = image.shape
+ mask_shape = mask.shape
+ ratios = [image_shape[i] / mask_shape[i] for i in range(2)]
+ level = np.log2(ratios[0])
+
+ if ratios[0] != ratios[1]:
+ raise ValueError(
+ "Image/Mask ratio across dimensions does not match!"
+ f"ratio 0: {ratios[0]} ({image_shape[0]} / {mask_shape[0]}),"
+ f"ratio 1: {ratios[1]} ({image_shape[1]} / {mask_shape[1]}),"
+ )
+ if not level.is_integer():
+ raise ValueError(f"Mask is not at a regular level (ratio not power of 2), image / mask ratio: {ratios[0]}")
+
+ return int(level), ratios[0]
+
+ def _load_a_patch(self, index):
+ """
+ Load sample given the index
+
+ Since index is sequential and the patches are coming in an stream from different images,
+ this method, first, finds the whole slide image and the patch that should be extracted,
+ then it loads the patch and provide it with its image name and the corresponding mask location.
+ """
+ sample_num = np.argmax(self.cum_num_patches > index) - 1
+ sample = self.data[sample_num]
+ patch_num = index - self.cum_num_patches[sample_num]
+ location_on_image = sample["image_locations"][patch_num]
+ location_on_mask = sample["mask_locations"][patch_num]
+
+ image, _ = self.image_reader.get_data(
+ img=sample["image"],
+ location=location_on_image,
+ size=self.patch_size,
+ )
+ processed_sample = {"image": image, "name": sample["name"], "mask_location": location_on_mask}
+ return processed_sample
+
+ def __len__(self):
+ return self.num_patches
+
+ def __getitem__(self, index):
+ patch = [self._load_a_patch(index)]
+ if self.transform:
+ patch = self.transform(patch)
+ return patch
diff --git a/monai/apps/pathology/handlers.py b/monai/apps/pathology/handlers.py
new file mode 100644
index 0000000000..7ac4a0e45b
--- /dev/null
+++ b/monai/apps/pathology/handlers.py
@@ -0,0 +1,114 @@
+# Copyright 2020 - 2021 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+from typing import TYPE_CHECKING, Dict, Optional
+
+import numpy as np
+
+from monai.config import DtypeLike, IgniteInfo
+from monai.utils import min_version, optional_import
+
+Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
+if TYPE_CHECKING:
+ from ignite.engine import Engine
+else:
+ Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
+
+
+class ProbMapProducer:
+ """
+ Event handler triggered on completing every iteration to save the probability map
+ """
+
+ def __init__(
+ self,
+ output_dir: str = "./",
+ output_postfix: str = "",
+ dtype: DtypeLike = np.float64,
+ name: Optional[str] = None,
+ ) -> None:
+ """
+ Args:
+ output_dir: output directory to save probability maps.
+ output_postfix: a string appended to all output file names.
+ dtype: the data type in which the probability map is stored. Default np.float64.
+ name: identifier of logging.logger to use, defaulting to `engine.logger`.
+
+ """
+ self.logger = logging.getLogger(name)
+ self._name = name
+ self.output_dir = output_dir
+ self.output_postfix = output_postfix
+ self.dtype = dtype
+ self.prob_map: Dict[str, np.ndarray] = {}
+ self.level: Dict[str, int] = {}
+ self.counter: Dict[str, int] = {}
+ self.num_done_images: int = 0
+ self.num_images: int = 0
+
+ def attach(self, engine: Engine) -> None:
+ """
+ Args:
+ engine: Ignite Engine, it can be a trainer, validator or evaluator.
+ """
+
+ self.num_images = len(engine.data_loader.dataset.data)
+
+ for sample in engine.data_loader.dataset.data:
+ name = sample["name"]
+ self.prob_map[name] = np.zeros(sample["mask_shape"], dtype=self.dtype)
+ self.counter[name] = len(sample["mask_locations"])
+ self.level[name] = sample["level"]
+
+ if self._name is None:
+ self.logger = engine.logger
+ if not engine.has_event_handler(self, Events.ITERATION_COMPLETED):
+ engine.add_event_handler(Events.ITERATION_COMPLETED, self)
+ if not engine.has_event_handler(self.finalize, Events.COMPLETED):
+ engine.add_event_handler(Events.COMPLETED, self.finalize)
+
+ def __call__(self, engine: Engine) -> None:
+ """
+ This method assumes self.batch_transform will extract metadata from the input batch.
+
+ Args:
+ engine: Ignite Engine, it can be a trainer, validator or evaluator.
+ """
+ names = engine.state.batch["name"]
+ locs = engine.state.batch["mask_location"]
+ pred = engine.state.output["pred"]
+ for i, name in enumerate(names):
+ self.prob_map[name][locs[0][i], locs[1][i]] = pred[i]
+ self.counter[name] -= 1
+ if self.counter[name] == 0:
+ self.save_prob_map(name)
+
+ def save_prob_map(self, name: str) -> None:
+ """
+ This method save the probability map for an image, when its inference is finished,
+ and delete that probability map from memory.
+
+ Args:
+ name: the name of image to be saved.
+ """
+ file_path = os.path.join(self.output_dir, name)
+ np.save(file_path + self.output_postfix + ".npy", self.prob_map[name])
+
+ self.num_done_images += 1
+ self.logger.info(f"Inference of '{name}' is done [{self.num_done_images}/{self.num_images}]!")
+ del self.prob_map[name]
+ del self.counter[name]
+ del self.level[name]
+
+ def finalize(self, engine: Engine):
+ self.logger.info(f"Probability map is created for {self.num_done_images}/{self.num_images} images!")
diff --git a/monai/apps/pathology/metrics.py b/monai/apps/pathology/metrics.py
new file mode 100644
index 0000000000..2140de0080
--- /dev/null
+++ b/monai/apps/pathology/metrics.py
@@ -0,0 +1,184 @@
+# Copyright 2020 - 2021 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Dict, List, Tuple
+
+import numpy as np
+
+from monai.apps.pathology.utils import PathologyProbNMS, compute_isolated_tumor_cells, compute_multi_instance_mask
+from monai.data.image_reader import WSIReader
+from monai.metrics import compute_fp_tp_probs, compute_froc_curve_data, compute_froc_score
+from monai.utils import min_version, optional_import
+
+if TYPE_CHECKING:
+ from tqdm import tqdm
+
+ has_tqdm = True
+else:
+ tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm")
+
+if not has_tqdm:
+
+ def tqdm(x):
+ return x
+
+
+class LesionFROC:
+ """
+ Evaluate with Free Response Operating Characteristic (FROC) score.
+
+ Args:
+ data: either the list of dictionaries containing probability maps (inference result) and
+ tumor mask (ground truth), as below, or the path to a json file containing such list.
+ `{
+ "prob_map": "path/to/prob_map_1.npy",
+ "tumor_mask": "path/to/ground_truth_1.tiff",
+ "level": 6,
+ "pixel_spacing": 0.243
+ }`
+ grow_distance: Euclidean distance (in micrometer) by which to grow the label the ground truth's tumors.
+ Defaults to 75, which is the equivalent size of 5 tumor cells.
+ itc_diameter: the maximum diameter of a region (in micrometer) to be considered as an isolated tumor cell.
+ Defaults to 200.
+ eval_thresholds: the false positive rates for calculating the average sensitivity.
+ Defaults to (0.25, 0.5, 1, 2, 4, 8) which is the same as the CAMELYON 16 Challenge.
+ nms_sigma: the standard deviation for gaussian filter of non-maximal suppression. Defaults to 0.0.
+ nms_prob_threshold: the probability threshold of non-maximal suppression. Defaults to 0.5.
+ nms_box_size: the box size (in pixel) to be removed around the the pixel for non-maximal suppression.
+ image_reader_name: the name of library to be used for loading whole slide imaging, either CuCIM or OpenSlide.
+ Defaults to CuCIM.
+
+ Note:
+ For more info on `nms_*` parameters look at monai.utils.prob_nms.ProbNMS`.
+
+ """
+
+ def __init__(
+ self,
+ data: List[Dict],
+ grow_distance: int = 75,
+ itc_diameter: int = 200,
+ eval_thresholds: Tuple = (0.25, 0.5, 1, 2, 4, 8),
+ nms_sigma: float = 0.0,
+ nms_prob_threshold: float = 0.5,
+ nms_box_size: int = 48,
+ image_reader_name: str = "cuCIM",
+ ) -> None:
+
+ self.data = data
+ self.grow_distance = grow_distance
+ self.itc_diameter = itc_diameter
+ self.eval_thresholds = eval_thresholds
+ self.image_reader = WSIReader(image_reader_name)
+ self.nms = PathologyProbNMS(
+ sigma=nms_sigma,
+ prob_threshold=nms_prob_threshold,
+ box_size=nms_box_size,
+ )
+
+ def prepare_inference_result(self, sample: Dict):
+ """
+ Prepare the probability map for detection evaluation.
+
+ """
+ # load the probability map (the result of model inference)
+ prob_map = np.load(sample["prob_map"])
+
+ # apply non-maximal suppression
+ nms_outputs = self.nms(probs_map=prob_map, resolution_level=sample["level"])
+
+ # separate nms outputs
+ if nms_outputs:
+ probs, x_coord, y_coord = zip(*nms_outputs)
+ else:
+ probs, x_coord, y_coord = [], [], []
+
+ return np.array(probs), np.array(x_coord), np.array(y_coord)
+
+ def prepare_ground_truth(self, sample):
+ """
+ Prepare the ground truth for evaluation based on the binary tumor mask
+
+ """
+ # load binary tumor masks
+ img_obj = self.image_reader.read(sample["tumor_mask"])
+ tumor_mask = self.image_reader.get_data(img_obj, level=sample["level"])[0][0]
+
+ # calculate pixel spacing at the mask level
+ mask_pixel_spacing = sample["pixel_spacing"] * pow(2, sample["level"])
+
+ # compute multi-instance mask from a binary mask
+ grow_pixel_threshold = self.grow_distance / (mask_pixel_spacing * 2)
+ tumor_mask = compute_multi_instance_mask(mask=tumor_mask, threshold=grow_pixel_threshold)
+
+ # identify isolated tumor cells
+ itc_threshold = (self.itc_diameter + self.grow_distance) / mask_pixel_spacing
+ itc_labels = compute_isolated_tumor_cells(tumor_mask=tumor_mask, threshold=itc_threshold)
+
+ return tumor_mask, itc_labels
+
+ def compute_fp_tp(self):
+ """
+ Compute false positive and true positive probabilities for tumor detection,
+ by comparing the model outputs with the prepared ground truths for all samples
+
+ """
+ total_fp_probs, total_tp_probs = [], []
+ total_num_targets = 0
+ num_images = len(self.data)
+
+ for sample in tqdm(self.data):
+ probs, y_coord, x_coord = self.prepare_inference_result(sample)
+ ground_truth, itc_labels = self.prepare_ground_truth(sample)
+ # compute FP and TP probabilities for a pair of an image and an ground truth mask
+ fp_probs, tp_probs, num_targets = compute_fp_tp_probs(
+ probs=probs,
+ y_coord=y_coord,
+ x_coord=x_coord,
+ evaluation_mask=ground_truth,
+ labels_to_exclude=itc_labels,
+ resolution_level=sample["level"],
+ )
+ total_fp_probs.extend(fp_probs)
+ total_tp_probs.extend(tp_probs)
+ total_num_targets += num_targets
+
+ return (
+ np.array(total_fp_probs),
+ np.array(total_tp_probs),
+ total_num_targets,
+ num_images,
+ )
+
+ def evaluate(self):
+ """
+ Evaluate the detection performance of a model based on the model probability map output,
+ the ground truth tumor mask, and their associated metadata (e.g., pixel_spacing, level)
+ """
+ # compute false positive (FP) and true positive (TP) probabilities for all images
+ fp_probs, tp_probs, num_targets, num_images = self.compute_fp_tp()
+
+ # compute FROC curve given the evaluation of all images
+ fps_per_image, total_sensitivity = compute_froc_curve_data(
+ fp_probs=fp_probs,
+ tp_probs=tp_probs,
+ num_targets=num_targets,
+ num_images=num_images,
+ )
+
+ # compute FROC score give specific evaluation threshold
+ froc_score = compute_froc_score(
+ fps_per_image=fps_per_image,
+ total_sensitivity=total_sensitivity,
+ eval_thresholds=self.eval_thresholds,
+ )
+
+ return froc_score
diff --git a/monai/apps/pathology/utils.py b/monai/apps/pathology/utils.py
new file mode 100644
index 0000000000..0d1f530bff
--- /dev/null
+++ b/monai/apps/pathology/utils.py
@@ -0,0 +1,85 @@
+# Copyright 2020 - 2021 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Union
+
+import numpy as np
+import torch
+
+from monai.transforms.post.array import ProbNMS
+from monai.utils import optional_import
+
+measure, _ = optional_import("skimage.measure")
+ndimage, _ = optional_import("scipy.ndimage")
+
+
+def compute_multi_instance_mask(mask: np.ndarray, threshold: float):
+ """
+ This method computes the segmentation mask according to the binary tumor mask.
+
+ Args:
+ mask: the binary mask array
+ threshold: the threshold to fill holes
+ """
+
+ neg = 255 - mask * 255
+ distance = ndimage.morphology.distance_transform_edt(neg)
+ binary = distance < threshold
+
+ filled_image = ndimage.morphology.binary_fill_holes(binary)
+ multi_instance_mask = measure.label(filled_image, connectivity=2)
+
+ return multi_instance_mask
+
+
+def compute_isolated_tumor_cells(tumor_mask: np.ndarray, threshold: float) -> List[int]:
+ """
+ This method computes identifies Isolated Tumor Cells (ITC) and return their labels.
+
+ Args:
+ tumor_mask: the tumor mask.
+ threshold: the threshold (at the mask level) to define an isolated tumor cell (ITC).
+ A region with the longest diameter less than this threshold is considered as an ITC.
+ """
+ max_label = np.amax(tumor_mask)
+ properties = measure.regionprops(tumor_mask, coordinates="rc")
+ itc_list = []
+ for i in range(max_label): # type: ignore
+ if properties[i].major_axis_length < threshold:
+ itc_list.append(i + 1)
+
+ return itc_list
+
+
+class PathologyProbNMS(ProbNMS):
+ """
+ This class extends monai.utils.ProbNMS and add the `resolution` option for
+ Pathology.
+ """
+
+ def __call__(
+ self,
+ probs_map: Union[np.ndarray, torch.Tensor],
+ resolution_level: int = 0,
+ ):
+ """
+ probs_map: the input probabilities map, it must have shape (H[, W, ...]).
+ resolution_level: the level at which the probabilities map is made.
+ """
+ resolution = pow(2, resolution_level)
+ org_outputs = ProbNMS.__call__(self, probs_map)
+ outputs = []
+ for org_output in org_outputs:
+ prob = org_output[0]
+ coord = np.asarray(org_output[1:])
+ coord_wsi = ((coord + 0.5) * resolution).astype(int)
+ outputs.append([prob] + list(coord_wsi))
+ return outputs
diff --git a/monai/apps/utils.py b/monai/apps/utils.py
index e2970b4a3d..36fac955fe 100644
--- a/monai/apps/utils.py
+++ b/monai/apps/utils.py
@@ -11,7 +11,9 @@
import hashlib
import os
+import shutil
import tarfile
+import tempfile
import warnings
import zipfile
from typing import TYPE_CHECKING, Optional
@@ -37,6 +39,53 @@
]
+def _basename(p):
+ """get the last part of the path (removing the trailing slash if it exists)"""
+ sep = os.path.sep + (os.path.altsep or "") + "/ "
+ return os.path.basename(p.rstrip(sep))
+
+
+def _download_with_progress(url, filepath, progress: bool = True):
+ """
+ Retrieve file from `url` to `filepath`, optionally showing a progress bar.
+ """
+ try:
+ if has_tqdm and progress:
+
+ class TqdmUpTo(tqdm):
+ """
+ Provides `update_to(n)` which uses `tqdm.update(delta_n)`.
+ Inspired by the example in https://github.com/tqdm/tqdm.
+ """
+
+ def update_to(self, b: int = 1, bsize: int = 1, tsize: Optional[int] = None):
+ """
+ Args:
+ b: number of blocks transferred so far, default: 1.
+ bsize: size of each block (in tqdm units), default: 1.
+ tsize: total size (in tqdm units). if None, remains unchanged.
+ """
+ if tsize is not None:
+ self.total = tsize
+ self.update(b * bsize - self.n) # will also set self.n = b * bsize
+
+ with TqdmUpTo(
+ unit="B",
+ unit_scale=True,
+ unit_divisor=1024,
+ miniters=1,
+ desc=_basename(filepath),
+ ) as t:
+ urlretrieve(url, filepath, reporthook=t.update_to)
+ else:
+ if not has_tqdm and progress:
+ warnings.warn("tqdm is not installed, will not show the downloading progress bar.")
+ urlretrieve(url, filepath)
+ except (URLError, HTTPError, ContentTooShortError, IOError) as e:
+ print(f"Download failed from {url} to {filepath}.")
+ raise e
+
+
def check_hash(filepath: str, val: Optional[str] = None, hash_type: str = "md5") -> bool:
"""
Verify hash signature of specified file.
@@ -64,23 +113,26 @@ def check_hash(filepath: str, val: Optional[str] = None, hash_type: str = "md5")
print(f"Exception in check_hash: {e}")
return False
if val != actual_hash.hexdigest():
- print("check_hash failed.")
+ print(f"check_hash failed {actual_hash.hexdigest()}.")
return False
- print(f"Verified '{os.path.basename(filepath)}', {hash_type}: {val}.")
+ print(f"Verified '{_basename(filepath)}', {hash_type}: {val}.")
return True
-def download_url(url: str, filepath: str, hash_val: Optional[str] = None, hash_type: str = "md5") -> None:
+def download_url(
+ url: str, filepath: str = "", hash_val: Optional[str] = None, hash_type: str = "md5", progress: bool = True
+) -> None:
"""
Download file from specified URL link, support process bar and hash check.
Args:
url: source URL link to download file.
- filepath: target filepath to save the downloaded file.
+ filepath: target filepath to save the downloaded file. If undefined, `os.path.basename(url)` will be used.
hash_val: expected hash value to validate the downloaded file.
if None, skip hash validation.
hash_type: 'md5' or 'sha1', defaults to 'md5'.
+ progress: whether to display a progress bar.
Raises:
RuntimeError: When the hash validation of the ``filepath`` existing file fails.
@@ -93,64 +145,34 @@ def download_url(url: str, filepath: str, hash_val: Optional[str] = None, hash_t
RuntimeError: When the hash validation of the ``url`` downloaded file fails.
"""
+ if not filepath:
+ filepath = os.path.abspath(os.path.join(".", _basename(url)))
+ print(f"Default downloading to '{filepath}'")
if os.path.exists(filepath):
if not check_hash(filepath, hash_val, hash_type):
raise RuntimeError(
f"{hash_type} check of existing file failed: filepath={filepath}, expected {hash_type}={hash_val}."
)
- print(f"file {filepath} exists, skip downloading.")
+ print(f"File exists: {filepath}, skipped downloading.")
return
- if url.startswith("https://drive.google.com"):
- if not has_gdown:
- raise RuntimeError("To download files from Google Drive, please install the gdown dependency.")
- os.makedirs(os.path.dirname(filepath), exist_ok=True)
- gdown.download(url, filepath, quiet=False)
- if not os.path.exists(filepath):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ tmp_name = os.path.join(tmp_dir, f"{_basename(filepath)}")
+ if url.startswith("https://drive.google.com"):
+ if not has_gdown:
+ raise RuntimeError("To download files from Google Drive, please install the gdown dependency.")
+ gdown.download(url, tmp_name, quiet=not progress)
+ else:
+ _download_with_progress(url, tmp_name, progress=progress)
+ if not os.path.exists(tmp_name):
raise RuntimeError(
f"Download of file from {url} to {filepath} failed due to network issue or denied permission."
)
- else:
- path = os.path.dirname(filepath)
- if path:
- os.makedirs(path, exist_ok=True)
- try:
- if has_tqdm:
-
- class TqdmUpTo(tqdm):
- """
- Provides `update_to(n)` which uses `tqdm.update(delta_n)`.
- Inspired by the example in https://github.com/tqdm/tqdm.
-
- """
-
- def update_to(self, b: int = 1, bsize: int = 1, tsize: Optional[int] = None):
- """
- b: number of blocks transferred so far, default: 1.
- bsize: size of each block (in tqdm units), default: 1.
- tsize: total size (in tqdm units). if None, remains unchanged.
-
- """
- if tsize is not None:
- self.total = tsize
- self.update(b * bsize - self.n) # will also set self.n = b * bsize
-
- with TqdmUpTo(
- unit="B",
- unit_scale=True,
- unit_divisor=1024,
- miniters=1,
- desc=filepath.split(os.sep)[-1],
- ) as t:
- urlretrieve(url, filepath, reporthook=t.update_to)
- else:
- warnings.warn("tqdm is not installed, will not show the downloading progress bar.")
- urlretrieve(url, filepath)
- print(f"\ndownloaded file: {filepath}.")
- except (URLError, HTTPError, ContentTooShortError, IOError) as e:
- print(f"download failed from {url} to {filepath}.")
- raise e
-
+ file_dir = os.path.dirname(filepath)
+ if file_dir:
+ os.makedirs(file_dir, exist_ok=True)
+ shutil.move(tmp_name, filepath) # copy the downloaded to a user-specified cache.
+ print(f"Downloaded: {filepath}")
if not check_hash(filepath, hash_val, hash_type):
raise RuntimeError(
f"{hash_type} check of downloaded file failed: URL={url}, "
@@ -158,7 +180,14 @@ def update_to(self, b: int = 1, bsize: int = 1, tsize: Optional[int] = None):
)
-def extractall(filepath: str, output_dir: str, hash_val: Optional[str] = None, hash_type: str = "md5") -> None:
+def extractall(
+ filepath: str,
+ output_dir: str = ".",
+ hash_val: Optional[str] = None,
+ hash_type: str = "md5",
+ file_type: str = "",
+ has_base: bool = True,
+) -> None:
"""
Extract file to the output directory.
Expected file types are: `zip`, `tar.gz` and `tar`.
@@ -169,48 +198,76 @@ def extractall(filepath: str, output_dir: str, hash_val: Optional[str] = None, h
hash_val: expected hash value to validate the compressed file.
if None, skip hash validation.
hash_type: 'md5' or 'sha1', defaults to 'md5'.
+ file_type: string of file type for decompressing. Leave it empty to infer the type from the filepath basename.
+ has_base: whether the extracted files have a base folder. This flag is used when checking if the existing
+ folder is a result of `extractall`, if it is, the extraction is skipped. For example, if A.zip is unzipped
+ to folder structure `A/*.png`, this flag should be True; if B.zip is unzipped to `*.png`, this flag should
+ be False.
Raises:
RuntimeError: When the hash validation of the ``filepath`` compressed file fails.
- ValueError: When the ``filepath`` file extension is not one of [zip", "tar.gz", "tar"].
+ NotImplementedError: When the ``filepath`` file extension is not one of [zip", "tar.gz", "tar"].
"""
- target_file = os.path.join(output_dir, os.path.basename(filepath).split(".")[0])
- if os.path.exists(target_file):
- print(f"extracted file {target_file} exists, skip extracting.")
+ if has_base:
+ # the extracted files will be in this folder
+ cache_dir = os.path.join(output_dir, _basename(filepath).split(".")[0])
+ else:
+ cache_dir = output_dir
+ if os.path.exists(cache_dir) and len(os.listdir(cache_dir)) > 0:
+ print(f"Non-empty folder exists in {cache_dir}, skipped extracting.")
return
- if not check_hash(filepath, hash_val, hash_type):
+ if hash_val and not check_hash(filepath, hash_val, hash_type):
raise RuntimeError(
f"{hash_type} check of compressed file failed: " f"filepath={filepath}, expected {hash_type}={hash_val}."
)
-
- if filepath.endswith("zip"):
+ print(f"Writing into directory: {output_dir}.")
+ _file_type = file_type.lower().strip()
+ if filepath.endswith("zip") or _file_type == "zip":
zip_file = zipfile.ZipFile(filepath)
zip_file.extractall(output_dir)
zip_file.close()
- elif filepath.endswith("tar") or filepath.endswith("tar.gz"):
+ return
+ if filepath.endswith("tar") or filepath.endswith("tar.gz") or "tar" in _file_type:
tar_file = tarfile.open(filepath)
tar_file.extractall(output_dir)
tar_file.close()
- else:
- raise ValueError('Unsupported file extension, available options are: ["zip", "tar.gz", "tar"].')
+ return
+ raise NotImplementedError(
+ f'Unsupported file type, available options are: ["zip", "tar.gz", "tar"]. name={filepath} type={file_type}.'
+ )
def download_and_extract(
- url: str, filepath: str, output_dir: str, hash_val: Optional[str] = None, hash_type: str = "md5"
+ url: str,
+ filepath: str = "",
+ output_dir: str = ".",
+ hash_val: Optional[str] = None,
+ hash_type: str = "md5",
+ file_type: str = "",
+ has_base: bool = True,
+ progress: bool = True,
) -> None:
"""
Download file from URL and extract it to the output directory.
Args:
url: source URL link to download file.
- filepath: the file path of compressed file.
+ filepath: the file path of the downloaded compressed file.
+ use this option to keep the directly downloaded compressed file, to avoid further repeated downloads.
output_dir: target directory to save extracted files.
- default is None to save in current directory.
+ default is the current directory.
hash_val: expected hash value to validate the downloaded file.
if None, skip hash validation.
hash_type: 'md5' or 'sha1', defaults to 'md5'.
-
+ file_type: string of file type for decompressing. Leave it empty to infer the type from url's base file name.
+ has_base: whether the extracted files have a base folder. This flag is used when checking if the existing
+ folder is a result of `extractall`, if it is, the extraction is skipped. For example, if A.zip is unzipped
+ to folder structure `A/*.png`, this flag should be True; if B.zip is unzipped to `*.png`, this flag should
+ be False.
+ progress: whether to display progress bar.
"""
- download_url(url=url, filepath=filepath, hash_val=hash_val, hash_type=hash_type)
- extractall(filepath=filepath, output_dir=output_dir, hash_val=hash_val, hash_type=hash_type)
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ filename = filepath or os.path.join(tmp_dir, f"{_basename(url)}")
+ download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress)
+ extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base)
diff --git a/monai/config/__init__.py b/monai/config/__init__.py
index f1c7707d1f..baa4400467 100644
--- a/monai/config/__init__.py
+++ b/monai/config/__init__.py
@@ -11,6 +11,7 @@
from .deviceconfig import (
USE_COMPILED,
+ IgniteInfo,
get_gpu_info,
get_system_info,
print_config,
@@ -18,4 +19,4 @@
print_gpu_info,
print_system_info,
)
-from .type_definitions import DtypeLike, IndexSelection, KeysCollection, NdarrayTensor
+from .type_definitions import DtypeLike, IndexSelection, KeysCollection, NdarrayTensor, TensorOrList
diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py
index be77a1d975..273431fc72 100644
--- a/monai/config/deviceconfig.py
+++ b/monai/config/deviceconfig.py
@@ -19,15 +19,7 @@
import torch
import monai
-from monai.utils import OptionalImportError, get_package_version, optional_import
-
-try:
- import itk # type: ignore
-
- itk_version = itk.Version.GetITKVersion()
- del itk
-except (ImportError, AttributeError):
- itk_version = "NOT INSTALLED or UNKNOWN VERSION."
+from monai.utils.module import OptionalImportError, get_package_version, optional_import
try:
_, HAS_EXT = optional_import("monai._C")
@@ -46,6 +38,7 @@
"print_gpu_info",
"print_debug_info",
"USE_COMPILED",
+ "IgniteInfo",
]
@@ -75,10 +68,11 @@ def get_optional_config_values():
output["Tensorboard"] = get_package_version("tensorboard")
output["gdown"] = get_package_version("gdown")
output["TorchVision"] = get_package_version("torchvision")
- output["ITK"] = itk_version
output["tqdm"] = get_package_version("tqdm")
output["lmdb"] = get_package_version("lmdb")
output["psutil"] = psutil_version
+ output["pandas"] = get_package_version("pandas")
+ output["einops"] = get_package_version("einops")
return output
@@ -106,10 +100,6 @@ def print_config(file=sys.stdout):
)
-def set_visible_devices(*dev_inds):
- os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, dev_inds))
-
-
def _dict_append(in_dict, key, fn):
try:
in_dict[key] = fn() if callable(fn) else fn
@@ -131,7 +121,8 @@ def get_system_info() -> OrderedDict:
elif output["System"] == "Darwin":
_dict_append(output, "Mac version", lambda: platform.mac_ver()[0])
else:
- linux_ver = re.search(r'PRETTY_NAME="(.*)"', open("/etc/os-release", "r").read())
+ with open("/etc/os-release", "r") as rel_f:
+ linux_ver = re.search(r'PRETTY_NAME="(.*)"', rel_f.read())
if linux_ver:
_dict_append(output, "Linux version", lambda: linux_ver.group(1))
@@ -217,12 +208,6 @@ def get_gpu_info() -> OrderedDict:
_dict_append(output, f"GPU {gpu} Is multi GPU board", lambda: bool(gpu_info.is_multi_gpu_board))
_dict_append(output, f"GPU {gpu} Multi processor count", lambda: gpu_info.multi_processor_count)
_dict_append(output, f"GPU {gpu} Total memory (GB)", lambda: round(gpu_info.total_memory / 1024 ** 3, 1))
- _dict_append(
- output, f"GPU {gpu} Cached memory (GB)", lambda: round(torch.cuda.memory_reserved(gpu) / 1024 ** 3, 1)
- )
- _dict_append(
- output, f"GPU {gpu} Allocated memory (GB)", lambda: round(torch.cuda.memory_allocated(gpu) / 1024 ** 3, 1)
- )
_dict_append(output, f"GPU {gpu} CUDA capability (maj.min)", lambda: f"{gpu_info.major}.{gpu_info.minor}")
return output
@@ -260,5 +245,14 @@ def print_debug_info(file=sys.stdout) -> None:
print_gpu_info(file)
+class IgniteInfo:
+ """
+ Config information of the PyTorch ignite package.
+
+ """
+
+ OPT_IMPORT_VERSION = "0.4.4"
+
+
if __name__ == "__main__":
print_debug_info()
diff --git a/monai/config/type_definitions.py b/monai/config/type_definitions.py
index daa9b10052..375ae460b2 100644
--- a/monai/config/type_definitions.py
+++ b/monai/config/type_definitions.py
@@ -9,12 +9,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Collection, Hashable, Iterable, TypeVar, Union
+from typing import Collection, Hashable, Iterable, Sequence, TypeVar, Union
import numpy as np
import torch
-__all__ = ["KeysCollection", "IndexSelection", "DtypeLike", "NdarrayTensor"]
+__all__ = ["KeysCollection", "IndexSelection", "DtypeLike", "NdarrayTensor", "TensorOrList"]
"""Commonly used concepts
This module provides naming and type specifications for commonly used concepts
@@ -55,6 +55,7 @@
container must be iterable.
"""
+
DtypeLike = Union[
np.dtype,
type,
@@ -67,3 +68,10 @@
# Generic type which can represent either a numpy.ndarray or a torch.Tensor
# Unlike Union can create a dependence between parameter(s) / return(s)
NdarrayTensor = TypeVar("NdarrayTensor", np.ndarray, torch.Tensor)
+
+
+TensorOrList = Union[torch.Tensor, Sequence[torch.Tensor]]
+"""TensorOrList
+
+The TensorOrList type is used for defining `batch-first Tensor` or `list of channel-first Tensor`.
+"""
diff --git a/monai/csrc/ext.cpp b/monai/csrc/ext.cpp
index 2e0644bc78..b4bb0f2c04 100644
--- a/monai/csrc/ext.cpp
+++ b/monai/csrc/ext.cpp
@@ -29,14 +29,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// resample bound mode
py::enum_(m, "BoundType")
- .value("replicate", monai::BoundType::Replicate)
- .value("dct1", monai::BoundType::DCT1)
- .value("dct2", monai::BoundType::DCT2)
- .value("dst1", monai::BoundType::DST1)
- .value("dst2", monai::BoundType::DST2)
- .value("dft", monai::BoundType::DFT)
- .value("sliding", monai::BoundType::Sliding)
- .value("zero", monai::BoundType::Zero)
+ .value("replicate", monai::BoundType::Replicate, "a a a | a b c d | d d d")
+ .value("nearest", monai::BoundType::Replicate, "a a a | a b c d | d d d")
+ .value("dct1", monai::BoundType::DCT1, "d c b | a b c d | c b a")
+ .value("mirror", monai::BoundType::DCT1, "d c b | a b c d | c b a")
+ .value("dct2", monai::BoundType::DCT2, "c b a | a b c d | d c b")
+ .value("reflect", monai::BoundType::DCT2, "c b a | a b c d | d c b")
+ .value("dst1", monai::BoundType::DST1, "-b -a 0 | a b c d | 0 -d -c")
+ .value("antimirror", monai::BoundType::DST1, "-b -a 0 | a b c d | 0 -d -c")
+ .value("dst2", monai::BoundType::DST2, "-c -b -a | a b c d | -d -c -b")
+ .value("antireflect", monai::BoundType::DST2, "-c -b -a | a b c d | -d -c -b")
+ .value("dft", monai::BoundType::DFT, "b c d | a b c d | a b c")
+ .value("wrap", monai::BoundType::DFT, "b c d | a b c d | a b c")
+ // .value("sliding", monai::BoundType::Sliding)
+ .value("zero", monai::BoundType::Zero, "0 0 0 | a b c d | 0 0 0")
.export_values();
// resample interpolation mode
diff --git a/monai/csrc/filtering/bilateral/bilateral.cpp b/monai/csrc/filtering/bilateral/bilateral.cpp
new file mode 100644
index 0000000000..2720d312e2
--- /dev/null
+++ b/monai/csrc/filtering/bilateral/bilateral.cpp
@@ -0,0 +1,49 @@
+/*
+Copyright 2020 - 2021 MONAI Consortium
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+#include
+#include
+#include
+
+#include "bilateral.h"
+#include "utils/common_utils.h"
+
+torch::Tensor BilateralFilter(torch::Tensor input, float spatial_sigma, float color_sigma, bool usePHL) {
+ torch::Tensor (*filterFunction)(torch::Tensor, float, float);
+
+#ifdef WITH_CUDA
+
+ if (torch::cuda::is_available() && input.is_cuda()) {
+ CHECK_CONTIGUOUS_CUDA(input);
+
+ if (input.size(1) > BF_CUDA_MAX_CHANNELS) {
+ throw std::runtime_error(
+ "Bilateral filtering not implemented for channel count > " + std::to_string(BF_CUDA_MAX_CHANNELS));
+ }
+
+ if (input.dim() - 2 > BF_CUDA_MAX_SPATIAL_DIMENSION) {
+ throw std::runtime_error(
+ "Bilateral filtering not implemented for spatial dimension > " +
+ std::to_string(BF_CUDA_MAX_SPATIAL_DIMENSION));
+ }
+
+ filterFunction = usePHL ? &BilateralFilterPHLCuda : &BilateralFilterCuda;
+ } else {
+ filterFunction = usePHL ? &BilateralFilterPHLCpu : &BilateralFilterCpu;
+ }
+#else
+ filterFunction = usePHL ? &BilateralFilterPHLCpu : &BilateralFilterCpu;
+#endif
+
+ return filterFunction(input, spatial_sigma, color_sigma);
+}
diff --git a/monai/csrc/filtering/bilateral/bilateral.h b/monai/csrc/filtering/bilateral/bilateral.h
index 1c16373fa9..c7a68d7457 100644
--- a/monai/csrc/filtering/bilateral/bilateral.h
+++ b/monai/csrc/filtering/bilateral/bilateral.h
@@ -14,7 +14,9 @@ limitations under the License.
#pragma once
#include
-#include "utils/common_utils.h"
+
+#define BF_CUDA_MAX_CHANNELS 16
+#define BF_CUDA_MAX_SPATIAL_DIMENSION 3
torch::Tensor BilateralFilterCpu(torch::Tensor input, float spatial_sigma, float color_sigma);
torch::Tensor BilateralFilterPHLCpu(torch::Tensor input, float spatial_sigma, float color_sigma);
@@ -24,19 +26,4 @@ torch::Tensor BilateralFilterCuda(torch::Tensor input, float spatial_sigma, floa
torch::Tensor BilateralFilterPHLCuda(torch::Tensor input, float spatial_sigma, float color_sigma);
#endif
-torch::Tensor BilateralFilter(torch::Tensor input, float spatial_sigma, float color_sigma, bool usePHL) {
- torch::Tensor (*filterFunction)(torch::Tensor, float, float);
-
-#ifdef WITH_CUDA
- if (torch::cuda::is_available() && input.is_cuda()) {
- CHECK_CONTIGUOUS_CUDA(input);
- filterFunction = usePHL ? &BilateralFilterPHLCuda : &BilateralFilterCuda;
- } else {
- filterFunction = usePHL ? &BilateralFilterPHLCpu : &BilateralFilterCpu;
- }
-#else
- filterFunction = usePHL ? &BilateralFilterPHLCpu : &BilateralFilterCpu;
-#endif
-
- return filterFunction(input, spatial_sigma, color_sigma);
-}
+torch::Tensor BilateralFilter(torch::Tensor input, float spatial_sigma, float color_sigma, bool usePHL);
diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp b/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp
index 1fb48cb6c9..847a452396 100644
--- a/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp
+++ b/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp
@@ -51,11 +51,11 @@ void BilateralFilterPHLCpu(
}
// Spatial features
- int offsetRemanider = i;
+ int offsetRemainder = i;
for (int d = 0; d < desc.dimensions; d++) {
- int coord = offsetRemanider / desc.strides[d];
- offsetRemanider -= coord * desc.strides[d];
+ int coord = offsetRemainder / desc.strides[d];
+ offsetRemainder -= coord * desc.strides[d];
features[i * featureChannels + desc.channelCount + d] = invSpatialSigma * coord;
}
diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu
index 4477ce5845..f73ae19ac9 100644
--- a/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu
+++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu
@@ -15,6 +15,7 @@ limitations under the License.
#include
#include
+#include "bilateral.h"
#include "utils/meta_macros.h"
#include "utils/tensor_description.h"
@@ -253,7 +254,7 @@ torch::Tensor BilateralFilterCuda(torch::Tensor inputTensor, float spatialSigma,
torch::Tensor outputTensor = torch::zeros_like(inputTensor);
#define CASE(c, d) BilateralFilterCuda(inputTensor, outputTensor, spatialSigma, colorSigma);
- SWITCH_AB(CASE, 16, 3, inputTensor.size(1), inputTensor.dim() - 2);
+ SWITCH_AB(CASE, BF_CUDA_MAX_CHANNELS, BF_CUDA_MAX_SPATIAL_DIMENSION, inputTensor.size(1), inputTensor.dim() - 2);
return outputTensor;
}
diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu
index 603ab689cf..719d1643d3 100644
--- a/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu
+++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu
@@ -15,6 +15,7 @@ limitations under the License.
#include
#include
+#include "bilateral.h"
#include "filtering/permutohedral/permutohedral.h"
#include "utils/meta_macros.h"
#include "utils/tensor_description.h"
@@ -95,7 +96,7 @@ void BilateralFilterPHLCuda(
cudaMalloc(&data, desc.batchCount * desc.channelStride * desc.channelCount * sizeof(scalar_t));
cudaMalloc(&features, desc.batchCount * desc.channelStride * featureChannelCount * sizeof(scalar_t));
- // Prparing constant memory
+ // Preparing constant memory
cudaMemcpyToSymbol(cBatchStride, &desc.batchStride, sizeof(int));
cudaMemcpyToSymbol(cChannelStride, &desc.channelStride, sizeof(int));
cudaMemcpyToSymbol(cSpatialStrides, desc.strides, sizeof(int) * desc.dimensions);
@@ -135,7 +136,7 @@ torch::Tensor BilateralFilterPHLCuda(torch::Tensor inputTensor, float spatialSig
inputTensor, outputTensor, spatialSigma, colorSigma); \
}));
- SWITCH_AB(CASE, 16, 3, inputTensor.size(1), inputTensor.dim() - 2);
+ SWITCH_AB(CASE, BF_CUDA_MAX_CHANNELS, BF_CUDA_MAX_SPATIAL_DIMENSION, inputTensor.size(1), inputTensor.dim() - 2);
return outputTensor;
}
diff --git a/monai/csrc/filtering/permutohedral/hash_table.cuh b/monai/csrc/filtering/permutohedral/hash_table.cuh
index 7d9d7eb163..f9893dffe2 100644
--- a/monai/csrc/filtering/permutohedral/hash_table.cuh
+++ b/monai/csrc/filtering/permutohedral/hash_table.cuh
@@ -15,7 +15,7 @@ limitations under the License.
//#define USE_ADDITIVE_HASH
-// turn this on if you want to get slighly less memory consumption and slightly longer run times.
+// turn this on if you want to get slightly less memory consumption and slightly longer run times.
//#define LINEAR_D_MEMORY
#define USE_CUSTOM_MODULO
diff --git a/monai/csrc/filtering/permutohedral/permutohedral.cpp b/monai/csrc/filtering/permutohedral/permutohedral.cpp
index 5d6916b8f4..d8fd3eaaeb 100644
--- a/monai/csrc/filtering/permutohedral/permutohedral.cpp
+++ b/monai/csrc/filtering/permutohedral/permutohedral.cpp
@@ -1,3 +1,19 @@
+/*
+Copyright 2020 - 2021 MONAI Consortium
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+#include
+#include
+
#include "utils/common_utils.h"
#include "utils/meta_macros.h"
@@ -33,6 +49,16 @@ torch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features) {
if (torch::cuda::is_available() && data.is_cuda()) {
CHECK_CONTIGUOUS_CUDA(data);
+ if (channelCount > PHL_CUDA_MAX_CHANNELS) {
+ throw std::runtime_error(
+ "PHL filtering not implemented for channel count > " + std::to_string(PHL_CUDA_MAX_CHANNELS));
+ }
+
+ if (featureCount > PHL_CUDA_MAX_FEATURES) {
+ throw std::runtime_error(
+ "PHL filtering not implemented for feature count > " + std::to_string(PHL_CUDA_MAX_FEATURES));
+ }
+
#define CASE(dc, fc) \
AT_DISPATCH_FLOATING_TYPES(data.scalar_type(), "PermutohedralCuda", ([&] { \
for (int batchIndex = 0; batchIndex < batchCount; batchIndex++) { \
@@ -42,7 +68,7 @@ torch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features) {
PermutohedralCuda(offsetData, offsetFeatures, elementCount, true); \
} \
}));
- SWITCH_AB(CASE, 16, 19, channelCount, featureCount);
+ SWITCH_AB(CASE, PHL_CUDA_MAX_CHANNELS, PHL_CUDA_MAX_FEATURES, channelCount, featureCount);
} else {
#endif
diff --git a/monai/csrc/filtering/permutohedral/permutohedral.h b/monai/csrc/filtering/permutohedral/permutohedral.h
index 27b0ff4859..32ffee83e5 100644
--- a/monai/csrc/filtering/permutohedral/permutohedral.h
+++ b/monai/csrc/filtering/permutohedral/permutohedral.h
@@ -11,9 +11,13 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
+#pragma once
+
#include
-#pragma once
+#define PHL_CUDA_MAX_CHANNELS 16
+#define PHL_CUDA_MAX_FEATURES 19
+
template
void PermutohedralCPU(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount);
#ifdef WITH_CUDA
diff --git a/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu b/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu
index b87a88a84f..d1d78eb940 100644
--- a/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu
+++ b/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu
@@ -38,7 +38,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/
-#define BLOCK_SIZE 64
+#define BLOCK_SIZE 32
#include
#include
@@ -47,6 +47,7 @@ SOFTWARE.
#include
#include "hash_table.cuh"
+#include "permutohedral.h"
#include "utils/meta_macros.h"
template
@@ -529,6 +530,8 @@ void PermutohedralCuda(scalar_t* values, scalar_t* positions, int elementCount,
destroyHashTable();
cudaFree(table_values);
+ cudaFree(scaleFactor);
+ cudaFree(matrix);
}
#define DECLARATION(dc, fc) \
diff --git a/monai/csrc/resample/pushpull.h b/monai/csrc/resample/pushpull.h
index 45fd5ce564..1c20cc0114 100644
--- a/monai/csrc/resample/pushpull.h
+++ b/monai/csrc/resample/pushpull.h
@@ -69,8 +69,8 @@ at::Tensor grid_pull(
CHECK_STRIDED(grid_opt)
CHECK_SAME_DEVICE(input_opt, grid_opt)
CHECK_SAME_DTYPE(input_opt, grid_opt)
- CHECK_SPATIAL_2D_OR_3D(input)
- CHECK_SPATIAL_2D_OR_3D(grid)
+ CHECK_SPATIAL_1D_2D_OR_3D(input)
+ CHECK_SPATIAL_1D_2D_OR_3D(grid)
CHECK_GRID_COMPONENT(grid, grid.dim())
CHECK_SPATIAL_NOT_EMPTY(input)
CHECK_SPATIAL_NOT_EMPTY(grid)
@@ -165,8 +165,8 @@ at::Tensor grid_push(
CHECK_STRIDED(grid_opt)
CHECK_SAME_DEVICE(input_opt, grid_opt)
CHECK_SAME_DTYPE(input_opt, grid_opt)
- CHECK_SPATIAL_2D_OR_3D(input)
- CHECK_SPATIAL_2D_OR_3D(grid)
+ CHECK_SPATIAL_1D_2D_OR_3D(input)
+ CHECK_SPATIAL_1D_2D_OR_3D(grid)
CHECK_GRID_COMPONENT(grid, grid.dim())
CHECK_SPATIAL_NOT_EMPTY(input)
CHECK_SPATIAL_NOT_EMPTY(grid)
@@ -175,7 +175,10 @@ at::Tensor grid_push(
CHECK_VEC_NOT_EMPTY(interpolation_mode);
if (source_size.empty()) {
- auto size = c10::IntArrayRef({input.size(2), input.size(3), input.dim() == 5 ? input.size(4) : 1});
+ auto size = c10::IntArrayRef(
+ {input.dim() >= 3 ? input.size(2) : 1,
+ input.dim() >= 4 ? input.size(3) : 1,
+ input.dim() >= 5 ? input.size(4) : 1});
if (input.is_cuda())
#ifdef WITH_CUDA
return cuda::pushpull(
@@ -295,14 +298,15 @@ at::Tensor grid_count(
CHECK_DEFINED(grid)
auto grid_opt = grid.options();
CHECK_STRIDED(grid_opt)
- CHECK_SPATIAL_2D_OR_3D(grid)
+ CHECK_SPATIAL_1D_2D_OR_3D(grid)
CHECK_GRID_COMPONENT(grid, grid.dim())
CHECK_SPATIAL_NOT_EMPTY(grid)
CHECK_VEC_NOT_EMPTY(bound_mode);
CHECK_VEC_NOT_EMPTY(interpolation_mode);
if (source_size.empty()) {
- auto size = c10::IntArrayRef({grid.size(1), grid.size(2), grid.dim() == 5 ? grid.size(3) : 1});
+ auto size = c10::IntArrayRef(
+ {grid.dim() >= 3 ? grid.size(2) : 1, grid.dim() >= 4 ? grid.size(3) : 1, grid.dim() >= 5 ? grid.size(4) : 1});
if (grid.is_cuda())
#ifdef WITH_CUDA
return cuda::pushpull(
@@ -422,8 +426,8 @@ at::Tensor grid_grad(
CHECK_STRIDED(grid_opt)
CHECK_SAME_DEVICE(input_opt, grid_opt)
CHECK_SAME_DTYPE(input_opt, grid_opt)
- CHECK_SPATIAL_2D_OR_3D(input)
- CHECK_SPATIAL_2D_OR_3D(grid)
+ CHECK_SPATIAL_1D_2D_OR_3D(input)
+ CHECK_SPATIAL_1D_2D_OR_3D(grid)
CHECK_GRID_COMPONENT(grid, grid.dim())
CHECK_SPATIAL_NOT_EMPTY(input)
CHECK_SPATIAL_NOT_EMPTY(grid)
diff --git a/monai/csrc/resample/pushpull_cpu.cpp b/monai/csrc/resample/pushpull_cpu.cpp
index 40743a6cf1..dd10dd76ee 100644
--- a/monai/csrc/resample/pushpull_cpu.cpp
+++ b/monai/csrc/resample/pushpull_cpu.cpp
@@ -18,13 +18,14 @@ limitations under the License.
// It handles boundary conditions and interpolation orders defined in
// `utils/resample_utils.h` and `utils/resample_utils.h`.
// These parameters can be specified per dimension.
-// Isotorpic 0-th and 1-st order interpolation have their own (faster)
+// Isotropic 0-th and 1-st order interpolation have their own (faster)
// implementations. Sliding boundary conditions are also implemented
// separately.
// TODO:
// . [DONE] generic 3d
// . [DONE] generic 2d
+// . [DONE] generic 1d
// . sliding nearest 3d
// . sliding nearest 2d
// . sliding linear 3d
@@ -37,6 +38,7 @@ limitations under the License.
// . input bound/inter are always vectors -> clean unused constructors
#include
+#include
#include
#include "bounds_common.h"
#include "interpolation_common.h"
@@ -44,7 +46,7 @@ limitations under the License.
//#include
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-// CPU/GPU -specific parameters
+// CPU-specific parameters
#include
namespace {
// This parameter specifies the minimum number of voxels that should be
@@ -74,18 +76,27 @@ MONAI_NAMESPACE_DEVICE { // cpu
namespace { // anonymous namespace > everything inside has internal linkage
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- // GENERIC PUSHPULL CLASS
+ // INDEXING UTILS
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- // This class implements the bulk of the code.
- // /!\ No type and shape checking is performed here.
- template
- class PushPullImpl {
+ // This class reads and sets all the parameters that will later be used
+ // by the algorithm in PushPullImpl. All of this is done outside of the
+ // implementation class so that we do not depend on generic types. The
+ // point is to pre-allocate all necessary tensors so that we can check
+ // if they're all compatible with 32 bit math. If it's the case, we can
+ // dispatch to a 32b cuda implementation, which might increase
+ // performance. Else, we use 64 bit math to compute offsets.
+ // (On CPU, we always use 64 bit offsets because it doesn't make a huge
+ // difference. It would be different if we had a vectorized
+ // implementation as in PyTorch).
+ class PushPullAllocator {
public:
+ static constexpr int64_t max_int32 = std::numeric_limits::max();
+
// ~~~ CONSTRUCTORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
MONAI_HOST
- PushPullImpl(
+ PushPullAllocator(
int dim,
BoundVectorRef bound,
InterpolationVectorRef interpolation,
@@ -125,101 +136,418 @@ MONAI_NAMESPACE_DEVICE { // cpu
iso = interpolation0 == interpolation1 && interpolation0 == interpolation2;
}
- MONAI_HOST
- PushPullImpl(
- int dim,
- BoundType bound,
- InterpolationVectorRef interpolation,
- bool extrapolate,
- bool do_pull,
- bool do_push,
- bool do_count,
- bool do_grad,
- bool do_sgrad)
- : dim(dim),
- bound0(bound),
- bound1(bound),
- bound2(bound),
- interpolation0(interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear),
- interpolation1(
- interpolation.size() > 1 ? interpolation[1]
- : interpolation.size() > 0 ? interpolation[0]
- : InterpolationType::Linear),
- interpolation2(
- interpolation.size() > 2 ? interpolation[2]
- : interpolation.size() > 1 ? interpolation[1]
- : interpolation.size() > 0 ? interpolation[0]
- : InterpolationType::Linear),
- extrapolate(extrapolate),
- do_pull(do_pull),
- do_push(do_push),
- do_count(do_count),
- do_grad(do_grad),
- do_sgrad(do_sgrad) {
- iso = interpolation0 == interpolation1 && interpolation0 == interpolation2;
+ // ~~~ FUNCTORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ // Usually used for pull:
+ // - do_pull -> return source[grid]
+ // - do_push -> fails
+ // - do_grad -> return J(source)[grid]
+ // - do_sgrad -> return H(source)[grid]
+ MONAI_HOST void ioset(const Tensor& source, const Tensor& grid) {
+ init_all();
+ init_source(source);
+ init_grid(grid);
+ init_output();
}
- MONAI_HOST
- PushPullImpl(
- int dim,
- BoundVectorRef bound,
- InterpolationType interpolation,
- bool extrapolate,
- bool do_pull,
- bool do_push,
- bool do_count,
- bool do_grad,
- bool do_sgrad)
- : dim(dim),
- bound0(bound.size() > 0 ? bound[0] : BoundType::Replicate),
- bound1(
- bound.size() > 1 ? bound[1]
- : bound.size() > 0 ? bound[0]
- : BoundType::Replicate),
- bound2(
- bound.size() > 2 ? bound[2]
- : bound.size() > 1 ? bound[1]
- : bound.size() > 0 ? bound[0]
- : BoundType::Replicate),
- interpolation0(interpolation),
- interpolation1(interpolation),
- interpolation2(interpolation),
- extrapolate(extrapolate),
- do_pull(do_pull),
- do_push(do_push),
- do_count(do_count),
- do_grad(do_grad),
- do_sgrad(do_sgrad) {
- iso = interpolation0 == interpolation1 && interpolation0 == interpolation2;
+ // Usually used for pull_backward:
+ // - do_pull -> return source[grid]
+ // - do_push -> return push(target, grid, source.shape)
+ // - do_grad -> return J(source)[grid]
+ // - do_sgrad -> return H(source)[grid]
+ MONAI_HOST void ioset(const Tensor& source, const Tensor& grid, const Tensor& target) {
+ init_all();
+ init_source(source);
+ init_grid(grid);
+ init_target(target);
+ init_output();
}
- MONAI_HOST
- PushPullImpl(
- int dim,
- BoundType bound,
- InterpolationType interpolation,
- bool extrapolate,
- bool do_pull,
- bool do_push,
- bool do_count,
- bool do_grad,
- bool do_sgrad)
- : dim(dim),
- bound0(bound),
- bound1(bound),
- bound2(bound),
- interpolation0(interpolation),
- interpolation1(interpolation),
- interpolation2(interpolation),
- extrapolate(extrapolate),
- do_pull(do_pull),
- do_push(do_push),
- do_count(do_count),
- do_grad(do_grad),
- do_sgrad(do_sgrad) {
- iso = interpolation0 == interpolation1 && interpolation0 == interpolation2;
+ // Usually used for push:
+ // - do_pull -> fails
+ // - do_push -> return push(target, grid, source_size)
+ // - do_grad -> fails
+ // - do_sgrad -> fails
+ MONAI_HOST void ioset(IntArrayRef source_size, const Tensor& grid, const Tensor& target) {
+ init_all();
+ init_source(source_size);
+ init_grid(grid);
+ init_target(target);
+ init_output();
}
+ // Usually used for count:
+ // - do_pull -> fails
+ // - do_push -> return push(ones, grid, source_size)
+ // - do_grad -> fails
+ // - do_sgrad -> fails
+ MONAI_HOST void ioset(IntArrayRef source_size, const Tensor& grid) {
+ init_all();
+ init_source(source_size);
+ init_grid(grid);
+ init_output();
+ }
+
+ // We just check that all tensors that we own are compatible with 32b math
+ bool canUse32BitIndexMath(int64_t max_elem = max_int32) const {
+ return src_32b_ok && trgt_32b_ok && grid_32b_ok && grad_32b_ok && out_32b_ok;
+ }
+
+ private:
+ // Copied from aten/src/ATen/native/IndexingUtils.cpp in PyTorch 1.6.
+ // It is used to decide to which pointer type we should dispatch to.
+ // Basically, we need to make sure that the "furthest" element we need
+ // to reach is less than max_elem away.
+ static bool tensorCanUse32BitIndexMath(const Tensor& t, int64_t max_elem = max_int32) {
+ int64_t elements = t.numel();
+ if (elements >= max_elem) {
+ return false;
+ }
+ if (elements == 0) {
+ return max_elem > 0;
+ }
+
+ int64_t offset = 0;
+ int64_t linearId = elements - 1;
+
+ // NOTE: Assumes all strides are positive, which is true for now
+ for (int i = t.dim() - 1; i >= 0; --i) {
+ int64_t curDimIndex = linearId % t.size(i);
+ int64_t curDimOffset = curDimIndex * t.stride(i);
+ offset += curDimOffset;
+ linearId /= t.size(i);
+ }
+
+ if (offset >= max_elem) {
+ return false;
+ }
+
+ return true;
+ }
+
+ // ~~~ COMPONENTS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ MONAI_HOST void init_all();
+ MONAI_HOST void init_source(const Tensor& source);
+ MONAI_HOST void init_source(IntArrayRef source_size);
+ MONAI_HOST void init_grid(const Tensor& grid);
+ MONAI_HOST void init_target(const Tensor& target);
+ MONAI_HOST void init_output();
+
+ // ~~~ OPTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ int dim; // dimensionality (2 or 3)
+ BoundType bound0; // boundary condition // x|W
+ BoundType bound1; // boundary condition // y|H
+ BoundType bound2; // boundary condition // z|D
+ InterpolationType interpolation0; // interpolation order // x|W
+ InterpolationType interpolation1; // interpolation order // y|H
+ InterpolationType interpolation2; // interpolation order // z|D
+ bool iso; // isotropic interpolation?
+ bool extrapolate; // compute out-of-bound values
+ bool do_pull; // sample a volume
+ bool do_push; // splat a volume
+ bool do_count; // splatting weights (= jacobian determinant)
+ bool do_grad; // backprop: gradient of grid // pull
+ bool do_sgrad; // sample spatial gradients
+
+ // ~~~ NAVIGATORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ std::deque output;
+ TensorOptions src_opt;
+ TensorOptions grid_opt;
+ TensorOptions trgt_opt;
+ int64_t N;
+ int64_t C;
+ int64_t src_X;
+ int64_t src_Y;
+ int64_t src_Z;
+ int64_t trgt_X;
+ int64_t trgt_Y;
+ int64_t trgt_Z;
+ int64_t trgt_K;
+ int64_t src_sN;
+ int64_t src_sC;
+ int64_t src_sX;
+ int64_t src_sY;
+ int64_t src_sZ;
+ bool src_32b_ok;
+ void* src_ptr;
+ int64_t trgt_sN;
+ int64_t trgt_sC;
+ int64_t trgt_sX;
+ int64_t trgt_sY;
+ int64_t trgt_sZ;
+ int64_t trgt_sK;
+ bool trgt_32b_ok;
+ void* trgt_ptr;
+ int64_t grid_sN;
+ int64_t grid_sC;
+ int64_t grid_sX;
+ int64_t grid_sY;
+ int64_t grid_sZ;
+ bool grid_32b_ok;
+ void* grid_ptr;
+ int64_t out_sN;
+ int64_t out_sC;
+ int64_t out_sX;
+ int64_t out_sY;
+ int64_t out_sZ;
+ int64_t out_sK; // gradient dimension
+ bool out_32b_ok;
+ void* out_ptr;
+ int64_t grad_sN;
+ int64_t grad_sC;
+ int64_t grad_sX;
+ int64_t grad_sY;
+ int64_t grad_sZ;
+ bool grad_32b_ok;
+ void* grad_ptr;
+
+ // Allow PushPullImpl's constructor to access PushPullAllocator's
+ // private members.
+ template
+ friend class PushPullImpl;
+ };
+
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ // INITIALISATION
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ MONAI_HOST
+ void PushPullAllocator::init_all() {
+ src_opt = grid_opt = trgt_opt = TensorOptions();
+ N = C = 1L;
+ src_X = src_Y = src_Z = 1L;
+ trgt_X = trgt_Y = trgt_Z = 1L;
+ trgt_K = 0L;
+ src_sN = src_sC = src_sX = src_sY = src_sZ = 0L;
+ grid_sN = grid_sC = grid_sX = grid_sY = grid_sZ = 0L;
+ grad_sN = grad_sC = grad_sX = grad_sY = grad_sZ = 0L;
+ trgt_sN = trgt_sC = trgt_sX = trgt_sY = trgt_sZ = trgt_sK = 0L;
+ out_sN = out_sC = out_sX = out_sY = out_sZ = out_sK = 0L;
+ src_ptr = trgt_ptr = grid_ptr = out_ptr = grad_ptr = static_cast(0);
+ src_32b_ok = trgt_32b_ok = grid_32b_ok = out_32b_ok = grad_32b_ok = true;
+ }
+
+ MONAI_HOST
+ void PushPullAllocator::init_source(const Tensor& source) {
+ N = source.size(0);
+ C = source.size(1);
+ src_X = source.size(2);
+ src_Y = dim < 2 ? 1L : source.size(3);
+ src_Z = dim < 3 ? 1L : source.size(4);
+ src_sN = source.stride(0);
+ src_sC = source.stride(1);
+ src_sX = source.stride(2);
+ src_sY = dim < 2 ? 0L : source.stride(3);
+ src_sZ = dim < 3 ? 0L : source.stride(4);
+ src_ptr = source.data_ptr();
+ src_opt = source.options();
+ src_32b_ok = tensorCanUse32BitIndexMath(source);
+ }
+
+ MONAI_HOST
+ void PushPullAllocator::init_source(IntArrayRef source_size) {
+ src_X = source_size[0];
+ src_Y = dim < 2 ? 1L : source_size[1];
+ src_Z = dim < 3 ? 1L : source_size[2];
+ }
+
+ MONAI_HOST
+ void PushPullAllocator::init_grid(const Tensor& grid) {
+ N = grid.size(0);
+ trgt_X = grid.size(1);
+ trgt_Y = dim < 2 ? 1L : grid.size(2);
+ trgt_Z = dim < 3 ? 1L : grid.size(3);
+ grid_sN = grid.stride(0);
+ grid_sX = grid.stride(1);
+ grid_sY = dim < 2 ? 0L : grid.stride(2);
+ grid_sZ = dim < 3 ? 0L : grid.stride(3);
+ grid_sC = grid.stride(dim == 1 ? 2 : dim == 2 ? 3 : 4);
+ grid_ptr = grid.data_ptr();
+ grid_opt = grid.options();
+ grid_32b_ok = tensorCanUse32BitIndexMath(grid);
+ }
+
+ MONAI_HOST
+ void PushPullAllocator::init_target(const Tensor& target) {
+ N = target.size(0);
+ C = target.size(1);
+ trgt_X = target.size(2);
+ trgt_Y = dim < 2 ? 1L : target.size(3);
+ trgt_Z = dim < 3 ? 1L : target.size(4);
+ trgt_K = target.dim() == dim + 3 ? target.size(dim == 1 ? 3 : dim == 2 ? 4 : 5) : 0L;
+ trgt_sN = target.stride(0);
+ trgt_sC = target.stride(1);
+ trgt_sX = target.stride(2);
+ trgt_sY = dim < 2 ? 0L : target.stride(3);
+ trgt_sZ = dim < 3 ? 0L : target.stride(4);
+ trgt_sK = target.dim() == dim + 3 ? target.stride(dim == 1 ? 3 : dim == 2 ? 4 : 5) : 0L;
+ trgt_ptr = target.data_ptr();
+ trgt_opt = target.options();
+ trgt_32b_ok = tensorCanUse32BitIndexMath(target);
+ }
+
+ MONAI_HOST
+ void PushPullAllocator::init_output() {
+ output.clear();
+ if (do_pull) {
+ if (dim == 1)
+ output.push_back(at::empty({N, C, trgt_X}, src_opt));
+ else if (dim == 2)
+ output.push_back(at::empty({N, C, trgt_X, trgt_Y}, src_opt));
+ else
+ output.push_back(at::empty({N, C, trgt_X, trgt_Y, trgt_Z}, src_opt));
+ auto pull = output.back();
+ out_sN = pull.stride(0);
+ out_sC = pull.stride(1);
+ out_sX = pull.stride(2);
+ out_sY = dim < 2 ? 0L : pull.stride(3);
+ out_sZ = dim < 3 ? 0L : pull.stride(4);
+ out_sK = 0L;
+ out_ptr = pull.data_ptr();
+ out_32b_ok = tensorCanUse32BitIndexMath(pull);
+ } else if (do_sgrad) {
+ if (dim == 1)
+ output.push_back(at::empty({N, C, trgt_X, 1}, src_opt));
+ else if (dim == 2)
+ output.push_back(at::empty({N, C, trgt_X, trgt_Y, 2}, src_opt));
+ else
+ output.push_back(at::empty({N, C, trgt_X, trgt_Y, trgt_Z, 3}, src_opt));
+ auto sgrad = output.back();
+ out_sN = sgrad.stride(0);
+ out_sC = sgrad.stride(1);
+ out_sX = sgrad.stride(2);
+ out_sY = dim < 2 ? 0L : sgrad.stride(3);
+ out_sZ = dim < 3 ? 0L : sgrad.stride(4);
+ out_sK = sgrad.stride(dim == 1 ? 3 : dim == 2 ? 4 : 5);
+ out_ptr = sgrad.data_ptr();
+ out_32b_ok = tensorCanUse32BitIndexMath(sgrad);
+
+ if (iso && interpolation0 == InterpolationType::Nearest)
+ sgrad.zero_();
+ if (iso && interpolation0 == InterpolationType::Linear && dim == 1)
+ sgrad.zero_();
+ } else if (do_push) {
+ if (dim == 1)
+ output.push_back(at::zeros({N, C, src_X}, trgt_opt));
+ else if (dim == 2)
+ output.push_back(at::zeros({N, C, src_X, src_Y}, trgt_opt));
+ else
+ output.push_back(at::zeros({N, C, src_X, src_Y, src_Z}, trgt_opt));
+ auto push = output.back();
+ out_sN = push.stride(0);
+ out_sC = push.stride(1);
+ out_sX = push.stride(2);
+ out_sY = dim < 2 ? 0L : push.stride(3);
+ out_sZ = dim < 3 ? 0L : push.stride(4);
+ out_sK = 0L;
+ out_ptr = push.data_ptr();
+ out_32b_ok = tensorCanUse32BitIndexMath(push);
+ } else if (do_count) {
+ if (dim == 1)
+ output.push_back(at::zeros({N, 1, src_X}, grid_opt));
+ else if (dim == 2)
+ output.push_back(at::zeros({N, 1, src_X, src_Y}, grid_opt));
+ else
+ output.push_back(at::zeros({N, 1, src_X, src_Y, src_Z}, grid_opt));
+ auto count = output.back();
+ out_sN = count.stride(0);
+ out_sC = count.stride(1);
+ out_sX = count.stride(2);
+ out_sY = dim < 2 ? 0L : count.stride(3);
+ out_sZ = dim < 3 ? 0L : count.stride(4);
+ out_sK = 0L;
+ out_ptr = count.data_ptr();
+ out_32b_ok = tensorCanUse32BitIndexMath(count);
+ }
+ if (do_grad) {
+ if (dim == 1)
+ output.push_back(at::zeros({N, trgt_X, 1}, grid_opt));
+ else if (dim == 2)
+ output.push_back(at::zeros({N, trgt_X, trgt_Y, 2}, grid_opt));
+ else
+ output.push_back(at::zeros({N, trgt_X, trgt_Y, trgt_Z, 3}, grid_opt));
+ auto grad = output.back();
+ grad_sN = grad.stride(0);
+ grad_sX = grad.stride(1);
+ grad_sY = dim < 2 ? 0L : grad.stride(2);
+ grad_sZ = dim < 3 ? 0L : grad.stride(3);
+ grad_sC = grad.stride(dim == 1 ? 2 : dim == 2 ? 3 : 4);
+ grad_ptr = grad.data_ptr();
+ out_32b_ok = tensorCanUse32BitIndexMath(grad);
+
+ if (iso && interpolation0 == InterpolationType::Nearest)
+ grad.zero_();
+ }
+ }
+
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ // GENERIC PUSHPULL CLASS
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ // This class implements the bulk of the code.
+ // /!\ No type and shape checking is performed here.
+
+ template
+ class PushPullImpl {
+ public:
+ // ~~~ CONSTRUCTOR ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ PushPullImpl(const PushPullAllocator& info)
+ : output(info.output),
+ dim(info.dim),
+ bound0(info.bound0),
+ bound1(info.bound1),
+ bound2(info.bound2),
+ interpolation0(info.interpolation0),
+ interpolation1(info.interpolation1),
+ interpolation2(info.interpolation1),
+ iso(info.iso),
+ extrapolate(info.extrapolate),
+ do_pull(info.do_pull),
+ do_push(info.do_push),
+ do_count(info.do_count),
+ do_grad(info.do_grad),
+ do_sgrad(info.do_sgrad),
+ N(static_cast(info.N)),
+ C(static_cast(info.C)),
+ src_X(static_cast(info.src_X)),
+ src_Y(static_cast(info.src_Y)),
+ src_Z(static_cast(info.src_Z)),
+ trgt_X(static_cast(info.trgt_X)),
+ trgt_Y(static_cast(info.trgt_Y)),
+ trgt_Z(static_cast(info.trgt_Z)),
+ trgt_K(static_cast(info.trgt_K)),
+ src_sN(static_cast(info.src_sN)),
+ src_sC(static_cast(info.src_sC)),
+ src_sX(static_cast(info.src_sX)),
+ src_sY(static_cast(info.src_sY)),
+ src_sZ(static_cast(info.src_sZ)),
+ src_ptr(static_cast(info.src_ptr)),
+ trgt_sN(static_cast(info.trgt_sN)),
+ trgt_sC(static_cast(info.trgt_sC)),
+ trgt_sX(static_cast(info.trgt_sX)),
+ trgt_sY(static_cast(info.trgt_sY)),
+ trgt_sZ(static_cast(info.trgt_sZ)),
+ trgt_sK(static_cast(info.trgt_sK)),
+ trgt_ptr(static_cast(info.trgt_ptr)),
+ grid_sN(static_cast(info.grid_sN)),
+ grid_sC(static_cast(info.grid_sC)),
+ grid_sX(static_cast(info.grid_sX)),
+ grid_sY(static_cast(info.grid_sY)),
+ grid_sZ(static_cast(info.grid_sZ)),
+ grid_ptr(static_cast(info.grid_ptr)),
+ out_sN(static_cast(info.out_sN)),
+ out_sC(static_cast(info.out_sC)),
+ out_sX(static_cast(info.out_sX)),
+ out_sY(static_cast(info.out_sY)),
+ out_sZ(static_cast(info.out_sZ)),
+ out_sK(static_cast(info.out_sK)),
+ out_ptr(static_cast(info.out_ptr)),
+ grad_sN(static_cast(info.grad_sN)),
+ grad_sC(static_cast(info.grad_sC)),
+ grad_sX(static_cast(info.grad_sX)),
+ grad_sY(static_cast(info.grad_sY)),
+ grad_sZ(static_cast(info.grad_sZ)),
+ grad_ptr(static_cast(info.grad_ptr)) {}
+
// ~~~ PUBLIC VALUE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
std::deque output;
@@ -247,39 +575,8 @@ MONAI_NAMESPACE_DEVICE { // cpu
// }
// ~~~ FUNCTORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- MONAI_HOST void ioset // Pull
- (const Tensor& source, const Tensor& grid) {
- init_all();
- init_source(source);
- init_grid(grid);
- init_output();
- }
-
- MONAI_HOST void ioset(const Tensor& source, const Tensor& grid, const Tensor& target) {
- init_all();
- init_source(source);
- init_grid(grid);
- init_target(target);
- init_output();
- }
-
- MONAI_HOST void ioset // Push
- (IntArrayRef source_size, const Tensor& grid, const Tensor& target) {
- init_all();
- init_source(source_size);
- init_grid(grid);
- init_target(target);
- init_output();
- }
-
- MONAI_HOST void ioset // Count
- (IntArrayRef source_size, const Tensor& grid) {
- init_all();
- init_source(source_size);
- init_grid(grid);
- init_output();
- }
+ // Loop over all voxels
void loop() const;
MONAI_HOST MONAI_DEVICE int64_t voxcount() const {
@@ -288,14 +585,18 @@ MONAI_NAMESPACE_DEVICE { // cpu
private:
// ~~~ COMPONENTS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- MONAI_HOST void init_all();
- MONAI_HOST void init_source(const Tensor& source);
- MONAI_HOST void init_source(IntArrayRef source_size);
- MONAI_HOST void init_grid(const Tensor& grid);
- MONAI_HOST void init_target(const Tensor& target);
- MONAI_HOST void init_output();
+ MONAI_DEVICE void check1d(offset_t w, offset_t n) const;
MONAI_DEVICE void check2d(offset_t w, offset_t h, offset_t n) const;
MONAI_DEVICE void check3d(offset_t w, offset_t h, offset_t d, offset_t n) const;
+ MONAI_DEVICE void interpolate1d(scalar_t x, offset_t w, offset_t n) const;
+ MONAI_DEVICE void interpolate1d_nearest(scalar_t x, offset_t w, offset_t n) const;
+ MONAI_DEVICE void interpolate1d_linear(scalar_t x, offset_t w, offset_t n) const;
+ MONAI_DEVICE void interpolate1d_sliding(scalar_t x, offset_t w, offset_t n) const { /*TODO*/
+ }
+ MONAI_DEVICE void interpolate1d_sliding_nearest(scalar_t x, offset_t w, offset_t n) const { /*TODO*/
+ }
+ MONAI_DEVICE void interpolate1d_sliding_linear(scalar_t x, offset_t w, offset_t n) const { /*TODO*/
+ }
MONAI_DEVICE void interpolate2d(scalar_t x, scalar_t y, offset_t w, offset_t h, offset_t n) const;
MONAI_DEVICE void interpolate2d_nearest(scalar_t x, scalar_t y, offset_t w, offset_t h, offset_t n) const;
MONAI_DEVICE void interpolate2d_bilinear(scalar_t x, scalar_t y, offset_t w, offset_t h, offset_t n) const;
@@ -370,9 +671,6 @@ MONAI_NAMESPACE_DEVICE { // cpu
bool do_sgrad; // sample spatial gradients
// ~~~ NAVIGATORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- TensorOptions src_opt;
- TensorOptions grid_opt;
- TensorOptions trgt_opt;
offset_t N;
offset_t C;
offset_t src_X;
@@ -402,174 +700,24 @@ MONAI_NAMESPACE_DEVICE { // cpu
offset_t grid_sZ;
scalar_t* grid_ptr;
offset_t out_sN;
- offset_t out_sC;
- offset_t out_sX;
- offset_t out_sY;
- offset_t out_sZ;
- offset_t out_sK; // gradient dimension
- scalar_t* out_ptr;
- offset_t grad_sN;
- offset_t grad_sC;
- offset_t grad_sX;
- offset_t grad_sY;
- offset_t grad_sZ;
- scalar_t* grad_ptr;
- };
-
- // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- // INITIALISATION
- // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
- template
- void PushPullImpl::init_all() {
- src_opt = grid_opt = trgt_opt = TensorOptions();
- N = C = static_cast(1);
- src_X = src_Y = src_Z = static_cast(1);
- trgt_X = trgt_Y = trgt_Z = trgt_K = static_cast(1);
- src_sN = src_sC = src_sX = src_sY = src_sZ = static_cast(0);
- grid_sN = grid_sC = grid_sX = grid_sY = grid_sZ = static_cast(0);
- grad_sN = grad_sC = grad_sX = grad_sY = grad_sZ = static_cast(0);
- trgt_sN = trgt_sC = trgt_sX = trgt_sY = trgt_sZ = trgt_sK = static_cast(0);
- out_sN = out_sC = out_sX = out_sY = out_sZ = out_sK = static_cast(0);
- src_ptr = trgt_ptr = grid_ptr = out_ptr = grad_ptr = static_cast(0);
- }
-
- template
- MONAI_HOST void PushPullImpl::init_source(const Tensor& source) {
- N = source.size(0);
- C = source.size(1);
- src_X = source.size(2);
- src_Y = source.size(3);
- src_Z = dim == 2 ? static_cast(1) : source.size(4);
- src_sN = source.stride(0);
- src_sC = source.stride(1);
- src_sX = source.stride(2);
- src_sY = source.stride(3);
- src_sZ = dim == 2 ? static_cast(0) : source.stride(4);
- src_ptr = source.data_ptr();
- src_opt = source.options();
- }
-
- template
- MONAI_HOST void PushPullImpl::init_source(IntArrayRef source_size) {
- src_X = source_size[0];
- src_Y = source_size[1];
- src_Z = dim == 2 ? static_cast(1) : source_size[2];
- }
-
- template
- MONAI_HOST void PushPullImpl::init_grid(const Tensor& grid) {
- N = grid.size(0);
- trgt_X = grid.size(1);
- trgt_Y = grid.size(2);
- trgt_Z = dim == 2 ? static_cast(1) : grid.size(3);
- grid_sN = grid.stride(0);
- grid_sX = grid.stride(1);
- grid_sY = grid.stride(2);
- grid_sZ = dim == 2 ? static_cast(0) : grid.stride(3);
- grid_sC = grid.stride(dim == 2 ? 3 : 4);
- grid_ptr = grid.data_ptr();
- grid_opt = grid.options();
- }
-
- template
- MONAI_HOST void PushPullImpl::init_target(const Tensor& target) {
- N = target.size(0);
- C = target.size(1);
- trgt_X = target.size(2);
- trgt_Y = target.size(3);
- trgt_Z = dim == 2 ? static_cast(1) : target.size(4);
- trgt_K = target.dim() == dim + 3 ? target.size(dim == 2 ? 4 : 5) : static_cast(1);
- trgt_sN = target.stride(0);
- trgt_sC = target.stride(1);
- trgt_sX = target.stride(2);
- trgt_sY = target.stride(3);
- trgt_sZ = dim == 2 ? static_cast(0) : target.stride(4);
- trgt_sK = target.dim() == dim + 3 ? target.stride(dim == 2 ? 4 : 5) : static_cast(0);
- trgt_ptr = target.data_ptr();
- trgt_opt = target.options();
- }
-
- template
- MONAI_HOST void PushPullImpl::init_output() {
- output.clear();
- if (do_pull) {
- if (dim == 2)
- output.push_back(at::empty({N, C, trgt_X, trgt_Y}, src_opt));
- else
- output.push_back(at::empty({N, C, trgt_X, trgt_Y, trgt_Z}, src_opt));
- auto pull = output.back();
- out_sN = pull.stride(0);
- out_sC = pull.stride(1);
- out_sX = pull.stride(2);
- out_sY = pull.stride(3);
- out_sZ = dim == 2 ? static_cast(0) : pull.stride(4);
- out_sK = static_cast(0);
- out_ptr = pull.template data_ptr();
- } else if (do_sgrad) {
- if (dim == 2)
- output.push_back(at::empty({N, C, trgt_X, trgt_Y, 2}, src_opt));
- else
- output.push_back(at::empty({N, C, trgt_X, trgt_Y, trgt_Z, 3}, src_opt));
- auto sgrad = output.back();
- out_sN = sgrad.stride(0);
- out_sC = sgrad.stride(1);
- out_sX = sgrad.stride(2);
- out_sY = sgrad.stride(3);
- out_sZ = dim == 2 ? static_cast(0) : sgrad.stride(4);
- out_sK = sgrad.stride(dim == 2 ? 4 : 5);
- out_ptr = sgrad.template data_ptr();
-
- if (iso && interpolation0 == InterpolationType::Nearest)
- sgrad.zero_();
- } else if (do_push) {
- if (dim == 2)
- output.push_back(at::zeros({N, C, src_X, src_Y}, trgt_opt));
- else
- output.push_back(at::zeros({N, C, src_X, src_Y, src_Z}, trgt_opt));
- auto push = output.back();
- out_sN = push.stride(0);
- out_sC = push.stride(1);
- out_sX = push.stride(2);
- out_sY = push.stride(3);
- out_sZ = dim == 2 ? static_cast(0) : push.stride(4);
- out_sK = static_cast(0);
- out_ptr = push.template data_ptr();
- } else if (do_count) {
- if (dim == 2)
- output.push_back(at::zeros({N, 1, src_X, src_Y}, grid_opt));
- else
- output.push_back(at::zeros({N, 1, src_X, src_Y, src_Z}, grid_opt));
- auto count = output.back();
- out_sN = count.stride(0);
- out_sC = count.stride(1);
- out_sX = count.stride(2);
- out_sY = count.stride(3);
- out_sZ = dim == 2 ? static_cast(0) : count.stride(4);
- out_sK = static_cast(0);
- out_ptr = count.template data_ptr();
- }
- if (do_grad) {
- if (dim == 2)
- output.push_back(at::zeros({N, src_X, src_Y, 2}, grid_opt));
- else
- output.push_back(at::zeros({N, src_X, src_Y, src_Z, 3}, grid_opt));
- auto grad = output.back();
- grad_sN = grad.stride(0);
- grad_sX = grad.stride(1);
- grad_sY = grad.stride(2);
- grad_sZ = dim == 2 ? static_cast(0) : grad.stride(3);
- grad_sC = grad.stride(dim == 2 ? 3 : 4);
- grad_ptr = grad.template data_ptr();
-
- if (iso && interpolation0 == InterpolationType::Nearest)
- grad.zero_();
- }
- }
+ offset_t out_sC;
+ offset_t out_sX;
+ offset_t out_sY;
+ offset_t out_sZ;
+ offset_t out_sK; // gradient dimension
+ scalar_t* out_ptr;
+ offset_t grad_sN;
+ offset_t grad_sC;
+ offset_t grad_sX;
+ offset_t grad_sY;
+ offset_t grad_sZ;
+ scalar_t* grad_ptr;
+ };
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// LOOP
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
// This bit loops over all target voxels. We therefore need to
// convert linear indices to multivariate indices. The way I do it
// might not be optimal.
@@ -586,7 +734,10 @@ MONAI_NAMESPACE_DEVICE { // cpu
// parallelize across voxels.
at::parallel_for(0, N, 0, [&](offset_t start, offset_t end) {
for (offset_t n = start; n < end; ++n) {
- if (dim == 2) {
+ if (dim == 1) {
+ for (offset_t w = 0; w < trgt_X; ++w)
+ check1d(w, n);
+ } else if (dim == 2) {
for (offset_t h = 0; h < trgt_Y; ++h)
for (offset_t w = 0; w < trgt_X; ++w)
check2d(w, h, n);
@@ -600,8 +751,8 @@ MONAI_NAMESPACE_DEVICE { // cpu
});
return;
}
-#endif
+#endif
// Parallelize across voxels
offset_t trgt_NXYZ = trgt_Z * trgt_Y * trgt_X * N;
offset_t trgt_XYZ = trgt_Z * trgt_Y * trgt_X;
@@ -615,7 +766,9 @@ MONAI_NAMESPACE_DEVICE { // cpu
h = (i / trgt_Z) % trgt_Y;
d = i % trgt_Z;
- if (dim == 2)
+ if (dim == 1)
+ check1d(w, n);
+ else if (dim == 2)
check2d(w, h, n);
else
check3d(w, h, d, n);
@@ -631,6 +784,59 @@ MONAI_NAMESPACE_DEVICE { // cpu
// 1) read the [x,y,z] source coordinate for the current target voxel
// 3) check if the source coordinate is in bounds
+ template
+ MONAI_DEVICE void PushPullImpl::check3d(offset_t w, offset_t h, offset_t d, offset_t n) const {
+ // get the corresponding input x, y, z co-ordinates from grid
+ scalar_t* grid_ptr_NXYZ = grid_ptr + n * grid_sN + w * grid_sX + h * grid_sY + d * grid_sZ;
+ scalar_t x = *grid_ptr_NXYZ;
+ scalar_t y = grid_ptr_NXYZ[grid_sC];
+ scalar_t z = grid_ptr_NXYZ[grid_sC * 2];
+
+ // Check if out-of-bound
+ if (!(extrapolate ||
+ (inbounds(x, src_X, static_cast(TINY)) && inbounds(y, src_Y, static_cast(TINY)) &&
+ inbounds(z, src_Z, static_cast(TINY))))) {
+ if (do_pull || do_sgrad) {
+ scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ;
+ for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC) {
+ *out_ptr_NCXYZ = static_cast(0);
+ if (do_sgrad) {
+ out_ptr_NCXYZ[out_sK] = static_cast(0);
+ out_ptr_NCXYZ[out_sK * 2] = static_cast(0);
+ }
+ }
+ }
+ if (do_grad) {
+ scalar_t* grad_ptr_NXYZ = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY + d * grad_sZ;
+ (*grad_ptr_NXYZ) = static_cast(0);
+ grad_ptr_NXYZ[grad_sC] = static_cast(0);
+ grad_ptr_NXYZ[grad_sC * 2] = static_cast(0);
+ }
+ return;
+ }
+
+ // Next step
+ if (bound0 == BoundType::Sliding) {
+ if (iso)
+ switch (static_cast(interpolation0)) {
+ case 0:
+ return interpolate3d_sliding_nearest(x, y, z, w, h, d, n);
+ case 1:
+ return interpolate3d_sliding_trilinear(x, y, z, w, h, d, n);
+ }
+ return interpolate3d_sliding(x, y, z, w, h, d, n);
+ } else {
+ if (iso)
+ switch (static_cast(interpolation0)) {
+ case 0:
+ return interpolate3d_nearest(x, y, z, w, h, d, n);
+ case 1:
+ return interpolate3d_trilinear(x, y, z, w, h, d, n);
+ }
+ return interpolate3d(x, y, z, w, h, d, n);
+ }
+ }
+
template
MONAI_DEVICE void PushPullImpl::check2d(offset_t w, offset_t h, offset_t n) const {
// get the corresponding input x, y, z co-ordinates from grid
@@ -642,7 +848,7 @@ MONAI_NAMESPACE_DEVICE { // cpu
if (!(extrapolate ||
(inbounds(x, src_X, static_cast(TINY)) && inbounds(y, src_Y, static_cast(TINY))))) {
if (do_pull || do_sgrad) {
- scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sZ + h * out_sY;
+ scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY;
for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC) {
*out_ptr_NCXY = static_cast(0);
if (do_sgrad)
@@ -680,32 +886,25 @@ MONAI_NAMESPACE_DEVICE { // cpu
}
template
- MONAI_DEVICE void PushPullImpl::check3d(offset_t w, offset_t h, offset_t d, offset_t n) const {
+ MONAI_DEVICE void PushPullImpl::check1d(offset_t w, offset_t n) const {
// get the corresponding input x, y, z co-ordinates from grid
- scalar_t* grid_ptr_NXYZ = grid_ptr + n * grid_sN + w * grid_sX + h * grid_sY + d * grid_sZ;
- scalar_t x = *grid_ptr_NXYZ;
- scalar_t y = grid_ptr_NXYZ[grid_sC];
- scalar_t z = grid_ptr_NXYZ[grid_sC * 2];
+ scalar_t* grid_ptr_NX = grid_ptr + n * grid_sN + w * grid_sX;
+ scalar_t x = *grid_ptr_NX;
// Check if out-of-bound
- if (!(extrapolate ||
- (inbounds(x, src_X, static_cast(TINY)) && inbounds(y, src_Y, static_cast(TINY)) &&
- inbounds(z, src_Z, static_cast(TINY))))) {
+ if (!(extrapolate || inbounds(x, src_X, static_cast(TINY)))) {
if (do_pull || do_sgrad) {
- scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ;
- for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC) {
- *out_ptr_NCXYZ = static_cast(0);
- if (do_sgrad) {
- out_ptr_NCXYZ[out_sK] = static_cast(0);
- out_ptr_NCXYZ[out_sK * 2] = static_cast(0);
- }
+ scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX;
+ for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC) {
+ *out_ptr_NCX = static_cast(0);
+ if (do_sgrad)
+ out_ptr_NCX[out_sK] = static_cast(0);
}
}
if (do_grad) {
- scalar_t* grad_ptr_NXYZ = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY + d * grad_sZ;
- (*grad_ptr_NXYZ) = static_cast(0);
- grad_ptr_NXYZ[grad_sC] = static_cast(0);
- grad_ptr_NXYZ[grad_sC * 2] = static_cast(0);
+ scalar_t* grad_ptr_NX = grad_ptr + n * grad_sN + w * grad_sX;
+ (*grad_ptr_NX) = static_cast(0);
+ grad_ptr_NX[grad_sC] = static_cast(0);
}
return;
}
@@ -715,20 +914,20 @@ MONAI_NAMESPACE_DEVICE { // cpu
if (iso)
switch (static_cast(interpolation0)) {
case 0:
- return interpolate3d_sliding_nearest(x, y, z, w, h, d, n);
+ return interpolate1d_sliding_nearest(x, w, n);
case 1:
- return interpolate3d_sliding_trilinear(x, y, z, w, h, d, n);
+ return interpolate1d_sliding_linear(x, w, n);
}
- return interpolate3d_sliding(x, y, z, w, h, d, n);
+ return interpolate1d_sliding(x, w, n);
} else {
if (iso)
switch (static_cast(interpolation0)) {
case 0:
- return interpolate3d_nearest(x, y, z, w, h, d, n);
+ return interpolate1d_nearest(x, w, n);
case 1:
- return interpolate3d_trilinear(x, y, z, w, h, d, n);
+ return interpolate1d_linear(x, w, n);
}
- return interpolate3d(x, y, z, w, h, d, n);
+ return interpolate1d(x, w, n);
}
}
@@ -763,7 +962,7 @@ MONAI_NAMESPACE_DEVICE { // cpu
if (trgt_ptr && (do_push || do_grad))
for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXYZ += trgt_sC) {
target[c] = *trgt_ptr_NCXYZ;
- if (trgt_K > 1) {
+ if (trgt_K > 0) {
target[c + C] = trgt_ptr_NCXYZ[trgt_sK];
target[c + C * 2] = trgt_ptr_NCXYZ[trgt_sK * 2];
}
@@ -881,7 +1080,7 @@ MONAI_NAMESPACE_DEVICE { // cpu
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
else if (do_push) {
- if (trgt_K == 1) {
+ if (trgt_K == 0) {
// Diff w.r.t. push/pull
scalar_t* out_ptr_NC = out_ptr_NC0;
for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC)
@@ -904,7 +1103,7 @@ MONAI_NAMESPACE_DEVICE { // cpu
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if (do_grad) {
- if (trgt_K == 1) {
+ if (trgt_K == 0) {
// Diff w.r.t. pull/push
scalar_t* src_ptr_NC = src_ptr_NC0;
scalar_t dot = static_cast(0);
@@ -973,7 +1172,7 @@ MONAI_NAMESPACE_DEVICE { // cpu
if (trgt_ptr && (do_push || do_grad))
for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXY += trgt_sC) {
target[c] = *trgt_ptr_NCXY;
- if (trgt_K > 1) {
+ if (trgt_K > 0) {
target[c + C] = trgt_ptr_NCXY[trgt_sK];
}
}
@@ -1066,7 +1265,7 @@ MONAI_NAMESPACE_DEVICE { // cpu
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
else if (do_push) {
- if (trgt_K == 1) {
+ if (trgt_K == 0) {
// Diff w.r.t. push/pull
scalar_t* out_ptr_NC = out_ptr_NC0;
for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC)
@@ -1088,7 +1287,7 @@ MONAI_NAMESPACE_DEVICE { // cpu
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if (do_grad) {
- if (trgt_K == 1) {
+ if (trgt_K == 0) {
// Diff w.r.t. pull/push
scalar_t* src_ptr_NC = src_ptr_NC0;
scalar_t dot = static_cast(0);
@@ -1125,6 +1324,150 @@ MONAI_NAMESPACE_DEVICE { // cpu
}
}
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ // GENERIC INTERPOLATION 1D
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ template
+ MONAI_DEVICE void PushPullImpl::interpolate1d(scalar_t x, offset_t w, offset_t n) const {
+ // Get corner pixel values from (x, y)
+ offset_t bx0, bx1;
+ interpolation::bounds(interpolation0, x, bx0, bx1);
+ offset_t dbx = bx1 - bx0;
+
+ // Pre-compute offsets and target value
+ scalar_t* src_ptr_NC0 = src_ptr + n * src_sN;
+ scalar_t* out_ptr_NC0 = out_ptr + n * out_sN;
+ scalar_t* out_ptr_NCX0 = out_ptr + n * out_sN + w * out_sX;
+ scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX;
+ scalar_t target[2 * MONAI_MAX_NUM_CHANNELS];
+ if (trgt_ptr && (do_push || do_grad))
+ for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC) {
+ target[c] = *trgt_ptr_NCX;
+ if (trgt_K > 0) {
+ target[c + C] = trgt_ptr_NCX[trgt_sK];
+ }
+ }
+
+ // Initialize output
+ scalar_t* out_ptr_NCX = out_ptr_NCX0;
+ if (do_pull || do_sgrad) {
+ for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC) {
+ *out_ptr_NCX = static_cast(0);
+ if (do_sgrad) {
+ out_ptr_NCX[out_sK] = static_cast(0);
+ }
+ }
+ }
+
+ // Pre-compute indices/weights/grad
+ scalar_t wx[8]; // B-spline weights
+ scalar_t gx[8]; // B-spline derivatives
+ scalar_t hx[8]; // B-spline 2nd derivatives
+ offset_t ix[8]; // Warped indices
+ uint8_t sx[8]; // Warped indices
+
+ {
+ scalar_t *owx = static_cast(wx), *ogx = static_cast(gx), *ohx = static_cast(hx);
+ offset_t* oix = static_cast(ix);
+ uint8_t* osx = static_cast(sx);
+ for (offset_t bx = bx0; bx <= bx1; ++bx) {
+ scalar_t dx = x - bx;
+ *(owx++) = interpolation::fastweight(interpolation0, dx);
+ if (do_grad || do_sgrad)
+ *(ogx++) = interpolation::fastgrad(interpolation0, dx);
+ if (do_grad && trgt_sK > 1)
+ *(ohx++) = interpolation::fasthess(interpolation0, dx);
+ *(osx++) = bound::sign(bound0, bx, src_X);
+ *(oix++) = bound::index(bound0, bx, src_X);
+ }
+ }
+
+ // Convolve coefficients with basis functions
+ scalar_t ogx;
+ ogx = static_cast(0);
+ for (offset_t i = 0; i <= dbx; ++i) {
+ offset_t oox = ix[i] * out_sX;
+ offset_t osx = ix[i] * src_sX;
+ uint8_t sxx = sx[i];
+ scalar_t wxx = wx[i];
+ scalar_t gxx = gx[i];
+ scalar_t hxx = hx[i];
+
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ if (do_pull) {
+ scalar_t* src_ptr_NC = src_ptr_NC0;
+ scalar_t* out_ptr_NCX = out_ptr_NCX0;
+ for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC)
+ *out_ptr_NCX += bound::get(src_ptr_NC, osx, sxx) * wxx;
+ }
+
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ else if (do_sgrad) {
+ scalar_t* src_ptr_NC = src_ptr_NC0;
+ scalar_t* out_ptr_NCX = out_ptr_NCX0;
+ for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) {
+ scalar_t src = bound::get(src_ptr_NC, osx, sxx);
+ *out_ptr_NCX += src * gxx;
+ }
+ }
+
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ else if (do_push) {
+ if (trgt_K == 0) {
+ // Diff w.r.t. push/pull
+ scalar_t* out_ptr_NC = out_ptr_NC0;
+ for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC)
+ bound::add(out_ptr_NC, oox, wxx * target[c], sxx);
+ } else {
+ // Diff w.r.t. sgrad
+ scalar_t* out_ptr_NC = out_ptr_NC0;
+ for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC) {
+ scalar_t val = gxx * target[c];
+ bound::add(out_ptr_NC, oox, val, sxx);
+ }
+ }
+ }
+
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Count ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ else if (do_count) {
+ bound::add(out_ptr_NC0, oox, wxx, sxx);
+ }
+
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ if (do_grad) {
+ if (trgt_K == 0) {
+ // Diff w.r.t. pull/push
+ scalar_t* src_ptr_NC = src_ptr_NC0;
+ scalar_t dot = static_cast(0);
+ for (offset_t c = 0; c < C; ++c, src_ptr_NC += src_sC) {
+ scalar_t src = bound::get(src_ptr_NC, osx, sxx);
+ dot += (trgt_ptr ? src * target[c] : src);
+ // trgt_ptr == 0 in the backward pass of 'count'
+ }
+ ogx += gxx * dot;
+ } else {
+ // Diff w.r.t. sgrad
+ scalar_t* src_ptr_NC = src_ptr_NC0;
+ scalar_t dot;
+ dot = static_cast(0);
+ for (offset_t c = 0; c < C; ++c, src_ptr_NC += src_sC) {
+ scalar_t src = bound::get(src_ptr_NC, osx, sxx);
+ dot += src * target[c];
+ }
+ ogx += hxx * dot;
+ }
+ }
+
+ } // x
+
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ if (do_grad) {
+ scalar_t* grad_ptr_NX = grad_ptr + n * grad_sN + w * grad_sX;
+ (*grad_ptr_NX) = ogx;
+ }
+ }
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// LINEAR INTERPOLATION 3D
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -1214,7 +1557,7 @@ MONAI_NAMESPACE_DEVICE { // cpu
scalar_t* trgt_ptr_NCXYZ = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY + d * trgt_sZ;
scalar_t* src_ptr_NC = src_ptr + n * src_sN;
- if (trgt_K == 1) {
+ if (trgt_K == 0) {
// backward w.r.t. push/pull
for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXYZ += trgt_sC, src_ptr_NC += src_sC) {
scalar_t src;
@@ -1376,7 +1719,7 @@ MONAI_NAMESPACE_DEVICE { // cpu
o111 = ix1 * out_sX + iy1 * out_sY + iz1 * out_sZ;
scalar_t* trgt_ptr_NCXYZ = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY + d * trgt_sZ;
scalar_t* out_ptr_NC = out_ptr + n * out_sN;
- if (trgt_K == 1) {
+ if (trgt_K == 0) {
// Diff w.r.t. push/pull
for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXYZ += trgt_sC, out_ptr_NC += out_sC) {
scalar_t trgt = *trgt_ptr_NCXYZ;
@@ -1461,7 +1804,6 @@ MONAI_NAMESPACE_DEVICE { // cpu
scalar_t w10 = dx1 * dy0;
scalar_t w01 = dx0 * dy1;
scalar_t w11 = dx1 * dy1;
- ;
// Sign (/!\ compute sign before warping indices)
int8_t sx1 = bound::sign(bound0, ix0 + 1, src_X);
@@ -1500,7 +1842,7 @@ MONAI_NAMESPACE_DEVICE { // cpu
scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY;
scalar_t* src_ptr_NC = src_ptr + n * src_sN;
- if (trgt_K == 1) {
+ if (trgt_K == 0) {
// backward w.r.t. push/pull
for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXY += trgt_sC, src_ptr_NC += src_sC) {
scalar_t src;
@@ -1547,9 +1889,9 @@ MONAI_NAMESPACE_DEVICE { // cpu
}
}
- scalar_t* grad_ptr_NXYZ = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY;
- (*grad_ptr_NXYZ) = gx;
- grad_ptr_NXYZ[grad_sC] = gy;
+ scalar_t* grad_ptr_NXY = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY;
+ (*grad_ptr_NXY) = gx;
+ grad_ptr_NXY[grad_sC] = gy;
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if (do_pull) {
@@ -1591,7 +1933,7 @@ MONAI_NAMESPACE_DEVICE { // cpu
o11 = ix1 * out_sX + iy1 * out_sY;
scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY;
scalar_t* out_ptr_NC = out_ptr + n * out_sN;
- if (trgt_K == 1) {
+ if (trgt_K == 0) {
// Diff w.r.t. push/pull
for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXY += trgt_sC, out_ptr_NC += out_sC) {
scalar_t trgt = *trgt_ptr_NCXY;
@@ -1632,6 +1974,123 @@ MONAI_NAMESPACE_DEVICE { // cpu
}
}
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ // LINEAR INTERPOLATION 1D
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ template
+ MONAI_DEVICE void PushPullImpl::interpolate1d_linear(scalar_t x, offset_t w, offset_t n) const {
+ // Get corner pixel values from (x)
+ offset_t ix0 = static_cast(std::floor(x));
+
+ // Interpolation weights (inversely proportional to distance)
+ scalar_t w1 = x - ix0;
+ scalar_t w0 = 1. - w1;
+
+ // Sign (/!\ compute sign before warping indices)
+ int8_t s1 = bound::sign(bound0, ix0 + 1, src_X);
+ int8_t s0 = bound::sign(bound0, ix0, src_X);
+
+ // Warp indices
+ offset_t ix1;
+ ix1 = bound::index(bound0, ix0 + 1, src_X);
+ ix0 = bound::index(bound0, ix0, src_X);
+
+ // Offsets into source volume
+ offset_t o0, o1;
+ if (do_pull || do_grad || do_sgrad) {
+ o0 = ix0 * src_sX;
+ o1 = ix1 * src_sX;
+ }
+
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~
+ if (do_grad) {
+ if (trgt_K == 0) {
+ // backward w.r.t. push/pull
+
+ o0 = ix0 * src_sX;
+ o1 = ix1 * src_sX;
+ scalar_t gx = static_cast(0);
+ scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX;
+ scalar_t* src_ptr_NC = src_ptr + n * src_sN;
+
+ for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC, src_ptr_NC += src_sC) {
+ scalar_t src;
+ scalar_t trgt = trgt_ptr ? *trgt_ptr_NCX : static_cast(1);
+ // ^ trgt_ptr == 0 during the backward pass of count
+ src = bound::get(src_ptr_NC, o0, s0);
+ if (trgt_ptr)
+ src *= trgt;
+ gx -= src;
+ src = bound::get(src_ptr_NC, o1, s1);
+ if (trgt_ptr)
+ src *= trgt;
+ gx += src;
+ }
+
+ scalar_t* grad_ptr_NX = grad_ptr + n * grad_sN + w * grad_sX;
+ (*grad_ptr_NX) = gx;
+ } else {
+ // backward w.r.t. sgrad
+ // -> zero (make sure this is done at initialization)
+ }
+ }
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ if (do_pull) {
+ o0 = ix0 * src_sX;
+ o1 = ix1 * src_sX;
+ scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX;
+ scalar_t* src_ptr_NC = src_ptr + n * src_sN;
+ for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) {
+ *out_ptr_NCX = bound::get(src_ptr_NC, o0, s0) * w0 + bound::get(src_ptr_NC, o1, s1) * w1;
+ }
+ }
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ else if (do_sgrad) {
+ o0 = ix0 * src_sX;
+ o1 = ix1 * src_sX;
+ scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX;
+ scalar_t* src_ptr_NC = src_ptr + n * src_sN;
+
+ for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) {
+ *out_ptr_NCX = bound::get(src_ptr_NC, o1, s1) - bound::get(src_ptr_NC, o0, s0);
+ }
+ }
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ else if (do_push) {
+ // Offsets into 'push' volume
+ o0 = ix0 * out_sX;
+ o1 = ix1 * out_sX;
+ scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX;
+ scalar_t* out_ptr_NC = out_ptr + n * out_sN;
+ if (trgt_K == 0) {
+ // Diff w.r.t. push/pull
+ for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC, out_ptr_NC += out_sC) {
+ scalar_t trgt = *trgt_ptr_NCX;
+ bound::add(out_ptr_NC, o0, w0 * trgt, s0);
+ bound::add(out_ptr_NC, o1, w1 * trgt, s1);
+ }
+ } else {
+ // Diff w.r.t. sgrad
+ for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC, out_ptr_NC += out_sC) {
+ scalar_t trgt0 = *trgt_ptr_NCX;
+ bound::add(out_ptr_NC, o0, -trgt0, s0);
+ bound::add(out_ptr_NC, o1, trgt0, s1);
+ }
+ }
+ }
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ else if (do_count) {
+ // Offsets into 'push' volume
+ o0 = ix0 * out_sX;
+ o1 = ix1 * out_sX;
+
+ scalar_t* out_ptr_N = out_ptr + n * out_sN;
+ bound::add(out_ptr_N, o0, w0, s0);
+ bound::add(out_ptr_N, o1, w1, s1);
+ }
+ }
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// NEAREST NEIGHBOR INTERPOLATION 3D
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -1666,7 +2125,7 @@ MONAI_NAMESPACE_DEVICE { // cpu
scalar_t* src_ptr_NC = src_ptr + n * src_sN;
for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC, src_ptr_NC += src_sC)
*out_ptr_NCXYZ = bound::get(src_ptr_NC, o, s);
- } else if (do_push && trgt_K == 1) {
+ } else if (do_push && trgt_K == 0) {
offset_t o = iz * out_sZ + iy * out_sY + ix * out_sX;
scalar_t* trgt_ptr_NCXYZ = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY + d * trgt_sZ;
scalar_t* out_ptr_NC = out_ptr + n * out_sN;
@@ -1709,7 +2168,7 @@ MONAI_NAMESPACE_DEVICE { // cpu
scalar_t* src_ptr_NC = src_ptr + n * src_sN;
for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC, src_ptr_NC += src_sC)
*out_ptr_NCXY = bound::get(src_ptr_NC, o, s);
- } else if (do_push && trgt_K == 1) {
+ } else if (do_push && trgt_K == 0) {
offset_t o = iy * out_sY + ix * out_sX;
scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY;
scalar_t* out_ptr_NC = out_ptr + n * out_sN;
@@ -1722,10 +2181,48 @@ MONAI_NAMESPACE_DEVICE { // cpu
bound::add(out_ptr_NC, o, static_cast(1), s);
}
}
+
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ // NEAREST NEIGHBOR INTERPOLATION 1D
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ template
+ MONAI_DEVICE void PushPullImpl::interpolate1d_nearest(scalar_t x, offset_t w, offset_t n) const {
+ offset_t i = static_cast(std::round(x));
+
+ // Boundary condition (/!\ compute sign before warping indices)
+ int8_t s = bound::sign(bound0, i, src_X);
+ i = bound::index(bound0, i, src_X);
+
+ if (do_pull) {
+ offset_t o = i * src_sX;
+ scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX;
+ scalar_t* src_ptr_NC = src_ptr + n * src_sN;
+ for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC)
+ *out_ptr_NCX = bound::get(src_ptr_NC, o, s);
+ } else if (do_push && trgt_K == 0) {
+ offset_t o = i * out_sX;
+ scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX;
+ scalar_t* out_ptr_NC = out_ptr + n * out_sN;
+ for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC, out_ptr_NC += out_sC)
+ bound::add(out_ptr_NC, o, *trgt_ptr_NCX, s);
+ } else if (do_count) {
+ offset_t o = i * out_sX;
+ scalar_t* out_ptr_NC = out_ptr + n * out_sN;
+ for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC)
+ bound::add(out_ptr_NC, o, static_cast(1), s);
+ }
+ }
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// LINEAR INTERPOLATION 3D + SLIDING BOUNDARY
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// TODO
+
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ // CUDA KERNEL (MUST BE OUT OF CLASS)
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
} // namespace
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -1757,8 +2254,6 @@ MONAI_NAMESPACE_DEVICE { // cpu
PUSHPULL_INSTANTIATE1(BoundType); \
PUSHPULL_INSTANTIATE1(BoundVectorRef)
- // ~~~ CPU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
// Two arguments (source, grid)
// > `bound` and `interpolation` can be single arguments or vectors.
template
@@ -1773,12 +2268,14 @@ MONAI_NAMESPACE_DEVICE { // cpu
bool do_count,
bool do_grad,
bool do_sgrad) {
+ PushPullAllocator info(
+ grid.dim() - 2, bound, interpolation, extrapolate, do_pull, do_push, do_count, do_grad, do_sgrad);
+ info.ioset(source, grid);
+
return AT_DISPATCH_FLOATING_TYPES(grid.scalar_type(), "pushpull", [&] {
- PushPullImpl f(
- grid.dim() - 2, bound, interpolation, extrapolate, do_pull, do_push, do_count, do_grad, do_sgrad);
- f.ioset(source, grid);
- f.loop();
- return f.output;
+ PushPullImpl algo(info);
+ algo.loop();
+ return algo.output;
});
}
@@ -1798,17 +2295,18 @@ MONAI_NAMESPACE_DEVICE { // cpu
bool do_count,
bool do_grad,
bool do_sgrad) {
+ PushPullAllocator info(
+ grid.dim() - 2, bound, interpolation, extrapolate, do_pull, do_push, do_count, do_grad, do_sgrad);
+ info.ioset(source, grid, target);
+
return AT_DISPATCH_FLOATING_TYPES(grid.scalar_type(), "pushpull", [&] {
- PushPullImpl f(
- grid.dim() - 2, bound, interpolation, extrapolate, do_pull, do_push, do_count, do_grad, do_sgrad);
- f.ioset(source, grid, target);
- f.loop();
- return f.output;
+ PushPullImpl algo(info);
+ algo.loop();
+ return algo.output;
});
}
PUSHPULL_INSTANTIATE;
-} // namespace
-
+} // namespace cpu
} // namespace monai
diff --git a/monai/csrc/resample/pushpull_cuda.cu b/monai/csrc/resample/pushpull_cuda.cu
index ecfeb562ab..38d34ffe98 100644
--- a/monai/csrc/resample/pushpull_cuda.cu
+++ b/monai/csrc/resample/pushpull_cuda.cu
@@ -25,6 +25,7 @@ limitations under the License.
// TODO:
// . [DONE] generic 3d
// . [DONE] generic 2d
+// . [DONE] generic 1d
// . sliding nearest 3d
// . sliding nearest 2d
// . sliding linear 3d
@@ -37,6 +38,7 @@ limitations under the License.
// . input bound/inter are always vectors -> clean unused constructors
#include
+#include
#include
#include "bounds_common.h"
#include "interpolation_common.h"
@@ -71,18 +73,27 @@ MONAI_NAMESPACE_DEVICE { // cuda
namespace { // anonymous namespace > everything inside has internal linkage
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- // GENERIC PUSHPULL CLASS
+ // INDEXING UTILS
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- // This class implements the bulk of the code.
- // /!\ No type and shape checking is performed here.
- template
- class PushPullImpl {
+ // This class reads and sets all the parameters that will later be used
+ // by the algorithm in PushPullImpl. All of this is done outside of the
+ // implementation class so that we do not depend on generic types. The
+ // point is to pre-allocate all necessary tensors so that we can check
+ // if they're all compatible with 32 bit math. If it's the case, we can
+ // dispatch to a 32b cuda implementation, which might increase
+ // performance. Else, we use 64 bit math to compute offsets.
+ // (On CPU, we always use 64 bit offsets because it doesn't make a huge
+ // difference. It would be different if we had a vectorized
+ // implementation as in PyTorch).
+ class PushPullAllocator {
public:
+ static constexpr int64_t max_int32 = std::numeric_limits::max();
+
// ~~~ CONSTRUCTORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
MONAI_HOST
- PushPullImpl(
+ PushPullAllocator(
int dim,
BoundVectorRef bound,
InterpolationVectorRef interpolation,
@@ -122,100 +133,417 @@ MONAI_NAMESPACE_DEVICE { // cuda
iso = interpolation0 == interpolation1 && interpolation0 == interpolation2;
}
- MONAI_HOST
- PushPullImpl(
- int dim,
- BoundType bound,
- InterpolationVectorRef interpolation,
- bool extrapolate,
- bool do_pull,
- bool do_push,
- bool do_count,
- bool do_grad,
- bool do_sgrad)
- : dim(dim),
- bound0(bound),
- bound1(bound),
- bound2(bound),
- interpolation0(interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear),
- interpolation1(
- interpolation.size() > 1 ? interpolation[1]
- : interpolation.size() > 0 ? interpolation[0]
- : InterpolationType::Linear),
- interpolation2(
- interpolation.size() > 2 ? interpolation[2]
- : interpolation.size() > 1 ? interpolation[1]
- : interpolation.size() > 0 ? interpolation[0]
- : InterpolationType::Linear),
- extrapolate(extrapolate),
- do_pull(do_pull),
- do_push(do_push),
- do_count(do_count),
- do_grad(do_grad),
- do_sgrad(do_sgrad) {
- iso = interpolation0 == interpolation1 && interpolation0 == interpolation2;
+ // ~~~ FUNCTORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ // Usually used for pull:
+ // - do_pull -> return source[grid]
+ // - do_push -> fails
+ // - do_grad -> return J(source)[grid]
+ // - do_sgrad -> return H(source)[grid]
+ MONAI_HOST void ioset(const Tensor& source, const Tensor& grid) {
+ init_all();
+ init_source(source);
+ init_grid(grid);
+ init_output();
}
- MONAI_HOST
- PushPullImpl(
- int dim,
- BoundVectorRef bound,
- InterpolationType interpolation,
- bool extrapolate,
- bool do_pull,
- bool do_push,
- bool do_count,
- bool do_grad,
- bool do_sgrad)
- : dim(dim),
- bound0(bound.size() > 0 ? bound[0] : BoundType::Replicate),
- bound1(
- bound.size() > 1 ? bound[1]
- : bound.size() > 0 ? bound[0]
- : BoundType::Replicate),
- bound2(
- bound.size() > 2 ? bound[2]
- : bound.size() > 1 ? bound[1]
- : bound.size() > 0 ? bound[0]
- : BoundType::Replicate),
- interpolation0(interpolation),
- interpolation1(interpolation),
- interpolation2(interpolation),
- extrapolate(extrapolate),
- do_pull(do_pull),
- do_push(do_push),
- do_count(do_count),
- do_grad(do_grad),
- do_sgrad(do_sgrad) {
- iso = interpolation0 == interpolation1 && interpolation0 == interpolation2;
+ // Usually used for pull_backward:
+ // - do_pull -> return source[grid]
+ // - do_push -> return push(target, grid, source.shape)
+ // - do_grad -> return J(source)[grid]
+ // - do_sgrad -> return H(source)[grid]
+ MONAI_HOST void ioset(const Tensor& source, const Tensor& grid, const Tensor& target) {
+ init_all();
+ init_source(source);
+ init_grid(grid);
+ init_target(target);
+ init_output();
}
- MONAI_HOST
- PushPullImpl(
- int dim,
- BoundType bound,
- InterpolationType interpolation,
- bool extrapolate,
- bool do_pull,
- bool do_push,
- bool do_count,
- bool do_grad,
- bool do_sgrad)
- : dim(dim),
- bound0(bound),
- bound1(bound),
- bound2(bound),
- interpolation0(interpolation),
- interpolation1(interpolation),
- interpolation2(interpolation),
- extrapolate(extrapolate),
- do_pull(do_pull),
- do_push(do_push),
- do_count(do_count),
- do_grad(do_grad),
- do_sgrad(do_sgrad) {
- iso = interpolation0 == interpolation1 && interpolation0 == interpolation2;
+ // Usually used for push:
+ // - do_pull -> fails
+ // - do_push -> return push(target, grid, source_size)
+ // - do_grad -> fails
+ // - do_sgrad -> fails
+ MONAI_HOST void ioset(IntArrayRef source_size, const Tensor& grid, const Tensor& target) {
+ init_all();
+ init_source(source_size);
+ init_grid(grid);
+ init_target(target);
+ init_output();
+ }
+
+ // Usually used for count:
+ // - do_pull -> fails
+ // - do_push -> return push(ones, grid, source_size)
+ // - do_grad -> fails
+ // - do_sgrad -> fails
+ MONAI_HOST void ioset(IntArrayRef source_size, const Tensor& grid) {
+ init_all();
+ init_source(source_size);
+ init_grid(grid);
+ init_output();
+ }
+
+ // We just check that all tensors that we own are compatible with 32b math
+ bool canUse32BitIndexMath(int64_t max_elem = max_int32) const {
+ return src_32b_ok && trgt_32b_ok && grid_32b_ok && grad_32b_ok && out_32b_ok;
+ }
+
+ private:
+ // Copied from aten/src/ATen/native/IndexingUtils.cpp in PyTorch 1.6.
+ // It is used to decide to which pointer type we should dispatch to.
+ // Basically, we need to make sure that the "furthest" element we need
+ // to reach is less than max_elem away.
+ static bool tensorCanUse32BitIndexMath(const Tensor& t, int64_t max_elem = max_int32) {
+ int64_t elements = t.numel();
+ if (elements >= max_elem) {
+ return false;
+ }
+ if (elements == 0) {
+ return max_elem > 0;
+ }
+
+ int64_t offset = 0;
+ int64_t linearId = elements - 1;
+
+ // NOTE: Assumes all strides are positive, which is true for now
+ for (int i = t.dim() - 1; i >= 0; --i) {
+ int64_t curDimIndex = linearId % t.size(i);
+ int64_t curDimOffset = curDimIndex * t.stride(i);
+ offset += curDimOffset;
+ linearId /= t.size(i);
+ }
+
+ if (offset >= max_elem) {
+ return false;
+ }
+
+ return true;
+ }
+
+ // ~~~ COMPONENTS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ MONAI_HOST void init_all();
+ MONAI_HOST void init_source(const Tensor& source);
+ MONAI_HOST void init_source(IntArrayRef source_size);
+ MONAI_HOST void init_grid(const Tensor& grid);
+ MONAI_HOST void init_target(const Tensor& target);
+ MONAI_HOST void init_output();
+
+ // ~~~ OPTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ int dim; // dimensionality (2 or 3)
+ BoundType bound0; // boundary condition // x|W
+ BoundType bound1; // boundary condition // y|H
+ BoundType bound2; // boundary condition // z|D
+ InterpolationType interpolation0; // interpolation order // x|W
+ InterpolationType interpolation1; // interpolation order // y|H
+ InterpolationType interpolation2; // interpolation order // z|D
+ bool iso; // isotropic interpolation?
+ bool extrapolate; // compute out-of-bound values
+ bool do_pull; // sample a volume
+ bool do_push; // splat a volume
+ bool do_count; // splatting weights (= jacobian determinant)
+ bool do_grad; // backprop: gradient of grid // pull
+ bool do_sgrad; // sample spatial gradients
+
+ // ~~~ NAVIGATORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ std::deque output;
+ TensorOptions src_opt;
+ TensorOptions grid_opt;
+ TensorOptions trgt_opt;
+ int64_t N;
+ int64_t C;
+ int64_t src_X;
+ int64_t src_Y;
+ int64_t src_Z;
+ int64_t trgt_X;
+ int64_t trgt_Y;
+ int64_t trgt_Z;
+ int64_t trgt_K;
+ int64_t src_sN;
+ int64_t src_sC;
+ int64_t src_sX;
+ int64_t src_sY;
+ int64_t src_sZ;
+ bool src_32b_ok;
+ void* src_ptr;
+ int64_t trgt_sN;
+ int64_t trgt_sC;
+ int64_t trgt_sX;
+ int64_t trgt_sY;
+ int64_t trgt_sZ;
+ int64_t trgt_sK;
+ bool trgt_32b_ok;
+ void* trgt_ptr;
+ int64_t grid_sN;
+ int64_t grid_sC;
+ int64_t grid_sX;
+ int64_t grid_sY;
+ int64_t grid_sZ;
+ bool grid_32b_ok;
+ void* grid_ptr;
+ int64_t out_sN;
+ int64_t out_sC;
+ int64_t out_sX;
+ int64_t out_sY;
+ int64_t out_sZ;
+ int64_t out_sK; // gradient dimension
+ bool out_32b_ok;
+ void* out_ptr;
+ int64_t grad_sN;
+ int64_t grad_sC;
+ int64_t grad_sX;
+ int64_t grad_sY;
+ int64_t grad_sZ;
+ bool grad_32b_ok;
+ void* grad_ptr;
+
+ // Allow PushPullImpl's constructor to access PushPullAllocator's
+ // private members.
+ template
+ friend class PushPullImpl;
+ };
+
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ // INITIALISATION
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ MONAI_HOST
+ void PushPullAllocator::init_all() {
+ src_opt = grid_opt = trgt_opt = TensorOptions();
+ N = C = 1L;
+ src_X = src_Y = src_Z = 1L;
+ trgt_X = trgt_Y = trgt_Z = 1L;
+ trgt_K = 0L;
+ src_sN = src_sC = src_sX = src_sY = src_sZ = 0L;
+ grid_sN = grid_sC = grid_sX = grid_sY = grid_sZ = 0L;
+ grad_sN = grad_sC = grad_sX = grad_sY = grad_sZ = 0L;
+ trgt_sN = trgt_sC = trgt_sX = trgt_sY = trgt_sZ = trgt_sK = 0L;
+ out_sN = out_sC = out_sX = out_sY = out_sZ = out_sK = 0L;
+ src_ptr = trgt_ptr = grid_ptr = out_ptr = grad_ptr = static_cast(0);
+ src_32b_ok = trgt_32b_ok = grid_32b_ok = out_32b_ok = grad_32b_ok = true;
+ }
+
+ MONAI_HOST
+ void PushPullAllocator::init_source(const Tensor& source) {
+ N = source.size(0);
+ C = source.size(1);
+ src_X = source.size(2);
+ src_Y = dim < 2 ? 1L : source.size(3);
+ src_Z = dim < 3 ? 1L : source.size(4);
+ src_sN = source.stride(0);
+ src_sC = source.stride(1);
+ src_sX = source.stride(2);
+ src_sY = dim < 2 ? 0L : source.stride(3);
+ src_sZ = dim < 3 ? 0L : source.stride(4);
+ src_ptr = source.data_ptr();
+ src_opt = source.options();
+ src_32b_ok = tensorCanUse32BitIndexMath(source);
+ }
+
+ MONAI_HOST
+ void PushPullAllocator::init_source(IntArrayRef source_size) {
+ src_X = source_size[0];
+ src_Y = dim < 2 ? 1L : source_size[1];
+ src_Z = dim < 3 ? 1L : source_size[2];
+ }
+
+ MONAI_HOST
+ void PushPullAllocator::init_grid(const Tensor& grid) {
+ N = grid.size(0);
+ trgt_X = grid.size(1);
+ trgt_Y = dim < 2 ? 1L : grid.size(2);
+ trgt_Z = dim < 3 ? 1L : grid.size(3);
+ grid_sN = grid.stride(0);
+ grid_sX = grid.stride(1);
+ grid_sY = dim < 2 ? 0L : grid.stride(2);
+ grid_sZ = dim < 3 ? 0L : grid.stride(3);
+ grid_sC = grid.stride(dim == 1 ? 2 : dim == 2 ? 3 : 4);
+ grid_ptr = grid.data_ptr();
+ grid_opt = grid.options();
+ grid_32b_ok = tensorCanUse32BitIndexMath(grid);
+ }
+
+ MONAI_HOST
+ void PushPullAllocator::init_target(const Tensor& target) {
+ N = target.size(0);
+ C = target.size(1);
+ trgt_X = target.size(2);
+ trgt_Y = dim < 2 ? 1L : target.size(3);
+ trgt_Z = dim < 3 ? 1L : target.size(4);
+ trgt_K = target.dim() == dim + 3 ? target.size(dim == 1 ? 3 : dim == 2 ? 4 : 5) : 0L;
+ trgt_sN = target.stride(0);
+ trgt_sC = target.stride(1);
+ trgt_sX = target.stride(2);
+ trgt_sY = dim < 2 ? 0L : target.stride(3);
+ trgt_sZ = dim < 3 ? 0L : target.stride(4);
+ trgt_sK = target.dim() == dim + 3 ? target.stride(dim == 1 ? 3 : dim == 2 ? 4 : 5) : 0L;
+ trgt_ptr = target.data_ptr();
+ trgt_opt = target.options();
+ trgt_32b_ok = tensorCanUse32BitIndexMath(target);
+ }
+
+ MONAI_HOST
+ void PushPullAllocator::init_output() {
+ output.clear();
+ if (do_pull) {
+ if (dim == 1)
+ output.push_back(at::empty({N, C, trgt_X}, src_opt));
+ else if (dim == 2)
+ output.push_back(at::empty({N, C, trgt_X, trgt_Y}, src_opt));
+ else
+ output.push_back(at::empty({N, C, trgt_X, trgt_Y, trgt_Z}, src_opt));
+ auto pull = output.back();
+ out_sN = pull.stride(0);
+ out_sC = pull.stride(1);
+ out_sX = pull.stride(2);
+ out_sY = dim < 2 ? 0L : pull.stride(3);
+ out_sZ = dim < 3 ? 0L : pull.stride(4);
+ out_sK = 0L;
+ out_ptr = pull.data_ptr();
+ out_32b_ok = tensorCanUse32BitIndexMath(pull);
+ } else if (do_sgrad) {
+ if (dim == 1)
+ output.push_back(at::empty({N, C, trgt_X, 1}, src_opt));
+ else if (dim == 2)
+ output.push_back(at::empty({N, C, trgt_X, trgt_Y, 2}, src_opt));
+ else
+ output.push_back(at::empty({N, C, trgt_X, trgt_Y, trgt_Z, 3}, src_opt));
+ auto sgrad = output.back();
+ out_sN = sgrad.stride(0);
+ out_sC = sgrad.stride(1);
+ out_sX = sgrad.stride(2);
+ out_sY = dim < 2 ? 0L : sgrad.stride(3);
+ out_sZ = dim < 3 ? 0L : sgrad.stride(4);
+ out_sK = sgrad.stride(dim == 1 ? 3 : dim == 2 ? 4 : 5);
+ out_ptr = sgrad.data_ptr();
+ out_32b_ok = tensorCanUse32BitIndexMath(sgrad);
+
+ if (iso && interpolation0 == InterpolationType::Nearest)
+ sgrad.zero_();
+ if (iso && interpolation0 == InterpolationType::Linear && dim == 1)
+ sgrad.zero_();
+ } else if (do_push) {
+ if (dim == 1)
+ output.push_back(at::zeros({N, C, src_X}, trgt_opt));
+ else if (dim == 2)
+ output.push_back(at::zeros({N, C, src_X, src_Y}, trgt_opt));
+ else
+ output.push_back(at::zeros({N, C, src_X, src_Y, src_Z}, trgt_opt));
+ auto push = output.back();
+ out_sN = push.stride(0);
+ out_sC = push.stride(1);
+ out_sX = push.stride(2);
+ out_sY = dim < 2 ? 0L : push.stride(3);
+ out_sZ = dim < 3 ? 0L : push.stride(4);
+ out_sK = 0L;
+ out_ptr = push.data_ptr();
+ out_32b_ok = tensorCanUse32BitIndexMath(push);
+ } else if (do_count) {
+ if (dim == 1)
+ output.push_back(at::zeros({N, 1, src_X}, grid_opt));
+ else if (dim == 2)
+ output.push_back(at::zeros({N, 1, src_X, src_Y}, grid_opt));
+ else
+ output.push_back(at::zeros({N, 1, src_X, src_Y, src_Z}, grid_opt));
+ auto count = output.back();
+ out_sN = count.stride(0);
+ out_sC = count.stride(1);
+ out_sX = count.stride(2);
+ out_sY = dim < 2 ? 0L : count.stride(3);
+ out_sZ = dim < 3 ? 0L : count.stride(4);
+ out_sK = 0L;
+ out_ptr = count.data_ptr();
+ out_32b_ok = tensorCanUse32BitIndexMath(count);
+ }
+ if (do_grad) {
+ if (dim == 1)
+ output.push_back(at::zeros({N, trgt_X, 1}, grid_opt));
+ else if (dim == 2)
+ output.push_back(at::zeros({N, trgt_X, trgt_Y, 2}, grid_opt));
+ else
+ output.push_back(at::zeros({N, trgt_X, trgt_Y, trgt_Z, 3}, grid_opt));
+ auto grad = output.back();
+ grad_sN = grad.stride(0);
+ grad_sX = grad.stride(1);
+ grad_sY = dim < 2 ? 0L : grad.stride(2);
+ grad_sZ = dim < 3 ? 0L : grad.stride(3);
+ grad_sC = grad.stride(dim == 1 ? 2 : dim == 2 ? 3 : 4);
+ grad_ptr = grad.data_ptr();
+ out_32b_ok = tensorCanUse32BitIndexMath(grad);
+
+ if (iso && interpolation0 == InterpolationType::Nearest)
+ grad.zero_();
}
+ }
+
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ // GENERIC PUSHPULL CLASS
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ // This class implements the bulk of the code.
+ // /!\ No type and shape checking is performed here.
+
+ template
+ class PushPullImpl {
+ public:
+ // ~~~ CONSTRUCTOR ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ PushPullImpl(const PushPullAllocator& info)
+ : output(info.output),
+ dim(info.dim),
+ bound0(info.bound0),
+ bound1(info.bound1),
+ bound2(info.bound2),
+ interpolation0(info.interpolation0),
+ interpolation1(info.interpolation1),
+ interpolation2(info.interpolation1),
+ iso(info.iso),
+ extrapolate(info.extrapolate),
+ do_pull(info.do_pull),
+ do_push(info.do_push),
+ do_count(info.do_count),
+ do_grad(info.do_grad),
+ do_sgrad(info.do_sgrad),
+ N(static_cast(info.N)),
+ C(static_cast(info.C)),
+ src_X(static_cast(info.src_X)),
+ src_Y(static_cast(info.src_Y)),
+ src_Z(static_cast(info.src_Z)),
+ trgt_X(static_cast(info.trgt_X)),
+ trgt_Y(static_cast(info.trgt_Y)),
+ trgt_Z(static_cast(info.trgt_Z)),
+ trgt_K(static_cast(info.trgt_K)),
+ src_sN(static_cast(info.src_sN)),
+ src_sC(static_cast(info.src_sC)),
+ src_sX(static_cast(info.src_sX)),
+ src_sY(static_cast(info.src_sY)),
+ src_sZ(static_cast(info.src_sZ)),
+ src_ptr(static_cast(info.src_ptr)),
+ trgt_sN(static_cast(info.trgt_sN)),
+ trgt_sC(static_cast(info.trgt_sC)),
+ trgt_sX(static_cast(info.trgt_sX)),
+ trgt_sY(static_cast(info.trgt_sY)),
+ trgt_sZ(static_cast(info.trgt_sZ)),
+ trgt_sK(static_cast(info.trgt_sK)),
+ trgt_ptr(static_cast(info.trgt_ptr)),
+ grid_sN(static_cast(info.grid_sN)),
+ grid_sC(static_cast(info.grid_sC)),
+ grid_sX(static_cast(info.grid_sX)),
+ grid_sY(static_cast(info.grid_sY)),
+ grid_sZ(static_cast(info.grid_sZ)),
+ grid_ptr(static_cast(info.grid_ptr)),
+ out_sN(static_cast(info.out_sN)),
+ out_sC(static_cast(info.out_sC)),
+ out_sX(static_cast(info.out_sX)),
+ out_sY(static_cast(info.out_sY)),
+ out_sZ(static_cast(info.out_sZ)),
+ out_sK(static_cast(info.out_sK)),
+ out_ptr(static_cast