diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000000..674f488a32 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,10 @@ +# Set update schedule for GitHub Actions + +version: 2 +updates: + + - package-ecosystem: "github-actions" + directory: "/" + schedule: + # Check for updates to GitHub Actions every week + interval: "monthly" diff --git a/.github/workflows/blossom-ci.yml b/.github/workflows/blossom-ci.yml index 1717618e9c..3c8e5b63eb 100644 --- a/.github/workflows/blossom-ci.yml +++ b/.github/workflows/blossom-ci.yml @@ -49,7 +49,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: repository: ${{ fromJson(needs.Authorization.outputs.args).repo }} ref: ${{ fromJson(needs.Authorization.outputs.args).ref }} diff --git a/.github/workflows/chatops.yml b/.github/workflows/chatops.yml index ac6bc1eb34..b4e201a0d9 100644 --- a/.github/workflows/chatops.yml +++ b/.github/workflows/chatops.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest steps: - name: dispatch - uses: peter-evans/slash-command-dispatch@v1.2.0 + uses: peter-evans/slash-command-dispatch@v3.0.1 with: token: ${{ secrets.PR_MAINTAIN }} reaction-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 1596db26b2..6df4857ec3 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -64,6 +64,7 @@ jobs: - name: Build run: | + python -m pip install -U pip wheel python -m pip install -r requirements-dev.txt BUILD_MONAI=1 ./runtests.sh --build diff --git a/.github/workflows/conda.yml b/.github/workflows/conda.yml index d175c9bdaf..80d3449458 100644 --- a/.github/workflows/conda.yml +++ b/.github/workflows/conda.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - os: [windows-latest, ubuntu-latest] + os: [ubuntu-latest] python-version: ["3.9"] runs-on: ${{ matrix.os }} env: @@ -26,7 +26,7 @@ jobs: steps: - if: runner.os == 'windows' name: Config pagefile (Windows only) - uses: al-cheb/configure-pagefile-action@v1.2 + uses: al-cheb/configure-pagefile-action@v1.3 with: minimum-size: 8 maximum-size: 16 diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml index 9410d5d58d..d73b901ddb 100644 --- a/.github/workflows/cron.yml +++ b/.github/workflows/cron.yml @@ -17,7 +17,7 @@ jobs: - "PT191+CUDA113" - "PT110+CUDA113" - "PT112+CUDA113" - - "PTLATEST+CUDA117" + - "PTLATEST+CUDA118" include: # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes - environment: PT182+CUDA102 @@ -32,9 +32,9 @@ jobs: - environment: PT112+CUDA113 pytorch: "torch==1.12.1 torchvision==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu113" base: "nvcr.io/nvidia/pytorch:21.06-py3" # CUDA 11.3 - - environment: PTLATEST+CUDA117 - pytorch: "-U torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117" - base: "nvcr.io/nvidia/pytorch:22.08-py3" # CUDA 11.7 + - environment: PTLATEST+CUDA118 + pytorch: "-U torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118" + base: "nvcr.io/nvidia/pytorch:22.12-py3" # CUDA 11.8 container: image: ${{ matrix.base }} options: "--gpus all" @@ -119,7 +119,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:21.02", "pytorch:21.10", "pytorch:22.10"] # 21.02, 21.10 for backward comp. + container: ["pytorch:22.09", "pytorch:22.11", "pytorch:22.12"] container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -164,7 +164,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:21.02", "pytorch:21.10", "pytorch:22.10"] # 21.02, 21.10 for backward comp. + container: ["pytorch:21.02", "pytorch:21.10", "pytorch:22.12"] # 21.02, 21.10 for backward comp. container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -264,7 +264,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' needs: cron-gpu # so that monai itself is verified first container: - image: nvcr.io/nvidia/pytorch:22.10-py3 # testing with the latest pytorch base image + image: nvcr.io/nvidia/pytorch:22.12-py3 # testing with the latest pytorch base image options: "--gpus all --ipc=host" runs-on: [self-hosted, linux, x64, integration] steps: diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index b88923e43d..30a7714a27 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -25,17 +25,18 @@ jobs: with: ref: dev fetch-depth: 0 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v4 with: - python-version: '3.8' + python-version: '3.9' - shell: bash run: | git describe + python -m pip install -U pip wheel setuptools python setup.py build cat build/lib/monai/_version.py - name: Upload version - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: _version.py path: build/lib/monai/_version.py @@ -55,18 +56,13 @@ jobs: with: ref: dev - name: Download version - uses: actions/download-artifact@v2 + uses: actions/download-artifact@v3 with: name: _version.py - - name: Install Latest Docker - run: | - curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add - - sudo add-apt-repository "deb [arch=amd64] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" - sudo apt-get update - sudo apt-get install docker-ce - name: docker_build shell: bash run: | + docker --version # get tag info for versioning cat _version.py mv _version.py monai/ diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 9000268d0a..9894848a53 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -34,13 +34,14 @@ jobs: run: | which python python -m pip install --upgrade pip wheel - python -m pip uninstall -y torch torchvision - python -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu116 + python -m pip install --upgrade torch torchvision python -m pip install -r requirements-dev.txt rm -rf /github/home/.cache/torch/hub/mmars/ - name: Run integration tests run: | python -m pip list + git config --global --add safe.directory /__w/MONAI/MONAI + git clean -ffdx nvidia-smi export CUDA_VISIBLE_DEVICES=$(python -m tests.utils | tail -n 1) echo $CUDA_VISIBLE_DEVICES @@ -53,7 +54,7 @@ jobs: if pgrep python; then pkill python; fi shell: bash - name: Add reaction - uses: peter-evans/create-or-update-comment@v1 + uses: peter-evans/create-or-update-comment@v2 if: github.event.pull_request.number != '' with: token: ${{ secrets.PR_MAINTAIN }} diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index a603deb8a2..9a0b8142e3 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -22,6 +22,9 @@ jobs: # - os-latest-pip- (shared) flake8-py3: runs-on: ubuntu-latest + strategy: + matrix: + opt: ["codeformat", "pytype", "mypy"] steps: - uses: actions/checkout@v3 - name: Set up Python 3.8 @@ -46,8 +49,8 @@ jobs: run: | # clean up temporary files $(pwd)/runtests.sh --build --clean - # Git hub actions have 2 cores, so parallize pytype - $(pwd)/runtests.sh --build --codeformat -j 2 + # Github actions have 2 cores, so parallelize pytype + $(pwd)/runtests.sh --build --${{ matrix.opt }} -j 2 quick-py3: # full dependencies installed tests for different OS runs-on: ${{ matrix.os }} @@ -59,7 +62,7 @@ jobs: steps: - if: runner.os == 'windows' name: Config pagefile (Windows only) - uses: al-cheb/configure-pagefile-action@v1.2 + uses: al-cheb/configure-pagefile-action@v1.3 with: minimum-size: 8 maximum-size: 16 @@ -101,8 +104,7 @@ jobs: python -m pip list python setup.py develop # test no compile installation shell: bash - - if: runner.os != 'windows' - name: Run compiled (${{ runner.os }}) + - name: Run compiled (${{ runner.os }}) run: | python setup.py develop --uninstall BUILD_MONAI=1 python setup.py develop # compile the cpp extensions diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b63fc40a4d..8a874fcfe1 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -65,7 +65,7 @@ jobs: - if: matrix.python-version == '3.8' && startsWith(github.ref, 'refs/tags/') name: Upload artifacts - uses: actions/upload-artifact@v1 + uses: actions/upload-artifact@v3 with: name: dist path: dist/ @@ -105,7 +105,7 @@ jobs: python setup.py build cat build/lib/monai/_version.py - name: Upload version - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: _version.py path: build/lib/monai/_version.py @@ -122,7 +122,7 @@ jobs: steps: - uses: actions/checkout@v3 - name: Download version - uses: actions/download-artifact@v2 + uses: actions/download-artifact@v3 with: name: _version.py - name: Set tag diff --git a/.github/workflows/weekly-preview.yml b/.github/workflows/weekly-preview.yml index 238302a018..a8eed74d9a 100644 --- a/.github/workflows/weekly-preview.yml +++ b/.github/workflows/weekly-preview.yml @@ -33,13 +33,13 @@ jobs: export YEAR_WEEK=$(date +'%y%U') echo "Year week for tag is ${YEAR_WEEK}" if ! [[ $YEAR_WEEK =~ ^[0-9]{4}$ ]] ; then echo "Wrong 'year week' format. Should be 4 digits."; exit 1 ; fi - git tag "1.1.dev${YEAR_WEEK}" + git tag "1.2.dev${YEAR_WEEK}" git log -1 git tag --list python setup.py sdist bdist_wheel - name: Publish to PyPI - uses: pypa/gh-action-pypi-publish@master + uses: pypa/gh-action-pypi-publish@release/v1 with: user: __token__ password: ${{ secrets.PYPI_TOKEN }} diff --git a/.gitignore b/.gitignore index 3da001d0ce..bd117cc321 100644 --- a/.gitignore +++ b/.gitignore @@ -135,6 +135,8 @@ tests/testing_data/endo.mp4 tests/testing_data/ultrasound.avi tests/testing_data/train_data_stats.yaml tests/testing_data/eval_data_stats.yaml +tests/testing_data/CT_2D_head_fixed.mha +tests/testing_data/CT_2D_head_moving.mha # clang format tool .clang-format-bin/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 62550f51d4..1269e18978 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.4.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace @@ -28,7 +28,7 @@ repos: - id: mixed-line-ending - repo: https://github.com/asottile/pyupgrade - rev: v2.38.2 + rev: v3.3.1 hooks: - id: pyupgrade args: [--py37-plus] @@ -54,11 +54,12 @@ repos: exclude: | (?x)^( monai/__init__.py| - docs/source/conf.py + docs/source/conf.py| + tests/utils.py )$ - repo: https://github.com/hadialqattan/pycln - rev: v2.1.1 + rev: v2.1.2 hooks: - id: pycln args: [--config=pyproject.toml] diff --git a/CITATION.cff b/CITATION.cff index d00c8a364a..dafb8578fe 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -6,8 +6,8 @@ title: "MONAI: Medical Open Network for AI" abstract: "AI Toolkit for Healthcare Imaging" authors: - name: "MONAI Consortium" -date-released: 2022-09-16 -version: "1.0.0" +date-released: 2022-12-19 +version: "1.1.0" identifiers: - description: "This DOI represents all versions of MONAI, and will always resolve to the latest one." type: doi diff --git a/Dockerfile b/Dockerfile index c41af38a57..9ce222fd15 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,7 +11,7 @@ # To build with a different base image # please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag. -ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:22.10-py3 +ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:22.12-py3 FROM ${PYTORCH_IMAGE} LABEL maintainer="monai.contact@gmail.com" diff --git a/README.md b/README.md index d54aa7e1a7..d769bf66b1 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,10 @@ Examples and notebook tutorials are located at [Project-MONAI/tutorials](https:/ Technical documentation is available at [docs.monai.io](https://docs.monai.io). +## Citation + +If you have used MONAI in your research, please cite us! The citation can be exported from: https://arxiv.org/abs/2211.02701. + ## Model Zoo [The MONAI Model Zoo](https://github.com/Project-MONAI/model-zoo) is a place for researchers and data scientists to share the latest and great models from the community. Utilizing [the MONAI Bundle format](https://docs.monai.io/en/latest/bundle_intro.html) makes it easy to [get started](https://github.com/Project-MONAI/tutorials/tree/main/model_zoo) building workflows with MONAI. diff --git a/docs/requirements.txt b/docs/requirements.txt index cb28412ad9..a89961826f 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,11 +1,11 @@ -f https://download.pytorch.org/whl/cpu/torch-1.12.1%2Bcpu-cp37-cp37m-linux_x86_64.whl torch>=1.8 -pytorch-ignite==0.4.10 -numpy>=1.17 +pytorch-ignite==0.4.11 +numpy>=1.20 itk>=5.2 nibabel parameterized -scikit-image>=0.14.2 +scikit-image>=0.19.0 tensorboard commonmark==0.9.1 recommonmark==0.6.0 @@ -21,10 +21,11 @@ sphinx-autodoc-typehints==1.11.1 pandas einops transformers<4.22 # https://github.com/Project-MONAI/MONAI/issues/5157 -mlflow +mlflow>=1.28.0 +clearml >=1.10.0rc0 tensorboardX -imagecodecs; platform_system == "Linux" -tifffile; platform_system == "Linux" +imagecodecs; platform_system == "Linux" or platform_system == "Darwin" +tifffile; platform_system == "Linux" or platform_system == "Darwin" pyyaml fire jsonschema diff --git a/docs/source/MONAI-logo-color.png b/docs/source/MONAI-logo-color.png new file mode 100644 index 0000000000..d1e8b6b7be Binary files /dev/null and b/docs/source/MONAI-logo-color.png differ diff --git a/docs/source/applications.md b/docs/source/applications.md index 5317a3d49a..c77cb4065c 100644 --- a/docs/source/applications.md +++ b/docs/source/applications.md @@ -76,4 +76,4 @@ The following figure shows the detection training and inference workflows: ![detection workflow](../images/detection.png) ### Reproducing the state-of-the-art Kaggle competition solutions -[A reimplementation](https://github.com/Project-MONAI/tutorials/tree/master/kaggle/RANZCR/4th_place_solution) of the 4th place solution of RANZCR CLiP - Catheter and Line Position Challenge in Kaggle: https://www.kaggle.com/c/ranzcr-clip-catheter-line-classification +[A reimplementation](https://github.com/Project-MONAI/tutorials/tree/main/competitions/kaggle/RANZCR/4th_place_solution) of the 4th place solution of RANZCR CLiP - Catheter and Line Position Challenge in Kaggle: https://www.kaggle.com/c/ranzcr-clip-catheter-line-classification diff --git a/docs/source/bundle.rst b/docs/source/bundle.rst index 0977788ded..a429f446e9 100644 --- a/docs/source/bundle.rst +++ b/docs/source/bundle.rst @@ -32,6 +32,7 @@ Model Bundle --------------- .. autoclass:: ConfigParser :members: + :special-members: `Scripts` --------- diff --git a/docs/source/conf.py b/docs/source/conf.py index ecc1b3ff59..2762027a79 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -116,8 +116,9 @@ def generate_apidocs(*args): "collapse_navigation": True, "navigation_depth": 3, "show_toc_level": 1, - "footer_items": ["copyright"], + "footer_start": ["copyright"], "navbar_align": "content", + "logo": {"image_light": "MONAI-logo-color.png", "image_dark": "MONAI-logo-color.png"}, } html_context = { "github_user": "Project-MONAI", diff --git a/docs/source/config_syntax.md b/docs/source/config_syntax.md index 7cd71b507f..7530e544ff 100644 --- a/docs/source/config_syntax.md +++ b/docs/source/config_syntax.md @@ -15,7 +15,7 @@ Content: - [`@` to reference Python objects in configurations](#to-reference-python-objects-in-configurations) - [`$` to evaluate as Python expressions](#to-evaluate-as-python-expressions) - [`%` to textually replace configuration elements](#to-textually-replace-configuration-elements) - - [`_target_` (`_disabled_`, `_desc_`, and `_requires_`) to instantiate a Python object](#instantiate-a-python-object) + - [`_target_` (`_disabled_`, `_desc_`, `_requires_`, `_mode_`) to instantiate a Python object](#instantiate-a-python-object) - [The command line interface](#the-command-line-interface) - [Recommendations](#recommendations) @@ -143,17 +143,26 @@ This dictionary will be instantiated as a Pytorch object at runtime. "_target_": "my.module.Class", "_desc_": "this is a customized class which also triggers 'cudnn_opt' reference", "_requires_": "@cudnn_opt", - "_disabled_": "true"} + "_disabled_": "true", + "_mode_": "default"} } ``` -_Description:_ `_requires_`, `_disabled_`, and `_desc_` are optional keys. +_Description:_ `_requires_`, `_disabled_`, `_desc_`, and `_mode_` are optional keys. - `_requires_` specifies references (string starts with `@`) or Python expression that will be evaluated/instantiated before `_target_` object is instantiated. It is useful when the component does not explicitly depend on the other ConfigItems via its arguments, but requires the dependencies to be instantiated/evaluated beforehand. - `_disabled_` specifies a flag to indicate whether to skip the instantiation. - `_desc_` can be used for providing free text descriptions. +- `_mode_` specifies the operating mode when the component is instantiated or the callable is called. + it currently supports the following values: + - `"default"` (default) -- return the return value of ``_target_(**kwargs)`` + - `"partial"` -- return a partial function of ``functools.partial(_target_, **kwargs)`` (this is often + useful when some portion of the full set of arguments are supplied to the ``_target_``, and the user wants to + call it with additional arguments later). + - `"debug"` -- execute with debug prompt and return the return value of ``pdb.runcall(_target_, **kwargs)``, + see also [`pdb.runcall`](https://docs.python.org/3/library/pdb.html#pdb.runcall). ## The command line interface diff --git a/docs/source/data.rst b/docs/source/data.rst index 8cb27cd347..69e694a37b 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -184,28 +184,6 @@ PILWriter .. autoclass:: PILWriter :members: -Nifti format handling ---------------------- - -Writing Nifti -~~~~~~~~~~~~~ -.. autoclass:: monai.data.NiftiSaver - :members: - -.. autofunction:: monai.data.write_nifti - - -PNG format handling -------------------- - -Writing PNG -~~~~~~~~~~~ -.. autoclass:: monai.data.PNGSaver - :members: - -.. autofunction:: monai.data.write_png - - Synthetic --------- .. automodule:: monai.data.synthetic @@ -274,6 +252,11 @@ N-Dim Fourier Transform .. autofunction:: monai.data.fft_utils.fftn_centered .. autofunction:: monai.data.fft_utils.ifftn_centered +ITK Torch Bridge +~~~~~~~~~~~~~~~~ +.. automodule:: monai.data.itk_torch_bridge + :members: + Meta Object ----------- diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 5b408cfa71..7da7f7f50d 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -101,6 +101,18 @@ Peak signal to noise ratio metrics handler :members: +Metrics reloaded binary handler +------------------------------- +.. autoclass:: MetricsReloadedBinaryHandler + :members: + + +Metrics reloaded categorical handler +------------------------------------ +.. autoclass:: MetricsReloadedCategoricalHandler + :members: + + Metric logger ------------- .. autoclass:: MetricLogger @@ -177,6 +189,17 @@ MLFlow handler .. autoclass:: MLFlowHandler :members: +ClearML handlers +---------------- +.. autoclass:: ClearMLHandler + :members: + +.. autoclass:: ClearMLStatsHandler + :members: + +.. autoclass:: ClearMLImageHandler + :members: + NVTX Handlers ------------- .. automodule:: monai.handlers.nvtx_handlers diff --git a/docs/source/inferers.rst b/docs/source/inferers.rst index ac638eb38d..85a7e1fb63 100644 --- a/docs/source/inferers.rst +++ b/docs/source/inferers.rst @@ -5,12 +5,6 @@ Inference methods ================= -Sliding Window Inference ------------------------- - -.. autofunction:: monai.inferers.sliding_window_inference - - Inferers -------- @@ -19,6 +13,12 @@ Inferers :members: :special-members: __call__ +`PatchInferer` +~~~~~~~~~~~~~~ +.. autoclass:: PatchInferer + :members: + :special-members: __call__ + `SimpleInferer` ~~~~~~~~~~~~~~~ .. autoclass:: SimpleInferer @@ -42,3 +42,37 @@ Inferers .. autoclass:: SliceInferer :members: :special-members: __call__ + + +Splitters +--------- +.. currentmodule:: monai.inferers +.. autoclass:: Splitter + :members: + :special-members: __call__ + +`SlidingWindowSplitter` +~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: SlidingWindowSplitter + :members: + :special-members: __call__ + +Mergers +------- +.. currentmodule:: monai.inferers +.. autoclass:: Merger + :members: + :special-members: __call__ + +`AvgMerger` +~~~~~~~~~~~ +.. autoclass:: AvgMerger + :members: + :special-members: __call__ + + + +Sliding Window Inference Function +--------------------------------- + +.. autofunction:: monai.inferers.sliding_window_inference diff --git a/docs/source/installation.md b/docs/source/installation.md index 5d071d0b8b..ffa9cdf091 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -1,18 +1,21 @@ # Installation Guide ## Table of Contents -1. [From PyPI](#from-pypi) - 1. [Milestone release](#milestone-release) - 2. [Weekly preview release](#weekly-preview-release) - 3. [Uninstall the packages](#uninstall-the-packages) -1. [From conda-forge](#from-conda-forge) -2. [From GitHub](#from-github) - 1. [System-wide](#milestone-release) - 2. [Editable](#weekly-preview-release) -3. [Validating the install](#validating-the-install) -4. [MONAI version string](#monai-version-string) -5. [From DockerHub](#from-dockerhub) -6. [Installing the recommended dependencies](#installing-the-recommended-dependencies) + +- [Installation Guide](#installation-guide) + - [Table of Contents](#table-of-contents) + - [From PyPI](#from-pypi) + - [Milestone release](#milestone-release) + - [Weekly preview release](#weekly-preview-release) + - [Uninstall the packages](#uninstall-the-packages) + - [From conda-forge](#from-conda-forge) + - [From GitHub](#from-github) + - [Option 1 (as a part of your system-wide module)](#option-1-as-a-part-of-your-system-wide-module) + - [Option 2 (editable installation)](#option-2-editable-installation) + - [Validating the install](#validating-the-install) + - [MONAI version string](#monai-version-string) + - [From DockerHub](#from-dockerhub) + - [Installing the recommended dependencies](#installing-the-recommended-dependencies) --- @@ -24,32 +27,39 @@ and the Python package index (PyPI). The pre-built Docker images are made availa To install optional features such as handling the NIfTI files using [Nibabel](https://nipy.org/nibabel/), or building workflows using [Pytorch Ignite](https://pytorch.org/ignite/), please follow the instructions: + - [Installing the recommended dependencies](#installing-the-recommended-dependencies) The installation commands bellow usually end up installing CPU variant of PyTorch. To install GPU-enabled PyTorch: + 1. Install the latest NVIDIA driver. 1. Check [PyTorch Official Guide](https://pytorch.org/get-started/locally/) for the recommended CUDA versions. For Pip package, the user needs to download the CUDA manually, install it on the system, and ensure CUDA_PATH is set properly. 1. Continue to follow the guide and install PyTorch. 1. Install MONAI using one the ways described below. ---- +--- ## From PyPI ### Milestone release + To install the [current milestone release](https://pypi.org/project/monai/): + ```bash pip install monai ``` ### Weekly preview release + To install the [weekly preview release](https://pypi.org/project/monai-weekly/): + ```bash pip install monai-weekly ``` The weekly build is released to PyPI every Sunday with a pre-release build number `dev[%y%U]`. To report any issues on the weekly preview, please include the version and commit information: + ```bash python -c "import monai; print(monai.__version__); print(monai.__commit_id__)" ``` @@ -61,7 +71,9 @@ without uninstalling the existing one first. To address this issue, please uninstall both packages, and retry the installation. ### Uninstall the packages + The packages installed using `pip install` could be removed by: + ```bash pip uninstall -y monai pip uninstall -y monai-weekly @@ -70,51 +82,64 @@ pip uninstall -y monai-weekly ## From conda-forge To install the [current milestone release](https://anaconda.org/conda-forge/monai): + ```bash conda install -c conda-forge monai ``` ## From GitHub + (_If you have installed the -PyPI release version using ``pip install monai``, please run ``pip uninstall -monai`` before using the commands from this section. Because ``pip`` by +PyPI release version using `pip install monai`, please run `pip uninstall +monai` before using the commands from this section. Because `pip` by default prefers the milestone release_.) -The milestone versions are currently planned and released every few months. As the +The milestone versions are currently planned and released every few months. As the codebase is under active development, you may want to install MONAI from GitHub for the latest features: ### Option 1 (as a part of your system-wide module): + ```bash pip install git+https://github.com/Project-MONAI/MONAI#egg=monai ``` + or, to build with MONAI C++/CUDA extensions: + ```bash BUILD_MONAI=1 pip install git+https://github.com/Project-MONAI/MONAI#egg=monai ``` To build the extensions, if the system environment already has a version of Pytorch installed, `--no-build-isolation` might be preferred: + ```bash BUILD_MONAI=1 pip install --no-build-isolation git+https://github.com/Project-MONAI/MONAI#egg=monai ``` + this command will download and install the current `dev` branch of [MONAI from GitHub](https://github.com/Project-MONAI/MONAI). This documentation website by default shows the information for the latest version. ### Option 2 (editable installation): + To install an editable version of MONAI, it is recommended to clone the codebase directly: + ```bash git clone https://github.com/Project-MONAI/MONAI.git ``` -This command will create a ``MONAI/`` folder in your current directory. + +This command will create a `MONAI/` folder in your current directory. You can install it by running: + ```bash cd MONAI/ python setup.py develop ``` + or, to build with MONAI C++/CUDA extensions and install: + ```bash cd MONAI/ BUILD_MONAI=1 python setup.py develop @@ -123,6 +148,7 @@ BUILD_MONAI=1 CC=clang CXX=clang++ python setup.py develop ``` To uninstall the package please run: + ```bash cd MONAI/ python setup.py develop --uninstall @@ -131,37 +157,41 @@ python setup.py develop --uninstall ./runtests.sh --clean ``` -Alternatively, simply adding the root directory of the cloned source code (e.g., ``/workspace/Documents/MONAI``) to your ``$PYTHONPATH`` +Alternatively, simply adding the root directory of the cloned source code (e.g., `/workspace/Documents/MONAI`) to your `$PYTHONPATH` and the codebase is ready to use (without the additional features of MONAI C++/CUDA extensions). > The C++/CUDA extension features are currently experimental, a pre-compiled version is made available via > [the recent docker image releases](https://hub.docker.com/r/projectmonai/monai). > Building the extensions from source may require [Ninja](https://ninja-build.org/) and [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit). > By default, CUDA extension is built if `torch.cuda.is_available()`. It's possible to force building by -> setting ``FORCE_CUDA=1`` environment variable. - +> setting `FORCE_CUDA=1` environment variable. ## Validating the install + You can verify the installation by: + ```bash python -c "import monai; monai.config.print_config()" ``` + If the installation is successful, this command will print out the MONAI version information, and this confirms the core modules of MONAI are ready-to-use. - ## MONAI version string + The MONAI version string shows the current status of your local installation. For example: + ``` MONAI version: 0.1.0+144.g52c763d.dirty ``` -- ``0.1.0`` indicates that your installation is based on the ``0.1.0`` milestone release. -- ``+144`` indicates that your installation is 144 git commits ahead of the milestone release. -- ``g52c763d`` indicates that your installation corresponds to the git commit hash ``52c763d``. -- ``dirty`` indicates that you have modified the codebase locally, and the codebase is inconsistent with ``52c763d``. +- `0.1.0` indicates that your installation is based on the `0.1.0` milestone release. +- `+144` indicates that your installation is 144 git commits ahead of the milestone release. +- `g52c763d` indicates that your installation corresponds to the git commit hash `52c763d`. +- `dirty` indicates that you have modified the codebase locally, and the codebase is inconsistent with `52c763d`. ## From DockerHub + Make sure you have installed the NVIDIA driver and Docker 19.03+ for your Linux distribution. Note that you do not need to install the CUDA toolkit on the host, but the driver needs to be installed. Please find out more information on [nvidia-docker](https://github.com/NVIDIA/nvidia-docker). @@ -169,20 +199,24 @@ Please find out more information on [nvidia-docker](https://github.com/NVIDIA/nv Assuming that you have the Nvidia driver and Docker 19.03+ installed, running the following command will download and start a container with the latest version of MONAI. The latest `dev` branch of MONAI from GitHub is included in the image. + ```bash docker run --gpus all --rm -ti --ipc=host projectmonai/monai:latest ``` You can also run a milestone release docker image by specifying the image tag, for example: + ``` docker run --gpus all --rm -ti --ipc=host projectmonai/monai:0.1.0 ``` ## Installing the recommended dependencies + By default, the installation steps will only download and install the minimal requirements of MONAI. Optional dependencies can be installed using [the extras syntax](https://packaging.python.org/tutorials/installing-packages/#installing-setuptools-extras) to support additional features. For example, to install MONAI with Nibabel and Scikit-image support: + ```bash git clone https://github.com/Project-MONAI/MONAI.git cd MONAI/ @@ -190,6 +224,7 @@ pip install -e '.[nibabel,skimage]' ``` Alternatively, to install all optional dependencies: + ```bash git clone https://github.com/Project-MONAI/MONAI.git cd MONAI/ @@ -197,6 +232,7 @@ pip install -e '.[all]' ``` To install all optional dependencies with `pip` based on MONAI development environment settings: + ```bash git clone https://github.com/Project-MONAI/MONAI.git cd MONAI/ @@ -205,6 +241,7 @@ pip install -r requirements-dev.txt To install all optional dependencies with `conda` based on MONAI development environment settings (`environment-dev.yml`; this will install PyTorch as well as `pytorch-cuda`, please follow https://pytorch.org/get-started/locally/#start-locally for more details about installing PyTorch): + ```bash git clone https://github.com/Project-MONAI/MONAI.git cd MONAI/ @@ -215,10 +252,12 @@ conda env update -n -f environment-dev.yml Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is available via PyPI. - The options are + ``` -[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna] +[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna] ``` + which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index c25d29f3e1..94f4dbbe28 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -145,6 +145,16 @@ Metrics .. autoclass:: CumulativeAverage :members: +`Metrics reloaded binary` +------------------------- +.. autoclass:: MetricsReloadedBinary + :members: + +`Metrics reloaded categorical` +------------------------------ +.. autoclass:: MetricsReloadedCategorical + :members: + Utilities --------- .. automodule:: monai.metrics.utils diff --git a/docs/source/modules.md b/docs/source/modules.md index b2cddd95cc..b6dbf190cd 100644 --- a/docs/source/modules.md +++ b/docs/source/modules.md @@ -56,7 +56,7 @@ so that the deep learning models and pipelines can readily incorporate the meta ### GPU-based accelerations -Implementations are provided to ensure optimal usage of the underlying hardware resources. [[fast training guide]](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/fast_model_training_guide.md) +Implementations are provided to ensure optimal usage of the underlying hardware resources. [[fast training guide]](https://github.com/Project-MONAI/tutorials/blob/main/acceleration) ### Determinism and reproducibility @@ -83,8 +83,8 @@ domain-specific usability and pipeline performance. ### Cache IO and transforms data to accelerate training Data-driven methods require many (potentially thousands of) epochs of training data reading and preprocessing. MONAI -provides multi-threaded cache-based datasets to accelerate the process [[Datasets experiment]](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/dataset_type_performance.ipynb). The -cache can be persistent and dynamic (`SmartCacheDataset`) and reused across different experiments [[SmartCache example]](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/distributed_training/unet_training_smartcache.py). +provides multi-threaded cache-based datasets to accelerate the process [[Datasets experiment]](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/dataset_type_performance.ipynb). The +cache can be persistent and dynamic (`SmartCacheDataset`) and reused across different experiments [[SmartCache example]](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/distributed_training/unet_training_smartcache.py). The following figure illustrates the training speedup compared with a regular PyTorch program. ![cachedataset speed](../images/datasets_speed.png) @@ -97,13 +97,13 @@ executes the transforms in a separate thread: ![threaddataloader](../images/threaddataloader.png) -a `ThreadDataLoader` example is within the [Spleen fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb). +a `ThreadDataLoader` example is within the [Spleen fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/fast_training_tutorial.ipynb). ### Public datasets To quickly get started with popular training data, MONAI provides several ready-to-integrate Dataset classes (such as `MedNISTDataset`, `DecathlonDataset`, [`TciaDataset`](https://github.com/Project-MONAI/tutorials/blob/main/modules/tcia_dataset.ipynb)), which include data downloading, and support training/evaluation splits generation with transforms. -[[Public datasets tutorial]](https://github.com/Project-MONAI/tutorials/blob/master/modules/public_datasets.ipynb) +[[Public datasets tutorial]](https://github.com/Project-MONAI/tutorials/blob/main/modules/public_datasets.ipynb) The common workflow of predefined datasets: ![pre-defined dataset](../images/dataset_progress.png) @@ -123,7 +123,7 @@ MONAI implements reference networks with the aim of both flexibility and code re Network layers and blocks are in general implemented to be compatible with spatial 1D, 2D and 3D inputs. Users can easily integrate the layers, blocks and networks as part of their customised pipelines. -Various utilities are provided to leverage the existing model weights, e.g., finetuning [from MMAR](https://github.com/Project-MONAI/tutorials/blob/master/modules/transfer_mmar.ipynb) +Various utilities are provided to leverage the existing model weights, e.g., finetuning [from MMAR](https://github.com/Project-MONAI/tutorials/blob/main/modules/transfer_mmar.ipynb) or [from a bundle in MONAI model-zoo](https://github.com/Project-MONAI/tutorials/tree/main/model_zoo). ### C++/CUDA optimized modules @@ -152,7 +152,7 @@ widely-used approaches. Currently, several popular evaluation metrics and infere For model inferences on large volumes, the sliding window approach is a popular choice to achieve high performance while having flexible memory requirements (_alternatively, please check out the latest research on [model parallel -training](#lamp-large-deep-nets-with-automated-model-parallelism-for-image-segmentation) using MONAI_). It also supports +training](https://github.com/Project-MONAI/research-contributions/tree/main/lamp-automated-model-parallelism). It also supports `overlap` and `blending_mode` configurations to handle the overlapped windows for better performances. ![sliding window scheme](../images/sliding_window.png) @@ -162,7 +162,7 @@ training](#lamp-large-deep-nets-with-automated-model-parallelism-for-image-segme Various useful evaluation metrics have been implemented to measure the quality of medical image specific models. These include `Mean Dice`, `ROCAUC`, `Confusion Matrices`, `Hausdorff Distance`, `Surface Distance`, `Occlusion Sensitivity`. -The APIs also support [multi-processing computation](https://github.com/Project-MONAI/tutorials/blob/master/modules/compute_metric.py). +The APIs also support [multi-processing computation](https://github.com/Project-MONAI/tutorials/blob/main/modules/compute_metric.py). ### Report generation `MetricsSaver` is provided to write the final metric summary report: `mean`, `median`, `max`, `min`, `percentile`, `std`: @@ -170,7 +170,7 @@ The APIs also support [multi-processing computation](https://github.com/Project- ![metrics report example](../images/metrics_report.png) ## Visualization -Beyond the simple point and curve plotting, intuitive interfaces are provided to visualize multidimensional data as GIF animations in TensorBoard. This could provide a quick qualitative assessment of the model by visualizing, for example, the volumetric inputs, segmentation maps, and intermediate feature maps. A runnable example with visualization is available at [UNet training example](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_training_dict.py). To work with ignite program, MONAI also provides several ignite handlers to visualize training curve and metrics with `TensorBoard` or `MLFlow`, more details is available in [TensorBoard and MLFlow handlers example](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb). +Beyond the simple point and curve plotting, intuitive interfaces are provided to visualize multidimensional data as GIF animations in TensorBoard. This could provide a quick qualitative assessment of the model by visualizing, for example, the volumetric inputs, segmentation maps, and intermediate feature maps. A runnable example with visualization is available at [UNet training example](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_training_dict.py). To work with ignite program, MONAI also provides several ignite handlers to visualize training curve and metrics with `TensorBoard` or `MLFlow`, more details is available in [TensorBoard and MLFlow handlers example](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unet_segmentation_3d_ignite.ipynb). To easily visualize a 3D image as frames of 2D images, MONAI provides the utility `matshow3d` based on `matplotlib` library. It can plot frames of image for the specified dimension, showing a spleen 3D image as example: `matshow3d(volume=image, figsize=(100, 100), every_n=10, frame_dim=-1 show=True, cmap="gray")` @@ -181,13 +181,13 @@ MONAI also provides the `blend_images` utility to blend the `image` and `label` ![blend example](../images/blend.png) -For more details of `TensorBoard utility`, `matshow3d` and `blend_images`, please check the [visualization tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/transform_visualization.ipynb). +For more details of `TensorBoard utility`, `matshow3d` and `blend_images`, please check the [visualization tutorial](https://github.com/Project-MONAI/tutorials/blob/main/modules/transform_visualization.ipynb). And to visualize the class activation mapping for a trained classification model, MONAI provides CAM, GradCAM, GradCAM++ APIs for both 2D and 3D models: ![CAM visualization example](../images/cam.png) -The above example is generated by computing [GradCAM/GradCAM++ from a lung CT lesion classification model](https://github.com/Project-MONAI/tutorials/tree/master/modules/interpretability). +The above example is generated by computing [GradCAM/GradCAM++ from a lung CT lesion classification model](https://github.com/Project-MONAI/tutorials/tree/main/modules/interpretability). ## Workflows @@ -199,14 +199,14 @@ The trainers and evaluators of the workflows are compatible with pytorch-ignite ### General workflows pipeline -The workflow and some of MONAI event handlers are shown as below [[Workflow examples]](https://github.com/Project-MONAI/tutorials/tree/master/modules/engines): +The workflow and some of MONAI event handlers are shown as below [[Workflow examples]](https://github.com/Project-MONAI/tutorials/tree/main/modules/engines): ![workflow pipeline](../images/workflows.png) ### EnsembleEvaluator -A typical ensemble procoess is implemented as a ready-to-use workflow [[Cross validation and model ensemble tutorial]](https://github.com/Project-MONAI/tutorials/blob/master/modules/cross_validation_models_ensemble.ipynb): +A typical ensemble procoess is implemented as a ready-to-use workflow [[Cross validation and model ensemble tutorial]](https://github.com/Project-MONAI/tutorials/blob/main/modules/cross_validation_models_ensemble.ipynb): 1. Split all the training dataset into K folds. 2. Train K models with every K-1 folds data. 3. Execute inference on the test data with all the K models. @@ -221,7 +221,7 @@ A typical ensemble procoess is implemented as a ready-to-use workflow [[Cross va 1. enabling postprocessing transforms for each item independently -- randomised transforms could be applied differently for each predicted item in a batch. 2. simplifying the transform APIs and reducing the input validation burdens because both the preprocessing and postprocessing transforms now only need to support the "channel-first" input format. 3. enabling the `Invertd` transform for the predictions and the inverted data with different shapes, as the data items are in a list, not stacked in a single tensor. -4. allowing for both batch-first tensor and list of channel-first tensors in a flexible metric computation. [[decollate batch tutorial]](https://github.com/Project-MONAI/tutorials/blob/master/modules/decollate_batch.ipynb) +4. allowing for both batch-first tensor and list of channel-first tensors in a flexible metric computation. [[decollate batch tutorial]](https://github.com/Project-MONAI/tutorials/blob/main/modules/decollate_batch.ipynb) A typical process of `decollate batch` is illustrated as follows (with a `batch_size=N` model predictions and labels as an example): @@ -231,8 +231,8 @@ A typical process of `decollate batch` is illustrated as follows (with a `batch_ Except for the pytorch-ignite based `monai.engines`, most of the MONAI modules could be used independently or combined with other software packages. For example, MONAI can be easily integrated into popular frameworks such as -PyTorch-Lightning and Catalyst. [[Lightning segmentation](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d_lightning.ipynb), -[Catalyst segmentation](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_catalyst.ipynb)] +PyTorch-Lightning and Catalyst. [[Lightning segmentation](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/spleen_segmentation_3d_lightning.ipynb), +[Catalyst segmentation](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unet_segmentation_3d_catalyst.ipynb)] ## Bundle @@ -264,7 +264,7 @@ A typical bundle example can include: ┗━ *license.txt ``` Details about the bundle config definition and syntax & examples are at [config syntax](https://docs.monai.io/en/latest/config_syntax.html). -A step-by-step [get started](https://github.com/Project-MONAI/tutorials/blob/master/bundle/get_started.md) tutorial notebook can help users quickly set up a bundle. [[bundle examples](https://github.com/Project-MONAI/tutorials/tree/main/bundle), [model-zoo](https://github.com/Project-MONAI/model-zoo)] +A step-by-step [get started](https://github.com/Project-MONAI/tutorials/blob/main/bundle/get_started.md) tutorial notebook can help users quickly set up a bundle. [[bundle examples](https://github.com/Project-MONAI/tutorials/tree/main/bundle), [model-zoo](https://github.com/Project-MONAI/model-zoo)] ## Federated Learning @@ -308,7 +308,7 @@ MONAI provides state-of-the-art performance optimization methods including: ### Auto mixed precision (AMP) Simply set `amp=True/False` in `SupervisedTrainer` or `SupervisedEvaluator` during training or evaluation to enable/disable AMP -Example benchmark results are as follows [[AMP training tutorial]](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/automatic_mixed_precision.ipynb): +Example benchmark results are as follows [[AMP training tutorial]](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/automatic_mixed_precision.ipynb): training with AMP ON/OFF on a NVIDIA V100 GPU with CUDA 11 and PyTorch 1.6: @@ -318,15 +318,15 @@ training with AMP ON/OFF on a NVIDIA A100 GPU with CUDA 11 and PyTorch 1.6: ![amp a100 results](../images/amp_training_a100.png) -Several tools including `DLProf`, `Nsight`, `NVTX` and `NVML` can be used with MONAI to identify the performance bottleneck. [[profiling tutorial]](https://github.com/Project-MONAI/tutorials/blob/master/performance_profiling/radiology/profiling_train_base_nvtx.md) +Several tools including `DLProf`, `Nsight`, `NVTX` and `NVML` can be used with MONAI to identify the performance bottleneck. [[profiling tutorial]](https://github.com/Project-MONAI/tutorials/blob/main/performance_profiling/radiology/profiling_train_base_nvtx.md) ### Distributed training The distributed data-parallel APIs of MONAI are compatible with the native PyTorch distributed module, pytorch-ignite distributed module, Horovod, XLA, and the SLURM platform. -[[distributed training tutorial]](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/distributed_training/brats_training_ddp.py) +[[distributed training tutorial]](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/distributed_training/brats_training_ddp.py) ![distributed training results](../images/brats_distributed.png) -The [fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb) +The [fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/fast_training_tutorial.ipynb) combines `AMP` with `CacheDataset`, `GPU cache`, `GPU transforms`, `ThreadDataLoader`, tuning of networks and optimizers, can achieve substantial speedup compared with a PyTorch regular implementation. diff --git a/docs/source/networks.rst b/docs/source/networks.rst index a4c225de29..7d34619fe6 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -58,6 +58,11 @@ Blocks .. autoclass:: Mish :members: +`GEGLU` +~~~~~~~ +.. autoclass:: GEGLU + :members: + `GCN Module` ~~~~~~~~~~~~ .. autoclass:: GCN @@ -349,6 +354,16 @@ Layers .. autoclass:: BilateralFilter :members: +`TrainableBilateralFilter` +~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: TrainableBilateralFilter + :members: + +`TrainableJointBilateralFilter` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: TrainableJointBilateralFilter + :members: + `PHLFilter` ~~~~~~~~~~~ .. autoclass:: PHLFilter diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index c92d6fe46a..56fe4bc1e7 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -88,6 +88,21 @@ Generic Interfaces .. autoclass:: RandomOrder :members: +Functionals +----------- + +Crop and Pad (functional) +^^^^^^^^^^^^^^^^^^^^^^^^^ +.. automodule:: monai.transforms.croppad.functional + :members: + +Spatial (functional) +^^^^^^^^^^^^^^^^^^^^ +.. automodule:: monai.transforms.spatial.functional + :members: + +.. currentmodule:: monai.transforms + Vanilla Transforms ------------------ @@ -1134,6 +1149,17 @@ Utility :members: :special-members: __call__ +`ImageFilter` +""""""""""""" +.. autoclass:: ImageFilter + :members: + :special-members: __call__ + +`RandImageFilter` +""""""""""""""""" +.. autoclass:: RandImageFilter + :members: + :special-members: __call__ Dictionary Transforms --------------------- @@ -2124,6 +2150,19 @@ Utility (Dict) :members: :special-members: __call__ +`ImageFilterd` +"""""""""""""" +.. autoclass:: ImageFilterd + :members: + :special-members: __call__ + +`RandImageFilterd` +"""""""""""""""""" +.. autoclass:: RandImageFilterd + :members: + :special-members: __call__ + + MetaTensor ^^^^^^^^^^ diff --git a/docs/source/whatsnew_0_7.md b/docs/source/whatsnew_0_7.md index 6df64948b0..6f515e64c3 100644 --- a/docs/source/whatsnew_0_7.md +++ b/docs/source/whatsnew_0_7.md @@ -44,7 +44,7 @@ training. With this release, we actively evaluate and enhance the quality and flexibility of the MONAI core modules, using the public Kaggle challenge as a testbed. [A -reimplementation](https://github.com/Project-MONAI/tutorials/tree/master/kaggle/RANZCR/4th_place_solution) +reimplementation](https://github.com/Project-MONAI/tutorials/tree/main/competitions/kaggle/RANZCR/4th_place_solution) of a state-of-the-art solution at [Kaggle RANZCR CLiP - Catheter and Line Position Challenge](https://www.kaggle.com/c/ranzcr-clip-catheter-line-classification) diff --git a/docs/source/whatsnew_1_0.md b/docs/source/whatsnew_1_0.md index e8e0b031c1..7e347780bf 100644 --- a/docs/source/whatsnew_1_0.md +++ b/docs/source/whatsnew_1_0.md @@ -39,7 +39,7 @@ With [the new federated learning APIs](https://docs.monai.io/en/latest/fl.html), and executed using single- or multi-GPU training. The MONAI FL client also allows computing summary data statistics (e.g., intensity histograms) on the datasets defined in the bundle configs. These can be shared and visualized on the FL server, for example, using NVIDIA FLARE's federated statistics operators, -see [here](https://github.com/NVIDIA/NVFlare/tree/dev/integration/monai/examples/spleen_ct_segmentation) for an example. +see [here](https://github.com/NVIDIA/NVFlare/tree/dev/integration/monai/examples) for an example. We welcome other federated learning toolkits to integrate with MONAI FL APIs, building a common foundation for collaborative learning in medical imaging. diff --git a/environment-dev.yml b/environment-dev.yml index bf69764e58..400822aaf3 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -5,7 +5,7 @@ channels: - nvidia - conda-forge dependencies: - - numpy>=1.17 + - numpy>=1.20 - pytorch>=1.8 - torchvision - pytorch-cuda=11.6 diff --git a/monai/__init__.py b/monai/__init__.py index 3f6c06d82d..62ca00c5c5 100644 --- a/monai/__init__.py +++ b/monai/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import sys diff --git a/monai/_extensions/__init__.py b/monai/_extensions/__init__.py index fd32d71840..47d0c7021a 100644 --- a/monai/_extensions/__init__.py +++ b/monai/_extensions/__init__.py @@ -9,4 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .loader import load_module diff --git a/monai/_extensions/loader.py b/monai/_extensions/loader.py index 35b050ba56..7affd1a3eb 100644 --- a/monai/_extensions/loader.py +++ b/monai/_extensions/loader.py @@ -9,13 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import platform from _thread import interrupt_main from contextlib import contextmanager from glob import glob from os import path from threading import Timer -from typing import Optional +from types import ModuleType import torch @@ -45,8 +47,8 @@ def timeout(time, message): def load_module( - module_name: str, defines: Optional[dict] = None, verbose_build: bool = False, build_timeout: int = 300 -): + module_name: str, defines: dict | None = None, verbose_build: bool = False, build_timeout: int = 300 +) -> ModuleType: """ Handles the loading of c++ extension modules. @@ -88,4 +90,4 @@ def load_module( name=name, sources=source, extra_cflags=define_args, extra_cuda_cflags=define_args, verbose=verbose_build ) - return module + return module # type: ignore[no-any-return] diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index 3df0e95a98..9cc7aeb8e0 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset, TciaDataset from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/auto3dseg/__init__.py b/monai/apps/auto3dseg/__init__.py index 7c335f4850..a90c626da9 100644 --- a/monai/apps/auto3dseg/__init__.py +++ b/monai/apps/auto3dseg/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .auto_runner import AutoRunner from .bundle_gen import BundleAlgo, BundleGen from .data_analyzer import DataAnalyzer diff --git a/monai/apps/auto3dseg/__main__.py b/monai/apps/auto3dseg/__main__.py index eec56b7582..d169467ba9 100644 --- a/monai/apps/auto3dseg/__main__.py +++ b/monai/apps/auto3dseg/__main__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from monai.apps.auto3dseg.auto_runner import AutoRunner from monai.apps.auto3dseg.bundle_gen import BundleAlgo, BundleGen from monai.apps.auto3dseg.data_analyzer import DataAnalyzer diff --git a/monai/apps/auto3dseg/auto_runner.py b/monai/apps/auto3dseg/auto_runner.py index 2ba10b9833..290a79e324 100644 --- a/monai/apps/auto3dseg/auto_runner.py +++ b/monai/apps/auto3dseg/auto_runner.py @@ -9,12 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import shutil import subprocess from copy import deepcopy from time import sleep -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, cast import numpy as np import torch @@ -204,24 +206,23 @@ class AutoRunner: """ - analyze_params: Optional[Dict] + analyze_params: dict | None def __init__( self, work_dir: str = "./work_dir", - input: Union[Dict[str, Any], str, None] = None, - algos: Optional[Union[Dict, List, str]] = None, - analyze: Optional[bool] = None, - algo_gen: Optional[bool] = None, - train: Optional[bool] = None, + input: dict[str, Any] | str | None = None, + algos: dict | list | str | None = None, + analyze: bool | None = None, + algo_gen: bool | None = None, + train: bool | None = None, hpo: bool = False, hpo_backend: str = "nni", ensemble: bool = True, not_use_cache: bool = False, - templates_path_or_url: Optional[str] = None, - **kwargs, + templates_path_or_url: str | None = None, + **kwargs: Any, ): - logger.info(f"AutoRunner using work directory {work_dir}") os.makedirs(work_dir, exist_ok=True) @@ -234,7 +235,7 @@ def __init__( input = self.data_src_cfg_name logger.info(f"Input config is not provided, using the default {input}") - if isinstance(input, Dict): + if isinstance(input, dict): self.data_src_cfg = input ConfigParser.export_config_file( config=input, filepath=self.data_src_cfg_name, fmt="yaml", default_flow_style=None, sort_keys=False @@ -279,14 +280,14 @@ def __init__( self.set_num_fold(num_fold=self.num_fold) self.gpu_customization = False - self.gpu_customization_specs: Dict[str, Any] = {} + self.gpu_customization_specs: dict[str, Any] = {} # hpo if hpo_backend.lower() != "nni": raise NotImplementedError("HPOGen backend only supports NNI") self.hpo = hpo and has_nni self.set_hpo_params() - self.search_space: Dict[str, Dict[str, Any]] = {} + self.search_space: dict[str, dict[str, Any]] = {} self.hpo_tasks = 0 def read_cache(self): @@ -336,8 +337,8 @@ def export_cache(self, **kwargs): ) def set_gpu_customization( - self, gpu_customization: bool = False, gpu_customization_specs: Optional[Dict[str, Any]] = None - ): + self, gpu_customization: bool = False, gpu_customization_specs: dict[str, Any] | None = None + ) -> None: """ Set options for GPU-based parameter customization/optimization. @@ -372,7 +373,7 @@ def set_gpu_customization( if gpu_customization_specs is not None: self.gpu_customization_specs = gpu_customization_specs - def set_num_fold(self, num_fold: int = 5): + def set_num_fold(self, num_fold: int = 5) -> None: """ Set the number of cross validation folds for all algos. @@ -389,7 +390,7 @@ def set_num_fold(self, num_fold: int = 5): if self.ensemble_method_name == "AlgoEnsembleBestByFold": self.ensemble_method.n_fold = self.num_fold # type: ignore - def set_training_params(self, params: Optional[Dict[str, Any]] = None): + def set_training_params(self, params: dict[str, Any] | None = None) -> None: """ Set the training params for all algos. @@ -404,7 +405,7 @@ def set_training_params(self, params: Optional[Dict[str, Any]] = None): """ self.train_params = deepcopy(params) if params is not None else {} - def set_prediction_params(self, params: Optional[Dict[str, Any]] = None): + def set_prediction_params(self, params: dict[str, Any] | None = None) -> None: """ Set the prediction params for all algos. @@ -420,7 +421,7 @@ def set_prediction_params(self, params: Optional[Dict[str, Any]] = None): """ self.pred_params = deepcopy(params) if params is not None else {} - def set_analyze_params(self, params: Optional[Dict[str, Any]] = None): + def set_analyze_params(self, params: dict[str, Any] | None = None) -> None: """ Set the data analysis extra params. @@ -438,7 +439,7 @@ def set_analyze_params(self, params: Optional[Dict[str, Any]] = None): else: self.analyze_params = deepcopy(params) - def set_hpo_params(self, params: Optional[Dict[str, Any]] = None): + def set_hpo_params(self, params: dict[str, Any] | None = None) -> None: """ Set parameters for the HPO module and the algos before the training. It will attempt to (1) override bundle templates with the key-value pairs in ``params`` (2) change the config of the HPO module (e.g. NNI) if the @@ -513,7 +514,7 @@ def set_image_save_transform(self, kwargs): output_dir=output_dir, output_postfix=output_postfix, output_dtype=output_dtype, resample=resample, **kwargs ) - def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestByFold", **kwargs): + def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestByFold", **kwargs: Any) -> None: """ Set the bundle ensemble method @@ -536,7 +537,7 @@ def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestByFol else: raise NotImplementedError(f"Ensemble method {self.ensemble_method_name} is not implemented.") - def _train_algo_in_sequence(self, history: List[Dict[str, Any]]): + def _train_algo_in_sequence(self, history: list[dict[str, Any]]) -> None: """ Train the Algos in a sequential scheme. The order of training is randomized. @@ -639,7 +640,6 @@ def run(self): # step 2: algorithm generation if self.algo_gen: - if not os.path.isfile(self.datastats_filename): raise ValueError( f"Could not find the datastats file {self.datastats_filename}. " diff --git a/monai/apps/auto3dseg/bundle_gen.py b/monai/apps/auto3dseg/bundle_gen.py index 59e90b795b..7c27284a0b 100644 --- a/monai/apps/auto3dseg/bundle_gen.py +++ b/monai/apps/auto3dseg/bundle_gen.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import importlib import os import shutil @@ -16,10 +18,11 @@ import sys import time import warnings +from collections.abc import Mapping from copy import deepcopy from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any from urllib.parse import urlparse import torch @@ -32,7 +35,7 @@ from monai.utils import ensure_tuple logger = get_logger(module_name=__name__) -ALGO_HASH = os.environ.get("MONAI_ALGO_HASH", "1dde7a1") +ALGO_HASH = os.environ.get("MONAI_ALGO_HASH", "d0fa876d") __all__ = ["BundleAlgo", "BundleGen"] @@ -47,7 +50,7 @@ class BundleAlgo(Algo): from monai.apps.auto3dseg import BundleAlgo - data_stats_yaml = "/workspace/data_stats.yaml" + data_stats_yaml = "/workspace/datastats.yaml" algo = BundleAlgo(template_path=../algorithms/templates/segresnet2d/configs) algo.set_data_stats(data_stats_yaml) # algo.set_data_src("../data_src.json") @@ -77,7 +80,7 @@ def __init__(self, template_path: str): # track records when filling template config: {"": {"": value, ...}, ...} self.fill_records: dict = {} - def set_data_stats(self, data_stats_files: str): + def set_data_stats(self, data_stats_files: str) -> None: """ Set the data analysis report (generated by DataAnalyzer). @@ -86,7 +89,7 @@ def set_data_stats(self, data_stats_files: str): """ self.data_stats_files = data_stats_files - def set_data_source(self, data_src_cfg: str): + def set_data_source(self, data_src_cfg: str) -> None: """ Set the data source configuration file @@ -97,7 +100,7 @@ def set_data_source(self, data_src_cfg: str): """ self.data_list_file = data_src_cfg - def fill_template_config(self, data_stats_filename: str, algo_path: str, **kwargs) -> dict: + def fill_template_config(self, data_stats_filename: str, algo_path: str, **kwargs: Any) -> dict: """ The configuration files defined when constructing this Algo instance might not have a complete training and validation pipelines. Some configuration components and hyperparameters of the pipelines depend on the @@ -114,7 +117,7 @@ def fill_template_config(self, data_stats_filename: str, algo_path: str, **kwarg """ return {} - def export_to_disk(self, output_path: str, algo_name: str, **kwargs): + def export_to_disk(self, output_path: str, algo_name: str, **kwargs: Any) -> None: """ Fill the configuration templates, write the bundle (configs + scripts) to folder `output_path/algo_name`. @@ -175,7 +178,7 @@ def _create_cmd(self, train_params=None): cmd += f" --{k}={v}" return cmd, devices_info - def _run_cmd(self, cmd: str, devices_info: str): + def _run_cmd(self, cmd: str, devices_info: str) -> subprocess.CompletedProcess: """ Execute the training command with target devices information. @@ -251,7 +254,7 @@ def infer(self, image_file): spec.loader.exec_module(infer_class) # type: ignore return infer_class.InferClass(configs_path, *args, **kwargs) - def predict(self, predict_files: list, predict_params=None): + def predict(self, predict_files: list, predict_params: dict | None = None) -> list: """ Use the trained model to predict the outputs with a given input image. @@ -287,7 +290,7 @@ def get_output_path(self): } -def _download_algos_url(url: str, at_path: str): +def _download_algos_url(url: str, at_path: str) -> dict[str, dict[str, str]]: """ Downloads the algorithm templates release archive, and extracts it into a parent directory of the at_path folder. Returns a dictionary of the algorithm templates. @@ -323,7 +326,7 @@ def _download_algos_url(url: str, at_path: str): def _copy_algos_folder(folder, at_path): """ Copies the algorithm templates folder to at_path. - Returns a dictionary of of algorithm templates. + Returns a dictionary of algorithm templates. """ folder = os.path.abspath(folder) at_path = os.path.abspath(at_path) @@ -339,7 +342,8 @@ def _copy_algos_folder(folder, at_path): algos_all[name] = dict( _target_=f"{name}.scripts.algo.{name.capitalize()}Algo", template_path=os.path.join(at_path, name) ) - if len(algos_all) == 0: + logger.info(f"{name} -- {algos_all[name]}") + if not algos_all: raise ValueError(f"Unable to find any algos in {folder}") return algos_all @@ -363,20 +367,18 @@ class BundleGen(AlgoGen): .. code-block:: bash - python -m monai.apps.auto3dseg BundleGen generate --data_stats_filename="../algorithms/data_stats.yaml" + python -m monai.apps.auto3dseg BundleGen generate --data_stats_filename="../algorithms/datastats.yaml" """ def __init__( self, algo_path: str = ".", - algos: Optional[Union[Dict, List, str]] = None, - templates_path_or_url: Optional[str] = None, - data_stats_filename: Optional[str] = None, - data_src_cfg_name: Optional[str] = None, + algos: dict | list | str | None = None, + templates_path_or_url: str | None = None, + data_stats_filename: str | None = None, + data_src_cfg_name: str | None = None, ): - if algos is None or isinstance(algos, (list, tuple, str)): - if templates_path_or_url is None: templates_path_or_url = default_algo_zip @@ -384,9 +386,11 @@ def __init__( if os.path.isdir(templates_path_or_url): # if a local folder, copy if necessary + logger.info(f"BundleGen from directory {templates_path_or_url}") algos_all = _copy_algos_folder(folder=templates_path_or_url, at_path=at_path) elif urlparse(templates_path_or_url).scheme in ("http", "https"): # if url, trigger the download and extract process + logger.info(f"BundleGen from {templates_path_or_url}") algos_all = _download_algos_url(url=templates_path_or_url, at_path=at_path) else: raise ValueError(f"{self.__class__} received invalid templates_path_or_url: {templates_path_or_url}") @@ -401,7 +405,6 @@ def __init__( self.algos: Any = [] if isinstance(algos, dict): for algo_name, algo_params in algos.items(): - template_path = os.path.dirname(algo_params.get("template_path", ".")) if len(template_path) > 0 and template_path not in sys.path: sys.path.append(template_path) @@ -426,9 +429,9 @@ def __init__( self.data_stats_filename = data_stats_filename self.data_src_cfg_filename = data_src_cfg_name - self.history: List[Dict] = [] + self.history: list[dict] = [] - def set_data_stats(self, data_stats_filename: str): + def set_data_stats(self, data_stats_filename: str) -> None: """ Set the data stats filename @@ -454,17 +457,17 @@ def get_data_src(self): """Get the data source filename""" return self.data_src_cfg_filename - def get_history(self) -> List: + def get_history(self) -> list: """get the history of the bundleAlgo object with their names/identifiers""" return self.history def generate( self, - output_folder=".", + output_folder: str = ".", num_fold: int = 5, gpu_customization: bool = False, - gpu_customization_specs: Optional[Dict[str, Any]] = None, - ): + gpu_customization_specs: dict[str, Any] | None = None, + ) -> None: """ Generate the bundle scripts/configs for each bundleAlgo diff --git a/monai/apps/auto3dseg/data_analyzer.py b/monai/apps/auto3dseg/data_analyzer.py index 4f9faa4928..3bb67bdbe2 100644 --- a/monai/apps/auto3dseg/data_analyzer.py +++ b/monai/apps/auto3dseg/data_analyzer.py @@ -9,9 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings from os import path -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, cast import numpy as np import torch @@ -111,20 +113,20 @@ class DataAnalyzer: def __init__( self, - datalist: Union[str, Dict], + datalist: str | dict, dataroot: str = "", - output_path: str = "./data_stats.yaml", + output_path: str = "./datastats.yaml", average: bool = True, do_ccp: bool = False, - device: Union[str, torch.device] = "cpu", - worker: int = 2, + device: str | torch.device = "cpu", + worker: int = 4, image_key: str = "image", - label_key: Optional[str] = "label", - hist_bins: Optional[Union[list, int]] = 0, - hist_range: Optional[list] = None, - fmt: Optional[str] = "yaml", + label_key: str | None = "label", + hist_bins: list | int | None = 0, + hist_range: list | None = None, + fmt: str = "yaml", histogram_only: bool = False, - **extra_params, + **extra_params: Any, ): if path.isfile(output_path): warnings.warn(f"File {output_path} already exists and will be overwritten.") @@ -146,7 +148,7 @@ def __init__( self.extra_params = extra_params @staticmethod - def _check_data_uniformity(keys: List[str], result: Dict): + def _check_data_uniformity(keys: list[str], result: dict) -> bool: """ Check data uniformity since DataAnalyzer provides no support to multi-modal images with different affine matrices/spacings due to monai transforms. @@ -207,12 +209,11 @@ def get_all_case_stats(self, key="training", transform_list=None): keys = list(filter(None, [self.image_key, self.label_key])) if transform_list is None: transform_list = [ - LoadImaged(keys=keys, ensure_channel_first=True), + LoadImaged(keys=keys, ensure_channel_first=True, image_only=True), EnsureTyped(keys=keys, data_type="tensor", dtype=torch.float), Orientationd(keys=keys, axcodes="RAS"), ] if self.label_key is not None: - allowed_shape_difference = self.extra_params.pop("allowed_shape_difference", 5) transform_list.append( EnsureSameShaped( @@ -226,14 +227,21 @@ def get_all_case_stats(self, key="training", transform_list=None): files, _ = datafold_read(datalist=self.datalist, basedir=self.dataroot, fold=-1, key=key) dataset = Dataset(data=files, transform=transform) - dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.worker, collate_fn=no_collation) - result: Dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []} + dataloader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + num_workers=self.worker, + collate_fn=no_collation, + pin_memory=self.device.type == "cuda", + ) + result: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []} + result_bycase: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []} if not has_tqdm: warnings.warn("tqdm is not installed. not displaying the caching progress.") for batch_data in tqdm(dataloader) if has_tqdm else dataloader: - batch_data = batch_data[0] batch_data[self.image_key] = batch_data[self.image_key].to(self.device) @@ -259,17 +267,29 @@ def get_all_case_stats(self, key="training", transform_list=None): DataStatsKeys.LABEL_STATS: d[DataStatsKeys.LABEL_STATS], } ) - result[DataStatsKeys.BY_CASE].append(stats_by_cases) + result_bycase[DataStatsKeys.BY_CASE].append(stats_by_cases) - result[DataStatsKeys.SUMMARY] = summarizer.summarize(cast(List, result[DataStatsKeys.BY_CASE])) + n_cases = len(result_bycase[DataStatsKeys.BY_CASE]) + + result[DataStatsKeys.SUMMARY] = summarizer.summarize(cast(list, result_bycase[DataStatsKeys.BY_CASE])) + result[DataStatsKeys.SUMMARY]["n_cases"] = n_cases + result[DataStatsKeys.BY_CASE] = [None] * n_cases if not self._check_data_uniformity([ImageStatsKeys.SPACING], result): print("Data spacing is not completely uniform. MONAI transforms may provide unexpected result") if self.output_path: + # saving summary and by_case as 2 files, to minimize loading time when only the summary is necessary ConfigParser.export_config_file( result, self.output_path, fmt=self.fmt, default_flow_style=None, sort_keys=False ) + ConfigParser.export_config_file( + result_bycase, + self.output_path.replace(".yaml", "_by_case.yaml"), + fmt=self.fmt, + default_flow_style=None, + sort_keys=False, + ) # release memory d = None @@ -278,4 +298,6 @@ def get_all_case_stats(self, key="training", transform_list=None): # limitation: https://github.com/pytorch/pytorch/issues/12873#issuecomment-482916237 torch.cuda.empty_cache() + # return combined + result[DataStatsKeys.BY_CASE] = result_bycase[DataStatsKeys.BY_CASE] return result diff --git a/monai/apps/auto3dseg/ensemble_builder.py b/monai/apps/auto3dseg/ensemble_builder.py index c63658789b..257102e8ac 100644 --- a/monai/apps/auto3dseg/ensemble_builder.py +++ b/monai/apps/auto3dseg/ensemble_builder.py @@ -9,13 +9,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os from abc import ABC, abstractmethod +from collections.abc import Sequence from copy import deepcopy -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, cast from warnings import warn import numpy as np +import torch from monai.apps.auto3dseg.bundle_gen import BundleAlgo from monai.apps.utils import get_logger @@ -67,7 +71,7 @@ def get_algo_ensemble(self): """ return self.algo_ensemble - def set_infer_files(self, dataroot: str, data_list_or_path: Union[str, List], data_key: str = "testing"): + def set_infer_files(self, dataroot: str, data_list_or_path: str | list, data_key: str = "testing") -> None: """ Set the files to perform model inference. @@ -78,7 +82,7 @@ def set_infer_files(self, dataroot: str, data_list_or_path: Union[str, List], da self.infer_files = [] - if isinstance(data_list_or_path, List): + if isinstance(data_list_or_path, list): self.infer_files = data_list_or_path elif isinstance(data_list_or_path, str): datalist = ConfigParser.load_config_file(data_list_or_path) @@ -105,7 +109,7 @@ def ensemble_pred(self, preds, sigmoid=False): if self.mode == "mean": prob = MeanEnsemble()(preds) - return prob2class(prob, dim=0, keepdim=True, sigmoid=sigmoid) + return prob2class(cast(torch.Tensor, prob), dim=0, keepdim=True, sigmoid=sigmoid) elif self.mode == "vote": classes = [prob2class(p, dim=0, keepdim=True, sigmoid=sigmoid) for p in preds] if sigmoid: @@ -113,7 +117,7 @@ def ensemble_pred(self, preds, sigmoid=False): else: return VoteEnsemble(num_classes=preds[0].shape[0])(classes) - def __call__(self, pred_param: Optional[Dict[str, Any]] = None): + def __call__(self, pred_param: dict[str, Any] | None = None) -> list[torch.Tensor]: """ Use the ensembled model to predict result. @@ -176,7 +180,6 @@ class AlgoEnsembleBestN(AlgoEnsemble): """ def __init__(self, n_best: int = 5): - super().__init__() self.n_best = n_best @@ -187,7 +190,7 @@ def sort_score(self): scores = concat_val_to_np(self.algos, [AlgoEnsembleKeys.SCORE]) return np.argsort(scores).tolist() - def collect_algos(self, n_best: int = -1): + def collect_algos(self, n_best: int = -1) -> None: """ Rank the algos by finding the top N (n_best) validation scores. """ @@ -221,7 +224,6 @@ class AlgoEnsembleBestByFold(AlgoEnsemble): """ def __init__(self, n_fold: int = 5): - super().__init__() self.n_fold = n_fold @@ -233,7 +235,7 @@ def collect_algos(self) -> None: self.algo_ensemble = [] for f_idx in range(self.n_fold): best_score = -1.0 - best_model: Optional[BundleAlgo] = None + best_model: BundleAlgo | None = None for algo in self.algos: # algorithm folder: {net}_{fold_index}_{other} identifier = algo[AlgoEnsembleKeys.ID].split("_")[1] @@ -264,8 +266,8 @@ class AlgoEnsembleBuilder: """ - def __init__(self, history: Sequence[Dict], data_src_cfg_filename: Optional[str] = None): - self.infer_algos: List[Dict[AlgoEnsembleKeys, Any]] = [] + def __init__(self, history: Sequence[dict], data_src_cfg_filename: str | None = None): + self.infer_algos: list[dict[AlgoEnsembleKeys, Any]] = [] self.ensemble: AlgoEnsemble self.data_src_cfg = ConfigParser(globals=False) @@ -292,7 +294,7 @@ def __init__(self, history: Sequence[Dict], data_src_cfg_filename: Optional[str] self.add_inferer(name, gen_algo, best_metric) - def add_inferer(self, identifier: str, gen_algo: BundleAlgo, best_metric: Optional[float] = None): + def add_inferer(self, identifier: str, gen_algo: BundleAlgo, best_metric: float | None = None) -> None: """ Add model inferer to the builder. @@ -308,7 +310,7 @@ def add_inferer(self, identifier: str, gen_algo: BundleAlgo, best_metric: Option algo = {AlgoEnsembleKeys.ID: identifier, AlgoEnsembleKeys.ALGO: gen_algo, AlgoEnsembleKeys.SCORE: best_metric} self.infer_algos.append(algo) - def set_ensemble_method(self, ensemble: AlgoEnsemble, *args, **kwargs): + def set_ensemble_method(self, ensemble: AlgoEnsemble, *args: Any, **kwargs: Any) -> None: """ Set the ensemble method. diff --git a/monai/apps/auto3dseg/hpo_gen.py b/monai/apps/auto3dseg/hpo_gen.py index a80890f570..46f7deea72 100644 --- a/monai/apps/auto3dseg/hpo_gen.py +++ b/monai/apps/auto3dseg/hpo_gen.py @@ -9,16 +9,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os from abc import abstractmethod from copy import deepcopy -from typing import Optional, cast +from typing import Any, cast from warnings import warn from monai.apps.auto3dseg.bundle_gen import BundleAlgo from monai.apps.utils import get_logger from monai.auto3dseg import Algo, AlgoGen, algo_from_pickle, algo_to_pickle from monai.bundle.config_parser import ConfigParser +from monai.config import PathLike from monai.utils import optional_import nni, has_nni = optional_import("nni") @@ -106,7 +109,7 @@ class NNIGen(HPOGen): NNI command manually. """ - def __init__(self, algo: Optional[Algo] = None, params=None): + def __init__(self, algo: Algo | None = None, params: dict | None = None): self.algo: Algo self.hint = "" self.obj_filename = "" @@ -164,7 +167,7 @@ def get_hyperparameters(self): warn("NNI is not detected. The code will continue to run without NNI.") return {} - def update_params(self, params: dict): + def update_params(self, params: dict) -> None: """ Translate the parameter from monai bundle to meet NNI requirements. @@ -195,7 +198,6 @@ def generate(self, output_folder: str = ".") -> None: if isinstance(self.algo, BundleAlgo): self.algo.export_to_disk(output_folder, task_prefix + task_id, fill_with_datastats=False) else: - ConfigParser.export_config_file(self.params, write_path) logger.info(write_path) @@ -208,7 +210,7 @@ def set_score(self, acc): else: warn("NNI is not detected. The code will continue to run without NNI.") - def run_algo(self, obj_filename: str, output_folder: str = ".", template_path=None) -> None: + def run_algo(self, obj_filename: str, output_folder: str = ".", template_path: PathLike | None = None) -> None: """ The python interface for NNI to run. @@ -281,7 +283,7 @@ class OptunaGen(HPOGen): """ - def __init__(self, algo: Optional[Algo] = None, params=None) -> None: + def __init__(self, algo: Algo | None = None, params: dict | None = None) -> None: self.algo: Algo self.obj_filename = "" @@ -329,7 +331,9 @@ def set_trial(self, trial): """Set the Optuna trial""" self.trial = trial - def __call__(self, trial, obj_filename: str, output_folder: str = ".", template_path=None): + def __call__( + self, trial: Any, obj_filename: str, output_folder: str = ".", template_path: PathLike | None = None + ) -> Any: """ Callabe that Optuna will use to optimize the hyper-parameters @@ -343,7 +347,7 @@ def __call__(self, trial, obj_filename: str, output_folder: str = ".", template_ self.run_algo(obj_filename, output_folder, template_path) return self.acc - def update_params(self, params: dict): + def update_params(self, params: dict) -> None: """ Translate the parameter from monai bundle. @@ -374,11 +378,10 @@ def generate(self, output_folder: str = ".") -> None: if isinstance(self.algo, BundleAlgo): self.algo.export_to_disk(output_folder, task_prefix + task_id, fill_with_datastats=False) else: - ConfigParser.export_config_file(self.params, write_path) logger.info(write_path) - def run_algo(self, obj_filename: str, output_folder: str = ".", template_path=None) -> None: + def run_algo(self, obj_filename: str, output_folder: str = ".", template_path: PathLike | None = None) -> None: """ The python interface for NNI to run. diff --git a/monai/apps/auto3dseg/transforms.py b/monai/apps/auto3dseg/transforms.py index 2793eb9202..0bb65edd13 100644 --- a/monai/apps/auto3dseg/transforms.py +++ b/monai/apps/auto3dseg/transforms.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import Dict, Hashable, Mapping +from collections.abc import Hashable, Mapping import numpy as np import torch @@ -18,6 +20,7 @@ from monai.config import KeysCollection from monai.networks.utils import pytorch_after from monai.transforms import MapTransform +from monai.utils.misc import ImageMetaKey class EnsureSameShaped(MapTransform): @@ -50,17 +53,18 @@ def __init__( self.source_key = source_key self.allowed_shape_difference = allowed_shape_difference - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) image_shape = d[self.source_key].shape[1:] for key in self.key_iterator(d): label_shape = d[key].shape[1:] if label_shape != image_shape: if np.allclose(list(label_shape), list(image_shape), atol=self.allowed_shape_difference): - warnings.warn( - f"The {key} with shape {label_shape} was resized to match the source shape {image_shape}," - f"the meta-data was not updated." - ) + msg = f"The {key} with shape {label_shape} was resized to match the source shape {image_shape}" + if hasattr(d[key], "meta") and isinstance(d[key].meta, Mapping): # type: ignore[attr-defined] + filename = d[key].meta.get(ImageMetaKey.FILENAME_OR_OBJ) # type: ignore[attr-defined] + msg += f", the metadata was not updated: filename={filename}" + warnings.warn(msg) d[key] = torch.nn.functional.interpolate( input=d[key].unsqueeze(0), size=image_shape, diff --git a/monai/apps/auto3dseg/utils.py b/monai/apps/auto3dseg/utils.py index f031bfde35..67cde64a2c 100644 --- a/monai/apps/auto3dseg/utils.py +++ b/monai/apps/auto3dseg/utils.py @@ -9,16 +9,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os -from typing import Dict, List, Optional from monai.apps.auto3dseg.bundle_gen import BundleAlgo from monai.auto3dseg import algo_from_pickle, algo_to_pickle def import_bundle_algo_history( - output_folder: str = ".", template_path: Optional[str] = None, only_trained: bool = True -) -> List: + output_folder: str = ".", template_path: str | None = None, only_trained: bool = True +) -> list: """ import the history of the bundleAlgo object with their names/identifiers @@ -55,7 +56,7 @@ def import_bundle_algo_history( return history -def export_bundle_algo_history(history: List[Dict[str, BundleAlgo]]): +def export_bundle_algo_history(history: list[dict[str, BundleAlgo]]) -> None: """ Save all the BundleAlgo in the history to algo_object.pkl in each individual folder diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index f9867599b3..a3c541ee91 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -9,12 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import shutil import sys import warnings +from collections.abc import Callable, Sequence from pathlib import Path -from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any import numpy as np @@ -88,18 +91,18 @@ def __init__( self, root_dir: PathLike, section: str, - transform: Union[Sequence[Callable], Callable] = (), + transform: Sequence[Callable] | Callable = (), download: bool = False, seed: int = 0, val_frac: float = 0.1, test_frac: float = 0.1, cache_num: int = sys.maxsize, cache_rate: float = 1.0, - num_workers: Optional[int] = 1, + num_workers: int | None = 1, progress: bool = True, copy_cache: bool = True, as_contiguous: bool = True, - runtime_cache=False, + runtime_cache: bool = False, ) -> None: root_dir = Path(root_dir) if not root_dir.is_dir(): @@ -148,7 +151,7 @@ def get_num_classes(self) -> int: """Get number of classes.""" return self.num_class - def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]: + def _generate_data_list(self, dataset_dir: PathLike) -> list[dict]: """ Raises: ValueError: When ``section`` is not one of ["training", "validation", "test"]. @@ -286,7 +289,7 @@ def __init__( root_dir: PathLike, task: str, section: str, - transform: Union[Sequence[Callable], Callable] = (), + transform: Sequence[Callable] | Callable = (), download: bool = False, seed: int = 0, val_frac: float = 0.2, @@ -296,7 +299,7 @@ def __init__( progress: bool = True, copy_cache: bool = True, as_contiguous: bool = True, - runtime_cache=False, + runtime_cache: bool = False, ) -> None: root_dir = Path(root_dir) if not root_dir.is_dir(): @@ -362,7 +365,7 @@ def get_indices(self) -> np.ndarray: def randomize(self, data: np.ndarray) -> None: self.R.shuffle(data) - def get_properties(self, keys: Optional[Union[Sequence[str], str]] = None): + def get_properties(self, keys: Sequence[str] | str | None = None) -> dict: """ Get the loaded properties of dataset with specified keys. If no keys specified, return all the loaded properties. @@ -374,14 +377,14 @@ def get_properties(self, keys: Optional[Union[Sequence[str], str]] = None): return {key: self._properties[key] for key in ensure_tuple(keys)} return {} - def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]: + def _generate_data_list(self, dataset_dir: PathLike) -> list[dict]: # the types of the item in data list should be compatible with the dataloader dataset_dir = Path(dataset_dir) section = "training" if self.section in ["training", "validation"] else "test" datalist = load_decathlon_datalist(dataset_dir / "dataset.json", True, section) return self._split_datalist(datalist) - def _split_datalist(self, datalist: List[Dict]) -> List[Dict]: + def _split_datalist(self, datalist: list[dict]) -> list[dict]: if self.section == "test": return datalist length = len(datalist) @@ -489,14 +492,14 @@ def __init__( root_dir: PathLike, collection: str, section: str, - transform: Union[Sequence[Callable], Callable] = (), + transform: Sequence[Callable] | Callable = (), download: bool = False, download_len: int = -1, seg_type: str = "SEG", - modality_tag: Tuple = (0x0008, 0x0060), - ref_series_uid_tag: Tuple = (0x0020, 0x000E), - ref_sop_uid_tag: Tuple = (0x0008, 0x1155), - specific_tags: Tuple = ( + modality_tag: tuple = (0x0008, 0x0060), + ref_series_uid_tag: tuple = (0x0020, 0x000E), + ref_sop_uid_tag: tuple = (0x0008, 0x1155), + specific_tags: tuple = ( (0x0008, 0x1115), # Referenced Series Sequence (0x0008, 0x1140), # Referenced Image Sequence (0x3006, 0x0010), # Referenced Frame of Reference Sequence @@ -514,7 +517,7 @@ def __init__( progress: bool = True, copy_cache: bool = True, as_contiguous: bool = True, - runtime_cache=False, + runtime_cache: bool = False, ) -> None: root_dir = Path(root_dir) if not root_dir.is_dir(): @@ -574,7 +577,7 @@ def get_indices(self) -> np.ndarray: def randomize(self, data: np.ndarray) -> None: self.R.shuffle(data) - def _download_series_reference_data(self, series_uid: str, download_dir: str): + def _download_series_reference_data(self, series_uid: str, download_dir: str) -> None: """ First of all, download a series from TCIA according to `series_uid`. Then find all referenced series and download. @@ -630,7 +633,7 @@ def _download_series_reference_data(self, series_uid: str, download_dir: str): if not os.path.exists(seg_dir): shutil.copytree(seg_first_dir, seg_dir) - def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]: + def _generate_data_list(self, dataset_dir: PathLike) -> list[dict]: # the types of the item in data list should be compatible with the dataloader dataset_dir = Path(dataset_dir) datalist = [] @@ -649,7 +652,7 @@ def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]: return self._split_datalist(datalist) - def _split_datalist(self, datalist: List[Dict]) -> List[Dict]: + def _split_datalist(self, datalist: list[dict]) -> list[dict]: if self.section == "test": return datalist length = len(datalist) @@ -703,7 +706,7 @@ class CrossValidation: """ - def __init__(self, dataset_cls, nfolds: int = 5, seed: int = 0, **dataset_params) -> None: + def __init__(self, dataset_cls: object, nfolds: int = 5, seed: int = 0, **dataset_params: Any) -> None: if not hasattr(dataset_cls, "_split_datalist"): raise ValueError("dataset class must have _split_datalist API.") self.dataset_cls = dataset_cls @@ -711,7 +714,7 @@ def __init__(self, dataset_cls, nfolds: int = 5, seed: int = 0, **dataset_params self.seed = seed self.dataset_params = dataset_params - def get_dataset(self, folds: Union[Sequence[int], int], **dataset_params): + def get_dataset(self, folds: Sequence[int] | int, **dataset_params: Any) -> object: """ Generate dataset based on the specified fold indices in the cross validation group. @@ -727,7 +730,7 @@ def get_dataset(self, folds: Union[Sequence[int], int], **dataset_params): dataset_params_.update(dataset_params) class _NsplitsDataset(self.dataset_cls): # type: ignore - def _split_datalist(self, datalist: List[Dict]) -> List[Dict]: + def _split_datalist(self, datalist: list[dict]) -> list[dict]: data = partition_dataset(data=datalist, num_partitions=nfolds, shuffle=True, seed=seed) return select_cross_validation_folds(partitions=data, folds=folds) diff --git a/monai/apps/deepedit/interaction.py b/monai/apps/deepedit/interaction.py index dce81f095e..07302575c6 100644 --- a/monai/apps/deepedit/interaction.py +++ b/monai/apps/deepedit/interaction.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Sequence, Union +from __future__ import annotations + +from collections.abc import Callable, Sequence import numpy as np import torch @@ -43,13 +45,12 @@ class Interaction: def __init__( self, deepgrow_probability: float, - transforms: Union[Sequence[Callable], Callable], + transforms: Sequence[Callable] | Callable, train: bool, - label_names: Union[None, Dict[str, int]] = None, + label_names: None | dict[str, int] = None, click_probability_key: str = "probability", max_interactions: int = 1, ) -> None: - self.deepgrow_probability = deepgrow_probability self.transforms = Compose(transforms) if not isinstance(transforms, Compose) else transforms self.train = train @@ -57,7 +58,7 @@ def __init__( self.click_probability_key = click_probability_key self.max_interactions = max_interactions - def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchdata: Dict[str, torch.Tensor]): + def __call__(self, engine: SupervisedTrainer | SupervisedEvaluator, batchdata: dict[str, torch.Tensor]) -> dict: if batchdata is None: raise ValueError("Must provide batch data for current iteration.") @@ -96,4 +97,4 @@ def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchd # first item in batch only engine.state.batch = batchdata - return engine._iteration(engine, batchdata) + return engine._iteration(engine, batchdata) # type: ignore[arg-type] diff --git a/monai/apps/deepedit/transforms.py b/monai/apps/deepedit/transforms.py index fb9bd8e2e2..880f90baae 100644 --- a/monai/apps/deepedit/transforms.py +++ b/monai/apps/deepedit/transforms.py @@ -9,11 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import json import logging import random import warnings -from typing import Dict, Hashable, List, Mapping, Optional +from collections.abc import Hashable, Mapping, Sequence, Sized import numpy as np import torch @@ -37,7 +39,7 @@ def __init__( keys: KeysCollection, number_intensity_ch: int = 1, probability: float = 1.0, - label_names=None, + label_names: Sized | None = None, allow_missing_keys: bool = False, ): """ @@ -52,7 +54,7 @@ def __init__( self.number_intensity_ch = number_intensity_ch self.discard_probability = probability - self.label_names = label_names + self.label_names = label_names or [] def _apply(self, image): if self.discard_probability >= 1.0 or np.random.choice( @@ -67,8 +69,8 @@ def _apply(self, image): image = np.concatenate([image, signal], axis=0) return image - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - d: Dict = dict(data) + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]: + d: dict = dict(data) for key in self.key_iterator(d): if key == "image": tmp_image = self._apply(d[key]) @@ -82,7 +84,9 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda class NormalizeLabelsInDatasetd(MapTransform): - def __init__(self, keys: KeysCollection, label_names=None, allow_missing_keys: bool = False): + def __init__( + self, keys: KeysCollection, label_names: dict[str, int] | None = None, allow_missing_keys: bool = False + ): """ Normalize label values according to label names dictionary @@ -92,10 +96,10 @@ def __init__(self, keys: KeysCollection, label_names=None, allow_missing_keys: b """ super().__init__(keys, allow_missing_keys) - self.label_names = label_names + self.label_names = label_names or {} - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - d: Dict = dict(data) + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]: + d: dict = dict(data) for key in self.key_iterator(d): # Dictionary containing new label numbers new_label_names = {} @@ -117,7 +121,9 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda class SingleLabelSelectiond(MapTransform): - def __init__(self, keys: KeysCollection, label_names=None, allow_missing_keys: bool = False): + def __init__( + self, keys: KeysCollection, label_names: Sequence[str] | None = None, allow_missing_keys: bool = False + ): """ Selects one label at a time to train the DeepEdit @@ -127,7 +133,7 @@ def __init__(self, keys: KeysCollection, label_names=None, allow_missing_keys: b """ super().__init__(keys, allow_missing_keys) - self.label_names = label_names + self.label_names: Sequence[str] = label_names or [] self.all_label_values = { "spleen": 1, "right kidney": 2, @@ -145,8 +151,8 @@ def __init__(self, keys: KeysCollection, label_names=None, allow_missing_keys: b "left adrenal gland": 14, } - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - d: Dict = dict(data) + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]: + d: dict = dict(data) for key in self.key_iterator(d): if key == "label": # Taking one label at a time @@ -232,8 +238,8 @@ def _get_signal(self, image, guidance): signal = np.zeros((1, image.shape[-2], image.shape[-1]), dtype=np.float32) return signal - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - d: Dict = dict(data) + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]: + d: dict = dict(data) for key in self.key_iterator(d): if key == "image": image = d[key] @@ -262,7 +268,7 @@ class FindAllValidSlicesDeepEditd(MapTransform): sids: key to store slices indices having valid label map. """ - def __init__(self, keys: KeysCollection, sids="sids", allow_missing_keys: bool = False): + def __init__(self, keys: KeysCollection, sids: Hashable = "sids", allow_missing_keys: bool = False): super().__init__(keys, allow_missing_keys) self.sids = sids @@ -276,8 +282,8 @@ def _apply(self, label, d): sids[key_label] = l_ids return sids - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - d: Dict = dict(data) + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]: + d: dict = dict(data) for key in self.key_iterator(d): if key == "label": label = d[key] @@ -324,7 +330,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.sids_key = sids self.sid_key = sid - self.sid: Dict[str, int] = dict() + self.sid: dict[str, int] = dict() self.guidance = guidance self.connected_regions = connected_regions @@ -384,8 +390,8 @@ def _randomize(self, d, key_label): sid = None self.sid[key_label] = sid - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - d: Dict = dict(data) + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]: + d: dict = dict(data) for key in self.key_iterator(d): if key == "label": label_guidances = {} @@ -442,8 +448,8 @@ def disparity(label, pred): def _apply(self, label, pred): return self.disparity(label, pred) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - d: Dict = dict(data) + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]: + d: dict = dict(data) for key in self.key_iterator(d): if key == "label": all_discrepancies = {} @@ -503,10 +509,10 @@ def __init__( self.discrepancy = discrepancy self.probability = probability self._will_interact = None - self.is_pos: Optional[bool] = None - self.is_other: Optional[bool] = None + self.is_pos: bool | None = None + self.is_other: bool | None = None self.default_guidance = None - self.guidance: Dict[str, List[List[int]]] = {} + self.guidance: dict[str, list[list[int]]] = {} def randomize(self, data=None): probability = data[self.probability] @@ -527,7 +533,6 @@ def find_guidance(self, discrepancy): return None def add_guidance(self, guidance, discrepancy, label_names, labels): - # Positive clicks of the segment in the iteration pos_discr = discrepancy[0] # idx 0 is positive discrepancy and idx 1 is negative discrepancy @@ -566,8 +571,8 @@ def add_guidance(self, guidance, discrepancy, label_names, labels): tmp_label = 1 - tmp_label self.guidance[key_label].append(self.find_guidance(discrepancy[1] * tmp_label)) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - d: Dict = dict(data) + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]: + d: dict = dict(data) guidance = d[self.guidance_key] discrepancy = d[self.discrepancy] self.randomize(data) @@ -634,15 +639,15 @@ class AddGuidanceFromPointsDeepEditd(Transform): def __init__( self, - ref_image, + ref_image: str, guidance: str = "guidance", - label_names=None, - meta_keys: Optional[str] = None, + label_names: dict | None = None, + meta_keys: str | None = None, meta_key_postfix: str = "meta_dict", ): self.ref_image = ref_image self.guidance = guidance - self.label_names = label_names + self.label_names = label_names or {} self.meta_keys = meta_keys self.meta_key_postfix = meta_key_postfix @@ -715,8 +720,8 @@ class SplitPredsLabeld(MapTransform): """ - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - d: Dict = dict(data) + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]: + d: dict = dict(data) for key in self.key_iterator(d): if key == "pred": for idx, (key_label, _) in enumerate(d["label_names"].items()): @@ -753,7 +758,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.sids_key = sids self.sid_key = sid - self.sid: Dict[str, int] = dict() + self.sid: dict[str, int] = dict() self.guidance = guidance self.connected_regions = connected_regions @@ -816,8 +821,8 @@ def _randomize(self, d, key_label): sid = None self.sid[key_label] = sid - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - d: Dict = dict(data) + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]: + d: dict = dict(data) for key in self.key_iterator(d): if key == "label": label_guidances = {} @@ -850,7 +855,7 @@ class FindAllValidSlicesMissingLabelsd(MapTransform): sids: key to store slices indices having valid label map. """ - def __init__(self, keys: KeysCollection, sids="sids", allow_missing_keys: bool = False): + def __init__(self, keys: KeysCollection, sids: Hashable = "sids", allow_missing_keys: bool = False): super().__init__(keys, allow_missing_keys) self.sids = sids @@ -867,8 +872,8 @@ def _apply(self, label, d): sids[key_label] = l_ids return sids - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - d: Dict = dict(data) + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]: + d: dict = dict(data) for key in self.key_iterator(d): if key == "label": label = d[key] diff --git a/monai/apps/deepgrow/dataset.py b/monai/apps/deepgrow/dataset.py index dcdba512d6..802d86e0c7 100644 --- a/monai/apps/deepgrow/dataset.py +++ b/monai/apps/deepgrow/dataset.py @@ -9,28 +9,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import os -from typing import Dict, List, Union +from collections.abc import Sequence import numpy as np -from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, Orientationd, Spacingd, SqueezeDimd +from monai.config import PathLike +from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, Orientationd, Spacingd, SqueezeDimd, Transform from monai.utils import GridSampleMode def create_dataset( - datalist, + datalist: list[dict], output_dir: str, dimension: int, - pixdim, + pixdim: Sequence[float] | float, image_key: str = "image", label_key: str = "label", - base_dir=None, + base_dir: PathLike | None = None, limit: int = 0, relative_path: bool = False, - transforms=None, -) -> List[Dict]: + transforms: Transform | None = None, +) -> list[dict]: """ Utility to pre-process and create dataset list for Deepgrow training over on existing one. The input data list is normally a list of images and labels (3D volume) that needs pre-processing @@ -144,7 +147,7 @@ def _default_transforms(image_key, label_key, pixdim): def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): - data_list: List[Dict[str, Union[str, int]]] = [] + data_list: list[dict[str, str | int]] = [] image_count = 0 label_count = 0 @@ -211,7 +214,7 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): - data_list: List[Dict[str, Union[str, int]]] = [] + data_list: list[dict[str, str | int]] = [] image_count = 0 label_count = 0 diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py index 73bc8e7e0b..88211c31e3 100644 --- a/monai/apps/deepgrow/interaction.py +++ b/monai/apps/deepgrow/interaction.py @@ -8,7 +8,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Sequence, Union + +from __future__ import annotations + +from collections.abc import Callable, Sequence import torch @@ -38,12 +41,11 @@ class Interaction: def __init__( self, - transforms: Union[Sequence[Callable], Callable], + transforms: Sequence[Callable] | Callable, max_interactions: int, train: bool, key_probability: str = "probability", ) -> None: - if not isinstance(transforms, Compose): transforms = Compose(transforms) @@ -52,7 +54,7 @@ def __init__( self.train = train self.key_probability = key_probability - def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchdata: Dict[str, torch.Tensor]): + def __call__(self, engine: SupervisedTrainer | SupervisedEvaluator, batchdata: dict[str, torch.Tensor]) -> dict: if batchdata is None: raise ValueError("Must provide batch data for current iteration.") @@ -85,4 +87,4 @@ def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchd # collate list into a batch for next round interaction batchdata = list_data_collate(batchdata_list) - return engine._iteration(engine, batchdata) + return engine._iteration(engine, batchdata) # type: ignore[arg-type] diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index e8b191845a..7078777f92 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -8,8 +8,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations + import json -from typing import Callable, Dict, Hashable, List, Optional, Sequence, Union +from collections.abc import Callable, Hashable, Iterable, Sequence +from typing import Any import numpy as np import torch @@ -19,7 +23,7 @@ from monai.transforms import Resize, SpatialCrop from monai.transforms.transform import MapTransform, Randomizable, Transform from monai.transforms.utils import generate_spatial_bounding_box, is_positive -from monai.utils import InterpolateMode, deprecated_arg, ensure_tuple, ensure_tuple_rep, min_version, optional_import +from monai.utils import InterpolateMode, ensure_tuple, ensure_tuple_rep, min_version, optional_import from monai.utils.enums import PostFix measure, _ = optional_import("skimage.measure", "0.14.2", min_version) @@ -50,8 +54,8 @@ def _apply(self, label): sids.append(sid) return np.asarray(sids) - def __call__(self, data) -> Dict: - d: Dict = dict(data) + def __call__(self, data: Any) -> dict: + d: dict = dict(data) label = d[self.label].numpy() if isinstance(data[self.label], torch.Tensor) else data[self.label] if label.shape[0] != 1: raise ValueError(f"Only supports single channel labels, got label shape {label.shape}!") @@ -396,13 +400,13 @@ def __init__( self, keys: KeysCollection, source_key: str, - spatial_size: Union[Sequence[int], np.ndarray], + spatial_size: Sequence[int] | np.ndarray, select_fn: Callable = is_positive, - channel_indices: Optional[IndexSelection] = None, + channel_indices: IndexSelection | None = None, margin: int = 0, allow_smaller: bool = True, - meta_keys: Optional[KeysCollection] = None, - meta_key_postfix=DEFAULT_POST_FIX, + meta_keys: KeysCollection | None = None, + meta_key_postfix: str = DEFAULT_POST_FIX, start_coord_key: str = "foreground_start_coord", end_coord_key: str = "foreground_end_coord", original_shape_key: str = "foreground_original_shape", @@ -493,10 +497,9 @@ class AddGuidanceFromPointsd(Transform): """ - @deprecated_arg(name="dimensions", since="0.6", msg_suffix="Please use `spatial_dims` instead.") def __init__( self, - ref_image, + ref_image: str, guidance: str = "guidance", foreground: str = "foreground", background: str = "background", @@ -504,9 +507,8 @@ def __init__( depth_first: bool = True, spatial_dims: int = 2, slice_key: str = "slice", - meta_keys: Optional[str] = None, + meta_keys: str | None = None, meta_key_postfix: str = DEFAULT_POST_FIX, - dimensions: Optional[int] = None, ): self.ref_image = ref_image self.guidance = guidance @@ -514,7 +516,7 @@ def __init__( self.background = background self.axis = axis self.depth_first = depth_first - self.dimensions = spatial_dims if dimensions is None else dimensions + self.dimensions = spatial_dims self.slice = slice_key self.meta_keys = meta_keys self.meta_key_postfix = meta_key_postfix @@ -523,7 +525,7 @@ def _apply(self, pos_clicks, neg_clicks, factor, slice_num): pos = neg = [] if self.dimensions == 2: - points: List = list(pos_clicks) + points: list = list(pos_clicks) points.extend(neg_clicks) slices = list(np.unique(np.array(points)[:, self.axis])) @@ -612,10 +614,10 @@ def __init__( self, keys: KeysCollection, guidance: str, - spatial_size, - margin=20, - meta_keys: Optional[KeysCollection] = None, - meta_key_postfix=DEFAULT_POST_FIX, + spatial_size: Iterable[int], + margin: int = 20, + meta_keys: KeysCollection | None = None, + meta_key_postfix: str = DEFAULT_POST_FIX, start_coord_key: str = "foreground_start_coord", end_coord_key: str = "foreground_end_coord", original_shape_key: str = "foreground_original_shape", @@ -653,8 +655,8 @@ def bounding_box(self, points, img_shape): box_start[di], box_end[di] = min_d, max_d return box_start, box_end - def __call__(self, data) -> Dict: - d: Dict = dict(data) + def __call__(self, data: Any) -> dict: + d: dict = dict(data) first_key: Hashable = self.first_key(d) if first_key == (): return d @@ -730,7 +732,7 @@ def __init__( self, guidance: str, ref_image: str, - meta_keys: Optional[str] = None, + meta_keys: str | None = None, meta_key_postfix: str = DEFAULT_POST_FIX, cropped_shape_key: str = "foreground_cropped_shape", ) -> None: @@ -740,10 +742,10 @@ def __init__( self.meta_key_postfix = meta_key_postfix self.cropped_shape_key = cropped_shape_key - def __call__(self, data) -> Dict: + def __call__(self, data: Any) -> dict: d = dict(data) guidance = d[self.guidance] - meta_dict: Dict = d[self.meta_keys or f"{self.ref_image}_{self.meta_key_postfix}"] + meta_dict: dict = d[self.meta_keys or f"{self.ref_image}_{self.meta_key_postfix}"] current_shape = d[self.ref_image].shape[1:] cropped_shape = meta_dict[self.cropped_shape_key][1:] factor = np.divide(current_shape, cropped_shape) @@ -811,9 +813,9 @@ def __init__( keys: KeysCollection, ref_image: str, slice_only: bool = False, - mode: Union[Sequence[Union[InterpolateMode, str]], InterpolateMode, str] = InterpolateMode.NEAREST, - align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, - meta_keys: Optional[str] = None, + mode: Sequence[InterpolateMode | str] | InterpolateMode | str = InterpolateMode.NEAREST, + align_corners: Sequence[bool | None] | bool | None = None, + meta_keys: str | None = None, meta_key_postfix: str = DEFAULT_POST_FIX, start_coord_key: str = "foreground_start_coord", end_coord_key: str = "foreground_end_coord", @@ -835,9 +837,9 @@ def __init__( self.original_shape_key = original_shape_key self.cropped_shape_key = cropped_shape_key - def __call__(self, data) -> Dict: + def __call__(self, data: Any) -> dict: d = dict(data) - meta_dict: Dict = d[f"{self.ref_image}_{self.meta_key_postfix}"] + meta_dict: dict = d[f"{self.ref_image}_{self.meta_key_postfix}"] for key, mode, align_corners, meta_key in self.key_iterator(d, self.mode, self.align_corners, self.meta_keys): image = d[key] @@ -916,10 +918,10 @@ class Fetch2DSliced(MapTransform): def __init__( self, - keys, - guidance="guidance", + keys: KeysCollection, + guidance: str = "guidance", axis: int = 0, - meta_keys: Optional[KeysCollection] = None, + meta_keys: KeysCollection | None = None, meta_key_postfix: str = DEFAULT_POST_FIX, allow_missing_keys: bool = False, ): diff --git a/monai/apps/detection/metrics/coco.py b/monai/apps/detection/metrics/coco.py index 8dba3fa7da..033b763be5 100644 --- a/monai/apps/detection/metrics/coco.py +++ b/monai/apps/detection/metrics/coco.py @@ -61,9 +61,12 @@ The changes include 1) code reformatting, 2) docstrings. """ +from __future__ import annotations + import logging as logger import time -from typing import Dict, List, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any import numpy as np @@ -153,7 +156,7 @@ def __init__( self.recall_thresholds = np.linspace(0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True) self.max_detections = max_detection - def __call__(self, *args, **kwargs) -> Tuple[Dict[str, float], Union[Dict[str, np.ndarray], None]]: + def __call__(self, *args: Any, **kwargs: Any) -> tuple[dict[str, float], dict[str, np.ndarray] | None]: """ Compute metric. See :func:`compute` for more information. @@ -162,12 +165,12 @@ def __call__(self, *args, **kwargs) -> Tuple[Dict[str, float], Union[Dict[str, n **kwargs: keyword arguments passed to :func:`compute` Returns: - Dict[str, float]: dictionary with scalar values for evaluation - Dict[str, np.ndarray]: dictionary with arrays, e.g. for visualization of graphs + dict[str, float]: dictionary with scalar values for evaluation + dict[str, np.ndarray]: dictionary with arrays, e.g. for visualization of graphs """ return self.compute(*args, **kwargs) - def check_number_of_iou(self, *args) -> None: + def check_number_of_iou(self, *args: np.ndarray) -> None: """ Check if shape of input in first dimension is consistent with expected IoU values (assumes IoU dimension is the first dimension) @@ -192,13 +195,13 @@ def get_iou_thresholds(self) -> Sequence[float]: """ return list(self.iou_thresholds) - def compute(self, results_list: List[Dict[int, Dict[str, np.ndarray]]]) -> Tuple[Dict[str, float], None]: + def compute(self, results_list: list[dict[int, dict[str, np.ndarray]]]) -> tuple[dict[str, float], None]: """ Compute COCO metrics Args: - results_list (List[Dict[int, Dict[str, np.ndarray]]]): list with results per image (in list) - per category (dict). Inner Dict contains multiple results obtained by :func:`box_matching_batch`. + results_list (list[dict[int, dict[str, np.ndarray]]]): list with results per image (in list) + per category (dict). Inner dict contains multiple results obtained by :func:`box_matching_batch`. - `dtMatches`: matched detections [T, D], where T = number of thresholds, D = number of detections @@ -211,13 +214,13 @@ def compute(self, results_list: List[Dict[int, Dict[str, np.ndarray]]]) -> Tuple indicate which detections should be ignored Returns: - Dict[str, float], dictionary with coco metrics + dict[str, float], dictionary with coco metrics """ if self.verbose: logger.info("Start COCO metric computation...") tic = time.time() - dataset_statistics = self._compute_statistics(results_list=results_list) # Dict[str, Union[np.ndarray, List]] + dataset_statistics = self._compute_statistics(results_list=results_list) # dict[str, Union[np.ndarray, list]] if self.verbose: toc = time.time() @@ -232,13 +235,13 @@ def compute(self, results_list: List[Dict[int, Dict[str, np.ndarray]]]) -> Tuple logger.info(f"COCO metrics computed in t={(toc - tic):0.2f}s.") return results, None - def _compute_ap(self, dataset_statistics: Dict[str, Union[np.ndarray, List]]) -> Dict[str, float]: + def _compute_ap(self, dataset_statistics: dict[str, np.ndarray | list]) -> dict[str, float]: """ Compute AP metrics Args: - dataset_statistics (List[Dict[int, Dict[str, np.ndarray]]]): list with result s per image (in list) - per category (dict). Inner Dict contains multiple results obtained by :func:`box_matching_batch`. + dataset_statistics (list[dict[int, dict[str, np.ndarray]]]): list with result s per image (in list) + per category (dict). Inner dict contains multiple results obtained by :func:`box_matching_batch`. - `dtMatches`: matched detections [T, D], where T = number of thresholds, D = number of detections @@ -279,13 +282,13 @@ def _compute_ap(self, dataset_statistics: Dict[str, Union[np.ndarray, List]]) -> results[key] = self._select_ap(dataset_statistics, iou_idx=[idx], cls_idx=cls_idx, max_det_idx=-1) return results - def _compute_ar(self, dataset_statistics: Dict[str, Union[np.ndarray, List]]) -> Dict[str, float]: + def _compute_ar(self, dataset_statistics: dict[str, np.ndarray | list]) -> dict[str, float]: """ Compute AR metrics Args: - dataset_statistics (List[Dict[int, Dict[str, np.ndarray]]]): list with result s per image (in list) - per category (dict). Inner Dict contains multiple results obtained by :func:`box_matching_batch`. + dataset_statistics (list[dict[int, dict[str, np.ndarray]]]): list with result s per image (in list) + per category (dict). Inner dict contains multiple results obtained by :func:`box_matching_batch`. - `dtMatches`: matched detections [T, D], where T = number of thresholds, D = number of detections @@ -324,8 +327,8 @@ def _compute_ar(self, dataset_statistics: Dict[str, Union[np.ndarray, List]]) -> @staticmethod def _select_ap( dataset_statistics: dict, - iou_idx: Union[int, List[int], np.ndarray, None] = None, - cls_idx: Union[int, Sequence[int], None] = None, + iou_idx: int | list[int] | np.ndarray | None = None, + cls_idx: int | Sequence[int] | None = None, max_det_idx: int = -1, ) -> float: """ @@ -359,8 +362,8 @@ def _select_ap( @staticmethod def _select_ar( dataset_statistics: dict, - iou_idx: Union[int, Sequence[int], None] = None, - cls_idx: Union[int, Sequence[int], None] = None, + iou_idx: int | Sequence[int] | None = None, + cls_idx: int | Sequence[int] | None = None, max_det_idx: int = -1, ) -> float: """ @@ -395,16 +398,14 @@ def _select_ar( return float(np.mean(rec[rec > -1])) - def _compute_statistics( - self, results_list: List[Dict[int, Dict[str, np.ndarray]]] - ) -> Dict[str, Union[np.ndarray, List]]: + def _compute_statistics(self, results_list: list[dict[int, dict[str, np.ndarray]]]) -> dict[str, np.ndarray | list]: """ Compute statistics needed for COCO metrics (mAP, AP of individual classes, mAP@IoU_Thresholds, AR) Adapted from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py Args: - results_list (List[Dict[int, Dict[str, np.ndarray]]]): list with result s per image (in list) - per category (dict). Inner Dict contains multiple results obtained by :func:`box_matching_batch`. + results_list (list[dict[int, dict[str, np.ndarray]]]): list with result s per image (in list) + per category (dict). Inner dict contains multiple results obtained by :func:`box_matching_batch`. - `dtMatches`: matched detections [T, D], where T = number of thresholds, D = number of detections @@ -487,9 +488,9 @@ def _compute_stats_single_threshold( tp: np.ndarray, fp: np.ndarray, dt_scores_sorted: np.ndarray, - recall_thresholds: Union[np.ndarray, Sequence[float]], + recall_thresholds: np.ndarray | Sequence[float], num_gt: int, -) -> Tuple[float, np.ndarray, np.ndarray]: +) -> tuple[float, np.ndarray, np.ndarray]: """ Compute recall value, precision curve and scores thresholds Adapted from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py diff --git a/monai/apps/detection/metrics/matching.py b/monai/apps/detection/metrics/matching.py index 37e6e2fa06..6efcb701d4 100644 --- a/monai/apps/detection/metrics/matching.py +++ b/monai/apps/detection/metrics/matching.py @@ -62,7 +62,9 @@ 3) allow input args gt_ignore to be optional. (If so, no GT boxes will be ignored.) """ -from typing import Callable, Dict, List, Sequence, Union +from __future__ import annotations + +from collections.abc import Callable, Sequence import numpy as np @@ -77,9 +79,9 @@ def matching_batch( pred_scores: Sequence[np.ndarray], gt_boxes: Sequence[np.ndarray], gt_classes: Sequence[np.ndarray], - gt_ignore: Union[Sequence[Sequence[bool]], Sequence[np.ndarray], None] = None, + gt_ignore: Sequence[Sequence[bool]] | Sequence[np.ndarray] | None = None, max_detections: int = 100, -) -> List[Dict[int, Dict[str, np.ndarray]]]: +) -> list[dict[int, dict[str, np.ndarray]]]: """ Match boxes of a batch to corresponding ground truth for each category independently. @@ -185,7 +187,7 @@ def matching_batch( def _matching_no_gt( iou_thresholds: Sequence[float], pred_scores: np.ndarray, max_detections: int -) -> Dict[str, np.ndarray]: +) -> dict[str, np.ndarray]: """ Matching result with not ground truth in image @@ -228,7 +230,7 @@ def _matching_no_gt( } -def _matching_no_pred(iou_thresholds: Sequence[float], gt_ignore: np.ndarray) -> Dict[str, np.ndarray]: +def _matching_no_pred(iou_thresholds: Sequence[float], gt_ignore: np.ndarray) -> dict[str, np.ndarray]: """ Matching result with no predictions @@ -275,7 +277,7 @@ def _matching_single_image_single_class( gt_ignore: np.ndarray, max_detections: int, iou_thresholds: Sequence[float], -) -> Dict[str, np.ndarray]: +) -> dict[str, np.ndarray]: """ Adapted from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py diff --git a/monai/apps/detection/networks/retinanet_detector.py b/monai/apps/detection/networks/retinanet_detector.py index 5064304cbb..a6a4cf4e56 100644 --- a/monai/apps/detection/networks/retinanet_detector.py +++ b/monai/apps/detection/networks/retinanet_detector.py @@ -37,8 +37,11 @@ https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py """ +from __future__ import annotations + import warnings -from typing import Any, Callable, Dict, List, Sequence, Tuple, Union +from collections.abc import Callable, Sequence +from typing import Any import torch from torch import Tensor, nn @@ -177,7 +180,11 @@ def forward(self, images: torch.Tensor): """ def __init__( - self, network, anchor_generator: AnchorGenerator, box_overlap_metric: Callable = box_iou, debug: bool = False + self, + network: nn.Module, + anchor_generator: AnchorGenerator, + box_overlap_metric: Callable = box_iou, + debug: bool = False, ): super().__init__() @@ -191,12 +198,12 @@ def __init__( ) self.network = network - self.spatial_dims = self.network.spatial_dims + self.spatial_dims: int = self.network.spatial_dims # type: ignore[assignment] self.num_classes = self.network.num_classes self.size_divisible = ensure_tuple_rep(self.network.size_divisible, self.spatial_dims) # keys for the network output - self.cls_key = self.network.cls_key - self.box_reg_key = self.network.box_reg_key + self.cls_key: str = self.network.cls_key # type: ignore[assignment] + self.box_reg_key: str = self.network.box_reg_key # type: ignore[assignment] # check if anchor_generator matches with network self.anchor_generator = anchor_generator @@ -209,14 +216,14 @@ def __init__( ) # if new coming input images has same shape with # self.previous_image_shape, there is no need to generate new anchors. - self.anchors: Union[List[Tensor], None] = None - self.previous_image_shape: Union[Any, None] = None + self.anchors: list[Tensor] | None = None + self.previous_image_shape: Any | None = None self.box_overlap_metric = box_overlap_metric self.debug = debug # default setting for training - self.fg_bg_sampler: Union[Any, None] = None + self.fg_bg_sampler: Any | None = None self.set_cls_loss(torch.nn.BCEWithLogitsLoss(reduction="mean")) # classification loss self.set_box_regression_loss( torch.nn.SmoothL1Loss(beta=1.0 / 9, reduction="mean"), encode_gt=True, decode_pred=False @@ -234,7 +241,7 @@ def __init__( # default setting for inference, # can be updated by self.set_sliding_window_inferer(*) - self.inferer: Union[SlidingWindowInferer, None] = None + self.inferer: SlidingWindowInferer | None = None # can be updated by self.set_box_selector_parameters(*), self.box_selector = BoxSelector( box_overlap_metric=self.box_overlap_metric, @@ -245,7 +252,7 @@ def __init__( apply_sigmoid=True, ) - def set_box_coder_weights(self, weights: Tuple[float]): + def set_box_coder_weights(self, weights: tuple[float]) -> None: """ Set the weights for box coder. @@ -257,7 +264,7 @@ def set_box_coder_weights(self, weights: Tuple[float]): raise ValueError(f"len(weights) should be {2 * self.spatial_dims}, got weights={weights}.") self.box_coder = BoxCoder(weights=weights) - def set_target_keys(self, box_key: str, label_key: str): + def set_target_keys(self, box_key: str, label_key: str) -> None: """ Set keys for the training targets and inference outputs. During training, both box_key and label_key should be keys in the targets @@ -310,7 +317,9 @@ def set_box_regression_loss(self, box_loss: nn.Module, encode_gt: bool, decode_p self.encode_gt = encode_gt self.decode_pred = decode_pred - def set_regular_matcher(self, fg_iou_thresh: float, bg_iou_thresh: float, allow_low_quality_matches=True) -> None: + def set_regular_matcher( + self, fg_iou_thresh: float, bg_iou_thresh: float, allow_low_quality_matches: bool = True + ) -> None: """ Using for training. Set torchvision matcher that matches anchors with ground truth boxes. @@ -344,7 +353,7 @@ def set_atss_matcher(self, num_candidates: int = 4, center_in_gt: bool = False) def set_hard_negative_sampler( self, batch_size_per_image: int, positive_fraction: float, min_neg: int = 1, pool_size: float = 10 - ): + ) -> None: """ Using for training. Set hard negative sampler that samples part of the anchors for training. @@ -367,7 +376,7 @@ def set_hard_negative_sampler( pool_size=pool_size, ) - def set_balanced_sampler(self, batch_size_per_image: int, positive_fraction: float): + def set_balanced_sampler(self, batch_size_per_image: int, positive_fraction: float) -> None: """ Using for training. Set torchvision balanced sampler that samples part of the anchors for training. @@ -382,18 +391,18 @@ def set_balanced_sampler(self, batch_size_per_image: int, positive_fraction: flo def set_sliding_window_inferer( self, - roi_size: Union[Sequence[int], int], + roi_size: Sequence[int] | int, sw_batch_size: int = 1, overlap: float = 0.5, - mode: Union[BlendMode, str] = BlendMode.CONSTANT, - sigma_scale: Union[Sequence[float], float] = 0.125, - padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, + mode: BlendMode | str = BlendMode.CONSTANT, + sigma_scale: Sequence[float] | float = 0.125, + padding_mode: PytorchPadMode | str = PytorchPadMode.CONSTANT, cval: float = 0.0, - sw_device: Union[torch.device, str, None] = None, - device: Union[torch.device, str, None] = None, + sw_device: torch.device | str | None = None, + device: torch.device | str | None = None, progress: bool = False, cache_roi_weight_map: bool = False, - ): + ) -> None: """ Define sliding window inferer and store it to self.inferer. """ @@ -418,7 +427,7 @@ def set_box_selector_parameters( nms_thresh: float = 0.5, detections_per_img: int = 300, apply_sigmoid: bool = True, - ): + ) -> None: """ Using for inference. Set the parameters that are used for box selection during inference. The box selection is performed with the following steps: @@ -446,10 +455,10 @@ def set_box_selector_parameters( def forward( self, - input_images: Union[List[Tensor], Tensor], - targets: Union[List[Dict[str, Tensor]], None] = None, + input_images: list[Tensor] | Tensor, + targets: list[dict[str, Tensor]] | None = None, use_inferer: bool = False, - ) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]: + ) -> dict[str, Tensor] | list[dict[str, Tensor]]: """ Returns a dict of losses during training, or a list predicted dict of boxes and labels during inference. @@ -536,7 +545,7 @@ def _check_detector_training_components(self): "or set classification loss function as Focal loss with self.set_cls_loss(*)" ) - def generate_anchors(self, images: Tensor, head_outputs: Dict[str, List[Tensor]]): + def generate_anchors(self, images: Tensor, head_outputs: dict[str, list[Tensor]]) -> None: """ Generate anchors and store it in self.anchors: List[Tensor]. We generate anchors only when there is no stored anchors, @@ -552,7 +561,7 @@ def generate_anchors(self, images: Tensor, head_outputs: Dict[str, List[Tensor]] self.anchors = self.anchor_generator(images, head_outputs[self.cls_key]) # List[Tensor], len = batchsize self.previous_image_shape = images.shape - def _reshape_maps(self, result_maps: List[Tensor]) -> Tensor: + def _reshape_maps(self, result_maps: list[Tensor]) -> Tensor: """ Concat network output map list to a single Tensor. This function is used in both training and inference. @@ -596,12 +605,12 @@ def _reshape_maps(self, result_maps: List[Tensor]) -> Tensor: def postprocess_detections( self, - head_outputs_reshape: Dict[str, Tensor], - anchors: List[Tensor], - image_sizes: List[List[int]], + head_outputs_reshape: dict[str, Tensor], + anchors: list[Tensor], + image_sizes: list[list[int]], num_anchor_locs_per_level: Sequence[int], need_sigmoid: bool = True, - ) -> List[Dict[str, Tensor]]: + ) -> list[dict[str, Tensor]]: """ Postprocessing to generate detection result from classification logits and box regression. Use self.box_selector to select the final output boxes for each image. @@ -626,7 +635,7 @@ def postprocess_detections( ] # split outputs per level - split_head_outputs: Dict[str, List[Tensor]] = {} + split_head_outputs: dict[str, list[Tensor]] = {} for k in head_outputs_reshape: split_head_outputs[k] = list(head_outputs_reshape[k].split(num_anchors_per_level, dim=1)) split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors] # List[List[Tensor]] @@ -637,7 +646,7 @@ def postprocess_detections( num_images = len(image_sizes) # B - detections: List[Dict[str, Tensor]] = [] + detections: list[dict[str, Tensor]] = [] for index in range(num_images): box_regression_per_image = [ @@ -667,11 +676,11 @@ def postprocess_detections( def compute_loss( self, - head_outputs_reshape: Dict[str, Tensor], - targets: List[Dict[str, Tensor]], - anchors: List[Tensor], + head_outputs_reshape: dict[str, Tensor], + targets: list[dict[str, Tensor]], + anchors: list[Tensor], num_anchor_locs_per_level: Sequence[int], - ) -> Dict[str, Tensor]: + ) -> dict[str, Tensor]: """ Compute losses. @@ -696,8 +705,8 @@ def compute_loss( return {self.cls_key: losses_cls, self.box_reg_key: losses_box_regression} def compute_anchor_matched_idxs( - self, anchors: List[Tensor], targets: List[Dict[str, Tensor]], num_anchor_locs_per_level: Sequence[int] - ) -> List[Tensor]: + self, anchors: list[Tensor], targets: list[dict[str, Tensor]], num_anchor_locs_per_level: Sequence[int] + ) -> list[Tensor]: """ Compute the matched indices between anchors and ground truth (gt) boxes in targets. output[k][i] represents the matched gt index for anchor[i] in image k. @@ -768,7 +777,7 @@ def compute_anchor_matched_idxs( return matched_idxs def compute_cls_loss( - self, cls_logits: Tensor, targets: List[Dict[str, Tensor]], matched_idxs: List[Tensor] + self, cls_logits: Tensor, targets: list[dict[str, Tensor]], matched_idxs: list[Tensor] ) -> Tensor: """ Compute classification losses. @@ -800,9 +809,9 @@ def compute_cls_loss( def compute_box_loss( self, box_regression: Tensor, - targets: List[Dict[str, Tensor]], - anchors: List[Tensor], - matched_idxs: List[Tensor], + targets: list[dict[str, Tensor]], + anchors: list[Tensor], + matched_idxs: list[Tensor], ) -> Tensor: """ Compute box regression losses. @@ -845,8 +854,8 @@ def compute_box_loss( return losses def get_cls_train_sample_per_image( - self, cls_logits_per_image: Tensor, targets_per_image: Dict[str, Tensor], matched_idxs_per_image: Tensor - ) -> Tuple[Tensor, Tensor]: + self, cls_logits_per_image: Tensor, targets_per_image: dict[str, Tensor], matched_idxs_per_image: Tensor + ) -> tuple[Tensor, Tensor]: """ Get samples from one image for classification losses computation. @@ -923,10 +932,10 @@ def get_cls_train_sample_per_image( def get_box_train_sample_per_image( self, box_regression_per_image: Tensor, - targets_per_image: Dict[str, Tensor], + targets_per_image: dict[str, Tensor], anchors_per_image: Tensor, matched_idxs_per_image: Tensor, - ) -> Tuple[Tensor, Tensor]: + ) -> tuple[Tensor, Tensor]: """ Get samples from one image for box regression losses computation. diff --git a/monai/apps/detection/networks/retinanet_network.py b/monai/apps/detection/networks/retinanet_network.py index 4a0d8dc228..543edcc735 100644 --- a/monai/apps/detection/networks/retinanet_network.py +++ b/monai/apps/detection/networks/retinanet_network.py @@ -37,13 +37,16 @@ https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py """ +from __future__ import annotations + import math -from typing import Callable, Dict, List, Sequence, Union +from collections.abc import Callable, Sequence +from typing import Dict import torch from torch import Tensor, nn -from monai.networks.blocks.backbone_fpn_utils import _resnet_fpn_extractor +from monai.networks.blocks.backbone_fpn_utils import BackboneWithFPN, _resnet_fpn_extractor from monai.networks.layers.factories import Conv from monai.networks.nets import resnet from monai.utils import ensure_tuple_rep, look_up_option, optional_import @@ -94,7 +97,7 @@ def __init__( self.num_classes = num_classes self.num_anchors = num_anchors - def forward(self, x: List[Tensor]) -> List[Tensor]: + def forward(self, x: list[Tensor]) -> list[Tensor]: """ It takes a list of feature maps as inputs, and outputs a list of classification maps. Each output classification map has same spatial size with the corresponding input feature map, @@ -163,7 +166,7 @@ def __init__(self, in_channels: int, num_anchors: int, spatial_dims: int): torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore torch.nn.init.zeros_(layer.bias) # type: ignore - def forward(self, x: List[Tensor]) -> List[Tensor]: + def forward(self, x: list[Tensor]) -> list[Tensor]: """ It takes a list of feature maps as inputs, and outputs a list of box regression maps. Each output box regression map has same spatial size with the corresponding input feature map, @@ -261,8 +264,8 @@ def __init__( spatial_dims: int, num_classes: int, num_anchors: int, - feature_extractor, - size_divisible: Union[Sequence[int], int] = 1, + feature_extractor: nn.Module, + size_divisible: Sequence[int] | int = 1, ): super().__init__() @@ -278,7 +281,7 @@ def __init__( ) self.feature_extractor = feature_extractor - self.feature_map_channels: int = self.feature_extractor.out_channels + self.feature_map_channels: int = self.feature_extractor.out_channels # type: ignore[assignment] self.num_anchors = num_anchors self.classification_head = RetinaNetClassificationHead( self.feature_map_channels, self.num_anchors, self.num_classes, spatial_dims=self.spatial_dims @@ -290,7 +293,7 @@ def __init__( self.cls_key: str = "classification" self.box_reg_key: str = "box_regression" - def forward(self, images: Tensor) -> Dict[str, List[Tensor]]: + def forward(self, images: Tensor) -> dict[str, list[Tensor]]: """ It takes an image tensor as inputs, and outputs a dictionary ``head_outputs``. ``head_outputs[self.cls_key]`` is the predicted classification maps, a list of Tensor. @@ -320,7 +323,7 @@ def forward(self, images: Tensor) -> Dict[str, List[Tensor]]: # compute classification and box regression maps from the feature maps # expandable for mask prediction in the future - head_outputs: Dict[str, List[Tensor]] = {self.cls_key: self.classification_head(feature_maps)} + head_outputs: dict[str, list[Tensor]] = {self.cls_key: self.classification_head(feature_maps)} head_outputs[self.box_reg_key] = self.regression_head(feature_maps) return head_outputs @@ -331,8 +334,8 @@ def resnet_fpn_feature_extractor( spatial_dims: int, pretrained_backbone: bool = False, returned_layers: Sequence[int] = (1, 2, 3), - trainable_backbone_layers: Union[int, None] = None, -): + trainable_backbone_layers: int | None = None, +) -> BackboneWithFPN: """ Constructs a feature extractor network with a ResNet-FPN backbone, used as feature_extractor in RetinaNet. diff --git a/monai/apps/detection/transforms/array.py b/monai/apps/detection/transforms/array.py index 6d59c8b49b..491af077f0 100644 --- a/monai/apps/detection/transforms/array.py +++ b/monai/apps/detection/transforms/array.py @@ -13,12 +13,14 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ -from typing import Callable, Optional, Sequence, Tuple, Type, Union +from __future__ import annotations + +from typing import Any, Sequence import numpy as np import torch -from monai.config.type_definitions import NdarrayOrTensor +from monai.config.type_definitions import DtypeLike, NdarrayOrTensor, NdarrayTensor from monai.data.box_utils import ( BoxMode, clip_boxes_to_image, @@ -108,8 +110,8 @@ class ConvertBoxMode(Transform): def __init__( self, - src_mode: Union[str, BoxMode, Type[BoxMode], None] = None, - dst_mode: Union[str, BoxMode, Type[BoxMode], None] = None, + src_mode: str | BoxMode | type[BoxMode] | None = None, + dst_mode: str | BoxMode | type[BoxMode] | None = None, ) -> None: self.src_mode = src_mode self.dst_mode = dst_mode @@ -148,7 +150,7 @@ class ConvertBoxToStandardMode(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, mode: Union[str, BoxMode, Type[BoxMode], None] = None) -> None: + def __init__(self, mode: str | BoxMode | type[BoxMode] | None = None) -> None: self.mode = mode def __call__(self, boxes: NdarrayOrTensor) -> NdarrayOrTensor: @@ -173,7 +175,7 @@ class AffineBox(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __call__(self, boxes: NdarrayOrTensor, affine: Union[NdarrayOrTensor, None]) -> NdarrayOrTensor: # type: ignore + def __call__(self, boxes: NdarrayOrTensor, affine: NdarrayOrTensor | None) -> NdarrayOrTensor: # type: ignore """ Args: boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` @@ -200,12 +202,12 @@ class ZoomBox(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, zoom: Union[Sequence[float], float], keep_size: bool = False, **kwargs) -> None: + def __init__(self, zoom: Sequence[float] | float, keep_size: bool = False, **kwargs: Any) -> None: self.zoom = zoom self.keep_size = keep_size self.kwargs = kwargs - def __call__(self, boxes: torch.Tensor, src_spatial_size: Union[Sequence[int], int, None] = None): + def __call__(self, boxes: NdarrayTensor, src_spatial_size: Sequence[int] | int | None = None) -> NdarrayTensor: """ Args: boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` @@ -260,11 +262,11 @@ class ResizeBox(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, spatial_size: Union[Sequence[int], int], size_mode: str = "all", **kwargs) -> None: + def __init__(self, spatial_size: Sequence[int] | int, size_mode: str = "all", **kwargs: Any) -> None: self.size_mode = look_up_option(size_mode, ["all", "longest"]) self.spatial_size = spatial_size - def __call__(self, boxes: NdarrayOrTensor, src_spatial_size: Union[Sequence[int], int]): # type: ignore + def __call__(self, boxes: NdarrayOrTensor, src_spatial_size: Sequence[int] | int) -> NdarrayOrTensor: # type: ignore[override] """ Args: boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` @@ -309,10 +311,10 @@ class FlipBox(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: + def __init__(self, spatial_axis: Sequence[int] | int | None = None) -> None: self.spatial_axis = spatial_axis - def __call__(self, boxes: NdarrayOrTensor, spatial_size: Union[Sequence[int], int]): # type: ignore + def __call__(self, boxes: NdarrayOrTensor, spatial_size: Sequence[int] | int): # type: ignore """ Args: boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` @@ -339,9 +341,9 @@ def __init__(self, remove_empty: bool = False) -> None: def __call__( # type: ignore self, boxes: NdarrayOrTensor, - labels: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor], - spatial_size: Union[Sequence[int], int], - ) -> Tuple[NdarrayOrTensor, Union[Tuple, NdarrayOrTensor]]: + labels: Sequence[NdarrayOrTensor] | NdarrayOrTensor, + spatial_size: Sequence[int] | int, + ) -> tuple[NdarrayOrTensor, tuple | NdarrayOrTensor]: """ Args: boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` @@ -392,7 +394,7 @@ def __init__(self, bg_label: int = -1, ellipse_mask: bool = False) -> None: self.ellipse_mask = ellipse_mask def __call__( # type: ignore - self, boxes: NdarrayOrTensor, labels: NdarrayOrTensor, spatial_size: Union[Sequence[int], int] + self, boxes: NdarrayOrTensor, labels: NdarrayOrTensor, spatial_size: Sequence[int] | int ) -> NdarrayOrTensor: """ Args: @@ -422,12 +424,17 @@ class MaskToBox(Transform): backend = [TransformBackends.NUMPY] - def __init__(self, bg_label: int = -1, box_dtype=torch.float32, label_dtype=torch.long) -> None: + def __init__( + self, + bg_label: int = -1, + box_dtype: DtypeLike | torch.dtype = torch.float32, + label_dtype: DtypeLike | torch.dtype = torch.long, + ) -> None: self.bg_label = bg_label self.box_dtype = box_dtype self.label_dtype = label_dtype - def __call__(self, boxes_mask: NdarrayOrTensor) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: + def __call__(self, boxes_mask: NdarrayOrTensor) -> tuple[NdarrayOrTensor, NdarrayOrTensor]: """ Args: boxes_mask: int16 array, sized (num_box, H, W). Each channel represents a box. @@ -470,20 +477,20 @@ class SpatialCropBox(SpatialCrop): def __init__( self, - roi_center: Union[Sequence[int], NdarrayOrTensor, None] = None, - roi_size: Union[Sequence[int], NdarrayOrTensor, None] = None, - roi_start: Union[Sequence[int], NdarrayOrTensor, None] = None, - roi_end: Union[Sequence[int], NdarrayOrTensor, None] = None, - roi_slices: Optional[Sequence[slice]] = None, + roi_center: Sequence[int] | NdarrayOrTensor | None = None, + roi_size: Sequence[int] | NdarrayOrTensor | None = None, + roi_start: Sequence[int] | NdarrayOrTensor | None = None, + roi_end: Sequence[int] | NdarrayOrTensor | None = None, + roi_slices: Sequence[slice] | None = None, ) -> None: super().__init__(roi_center, roi_size, roi_start, roi_end, roi_slices) for s in self.slices: if s.start < 0 or s.stop < 0 or (s.step is not None and s.step < 0): raise ValueError("Currently negative indexing is not supported for SpatialCropBox.") - def __call__( # type: ignore - self, boxes: NdarrayOrTensor, labels: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor] - ): + def __call__( # type: ignore[override] + self, boxes: NdarrayTensor, labels: Sequence[NdarrayOrTensor] | NdarrayOrTensor + ) -> tuple[NdarrayTensor, tuple | NdarrayOrTensor]: """ Args: boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` @@ -526,11 +533,9 @@ class RotateBox90(Rotate90): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __call__(self, boxes: NdarrayOrTensor, spatial_size: Union[Sequence[int], int]): # type: ignore + def __call__(self, boxes: NdarrayTensor, spatial_size: Sequence[int] | int) -> NdarrayTensor: # type: ignore[override] """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - rot90: Callable = rot90_boxes - out: NdarrayOrTensor = rot90(boxes, spatial_size, self.k, self.spatial_axes) - return out + return rot90_boxes(boxes, spatial_size, self.k, self.spatial_axes) diff --git a/monai/apps/detection/transforms/box_ops.py b/monai/apps/detection/transforms/box_ops.py index d2445577d0..15de8a4342 100644 --- a/monai/apps/detection/transforms/box_ops.py +++ b/monai/apps/detection/transforms/box_ops.py @@ -9,13 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from collections.abc import Sequence from copy import deepcopy -from typing import Optional, Sequence, Tuple, Union import numpy as np import torch -from monai.config.type_definitions import NdarrayOrTensor +from monai.config.type_definitions import DtypeLike, NdarrayOrTensor, NdarrayTensor from monai.data.box_utils import COMPUTE_DTYPE, TO_REMOVE, get_spatial_dims from monai.transforms import Resize from monai.transforms.utils import create_scale @@ -59,7 +61,7 @@ def _apply_affine_to_points(points: torch.Tensor, affine: torch.Tensor, include_ return points_affine -def apply_affine_to_boxes(boxes: NdarrayOrTensor, affine: NdarrayOrTensor) -> NdarrayOrTensor: +def apply_affine_to_boxes(boxes: NdarrayTensor, affine: NdarrayOrTensor) -> NdarrayTensor: """ This function applies affine matrices to the boxes @@ -96,10 +98,10 @@ def apply_affine_to_boxes(boxes: NdarrayOrTensor, affine: NdarrayOrTensor) -> Nd # convert tensor back to numpy if needed boxes_affine: NdarrayOrTensor boxes_affine, *_ = convert_to_dst_type(src=boxes_t_affine, dst=boxes) - return boxes_affine + return boxes_affine # type: ignore[return-value] -def zoom_boxes(boxes: NdarrayOrTensor, zoom: Union[Sequence[float], float]): +def zoom_boxes(boxes: NdarrayTensor, zoom: Sequence[float] | float) -> NdarrayTensor: """ Zoom boxes @@ -127,8 +129,8 @@ def zoom_boxes(boxes: NdarrayOrTensor, zoom: Union[Sequence[float], float]): def resize_boxes( - boxes: NdarrayOrTensor, src_spatial_size: Union[Sequence[int], int], dst_spatial_size: Union[Sequence[int], int] -): + boxes: NdarrayOrTensor, src_spatial_size: Sequence[int] | int, dst_spatial_size: Sequence[int] | int +) -> NdarrayOrTensor: """ Resize boxes when the corresponding image is resized @@ -159,10 +161,8 @@ def resize_boxes( def flip_boxes( - boxes: NdarrayOrTensor, - spatial_size: Union[Sequence[int], int], - flip_axes: Optional[Union[Sequence[int], int]] = None, -): + boxes: NdarrayTensor, spatial_size: Sequence[int] | int, flip_axes: Sequence[int] | int | None = None +) -> NdarrayTensor: """ Flip boxes when the corresponding image is flipped @@ -185,10 +185,7 @@ def flip_boxes( flip_axes = ensure_tuple(flip_axes) # flip box - if isinstance(boxes, torch.Tensor): - _flip_boxes = boxes.clone() - else: - _flip_boxes = deepcopy(boxes) # type: ignore + _flip_boxes: NdarrayTensor = boxes.clone() if isinstance(boxes, torch.Tensor) else deepcopy(boxes) # type: ignore[assignment] for axis in flip_axes: _flip_boxes[:, axis + spatial_dims] = spatial_size[axis] - boxes[:, axis] - TO_REMOVE @@ -200,7 +197,7 @@ def flip_boxes( def convert_box_to_mask( boxes: NdarrayOrTensor, labels: NdarrayOrTensor, - spatial_size: Union[Sequence[int], int], + spatial_size: Sequence[int] | int, bg_label: int = -1, ellipse_mask: bool = False, ) -> NdarrayOrTensor: @@ -275,8 +272,11 @@ def convert_box_to_mask( def convert_mask_to_box( - boxes_mask: NdarrayOrTensor, bg_label: int = -1, box_dtype=torch.float32, label_dtype=torch.long -) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: + boxes_mask: NdarrayOrTensor, + bg_label: int = -1, + box_dtype: DtypeLike | torch.dtype = torch.float32, + label_dtype: DtypeLike | torch.dtype = torch.long, +) -> tuple[NdarrayOrTensor, NdarrayOrTensor]: """ Convert int16 mask image to box, which has the same size with the input image @@ -325,8 +325,8 @@ def convert_mask_to_box( def select_labels( - labels: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor], keep: NdarrayOrTensor -) -> Union[Tuple, NdarrayOrTensor]: + labels: Sequence[NdarrayOrTensor] | NdarrayOrTensor, keep: NdarrayOrTensor +) -> tuple | NdarrayOrTensor: """ For element in labels, select indices keep from it. @@ -353,7 +353,7 @@ def select_labels( return tuple(labels_select_list) -def swapaxes_boxes(boxes: NdarrayOrTensor, axis1: int, axis2: int): +def swapaxes_boxes(boxes: NdarrayTensor, axis1: int, axis2: int) -> NdarrayTensor: """ Interchange two axes of boxes. @@ -377,12 +377,12 @@ def swapaxes_boxes(boxes: NdarrayOrTensor, axis1: int, axis2: int): boxes_swap[:, [spatial_dims + axis1, spatial_dims + axis2]] = boxes_swap[ :, [spatial_dims + axis2, spatial_dims + axis1] ] - return boxes_swap + return boxes_swap # type: ignore[return-value] def rot90_boxes( - boxes: NdarrayOrTensor, spatial_size: Union[Sequence[int], int], k: int = 1, axes: Tuple[int, int] = (0, 1) -): + boxes: NdarrayTensor, spatial_size: Sequence[int] | int, k: int = 1, axes: tuple[int, int] = (0, 1) +) -> NdarrayTensor: """ Rotate boxes by 90 degrees in the plane specified by axes. Rotation direction is from the first towards the second axis. diff --git a/monai/apps/detection/transforms/dictionary.py b/monai/apps/detection/transforms/dictionary.py index fa365895b5..f77c5f4c48 100644 --- a/monai/apps/detection/transforms/dictionary.py +++ b/monai/apps/detection/transforms/dictionary.py @@ -14,8 +14,12 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ + +from __future__ import annotations + +from collections.abc import Hashable, Mapping, Sequence from copy import deepcopy -from typing import Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any import numpy as np import torch @@ -34,7 +38,7 @@ ) from monai.apps.detection.transforms.box_ops import convert_box_to_mask from monai.config import KeysCollection, SequenceStr -from monai.config.type_definitions import NdarrayOrTensor +from monai.config.type_definitions import DtypeLike, NdarrayOrTensor from monai.data.box_utils import COMPUTE_DTYPE, BoxMode, clip_boxes_to_image from monai.data.meta_tensor import MetaTensor, get_track_meta from monai.data.utils import orientation_ras_lps @@ -109,8 +113,8 @@ class ConvertBoxModed(MapTransform, InvertibleTransform): def __init__( self, box_keys: KeysCollection, - src_mode: Union[str, BoxMode, Type[BoxMode], None] = None, - dst_mode: Union[str, BoxMode, Type[BoxMode], None] = None, + src_mode: str | BoxMode | type[BoxMode] | None = None, + dst_mode: str | BoxMode | type[BoxMode] | None = None, allow_missing_keys: bool = False, ) -> None: """ @@ -127,14 +131,14 @@ def __init__( super().__init__(box_keys, allow_missing_keys) self.converter = ConvertBoxMode(src_mode=src_mode, dst_mode=dst_mode) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) self.push_transform(d, key, extra_info={"src": self.converter.src_mode, "dst": self.converter.dst_mode}) return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): tr = self.get_most_recent_transform(d, key) @@ -167,7 +171,7 @@ class ConvertBoxToStandardModed(MapTransform, InvertibleTransform): def __init__( self, box_keys: KeysCollection, - mode: Union[str, BoxMode, Type[BoxMode], None] = None, + mode: str | BoxMode | type[BoxMode] | None = None, allow_missing_keys: bool = False, ) -> None: """ @@ -182,14 +186,14 @@ def __init__( super().__init__(box_keys, allow_missing_keys) self.converter = ConvertBoxToStandardMode(mode=mode) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) self.push_transform(d, key, extra_info={"mode": self.converter.mode}) return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): tr = self.get_most_recent_transform(d, key) @@ -230,9 +234,9 @@ def __init__( box_keys: KeysCollection, box_ref_image_keys: str, allow_missing_keys: bool = False, - image_meta_key: Union[str, None] = None, - image_meta_key_postfix: Union[str, None] = DEFAULT_POST_FIX, - affine_lps_to_ras=False, + image_meta_key: str | None = None, + image_meta_key_postfix: str | None = DEFAULT_POST_FIX, + affine_lps_to_ras: bool = False, ) -> None: super().__init__(box_keys, allow_missing_keys) box_ref_image_keys_tuple = ensure_tuple(box_ref_image_keys) @@ -246,7 +250,7 @@ def __init__( self.converter_to_image_coordinate = AffineBox() self.affine_lps_to_ras = affine_lps_to_ras - def extract_affine(self, data: Mapping[Hashable, torch.Tensor]) -> Tuple[NdarrayOrTensor, torch.Tensor]: + def extract_affine(self, data: Mapping[Hashable, torch.Tensor]) -> tuple[NdarrayOrTensor, torch.Tensor]: d = dict(data) meta_key = self.image_meta_key @@ -274,7 +278,7 @@ def extract_affine(self, data: Mapping[Hashable, torch.Tensor]) -> Tuple[Ndarray inv_affine_t = torch.inverse(affine_t.to(COMPUTE_DTYPE)) return affine, inv_affine_t - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) affine, inv_affine_t = self.extract_affine(data) # type: ignore @@ -284,7 +288,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N self.push_transform(d, key, extra_info={"affine": affine}) return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -322,16 +326,16 @@ def __init__( box_keys: KeysCollection, box_ref_image_keys: str, allow_missing_keys: bool = False, - image_meta_key: Union[str, None] = None, - image_meta_key_postfix: Union[str, None] = DEFAULT_POST_FIX, - affine_lps_to_ras=False, + image_meta_key: str | None = None, + image_meta_key_postfix: str | None = DEFAULT_POST_FIX, + affine_lps_to_ras: bool = False, ) -> None: super().__init__( box_keys, box_ref_image_keys, allow_missing_keys, image_meta_key, image_meta_key_postfix, affine_lps_to_ras ) self.converter_to_world_coordinate = AffineBox() - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) affine, inv_affine_t = self.extract_affine(data) # type: ignore @@ -379,13 +383,13 @@ def __init__( image_keys: KeysCollection, box_keys: KeysCollection, box_ref_image_keys: KeysCollection, - zoom: Union[Sequence[float], float], + zoom: Sequence[float] | float, mode: SequenceStr = InterpolateMode.AREA, padding_mode: SequenceStr = NumpyPadMode.EDGE, - align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, + align_corners: Sequence[bool | None] | bool | None = None, keep_size: bool = True, allow_missing_keys: bool = False, - **kwargs, + **kwargs: Any, ) -> None: self.image_keys = ensure_tuple(image_keys) self.box_keys = ensure_tuple(box_keys) @@ -398,8 +402,8 @@ def __init__( self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **kwargs) self.keep_size = keep_size - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: - d = dict(data) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d: dict[Hashable, torch.Tensor] = dict(data) # zoom box for box_key, box_ref_image_key in zip(self.box_keys, self.box_ref_image_keys): @@ -423,8 +427,8 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: - d = dict(data) + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d: dict[Hashable, torch.Tensor] = dict(data) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key, check=False) @@ -493,14 +497,14 @@ def __init__( box_keys: KeysCollection, box_ref_image_keys: KeysCollection, prob: float = 0.1, - min_zoom: Union[Sequence[float], float] = 0.9, - max_zoom: Union[Sequence[float], float] = 1.1, + min_zoom: Sequence[float] | float = 0.9, + max_zoom: Sequence[float] | float = 1.1, mode: SequenceStr = InterpolateMode.AREA, padding_mode: SequenceStr = NumpyPadMode.EDGE, - align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, + align_corners: Sequence[bool | None] | bool | None = None, keep_size: bool = True, allow_missing_keys: bool = False, - **kwargs, + **kwargs: Any, ) -> None: self.image_keys = ensure_tuple(image_keys) self.box_keys = ensure_tuple(box_keys) @@ -514,14 +518,12 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.image_keys)) self.keep_size = keep_size - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandZoomBoxd": + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandZoomBoxd: super().set_random_state(seed, state) self.rand_zoom.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) first_key: Hashable = self.first_key(d) if first_key == (): @@ -564,7 +566,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): @@ -609,7 +611,7 @@ def __init__( image_keys: KeysCollection, box_keys: KeysCollection, box_ref_image_keys: KeysCollection, - spatial_axis: Optional[Union[Sequence[int], int]] = None, + spatial_axis: Sequence[int] | int | None = None, allow_missing_keys: bool = False, ) -> None: self.image_keys = ensure_tuple(image_keys) @@ -620,7 +622,7 @@ def __init__( self.flipper = Flip(spatial_axis=spatial_axis) self.box_flipper = FlipBox(spatial_axis=self.flipper.spatial_axis) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.image_keys: @@ -632,7 +634,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc self.push_transform(d, box_key, extra_info={"spatial_size": spatial_size, "type": "box_key"}) return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): @@ -673,7 +675,7 @@ def __init__( box_keys: KeysCollection, box_ref_image_keys: KeysCollection, prob: float = 0.1, - spatial_axis: Optional[Union[Sequence[int], int]] = None, + spatial_axis: Sequence[int] | int | None = None, allow_missing_keys: bool = False, ) -> None: self.image_keys = ensure_tuple(image_keys) @@ -685,13 +687,11 @@ def __init__( self.flipper = Flip(spatial_axis=spatial_axis) self.box_flipper = FlipBox(spatial_axis=spatial_axis) - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandFlipBoxd": + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandFlipBoxd: super().set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) self.randomize(None) @@ -711,7 +711,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc self.push_transform(d, box_key, extra_info={"spatial_size": spatial_size, "type": "box_key"}) return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): @@ -784,7 +784,7 @@ def __init__( self.box_ref_image_keys = box_ref_image_keys_tuple[0] self.clipper = ClipBoxToImage(remove_empty=remove_empty) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) spatial_size = d[self.box_ref_image_keys].shape[1:] labels = [d[label_key] for label_key in self.label_keys] # could be multiple arrays @@ -872,7 +872,7 @@ def __init__( self.bg_label = min_fg_label - 1 # make sure background label is always smaller than fg labels. self.converter = BoxToMask(bg_label=self.bg_label, ellipse_mask=ellipse_mask) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for box_key, label_key, box_mask_key, box_ref_image_key in zip( @@ -940,8 +940,8 @@ def __init__( box_mask_keys: KeysCollection, label_keys: KeysCollection, min_fg_label: int, - box_dtype=torch.float32, - label_dtype=torch.long, + box_dtype: DtypeLike | torch.dtype = torch.float32, + label_dtype: DtypeLike | torch.dtype = torch.long, allow_missing_keys: bool = False, ) -> None: super().__init__(box_keys, allow_missing_keys) @@ -954,7 +954,7 @@ def __init__( self.converter = MaskToBox(bg_label=self.bg_label, box_dtype=box_dtype, label_dtype=label_dtype) self.box_dtype = box_dtype - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for box_key, label_key, box_mask_key in zip(self.box_keys, self.label_keys, self.box_mask_keys): @@ -1021,16 +1021,16 @@ def __init__( image_keys: KeysCollection, box_keys: str, label_keys: KeysCollection, - spatial_size: Union[Sequence[int], int], + spatial_size: Sequence[int] | int, pos: float = 1.0, neg: float = 1.0, num_samples: int = 1, whole_box: bool = True, - thresh_image_key: Optional[str] = None, + thresh_image_key: str | None = None, image_threshold: float = 0.0, - fg_indices_key: Optional[str] = None, - bg_indices_key: Optional[str] = None, - meta_keys: Optional[KeysCollection] = None, + fg_indices_key: str | None = None, + bg_indices_key: str | None = None, + meta_keys: KeysCollection | None = None, meta_key_postfix: str = DEFAULT_POST_FIX, allow_smaller: bool = False, allow_missing_keys: bool = False, @@ -1050,7 +1050,7 @@ def __init__( self.box_keys = box_keys_tuple[0] self.label_keys = ensure_tuple(label_keys) - self.spatial_size_: Union[Tuple[int, ...], Sequence[int], int] = spatial_size + self.spatial_size_: tuple[int, ...] | Sequence[int] | int = spatial_size if pos < 0 or neg < 0: raise ValueError(f"pos and neg must be nonnegative, got pos={pos} neg={neg}.") @@ -1071,7 +1071,7 @@ def __init__( if len(self.image_keys) != len(self.meta_keys): raise ValueError("meta_keys should have the same length as keys.") self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.image_keys)) - self.centers: Optional[List[List[int]]] = None + self.centers: tuple[tuple] | None = None self.allow_smaller = allow_smaller def generate_fg_center_boxes_np(self, boxes: NdarrayOrTensor, image_size: Sequence[int]) -> np.ndarray: @@ -1104,9 +1104,9 @@ def randomize( # type: ignore self, boxes: NdarrayOrTensor, image_size: Sequence[int], - fg_indices: Optional[NdarrayOrTensor] = None, - bg_indices: Optional[NdarrayOrTensor] = None, - thresh_image: Optional[NdarrayOrTensor] = None, + fg_indices: NdarrayOrTensor | None = None, + bg_indices: NdarrayOrTensor | None = None, + thresh_image: NdarrayOrTensor | None = None, ) -> None: if fg_indices is None or bg_indices is None: # We don't require crop center to be within the boxes. @@ -1133,7 +1133,7 @@ def randomize( # type: ignore self.allow_smaller, ) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> List[Dict[Hashable, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> list[dict[Hashable, torch.Tensor]]: d = dict(data) image_size = d[self.image_keys[0]].shape[1:] self.spatial_size = fall_back_tuple(self.spatial_size_, image_size) @@ -1150,7 +1150,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> List[Dict[Hashable, raise ValueError("no available ROI centers to crop.") # initialize returned list with shallow copy to preserve key ordering - results: List[Dict[Hashable, torch.Tensor]] = [dict(d) for _ in range(self.num_samples)] + results: list[dict[Hashable, torch.Tensor]] = [dict(d) for _ in range(self.num_samples)] # crop images and boxes for each center. for i, center in enumerate(self.centers): @@ -1198,7 +1198,7 @@ def __init__( box_keys: KeysCollection, box_ref_image_keys: KeysCollection, k: int = 1, - spatial_axes: Tuple[int, int] = (0, 1), + spatial_axes: tuple[int, int] = (0, 1), allow_missing_keys: bool = False, ) -> None: self.image_keys = ensure_tuple(image_keys) @@ -1225,7 +1225,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, t d[key] = self.img_rotator(d[key]) return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): @@ -1270,7 +1270,7 @@ def __init__( box_ref_image_keys: KeysCollection, prob: float = 0.1, max_k: int = 3, - spatial_axes: Tuple[int, int] = (0, 1), + spatial_axes: tuple[int, int] = (0, 1), allow_missing_keys: bool = False, ) -> None: self.image_keys = ensure_tuple(image_keys) @@ -1316,7 +1316,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, t self.push_transform(d[key], extra_info=xform) return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) if self._rand_k % 4 == 0: return d diff --git a/monai/apps/detection/utils/ATSS_matcher.py b/monai/apps/detection/utils/ATSS_matcher.py index c208fcd41c..cc9e238862 100644 --- a/monai/apps/detection/utils/ATSS_matcher.py +++ b/monai/apps/detection/utils/ATSS_matcher.py @@ -72,9 +72,12 @@ 5) add support for float16 cpu """ +from __future__ import annotations + import logging from abc import ABC, abstractmethod -from typing import Callable, Sequence, Tuple, TypeVar +from collections.abc import Callable, Sequence +from typing import TypeVar import torch from torch import Tensor @@ -103,7 +106,7 @@ def __init__(self, similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou): def __call__( self, boxes: torch.Tensor, anchors: torch.Tensor, num_anchors_per_level: Sequence[int], num_anchors_per_loc: int - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: """ Compute matches for a single image @@ -141,7 +144,7 @@ def __call__( @abstractmethod def compute_matches( self, boxes: torch.Tensor, anchors: torch.Tensor, num_anchors_per_level: Sequence[int], num_anchors_per_loc: int - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: """ Compute matches @@ -194,7 +197,7 @@ def __init__( def compute_matches( self, boxes: torch.Tensor, anchors: torch.Tensor, num_anchors_per_level: Sequence[int], num_anchors_per_loc: int - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: """ Compute matches according to ATTS for a single image Adapted from diff --git a/monai/apps/detection/utils/anchor_utils.py b/monai/apps/detection/utils/anchor_utils.py index 55c256248a..baaa7ce874 100644 --- a/monai/apps/detection/utils/anchor_utils.py +++ b/monai/apps/detection/utils/anchor_utils.py @@ -37,7 +37,9 @@ https://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/anchor_utils.py """ -from typing import List, Sequence, Union +from __future__ import annotations + +from typing import List, Sequence import torch from torch import Tensor, nn @@ -148,7 +150,7 @@ def generate_anchors( scales: Sequence, aspect_ratios: Sequence, dtype: torch.dtype = torch.float32, - device: Union[torch.device, None] = None, + device: torch.device | None = None, ) -> torch.Tensor: """ Compute cell anchor shapes at multiple sizes and aspect ratios for the current feature map. @@ -203,7 +205,7 @@ def generate_anchors( return base_anchors.round() - def set_cell_anchors(self, dtype: torch.dtype, device: torch.device): + def set_cell_anchors(self, dtype: torch.dtype, device: torch.device) -> None: """ Convert each element in self.cell_anchors to ``dtype`` and send to ``device``. """ @@ -215,7 +217,7 @@ def num_anchors_per_location(self): """ return [c.shape[0] for c in self.cell_anchors] - def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]: + def grid_anchors(self, grid_sizes: list[list[int]], strides: list[list[Tensor]]) -> list[Tensor]: """ Every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:spatial_dims) corresponds to a feature map. @@ -279,7 +281,7 @@ def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) return anchors - def forward(self, images: Tensor, feature_maps: List[Tensor]) -> List[Tensor]: + def forward(self, images: Tensor, feature_maps: list[Tensor]) -> list[Tensor]: """ Generate anchor boxes for each image. @@ -366,16 +368,11 @@ class AnchorGeneratorWithAnchorShape(AnchorGenerator): def __init__( self, - feature_map_scales: Union[Sequence[int], Sequence[float]] = (1, 2, 4, 8), - base_anchor_shapes: Union[Sequence[Sequence[int]], Sequence[Sequence[float]]] = ( - (32, 32, 32), - (48, 20, 20), - (20, 48, 20), - (20, 20, 48), - ), + feature_map_scales: Sequence[int] | Sequence[float] = (1, 2, 4, 8), + base_anchor_shapes: Sequence[Sequence[int]] + | Sequence[Sequence[float]] = ((32, 32, 32), (48, 20, 20), (20, 48, 20), (20, 20, 48)), indexing: str = "ij", ) -> None: - nn.Module.__init__(self) spatial_dims = len(base_anchor_shapes[0]) @@ -389,7 +386,7 @@ def __init__( @staticmethod def generate_anchors_using_shape( - anchor_shapes: torch.Tensor, dtype: torch.dtype = torch.float32, device: Union[torch.device, None] = None + anchor_shapes: torch.Tensor, dtype: torch.dtype = torch.float32, device: torch.device | None = None ) -> torch.Tensor: """ Compute cell anchor shapes at multiple sizes and aspect ratios for the current feature map. diff --git a/monai/apps/detection/utils/box_coder.py b/monai/apps/detection/utils/box_coder.py index 6458360fcd..504ae21d0f 100644 --- a/monai/apps/detection/utils/box_coder.py +++ b/monai/apps/detection/utils/box_coder.py @@ -49,8 +49,10 @@ https://github.com/pytorch/vision/blob/main/torchvision/models/detection/_utils.py """ +from __future__ import annotations + import math -from typing import Sequence, Tuple, Union +from collections.abc import Sequence import torch from torch import Tensor @@ -120,14 +122,14 @@ class BoxCoder: # We expect gt_back to be equal to gt_boxes """ - def __init__(self, weights: Tuple[float], boxes_xform_clip: Union[float, None] = None) -> None: + def __init__(self, weights: Sequence[float], boxes_xform_clip: float | None = None) -> None: if boxes_xform_clip is None: boxes_xform_clip = math.log(1000.0 / 16) self.spatial_dims = look_up_option(len(weights), [4, 6]) // 2 self.weights = weights self.boxes_xform_clip = boxes_xform_clip - def encode(self, gt_boxes: Sequence[Tensor], proposals: Sequence[Tensor]) -> Tuple[Tensor]: + def encode(self, gt_boxes: Sequence[Tensor], proposals: Sequence[Tensor]) -> tuple[Tensor]: """ Encode a set of proposals with respect to some ground truth (gt) boxes. @@ -146,7 +148,7 @@ def encode(self, gt_boxes: Sequence[Tensor], proposals: Sequence[Tensor]) -> Tup concat_proposals = torch.cat(tuple(proposals), dim=0) concat_targets = self.encode_single(concat_gt_boxes, concat_proposals) # split to tuple - targets: Tuple[Tensor] = concat_targets.split(boxes_per_image, 0) + targets: tuple[Tensor] = concat_targets.split(boxes_per_image, 0) return targets def encode_single(self, gt_boxes: Tensor, proposals: Tensor) -> Tensor: diff --git a/monai/apps/detection/utils/box_selector.py b/monai/apps/detection/utils/box_selector.py index e0e82dbef7..c3da858880 100644 --- a/monai/apps/detection/utils/box_selector.py +++ b/monai/apps/detection/utils/box_selector.py @@ -37,7 +37,9 @@ https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py """ -from typing import Callable, List, Tuple, Union +from __future__ import annotations + +from collections.abc import Callable import torch from torch import Tensor @@ -100,7 +102,7 @@ def __init__( self.nms_thresh = nms_thresh self.detections_per_img = detections_per_img - def select_top_score_idx_per_level(self, logits: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + def select_top_score_idx_per_level(self, logits: Tensor) -> tuple[Tensor, Tensor, Tensor]: """ Select indices with highest scores. @@ -144,8 +146,8 @@ def select_top_score_idx_per_level(self, logits: Tensor) -> Tuple[Tensor, Tensor return topk_idxs, selected_scores, selected_labels # type: ignore def select_boxes_per_image( - self, boxes_list: List[Tensor], logits_list: List[Tensor], spatial_size: Union[List[int], Tuple[int]] - ) -> Tuple[Tensor, Tensor, Tensor]: + self, boxes_list: list[Tensor], logits_list: list[Tensor], spatial_size: list[int] | tuple[int] + ) -> tuple[Tensor, Tensor, Tensor]: """ Postprocessing to generate detection result from classification logits and boxes. diff --git a/monai/apps/detection/utils/detector_utils.py b/monai/apps/detection/utils/detector_utils.py index d7693da62c..493ce5b216 100644 --- a/monai/apps/detection/utils/detector_utils.py +++ b/monai/apps/detection/utils/detector_utils.py @@ -9,7 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Sequence, Tuple, Union +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any import torch import torch.nn.functional as F @@ -20,7 +23,7 @@ from monai.utils import PytorchPadMode, ensure_tuple_rep -def check_input_images(input_images: Union[List[Tensor], Tensor], spatial_dims: int) -> None: +def check_input_images(input_images: list[Tensor] | Tensor, spatial_dims: int) -> None: """ Validate the input dimensionality (raise a `ValueError` if invalid). @@ -35,7 +38,7 @@ def check_input_images(input_images: Union[List[Tensor], Tensor], spatial_dims: "When input_images is a Tensor, its need to be (spatial_dims + 2)-D." f"In this case, it should be a {(spatial_dims + 2)}-D Tensor, got Tensor shape {input_images.shape}." ) - elif isinstance(input_images, List): + elif isinstance(input_images, list): for img in input_images: if len(img.shape) != spatial_dims + 1: raise ValueError( @@ -48,8 +51,8 @@ def check_input_images(input_images: Union[List[Tensor], Tensor], spatial_dims: def check_training_targets( - input_images: Union[List[Tensor], Tensor], - targets: Union[List[Dict[str, Tensor]], None], + input_images: list[Tensor] | Tensor, + targets: list[dict[str, Tensor]] | None, spatial_dims: int, target_label_key: str, target_box_key: str, @@ -89,12 +92,12 @@ def check_training_targets( def pad_images( - input_images: Union[List[Tensor], Tensor], + input_images: list[Tensor] | Tensor, spatial_dims: int, - size_divisible: Union[int, Sequence[int]], - mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, - **kwargs, -) -> Tuple[Tensor, List[List[int]]]: + size_divisible: int | Sequence[int], + mode: PytorchPadMode | str = PytorchPadMode.CONSTANT, + **kwargs: Any, +) -> tuple[Tensor, list[list[int]]]: """ Pad the input images, so that the output spatial sizes are divisible by `size_divisible`. It pads them at the end to create a (B, C, H, W) or (B, C, H, W, D) Tensor. @@ -146,7 +149,7 @@ def pad_images( max_spatial_size = compute_divisible_spatial_size(spatial_shape=list(max_spatial_size_t), k=size_divisible) # allocate memory for the padded images - images = torch.zeros([len(image_sizes), in_channels] + max_spatial_size, dtype=dtype, device=device) + images = torch.zeros([len(image_sizes), in_channels] + list(max_spatial_size), dtype=dtype, device=device) # Use `SpatialPad` to match sizes, padding in the end will not affect boxes padder = SpatialPad(spatial_size=max_spatial_size, method="end", mode=mode, **kwargs) @@ -157,12 +160,12 @@ def pad_images( def preprocess_images( - input_images: Union[List[Tensor], Tensor], + input_images: list[Tensor] | Tensor, spatial_dims: int, - size_divisible: Union[int, Sequence[int]], - mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, - **kwargs, -) -> Tuple[Tensor, List[List[int]]]: + size_divisible: int | Sequence[int], + mode: PytorchPadMode | str = PytorchPadMode.CONSTANT, + **kwargs: Any, +) -> tuple[Tensor, list[list[int]]]: """ Preprocess the input images, including diff --git a/monai/apps/detection/utils/hard_negative_sampler.py b/monai/apps/detection/utils/hard_negative_sampler.py index ee423bb4bc..4c8dcf5d45 100644 --- a/monai/apps/detection/utils/hard_negative_sampler.py +++ b/monai/apps/detection/utils/hard_negative_sampler.py @@ -29,8 +29,9 @@ https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/sampler.py """ +from __future__ import annotations + import logging -from typing import List, Tuple import torch from torch import Tensor @@ -125,7 +126,7 @@ def __init__( self.positive_fraction = positive_fraction logging.info("Sampling hard negatives on a per batch basis") - def __call__(self, target_labels: List[Tensor], concat_fg_probs: Tensor) -> Tuple[List[Tensor], List[Tensor]]: + def __call__(self, target_labels: list[Tensor], concat_fg_probs: Tensor) -> tuple[list[Tensor], list[Tensor]]: """ Select positives and hard negatives from list samples per image. Hard negative sampler will be applied to each image independently. @@ -158,8 +159,8 @@ def __call__(self, target_labels: List[Tensor], concat_fg_probs: Tensor) -> Tupl return self.select_samples_img_list(target_labels, fg_probs) def select_samples_img_list( - self, target_labels: List[Tensor], fg_probs: List[Tensor] - ) -> Tuple[List[Tensor], List[Tensor]]: + self, target_labels: list[Tensor], fg_probs: list[Tensor] + ) -> tuple[list[Tensor], list[Tensor]]: """ Select positives and hard negatives from list samples per image. Hard negative sampler will be applied to each image independently. @@ -205,7 +206,7 @@ def select_samples_img_list( return pos_idx, neg_idx - def select_samples_per_img(self, labels_per_img: Tensor, fg_probs_per_img: Tensor) -> Tuple[Tensor, Tensor]: + def select_samples_per_img(self, labels_per_img: Tensor, fg_probs_per_img: Tensor) -> tuple[Tensor, Tensor]: """ Select positives and hard negatives from samples. diff --git a/monai/apps/detection/utils/predict_utils.py b/monai/apps/detection/utils/predict_utils.py index a11aa97ce7..d030320fa1 100644 --- a/monai/apps/detection/utils/predict_utils.py +++ b/monai/apps/detection/utils/predict_utils.py @@ -9,15 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional +from __future__ import annotations import torch -from torch import Tensor +from torch import Tensor, nn from monai.inferers import SlidingWindowInferer -def ensure_dict_value_to_list_(head_outputs: Dict[str, List[Tensor]], keys: Optional[List[str]] = None) -> None: +def ensure_dict_value_to_list_(head_outputs: dict[str, list[Tensor]], keys: list[str] | None = None) -> None: """ An in-place function. We expect ``head_outputs`` to be Dict[str, List[Tensor]]. Yet if it is Dict[str, Tensor], this func converts it to Dict[str, List[Tensor]]. @@ -41,7 +41,7 @@ def ensure_dict_value_to_list_(head_outputs: Dict[str, List[Tensor]], keys: Opti raise ValueError("The output of network should be Dict[str, List[Tensor]] or Dict[str, Tensor].") -def check_dict_values_same_length(head_outputs: Dict[str, List[Tensor]], keys: Optional[List[str]] = None) -> None: +def check_dict_values_same_length(head_outputs: dict[str, list[Tensor]], keys: list[str] | None = None) -> None: """ We expect the values in ``head_outputs``: Dict[str, List[Tensor]] to have the same length. Will raise ValueError if not. @@ -54,13 +54,13 @@ def check_dict_values_same_length(head_outputs: Dict[str, List[Tensor]], keys: O if keys is None: keys = list(head_outputs.keys()) - num_output_levels_list: List[int] = [len(head_outputs[k]) for k in keys] + num_output_levels_list: list[int] = [len(head_outputs[k]) for k in keys] num_output_levels = torch.unique(torch.tensor(num_output_levels_list)) if len(num_output_levels) != 1: raise ValueError(f"The values in the input dict should have the same length, Got {num_output_levels_list}.") -def _network_sequence_output(images: Tensor, network, keys: Optional[List[str]] = None) -> List[Tensor]: +def _network_sequence_output(images: Tensor, network: nn.Module, keys: list[str] | None = None) -> list[Tensor]: """ Decompose the output of network (a dict) into a list. @@ -84,8 +84,8 @@ def _network_sequence_output(images: Tensor, network, keys: Optional[List[str]] def predict_with_inferer( - images: Tensor, network, keys: List[str], inferer: Optional[SlidingWindowInferer] = None -) -> Dict[str, List[Tensor]]: + images: Tensor, network: nn.Module, keys: list[str], inferer: SlidingWindowInferer | None = None +) -> dict[str, list[Tensor]]: """ Predict network dict output with an inferer. Compared with directly output network(images), it enables a sliding window inferer that can be used to handle large inputs. diff --git a/monai/apps/mmars/__init__.py b/monai/apps/mmars/__init__.py index 8f1448bb06..5534ed6951 100644 --- a/monai/apps/mmars/__init__.py +++ b/monai/apps/mmars/__init__.py @@ -9,5 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .mmars import download_mmar, get_model_spec, load_from_mmar from .model_desc import MODEL_DESC, RemoteMMARKeys diff --git a/monai/apps/mmars/mmars.py b/monai/apps/mmars/mmars.py index 6e1770b19e..31c88a17be 100644 --- a/monai/apps/mmars/mmars.py +++ b/monai/apps/mmars/mmars.py @@ -15,11 +15,14 @@ - https://docs.nvidia.com/clara/clara-train-sdk/pt/mmar.html """ +from __future__ import annotations + import json import os import warnings +from collections.abc import Mapping from pathlib import Path -from typing import Mapping, Optional, Union +from typing import Any import torch @@ -35,7 +38,7 @@ __all__ = ["get_model_spec", "download_mmar", "load_from_mmar"] -def get_model_spec(idx: Union[int, str]): +def get_model_spec(idx: int | str) -> dict | Any: """get model specification by `idx`. `idx` could be index of the constant tuple of dict or the actual model ID.""" if isinstance(idx, int): return MODEL_DESC[idx] @@ -92,17 +95,17 @@ def _get_all_ngc_models(pattern, page_index=0, page_size=50): return model_dict -def _get_ngc_url(model_name: str, version: str, model_prefix=""): +def _get_ngc_url(model_name: str, version: str, model_prefix: str = "") -> str: return f"https://api.ngc.nvidia.com/v2/models/{model_prefix}{model_name}/versions/{version}/zip" -def _get_ngc_doc_url(model_name: str, model_prefix=""): +def _get_ngc_doc_url(model_name: str, model_prefix: str = "") -> str: return f"https://ngc.nvidia.com/catalog/models/{model_prefix}{model_name}" def download_mmar( - item, mmar_dir: Optional[PathLike] = None, progress: bool = True, api: bool = True, version: int = -1 -): + item: str | Mapping, mmar_dir: PathLike | None = None, progress: bool = True, api: bool = True, version: int = -1 +) -> Path: """ Download and extract Medical Model Archive (MMAR) from Nvidia Clara Train. @@ -135,19 +138,20 @@ def download_mmar( mmar_dir = Path(get_dir()) / "mmars" else: raise ValueError("mmar_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?") - mmar_dir = Path(mmar_dir) + _mmar_dir = Path(mmar_dir) + model_dir: Path if api: model_dict = _get_all_ngc_models(item.get(Keys.NAME, f"{item}") if isinstance(item, Mapping) else f"{item}") if len(model_dict) == 0: raise ValueError(f"api query returns no item for pattern {item}. Please change or shorten it.") - model_dir_list = [] + model_dir_list: list[Path] = [] for k, v in model_dict.items(): ver = v["latest"] if version == -1 else str(version) download_url = _get_ngc_url(k, ver) - model_dir = mmar_dir / v["name"] + model_dir = _mmar_dir / v["name"] download_and_extract( url=download_url, - filepath=mmar_dir / f'{v["name"]}_{ver}.zip', + filepath=_mmar_dir / f'{v["name"]}_{ver}.zip', output_dir=model_dir, hash_val=None, hash_type="md5", @@ -166,11 +170,11 @@ def download_mmar( if version > 0: ver = str(version) model_fullname = f"{item[Keys.NAME]}_{ver}" - model_dir = mmar_dir / model_fullname + model_dir = _mmar_dir / model_fullname model_url = item.get(Keys.URL) or _get_ngc_url(item[Keys.NAME], version=ver, model_prefix="nvidia/med/") download_and_extract( url=model_url, - filepath=mmar_dir / f"{model_fullname}.{item[Keys.FILE_TYPE]}", + filepath=_mmar_dir / f"{model_fullname}.{item[Keys.FILE_TYPE]}", output_dir=model_dir, hash_val=item[Keys.HASH_VAL], hash_type=item[Keys.HASH_TYPE], @@ -182,17 +186,17 @@ def download_mmar( def load_from_mmar( - item, - mmar_dir: Optional[PathLike] = None, + item: Mapping | str | int, + mmar_dir: PathLike | None = None, progress: bool = True, version: int = -1, - map_location=None, - pretrained=True, - weights_only=False, + map_location: Any | None = None, + pretrained: bool = True, + weights_only: bool = False, model_key: str = "model", api: bool = True, - model_file=None, -): + model_file: PathLike | None = None, +) -> Any: """ Download and extract Medical Model Archive (MMAR) model weights from Nvidia Clara Train. @@ -225,19 +229,19 @@ def load_from_mmar( model_dir = download_mmar(item=item, mmar_dir=mmar_dir, progress=progress, version=version, api=api) if model_file is None: model_file = os.path.join("models", "model.pt") - model_file = model_dir / item.get(Keys.MODEL_FILE, model_file) + _model_file = model_dir / item.get(Keys.MODEL_FILE, model_file) logger.info(f'\n*** "{item.get(Keys.NAME)}" available at {model_dir}.') # loading with `torch.jit.load` - if model_file.name.endswith(".ts"): + if _model_file.name.endswith(".ts"): if not pretrained: warnings.warn("Loading a ScriptModule, 'pretrained' option ignored.") if weights_only: warnings.warn("Loading a ScriptModule, 'weights_only' option ignored.") - return torch.jit.load(model_file, map_location=map_location) + return torch.jit.load(_model_file, map_location=map_location) # loading with `torch.load` - model_dict = torch.load(model_file, map_location=map_location) + model_dict = torch.load(_model_file, map_location=map_location) if weights_only: return model_dict.get(model_key, model_dict) # model_dict[model_key] or model_dict directly @@ -294,7 +298,7 @@ def load_from_mmar( return model_inst -def _get_val(input_dict: Mapping, key="model", default=None): +def _get_val(input_dict: Mapping, key: str = "model", default: Any | None = None) -> Any | None: """ Search for the item with `key` in `config_dict`. Returns: the first occurrence of `key` in a breadth first search. diff --git a/monai/apps/mmars/model_desc.py b/monai/apps/mmars/model_desc.py index e0a7f26117..a3963689fb 100644 --- a/monai/apps/mmars/model_desc.py +++ b/monai/apps/mmars/model_desc.py @@ -15,8 +15,10 @@ - https://docs.nvidia.com/clara/clara-train-sdk/pt/mmar.html """ +from __future__ import annotations + import os -from typing import Any, Dict, Tuple +from typing import Any __all__ = ["MODEL_DESC", "RemoteMMARKeys"] @@ -39,7 +41,7 @@ class RemoteMMARKeys: VERSION = "version" # version of the MMAR -MODEL_DESC: Tuple[Dict[Any, Any], ...] = ( +MODEL_DESC: tuple[dict[Any, Any], ...] = ( { RemoteMMARKeys.ID: "clara_pt_spleen_ct_segmentation_1", RemoteMMARKeys.NAME: "clara_pt_spleen_ct_segmentation", diff --git a/monai/apps/nuclick/transforms.py b/monai/apps/nuclick/transforms.py index f080961e4c..f22ea764be 100644 --- a/monai/apps/nuclick/transforms.py +++ b/monai/apps/nuclick/transforms.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import math -from typing import Any, Optional, Tuple, Union +from typing import Any import numpy as np import torch @@ -87,7 +89,7 @@ def __init__( self, keys: KeysCollection, centroid_key: str = NuclickKeys.CENTROID, - patch_size: Union[Tuple[int, int], int] = 128, + patch_size: tuple[int, int] | int = 128, allow_missing_keys: bool = False, **kwargs: Any, ): @@ -144,12 +146,11 @@ def __init__( self, keys: KeysCollection, others: str = NuclickKeys.OTHERS, - mask_value: Optional[str] = NuclickKeys.MASK_VALUE, + mask_value: str | None = NuclickKeys.MASK_VALUE, min_area: int = 5, others_value: int = 0, to_binary_mask: bool = True, ): - super().__init__(keys, allow_missing_keys=False) self.others = others self.mask_value = mask_value @@ -409,10 +410,10 @@ def __init__( image: str = NuclickKeys.IMAGE, foreground: str = NuclickKeys.FOREGROUND, bb_size: int = 128, - gaussian=False, - sigma=1.0, + gaussian: bool = False, + sigma: float = 1.0, truncated: float = 2.0, - add_exclusion_map=True, + add_exclusion_map: bool = True, ): self.image = image self.foreground = foreground @@ -601,7 +602,7 @@ class AddLabelAsGuidanced(MapTransform): source: label/source key which gets added as additional guidance channel """ - def __init__(self, keys: KeysCollection, source="label") -> None: + def __init__(self, keys: KeysCollection, source: str = "label") -> None: super().__init__(keys, allow_missing_keys=False) self.source = source @@ -627,7 +628,7 @@ class SetLabelClassd(MapTransform): offset: offset value to be added to the mask value to determine the final class """ - def __init__(self, keys: KeysCollection, offset=-1) -> None: + def __init__(self, keys: KeysCollection, offset: int = -1) -> None: super().__init__(keys, allow_missing_keys=False) self.offset = offset diff --git a/monai/apps/pathology/__init__.py b/monai/apps/pathology/__init__.py index da5cec7e6c..3de1c754c2 100644 --- a/monai/apps/pathology/__init__.py +++ b/monai/apps/pathology/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .data import MaskedInferenceWSIDataset, PatchWSIDataset, SmartCachePatchWSIDataset from .handlers import ProbMapProducer from .losses import HoVerNetLoss diff --git a/monai/apps/pathology/data/__init__.py b/monai/apps/pathology/data/__init__.py index e1b2ef7bd2..cfdd71cfe3 100644 --- a/monai/apps/pathology/data/__init__.py +++ b/monai/apps/pathology/data/__init__.py @@ -9,4 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .datasets import MaskedInferenceWSIDataset, PatchWSIDataset, SmartCachePatchWSIDataset diff --git a/monai/apps/pathology/data/datasets.py b/monai/apps/pathology/data/datasets.py index 23ac9ac062..70a726e798 100644 --- a/monai/apps/pathology/data/datasets.py +++ b/monai/apps/pathology/data/datasets.py @@ -9,9 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import sys -from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union, cast +from collections.abc import Callable, Sequence +from typing import Any, Tuple, cast import numpy as np @@ -52,13 +55,13 @@ class PatchWSIDataset(Dataset): def __init__( self, - data: List, - region_size: Union[int, Tuple[int, int]], - grid_shape: Union[int, Tuple[int, int]], - patch_size: Union[int, Tuple[int, int]], - transform: Optional[Callable] = None, + data: list, + region_size: int | tuple[int, int], + grid_shape: int | tuple[int, int], + patch_size: int | tuple[int, int], + transform: Callable | None = None, image_reader_name: str = "cuCIM", - **kwargs, + **kwargs: Any, ): super().__init__(data, transform) @@ -69,7 +72,7 @@ def __init__( self.image_path_list = list({x["image"] for x in self.data}) self.image_reader_name = image_reader_name.lower() self.image_reader = WSIReader(backend=image_reader_name, **kwargs) - self.wsi_object_dict: Optional[Dict] = None + self.wsi_object_dict: dict | None = None if self.image_reader_name != "openslide": # OpenSlide causes memory issue if we prefetch image objects self._fetch_wsi_objects() @@ -85,7 +88,7 @@ def __getitem__(self, index): if self.image_reader_name == "openslide": img_obj = self.image_reader.read(sample["image"]) else: - img_obj = cast(Dict, self.wsi_object_dict)[sample["image"]] + img_obj = cast(dict, self.wsi_object_dict)[sample["image"]] location = [sample["location"][i] - self.region_size[i] // 2 for i in range(len(self.region_size))] images, _ = self.image_reader.get_data( img=img_obj, @@ -140,21 +143,21 @@ class SmartCachePatchWSIDataset(SmartCacheDataset): def __init__( self, - data: List, - region_size: Union[int, Tuple[int, int]], - grid_shape: Union[int, Tuple[int, int]], - patch_size: Union[int, Tuple[int, int]], - transform: Union[Sequence[Callable], Callable], + data: list, + region_size: int | tuple[int, int], + grid_shape: int | tuple[int, int], + patch_size: int | tuple[int, int], + transform: Sequence[Callable] | Callable, image_reader_name: str = "cuCIM", replace_rate: float = 0.5, cache_num: int = sys.maxsize, cache_rate: float = 1.0, - num_init_workers: Optional[int] = 1, - num_replace_workers: Optional[int] = 1, + num_init_workers: int | None = 1, + num_replace_workers: int | None = 1, progress: bool = True, copy_cache: bool = True, as_contiguous: bool = True, - **kwargs, + **kwargs: Any, ): patch_wsi_dataset = PatchWSIDataset( data=data, @@ -201,11 +204,11 @@ class MaskedInferenceWSIDataset(Dataset): def __init__( self, - data: List[Dict["str", "str"]], - patch_size: Union[int, Tuple[int, int]], - transform: Optional[Callable] = None, + data: list[dict[str, str]], + patch_size: int | tuple[int, int], + transform: Callable | None = None, image_reader_name: str = "cuCIM", - **kwargs, + **kwargs: Any, ) -> None: super().__init__(data, transform) @@ -223,14 +226,14 @@ def __init__( self.num_patches = sum(self.num_patches_per_sample) self.cum_num_patches = np.cumsum([0] + self.num_patches_per_sample[:-1]) - def _prepare_data(self, input_data: List[Dict["str", "str"]]) -> List[Dict]: + def _prepare_data(self, input_data: list[dict[str, str]]) -> list[dict]: prepared_data = [] for sample in input_data: prepared_sample = self._prepare_a_sample(sample) prepared_data.append(prepared_sample) return prepared_data - def _prepare_a_sample(self, sample: Dict["str", "str"]) -> Dict: + def _prepare_a_sample(self, sample: dict[str, str]) -> dict: """ Preprocess input data to load WSIReader object and the foreground mask, and define the locations where patches need to be extracted. @@ -272,7 +275,7 @@ def _prepare_a_sample(self, sample: Dict["str", "str"]) -> Dict: "level": level, } - def _calculate_mask_level(self, image: np.ndarray, mask: np.ndarray) -> Tuple[int, float]: + def _calculate_mask_level(self, image: np.ndarray, mask: np.ndarray) -> tuple[int, float]: """ Calculate level of the mask and its ratio with respect to the whole slide image diff --git a/monai/apps/pathology/engines/__init__.py b/monai/apps/pathology/engines/__init__.py index 68c084d40d..b32c148c2e 100644 --- a/monai/apps/pathology/engines/__init__.py +++ b/monai/apps/pathology/engines/__init__.py @@ -9,4 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .utils import PrepareBatchHoVerNet diff --git a/monai/apps/pathology/engines/utils.py b/monai/apps/pathology/engines/utils.py index 895638a01b..02249ed640 100644 --- a/monai/apps/pathology/engines/utils.py +++ b/monai/apps/pathology/engines/utils.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Sequence, Union +from __future__ import annotations + +from typing import Any, Sequence import torch @@ -38,11 +40,11 @@ def __init__(self, extra_keys: Sequence[str]) -> None: def __call__( self, - batchdata: Dict[str, torch.Tensor], - device: Optional[Union[str, torch.device]] = None, + batchdata: dict[str, torch.Tensor], + device: str | torch.device | None = None, non_blocking: bool = False, - **kwargs, - ): + **kwargs: Any, + ) -> tuple[torch.Tensor, dict[HoVerNetBranch, torch.Tensor]]: """ Args `batchdata`, `device`, `non_blocking` refer to the ignite API: https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html. diff --git a/monai/apps/pathology/handlers/__init__.py b/monai/apps/pathology/handlers/__init__.py index 0638950bd8..3f2aab272f 100644 --- a/monai/apps/pathology/handlers/__init__.py +++ b/monai/apps/pathology/handlers/__init__.py @@ -9,4 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .prob_map_producer import ProbMapProducer diff --git a/monai/apps/pathology/handlers/prob_map_producer.py b/monai/apps/pathology/handlers/prob_map_producer.py index d5b1b50c47..7b43653710 100644 --- a/monai/apps/pathology/handlers/prob_map_producer.py +++ b/monai/apps/pathology/handlers/prob_map_producer.py @@ -9,9 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import os -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING import numpy as np @@ -38,11 +40,7 @@ class ProbMapProducer: """ def __init__( - self, - output_dir: str = "./", - output_postfix: str = "", - dtype: DtypeLike = np.float64, - name: Optional[str] = None, + self, output_dir: str = "./", output_postfix: str = "", dtype: DtypeLike = np.float64, name: str | None = None ) -> None: """ Args: @@ -57,9 +55,9 @@ def __init__( self.output_dir = output_dir self.output_postfix = output_postfix self.dtype = dtype - self.prob_map: Dict[str, np.ndarray] = {} - self.level: Dict[str, int] = {} - self.counter: Dict[str, int] = {} + self.prob_map: dict[str, np.ndarray] = {} + self.level: dict[str, int] = {} + self.counter: dict[str, int] = {} self.num_done_images: int = 0 self.num_images: int = 0 @@ -120,5 +118,5 @@ def save_prob_map(self, name: str) -> None: del self.counter[name] del self.level[name] - def finalize(self, engine: Engine): + def finalize(self, engine: Engine) -> None: self.logger.info(f"Probability map is created for {self.num_done_images}/{self.num_images} images!") diff --git a/monai/apps/pathology/handlers/utils.py b/monai/apps/pathology/handlers/utils.py index 8daac57143..4c11f2e859 100644 --- a/monai/apps/pathology/handlers/utils.py +++ b/monai/apps/pathology/handlers/utils.py @@ -8,13 +8,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Hashable, Tuple + +from __future__ import annotations + +from collections.abc import Callable, Hashable +from typing import Any from monai.config import KeysCollection from monai.utils import ensure_tuple -def from_engine_hovernet(keys: KeysCollection, nested_key: str): +def from_engine_hovernet(keys: KeysCollection, nested_key: str) -> Callable[[Any], Any]: """ Since the output of HoVerNet is a dictionary, this function is to extend `monai.handlers.from_engine` to work with HoVerNet. @@ -38,7 +42,7 @@ def from_engine_hovernet(keys: KeysCollection, nested_key: str): nested_key: specified key to extract nested data from dictionary or decollated list of dictionaries. """ - _keys: Tuple[Hashable, ...] = ensure_tuple(keys) + _keys: tuple[Hashable, ...] = ensure_tuple(keys) def _wrapper(data): if isinstance(data, dict): diff --git a/monai/apps/pathology/inferers/__init__.py b/monai/apps/pathology/inferers/__init__.py index c3571c87a3..3549b8ec29 100644 --- a/monai/apps/pathology/inferers/__init__.py +++ b/monai/apps/pathology/inferers/__init__.py @@ -9,4 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .inferer import SlidingWindowHoVerNetInferer diff --git a/monai/apps/pathology/inferers/inferer.py b/monai/apps/pathology/inferers/inferer.py index 1aacb5d99d..da4ac4fd7a 100644 --- a/monai/apps/pathology/inferers/inferer.py +++ b/monai/apps/pathology/inferers/inferer.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from typing import Any, Callable, Sequence import numpy as np import torch @@ -73,19 +75,19 @@ class SlidingWindowHoVerNetInferer(SlidingWindowInferer): def __init__( self, - roi_size: Union[Sequence[int], int], + roi_size: Sequence[int] | int, sw_batch_size: int = 1, overlap: float = 0.25, - mode: Union[BlendMode, str] = BlendMode.CONSTANT, - sigma_scale: Union[Sequence[float], float] = 0.125, - padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, + mode: BlendMode | str = BlendMode.CONSTANT, + sigma_scale: Sequence[float] | float = 0.125, + padding_mode: PytorchPadMode | str = PytorchPadMode.CONSTANT, cval: float = 0.0, - sw_device: Optional[Union[torch.device, str]] = None, - device: Optional[Union[torch.device, str]] = None, + sw_device: torch.device | str | None = None, + device: torch.device | str | None = None, progress: bool = False, cache_roi_weight_map: bool = False, - cpu_thresh: Optional[int] = None, - extra_input_padding: Optional[Tuple[int]] = None, + cpu_thresh: int | None = None, + extra_input_padding: tuple[int] | None = None, ) -> None: super().__init__( roi_size=roi_size, @@ -130,10 +132,10 @@ def process_output(self, seg_prob_tuple, window_data, importance_map_): def __call__( self, inputs: torch.Tensor, - network: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]], + network: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]], *args: Any, **kwargs: Any, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]: + ) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]: """ Args: @@ -179,7 +181,7 @@ def __call__( ) if self.extra_input_padding: - extra_slicing: List[slice] = [] + extra_slicing: list[slice] = [] num_padded_dims = len(self.extra_input_padding) // 2 for sp in range(num_padded_dims): slice_dim = slice( diff --git a/monai/apps/pathology/losses/__init__.py b/monai/apps/pathology/losses/__init__.py index 5e960b34cf..09c4d43836 100644 --- a/monai/apps/pathology/losses/__init__.py +++ b/monai/apps/pathology/losses/__init__.py @@ -9,4 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .hovernet_loss import HoVerNetLoss diff --git a/monai/apps/pathology/losses/hovernet_loss.py b/monai/apps/pathology/losses/hovernet_loss.py index 5f35d9c509..72b4d8a6ff 100644 --- a/monai/apps/pathology/losses/hovernet_loss.py +++ b/monai/apps/pathology/losses/hovernet_loss.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict +from __future__ import annotations import torch from torch.nn import CrossEntropyLoss @@ -95,7 +95,7 @@ def _mse_gradient_loss(self, prediction: torch.Tensor, target: torch.Tensor, foc return loss - def forward(self, prediction: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor]) -> torch.Tensor: + def forward(self, prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor]) -> torch.Tensor: """ Args: prediction: dictionary of predicted outputs for three branches, diff --git a/monai/apps/pathology/metrics/__init__.py b/monai/apps/pathology/metrics/__init__.py index f19811dcaf..4f77d6a852 100644 --- a/monai/apps/pathology/metrics/__init__.py +++ b/monai/apps/pathology/metrics/__init__.py @@ -9,4 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .lesion_froc import LesionFROC diff --git a/monai/apps/pathology/metrics/lesion_froc.py b/monai/apps/pathology/metrics/lesion_froc.py index 6c7965bae6..d3f541bfa3 100644 --- a/monai/apps/pathology/metrics/lesion_froc.py +++ b/monai/apps/pathology/metrics/lesion_froc.py @@ -9,11 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, List, Tuple +from __future__ import annotations + +from typing import TYPE_CHECKING import numpy as np from monai.apps.pathology.utils import PathologyProbNMS, compute_isolated_tumor_cells, compute_multi_instance_mask +from monai.config import NdarrayOrTensor from monai.data.image_reader import WSIReader from monai.metrics import compute_fp_tp_probs, compute_froc_curve_data, compute_froc_score from monai.utils import min_version, optional_import @@ -63,16 +66,15 @@ class LesionFROC: def __init__( self, - data: List[Dict], + data: list[dict], grow_distance: int = 75, itc_diameter: int = 200, - eval_thresholds: Tuple = (0.25, 0.5, 1, 2, 4, 8), + eval_thresholds: tuple = (0.25, 0.5, 1, 2, 4, 8), nms_sigma: float = 0.0, nms_prob_threshold: float = 0.5, nms_box_size: int = 48, image_reader_name: str = "cuCIM", ) -> None: - self.data = data self.grow_distance = grow_distance self.itc_diameter = itc_diameter @@ -80,7 +82,7 @@ def __init__( self.image_reader = WSIReader(image_reader_name) self.nms = PathologyProbNMS(sigma=nms_sigma, prob_threshold=nms_prob_threshold, box_size=nms_box_size) - def prepare_inference_result(self, sample: Dict): + def prepare_inference_result(self, sample: dict) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Prepare the probability map for detection evaluation. @@ -127,7 +129,8 @@ def compute_fp_tp(self): by comparing the model outputs with the prepared ground truths for all samples """ - total_fp_probs, total_tp_probs = [], [] + total_fp_probs: list[NdarrayOrTensor] = [] + total_tp_probs: list[NdarrayOrTensor] = [] total_num_targets = 0 num_images = len(self.data) diff --git a/monai/apps/pathology/transforms/__init__.py b/monai/apps/pathology/transforms/__init__.py index c30d2c1a76..af2254074a 100644 --- a/monai/apps/pathology/transforms/__init__.py +++ b/monai/apps/pathology/transforms/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .post.array import ( GenerateDistanceMap, GenerateInstanceBorder, diff --git a/monai/apps/pathology/transforms/post/__init__.py b/monai/apps/pathology/transforms/post/__init__.py index c5b928b991..dfdf7d31eb 100644 --- a/monai/apps/pathology/transforms/post/__init__.py +++ b/monai/apps/pathology/transforms/post/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .array import ( GenerateDistanceMap, GenerateInstanceBorder, diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index aa54147b3d..5289dc101c 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Callable, Sequence import numpy as np import torch @@ -68,12 +70,12 @@ class Watershed(Transform): backend = [TransformBackends.NUMPY] - def __init__(self, connectivity: Optional[int] = 1, dtype: DtypeLike = np.int64) -> None: + def __init__(self, connectivity: int | None = 1, dtype: DtypeLike = np.int64) -> None: self.connectivity = connectivity self.dtype = dtype def __call__( - self, image: NdarrayOrTensor, mask: Optional[NdarrayOrTensor] = None, markers: Optional[NdarrayOrTensor] = None + self, image: NdarrayOrTensor, mask: NdarrayOrTensor | None = None, markers: NdarrayOrTensor | None = None ) -> NdarrayOrTensor: """ Args: @@ -112,8 +114,8 @@ class GenerateWatershedMask(Transform): def __init__( self, - activation: Union[str, Callable] = "softmax", - threshold: Optional[float] = None, + activation: str | Callable = "softmax", + threshold: float | None = None, min_object_size: int = 10, dtype: DtypeLike = np.uint8, ) -> None: @@ -249,8 +251,7 @@ class GenerateDistanceMap(Transform): backend = [TransformBackends.NUMPY] - def __init__(self, smooth_fn: Optional[Callable] = None, dtype: DtypeLike = np.float32) -> None: - + def __init__(self, smooth_fn: Callable | None = None, dtype: DtypeLike = np.float32) -> None: self.smooth_fn = smooth_fn if smooth_fn is not None else GaussianSmooth() self.dtype = dtype @@ -304,7 +305,7 @@ def __init__( threshold: float = 0.4, radius: int = 2, min_object_size: int = 10, - postprocess_fn: Optional[Callable] = None, + postprocess_fn: Callable | None = None, dtype: DtypeLike = np.int64, ) -> None: self.threshold = threshold @@ -367,7 +368,7 @@ def __init__(self, height: int, width: int) -> None: self.height = height self.width = width - def _generate_contour_coord(self, current: np.ndarray, previous: np.ndarray) -> Tuple[int, int]: + def _generate_contour_coord(self, current: np.ndarray, previous: np.ndarray) -> tuple[int, int]: """ Generate contour coordinates. Given the previous and current coordinates of border positions, returns the int pixel that marks the extremity of the segmented pixels. @@ -394,7 +395,7 @@ def _generate_contour_coord(self, current: np.ndarray, previous: np.ndarray) -> return row, col - def _calculate_distance_from_top_left(self, sequence: Sequence[Tuple[int, int]]) -> int: + def _calculate_distance_from_top_left(self, sequence: Sequence[tuple[int, int]]) -> int: """ Each sequence of coordinates describes a boundary between foreground and background starting and ending at two sides of the bounding box. To order the sequences correctly, we compute the distance from the top-left of the bounding box @@ -419,18 +420,18 @@ def _calculate_distance_from_top_left(self, sequence: Sequence[Tuple[int, int]]) return distance - def __call__(self, contours: List[np.ndarray]) -> np.ndarray: + def __call__(self, contours: list[np.ndarray]) -> np.ndarray: """ Args: contours: list of (n, 2)-ndarrays, scipy-style clockwise line segments, with lines separating foreground/background. Each contour is an ndarray of shape (n, 2), consisting of n (row, column) coordinates along the contour. """ - pixels: List[Tuple[int, int]] = [] + pixels: list[tuple[int, int]] = [] sequences = [] corners = [False, False, False, False] for group in contours: - sequence: List[Tuple[int, int]] = [] + sequence: list[tuple[int, int]] = [] last_added = None prev = None corner = -1 @@ -544,11 +545,11 @@ class GenerateInstanceContour(Transform): backend = [TransformBackends.NUMPY] - def __init__(self, min_num_points: int = 3, contour_level: Optional[float] = None) -> None: + def __init__(self, min_num_points: int = 3, contour_level: float | None = None) -> None: self.contour_level = contour_level self.min_num_points = min_num_points - def __call__(self, inst_mask: NdarrayOrTensor, offset: Optional[Sequence[int]] = (0, 0)) -> Optional[np.ndarray]: + def __call__(self, inst_mask: NdarrayOrTensor, offset: Sequence[int] | None = (0, 0)) -> np.ndarray | None: """ Args: inst_mask: segmentation mask for a single instance. Shape should be [1, H, W, [D]] @@ -587,10 +588,10 @@ class GenerateInstanceCentroid(Transform): backend = [TransformBackends.NUMPY] - def __init__(self, dtype: Optional[DtypeLike] = int) -> None: + def __init__(self, dtype: DtypeLike | None = int) -> None: self.dtype = dtype - def __call__(self, inst_mask: NdarrayOrTensor, offset: Union[Sequence[int], int] = 0) -> NdarrayOrTensor: + def __call__(self, inst_mask: NdarrayOrTensor, offset: Sequence[int] | int = 0) -> NdarrayOrTensor: """ Args: inst_mask: segmentation mask for a single instance. Shape should be [1, H, W, [D]] @@ -618,7 +619,7 @@ class GenerateInstanceType(Transform): def __call__( # type: ignore self, type_pred: NdarrayOrTensor, seg_pred: NdarrayOrTensor, bbox: np.ndarray, instance_id: int - ) -> Tuple[int, float]: + ) -> tuple[int, float]: """ Args: type_pred: pixel-level type prediction map after activation function. @@ -674,17 +675,17 @@ class HoVerNetInstanceMapPostProcessing(Transform): def __init__( self, - activation: Union[str, Callable] = "softmax", - mask_threshold: Optional[float] = None, + activation: str | Callable = "softmax", + mask_threshold: float | None = None, min_object_size: int = 10, sobel_kernel_size: int = 5, - distance_smooth_fn: Optional[Callable] = None, + distance_smooth_fn: Callable | None = None, marker_threshold: float = 0.4, marker_radius: int = 2, - marker_postprocess_fn: Optional[Callable] = None, - watershed_connectivity: Optional[int] = 1, + marker_postprocess_fn: Callable | None = None, + watershed_connectivity: int | None = 1, min_num_points: int = 3, - contour_level: Optional[float] = None, + contour_level: float | None = None, ) -> None: super().__init__() @@ -707,7 +708,7 @@ def __init__( def __call__( # type: ignore self, nuclear_prediction: NdarrayOrTensor, hover_map: NdarrayOrTensor - ) -> Tuple[Dict, NdarrayOrTensor]: + ) -> tuple[dict, NdarrayOrTensor]: """post-process instance segmentation branches (NP and HV) to generate instance segmentation map. Args: @@ -761,10 +762,7 @@ class HoVerNetNuclearTypePostProcessing(Transform): """ def __init__( - self, - activation: Union[str, Callable] = "softmax", - threshold: Optional[float] = None, - return_type_map: bool = True, + self, activation: str | Callable = "softmax", threshold: float | None = None, return_type_map: bool = True ) -> None: super().__init__() self.return_type_map = return_type_map @@ -795,8 +793,8 @@ def __init__( self.as_discrete = AsDiscrete(threshold=threshold, argmax=use_softmax) def __call__( # type: ignore - self, type_prediction: NdarrayOrTensor, instance_info: Dict[int, Dict], instance_map: NdarrayOrTensor - ) -> Tuple[Dict, Optional[NdarrayOrTensor]]: + self, type_prediction: NdarrayOrTensor, instance_info: dict[int, dict], instance_map: NdarrayOrTensor + ) -> tuple[dict, NdarrayOrTensor | None]: """Process NC (type prediction) branch and combine it with instance segmentation It updates the instance_info with instance type and associated probability, and generate instance type map. diff --git a/monai/apps/pathology/transforms/post/dictionary.py b/monai/apps/pathology/transforms/post/dictionary.py index ee9654eaca..ef6de1b596 100644 --- a/monai/apps/pathology/transforms/post/dictionary.py +++ b/monai/apps/pathology/transforms/post/dictionary.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Hashable, Mapping, Optional, Union +from __future__ import annotations + +from collections.abc import Callable, Hashable, Mapping import numpy as np @@ -100,9 +102,9 @@ class Watershedd(MapTransform): def __init__( self, keys: KeysCollection, - mask_key: Optional[str] = "mask", - markers_key: Optional[str] = None, - connectivity: Optional[int] = 1, + mask_key: str | None = "mask", + markers_key: str | None = None, + connectivity: int | None = 1, dtype: DtypeLike = np.uint8, allow_missing_keys: bool = False, ) -> None: @@ -111,7 +113,7 @@ def __init__( self.markers_key = markers_key self.transform = Watershed(connectivity=connectivity, dtype=dtype) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) markers = d[self.markers_key] if self.markers_key else None mask = d[self.mask_key] if self.mask_key else None @@ -144,8 +146,8 @@ def __init__( self, keys: KeysCollection, mask_key: str = "mask", - activation: Union[str, Callable] = "softmax", - threshold: Optional[float] = None, + activation: str | Callable = "softmax", + threshold: float | None = None, min_object_size: int = 10, dtype: DtypeLike = np.uint8, allow_missing_keys: bool = False, @@ -156,7 +158,7 @@ def __init__( activation=activation, threshold=threshold, min_object_size=min_object_size, dtype=dtype ) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): mask = self.transform(d[key]) @@ -199,7 +201,7 @@ def __init__( self.border_key = border_key self.transform = GenerateInstanceBorder(kernel_size=kernel_size, dtype=dtype) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) if self.border_key in d: raise KeyError(f"The key '{self.border_key}' for instance border map already exists.") @@ -227,7 +229,7 @@ def __init__( mask_key: str = "mask", border_key: str = "border", dist_map_key: str = "dist_map", - smooth_fn: Optional[Callable] = None, + smooth_fn: Callable | None = None, dtype: DtypeLike = np.float32, ) -> None: self.mask_key = mask_key @@ -235,7 +237,7 @@ def __init__( self.dist_map_key = dist_map_key self.transform = GenerateDistanceMap(smooth_fn=smooth_fn, dtype=dtype) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) if self.dist_map_key in d: raise KeyError(f"The key '{self.dist_map_key}' for distance map already exists.") @@ -270,7 +272,7 @@ def __init__( threshold: float = 0.4, radius: int = 2, min_object_size: int = 10, - postprocess_fn: Optional[Callable] = None, + postprocess_fn: Callable | None = None, dtype: DtypeLike = np.uint8, ) -> None: self.mask_key = mask_key @@ -284,7 +286,7 @@ def __init__( dtype=dtype, ) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) if self.markers_key in d: raise KeyError(f"The key '{self.markers_key}' for markers already exists.") @@ -345,9 +347,9 @@ def __init__( self, keys: KeysCollection, contour_key_postfix: str = "contour", - offset_key: Optional[str] = None, + offset_key: str | None = None, min_num_points: int = 3, - level: Optional[float] = None, + level: float | None = None, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) @@ -388,8 +390,8 @@ def __init__( self, keys: KeysCollection, centroid_key_postfix: str = "centroid", - offset_key: Optional[str] = None, - dtype: Optional[DtypeLike] = int, + offset_key: str | None = None, + dtype: DtypeLike | None = int, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) @@ -494,17 +496,17 @@ def __init__( hover_map_key: str = HoVerNetBranch.HV.value, instance_info_key: str = "instance_info", instance_map_key: str = "instance_map", - activation: Union[str, Callable] = "softmax", - mask_threshold: Optional[float] = None, + activation: str | Callable = "softmax", + mask_threshold: float | None = None, min_object_size: int = 10, sobel_kernel_size: int = 5, - distance_smooth_fn: Optional[Callable] = None, + distance_smooth_fn: Callable | None = None, marker_threshold: float = 0.4, marker_radius: int = 2, - marker_postprocess_fn: Optional[Callable] = None, - watershed_connectivity: Optional[int] = 1, + marker_postprocess_fn: Callable | None = None, + watershed_connectivity: int | None = 1, min_num_points: int = 3, - contour_level: Optional[float] = None, + contour_level: float | None = None, ) -> None: super().__init__() self.instance_map_post_process = HoVerNetInstanceMapPostProcessing( @@ -561,8 +563,8 @@ def __init__( instance_info_key: str = "instance_info", instance_map_key: str = "instance_map", type_map_key: str = "type_map", - activation: Union[str, Callable] = "softmax", - threshold: Optional[float] = None, + activation: str | Callable = "softmax", + threshold: float | None = None, return_type_map: bool = True, ) -> None: super().__init__() diff --git a/monai/apps/pathology/transforms/spatial/__init__.py b/monai/apps/pathology/transforms/spatial/__init__.py index eed111d2b6..7e0f6c75d0 100644 --- a/monai/apps/pathology/transforms/spatial/__init__.py +++ b/monai/apps/pathology/transforms/spatial/__init__.py @@ -9,5 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .array import SplitOnGrid, TileOnGrid from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict diff --git a/monai/apps/pathology/transforms/spatial/array.py b/monai/apps/pathology/transforms/spatial/array.py index ce22b86d49..ea8a8c89a9 100644 --- a/monai/apps/pathology/transforms/spatial/array.py +++ b/monai/apps/pathology/transforms/spatial/array.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple, Union +from __future__ import annotations + +from collections.abc import Sequence import numpy as np import torch @@ -41,9 +43,7 @@ class SplitOnGrid(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__( - self, grid_size: Union[int, Tuple[int, int]] = (2, 2), patch_size: Optional[Union[int, Tuple[int, int]]] = None - ): + def __init__(self, grid_size: int | tuple[int, int] = (2, 2), patch_size: int | tuple[int, int] | None = None): # Grid size if isinstance(grid_size, int): self.grid_size = (grid_size, grid_size) @@ -137,9 +137,9 @@ class TileOnGrid(Randomizable, Transform): def __init__( self, - tile_count: Optional[int] = None, + tile_count: int | None = None, tile_size: int = 256, - step: Optional[int] = None, + step: int | None = None, random_offset: bool = False, pad_full: bool = False, background_val: int = 255, @@ -165,7 +165,6 @@ def __init__( raise ValueError("Unsupported filter_mode, must be [min, max or random]: " + str(self.filter_mode)) def randomize(self, img_size: Sequence[int]) -> None: - c, h, w = img_size self.offset = (0, 0) @@ -239,7 +238,6 @@ def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: else: if len(img_np) > self.tile_count: - if self.filter_mode == "min": # default, keep non-background tiles (smallest values) idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[: self.tile_count] diff --git a/monai/apps/pathology/transforms/spatial/dictionary.py b/monai/apps/pathology/transforms/spatial/dictionary.py index 022d82a053..8166a5891d 100644 --- a/monai/apps/pathology/transforms/spatial/dictionary.py +++ b/monai/apps/pathology/transforms/spatial/dictionary.py @@ -9,8 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import copy -from typing import Any, Dict, Hashable, List, Mapping, Optional, Tuple, Union +from collections.abc import Hashable, Mapping +from typing import Any from monai.config import KeysCollection from monai.config.type_definitions import NdarrayOrTensor @@ -43,14 +46,14 @@ class SplitOnGridd(MapTransform): def __init__( self, keys: KeysCollection, - grid_size: Union[int, Tuple[int, int]] = (2, 2), - patch_size: Optional[Union[int, Tuple[int, int]]] = None, + grid_size: int | tuple[int, int] = (2, 2), + patch_size: int | tuple[int, int] | None = None, allow_missing_keys: bool = False, ): super().__init__(keys, allow_missing_keys) self.splitter = SplitOnGrid(grid_size=grid_size, patch_size=patch_size) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.splitter(d[key]) @@ -87,9 +90,9 @@ class TileOnGridd(Randomizable, MapTransform): def __init__( self, keys: KeysCollection, - tile_count: Optional[int] = None, + tile_count: int | None = None, tile_size: int = 256, - step: Optional[int] = None, + step: int | None = None, random_offset: bool = False, pad_full: bool = False, background_val: int = 255, @@ -117,8 +120,7 @@ def randomize(self, data: Any = None) -> None: def __call__( self, data: Mapping[Hashable, NdarrayOrTensor] - ) -> Union[Dict[Hashable, NdarrayOrTensor], List[Dict[Hashable, NdarrayOrTensor]]]: - + ) -> dict[Hashable, NdarrayOrTensor] | list[dict[Hashable, NdarrayOrTensor]]: self.randomize() d = dict(data) diff --git a/monai/apps/pathology/transforms/stain/__init__.py b/monai/apps/pathology/transforms/stain/__init__.py index dfa235de55..3239fcbce3 100644 --- a/monai/apps/pathology/transforms/stain/__init__.py +++ b/monai/apps/pathology/transforms/stain/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .array import ExtractHEStains, NormalizeHEStains from .dictionary import ( ExtractHEStainsd, diff --git a/monai/apps/pathology/transforms/stain/array.py b/monai/apps/pathology/transforms/stain/array.py index 3b3a293451..5df9ad7ef3 100644 --- a/monai/apps/pathology/transforms/stain/array.py +++ b/monai/apps/pathology/transforms/stain/array.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from __future__ import annotations import numpy as np @@ -37,11 +37,7 @@ class ExtractHEStains(Transform): """ def __init__( - self, - tli: float = 240, - alpha: float = 1, - beta: float = 0.15, - max_cref: Union[tuple, np.ndarray] = (1.9705, 1.0308), + self, tli: float = 240, alpha: float = 1, beta: float = 0.15, max_cref: tuple | np.ndarray = (1.9705, 1.0308) ) -> None: self.tli = tli self.alpha = alpha @@ -145,8 +141,8 @@ def __init__( tli: float = 240, alpha: float = 1, beta: float = 0.15, - target_he: Union[tuple, np.ndarray] = ((0.5626, 0.2159), (0.7201, 0.8012), (0.4062, 0.5581)), - max_cref: Union[tuple, np.ndarray] = (1.9705, 1.0308), + target_he: tuple | np.ndarray = ((0.5626, 0.2159), (0.7201, 0.8012), (0.4062, 0.5581)), + max_cref: tuple | np.ndarray = (1.9705, 1.0308), ) -> None: self.tli = tli self.target_he = np.array(target_he) diff --git a/monai/apps/pathology/transforms/stain/dictionary.py b/monai/apps/pathology/transforms/stain/dictionary.py index eb8eba43f8..aa77301cb6 100644 --- a/monai/apps/pathology/transforms/stain/dictionary.py +++ b/monai/apps/pathology/transforms/stain/dictionary.py @@ -15,7 +15,9 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ -from typing import Dict, Hashable, Mapping, Union +from __future__ import annotations + +from collections.abc import Hashable, Mapping import numpy as np @@ -48,13 +50,13 @@ def __init__( tli: float = 240, alpha: float = 1, beta: float = 0.15, - max_cref: Union[tuple, np.ndarray] = (1.9705, 1.0308), + max_cref: tuple | np.ndarray = (1.9705, 1.0308), allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) self.extractor = ExtractHEStains(tli=tli, alpha=alpha, beta=beta, max_cref=max_cref) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]: d = dict(data) for key in self.key_iterator(d): d[key] = self.extractor(d[key]) @@ -93,14 +95,14 @@ def __init__( tli: float = 240, alpha: float = 1, beta: float = 0.15, - target_he: Union[tuple, np.ndarray] = ((0.5626, 0.2159), (0.7201, 0.8012), (0.4062, 0.5581)), - max_cref: Union[tuple, np.ndarray] = (1.9705, 1.0308), + target_he: tuple | np.ndarray = ((0.5626, 0.2159), (0.7201, 0.8012), (0.4062, 0.5581)), + max_cref: tuple | np.ndarray = (1.9705, 1.0308), allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) self.normalizer = NormalizeHEStains(tli=tli, alpha=alpha, beta=beta, target_he=target_he, max_cref=max_cref) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]: d = dict(data) for key in self.key_iterator(d): d[key] = self.normalizer(d[key]) diff --git a/monai/apps/pathology/utils.py b/monai/apps/pathology/utils.py index 5a57364a11..d3ebe0a7a6 100644 --- a/monai/apps/pathology/utils.py +++ b/monai/apps/pathology/utils.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from __future__ import annotations + +from typing import Any import numpy as np import torch @@ -21,7 +23,7 @@ ndimage, _ = optional_import("scipy.ndimage") -def compute_multi_instance_mask(mask: np.ndarray, threshold: float): +def compute_multi_instance_mask(mask: np.ndarray, threshold: float) -> Any: """ This method computes the segmentation mask according to the binary tumor mask. @@ -40,7 +42,7 @@ def compute_multi_instance_mask(mask: np.ndarray, threshold: float): return multi_instance_mask -def compute_isolated_tumor_cells(tumor_mask: np.ndarray, threshold: float) -> List[int]: +def compute_isolated_tumor_cells(tumor_mask: np.ndarray, threshold: float) -> list[int]: """ This method computes identifies Isolated Tumor Cells (ITC) and return their labels. @@ -50,7 +52,7 @@ def compute_isolated_tumor_cells(tumor_mask: np.ndarray, threshold: float) -> Li A region with the longest diameter less than this threshold is considered as an ITC. """ max_label = np.amax(tumor_mask) - properties = measure.regionprops(tumor_mask, coordinates="rc") + properties = measure.regionprops(tumor_mask) itc_list = [i + 1 for i in range(max_label) if properties[i].major_axis_length < threshold] return itc_list @@ -62,7 +64,7 @@ class PathologyProbNMS(ProbNMS): Pathology. """ - def __call__(self, probs_map: Union[np.ndarray, torch.Tensor], resolution_level: int = 0): + def __call__(self, probs_map: np.ndarray | torch.Tensor, resolution_level: int = 0) -> list[list]: """ probs_map: the input probabilities map, it must have shape (H[, W, ...]). resolution_level: the level at which the probabilities map is made. diff --git a/monai/apps/reconstruction/complex_utils.py b/monai/apps/reconstruction/complex_utils.py index 0a5cdccd0d..0436578f5d 100644 --- a/monai/apps/reconstruction/complex_utils.py +++ b/monai/apps/reconstruction/complex_utils.py @@ -12,8 +12,9 @@ This script contains utility functions for complex-value PyTorch tensor. """ +from __future__ import annotations + import re -from typing import Optional import numpy as np import torch @@ -24,9 +25,9 @@ def convert_to_tensor_complex( - data, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + data: NdarrayOrTensor | list | int | float, + dtype: torch.dtype | None = None, + device: torch.device | None = None, wrap_sequence: bool = True, track_meta: bool = False, ) -> Tensor: @@ -90,7 +91,7 @@ def convert_to_tensor_complex( elif isinstance(data, list): data = convert_to_numpy(data, wrap_sequence=True) - data = np.stack((data.real, data.imag), axis=-1).tolist() + data = np.stack((data.real, data.imag), axis=-1).tolist() # type: ignore converted_data = convert_to_tensor( data, dtype=dtype, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta diff --git a/monai/apps/reconstruction/fastmri_reader.py b/monai/apps/reconstruction/fastmri_reader.py index 52bd5c0db3..5f0d6c2d23 100644 --- a/monai/apps/reconstruction/fastmri_reader.py +++ b/monai/apps/reconstruction/fastmri_reader.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os -from typing import Dict, Sequence, Tuple, Union +from collections.abc import Sequence import numpy as np from numpy import ndarray @@ -41,7 +43,7 @@ class FastMRIReader(ImageReader): - patient_id (str): the patient's id whose measurements were recorded """ - def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ Verify whether the specified file format is supported by h5py reader. @@ -51,7 +53,7 @@ def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: suffixes: Sequence[str] = [".h5"] return has_h5py and is_supported_format(filename, suffixes) - def read(self, data: Union[Sequence[PathLike], PathLike]) -> Dict: # type: ignore + def read(self, data: Sequence[PathLike] | PathLike) -> dict: # type: ignore """ Read data from specified h5 file. Note that the returned object is a dictionary. @@ -73,7 +75,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike]) -> Dict: # type: igno return dat - def get_data(self, dat: Dict) -> Tuple[ndarray, dict]: + def get_data(self, dat: dict) -> tuple[ndarray, dict]: """ Extract data array and metadata from the loaded data and return them. This function returns two objects, first is numpy array of image data, second is dict of metadata. @@ -90,7 +92,7 @@ def get_data(self, dat: Dict) -> Tuple[ndarray, dict]: ) return data, header - def _get_meta_dict(self, dat) -> Dict: + def _get_meta_dict(self, dat: dict) -> dict: """ Get all the metadata of the loaded dict and return the meta dict. diff --git a/monai/apps/reconstruction/mri_utils.py b/monai/apps/reconstruction/mri_utils.py index 9c06b492d5..f51040509e 100644 --- a/monai/apps/reconstruction/mri_utils.py +++ b/monai/apps/reconstruction/mri_utils.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from torch import Tensor from monai.config.type_definitions import NdarrayOrTensor diff --git a/monai/apps/reconstruction/networks/blocks/varnetblock.py b/monai/apps/reconstruction/networks/blocks/varnetblock.py index daaa3efbf3..289505a057 100644 --- a/monai/apps/reconstruction/networks/blocks/varnetblock.py +++ b/monai/apps/reconstruction/networks/blocks/varnetblock.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import torch import torch.nn as nn from torch import Tensor diff --git a/monai/apps/reconstruction/networks/nets/coil_sensitivity_model.py b/monai/apps/reconstruction/networks/nets/coil_sensitivity_model.py index 94568db90f..91a9f3d8d3 100644 --- a/monai/apps/reconstruction/networks/nets/coil_sensitivity_model.py +++ b/monai/apps/reconstruction/networks/nets/coil_sensitivity_model.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple, Union +from __future__ import annotations + +from collections.abc import Sequence import torch import torch.nn as nn @@ -54,13 +56,13 @@ def __init__( self, spatial_dims: int = 2, features: Sequence[int] = (32, 32, 64, 128, 256, 32), - act: Union[str, tuple] = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}), - norm: Union[str, tuple] = ("instance", {"affine": True}), + act: str | tuple = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}), + norm: str | tuple = ("instance", {"affine": True}), bias: bool = True, - dropout: Union[float, tuple] = 0.0, + dropout: float | tuple = 0.0, upsample: str = "deconv", coil_dim: int = 1, - conv_net: Optional[nn.Module] = None, + conv_net: nn.Module | None = None, ): super().__init__() if conv_net is None: @@ -83,7 +85,7 @@ def __init__( self.spatial_dims = spatial_dims self.coil_dim = coil_dim - def get_fully_sampled_region(self, mask: Tensor) -> Tuple[int, int]: + def get_fully_sampled_region(self, mask: Tensor) -> tuple[int, int]: """ Extracts the size of the fully-sampled part of the kspace. Note that when a kspace is under-sampled, a part of its center is fully sampled. This part is called the Auto diff --git a/monai/apps/reconstruction/networks/nets/complex_unet.py b/monai/apps/reconstruction/networks/nets/complex_unet.py index ccbb5731a1..1ca5fd5eec 100644 --- a/monai/apps/reconstruction/networks/nets/complex_unet.py +++ b/monai/apps/reconstruction/networks/nets/complex_unet.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Union +from __future__ import annotations + +from collections.abc import Sequence import torch.nn as nn from torch import Tensor @@ -56,15 +58,16 @@ def __init__( self, spatial_dims: int = 2, features: Sequence[int] = (32, 32, 64, 128, 256, 32), - act: Union[str, tuple] = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}), - norm: Union[str, tuple] = ("instance", {"affine": True}), + act: str | tuple = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}), + norm: str | tuple = ("instance", {"affine": True}), bias: bool = True, - dropout: Union[float, tuple] = 0.0, + dropout: float | tuple = 0.0, upsample: str = "deconv", pad_factor: int = 16, - conv_net: Optional[nn.Module] = None, + conv_net: nn.Module | None = None, ): super().__init__() + self.unet: nn.Module if conv_net is None: self.unet = BasicUNet( spatial_dims=spatial_dims, diff --git a/monai/apps/reconstruction/networks/nets/utils.py b/monai/apps/reconstruction/networks/nets/utils.py index b97cdab786..419ef5e6ff 100644 --- a/monai/apps/reconstruction/networks/nets/utils.py +++ b/monai/apps/reconstruction/networks/nets/utils.py @@ -12,8 +12,9 @@ This script contains utility functions for developing new networks/blocks in PyTorch. """ +from __future__ import annotations + import math -from typing import Tuple from torch import Tensor from torch.nn import functional as F @@ -75,7 +76,7 @@ def reshape_channel_complex_to_last_dim(x: Tensor) -> Tensor: raise ValueError(f"only 2D (B,C*2,H,W) and 3D (B,C*2,H,W,D) data are supported but x has shape {x.shape}") -def reshape_channel_to_batch_dim(x: Tensor) -> Tuple[Tensor, int]: +def reshape_channel_to_batch_dim(x: Tensor) -> tuple[Tensor, int]: """ Combines batch and channel dimensions. @@ -125,7 +126,7 @@ def reshape_batch_channel_to_channel_dim(x: Tensor, batch_size: int) -> Tensor: raise ValueError(f"only 2D (B*C,1,H,W,2) and 3D (B*C,1,H,W,D,2) data are supported but x has shape {x.shape}") -def complex_normalize(x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: +def complex_normalize(x: Tensor) -> tuple[Tensor, Tensor, Tensor]: """ Performs layer mean-std normalization for complex data. Normalization is done for each batch member along each part (part refers to real and imaginary parts), separately. @@ -167,7 +168,7 @@ def complex_normalize(x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: def divisible_pad_t( x: Tensor, k: int = 16 -) -> Tuple[Tensor, Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int], int, int, int]]: +) -> tuple[Tensor, tuple[tuple[int, int], tuple[int, int], tuple[int, int], int, int, int]]: """ Pad input to feed into the network (torch script compatible) @@ -228,7 +229,7 @@ def divisible_pad_t( def inverse_divisible_pad_t( - x: Tensor, pad_sizes: Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int], int, int, int] + x: Tensor, pad_sizes: tuple[tuple[int, int], tuple[int, int], tuple[int, int], int, int, int] ) -> Tensor: """ De-pad network output to match its original shape @@ -252,7 +253,7 @@ def inverse_divisible_pad_t( raise ValueError(f"only 2D (B,C,H,W) and 3D (B,C,H,W,D) data are supported but x has shape {x.shape}") -def floor_ceil(n: float) -> Tuple[int, int]: +def floor_ceil(n: float) -> tuple[int, int]: """ Returns floor and ceil of the input diff --git a/monai/apps/reconstruction/networks/nets/varnet.py b/monai/apps/reconstruction/networks/nets/varnet.py index 33b93b3d82..de4deb9afc 100644 --- a/monai/apps/reconstruction/networks/nets/varnet.py +++ b/monai/apps/reconstruction/networks/nets/varnet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import copy import torch.nn as nn diff --git a/monai/apps/reconstruction/transforms/array.py b/monai/apps/reconstruction/transforms/array.py index cd2936de41..911d7a06bb 100644 --- a/monai/apps/reconstruction/transforms/array.py +++ b/monai/apps/reconstruction/transforms/array.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from abc import abstractmethod -from typing import Sequence +from collections.abc import Sequence import numpy as np from torch import Tensor @@ -70,7 +72,7 @@ def __init__( self.is_complex = is_complex @abstractmethod - def __call__(self, kspace: NdarrayOrTensor): + def __call__(self, kspace: NdarrayOrTensor) -> Sequence[Tensor]: """ This is an extra instance to allow for defining new mask generators. For creating other mask transforms, define a new class and simply diff --git a/monai/apps/reconstruction/transforms/dictionary.py b/monai/apps/reconstruction/transforms/dictionary.py index 4f3a2e03cf..11454b0b6b 100644 --- a/monai/apps/reconstruction/transforms/dictionary.py +++ b/monai/apps/reconstruction/transforms/dictionary.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Hashable, Mapping, Optional, Sequence +from __future__ import annotations + +from collections.abc import Hashable, Mapping, Sequence import numpy as np from numpy import ndarray @@ -46,7 +48,7 @@ def __init__(self, keys: KeysCollection, meta_key: str, allow_missing_keys: bool MapTransform.__init__(self, keys, allow_missing_keys) self.meta_key = meta_key - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, Tensor]: """ Args: data: is a dictionary containing (key,value) pairs from the @@ -113,13 +115,13 @@ def __init__( ) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandomKspaceMaskd": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandomKspaceMaskd: super().set_random_state(seed, state) self.masker.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, Tensor]: """ Args: data: is a dictionary containing (key,value) pairs from the @@ -181,8 +183,8 @@ def __init__( ) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "EquispacedKspaceMaskd": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> EquispacedKspaceMaskd: super().set_random_state(seed, state) self.masker.set_random_state(seed, state) return self @@ -211,11 +213,10 @@ class ReferenceBasedSpatialCropd(Cropd): """ def __init__(self, keys: KeysCollection, ref_key: str, allow_missing_keys: bool = False) -> None: - super().__init__(keys, cropper=None, allow_missing_keys=allow_missing_keys) # type: ignore self.ref_key = ref_key - def __call__(self, data: Mapping[Hashable, Tensor]) -> Dict[Hashable, Tensor]: + def __call__(self, data: Mapping[Hashable, Tensor]) -> dict[Hashable, Tensor]: """ This transform can support to crop ND spatial (channel-first) data. It also supports pseudo ND spatial data (e.g., (C,H,W) is a pseudo-3D @@ -277,8 +278,8 @@ def __init__( self, keys: KeysCollection, ref_key: str, - subtrahend: Optional[NdarrayOrTensor] = None, - divisor: Optional[NdarrayOrTensor] = None, + subtrahend: NdarrayOrTensor | None = None, + divisor: NdarrayOrTensor | None = None, nonzero: bool = False, channel_wise: bool = False, dtype: DtypeLike = np.float32, @@ -288,7 +289,7 @@ def __init__( self.default_normalizer = NormalizeIntensity(subtrahend, divisor, nonzero, channel_wise, dtype) self.ref_key = ref_key - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: """ This transform can support to normalize ND spatial (channel-first) data. It also supports pseudo ND spatial data (e.g., (C,H,W) is a pseudo-3D diff --git a/monai/apps/tcia/__init__.py b/monai/apps/tcia/__init__.py index bd266704b8..af3d44fd14 100644 --- a/monai/apps/tcia/__init__.py +++ b/monai/apps/tcia/__init__.py @@ -9,5 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .label_desc import TCIA_LABEL_DICT from .utils import download_tcia_series_instance, get_tcia_metadata, get_tcia_ref_uid, match_tcia_ref_uid_in_study diff --git a/monai/apps/tcia/label_desc.py b/monai/apps/tcia/label_desc.py index e3875e4095..29ae1fad1f 100644 --- a/monai/apps/tcia/label_desc.py +++ b/monai/apps/tcia/label_desc.py @@ -9,11 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict +from __future__ import annotations __all__ = ["TCIA_LABEL_DICT"] -TCIA_LABEL_DICT: Dict[str, Dict[str, int]] = { +TCIA_LABEL_DICT: dict[str, dict[str, int]] = { "C4KC-KiTS": {"Kidney": 0, "Renal Tumor": 1}, "NSCLC-Radiomics": { "Esophagus": 0, diff --git a/monai/apps/tcia/utils.py b/monai/apps/tcia/utils.py index ad95596223..9c120f0072 100644 --- a/monai/apps/tcia/utils.py +++ b/monai/apps/tcia/utils.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os -from typing import List, Optional +from typing import Iterable import monai from monai.config.type_definitions import PathLike @@ -24,7 +26,7 @@ BASE_URL = "https://services.cancerimagingarchive.net/nbia-api/services/v1/" -def get_tcia_metadata(query: str, attribute: Optional[str] = None): +def get_tcia_metadata(query: str, attribute: str | None = None) -> list: """ Achieve metadata of a public The Cancer Imaging Archive (TCIA) dataset. @@ -51,7 +53,7 @@ def get_tcia_metadata(query: str, attribute: Optional[str] = None): full_url = f"{BASE_URL}{query}" resp = requests_get(full_url) resp.raise_for_status() - metadata_list: List = [] + metadata_list: list = [] if len(resp.text) == 0: return metadata_list for d in resp.json(): @@ -70,7 +72,7 @@ def download_tcia_series_instance( check_md5: bool = False, hashes_filename: str = "md5hashes.csv", progress: bool = True, -): +) -> None: """ Download a dicom series from a public The Cancer Imaging Archive (TCIA) dataset. The downloaded compressed file will be stored in `download_dir`, and the uncompressed folder will be saved @@ -104,7 +106,12 @@ def download_tcia_series_instance( monai.apps.utils.check_hash(filepath=os.path.join(output_dir, dcm), val=md5hash, hash_type="md5") -def get_tcia_ref_uid(ds, find_sop: bool = False, ref_series_uid_tag=(0x0020, 0x000E), ref_sop_uid_tag=(0x0008, 0x1155)): +def get_tcia_ref_uid( + ds: Iterable, + find_sop: bool = False, + ref_series_uid_tag: tuple = (0x0020, 0x000E), + ref_sop_uid_tag: tuple = (0x0008, 0x1155), +) -> str: """ Achieve the referenced UID from the referenced Series Sequence for the input pydicom dataset object. The referenced UID could be Series Instance UID or SOP Instance UID. The UID will be detected from @@ -126,7 +133,7 @@ def get_tcia_ref_uid(ds, find_sop: bool = False, ref_series_uid_tag=(0x0020, 0x0 for item in elem: output = get_tcia_ref_uid(item, find_sop) if elem.tag == ref_uid_tag: - return elem.value + return elem.value # type: ignore[no-any-return] return output diff --git a/monai/apps/utils.py b/monai/apps/utils.py index cbfdcd7423..a36caf2e66 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import hashlib import logging import os @@ -19,7 +21,7 @@ import warnings import zipfile from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any from urllib.error import ContentTooShortError, HTTPError, URLError from urllib.parse import urlparse from urllib.request import urlretrieve @@ -45,9 +47,9 @@ def get_logger( module_name: str = "monai.apps", fmt: str = DEFAULT_FMT, - datefmt: Optional[str] = None, - logger_handler: Optional[logging.Handler] = None, -): + datefmt: str | None = None, + logger_handler: logging.Handler | None = None, +) -> logging.Logger: """ Get a `module_name` logger with the specified format and date format. By default, the logger will print to `stdout` at the INFO level. @@ -56,13 +58,15 @@ def get_logger( (https://docs.python.org/3/library/logging.html#formatter-objects). `logger_handler` can be used to add an additional handler. """ + adds_stdout_handler = module_name is not None and module_name not in logging.root.manager.loggerDict logger = logging.getLogger(module_name) logger.propagate = False logger.setLevel(logging.INFO) - handler = logging.StreamHandler(sys.stdout) - formatter = logging.Formatter(fmt=fmt, datefmt=datefmt) - handler.setFormatter(formatter) - logger.addHandler(handler) + if adds_stdout_handler: # don't add multiple stdout or add to the root + handler = logging.StreamHandler(sys.stdout) + formatter = logging.Formatter(fmt=fmt, datefmt=datefmt) + handler.setFormatter(formatter) + logger.addHandler(handler) if logger_handler is not None: logger.addHandler(logger_handler) return logger @@ -79,7 +83,7 @@ def _basename(p: PathLike) -> str: return Path(f"{p}".rstrip(sep)).name -def _download_with_progress(url, filepath, progress: bool = True): +def _download_with_progress(url: str, filepath: Path, progress: bool = True) -> None: """ Retrieve file from `url` to `filepath`, optionally showing a progress bar. """ @@ -92,7 +96,7 @@ class TqdmUpTo(tqdm): Inspired by the example in https://github.com/tqdm/tqdm. """ - def update_to(self, b: int = 1, bsize: int = 1, tsize: Optional[int] = None): + def update_to(self, b: int = 1, bsize: int = 1, tsize: int | None = None) -> None: """ Args: b: number of blocks transferred so far, default: 1. @@ -114,7 +118,7 @@ def update_to(self, b: int = 1, bsize: int = 1, tsize: Optional[int] = None): raise e -def check_hash(filepath: PathLike, val: Optional[str] = None, hash_type: str = "md5") -> bool: +def check_hash(filepath: PathLike, val: str | None = None, hash_type: str = "md5") -> bool: """ Verify hash signature of specified file. @@ -149,10 +153,10 @@ def check_hash(filepath: PathLike, val: Optional[str] = None, hash_type: str = " def download_url( url: str, filepath: PathLike = "", - hash_val: Optional[str] = None, + hash_val: str | None = None, hash_type: str = "md5", progress: bool = True, - **gdown_kwargs, + **gdown_kwargs: Any, ) -> None: """ Download file from specified URL link, support process bar and hash check. @@ -222,7 +226,7 @@ def download_url( def extractall( filepath: PathLike, output_dir: PathLike = ".", - hash_val: Optional[str] = None, + hash_val: str | None = None, hash_type: str = "md5", file_type: str = "", has_base: bool = True, @@ -282,7 +286,7 @@ def download_and_extract( url: str, filepath: PathLike = "", output_dir: PathLike = ".", - hash_val: Optional[str] = None, + hash_val: str | None = None, hash_type: str = "md5", file_type: str = "", has_base: bool = True, diff --git a/monai/auto3dseg/__init__.py b/monai/auto3dseg/__init__.py index 9d35026045..4e5d15613b 100644 --- a/monai/auto3dseg/__init__.py +++ b/monai/auto3dseg/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .algo_gen import Algo, AlgoGen from .analyzer import ( Analyzer, diff --git a/monai/auto3dseg/algo_gen.py b/monai/auto3dseg/algo_gen.py index ad185117a4..5ebe479aec 100644 --- a/monai/auto3dseg/algo_gen.py +++ b/monai/auto3dseg/algo_gen.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from monai.transforms import Randomizable diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py index 386ca2f99e..c726502a53 100644 --- a/monai/auto3dseg/analyzer.py +++ b/monai/auto3dseg/analyzer.py @@ -9,10 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import time from abc import ABC, abstractmethod +from collections.abc import Hashable, Mapping from copy import deepcopy -from typing import Any, Dict, Hashable, List, Mapping, Optional, Union +from typing import Any import numpy as np import torch @@ -71,7 +74,7 @@ def __init__(self, stats_name: str, report_format: dict) -> None: self.stats_name = stats_name self.ops = ConfigParser({}, globals=False) - def update_ops(self, key: str, op): + def update_ops(self, key: str, op: Operations) -> None: """ Register a statistical operation to the Analyzer and update the report_format. @@ -88,7 +91,7 @@ def update_ops(self, key: str, op): self.report_format = parser.get("") - def update_ops_nested_label(self, nested_key: str, op): + def update_ops_nested_label(self, nested_key: str, op: Operations) -> None: """ Update operations for nested label format. Operation value in report_format will be resolved to a dict with only keys. @@ -113,7 +116,7 @@ def update_ops_nested_label(self, nested_key: str, op): if parser.get(nested_key, "NA") != "NA": parser[nested_key] = op - def get_report_format(self): + def get_report_format(self) -> dict: """ Get the report format by resolving the registered operations recursively. @@ -122,7 +125,7 @@ def get_report_format(self): """ self.resolve_format(self.report_format) - return self.report_format + return self.report_format # type: ignore[no-any-return] @staticmethod def unwrap_ops(func): @@ -146,7 +149,7 @@ def unwrap_ops(func): ret.update({key: None}) return ret - def resolve_format(self, report: dict): + def resolve_format(self, report: dict) -> None: """ Resolve the format of the pre-defined report. @@ -163,7 +166,7 @@ def resolve_format(self, report: dict): report[k] = v @abstractmethod - def __call__(self, data: Any): + def __call__(self, data: Any) -> dict: """Analyze the dict format dataset, return the summary report""" raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") @@ -196,7 +199,6 @@ class ImageStats(Analyzer): """ def __init__(self, image_key: str, stats_name: str = "image_stats") -> None: - if not isinstance(image_key, str): raise ValueError("image_key input must be str") @@ -207,6 +209,7 @@ def __init__(self, image_key: str, stats_name: str = "image_stats") -> None: ImageStatsKeys.CHANNELS: None, ImageStatsKeys.CROPPED_SHAPE: None, ImageStatsKeys.SPACING: None, + ImageStatsKeys.SIZEMM: None, ImageStatsKeys.INTENSITY: None, } @@ -251,6 +254,12 @@ def __call__(self, data): if isinstance(data[self.image_key], MetaTensor) else [1.0] * min(3, data[self.image_key].ndim) ) + + report[ImageStatsKeys.SIZEMM] = [ + np.multiply(x, y).astype(int, copy=False).tolist() + for x, y in zip(report[ImageStatsKeys.SHAPE], report[ImageStatsKeys.SPACING]) + ] + report[ImageStatsKeys.INTENSITY] = [ self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds ] @@ -289,7 +298,6 @@ class FgImageStats(Analyzer): """ def __init__(self, image_key: str, label_key: str, stats_name: str = "image_foreground_stats"): - self.image_key = image_key self.label_key = label_key @@ -298,7 +306,7 @@ def __init__(self, image_key: str, label_key: str, stats_name: str = "image_fore super().__init__(stats_name, report_format) self.update_ops(ImageStatsKeys.INTENSITY, SampleOperations()) - def __call__(self, data) -> dict: + def __call__(self, data: Mapping) -> dict: """ Callable to execute the pre-defined functions @@ -371,13 +379,12 @@ class LabelStats(Analyzer): """ - def __init__(self, image_key: str, label_key: str, stats_name: str = "label_stats", do_ccp: Optional[bool] = True): - + def __init__(self, image_key: str, label_key: str, stats_name: str = "label_stats", do_ccp: bool | None = True): self.image_key = image_key self.label_key = label_key self.do_ccp = do_ccp - report_format: Dict[LabelStatsKeys, Any] = { + report_format: dict[LabelStatsKeys, Any] = { LabelStatsKeys.LABEL_UID: None, LabelStatsKeys.IMAGE_INTST: None, LabelStatsKeys.LABEL: [{LabelStatsKeys.PIXEL_PCT: None, LabelStatsKeys.IMAGE_INTST: None}], @@ -394,7 +401,7 @@ def __init__(self, image_key: str, label_key: str, stats_name: str = "label_stat id_seq = ID_SEP_KEY.join([LabelStatsKeys.LABEL, "0", LabelStatsKeys.IMAGE_INTST]) self.update_ops_nested_label(id_seq, SampleOperations()) - def __call__(self, data: Mapping[Hashable, MetaTensor]) -> Dict[Hashable, MetaTensor]: + def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTensor | dict]: """ Callable to execute the pre-defined functions. @@ -442,7 +449,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> Dict[Hashable, MetaTe The stats operation uses numpy and torch to compute max, min, and other functions. If the input has nan/inf, the stats results will be nan/inf. """ - d: Dict[Hashable, MetaTensor] = dict(data) + d: dict[Hashable, MetaTensor] = dict(data) start = time.time() if isinstance(d[self.image_key], (torch.Tensor, MetaTensor)) and d[self.image_key].device.type == "cuda": using_cuda = True @@ -451,13 +458,13 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> Dict[Hashable, MetaTe restore_grad_state = torch.is_grad_enabled() torch.set_grad_enabled(False) - ndas: List[MetaTensor] = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] # type: ignore + ndas: list[MetaTensor] = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] # type: ignore ndas_label: MetaTensor = d[self.label_key] # (H,W,D) if ndas_label.shape != ndas[0].shape: raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}") - nda_foregrounds: List[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas] + nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas] nda_foregrounds = [nda if nda.numel() > 0 else torch.Tensor([0]) for nda in nda_foregrounds] unique_label = unique(ndas_label) @@ -471,7 +478,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> Dict[Hashable, MetaTe pixel_arr = [] for index in unique_label: start_label = time.time() - label_dict: Dict[str, Any] = {} + label_dict: dict[str, Any] = {} mask_index = ndas_label == index nda_masks = [nda[mask_index] for nda in ndas] @@ -508,11 +515,11 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> Dict[Hashable, MetaTe if not verify_report_format(report, self.get_report_format()): raise RuntimeError(f"report generated by {self.__class__} differs from the report format.") - d[self.stats_name] = report + d[self.stats_name] = report # type: ignore[assignment] torch.set_grad_enabled(restore_grad_state) logger.debug(f"Get label stats spent {time.time()-start}") - return d + return d # type: ignore[return-value] class ImageStatsSumm(Analyzer): @@ -527,13 +534,14 @@ class ImageStatsSumm(Analyzer): """ - def __init__(self, stats_name: str = "image_stats", average: Optional[bool] = True): + def __init__(self, stats_name: str = "image_stats", average: bool | None = True): self.summary_average = average report_format = { ImageStatsKeys.SHAPE: None, ImageStatsKeys.CHANNELS: None, ImageStatsKeys.CROPPED_SHAPE: None, ImageStatsKeys.SPACING: None, + ImageStatsKeys.SIZEMM: None, ImageStatsKeys.INTENSITY: None, } super().__init__(stats_name, report_format) @@ -542,9 +550,10 @@ def __init__(self, stats_name: str = "image_stats", average: Optional[bool] = Tr self.update_ops(ImageStatsKeys.CHANNELS, SampleOperations()) self.update_ops(ImageStatsKeys.CROPPED_SHAPE, SampleOperations()) self.update_ops(ImageStatsKeys.SPACING, SampleOperations()) + self.update_ops(ImageStatsKeys.SIZEMM, SampleOperations()) self.update_ops(ImageStatsKeys.INTENSITY, SummaryOperations()) - def __call__(self, data: List[Dict]): + def __call__(self, data: list[dict]) -> dict: """ Callable to execute the pre-defined functions @@ -563,6 +572,7 @@ def __call__(self, data: List[Dict]): ImageStatsKeys.CHANNELS: {...}, ImageStatsKeys.CROPPED_SHAPE: {...}, ImageStatsKeys.SPACING: {...}, + ImageStatsKeys.SIZEMM: {...}, ImageStatsKeys.INTENSITY: {...}, } @@ -571,17 +581,23 @@ def __call__(self, data: List[Dict]): functions. If the input has nan/inf, the stats results will be nan/inf. """ if not isinstance(data, list): - return ValueError(f"Callable {self.__class__} requires list inputs") + raise ValueError(f"Callable {self.__class__} requires list inputs") if len(data) == 0: - return ValueError(f"Callable {self.__class__} input list is empty") + raise ValueError(f"Callable {self.__class__} input list is empty") if self.stats_name not in data[0]: - return KeyError(f"{self.stats_name} is not in input data") + raise KeyError(f"{self.stats_name} is not in input data") report = deepcopy(self.get_report_format()) - for k in [ImageStatsKeys.SHAPE, ImageStatsKeys.CHANNELS, ImageStatsKeys.CROPPED_SHAPE, ImageStatsKeys.SPACING]: + for k in [ + ImageStatsKeys.SHAPE, + ImageStatsKeys.CHANNELS, + ImageStatsKeys.CROPPED_SHAPE, + ImageStatsKeys.SPACING, + ImageStatsKeys.SIZEMM, + ]: v_np = concat_val_to_np(data, [self.stats_name, k]) report[k] = self.ops[k].evaluate(v_np, dim=(0, 1) if v_np.ndim > 2 and self.summary_average else 0) @@ -608,14 +624,14 @@ class FgImageStatsSumm(Analyzer): """ - def __init__(self, stats_name: str = "image_foreground_stats", average: Optional[bool] = True): + def __init__(self, stats_name: str = "image_foreground_stats", average: bool | None = True): self.summary_average = average report_format = {ImageStatsKeys.INTENSITY: None} super().__init__(stats_name, report_format) self.update_ops(ImageStatsKeys.INTENSITY, SummaryOperations()) - def __call__(self, data: List[Dict]): + def __call__(self, data: list[dict]) -> dict: """ Callable to execute the pre-defined functions. @@ -639,13 +655,13 @@ def __call__(self, data: List[Dict]): functions. If the input has nan/inf, the stats results will be nan/inf. """ if not isinstance(data, list): - return ValueError(f"Callable {self.__class__} requires list inputs") + raise ValueError(f"Callable {self.__class__} requires list inputs") if len(data) == 0: - return ValueError(f"Callable {self.__class__} input list is empty") + raise ValueError(f"Callable {self.__class__} input list is empty") if self.stats_name not in data[0]: - return KeyError(f"{self.stats_name} is not in input data.") + raise KeyError(f"{self.stats_name} is not in input data.") report = deepcopy(self.get_report_format()) intst_str = ImageStatsKeys.INTENSITY @@ -672,11 +688,11 @@ class LabelStatsSumm(Analyzer): """ - def __init__(self, stats_name: str = "label_stats", average: Optional[bool] = True, do_ccp: Optional[bool] = True): + def __init__(self, stats_name: str = "label_stats", average: bool | None = True, do_ccp: bool | None = True): self.summary_average = average self.do_ccp = do_ccp - report_format: Dict[str, Any] = { + report_format: dict[str, Any] = { LabelStatsKeys.LABEL_UID: None, LabelStatsKeys.IMAGE_INTST: None, LabelStatsKeys.LABEL: [{LabelStatsKeys.PIXEL_PCT: None, LabelStatsKeys.IMAGE_INTST: None}], @@ -702,7 +718,7 @@ def __init__(self, stats_name: str = "label_stats", average: Optional[bool] = Tr id_seq = ID_SEP_KEY.join([LabelStatsKeys.LABEL, "0", LabelStatsKeys.LABEL_NCOMP]) self.update_ops_nested_label(id_seq, SampleOperations()) - def __call__(self, data: List[Dict]): + def __call__(self, data: list[dict]) -> dict: """ Callable to execute the pre-defined functions @@ -721,13 +737,13 @@ def __call__(self, data: List[Dict]): functions. If the input has nan/inf, the stats results will be nan/inf. """ if not isinstance(data, list): - return ValueError(f"Callable {self.__class__} requires list inputs") + raise ValueError(f"Callable {self.__class__} requires list inputs") if len(data) == 0: - return ValueError(f"Callable {self.__class__} input list is empty") + raise ValueError(f"Callable {self.__class__} input list is empty") if self.stats_name not in data[0]: - return KeyError(f"{self.stats_name} is not in input data") + raise KeyError(f"{self.stats_name} is not in input data") report = deepcopy(self.get_report_format()) # unique class ID @@ -800,7 +816,7 @@ class FilenameStats(Analyzer): """ - def __init__(self, key: Optional[str], stats_name: str) -> None: + def __init__(self, key: str | None, stats_name: str) -> None: self.key = key super().__init__(stats_name, {}) @@ -851,14 +867,13 @@ def __init__( self, image_key: str, stats_name: str = DataStatsKeys.IMAGE_HISTOGRAM, - hist_bins: Union[List[int], int, None] = None, - hist_range: Optional[list] = None, + hist_bins: list[int] | int | None = None, + hist_range: list | None = None, ): - self.image_key = image_key # set defaults - self.hist_bins: List[int] = ( + self.hist_bins: list[int] = ( [100] if hist_bins is None else hist_bins if isinstance(hist_bins, list) else [hist_bins] ) self.hist_range: list = [-500, 500] if hist_range is None else hist_range @@ -883,7 +898,7 @@ def __init__( if not isinstance(_hist_range, list) or len(_hist_range) != 2: raise ValueError(f"Expected {i+1}. hist_range values to be list of length 2 but received {_hist_range}") - def __call__(self, data) -> dict: + def __call__(self, data: dict) -> dict: """ Callable to execute the pre-defined functions @@ -949,14 +964,14 @@ class ImageHistogramSumm(Analyzer): """ - def __init__(self, stats_name: str = DataStatsKeys.IMAGE_HISTOGRAM, average: Optional[bool] = True): + def __init__(self, stats_name: str = DataStatsKeys.IMAGE_HISTOGRAM, average: bool | None = True): self.summary_average = average report_format = {ImageStatsKeys.HISTOGRAM: None} super().__init__(stats_name, report_format) self.update_ops(ImageStatsKeys.HISTOGRAM, SummaryOperations()) - def __call__(self, data: List[Dict]): + def __call__(self, data: list[dict]) -> dict: """ Callable to execute the pre-defined functions @@ -975,6 +990,7 @@ def __call__(self, data: List[Dict]): ImageStatsKeys.CHANNELS: {...}, ImageStatsKeys.CROPPED_SHAPE: {...}, ImageStatsKeys.SPACING: {...}, + ImageStatsKeys.SIZEMM: {...}, ImageStatsKeys.INTENSITY: {...}, } @@ -983,15 +999,15 @@ def __call__(self, data: List[Dict]): functions. If the input has nan/inf, the stats results will be nan/inf. """ if not isinstance(data, list): - return ValueError(f"Callable {self.__class__} requires list inputs") + raise ValueError(f"Callable {self.__class__} requires list inputs") if len(data) == 0: - return ValueError(f"Callable {self.__class__} input list is empty") + raise ValueError(f"Callable {self.__class__} input list is empty") if self.stats_name not in data[0]: - return KeyError(f"{self.stats_name} is not in input data") + raise KeyError(f"{self.stats_name} is not in input data") - summ_histogram: Dict = {} + summ_histogram: dict = {} for d in data: if not summ_histogram: diff --git a/monai/auto3dseg/operations.py b/monai/auto3dseg/operations.py index 45294549ef..404a6d326e 100644 --- a/monai/auto3dseg/operations.py +++ b/monai/auto3dseg/operations.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections import UserDict from functools import partial from typing import Any @@ -23,7 +25,7 @@ class Operations(UserDict): Base class of operation interface """ - def evaluate(self, data: Any, **kwargs) -> dict: + def evaluate(self, data: Any, **kwargs: Any) -> dict: """ For key-value pairs in the self.data, if the value is a callable, then this function will apply the callable to the input data. @@ -81,7 +83,7 @@ def __init__(self) -> None: "percentile_99_5": ("percentile", 3), } - def evaluate(self, data: Any, **kwargs) -> dict: + def evaluate(self, data: Any, **kwargs: Any) -> dict: """ Applies the callables to the data, and convert the numerics to list or Python numeric types (int/float). @@ -139,7 +141,7 @@ def __init__(self) -> None: "percentile_99_5": mean, } - def evaluate(self, data: Any, **kwargs) -> dict: + def evaluate(self, data: Any, **kwargs: Any) -> dict: """ Applies the callables to the data, and convert the numerics to list or Python numeric types (int/float). diff --git a/monai/auto3dseg/seg_summarizer.py b/monai/auto3dseg/seg_summarizer.py index 6f8093a40e..d38ad582ac 100644 --- a/monai/auto3dseg/seg_summarizer.py +++ b/monai/auto3dseg/seg_summarizer.py @@ -9,9 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union +from __future__ import annotations + +from typing import Any from monai.auto3dseg.analyzer import ( + Analyzer, FgImageStats, FgImageStatsSumm, FilenameStats, @@ -80,22 +83,21 @@ class SegSummarizer(Compose): def __init__( self, image_key: str, - label_key: Optional[str], - average=True, + label_key: str | None, + average: bool = True, do_ccp: bool = True, - hist_bins: Union[List[int], int, None] = None, - hist_range: Optional[list] = None, + hist_bins: list[int] | int | None = None, + hist_range: list | None = None, histogram_only: bool = False, ) -> None: - self.image_key = image_key self.label_key = label_key # set defaults - self.hist_bins: Union[List[int], int] = [100] if hist_bins is None else hist_bins + self.hist_bins: list[int] | int = [100] if hist_bins is None else hist_bins self.hist_range: list = [-500, 500] if hist_range is None else hist_range self.histogram_only = histogram_only - self.summary_analyzers: List[Any] = [] + self.summary_analyzers: list[Any] = [] super().__init__() if not self.histogram_only: @@ -118,7 +120,7 @@ def __init__( ImageHistogram(image_key=image_key, hist_bins=hist_bins, hist_range=hist_range), ImageHistogramSumm() ) - def add_analyzer(self, case_analyzer, summary_analyzer) -> None: + def add_analyzer(self, case_analyzer: Analyzer, summary_analyzer: Analyzer | None) -> None: """ Add new analyzers to the engine so that the callable and summarize functions will utilize the new analyzers for stats computations. @@ -165,9 +167,10 @@ def __call__(self, data): """ self.transforms += (case_analyzer,) - self.summary_analyzers.append(summary_analyzer) + if summary_analyzer is not None: + self.summary_analyzers.append(summary_analyzer) - def summarize(self, data: List[Dict]): + def summarize(self, data: list[dict]) -> dict[str, dict]: """ Summarize the input list of data and generates a report ready for json/yaml export. @@ -196,7 +199,7 @@ def summarize(self, data: List[Dict]): if not isinstance(data, list): raise ValueError(f"{self.__class__} summarize function needs input to be a list of dict") - report: Dict[str, Dict] = {} + report: dict[str, dict] = {} if len(data) == 0: return report diff --git a/monai/auto3dseg/utils.py b/monai/auto3dseg/utils.py index 78593f8369..b1becc7677 100644 --- a/monai/auto3dseg/utils.py +++ b/monai/auto3dseg/utils.py @@ -9,13 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import pickle import sys import warnings from copy import deepcopy from numbers import Number -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast +from typing import Any, cast import numpy as np import torch @@ -41,10 +43,9 @@ measure_np, has_measure = optional_import("skimage.measure", "0.14.2", min_version) cp, has_cp = optional_import("cupy") -cucim, has_cucim = optional_import("cucim") -def get_foreground_image(image: MetaTensor): +def get_foreground_image(image: MetaTensor) -> np.ndarray: """ Get a foreground image by removing all-zero rectangles on the edges of the image Note for the developer: update select_fn if the foreground is defined differently. @@ -61,7 +62,7 @@ def get_foreground_image(image: MetaTensor): copper = CropForeground(select_fn=lambda x: x > 0) image_foreground = copper(image) - return image_foreground + return cast(np.ndarray, image_foreground) def get_foreground_label(image: MetaTensor, label: MetaTensor) -> MetaTensor: @@ -80,7 +81,7 @@ def get_foreground_label(image: MetaTensor, label: MetaTensor) -> MetaTensor: return label_foreground -def get_label_ccp(mask_index: MetaTensor, use_gpu: bool = True) -> Tuple[List[Any], int]: +def get_label_ccp(mask_index: MetaTensor, use_gpu: bool = True) -> tuple[list[Any], int]: """ Find all connected components and their bounding shape. Backend can be cuPy/cuCIM or Numpy depending on the hardware. @@ -91,7 +92,7 @@ def get_label_ccp(mask_index: MetaTensor, use_gpu: bool = True) -> Tuple[List[An regardless of this setting. """ - + cucim, has_cucim = optional_import("cucim") shape_list = [] if mask_index.device.type == "cuda" and has_cp and has_cucim and use_gpu: mask_cupy = ToCupy()(mask_index.short()) @@ -124,12 +125,12 @@ def get_label_ccp(mask_index: MetaTensor, use_gpu: bool = True) -> Tuple[List[An def concat_val_to_np( - data_list: List[Dict], - fixed_keys: List[Union[str, int]], - ragged: Optional[bool] = False, - allow_missing: Optional[bool] = False, - **kwargs, -): + data_list: list[dict], + fixed_keys: list[str | int], + ragged: bool | None = False, + allow_missing: bool | None = False, + **kwargs: Any, +) -> np.ndarray: """ Get the nested value in a list of dictionary that shares the same structure. @@ -144,14 +145,14 @@ def concat_val_to_np( """ - np_list: List[Optional[np.ndarray]] = [] + np_list: list[np.ndarray | None] = [] for data in data_list: parser = ConfigParser(data) for i, key in enumerate(fixed_keys): fixed_keys[i] = str(key) val: Any - val = parser.get(ID_SEP_KEY.join(cast(Iterable[str], fixed_keys))) + val = parser.get(ID_SEP_KEY.join(fixed_keys)) # type: ignore if val is None: if allow_missing: @@ -181,8 +182,8 @@ def concat_val_to_np( def concat_multikeys_to_dict( - data_list: List[Dict], fixed_keys: List[Union[str, int]], keys: List[str], zero_insert: bool = True, **kwargs -): + data_list: list[dict], fixed_keys: list[str | int], keys: list[str], zero_insert: bool = True, **kwargs: Any +) -> dict[str, np.ndarray]: """ Get the nested value in a list of dictionary that shares the same structure iteratively on all keys. It returns a dictionary with keys with the found values in nd.ndarray. @@ -200,14 +201,14 @@ def concat_multikeys_to_dict( ret_dict = {} for key in keys: - addon: List[Union[str, int]] = [0, key] if zero_insert else [key] + addon: list[str | int] = [0, key] if zero_insert else [key] val = concat_val_to_np(data_list, fixed_keys + addon, **kwargs) ret_dict.update({key: val}) return ret_dict -def datafold_read(datalist: Union[str, Dict], basedir: str, fold: int = 0, key: str = "training") -> Tuple[List, List]: +def datafold_read(datalist: str | dict, basedir: str, fold: int = 0, key: str = "training") -> tuple[list, list]: """ Read a list of data dictionary `datalist` @@ -246,7 +247,7 @@ def datafold_read(datalist: Union[str, Dict], basedir: str, fold: int = 0, key: return tr, val -def verify_report_format(report: dict, report_format: dict): +def verify_report_format(report: dict, report_format: dict) -> bool: """ Compares the report and the report_format that has only keys. @@ -268,10 +269,10 @@ def verify_report_format(report: dict, report_format: dict): else: return False - return True + return True -def algo_to_pickle(algo: Algo, **algo_meta_data) -> str: +def algo_to_pickle(algo: Algo, **algo_meta_data: Any) -> str: """ Export the Algo object to pickle file @@ -294,7 +295,7 @@ def algo_to_pickle(algo: Algo, **algo_meta_data) -> str: return pkl_filename -def algo_from_pickle(pkl_filename: str, **kwargs) -> Any: +def algo_from_pickle(pkl_filename: str, **kwargs: Any) -> Any: """ Import the Algo object from a pickle file diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index c5cb0da978..4b10f71a11 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -9,8 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable from .config_parser import ConfigParser +from .properties import InferProperties, TrainProperties from .reference_resolver import ReferenceResolver from .scripts import ( ckpt_export, @@ -20,7 +23,6 @@ get_bundle_versions, init_bundle, load, - patch_bundle_tracking, run, verify_metadata, verify_net_in_out, @@ -34,3 +36,4 @@ MACRO_KEY, load_bundle_config, ) +from .workflows import BundleWorkflow, ConfigWorkflow diff --git a/monai/bundle/__main__.py b/monai/bundle/__main__.py index a9671fe385..b5142ca781 100644 --- a/monai/bundle/__main__.py +++ b/monai/bundle/__main__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from monai.bundle.scripts import ckpt_export, download, init_bundle, run, verify_metadata, verify_net_in_out if __name__ == "__main__": diff --git a/monai/bundle/config_item.py b/monai/bundle/config_item.py index 75a20b5e6a..f3d1cf69fd 100644 --- a/monai/bundle/config_item.py +++ b/monai/bundle/config_item.py @@ -9,16 +9,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import ast import inspect import sys import warnings from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence from importlib import import_module -from typing import Any, Dict, List, Mapping, Optional, Sequence, Union +from typing import Any from monai.bundle.utils import EXPR_KEY -from monai.utils import ensure_tuple, first, instantiate, optional_import, run_debug, run_eval +from monai.utils import CompInitMode, ensure_tuple, first, instantiate, optional_import, run_debug, run_eval __all__ = ["ComponentLocator", "ConfigItem", "ConfigExpression", "ConfigComponent", "Instantiable"] @@ -55,18 +58,18 @@ class ComponentLocator: MOD_START = "monai" - def __init__(self, excludes: Optional[Union[Sequence[str], str]] = None): + def __init__(self, excludes: Sequence[str] | str | None = None): self.excludes = [] if excludes is None else ensure_tuple(excludes) - self._components_table: Optional[Dict[str, List]] = None + self._components_table: dict[str, list] | None = None - def _find_module_names(self) -> List[str]: + def _find_module_names(self) -> list[str]: """ Find all the modules start with MOD_START and don't contain any of `excludes`. """ return [m for m in sys.modules if m.startswith(self.MOD_START) and all(s not in m for s in self.excludes)] - def _find_classes_or_functions(self, modnames: Union[Sequence[str], str]) -> Dict[str, List]: + def _find_classes_or_functions(self, modnames: Sequence[str] | str) -> dict[str, list]: """ Find all the classes and functions in the modules with specified `modnames`. @@ -74,7 +77,7 @@ def _find_classes_or_functions(self, modnames: Union[Sequence[str], str]) -> Dic modnames: names of the target modules to find all the classes and functions. """ - table: Dict[str, List] = {} + table: dict[str, list] = {} # all the MONAI modules are already loaded by `load_submodules` for modname in ensure_tuple(modnames): try: @@ -89,7 +92,7 @@ def _find_classes_or_functions(self, modnames: Union[Sequence[str], str]) -> Dic pass return table - def get_component_module_name(self, name: str) -> Optional[Union[List[str], str]]: + def get_component_module_name(self, name: str) -> list[str] | str | None: """ Get the full module name of the class or function with specified ``name``. If target component name exists in multiple packages or modules, return a list of full module names. @@ -104,7 +107,7 @@ def get_component_module_name(self, name: str) -> Optional[Union[List[str], str] # init component and module mapping table self._components_table = self._find_classes_or_functions(self._find_module_names()) - mods: Optional[Union[List[str], str]] = self._components_table.get(name) + mods: list[str] | str | None = self._components_table.get(name) if isinstance(mods, list) and len(mods) == 1: mods = mods[0] return mods @@ -135,7 +138,7 @@ def get_id(self) -> str: """ return self.id - def update_config(self, config: Any): + def update_config(self, config: Any) -> None: """ Replace the content of `self.config` with new `config`. A typical usage is to modify the initial config content at runtime. @@ -165,8 +168,8 @@ class ConfigComponent(ConfigItem, Instantiable): Currently, three special keys (strings surrounded by ``_``) are defined and interpreted beyond the regular literals: - class or function identifier of the python module, specified by ``"_target_"``, - indicating a build-in python class or function such as ``"LoadImageDict"``, - or a full module name, such as ``"monai.transforms.LoadImageDict"``. + indicating a monai built-in Python class or function such as ``"LoadImageDict"``, + or a full module name, e.g. ``"monai.transforms.LoadImageDict"``, or a callable, e.g. ``"$@model.forward"``. - ``"_requires_"`` (optional): specifies reference IDs (string starts with ``"@"``) or ``ConfigExpression`` of the dependencies for this ``ConfigComponent`` object. These dependencies will be evaluated/instantiated before this object is instantiated. It is useful when the @@ -174,6 +177,11 @@ class ConfigComponent(ConfigItem, Instantiable): but requires the dependencies to be instantiated/evaluated beforehand. - ``"_disabled_"`` (optional): a flag to indicate whether to skip the instantiation. - ``"_desc_"`` (optional): free text descriptions of the component for code readability. + - ``"_mode_"`` (optional): operating mode for invoking the callable ``component`` defined by ``"_target_"``: + + - ``"default"``: returns ``component(**kwargs)`` + - ``"partial"``: returns ``functools.partial(component, **kwargs)`` + - ``"debug"``: returns ``pdb.runcall(component, **kwargs)`` Other fields in the config content are input arguments to the python module. @@ -201,14 +209,14 @@ class ConfigComponent(ConfigItem, Instantiable): """ - non_arg_keys = {"_target_", "_disabled_", "_requires_", "_desc_"} + non_arg_keys = {"_target_", "_disabled_", "_requires_", "_desc_", "_mode_"} def __init__( self, config: Any, id: str = "", - locator: Optional[ComponentLocator] = None, - excludes: Optional[Union[Sequence[str], str]] = None, + locator: ComponentLocator | None = None, + excludes: Sequence[str] | str | None = None, ) -> None: super().__init__(config=config, id=id) self.locator = ComponentLocator(excludes=excludes) if locator is None else locator @@ -233,7 +241,7 @@ def resolve_module_name(self): config = dict(self.get_config()) target = config.get("_target_") if not isinstance(target, str): - raise ValueError("must provide a string for the `_target_` of component to instantiate.") + return target # for feature discussed in project-monai/monai#5852 module = self.locator.get_component_module_name(target) if module is None: @@ -263,7 +271,7 @@ def is_disabled(self) -> bool: _is_disabled = self.get_config().get("_disabled_", False) return _is_disabled.lower().strip() == "true" if isinstance(_is_disabled, str) else bool(_is_disabled) - def instantiate(self, **kwargs) -> object: + def instantiate(self, **kwargs: Any) -> object: """ Instantiate component based on ``self.config`` content. The target component must be a `class` or a `function`, otherwise, return `None`. @@ -277,10 +285,11 @@ def instantiate(self, **kwargs) -> object: return None modname = self.resolve_module_name() + mode = self.get_config().get("_mode_", CompInitMode.DEFAULT) args = self.resolve_args() args.update(kwargs) try: - return instantiate(modname, **args) + return instantiate(modname, mode, **args) except Exception as e: raise RuntimeError(f"Failed to instantiate {self}.") from e @@ -315,11 +324,11 @@ class ConfigExpression(ConfigItem): prefix = EXPR_KEY run_eval = run_eval - def __init__(self, config: Any, id: str = "", globals: Optional[Dict] = None) -> None: + def __init__(self, config: Any, id: str = "", globals: dict | None = None) -> None: super().__init__(config=config, id=id) self.globals = globals if globals is not None else {} - def _parse_import_string(self, import_string: str): + def _parse_import_string(self, import_string: str) -> Any | None: """parse single import statement such as "from monai.transforms import Resize""" node = first(ast.iter_child_nodes(ast.parse(import_string))) if not isinstance(node, (ast.Import, ast.ImportFrom)): @@ -338,7 +347,7 @@ def _parse_import_string(self, import_string: str): return self.globals[asname] return None - def evaluate(self, globals: Optional[Dict] = None, locals: Optional[Dict] = None): + def evaluate(self, globals: dict | None = None, locals: dict | None = None) -> str | Any | None: """ Execute the current config content and return the result if it is expression, based on Python `eval()`. For more details: https://docs.python.org/3/library/functions.html#eval. @@ -370,10 +379,11 @@ def evaluate(self, globals: Optional[Dict] = None, locals: Optional[Dict] = None ) import pdb - return pdb.run(value[len(self.prefix) :], globals_, locals) + pdb.run(value[len(self.prefix) :], globals_, locals) + return None @classmethod - def is_expression(cls, config: Union[Dict, List, str]) -> bool: + def is_expression(cls, config: dict | list | str) -> bool: """ Check whether the config is an executable expression string. Currently, a string starts with ``"$"`` character is interpreted as an expression. @@ -385,7 +395,7 @@ def is_expression(cls, config: Union[Dict, List, str]) -> bool: return isinstance(config, str) and config.startswith(cls.prefix) @classmethod - def is_import_statement(cls, config: Union[Dict, List, str]) -> bool: + def is_import_statement(cls, config: dict | list | str) -> bool: """ Check whether the config is an import statement (a special case of expression). diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py index d57238cfaa..613ad4e44a 100644 --- a/monai/bundle/config_parser.py +++ b/monai/bundle/config_parser.py @@ -9,11 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import json import re +from collections.abc import Sequence from copy import deepcopy from pathlib import Path -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem from monai.bundle.reference_resolver import ReferenceResolver @@ -21,7 +24,10 @@ from monai.config import PathLike from monai.utils import ensure_tuple, look_up_option, optional_import -yaml, _ = optional_import("yaml") +if TYPE_CHECKING: + import yaml +else: + yaml, _ = optional_import("yaml") __all__ = ["ConfigParser"] @@ -95,11 +101,11 @@ class ConfigParser: def __init__( self, config: Any = None, - excludes: Optional[Union[Sequence[str], str]] = None, - globals: Union[Dict[str, Any], None, bool] = None, + excludes: Sequence[str] | str | None = None, + globals: dict[str, Any] | None | bool = None, ): - self.config = None - self.globals: Dict[str, Any] = {} + self.config: ConfigItem | None = None + self.globals: dict[str, Any] = {} _globals = _default_globals.copy() if isinstance(_globals, dict) and globals not in (None, False): _globals.update(globals) # type: ignore @@ -116,7 +122,20 @@ def __init__( def __repr__(self): return f"{self.config}" - def __getitem__(self, id: Union[str, int]): + def __getattr__(self, id): + """ + Get the parsed result of ``ConfigItem`` with the specified ``id`` + with default arguments (e.g. ``lazy=True``, ``instantiate=True`` and ``eval_expr=True``). + + Args: + id: id of the ``ConfigItem``. + + See also: + :py:meth:`get_parsed_content` + """ + return self.get_parsed_content(id) + + def __getitem__(self, id: str | int) -> Any: """ Get the config by id. @@ -141,7 +160,7 @@ def __getitem__(self, id: Union[str, int]): raise KeyError(f"query key: {k}") from e return config - def __setitem__(self, id: Union[str, int], config: Any): + def __setitem__(self, id: str | int, config: Any) -> None: """ Set config by ``id``. Note that this method should be used before ``parse()`` or ``get_parsed_content()`` to ensure the updates are included in the parsed content. @@ -168,7 +187,7 @@ def __setitem__(self, id: Union[str, int], config: Any): self.ref_resolver.reset() return - def get(self, id: str = "", default: Optional[Any] = None): + def get(self, id: str = "", default: Any | None = None) -> Any: """ Get the config by id. @@ -182,7 +201,7 @@ def get(self, id: str = "", default: Optional[Any] = None): except (KeyError, IndexError, ValueError): # Index error for integer indexing return default - def set(self, config: Any, id: str = "", recursive: bool = True): + def set(self, config: Any, id: str = "", recursive: bool = True) -> None: """ Set config by ``id``. @@ -204,7 +223,7 @@ def set(self, config: Any, id: str = "", recursive: bool = True): conf_ = conf_[k if isinstance(conf_, dict) else int(k)] self[id] = config - def update(self, pairs: Dict[str, Any]): + def update(self, pairs: dict[str, Any]) -> None: """ Set the ``id`` and the corresponding config content in pairs, see also :py:meth:`__setitem__`. For example, ``parser.update({"train#epoch": 100, "train#lr": 0.02})`` @@ -216,7 +235,7 @@ def update(self, pairs: Dict[str, Any]): for k, v in pairs.items(): self[k] = v - def __contains__(self, id: Union[str, int]) -> bool: + def __contains__(self, id: str | int) -> bool: """ Returns True if `id` is stored in this configuration. @@ -229,7 +248,7 @@ def __contains__(self, id: Union[str, int]) -> bool: except (KeyError, IndexError, ValueError): # Index error for integer indexing return False - def parse(self, reset: bool = True): + def parse(self, reset: bool = True) -> None: """ Recursively resolve `self.config` to replace the macro tokens with target content. Then recursively parse the config source, add every item as ``ConfigItem`` to the reference resolver. @@ -243,7 +262,7 @@ def parse(self, reset: bool = True): self.resolve_macro_and_relative_ids() self._do_parse(config=self.get()) - def get_parsed_content(self, id: str = "", **kwargs): + def get_parsed_content(self, id: str = "", **kwargs: Any) -> Any: """ Get the parsed result of ``ConfigItem`` with the specified ``id``. @@ -270,7 +289,7 @@ def get_parsed_content(self, id: str = "", **kwargs): self.parse(reset=not kwargs.get("lazy", True)) return self.ref_resolver.get_resolved_content(id=id, **kwargs) - def read_meta(self, f: Union[PathLike, Sequence[PathLike], Dict], **kwargs): + def read_meta(self, f: PathLike | Sequence[PathLike] | dict, **kwargs: Any) -> None: """ Read the metadata from specified JSON or YAML file. The metadata as a dictionary will be stored at ``self.config["_meta_"]``. @@ -284,7 +303,7 @@ def read_meta(self, f: Union[PathLike, Sequence[PathLike], Dict], **kwargs): """ self.set(self.load_config_files(f, **kwargs), self.meta_key) - def read_config(self, f: Union[PathLike, Sequence[PathLike], Dict], **kwargs): + def read_config(self, f: PathLike | Sequence[PathLike] | dict, **kwargs: Any) -> None: """ Read the config from specified JSON or YAML file. The config content in the `self.config` dictionary. @@ -300,7 +319,7 @@ def read_config(self, f: Union[PathLike, Sequence[PathLike], Dict], **kwargs): content.update(self.load_config_files(f, **kwargs)) self.set(config=content) - def _do_resolve(self, config: Any, id: str = ""): + def _do_resolve(self, config: Any, id: str = "") -> Any: """ Recursively resolve `self.config` to replace the relative ids with absolute ids, for example, `@##A` means `A` in the upper level. and replace the macro tokens with target content, @@ -338,7 +357,7 @@ def resolve_macro_and_relative_ids(self): """ self.set(self._do_resolve(config=deepcopy(self.get()))) - def _do_parse(self, config, id: str = ""): + def _do_parse(self, config: Any, id: str = "") -> None: """ Recursively parse the nested data in config source, add every item as `ConfigItem` to the resolver. @@ -365,7 +384,7 @@ def _do_parse(self, config, id: str = ""): self.ref_resolver.add_item(ConfigItem(config=item_conf, id=id)) @classmethod - def load_config_file(cls, filepath: PathLike, **kwargs): + def load_config_file(cls, filepath: PathLike, **kwargs: Any) -> dict: """ Load config file with specified file path (currently support JSON and YAML files). @@ -381,13 +400,13 @@ def load_config_file(cls, filepath: PathLike, **kwargs): raise ValueError(f'unknown file input: "{filepath}"') with open(_filepath) as f: if _filepath.lower().endswith(cls.suffixes[0]): - return json.load(f, **kwargs) + return json.load(f, **kwargs) # type: ignore[no-any-return] if _filepath.lower().endswith(cls.suffixes[1:]): - return yaml.safe_load(f, **kwargs) + return yaml.safe_load(f, **kwargs) # type: ignore[no-any-return] raise ValueError(f"only support JSON or YAML config file so far, got name {_filepath}.") @classmethod - def load_config_files(cls, files: Union[PathLike, Sequence[PathLike], dict], **kwargs) -> Dict: + def load_config_files(cls, files: PathLike | Sequence[PathLike] | dict, **kwargs: Any) -> dict: """ Load config files into a single config dict. The latter config file in the list will override or add the former config file. @@ -407,7 +426,7 @@ def load_config_files(cls, files: Union[PathLike, Sequence[PathLike], dict], **k return parser.get() # type: ignore @classmethod - def export_config_file(cls, config: Dict, filepath: PathLike, fmt="json", **kwargs): + def export_config_file(cls, config: dict, filepath: PathLike, fmt: str = "json", **kwargs: Any) -> None: """ Export the config content to the specified file path (currently support JSON and YAML files). @@ -429,7 +448,7 @@ def export_config_file(cls, config: Dict, filepath: PathLike, fmt="json", **kwar raise ValueError(f"only support JSON or YAML config file so far, got {writer}.") @classmethod - def split_path_id(cls, src: str) -> Tuple[str, str]: + def split_path_id(cls, src: str) -> tuple[str, str]: """ Split `src` string into two parts: a config file path and component id. The file path should end with `(json|yaml|yml)`. The component id should be separated by `#` if it exists. diff --git a/monai/bundle/properties.py b/monai/bundle/properties.py new file mode 100644 index 0000000000..456d84a3b3 --- /dev/null +++ b/monai/bundle/properties.py @@ -0,0 +1,190 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The predefined properties for a bundle workflow, other applications can leverage the properties +to interact with the bundle workflow. +Some properties are required and some are optional, optional properties mean: if some component of the +bundle workflow refer to the property, the property must be defined, otherwise, the property can be None. +Every item in this `TrainProperties` or `InferProperties` dictionary is a property, +the key is the property name and the values include: +1. description. +2. whether it's a required property. +3. config item ID name (only applicable when the bundle workflow is defined in config). +4. reference config item ID name (only applicable when the bundle workflow is defined in config). + +""" + +from __future__ import annotations + +from monai.bundle.utils import ID_SEP_KEY +from monai.utils import BundleProperty, BundlePropertyConfig + +TrainProperties = { + "bundle_root": { + BundleProperty.DESC: "root path of the bundle.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: "bundle_root", + }, + "device": { + BundleProperty.DESC: "target device to execute the bundle workflow.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: "device", + }, + "dataset_dir": { + BundleProperty.DESC: "directory path of the dataset.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: "dataset_dir", + }, + "trainer": { + BundleProperty.DESC: "training workflow engine.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: f"train{ID_SEP_KEY}trainer", + }, + "max_epochs": { + BundleProperty.DESC: "max number of epochs to execute the training.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: f"train{ID_SEP_KEY}trainer{ID_SEP_KEY}max_epochs", + }, + "train_dataset": { + BundleProperty.DESC: "PyTorch dataset object for the training logic.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: f"train{ID_SEP_KEY}dataset", + }, + "train_dataset_data": { + BundleProperty.DESC: "data source for the training dataset.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: f"train{ID_SEP_KEY}dataset{ID_SEP_KEY}data", + }, + "train_inferer": { + BundleProperty.DESC: "MONAI Inferer object to execute the model computation in training.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: f"train{ID_SEP_KEY}inferer", + }, + "train_handlers": { + BundleProperty.DESC: "event-handlers for the training logic.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"train{ID_SEP_KEY}handlers", + BundlePropertyConfig.REF_ID: f"train{ID_SEP_KEY}trainer{ID_SEP_KEY}train_handlers", + }, + "train_preprocessing": { + BundleProperty.DESC: "preprocessing for the training input data.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"train{ID_SEP_KEY}preprocessing", + BundlePropertyConfig.REF_ID: f"train{ID_SEP_KEY}dataset{ID_SEP_KEY}transform", + }, + "train_postprocessing": { + BundleProperty.DESC: "postprocessing for the training model output data.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"train{ID_SEP_KEY}postprocessing", + BundlePropertyConfig.REF_ID: f"train{ID_SEP_KEY}trainer{ID_SEP_KEY}postprocessing", + }, + "train_key_metric": { + BundleProperty.DESC: "key metric to compute on the training data.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"train{ID_SEP_KEY}key_metric", + BundlePropertyConfig.REF_ID: f"train{ID_SEP_KEY}trainer{ID_SEP_KEY}key_train_metric", + }, + "evaluator": { + BundleProperty.DESC: "validation workflow engine.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"validate{ID_SEP_KEY}evaluator", + BundlePropertyConfig.REF_ID: "validator", # this REF_ID is the arg name of `ValidationHandler` + }, + "val_interval": { + BundleProperty.DESC: "validation interval during the training.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: "val_interval", + BundlePropertyConfig.REF_ID: "interval", # this REF_ID is the arg name of `ValidationHandler` + }, + "val_handlers": { + BundleProperty.DESC: "event-handlers for the validation logic.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"validate{ID_SEP_KEY}handlers", + BundlePropertyConfig.REF_ID: f"validate{ID_SEP_KEY}evaluator{ID_SEP_KEY}val_handlers", + }, + "val_dataset": { + BundleProperty.DESC: "PyTorch dataset object for the validation logic.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"validate{ID_SEP_KEY}dataset", + BundlePropertyConfig.REF_ID: f"validate{ID_SEP_KEY}dataloader{ID_SEP_KEY}dataset", + }, + "val_dataset_data": { + BundleProperty.DESC: "data source for the validation dataset.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"validate{ID_SEP_KEY}dataset{ID_SEP_KEY}data", + BundlePropertyConfig.REF_ID: None, # no reference to this ID + }, + "val_inferer": { + BundleProperty.DESC: "MONAI Inferer object to execute the model computation in validation.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"validate{ID_SEP_KEY}inferer", + BundlePropertyConfig.REF_ID: f"validate{ID_SEP_KEY}evaluator{ID_SEP_KEY}inferer", + }, + "val_preprocessing": { + BundleProperty.DESC: "preprocessing for the validation input data.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"validate{ID_SEP_KEY}preprocessing", + BundlePropertyConfig.REF_ID: f"validate{ID_SEP_KEY}dataset{ID_SEP_KEY}transform", + }, + "val_postprocessing": { + BundleProperty.DESC: "postprocessing for the validation model output data.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"validate{ID_SEP_KEY}postprocessing", + BundlePropertyConfig.REF_ID: f"validate{ID_SEP_KEY}evaluator{ID_SEP_KEY}postprocessing", + }, + "val_key_metric": { + BundleProperty.DESC: "key metric to compute on the validation data.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"validate{ID_SEP_KEY}key_metric", + BundlePropertyConfig.REF_ID: f"validate{ID_SEP_KEY}evaluator{ID_SEP_KEY}key_val_metric", + }, +} + +InferProperties = { + "bundle_root": { + BundleProperty.DESC: "root path of the bundle.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: "bundle_root", + }, + "device": { + BundleProperty.DESC: "target device to execute the bundle workflow.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: "device", + }, + "network_def": { + BundleProperty.DESC: "network module for the inference.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: "network_def", + }, + "inferer": { + BundleProperty.DESC: "MONAI Inferer object to execute the model computation in inference.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: "inferer", + }, + "preprocessing": { + BundleProperty.DESC: "preprocessing for the input data.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: "preprocessing", + BundlePropertyConfig.REF_ID: f"dataset{ID_SEP_KEY}transform", + }, + "postprocessing": { + BundleProperty.DESC: "postprocessing for the model output data.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: "postprocessing", + BundlePropertyConfig.REF_ID: f"evaluator{ID_SEP_KEY}postprocessing", + }, + "key_metric": { + BundleProperty.DESC: "the key metric during evaluation.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: "key_metric", + BundlePropertyConfig.REF_ID: f"evaluator{ID_SEP_KEY}key_val_metric", + }, +} diff --git a/monai/bundle/reference_resolver.py b/monai/bundle/reference_resolver.py index ff2eaf4053..e09317aac2 100644 --- a/monai/bundle/reference_resolver.py +++ b/monai/bundle/reference_resolver.py @@ -9,9 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import re import warnings -from typing import Any, Dict, Optional, Sequence, Set +from collections.abc import Sequence +from typing import Any from monai.bundle.config_item import ConfigComponent, ConfigExpression, ConfigItem from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY @@ -54,10 +57,10 @@ class ReferenceResolver: # if `allow_missing_reference` and can't find a reference ID, will just raise a warning and don't update the config allow_missing_reference = allow_missing_reference - def __init__(self, items: Optional[Sequence[ConfigItem]] = None): + def __init__(self, items: Sequence[ConfigItem] | None = None): # save the items in a dictionary with the `ConfigItem.id` as key - self.items: Dict[str, Any] = {} if items is None else {i.get_id(): i for i in items} - self.resolved_content: Dict[str, Any] = {} + self.items: dict[str, ConfigItem] = {} if items is None else {i.get_id(): i for i in items} + self.resolved_content: dict[str, ConfigExpression | str | Any | None] = {} def reset(self): """ @@ -70,7 +73,7 @@ def reset(self): def is_resolved(self) -> bool: return bool(self.resolved_content) - def add_item(self, item: ConfigItem): + def add_item(self, item: ConfigItem) -> None: """ Add a ``ConfigItem`` to the resolver. @@ -83,7 +86,7 @@ def add_item(self, item: ConfigItem): return self.items[id] = item - def get_item(self, id: str, resolve: bool = False, **kwargs): + def get_item(self, id: str, resolve: bool = False, **kwargs: Any) -> ConfigItem | None: """ Get the ``ConfigItem`` by id. @@ -100,7 +103,9 @@ def get_item(self, id: str, resolve: bool = False, **kwargs): self._resolve_one_item(id=id, **kwargs) return self.items.get(id) - def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, **kwargs): + def _resolve_one_item( + self, id: str, waiting_list: set[str] | None = None, **kwargs: Any + ) -> ConfigExpression | str | Any | None: """ Resolve and return one ``ConfigItem`` of ``id``, cache the resolved result in ``resolved_content``. If it has unresolved references, recursively resolve the referring items first. @@ -170,7 +175,7 @@ def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, ** self.resolved_content[id] = new_config return self.resolved_content[id] - def get_resolved_content(self, id: str, **kwargs): + def get_resolved_content(self, id: str, **kwargs: Any) -> ConfigExpression | str | Any | None: """ Get the resolved ``ConfigItem`` by id. @@ -185,7 +190,7 @@ def get_resolved_content(self, id: str, **kwargs): return self._resolve_one_item(id=id, **kwargs) @classmethod - def match_refs_pattern(cls, value: str) -> Dict[str, int]: + def match_refs_pattern(cls, value: str) -> dict[str, int]: """ Match regular expression for the input string to find the references. The reference string starts with ``"@"``, like: ``"@XXX#YYY#ZZZ"``. @@ -194,7 +199,7 @@ def match_refs_pattern(cls, value: str) -> Dict[str, int]: value: input value to match regular expression. """ - refs: Dict[str, int] = {} + refs: dict[str, int] = {} # regular expression pattern to match "@XXX" or "@XXX#YYY" result = cls.id_matcher.findall(value) value_is_expr = ConfigExpression.is_expression(value) @@ -206,7 +211,7 @@ def match_refs_pattern(cls, value: str) -> Dict[str, int]: return refs @classmethod - def update_refs_pattern(cls, value: str, refs: Dict) -> str: + def update_refs_pattern(cls, value: str, refs: dict) -> str: """ Match regular expression for the input string to update content with the references. The reference part starts with ``"@"``, like: ``"@XXX#YYY#ZZZ"``. @@ -219,6 +224,9 @@ def update_refs_pattern(cls, value: str, refs: Dict) -> str: """ # regular expression pattern to match "@XXX" or "@XXX#YYY" result = cls.id_matcher.findall(value) + # reversely sort the matched references by length + # and handle the longer first in case a reference item is substring of another longer item + result.sort(key=len, reverse=True) value_is_expr = ConfigExpression.is_expression(value) for item in result: # only update reference when string starts with "$" or the whole content is "@XXX" @@ -241,7 +249,7 @@ def update_refs_pattern(cls, value: str, refs: Dict) -> str: return value @classmethod - def find_refs_in_config(cls, config, id: str, refs: Optional[Dict[str, int]] = None) -> Dict[str, int]: + def find_refs_in_config(cls, config: Any, id: str, refs: dict[str, int] | None = None) -> dict[str, int]: """ Recursively search all the content of input config item to get the ids of references. References mean: the IDs of other config items (``"@XXX"`` in this config item), or the @@ -254,7 +262,7 @@ def find_refs_in_config(cls, config, id: str, refs: Optional[Dict[str, int]] = N refs: dict of the ID name and count of found references, default to `None`. """ - refs_: Dict[str, int] = refs or {} + refs_: dict[str, int] = refs or {} if isinstance(config, str): for id, count in cls.match_refs_pattern(value=config).items(): refs_[id] = refs_.get(id, 0) + count @@ -268,7 +276,7 @@ def find_refs_in_config(cls, config, id: str, refs: Optional[Dict[str, int]] = N return refs_ @classmethod - def update_config_with_refs(cls, config, id: str, refs: Optional[Dict] = None): + def update_config_with_refs(cls, config: Any, id: str, refs: dict | None = None) -> Any: """ With all the references in ``refs``, update the input config content with references and return the new config. @@ -279,7 +287,7 @@ def update_config_with_refs(cls, config, id: str, refs: Optional[Dict] = None): refs: all the referring content with ids, default to `None`. """ - refs_: Dict = refs or {} + refs_: dict = refs or {} if isinstance(config, str): return cls.update_refs_pattern(config, refs_) if not isinstance(config, (list, dict)): diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index b625578049..55182e429c 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -9,18 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import ast import json import os -import pprint import re -import time import warnings -from logging.config import fileConfig +from collections.abc import Mapping, Sequence from pathlib import Path from shutil import copyfile from textwrap import dedent -from typing import Dict, Mapping, Optional, Sequence, Tuple, Union +from typing import Any import torch from torch.cuda import is_available @@ -29,12 +29,20 @@ from monai.apps.utils import _basename, download_url, extractall, get_logger from monai.bundle.config_item import ConfigComponent from monai.bundle.config_parser import ConfigParser -from monai.bundle.utils import DEFAULT_EXP_MGMT_SETTINGS, DEFAULT_INFERENCE, DEFAULT_METADATA +from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA +from monai.bundle.workflows import ConfigWorkflow from monai.config import IgniteInfo, PathLike from monai.data import load_net_with_metadata, save_net_with_metadata from monai.networks import convert_to_torchscript, copy_model_state, get_state_dict, save_state -from monai.utils import check_parent_dir, get_equivalent_dtype, min_version, optional_import -from monai.utils.misc import ensure_tuple +from monai.utils import ( + check_parent_dir, + deprecated_arg, + ensure_tuple, + get_equivalent_dtype, + min_version, + optional_import, + pprint_edges, +) validate, _ = optional_import("jsonschema", name="validate") ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError") @@ -45,9 +53,10 @@ # set BUNDLE_DOWNLOAD_SRC="ngc" to use NGC source in default for bundle download download_source = os.environ.get("BUNDLE_DOWNLOAD_SRC", "github") +PPRINT_CONFIG_N = 5 -def _update_args(args: Optional[Union[str, Dict]] = None, ignore_none: bool = True, **kwargs) -> Dict: +def _update_args(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict: """ Update the `args` with the input `kwargs`. For dict data, recursively update the content based on the keys. @@ -58,7 +67,7 @@ def _update_args(args: Optional[Union[str, Dict]] = None, ignore_none: bool = Tr kwargs: destination args to update. """ - args_: Dict = args if isinstance(args, dict) else {} + args_: dict = args if isinstance(args, dict) else {} if isinstance(args, str): # args are defined in a structured file args_ = ConfigParser.load_config_file(args) @@ -74,7 +83,7 @@ def _update_args(args: Optional[Union[str, Dict]] = None, ignore_none: bool = Tr return args_ -def _pop_args(src: Dict, *args, **kwargs): +def _pop_args(src: dict, *args: Any, **kwargs: Any) -> tuple: """ Pop args from the `src` dictionary based on specified keys in `args` and (key, default value) pairs in `kwargs`. @@ -82,14 +91,14 @@ def _pop_args(src: Dict, *args, **kwargs): return tuple([src.pop(i) for i in args] + [src.pop(k, v) for k, v in kwargs.items()]) -def _log_input_summary(tag, args: Dict): +def _log_input_summary(tag: str, args: dict) -> None: logger.info(f"--- input summary of monai.bundle.scripts.{tag} ---") for name, val in args.items(): - logger.info(f"> {name}: {pprint.pformat(val)}") + logger.info(f"> {name}: {pprint_edges(val, PPRINT_CONFIG_N)}") logger.info("---\n\n") -def _get_var_names(expr: str): +def _get_var_names(expr: str) -> list[str]: """ Parse the expression and discover what variables are present in it based on ast module. @@ -101,7 +110,7 @@ def _get_var_names(expr: str): return [m.id for m in ast.walk(tree) if isinstance(m, ast.Name)] -def _get_fake_spatial_shape(shape: Sequence[Union[str, int]], p: int = 1, n: int = 1, any: int = 1) -> Tuple: +def _get_fake_spatial_shape(shape: Sequence[str | int], p: int = 1, n: int = 1, any: int = 1) -> tuple: """ Get spatial shape for fake data according to the specified shape pattern. It supports `int` number and `string` with formats like: "32", "32 * n", "32 ** p", "32 ** p *n". @@ -130,15 +139,15 @@ def _get_fake_spatial_shape(shape: Sequence[Union[str, int]], p: int = 1, n: int return tuple(ret) -def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filename: str): +def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filename: str) -> str: return f"https://github.com/{repo_owner}/{repo_name}/releases/download/{tag_name}/{filename}" -def _get_ngc_bundle_url(model_name: str, version: str): +def _get_ngc_bundle_url(model_name: str, version: str) -> str: return f"https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit/{model_name}/versions/{version}/zip" -def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True): +def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True) -> None: repo_owner, repo_name, tag_name = repo.split("/") if ".zip" not in filename: filename += ".zip" @@ -148,19 +157,21 @@ def _download_from_github(repo: str, download_path: Path, filename: str, progres extractall(filepath=filepath, output_dir=download_path, has_base=True) -def _add_ngc_prefix(name: str, prefix: str = "monai_"): +def _add_ngc_prefix(name: str, prefix: str = "monai_") -> str: if name.startswith(prefix): return name return f"{prefix}{name}" -def _remove_ngc_prefix(name: str, prefix: str = "monai_"): +def _remove_ngc_prefix(name: str, prefix: str = "monai_") -> str: if name.startswith(prefix): return name[len(prefix) :] return name -def _download_from_ngc(download_path: Path, filename: str, version: str, remove_prefix: Optional[str], progress: bool): +def _download_from_ngc( + download_path: Path, filename: str, version: str, remove_prefix: str | None, progress: bool +) -> None: # ensure prefix is contained filename = _add_ngc_prefix(filename) url = _get_ngc_bundle_url(model_name=filename, version=version) @@ -172,7 +183,7 @@ def _download_from_ngc(download_path: Path, filename: str, version: str, remove_ extractall(filepath=filepath, output_dir=extract_path, has_base=True) -def _get_latest_bundle_version(source: str, name: str, repo: str): +def _get_latest_bundle_version(source: str, name: str, repo: str) -> dict[str, list[str] | str] | Any | None: if source == "ngc": name = _add_ngc_prefix(name) model_dict = _get_all_ngc_models(name) @@ -182,12 +193,12 @@ def _get_latest_bundle_version(source: str, name: str, repo: str): return None elif source == "github": repo_owner, repo_name, tag_name = repo.split("/") - return get_bundle_versions(name, repo=os.path.join(repo_owner, repo_name), tag=tag_name)["latest_version"] + return get_bundle_versions(name, repo=f"{repo_owner}/{repo_name}", tag=tag_name)["latest_version"] else: raise ValueError(f"To get the latest bundle version, source should be 'github' or 'ngc', got {source}.") -def _process_bundle_dir(bundle_dir: Optional[PathLike] = None): +def _process_bundle_dir(bundle_dir: PathLike | None = None) -> Path: if bundle_dir is None: get_dir, has_home = optional_import("torch.hub", name="get_dir") if has_home: @@ -198,16 +209,16 @@ def _process_bundle_dir(bundle_dir: Optional[PathLike] = None): def download( - name: Optional[str] = None, - version: Optional[str] = None, - bundle_dir: Optional[PathLike] = None, + name: str | None = None, + version: str | None = None, + bundle_dir: PathLike | None = None, source: str = download_source, - repo: Optional[str] = None, - url: Optional[str] = None, - remove_prefix: Optional[str] = "monai_", + repo: str | None = None, + url: str | None = None, + remove_prefix: str | None = "monai_", progress: bool = True, - args_file: Optional[str] = None, -): + args_file: str | None = None, +) -> None: """ download bundle from the specified source or url. The bundle should be a zip file and it will be extracted after downloading. @@ -320,20 +331,20 @@ def download( def load( name: str, - version: Optional[str] = None, - model_file: Optional[str] = None, + version: str | None = None, + model_file: str | None = None, load_ts_module: bool = False, - bundle_dir: Optional[PathLike] = None, + bundle_dir: PathLike | None = None, source: str = download_source, - repo: Optional[str] = None, - remove_prefix: Optional[str] = "monai_", + repo: str | None = None, + remove_prefix: str | None = "monai_", progress: bool = True, - device: Optional[str] = None, - key_in_ckpt: Optional[str] = None, + device: str | None = None, + key_in_ckpt: str | None = None, config_files: Sequence[str] = (), - net_name: Optional[str] = None, - **net_kwargs, -): + net_name: str | None = None, + **net_kwargs: Any, +) -> object | tuple[torch.nn.Module, dict, dict] | Any: """ Load model weights or TorchScript module of a bundle. @@ -422,8 +433,8 @@ def load( def _get_all_bundles_info( - repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", auth_token: Optional[str] = None -): + repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", auth_token: str | None = None +) -> dict[str, dict[str, dict[str, Any]]]: if has_requests: request_url = f"https://api.github.com/repos/{repo}/releases" if auth_token is not None: @@ -436,7 +447,7 @@ def _get_all_bundles_info( raise ValueError("requests package is required, please install it.") releases_list = json.loads(resp.text) bundle_name_pattern = re.compile(r"_v\d*.") - bundles_info: Dict = {} + bundles_info: dict[str, dict[str, dict[str, Any]]] = {} for release in releases_list: if release["tag_name"] == tag: @@ -459,8 +470,8 @@ def _get_all_bundles_info( def get_all_bundles_list( - repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", auth_token: Optional[str] = None -): + repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", auth_token: str | None = None +) -> list[tuple[str, str]]: """ Get all bundles names (and the latest versions) that are stored in the release of specified repository with the provided tag. The default values of arguments correspond to the release of MONAI model zoo. @@ -494,8 +505,8 @@ def get_bundle_versions( bundle_name: str, repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", - auth_token: Optional[str] = None, -): + auth_token: str | None = None, +) -> dict[str, list[str] | str]: """ Get the latest version, as well as all existing versions of a bundle that is stored in the release of specified repository with the provided tag. @@ -528,11 +539,11 @@ def get_bundle_versions( def get_bundle_info( bundle_name: str, - version: Optional[str] = None, + version: str | None = None, repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", - auth_token: Optional[str] = None, -): + auth_token: str | None = None, +) -> dict[str, Any]: """ Get all information (include "id", "name", "size", "download_count", "browser_download_url", "created_at", "updated_at") of a bundle @@ -568,51 +579,20 @@ def get_bundle_info( return bundle_info[version] -def patch_bundle_tracking(parser: ConfigParser, settings: dict): - """ - Patch the loaded bundle config with a new handler logic to enable experiment tracking features. - - Args: - parser: loaded config content to patch the handler. - settings: settings for the experiment tracking, should follow the pattern of default settings. - - """ - for k, v in settings["configs"].items(): - if k in settings["handlers_id"]: - engine = parser.get(settings["handlers_id"][k]["id"]) - if engine is not None: - handlers = parser.get(settings["handlers_id"][k]["handlers"]) - if handlers is None: - engine["train_handlers" if k == "trainer" else "val_handlers"] = [v] - else: - handlers.append(v) - elif k not in parser: - parser[k] = v - # save the executed config into file - default_name = f"config_{time.strftime('%Y%m%d_%H%M%S')}.json" - filepath = parser.get("execute_config", None) - if filepath is None: - if "output_dir" not in parser: - # if no "output_dir" in the bundle config, default to "/eval" - parser["output_dir"] = "$@bundle_root + '/eval'" - # experiment management tools can refer to this config item to track the config info - parser["execute_config"] = parser["output_dir"] + f" + '/{default_name}'" - filepath = os.path.join(parser.get_parsed_content("output_dir"), default_name) - Path(filepath).parent.mkdir(parents=True, exist_ok=True) - parser.export_config_file(parser.get(), filepath) - - +@deprecated_arg("runner_id", since="1.1", removed="1.3", new_name="run_id", msg_suffix="please use `run_id` instead.") def run( - runner_id: Optional[Union[str, Sequence[str]]] = None, - meta_file: Optional[Union[str, Sequence[str]]] = None, - config_file: Optional[Union[str, Sequence[str]]] = None, - logging_file: Optional[str] = None, - tracking: Optional[Union[str, dict]] = None, - args_file: Optional[str] = None, - **override, -): + run_id: str | None = None, + init_id: str | None = None, + final_id: str | None = None, + meta_file: str | Sequence[str] | None = None, + config_file: str | Sequence[str] | None = None, + logging_file: str | None = None, + tracking: str | dict | None = None, + args_file: str | None = None, + **override: Any, +) -> None: """ - Specify `meta_file` and `config_file` to run monai bundle components and workflows. + Specify `config_file` to run monai bundle components and workflows. Typical usage examples: @@ -635,65 +615,23 @@ def run( python -m monai.bundle run --args_file "/workspace/data/args.json" --config_file Args: - runner_id: ID name of the expected config expression to run, can also be a list of IDs to run in order. + run_id: ID name of the expected config expression to run, default to "run". + init_id: ID name of the expected config expression to initialize before running, default to "initialize". + final_id: ID name of the expected config expression to finalize after running, default to "finalize". meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged. + Default to "configs/metadata.json", which is commonly used for bundles in MONAI model zoo. config_file: filepath of the config file, if `None`, must be provided in `args_file`. if it is a list of file paths, the content of them will be merged. - logging_file: config file for `logging` module in the program, default to `None`. for more details: + logging_file: config file for `logging` module in the program. for more details: https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig. - tracking: enable the experiment tracking feature at runtime with optionally configurable and extensible. - if "mlflow", will add `MLFlowHandler` to the parsed bundle with default logging settings, - if other string, treat it as file path to load the logging settings, if `dict`, - treat it as logging settings, otherwise, use all the default settings. + Default to "configs/logging.conf", which is commonly used for bundles in MONAI model zoo. + tracking: if not None, enable the experiment tracking at runtime with optionally configurable and extensible. + if "mlflow", will add `MLFlowHandler` to the parsed bundle with default tracking settings, + if other string, treat it as file path to load the tracking settings. + if `dict`, treat it as tracking settings. will patch the target config content with `tracking handlers` and the top-level items of `configs`. - example of customized settings: - - .. code-block:: python - - tracking = { - "handlers_id": { - "trainer": {"id": "train#trainer", "handlers": "train#handlers"}, - "validator": {"id": "evaluate#evaluator", "handlers": "evaluate#handlers"}, - "evaluator": {"id": "evaluator", "handlers": "handlers"}, - }, - "configs": { - "tracking_uri": "", - "experiment_name": "monai_experiment", - "run_name": None, - "is_not_rank0": ( - "$torch.distributed.is_available() \ - and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0" - ), - "trainer": { - "_target_": "MLFlowHandler", - "_disabled_": "@is_not_rank0", - "tracking_uri": "@tracking_uri", - "experiment_name": "@experiment_name", - "run_name": "@run_name", - "iteration_log": True, - "output_transform": "$monai.handlers.from_engine(['loss'], first=True)", - "close_on_complete": True, - }, - "validator": { - "_target_": "MLFlowHandler", - "_disabled_": "@is_not_rank0", - "tracking_uri": "@tracking_uri", - "experiment_name": "@experiment_name", - "run_name": "@run_name", - "iteration_log": False, - }, - "evaluator": { - "_target_": "MLFlowHandler", - "_disabled_": "@is_not_rank0", - "tracking_uri": "@tracking_uri", - "experiment_name": "@experiment_name", - "run_name": "@run_name", - "iteration_log": False, - "close_on_complete": True, - }, - }, - }, - + for detailed usage examples, plesae check the tutorial: + https://github.com/Project-MONAI/tutorials/blob/main/experiment_management/bundle_integrate_mlflow.ipynb. args_file: a JSON or YAML file to provide default values for `runner_id`, `meta_file`, `config_file`, `logging`, and override pairs. so that the command line inputs can be simplified. override: id-value pairs to override or add the corresponding config content. @@ -703,7 +641,9 @@ def run( _args = _update_args( args=args_file, - runner_id=runner_id, + run_id=run_id, + init_id=init_id, + final_id=final_id, meta_file=meta_file, config_file=config_file, logging_file=logging_file, @@ -713,44 +653,40 @@ def run( if "config_file" not in _args: warnings.warn("`config_file` not provided for 'monai.bundle run'.") _log_input_summary(tag="run", args=_args) - config_file_, meta_file_, runner_id_, logging_file_, tracking_ = _pop_args( - _args, config_file=None, meta_file=None, runner_id="", logging_file=None, tracking=None + config_file_, meta_file_, init_id_, run_id_, final_id_, logging_file_, tracking_ = _pop_args( + _args, + config_file=None, + meta_file="configs/metadata.json", + init_id="initialize", + run_id="run", + final_id="finalize", + logging_file="configs/logging.conf", + tracking=None, ) - if logging_file_ is not None: - if not os.path.exists(logging_file_): - raise FileNotFoundError(f"can't find the logging config file: {logging_file_}.") - logger.info(f"set logging properties based on config: {logging_file_}.") - fileConfig(logging_file_, disable_existing_loggers=False) - - parser = ConfigParser() - parser.read_config(f=config_file_) - if meta_file_ is not None: - parser.read_meta(f=meta_file_) - - # the rest key-values in the _args are to override config content - parser.update(pairs=_args) - - # set tracking configs for experiment management - if tracking_ is not None: - if isinstance(tracking_, str) and tracking_ in DEFAULT_EXP_MGMT_SETTINGS: - settings_ = DEFAULT_EXP_MGMT_SETTINGS[tracking_] - else: - settings_ = ConfigParser.load_config_files(tracking_) - patch_bundle_tracking(parser=parser, settings=settings_) - - # resolve and execute the specified runner expressions in the config, return the results - return [parser.get_parsed_content(i, lazy=True, eval_expr=True, instantiate=True) for i in ensure_tuple(runner_id_)] + workflow = ConfigWorkflow( + config_file=config_file_, + meta_file=meta_file_, + logging_file=logging_file_, + init_id=init_id_, + run_id=run_id_, + final_id=final_id_, + tracking=tracking_, + **_args, + ) + workflow.initialize() + workflow.run() + workflow.finalize() def verify_metadata( - meta_file: Optional[Union[str, Sequence[str]]] = None, - filepath: Optional[PathLike] = None, - create_dir: Optional[bool] = None, - hash_val: Optional[str] = None, - hash_type: Optional[str] = None, - args_file: Optional[str] = None, - **kwargs, -): + meta_file: str | Sequence[str] | None = None, + filepath: PathLike | None = None, + create_dir: bool | None = None, + hash_val: str | None = None, + hash_type: str | None = None, + args_file: str | None = None, + **kwargs: Any, +) -> None: """ Verify the provided `metadata` file based on the predefined `schema`. `metadata` content must contain the `schema` field for the URL of schema file to download. @@ -804,16 +740,16 @@ def verify_metadata( def verify_net_in_out( - net_id: Optional[str] = None, - meta_file: Optional[Union[str, Sequence[str]]] = None, - config_file: Optional[Union[str, Sequence[str]]] = None, - device: Optional[str] = None, - p: Optional[int] = None, - n: Optional[int] = None, - any: Optional[int] = None, - args_file: Optional[str] = None, - **override, -): + net_id: str | None = None, + meta_file: str | Sequence[str] | None = None, + config_file: str | Sequence[str] | None = None, + device: str | None = None, + p: int | None = None, + n: int | None = None, + any: int | None = None, + args_file: str | None = None, + **override: Any, +) -> None: """ Verify the input and output data shape and data type of network defined in the metadata. Will test with fake Tensor data according to the required data shape in `metadata`. @@ -903,15 +839,15 @@ def verify_net_in_out( def ckpt_export( - net_id: Optional[str] = None, - filepath: Optional[PathLike] = None, - ckpt_file: Optional[str] = None, - meta_file: Optional[Union[str, Sequence[str]]] = None, - config_file: Optional[Union[str, Sequence[str]]] = None, - key_in_ckpt: Optional[str] = None, - args_file: Optional[str] = None, - **override, -): + net_id: str | None = None, + filepath: PathLike | None = None, + ckpt_file: str | None = None, + meta_file: str | Sequence[str] | None = None, + config_file: str | Sequence[str] | None = None, + key_in_ckpt: str | None = None, + args_file: str | None = None, + **override: Any, +) -> None: """ Export the model checkpoint to the given filepath with metadata and config included as JSON files. @@ -974,7 +910,7 @@ def ckpt_export( # convert to TorchScript model and save with metadata, config content net = convert_to_torchscript(model=net) - extra_files: Dict = {} + extra_files: dict = {} for i in ensure_tuple(config_file_): # split the filename and directory filename = os.path.basename(i) @@ -1002,12 +938,12 @@ def ckpt_export( def init_bundle( bundle_dir: PathLike, - ckpt_file: Optional[PathLike] = None, - network: Optional[torch.nn.Module] = None, + ckpt_file: PathLike | None = None, + network: torch.nn.Module | None = None, dataset_license: bool = False, - metadata_str: Union[Dict, str, None] = None, - inference_str: Union[Dict, str, None] = None, -): + metadata_str: dict | str | None = None, + inference_str: dict | str | None = None, +) -> None: """ Initialise a new bundle directory with some default configuration files and optionally network weights. diff --git a/monai/bundle/utils.py b/monai/bundle/utils.py index 33ff3ff28f..62d4975d94 100644 --- a/monai/bundle/utils.py +++ b/monai/bundle/utils.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import json import os import zipfile @@ -156,7 +158,7 @@ DEFAULT_EXP_MGMT_SETTINGS = {"mlflow": DEFAULT_MLFLOW_SETTINGS} # default experiment management settings -def load_bundle_config(bundle_path: str, *config_names, **load_kw_args) -> Any: +def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any) -> Any: """ Load the metadata and nominated configuration files from a MONAI bundle without loading the network itself. diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py new file mode 100644 index 0000000000..ace08b3ec8 --- /dev/null +++ b/monai/bundle/workflows.py @@ -0,0 +1,368 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import time +import warnings +from abc import ABC, abstractmethod +from logging.config import fileConfig +from pathlib import Path +from typing import Any, Sequence + +from monai.apps.utils import get_logger +from monai.bundle.config_parser import ConfigParser +from monai.bundle.properties import InferProperties, TrainProperties +from monai.bundle.utils import DEFAULT_EXP_MGMT_SETTINGS, EXPR_KEY, ID_REF_KEY, ID_SEP_KEY +from monai.utils import BundleProperty, BundlePropertyConfig + +__all__ = ["BundleWorkflow", "ConfigWorkflow"] + +logger = get_logger(module_name=__name__) + + +class BundleWorkflow(ABC): + """ + Base class for the workflow specification in bundle, it can be a training, evaluation or inference workflow. + It defines the basic interfaces for the bundle workflow behavior: `initialize`, `run`, `finalize`, etc. + And also provides the interface to get / set public properties to interact with a bundle workflow. + + Args: + workflow: specifies the workflow type: "train" or "training" for a training workflow, + or "infer", "inference", "eval", "evaluation" for a inference workflow, + other unsupported string will raise a ValueError. + default to `None` for common workflow. + + """ + + def __init__(self, workflow: str | None = None): + if workflow is None: + self.properties = None + self.workflow = None + return + if workflow.lower() in ("train", "training"): + self.properties = TrainProperties + self.workflow = "train" + elif workflow.lower() in ("infer", "inference", "eval", "evaluation"): + self.properties = InferProperties + self.workflow = "infer" + else: + raise ValueError(f"Unsupported workflow type: '{workflow}'.") + + @abstractmethod + def initialize(self, *args: Any, **kwargs: Any) -> Any: + """ + Initialize the bundle workflow before running. + + """ + raise NotImplementedError() + + @abstractmethod + def run(self, *args: Any, **kwargs: Any) -> Any: + """ + Run the bundle workflow, it can be a training, evaluation or inference. + + """ + raise NotImplementedError() + + @abstractmethod + def finalize(self, *args: Any, **kwargs: Any) -> Any: + """ + Finalize step after the running of bundle workflow. + + """ + raise NotImplementedError() + + @abstractmethod + def _get_property(self, name: str, property: dict) -> Any: + """ + With specified property name and information, get the expected property value. + + Args: + name: the name of target property. + property: other information for the target property, defined in `TrainProperties` or `InferProperties`. + + """ + raise NotImplementedError() + + @abstractmethod + def _set_property(self, name: str, property: dict, value: Any) -> Any: + """ + With specified property name and information, set value for the expected property. + + Args: + name: the name of target property. + property: other information for the target property, defined in `TrainProperties` or `InferProperties`. + value: value to set for the property. + + """ + raise NotImplementedError() + + def __getattr__(self, name): + if self.properties is not None and name in self.properties: + return self._get_property(name=name, property=self.properties[name]) + else: + return self.__getattribute__(name) # getting regular attribute + + def __setattr__(self, name, value): + if name != "properties" and self.properties is not None and name in self.properties: + self._set_property(name=name, property=self.properties[name], value=value) + else: + super().__setattr__(name, value) # setting regular attribute + + def get_workflow_type(self): + """ + Get the workflow type, it can be `None`, "train", or "infer". + + """ + return self.workflow + + def check_properties(self) -> list[str] | None: + """ + Check whether the required properties are existing in the bundle workflow. + If no workflow type specified, return None, otherwise, return a list of required but missing properties. + + """ + if self.properties is None: + return None + return [n for n, p in self.properties.items() if p.get(BundleProperty.REQUIRED, False) and not hasattr(self, n)] + + +class ConfigWorkflow(BundleWorkflow): + """ + Specification for the config-based bundle workflow. + Standardized the `initialize`, `run`, `finalize` behavior in a config-based training, evaluation, or inference. + For more information: https://docs.monai.io/en/latest/mb_specification.html. + + Args: + run_id: ID name of the expected config expression to run, default to "run". + init_id: ID name of the expected config expression to initialize before running, default to "initialize". + final_id: ID name of the expected config expression to finalize after running, default to "finalize". + meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged. + Default to "configs/metadata.json", which is commonly used for bundles in MONAI model zoo. + config_file: filepath of the config file, if it is a list of file paths, the content of them will be merged. + logging_file: config file for `logging` module in the program. for more details: + https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig. + Default to "configs/logging.conf", which is commonly used for bundles in MONAI model zoo. + tracking: if not None, enable the experiment tracking at runtime with optionally configurable and extensible. + if "mlflow", will add `MLFlowHandler` to the parsed bundle with default tracking settings, + if other string, treat it as file path to load the tracking settings. + if `dict`, treat it as tracking settings. + will patch the target config content with `tracking handlers` and the top-level items of `configs`. + for detailed usage examples, plesae check the tutorial: + https://github.com/Project-MONAI/tutorials/blob/main/experiment_management/bundle_integrate_mlflow.ipynb. + workflow: specifies the workflow type: "train" or "training" for a training workflow, + or "infer", "inference", "eval", "evaluation" for a inference workflow, + other unsupported string will raise a ValueError. + default to `None` for common workflow. + override: id-value pairs to override or add the corresponding config content. + e.g. ``--net#input_chns 42``. + + """ + + def __init__( + self, + config_file: str | Sequence[str], + meta_file: str | Sequence[str] | None = "configs/metadata.json", + logging_file: str | None = "configs/logging.conf", + init_id: str = "initialize", + run_id: str = "run", + final_id: str = "finalize", + tracking: str | dict | None = None, + workflow: str | None = None, + **override: dict, + ) -> None: + super().__init__(workflow=workflow) + if logging_file is not None: + if not os.path.exists(logging_file): + if logging_file == "configs/logging.conf": + warnings.warn("Default logging file in 'configs/logging.conf' does not exist, skipping logging.") + else: + raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.") + else: + logger.info(f"Setting logging properties based on config: {logging_file}.") + fileConfig(logging_file, disable_existing_loggers=False) + + self.parser = ConfigParser() + self.parser.read_config(f=config_file) + if meta_file is not None: + if isinstance(meta_file, str) and not os.path.exists(meta_file): + if meta_file == "configs/metadata.json": + warnings.warn("Default metadata file in 'configs/metadata.json' does not exist, skipping loading.") + else: + raise FileNotFoundError(f"Cannot find the metadata config file: {meta_file}.") + else: + self.parser.read_meta(f=meta_file) + + # the rest key-values in the _args are to override config content + self.parser.update(pairs=override) + self.init_id = init_id + self.run_id = run_id + self.final_id = final_id + # set tracking configs for experiment management + if tracking is not None: + if isinstance(tracking, str) and tracking in DEFAULT_EXP_MGMT_SETTINGS: + settings_ = DEFAULT_EXP_MGMT_SETTINGS[tracking] + else: + settings_ = ConfigParser.load_config_files(tracking) + self.patch_bundle_tracking(parser=self.parser, settings=settings_) + + def initialize(self) -> Any: + """ + Initialize the bundle workflow before running. + + """ + # reset the "reference_resolver" buffer at initialization stage + self.parser.parse(reset=True) + return self._run_expr(id=self.init_id) + + def run(self) -> Any: + """ + Run the bundle workflow, it can be a training, evaluation or inference. + + """ + return self._run_expr(id=self.run_id) + + def finalize(self) -> Any: + """ + Finalize step after the running of bundle workflow. + + """ + return self._run_expr(id=self.final_id) + + def check_properties(self) -> list[str] | None: + """ + Check whether the required properties are existing in the bundle workflow. + If the optional properties have reference in the config, will also check whether the properties are exising. + If no workflow type specified, return None, otherwise, return a list of required but missing properites. + + """ + ret = super().check_properties() + if self.properties is None: + warnings.warn("No available properties had been set, skipping check.") + return None + if ret: + warnings.warn(f"Loaded bundle does not contain the following required properties: {ret}") + # also check whether the optional properties use correct ID name if existing + wrong_props = [] + for n, p in self.properties.items(): + if not p.get(BundleProperty.REQUIRED, False) and not self._check_optional_id(name=n, property=p): + wrong_props.append(n) + if wrong_props: + warnings.warn(f"Loaded bundle defines the following optional properties with wrong ID: {wrong_props}") + if ret is not None: + ret.extend(wrong_props) + return ret + + def _run_expr(self, id: str, **kwargs: dict) -> Any: + return self.parser.get_parsed_content(id, **kwargs) if id in self.parser else None + + def _get_prop_id(self, name: str, property: dict) -> Any: + prop_id = property[BundlePropertyConfig.ID] + if prop_id not in self.parser: + if not property.get(BundleProperty.REQUIRED, False): + return None + else: + raise KeyError(f"Property '{name}' with config ID '{prop_id}' not in the config.") + return prop_id + + def _get_property(self, name: str, property: dict) -> Any: + """ + With specified property name and information, get the parsed property value from config. + + Args: + name: the name of target property. + property: other information for the target property, defined in `TrainProperties` or `InferProperties`. + + """ + if not self.parser.ref_resolver.is_resolved(): + raise RuntimeError("Please execute 'initialize' before getting any parsed content.") + prop_id = self._get_prop_id(name, property) + return self.parser.get_parsed_content(id=prop_id) if prop_id is not None else None + + def _set_property(self, name: str, property: dict, value: Any) -> None: + """ + With specified property name and information, set value for the expected property. + + Args: + name: the name of target property. + property: other information for the target property, defined in `TrainProperties` or `InferProperties`. + value: value to set for the property. + + """ + prop_id = self._get_prop_id(name, property) + if prop_id is not None: + self.parser[prop_id] = value + # must parse the config again after changing the content + self.parser.ref_resolver.reset() + + def _check_optional_id(self, name: str, property: dict) -> bool: + """ + If an optional property has reference in the config, check whether the property is existing. + If `ValidationHandler` is defined for a training workflow, will check whether the optional properties + "evaluator" and "val_interval" are existing. + + Args: + name: the name of target property. + property: other information for the target property, defined in `TrainProperties` or `InferProperties`. + + """ + id = property.get(BundlePropertyConfig.ID, None) + ref_id = property.get(BundlePropertyConfig.REF_ID, None) + if ref_id is None: + # no ID of reference config item, skipping check for this optional property + return True + # check validation `validator` and `interval` properties as the handler index of ValidationHandler is unknown + if name in ("evaluator", "val_interval"): + if f"train{ID_SEP_KEY}handlers" in self.parser: + for h in self.parser[f"train{ID_SEP_KEY}handlers"]: + if h["_target_"] == "ValidationHandler": + ref = h.get(ref_id, None) + else: + ref = self.parser.get(ref_id, None) + if ref is not None and ref != ID_REF_KEY + id: + return False + return True + + @staticmethod + def patch_bundle_tracking(parser: ConfigParser, settings: dict) -> None: + """ + Patch the loaded bundle config with a new handler logic to enable experiment tracking features. + + Args: + parser: loaded config content to patch the handler. + settings: settings for the experiment tracking, should follow the pattern of default settings. + + """ + for k, v in settings["configs"].items(): + if k in settings["handlers_id"]: + engine = parser.get(settings["handlers_id"][k]["id"]) + if engine is not None: + handlers = parser.get(settings["handlers_id"][k]["handlers"]) + if handlers is None: + engine["train_handlers" if k == "trainer" else "val_handlers"] = [v] + else: + handlers.append(v) + elif k not in parser: + parser[k] = v + # save the executed config into file + default_name = f"config_{time.strftime('%Y%m%d_%H%M%S')}.json" + filepath = parser.get("execute_config", None) + if filepath is None: + if "output_dir" not in parser: + # if no "output_dir" in the bundle config, default to "/eval" + parser["output_dir"] = f"{EXPR_KEY}{ID_REF_KEY}bundle_root + '/eval'" + # experiment management tools can refer to this config item to track the config info + parser["execute_config"] = parser["output_dir"] + f" + '/{default_name}'" + filepath = os.path.join(parser.get_parsed_content("output_dir"), default_name) + Path(filepath).parent.mkdir(parents=True, exist_ok=True) + parser.export_config_file(parser.get(), filepath) diff --git a/monai/config/__init__.py b/monai/config/__init__.py index f494202a56..c814e1f8eb 100644 --- a/monai/config/__init__.py +++ b/monai/config/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .deviceconfig import ( USE_COMPILED, USE_META_DICT, diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index 87d46895aa..6ee454ac06 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -9,11 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import platform import re import sys from collections import OrderedDict +from typing import TextIO import numpy as np import torch @@ -65,6 +68,7 @@ def get_optional_config_values(): output = OrderedDict() output["Pytorch Ignite"] = get_package_version("ignite") + output["ITK"] = get_package_version("itk") output["Nibabel"] = get_package_version("nibabel") output["scikit-image"] = get_package_version("skimage") output["Pillow"] = get_package_version("PIL") @@ -173,7 +177,7 @@ def get_system_info() -> OrderedDict: return output -def print_system_info(file=sys.stdout) -> None: +def print_system_info(file: TextIO = sys.stdout) -> None: """ Print system info to `file`. Requires the optional library, `psutil`. @@ -188,7 +192,6 @@ def print_system_info(file=sys.stdout) -> None: def get_gpu_info() -> OrderedDict: - output: OrderedDict = OrderedDict() num_gpus = torch.cuda.device_count() @@ -220,7 +223,7 @@ def get_gpu_info() -> OrderedDict: return output -def print_gpu_info(file=sys.stdout) -> None: +def print_gpu_info(file: TextIO = sys.stdout) -> None: """ Print GPU info to `file`. @@ -231,7 +234,7 @@ def print_gpu_info(file=sys.stdout) -> None: print(f"{k}: {v}", file=file, flush=True) -def print_debug_info(file=sys.stdout) -> None: +def print_debug_info(file: TextIO = sys.stdout) -> None: """ Print config (installed dependencies, etc.) and system info for debugging. diff --git a/monai/config/type_definitions.py b/monai/config/type_definitions.py index 5c360b5536..57454a94e1 100644 --- a/monai/config/type_definitions.py +++ b/monai/config/type_definitions.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os from typing import Collection, Hashable, Iterable, Sequence, TypeVar, Union diff --git a/monai/csrc/ext.cpp b/monai/csrc/ext.cpp index ac43e6fd3e..b56a7454c7 100644 --- a/monai/csrc/ext.cpp +++ b/monai/csrc/ext.cpp @@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // filtering m.def("bilateral_filter", &BilateralFilter, "Bilateral Filter"); m.def("phl_filter", &PermutohedralFilter, "Permutohedral Filter"); + m.def("tbf_forward", &TrainableBilateralFilterForward, "Trainable Bilateral Filter Forward"); + m.def("tbf_backward", &TrainableBilateralFilterBackward, "Trainable Bilateral Filter Backward"); + m.def("tjbf_forward", &TrainableJointBilateralFilterForward, "Trainable Joint Bilateral Filter Forward"); + m.def("tjbf_backward", &TrainableJointBilateralFilterBackward, "Trainable Joint Bilateral Filter Backward"); // lltm m.def("lltm_forward", &lltm_forward, "LLTM forward"); diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp b/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp index 51573ebbc0..39f6d99a8c 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp @@ -15,38 +15,7 @@ limitations under the License. #include #include "utils/tensor_description.h" - -struct Indexer { - public: - Indexer(int dimensions, int* sizes) { - m_dimensions = dimensions; - m_sizes = sizes; - m_index = new int[dimensions]{0}; - } - - bool operator++(int) { - for (int i = 0; i < m_dimensions; i++) { - m_index[i] += 1; - - if (m_index[i] < m_sizes[i]) { - return true; - } else { - m_index[i] = 0; - } - } - - return false; - } - - int& operator[](int dimensionIndex) { - return m_index[dimensionIndex]; - } - - private: - int m_dimensions; - int* m_sizes; - int* m_index; -}; +#include "utils/tensor_indexing.h" template void BilateralFilterCpu(torch::Tensor inputTensor, torch::Tensor outputTensor, float spatialSigma, float colorSigma) { diff --git a/monai/csrc/filtering/filtering.h b/monai/csrc/filtering/filtering.h index 3e680010ed..1ff799cf5e 100644 --- a/monai/csrc/filtering/filtering.h +++ b/monai/csrc/filtering/filtering.h @@ -15,3 +15,5 @@ limitations under the License. #include "bilateral/bilateral.h" #include "permutohedral/permutohedral.h" +#include "trainable_bilateral/trainable_bilateral.h" +#include "trainable_joint_bilateral/trainable_joint_bilateral.h" diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp new file mode 100644 index 0000000000..01840dc2e3 --- /dev/null +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp @@ -0,0 +1,232 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +========================================================================= +Adapted from https://github.com/faebstn96/trainable-bilateral-filter-source +which has the following license... +https://github.com/faebstn96/trainable-bilateral-filter-source/blob/main/LICENSE.md + +Copyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "trainable_bilateral.h" +#include "utils/tensor_description.h" +#include "utils/tensor_indexing.h" + +template +void BilateralFilterCpuBackward_3d( + torch::Tensor gradientInputTensor, + torch::Tensor gradientOutputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(gradientInputTensor); + + // Raw tensor data pointers. + scalar_t* gradientInputTensorData = gradientInputTensor.data_ptr(); + scalar_t* gradientOutputTensorData = gradientOutputTensor.data_ptr(); + scalar_t* inputTensorData = inputTensor.data_ptr(); + scalar_t* outputTensorData = outputTensor.data_ptr(); + scalar_t* outputWeightsTensorData = outputWeightsTensor.data_ptr(); + scalar_t* dO_dx_kiData = dO_dx_ki.data_ptr(); + + // Pre-calculate common values + int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size + int halfWindowSize_x = floor(0.5f * windowSize_x); + int halfWindowSize_y = floor(0.5f * windowSize_y); + int halfWindowSize_z = floor(0.5f * windowSize_z); + int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z}; + scalar_t spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x); + scalar_t spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y); + scalar_t spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z); + scalar_t colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); + + // Set kernel sizes with respect to the defined spatial sigmas. + int* kernelSizes = new int[desc.dimensions]; + + kernelSizes[0] = windowSize_x; + kernelSizes[1] = windowSize_y; + kernelSizes[2] = windowSize_z; + + // Pre-calculate gaussian kernel and distance map in 1D. + scalar_t* gaussianKernel_x = new scalar_t[windowSize_x]; + scalar_t* gaussianKernel_y = new scalar_t[windowSize_y]; + scalar_t* gaussianKernel_z = new scalar_t[windowSize_z]; + scalar_t* xDistanceSquared = new scalar_t[windowSize_x]; + scalar_t* yDistanceSquared = new scalar_t[windowSize_y]; + scalar_t* zDistanceSquared = new scalar_t[windowSize_z]; + + for (int i = 0; i < windowSize_x; i++) { + int distance = i - halfWindowSize_x; + gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); + xDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_y; i++) { + int distance = i - halfWindowSize_y; + gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); + yDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_z; i++) { + int distance = i - halfWindowSize_z; + gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); + zDistanceSquared[i] = distance * distance; + } + + // Looping over the batches + for (int b = 0; b < desc.batchCount; b++) { + int batchOffset = b * desc.batchStride; + + // Looping over all dimensions for the home element + for (int z = 0; z < desc.sizes[2]; z++) +#pragma omp parallel for + for (int y = 0; y < desc.sizes[1]; y++) { + for (int x = 0; x < desc.sizes[0]; x++) { + // Calculating indexing offset for the home element + int homeOffset = batchOffset; + + int homeIndex[] = {x, y, z}; + homeOffset += x * desc.strides[0]; + homeOffset += y * desc.strides[1]; + homeOffset += z * desc.strides[2]; + + // Zero kernel aggregates. + scalar_t filter_kernel = 0; + scalar_t valueSum = 0; + + // Looping over all dimensions for the neighbour element + Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes); + do // while(kernelIndex++) + { + // Calculating buffer offset for the neighbour element + // Index is clamped to the border in each dimension. + int neighbourOffset = batchOffset; + bool flagNotClamped = true; + + for (int i = 0; i < desc.dimensions; i++) { + int neighbourIndex = homeIndex[i] + kernelIndex[i] - halfWindowSize_arr[i]; + int neighbourIndexClamped = std::min(desc.sizes[i] - 1, std::max(0, neighbourIndex)); + neighbourOffset += neighbourIndexClamped * desc.strides[i]; + if (neighbourIndex != neighbourIndexClamped) { + flagNotClamped = false; + } + } + + // Euclidean color distance. + scalar_t colorDistance = 0; + scalar_t colorDistanceSquared = 0; + + for (int i = 0; i < desc.channelCount; i++) { + scalar_t diff = inputTensorData[neighbourOffset + i * desc.channelStride] - + inputTensorData[homeOffset + + i * desc.channelStride]; // Be careful: Here it is (X_k - X_i) and not (X_i - X_q) + colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. + colorDistanceSquared += diff * diff; + } + + // Calculating and combining the spatial + // and color weights. + scalar_t spatialWeight = 1; + + spatialWeight = + gaussianKernel_x[kernelIndex[0]] * gaussianKernel_y[kernelIndex[1]] * gaussianKernel_z[kernelIndex[2]]; + + scalar_t colorWeight = exp(colorDistanceSquared * colorExpConstant); + scalar_t totalWeight = spatialWeight * colorWeight; + + // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. + if (flagNotClamped) { + for (int i = 0; i < desc.channelCount; i++) { + // Distinguish cases for k!=i (calculation is done here) + // and k==i (partial derivatives are precalculated). + // If statement replaces center element of neighborhood/kernel. + if (kernelIndex[0] != halfWindowSize_x || kernelIndex[1] != halfWindowSize_y || + kernelIndex[2] != halfWindowSize_z) { + filter_kernel = -(1 / outputWeightsTensorData[neighbourOffset + i * desc.channelStride]) * + outputTensorData[neighbourOffset + i * desc.channelStride] * totalWeight * colorDistance / + (colorSigma * colorSigma) + + (1 / outputWeightsTensorData[neighbourOffset + i * desc.channelStride]) * totalWeight * + (1 + + inputTensorData[homeOffset + i * desc.channelStride] * colorDistance / + (colorSigma * colorSigma)); // inputTensorData[homeOffset] !! + } else { + filter_kernel = dO_dx_kiData[homeOffset + i * desc.channelStride]; + } + + valueSum += gradientInputTensorData[neighbourOffset + i * desc.channelStride] * filter_kernel; + } + } + } while (kernelIndex++); + + // Do the filtering and calculate the values for the backward pass. + for (int i = 0; i < desc.channelCount; i++) { + // Filtering: + gradientOutputTensorData[homeOffset + i * desc.channelStride] = valueSum; + } + } + } + } + + delete[] kernelSizes; + delete[] gaussianKernel_x; + delete[] gaussianKernel_y; + delete[] gaussianKernel_z; + delete[] xDistanceSquared; + delete[] yDistanceSquared; + delete[] zDistanceSquared; +} + +torch::Tensor BilateralFilterCpuBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Preparing output tensor. + torch::Tensor gradientOutputTensor = torch::zeros_like(gradientInputTensor); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradientInputTensor.scalar_type(), "BilateralFilterCpuBackward_3d", ([&] { + BilateralFilterCpuBackward_3d( + gradientInputTensor, + gradientOutputTensor, + inputTensor, + outputTensor, + outputWeightsTensor, + dO_dx_ki, + sigma_x, + sigma_y, + sigma_z, + colorSigma); + })); + + return gradientOutputTensor; +} diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_forward.cpp b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_forward.cpp new file mode 100644 index 0000000000..b5ef0c077a --- /dev/null +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_forward.cpp @@ -0,0 +1,269 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +========================================================================= +Adapted from https://github.com/faebstn96/trainable-bilateral-filter-source +which has the following license... +https://github.com/faebstn96/trainable-bilateral-filter-source/blob/main/LICENSE.md + +Copyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "trainable_bilateral.h" +#include "utils/tensor_description.h" +#include "utils/tensor_indexing.h" + +template +void BilateralFilterCpuForward_3d( + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + torch::Tensor dO_dsig_r, + torch::Tensor dO_dsig_x, + torch::Tensor dO_dsig_y, + torch::Tensor dO_dsig_z, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + // Raw tensor data pointers. + scalar_t* inputTensorData = inputTensor.data_ptr(); + scalar_t* outputTensorData = outputTensor.data_ptr(); + scalar_t* outputWeightsTensorData = outputWeightsTensor.data_ptr(); + scalar_t* dO_dx_kiData = dO_dx_ki.data_ptr(); + scalar_t* dO_dsig_rData = dO_dsig_r.data_ptr(); + scalar_t* dO_dsig_xData = dO_dsig_x.data_ptr(); + scalar_t* dO_dsig_yData = dO_dsig_y.data_ptr(); + scalar_t* dO_dsig_zData = dO_dsig_z.data_ptr(); + + // Pre-calculate common values + int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size + int halfWindowSize_x = floor(0.5f * windowSize_x); + int halfWindowSize_y = floor(0.5f * windowSize_y); + int halfWindowSize_z = floor(0.5f * windowSize_z); + int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z}; + scalar_t spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x); + scalar_t spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y); + scalar_t spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z); + scalar_t colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); + + // Set kernel sizes with respect to the defined spatial sigmas. + int* kernelSizes = new int[desc.dimensions]; + + kernelSizes[0] = windowSize_x; + kernelSizes[1] = windowSize_y; + kernelSizes[2] = windowSize_z; + + // Pre-calculate gaussian kernel and distance map in 1D. + scalar_t* gaussianKernel_x = new scalar_t[windowSize_x]; + scalar_t* gaussianKernel_y = new scalar_t[windowSize_y]; + scalar_t* gaussianKernel_z = new scalar_t[windowSize_z]; + scalar_t* xDistanceSquared = new scalar_t[windowSize_x]; + scalar_t* yDistanceSquared = new scalar_t[windowSize_y]; + scalar_t* zDistanceSquared = new scalar_t[windowSize_z]; + + for (int i = 0; i < windowSize_x; i++) { + int distance = i - halfWindowSize_x; + gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); + xDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_y; i++) { + int distance = i - halfWindowSize_y; + gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); + yDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_z; i++) { + int distance = i - halfWindowSize_z; + gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); + zDistanceSquared[i] = distance * distance; + } + + // Looping over the batches + for (int b = 0; b < desc.batchCount; b++) { + int batchOffset = b * desc.batchStride; + + // Looping over all dimensions for the home element + for (int z = 0; z < desc.sizes[2]; z++) +#pragma omp parallel for + for (int y = 0; y < desc.sizes[1]; y++) { + for (int x = 0; x < desc.sizes[0]; x++) { + // Calculating indexing offset for the home element + int homeOffset = batchOffset; + + int homeIndex[] = {x, y, z}; + homeOffset += x * desc.strides[0]; + homeOffset += y * desc.strides[1]; + homeOffset += z * desc.strides[2]; + + // Zero kernel aggregates. + scalar_t valueSum = 0; + scalar_t dw_dx_ki = 0; + scalar_t dfilter_dx_ki = 0; + scalar_t colorSum_w = 0; + scalar_t colorSum_alpha = 0; + scalar_t xSum_w = 0; + scalar_t xSum_alpha = 0; + scalar_t ySum_w = 0; + scalar_t ySum_alpha = 0; + scalar_t zSum_w = 0; + scalar_t zSum_alpha = 0; + + scalar_t weightSum = 0.0f; + + // Looping over all dimensions for the neighbour element + Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes); + do // while(kernelIndex++) + { + // Calculating buffer offset for the neighbour element + // Index is clamped to the border in each dimension. + int neighbourOffset = batchOffset; + bool flagNotClamped = true; + + for (int i = 0; i < desc.dimensions; i++) { + int neighbourIndex = homeIndex[i] + kernelIndex[i] - halfWindowSize_arr[i]; + int neighbourIndexClamped = std::min(desc.sizes[i] - 1, std::max(0, neighbourIndex)); + neighbourOffset += neighbourIndexClamped * desc.strides[i]; + if (neighbourIndex != neighbourIndexClamped) { + flagNotClamped = false; + } + } + + // Euclidean color distance. + scalar_t colorDistance = 0; + scalar_t colorDistanceSquared = 0; + + for (int i = 0; i < desc.channelCount; i++) { + scalar_t diff = inputTensorData[homeOffset + i * desc.channelStride] - + inputTensorData[neighbourOffset + i * desc.channelStride]; + colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. + colorDistanceSquared += diff * diff; + } + + // Calculating and combining the spatial + // and color weights. + scalar_t spatialWeight = 1; + + spatialWeight = + gaussianKernel_x[kernelIndex[0]] * gaussianKernel_y[kernelIndex[1]] * gaussianKernel_z[kernelIndex[2]]; + + scalar_t colorWeight = exp(colorDistanceSquared * colorExpConstant); + scalar_t totalWeight = spatialWeight * colorWeight; + + // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. + if (flagNotClamped) { + for (int i = 0; i < desc.channelCount; i++) { + valueSum += inputTensorData[neighbourOffset + i * desc.channelStride] * totalWeight; + + // Derivative of weights with respect to X_i while i=k. + dw_dx_ki += (-1) * totalWeight * colorDistance / (colorSigma * colorSigma); + // Derivative of convolved image with respect to X_i while i=k. + dfilter_dx_ki += (-1) * totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + colorDistance / + (colorSigma * + colorSigma); // Be careful, the +1 is missing here -> Added before filling dfilter_dx_kiData + + colorSum_w += totalWeight * colorDistanceSquared / std::abs(colorSigma * colorSigma * colorSigma); + colorSum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + colorDistanceSquared / std::abs(colorSigma * colorSigma * colorSigma); + + xSum_w += totalWeight * xDistanceSquared[kernelIndex[0]] / std::abs(sigma_x * sigma_x * sigma_x); + xSum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + xDistanceSquared[kernelIndex[0]] / std::abs(sigma_x * sigma_x * sigma_x); + + ySum_w += totalWeight * yDistanceSquared[kernelIndex[1]] / std::abs(sigma_y * sigma_y * sigma_y); + ySum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + yDistanceSquared[kernelIndex[1]] / std::abs(sigma_y * sigma_y * sigma_y); + + zSum_w += totalWeight * zDistanceSquared[kernelIndex[2]] / std::abs(sigma_z * sigma_z * sigma_z); + zSum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + zDistanceSquared[kernelIndex[2]] / std::abs(sigma_z * sigma_z * sigma_z); + } + + weightSum += totalWeight; + } + } while (kernelIndex++); + + // Do the filtering and calculate the values for the backward pass. + for (int i = 0; i < desc.channelCount; i++) { + // Filtering: + outputTensorData[homeOffset + i * desc.channelStride] = valueSum / weightSum; + + // Pre-computations for the backward pass: + outputWeightsTensorData[homeOffset + i * desc.channelStride] = weightSum; + dO_dx_kiData[homeOffset + i * desc.channelStride] = -(1 / weightSum) * (valueSum / weightSum) * dw_dx_ki + + (1 / weightSum) * (dfilter_dx_ki + 1); // +1 for dfilter_dx_ki is added here + dO_dsig_rData[homeOffset + i * desc.channelStride] = + -(1 / weightSum) * (valueSum / weightSum) * colorSum_w + (1 / weightSum) * colorSum_alpha; + dO_dsig_xData[homeOffset + i * desc.channelStride] = + -(1 / weightSum) * (valueSum / weightSum) * xSum_w + (1 / weightSum) * xSum_alpha; + dO_dsig_yData[homeOffset + i * desc.channelStride] = + -(1 / weightSum) * (valueSum / weightSum) * ySum_w + (1 / weightSum) * ySum_alpha; + dO_dsig_zData[homeOffset + i * desc.channelStride] = + -(1 / weightSum) * (valueSum / weightSum) * zSum_w + (1 / weightSum) * zSum_alpha; + } + } + } + } + + delete[] kernelSizes; + delete[] gaussianKernel_x; + delete[] gaussianKernel_y; + delete[] gaussianKernel_z; + delete[] xDistanceSquared; + delete[] yDistanceSquared; + delete[] zDistanceSquared; +} + +std::tuple +BilateralFilterCpuForward(torch::Tensor inputTensor, float sigma_x, float sigma_y, float sigma_z, float colorSigma) { + // Preparing output tensor. + torch::Tensor outputTensor = torch::zeros_like(inputTensor); + torch::Tensor outputWeightsTensor = torch::zeros_like(inputTensor); + torch::Tensor dO_dx_ki = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_r = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_x = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_y = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_z = torch::zeros_like(inputTensor); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputTensor.scalar_type(), "BilateralFilterCpuForward_3d", ([&] { + BilateralFilterCpuForward_3d( + inputTensor, + outputTensor, + outputWeightsTensor, + dO_dx_ki, + dO_dsig_r, + dO_dsig_x, + dO_dsig_y, + dO_dsig_z, + sigma_x, + sigma_y, + sigma_z, + colorSigma); + })); + + return {outputTensor, outputWeightsTensor, dO_dx_ki, dO_dsig_r, dO_dsig_x, dO_dsig_y, dO_dsig_z}; +} diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu new file mode 100644 index 0000000000..973f3e1639 --- /dev/null +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu @@ -0,0 +1,296 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +========================================================================= +Adapted from https://github.com/faebstn96/trainable-bilateral-filter-source +which has the following license... +https://github.com/faebstn96/trainable-bilateral-filter-source/blob/main/LICENSE.md + +Copyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include + +#include "trainable_bilateral.h" +//#include "../utils/cuda_error_check.h" +#include "utils/meta_macros.h" +#include "utils/tensor_description.h" + +__constant__ int cBatchStrideBack; +__constant__ int cColorStrideBack; + +__constant__ int cSizesBack[3]; +__constant__ int cStridesBack[3]; + +__constant__ int cKernelSizesBack[3]; +__constant__ int cHalfWindowSize_arrBack[3]; +__constant__ float cGaussianKernel_xBack[256]; +__constant__ float cGaussianKernel_yBack[256]; +__constant__ float cGaussianKernel_zBack[256]; +__constant__ float cXDistanceSquaredBack[256]; +__constant__ float cYDistanceSquaredBack[256]; +__constant__ float cZDistanceSquaredBack[256]; +__constant__ float cColorExponentConstantBack; +__constant__ float cSigma_xBack; +__constant__ float cSigma_yBack; +__constant__ float cSigma_zBack; +__constant__ float cColorSigmaBack; + +template +__global__ void BilateralFilterCudaKernel3DBackward( + scalar_t* gradientInputTensor, + scalar_t* gradientOutputTensor, + scalar_t* inputTensor, + scalar_t* outputTensor, + scalar_t* outputWeightsTensor, + scalar_t* dO_dx_ki) { + int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; + int batchOffset = blockIdx.y * cBatchStrideBack; + + if (homeOffset >= cColorStrideBack) + return; + + int homeX = homeOffset / cStridesBack[0]; + int homeY = (homeOffset - homeX * cStridesBack[0]) / cStridesBack[1]; + int homeZ = (homeOffset - homeX * cStridesBack[0] - homeY * cStridesBack[1]) / cStridesBack[2]; + int homeIndex[] = {homeX, homeY, homeZ}; + + // Zero kernel aggregates. + scalar_t valueSum = 0; + + for (int kernelX = 0; kernelX < cKernelSizesBack[0]; kernelX++) { + int neighbourX = max(0, min(homeX + (kernelX - cHalfWindowSize_arrBack[0]), cSizesBack[0] - 1)); + scalar_t gaussianX = cGaussianKernel_xBack[kernelX]; + + for (int kernelY = 0; kernelY < cKernelSizesBack[1]; kernelY++) { + int neighbourY = max(0, min(homeY + (kernelY - cHalfWindowSize_arrBack[1]), cSizesBack[1] - 1)); + scalar_t gaussianY = cGaussianKernel_yBack[kernelY]; + + for (int kernelZ = 0; kernelZ < cKernelSizesBack[2]; kernelZ++) { + int neighbourZ = max(0, min(homeZ + (kernelZ - cHalfWindowSize_arrBack[2]), cSizesBack[2] - 1)); + scalar_t gaussianZ = cGaussianKernel_zBack[kernelZ]; + + int neighbourOffset = neighbourX * cStridesBack[0] + neighbourY * cStridesBack[1] + neighbourZ; + + bool flagNotClamped = true; + int kernelIndex[] = {kernelX, kernelY, kernelZ}; + int dimensions = 3; // Must equal the number of spatial dimensions. + + for (int i = 0; i < dimensions; i++) { + int HalfWindowSizeBack = cHalfWindowSize_arrBack[i]; // Define constant memory as new variable here (!!), + // otherwise: cudaErrorMisalignedAddress + int neighbourIndex = homeIndex[i] + kernelIndex[i] - HalfWindowSizeBack; + int neighbourIndexClamped = min(cSizesBack[i] - 1, max(0, neighbourIndex)); + if (neighbourIndex != neighbourIndexClamped) { + flagNotClamped = false; + } + } + + scalar_t colorDistance = 0; + scalar_t colorDistanceSquared = 0; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = inputTensor[batchOffset + neighbourOffset + c * cColorStrideBack]; + scalar_t b = inputTensor[batchOffset + homeOffset + c * cColorStrideBack]; // Be careful: Here it is (X_k - + // X_i) and not (X_i - X_q) + scalar_t diff = a - b; + colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. + colorDistanceSquared += diff * diff; + } + + scalar_t spatialWeight = gaussianX * gaussianY * gaussianZ; + scalar_t colorWeight = exp(cColorExponentConstantBack * colorDistanceSquared); + scalar_t totalWeight = spatialWeight * colorWeight; + + // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. + if (flagNotClamped) { + scalar_t filter_kernel_back; + +#pragma unroll + for (int c = 0; c < C; c++) { + // Distinguish cases for k!=i (calculation is done here) + // and k==i (partial derivatives are precalculated). + // If statement replaces center element of neighborhood/kernel. + if (kernelX != cHalfWindowSize_arrBack[0] || kernelY != cHalfWindowSize_arrBack[1] || + kernelZ != cHalfWindowSize_arrBack[2]) { + filter_kernel_back = -(1 / outputWeightsTensor[batchOffset + neighbourOffset + c * cColorStrideBack]) * + outputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * totalWeight * colorDistance / + (cColorSigmaBack * cColorSigmaBack) + + (1 / outputWeightsTensor[batchOffset + neighbourOffset + c * cColorStrideBack]) * totalWeight * + (1 + + inputTensor[batchOffset + homeOffset + c * cColorStrideBack] * colorDistance / + (cColorSigmaBack * cColorSigmaBack)); // inputTensorData[homeOffset] !! + } else { + filter_kernel_back = dO_dx_ki[batchOffset + homeOffset + c * cColorStrideBack]; + } + + valueSum += gradientInputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * filter_kernel_back; + } + } + } + } + } + +#pragma unroll + for (int c = 0; c < C; c++) { + gradientOutputTensor[batchOffset + homeOffset + c * cColorStrideBack] = valueSum; + } +} + +template +void BilateralFilterCudaBackwardFunction( + torch::Tensor gradientInputTensor, + torch::Tensor gradientOutputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + // Pre-calculating gaussian kernel. + int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size + int halfWindowSize_x = floor(0.5f * windowSize_x); + int halfWindowSize_y = floor(0.5f * windowSize_y); + int halfWindowSize_z = floor(0.5f * windowSize_z); + int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z}; + float spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x); + float spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y); + float spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z); + float colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); + + int* kernelSizes = new int[desc.dimensions]; + kernelSizes[0] = windowSize_x; + kernelSizes[1] = windowSize_y; + kernelSizes[2] = windowSize_z; + + auto* gaussianKernel_x = new float[windowSize_x]; + auto* gaussianKernel_y = new float[windowSize_y]; + auto* gaussianKernel_z = new float[windowSize_z]; + auto* xDistanceSquared = new float[windowSize_x]; + auto* yDistanceSquared = new float[windowSize_y]; + auto* zDistanceSquared = new float[windowSize_z]; + + for (int i = 0; i < windowSize_x; i++) { + int distance = i - halfWindowSize_x; + gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); + xDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_y; i++) { + int distance = i - halfWindowSize_y; + gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); + yDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_z; i++) { + int distance = i - halfWindowSize_z; + gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); + zDistanceSquared[i] = distance * distance; + } + + // Writing constant memory. + cudaMemcpyToSymbol(cBatchStrideBack, &desc.batchStride, sizeof(int)); + cudaMemcpyToSymbol(cColorStrideBack, &desc.channelStride, sizeof(int)); + cudaMemcpyToSymbol(cSizesBack, desc.sizes, sizeof(int) * 3); + cudaMemcpyToSymbol(cStridesBack, desc.strides, sizeof(int) * 3); + cudaMemcpyToSymbol(cKernelSizesBack, kernelSizes, sizeof(int) * desc.dimensions); + cudaMemcpyToSymbol(cHalfWindowSize_arrBack, halfWindowSize_arr, sizeof(int) * desc.dimensions); + cudaMemcpyToSymbol(cGaussianKernel_xBack, gaussianKernel_x, sizeof(float) * windowSize_x); + cudaMemcpyToSymbol(cGaussianKernel_yBack, gaussianKernel_y, sizeof(float) * windowSize_y); + cudaMemcpyToSymbol(cGaussianKernel_zBack, gaussianKernel_z, sizeof(float) * windowSize_z); + cudaMemcpyToSymbol(cXDistanceSquaredBack, xDistanceSquared, sizeof(float) * windowSize_x); + cudaMemcpyToSymbol(cYDistanceSquaredBack, yDistanceSquared, sizeof(float) * windowSize_y); + cudaMemcpyToSymbol(cZDistanceSquaredBack, zDistanceSquared, sizeof(float) * windowSize_z); + cudaMemcpyToSymbol(cColorExponentConstantBack, &colorExpConstant, sizeof(float)); + cudaMemcpyToSymbol(cSigma_xBack, &sigma_x, sizeof(float)); + cudaMemcpyToSymbol(cSigma_yBack, &sigma_y, sizeof(float)); + cudaMemcpyToSymbol(cSigma_zBack, &sigma_z, sizeof(float)); + cudaMemcpyToSymbol(cColorSigmaBack, &colorSigma, sizeof(float)); + + // cuda_error_check("Cuda check before kernel call."); + +#define BLOCK_SIZE 32 + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + inputTensor.scalar_type(), "BilateralFilterCudaKernel3DBackward", ([&] { + BilateralFilterCudaKernel3DBackward + <<>>( + gradientInputTensor.data_ptr(), + gradientOutputTensor.data_ptr(), + inputTensor.data_ptr(), + outputTensor.data_ptr(), + outputWeightsTensor.data_ptr(), + dO_dx_ki.data_ptr()); + })); + + // cuda_error_check("Cuda check after kernel call."); + // delete[] kernel; + delete[] kernelSizes; + delete[] gaussianKernel_x; + delete[] gaussianKernel_y; + delete[] gaussianKernel_z; + delete[] xDistanceSquared; + delete[] yDistanceSquared; + delete[] zDistanceSquared; +} + +// Function to choose template implementation based on dynamic, channels and dimensions +torch::Tensor BilateralFilterCudaBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + torch::Tensor gradientOutputTensor = torch::zeros_like(gradientInputTensor); + // cuda_error_check("beginning"); + +#define CASE(c, d) \ + BilateralFilterCudaBackwardFunction( \ + gradientInputTensor, \ + gradientOutputTensor, \ + inputTensor, \ + outputTensor, \ + outputWeightsTensor, \ + dO_dx_ki, \ + sigma_x, \ + sigma_y, \ + sigma_z, \ + colorSigma); + SWITCH_AB( + CASE, + BF_CUDA_MAX_CHANNELS, + BF_CUDA_MAX_SPATIAL_DIMENSION, + gradientInputTensor.size(1), + gradientInputTensor.dim() - 2); + + return gradientOutputTensor; +} diff --git a/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_forward.cu b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_forward.cu new file mode 100644 index 0000000000..b9856e2d1e --- /dev/null +++ b/monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_forward.cu @@ -0,0 +1,330 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +========================================================================= +Adapted from https://github.com/faebstn96/trainable-bilateral-filter-source +which has the following license... +https://github.com/faebstn96/trainable-bilateral-filter-source/blob/main/LICENSE.md + +Copyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include + +#include "trainable_bilateral.h" +//#include "../utils/cuda_error_check.h" +#include "utils/meta_macros.h" +#include "utils/tensor_description.h" + +__constant__ int cBatchStride; +__constant__ int cColorStride; + +__constant__ int cSizes[3]; +__constant__ int cStrides[3]; + +__constant__ int cKernelSizes[3]; +__constant__ int cHalfWindowSize_arr[3]; +__constant__ float cGaussianKernel_x[256]; +__constant__ float cGaussianKernel_y[256]; +__constant__ float cGaussianKernel_z[256]; +__constant__ float cXDistanceSquared[256]; +__constant__ float cYDistanceSquared[256]; +__constant__ float cZDistanceSquared[256]; +__constant__ float cColorExponentConstant; +__constant__ float cSigma_x; +__constant__ float cSigma_y; +__constant__ float cSigma_z; +__constant__ float cColorSigma; + +template +__global__ void BilateralFilterCudaKernel3DForward( + scalar_t* input, + scalar_t* output, + scalar_t* outputWeightsTensor, + scalar_t* dO_dx_ki, + scalar_t* dO_dsig_r, + scalar_t* dO_dsig_x, + scalar_t* dO_dsig_y, + scalar_t* dO_dsig_z) { + int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; + int batchOffset = blockIdx.y * cBatchStride; + + if (homeOffset >= cColorStride) + return; + + int homeX = homeOffset / cStrides[0]; + int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1]; + int homeZ = (homeOffset - homeX * cStrides[0] - homeY * cStrides[1]) / cStrides[2]; + int homeIndex[] = {homeX, homeY, homeZ}; + + // Zero kernel aggregates. + scalar_t valueSum = 0; + scalar_t dw_dx_ki = 0; + scalar_t dfilter_dx_ki = 0; + scalar_t colorSum_w = 0; + scalar_t colorSum_alpha = 0; + scalar_t xSum_w = 0; + scalar_t xSum_alpha = 0; + scalar_t ySum_w = 0; + scalar_t ySum_alpha = 0; + scalar_t zSum_w = 0; + scalar_t zSum_alpha = 0; + scalar_t weightSum = 0; + + for (int kernelX = 0; kernelX < cKernelSizes[0]; kernelX++) { + int neighbourX = max(0, min(homeX + (kernelX - cHalfWindowSize_arr[0]), cSizes[0] - 1)); + scalar_t gaussianX = cGaussianKernel_x[kernelX]; + + for (int kernelY = 0; kernelY < cKernelSizes[1]; kernelY++) { + int neighbourY = max(0, min(homeY + (kernelY - cHalfWindowSize_arr[1]), cSizes[1] - 1)); + scalar_t gaussianY = cGaussianKernel_y[kernelY]; + + for (int kernelZ = 0; kernelZ < cKernelSizes[2]; kernelZ++) { + int neighbourZ = max(0, min(homeZ + (kernelZ - cHalfWindowSize_arr[2]), cSizes[2] - 1)); + scalar_t gaussianZ = cGaussianKernel_z[kernelZ]; + + int neighbourOffset = neighbourX * cStrides[0] + neighbourY * cStrides[1] + neighbourZ; + + bool flagNotClamped = true; + int kernelIndex[] = {kernelX, kernelY, kernelZ}; + int dimensions = 3; // Must equal the number of spatial dimensions. + + for (int i = 0; i < dimensions; i++) { + int HalfWindowSizeBack = cHalfWindowSize_arr[i]; // Define constant memory as new variable here (!!), + // otherwise: cudaErrorMisalignedAddress + int neighbourIndex = homeIndex[i] + kernelIndex[i] - HalfWindowSizeBack; + int neighbourIndexClamped = min(cSizes[i] - 1, max(0, neighbourIndex)); + if (neighbourIndex != neighbourIndexClamped) { + flagNotClamped = false; + } + } + + scalar_t colorDistance = 0; + scalar_t colorDistanceSquared = 0; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = input[batchOffset + homeOffset + c * cColorStride]; + scalar_t b = input[batchOffset + neighbourOffset + c * cColorStride]; // Home - neighbor (!!) in backward the + // other way around !! + scalar_t diff = a - b; + colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. + colorDistanceSquared += diff * diff; + } + + scalar_t spatialWeight = gaussianX * gaussianY * gaussianZ; + scalar_t colorWeight = exp(cColorExponentConstant * colorDistanceSquared); + scalar_t totalWeight = spatialWeight * colorWeight; + + // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. + if (flagNotClamped) { +#pragma unroll + for (int c = 0; c < C; c++) { + valueSum += input[batchOffset + neighbourOffset + c * cColorStride] * totalWeight; + + // Derivative of weights with respect to X_i while i=k. + dw_dx_ki += (-1) * totalWeight * colorDistance / (cColorSigma * cColorSigma); + // Derivative of convolved image with respect to X_i while i=k. + dfilter_dx_ki += (-1) * totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + colorDistance / + (cColorSigma * + cColorSigma); // Be careful, the +1 is missing here -> Added before filling dfilter_dx_kiData + + colorSum_w += totalWeight * colorDistanceSquared / std::abs(cColorSigma * cColorSigma * cColorSigma); + colorSum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + colorDistanceSquared / std::abs(cColorSigma * cColorSigma * cColorSigma); + + xSum_w += totalWeight * cXDistanceSquared[kernelX] / std::abs(cSigma_x * cSigma_x * cSigma_x); + xSum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + cXDistanceSquared[kernelX] / std::abs(cSigma_x * cSigma_x * cSigma_x); + + ySum_w += totalWeight * cYDistanceSquared[kernelY] / std::abs(cSigma_y * cSigma_y * cSigma_y); + ySum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + cYDistanceSquared[kernelY] / std::abs(cSigma_y * cSigma_y * cSigma_y); + + zSum_w += totalWeight * cZDistanceSquared[kernelZ] / std::abs(cSigma_z * cSigma_z * cSigma_z); + zSum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + cZDistanceSquared[kernelZ] / std::abs(cSigma_z * cSigma_z * cSigma_z); + } + + weightSum += totalWeight; + } + } + } + } + +#pragma unroll + for (int c = 0; c < C; c++) { + // output[batchOffset + homeOffset + c * cColorStride] /= weightSum; + output[batchOffset + homeOffset + c * cColorStride] = valueSum / weightSum; + + // Pre-computations for the backward pass: + outputWeightsTensor[batchOffset + homeOffset + c * cColorStride] = weightSum; + dO_dx_ki[batchOffset + homeOffset + c * cColorStride] = -(1 / weightSum) * (valueSum / weightSum) * dw_dx_ki + + (1 / weightSum) * (dfilter_dx_ki + 1); // +1 for dfilter_dx_ki is added here + dO_dsig_r[batchOffset + homeOffset + c * cColorStride] = + -(1 / weightSum) * (valueSum / weightSum) * colorSum_w + (1 / weightSum) * colorSum_alpha; + dO_dsig_x[batchOffset + homeOffset + c * cColorStride] = + -(1 / weightSum) * (valueSum / weightSum) * xSum_w + (1 / weightSum) * xSum_alpha; + dO_dsig_y[batchOffset + homeOffset + c * cColorStride] = + -(1 / weightSum) * (valueSum / weightSum) * ySum_w + (1 / weightSum) * ySum_alpha; + dO_dsig_z[batchOffset + homeOffset + c * cColorStride] = + -(1 / weightSum) * (valueSum / weightSum) * zSum_w + (1 / weightSum) * zSum_alpha; + } +} + +template +void BilateralFilterCudaForwardFunction( + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + torch::Tensor dO_dsig_r, + torch::Tensor dO_dsig_x, + torch::Tensor dO_dsig_y, + torch::Tensor dO_dsig_z, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + // Pre-calculating gaussian kernel. + int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size + int halfWindowSize_x = floor(0.5f * windowSize_x); + int halfWindowSize_y = floor(0.5f * windowSize_y); + int halfWindowSize_z = floor(0.5f * windowSize_z); + int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z}; + float spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x); + float spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y); + float spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z); + float colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); + + int* kernelSizes = new int[desc.dimensions]; + kernelSizes[0] = windowSize_x; + kernelSizes[1] = windowSize_y; + kernelSizes[2] = windowSize_z; + + auto* gaussianKernel_x = new float[windowSize_x]; + auto* gaussianKernel_y = new float[windowSize_y]; + auto* gaussianKernel_z = new float[windowSize_z]; + auto* xDistanceSquared = new float[windowSize_x]; + auto* yDistanceSquared = new float[windowSize_y]; + auto* zDistanceSquared = new float[windowSize_z]; + + for (int i = 0; i < windowSize_x; i++) { + int distance = i - halfWindowSize_x; + gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); + xDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_y; i++) { + int distance = i - halfWindowSize_y; + gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); + yDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_z; i++) { + int distance = i - halfWindowSize_z; + gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); + zDistanceSquared[i] = distance * distance; + } + + // Writing constant memory. + cudaMemcpyToSymbol(cBatchStride, &desc.batchStride, sizeof(int)); + cudaMemcpyToSymbol(cColorStride, &desc.channelStride, sizeof(int)); + cudaMemcpyToSymbol(cSizes, desc.sizes, sizeof(int) * 3); + cudaMemcpyToSymbol(cStrides, desc.strides, sizeof(int) * 3); + cudaMemcpyToSymbol(cKernelSizes, kernelSizes, sizeof(int) * desc.dimensions); + cudaMemcpyToSymbol(cHalfWindowSize_arr, halfWindowSize_arr, sizeof(int) * desc.dimensions); + cudaMemcpyToSymbol(cGaussianKernel_x, gaussianKernel_x, sizeof(float) * windowSize_x); + cudaMemcpyToSymbol(cGaussianKernel_y, gaussianKernel_y, sizeof(float) * windowSize_y); + cudaMemcpyToSymbol(cGaussianKernel_z, gaussianKernel_z, sizeof(float) * windowSize_z); + cudaMemcpyToSymbol(cXDistanceSquared, xDistanceSquared, sizeof(float) * windowSize_x); + cudaMemcpyToSymbol(cYDistanceSquared, yDistanceSquared, sizeof(float) * windowSize_y); + cudaMemcpyToSymbol(cZDistanceSquared, zDistanceSquared, sizeof(float) * windowSize_z); + cudaMemcpyToSymbol(cColorExponentConstant, &colorExpConstant, sizeof(float)); + cudaMemcpyToSymbol(cSigma_x, &sigma_x, sizeof(float)); + cudaMemcpyToSymbol(cSigma_y, &sigma_y, sizeof(float)); + cudaMemcpyToSymbol(cSigma_z, &sigma_z, sizeof(float)); + cudaMemcpyToSymbol(cColorSigma, &colorSigma, sizeof(float)); + + // cuda_error_check("Cuda check before kernel call."); + +#define BLOCK_SIZE 32 + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + inputTensor.scalar_type(), "BilateralFilterCudaKernel3DForward", ([&] { + BilateralFilterCudaKernel3DForward + <<>>( + inputTensor.data_ptr(), + outputTensor.data_ptr(), + outputWeightsTensor.data_ptr(), + dO_dx_ki.data_ptr(), + dO_dsig_r.data_ptr(), + dO_dsig_x.data_ptr(), + dO_dsig_y.data_ptr(), + dO_dsig_z.data_ptr()); + })); + + // cuda_error_check("Cuda check after kernel call."); + // delete[] kernel; + delete[] kernelSizes; + delete[] gaussianKernel_x; + delete[] gaussianKernel_y; + delete[] gaussianKernel_z; + delete[] xDistanceSquared; + delete[] yDistanceSquared; + delete[] zDistanceSquared; +} + +// Function to choose template implementation based on dynamic, channels and dimensions +std::tuple +BilateralFilterCudaForward(torch::Tensor inputTensor, float sigma_x, float sigma_y, float sigma_z, float colorSigma) { + torch::Tensor outputTensor = torch::zeros_like(inputTensor); + torch::Tensor outputWeightsTensor = torch::zeros_like(inputTensor); + torch::Tensor dO_dx_ki = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_r = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_x = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_y = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_z = torch::zeros_like(inputTensor); + // cuda_error_check("beginning"); + +#define CASE(c, d) \ + BilateralFilterCudaForwardFunction( \ + inputTensor, \ + outputTensor, \ + outputWeightsTensor, \ + dO_dx_ki, \ + dO_dsig_r, \ + dO_dsig_x, \ + dO_dsig_y, \ + dO_dsig_z, \ + sigma_x, \ + sigma_y, \ + sigma_z, \ + colorSigma); + SWITCH_AB(CASE, BF_CUDA_MAX_CHANNELS, BF_CUDA_MAX_SPATIAL_DIMENSION, inputTensor.size(1), inputTensor.dim() - 2); + + return {outputTensor, outputWeightsTensor, dO_dx_ki, dO_dsig_r, dO_dsig_x, dO_dsig_y, dO_dsig_z}; +} diff --git a/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.cpp b/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.cpp new file mode 100644 index 0000000000..0c8d38fd6c --- /dev/null +++ b/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.cpp @@ -0,0 +1,121 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +========================================================================= +Adapted from https://github.com/faebstn96/trainable-bilateral-filter-source +which has the following license... +https://github.com/faebstn96/trainable-bilateral-filter-source/blob/main/LICENSE.md + +Copyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include + +#include "trainable_bilateral.h" +#include "utils/common_utils.h" + +std::tuple +TrainableBilateralFilterForward( + torch::Tensor inputTensor, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + std::tuple ( + *filterFunction)(torch::Tensor, float, float, float, float); + +#ifdef WITH_CUDA + + if (torch::cuda::is_available() && inputTensor.is_cuda()) { + CHECK_CONTIGUOUS_CUDA(inputTensor); + + if (inputTensor.size(1) > BF_CUDA_MAX_CHANNELS) { + throw std::runtime_error( + "Bilateral filtering not implemented for channel count > " + std::to_string(BF_CUDA_MAX_CHANNELS)); + } + + if (inputTensor.dim() - 2 > BF_CUDA_MAX_SPATIAL_DIMENSION) { + throw std::runtime_error( + "Bilateral filtering not implemented for spatial dimension > " + + std::to_string(BF_CUDA_MAX_SPATIAL_DIMENSION)); + } + + filterFunction = &BilateralFilterCudaForward; + } else { + filterFunction = &BilateralFilterCpuForward; + } +#else + filterFunction = &BilateralFilterCpuForward; +#endif + + return filterFunction(inputTensor, sigma_x, sigma_y, sigma_z, colorSigma); +} + +torch::Tensor TrainableBilateralFilterBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + torch::Tensor (*filterFunction)( + torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, float, float, float, float); + +#ifdef WITH_CUDA + + if (torch::cuda::is_available() && gradientInputTensor.is_cuda()) { + CHECK_CONTIGUOUS_CUDA(gradientInputTensor); + + if (gradientInputTensor.size(1) > BF_CUDA_MAX_CHANNELS) { + throw std::runtime_error( + "Bilateral filtering not implemented for channel count > " + std::to_string(BF_CUDA_MAX_CHANNELS)); + } + + if (gradientInputTensor.dim() - 2 > BF_CUDA_MAX_SPATIAL_DIMENSION) { + throw std::runtime_error( + "Bilateral filtering not implemented for spatial dimension > " + + std::to_string(BF_CUDA_MAX_SPATIAL_DIMENSION)); + } + + filterFunction = &BilateralFilterCudaBackward; + } else { + filterFunction = &BilateralFilterCpuBackward; + } +#else + filterFunction = &BilateralFilterCpuBackward; +#endif + + return filterFunction( + gradientInputTensor, + inputTensor, + outputTensor, + outputWeightsTensor, + dO_dx_ki, + sigma_x, + sigma_y, + sigma_z, + colorSigma); +} diff --git a/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.h b/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.h new file mode 100644 index 0000000000..7420fe82fd --- /dev/null +++ b/monai/csrc/filtering/trainable_bilateral/trainable_bilateral.h @@ -0,0 +1,88 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +========================================================================= +Adapted from https://github.com/faebstn96/trainable-bilateral-filter-source +which has the following license... +https://github.com/faebstn96/trainable-bilateral-filter-source/blob/main/LICENSE.md + +Copyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +#include +#include +#include +#include +#include "utils/common_utils.h" +//#include "utils/tensor_description.h" + +#define BF_CUDA_MAX_CHANNELS 16 +#define BF_CUDA_MAX_SPATIAL_DIMENSION 3 + +#ifdef WITH_CUDA +std::tuple +BilateralFilterCudaForward(torch::Tensor inputTensor, float sigma_x, float sigma_y, float sigma_z, float colorSigma); +torch::Tensor BilateralFilterCudaBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); +#endif + +std::tuple +BilateralFilterCpuForward(torch::Tensor inputTensor, float sigma_x, float sigma_y, float sigma_z, float colorSigma); + +torch::Tensor BilateralFilterCpuBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); + +std::tuple +TrainableBilateralFilterForward( + torch::Tensor inputTensor, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); + +torch::Tensor TrainableBilateralFilterBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); diff --git a/monai/csrc/filtering/trainable_joint_bilateral/jbf_layer_cpu_backward.cpp b/monai/csrc/filtering/trainable_joint_bilateral/jbf_layer_cpu_backward.cpp new file mode 100644 index 0000000000..810c0b3fda --- /dev/null +++ b/monai/csrc/filtering/trainable_joint_bilateral/jbf_layer_cpu_backward.cpp @@ -0,0 +1,246 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +========================================================================= +Adapted from https://github.com/faebstn96/trainable-joint-bilateral-filter-source +which has the following license... +https://github.com/faebstn96/trainable-joint-bilateral-filter-source/blob/main/LICENSE + +Copyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "trainable_joint_bilateral.h" +#include "utils/tensor_description.h" +#include "utils/tensor_indexing.h" + +template +void JointBilateralFilterCpuBackward_3d( + torch::Tensor gradientInputTensor, + torch::Tensor gradientGuidanceTensor, + torch::Tensor gradientOutputTensor, + torch::Tensor inputTensor, + torch::Tensor guidanceTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dz_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(gradientInputTensor); + + // Raw tensor data pointers. + scalar_t* gradientInputTensorData = gradientInputTensor.data_ptr(); + scalar_t* gradientGuidanceTensorData = gradientGuidanceTensor.data_ptr(); + scalar_t* gradientOutputTensorData = gradientOutputTensor.data_ptr(); + scalar_t* inputTensorData = inputTensor.data_ptr(); + scalar_t* guidanceTensorData = guidanceTensor.data_ptr(); + scalar_t* outputTensorData = outputTensor.data_ptr(); + scalar_t* outputWeightsTensorData = outputWeightsTensor.data_ptr(); + scalar_t* dO_dz_kiData = dO_dz_ki.data_ptr(); + // scalar_t* dw_dx_kiData = dw_dx_ki_Tensor.data_ptr(); + // scalar_t* dfilter_dx_kiData = dfilter_dx_ki_Tensor.data_ptr(); + + // Pre-calculate common values + int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size + int halfWindowSize_x = floor(0.5f * windowSize_x); + int halfWindowSize_y = floor(0.5f * windowSize_y); + int halfWindowSize_z = floor(0.5f * windowSize_z); + int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z}; + scalar_t spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x); + scalar_t spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y); + scalar_t spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z); + scalar_t colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); + + // Set kernel sizes with respect to the defined spatial sigmas. + int* kernelSizes = new int[desc.dimensions]; + + kernelSizes[0] = windowSize_x; + kernelSizes[1] = windowSize_y; + kernelSizes[2] = windowSize_z; + + // Pre-calculate gaussian kernel and distance map in 1D. + scalar_t* gaussianKernel_x = new scalar_t[windowSize_x]; + scalar_t* gaussianKernel_y = new scalar_t[windowSize_y]; + scalar_t* gaussianKernel_z = new scalar_t[windowSize_z]; + scalar_t* xDistanceSquared = new scalar_t[windowSize_x]; + scalar_t* yDistanceSquared = new scalar_t[windowSize_y]; + scalar_t* zDistanceSquared = new scalar_t[windowSize_z]; + + for (int i = 0; i < windowSize_x; i++) { + int distance = i - halfWindowSize_x; + gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); + xDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_y; i++) { + int distance = i - halfWindowSize_y; + gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); + yDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_z; i++) { + int distance = i - halfWindowSize_z; + gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); + zDistanceSquared[i] = distance * distance; + } + + // Looping over the batches + for (int b = 0; b < desc.batchCount; b++) { + int batchOffset = b * desc.batchStride; + + // Looping over all dimensions for the home element + for (int z = 0; z < desc.sizes[2]; z++) +#pragma omp parallel for + for (int y = 0; y < desc.sizes[1]; y++) { + for (int x = 0; x < desc.sizes[0]; x++) { + // Calculating indexing offset for the home element + int homeOffset = batchOffset; + + int homeIndex[] = {x, y, z}; + homeOffset += x * desc.strides[0]; + homeOffset += y * desc.strides[1]; + homeOffset += z * desc.strides[2]; + + // Zero kernel aggregates. + scalar_t filter_kernel_guidance = 0; + scalar_t valueSumGuidance = 0; + scalar_t valueSumInput = 0; + + // Looping over all dimensions for the neighbour element + Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes); + do // while(kernelIndex++) + { + // Calculating buffer offset for the neighbour element + // Index is clamped to the border in each dimension. + int neighbourOffset = batchOffset; + bool flagNotClamped = true; + + for (int i = 0; i < desc.dimensions; i++) { + int neighbourIndex = homeIndex[i] + kernelIndex[i] - halfWindowSize_arr[i]; + int neighbourIndexClamped = std::min(desc.sizes[i] - 1, std::max(0, neighbourIndex)); + neighbourOffset += neighbourIndexClamped * desc.strides[i]; + if (neighbourIndex != neighbourIndexClamped) { + flagNotClamped = false; + } + } + + // Euclidean color distance. + scalar_t colorDistance = 0; + scalar_t colorDistanceSquared = 0; + + for (int i = 0; i < desc.channelCount; i++) { + scalar_t diff = guidanceTensorData[neighbourOffset + i * desc.channelStride] - + guidanceTensorData[homeOffset + i * desc.channelStride]; // Be careful: Here it is (Z_k - Z_i) and not + // (Z_i - Z_q) + colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. + colorDistanceSquared += diff * diff; + } + + // Calculating and combining the spatial + // and color weights. + scalar_t spatialWeight = 1; + + spatialWeight = + gaussianKernel_x[kernelIndex[0]] * gaussianKernel_y[kernelIndex[1]] * gaussianKernel_z[kernelIndex[2]]; + + scalar_t colorWeight = exp(colorDistanceSquared * colorExpConstant); + scalar_t totalWeight = spatialWeight * colorWeight; + + // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. + if (flagNotClamped) { + for (int i = 0; i < desc.channelCount; i++) { + // Distinguish cases for k!=i (calculation is done here) + // and k==i (partial derivatives are precalculated). + // If statement replaces center element of neighborhood/kernel. + if (kernelIndex[0] != halfWindowSize_x || kernelIndex[1] != halfWindowSize_y || + kernelIndex[2] != halfWindowSize_z) { + filter_kernel_guidance = -(1 / outputWeightsTensorData[neighbourOffset + i * desc.channelStride]) * + outputTensorData[neighbourOffset + i * desc.channelStride] * totalWeight * colorDistance / + (colorSigma * colorSigma) + + (1 / outputWeightsTensorData[neighbourOffset + i * desc.channelStride]) * totalWeight * + (inputTensorData[homeOffset + i * desc.channelStride] * colorDistance / + (colorSigma * colorSigma)); // inputTensorData[homeOffset] !!, no +1!! + } else { + filter_kernel_guidance = dO_dz_kiData[homeOffset + i * desc.channelStride]; + } + + valueSumGuidance += + gradientInputTensorData[neighbourOffset + i * desc.channelStride] * filter_kernel_guidance; + valueSumInput += gradientInputTensorData[neighbourOffset + i * desc.channelStride] * + (1 / outputWeightsTensorData[neighbourOffset + i * desc.channelStride]) * totalWeight; + } + } + } while (kernelIndex++); + + // Do the filtering and calculate the values for the backward pass. + for (int i = 0; i < desc.channelCount; i++) { + // Filtering: + gradientGuidanceTensorData[homeOffset + i * desc.channelStride] = valueSumGuidance; + gradientOutputTensorData[homeOffset + i * desc.channelStride] = valueSumInput; + } + } + } + } + + delete[] kernelSizes; + delete[] gaussianKernel_x; + delete[] gaussianKernel_y; + delete[] gaussianKernel_z; + delete[] xDistanceSquared; + delete[] yDistanceSquared; + delete[] zDistanceSquared; +} + +std::tuple JointBilateralFilterCpuBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor guidanceTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dz_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Preparing output tensor. + torch::Tensor gradientOutputTensor = torch::zeros_like(gradientInputTensor); + torch::Tensor gradientGuidanceTensor = torch::zeros_like(gradientInputTensor); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradientInputTensor.scalar_type(), "JointBilateralFilterCpuBackward_3d", ([&] { + JointBilateralFilterCpuBackward_3d( + gradientInputTensor, + gradientGuidanceTensor, + gradientOutputTensor, + inputTensor, + guidanceTensor, + outputTensor, + outputWeightsTensor, + dO_dz_ki, + sigma_x, + sigma_y, + sigma_z, + colorSigma); + })); + + return {gradientOutputTensor, gradientGuidanceTensor}; +} diff --git a/monai/csrc/filtering/trainable_joint_bilateral/jbf_layer_cpu_forward.cpp b/monai/csrc/filtering/trainable_joint_bilateral/jbf_layer_cpu_forward.cpp new file mode 100644 index 0000000000..041b10f7d8 --- /dev/null +++ b/monai/csrc/filtering/trainable_joint_bilateral/jbf_layer_cpu_forward.cpp @@ -0,0 +1,278 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +========================================================================= +Adapted from https://github.com/faebstn96/trainable-joint-bilateral-filter-source +which has the following license... +https://github.com/faebstn96/trainable-joint-bilateral-filter-source/blob/main/LICENSE + +Copyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "trainable_joint_bilateral.h" +#include "utils/tensor_description.h" +#include "utils/tensor_indexing.h" + +template +void JointBilateralFilterCpuForward_3d( + torch::Tensor inputTensor, + torch::Tensor guidanceTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dz_ki, + torch::Tensor dO_dsig_r, + torch::Tensor dO_dsig_x, + torch::Tensor dO_dsig_y, + torch::Tensor dO_dsig_z, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + // Raw tensor data pointers. + scalar_t* inputTensorData = inputTensor.data_ptr(); + scalar_t* guidanceTensorData = guidanceTensor.data_ptr(); + scalar_t* outputTensorData = outputTensor.data_ptr(); + scalar_t* outputWeightsTensorData = outputWeightsTensor.data_ptr(); + scalar_t* dO_dz_kiData = dO_dz_ki.data_ptr(); + scalar_t* dO_dsig_rData = dO_dsig_r.data_ptr(); + scalar_t* dO_dsig_xData = dO_dsig_x.data_ptr(); + scalar_t* dO_dsig_yData = dO_dsig_y.data_ptr(); + scalar_t* dO_dsig_zData = dO_dsig_z.data_ptr(); + + // Pre-calculate common values + int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size + int halfWindowSize_x = floor(0.5f * windowSize_x); + int halfWindowSize_y = floor(0.5f * windowSize_y); + int halfWindowSize_z = floor(0.5f * windowSize_z); + int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z}; + scalar_t spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x); + scalar_t spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y); + scalar_t spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z); + scalar_t colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); + + // Set kernel sizes with respect to the defined spatial sigmas. + int* kernelSizes = new int[desc.dimensions]; + + kernelSizes[0] = windowSize_x; + kernelSizes[1] = windowSize_y; + kernelSizes[2] = windowSize_z; + + // Pre-calculate gaussian kernel and distance map in 1D. + scalar_t* gaussianKernel_x = new scalar_t[windowSize_x]; + scalar_t* gaussianKernel_y = new scalar_t[windowSize_y]; + scalar_t* gaussianKernel_z = new scalar_t[windowSize_z]; + scalar_t* xDistanceSquared = new scalar_t[windowSize_x]; + scalar_t* yDistanceSquared = new scalar_t[windowSize_y]; + scalar_t* zDistanceSquared = new scalar_t[windowSize_z]; + + for (int i = 0; i < windowSize_x; i++) { + int distance = i - halfWindowSize_x; + gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); + xDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_y; i++) { + int distance = i - halfWindowSize_y; + gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); + yDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_z; i++) { + int distance = i - halfWindowSize_z; + gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); + zDistanceSquared[i] = distance * distance; + } + + // Looping over the batches + for (int b = 0; b < desc.batchCount; b++) { + int batchOffset = b * desc.batchStride; + + // Looping over all dimensions for the home element + for (int z = 0; z < desc.sizes[2]; z++) +#pragma omp parallel for + for (int y = 0; y < desc.sizes[1]; y++) { + for (int x = 0; x < desc.sizes[0]; x++) { + // Calculating indexing offset for the home element + int homeOffset = batchOffset; + + int homeIndex[] = {x, y, z}; + homeOffset += x * desc.strides[0]; + homeOffset += y * desc.strides[1]; + homeOffset += z * desc.strides[2]; + + // Zero kernel aggregates. + scalar_t valueSum = 0; + scalar_t dw_dz_ki = 0; + scalar_t dfilter_dz_ki = 0; + scalar_t colorSum_w = 0; + scalar_t colorSum_alpha = 0; + scalar_t xSum_w = 0; + scalar_t xSum_alpha = 0; + scalar_t ySum_w = 0; + scalar_t ySum_alpha = 0; + scalar_t zSum_w = 0; + scalar_t zSum_alpha = 0; + + scalar_t weightSum = 0.0f; + + // Looping over all dimensions for the neighbour element + Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes); + do // while(kernelIndex++) + { + // Calculating buffer offset for the neighbour element + // Index is clamped to the border in each dimension. + int neighbourOffset = batchOffset; + bool flagNotClamped = true; + + for (int i = 0; i < desc.dimensions; i++) { + int neighbourIndex = homeIndex[i] + kernelIndex[i] - halfWindowSize_arr[i]; + int neighbourIndexClamped = std::min(desc.sizes[i] - 1, std::max(0, neighbourIndex)); + neighbourOffset += neighbourIndexClamped * desc.strides[i]; + if (neighbourIndex != neighbourIndexClamped) { + flagNotClamped = false; + } + } + + // Euclidean color distance. + scalar_t colorDistance = 0; + scalar_t colorDistanceSquared = 0; + + for (int i = 0; i < desc.channelCount; i++) { + scalar_t diff = guidanceTensorData[homeOffset + i * desc.channelStride] - + guidanceTensorData[neighbourOffset + i * desc.channelStride]; + colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. + colorDistanceSquared += diff * diff; + } + + // Calculating and combining the spatial + // and color weights. + scalar_t spatialWeight = 1; + + spatialWeight = + gaussianKernel_x[kernelIndex[0]] * gaussianKernel_y[kernelIndex[1]] * gaussianKernel_z[kernelIndex[2]]; + + scalar_t colorWeight = exp(colorDistanceSquared * colorExpConstant); + scalar_t totalWeight = spatialWeight * colorWeight; + + // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. + if (flagNotClamped) { + for (int i = 0; i < desc.channelCount; i++) { + valueSum += inputTensorData[neighbourOffset + i * desc.channelStride] * totalWeight; + + // Derivative of weights with respect to X_i while i=k. + dw_dz_ki += (-1) * totalWeight * colorDistance / (colorSigma * colorSigma); + // Derivative of convolved image with respect to X_i while i=k. + dfilter_dz_ki += (-1) * totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + colorDistance / + (colorSigma * + colorSigma); // Be careful, the +1 is missing here -> Added before filling dfilter_dx_kiData + + colorSum_w += totalWeight * colorDistanceSquared / std::abs(colorSigma * colorSigma * colorSigma); + colorSum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + colorDistanceSquared / std::abs(colorSigma * colorSigma * colorSigma); + + xSum_w += totalWeight * xDistanceSquared[kernelIndex[0]] / std::abs(sigma_x * sigma_x * sigma_x); + xSum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + xDistanceSquared[kernelIndex[0]] / std::abs(sigma_x * sigma_x * sigma_x); + + ySum_w += totalWeight * yDistanceSquared[kernelIndex[1]] / std::abs(sigma_y * sigma_y * sigma_y); + ySum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + yDistanceSquared[kernelIndex[1]] / std::abs(sigma_y * sigma_y * sigma_y); + + zSum_w += totalWeight * zDistanceSquared[kernelIndex[2]] / std::abs(sigma_z * sigma_z * sigma_z); + zSum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] * + zDistanceSquared[kernelIndex[2]] / std::abs(sigma_z * sigma_z * sigma_z); + } + + weightSum += totalWeight; + } + } while (kernelIndex++); + + // Do the filtering and calculate the values for the backward pass. + for (int i = 0; i < desc.channelCount; i++) { + // Filtering: + outputTensorData[homeOffset + i * desc.channelStride] = valueSum / weightSum; + + // Pre-computations for the backward pass: + outputWeightsTensorData[homeOffset + i * desc.channelStride] = weightSum; + dO_dz_kiData[homeOffset + i * desc.channelStride] = -(1 / weightSum) * (valueSum / weightSum) * dw_dz_ki + + (1 / weightSum) * (dfilter_dz_ki); // no +1 for dfilter_dz_ki for JBF added here! + dO_dsig_rData[homeOffset + i * desc.channelStride] = + -(1 / weightSum) * (valueSum / weightSum) * colorSum_w + (1 / weightSum) * colorSum_alpha; + dO_dsig_xData[homeOffset + i * desc.channelStride] = + -(1 / weightSum) * (valueSum / weightSum) * xSum_w + (1 / weightSum) * xSum_alpha; + dO_dsig_yData[homeOffset + i * desc.channelStride] = + -(1 / weightSum) * (valueSum / weightSum) * ySum_w + (1 / weightSum) * ySum_alpha; + dO_dsig_zData[homeOffset + i * desc.channelStride] = + -(1 / weightSum) * (valueSum / weightSum) * zSum_w + (1 / weightSum) * zSum_alpha; + } + } + } + } + + delete[] kernelSizes; + delete[] gaussianKernel_x; + delete[] gaussianKernel_y; + delete[] gaussianKernel_z; + delete[] xDistanceSquared; + delete[] yDistanceSquared; + delete[] zDistanceSquared; +} + +std::tuple +JointBilateralFilterCpuForward( + torch::Tensor inputTensor, + torch::Tensor guidanceTensor, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Preparing output tensor. + torch::Tensor outputTensor = torch::zeros_like(inputTensor); + torch::Tensor outputWeightsTensor = torch::zeros_like(inputTensor); + torch::Tensor dO_dz_ki = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_r = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_x = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_y = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_z = torch::zeros_like(inputTensor); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputTensor.scalar_type(), "JointBilateralFilterCpuForward_3d", ([&] { + JointBilateralFilterCpuForward_3d( + inputTensor, + guidanceTensor, + outputTensor, + outputWeightsTensor, + dO_dz_ki, + dO_dsig_r, + dO_dsig_x, + dO_dsig_y, + dO_dsig_z, + sigma_x, + sigma_y, + sigma_z, + colorSigma); + })); + + return {outputTensor, outputWeightsTensor, dO_dz_ki, dO_dsig_r, dO_dsig_x, dO_dsig_y, dO_dsig_z}; +} diff --git a/monai/csrc/filtering/trainable_joint_bilateral/jbf_layer_gpu_backward.cu b/monai/csrc/filtering/trainable_joint_bilateral/jbf_layer_gpu_backward.cu new file mode 100644 index 0000000000..3989914015 --- /dev/null +++ b/monai/csrc/filtering/trainable_joint_bilateral/jbf_layer_gpu_backward.cu @@ -0,0 +1,311 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +========================================================================= +Adapted from https://github.com/faebstn96/trainable-joint-bilateral-filter-source +which has the following license... +https://github.com/faebstn96/trainable-joint-bilateral-filter-source/blob/main/LICENSE + +Copyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include + +#include "trainable_joint_bilateral.h" +//#include "../utils/cuda_error_check.h" +#include "utils/meta_macros.h" +#include "utils/tensor_description.h" + +__constant__ int cBatchStrideBack; +__constant__ int cColorStrideBack; + +__constant__ int cSizesBack[3]; +__constant__ int cStridesBack[3]; + +__constant__ int cKernelSizesBack[3]; +__constant__ int cHalfWindowSize_arrBack[3]; +__constant__ float cGaussianKernel_xBack[256]; +__constant__ float cGaussianKernel_yBack[256]; +__constant__ float cGaussianKernel_zBack[256]; +__constant__ float cXDistanceSquaredBack[256]; +__constant__ float cYDistanceSquaredBack[256]; +__constant__ float cZDistanceSquaredBack[256]; +__constant__ float cColorExponentConstantBack; +__constant__ float cSigma_xBack; +__constant__ float cSigma_yBack; +__constant__ float cSigma_zBack; +__constant__ float cColorSigmaBack; + +template +__global__ void JointBilateralFilterCudaKernel3DBackward( + scalar_t* gradientInputTensor, + scalar_t* gradientGuidanceTensor, + scalar_t* gradientOutputTensor, + scalar_t* inputTensor, + scalar_t* guidanceTensor, + scalar_t* outputTensor, + scalar_t* outputWeightsTensor, + scalar_t* dO_dz_ki) { + int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; + int batchOffset = blockIdx.y * cBatchStrideBack; + + if (homeOffset >= cColorStrideBack) + return; + + int homeX = homeOffset / cStridesBack[0]; + int homeY = (homeOffset - homeX * cStridesBack[0]) / cStridesBack[1]; + int homeZ = (homeOffset - homeX * cStridesBack[0] - homeY * cStridesBack[1]) / cStridesBack[2]; + int homeIndex[] = {homeX, homeY, homeZ}; + + // Zero kernel aggregates. + scalar_t valueSumGuidance = 0; + scalar_t valueSumInput = 0; + + for (int kernelX = 0; kernelX < cKernelSizesBack[0]; kernelX++) { + int neighbourX = max(0, min(homeX + (kernelX - cHalfWindowSize_arrBack[0]), cSizesBack[0] - 1)); + scalar_t gaussianX = cGaussianKernel_xBack[kernelX]; + + for (int kernelY = 0; kernelY < cKernelSizesBack[1]; kernelY++) { + int neighbourY = max(0, min(homeY + (kernelY - cHalfWindowSize_arrBack[1]), cSizesBack[1] - 1)); + scalar_t gaussianY = cGaussianKernel_yBack[kernelY]; + + for (int kernelZ = 0; kernelZ < cKernelSizesBack[2]; kernelZ++) { + int neighbourZ = max(0, min(homeZ + (kernelZ - cHalfWindowSize_arrBack[2]), cSizesBack[2] - 1)); + scalar_t gaussianZ = cGaussianKernel_zBack[kernelZ]; + + int neighbourOffset = neighbourX * cStridesBack[0] + neighbourY * cStridesBack[1] + neighbourZ; + + bool flagNotClamped = true; + int kernelIndex[] = {kernelX, kernelY, kernelZ}; + int dimensions = 3; // Must equal the number of spatial dimensions. + + for (int i = 0; i < dimensions; i++) { + int HalfWindowSizeBack = cHalfWindowSize_arrBack[i]; // Define constant memory as new variable here (!!), + // otherwise: cudaErrorMisalignedAddress + int neighbourIndex = homeIndex[i] + kernelIndex[i] - HalfWindowSizeBack; + int neighbourIndexClamped = min(cSizesBack[i] - 1, max(0, neighbourIndex)); + if (neighbourIndex != neighbourIndexClamped) { + flagNotClamped = false; + } + } + + scalar_t colorDistance = 0; + scalar_t colorDistanceSquared = 0; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = guidanceTensor[batchOffset + neighbourOffset + c * cColorStrideBack]; + scalar_t b = guidanceTensor[batchOffset + homeOffset + c * cColorStrideBack]; // Be careful: Here it is (Z_k - + // Z_i) and not (Z_i - Z_q) + scalar_t diff = a - b; + colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. + colorDistanceSquared += diff * diff; + } + + scalar_t spatialWeight = gaussianX * gaussianY * gaussianZ; + scalar_t colorWeight = exp(cColorExponentConstantBack * colorDistanceSquared); + scalar_t totalWeight = spatialWeight * colorWeight; + + // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. + if (flagNotClamped) { + scalar_t filter_kernel_guidance_back; + +#pragma unroll + for (int c = 0; c < C; c++) { + // Distinguish cases for k!=i (calculation is done here) + // and k==i (partial derivatives are precalculated). + // If statement replaces center element of neighborhood/kernel. + if (kernelX != cHalfWindowSize_arrBack[0] || kernelY != cHalfWindowSize_arrBack[1] || + kernelZ != cHalfWindowSize_arrBack[2]) { + filter_kernel_guidance_back = + -(1 / outputWeightsTensor[batchOffset + neighbourOffset + c * cColorStrideBack]) * + outputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * totalWeight * colorDistance / + (cColorSigmaBack * cColorSigmaBack) + + (1 / outputWeightsTensor[batchOffset + neighbourOffset + c * cColorStrideBack]) * totalWeight * + (inputTensor[batchOffset + homeOffset + c * cColorStrideBack] * colorDistance / + (cColorSigmaBack * cColorSigmaBack)); // inputTensorData[homeOffset] !!, no +1!! + } else { + filter_kernel_guidance_back = dO_dz_ki[batchOffset + homeOffset + c * cColorStrideBack]; + } + + valueSumGuidance += + gradientInputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * filter_kernel_guidance_back; + valueSumInput += gradientInputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * + (1 / outputWeightsTensor[batchOffset + neighbourOffset + c * cColorStrideBack]) * totalWeight; + } + } + } + } + } + +#pragma unroll + for (int c = 0; c < C; c++) { + gradientGuidanceTensor[batchOffset + homeOffset + c * cColorStrideBack] = valueSumGuidance; + gradientOutputTensor[batchOffset + homeOffset + c * cColorStrideBack] = valueSumInput; + } +} + +template +void JointBilateralFilterCudaBackwardFunction( + torch::Tensor gradientInputTensor, + torch::Tensor gradientGuidanceTensor, + torch::Tensor gradientOutputTensor, + torch::Tensor inputTensor, + torch::Tensor guidanceTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dz_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + // Pre-calculating gaussian kernel. + int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size + int halfWindowSize_x = floor(0.5f * windowSize_x); + int halfWindowSize_y = floor(0.5f * windowSize_y); + int halfWindowSize_z = floor(0.5f * windowSize_z); + int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z}; + float spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x); + float spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y); + float spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z); + float colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); + + int* kernelSizes = new int[desc.dimensions]; + kernelSizes[0] = windowSize_x; + kernelSizes[1] = windowSize_y; + kernelSizes[2] = windowSize_z; + + auto* gaussianKernel_x = new float[windowSize_x]; + auto* gaussianKernel_y = new float[windowSize_y]; + auto* gaussianKernel_z = new float[windowSize_z]; + auto* xDistanceSquared = new float[windowSize_x]; + auto* yDistanceSquared = new float[windowSize_y]; + auto* zDistanceSquared = new float[windowSize_z]; + + for (int i = 0; i < windowSize_x; i++) { + int distance = i - halfWindowSize_x; + gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); + xDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_y; i++) { + int distance = i - halfWindowSize_y; + gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); + yDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_z; i++) { + int distance = i - halfWindowSize_z; + gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); + zDistanceSquared[i] = distance * distance; + } + + // Writing constant memory. + cudaMemcpyToSymbol(cBatchStrideBack, &desc.batchStride, sizeof(int)); + cudaMemcpyToSymbol(cColorStrideBack, &desc.channelStride, sizeof(int)); + cudaMemcpyToSymbol(cSizesBack, desc.sizes, sizeof(int) * 3); + cudaMemcpyToSymbol(cStridesBack, desc.strides, sizeof(int) * 3); + cudaMemcpyToSymbol(cKernelSizesBack, kernelSizes, sizeof(int) * desc.dimensions); + cudaMemcpyToSymbol(cHalfWindowSize_arrBack, halfWindowSize_arr, sizeof(int) * desc.dimensions); + cudaMemcpyToSymbol(cGaussianKernel_xBack, gaussianKernel_x, sizeof(float) * windowSize_x); + cudaMemcpyToSymbol(cGaussianKernel_yBack, gaussianKernel_y, sizeof(float) * windowSize_y); + cudaMemcpyToSymbol(cGaussianKernel_zBack, gaussianKernel_z, sizeof(float) * windowSize_z); + cudaMemcpyToSymbol(cXDistanceSquaredBack, xDistanceSquared, sizeof(float) * windowSize_x); + cudaMemcpyToSymbol(cYDistanceSquaredBack, yDistanceSquared, sizeof(float) * windowSize_y); + cudaMemcpyToSymbol(cZDistanceSquaredBack, zDistanceSquared, sizeof(float) * windowSize_z); + cudaMemcpyToSymbol(cColorExponentConstantBack, &colorExpConstant, sizeof(float)); + cudaMemcpyToSymbol(cSigma_xBack, &sigma_x, sizeof(float)); + cudaMemcpyToSymbol(cSigma_yBack, &sigma_y, sizeof(float)); + cudaMemcpyToSymbol(cSigma_zBack, &sigma_z, sizeof(float)); + cudaMemcpyToSymbol(cColorSigmaBack, &colorSigma, sizeof(float)); + + // cuda_error_check("Cuda check before kernel call."); + +#define BLOCK_SIZE 32 + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + inputTensor.scalar_type(), "JointBilateralFilterCudaKernel3DBackward", ([&] { + JointBilateralFilterCudaKernel3DBackward + <<>>( + gradientInputTensor.data_ptr(), + gradientGuidanceTensor.data_ptr(), + gradientOutputTensor.data_ptr(), + inputTensor.data_ptr(), + guidanceTensor.data_ptr(), + outputTensor.data_ptr(), + outputWeightsTensor.data_ptr(), + dO_dz_ki.data_ptr()); + })); + + // cuda_error_check("Cuda check after kernel call."); + // delete[] kernel; + delete[] kernelSizes; + delete[] gaussianKernel_x; + delete[] gaussianKernel_y; + delete[] gaussianKernel_z; + delete[] xDistanceSquared; + delete[] yDistanceSquared; + delete[] zDistanceSquared; +} + +// Function to choose template implementation based on dynamic, channels and dimensions +std::tuple JointBilateralFilterCudaBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor guidanceTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dz_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + torch::Tensor gradientOutputTensor = torch::zeros_like(gradientInputTensor); + torch::Tensor gradientGuidanceTensor = torch::zeros_like(gradientInputTensor); + // cuda_error_check("beginning"); + +#define CASE(c, d) \ + JointBilateralFilterCudaBackwardFunction( \ + gradientInputTensor, \ + gradientGuidanceTensor, \ + gradientOutputTensor, \ + inputTensor, \ + guidanceTensor, \ + outputTensor, \ + outputWeightsTensor, \ + dO_dz_ki, \ + sigma_x, \ + sigma_y, \ + sigma_z, \ + colorSigma); + SWITCH_AB( + CASE, + BF_CUDA_MAX_CHANNELS, + BF_CUDA_MAX_SPATIAL_DIMENSION, + gradientInputTensor.size(1), + gradientInputTensor.dim() - 2); + + return {gradientOutputTensor, gradientGuidanceTensor}; +} diff --git a/monai/csrc/filtering/trainable_joint_bilateral/jbf_layer_gpu_forward.cu b/monai/csrc/filtering/trainable_joint_bilateral/jbf_layer_gpu_forward.cu new file mode 100644 index 0000000000..a3125d11f7 --- /dev/null +++ b/monai/csrc/filtering/trainable_joint_bilateral/jbf_layer_gpu_forward.cu @@ -0,0 +1,340 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +========================================================================= +Adapted from https://github.com/faebstn96/trainable-joint-bilateral-filter-source +which has the following license... +https://github.com/faebstn96/trainable-joint-bilateral-filter-source/blob/main/LICENSE + +Copyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include + +#include "trainable_joint_bilateral.h" +//#include "../utils/cuda_error_check.h" +#include "utils/meta_macros.h" +#include "utils/tensor_description.h" + +__constant__ int cBatchStride; +__constant__ int cColorStride; + +__constant__ int cSizes[3]; +__constant__ int cStrides[3]; + +__constant__ int cKernelSizes[3]; +__constant__ int cHalfWindowSize_arr[3]; +__constant__ float cGaussianKernel_x[256]; +__constant__ float cGaussianKernel_y[256]; +__constant__ float cGaussianKernel_z[256]; +__constant__ float cXDistanceSquared[256]; +__constant__ float cYDistanceSquared[256]; +__constant__ float cZDistanceSquared[256]; +__constant__ float cColorExponentConstant; +__constant__ float cSigma_x; +__constant__ float cSigma_y; +__constant__ float cSigma_z; +__constant__ float cColorSigma; + +template +__global__ void JointBilateralFilterCudaKernel3DForward( + scalar_t* input, + scalar_t* guidance, + scalar_t* output, + scalar_t* outputWeightsTensor, + scalar_t* dO_dz_ki, + scalar_t* dO_dsig_r, + scalar_t* dO_dsig_x, + scalar_t* dO_dsig_y, + scalar_t* dO_dsig_z) { + int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; + int batchOffset = blockIdx.y * cBatchStride; + + if (homeOffset >= cColorStride) + return; + + int homeX = homeOffset / cStrides[0]; + int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1]; + int homeZ = (homeOffset - homeX * cStrides[0] - homeY * cStrides[1]) / cStrides[2]; + int homeIndex[] = {homeX, homeY, homeZ}; + + // Zero kernel aggregates. + scalar_t valueSum = 0; + scalar_t dw_dz_ki = 0; + scalar_t dfilter_dz_ki = 0; + scalar_t colorSum_w = 0; + scalar_t colorSum_alpha = 0; + scalar_t xSum_w = 0; + scalar_t xSum_alpha = 0; + scalar_t ySum_w = 0; + scalar_t ySum_alpha = 0; + scalar_t zSum_w = 0; + scalar_t zSum_alpha = 0; + scalar_t weightSum = 0; + + for (int kernelX = 0; kernelX < cKernelSizes[0]; kernelX++) { + int neighbourX = max(0, min(homeX + (kernelX - cHalfWindowSize_arr[0]), cSizes[0] - 1)); + scalar_t gaussianX = cGaussianKernel_x[kernelX]; + + for (int kernelY = 0; kernelY < cKernelSizes[1]; kernelY++) { + int neighbourY = max(0, min(homeY + (kernelY - cHalfWindowSize_arr[1]), cSizes[1] - 1)); + scalar_t gaussianY = cGaussianKernel_y[kernelY]; + + for (int kernelZ = 0; kernelZ < cKernelSizes[2]; kernelZ++) { + int neighbourZ = max(0, min(homeZ + (kernelZ - cHalfWindowSize_arr[2]), cSizes[2] - 1)); + scalar_t gaussianZ = cGaussianKernel_z[kernelZ]; + + int neighbourOffset = neighbourX * cStrides[0] + neighbourY * cStrides[1] + neighbourZ; + + bool flagNotClamped = true; + int kernelIndex[] = {kernelX, kernelY, kernelZ}; + int dimensions = 3; // Must equal the number of spatial dimensions. + + for (int i = 0; i < dimensions; i++) { + int HalfWindowSizeBack = cHalfWindowSize_arr[i]; // Define constant memory as new variable here (!!), + // otherwise: cudaErrorMisalignedAddress + int neighbourIndex = homeIndex[i] + kernelIndex[i] - HalfWindowSizeBack; + int neighbourIndexClamped = min(cSizes[i] - 1, max(0, neighbourIndex)); + if (neighbourIndex != neighbourIndexClamped) { + flagNotClamped = false; + } + } + + scalar_t colorDistance = 0; + scalar_t colorDistanceSquared = 0; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = guidance[batchOffset + homeOffset + c * cColorStride]; + scalar_t b = guidance[batchOffset + neighbourOffset + c * cColorStride]; // Home - neighbor (!!) in backward + // the other way around !! + scalar_t diff = a - b; + colorDistance += diff; // Do not take the absolute value here. Be careful with the signs. + colorDistanceSquared += diff * diff; + } + + scalar_t spatialWeight = gaussianX * gaussianY * gaussianZ; + scalar_t colorWeight = exp(cColorExponentConstant * colorDistanceSquared); + scalar_t totalWeight = spatialWeight * colorWeight; + + // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded. + if (flagNotClamped) { +#pragma unroll + for (int c = 0; c < C; c++) { + valueSum += input[batchOffset + neighbourOffset + c * cColorStride] * totalWeight; + + // Derivative of weights with respect to X_i while i=k. + dw_dz_ki += (-1) * totalWeight * colorDistance / (cColorSigma * cColorSigma); + // Derivative of convolved image with respect to X_i while i=k. + dfilter_dz_ki += (-1) * totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + colorDistance / + (cColorSigma * + cColorSigma); // Be careful, the +1 is missing here -> Added before filling dfilter_dx_kiData + + colorSum_w += totalWeight * colorDistanceSquared / std::abs(cColorSigma * cColorSigma * cColorSigma); + colorSum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + colorDistanceSquared / std::abs(cColorSigma * cColorSigma * cColorSigma); + + xSum_w += totalWeight * cXDistanceSquared[kernelX] / std::abs(cSigma_x * cSigma_x * cSigma_x); + xSum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + cXDistanceSquared[kernelX] / std::abs(cSigma_x * cSigma_x * cSigma_x); + + ySum_w += totalWeight * cYDistanceSquared[kernelY] / std::abs(cSigma_y * cSigma_y * cSigma_y); + ySum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + cYDistanceSquared[kernelY] / std::abs(cSigma_y * cSigma_y * cSigma_y); + + zSum_w += totalWeight * cZDistanceSquared[kernelZ] / std::abs(cSigma_z * cSigma_z * cSigma_z); + zSum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] * + cZDistanceSquared[kernelZ] / std::abs(cSigma_z * cSigma_z * cSigma_z); + } + + weightSum += totalWeight; + } + } + } + } + +#pragma unroll + for (int c = 0; c < C; c++) { + // output[batchOffset + homeOffset + c * cColorStride] /= weightSum; + output[batchOffset + homeOffset + c * cColorStride] = valueSum / weightSum; + + // Pre-computations for the backward pass: + outputWeightsTensor[batchOffset + homeOffset + c * cColorStride] = weightSum; + dO_dz_ki[batchOffset + homeOffset + c * cColorStride] = -(1 / weightSum) * (valueSum / weightSum) * dw_dz_ki + + (1 / weightSum) * (dfilter_dz_ki); // no +1 for dfilter_dz_ki for JBF added here! + dO_dsig_r[batchOffset + homeOffset + c * cColorStride] = + -(1 / weightSum) * (valueSum / weightSum) * colorSum_w + (1 / weightSum) * colorSum_alpha; + dO_dsig_x[batchOffset + homeOffset + c * cColorStride] = + -(1 / weightSum) * (valueSum / weightSum) * xSum_w + (1 / weightSum) * xSum_alpha; + dO_dsig_y[batchOffset + homeOffset + c * cColorStride] = + -(1 / weightSum) * (valueSum / weightSum) * ySum_w + (1 / weightSum) * ySum_alpha; + dO_dsig_z[batchOffset + homeOffset + c * cColorStride] = + -(1 / weightSum) * (valueSum / weightSum) * zSum_w + (1 / weightSum) * zSum_alpha; + } +} + +template +void JointBilateralFilterCudaForwardFunction( + torch::Tensor inputTensor, + torch::Tensor guidanceTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dz_ki, + torch::Tensor dO_dsig_r, + torch::Tensor dO_dsig_x, + torch::Tensor dO_dsig_y, + torch::Tensor dO_dsig_z, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + // Pre-calculating gaussian kernel. + int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size + int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size + int halfWindowSize_x = floor(0.5f * windowSize_x); + int halfWindowSize_y = floor(0.5f * windowSize_y); + int halfWindowSize_z = floor(0.5f * windowSize_z); + int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z}; + float spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x); + float spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y); + float spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z); + float colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); + + int* kernelSizes = new int[desc.dimensions]; + kernelSizes[0] = windowSize_x; + kernelSizes[1] = windowSize_y; + kernelSizes[2] = windowSize_z; + + auto* gaussianKernel_x = new float[windowSize_x]; + auto* gaussianKernel_y = new float[windowSize_y]; + auto* gaussianKernel_z = new float[windowSize_z]; + auto* xDistanceSquared = new float[windowSize_x]; + auto* yDistanceSquared = new float[windowSize_y]; + auto* zDistanceSquared = new float[windowSize_z]; + + for (int i = 0; i < windowSize_x; i++) { + int distance = i - halfWindowSize_x; + gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x); + xDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_y; i++) { + int distance = i - halfWindowSize_y; + gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y); + yDistanceSquared[i] = distance * distance; + } + for (int i = 0; i < windowSize_z; i++) { + int distance = i - halfWindowSize_z; + gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z); + zDistanceSquared[i] = distance * distance; + } + + // Writing constant memory. + cudaMemcpyToSymbol(cBatchStride, &desc.batchStride, sizeof(int)); + cudaMemcpyToSymbol(cColorStride, &desc.channelStride, sizeof(int)); + cudaMemcpyToSymbol(cSizes, desc.sizes, sizeof(int) * 3); + cudaMemcpyToSymbol(cStrides, desc.strides, sizeof(int) * 3); + cudaMemcpyToSymbol(cKernelSizes, kernelSizes, sizeof(int) * desc.dimensions); + cudaMemcpyToSymbol(cHalfWindowSize_arr, halfWindowSize_arr, sizeof(int) * desc.dimensions); + cudaMemcpyToSymbol(cGaussianKernel_x, gaussianKernel_x, sizeof(float) * windowSize_x); + cudaMemcpyToSymbol(cGaussianKernel_y, gaussianKernel_y, sizeof(float) * windowSize_y); + cudaMemcpyToSymbol(cGaussianKernel_z, gaussianKernel_z, sizeof(float) * windowSize_z); + cudaMemcpyToSymbol(cXDistanceSquared, xDistanceSquared, sizeof(float) * windowSize_x); + cudaMemcpyToSymbol(cYDistanceSquared, yDistanceSquared, sizeof(float) * windowSize_y); + cudaMemcpyToSymbol(cZDistanceSquared, zDistanceSquared, sizeof(float) * windowSize_z); + cudaMemcpyToSymbol(cColorExponentConstant, &colorExpConstant, sizeof(float)); + cudaMemcpyToSymbol(cSigma_x, &sigma_x, sizeof(float)); + cudaMemcpyToSymbol(cSigma_y, &sigma_y, sizeof(float)); + cudaMemcpyToSymbol(cSigma_z, &sigma_z, sizeof(float)); + cudaMemcpyToSymbol(cColorSigma, &colorSigma, sizeof(float)); + + // cuda_error_check("Cuda check before kernel call."); + +#define BLOCK_SIZE 32 + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + inputTensor.scalar_type(), "JointBilateralFilterCudaKernel3DForward", ([&] { + JointBilateralFilterCudaKernel3DForward + <<>>( + inputTensor.data_ptr(), + guidanceTensor.data_ptr(), + outputTensor.data_ptr(), + outputWeightsTensor.data_ptr(), + dO_dz_ki.data_ptr(), + dO_dsig_r.data_ptr(), + dO_dsig_x.data_ptr(), + dO_dsig_y.data_ptr(), + dO_dsig_z.data_ptr()); + })); + + // cuda_error_check("Cuda check after kernel call."); + // delete[] kernel; + delete[] kernelSizes; + delete[] gaussianKernel_x; + delete[] gaussianKernel_y; + delete[] gaussianKernel_z; + delete[] xDistanceSquared; + delete[] yDistanceSquared; + delete[] zDistanceSquared; +} + +// Function to choose template implementation based on dynamic, channels and dimensions +std::tuple +JointBilateralFilterCudaForward( + torch::Tensor inputTensor, + torch::Tensor guidanceTensor, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + torch::Tensor outputTensor = torch::zeros_like(inputTensor); + torch::Tensor outputWeightsTensor = torch::zeros_like(inputTensor); + torch::Tensor dO_dz_ki = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_r = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_x = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_y = torch::zeros_like(inputTensor); + torch::Tensor dO_dsig_z = torch::zeros_like(inputTensor); + // cuda_error_check("beginning"); + +#define CASE(c, d) \ + JointBilateralFilterCudaForwardFunction( \ + inputTensor, \ + guidanceTensor, \ + outputTensor, \ + outputWeightsTensor, \ + dO_dz_ki, \ + dO_dsig_r, \ + dO_dsig_x, \ + dO_dsig_y, \ + dO_dsig_z, \ + sigma_x, \ + sigma_y, \ + sigma_z, \ + colorSigma); + SWITCH_AB(CASE, BF_CUDA_MAX_CHANNELS, BF_CUDA_MAX_SPATIAL_DIMENSION, inputTensor.size(1), inputTensor.dim() - 2); + + return {outputTensor, outputWeightsTensor, dO_dz_ki, dO_dsig_r, dO_dsig_x, dO_dsig_y, dO_dsig_z}; +} diff --git a/monai/csrc/filtering/trainable_joint_bilateral/trainable_joint_bilateral.cpp b/monai/csrc/filtering/trainable_joint_bilateral/trainable_joint_bilateral.cpp new file mode 100644 index 0000000000..e06b97113a --- /dev/null +++ b/monai/csrc/filtering/trainable_joint_bilateral/trainable_joint_bilateral.cpp @@ -0,0 +1,133 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +========================================================================= +Adapted from https://github.com/faebstn96/trainable-joint-bilateral-filter-source +which has the following license... +https://github.com/faebstn96/trainable-joint-bilateral-filter-source/blob/main/LICENSE + +Copyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include + +#include "trainable_joint_bilateral.h" +#include "utils/common_utils.h" + +std::tuple +TrainableJointBilateralFilterForward( + torch::Tensor inputTensor, + torch::Tensor guidanceTensor, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + std::tuple ( + *filterFunction)(torch::Tensor, torch::Tensor, float, float, float, float); + +#ifdef WITH_CUDA + + if (torch::cuda::is_available() && inputTensor.is_cuda()) { + CHECK_CONTIGUOUS_CUDA(inputTensor); + + if (inputTensor.size(1) > BF_CUDA_MAX_CHANNELS) { + throw std::runtime_error( + "Bilateral filtering not implemented for channel count > " + std::to_string(BF_CUDA_MAX_CHANNELS)); + } + + if (inputTensor.dim() - 2 > BF_CUDA_MAX_SPATIAL_DIMENSION) { + throw std::runtime_error( + "Bilateral filtering not implemented for spatial dimension > " + + std::to_string(BF_CUDA_MAX_SPATIAL_DIMENSION)); + } + + filterFunction = &JointBilateralFilterCudaForward; + } else { + filterFunction = &JointBilateralFilterCpuForward; + } +#else + filterFunction = &JointBilateralFilterCpuForward; +#endif + + return filterFunction(inputTensor, guidanceTensor, sigma_x, sigma_y, sigma_z, colorSigma); +} + +std::tuple TrainableJointBilateralFilterBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor guidanceTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma) { + std::tuple (*filterFunction)( + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + float, + float, + float, + float); + +#ifdef WITH_CUDA + + if (torch::cuda::is_available() && gradientInputTensor.is_cuda()) { + CHECK_CONTIGUOUS_CUDA(gradientInputTensor); + + if (gradientInputTensor.size(1) > BF_CUDA_MAX_CHANNELS) { + throw std::runtime_error( + "Bilateral filtering not implemented for channel count > " + std::to_string(BF_CUDA_MAX_CHANNELS)); + } + + if (gradientInputTensor.dim() - 2 > BF_CUDA_MAX_SPATIAL_DIMENSION) { + throw std::runtime_error( + "Bilateral filtering not implemented for spatial dimension > " + + std::to_string(BF_CUDA_MAX_SPATIAL_DIMENSION)); + } + + filterFunction = &JointBilateralFilterCudaBackward; + } else { + filterFunction = &JointBilateralFilterCpuBackward; + } +#else + filterFunction = &JointBilateralFilterCpuBackward; +#endif + + return filterFunction( + gradientInputTensor, + inputTensor, + guidanceTensor, + outputTensor, + outputWeightsTensor, + dO_dx_ki, + sigma_x, + sigma_y, + sigma_z, + colorSigma); +} diff --git a/monai/csrc/filtering/trainable_joint_bilateral/trainable_joint_bilateral.h b/monai/csrc/filtering/trainable_joint_bilateral/trainable_joint_bilateral.h new file mode 100644 index 0000000000..2f370d3870 --- /dev/null +++ b/monai/csrc/filtering/trainable_joint_bilateral/trainable_joint_bilateral.h @@ -0,0 +1,104 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +========================================================================= +Adapted from https://github.com/faebstn96/trainable-joint-bilateral-filter-source +which has the following license... +https://github.com/faebstn96/trainable-joint-bilateral-filter-source/blob/main/LICENSE + +Copyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +#include +#include +#include +#include +#include "utils/common_utils.h" +//#include "utils/tensor_description.h" + +#define BF_CUDA_MAX_CHANNELS 16 +#define BF_CUDA_MAX_SPATIAL_DIMENSION 3 + +#ifdef WITH_CUDA +std::tuple +JointBilateralFilterCudaForward( + torch::Tensor inputTensor, + torch::Tensor guidanceTensor, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); +std::tuple JointBilateralFilterCudaBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor guidanceTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dz_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); +#endif + +std::tuple +JointBilateralFilterCpuForward( + torch::Tensor inputTensor, + torch::Tensor guidanceTensor, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); + +std::tuple JointBilateralFilterCpuBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor guidanceTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dz_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); + +std::tuple +TrainableJointBilateralFilterForward( + torch::Tensor inputTensor, + torch::Tensor guidanceTensor, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); + +std::tuple TrainableJointBilateralFilterBackward( + torch::Tensor gradientInputTensor, + torch::Tensor inputTensor, + torch::Tensor guidanceTensor, + torch::Tensor outputTensor, + torch::Tensor outputWeightsTensor, + torch::Tensor dO_dx_ki, + float sigma_x, + float sigma_y, + float sigma_z, + float colorSigma); diff --git a/monai/csrc/utils/tensor_indexing.h b/monai/csrc/utils/tensor_indexing.h new file mode 100644 index 0000000000..9c02ba0691 --- /dev/null +++ b/monai/csrc/utils/tensor_indexing.h @@ -0,0 +1,50 @@ +/* +Copyright (c) MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include + +// Struct to easily index input tensors. +struct Indexer { + public: + Indexer(int dimensions, int* sizes) { + m_dimensions = dimensions; + m_sizes = sizes; + m_index = new int[dimensions]{0}; + } + ~Indexer() { + delete[] m_index; + } + + bool operator++(int) { + for (int i = 0; i < m_dimensions; i++) { + m_index[i] += 1; + + if (m_index[i] < m_sizes[i]) { + return true; + } else { + m_index[i] = 0; + } + } + + return false; + } + + int& operator[](int dimensionIndex) { + return m_index[dimensionIndex]; + } + + private: + int m_dimensions; + int* m_sizes; + int* m_index; +}; diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 65ee8c377f..8d8297deaf 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import contextlib from .box_utils import ( @@ -59,12 +61,16 @@ resolve_writer, ) from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer +from .itk_torch_bridge import ( + get_itk_image_center, + itk_image_to_metatensor, + itk_to_monai_affine, + metatensor_to_itk_image, + monai_to_itk_affine, + monai_to_itk_ddf, +) from .meta_obj import MetaObj, get_track_meta, set_track_meta from .meta_tensor import MetaTensor -from .nifti_saver import NiftiSaver -from .nifti_writer import write_nifti -from .png_saver import PNGSaver -from .png_writer import write_png from .samplers import DistributedSampler, DistributedWeightedRandomSampler from .synthetic import create_test_image_2d, create_test_image_3d from .test_time_augmentation import TestTimeAugmentation @@ -115,18 +121,32 @@ with contextlib.suppress(BaseException): from multiprocessing.reduction import ForkingPickler - def _rebuild_meta(cls, storage, metadata): - storage_offset, size, stride, meta_dict = metadata - t = cls([], dtype=storage.dtype, device=storage.device) - t.set_(storage._untyped() if hasattr(storage, "_untyped") else storage, storage_offset, size, stride) + def _rebuild_meta(cls, storage, dtype, metadata): + storage_offset, size, stride, requires_grad, meta_dict = metadata + storage = storage._untyped_storage if hasattr(storage, "_untyped_storage") else storage + t = cls([], dtype=dtype, device=storage.device) + t.set_(storage, storage_offset, size, stride) + t.requires_grad = requires_grad t.__dict__ = meta_dict return t def reduce_meta_tensor(meta_tensor): - storage = meta_tensor.storage() - if storage.is_cuda: + if hasattr(meta_tensor, "untyped_storage"): + storage = meta_tensor.untyped_storage() + elif hasattr(meta_tensor, "_typed_storage"): # gh pytorch 44dac51/torch/_tensor.py#L231-L233 + storage = meta_tensor._typed_storage() + else: + storage = meta_tensor.storage() + dtype = meta_tensor.dtype + if meta_tensor.is_cuda: raise NotImplementedError("sharing CUDA metatensor across processes not implemented") - metadata = (meta_tensor.storage_offset(), meta_tensor.size(), meta_tensor.stride(), meta_tensor.__dict__) - return _rebuild_meta, (type(meta_tensor), storage, metadata) + metadata = ( + meta_tensor.storage_offset(), + meta_tensor.size(), + meta_tensor.stride(), + meta_tensor.requires_grad, + meta_tensor.__dict__, + ) + return _rebuild_meta, (type(meta_tensor), storage, dtype, metadata) ForkingPickler.register(MetaTensor, reduce_meta_tensor) diff --git a/monai/data/box_utils.py b/monai/data/box_utils.py index 162c7cae26..b040119626 100644 --- a/monai/data/box_utils.py +++ b/monai/data/box_utils.py @@ -18,16 +18,18 @@ the rest of the detection pipelines mainly assumes boxes in `StandardMode`. """ +from __future__ import annotations + import inspect import warnings from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence from copy import deepcopy -from typing import Callable, Dict, Sequence, Tuple, Type, Union import numpy as np import torch -from monai.config.type_definitions import NdarrayOrTensor +from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor from monai.utils import look_up_option from monai.utils.enums import BoxModeName from monai.utils.type_conversion import convert_data_type, convert_to_dst_type @@ -77,7 +79,7 @@ class BoxMode(ABC): """ # a dictionary that maps spatial_dims to monai.utils.enums.BoxModeName. - name: Dict[int, BoxModeName] = {} + name: dict[int, BoxModeName] = {} @classmethod def get_name(cls, spatial_dims: int) -> str: @@ -93,7 +95,7 @@ def get_name(cls, spatial_dims: int) -> str: return cls.name[spatial_dims].value @abstractmethod - def boxes_to_corners(self, boxes: torch.Tensor) -> Tuple: + def boxes_to_corners(self, boxes: torch.Tensor) -> tuple: """ Convert the bounding boxes of the current mode to corners. @@ -101,7 +103,7 @@ def boxes_to_corners(self, boxes: torch.Tensor) -> Tuple: boxes: bounding boxes, Nx4 or Nx6 torch tensor Returns: - ``Tuple``: corners of boxes, 4-element or 6-element tuple, each element is a Nx1 torch tensor. + ``tuple``: corners of boxes, 4-element or 6-element tuple, each element is a Nx1 torch tensor. It represents (xmin, ymin, xmax, ymax) or (xmin, ymin, zmin, xmax, ymax, zmax) Example: @@ -151,8 +153,8 @@ class CornerCornerModeTypeA(BoxMode): name = {2: BoxModeName.XYXY, 3: BoxModeName.XYZXYZ} - def boxes_to_corners(self, boxes: torch.Tensor) -> Tuple: - corners: Tuple + def boxes_to_corners(self, boxes: torch.Tensor) -> tuple: + corners: tuple corners = boxes.split(1, dim=-1) return corners @@ -178,8 +180,8 @@ class CornerCornerModeTypeB(BoxMode): name = {2: BoxModeName.XXYY, 3: BoxModeName.XXYYZZ} - def boxes_to_corners(self, boxes: torch.Tensor) -> Tuple: - corners: Tuple + def boxes_to_corners(self, boxes: torch.Tensor) -> tuple: + corners: tuple spatial_dims = get_spatial_dims(boxes=boxes) if spatial_dims == 3: xmin, xmax, ymin, ymax, zmin, zmax = boxes.split(1, dim=-1) @@ -215,8 +217,8 @@ class CornerCornerModeTypeC(BoxMode): name = {2: BoxModeName.XYXY, 3: BoxModeName.XYXYZZ} - def boxes_to_corners(self, boxes: torch.Tensor) -> Tuple: - corners: Tuple + def boxes_to_corners(self, boxes: torch.Tensor) -> tuple: + corners: tuple spatial_dims = get_spatial_dims(boxes=boxes) if spatial_dims == 3: xmin, ymin, xmax, ymax, zmin, zmax = boxes.split(1, dim=-1) @@ -251,8 +253,8 @@ class CornerSizeMode(BoxMode): name = {2: BoxModeName.XYWH, 3: BoxModeName.XYZWHD} - def boxes_to_corners(self, boxes: torch.Tensor) -> Tuple: - corners: Tuple + def boxes_to_corners(self, boxes: torch.Tensor) -> tuple: + corners: tuple # convert to float32 when computing torch.clamp, which does not support float16 box_dtype = boxes.dtype @@ -300,8 +302,8 @@ class CenterSizeMode(BoxMode): name = {2: BoxModeName.CCWH, 3: BoxModeName.CCCWHD} - def boxes_to_corners(self, boxes: torch.Tensor) -> Tuple: - corners: Tuple + def boxes_to_corners(self, boxes: torch.Tensor) -> tuple: + corners: tuple # convert to float32 when computing torch.clamp, which does not support float16 box_dtype = boxes.dtype @@ -361,10 +363,10 @@ def corners_to_boxes(self, corners: Sequence) -> torch.Tensor: def get_spatial_dims( - boxes: Union[torch.Tensor, np.ndarray, None] = None, - points: Union[torch.Tensor, np.ndarray, None] = None, - corners: Union[Sequence, None] = None, - spatial_size: Union[Sequence[int], torch.Tensor, np.ndarray, None] = None, + boxes: torch.Tensor | np.ndarray | None = None, + points: torch.Tensor | np.ndarray | None = None, + corners: Sequence | None = None, + spatial_size: Sequence[int] | torch.Tensor | np.ndarray | None = None, ) -> int: """ Get spatial dimension for the giving setting and check the validity of them. @@ -430,7 +432,7 @@ def get_spatial_dims( raise ValueError("The dimensions of multiple inputs should match with each other.") -def get_boxmode(mode: Union[str, BoxMode, Type[BoxMode], None] = None, *args, **kwargs) -> BoxMode: +def get_boxmode(mode: str | BoxMode | type[BoxMode] | None = None, *args, **kwargs) -> BoxMode: """ This function that return a :class:`~monai.data.box_utils.BoxMode` object giving a representation of box mode @@ -494,8 +496,8 @@ def get_boxmode(mode: Union[str, BoxMode, Type[BoxMode], None] = None, *args, ** def convert_box_mode( boxes: NdarrayOrTensor, - src_mode: Union[str, BoxMode, Type[BoxMode], None] = None, - dst_mode: Union[str, BoxMode, Type[BoxMode], None] = None, + src_mode: str | BoxMode | type[BoxMode] | None = None, + dst_mode: str | BoxMode | type[BoxMode] | None = None, ) -> NdarrayOrTensor: """ This function converts the boxes in src_mode to the dst_mode. @@ -549,7 +551,7 @@ def convert_box_mode( def convert_box_to_standard_mode( - boxes: NdarrayOrTensor, mode: Union[str, BoxMode, Type[BoxMode], None] = None + boxes: NdarrayOrTensor, mode: str | BoxMode | type[BoxMode] | None = None ) -> NdarrayOrTensor: """ Convert given boxes to standard mode. @@ -624,7 +626,7 @@ def centers_in_boxes(centers: NdarrayOrTensor, boxes: NdarrayOrTensor, eps: floa def boxes_center_distance( boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor, euclidean: bool = True -) -> Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor]: +) -> tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor]: """ Distance of center points between two sets of boxes @@ -726,7 +728,7 @@ def box_area(boxes: NdarrayOrTensor) -> NdarrayOrTensor: def _box_inter_union( boxes1_t: torch.Tensor, boxes2_t: torch.Tensor, compute_dtype: torch.dtype = torch.float32 -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """ This internal function computes the intersection and union area of two set of boxes. @@ -937,11 +939,11 @@ def box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOr def spatial_crop_boxes( - boxes: NdarrayOrTensor, - roi_start: Union[Sequence[int], NdarrayOrTensor], - roi_end: Union[Sequence[int], NdarrayOrTensor], + boxes: NdarrayTensor, + roi_start: Sequence[int] | NdarrayOrTensor, + roi_end: Sequence[int] | NdarrayOrTensor, remove_empty: bool = True, -) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: +) -> tuple[NdarrayTensor, NdarrayOrTensor]: """ This function generate the new boxes when the corresponding image is cropped to the given ROI. When ``remove_empty=True``, it makes sure the bounding boxes are within the new cropped image. @@ -994,8 +996,8 @@ def spatial_crop_boxes( def clip_boxes_to_image( - boxes: NdarrayOrTensor, spatial_size: Union[Sequence[int], NdarrayOrTensor], remove_empty: bool = True -) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: + boxes: NdarrayOrTensor, spatial_size: Sequence[int] | NdarrayOrTensor, remove_empty: bool = True +) -> tuple[NdarrayOrTensor, NdarrayOrTensor]: """ This function clips the ``boxes`` to makes sure the bounding boxes are within the image. diff --git a/monai/data/csv_saver.py b/monai/data/csv_saver.py index 36d159d4be..f2b483765f 100644 --- a/monai/data/csv_saver.py +++ b/monai/data/csv_saver.py @@ -9,11 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import warnings from collections import OrderedDict from pathlib import Path -from typing import Dict, Optional, Union import numpy as np import torch @@ -81,7 +82,7 @@ def finalize(self) -> None: # clear cache content after writing self.reset_cache() - def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: + def save(self, data: torch.Tensor | np.ndarray, meta_data: dict | None = None) -> None: """Save data into the cache dictionary. The metadata should have the following key: - ``'filename_or_obj'`` -- save the data corresponding to file name or object. If meta_data is None, use the default index from 0 to save data instead. @@ -97,7 +98,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] data = data.detach().cpu().numpy() self._cache_dict[save_key] = np.asarray(data, dtype=float) - def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: + def save_batch(self, batch_data: torch.Tensor | np.ndarray, meta_data: dict | None = None) -> None: """Save a batch of data into the cache dictionary. Args: diff --git a/monai/data/dataloader.py b/monai/data/dataloader.py index f43211f184..9de0d28b96 100644 --- a/monai/data/dataloader.py +++ b/monai/data/dataloader.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import torch from torch.utils.data import DataLoader as _TorchDataLoader from torch.utils.data import Dataset diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 1e9d67f358..040d583b0d 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import collections.abc import math import pickle @@ -18,11 +20,12 @@ import threading import time import warnings +from collections.abc import Callable, Sequence from copy import copy, deepcopy from multiprocessing.managers import ListProxy from multiprocessing.pool import ThreadPool from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union +from typing import IO, TYPE_CHECKING, Any, cast import numpy as np import torch @@ -42,7 +45,7 @@ convert_to_contiguous, reset_ops_id, ) -from monai.utils import MAX_SEED, deprecated_arg, get_seed, look_up_option, min_version, optional_import +from monai.utils import MAX_SEED, get_seed, look_up_option, min_version, optional_import from monai.utils.misc import first if TYPE_CHECKING: @@ -72,7 +75,7 @@ class Dataset(_TorchDataset): }, }, }] """ - def __init__(self, data: Sequence, transform: Optional[Callable] = None) -> None: + def __init__(self, data: Sequence, transform: Callable | None = None) -> None: """ Args: data: input data to load and transform to generate dataset for model. @@ -92,7 +95,7 @@ def _transform(self, index: int): data_i = self.data[index] return apply_transform(self.transform, data_i) if self.transform is not None else data_i - def __getitem__(self, index: Union[int, slice, Sequence[int]]): + def __getitem__(self, index: int | slice | Sequence[int]): """ Returns a `Subset` if `index` is a slice or Sequence, a data item otherwise. """ @@ -143,7 +146,7 @@ def __init__(self, data: Any, func: Callable, **kwargs) -> None: self.kwargs = kwargs self.reset() - def reset(self, data: Optional[Any] = None, func: Optional[Callable] = None, **kwargs): + def reset(self, data: Any | None = None, func: Callable | None = None, **kwargs): """ Reset the dataset items with specified `func`. @@ -211,12 +214,12 @@ class PersistentDataset(Dataset): def __init__( self, data: Sequence, - transform: Union[Sequence[Callable], Callable], - cache_dir: Optional[Union[Path, str]], + transform: Sequence[Callable] | Callable, + cache_dir: Path | str | None, hash_func: Callable[..., bytes] = pickle_hashing, pickle_module: str = "pickle", pickle_protocol: int = DEFAULT_PROTOCOL, - hash_transform: Optional[Callable[..., bytes]] = None, + hash_transform: Callable[..., bytes] | None = None, reset_ops_id: bool = True, ) -> None: """ @@ -420,13 +423,13 @@ class CacheNTransDataset(PersistentDataset): def __init__( self, data: Sequence, - transform: Union[Sequence[Callable], Callable], + transform: Sequence[Callable] | Callable, cache_n_trans: int, - cache_dir: Optional[Union[Path, str]], + cache_dir: Path | str | None, hash_func: Callable[..., bytes] = pickle_hashing, pickle_module: str = "pickle", pickle_protocol: int = DEFAULT_PROTOCOL, - hash_transform: Optional[Callable[..., bytes]] = None, + hash_transform: Callable[..., bytes] | None = None, reset_ops_id: bool = True, ) -> None: """ @@ -534,15 +537,15 @@ class LMDBDataset(PersistentDataset): def __init__( self, data: Sequence, - transform: Union[Sequence[Callable], Callable], - cache_dir: Union[Path, str] = "cache", + transform: Sequence[Callable] | Callable, + cache_dir: Path | str = "cache", hash_func: Callable[..., bytes] = pickle_hashing, db_name: str = "monai_cache", progress: bool = True, pickle_protocol=pickle.HIGHEST_PROTOCOL, - hash_transform: Optional[Callable[..., bytes]] = None, + hash_transform: Callable[..., bytes] | None = None, reset_ops_id: bool = True, - lmdb_kwargs: Optional[dict] = None, + lmdb_kwargs: dict | None = None, ) -> None: """ Args: @@ -589,7 +592,7 @@ def __init__( self.lmdb_kwargs["map_size"] = 1024**4 # default map_size # lmdb is single-writer multi-reader by default # the cache is created without multi-threading - self._read_env: Optional[Any] = None + self._read_env: Any | None = None # this runs on the primary thread/process self._fill_cache_start_reader(show_progress=self.progress) print(f"Accessing lmdb file: {self.db_file.absolute()}.") @@ -742,16 +745,16 @@ class CacheDataset(Dataset): def __init__( self, data: Sequence, - transform: Optional[Union[Sequence[Callable], Callable]] = None, + transform: Sequence[Callable] | Callable | None = None, cache_num: int = sys.maxsize, cache_rate: float = 1.0, - num_workers: Optional[int] = 1, + num_workers: int | None = 1, progress: bool = True, copy_cache: bool = True, as_contiguous: bool = True, hash_as_key: bool = False, hash_func: Callable[..., bytes] = pickle_hashing, - runtime_cache: Union[bool, str, List, ListProxy] = False, + runtime_cache: bool | str | list | ListProxy = False, ) -> None: """ Args: @@ -809,8 +812,8 @@ def __init__( self.num_workers = max(int(self.num_workers), 1) self.runtime_cache = runtime_cache self.cache_num = 0 - self._cache: Union[List, ListProxy] = [] - self._hash_keys: List = [] + self._cache: list | ListProxy = [] + self._hash_keys: list = [] self.set_data(data) def set_data(self, data: Sequence) -> None: @@ -850,7 +853,7 @@ def _compute_cache_num(data_len: int): self._cache = self.runtime_cache # type: ignore return - def _fill_cache(self, indices=None) -> List: + def _fill_cache(self, indices=None) -> list: """ Compute and fill the cache content from data source. @@ -999,12 +1002,12 @@ class SmartCacheDataset(Randomizable, CacheDataset): def __init__( self, data: Sequence, - transform: Optional[Union[Sequence[Callable], Callable]] = None, + transform: Sequence[Callable] | Callable | None = None, replace_rate: float = 0.1, cache_num: int = sys.maxsize, cache_rate: float = 1.0, - num_init_workers: Optional[int] = 1, - num_replace_workers: Optional[int] = 1, + num_init_workers: int | None = 1, + num_replace_workers: int | None = 1, progress: bool = True, shuffle: bool = True, seed: int = 0, @@ -1020,7 +1023,7 @@ def __init__( self._update_lock: threading.Lock = threading.Lock() self._round: int = 1 self._replace_done: bool = False - self._replace_mgr: Optional[threading.Thread] = None + self._replace_mgr: threading.Thread | None = None if runtime_cache is not False: raise NotImplementedError("Options other than `runtime_cache=False` is not implemented yet.") @@ -1044,14 +1047,14 @@ def __init__( if replace_rate <= 0: raise ValueError("replace_rate must be greater than 0, otherwise, please use monai.data.CacheDataset.") - self.num_replace_workers: Optional[int] = num_replace_workers + self.num_replace_workers: int | None = num_replace_workers if self.num_replace_workers is not None: self.num_replace_workers = max(int(self.num_replace_workers), 1) self._total_num: int = len(data) self._replace_num: int = min(math.ceil(self.cache_num * replace_rate), len(data) - self.cache_num) - self._replacements: List[Any] = [None for _ in range(self._replace_num)] - self._replace_data_idx: List[int] = list(range(self._replace_num)) + self._replacements: list[Any] = [None for _ in range(self._replace_num)] + self._replace_data_idx: list[int] = list(range(self._replace_num)) self._compute_data_idx() def set_data(self, data: Sequence): @@ -1250,7 +1253,7 @@ class ZipDataset(Dataset): """ - def __init__(self, datasets: Sequence, transform: Optional[Callable] = None) -> None: + def __init__(self, datasets: Sequence, transform: Callable | None = None) -> None: """ Args: datasets: list of datasets to zip together. @@ -1327,11 +1330,11 @@ def __call__(self, input_): def __init__( self, img: Sequence, - img_transform: Optional[Callable] = None, - seg: Optional[Sequence] = None, - seg_transform: Optional[Callable] = None, - labels: Optional[Sequence] = None, - label_transform: Optional[Callable] = None, + img_transform: Callable | None = None, + seg: Sequence | None = None, + seg_transform: Callable | None = None, + labels: Sequence | None = None, + label_transform: Callable | None = None, ) -> None: """ Initializes the dataset with the filename lists. The transform `img_transform` is applied @@ -1356,7 +1359,7 @@ def __init__( def __len__(self) -> int: return len(self.dataset) - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, data: Any | None = None) -> None: self._seed = self.R.randint(MAX_SEED, dtype="uint32") def __getitem__(self, index: int): @@ -1390,17 +1393,17 @@ class NPZDictItemDataset(Dataset): def __init__( self, - npzfile: Union[str, IO], - keys: Dict[str, str], - transform: Optional[Callable[..., Dict[str, Any]]] = None, - other_keys: Optional[Sequence[str]] = (), + npzfile: str | IO, + keys: dict[str, str], + transform: Callable[..., dict[str, Any]] | None = None, + other_keys: Sequence[str] | None = (), ): - self.npzfile: Union[str, IO] = npzfile if isinstance(npzfile, str) else "STREAM" - self.keys: Dict[str, str] = dict(keys) + self.npzfile: str | IO = npzfile if isinstance(npzfile, str) else "STREAM" + self.keys: dict[str, str] = dict(keys) dat = np.load(npzfile) self.arrays = {storedk: dat[datak] for datak, storedk in self.keys.items()} - self.length = self.arrays[first(self.keys.values())].shape[0] + self.length = self.arrays[cast(str, first(self.keys.values()))].shape[0] self.other_keys = {} if other_keys is None else {k: dat[k] for k in other_keys} @@ -1480,20 +1483,19 @@ class CSVDataset(Dataset): """ - @deprecated_arg(name="filename", new_name="src", since="0.8", msg_suffix="please use `src` instead.") def __init__( self, - src: Optional[Union[str, Sequence[str]]] = None, # also can be `DataFrame` or a sequence of `DataFrame` - row_indices: Optional[Sequence[Union[int, str]]] = None, - col_names: Optional[Sequence[str]] = None, - col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None, - col_groups: Optional[Dict[str, Sequence[str]]] = None, - transform: Optional[Callable] = None, - kwargs_read_csv: Optional[Dict] = None, + src: str | Sequence[str] | None = None, # also can be `DataFrame` or a sequence of `DataFrame` + row_indices: Sequence[int | str] | None = None, + col_names: Sequence[str] | None = None, + col_types: dict[str, dict[str, Any] | None] | None = None, + col_groups: dict[str, Sequence[str]] | None = None, + transform: Callable | None = None, + kwargs_read_csv: dict | None = None, **kwargs, ): srcs = (src,) if not isinstance(src, (tuple, list)) else src - dfs: List = [] + dfs: list = [] for i in srcs: if isinstance(i, str): dfs.append(pd.read_csv(i, **kwargs_read_csv) if kwargs_read_csv else pd.read_csv(i)) @@ -1502,9 +1504,6 @@ def __init__( else: raise ValueError("`src` must be file path or pandas `DataFrame`.") - # in case treating deprecated arg `filename` as kwargs, remove it from `kwargs` - kwargs.pop("filename", None) - data = convert_tables_to_dicts( dfs=dfs, row_indices=row_indices, col_names=col_names, col_types=col_types, col_groups=col_groups, **kwargs ) diff --git a/monai/data/dataset_summary.py b/monai/data/dataset_summary.py index 785e8c7b88..769ae33b46 100644 --- a/monai/data/dataset_summary.py +++ b/monai/data/dataset_summary.py @@ -9,9 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings from itertools import chain -from typing import List, Optional import numpy as np import torch @@ -44,9 +45,9 @@ class DatasetSummary: def __init__( self, dataset: Dataset, - image_key: Optional[str] = "image", - label_key: Optional[str] = "label", - meta_key: Optional[KeysCollection] = None, + image_key: str | None = "image", + label_key: str | None = "label", + meta_key: KeysCollection | None = None, meta_key_postfix: str = DEFAULT_POST_FIX, num_workers: int = 0, **kwargs, @@ -75,7 +76,7 @@ def __init__( self.image_key = image_key self.label_key = label_key self.meta_key = meta_key or f"{image_key}_{meta_key_postfix}" - self.all_meta_data: List = [] + self.all_meta_data: list = [] def collect_meta_data(self): """ diff --git a/monai/data/decathlon_datalist.py b/monai/data/decathlon_datalist.py index c1bbabceb9..6f163f972e 100644 --- a/monai/data/decathlon_datalist.py +++ b/monai/data/decathlon_datalist.py @@ -9,11 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import json import os import warnings +from collections.abc import Sequence from pathlib import Path -from typing import Dict, List, Optional, Sequence, Union, overload +from typing import overload from monai.config import KeysCollection, PathLike from monai.data.utils import partition_dataset, select_cross_validation_folds @@ -26,7 +29,7 @@ def _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = Fals @overload -def _compute_path(base_dir: PathLike, element: List[PathLike], check_path: bool = False) -> List[str]: +def _compute_path(base_dir: PathLike, element: list[PathLike], check_path: bool = False) -> list[str]: ... @@ -60,7 +63,7 @@ def _join_path(base_dir: PathLike, item: PathLike): return element -def _append_paths(base_dir: PathLike, is_segmentation: bool, items: List[Dict]) -> List[Dict]: +def _append_paths(base_dir: PathLike, is_segmentation: bool, items: list[dict]) -> list[dict]: """ Args: base_dir: the base directory of the dataset. @@ -87,8 +90,8 @@ def load_decathlon_datalist( data_list_file_path: PathLike, is_segmentation: bool = True, data_list_key: str = "training", - base_dir: Optional[PathLike] = None, -) -> List[Dict]: + base_dir: PathLike | None = None, +) -> list[dict]: """Load image/label paths of decathlon challenge from JSON file Json file is similar to what you get from http://medicaldecathlon.com/ @@ -132,7 +135,7 @@ def load_decathlon_datalist( return _append_paths(base_dir, is_segmentation, expected_data) -def load_decathlon_properties(data_property_file_path: PathLike, property_keys: Union[Sequence[str], str]) -> Dict: +def load_decathlon_properties(data_property_file_path: PathLike, property_keys: Sequence[str] | str) -> dict: """Load the properties from the JSON file contains data property with specified `property_keys`. Args: @@ -158,7 +161,7 @@ def load_decathlon_properties(data_property_file_path: PathLike, property_keys: def check_missing_files( - datalist: List[Dict], keys: KeysCollection, root_dir: Optional[PathLike] = None, allow_missing_keys: bool = False + datalist: list[dict], keys: KeysCollection, root_dir: PathLike | None = None, allow_missing_keys: bool = False ): """Checks whether some files in the Decathlon datalist are missing. It would be helpful to check missing files before a heavy training run. @@ -196,18 +199,18 @@ def check_missing_files( def create_cross_validation_datalist( - datalist: List[Dict], + datalist: list[dict], nfolds: int, - train_folds: Union[Sequence[int], int], - val_folds: Union[Sequence[int], int], + train_folds: Sequence[int] | int, + val_folds: Sequence[int] | int, train_key: str = "training", val_key: str = "validation", - filename: Optional[Union[Path, str]] = None, + filename: Path | str | None = None, shuffle: bool = True, seed: int = 0, check_missing: bool = False, - keys: Optional[KeysCollection] = None, - root_dir: Optional[str] = None, + keys: KeysCollection | None = None, + root_dir: str | None = None, allow_missing_keys: bool = False, raise_error: bool = True, ): diff --git a/monai/data/fft_utils.py b/monai/data/fft_utils.py index 19083aa711..d26a31d656 100644 --- a/monai/data/fft_utils.py +++ b/monai/data/fft_utils.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import torch from monai.config.type_definitions import NdarrayOrTensor diff --git a/monai/data/folder_layout.py b/monai/data/folder_layout.py index 2da9543e7d..190a07334d 100644 --- a/monai/data/folder_layout.py +++ b/monai/data/folder_layout.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import monai from monai.config import PathLike from monai.data.utils import create_file_basename @@ -19,7 +21,11 @@ def default_name_formatter(metadict, saver): """Returns a kwargs dict for :py:meth:`FolderLayout.filename`, according to the input metadata and SaveImage transform.""" - subject = metadict[monai.utils.ImageMetaKey.FILENAME_OR_OBJ] if metadict else getattr(saver, "_data_index", 0) + subject = ( + metadict.get(monai.utils.ImageMetaKey.FILENAME_OR_OBJ, getattr(saver, "_data_index", 0)) + if metadict + else getattr(saver, "_data_index", 0) + ) patch_index = metadict.get(monai.utils.ImageMetaKey.PATCH_INDEX, None) if metadict else None return {"subject": f"{subject}", "idx": patch_index} diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 2b28949419..fc8175f630 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence from copy import deepcopy -from typing import Callable, Dict, Hashable, Iterable, Mapping, Optional, Sequence, Union import numpy as np @@ -19,7 +21,7 @@ from monai.data.iterable_dataset import IterableDataset from monai.data.utils import iter_patch from monai.transforms import apply_transform -from monai.utils import NumpyPadMode, deprecated_arg, ensure_tuple, first, look_up_option +from monai.utils import NumpyPadMode, ensure_tuple, first, look_up_option __all__ = ["PatchDataset", "GridPatchDataset", "PatchIter", "PatchIterd"] @@ -32,7 +34,7 @@ class PatchIter: """ def __init__( - self, patch_size: Sequence[int], start_pos: Sequence[int] = (), mode: str = NumpyPadMode.WRAP, **pad_opts: Dict + self, patch_size: Sequence[int], start_pos: Sequence[int] = (), mode: str = NumpyPadMode.WRAP, **pad_opts: dict ): """ @@ -173,12 +175,11 @@ class GridPatchDataset(IterableDataset): """ - @deprecated_arg(name="dataset", new_name="data", since="0.8", msg_suffix="please use `data` instead.") def __init__( self, - data: Union[Iterable, Sequence], + data: Iterable | Sequence, patch_iter: Callable, - transform: Optional[Callable] = None, + transform: Callable | None = None, with_coordinates: bool = True, ) -> None: super().__init__(data=data, transform=None) @@ -242,9 +243,8 @@ class PatchDataset(Dataset): """ - @deprecated_arg(name="dataset", new_name="data", since="0.8", msg_suffix="please use `data` instead.") def __init__( - self, data: Sequence, patch_func: Callable, samples_per_image: int = 1, transform: Optional[Callable] = None + self, data: Sequence, patch_func: Callable, samples_per_image: int = 1, transform: Callable | None = None ) -> None: """ Args: diff --git a/monai/data/image_dataset.py b/monai/data/image_dataset.py index 89694a4bb8..6c8ddcf8de 100644 --- a/monai/data/image_dataset.py +++ b/monai/data/image_dataset.py @@ -9,7 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence, Union +from __future__ import annotations + +from collections.abc import Callable, Sequence +from typing import Any import numpy as np from torch.utils.data import Dataset @@ -33,15 +36,15 @@ class ImageDataset(Dataset, Randomizable): def __init__( self, image_files: Sequence[str], - seg_files: Optional[Sequence[str]] = None, - labels: Optional[Sequence[float]] = None, - transform: Optional[Callable] = None, - seg_transform: Optional[Callable] = None, - label_transform: Optional[Callable] = None, + seg_files: Sequence[str] | None = None, + labels: Sequence[float] | None = None, + transform: Callable | None = None, + seg_transform: Callable | None = None, + label_transform: Callable | None = None, image_only: bool = True, transform_with_metadata: bool = False, dtype: DtypeLike = np.float32, - reader: Optional[Union[ImageReader, str]] = None, + reader: ImageReader | str | None = None, *args, **kwargs, ) -> None: @@ -93,7 +96,7 @@ def __init__( def __len__(self) -> int: return len(self.image_files) - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, data: Any | None = None) -> None: self._seed = self.R.randint(MAX_SEED, dtype="uint32") def __getitem__(self, index: int): diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 6919c10ffe..03bffbb1e8 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -9,13 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import glob import os import warnings from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable, Iterator, Sequence from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern @@ -24,6 +27,7 @@ from monai.data.utils import ( affine_to_spacing, correct_nifti_header_if_necessary, + is_no_channel, is_supported_format, orientation_ras_lps, ) @@ -58,7 +62,6 @@ nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True) OpenSlide, _ = optional_import("openslide", name="OpenSlide") -CuImage, _ = optional_import("cucim", name="CuImage") TiffFile, _ = optional_import("tifffile", name="TiffFile") __all__ = [ @@ -93,7 +96,7 @@ class ImageReader(ABC): """ @abstractmethod - def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ Verify whether the specified `filename` is supported by the current reader. This method should return True if the reader is able to read the format suggested by the @@ -107,7 +110,7 @@ def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") @abstractmethod - def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs) -> Union[Sequence[Any], Any]: + def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | Any: """ Read image data from specified file or files. Note that it returns a data object or a sequence of data objects. @@ -120,7 +123,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs) -> Union[Seq raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") @abstractmethod - def get_data(self, img) -> Tuple[np.ndarray, Dict]: + def get_data(self, img) -> tuple[np.ndarray, dict]: """ Extract data array and metadata from loaded image and return them. This function must return two objects, the first is a numpy array of image data, @@ -133,7 +136,7 @@ def get_data(self, img) -> Tuple[np.ndarray, Dict]: raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") -def _copy_compatible_dict(from_dict: Dict, to_dict: Dict): +def _copy_compatible_dict(from_dict: dict, to_dict: dict): if not isinstance(to_dict, dict): raise ValueError(f"to_dict must be a Dict, got {type(to_dict)}.") if not to_dict: @@ -156,10 +159,10 @@ def _copy_compatible_dict(from_dict: Dict, to_dict: Dict): ) -def _stack_images(image_list: List, meta_dict: Dict): +def _stack_images(image_list: list, meta_dict: dict): if len(image_list) <= 1: return image_list[0] - if meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None) not in ("no_channel", None): + if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)): channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM]) return np.concatenate(image_list, axis=channel_dim) # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified @@ -201,7 +204,7 @@ class ITKReader(ImageReader): def __init__( self, - channel_dim: Optional[int] = None, + channel_dim: str | int | None = None, series_name: str = "", reverse_indexing: bool = False, series_meta: bool = False, @@ -210,13 +213,13 @@ def __init__( ): super().__init__() self.kwargs = kwargs - self.channel_dim = channel_dim + self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim self.series_name = series_name self.reverse_indexing = reverse_indexing self.series_meta = series_meta self.affine_lps_to_ras = affine_lps_to_ras - def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ Verify whether the specified file or files format is supported by ITK reader. @@ -227,7 +230,7 @@ def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ return has_itk - def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): + def read(self, data: Sequence[PathLike] | PathLike, **kwargs): """ Read image data from specified file or files, it can read a list of images and stack them together as multi-channel data in `get_data()`. @@ -277,7 +280,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): img_.append(itk.imread(name, **kwargs_)) return img_ if len(filenames) > 1 else img_[0] - def get_data(self, img) -> Tuple[np.ndarray, Dict]: + def get_data(self, img) -> tuple[np.ndarray, dict]: """ Extract data array and metadata from loaded image and return them. This function returns two objects, first is numpy array of image data, second is dict of metadata. @@ -289,8 +292,8 @@ def get_data(self, img) -> Tuple[np.ndarray, Dict]: img: an ITK image object loaded from an image file or a list of ITK image objects. """ - img_array: List[np.ndarray] = [] - compatible_meta: Dict = {} + img_array: list[np.ndarray] = [] + compatible_meta: dict = {} for i in ensure_tuple(img): data = self._get_array_data(i) @@ -302,7 +305,7 @@ def get_data(self, img) -> Tuple[np.ndarray, Dict]: header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) if self.channel_dim is None: # default to "no_channel" or -1 header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - "no_channel" if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 + float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 ) else: header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim @@ -310,7 +313,7 @@ def get_data(self, img) -> Tuple[np.ndarray, Dict]: return _stack_images(img_array, compatible_meta), compatible_meta - def _get_meta_dict(self, img) -> Dict: + def _get_meta_dict(self, img) -> dict: """ Get all the metadata of the image and convert to dict type. @@ -363,7 +366,7 @@ def _get_spatial_shape(self, img): sr = itk.array_from_matrix(img.GetDirection()).shape[0] sr = max(min(sr, 3), 1) _size = list(itk.size(img)) - if self.channel_dim is not None: + if isinstance(self.channel_dim, int): _size.pop(self.channel_dim) return np.asarray(_size[:sr]) @@ -432,22 +435,22 @@ class PydicomReader(ImageReader): def __init__( self, - channel_dim: Optional[int] = None, + channel_dim: str | int | None = None, affine_lps_to_ras: bool = True, swap_ij: bool = True, prune_metadata: bool = True, - label_dict: Optional[Dict] = None, + label_dict: dict | None = None, **kwargs, ): super().__init__() self.kwargs = kwargs - self.channel_dim = channel_dim + self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim self.affine_lps_to_ras = affine_lps_to_ras self.swap_ij = swap_ij self.prune_metadata = prune_metadata self.label_dict = label_dict - def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ Verify whether the specified file or files format is supported by Pydicom reader. @@ -458,7 +461,7 @@ def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ return has_pydicom - def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): + def read(self, data: Sequence[PathLike] | PathLike, **kwargs): """ Read image data from specified file or files, it can read a list of images and stack them together as multi-channel data in `get_data()`. @@ -515,7 +518,7 @@ def _combine_dicom_series(self, data: Iterable): Returns: a tuple that consisted with data array and metadata. """ - slices: List = [] + slices: list = [] # for a dicom series for slc_ds in data: if hasattr(slc_ds, "InstanceNumber"): @@ -564,7 +567,7 @@ def _combine_dicom_series(self, data: Iterable): return stack_array, stack_metadata - def get_data(self, data) -> Tuple[np.ndarray, Dict]: + def get_data(self, data) -> tuple[np.ndarray, dict]: """ Extract data array and metadata from loaded image and return them. This function returns two objects, first is numpy array of image data, second is dict of metadata. @@ -591,7 +594,7 @@ def get_data(self, data) -> Tuple[np.ndarray, Dict]: # combine dicom series if exists if self.has_series is True: # a list, all objects within a list belong to one dicom series - if not isinstance(data[0], List): + if not isinstance(data[0], list): dicom_data.append(self._combine_dicom_series(data)) # a list of list, each inner list represents a dicom series else: @@ -599,7 +602,7 @@ def get_data(self, data) -> Tuple[np.ndarray, Dict]: dicom_data.append(self._combine_dicom_series(series)) else: # a single pydicom dataset object - if not isinstance(data, List): + if not isinstance(data, list): data = [data] for d in data: if hasattr(d, "SegmentSequence"): @@ -610,10 +613,10 @@ def get_data(self, data) -> Tuple[np.ndarray, Dict]: metadata[MetaKeys.SPATIAL_SHAPE] = data_array.shape dicom_data.append((data_array, metadata)) - img_array: List[np.ndarray] = [] - compatible_meta: Dict = {} + img_array: list[np.ndarray] = [] + compatible_meta: dict = {} - for (data_array, metadata) in ensure_tuple(dicom_data): + for data_array, metadata in ensure_tuple(dicom_data): img_array.append(np.ascontiguousarray(np.swapaxes(data_array, 0, 1) if self.swap_ij else data_array)) affine = self._get_affine(metadata, self.affine_lps_to_ras) metadata[MetaKeys.SPACE] = SpaceKeys.RAS if self.affine_lps_to_ras else SpaceKeys.LPS @@ -626,7 +629,7 @@ def get_data(self, data) -> Tuple[np.ndarray, Dict]: metadata[MetaKeys.AFFINE] = affine.copy() if self.channel_dim is None: # default to "no_channel" or -1 metadata[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - "no_channel" if len(data_array.shape) == len(metadata[MetaKeys.SPATIAL_SHAPE]) else -1 + float("nan") if len(data_array.shape) == len(metadata[MetaKeys.SPATIAL_SHAPE]) else -1 ) else: metadata[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim @@ -638,7 +641,7 @@ def get_data(self, data) -> Tuple[np.ndarray, Dict]: return _stack_images(img_array, compatible_meta), compatible_meta - def _get_meta_dict(self, img) -> Dict: + def _get_meta_dict(self, img) -> dict: """ Get all the metadata of the image and convert to dict type. @@ -664,7 +667,7 @@ def _get_meta_dict(self, img) -> Dict: return metadata # type: ignore - def _get_affine(self, metadata: Dict, lps_to_ras: bool = True): + def _get_affine(self, metadata: dict, lps_to_ras: bool = True): """ Get or construct the affine matrix of the image, it can be used to correct spacing, orientation or execute spatial transforms. @@ -880,20 +883,20 @@ class NibabelReader(ImageReader): @deprecated_arg("dtype", since="1.0", msg_suffix="please modify dtype of the returned by ``get_data`` instead.") def __init__( self, - channel_dim: Optional[int] = None, + channel_dim: str | int | None = None, as_closest_canonical: bool = False, squeeze_non_spatial_dims: bool = False, dtype: DtypeLike = np.float32, **kwargs, ): super().__init__() - self.channel_dim = channel_dim + self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim self.as_closest_canonical = as_closest_canonical self.squeeze_non_spatial_dims = squeeze_non_spatial_dims self.dtype = dtype # deprecated self.kwargs = kwargs - def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ Verify whether the specified file or files format is supported by Nibabel reader. @@ -905,7 +908,7 @@ def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: suffixes: Sequence[str] = ["nii", "nii.gz"] return has_nib and is_supported_format(filename, suffixes) - def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): + def read(self, data: Sequence[PathLike] | PathLike, **kwargs): """ Read image data from specified file or files, it can read a list of images and stack them together as multi-channel data in `get_data()`. @@ -918,7 +921,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py """ - img_: List[Nifti1Image] = [] + img_: list[Nifti1Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() @@ -929,7 +932,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): img_.append(img) return img_ if len(filenames) > 1 else img_[0] - def get_data(self, img) -> Tuple[np.ndarray, Dict]: + def get_data(self, img) -> tuple[np.ndarray, dict]: """ Extract data array and metadata from loaded image and return them. This function returns two objects, first is numpy array of image data, second is dict of metadata. @@ -941,8 +944,8 @@ def get_data(self, img) -> Tuple[np.ndarray, Dict]: img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. """ - img_array: List[np.ndarray] = [] - compatible_meta: Dict = {} + img_array: list[np.ndarray] = [] + compatible_meta: dict = {} for i in ensure_tuple(img): header = self._get_meta_dict(i) @@ -962,7 +965,7 @@ def get_data(self, img) -> Tuple[np.ndarray, Dict]: img_array.append(data) if self.channel_dim is None: # default to "no_channel" or -1 header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - "no_channel" if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 + float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 ) else: header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim @@ -970,7 +973,7 @@ def get_data(self, img) -> Tuple[np.ndarray, Dict]: return _stack_images(img_array, compatible_meta), compatible_meta - def _get_meta_dict(self, img) -> Dict: + def _get_meta_dict(self, img) -> dict: """ Get the all the metadata of the image and convert to dict type. @@ -1015,8 +1018,8 @@ def _get_spatial_shape(self, img): dim = np.insert(dim, 0, 3) ndim = dim[0] size = list(dim[1:]) - if self.channel_dim is not None: - size.pop(self.channel_dim) + if not is_no_channel(self.channel_dim): + size.pop(int(self.channel_dim)) # type: ignore spatial_rank = max(min(ndim, 3), 1) return np.asarray(size[:spatial_rank]) @@ -1028,7 +1031,7 @@ def _get_array_data(self, img): img: a Nibabel image object loaded from an image file. """ - return np.asanyarray(img.dataobj) + return np.asanyarray(img.dataobj, order="C") class NumpyReader(ImageReader): @@ -1046,15 +1049,15 @@ class NumpyReader(ImageReader): """ - def __init__(self, npz_keys: Optional[KeysCollection] = None, channel_dim: Optional[int] = None, **kwargs): + def __init__(self, npz_keys: KeysCollection | None = None, channel_dim: str | int | None = None, **kwargs): super().__init__() if npz_keys is not None: npz_keys = ensure_tuple(npz_keys) self.npz_keys = npz_keys - self.channel_dim = channel_dim + self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim self.kwargs = kwargs - def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ Verify whether the specified file or files format is supported by Numpy reader. @@ -1065,7 +1068,7 @@ def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: suffixes: Sequence[str] = ["npz", "npy"] return is_supported_format(filename, suffixes) - def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): + def read(self, data: Sequence[PathLike] | PathLike, **kwargs): """ Read image data from specified file or files, it can read a list of data files and stack them together as multi-channel data in `get_data()`. @@ -1078,7 +1081,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): https://numpy.org/doc/stable/reference/generated/numpy.load.html """ - img_: List[Nifti1Image] = [] + img_: list[Nifti1Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() @@ -1095,7 +1098,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): return img_ if len(img_) > 1 else img_[0] - def get_data(self, img) -> Tuple[np.ndarray, Dict]: + def get_data(self, img) -> tuple[np.ndarray, dict]: """ Extract data array and metadata from loaded image and return them. This function returns two objects, first is numpy array of image data, second is dict of metadata. @@ -1107,13 +1110,13 @@ def get_data(self, img) -> Tuple[np.ndarray, Dict]: img: a Numpy array loaded from a file or a list of Numpy arrays. """ - img_array: List[np.ndarray] = [] - compatible_meta: Dict = {} + img_array: list[np.ndarray] = [] + compatible_meta: dict = {} if isinstance(img, np.ndarray): img = (img,) for i in ensure_tuple(img): - header: Dict[MetaKeys, Any] = {} + header: dict[MetaKeys, Any] = {} if isinstance(i, np.ndarray): # if `channel_dim` is None, can not detect the channel dim, use all the dims as spatial_shape spatial_shape = np.asarray(i.shape) @@ -1123,7 +1126,7 @@ def get_data(self, img) -> Tuple[np.ndarray, Dict]: header[MetaKeys.SPACE] = SpaceKeys.RAS img_array.append(i) header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - self.channel_dim if isinstance(self.channel_dim, int) else "no_channel" + self.channel_dim if isinstance(self.channel_dim, int) else float("nan") ) _copy_compatible_dict(header, compatible_meta) @@ -1138,16 +1141,20 @@ class PILReader(ImageReader): Args: converter: additional function to convert the image data after `read()`. for example, use `converter=lambda image: image.convert("LA")` to convert image format. + reverse_indexing: whether to swap axis 0 and 1 after loading the array, this is enabled by default, + so that output of the reader is consistent with the other readers. Set this option to ``False`` to use + the PIL backend's original spatial axes convention. kwargs: additional args for `Image.open` API in `read()`, mode details about available args: https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.open """ - def __init__(self, converter: Optional[Callable] = None, **kwargs): + def __init__(self, converter: Callable | None = None, reverse_indexing: bool = True, **kwargs): super().__init__() self.converter = converter + self.reverse_indexing = reverse_indexing self.kwargs = kwargs - def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ Verify whether the specified file or files format is supported by PIL reader. @@ -1158,7 +1165,7 @@ def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: suffixes: Sequence[str] = ["png", "jpg", "jpeg", "bmp"] return has_pil and is_supported_format(filename, suffixes) - def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): + def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs): """ Read image data from specified file or files, it can read a list of images and stack them together as multi-channel data in `get_data()`. @@ -1171,7 +1178,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.open """ - img_: List[PILImage.Image] = [] + img_: list[PILImage.Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() @@ -1184,36 +1191,36 @@ def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): return img_ if len(filenames) > 1 else img_[0] - def get_data(self, img) -> Tuple[np.ndarray, Dict]: + def get_data(self, img) -> tuple[np.ndarray, dict]: """ Extract data array and metadata from loaded image and return them. This function returns two objects, first is numpy array of image data, second is dict of metadata. It computes `spatial_shape` and stores it in meta dict. When loading a list of files, they are stacked together at a new dimension as the first dimension, and the metadata of the first image is used to represent the output metadata. - Note that it will swap axis 0 and 1 after loading the array because the `HW` definition in PIL - is different from other common medical packages. + Note that by default `self.reverse_indexing` is set to ``True``, which swaps axis 0 and 1 after loading + the array because the spatial axes definition in PIL is different from other common medical packages. Args: img: a PIL Image object loaded from a file or a list of PIL Image objects. """ - img_array: List[np.ndarray] = [] - compatible_meta: Dict = {} + img_array: list[np.ndarray] = [] + compatible_meta: dict = {} for i in ensure_tuple(img): header = self._get_meta_dict(i) header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) - data = np.moveaxis(np.asarray(i), 0, 1) + data = np.moveaxis(np.asarray(i), 0, 1) if self.reverse_indexing else np.asarray(i) img_array.append(data) header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - "no_channel" if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 + float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 ) _copy_compatible_dict(header, compatible_meta) return _stack_images(img_array, compatible_meta), compatible_meta - def _get_meta_dict(self, img) -> Dict: + def _get_meta_dict(self, img) -> dict: """ Get the all the metadata of the image and convert to dict type. Args: @@ -1266,12 +1273,12 @@ def _set_reader(backend: str): if backend == "openslide": return OpenSlide if backend == "cucim": - return CuImage + return optional_import("cucim", name="CuImage")[0] if backend == "tifffile": return TiffFile raise ValueError("`backend` should be 'cuCIM', 'OpenSlide' or 'TiffFile'.") - def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ Verify whether the specified file or files format is supported by WSI reader. @@ -1281,7 +1288,7 @@ def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ return is_supported_format(filename, ["tif", "tiff"]) - def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): + def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs): """ Read image data from given file or list of files. @@ -1297,7 +1304,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): image object or list of image objects """ - img_: List = [] + img_: list = [] filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() @@ -1313,12 +1320,12 @@ def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): def get_data( self, img, - location: Tuple[int, int] = (0, 0), - size: Optional[Tuple[int, int]] = None, - level: Optional[int] = None, + location: tuple[int, int] = (0, 0), + size: tuple[int, int] | None = None, + level: int | None = None, dtype: DtypeLike = np.uint8, - grid_shape: Tuple[int, int] = (1, 1), - patch_size: Optional[Union[int, Tuple[int, int]]] = None, + grid_shape: tuple[int, int] = (1, 1), + patch_size: int | tuple[int, int] | None = None, ): """ Extract regions as numpy array from WSI image and return them. @@ -1345,7 +1352,7 @@ def get_data( region = self._extract_region(img, location=location, size=size, level=level, dtype=dtype) # Add necessary metadata - metadata: Dict = {} + metadata: dict = {} metadata[MetaKeys.SPATIAL_SHAPE] = np.asarray(region.shape[:-1]) metadata[MetaKeys.ORIGINAL_CHANNEL_DIM] = -1 @@ -1403,8 +1410,8 @@ def _get_image_size(self, img, size, level, location): def _extract_region( self, img_obj, - size: Optional[Tuple[int, int]], - location: Tuple[int, int] = (0, 0), + size: tuple[int, int] | None, + location: tuple[int, int] = (0, 0), level: int = 0, dtype: DtypeLike = np.uint8, ): @@ -1464,8 +1471,8 @@ def convert_to_rgb_array(self, raw_region, dtype: DtypeLike = np.uint8): def _extract_patches( self, region: np.ndarray, - grid_shape: Tuple[int, int] = (1, 1), - patch_size: Optional[Tuple[int, int]] = None, + grid_shape: tuple[int, int] = (1, 1), + patch_size: tuple[int, int] | None = None, dtype: DtypeLike = np.uint8, ): if patch_size is None and grid_shape == (1, 1): @@ -1515,6 +1522,9 @@ class NrrdReader(ImageReader): dtype: dtype of the data array when loading image. index_order: Specify whether the returned data array should be in C-order (‘C’) or Fortran-order (‘F’). Numpy is usually in C-order, but default on the NRRD header is F + affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``. + Set to ``True`` to be consistent with ``NibabelReader``, otherwise the affine matrix is unmodified. + kwargs: additional args for `nrrd.read` API. more details about available args: https://github.com/mhe/pynrrd/blob/master/nrrd/reader.py @@ -1522,17 +1532,19 @@ class NrrdReader(ImageReader): def __init__( self, - channel_dim: Optional[int] = None, - dtype: Union[np.dtype, type, str, None] = np.float32, + channel_dim: str | int | None = None, + dtype: np.dtype | type | str | None = np.float32, index_order: str = "F", + affine_lps_to_ras: bool = True, **kwargs, ): - self.channel_dim = channel_dim + self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim self.dtype = dtype self.index_order = index_order + self.affine_lps_to_ras = affine_lps_to_ras self.kwargs = kwargs - def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ Verify whether the specified `filename` is supported by pynrrd reader. @@ -1544,7 +1556,7 @@ def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: suffixes: Sequence[str] = ["nrrd", "seg.nrrd"] return has_nrrd and is_supported_format(filename, suffixes) - def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs) -> Union[Sequence[Any], Any]: + def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | Any: """ Read image data from specified file or files. Note that it returns a data object or a sequence of data objects. @@ -1554,7 +1566,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs) -> Union[Seq kwargs: additional args for actual `read` API of 3rd party libs. """ - img_: List = [] + img_: list = [] filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) @@ -1563,7 +1575,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs) -> Union[Seq img_.append(nrrd_image) return img_ if len(filenames) > 1 else img_[0] - def get_data(self, img: Union[NrrdImage, List[NrrdImage]]) -> Tuple[np.ndarray, Dict]: + def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: """ Extract data array and metadata from loaded image and return them. This function must return two objects, the first is a numpy array of image data, @@ -1573,8 +1585,8 @@ def get_data(self, img: Union[NrrdImage, List[NrrdImage]]) -> Tuple[np.ndarray, img: a `NrrdImage` loaded from an image file or a list of image objects. """ - img_array: List[np.ndarray] = [] - compatible_meta: Dict = {} + img_array: list[np.ndarray] = [] + compatible_meta: dict = {} for i in ensure_tuple(img): data = i.array.astype(self.dtype) @@ -1583,14 +1595,17 @@ def get_data(self, img: Union[NrrdImage, List[NrrdImage]]) -> Tuple[np.ndarray, if self.index_order == "C": header = self._convert_f_to_c_order(header) header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i) - header = self._switch_lps_ras(header) + + if self.affine_lps_to_ras: + header = self._switch_lps_ras(header) + header[MetaKeys.AFFINE] = header[MetaKeys.ORIGINAL_AFFINE].copy() header[MetaKeys.SPATIAL_SHAPE] = header["sizes"] [header.pop(k) for k in ("sizes", "space origin", "space directions")] # rm duplicated data in header if self.channel_dim is None: # default to "no_channel" or -1 header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - "no_channel" if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else 0 + float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else 0 ) else: header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 8d42d032c9..0274c44900 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -9,7 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence, Union, cast +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any, cast import numpy as np @@ -58,7 +61,7 @@ "logger", ] -SUPPORTED_WRITERS: Dict = {} +SUPPORTED_WRITERS: dict = {} def register_writer(ext_name, *im_writers): @@ -178,14 +181,14 @@ def __init__(self, **kwargs): The current member in the base class is ``self.data_obj``, the subclasses can add more members, so that necessary meta information can be stored in the object and shared among the class methods. """ - self.data_obj: Union[Any, NdarrayOrTensor] = None + self.data_obj: Any | NdarrayOrTensor = None for k, v in kwargs.items(): setattr(self, k, v) def set_data_array(self, data_array, **kwargs): raise NotImplementedError(f"Subclasses of {self.__class__.__name__} must implement this method.") - def set_metadata(self, meta_dict: Optional[Mapping], **options): + def set_metadata(self, meta_dict: Mapping | None, **options): raise NotImplementedError(f"Subclasses of {self.__class__.__name__} must implement this method.") def write(self, filename: PathLike, verbose: bool = True, **kwargs): @@ -205,9 +208,9 @@ def create_backend_obj(cls, data_array: NdarrayOrTensor, **kwargs) -> np.ndarray def resample_if_needed( cls, data_array: NdarrayOrTensor, - affine: Optional[NdarrayOrTensor] = None, - target_affine: Optional[NdarrayOrTensor] = None, - output_spatial_shape: Union[Sequence[int], int, None] = None, + affine: NdarrayOrTensor | None = None, + target_affine: NdarrayOrTensor | None = None, + output_spatial_shape: Sequence[int] | int | None = None, mode: str = GridSampleMode.BILINEAR, padding_mode: str = GridSamplePadMode.BORDER, align_corners: bool = False, @@ -267,7 +270,9 @@ def resample_if_needed( if affine is not None: data_array.affine = convert_to_tensor(affine, track_meta=False) # type: ignore resampler = SpatialResample(mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype) - output_array = resampler(data_array[None], dst_affine=target_affine, spatial_size=output_spatial_shape) + output_array = resampler( + data_array[None], dst_affine=target_affine, spatial_size=output_spatial_shape # type: ignore + ) # convert back at the end if isinstance(output_array, MetaTensor): output_array.applied_operations = [] @@ -279,9 +284,9 @@ def resample_if_needed( def convert_to_channel_last( cls, data: NdarrayOrTensor, - channel_dim: Union[None, int, Sequence[int]] = 0, + channel_dim: None | int | Sequence[int] = 0, squeeze_end_dims: bool = True, - spatial_ndim: Optional[int] = 3, + spatial_ndim: int | None = 3, contiguous: bool = False, ): """ @@ -325,7 +330,7 @@ def convert_to_channel_last( return data @classmethod - def get_meta_info(cls, metadata: Optional[Mapping] = None): + def get_meta_info(cls, metadata: Mapping | None = None): """ Extracts relevant meta information from the metadata object (using ``.get``). Optional keys are ``"spatial_shape"``, ``MetaKeys.AFFINE``, ``"original_affine"``. @@ -366,7 +371,7 @@ class ITKWriter(ImageWriter): """ output_dtype: DtypeLike = None - channel_dim: Optional[int] + channel_dim: int | None def __init__(self, output_dtype: DtypeLike = np.float32, affine_lps_to_ras: bool = True, **kwargs): """ @@ -388,7 +393,7 @@ def __init__(self, output_dtype: DtypeLike = np.float32, affine_lps_to_ras: bool ) def set_data_array( - self, data_array: NdarrayOrTensor, channel_dim: Optional[int] = 0, squeeze_end_dims: bool = True, **kwargs + self, data_array: NdarrayOrTensor, channel_dim: int | None = 0, squeeze_end_dims: bool = True, **kwargs ): """ Convert ``data_array`` into 'channel-last' numpy ndarray. @@ -413,7 +418,7 @@ def set_data_array( channel_dim if self.data_obj is not None and len(self.data_obj.shape) >= _r else None ) # channel dim is at the end - def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = True, **options): + def set_metadata(self, meta_dict: Mapping | None = None, resample: bool = True, **options): """ Resample ``self.dataobj`` if needed. This method assumes ``self.data_obj`` is a 'channel-last' ndarray. @@ -470,8 +475,8 @@ def write(self, filename: PathLike, verbose: bool = False, **kwargs): def create_backend_obj( cls, data_array: NdarrayOrTensor, - channel_dim: Optional[int] = 0, - affine: Optional[NdarrayOrTensor] = None, + channel_dim: int | None = 0, + affine: NdarrayOrTensor | None = None, dtype: DtypeLike = np.float32, affine_lps_to_ras: bool = True, **kwargs, @@ -551,7 +556,7 @@ def __init__(self, output_dtype: DtypeLike = np.float32, **kwargs): super().__init__(output_dtype=output_dtype, affine=None, **kwargs) def set_data_array( - self, data_array: NdarrayOrTensor, channel_dim: Optional[int] = 0, squeeze_end_dims: bool = True, **kwargs + self, data_array: NdarrayOrTensor, channel_dim: int | None = 0, squeeze_end_dims: bool = True, **kwargs ): """ Convert ``data_array`` into 'channel-last' numpy ndarray. @@ -571,7 +576,7 @@ def set_data_array( spatial_ndim=kwargs.pop("spatial_ndim", 3), ) - def set_metadata(self, meta_dict: Optional[Mapping], resample: bool = True, **options): + def set_metadata(self, meta_dict: Mapping | None, resample: bool = True, **options): """ Resample ``self.dataobj`` if needed. This method assumes ``self.data_obj`` is a 'channel-last' ndarray. @@ -626,7 +631,7 @@ def write(self, filename: PathLike, verbose: bool = False, **obj_kwargs): @classmethod def create_backend_obj( - cls, data_array: NdarrayOrTensor, affine: Optional[NdarrayOrTensor] = None, dtype: DtypeLike = None, **kwargs + cls, data_array: NdarrayOrTensor, affine: NdarrayOrTensor | None = None, dtype: DtypeLike = None, **kwargs ): """ Create an Nifti1Image object from ``data_array``. This method assumes a 'channel-last' ``data_array``. @@ -678,11 +683,11 @@ class PILWriter(ImageWriter): """ output_dtype: DtypeLike - channel_dim: Optional[int] - scale: Optional[int] + channel_dim: int | None + scale: int | None def __init__( - self, output_dtype: DtypeLike = np.float32, channel_dim: Optional[int] = 0, scale: Optional[int] = 255, **kwargs + self, output_dtype: DtypeLike = np.float32, channel_dim: int | None = 0, scale: int | None = 255, **kwargs ): """ Args: @@ -698,7 +703,7 @@ def __init__( def set_data_array( self, data_array: NdarrayOrTensor, - channel_dim: Optional[int] = 0, + channel_dim: int | None = 0, squeeze_end_dims: bool = True, contiguous: bool = False, **kwargs, @@ -723,7 +728,7 @@ def set_data_array( contiguous=contiguous, ) - def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = True, **options): + def set_metadata(self, meta_dict: Mapping | None = None, resample: bool = True, **options): """ Resample ``self.dataobj`` if needed. This method assumes ``self.data_obj`` is a 'channel-last' ndarray. @@ -769,14 +774,14 @@ def write(self, filename: PathLike, verbose: bool = False, **kwargs): self.data_obj.save(filename, **kwargs) @classmethod - def get_meta_info(cls, metadata: Optional[Mapping] = None): + def get_meta_info(cls, metadata: Mapping | None = None): return None if not metadata else metadata.get(MetaKeys.SPATIAL_SHAPE) @classmethod def resample_and_clip( cls, data_array: NdarrayOrTensor, - output_spatial_shape: Optional[Sequence[int]] = None, + output_spatial_shape: Sequence[int] | None = None, mode: str = InterpolateMode.BICUBIC, ) -> np.ndarray: """ @@ -810,7 +815,7 @@ def create_backend_obj( cls, data_array: NdarrayOrTensor, dtype: DtypeLike = None, - scale: Optional[int] = 255, + scale: int | None = 255, reverse_indexing: bool = True, **kwargs, ): diff --git a/monai/data/iterable_dataset.py b/monai/data/iterable_dataset.py index 1d6b4b06b2..4c476b2f9d 100644 --- a/monai/data/iterable_dataset.py +++ b/monai/data/iterable_dataset.py @@ -9,7 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Sequence, Union +from __future__ import annotations + +from collections.abc import Callable, Iterable, Iterator, Sequence +from typing import Any from torch.utils.data import IterableDataset as _TorchIterableDataset from torch.utils.data import get_worker_info @@ -17,7 +20,7 @@ from monai.data.utils import convert_tables_to_dicts from monai.transforms import apply_transform from monai.transforms.transform import Randomizable -from monai.utils import deprecated_arg, optional_import +from monai.utils import optional_import pd, _ = optional_import("pandas") @@ -37,7 +40,7 @@ class IterableDataset(_TorchIterableDataset): """ - def __init__(self, data: Iterable[Any], transform: Optional[Callable] = None) -> None: + def __init__(self, data: Iterable[Any], transform: Callable | None = None) -> None: """ Args: data: input data source to load and transform to generate dataset for model. @@ -45,7 +48,7 @@ def __init__(self, data: Iterable[Any], transform: Optional[Callable] = None) -> """ self.data = data self.transform = transform - self.source: Optional[Iterator[Any]] = None + self.source: Iterator[Any] | None = None def __iter__(self): info = get_worker_info() @@ -110,7 +113,7 @@ def randomized_pop(self, buffer): def generate_item(self): """Fill a `buffer` list up to `self.size`, then generate randomly popped items.""" - buffer: List[Any] = [] + buffer: list[Any] = [] for item in iter(self.data): if len(buffer) >= self.size: yield self.randomized_pop(buffer) @@ -197,19 +200,18 @@ class CSVIterableDataset(IterableDataset): """ - @deprecated_arg(name="filename", new_name="src", since="0.8", msg_suffix="please use `src` instead.") def __init__( self, - src: Union[Union[str, Sequence[str]], Union[Iterable, Sequence[Iterable]]], + src: str | Sequence[str] | Iterable | Sequence[Iterable], chunksize: int = 1000, - buffer_size: Optional[int] = None, - col_names: Optional[Sequence[str]] = None, - col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None, - col_groups: Optional[Dict[str, Sequence[str]]] = None, - transform: Optional[Callable] = None, + buffer_size: int | None = None, + col_names: Sequence[str] | None = None, + col_types: dict[str, dict[str, Any] | None] | None = None, + col_groups: dict[str, Sequence[str]] | None = None, + transform: Callable | None = None, shuffle: bool = False, seed: int = 0, - kwargs_read_csv: Optional[Dict] = None, + kwargs_read_csv: dict | None = None, **kwargs, ): self.src = src @@ -225,11 +227,10 @@ def __init__( kwargs.pop("filename", None) self.kwargs = kwargs - self.iters: List[Iterable] = self.reset() + self.iters: list[Iterable] = self.reset() super().__init__(data=None, transform=transform) # type: ignore - @deprecated_arg(name="filename", new_name="src", since="0.8", msg_suffix="please use `src` instead.") - def reset(self, src: Optional[Union[Union[str, Sequence[str]], Union[Iterable, Sequence[Iterable]]]] = None): + def reset(self, src: str | Sequence[str] | Iterable | Sequence[Iterable] | None = None): """ Reset the pandas `TextFileReader` iterable object to read data. For more details, please check: https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html?#iteration. diff --git a/monai/data/itk_torch_bridge.py b/monai/data/itk_torch_bridge.py new file mode 100644 index 0000000000..3dc25ad0bd --- /dev/null +++ b/monai/data/itk_torch_bridge.py @@ -0,0 +1,338 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +import numpy as np +import torch + +from monai.config.type_definitions import DtypeLike +from monai.data import ITKReader, ITKWriter +from monai.data.meta_tensor import MetaTensor +from monai.transforms import EnsureChannelFirst +from monai.utils import convert_to_dst_type, optional_import + +if TYPE_CHECKING: + import itk + + has_itk = True +else: + itk, has_itk = optional_import("itk") + +__all__ = [ + "itk_image_to_metatensor", + "metatensor_to_itk_image", + "itk_to_monai_affine", + "monai_to_itk_affine", + "get_itk_image_center", + "monai_to_itk_ddf", +] + + +def itk_image_to_metatensor( + image, channel_dim: str | int | None = None, dtype: DtypeLike | torch.dtype = float +) -> MetaTensor: + """ + Converts an ITK image to a MetaTensor object. + + Args: + image: The ITK image to be converted. + channel_dim: the channel dimension of the input image, default is None. + This is used to set original_channel_dim in the metadata, EnsureChannelFirst reads this field. + If None, the channel_dim is inferred automatically. + If the input array doesn't have a channel dim, this value should be ``'no_channel'``. + dtype: output dtype, defaults to the Python built-in `float`. + + Returns: + A MetaTensor object containing the array data and metadata in ChannelFirst format. + """ + reader = ITKReader(affine_lps_to_ras=False, channel_dim=channel_dim) + image_array, meta_data = reader.get_data(image) + image_array = convert_to_dst_type(image_array, dst=image_array, dtype=dtype)[0] + metatensor = MetaTensor.ensure_torch_and_prune_meta(image_array, meta_data) + metatensor = EnsureChannelFirst(channel_dim=channel_dim)(metatensor) + + return cast(MetaTensor, metatensor) + + +def metatensor_to_itk_image( + meta_tensor: MetaTensor, channel_dim: int | None = 0, dtype: DtypeLike = np.float32, **kwargs +): + """ + Converts a MetaTensor object to an ITK image. Expects the MetaTensor to be in ChannelFirst format. + + Args: + meta_tensor: The MetaTensor to be converted. + channel_dim: channel dimension of the data array, defaults to ``0`` (Channel-first). + ``None`` indicates no channel dimension. This is used to create a Vector Image if it is not ``None``. + dtype: output data type, defaults to `np.float32`. + kwargs: additional keyword arguments. Currently `itk.GetImageFromArray` will get ``ttype`` from this dictionary. + + Returns: + The ITK image. + + See also: :py:func:`ITKWriter.create_backend_obj` + """ + writer = ITKWriter(output_dtype=dtype, affine_lps_to_ras=False) + writer.set_data_array(data_array=meta_tensor.data, channel_dim=channel_dim, squeeze_end_dims=True) + return writer.create_backend_obj( + writer.data_obj, + channel_dim=writer.channel_dim, + affine=meta_tensor.affine, + affine_lps_to_ras=False, # False if the affine is in itk convention + dtype=writer.output_dtype, + kwargs=kwargs, + ) + + +def itk_to_monai_affine(image, matrix, translation, center_of_rotation=None, reference_image=None) -> torch.Tensor: + """ + Converts an ITK affine matrix (2x2 for 2D or 3x3 for 3D matrix and translation vector) to a MONAI affine matrix. + + Args: + image: The ITK image object. This is used to extract the spacing and direction information. + matrix: The 2x2 or 3x3 ITK affine matrix. + translation: The 2-element or 3-element ITK affine translation vector. + center_of_rotation: The center of rotation. If provided, the affine + matrix will be adjusted to account for the difference + between the center of the image and the center of rotation. + reference_image: The coordinate space that matrix and translation were defined + in respect to. If not supplied, the coordinate space of image + is used. + + Returns: + A 4x4 MONAI affine matrix. + """ + + _assert_itk_regions_match_array(image) + ndim = image.ndim + # If there is a reference image, compute an affine matrix that maps the image space to the + # reference image space. + if reference_image: + reference_affine_matrix = _compute_reference_space_affine_matrix(image, reference_image) + else: + reference_affine_matrix = torch.eye(ndim + 1, dtype=torch.float64) + + # Create affine matrix that includes translation + affine_matrix = torch.eye(ndim + 1, dtype=torch.float64) + affine_matrix[:ndim, :ndim] = torch.tensor(matrix, dtype=torch.float64) + affine_matrix[:ndim, ndim] = torch.tensor(translation, dtype=torch.float64) + + # Adjust offset when center of rotation is different from center of the image + if center_of_rotation: + offset_matrix, inverse_offset_matrix = _compute_offset_matrix(image, center_of_rotation) + affine_matrix = inverse_offset_matrix @ affine_matrix @ offset_matrix + + # Adjust direction + direction_matrix, inverse_direction_matrix = _compute_direction_matrix(image) + affine_matrix = inverse_direction_matrix @ affine_matrix @ direction_matrix + + # Adjust based on spacing. It is required because MONAI does not update the + # pixel array according to the spacing after a transformation. For example, + # a rotation of 90deg for an image with different spacing along the two axis + # will just rotate the image array by 90deg without also scaling accordingly. + spacing_matrix, inverse_spacing_matrix = _compute_spacing_matrix(image) + affine_matrix = inverse_spacing_matrix @ affine_matrix @ spacing_matrix + + return affine_matrix @ reference_affine_matrix + + +def monai_to_itk_affine(image, affine_matrix, center_of_rotation=None): + """ + Converts a MONAI affine matrix to an ITK affine matrix (2x2 for 2D or 3x3 for + 3D matrix and translation vector). See also 'itk_to_monai_affine'. + + Args: + image: The ITK image object. This is used to extract the spacing and direction information. + affine_matrix: The 3x3 for 2D or 4x4 for 3D MONAI affine matrix. + center_of_rotation: The center of rotation. If provided, the affine + matrix will be adjusted to account for the difference + between the center of the image and the center of rotation. + + Returns: + The ITK matrix and the translation vector. + """ + _assert_itk_regions_match_array(image) + + # Adjust spacing + spacing_matrix, inverse_spacing_matrix = _compute_spacing_matrix(image) + affine_matrix = spacing_matrix @ affine_matrix @ inverse_spacing_matrix + + # Adjust direction + direction_matrix, inverse_direction_matrix = _compute_direction_matrix(image) + affine_matrix = direction_matrix @ affine_matrix @ inverse_direction_matrix + + # Adjust offset when center of rotation is different from center of the image + if center_of_rotation: + offset_matrix, inverse_offset_matrix = _compute_offset_matrix(image, center_of_rotation) + affine_matrix = offset_matrix @ affine_matrix @ inverse_offset_matrix + + ndim = image.ndim + matrix = affine_matrix[:ndim, :ndim].numpy() + translation = affine_matrix[:ndim, ndim].tolist() + + return matrix, translation + + +def get_itk_image_center(image): + """ + Calculates the center of the ITK image based on its origin, size, and spacing. + This center is equivalent to the implicit image center that MONAI uses. + + Args: + image: The ITK image. + + Returns: + The center of the image as a list of coordinates. + """ + image_size = np.asarray(image.GetLargestPossibleRegion().GetSize(), np.float32) + spacing = np.asarray(image.GetSpacing()) + origin = np.asarray(image.GetOrigin()) + center = image.GetDirection() @ ((image_size / 2 - 0.5) * spacing) + origin + + return center.tolist() + + +def _assert_itk_regions_match_array(image): + # Note: Make it more compact? Also, are there redundant checks? + largest_region = image.GetLargestPossibleRegion() + buffered_region = image.GetBufferedRegion() + requested_region = image.GetRequestedRegion() + + largest_region_size = np.array(largest_region.GetSize()) + buffered_region_size = np.array(buffered_region.GetSize()) + requested_region_size = np.array(requested_region.GetSize()) + array_size = np.array(image.shape)[::-1] + + largest_region_index = np.array(largest_region.GetIndex()) + buffered_region_index = np.array(buffered_region.GetIndex()) + requested_region_index = np.array(requested_region.GetIndex()) + + indices_are_zeros = ( + np.all(largest_region_index == 0) and np.all(buffered_region_index == 0) and np.all(requested_region_index == 0) + ) + + sizes_match = ( + np.array_equal(array_size, largest_region_size) + and np.array_equal(largest_region_size, buffered_region_size) + and np.array_equal(buffered_region_size, requested_region_size) + ) + + if not indices_are_zeros: + raise AssertionError("ITK-MONAI bridge: non-zero ITK region indices encountered") + if not sizes_match: + raise AssertionError("ITK-MONAI bridge: ITK regions should be of the same shape") + + +def _compute_offset_matrix(image, center_of_rotation) -> tuple[torch.Tensor, torch.Tensor]: + ndim = image.ndim + offset = np.asarray(get_itk_image_center(image)) - np.asarray(center_of_rotation) + offset_matrix = torch.eye(ndim + 1, dtype=torch.float64) + offset_matrix[:ndim, ndim] = torch.tensor(offset, dtype=torch.float64) + inverse_offset_matrix = torch.eye(ndim + 1, dtype=torch.float64) + inverse_offset_matrix[:ndim, ndim] = -torch.tensor(offset, dtype=torch.float64) + + return offset_matrix, inverse_offset_matrix + + +def _compute_spacing_matrix(image) -> tuple[torch.Tensor, torch.Tensor]: + ndim = image.ndim + spacing = np.asarray(image.GetSpacing(), dtype=np.float64) + spacing_matrix = torch.eye(ndim + 1, dtype=torch.float64) + inverse_spacing_matrix = torch.eye(ndim + 1, dtype=torch.float64) + for i, e in enumerate(spacing): + spacing_matrix[i, i] = e + inverse_spacing_matrix[i, i] = 1 / e + + return spacing_matrix, inverse_spacing_matrix + + +def _compute_direction_matrix(image) -> tuple[torch.Tensor, torch.Tensor]: + ndim = image.ndim + direction = itk.array_from_matrix(image.GetDirection()) + direction_matrix = torch.eye(ndim + 1, dtype=torch.float64) + direction_matrix[:ndim, :ndim] = torch.tensor(direction, dtype=torch.float64) + inverse_direction = itk.array_from_matrix(image.GetInverseDirection()) + inverse_direction_matrix = torch.eye(ndim + 1, dtype=torch.float64) + inverse_direction_matrix[:ndim, :ndim] = torch.tensor(inverse_direction, dtype=torch.float64) + + return direction_matrix, inverse_direction_matrix + + +def _compute_reference_space_affine_matrix(image, ref_image) -> torch.Tensor: + ndim = ref_image.ndim + + # Spacing and direction as matrices + spacing_matrix, inv_spacing_matrix = (m[:ndim, :ndim].numpy() for m in _compute_spacing_matrix(image)) + ref_spacing_matrix, ref_inv_spacing_matrix = (m[:ndim, :ndim].numpy() for m in _compute_spacing_matrix(ref_image)) + + direction_matrix, inv_direction_matrix = (m[:ndim, :ndim].numpy() for m in _compute_direction_matrix(image)) + ref_direction_matrix, ref_inv_direction_matrix = ( + m[:ndim, :ndim].numpy() for m in _compute_direction_matrix(ref_image) + ) + + # Matrix calculation + matrix = ref_direction_matrix @ ref_spacing_matrix @ inv_spacing_matrix @ inv_direction_matrix + + # Offset calculation + pixel_offset = -1 + image_size = np.asarray(ref_image.GetLargestPossibleRegion().GetSize(), np.float32) + translation = ( + (ref_direction_matrix @ ref_spacing_matrix - direction_matrix @ spacing_matrix) + @ (image_size + pixel_offset) + / 2 + ) + translation += np.asarray(ref_image.GetOrigin()) - np.asarray(image.GetOrigin()) + + # Convert matrix ITK matrix and translation to MONAI affine matrix + ref_affine_matrix = itk_to_monai_affine(image, matrix=matrix, translation=translation) + + return ref_affine_matrix + + +def monai_to_itk_ddf(image, ddf): + """ + converting the dense displacement field from the MONAI space to the ITK + Args: + image: itk image of array shape 2D: (H, W) or 3D: (D, H, W) + ddf: numpy array of shape 2D: (2, H, W) or 3D: (3, D, H, W) + Returns: + displacement_field: itk image of the corresponding displacement field + + """ + # 3, D, H, W -> D, H, W, 3 + ndim = image.ndim + ddf = ddf.transpose(tuple(list(range(1, ndim + 1)) + [0])) + # x, y, z -> z, x, y + ddf = ddf[..., ::-1] + + # Correct for spacing + spacing = np.asarray(image.GetSpacing(), dtype=np.float64) + ddf *= np.array(spacing, ndmin=ndim + 1) + + # Correct for direction + direction = np.asarray(image.GetDirection(), dtype=np.float64) + ddf = np.einsum("ij,...j->...i", direction, ddf, dtype=np.float64).astype(np.float32) + + # initialise displacement field - + vector_component_type = itk.F + vector_pixel_type = itk.Vector[vector_component_type, ndim] + displacement_field_type = itk.Image[vector_pixel_type, ndim] + displacement_field = itk.GetImageFromArray(ddf, ttype=displacement_field_type) + + # Set image metadata + displacement_field.SetSpacing(image.GetSpacing()) + displacement_field.SetOrigin(image.GetOrigin()) + displacement_field.SetDirection(image.GetDirection()) + + return displacement_field diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 67f4109c86..86ce7e33fb 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -19,8 +19,7 @@ import numpy as np import torch -from monai.utils.enums import TraceKeys -from monai.utils.misc import first +from monai.utils import TraceKeys, first, is_immutable _TRACK_META = True @@ -107,13 +106,15 @@ def flatten_meta_objs(*args: Iterable): @staticmethod def copy_items(data): """returns a copy of the data. list and dict are shallow copied for efficiency purposes.""" + if is_immutable(data): + return data if isinstance(data, (list, dict, np.ndarray)): return data.copy() if isinstance(data, torch.Tensor): return data.detach().clone() return deepcopy(data) - def copy_meta_from(self, input_objs, copy_attr=True) -> None: + def copy_meta_from(self, input_objs, copy_attr=True, keys=None): """ Copy metadata from a `MetaObj` or an iterable of `MetaObj` instances. @@ -121,13 +122,19 @@ def copy_meta_from(self, input_objs, copy_attr=True) -> None: input_objs: list of `MetaObj` to copy data from. copy_attr: whether to copy each attribute with `MetaObj.copy_item`. note that if the attribute is a nested list or dict, only a shallow copy will be done. + keys: the keys of attributes to copy from the ``input_objs``. + If None, all keys from the input_objs will be copied. """ first_meta = input_objs if isinstance(input_objs, MetaObj) else first(input_objs, default=self) + if not hasattr(first_meta, "__dict__"): + return self first_meta = first_meta.__dict__ + keys = first_meta.keys() if keys is None else keys if not copy_attr: - self.__dict__ = first_meta.copy() # shallow copy for performance + self.__dict__ = {a: first_meta[a] for a in keys if a in first_meta} # shallow copy for performance else: - self.__dict__.update({a: MetaObj.copy_items(first_meta[a]) for a in first_meta}) + self.__dict__.update({a: MetaObj.copy_items(first_meta[a]) for a in keys if a in first_meta}) + return self @staticmethod def get_default_meta() -> dict: diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index f6b93cab2c..3bbd243b4a 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -25,7 +25,7 @@ from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata from monai.utils import look_up_option from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys -from monai.utils.type_conversion import convert_data_type, convert_to_numpy, convert_to_tensor +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_numpy, convert_to_tensor __all__ = ["MetaTensor"] @@ -235,11 +235,19 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: # respectively. Don't need to do anything with the metadata. if batch_idx not in (slice(None, None, None), Ellipsis, None) and idx == 0: ret_meta = decollate_batch(args[0], detach=False)[batch_idx] - if isinstance(ret_meta, list): # e.g. batch[0:2], re-collate - ret_meta = list_data_collate(ret_meta) - else: # e.g. `batch[0]` or `batch[0, 1]`, batch index is an integer + if isinstance(ret_meta, list) and ret_meta: # e.g. batch[0:2], re-collate + try: + ret_meta = list_data_collate(ret_meta) + except (TypeError, ValueError, RuntimeError, IndexError) as e: + raise ValueError( + "Inconsistent batched metadata dicts when slicing a batch of MetaTensors, " + "please convert it into a torch Tensor using `x.as_tensor()` or " + "a numpy array using `x.array`." + ) from e + elif isinstance(ret_meta, MetaObj): # e.g. `batch[0]` or `batch[0, 1]`, batch_idx is int ret_meta.is_batch = False - ret.__dict__ = ret_meta.__dict__.copy() + if hasattr(ret_meta, "__dict__"): + ret.__dict__ = ret_meta.__dict__.copy() # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. # But we only want to split the batch if the `unbind` is along the 0th # dimension. @@ -453,7 +461,7 @@ def affine(self) -> torch.Tensor: @affine.setter def affine(self, d: NdarrayTensor) -> None: """Set the affine.""" - self.meta[MetaKeys.AFFINE] = torch.as_tensor(d, device=torch.device("cpu")) + self.meta[MetaKeys.AFFINE] = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64) @property def pixdim(self): @@ -463,7 +471,10 @@ def pixdim(self): return affine_to_spacing(self.affine) def peek_pending_shape(self): - """Get the currently expected spatial shape as if all the pending operations are executed.""" + """ + Get the currently expected spatial shape as if all the pending operations are executed. + For tensors that have more than 3 spatial dimensions, only the shapes of the top 3 dimensions will be returned. + """ res = None if self.pending_operations: res = self.pending_operations[-1].get(LazyAttr.SHAPE, None) @@ -471,10 +482,22 @@ def peek_pending_shape(self): return tuple(convert_to_numpy(self.shape, wrap_sequence=True).tolist()[1:]) if res is None else res def peek_pending_affine(self): - res = None - if self.pending_operations: - res = self.pending_operations[-1].get(LazyAttr.AFFINE, None) - return self.affine if res is None else res + res = self.affine + r = len(res) - 1 + if r not in (2, 3): + warnings.warn(f"Only 2d and 3d affine are supported, got {r}d input.") + for p in self.pending_operations: + next_matrix = convert_to_tensor(p.get(LazyAttr.AFFINE), dtype=torch.float64) + if next_matrix is None: + continue + res = convert_to_dst_type(res, next_matrix)[0] + next_matrix = monai.data.utils.to_affine_nd(r, next_matrix) + res = monai.transforms.lazy.utils.combine_transforms(res, next_matrix) + return res + + def peek_pending_rank(self): + a = self.pending_operations[-1].get(LazyAttr.AFFINE, None) if self.pending_operations else self.affine + return 1 if a is None else int(max(1, len(a) - 1)) def new_empty(self, size, dtype=None, device=None, requires_grad=False): """ @@ -495,15 +518,15 @@ def clone(self): @staticmethod def ensure_torch_and_prune_meta( - im: NdarrayTensor, meta: dict, simple_keys: bool = False, pattern: str | None = None, sep: str = "." + im: NdarrayTensor, meta: dict | None, simple_keys: bool = False, pattern: str | None = None, sep: str = "." ): """ - Convert the image to `torch.Tensor`. If `affine` is in the `meta` dictionary, + Convert the image to MetaTensor (when meta is not None). If `affine` is in the `meta` dictionary, convert that to `torch.Tensor`, too. Remove any superfluous metadata. Args: im: Input image (`np.ndarray` or `torch.Tensor`) - meta: Metadata dictionary. + meta: Metadata dictionary. When it's None, the metadata is not tracked, this method returns a torch.Tensor. simple_keys: whether to keep only a simple subset of metadata keys. pattern: combined with `sep`, a regular expression used to match and prune keys in the metadata (nested dictionary), default to None, no key deletion. @@ -513,14 +536,17 @@ def ensure_torch_and_prune_meta( Returns: By default, a `MetaTensor` is returned. - However, if `get_track_meta()` is `False`, a `torch.Tensor` is returned. + However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned. """ - img = convert_to_tensor(im) # potentially ascontiguousarray + img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray # if not tracking metadata, return `torch.Tensor` - if not get_track_meta() or meta is None: + if not isinstance(img, MetaTensor): return img + if meta is None: + meta = {} + # remove any superfluous metadata. if simple_keys: # ensure affine is of type `torch.Tensor` @@ -532,7 +558,14 @@ def ensure_torch_and_prune_meta( meta = monai.transforms.DeleteItemsd(keys=pattern, sep=sep, use_re=True)(meta) # return the `MetaTensor` - return MetaTensor(img, meta=meta) + if meta is None: + meta = {} + img.meta = meta + if MetaKeys.AFFINE in meta: + img.affine = meta[MetaKeys.AFFINE] # this uses the affine property setter + else: + img.affine = MetaTensor.get_default_affine() + return img def __repr__(self): """ diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py deleted file mode 100644 index ddc5e10f63..0000000000 --- a/monai/data/nifti_saver.py +++ /dev/null @@ -1,204 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, Optional, Union - -import numpy as np -import torch - -from monai.config import DtypeLike, PathLike -from monai.data.nifti_writer import write_nifti -from monai.data.utils import create_file_basename -from monai.utils import GridSampleMode, GridSamplePadMode -from monai.utils import ImageMetaKey as Key -from monai.utils import deprecated - - -@deprecated(since="0.8", msg_suffix="use monai.transforms.SaveImage instead.") -class NiftiSaver: - """ - Save the data as NIfTI file, it can support single data content or a batch of data. - Typically, the data can be segmentation predictions, call `save` for single data - or call `save_batch` to save a batch of data together. - The name of saved file will be `{input_image_name}_{output_postfix}{output_ext}`, - where the input image name is extracted from the provided metadata dictionary. - If no metadata provided, use index from 0 as the filename prefix. - - Note: image should include channel dimension: [B],C,H,W,[D]. - - .. deprecated:: 0.8 - Use :py:class:`monai.transforms.SaveImage` instead. - - """ - - def __init__( - self, - output_dir: PathLike = "./", - output_postfix: str = "seg", - output_ext: str = ".nii.gz", - resample: bool = True, - mode: str = GridSampleMode.BILINEAR, - padding_mode: str = GridSamplePadMode.BORDER, - align_corners: bool = False, - dtype: DtypeLike = np.float64, - output_dtype: DtypeLike = np.float32, - squeeze_end_dims: bool = True, - data_root_dir: PathLike = "", - separate_folder: bool = True, - print_log: bool = True, - ) -> None: - """ - Args: - output_dir: output image directory. - output_postfix: a string appended to all output file names. - output_ext: output file extension name. - resample: whether to convert the data array to it's original coordinate system - based on `original_affine` in the `meta_data`. - mode: {``"bilinear"``, ``"nearest"``} - This option is used when ``resample = True``. - Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} - This option is used when ``resample = True``. - Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - align_corners: Geometrically, we consider the pixels of the input as squares rather than points. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. - If None, use the data type of input data. - output_dtype: data type for saving data. Defaults to ``np.float32``. - squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel - has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and - then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false, - image will always be saved as (H,W,D,C). - data_root_dir: if not empty, it specifies the beginning parts of the input file's - absolute path. it's used to compute `input_file_rel_path`, the relative path to the file from - `data_root_dir` to preserve folder structure when saving in case there are files in different - folders with the same file names. for example: - input_file_name: /foo/bar/test1/image.nii, - postfix: seg - output_ext: nii.gz - output_dir: /output, - data_root_dir: /foo/bar, - output will be: /output/test1/image/image_seg.nii.gz - separate_folder: whether to save every file in a separate folder, for example: if input filename is - `image.nii`, postfix is `seg` and folder_path is `output`, if `True`, save as: - `output/image/image_seg.nii`, if `False`, save as `output/image_seg.nii`. default to `True`. - print_log: whether to print log about the saved NIfTI file path, etc. default to `True`. - - """ - self.output_dir = output_dir - self.output_postfix = output_postfix - self.output_ext = output_ext - self.resample = resample - self.mode: str = GridSampleMode(mode) - self.padding_mode: str = GridSamplePadMode(padding_mode) - self.align_corners = align_corners - self.dtype = dtype - self.output_dtype = output_dtype - self._data_index = 0 - self.squeeze_end_dims = squeeze_end_dims - self.data_root_dir = data_root_dir - self.separate_folder = separate_folder - self.print_log = print_log - - def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: - """ - Save data into a NIfTI file. - The meta_data could optionally have the following keys: - - - ``'filename_or_obj'`` -- for output file name creation, corresponding to filename or object. - - ``'original_affine'`` -- for data orientation handling, defaulting to an identity matrix. - - ``'affine'`` -- for data output affine, defaulting to an identity matrix. - - ``'spatial_shape'`` -- for data output shape. - - ``'patch_index'`` -- if the data is a patch of big image, append the patch index to filename. - - When meta_data is specified and `resample=True`, the saver will try to resample batch data from the space - defined by "affine" to the space defined by "original_affine". - - If meta_data is None, use the default index (starting from 0) as the filename. - - Args: - data: target data content that to be saved as a NIfTI format file. - Assuming the data shape starts with a channel dimension and followed by spatial dimensions. - meta_data: the metadata information corresponding to the data. - - See Also - :py:meth:`monai.data.nifti_writer.write_nifti` - """ - filename = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) - self._data_index += 1 - original_affine = meta_data.get("original_affine", None) if meta_data and self.resample else None - affine = meta_data.get("affine", None) if meta_data else None - spatial_shape = meta_data.get("spatial_shape", None) if meta_data else None - patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None - - if isinstance(data, torch.Tensor): - data = data.detach().cpu().numpy() - - path = create_file_basename( - postfix=self.output_postfix, - input_file_name=filename, - folder_path=self.output_dir, - data_root_dir=self.data_root_dir, - separate_folder=self.separate_folder, - patch_index=patch_index, - ) - path = f"{path}{self.output_ext}" - # change data shape to be (channel, h, w, d) - while len(data.shape) < 4: - data = np.expand_dims(data, -1) - # change data to "channel last" format and write to NIfTI format file - data = np.moveaxis(np.asarray(data), 0, -1) - - # if desired, remove trailing singleton dimensions - if self.squeeze_end_dims: - while data.shape[-1] == 1: - data = np.squeeze(data, -1) - - write_nifti( - data, - file_name=path, - affine=affine, - target_affine=original_affine, - resample=True, - output_spatial_shape=spatial_shape, - mode=self.mode, - padding_mode=self.padding_mode, - align_corners=self.align_corners, - dtype=self.dtype, - output_dtype=self.output_dtype, - ) - - if self.print_log: - print(f"file written: {path}.") - - def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: - """ - Save a batch of data into NIfTI format files. - - Spatially it supports up to three dimensions, that is, H, HW, HWD for - 1D, 2D, 3D respectively (with resampling supports for 2D and 3D only). - - When saving multiple time steps or multiple channels `batch_data`, - time and/or modality axes should be appended after the batch dimensions. - For example, the shape of a batch of 2D eight-class - segmentation probabilities to be saved could be `(batch, 8, 64, 64)`; - in this case each item in the batch will be saved as (64, 64, 1, 8) - NIfTI file (the third dimension is reserved as a spatial dimension). - - Args: - batch_data: target batch data content that save into NIfTI format. - meta_data: every key-value in the meta_data is corresponding to a batch of data. - - """ - for i, data in enumerate(batch_data): # save a batch of files - self.save(data=data, meta_data={k: meta_data[k][i] for k in meta_data} if meta_data is not None else None) diff --git a/monai/data/nifti_writer.py b/monai/data/nifti_writer.py deleted file mode 100644 index 234f5b0a22..0000000000 --- a/monai/data/nifti_writer.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Sequence, Union - -import numpy as np -import torch - -from monai.config import DtypeLike -from monai.config.type_definitions import NdarrayOrTensor -from monai.data.utils import compute_shape_offset, to_affine_nd -from monai.networks.layers import AffineTransform -from monai.transforms.utils_pytorch_numpy_unification import allclose -from monai.utils import GridSampleMode, GridSamplePadMode, deprecated, optional_import -from monai.utils.type_conversion import convert_data_type - -nib, _ = optional_import("nibabel") - - -@deprecated(since="0.8", msg_suffix="use monai.data.NibabelWriter instead.") -def write_nifti( - data: NdarrayOrTensor, - file_name: str, - affine: Optional[NdarrayOrTensor] = None, - target_affine: Optional[np.ndarray] = None, - resample: bool = True, - output_spatial_shape: Union[Sequence[int], np.ndarray, None] = None, - mode: str = GridSampleMode.BILINEAR, - padding_mode: str = GridSamplePadMode.BORDER, - align_corners: bool = False, - dtype: DtypeLike = np.float64, - output_dtype: DtypeLike = np.float32, -) -> None: - """ - Write numpy data into NIfTI files to disk. This function converts data - into the coordinate system defined by `target_affine` when `target_affine` - is specified. - - If the coordinate transform between `affine` and `target_affine` could be - achieved by simply transposing and flipping `data`, no resampling will - happen. otherwise this function will resample `data` using the coordinate - transform computed from `affine` and `target_affine`. Note that the shape - of the resampled `data` may subject to some rounding errors. For example, - resampling a 20x20 pixel image from pixel size (1.5, 1.5)-mm to (3.0, - 3.0)-mm space will return a 10x10-pixel image. However, resampling a - 20x20-pixel image from pixel size (2.0, 2.0)-mm to (3.0, 3.0)-mma space - will output a 14x14-pixel image, where the image shape is rounded from - 13.333x13.333 pixels. In this case `output_spatial_shape` could be specified so - that this function writes image data to a designated shape. - - The saved `affine` matrix follows: - - If `affine` equals to `target_affine`, save the data with `target_affine`. - - If `resample=False`, transform `affine` to `new_affine` based on the orientation - of `target_affine` and save the data with `new_affine`. - - If `resample=True`, save the data with `target_affine`, if explicitly specify - the `output_spatial_shape`, the shape of saved data is not computed by `target_affine`. - - If `target_affine` is None, set `target_affine=affine` and save. - - If `affine` and `target_affine` are None, the data will be saved with an identity - matrix as the image affine. - - This function assumes the NIfTI dimension notations. - Spatially it supports up to three dimensions, that is, H, HW, HWD for - 1D, 2D, 3D respectively. - When saving multiple time steps or multiple channels `data`, time and/or - modality axes should be appended after the first three dimensions. For - example, shape of 2D eight-class segmentation probabilities to be saved - could be `(64, 64, 1, 8)`. Also, data in shape (64, 64, 8), (64, 64, 8, 1) - will be considered as a single-channel 3D image. - - Args: - data: input data to write to file. - file_name: expected file name that saved on disk. - affine: the current affine of `data`. Defaults to `np.eye(4)` - target_affine: before saving - the (`data`, `affine`) as a Nifti1Image, - transform the data into the coordinates defined by `target_affine`. - resample: whether to run resampling when the target affine - could not be achieved by swapping/flipping data axes. - output_spatial_shape: spatial shape of the output image. - This option is used when resample = True. - mode: {``"bilinear"``, ``"nearest"``} - This option is used when ``resample = True``. - Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} - This option is used when ``resample = True``. - Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - align_corners: Geometrically, we consider the pixels of the input as squares rather than points. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. - If None, use the data type of input data. - output_dtype: data type for saving data. Defaults to ``np.float32``. - - .. deprecated:: 0.8 - Use :py:meth:`monai.data.NibabelWriter` instead. - - """ - data, *_ = convert_data_type(data, np.ndarray) - affine, *_ = convert_data_type(affine, np.ndarray) - if not isinstance(data, np.ndarray): - raise AssertionError("input data must be numpy array or torch tensor.") - dtype = dtype or data.dtype - sr = min(data.ndim, 3) - if affine is None: - affine = np.eye(4, dtype=np.float64) - affine = to_affine_nd(sr, affine) - - if target_affine is None: - target_affine = affine - target_affine, *_ = convert_data_type(to_affine_nd(sr, target_affine), np.ndarray) - - if allclose(affine, target_affine, atol=1e-3): - # no affine changes, save (data, affine) - results_img = nib.Nifti1Image(data.astype(output_dtype, copy=False), to_affine_nd(3, target_affine)) - nib.save(results_img, file_name) - return - - # resolve orientation - start_ornt = nib.orientations.io_orientation(affine) - target_ornt = nib.orientations.io_orientation(target_affine) - ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt) - data_shape = data.shape - data = nib.orientations.apply_orientation(data, ornt_transform) - _affine = affine @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape) - if allclose(_affine, target_affine, atol=1e-3) or not resample: - results_img = nib.Nifti1Image(data.astype(output_dtype, copy=False), to_affine_nd(3, _affine)) # type: ignore - nib.save(results_img, file_name) - return - - # need resampling - affine_xform = AffineTransform( - normalized=False, mode=mode, padding_mode=padding_mode, align_corners=align_corners, reverse_indexing=True - ) - transform = np.linalg.inv(_affine) @ target_affine - if output_spatial_shape is None: - output_spatial_shape, _ = compute_shape_offset(data.shape, _affine, target_affine) - output_spatial_shape_ = list(output_spatial_shape) if output_spatial_shape is not None else [] - if data.ndim > 3: # multi channel, resampling each channel - while len(output_spatial_shape_) < 3: - output_spatial_shape_ = output_spatial_shape_ + [1] - spatial_shape, channel_shape = data.shape[:3], data.shape[3:] - data_np: np.ndarray = data.reshape(list(spatial_shape) + [-1]) # type: ignore - data_np = np.moveaxis(data_np, -1, 0) # channel first for pytorch - data_torch = affine_xform( - torch.as_tensor(np.ascontiguousarray(data_np, dtype=dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(transform, dtype=dtype)), - spatial_size=output_spatial_shape_[:3], - ) - data_np = data_torch.squeeze(0).detach().cpu().numpy() - data_np = np.moveaxis(data_np, 0, -1) # channel last for nifti - data_np = data_np.reshape(list(data_np.shape[:3]) + list(channel_shape)) - else: # single channel image, need to expand to have batch and channel - while len(output_spatial_shape_) < len(data.shape): - output_spatial_shape_ = output_spatial_shape_ + [1] - data_torch = affine_xform( - torch.as_tensor(np.ascontiguousarray(data, dtype=dtype)[None, None]), - torch.as_tensor(np.ascontiguousarray(transform, dtype=dtype)), - spatial_size=output_spatial_shape_[: len(data.shape)], - ) - data_np = data_torch.squeeze(0).squeeze(0).detach().cpu().numpy() - - results_img = nib.Nifti1Image(data_np.astype(output_dtype, copy=False), to_affine_nd(3, target_affine)) - nib.save(results_img, file_name) - return diff --git a/monai/data/png_saver.py b/monai/data/png_saver.py deleted file mode 100644 index 5b6e3b5a30..0000000000 --- a/monai/data/png_saver.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, Optional, Union - -import numpy as np -import torch - -from monai.config.type_definitions import PathLike -from monai.data.png_writer import write_png -from monai.data.utils import create_file_basename -from monai.utils import ImageMetaKey as Key -from monai.utils import InterpolateMode, deprecated, look_up_option - - -@deprecated(since="0.8", msg_suffix="use monai.transforms.SaveImage instead.") -class PNGSaver: - """ - Save the data as png file, it can support single data content or a batch of data. - Typically, the data can be segmentation predictions, call `save` for single data - or call `save_batch` to save a batch of data together. - The name of saved file will be `{input_image_name}_{output_postfix}{output_ext}`, - where the input image name is extracted from the provided metadata dictionary. - If no metadata provided, use index from 0 as the filename prefix. - - .. deprecated:: 0.8 - Use :py:class:`monai.transforms.SaveImage` instead. - - """ - - def __init__( - self, - output_dir: PathLike = "./", - output_postfix: str = "seg", - output_ext: str = ".png", - resample: bool = True, - mode: str = InterpolateMode.NEAREST, - scale: Optional[int] = None, - data_root_dir: PathLike = "", - separate_folder: bool = True, - print_log: bool = True, - ) -> None: - """ - Args: - output_dir: output image directory. - output_postfix: a string appended to all output file names. - output_ext: output file extension name. - resample: whether to resample and resize if providing spatial_shape in the metadata. - mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} - The interpolation mode. Defaults to ``"nearest"``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html - scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling - [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. - data_root_dir: if not empty, it specifies the beginning parts of the input file's - absolute path. it's used to compute `input_file_rel_path`, the relative path to the file from - `data_root_dir` to preserve folder structure when saving in case there are files in different - folders with the same file names. for example: - input_file_name: /foo/bar/test1/image.png, - postfix: seg - output_ext: png - output_dir: /output, - data_root_dir: /foo/bar, - output will be: /output/test1/image/image_seg.png - separate_folder: whether to save every file in a separate folder, for example: if input filename is - `image.png`, postfix is `seg` and folder_path is `output`, if `True`, save as: - `output/image/image_seg.png`, if `False`, save as `output/image_seg.nii`. default to `True`. - print_log: whether to print log about the saved PNG file path, etc. default to `True`. - - """ - self.output_dir = output_dir - self.output_postfix = output_postfix - self.output_ext = output_ext - self.resample = resample - self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) - self.scale = scale - self.data_root_dir = data_root_dir - self.separate_folder = separate_folder - self.print_log = print_log - - self._data_index = 0 - - def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: - """ - Save data into a png file. - The meta_data could optionally have the following keys: - - - ``'filename_or_obj'`` -- for output file name creation, corresponding to filename or object. - - ``'spatial_shape'`` -- for data output shape. - - ``'patch_index'`` -- if the data is a patch of big image, append the patch index to filename. - - If meta_data is None, use the default index (starting from 0) as the filename. - - Args: - data: target data content that to be saved as a png format file. - Assuming the data shape are spatial dimensions. - Shape of the spatial dimensions (C,H,W). - C should be 1, 3 or 4 - meta_data: the metadata information corresponding to the data. - - Raises: - ValueError: When ``data`` channels is not one of [1, 3, 4]. - - See Also - :py:meth:`monai.data.png_writer.write_png` - - """ - filename = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) - self._data_index += 1 - spatial_shape = meta_data.get("spatial_shape", None) if meta_data and self.resample else None - patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None - - if isinstance(data, torch.Tensor): - data = data.detach().cpu().numpy() - - path = create_file_basename( - postfix=self.output_postfix, - input_file_name=filename, - folder_path=self.output_dir, - data_root_dir=self.data_root_dir, - separate_folder=self.separate_folder, - patch_index=patch_index, - ) - path = f"{path}{self.output_ext}" - - if data.shape[0] == 1: - data = data.squeeze(0) - elif 2 < data.shape[0] < 5: - data = np.moveaxis(np.asarray(data), 0, -1) - else: - raise ValueError(f"Unsupported number of channels: {data.shape[0]}, available options are [1, 3, 4]") - - write_png( - np.asarray(data), file_name=path, output_spatial_shape=spatial_shape, mode=self.mode, scale=self.scale - ) - - if self.print_log: - print(f"file written: {path}.") - - def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: - """Save a batch of data into png format files. - - Args: - batch_data: target batch data content that save into png format. - meta_data: every key-value in the meta_data is corresponding to a batch of data. - - """ - for i, data in enumerate(batch_data): # save a batch of files - self.save(data=data, meta_data={k: meta_data[k][i] for k in meta_data} if meta_data is not None else None) diff --git a/monai/data/png_writer.py b/monai/data/png_writer.py deleted file mode 100644 index 8c49944843..0000000000 --- a/monai/data/png_writer.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Sequence - -import numpy as np - -from monai.transforms.spatial.array import Resize -from monai.utils import ( - InterpolateMode, - convert_data_type, - deprecated, - ensure_tuple_rep, - look_up_option, - optional_import, -) - -Image, _ = optional_import("PIL", name="Image") - - -@deprecated(since="0.8", msg_suffix="use monai.data.PILWriter instead.") -def write_png( - data: np.ndarray, - file_name: str, - output_spatial_shape: Optional[Sequence[int]] = None, - mode: str = InterpolateMode.BICUBIC, - scale: Optional[int] = None, -) -> None: - """ - Write numpy data into png files to disk. - Spatially it supports HW for 2D.(H,W) or (H,W,3) or (H,W,4). - If `scale` is None, expect the input data in `np.uint8` or `np.uint16` type. - It's based on the Image module in PIL library: - https://pillow.readthedocs.io/en/stable/reference/Image.html - - Args: - data: input data to write to file. - file_name: expected file name that saved on disk. - output_spatial_shape: spatial shape of the output image. - mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} - The interpolation mode. Defaults to ``"bicubic"``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html - scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling to - [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. - - Raises: - ValueError: When ``scale`` is not one of [255, 65535]. - - .. deprecated:: 0.8 - Use :py:meth:`monai.data.PILWriter` instead. - - """ - if not isinstance(data, np.ndarray): - raise ValueError("input data must be numpy array.") - if len(data.shape) == 3 and data.shape[2] == 1: # PIL Image can't save image with 1 channel - data = data.squeeze(2) - if output_spatial_shape is not None: - output_spatial_shape_ = ensure_tuple_rep(output_spatial_shape, 2) - mode = look_up_option(mode, InterpolateMode) - align_corners = None if mode in (InterpolateMode.NEAREST, InterpolateMode.AREA) else False - xform = Resize(spatial_size=output_spatial_shape_, mode=mode, align_corners=align_corners) - _min, _max = np.min(data), np.max(data) - if len(data.shape) == 3: - data = np.moveaxis(data, -1, 0) # to channel first - data = xform(data) # type: ignore - data = np.moveaxis(data, 0, -1) - else: # (H, W) - data = np.expand_dims(data, 0) # make a channel - data = xform(data)[0] # type: ignore - if mode != InterpolateMode.NEAREST: - data = np.clip(data, _min, _max) - - if scale is not None: - data = np.clip(data, 0.0, 1.0) # png writer only can scale data in range [0, 1] - if scale == np.iinfo(np.uint8).max: - data = convert_data_type((scale * data), np.ndarray, dtype=np.uint8)[0] - elif scale == np.iinfo(np.uint16).max: - data = convert_data_type((scale * data), np.ndarray, dtype=np.uint16)[0] - else: - raise ValueError(f"Unsupported scale: {scale}, available options are [255, 65535]") - - # PNG data must be int number - if data.dtype not in (np.uint8, np.uint16): - data = data.astype(np.uint8, copy=False) - - data = np.moveaxis(data, 0, 1) - img = Image.fromarray(data) - img.save(file_name, "PNG") - return diff --git a/monai/data/samplers.py b/monai/data/samplers.py index 9392ba562d..12cb7cf584 100644 --- a/monai/data/samplers.py +++ b/monai/data/samplers.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence +from __future__ import annotations + +from collections.abc import Sequence import torch from torch.utils.data import Dataset @@ -42,8 +44,8 @@ def __init__( self, dataset: Dataset, even_divisible: bool = True, - num_replicas: Optional[int] = None, - rank: Optional[int] = None, + num_replicas: int | None = None, + rank: int | None = None, shuffle: bool = True, **kwargs, ): @@ -88,11 +90,11 @@ def __init__( self, dataset: Dataset, weights: Sequence[float], - num_samples_per_rank: Optional[int] = None, - generator: Optional[torch.Generator] = None, + num_samples_per_rank: int | None = None, + generator: torch.Generator | None = None, even_divisible: bool = True, - num_replicas: Optional[int] = None, - rank: Optional[int] = None, + num_replicas: int | None = None, + rank: int | None = None, shuffle: bool = True, **kwargs, ): diff --git a/monai/data/synthetic.py b/monai/data/synthetic.py index a1a85338fe..97ed57ba7c 100644 --- a/monai/data/synthetic.py +++ b/monai/data/synthetic.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from __future__ import annotations import numpy as np @@ -26,9 +26,9 @@ def create_test_image_2d( rad_min: int = 5, noise_max: float = 0.0, num_seg_classes: int = 5, - channel_dim: Optional[int] = None, - random_state: Optional[np.random.RandomState] = None, -) -> Tuple[np.ndarray, np.ndarray]: + channel_dim: int | None = None, + random_state: np.random.RandomState | None = None, +) -> tuple[np.ndarray, np.ndarray]: """ Return a noisy 2D image with `num_objs` circles and a 2D mask image. The maximum and minimum radii of the circles are given as `rad_max` and `rad_min`. The mask will have `num_seg_classes` number of classes for segmentations labeled @@ -54,12 +54,12 @@ def create_test_image_2d( """ if rad_max <= rad_min: - raise ValueError("`rad_min` should be less than `rad_max`.") + raise ValueError(f"`rad_min` {rad_min} should be less than `rad_max` {rad_max}.") if rad_min < 1: - raise ValueError("`rad_min` should be no less than 1.") + raise ValueError(f"`rad_min` {rad_min} should be no less than 1.") min_size = min(height, width) if min_size <= 2 * rad_max: - raise ValueError("the minimal size of the image should be larger than `2 * rad_max`.") + raise ValueError(f"the minimal size {min_size} of the image should be larger than `2 * rad_max` 2x{rad_max}.") image = np.zeros((height, width)) rs: np.random.RandomState = np.random.random.__self__ if random_state is None else random_state # type: ignore @@ -103,9 +103,9 @@ def create_test_image_3d( rad_min: int = 5, noise_max: float = 0.0, num_seg_classes: int = 5, - channel_dim: Optional[int] = None, - random_state: Optional[np.random.RandomState] = None, -) -> Tuple[np.ndarray, np.ndarray]: + channel_dim: int | None = None, + random_state: np.random.RandomState | None = None, +) -> tuple[np.ndarray, np.ndarray]: """ Return a noisy 3D image and segmentation. @@ -131,12 +131,12 @@ def create_test_image_3d( """ if rad_max <= rad_min: - raise ValueError("`rad_min` should be less than `rad_max`.") + raise ValueError(f"`rad_min` {rad_min} should be less than `rad_max` {rad_max}.") if rad_min < 1: - raise ValueError("`rad_min` should be no less than 1.") + raise ValueError("f`rad_min` {rad_min} should be no less than 1.") min_size = min(height, width, depth) if min_size <= 2 * rad_max: - raise ValueError("the minimal size of the image should be larger than `2 * rad_max`.") + raise ValueError(f"the minimal size {min_size} of the image should be larger than `2 * rad_max` 2x{rad_max}.") image = np.zeros((height, width, depth)) rs: np.random.RandomState = np.random.random.__self__ if random_state is None else random_state # type: ignore diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index fdff86b745..23572dcef4 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -9,9 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings +from collections.abc import Callable from copy import deepcopy -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any import numpy as np import torch @@ -109,14 +112,14 @@ def __init__( batch_size: int, num_workers: int = 0, inferrer_fn: Callable = _identity, - device: Union[str, torch.device] = "cpu", + device: str | torch.device = "cpu", image_key=CommonKeys.IMAGE, orig_key=CommonKeys.LABEL, nearest_interp: bool = True, - orig_meta_keys: Optional[str] = None, + orig_meta_keys: str | None = None, meta_key_postfix=DEFAULT_POST_FIX, to_tensor: bool = True, - output_device: Union[str, torch.device] = "cpu", + output_device: str | torch.device = "cpu", post_func: Callable = _identity, return_full_data: bool = False, progress: bool = True, @@ -163,8 +166,8 @@ def _check_transforms(self): ) def __call__( - self, data: Dict[str, Any], num_examples: int = 10 - ) -> Union[Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor, float], NdarrayOrTensor]: + self, data: dict[str, Any], num_examples: int = 10 + ) -> tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor, float] | NdarrayOrTensor: """ Args: data: dictionary data to be processed. @@ -189,7 +192,7 @@ def __call__( ds = Dataset(data_in, self.transform) dl = DataLoader(ds, num_workers=self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate) - outs: List = [] + outs: list = [] for b in tqdm(dl) if has_tqdm and self.progress else dl: # do model forward pass diff --git a/monai/data/thread_buffer.py b/monai/data/thread_buffer.py index 30bccd1ec9..fc7826fb15 100644 --- a/monai/data/thread_buffer.py +++ b/monai/data/thread_buffer.py @@ -9,10 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from multiprocessing.context import SpawnContext from queue import Empty, Full, Queue from threading import Thread -from typing import Optional import torch @@ -41,7 +42,7 @@ def __init__(self, src, buffer_size: int = 1, timeout: float = 0.01): self.buffer_size = buffer_size self.timeout = timeout self.buffer: Queue = Queue(self.buffer_size) - self.gen_thread: Optional[Thread] = None + self.gen_thread: Thread | None = None self.is_running = False def enqueue_values(self): @@ -65,7 +66,6 @@ def stop(self): self.gen_thread = None def __iter__(self): - self.is_running = True self.gen_thread = Thread(target=self.enqueue_values, daemon=True) self.gen_thread.start() diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index ca46dd9dc4..cabf06ce89 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -9,10 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import datetime import json import os -from typing import IO, Any, Mapping, Optional, Sequence, Tuple, Union +from collections.abc import Mapping, Sequence +from typing import IO, Any import torch @@ -24,11 +27,11 @@ def save_net_with_metadata( jit_obj: torch.nn.Module, - filename_prefix_or_stream: Union[str, IO[Any]], + filename_prefix_or_stream: str | IO[Any], include_config_vals: bool = True, append_timestamp: bool = False, - meta_values: Optional[Mapping[str, Any]] = None, - more_extra_files: Optional[Mapping[str, bytes]] = None, + meta_values: Mapping[str, Any] | None = None, + more_extra_files: Mapping[str, bytes] | None = None, ) -> None: """ Save the JIT object (script or trace produced object) `jit_obj` to the given file or stream with metadata @@ -98,10 +101,10 @@ def save_net_with_metadata( def load_net_with_metadata( - filename_prefix_or_stream: Union[str, IO[Any]], - map_location: Optional[torch.device] = None, + filename_prefix_or_stream: str | IO[Any], + map_location: torch.device | None = None, more_extra_files: Sequence[str] = (), -) -> Tuple[torch.nn.Module, dict, dict]: +) -> tuple[torch.nn.Module, dict, dict]: """ Load the module object from the given Torchscript filename or stream, and convert the stored JSON metadata back to a dict object. This will produce an empty dict if the metadata file is not present. diff --git a/monai/data/utils.py b/monai/data/utils.py index ce566d0c31..5461fda937 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import hashlib import json import logging @@ -17,11 +19,12 @@ import pickle import warnings from collections import abc, defaultdict +from collections.abc import Generator, Iterable, Mapping, Sequence, Sized from copy import deepcopy from functools import reduce from itertools import product, starmap, zip_longest from pathlib import PurePath -from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Sized, Tuple, Union +from typing import Any import numpy as np import torch @@ -43,6 +46,7 @@ ensure_tuple_size, fall_back_tuple, first, + get_equivalent_dtype, issequenceiterable, look_up_option, optional_import, @@ -90,6 +94,7 @@ "remove_extra_metadata", "get_extra_metadata_keys", "PICKLE_KEY_SUFFIX", + "is_no_channel", ] # module to be used by `torch.save` @@ -100,8 +105,8 @@ def get_random_patch( - dims: Sequence[int], patch_size: Sequence[int], rand_state: Optional[np.random.RandomState] = None -) -> Tuple[slice, ...]: + dims: Sequence[int], patch_size: Sequence[int], rand_state: np.random.RandomState | None = None +) -> tuple[slice, ...]: """ Returns a tuple of slices to define a random patch in an array of shape `dims` with size `patch_size` or the as close to it as possible within the given dimension. It is expected that `patch_size` is a valid patch for a source @@ -126,11 +131,11 @@ def get_random_patch( def iter_patch_slices( image_size: Sequence[int], - patch_size: Union[Sequence[int], int], + patch_size: Sequence[int] | int, start_pos: Sequence[int] = (), - overlap: Union[Sequence[float], float] = 0.0, + overlap: Sequence[float] | float = 0.0, padded: bool = True, -) -> Generator[Tuple[slice, ...], None, None]: +) -> Generator[tuple[slice, ...], None, None]: """ Yield successive tuples of slices defining patches of size `patch_size` from an array of dimensions `image_size`. The iteration starts from position `start_pos` in the array, or starting at the origin if this isn't provided. Each @@ -160,7 +165,7 @@ def iter_patch_slices( def dense_patch_slices( image_size: Sequence[int], patch_size: Sequence[int], scan_interval: Sequence[int] -) -> List[Tuple[slice, ...]]: +) -> list[tuple[slice, ...]]: """ Enumerate all slices defining ND patches of size `patch_size` from an `image_size` input image. @@ -200,9 +205,9 @@ def dense_patch_slices( def iter_patch_position( image_size: Sequence[int], - patch_size: Union[Sequence[int], int, np.ndarray], + patch_size: Sequence[int] | int | np.ndarray, start_pos: Sequence[int] = (), - overlap: Union[Sequence[float], float] = 0.0, + overlap: Sequence[float] | float = 0.0, padded: bool = False, ): """ @@ -243,12 +248,12 @@ def iter_patch_position( def iter_patch( arr: np.ndarray, - patch_size: Union[Sequence[int], int] = 0, + patch_size: Sequence[int] | int = 0, start_pos: Sequence[int] = (), - overlap: Union[Sequence[float], float] = 0.0, + overlap: Sequence[float] | float = 0.0, copy_back: bool = True, - mode: Optional[str] = NumpyPadMode.WRAP, - **pad_opts: Dict, + mode: str | None = NumpyPadMode.WRAP, + **pad_opts: dict, ): """ Yield successive patches from `arr` of size `patch_size`. The iteration can start from position `start_pos` in `arr` @@ -317,9 +322,7 @@ def iter_patch( arr[...] = arrpad[slices] -def get_valid_patch_size( - image_size: Sequence[int], patch_size: Union[Sequence[int], int, np.ndarray] -) -> Tuple[int, ...]: +def get_valid_patch_size(image_size: Sequence[int], patch_size: Sequence[int] | int | np.ndarray) -> tuple[int, ...]: """ Given an image of dimensions `image_size`, return a patch size tuple taking the dimension from `patch_size` if this is not 0/None. Otherwise, or if `patch_size` is shorter than `image_size`, the dimension from `image_size` is taken. This ensures @@ -495,14 +498,14 @@ def list_data_collate(batch: Sequence): raise TypeError(re_str) from re -def _non_zipping_check(batch_data: Union[Mapping, Iterable], detach: bool, pad: bool, fill_value): +def _non_zipping_check(batch_data: Mapping | Iterable, detach: bool, pad: bool, fill_value): """ Utility function based on `decollate_batch`, to identify the largest batch size from the collated data. returns batch_size, the list of non-iterable items, and the dictionary or list with their items decollated. See `decollate_batch` for more details. """ - _deco: Union[Mapping, Sequence] + _deco: Mapping | Sequence if isinstance(batch_data, Mapping): _deco = {key: decollate_batch(batch_data[key], detach, pad=pad, fill_value=fill_value) for key in batch_data} elif isinstance(batch_data, Iterable): @@ -664,7 +667,7 @@ def worker_init_fn(worker_id: int) -> None: """ worker_info = torch.utils.data.get_worker_info() - set_rnd(worker_info.dataset, seed=worker_info.seed) + set_rnd(worker_info.dataset, seed=worker_info.seed) # type: ignore[union-attr] def set_rnd(obj, seed: int) -> int: @@ -779,7 +782,7 @@ def rectify_header_sform_qform(img_nii): return img_nii -def zoom_affine(affine: np.ndarray, scale: Union[np.ndarray, Sequence[float]], diagonal: bool = True): +def zoom_affine(affine: np.ndarray, scale: np.ndarray | Sequence[float], diagonal: bool = True): """ To make column norm of `affine` the same as `scale`. If diagonal is False, returns an affine that combines orthogonal rotation and the new scale. @@ -832,11 +835,11 @@ def zoom_affine(affine: np.ndarray, scale: Union[np.ndarray, Sequence[float]], d def compute_shape_offset( - spatial_shape: Union[np.ndarray, Sequence[int]], + spatial_shape: np.ndarray | Sequence[int], in_affine: NdarrayOrTensor, out_affine: NdarrayOrTensor, scale_extent: bool = False, -) -> Tuple[np.ndarray, np.ndarray]: +) -> tuple[np.ndarray, np.ndarray]: """ Given input and output affine, compute appropriate shapes in the output space based on the input array's shape. @@ -872,15 +875,14 @@ def compute_shape_offset( in_coords = [(-0.5, dim - 0.5) if scale_extent else (0.0, dim - 1.0) for dim in shape] corners: np.ndarray = np.asarray(np.meshgrid(*in_coords, indexing="ij")).reshape((len(shape), -1)) corners = np.concatenate((corners, np.ones_like(corners[:1]))) - corners = in_affine_ @ corners try: - inv_mat = np.linalg.inv(out_affine_) + corners_out = np.linalg.solve(out_affine_, in_affine_) @ corners except np.linalg.LinAlgError as e: raise ValueError(f"Affine {out_affine_} is not invertible") from e - corners_out = inv_mat @ corners + corners = in_affine_ @ corners + all_dist = corners_out[:-1].copy() corners_out = corners_out[:-1] / corners_out[-1] out_shape = np.round(corners_out.ptp(axis=1)) if scale_extent else np.round(corners_out.ptp(axis=1) + 1.0) - all_dist = inv_mat[:-1, :-1] @ corners[:-1, :] offset = None for i in range(corners.shape[1]): min_corner = np.min(all_dist - all_dist[:, i : i + 1], 1) @@ -895,7 +897,7 @@ def compute_shape_offset( return out_shape.astype(int, copy=False), offset # type: ignore -def to_affine_nd(r: Union[np.ndarray, int], affine: NdarrayTensor, dtype=np.float64) -> NdarrayTensor: +def to_affine_nd(r: np.ndarray | int, affine: NdarrayTensor, dtype=np.float64) -> NdarrayTensor: """ Using elements from affine, to create a new affine matrix by assigning the rotation/zoom/scaling matrix and the translation vector. @@ -923,6 +925,7 @@ def to_affine_nd(r: Union[np.ndarray, int], affine: NdarrayTensor, dtype=np.floa an (r+1) x (r+1) matrix (tensor or ndarray depends on the input ``affine`` data type) """ + dtype = get_equivalent_dtype(dtype, np.ndarray) affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True)[0] affine_np = affine_np.copy() if affine_np.ndim != 2: @@ -943,7 +946,7 @@ def to_affine_nd(r: Union[np.ndarray, int], affine: NdarrayTensor, dtype=np.floa def reorient_spatial_axes( data_shape: Sequence[int], init_affine: NdarrayOrTensor, target_affine: NdarrayOrTensor -) -> Tuple[np.ndarray, NdarrayOrTensor]: +) -> tuple[np.ndarray, NdarrayOrTensor]: """ Given the input ``init_affine``, compute the orientation transform between it and ``target_affine`` by rearranging/flipping the axes. @@ -1041,10 +1044,10 @@ def create_file_basename( def compute_importance_map( - patch_size: Tuple[int, ...], - mode: Union[BlendMode, str] = BlendMode.CONSTANT, - sigma_scale: Union[Sequence[float], float] = 0.125, - device: Union[torch.device, int, str] = "cpu", + patch_size: tuple[int, ...], + mode: BlendMode | str = BlendMode.CONSTANT, + sigma_scale: Sequence[float] | float = 0.125, + device: torch.device | int | str = "cpu", ) -> torch.Tensor: """Get importance map for different weight modes. @@ -1072,7 +1075,6 @@ def compute_importance_map( if mode == BlendMode.CONSTANT: importance_map = torch.ones(patch_size, device=device, dtype=torch.float) elif mode == BlendMode.GAUSSIAN: - sigma_scale = ensure_tuple_rep(sigma_scale, len(patch_size)) sigmas = [i * sigma_s for i, sigma_s in zip(patch_size, sigma_scale)] @@ -1089,7 +1091,7 @@ def compute_importance_map( return importance_map -def is_supported_format(filename: Union[Sequence[PathLike], PathLike], suffixes: Sequence[str]) -> bool: +def is_supported_format(filename: Sequence[PathLike] | PathLike, suffixes: Sequence[str]) -> bool: """ Verify whether the specified file or files format match supported suffixes. If supported suffixes is None, skip the verification and return True. @@ -1111,8 +1113,8 @@ def is_supported_format(filename: Union[Sequence[PathLike], PathLike], suffixes: def partition_dataset( data: Sequence, - ratios: Optional[Sequence[float]] = None, - num_partitions: Optional[int] = None, + ratios: Sequence[float] | None = None, + num_partitions: int | None = None, shuffle: bool = False, seed: int = 0, drop_last: bool = False, @@ -1222,8 +1224,8 @@ def partition_dataset( def partition_dataset_classes( data: Sequence, classes: Sequence[int], - ratios: Optional[Sequence[float]] = None, - num_partitions: Optional[int] = None, + ratios: Sequence[float] | None = None, + num_partitions: int | None = None, shuffle: bool = False, seed: int = 0, drop_last: bool = False, @@ -1261,7 +1263,7 @@ def partition_dataset_classes( for i, c in enumerate(classes): class_indices[c].append(i) - class_partition_indices: List[Sequence] = [] + class_partition_indices: list[Sequence] = [] for _, per_class_indices in sorted(class_indices.items()): per_class_partition_indices = partition_dataset( data=per_class_indices, @@ -1302,7 +1304,7 @@ def resample_datalist(data: Sequence, factor: float, random_pick: bool = False, """ scale, repeats = math.modf(factor) - ret: List = list() + ret: list = list() for _ in range(int(repeats)): ret.extend(list(deepcopy(data))) @@ -1312,7 +1314,7 @@ def resample_datalist(data: Sequence, factor: float, random_pick: bool = False, return ret -def select_cross_validation_folds(partitions: Sequence[Iterable], folds: Union[Sequence[int], int]) -> List: +def select_cross_validation_folds(partitions: Sequence[Iterable], folds: Sequence[int] | int) -> list: """ Select cross validation data based on data partitions and specified fold index. if a list of fold indices is provided, concatenate the partitions of these folds. @@ -1375,12 +1377,12 @@ def sorted_dict(item, key=None, reverse=False): def convert_tables_to_dicts( dfs, - row_indices: Optional[Sequence[Union[int, str]]] = None, - col_names: Optional[Sequence[str]] = None, - col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None, - col_groups: Optional[Dict[str, Sequence[str]]] = None, + row_indices: Sequence[int | str] | None = None, + col_names: Sequence[str] | None = None, + col_types: dict[str, dict[str, Any] | None] | None = None, + col_groups: dict[str, Sequence[str]] | None = None, **kwargs, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ Utility to join pandas tables, select rows, columns and generate groups. Will return a list of dictionaries, every dictionary maps to a row of data in tables. @@ -1414,7 +1416,7 @@ def convert_tables_to_dicts( """ df = reduce(lambda l, r: pd.merge(l, r, **kwargs), ensure_tuple(dfs)) # parse row indices - rows: List[Union[int, str]] = [] + rows: list[int | str] = [] if row_indices is None: rows = slice(df.shape[0]) # type: ignore else: @@ -1437,11 +1439,11 @@ def convert_tables_to_dicts( types = {k: v["type"] for k, v in col_types.items() if v is not None and "type" in v} if types: data_ = data_.astype(dtype=types, copy=False) - data: List[Dict] = data_.to_dict(orient="records") + data: list[dict] = data_.to_dict(orient="records") # group columns to generate new column if col_groups is not None: - groups: Dict[str, List] = {} + groups: dict[str, list] = {} for name, cols in col_groups.items(): groups[name] = df.loc[rows, cols].values # invert items of groups to every row of data @@ -1466,7 +1468,7 @@ def orientation_ras_lps(affine: NdarrayTensor) -> NdarrayTensor: return np.diag(flip_diag).astype(affine.dtype) @ affine # type: ignore -def remove_keys(data: dict, keys: List[str]) -> None: +def remove_keys(data: dict, keys: list[str]) -> None: """ Remove keys from a dictionary. Operates in-place so nothing is returned. @@ -1495,7 +1497,7 @@ def remove_extra_metadata(meta: dict) -> None: remove_keys(data=meta, keys=keys) -def get_extra_metadata_keys() -> List[str]: +def get_extra_metadata_keys() -> list[str]: """ Get a list of unnecessary keys for metadata that can be removed. @@ -1527,3 +1529,14 @@ def get_extra_metadata_keys() -> List[str]: # ] return keys + + +def is_no_channel(val) -> bool: + """Returns whether `val` indicates "no_channel", for MetaKeys.ORIGINAL_CHANNEL_DIM.""" + if isinstance(val, torch.Tensor): + return bool(torch.isnan(val)) + if isinstance(val, str): + return val == "no_channel" + if np.isscalar(val): + return bool(np.isnan(val)) + return val is None diff --git a/monai/data/video_dataset.py b/monai/data/video_dataset.py index fda262b398..be3bcf5bd5 100644 --- a/monai/data/video_dataset.py +++ b/monai/data/video_dataset.py @@ -9,10 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import sys import tempfile -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +from collections.abc import Callable +from typing import TYPE_CHECKING, Any import numpy as np from torch.utils.data import Dataset, IterableDataset @@ -62,9 +65,9 @@ class VideoDataset: def __init__( self, - video_source: Union[str, int], - transform: Optional[Callable] = None, - max_num_frames: Optional[int] = None, + video_source: str | int, + transform: Callable | None = None, + max_num_frames: int | None = None, color_order: str = ColorOrder.RGB, multiprocessing: bool = False, channel_dim: int = 0, @@ -107,7 +110,7 @@ def __init__( self.max_num_frames = max_num_frames @staticmethod - def open_video(video_source: Union[str, int]): + def open_video(video_source: str | int): """ Use OpenCV to open a video source from either file or capture device. @@ -162,7 +165,7 @@ def __init__(self, *args, **kwargs) -> None: self.max_num_frames = num_frames @staticmethod - def get_available_codecs() -> Dict[str, str]: + def get_available_codecs() -> dict[str, str]: """Try different codecs, see which are available. Returns a dictionary with of available codecs with codecs as keys and file extensions as values.""" if not has_cv2: diff --git a/monai/data/wsi_datasets.py b/monai/data/wsi_datasets.py index d4b70f7f0a..0ed8bf02ec 100644 --- a/monai/data/wsi_datasets.py +++ b/monai/data/wsi_datasets.py @@ -9,9 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import inspect import os -from typing import Callable, Dict, Optional, Sequence, Tuple, Union +from collections.abc import Callable, Sequence import numpy as np import torch @@ -34,8 +36,8 @@ class PatchWSIDataset(Dataset): Args: data: the list of input samples including image, location, and label (see the note below for more details). - size: the size of patch to be extracted from the whole slide image. - level: the level at which the patches to be extracted (default to 0). + patch_size: the size of patch to be extracted from the whole slide image. + patch_level: the level at which the patches to be extracted (default to 0). transform: transforms to be executed on input data. include_label: whether to load and include labels in the output center_location: whether the input location information is the position of the center of the patch @@ -67,12 +69,12 @@ class PatchWSIDataset(Dataset): def __init__( self, data: Sequence, - patch_size: Optional[Union[int, Tuple[int, int]]] = None, - patch_level: Optional[int] = None, - transform: Optional[Callable] = None, + patch_size: int | tuple[int, int] | None = None, + patch_level: int | None = None, + transform: Callable | None = None, include_label: bool = True, center_location: bool = True, - additional_meta_keys: Optional[Sequence[str]] = None, + additional_meta_keys: Sequence[str] | None = None, reader="cuCIM", **kwargs, ): @@ -91,7 +93,7 @@ def __init__( patch_level = 0 # Setup the WSI reader - self.wsi_reader: Union[WSIReader, BaseWSIReader] + self.wsi_reader: WSIReader | BaseWSIReader if isinstance(reader, str): self.wsi_reader = WSIReader(backend=reader, level=patch_level, **kwargs) elif inspect.isclass(reader) and issubclass(reader, BaseWSIReader): @@ -107,35 +109,35 @@ def __init__( self.additional_meta_keys = additional_meta_keys or [] # Initialized an empty whole slide image object dict - self.wsi_object_dict: Dict = {} + self.wsi_object_dict: dict = {} - def _get_wsi_object(self, sample: Dict): + def _get_wsi_object(self, sample: dict): image_path = sample[CommonKeys.IMAGE] if image_path not in self.wsi_object_dict: self.wsi_object_dict[image_path] = self.wsi_reader.read(image_path) return self.wsi_object_dict[image_path] - def _get_label(self, sample: Dict): + def _get_label(self, sample: dict): return torch.tensor(sample[CommonKeys.LABEL], dtype=torch.float32) - def _get_location(self, sample: Dict): + def _get_location(self, sample: dict): if self.center_location: size = self._get_size(sample) return [sample[WSIPatchKeys.LOCATION][i] - size[i] // 2 for i in range(len(size))] else: return sample[WSIPatchKeys.LOCATION] - def _get_level(self, sample: Dict): + def _get_level(self, sample: dict): if self.patch_level is None: return sample.get(WSIPatchKeys.LEVEL, 0) return self.patch_level - def _get_size(self, sample: Dict): + def _get_size(self, sample: dict): if self.patch_size is None: return ensure_tuple_rep(sample.get(WSIPatchKeys.SIZE), 2) return self.patch_size - def _get_data(self, sample: Dict): + def _get_data(self, sample: dict): # Don't store OpenSlide objects to avoid issues with OpenSlide internal cache if self.backend == "openslide": self.wsi_object_dict = {} @@ -147,7 +149,7 @@ def _get_data(self, sample: Dict): def _transform(self, index: int): # Get a single entry of data - sample: Dict = self.data[index] + sample: dict = self.data[index] # Extract patch image and associated metadata image, metadata = self._get_data(sample) @@ -169,26 +171,29 @@ def _transform(self, index: int): class SlidingPatchWSIDataset(Randomizable, PatchWSIDataset): """ - This dataset extracts patches from whole slide images (without loading the whole image) + This dataset extracts patches in sliding-window manner from whole slide images (without loading the whole image). It also reads labels for each patch and provides each patch with its associated class labels. Args: data: the list of input samples including image, location, and label (see the note below for more details). - size: the size of patch to be extracted from the whole slide image. - level: the level at which the patches to be extracted (default to 0). + patch_size: the size of patch to be extracted from the whole slide image. + patch_level: the level at which the patches to be extracted (default to 0). + mask_level: the resolution level at which the mask/map is created (for `ProbMapProducer` for instance). + overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). + If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. offset: the offset of image to extract patches (the starting position of the upper left patch). offset_limits: if offset is set to "random", a tuple of integers defining the lower and upper limit of the random offset for all dimensions, or a tuple of tuples that defines the limits for each dimension. - overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). - If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. transform: transforms to be executed on input data. + include_label: whether to load and include labels in the output + center_location: whether the input location information is the position of the center of the patch + additional_meta_keys: the list of keys for items to be copied to the output metadata from the input data reader: the module to be used for loading whole slide imaging. Defaults to cuCIM. If `reader` is - a string, it defines the backend of `monai.data.WSIReader`. - a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader, - an instance of a class inherited from `BaseWSIReader`, it is set as the wsi_reader. - map_level: the resolution level at which the output map is created. seed: random seed to randomly generate offsets. Defaults to 0. kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class @@ -202,23 +207,23 @@ class SlidingPatchWSIDataset(Randomizable, PatchWSIDataset): {"image": "path/to/image2.tiff", "patch_size": [20, 20], "patch_level": 2} ] + Unlike `MaskedPatchWSIDataset`, this dataset does not filter any patches. """ def __init__( self, data: Sequence, - patch_size: Optional[Union[int, Tuple[int, int]]] = None, - patch_level: Optional[int] = None, + patch_size: int | tuple[int, int] | None = None, + patch_level: int | None = None, mask_level: int = 0, - overlap: Union[Tuple[float, float], float] = 0.0, - offset: Union[Tuple[int, int], int, str] = (0, 0), - offset_limits: Optional[Union[Tuple[Tuple[int, int], Tuple[int, int]], Tuple[int, int]]] = None, - transform: Optional[Callable] = None, + overlap: tuple[float, float] | float = 0.0, + offset: tuple[int, int] | int | str = (0, 0), + offset_limits: tuple[tuple[int, int], tuple[int, int]] | tuple[int, int] | None = None, + transform: Callable | None = None, include_label: bool = False, center_location: bool = False, additional_meta_keys: Sequence[str] = (ProbMapKeys.LOCATION, ProbMapKeys.SIZE, ProbMapKeys.COUNT), reader="cuCIM", - map_level: int = 0, seed: int = 0, **kwargs, ): @@ -240,7 +245,7 @@ def __init__( if isinstance(offset, str): if offset == "random": self.random_offset = True - self.offset_limits: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] + self.offset_limits: tuple[tuple[int, int], tuple[int, int]] | None if offset_limits is None: self.offset_limits = None elif isinstance(offset_limits, tuple): @@ -320,8 +325,8 @@ class MaskedPatchWSIDataset(PatchWSIDataset): Args: data: the list of input samples including image, location, and label (see the note below for more details). - size: the size of patch to be extracted from the whole slide image. - level: the level at which the patches to be extracted (default to 0). + patch_size: the size of patch to be extracted from the whole slide image. + patch_level: the level at which the patches to be extracted (default to 0). mask_level: the resolution level at which the mask is created. transform: transforms to be executed on input data. include_label: whether to load and include labels in the output @@ -350,10 +355,10 @@ class MaskedPatchWSIDataset(PatchWSIDataset): def __init__( self, data: Sequence, - patch_size: Optional[Union[int, Tuple[int, int]]] = None, - patch_level: Optional[int] = None, + patch_size: int | tuple[int, int] | None = None, + patch_level: int | None = None, mask_level: int = 7, - transform: Optional[Callable] = None, + transform: Callable | None = None, include_label: bool = False, center_location: bool = False, additional_meta_keys: Sequence[str] = (ProbMapKeys.LOCATION, ProbMapKeys.NAME), diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index d18d40935a..4f88139677 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -9,18 +9,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from abc import abstractmethod +from collections.abc import Sequence from os.path import abspath -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any import numpy as np +import torch -from monai.config import DtypeLike, PathLike +from monai.config import DtypeLike, NdarrayOrTensor, PathLike from monai.data.image_reader import ImageReader, _stack_images from monai.data.utils import is_supported_format -from monai.utils import WSIPatchKeys, ensure_tuple, optional_import, require_pkg +from monai.utils import ( + WSIPatchKeys, + dtype_numpy_to_torch, + dtype_torch_to_numpy, + ensure_tuple, + optional_import, + require_pkg, +) -CuImage, _ = optional_import("cucim", name="CuImage") OpenSlide, _ = optional_import("openslide", name="OpenSlide") TiffFile, _ = optional_import("tifffile", name="TiffFile") @@ -31,12 +41,21 @@ class BaseWSIReader(ImageReader): """ An abstract class that defines APIs to load patches from whole slide image files. + Args: + level: the whole slide image level at which the image is extracted. + channel_dim: the desired dimension for color channel. + dtype: the data type of output image. + device: target device to put the extracted patch. Note that if device is "cuda"", + the output will be converted to torch tenor and sent to the gpu even if the dtype is numpy. + mode: the output image color mode, e.g., "RGB" or "RGBA". + kwargs: additional args for the reader + Typical usage of a concrete implementation of this class is: .. code-block:: python image_reader = MyWSIReader() - wsi = image_reader.read(, **kwargs) + wsi = image_reader.read(filepath, **kwargs) img_data, meta_data = image_reader.get_data(wsi) - The `read` call converts an image filename into whole slide image object, @@ -53,18 +72,42 @@ class BaseWSIReader(ImageReader): """ - supported_suffixes: List[str] = [] + supported_suffixes: list[str] = [] backend = "" - def __init__(self, level: int = 0, channel_dim: int = 0, **kwargs): + def __init__( + self, + level: int, + channel_dim: int, + dtype: DtypeLike | torch.dtype, + device: torch.device | str | None, + mode: str, + **kwargs, + ): super().__init__() self.level = level self.channel_dim = channel_dim + self.set_dtype(dtype) + self.set_device(device) + self.mode = mode self.kwargs = kwargs - self.metadata: Dict[Any, Any] = {} + self.metadata: dict[Any, Any] = {} + + def set_dtype(self, dtype): + self.dtype: torch.dtype | np.dtype + if isinstance(dtype, torch.dtype): + self.dtype = dtype + else: + self.dtype = np.dtype(dtype) + + def set_device(self, device): + if device is None or isinstance(device, (torch.device, str)): + self.device = device + else: + raise ValueError(f"`device` must be `torch.device`, `str` or `None` but {type(device)} is given.") @abstractmethod - def get_size(self, wsi, level: Optional[int] = None) -> Tuple[int, int]: + def get_size(self, wsi, level: int | None = None) -> tuple[int, int]: """ Returns the size (height, width) of the whole slide image at a given level. @@ -87,7 +130,7 @@ def get_level_count(self, wsi) -> int: raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") @abstractmethod - def get_downsample_ratio(self, wsi, level: Optional[int] = None) -> float: + def get_downsample_ratio(self, wsi, level: int | None = None) -> float: """ Returns the down-sampling ratio of the whole slide image at a given level. @@ -105,7 +148,7 @@ def get_file_path(self, wsi) -> str: raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") @abstractmethod - def get_mpp(self, wsi, level: Optional[int] = None) -> Tuple[float, float]: + def get_mpp(self, wsi, level: int | None = None) -> tuple[float, float]: """ Returns the micro-per-pixel resolution of the whole slide image at a given level. @@ -118,7 +161,7 @@ def get_mpp(self, wsi, level: Optional[int] = None) -> Tuple[float, float]: @abstractmethod def _get_patch( - self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str + self, wsi, location: tuple[int, int], size: tuple[int, int], level: int, dtype: DtypeLike, mode: str ) -> np.ndarray: """ Extracts and returns a patch image form the whole slide image. @@ -136,8 +179,8 @@ def _get_patch( raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") def _get_metadata( - self, wsi, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int - ) -> Dict: + self, wsi, patch: NdarrayOrTensor, location: tuple[int, int], size: tuple[int, int], level: int + ) -> dict: """ Returns metadata of the extracted patch from the whole slide image. @@ -155,7 +198,7 @@ def _get_metadata( f"The desired channel_dim ({self.channel_dim}) is out of bound for image shape: {patch.shape}" ) channel_dim: int = self.channel_dim + (len(patch.shape) if self.channel_dim < 0 else 0) - metadata: Dict = { + metadata: dict = { "backend": self.backend, "original_channel_dim": channel_dim, "spatial_shape": np.array(patch.shape[:channel_dim] + patch.shape[channel_dim + 1 :]), @@ -170,12 +213,11 @@ def _get_metadata( def get_data( self, wsi, - location: Tuple[int, int] = (0, 0), - size: Optional[Tuple[int, int]] = None, - level: Optional[int] = None, - dtype: DtypeLike = np.uint8, - mode: str = "RGB", - ) -> Tuple[np.ndarray, Dict]: + location: tuple[int, int] = (0, 0), + size: tuple[int, int] | None = None, + level: int | None = None, + mode: str | None = None, + ) -> tuple[np.ndarray, dict]: """ Verifies inputs, extracts patches from WSI image and generates metadata, and return them. @@ -183,19 +225,20 @@ def get_data( wsi: a whole slide image object loaded from a file or a list of such objects location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). size: (height, width) tuple giving the patch size at the given level (`level`). - If None, it is set to the full image size at the given level. + If not provided or None, it is set to the full image size at the given level. level: the level number. Defaults to 0 - dtype: the data type of output image - mode: the output image mode, 'RGB' or 'RGBA' + mode: the output image color mode, "RGB" or "RGBA". If not provided the default of "RGB" is used. Returns: a tuples, where the first element is an image patch [CxHxW] or stack of patches, and second element is a dictionary of metadata """ - patch_list: List = [] - metadata_list: List = [] + if mode is None: + mode = self.mode + patch_list: list = [] + metadata_list: list = [] # CuImage object is iterable, so ensure_tuple won't work on single object - if not isinstance(wsi, List): + if not isinstance(wsi, list): wsi = [wsi] for each_wsi in ensure_tuple(wsi): # Verify magnification level @@ -221,8 +264,25 @@ def get_data( if size[0] <= 0 or size[1] <= 0: raise ValueError(f"Patch size should be greater than zero, provided: patch size = {size}") + # Get numpy dtype if it is not already. + dtype_np = dtype_torch_to_numpy(self.dtype) if isinstance(self.dtype, torch.dtype) else self.dtype # Extract a patch or the entire image - patch = self._get_patch(each_wsi, location=location, size=size, level=level, dtype=dtype, mode=mode) + patch: NdarrayOrTensor + patch = self._get_patch(each_wsi, location=location, size=size, level=level, dtype=dtype_np, mode=mode) + + # Convert the patch to torch.Tensor if dtype is torch + if isinstance(self.dtype, torch.dtype) or ( + self.device is not None and torch.device(self.device).type == "cuda" + ): + # Ensure dtype is torch.dtype if the device is not "cpu" + dtype_torch = ( + dtype_numpy_to_torch(self.dtype) if not isinstance(self.dtype, torch.dtype) else self.dtype + ) + # Copy the numpy array if it is not writable + if patch.flags["WRITEABLE"]: + patch = torch.as_tensor(patch, dtype=dtype_torch, device=self.device) + else: + patch = torch.tensor(patch, dtype=dtype_torch, device=self.device) # check if the image has three dimensions (2D + color) if patch.ndim != 3: @@ -257,7 +317,7 @@ def get_data( metadata[key] = [m[key] for m in metadata_list] return _stack_images(patch_list, metadata), metadata - def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ Verify whether the specified file or files format is supported by WSI reader. @@ -278,26 +338,53 @@ class WSIReader(BaseWSIReader): backend: the name of backend whole slide image reader library, the default is cuCIM. level: the level at which patches are extracted. channel_dim: the desired dimension for color channel. Default to 0 (channel first). + dtype: the data type of output image. Defaults to `np.uint8`. + mode: the output image color mode, "RGB" or "RGBA". Defaults to "RGB". + device: target device to put the extracted patch. Note that if device is "cuda"", + the output will be converted to torch tenor and sent to the gpu even if the dtype is numpy. num_workers: number of workers for multi-thread image loading (cucim backend only). kwargs: additional arguments to be passed to the backend library """ - def __init__(self, backend="cucim", level: int = 0, channel_dim: int = 0, **kwargs): - super().__init__(level, channel_dim, **kwargs) + supported_backends = ["cucim", "openslide", "tifffile"] + + def __init__( + self, + backend="cucim", + level: int = 0, + channel_dim: int = 0, + dtype: DtypeLike | torch.dtype = np.uint8, + device: torch.device | str | None = None, + mode: str = "RGB", + **kwargs, + ): self.backend = backend.lower() - self.reader: Union[CuCIMWSIReader, OpenSlideWSIReader, TiffFileWSIReader] + self.reader: CuCIMWSIReader | OpenSlideWSIReader | TiffFileWSIReader if self.backend == "cucim": - self.reader = CuCIMWSIReader(level=level, channel_dim=channel_dim, **kwargs) + self.reader = CuCIMWSIReader( + level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs + ) elif self.backend == "openslide": - self.reader = OpenSlideWSIReader(level=level, channel_dim=channel_dim, **kwargs) + self.reader = OpenSlideWSIReader( + level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs + ) elif self.backend == "tifffile": - self.reader = TiffFileWSIReader(level=level, channel_dim=channel_dim, **kwargs) + self.reader = TiffFileWSIReader( + level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs + ) else: raise ValueError( f"The supported backends are cucim, openslide, and tifffile but '{self.backend}' was given." ) self.supported_suffixes = self.reader.supported_suffixes + self.level = self.reader.level + self.channel_dim = self.reader.channel_dim + self.dtype = self.reader.dtype + self.device = self.reader.device + self.mode = self.reader.mode + self.kwargs = self.reader.kwargs + self.metadata = self.reader.metadata def get_level_count(self, wsi) -> int: """ @@ -309,7 +396,7 @@ def get_level_count(self, wsi) -> int: """ return self.reader.get_level_count(wsi) - def get_size(self, wsi, level: Optional[int] = None) -> Tuple[int, int]: + def get_size(self, wsi, level: int | None = None) -> tuple[int, int]: """ Returns the size (height, width) of the whole slide image at a given level. @@ -324,7 +411,7 @@ def get_size(self, wsi, level: Optional[int] = None) -> Tuple[int, int]: return self.reader.get_size(wsi, level) - def get_downsample_ratio(self, wsi, level: Optional[int] = None) -> float: + def get_downsample_ratio(self, wsi, level: int | None = None) -> float: """ Returns the down-sampling ratio of the whole slide image at a given level. @@ -343,7 +430,7 @@ def get_file_path(self, wsi) -> str: """Return the file path for the WSI object""" return self.reader.get_file_path(wsi) - def get_mpp(self, wsi, level: Optional[int] = None) -> Tuple[float, float]: + def get_mpp(self, wsi, level: int | None = None) -> tuple[float, float]: """ Returns the micro-per-pixel resolution of the whole slide image at a given level. @@ -359,7 +446,7 @@ def get_mpp(self, wsi, level: Optional[int] = None) -> Tuple[float, float]: return self.reader.get_mpp(wsi, level) def _get_patch( - self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str + self, wsi, location: tuple[int, int], size: tuple[int, int], level: int, dtype: DtypeLike, mode: str ) -> np.ndarray: """ Extracts and returns a patch image form the whole slide image. @@ -376,7 +463,7 @@ def _get_patch( """ return self.reader._get_patch(wsi=wsi, location=location, size=size, level=level, dtype=dtype, mode=mode) - def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): + def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs): """ Read whole slide image objects from given file or list of files. @@ -400,6 +487,10 @@ class CuCIMWSIReader(BaseWSIReader): level: the whole slide image level at which the image is extracted. (default=0) This is overridden if the level argument is provided in `get_data`. channel_dim: the desired dimension for color channel. Default to 0 (channel first). + dtype: the data type of output image. Defaults to `np.uint8`. + device: target device to put the extracted patch. Note that if device is "cuda"", + the output will be converted to torch tenor and sent to the gpu even if the dtype is numpy. + mode: the output image color mode, "RGB" or "RGBA". Defaults to "RGB". num_workers: number of workers for multi-thread image loading kwargs: additional args for `cucim.CuImage` module: https://github.com/rapidsai/cucim/blob/main/cpp/include/cucim/cuimage.h @@ -409,8 +500,17 @@ class CuCIMWSIReader(BaseWSIReader): supported_suffixes = ["tif", "tiff", "svs"] backend = "cucim" - def __init__(self, level: int = 0, channel_dim: int = 0, num_workers: int = 0, **kwargs): - super().__init__(level, channel_dim, **kwargs) + def __init__( + self, + level: int = 0, + channel_dim: int = 0, + dtype: DtypeLike | torch.dtype = np.uint8, + device: torch.device | str | None = None, + mode: str = "RGB", + num_workers: int = 0, + **kwargs, + ): + super().__init__(level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs) self.num_workers = num_workers @staticmethod @@ -424,7 +524,7 @@ def get_level_count(wsi) -> int: """ return wsi.resolutions["level_count"] # type: ignore - def get_size(self, wsi, level: Optional[int] = None) -> Tuple[int, int]: + def get_size(self, wsi, level: int | None = None) -> tuple[int, int]: """ Returns the size (height, width) of the whole slide image at a given level. @@ -439,7 +539,7 @@ def get_size(self, wsi, level: Optional[int] = None) -> Tuple[int, int]: return (wsi.resolutions["level_dimensions"][level][1], wsi.resolutions["level_dimensions"][level][0]) - def get_downsample_ratio(self, wsi, level: Optional[int] = None) -> float: + def get_downsample_ratio(self, wsi, level: int | None = None) -> float: """ Returns the down-sampling ratio of the whole slide image at a given level. @@ -459,7 +559,7 @@ def get_file_path(wsi) -> str: """Return the file path for the WSI object""" return str(abspath(wsi.path)) - def get_mpp(self, wsi, level: Optional[int] = None) -> Tuple[float, float]: + def get_mpp(self, wsi, level: int | None = None) -> tuple[float, float]: """ Returns the micro-per-pixel resolution of the whole slide image at a given level. @@ -475,7 +575,7 @@ def get_mpp(self, wsi, level: Optional[int] = None) -> Tuple[float, float]: factor = float(wsi.resolutions["level_downsamples"][level]) return (wsi.metadata["cucim"]["spacing"][1] * factor, wsi.metadata["cucim"]["spacing"][0] * factor) - def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): + def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs): """ Read whole slide image objects from given file or list of files. @@ -488,19 +588,20 @@ def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): whole slide image object or list of such objects """ - wsi_list: List = [] + cuimage_cls, _ = optional_import("cucim", name="CuImage") + wsi_list: list = [] filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for filename in filenames: - wsi = CuImage(filename, **kwargs_) + wsi = cuimage_cls(filename, **kwargs_) wsi_list.append(wsi) return wsi_list if len(filenames) > 1 else wsi_list[0] def _get_patch( - self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str + self, wsi, location: tuple[int, int], size: tuple[int, int], level: int, dtype: DtypeLike, mode: str ) -> np.ndarray: """ Extracts and returns a patch image form the whole slide image. @@ -548,6 +649,10 @@ class OpenSlideWSIReader(BaseWSIReader): level: the whole slide image level at which the image is extracted. (default=0) This is overridden if the level argument is provided in `get_data`. channel_dim: the desired dimension for color channel. Default to 0 (channel first). + dtype: the data type of output image. Defaults to `np.uint8`. + device: target device to put the extracted patch. Note that if device is "cuda"", + the output will be converted to torch tenor and sent to the gpu even if the dtype is numpy. + mode: the output image color mode, "RGB" or "RGBA". Defaults to "RGB". kwargs: additional args for `openslide.OpenSlide` module. """ @@ -555,6 +660,17 @@ class OpenSlideWSIReader(BaseWSIReader): supported_suffixes = ["tif", "tiff", "svs"] backend = "openslide" + def __init__( + self, + level: int = 0, + channel_dim: int = 0, + dtype: DtypeLike | torch.dtype = np.uint8, + device: torch.device | str | None = None, + mode: str = "RGB", + **kwargs, + ): + super().__init__(level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs) + @staticmethod def get_level_count(wsi) -> int: """ @@ -566,7 +682,7 @@ def get_level_count(wsi) -> int: """ return wsi.level_count # type: ignore - def get_size(self, wsi, level: Optional[int] = None) -> Tuple[int, int]: + def get_size(self, wsi, level: int | None = None) -> tuple[int, int]: """ Returns the size (height, width) of the whole slide image at a given level. @@ -581,7 +697,7 @@ def get_size(self, wsi, level: Optional[int] = None) -> Tuple[int, int]: return (wsi.level_dimensions[level][1], wsi.level_dimensions[level][0]) - def get_downsample_ratio(self, wsi, level: Optional[int] = None) -> float: + def get_downsample_ratio(self, wsi, level: int | None = None) -> float: """ Returns the down-sampling ratio of the whole slide image at a given level. @@ -601,7 +717,7 @@ def get_file_path(wsi) -> str: """Return the file path for the WSI object""" return str(abspath(wsi._filename)) - def get_mpp(self, wsi, level: Optional[int] = None) -> Tuple[float, float]: + def get_mpp(self, wsi, level: int | None = None) -> tuple[float, float]: """ Returns the micro-per-pixel resolution of the whole slide image at a given level. @@ -628,7 +744,7 @@ def get_mpp(self, wsi, level: Optional[int] = None) -> Tuple[float, float]: factor *= wsi.level_downsamples[level] return (factor / float(wsi.properties["tiff.YResolution"]), factor / float(wsi.properties["tiff.XResolution"])) - def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): + def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs): """ Read whole slide image objects from given file or list of files. @@ -640,7 +756,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): whole slide image object or list of such objects """ - wsi_list: List = [] + wsi_list: list = [] filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() @@ -652,7 +768,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): return wsi_list if len(filenames) > 1 else wsi_list[0] def _get_patch( - self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str + self, wsi, location: tuple[int, int], size: tuple[int, int], level: int, dtype: DtypeLike, mode: str ) -> np.ndarray: """ Extracts and returns a patch image form the whole slide image. @@ -692,6 +808,10 @@ class TiffFileWSIReader(BaseWSIReader): level: the whole slide image level at which the image is extracted. (default=0) This is overridden if the level argument is provided in `get_data`. channel_dim: the desired dimension for color channel. Default to 0 (channel first). + dtype: the data type of output image. Defaults to `np.uint8`. + device: target device to put the extracted patch. Note that if device is "cuda"", + the output will be converted to torch tenor and sent to the gpu even if the dtype is numpy. + mode: the output image color mode, "RGB" or "RGBA". Defaults to "RGB". kwargs: additional args for `tifffile.TiffFile` module. """ @@ -699,6 +819,17 @@ class TiffFileWSIReader(BaseWSIReader): supported_suffixes = ["tif", "tiff", "svs"] backend = "tifffile" + def __init__( + self, + level: int = 0, + channel_dim: int = 0, + dtype: DtypeLike | torch.dtype = np.uint8, + device: torch.device | str | None = None, + mode: str = "RGB", + **kwargs, + ): + super().__init__(level=level, channel_dim=channel_dim, dtype=dtype, device=device, mode=mode, **kwargs) + @staticmethod def get_level_count(wsi) -> int: """ @@ -710,7 +841,7 @@ def get_level_count(wsi) -> int: """ return len(wsi.pages) - def get_size(self, wsi, level: Optional[int] = None) -> Tuple[int, int]: + def get_size(self, wsi, level: int | None = None) -> tuple[int, int]: """ Returns the size (height, width) of the whole slide image at a given level. @@ -725,7 +856,7 @@ def get_size(self, wsi, level: Optional[int] = None) -> Tuple[int, int]: return (wsi.pages[level].imagelength, wsi.pages[level].imagewidth) - def get_downsample_ratio(self, wsi, level: Optional[int] = None) -> float: + def get_downsample_ratio(self, wsi, level: int | None = None) -> float: """ Returns the down-sampling ratio of the whole slide image at a given level. @@ -745,7 +876,7 @@ def get_file_path(wsi) -> str: """Return the file path for the WSI object""" return str(abspath(wsi.filehandle.path)) - def get_mpp(self, wsi, level: Optional[int] = None) -> Tuple[float, float]: + def get_mpp(self, wsi, level: int | None = None) -> tuple[float, float]: """ Returns the micro-per-pixel resolution of the whole slide image at a given level. @@ -775,7 +906,7 @@ def get_mpp(self, wsi, level: Optional[int] = None) -> Tuple[float, float]: xres = wsi.pages[level].tags["XResolution"].value return (factor * yres[1] / yres[0], factor * xres[1] / xres[0]) - def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): + def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs): """ Read whole slide image objects from given file or list of files. @@ -787,7 +918,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): whole slide image object or list of such objects """ - wsi_list: List = [] + wsi_list: list = [] filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() @@ -799,7 +930,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): return wsi_list if len(filenames) > 1 else wsi_list[0] def _get_patch( - self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str + self, wsi, location: tuple[int, int], size: tuple[int, int], level: int, dtype: DtypeLike, mode: str ) -> np.ndarray: """ Extracts and returns a patch image form the whole slide image. @@ -822,7 +953,7 @@ def _get_patch( # Extract patch downsampling_ratio = self.get_downsample_ratio(wsi=wsi, level=level) location_ = [round(location[i] / downsampling_ratio) for i in range(len(location))] - patch = wsi_image[location_[0] : location_[0] + size[0], location_[1] : location_[1] + size[1], :].copy() + patch = wsi_image[location_[0] : location_[0] + size[0], location_[1] : location_[1] + size[1], :] # Make the channel to desired dimensions patch = np.moveaxis(patch, -1, self.channel_dim) diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index b6e54a6c4e..9e425ccbe2 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator from .multi_gpu_supervised_trainer import create_multigpu_supervised_evaluator, create_multigpu_supervised_trainer from .trainer import GanTrainer, SupervisedTrainer, Trainer diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index cc6e3c4253..7c6ddd5bdd 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -99,7 +99,7 @@ def __init__( val_handlers: Sequence | None = None, amp: bool = False, mode: ForwardMode | str = ForwardMode.EVAL, - event_names: list[str | EventEnum] | None = None, + event_names: list[str | EventEnum | type[EventEnum]] | None = None, event_to_attr: dict | None = None, decollate: bool = True, to_kwargs: dict | None = None, @@ -133,7 +133,7 @@ def __init__( else: raise ValueError(f"unsupported mode: {mode}, should be 'eval' or 'train'.") - def run(self, global_epoch: int = 1) -> None: + def run(self, global_epoch: int = 1) -> None: # type: ignore[override] """ Execute validation/evaluation based on Ignite Engine. @@ -237,7 +237,7 @@ def __init__( val_handlers: Sequence | None = None, amp: bool = False, mode: ForwardMode | str = ForwardMode.EVAL, - event_names: list[str | EventEnum] | None = None, + event_names: list[str | EventEnum | type[EventEnum]] | None = None, event_to_attr: dict | None = None, decollate: bool = True, to_kwargs: dict | None = None, @@ -267,7 +267,7 @@ def __init__( self.network = network self.inferer = SimpleInferer() if inferer is None else inferer - def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Tensor]): + def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Tensor]) -> dict: """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: @@ -295,10 +295,8 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} - # execute forward computation with engine.mode(engine.network): - if engine.amp: with torch.cuda.amp.autocast(**engine.amp_kwargs): engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) @@ -380,7 +378,7 @@ def __init__( val_handlers: Sequence | None = None, amp: bool = False, mode: ForwardMode | str = ForwardMode.EVAL, - event_names: list[str | EventEnum] | None = None, + event_names: list[str | EventEnum | type[EventEnum]] | None = None, event_to_attr: dict | None = None, decollate: bool = True, to_kwargs: dict | None = None, @@ -415,7 +413,7 @@ def __init__( raise ValueError("length of `pred_keys` must be same as the length of `networks`.") self.inferer = SimpleInferer() if inferer is None else inferer - def _iteration(self, engine: EnsembleEvaluator, batchdata: dict[str, torch.Tensor]): + def _iteration(self, engine: EnsembleEvaluator, batchdata: dict[str, torch.Tensor]) -> dict: """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: diff --git a/monai/engines/multi_gpu_supervised_trainer.py b/monai/engines/multi_gpu_supervised_trainer.py index 71ba40dd20..cb6bbe86cd 100644 --- a/monai/engines/multi_gpu_supervised_trainer.py +++ b/monai/engines/multi_gpu_supervised_trainer.py @@ -9,16 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable, Dict, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING -import torch import torch.nn from torch.nn.parallel import DataParallel, DistributedDataParallel from torch.optim.optimizer import Optimizer from monai.config import IgniteInfo from monai.engines.utils import get_devices_spec -from monai.utils import min_version, optional_import +from monai.utils import deprecated, min_version, optional_import create_supervised_trainer, _ = optional_import( "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "create_supervised_trainer" @@ -47,20 +49,25 @@ def _default_transform(_x: torch.Tensor, _y: torch.Tensor, _y_pred: torch.Tensor def _default_eval_transform( x: torch.Tensor, y: torch.Tensor, y_pred: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: return y_pred, y +@deprecated( + since="1.1", + removed="1.3", + msg_suffix=("Native ignite engine lacks support of many MONAI features, please use `SupervisedTrainer` instead."), +) def create_multigpu_supervised_trainer( net: torch.nn.Module, optimizer: Optimizer, loss_fn: Callable, - devices: Optional[Sequence[Union[str, torch.device]]] = None, + devices: Sequence[str | torch.device] | None = None, non_blocking: bool = False, prepare_batch: Callable = _prepare_batch, output_transform: Callable = _default_transform, distributed: bool = False, -): +) -> Engine: """ Derived from `create_supervised_trainer` in Ignite. @@ -97,20 +104,33 @@ def create_multigpu_supervised_trainer( elif len(devices_) > 1: net = DataParallel(net) - return create_supervised_trainer( - net, optimizer, loss_fn, devices_[0], non_blocking, prepare_batch, output_transform + return create_supervised_trainer( # type: ignore[no-any-return] + model=net, + optimizer=optimizer, + loss_fn=loss_fn, + device=devices_[0], + non_blocking=non_blocking, + prepare_batch=prepare_batch, + output_transform=output_transform, ) +@deprecated( + since="1.1", + removed="1.3", + msg_suffix=( + "Native ignite evaluator lacks support of many MONAI features, please use `SupervisedEvaluator` instead." + ), +) def create_multigpu_supervised_evaluator( net: torch.nn.Module, - metrics: Optional[Dict[str, Metric]] = None, - devices: Optional[Sequence[Union[str, torch.device]]] = None, + metrics: dict[str, Metric] | None = None, + devices: Sequence[str | torch.device] | None = None, non_blocking: bool = False, prepare_batch: Callable = _prepare_batch, output_transform: Callable = _default_eval_transform, distributed: bool = False, -): +) -> Engine: """ Derived from `create_supervised_evaluator` in Ignite. @@ -150,4 +170,11 @@ def create_multigpu_supervised_evaluator( elif len(devices_) > 1: net = DataParallel(net) - return create_supervised_evaluator(net, metrics, devices_[0], non_blocking, prepare_batch, output_transform) + return create_supervised_evaluator( # type: ignore[no-any-return] + model=net, + metrics=metrics, + device=devices_[0], + non_blocking=non_blocking, + prepare_batch=prepare_batch, + output_transform=output_transform, + ) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 2394688b9e..b18050fd09 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -43,7 +43,7 @@ class Trainer(Workflow): """ - def run(self) -> None: + def run(self) -> None: # type: ignore[override] """ Execute training based on Ignite Engine. If call this function multiple times, it will continuously run from the previous state. @@ -151,7 +151,7 @@ def __init__( metric_cmp_fn: Callable = default_metric_cmp_fn, train_handlers: Sequence | None = None, amp: bool = False, - event_names: list[str | EventEnum] | None = None, + event_names: list[str | EventEnum | type[EventEnum]] | None = None, event_to_attr: dict | None = None, decollate: bool = True, optim_set_to_none: bool = False, @@ -185,7 +185,7 @@ def __init__( self.inferer = SimpleInferer() if inferer is None else inferer self.optim_set_to_none = optim_set_to_none - def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tensor]): + def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tensor]) -> dict: """ Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 9d19737ab5..02c718cd14 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -9,8 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any, cast import torch @@ -60,7 +63,7 @@ class IterationEvents(EventEnum): INNER_ITERATION_COMPLETED = "inner_iteration_completed" -def get_devices_spec(devices: Optional[Sequence[Union[torch.device, str]]] = None) -> List[torch.device]: +def get_devices_spec(devices: Sequence[torch.device | str] | None = None) -> list[torch.device]: """ Get a valid specification for one or more devices. If `devices` is None get devices for all CUDA devices available. If `devices` is and zero-length structure a single CPU compute device is returned. In any other cases `devices` is @@ -93,11 +96,11 @@ def get_devices_spec(devices: Optional[Sequence[Union[torch.device, str]]] = Non def default_prepare_batch( - batchdata: Union[Dict[str, torch.Tensor], torch.Tensor, Sequence[torch.Tensor]], - device: Optional[Union[str, torch.device]] = None, + batchdata: dict[str, torch.Tensor] | torch.Tensor | Sequence[torch.Tensor], + device: str | torch.device | None = None, non_blocking: bool = False, - **kwargs, -) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: + **kwargs: Any, +) -> tuple[torch.Tensor, torch.Tensor | None] | torch.Tensor: """ Default function to prepare the data for current iteration. @@ -156,11 +159,11 @@ class PrepareBatch(ABC): @abstractmethod def __call__( self, - batchdata: Dict[str, torch.Tensor], - device: Optional[Union[str, torch.device]] = None, + batchdata: dict[str, torch.Tensor], + device: str | torch.device | None = None, non_blocking: bool = False, - **kwargs, - ): + **kwargs: Any, + ) -> Any: raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") @@ -171,11 +174,11 @@ class PrepareBatchDefault(PrepareBatch): def __call__( self, - batchdata: Union[Dict[str, torch.Tensor], torch.Tensor, Sequence[torch.Tensor]], - device: Optional[Union[str, torch.device]] = None, + batchdata: dict[str, torch.Tensor] | torch.Tensor | Sequence[torch.Tensor], + device: str | torch.device | None = None, non_blocking: bool = False, - **kwargs, - ): + **kwargs: Any, + ) -> tuple[torch.Tensor, torch.Tensor | None] | torch.Tensor: """ Args `batchdata`, `device`, `non_blocking` refer to the ignite API: https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html. @@ -198,16 +201,16 @@ class PrepareBatchExtraInput(PrepareBatch): dictionary keyed to `v`. """ - def __init__(self, extra_keys: Union[str, Sequence[str], Dict[str, str]]) -> None: + def __init__(self, extra_keys: str | Sequence[str] | dict[str, str]) -> None: self.extra_keys = extra_keys def __call__( self, - batchdata: Dict[str, torch.Tensor], - device: Optional[Union[str, torch.device]] = None, + batchdata: dict[str, torch.Tensor], + device: str | torch.device | None = None, non_blocking: bool = False, - **kwargs, - ): + **kwargs: Any, + ) -> tuple[torch.Tensor, torch.Tensor, tuple, dict]: """ Args `batchdata`, `device`, `non_blocking` refer to the ignite API: https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html. @@ -217,7 +220,7 @@ def __call__( args_ = list() kwargs_ = dict() - def _get_data(key: str): + def _get_data(key: str) -> torch.Tensor: data = batchdata[key] if isinstance(data, torch.Tensor): @@ -232,20 +235,20 @@ def _get_data(key: str): for k, v in self.extra_keys.items(): kwargs_.update({k: _get_data(v)}) - return image, label, tuple(args_), kwargs_ + return cast(torch.Tensor, image), cast(torch.Tensor, label), tuple(args_), kwargs_ def default_make_latent( num_latents: int, latent_size: int, - device: Optional[Union[str, torch.device]] = None, + device: str | torch.device | None = None, non_blocking: bool = False, - **kwargs, + **kwargs: Any, ) -> torch.Tensor: return torch.randn(num_latents, latent_size).to(device=device, non_blocking=non_blocking, **kwargs) -def engine_apply_transform(batch: Any, output: Any, transform: Callable[..., Dict]): +def engine_apply_transform(batch: Any, output: Any, transform: Callable[..., dict]) -> tuple[Any, Any]: """ Apply transform on `batch` and `output`. If `batch` and `output` are dictionaries, temporarily combine them for the transform, diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index da6c086be9..30622c2b93 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -9,8 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Union +from collections.abc import Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Any import torch import torch.distributed as dist @@ -24,7 +27,6 @@ from .utils import engine_apply_transform -IgniteEngine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine", as_type="") State, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "State") Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") @@ -43,7 +45,7 @@ ) -class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optional_import +class Workflow(Engine): """ Workflow defines the core work process inheriting from Ignite engine. All trainer, validator and evaluator share this same workflow as base class, @@ -101,24 +103,24 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona def __init__( self, - device: Union[torch.device, str], + device: torch.device | str, max_epochs: int, - data_loader: Union[Iterable, DataLoader], - epoch_length: Optional[int] = None, + data_loader: Iterable | DataLoader, + epoch_length: int | None = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable[[Engine, Any], Any]] = None, - postprocessing: Optional[Callable] = None, - key_metric: Optional[Dict[str, Metric]] = None, - additional_metrics: Optional[Dict[str, Metric]] = None, + iteration_update: Callable[[Engine, Any], Any] | None = None, + postprocessing: Callable | None = None, + key_metric: dict[str, Metric] | None = None, + additional_metrics: dict[str, Metric] | None = None, metric_cmp_fn: Callable = default_metric_cmp_fn, - handlers: Optional[Sequence] = None, + handlers: Sequence | None = None, amp: bool = False, - event_names: Optional[List[Union[str, EventEnum]]] = None, - event_to_attr: Optional[dict] = None, + event_names: list[str | EventEnum | type[EventEnum]] | None = None, + event_to_attr: dict | None = None, decollate: bool = True, - to_kwargs: Optional[Dict] = None, - amp_kwargs: Optional[Dict] = None, + to_kwargs: dict | None = None, + amp_kwargs: dict | None = None, ) -> None: if iteration_update is not None: super().__init__(iteration_update) @@ -130,7 +132,7 @@ def __init__( if isinstance(sampler, DistributedSampler): @self.on(Events.EPOCH_STARTED) - def set_sampler_epoch(engine: Engine): + def set_sampler_epoch(engine: Engine) -> None: sampler.set_epoch(engine.state.epoch) if epoch_length is None: @@ -140,7 +142,7 @@ def set_sampler_epoch(engine: Engine): raise ValueError("If data_loader is not PyTorch DataLoader, must specify the epoch_length.") # set all sharable data for the workflow based on Ignite engine.state - self.state = State( + self.state: Any = State( rank=dist.get_rank() if dist.is_available() and dist.is_initialized() else 0, seed=0, iteration=0, @@ -164,21 +166,21 @@ def set_sampler_epoch(engine: Engine): self.amp = amp self.to_kwargs = {} if to_kwargs is None else to_kwargs self.amp_kwargs = {} if amp_kwargs is None else amp_kwargs - self.scaler: Optional[torch.cuda.amp.GradScaler] = None + self.scaler: torch.cuda.amp.GradScaler | None = None if event_names is None: - event_names = [IterationEvents] # type: ignore + event_names = [IterationEvents] else: if not isinstance(event_names, list): - raise ValueError("`event_names` must be a list or string or EventEnum.") - event_names += [IterationEvents] # type: ignore + raise ValueError("`event_names` must be a list of strings or EventEnums.") + event_names += [IterationEvents] for name in event_names: - if isinstance(name, str): - self.register_events(name, event_to_attr=event_to_attr) - elif issubclass(name, EventEnum): # type: ignore + if isinstance(name, (str, EventEnum)): + self.register_events(name, event_to_attr=event_to_attr) # type: ignore[arg-type] + elif issubclass(name, EventEnum): self.register_events(*name, event_to_attr=event_to_attr) else: - raise ValueError("`event_names` must be a list or string or EventEnum.") + raise ValueError("`event_names` must be a list of strings or EventEnums.") if decollate: self._register_decollate() @@ -207,7 +209,7 @@ def _decollate_data(engine: Engine) -> None: if isinstance(engine.state.output, (list, dict)): engine.state.output = transform(engine.state.output) - def _register_postprocessing(self, posttrans: Callable): + def _register_postprocessing(self, posttrans: Callable) -> None: """ Register the postprocessing logic to the engine, will execute them as a chain when iteration completed. @@ -223,7 +225,7 @@ def _run_postprocessing(engine: Engine) -> None: for i, (b, o) in enumerate(zip(engine.state.batch, engine.state.output)): engine.state.batch[i], engine.state.output[i] = engine_apply_transform(b, o, posttrans) - def _register_metrics(self, k_metric: Dict, add_metrics: Optional[Dict] = None): + def _register_metrics(self, k_metric: dict, add_metrics: dict | None = None) -> None: """ Register the key metric and additional metrics to the engine, supports ignite Metrics. @@ -258,7 +260,7 @@ def _compare_metrics(engine: Workflow) -> None: engine.state.best_metric = current_val_metric engine.state.best_metric_epoch = engine.state.epoch - def _register_handlers(self, handlers: Sequence): + def _register_handlers(self, handlers: Sequence) -> None: """ Register the handlers to the engine, supports ignite Handlers with `attach` API. @@ -267,7 +269,7 @@ def _register_handlers(self, handlers: Sequence): for handler in handlers_: handler.attach(self) - def run(self) -> None: + def run(self) -> None: # type: ignore[override] """ Execute training, validation or evaluation based on Ignite Engine. """ @@ -280,7 +282,7 @@ def run(self) -> None: return super().run(data=self.data_loader, max_epochs=self.state.max_epochs) - def _iteration(self, engine, batchdata: Dict[str, torch.Tensor]): + def _iteration(self, engine: Any, batchdata: dict[str, torch.Tensor]) -> dict: """ Abstract callback function for the processing logic of 1 iteration in Ignite Engine. Need subclass to implement different logics, like SupervisedTrainer/Evaluator, GANTrainer, etc. diff --git a/monai/fl/client/__init__.py b/monai/fl/client/__init__.py index 7acb82c635..e9f1ab8601 100644 --- a/monai/fl/client/__init__.py +++ b/monai/fl/client/__init__.py @@ -9,5 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .client_algo import BaseClient, ClientAlgo, ClientAlgoStats from .monai_algo import MonaiAlgo, MonaiAlgoStats diff --git a/monai/fl/client/client_algo.py b/monai/fl/client/client_algo.py index 9c54f2891b..25a88a9e66 100644 --- a/monai/fl/client/client_algo.py +++ b/monai/fl/client/client_algo.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from __future__ import annotations from monai.fl.utils.exchange_object import ExchangeObject @@ -27,7 +27,7 @@ class BaseClient: to help with lifecycle management of the class object. """ - def initialize(self, extra: Optional[dict] = None): + def initialize(self, extra: dict | None = None) -> None: """ Call to initialize the ClientAlgo class. @@ -36,7 +36,7 @@ def initialize(self, extra: Optional[dict] = None): """ pass - def finalize(self, extra: Optional[dict] = None): + def finalize(self, extra: dict | None = None) -> None: """ Call to finalize the ClientAlgo class. @@ -45,7 +45,7 @@ def finalize(self, extra: Optional[dict] = None): """ pass - def abort(self, extra: Optional[dict] = None): + def abort(self, extra: dict | None = None) -> None: """ Call to abort the ClientAlgo training or evaluation. @@ -57,7 +57,7 @@ def abort(self, extra: Optional[dict] = None): class ClientAlgoStats(BaseClient): - def get_data_stats(self, extra: Optional[dict] = None) -> ExchangeObject: + def get_data_stats(self, extra: dict | None = None) -> ExchangeObject: """ Get summary statistics about the local data. @@ -102,7 +102,7 @@ class ClientAlgo(ClientAlgoStats): to help with lifecycle management of the class object. """ - def train(self, data: ExchangeObject, extra: Optional[dict] = None) -> None: + def train(self, data: ExchangeObject, extra: dict | None = None) -> None: """ Train network and produce new network from train data. @@ -115,7 +115,7 @@ def train(self, data: ExchangeObject, extra: Optional[dict] = None) -> None: """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - def get_weights(self, extra: Optional[dict] = None) -> ExchangeObject: + def get_weights(self, extra: dict | None = None) -> ExchangeObject: """ Get current local weights or weight differences. @@ -138,7 +138,7 @@ def get_weights(self, extra: Optional[dict] = None) -> ExchangeObject: """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - def evaluate(self, data: ExchangeObject, extra: Optional[dict] = None) -> ExchangeObject: + def evaluate(self, data: ExchangeObject, extra: dict | None = None) -> ExchangeObject: """ Get evaluation metrics on test data. diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py index 2cdabdca9a..031143c69b 100644 --- a/monai/fl/client/monai_algo.py +++ b/monai/fl/client/monai_algo.py @@ -9,19 +9,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import os -import sys -from typing import Any, Dict, Mapping, MutableMapping, Optional, Union, cast +from collections.abc import Mapping, MutableMapping +from typing import Any, cast import torch import torch.distributed as dist import monai from monai.apps.auto3dseg.data_analyzer import DataAnalyzer +from monai.apps.utils import get_logger from monai.auto3dseg import SegSummarizer -from monai.bundle import DEFAULT_EXP_MGMT_SETTINGS, ConfigComponent, ConfigItem, ConfigParser, patch_bundle_tracking -from monai.engines import Trainer +from monai.bundle import DEFAULT_EXP_MGMT_SETTINGS, ConfigComponent, ConfigItem, ConfigParser, ConfigWorkflow +from monai.engines import SupervisedTrainer, Trainer from monai.fl.client import ClientAlgo, ClientAlgoStats from monai.fl.utils.constants import ( BundleKeys, @@ -38,10 +41,10 @@ from monai.utils import min_version, require_pkg from monai.utils.enums import DataStatsKeys -logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(asctime)s - %(message)s") +logger = get_logger(__name__) -def convert_global_weights(global_weights: Mapping, local_var_dict: MutableMapping): +def convert_global_weights(global_weights: Mapping, local_var_dict: MutableMapping) -> tuple[MutableMapping, int]: """Helper function to convert global weights to local weights format""" # Before loading weights, tensors might need to be reshaped to support HE for secure aggregation. model_keys = global_weights.keys() @@ -106,14 +109,14 @@ class MonaiAlgoStats(ClientAlgoStats): def __init__( self, bundle_root: str, - config_train_filename: Optional[Union[str, list]] = "configs/train.json", - config_filters_filename: Optional[Union[str, list]] = None, - train_data_key: Optional[str] = BundleKeys.TRAIN_DATA, - eval_data_key: Optional[str] = BundleKeys.VALID_DATA, - data_stats_transform_list: Optional[list] = None, + config_train_filename: str | list | None = "configs/train.json", + config_filters_filename: str | list | None = None, + train_data_key: str | None = BundleKeys.TRAIN_DATA, + eval_data_key: str | None = BundleKeys.VALID_DATA, + data_stats_transform_list: list | None = None, histogram_only: bool = False, ): - self.logger = logging.getLogger(self.__class__.__name__) + self.logger = logger self.bundle_root = bundle_root self.config_train_filename = config_train_filename self.config_filters_filename = config_filters_filename @@ -122,10 +125,10 @@ def __init__( self.data_stats_transform_list = data_stats_transform_list self.histogram_only = histogram_only - self.client_name: Optional[str] = None + self.client_name: str | None = None self.app_root: str = "" - self.train_parser: Optional[ConfigParser] = None - self.filter_parser: Optional[ConfigParser] = None + self.train_parser: ConfigParser | None = None + self.filter_parser: ConfigParser | None = None self.post_statistics_filters: Any = None self.phase = FlPhase.IDLE self.dataset_root: Any = None @@ -177,7 +180,7 @@ def initialize(self, extra=None): self.logger.info(f"Initialized {self.client_name}.") - def get_data_stats(self, extra: Optional[dict] = None) -> ExchangeObject: + def get_data_stats(self, extra: dict | None = None) -> ExchangeObject: """ Returns summary statistics about the local data. @@ -322,6 +325,7 @@ def _add_config_files(self, config_files): class MonaiAlgo(ClientAlgo, MonaiAlgoStats): """ Implementation of ``ClientAlgo`` to allow federated learning with MONAI bundle configurations. + FIXME: reimplement this class based on the bundle "ConfigWorkflow". Args: bundle_root: path of bundle. @@ -346,37 +350,13 @@ class MonaiAlgo(ClientAlgo, MonaiAlgoStats): multi_gpu: whether to run MonaiAlgo in a multi-GPU setting; defaults to `False`. backend: backend to use for torch.distributed; defaults to "nccl". init_method: init_method for torch.distributed; defaults to "env://". - tracking: enable the experiment tracking feature at runtime with optionally configurable and extensible. - if "mlflow", will add `MLFlowHandler` to the parsed bundle with default logging settings, - if other string, treat it as file path to load the logging settings, if `dict`, - treat it as logging settings, otherwise, use all the default settings. + tracking: if not None, enable the experiment tracking at runtime with optionally configurable and extensible. + if "mlflow", will add `MLFlowHandler` to the parsed bundle with default tracking settings, + if other string, treat it as file path to load the tracking settings. + if `dict`, treat it as tracking settings. will patch the target config content with `tracking handlers` and the top-level items of `configs`. - example of customized settings: - - .. code-block:: python - - tracking = { - "handlers_id": { - "trainer": {"id": "train#trainer", "handlers": "train#handlers"}, - "validator": {"id": "evaluate#evaluator", "handlers": "evaluate#handlers"}, - "evaluator": {"id": "evaluator", "handlers": "handlers"}, - }, - "configs": { - "tracking_uri": "", - "trainer": { - "_target_": "MLFlowHandler", - "tracking_uri": "@tracking_uri", - "iteration_log": True, - "output_transform": "$monai.handlers.from_engine(['loss'], first=True)", - }, - "validator": { - "_target_": "MLFlowHandler", "tracking_uri": "@tracking_uri", "iteration_log": False, - }, - "evaluator": { - "_target_": "MLFlowHandler", "tracking_uri": "@tracking_uri", "iteration_log": False, - }, - }, - }, + for detailed usage examples, plesae check the tutorial: + https://github.com/Project-MONAI/tutorials/blob/main/experiment_management/bundle_integrate_mlflow.ipynb. """ @@ -385,24 +365,24 @@ def __init__( bundle_root: str, local_epochs: int = 1, send_weight_diff: bool = True, - config_train_filename: Optional[Union[str, list]] = "configs/train.json", - config_evaluate_filename: Optional[Union[str, list]] = "default", - config_filters_filename: Optional[Union[str, list]] = None, + config_train_filename: str | list | None = "configs/train.json", + config_evaluate_filename: str | list | None = "default", + config_filters_filename: str | list | None = None, disable_ckpt_loading: bool = True, - best_model_filepath: Optional[str] = "models/model.pt", - final_model_filepath: Optional[str] = "models/model_final.pt", - save_dict_key: Optional[str] = "model", - seed: Optional[int] = None, + best_model_filepath: str | None = "models/model.pt", + final_model_filepath: str | None = "models/model_final.pt", + save_dict_key: str | None = "model", + seed: int | None = None, benchmark: bool = True, multi_gpu: bool = False, backend: str = "nccl", init_method: str = "env://", - train_data_key: Optional[str] = BundleKeys.TRAIN_DATA, - eval_data_key: Optional[str] = BundleKeys.VALID_DATA, - data_stats_transform_list: Optional[list] = None, - tracking: Optional[Union[str, dict]] = None, + train_data_key: str | None = BundleKeys.TRAIN_DATA, + eval_data_key: str | None = BundleKeys.VALID_DATA, + data_stats_transform_list: list | None = None, + tracking: str | dict | None = None, ): - self.logger = logging.getLogger(self.__class__.__name__) + self.logger = logger if config_evaluate_filename == "default": # by default, evaluator needs both training and evaluate to be instantiated. config_evaluate_filename = ["configs/train.json", "configs/evaluate.json"] @@ -426,16 +406,16 @@ def __init__( self.tracking = tracking self.app_root = "" - self.train_parser: Optional[ConfigParser] = None - self.eval_parser: Optional[ConfigParser] = None - self.filter_parser: Optional[ConfigParser] = None - self.trainer: Optional[Trainer] = None - self.evaluator: Optional[Any] = None + self.train_parser: ConfigParser | None = None + self.eval_parser: ConfigParser | None = None + self.filter_parser: ConfigParser | None = None + self.trainer: SupervisedTrainer | None = None + self.evaluator: Any | None = None self.pre_filters = None self.post_weight_filters = None self.post_evaluate_filters = None self.iter_of_start_time = 0 - self.global_weights: Optional[Mapping] = None + self.global_weights: Mapping | None = None self.rank = 0 self.phase = FlPhase.IDLE @@ -510,8 +490,8 @@ def initialize(self, extra=None): settings_ = DEFAULT_EXP_MGMT_SETTINGS[self.tracking] else: settings_ = ConfigParser.load_config_files(self.tracking) - patch_bundle_tracking(parser=self.train_parser, settings=settings_) - patch_bundle_tracking(parser=self.eval_parser, settings=settings_) + ConfigWorkflow.patch_bundle_tracking(parser=self.train_parser, settings=settings_) + ConfigWorkflow.patch_bundle_tracking(parser=self.eval_parser, settings=settings_) # Get trainer, evaluator self.trainer = self.train_parser.get_parsed_content( @@ -547,7 +527,7 @@ def initialize(self, extra=None): self.evaluator.logger.setLevel(logging.WARNING) self.logger.info(f"Initialized {self.client_name}.") - def train(self, data: ExchangeObject, extra=None): + def train(self, data: ExchangeObject, extra: dict | None = None) -> None: """ Train on client's local data. @@ -572,7 +552,7 @@ def train(self, data: ExchangeObject, extra=None): self.logger.info(f"Load {self.client_name} weights...") local_var_dict = get_state_dict(self.trainer.network) self.global_weights, n_converted = convert_global_weights( - global_weights=data.weights, local_var_dict=local_var_dict + global_weights=cast(dict, data.weights), local_var_dict=local_var_dict ) self._check_converted(data.weights, local_var_dict, n_converted) @@ -621,8 +601,8 @@ def get_weights(self, extra=None): # if weights contain several state dicts, use the one defined by `save_dict_key` if isinstance(weights, dict) and self.save_dict_key in weights: weights = weights.get(self.save_dict_key) - weigh_type: Optional[WeightType] = WeightType.WEIGHTS - stats: Dict = {} + weigh_type: WeightType | None = WeightType.WEIGHTS + stats: dict = {} self.logger.info(f"Returning {model_type} checkpoint weights from {model_path}.") else: raise ValueError( @@ -666,7 +646,7 @@ def get_weights(self, extra=None): return return_weights - def evaluate(self, data: ExchangeObject, extra=None): + def evaluate(self, data: ExchangeObject, extra: dict | None = None) -> ExchangeObject: """ Evaluate on client's local data. @@ -694,7 +674,9 @@ def evaluate(self, data: ExchangeObject, extra=None): self.phase = FlPhase.EVALUATE self.logger.info(f"Load {self.client_name} weights...") local_var_dict = get_state_dict(self.evaluator.network) - global_weights, n_converted = convert_global_weights(global_weights=data.weights, local_var_dict=local_var_dict) + global_weights, n_converted = convert_global_weights( + global_weights=cast(dict, data.weights), local_var_dict=local_var_dict + ) self._check_converted(data.weights, local_var_dict, n_converted) _, updated_keys, _ = copy_model_state(src=global_weights, dst=self.evaluator.network) @@ -726,7 +708,7 @@ def abort(self, extra=None): self.logger.info(f"Aborting {self.client_name} evaluator...") self.evaluator.interrupt() - def finalize(self, extra=None): + def finalize(self, extra: dict | None = None) -> None: """ Finalize the training or evaluation. Args: diff --git a/monai/fl/utils/constants.py b/monai/fl/utils/constants.py index cd24e6093d..fbd18b364c 100644 --- a/monai/fl/utils/constants.py +++ b/monai/fl/utils/constants.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from monai.utils.enums import StrEnum diff --git a/monai/fl/utils/exchange_object.py b/monai/fl/utils/exchange_object.py index 3fba895101..b895d2d53e 100644 --- a/monai/fl/utils/exchange_object.py +++ b/monai/fl/utils/exchange_object.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional +from __future__ import annotations from monai.fl.utils.constants import WeightType @@ -28,11 +28,11 @@ class ExchangeObject(dict): def __init__( self, - weights=None, - optim=None, - metrics: Optional[Dict] = None, - weight_type: Optional[WeightType] = None, - statistics: Optional[Dict] = None, + weights: dict | None = None, + optim: dict | None = None, + metrics: dict | None = None, + weight_type: WeightType | None = None, + statistics: dict | None = None, ): super().__init__() self.weights = weights @@ -40,7 +40,7 @@ def __init__( self.metrics = metrics self.weight_type = weight_type self.statistics = statistics - self._summary: Dict = {} + self._summary: dict = {} @property def metrics(self): diff --git a/monai/fl/utils/filters.py b/monai/fl/utils/filters.py index b205ffe668..15acabd9a2 100644 --- a/monai/fl/utils/filters.py +++ b/monai/fl/utils/filters.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import abc from monai.fl.utils.exchange_object import ExchangeObject @@ -20,7 +22,7 @@ class Filter(abc.ABC): """ @abc.abstractmethod - def __call__(self, data: ExchangeObject, extra=None) -> ExchangeObject: + def __call__(self, data: ExchangeObject, extra: dict | None = None) -> ExchangeObject: """ Run the filtering. @@ -39,7 +41,7 @@ class SummaryFilter(Filter): Summary filter to content of ExchangeObject. """ - def __call__(self, data: ExchangeObject, extra=None) -> ExchangeObject: + def __call__(self, data: ExchangeObject, extra: dict | None = None) -> ExchangeObject: """ Example filter that doesn't filter anything but only prints data summary. diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 9880e39817..f032191043 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -9,9 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .checkpoint_loader import CheckpointLoader from .checkpoint_saver import CheckpointSaver from .classification_saver import ClassificationSaver +from .clearml_handlers import ClearMLHandler, ClearMLImageHandler, ClearMLStatsHandler from .confusion_matrix import ConfusionMatrix from .decollate_batch import DecollateBatch from .earlystop_handler import EarlyStopHandler @@ -23,6 +26,7 @@ from .mean_dice import MeanDice from .mean_iou import MeanIoUHandler from .metric_logger import MetricLogger, MetricLoggerKeys +from .metrics_reloaded_handler import MetricsReloadedBinaryHandler, MetricsReloadedCategoricalHandler from .metrics_saver import MetricsSaver from .mlflow_handler import MLFlowHandler from .nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index 91cfca354a..5b05e7055c 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -9,9 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import warnings -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING import torch @@ -69,9 +71,9 @@ class CheckpointLoader: def __init__( self, load_path: str, - load_dict: Dict, - name: Optional[str] = None, - map_location: Optional[Dict] = None, + load_dict: dict, + name: str | None = None, + map_location: dict | None = None, strict: bool = True, strict_shape: bool = True, ) -> None: @@ -112,7 +114,7 @@ def __call__(self, engine: Engine) -> None: checkpoint = {k: checkpoint} if not self.strict_shape: - pop_items: List[str] = [] + pop_items: list[str] = [] for k, obj in self.load_dict.items(): if isinstance(obj, torch.nn.Module): # skip items that don't match key name or data shape diff --git a/monai/handlers/checkpoint_saver.py b/monai/handlers/checkpoint_saver.py index 76f6458f3d..0651c6ff33 100644 --- a/monai/handlers/checkpoint_saver.py +++ b/monai/handlers/checkpoint_saver.py @@ -9,9 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging +import os import warnings -from typing import TYPE_CHECKING, Dict, Mapping, Optional +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any from monai.config import IgniteInfo from monai.utils import is_scalar, min_version, optional_import @@ -86,21 +90,21 @@ class CheckpointSaver: def __init__( self, save_dir: str, - save_dict: Dict, - name: Optional[str] = None, + save_dict: dict, + name: str | None = None, file_prefix: str = "", save_final: bool = False, - final_filename: Optional[str] = None, + final_filename: str | None = None, save_key_metric: bool = False, - key_metric_name: Optional[str] = None, + key_metric_name: str | None = None, key_metric_n_saved: int = 1, - key_metric_filename: Optional[str] = None, + key_metric_filename: str | None = None, key_metric_save_state: bool = False, key_metric_greater_or_equal: bool = False, key_metric_negative_sign: bool = False, epoch_level: bool = True, save_interval: int = 0, - n_saved: Optional[int] = None, + n_saved: int | None = None, ) -> None: if save_dir is None: raise AssertionError("must provide directory to save the checkpoints.") @@ -111,10 +115,11 @@ def __init__( self.logger = logging.getLogger(name) self.epoch_level = epoch_level self.save_interval = save_interval - self._final_checkpoint: Optional[Checkpoint] = None - self._key_metric_checkpoint: Optional[Checkpoint] = None - self._interval_checkpoint: Optional[Checkpoint] = None + self._final_checkpoint: Checkpoint | None = None + self._key_metric_checkpoint: Checkpoint | None = None + self._interval_checkpoint: Checkpoint | None = None self._name = name + self._final_filename = final_filename class _DiskSaver(DiskSaver): """ @@ -122,13 +127,13 @@ class _DiskSaver(DiskSaver): """ - def __init__(self, dirname: str, filename: Optional[str] = None): + def __init__(self, dirname: str, filename: str | None = None): # set `atomic=False` as `atomic=True` only gives read/write permission to the user who saved the file, # without group/others read permission super().__init__(dirname=dirname, require_empty=False, atomic=False) self.filename = filename - def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mapping] = None) -> None: + def __call__(self, checkpoint: Mapping, filename: str, metadata: Mapping | None = None) -> None: if self.filename is not None: filename = self.filename super().__call__(checkpoint=checkpoint, filename=filename, metadata=metadata) @@ -140,12 +145,12 @@ def remove(self, filename: str) -> None: if save_final: - def _final_func(engine: Engine): + def _final_func(engine: Engine) -> Any: return engine.state.iteration self._final_checkpoint = Checkpoint( to_save=self.save_dict, - save_handler=_DiskSaver(dirname=self.save_dir, filename=final_filename), + save_handler=_DiskSaver(dirname=self.save_dir, filename=self._final_filename), filename_prefix=file_prefix, score_function=_final_func, score_name="final_iteration", @@ -153,7 +158,7 @@ def _final_func(engine: Engine): if save_key_metric: - def _score_func(engine: Engine): + def _score_func(engine: Engine) -> Any: if isinstance(key_metric_name, str): metric_name = key_metric_name elif hasattr(engine.state, "key_metric_name"): @@ -188,7 +193,7 @@ def _score_func(engine: Engine): if save_interval > 0: - def _interval_func(engine: Engine): + def _interval_func(engine: Engine) -> Any: return engine.state.epoch if self.epoch_level else engine.state.iteration self._interval_checkpoint = Checkpoint( @@ -200,7 +205,7 @@ def _interval_func(engine: Engine): n_saved=n_saved, ) - def load_state_dict(self, state_dict: Dict) -> None: + def load_state_dict(self, state_dict: dict) -> None: """ Utility to resume the internal state of key metric tracking list if configured to save checkpoints based on the key metric value. @@ -268,7 +273,11 @@ def completed(self, engine: Engine) -> None: raise AssertionError if not hasattr(self.logger, "info"): raise AssertionError("Error, provided logger has not info attribute.") - self.logger.info(f"Train completed, saved final checkpoint: {self._final_checkpoint.last_checkpoint}") + if self._final_filename is not None: + _final_checkpoint_path = os.path.join(self.save_dir, self._final_filename) + else: + _final_checkpoint_path = self._final_checkpoint.last_checkpoint # type: ignore[assignment] + self.logger.info(f"Train completed, saved final checkpoint: {_final_checkpoint_path}") def exception_raised(self, engine: Engine, e: Exception) -> None: """Callback for train or validation/evaluation exception raised Event. @@ -288,7 +297,11 @@ def exception_raised(self, engine: Engine, e: Exception) -> None: raise AssertionError if not hasattr(self.logger, "info"): raise AssertionError("Error, provided logger has not info attribute.") - self.logger.info(f"Exception raised, saved the last checkpoint: {self._final_checkpoint.last_checkpoint}") + if self._final_filename is not None: + _final_checkpoint_path = os.path.join(self.save_dir, self._final_filename) + else: + _final_checkpoint_path = self._final_checkpoint.last_checkpoint # type: ignore[assignment] + self.logger.info(f"Exception raised, saved the last checkpoint: {_final_checkpoint_path}") raise e def metrics_completed(self, engine: Engine) -> None: diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 75fb394177..831808f4fb 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -9,9 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import warnings -from typing import TYPE_CHECKING, Callable, List, Optional +from collections.abc import Callable +from typing import TYPE_CHECKING import torch @@ -43,9 +46,9 @@ def __init__( overwrite: bool = True, batch_transform: Callable = lambda x: x, output_transform: Callable = lambda x: x, - name: Optional[str] = None, + name: str | None = None, save_rank: int = 0, - saver: Optional[CSVSaver] = None, + saver: CSVSaver | None = None, ) -> None: """ Args: @@ -85,8 +88,8 @@ def __init__( self.logger = logging.getLogger(name) self._name = name - self._outputs: List[torch.Tensor] = [] - self._filenames: List[str] = [] + self._outputs: list[torch.Tensor] = [] + self._filenames: list[str] = [] def attach(self, engine: Engine) -> None: """ diff --git a/monai/handlers/clearml_handlers.py b/monai/handlers/clearml_handlers.py new file mode 100644 index 0000000000..f4d6f197d2 --- /dev/null +++ b/monai/handlers/clearml_handlers.py @@ -0,0 +1,178 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Mapping, Sequence + +from monai.utils import optional_import + +from .tensorboard_handlers import TensorBoardImageHandler, TensorBoardStatsHandler + + +class ClearMLHandler: + """ + Base class for the handlers to log everything to ClearML. + For more details of ClearML usage, please refer to: + https://clear.ml/docs/latest/docs/references/sdk/task + + Usage example is available in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb. + + """ + + def __init__( + self, + project_name: str | None, + task_name: str | None, + output_uri: str | bool, + tags: Sequence[str] | None, + reuse_last_task_id: bool, + continue_last_task: bool, + auto_connect_frameworks: bool | Mapping[str, bool | str | list], + auto_connect_arg_parser: bool | Mapping[str, bool], + ) -> None: + """ + Args: + project_name: ClearML project name, default to 'MONAI'. + task_name: ClearML task name, default to 'monai_experiment'. + output_uri: The default location for output models and other artifacts, default to 'True'. + tags: A list of tags (str) to the created Task, default to 'None'. + reuse_last_task_id: Force a new Task (experiment) with a previously used Task ID, default to 'True'. + continue_last_task: Continue the execution of a previously executed Task (experiment), default to 'False'. + auto_connect_frameworks: Automatically connect frameworks, default to 'True'. + auto_connect_arg_parser: Automatically connect an argparse object to the Task, default to 'True'. + + """ + + if TYPE_CHECKING: + import clearml + else: + clearml, _ = optional_import("clearml") + + # Always check if the user didn't already add a `task.init`` in before + # if so, use that task, otherwise create a new one. + if clearml.Task.current_task(): + self.clearml_task = clearml.Task.current_task() + else: + self.clearml_task = clearml.Task.init( + project_name=project_name, + task_name=task_name, + output_uri=output_uri, + tags=tags, + reuse_last_task_id=reuse_last_task_id, + continue_last_task=continue_last_task, + auto_connect_frameworks=auto_connect_frameworks, + auto_connect_arg_parser=auto_connect_arg_parser, + ) + + +class ClearMLStatsHandler(ClearMLHandler, TensorBoardStatsHandler): + """ + + Class to write tensorboard stats by inheriting TensorBoardStatsHandler class. + Everything from Tensorboard is logged automatically to ClearML. + + Usage example is available in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb. + + """ + + def __init__( + self, + project_name: str | None = "MONAI", + task_name: str | None = "monai_experiment", + output_uri: str | bool = True, + tags: Sequence[str] | None = None, + reuse_last_task_id: bool = True, + continue_last_task: bool = False, + auto_connect_frameworks: bool | Mapping[str, bool | str | list] = True, + auto_connect_arg_parser: bool | Mapping[str, bool] = True, + *args: Any, + **kwargs: Any, + ) -> None: + """ + Args: + project_name: ClearML project name, default to 'MONAI'. + task_name: ClearML task name, default to 'monai_experiment'. + output_uri: The default location for output models and other artifacts, default to 'True'. + tags: A list of tags (str) to the created Task, default to 'None'. + reuse_last_task_id: Force a new Task (experiment) with a previously used Task ID, default to 'True'. + continue_last_task: Continue the execution of a previously executed Task (experiment), default to 'False'. + auto_connect_frameworks: Automatically connect frameworks, default to 'True'. + auto_connect_arg_parser: Automatically connect an argparse object to the Task, default to 'True'. + + """ + + ClearMLHandler.__init__( + self, + project_name=project_name, + task_name=task_name, + output_uri=output_uri, + tags=tags, + reuse_last_task_id=reuse_last_task_id, + continue_last_task=continue_last_task, + auto_connect_frameworks=auto_connect_frameworks, + auto_connect_arg_parser=auto_connect_arg_parser, + ) + TensorBoardStatsHandler.__init__(self, *args, **kwargs) + + +class ClearMLImageHandler(ClearMLHandler, TensorBoardImageHandler): + """ + + This class inherits all functionality from TensorBoardImageHandler class. + Everything from Tensorboard is logged automatically to ClearML. + + Usage example is available in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb. + + """ + + def __init__( + self, + project_name: str | None = "MONAI", + task_name: str | None = "monai_experiment", + output_uri: str | bool = True, + tags: Sequence[str] | None = None, + reuse_last_task_id: bool = True, + continue_last_task: bool = False, + auto_connect_frameworks: bool | Mapping[str, bool | str | list] = True, + auto_connect_arg_parser: bool | Mapping[str, bool] = True, + *args: Any, + **kwargs: Any, + ) -> None: + """ + Args: + project_name: ClearML project name, default to 'MONAI'. + task_name: ClearML task name, default to 'monai_experiment'. + output_uri: The default location for output models and other artifacts, default to 'True'. + tags: A list of tags (str) to the created Task, default to 'None'. + reuse_last_task_id: Force a new Task (experiment) with a previously used Task ID, default to 'True'. + continue_last_task: Continue the execution of a previously executed Task (experiment), default to 'False'. + auto_connect_frameworks: Automatically connect frameworks, default to 'True'. + auto_connect_arg_parser: Automatically connect an argparse object to the Task, default to 'True'. + + """ + + ClearMLHandler.__init__( + self, + project_name=project_name, + task_name=task_name, + output_uri=output_uri, + tags=tags, + reuse_last_task_id=reuse_last_task_id, + continue_last_task=continue_last_task, + auto_connect_frameworks=auto_connect_frameworks, + auto_connect_arg_parser=auto_connect_arg_parser, + ) + + TensorBoardImageHandler.__init__(self, *args, **kwargs) diff --git a/monai/handlers/confusion_matrix.py b/monai/handlers/confusion_matrix.py index e3fc4bfbf1..0684c453a1 100644 --- a/monai/handlers/confusion_matrix.py +++ b/monai/handlers/confusion_matrix.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Union +from __future__ import annotations + +from collections.abc import Callable from monai.handlers.ignite_metric import IgniteMetric from monai.metrics import ConfusionMatrixMetric @@ -26,7 +28,7 @@ def __init__( include_background: bool = True, metric_name: str = "hit_rate", compute_sample: bool = False, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + reduction: MetricReduction | str = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: diff --git a/monai/handlers/decollate_batch.py b/monai/handlers/decollate_batch.py index a0d0ef3ad2..ac3aa94145 100644 --- a/monai/handlers/decollate_batch.py +++ b/monai/handlers/decollate_batch.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING from monai.config import IgniteInfo, KeysCollection from monai.engines.utils import IterationEvents @@ -51,9 +53,9 @@ def __init__( event: str = "MODEL_COMPLETED", detach: bool = True, decollate_batch: bool = True, - batch_keys: Optional[KeysCollection] = None, + batch_keys: KeysCollection | None = None, decollate_output: bool = True, - output_keys: Optional[KeysCollection] = None, + output_keys: KeysCollection | None = None, allow_missing_keys: bool = False, ): event = event.upper() diff --git a/monai/handlers/earlystop_handler.py b/monai/handlers/earlystop_handler.py index 4f61fa3e00..e9995086f5 100644 --- a/monai/handlers/earlystop_handler.py +++ b/monai/handlers/earlystop_handler.py @@ -9,7 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable, Optional +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING from monai.config import IgniteInfo from monai.utils import min_version, optional_import @@ -55,7 +58,7 @@ def __init__( self, patience: int, score_function: Callable, - trainer: Optional[Engine] = None, + trainer: Engine | None = None, min_delta: float = 0.0, cumulative_delta: bool = False, epoch_level: bool = True, @@ -80,7 +83,7 @@ def attach(self, engine: Engine) -> None: else: engine.add_event_handler(Events.ITERATION_COMPLETED, self) - def set_trainer(self, trainer: Engine): + def set_trainer(self, trainer: Engine) -> None: """ Set trainer to execute early stop if not setting properly in `__init__()`. """ diff --git a/monai/handlers/garbage_collector.py b/monai/handlers/garbage_collector.py index 74ccac8a72..858c78095a 100644 --- a/monai/handlers/garbage_collector.py +++ b/monai/handlers/garbage_collector.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import gc from typing import TYPE_CHECKING diff --git a/monai/handlers/hausdorff_distance.py b/monai/handlers/hausdorff_distance.py index 739c9e9935..ef4136906c 100644 --- a/monai/handlers/hausdorff_distance.py +++ b/monai/handlers/hausdorff_distance.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Union +from __future__ import annotations + +from collections.abc import Callable from monai.handlers.ignite_metric import IgniteMetric from monai.metrics import HausdorffDistanceMetric @@ -25,9 +27,9 @@ def __init__( self, include_background: bool = False, distance_metric: str = "euclidean", - percentile: Optional[float] = None, + percentile: float | None = None, directed: bool = False, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + reduction: MetricReduction | str = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py index d6f3f50144..822da0aa18 100644 --- a/monai/handlers/ignite_metric.py +++ b/monai/handlers/ignite_metric.py @@ -9,8 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any import torch @@ -19,17 +22,20 @@ from monai.utils import min_version, optional_import idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") -Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric", as_type="base") -reinit__is_reduced, _ = optional_import( - "ignite.metrics.metric", IgniteInfo.OPT_IMPORT_VERSION, min_version, "reinit__is_reduced", as_type="decorator" -) + if TYPE_CHECKING: from ignite.engine import Engine + from ignite.metrics import Metric + from ignite.metrics.metric import reinit__is_reduced else: Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric", as_type="base") + reinit__is_reduced, _ = optional_import( + "ignite.metrics.metric", IgniteInfo.OPT_IMPORT_VERSION, min_version, "reinit__is_reduced", as_type="decorator" + ) -class IgniteMetric(Metric): # type: ignore[valid-type, misc] # due to optional_import +class IgniteMetric(Metric): """ Base Metric class based on ignite event handler mechanism. The input `prediction` or `label` data can be a PyTorch Tensor or numpy array with batch dim and channel dim, @@ -55,9 +61,9 @@ def __init__( self._is_reduced: bool = False self.metric_fn = metric_fn self.save_details = save_details - self._scores: List = [] - self._engine: Optional[Engine] = None - self._name: Optional[str] = None + self._scores: list = [] + self._engine: Engine | None = None + self._name: str | None = None super().__init__(output_transform) @reinit__is_reduced @@ -107,7 +113,7 @@ def compute(self) -> Any: result = result.item() return result - def attach(self, engine: Engine, name: str) -> None: + def attach(self, engine: Engine, name: str) -> None: # type: ignore[override] """ Attaches current metric to provided engine. On the end of engine's run, `engine.state.metrics` dictionary will contain computed metric's value under provided name. diff --git a/monai/handlers/logfile_handler.py b/monai/handlers/logfile_handler.py index 73c58431a9..df6ebd34a7 100644 --- a/monai/handlers/logfile_handler.py +++ b/monai/handlers/logfile_handler.py @@ -9,9 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import os -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from monai.config import IgniteInfo from monai.utils import min_version, optional_import @@ -58,8 +60,8 @@ def __init__( self.loglevel: int = loglevel self.formatter: str = formatter self.create_dir: bool = create_dir - self.logger: Optional[logging.Logger] = None - self.handler: Optional[logging.FileHandler] = None + self.logger: logging.Logger | None = None + self.handler: logging.FileHandler | None = None def attach(self, engine: Engine) -> None: self.logger = engine.logger diff --git a/monai/handlers/lr_schedule_handler.py b/monai/handlers/lr_schedule_handler.py index 66059bba95..a79722517d 100644 --- a/monai/handlers/lr_schedule_handler.py +++ b/monai/handlers/lr_schedule_handler.py @@ -9,8 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from collections.abc import Callable +from typing import TYPE_CHECKING, Any from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler @@ -33,9 +36,9 @@ class LrScheduleHandler: def __init__( self, - lr_scheduler: Union[_LRScheduler, ReduceLROnPlateau], + lr_scheduler: _LRScheduler | ReduceLROnPlateau, print_lr: bool = True, - name: Optional[str] = None, + name: str | None = None, epoch_level: bool = True, step_transform: Callable[[Engine], Any] = lambda engine: (), ) -> None: diff --git a/monai/handlers/mean_dice.py b/monai/handlers/mean_dice.py index c5609c6746..aa3b5b0763 100644 --- a/monai/handlers/mean_dice.py +++ b/monai/handlers/mean_dice.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Union +from __future__ import annotations + +from collections.abc import Callable from monai.handlers.ignite_metric import IgniteMetric from monai.metrics import DiceMetric @@ -24,7 +26,7 @@ class MeanDice(IgniteMetric): def __init__( self, include_background: bool = True, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + reduction: MetricReduction | str = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: diff --git a/monai/handlers/mean_iou.py b/monai/handlers/mean_iou.py index ee4602e6a7..2fc0d5f8ab 100644 --- a/monai/handlers/mean_iou.py +++ b/monai/handlers/mean_iou.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Union +from __future__ import annotations + +from collections.abc import Callable from monai.handlers.ignite_metric import IgniteMetric from monai.metrics import MeanIoU @@ -24,7 +26,7 @@ class MeanIoUHandler(IgniteMetric): def __init__( self, include_background: bool = True, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + reduction: MetricReduction | str = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: diff --git a/monai/handlers/metric_logger.py b/monai/handlers/metric_logger.py index 334f631b88..d59205a021 100644 --- a/monai/handlers/metric_logger.py +++ b/monai/handlers/metric_logger.py @@ -9,10 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections import defaultdict +from collections.abc import Callable, Mapping, Sequence from enum import Enum from threading import RLock -from typing import TYPE_CHECKING, Callable, DefaultDict, List, Optional +from typing import TYPE_CHECKING, Any from monai.config import IgniteInfo from monai.utils import min_version, optional_import @@ -27,7 +30,7 @@ ) -def _get_loss_from_output(output, loss_key: str = CommonKeys.LOSS): +def _get_loss_from_output(output: Sequence[Mapping[str, Any]], loss_key: str = CommonKeys.LOSS) -> Any: return output[0][loss_key] @@ -70,12 +73,12 @@ def __init__( self, loss_transform: Callable = _get_loss_from_output, metric_transform: Callable = lambda x: x, - evaluator: Optional[Engine] = None, + evaluator: Engine | None = None, ) -> None: self.loss_transform = loss_transform self.metric_transform = metric_transform - self.loss: List = [] - self.metrics: DefaultDict = defaultdict(list) + self.loss: list = [] + self.metrics: defaultdict = defaultdict(list) self.iteration = 0 self.lock = RLock() diff --git a/monai/handlers/metrics_reloaded_handler.py b/monai/handlers/metrics_reloaded_handler.py new file mode 100644 index 0000000000..7239c9dee0 --- /dev/null +++ b/monai/handlers/metrics_reloaded_handler.py @@ -0,0 +1,115 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable + +from monai.handlers.ignite_metric import IgniteMetric +from monai.metrics import MetricsReloadedBinary, MetricsReloadedCategorical +from monai.utils.enums import MetricReduction + + +class MetricsReloadedBinaryHandler(IgniteMetric): + """ + Handler of MetricsReloadedBinary, which wraps the binary pairwise metrics of MetricsReloaded. + """ + + def __init__( + self, + metric_name: str, + include_background: bool = True, + reduction: MetricReduction | str = MetricReduction.MEAN, + get_not_nans: bool = False, + output_transform: Callable = lambda x: x, + save_details: bool = True, + ) -> None: + """ + + Args: + metric_name: Name of a binary metric from the MetricsReloaded package. + include_background: whether to skip computation on the first channel of + the predicted output. Defaults to ``True``. + reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. + get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). + Here `not_nans` count the number of not nans for the metric, + thus its shape equals to the shape of the metric. + output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then + construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or + lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. + save_details: whether to save metric computation details per image, for example: TP/TN/FP/FN of every image. + default to True, will save to `engine.state.metric_details` dict with the metric name as key. + + See also: + :py:meth:`monai.metrics.wrapper` + """ + metric_fn = MetricsReloadedBinary( + metric_name=metric_name, + include_background=include_background, + reduction=reduction, + get_not_nans=get_not_nans, + ) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) + + +class MetricsReloadedCategoricalHandler(IgniteMetric): + """ + Handler of MetricsReloadedCategorical, which wraps the categorical pairwise metrics of MetricsReloaded. + """ + + def __init__( + self, + metric_name: str, + include_background: bool = True, + reduction: MetricReduction | str = MetricReduction.MEAN, + get_not_nans: bool = False, + smooth_dr: float = 1e-5, + output_transform: Callable = lambda x: x, + save_details: bool = True, + ) -> None: + """ + + Args: + metric_name: Name of a categorical metric from the MetricsReloaded package. + include_background: whether to skip computation on the first channel of + the predicted output. Defaults to ``True``. + reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. + get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). + Here `not_nans` count the number of not nans for the metric, + thus its shape equals to the shape of the metric. + smooth_dr: a small constant added to the denominator to avoid nan. OBS: should be greater than zero. + output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then + construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or + lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. + save_details: whether to save metric computation details per image, for example: TP/TN/FP/FN of every image. + default to True, will save to `engine.state.metric_details` dict with the metric name as key. + + See also: + :py:meth:`monai.metrics.wrapper` + """ + metric_fn = MetricsReloadedCategorical( + metric_name=metric_name, + include_background=include_background, + reduction=reduction, + get_not_nans=get_not_nans, + smooth_dr=smooth_dr, + ) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py index 0fc62c68ea..88a0926b91 100644 --- a/monai/handlers/metrics_saver.py +++ b/monai/handlers/metrics_saver.py @@ -9,7 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Union +from __future__ import annotations + +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING from monai.config import IgniteInfo from monai.data import decollate_batch @@ -79,10 +82,10 @@ class mean median max 5percentile 95percentile notnans def __init__( self, save_dir: str, - metrics: Optional[Union[str, Sequence[str]]] = "*", - metric_details: Optional[Union[str, Sequence[str]]] = None, + metrics: str | Sequence[str] | None = "*", + metric_details: str | Sequence[str] | None = None, batch_transform: Callable = lambda x: x, - summary_ops: Optional[Union[str, Sequence[str]]] = None, + summary_ops: str | Sequence[str] | None = None, save_rank: int = 0, delimiter: str = ",", output_type: str = "csv", @@ -95,7 +98,7 @@ def __init__( self.save_rank = save_rank self.deli = delimiter self.output_type = output_type - self._filenames: List[str] = [] + self._filenames: list[str] = [] def attach(self, engine: Engine) -> None: """ diff --git a/monai/handlers/mlflow_handler.py b/monai/handlers/mlflow_handler.py index 340138a372..e49a2e967e 100644 --- a/monai/handlers/mlflow_handler.py +++ b/monai/handlers/mlflow_handler.py @@ -9,10 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import time +from collections.abc import Callable, Sequence from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any import torch @@ -57,7 +60,12 @@ class MLFlowHandler: to log data to a directory. The URI defaults to path `mlruns`. for more details: https://mlflow.org/docs/latest/python_api/mlflow.html#mlflow.set_tracking_uri. iteration_log: whether to log data to MLFlow when iteration completed, default to `True`. + ``iteration_log`` can be also a function and it will be interpreted as an event filter + (see https://pytorch.org/ignite/generated/ignite.engine.events.Events.html for details). + Event filter function accepts as input engine and event value (iteration) and should return True/False. epoch_log: whether to log data to MLFlow when epoch completed, default to `True`. + ``epoch_log`` can be also a function and it will be interpreted as an event filter. + See ``iteration_log`` argument for more details. epoch_logger: customized callable logger for epoch level logging with MLFlow. Must accept parameter "engine", use default logger if None. iteration_logger: customized callable logger for iteration level logging with MLFlow. @@ -94,20 +102,20 @@ class MLFlowHandler: def __init__( self, - tracking_uri: Optional[str] = None, - iteration_log: bool = True, - epoch_log: bool = True, - epoch_logger: Optional[Callable[[Engine], Any]] = None, - iteration_logger: Optional[Callable[[Engine], Any]] = None, + tracking_uri: str | None = None, + iteration_log: bool | Callable[[Engine, int], bool] = True, + epoch_log: bool | Callable[[Engine, int], bool] = True, + epoch_logger: Callable[[Engine], Any] | None = None, + iteration_logger: Callable[[Engine], Any] | None = None, output_transform: Callable = lambda x: x[0], global_epoch_transform: Callable = lambda x: x, - state_attributes: Optional[Sequence[str]] = None, + state_attributes: Sequence[str] | None = None, tag_name: str = DEFAULT_TAG, experiment_name: str = "monai_experiment", - run_name: Optional[str] = None, - experiment_param: Optional[Dict] = None, - artifacts: Optional[Union[str, Sequence[Path]]] = None, - optimizer_param_names: Union[str, Sequence[str]] = "lr", + run_name: str | None = None, + experiment_param: dict | None = None, + artifacts: str | Sequence[Path] | None = None, + optimizer_param_names: str | Sequence[str] = "lr", close_on_complete: bool = False, ) -> None: self.iteration_log = iteration_log @@ -128,7 +136,7 @@ def __init__( self.experiment = None self.cur_run = None - def _delete_exist_param_in_dict(self, param_dict: Dict) -> None: + def _delete_exist_param_in_dict(self, param_dict: dict) -> None: """ Delete parameters in given dict, if they are already logged by current mlflow run. @@ -156,9 +164,15 @@ def attach(self, engine: Engine) -> None: if not engine.has_event_handler(self.start, Events.STARTED): engine.add_event_handler(Events.STARTED, self.start) if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): - engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) + event = Events.ITERATION_COMPLETED + if callable(self.iteration_log): # substitute event with new one using filter callable + event = event(event_filter=self.iteration_log) + engine.add_event_handler(event, self.iteration_completed) if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): - engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed) + event = Events.EPOCH_COMPLETED + if callable(self.epoch_log): # substitute event with new one using filter callable + event = event(event_filter=self.epoch_log) + engine.add_event_handler(event, self.epoch_completed) if not engine.has_event_handler(self.complete, Events.COMPLETED): engine.add_event_handler(Events.COMPLETED, self.complete) if self.close_on_complete and (not engine.has_event_handler(self.close, Events.COMPLETED)): @@ -201,13 +215,13 @@ def _set_experiment(self): raise ValueError(f"Cannot set a deleted experiment '{self.experiment_name}' as the active experiment") self.experiment = experiment - def _log_params(self, params: Dict[str, Any]) -> None: + def _log_params(self, params: dict[str, Any]) -> None: if not self.cur_run: raise ValueError("Current Run is not Active to log params") params_arr = [mlflow.entities.Param(key, str(value)) for key, value in params.items()] self.client.log_batch(run_id=self.cur_run.info.run_id, metrics=[], params=params_arr, tags=[]) - def _log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: + def _log_metrics(self, metrics: dict[str, Any], step: int | None = None) -> None: if not self.cur_run: raise ValueError("Current Run is not Active to log metrics") diff --git a/monai/handlers/nvtx_handlers.py b/monai/handlers/nvtx_handlers.py index 19f1b2e2bb..38eef6f05b 100644 --- a/monai/handlers/nvtx_handlers.py +++ b/monai/handlers/nvtx_handlers.py @@ -12,7 +12,9 @@ Wrapper around NVIDIA Tools Extension for profiling MONAI ignite workflow """ -from typing import TYPE_CHECKING, Optional, Tuple, Union +from __future__ import annotations + +from typing import TYPE_CHECKING from monai.config import IgniteInfo from monai.utils import ensure_tuple, min_version, optional_import @@ -52,9 +54,7 @@ class RangeHandler: If not provided, the name of first event will be assigned to the NVTX range. """ - def __init__( - self, events: Union[str, Tuple[Union[str, Events], Union[str, Events]]], msg: Optional[str] = None - ) -> None: + def __init__(self, events: str | tuple[str | Events, str | Events], msg: str | None = None) -> None: self.events = self.resolve_events(events) if msg is None: if isinstance(events, str): @@ -66,7 +66,7 @@ def __init__( self.msg = msg self.depth = None - def resolve_events(self, events: Union[str, Tuple]) -> Tuple[Events, Events]: + def resolve_events(self, events: str | tuple) -> tuple[Events, Events]: """ Resolve the input events to create a pair of Ignite events """ @@ -77,7 +77,7 @@ def resolve_events(self, events: Union[str, Tuple]) -> Tuple[Events, Events]: return self.get_event(events[0]), self.get_event(events[1]) raise ValueError(f"Exactly two Ignite events should be provided [received {len(events)}].") - def create_paired_events(self, event: str) -> Tuple[Events, Events]: + def create_paired_events(self, event: str) -> tuple[Events, Events]: """ Create pair of Ignite events from a event prefix name """ @@ -85,7 +85,7 @@ def create_paired_events(self, event: str) -> Tuple[Events, Events]: event_prefix = {"": "", "ENGINE": "", "EPOCH": "EPOCH_", "ITERATION": "ITERATION_", "BATCH": "GET_BATCH_"} return self.get_event(event_prefix[event] + "STARTED"), self.get_event(event_prefix[event] + "COMPLETED") - def get_event(self, event: Union[str, Events]) -> Events: + def get_event(self, event: str | Events) -> Events: return Events[event.upper()] if isinstance(event, str) else event def attach(self, engine: Engine) -> None: @@ -113,7 +113,7 @@ class RangePushHandler: msg: ASCII message to associate with range """ - def __init__(self, event: Union[str, Events], msg: Optional[str] = None) -> None: + def __init__(self, event: str | Events, msg: str | None = None) -> None: self.event = Events[event.upper()] if isinstance(event, str) else event if msg is None: msg = self.event.name @@ -141,7 +141,7 @@ class RangePopHandler: msg: ASCII message to associate with range """ - def __init__(self, event: Union[str, Events]) -> None: + def __init__(self, event: str | Events) -> None: self.event = Events[event.upper()] if isinstance(event, str) else event def attach(self, engine: Engine) -> None: @@ -164,7 +164,7 @@ class MarkHandler: msg: ASCII message to associate with range """ - def __init__(self, event: Union[str, Events], msg: Optional[str] = None) -> None: + def __init__(self, event: str | Events, msg: str | None = None) -> None: self.event = Events[event.upper()] if isinstance(event, str) else event if msg is None: msg = self.event.name diff --git a/monai/handlers/panoptic_quality.py b/monai/handlers/panoptic_quality.py index ffa0aee03a..4bf561826c 100644 --- a/monai/handlers/panoptic_quality.py +++ b/monai/handlers/panoptic_quality.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Union +from __future__ import annotations + +from collections.abc import Callable from monai.handlers.ignite_metric import IgniteMetric from monai.metrics import PanopticQualityMetric @@ -25,7 +27,7 @@ def __init__( self, num_classes: int, metric_name: str = "pq", - reduction: Union[MetricReduction, str] = MetricReduction.MEAN_BATCH, + reduction: MetricReduction | str = MetricReduction.MEAN_BATCH, match_iou_threshold: float = 0.5, smooth_numerator: float = 1e-6, output_transform: Callable = lambda x: x, diff --git a/monai/handlers/parameter_scheduler.py b/monai/handlers/parameter_scheduler.py index 233abca2e0..d12e6e072c 100644 --- a/monai/handlers/parameter_scheduler.py +++ b/monai/handlers/parameter_scheduler.py @@ -9,9 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging from bisect import bisect_right -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union +from collections.abc import Callable +from typing import TYPE_CHECKING from monai.config import IgniteInfo from monai.utils import min_version, optional_import @@ -41,11 +44,11 @@ class ParamSchedulerHandler: def __init__( self, parameter_setter: Callable, - value_calculator: Union[str, Callable], - vc_kwargs: Dict, + value_calculator: str | Callable, + vc_kwargs: dict, epoch_level: bool = False, - name: Optional[str] = None, - event=None, + name: str | None = None, + event: str | None = None, ): self.epoch_level = epoch_level self.event = event if event is not None else Events.ITERATION_COMPLETED @@ -73,7 +76,7 @@ def _get_value_calculator(self, value_calculator): f"value_calculator must be either a string from {list(self._calculators.keys())} or a Callable." ) - def __call__(self, engine: Engine): + def __call__(self, engine: Engine) -> None: if self.epoch_level: self._vc_kwargs["current_step"] = engine.state.epoch else: @@ -156,7 +159,7 @@ def _step(initial_value: float, gamma: float, step_size: int, current_step: int) return initial_value * gamma ** (current_step // step_size) @staticmethod - def _multistep(initial_value: float, gamma: float, milestones: List[int], current_step: int) -> float: + def _multistep(initial_value: float, gamma: float, milestones: list[int], current_step: int) -> float: """ Decays the parameter value by gamma once the number of steps reaches one of the milestones. diff --git a/monai/handlers/postprocessing.py b/monai/handlers/postprocessing.py index 4a89c86f47..c698c84338 100644 --- a/monai/handlers/postprocessing.py +++ b/monai/handlers/postprocessing.py @@ -9,7 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING from monai.config import IgniteInfo from monai.engines.utils import IterationEvents, engine_apply_transform diff --git a/monai/handlers/probability_maps.py b/monai/handlers/probability_maps.py index d39b640722..8a60fcc983 100644 --- a/monai/handlers/probability_maps.py +++ b/monai/handlers/probability_maps.py @@ -9,9 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import threading -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING import numpy as np @@ -41,7 +43,7 @@ def __init__( output_postfix: str = "", prob_key: str = "pred", dtype: DtypeLike = np.float64, - name: Optional[str] = None, + name: str | None = None, ) -> None: """ Args: @@ -65,8 +67,8 @@ def __init__( self._name = name self.prob_key = prob_key self.dtype = dtype - self.prob_map: Dict[str, np.ndarray] = {} - self.counter: Dict[str, int] = {} + self.prob_map: dict[str, np.ndarray] = {} + self.counter: dict[str, int] = {} self.num_done_images: int = 0 self.num_images: int = 0 self.lock = threading.Lock() @@ -128,5 +130,5 @@ def save_prob_map(self, name: str) -> None: del self.prob_map[name] del self.counter[name] - def finalize(self, engine: Engine): + def finalize(self, engine: Engine) -> None: self.logger.info(f"Probability map is created for {self.num_done_images}/{self.num_images} images!") diff --git a/monai/handlers/regression_metrics.py b/monai/handlers/regression_metrics.py index bf4ac3af1d..fee7238491 100644 --- a/monai/handlers/regression_metrics.py +++ b/monai/handlers/regression_metrics.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Union +from __future__ import annotations + +from collections.abc import Callable from monai.handlers.ignite_metric import IgniteMetric from monai.metrics import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric @@ -23,7 +25,7 @@ class MeanSquaredError(IgniteMetric): def __init__( self, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + reduction: MetricReduction | str = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: @@ -56,7 +58,7 @@ class MeanAbsoluteError(IgniteMetric): def __init__( self, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + reduction: MetricReduction | str = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: @@ -89,7 +91,7 @@ class RootMeanSquaredError(IgniteMetric): def __init__( self, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + reduction: MetricReduction | str = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: @@ -122,8 +124,8 @@ class PeakSignalToNoiseRatio(IgniteMetric): def __init__( self, - max_val: Union[int, float], - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + max_val: int | float, + reduction: MetricReduction | str = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: diff --git a/monai/handlers/roc_auc.py b/monai/handlers/roc_auc.py index 68cf2e655e..a521a4cc06 100644 --- a/monai/handlers/roc_auc.py +++ b/monai/handlers/roc_auc.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Union +from __future__ import annotations + +from collections.abc import Callable from monai.handlers.ignite_metric import IgniteMetric from monai.metrics import ROCAUCMetric @@ -46,6 +48,6 @@ class ROCAUC(IgniteMetric): """ - def __init__(self, average: Union[Average, str] = Average.MACRO, output_transform: Callable = lambda x: x) -> None: + def __init__(self, average: Average | str = Average.MACRO, output_transform: Callable = lambda x: x) -> None: metric_fn = ROCAUCMetric(average=Average(average)) super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False) diff --git a/monai/handlers/smartcache_handler.py b/monai/handlers/smartcache_handler.py index 56fee78b1d..ee043635db 100644 --- a/monai/handlers/smartcache_handler.py +++ b/monai/handlers/smartcache_handler.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import TYPE_CHECKING from monai.config import IgniteInfo diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index b8acb4d0ad..4e4ad78798 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -9,14 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import warnings -from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any import torch +from monai.apps import get_logger from monai.config import IgniteInfo -from monai.utils import is_scalar, min_version, optional_import +from monai.utils import deprecated_arg_default, is_scalar, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") if TYPE_CHECKING: @@ -36,11 +40,11 @@ class StatsHandler: It can be used for any Ignite Engine(trainer, validator and evaluator). And it can support logging for epoch level and iteration level with pre-defined loggers. - Note that if `name` arg is None, will leverage `engine.logger` as default logger directly, otherwise, - get logger from `logging.getLogger(name)`, we can setup a logger outside first with the same `name`. - As the default log level of `RootLogger` is `WARNING`, may need to call - `logging.basicConfig(stream=sys.stdout, level=logging.INFO)` before running this handler to enable - the stats logging. + Note that if ``name`` is None, this class will leverage `engine.logger` as the logger, otherwise, + ``logging.getLogger(name)`` is used. In both cases, it's important to make sure that the logging level is at least + ``INFO``. To change the level of logging, please call ``import ignite; ignite.utils.setup_logger(name)`` + (when ``name`` is not None) or ``engine.logger = ignite.utils.setup_logger(engine.logger.name, reset=True)`` + (when ``name`` is None) before running the engine with this handler attached. Default behaviors: - When EPOCH_COMPLETED, logs ``engine.state.metrics`` using ``self.logger``. @@ -49,36 +53,44 @@ class StatsHandler: Usage example:: - logging.basicConfig(stream=sys.stdout, level=logging.INFO) + import ignite + import monai - trainer = SupervisedTrainer(...) - StatsHandler(name="train_stats").attach(trainer) + trainer = ignite.engine.Engine(lambda x, y: [0.0]) # an example trainer + monai.handlers.StatsHandler(name="train_stats").attach(trainer) - trainer.run() + trainer.run(range(3), max_epochs=4) More details of example is available in the tutorial: https://github.com/Project-MONAI/tutorials/blob/master/modules/engines/unet_training_dict.py. """ + @deprecated_arg_default("name", old_default=None, new_default="StatsHandler", since="1.1", replaced="1.3") def __init__( self, - iteration_log: bool = True, - epoch_log: bool = True, - epoch_print_logger: Optional[Callable[[Engine], Any]] = None, - iteration_print_logger: Optional[Callable[[Engine], Any]] = None, + iteration_log: bool | Callable[[Engine, int], bool] = True, + epoch_log: bool | Callable[[Engine, int], bool] = True, + epoch_print_logger: Callable[[Engine], Any] | None = None, + iteration_print_logger: Callable[[Engine], Any] | None = None, output_transform: Callable = lambda x: x[0], global_epoch_transform: Callable = lambda x: x, - state_attributes: Optional[Sequence[str]] = None, - name: Optional[str] = None, + state_attributes: Sequence[str] | None = None, + name: str | None = None, tag_name: str = DEFAULT_TAG, key_var_format: str = DEFAULT_KEY_VAL_FORMAT, ) -> None: """ Args: - iteration_log: whether to log data when iteration completed, default to `True`. - epoch_log: whether to log data when epoch completed, default to `True`. + iteration_log: whether to log data when iteration completed, default to `True`. ``iteration_log`` can + be also a function and it will be interpreted as an event filter + (see https://pytorch.org/ignite/generated/ignite.engine.events.Events.html for details). + Event filter function accepts as input engine and event value (iteration) and should return True/False. + Event filtering can be helpful to customize iteration logging frequency. + epoch_log: whether to log data when epoch completed, default to `True`. ``epoch_log`` can be + also a function and it will be interpreted as an event filter. See ``iteration_log`` argument for more + details. epoch_print_logger: customized callable printer for epoch level logging. Must accept parameter "engine", use default printer if None. iteration_print_logger: customized callable printer for iteration level logging. @@ -113,7 +125,7 @@ def __init__( self.state_attributes = state_attributes self.tag_name = tag_name self.key_var_format = key_var_format - self.logger = logging.getLogger(name) # if `name` is None, will default to `engine.logger` when attached + self.logger = get_logger(name) # type: ignore self.name = name def attach(self, engine: Engine) -> None: @@ -126,15 +138,25 @@ def attach(self, engine: Engine) -> None: """ if self.name is None: self.logger = engine.logger - if self.logger.getEffectiveLevel() > logging.INFO or logging.root.getEffectiveLevel() > logging.INFO: + if self.logger.getEffectiveLevel() > logging.INFO: + suggested = f"\n\nimport ignite\nignite.utils.setup_logger('{self.logger.name}', reset=True)" + if self.logger.name != engine.logger.name: + suggested += f"\nignite.utils.setup_logger('{engine.logger.name}', reset=True)" + suggested += "\n\n" warnings.warn( - "the effective log level of engine logger or RootLogger is higher than INFO, may not record log," - " please call `logging.basicConfig(stream=sys.stdout, level=logging.INFO)` to enable it." + f"the effective log level of {self.logger.name} is higher than INFO, StatsHandler may not output logs," + f"\nplease use the following code before running the engine to enable it: {suggested}" ) if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): - engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) + event = Events.ITERATION_COMPLETED + if callable(self.iteration_log): # substitute event with new one using filter callable + event = event(event_filter=self.iteration_log) + engine.add_event_handler(event, self.iteration_completed) if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): - engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed) + event = Events.EPOCH_COMPLETED + if callable(self.epoch_log): # substitute event with new one using filter callable + event = event(event_filter=self.epoch_log) + engine.add_event_handler(event, self.epoch_completed) if not engine.has_event_handler(self.exception_raised, Events.EXCEPTION_RAISED): engine.add_event_handler(Events.EXCEPTION_RAISED, self.exception_raised) diff --git a/monai/handlers/surface_distance.py b/monai/handlers/surface_distance.py index 77f0debfe9..eb80b41a07 100644 --- a/monai/handlers/surface_distance.py +++ b/monai/handlers/surface_distance.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Union +from __future__ import annotations + +from collections.abc import Callable from monai.handlers.ignite_metric import IgniteMetric from monai.metrics import SurfaceDistanceMetric @@ -26,7 +28,7 @@ def __init__( include_background: bool = False, symmetric: bool = False, distance_metric: str = "euclidean", - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + reduction: MetricReduction | str = MetricReduction.MEAN, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index 14701e79d9..3eb2fe1280 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -9,25 +9,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any import numpy as np import torch from monai.config import IgniteInfo -from monai.utils import is_scalar, min_version, optional_import +from monai.utils import deprecated_arg, is_scalar, min_version, optional_import from monai.visualize import plot_2d_or_3d_image Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") -SummaryWriter, _ = optional_import("torch.utils.tensorboard", name="SummaryWriter") if TYPE_CHECKING: from ignite.engine import Engine + from tensorboardX import SummaryWriter as SummaryWriterX + from torch.utils.tensorboard import SummaryWriter else: Engine, _ = optional_import( "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine", as_type="decorator" ) + SummaryWriter, _ = optional_import("torch.utils.tensorboard", name="SummaryWriter") + SummaryWriterX, _ = optional_import("tensorboardX", name="SummaryWriter") DEFAULT_TAG = "Loss" @@ -43,7 +49,7 @@ class TensorBoardHandler: """ - def __init__(self, summary_writer=None, log_dir: str = "./runs"): + def __init__(self, summary_writer: SummaryWriter | SummaryWriterX | None = None, log_dir: str = "./runs"): if summary_writer is None: self._writer = SummaryWriter(log_dir=log_dir) self.internal_writer = True @@ -81,19 +87,21 @@ class TensorBoardStatsHandler(TensorBoardHandler): """ + @deprecated_arg("epoch_interval", since="1.1", removed="1.3") + @deprecated_arg("iteration_interval", since="1.1", removed="1.3") def __init__( self, - summary_writer=None, + summary_writer: SummaryWriter | SummaryWriterX | None = None, log_dir: str = "./runs", - iteration_log: bool = True, - epoch_log: bool = True, - epoch_event_writer: Optional[Callable[[Engine, Any], Any]] = None, + iteration_log: bool | Callable[[Engine, int], bool] = True, + epoch_log: bool | Callable[[Engine, int], bool] = True, + epoch_event_writer: Callable[[Engine, Any], Any] | None = None, epoch_interval: int = 1, - iteration_event_writer: Optional[Callable[[Engine, Any], Any]] = None, + iteration_event_writer: Callable[[Engine, Any], Any] | None = None, iteration_interval: int = 1, output_transform: Callable = lambda x: x[0], global_epoch_transform: Callable = lambda x: x, - state_attributes: Optional[Sequence[str]] = None, + state_attributes: Sequence[str] | None = None, tag_name: str = DEFAULT_TAG, ) -> None: """ @@ -102,13 +110,20 @@ def __init__( default to create a new TensorBoard writer. log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`. iteration_log: whether to write data to TensorBoard when iteration completed, default to `True`. + ``iteration_log`` can be also a function and it will be interpreted as an event filter + (see https://pytorch.org/ignite/generated/ignite.engine.events.Events.html for details). + Event filter function accepts as input engine and event value (iteration) and should return True/False. epoch_log: whether to write data to TensorBoard when epoch completed, default to `True`. + ``epoch_log`` can be also a function and it will be interpreted as an event filter. + See ``iteration_log`` argument for more details. epoch_event_writer: customized callable TensorBoard writer for epoch level. Must accept parameter "engine" and "summary_writer", use default event writer if None. epoch_interval: the epoch interval at which the epoch_event_writer is called. Defaults to 1. + ``epoch_interval`` must be 1 if ``epoch_log`` is callable. iteration_event_writer: customized callable TensorBoard writer for iteration level. Must accept parameter "engine" and "summary_writer", use default event writer if None. iteration_interval: the iteration interval at which the iteration_event_writer is called. Defaults to 1. + ``iteration_interval`` must be 1 if ``iteration_log`` is callable. output_transform: a callable that is used to transform the ``ignite.engine.state.output`` into a scalar to plot, or a dictionary of {key: scalar}. In the latter case, the output string will be formatted as key: value. @@ -125,6 +140,12 @@ def __init__( when epoch completed. tag_name: when iteration output is a scalar, tag_name is used to plot, defaults to ``'Loss'``. """ + if callable(iteration_log) and iteration_interval > 1: + raise ValueError("If iteration_log is callable, then iteration_interval should be 1") + + if callable(epoch_log) and epoch_interval > 1: + raise ValueError("If epoch_log is callable, then epoch_interval should be 1") + super().__init__(summary_writer=summary_writer, log_dir=log_dir) self.iteration_log = iteration_log self.epoch_log = epoch_log @@ -146,11 +167,19 @@ def attach(self, engine: Engine) -> None: """ if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): - engine.add_event_handler( - Events.ITERATION_COMPLETED(every=self.iteration_interval), self.iteration_completed - ) + event = Events.ITERATION_COMPLETED + if callable(self.iteration_log): # substitute event with new one using filter callable + event = event(event_filter=self.iteration_log) + elif self.iteration_interval > 1: + event = event(every=self.iteration_interval) + engine.add_event_handler(event, self.iteration_completed) if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): - engine.add_event_handler(Events.EPOCH_COMPLETED(every=self.epoch_interval), self.epoch_completed) + event = Events.EPOCH_COMPLETED + if callable(self.epoch_log): # substitute event with new one using filter callable + event = event(event_filter=self.epoch_log) + elif self.epoch_log > 1: + event = event(every=self.epoch_interval) + engine.add_event_handler(event, self.epoch_completed) def epoch_completed(self, engine: Engine) -> None: """ @@ -180,7 +209,9 @@ def iteration_completed(self, engine: Engine) -> None: else: self._default_iteration_writer(engine, self._writer) - def _write_scalar(self, _engine: Engine, writer, tag: str, value: Any, step: int) -> None: + def _write_scalar( + self, _engine: Engine, writer: SummaryWriter | SummaryWriterX, tag: str, value: Any, step: int + ) -> None: """ Write scale value into TensorBoard. Default to call `SummaryWriter.add_scalar()`. @@ -195,7 +226,7 @@ def _write_scalar(self, _engine: Engine, writer, tag: str, value: Any, step: int """ writer.add_scalar(tag, value, step) - def _default_epoch_writer(self, engine: Engine, writer) -> None: + def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter | SummaryWriterX) -> None: """ Execute epoch level event write operation. Default to write the values from Ignite `engine.state.metrics` dict and @@ -217,7 +248,7 @@ def _default_epoch_writer(self, engine: Engine, writer) -> None: self._write_scalar(engine, writer, attr, getattr(engine.state, attr, None), current_epoch) writer.flush() - def _default_iteration_writer(self, engine: Engine, writer) -> None: + def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter | SummaryWriterX) -> None: """ Execute iteration level event write operation based on Ignite `engine.state.output` data. Extract the values from `self.output_transform(engine.state.output)`. @@ -296,7 +327,7 @@ class TensorBoardImageHandler(TensorBoardHandler): def __init__( self, - summary_writer=None, + summary_writer: SummaryWriter | SummaryWriterX | None = None, log_dir: str = "./runs", interval: int = 1, epoch_level: bool = True, diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 9b2c4f716e..58a3fd36f3 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -9,9 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any import numpy as np import torch @@ -28,23 +31,23 @@ __all__ = ["stopping_fn_from_metric", "stopping_fn_from_loss", "write_metrics_reports", "from_engine"] -def stopping_fn_from_metric(metric_name: str): +def stopping_fn_from_metric(metric_name: str) -> Callable[[Engine], Any]: """ Returns a stopping function for ignite.handlers.EarlyStopping using the given metric name. """ - def stopping_fn(engine: Engine): + def stopping_fn(engine: Engine) -> Any: return engine.state.metrics[metric_name] return stopping_fn -def stopping_fn_from_loss(): +def stopping_fn_from_loss() -> Callable[[Engine], Any]: """ Returns a stopping function for ignite.handlers.EarlyStopping using the loss value. """ - def stopping_fn(engine: Engine): + def stopping_fn(engine: Engine) -> Any: return -engine.state.output # type:ignore return stopping_fn @@ -52,13 +55,13 @@ def stopping_fn(engine: Engine): def write_metrics_reports( save_dir: PathLike, - images: Optional[Sequence[str]], - metrics: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]], - metric_details: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]], - summary_ops: Optional[Union[str, Sequence[str]]], + images: Sequence[str] | None, + metrics: dict[str, torch.Tensor | np.ndarray] | None, + metric_details: dict[str, torch.Tensor | np.ndarray] | None, + summary_ops: str | Sequence[str] | None, deli: str = ",", output_type: str = "csv", -): +) -> None: """ Utility function to write the metrics into files, contains 3 parts: 1. if `metrics` dict is not None, write overall metrics into file, every line is a metric name and value pair. @@ -142,7 +145,7 @@ class mean median max 5percentile 95percentile notnans if "*" in ops: ops = tuple(supported_ops.keys()) - def _compute_op(op: str, d: np.ndarray): + def _compute_op(op: str, d: np.ndarray) -> Any: if not op.endswith("percentile"): c_op = look_up_option(op, supported_ops) return c_op(d) @@ -156,7 +159,7 @@ def _compute_op(op: str, d: np.ndarray): f.write(f"{class_labels[i]}{deli}{deli.join([f'{_compute_op(k, c):.4f}' for k in ops])}\n") -def from_engine(keys: KeysCollection, first: bool = False): +def from_engine(keys: KeysCollection, first: bool = False) -> Callable: """ Utility function to simplify the `batch_transform` or `output_transform` args of ignite components when handling dictionary or list of dictionaries(for example: `engine.state.batch` or `engine.state.output`). @@ -199,7 +202,7 @@ def _wrapper(data): return _wrapper -def ignore_data(x: Any): +def ignore_data(x: Any) -> None: """ Always return `None` for any input data. A typical usage is to avoid logging the engine output of every iteration during evaluation. diff --git a/monai/handlers/validation_handler.py b/monai/handlers/validation_handler.py index 171c901fbb..19183fb4e2 100644 --- a/monai/handlers/validation_handler.py +++ b/monai/handlers/validation_handler.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING from monai.config import IgniteInfo from monai.engines.evaluator import Evaluator @@ -29,7 +31,7 @@ class ValidationHandler: """ - def __init__(self, interval: int, validator: Optional[Evaluator] = None, epoch_level: bool = True) -> None: + def __init__(self, interval: int, validator: Evaluator | None = None, epoch_level: bool = True) -> None: """ Args: interval: do validation every N epochs or every N iterations during training. @@ -48,7 +50,7 @@ def __init__(self, interval: int, validator: Optional[Evaluator] = None, epoch_l self.interval = interval self.epoch_level = epoch_level - def set_validator(self, validator: Evaluator): + def set_validator(self, validator: Evaluator) -> None: """ Set validator if not setting in the __init__(). """ diff --git a/monai/inferers/__init__.py b/monai/inferers/__init__.py index 3447782be9..ad0ec77bcf 100644 --- a/monai/inferers/__init__.py +++ b/monai/inferers/__init__.py @@ -9,5 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .inferer import Inferer, SaliencyInferer, SimpleInferer, SliceInferer, SlidingWindowInferer +from __future__ import annotations + +from .inferer import Inferer, PatchInferer, SaliencyInferer, SimpleInferer, SliceInferer, SlidingWindowInferer +from .merger import AvgMerger, Merger +from .splitter import SlidingWindowSplitter, Splitter from .utils import sliding_window_inference diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index d0d2f932b5..03b5e7a75f 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -9,18 +9,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Union +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence +from pydoc import locate +from typing import Any import torch import torch.nn as nn +from monai.data.meta_tensor import MetaTensor +from monai.inferers.merger import AvgMerger, Merger +from monai.inferers.splitter import Splitter from monai.inferers.utils import compute_importance_map, sliding_window_inference -from monai.utils import BlendMode, PytorchPadMode, ensure_tuple +from monai.utils import BlendMode, PatchKeys, PytorchPadMode, ensure_tuple, optional_import from monai.visualize import CAM, GradCAM, GradCAMpp -__all__ = ["Inferer", "SimpleInferer", "SlidingWindowInferer", "SaliencyInferer", "SliceInferer"] +__all__ = ["Inferer", "PatchInferer", "SimpleInferer", "SlidingWindowInferer", "SaliencyInferer", "SliceInferer"] class Inferer(ABC): @@ -44,7 +51,7 @@ class Inferer(ABC): """ @abstractmethod - def __call__(self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any): + def __call__(self, inputs: torch.Tensor, network: Callable, *args: Any, **kwargs: Any) -> Any: """ Run inference on `inputs` with the `network` model. @@ -61,6 +68,238 @@ def __call__(self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], * raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") +class PatchInferer(Inferer): + """ + Inference on patches instead of the whole image based on Splitter and Merger. + This splits the input image into patches and then merge the resulted patches. + + Args: + splitter: a `Splitter` object that split the inputs into patches. Defaults to None. + If not provided or None, the inputs are considered to be already split into patches. + merger_cls: a `Merger` subclass that can be instantiated to merges patch outputs. + It can also be a string that matches the name of a class inherited from `Merger` class. + Defaults to `AvgMerger`. + batch_size: batch size for patches. If the input tensor is already batched [BxCxWxH], + this adds additional batching [(Bp*B)xCxWpxHp] for inference on patches. + Defaults to 1. + preprocessing: a callable that process patches before the being fed to the network. + Defaults to None. + postprocessing: a callable that process the output of the network. + Defaults to None. + output_keys: if the network output is a dictionary, this defines the keys of + the output dictionary to be used for merging. + Defaults to None, where all the keys are used. + merger_kwargs: arguments to be passed to `merger_cls` for instantiation. + `output_shape` is calculated automatically based on the input shape and + the output patch shape unless it is passed here. + """ + + def __init__( + self, + splitter: Splitter | Callable | None = None, + merger_cls: type[Merger] | str = AvgMerger, + batch_size: int = 1, + preprocessing: Callable | None = None, + postprocessing: Callable | None = None, + output_keys: Sequence | None = None, + **merger_kwargs: Any, + ) -> None: + Inferer.__init__(self) + + # splitter + if splitter is not None and not isinstance(splitter, Splitter): + if callable(splitter): + warnings.warn( + "`splitter` is a callable instead of `Splitter` object, please make sure that it returns " + "the correct values. Either Iterable[tuple[torch.Tensor, Sequence[int]]], or " + "a MetaTensor with defined `PatchKey.LOCATION` metadata." + ) + else: + raise TypeError( + f"'splitter' should be a `Splitter` object (or a callable that returns " + "an iterable of pairs of (patch, location) or a MetaTensor that has `PatchKeys.LOCATION` metadata)." + f"{type(splitter)} is given." + ) + self.splitter = splitter + + # merger + if isinstance(merger_cls, str): + valid_merger_cls: type[Merger] + # search amongst implemented mergers in MONAI + valid_merger_cls, merger_found = optional_import("monai.inferers.merger", name=merger_cls) + if not merger_found: + # try to locate the requested merger class (with dotted path) + valid_merger_cls = locate(merger_cls) # type: ignore + if valid_merger_cls is None: + raise ValueError(f"The requested `merger_cls` ['{merger_cls}'] does not exist.") + merger_cls = valid_merger_cls + if not issubclass(merger_cls, Merger): + raise TypeError(f"'merger' should be a subclass of `Merger`, {merger_cls} is given.") + self.merger_cls = merger_cls + self.merger_kwargs = merger_kwargs + + # pre-processor (process patch before the network) + if preprocessing is not None and not callable(preprocessing): + raise TypeError(f"'preprocessing' should be a callable object, {type(preprocessing)} is given.") + self.preprocessing = preprocessing + + # post-processor (process the output of the network) + if postprocessing is not None and not callable(postprocessing): + raise TypeError(f"'postprocessing' should be a callable object, {type(postprocessing)} is given.") + self.postprocessing = postprocessing + + # batch size for patches + self.batch_size = batch_size + + # model output keys + self.output_keys = output_keys + + def _batch_sampler( + self, patches: Iterable[tuple[torch.Tensor, Sequence[int]]] | MetaTensor + ) -> Iterator[tuple[torch.Tensor, Sequence, int]]: + """Generate batch of patches and locations + + Args: + patches: a tensor or list of tensors + + Yields: + A batch of patches (torch.Tensor or MetaTensor), a sequence of location tuples, and the batch size + """ + if isinstance(patches, MetaTensor): + total_size = len(patches) + for i in range(0, total_size, self.batch_size): + batch_size = min(self.batch_size, total_size - i) + yield patches[i : i + batch_size], patches[i : i + batch_size].meta[PatchKeys.LOCATION], batch_size # type: ignore + else: + patch_batch: list[Any] = [None] * self.batch_size + location_batch: list[Any] = [None] * self.batch_size + idx_in_batch = 0 + for sample in patches: + patch_batch[idx_in_batch] = sample[0] + location_batch[idx_in_batch] = sample[1] + idx_in_batch += 1 + if idx_in_batch == self.batch_size: + # concatenate batch of patches to create a tensor + yield torch.cat(patch_batch), location_batch, idx_in_batch + patch_batch = [None] * self.batch_size + location_batch = [None] * self.batch_size + idx_in_batch = 0 + if idx_in_batch > 0: + # concatenate batch of patches to create a tensor + yield torch.cat(patch_batch[:idx_in_batch]), location_batch, idx_in_batch + + def _ensure_tuple_outputs(self, outputs: Any) -> tuple: + if isinstance(outputs, dict): + if self.output_keys is None: + self.output_keys = list(outputs.keys()) # model's output keys + return tuple(outputs[k] for k in self.output_keys) + return ensure_tuple(outputs, wrap_array=True) + + def _run_inference(self, network: Callable, patch: torch.Tensor, *args: Any, **kwargs: Any) -> tuple: + # pre-process + if self.preprocessing: + patch = self.preprocessing(patch) + # inference + outputs = network(patch, *args, **kwargs) + # post-process + if self.postprocessing: + outputs = self.postprocessing(outputs) + # ensure we have a tuple of model outputs to support multiple outputs + return self._ensure_tuple_outputs(outputs) + + def _initialize_mergers(self, inputs, outputs, patches, batch_size): + in_patch = torch.chunk(patches, batch_size)[0] + mergers = [] + ratios = [] + for out_patch_batch in outputs: + out_patch = torch.chunk(out_patch_batch, batch_size)[0] + # calculate the ratio of input and output patch sizes + ratio = tuple(op / ip for ip, op in zip(in_patch.shape[2:], out_patch.shape[2:])) + ratios.append(ratio) + # calculate output_shape only if it is not provided and splitter is not None. + if self.splitter is not None and "output_shape" not in self.merger_kwargs: + output_shape = self._get_output_shape(inputs, out_patch, ratio) + merger = self.merger_cls(output_shape=output_shape, **self.merger_kwargs) + else: + merger = self.merger_cls(**self.merger_kwargs) + mergers.append(merger) + return mergers, ratios + + def _aggregate(self, outputs, locations, batch_size, mergers, ratios): + for output_patches, merger, ratio in zip(outputs, mergers, ratios): + # split batched output into individual patches and then aggregate + for in_loc, out_patch in zip(locations, torch.chunk(output_patches, batch_size)): + out_loc = [round(l * r) for l, r in zip(in_loc, ratio)] + merger.aggregate(out_patch, out_loc) + + def _get_output_shape(self, inputs, out_patch, ratio): + """Define the shape of output merged tensors""" + in_spatial_shape = inputs.shape[2:] + out_spatial_shape = tuple(round(s * r) for s, r in zip(in_spatial_shape, ratio)) + output_shape = out_patch.shape[:2] + out_spatial_shape + return output_shape + + def __call__( + self, + inputs: torch.Tensor, + network: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]], + *args: Any, + **kwargs: Any, + ) -> Any: + """ + Args: + inputs: input data for inference, a torch.Tensor, representing an image or batch of images. + However if the data is already split, it can be fed by providing a list of tuple (patch, location), + or a MetaTensor that has metadata for `PatchKeys.LOCATION`. In both cases no splitter should be provided. + network: target model to execute inference. + supports callables such as ``lambda x: my_torch_model(x, additional_config)`` + args: optional args to be passed to ``network``. + kwargs: optional keyword args to be passed to ``network``. + + """ + patches_locations: Iterable[tuple[torch.Tensor, Sequence[int]]] | MetaTensor + if self.splitter is None: + if isinstance(inputs, torch.Tensor): + if isinstance(inputs, MetaTensor): + if PatchKeys.LOCATION not in inputs.meta: + raise ValueError( + "`PatchKey.LOCATION` does not exists in `inputs.meta`. " + "If the inputs are already split into patches, the location of patches needs to be " + "provided as `PatchKey.LOCATION` metadata in a MetaTensor. " + "If the input is not already split, please provide `splitter`." + ) + else: + raise ValueError( + "`splitter` should be set if the input is not already split into patches. " + "For inputs that are split, the location of patches needs to be provided as " + "(image, location) pairs, or as `PatchKey.LOCATION` metadata in a MetaTensor. " + f"The provided inputs type is {type(inputs)}." + ) + patches_locations = inputs + else: + patches_locations = self.splitter(inputs) + + ratios: list[float] = [] + mergers: list[Merger] = [] + for patches, locations, batch_size in self._batch_sampler(patches_locations): + # run inference + outputs = self._run_inference(network, patches, *args, **kwargs) + # initialize the mergers + if not mergers: + mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size) + # aggregate outputs + self._aggregate(outputs, locations, batch_size, mergers, ratios) + + # finalize the mergers and get the results + merged_outputs = tuple(merger.finalize() for merger in mergers) + # return according to the model output + if self.output_keys: + return dict(zip(self.output_keys, merged_outputs)) + if len(merged_outputs) == 1: + return merged_outputs[0] + return merged_outputs + + class SimpleInferer(Inferer): """ SimpleInferer is the normal inference method that run model forward() directly. @@ -71,7 +310,9 @@ class SimpleInferer(Inferer): def __init__(self) -> None: Inferer.__init__(self) - def __call__(self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any): + def __call__( + self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any + ) -> torch.Tensor: """Unified callable function API of Inferers. Args: @@ -134,18 +375,18 @@ class SlidingWindowInferer(Inferer): def __init__( self, - roi_size: Union[Sequence[int], int], + roi_size: Sequence[int] | int, sw_batch_size: int = 1, overlap: float = 0.25, - mode: Union[BlendMode, str] = BlendMode.CONSTANT, - sigma_scale: Union[Sequence[float], float] = 0.125, - padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, + mode: BlendMode | str = BlendMode.CONSTANT, + sigma_scale: Sequence[float] | float = 0.125, + padding_mode: PytorchPadMode | str = PytorchPadMode.CONSTANT, cval: float = 0.0, - sw_device: Union[torch.device, str, None] = None, - device: Union[torch.device, str, None] = None, + sw_device: torch.device | str | None = None, + device: torch.device | str | None = None, progress: bool = False, cache_roi_weight_map: bool = False, - cpu_thresh: Optional[int] = None, + cpu_thresh: int | None = None, ) -> None: super().__init__() self.roi_size = roi_size @@ -180,10 +421,10 @@ def __init__( def __call__( self, inputs: torch.Tensor, - network: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]], + network: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]], *args: Any, **kwargs: Any, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]: + ) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]: """ Args: @@ -232,7 +473,9 @@ class SaliencyInferer(Inferer): """ - def __init__(self, cam_name: str, target_layers: str, class_idx: Optional[int] = None, *args, **kwargs) -> None: + def __init__( + self, cam_name: str, target_layers: str, class_idx: int | None = None, *args: Any, **kwargs: Any + ) -> None: Inferer.__init__(self) if cam_name.lower() not in ("cam", "gradcam", "gradcampp"): raise ValueError("cam_name should be: 'CAM', 'GradCAM' or 'GradCAMpp'.") @@ -253,7 +496,7 @@ def __call__(self, inputs: torch.Tensor, network: nn.Module, *args: Any, **kwarg kwargs: other optional keyword args to be passed to `__call__` of cam. """ - cam: Union[CAM, GradCAM, GradCAMpp] + cam: CAM | GradCAM | GradCAMpp if self.cam_name == "cam": cam = CAM(network, self.target_layers, *self.args, **self.kwargs) elif self.cam_name == "gradcam": @@ -286,7 +529,7 @@ class SliceInferer(SlidingWindowInferer): """ - def __init__(self, spatial_dim: int = 0, *args, **kwargs) -> None: + def __init__(self, spatial_dim: int = 0, *args: Any, **kwargs: Any) -> None: self.spatial_dim = spatial_dim super().__init__(*args, **kwargs) self.orig_roi_size = ensure_tuple(self.roi_size) @@ -294,10 +537,10 @@ def __init__(self, spatial_dim: int = 0, *args, **kwargs) -> None: def __call__( self, inputs: torch.Tensor, - network: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]], + network: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]], *args: Any, **kwargs: Any, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]: + ) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]: """ Args: inputs: 3D input for inference @@ -322,11 +565,11 @@ def __call__( def network_wrapper( self, - network: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]], + network: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]], x: torch.Tensor, - *args, - **kwargs, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]: + *args: Any, + **kwargs: Any, + ) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]: """ Wrapper handles inference for 2D models over 3D volume inputs. """ diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py new file mode 100644 index 0000000000..868ccea446 --- /dev/null +++ b/monai/inferers/merger.py @@ -0,0 +1,162 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import Any + +import torch + +from monai.utils import ensure_tuple_size + +__all__ = ["Merger", "AvgMerger"] + + +class Merger(ABC): + """ + A base class for merging patches. + Extend this class to support operations for `PatchInference`. + There are two methods that must be implemented in the concrete classes: + + - aggregate: aggregate the values at their corresponding locations + - finalize: perform any final process and return the merged output + + Args: + output_shape: the shape of the merged output tensor. Default to None. + device: the device where Merger tensors should reside. + """ + + def __init__(self, output_shape: Sequence[int] | None = None, device: torch.device | str | None = None) -> None: + self.output_shape = output_shape + self.device = device + self.is_finalized = False + + @abstractmethod + def aggregate(self, values: torch.Tensor, location: Sequence[int]) -> None: + """ + Aggregate values for merging. + This method is being called in a loop and should add values to their corresponding location in the merged output results. + + Args: + values: a tensor of shape BCHW[D], representing the values of inference output. + location: a tuple/list giving the top left location of the patch in the output. + + Raises: + NotImplementedError: When the subclass does not override this method. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def finalize(self) -> Any: + """ + Perform final operations for merging patches and return the final merged output. + + Returns: + The results of merged patches, which is commonly a torch.Tensor representing the merged result, or + a string representing the filepath to the merged results on disk. + + Raises: + NotImplementedError: When the subclass does not override this method. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + +class AvgMerger(Merger): + """Merge patches by taking average of the overlapping area + + Args: + output_shape: the shape of the merged output tensor. + device: the device for aggregator tensors and final results. + value_dtype: the dtype for value aggregating tensor and the final result. + count_dtype: the dtype for sample counting tensor. + """ + + def __init__( + self, + output_shape: Sequence[int], + device: torch.device | str = "cpu", + value_dtype: torch.dtype = torch.float32, + count_dtype: torch.dtype = torch.uint8, + ) -> None: + super().__init__(output_shape=output_shape, device=device) + if not self.output_shape: + raise ValueError(f"`output_shape` must be provided for `AvgMerger`. {self.output_shape} is give.") + self.value_dtype = value_dtype + self.count_dtype = count_dtype + self.values = torch.zeros(self.output_shape, dtype=self.value_dtype, device=self.device) + self.counts = torch.zeros(self.output_shape, dtype=self.count_dtype, device=self.device) + + def aggregate(self, values: torch.Tensor, location: Sequence[int]) -> None: + """ + Aggregate values for merging. + + Args: + values: a tensor of shape BCHW[D], representing the values of inference output. + location: a tuple/list giving the top left location of the patch in the original image. + + Raises: + NotImplementedError: When the subclass does not override this method. + + """ + if self.is_finalized: + raise ValueError("`AvgMerger` is already finalized. Please instantiate a new object to aggregate.") + patch_size = values.shape[2:] + map_slice = tuple(slice(loc, loc + size) for loc, size in zip(location, patch_size)) + map_slice = ensure_tuple_size(map_slice, values.ndim, pad_val=slice(None), pad_from_start=True) + self.values[map_slice] += values + self.counts[map_slice] += 1 + + def finalize(self) -> torch.Tensor: + """ + Finalize merging by dividing values by counts and return the merged tensor. + + Notes: + To avoid creating a new tensor for the final results (to save memory space), + after this method is called, `get_values()` method will return the "final" averaged values, + and not the accumulating values. Also calling `finalize()` multiple times does not have any effect. + + Returns: + torch.tensor: a tensor of merged patches + """ + # guard against multiple call to finalize + if not self.is_finalized: + # use in-place division to save space + self.values.div_(self.counts) + # set finalize flag to protect performing in-place division again + self.is_finalized = True + + return self.values + + def get_values(self) -> torch.Tensor: + """ + Get the accumulated values during aggregation or final averaged values after it is finalized. + + Returns: + Merged (averaged) output tensor. + + Notes: + - If called before calling `finalize()`, this method returns the accumulating values. + - If called after calling `finalize()`, this method returns the final merged [and averaged] values. + """ + return self.values + + def get_counts(self) -> torch.Tensor: + """ + Get the aggregator tensor for number of samples. + + Returns: + torch.Tensor: Number of accumulated samples at each location. + """ + return self.counts diff --git a/monai/inferers/splitter.py b/monai/inferers/splitter.py new file mode 100644 index 0000000000..6c4ccdce51 --- /dev/null +++ b/monai/inferers/splitter.py @@ -0,0 +1,200 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable, Sequence +from inspect import _empty, signature +from typing import Any + +import torch + +from monai.data.utils import iter_patch_position +from monai.utils.enums import PytorchPadMode +from monai.utils.misc import ensure_tuple, ensure_tuple_rep +from monai.utils.module import look_up_option + +__all__ = ["Splitter", "SlidingWindowSplitter"] + + +class Splitter(ABC): + """ + A base class for splitting the inputs into iterable tuple of patches and locations + Extend this class to support operations for `PatchInference`, e.g. SlidingPatchSplitter. + + Args: + patch_size: the size of patches to be generated. + device: the device where the patches are generated. + """ + + def __init__(self, patch_size: Sequence[int] | int, device: torch.device | str | None = None) -> None: + self.patch_size = patch_size + self.device = device + + @abstractmethod + def __call__(self, inputs: Any) -> Iterable[tuple[torch.Tensor, Sequence[int]]]: + """ + Split the input image (or batch of images) into patches and return pairs of (patch, location). + Where location is the coordinate of top left [front] corner of a patch. + + Args: + inputs: either a tensor of shape BCHW[D], representing a batch of images, + or a filename (str) or list of filenames to the image(s). + + Raises: + NotImplementedError: When the subclass does not override this method. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + +class SlidingWindowSplitter(Splitter): + def __init__( + self, + patch_size: Sequence[int] | int, + offset: Sequence[int] | int = 0, + overlap: Sequence[float] | float = 0.0, + filter_fn: Callable | None = None, + device: torch.device | str | None = None, + pad_mode: str | None = PytorchPadMode.CONSTANT, + **pad_kwargs: Any, + ) -> None: + """Split the input into patches with sliding window strategy and a possible overlap. + If also allows to offset the starting position and filter the patches. + + Args: + patch_size : the size of the patch to be generated. + offset: the amount of offset for the patches with respect to the original input. Defaults to 0. + overlap: the amount of overlap between patches in each dimension [0, 1). Defaults to 0.0. + filter_fn: a callable to filter patches. It should accepts exactly two parameters (patch, location), and + return True for a patch to keep. Defaults to no filtering. + device: the device where the patches are generated. Defaults to the device of inputs. + pad_mode: the pad mode when the patches are extracted from outside of the image + (either when the offset is negative or the image is non-divisible by the patch_size). + If set to `None`, the last incomplete patch will be dropped. Defaults to PytorchPadMode.CONSTANT. + pad_kwargs: other arguments for `torch.nn.functional.pad`. + + Note: + If only one scaler value is provided for `patch_size`, `offset`, or `overlap`, + it will be broadcasted to all the spatial dimensions. + """ + super().__init__(patch_size=patch_size, device=device) + self.offset = offset + if any(ov < 0 or ov >= 1 for ov in ensure_tuple(overlap)): + raise ValueError(f"Overlap must be between 0 and 1 but {overlap} is given.") + self.overlap = overlap + self.filter_fn = self._get_valid_filter_fn(filter_fn) + self.pad_mode = pad_mode + self.pad_kwargs = pad_kwargs + + def _get_valid_filter_fn(self, filter_fn): + if filter_fn is None: + return + if callable(filter_fn): + sig = signature(filter_fn) + n_params = len(sig.parameters) + n_pos_params = len([v for v in sig.parameters.values() if v.default is _empty]) + if n_params < 2: + raise ValueError( + f"`patch_filter_fn` requires to accept at least two parameters (patch, location)." + f"The provided callable ({filter_fn}) has {n_params} parameters." + ) + elif n_pos_params > 2: + raise ValueError( + f"`patch_filter_fn` can have at most two positional parameters (patch, location)." + f"The provided callable ({filter_fn}) has {n_pos_params} positional parameters." + ) + return filter_fn + raise ValueError( + "`patch_filter_fn` should be a callable with two input parameters (patch, location). " + f"{type(filter_fn)} is given." + ) + + def _get_valid_patch_size(self, spatial_ndim): + return ensure_tuple_rep(self.patch_size, spatial_ndim) + + def _get_valid_overlap(self, spatial_ndim, patch_size): + # broadcast overlap is possible + overlap = ensure_tuple_rep(self.overlap, spatial_ndim) + # keep overlap only in patching dimensions + return tuple(o if p else 0.0 for o, p in zip(overlap, patch_size)) + + def _get_valid_offset(self, spatial_shape, spatial_ndim, patch_size): + offset = ensure_tuple_rep(self.offset, spatial_ndim) + for off, ps, ins in zip(offset, patch_size, spatial_shape): + if off < 0 and not self.pad_mode: + raise ValueError( + f"Negative `offset` ({off}) requires a valid padding mode, " + f"but `pad_mod` is set to {self.pad_mode}." + ) + if off < -ps: + raise ValueError(f"Negative `offset` ({off}) cannot be larger than `patch_size` ({ps}) in magnitude.") + if off >= ins: + raise ValueError(f"`offset` ({off}) cannot be larger than inputs size ({ins}).") + return offset + + def _calculate_pad_size(self, spatial_shape, spatial_ndim, patch_size, offset, overlap): + if not self.pad_mode: + return [], False + # initialize with zero + pad_size = [0] * 2 * spatial_ndim + # set the starting pad size only if the offset is negative + pad_size[1::2] = (-min(off, 0) for off in offset) + # set the ending pad size only if it is not divisible by the patch size + pad_size[::2] = ( + 0 if ps == 0 else (off - ins + ps) % round(ps * (1.0 - ov)) + for ins, off, ps, ov in zip(spatial_shape, offset, patch_size, overlap) + ) + return pad_size, any(pad_size[1::2]) + + def __call__(self, inputs: torch.Tensor) -> Iterable[tuple[torch.Tensor, Sequence[int]]]: + """Split the input tensor into patches and return patches and locations. + + Args: + inputs: a torch.Tensor with BCHW[D] dimensions, representing an image or a batch of images. + + Yields: + tuple[torch.Tensor, Sequence[int]]: yields tuple of patch and location + """ + n_non_spatial_dims = 2 + spatial_ndim = inputs.ndim - n_non_spatial_dims + spatial_shape = inputs.shape[n_non_spatial_dims:] + patch_size = self._get_valid_patch_size(spatial_ndim) + overlap = self._get_valid_overlap(spatial_ndim, patch_size) + offset = self._get_valid_offset(spatial_shape, spatial_ndim, patch_size) + pad_size, is_start_padded = self._calculate_pad_size(spatial_shape, spatial_ndim, patch_size, offset, overlap) + + if any(pad_size): + # pad the inputs + inputs = torch.nn.functional.pad( + inputs, pad_size[::-1], look_up_option(self.pad_mode, PytorchPadMode).value, **self.pad_kwargs + ) + # update spatial shape + spatial_shape = inputs.shape[n_non_spatial_dims:] + # correct the offset with respect to the padded image + if is_start_padded: + offset = tuple(off + p for off, p in zip(offset, pad_size[1::2])) + + for location in iter_patch_position(spatial_shape, patch_size, offset, overlap, False): + slices = (slice(None),) * 2 + tuple(slice(loc, loc + ps) for loc, ps in zip(location, patch_size)) + patch = inputs[slices] + # send the patch to target device + if self.device: + patch.to(self.device) + # correct the location with respect to original inputs (remove starting pads) + if is_start_padded: + location = tuple(loc - p for loc, p in zip(location, pad_size[1::2])) + # filter patches and yield + if self.filter_fn is None: + yield patch, location + elif self.filter_fn(patch, location): + yield patch, location diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index fe83c416d6..c4405911d0 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -9,8 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from collections.abc import Callable, Mapping, Sequence +from typing import Any import torch import torch.nn.functional as F @@ -36,22 +39,22 @@ def sliding_window_inference( inputs: torch.Tensor, - roi_size: Union[Sequence[int], int], + roi_size: Sequence[int] | int, sw_batch_size: int, - predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]], + predictor: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]], overlap: float = 0.25, - mode: Union[BlendMode, str] = BlendMode.CONSTANT, - sigma_scale: Union[Sequence[float], float] = 0.125, - padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, + mode: BlendMode | str = BlendMode.CONSTANT, + sigma_scale: Sequence[float] | float = 0.125, + padding_mode: PytorchPadMode | str = PytorchPadMode.CONSTANT, cval: float = 0.0, - sw_device: Union[torch.device, str, None] = None, - device: Union[torch.device, str, None] = None, + sw_device: torch.device | str | None = None, + device: torch.device | str | None = None, progress: bool = False, - roi_weight_map: Optional[torch.Tensor] = None, - process_fn: Optional[Callable] = None, + roi_weight_map: torch.Tensor | None = None, + process_fn: Callable | None = None, *args: Any, **kwargs: Any, -) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]: +) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]: """ Sliding window inference on `inputs` with `predictor`. @@ -139,7 +142,9 @@ def sliding_window_inference( diff = max(roi_size[k - 2] - inputs.shape[k], 0) half = diff // 2 pad_size.extend([half, diff - half]) - inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval) + + if max(pad_size) > 0: + inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval) scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) @@ -164,8 +169,8 @@ def sliding_window_inference( importance_map_ = convert_data_type(importance_map_, torch.Tensor, device, compute_dtype)[0] # handle non-positive weights - min_non_zero = max(importance_map_[importance_map_ != 0].min().item(), 1e-3) - importance_map_ = torch.clamp(importance_map_.to(torch.float32), min=min_non_zero).to(compute_dtype) + min_non_zero = max(torch.min(importance_map_).item(), 1e-3) + importance_map_ = torch.clamp_(importance_map_.to(torch.float32), min=min_non_zero).to(compute_dtype) # Perform predictions dict_key, output_image_list, count_map_list = None, [], [] @@ -185,7 +190,7 @@ def sliding_window_inference( seg_prob_out = predictor(window_data, *args, **kwargs) # batched patch segmentation # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory. - seg_prob_tuple: Tuple[torch.Tensor, ...] + seg_prob_tuple: tuple[torch.Tensor, ...] if isinstance(seg_prob_out, torch.Tensor): seg_prob_tuple = (seg_prob_out,) elif isinstance(seg_prob_out, Mapping): @@ -250,7 +255,11 @@ def sliding_window_inference( "Tips: if overlap*roi_size*zoom_scale is an integer, it usually works." ) original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None) - importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(compute_dtype) + importance_map_zoom = ( + resizer(importance_map.unsqueeze(0))[0].to(compute_dtype) + if seg_prob.shape[2:] != importance_map.shape + else importance_map.to(compute_dtype) + ) # store results and weights output_image_list[ss][original_idx_zoom] += importance_map_zoom * seg_prob[idx - slice_g] count_map_list[ss][original_idx_zoom] += ( @@ -259,18 +268,19 @@ def sliding_window_inference( # account for any overlapping sections for ss in range(len(output_image_list)): - output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(compute_dtype) + output_image_list[ss] = output_image_list[ss] + _map = count_map_list.pop(0) + for _i in range(output_image_list[ss].shape[1]): + output_image_list[ss][:, _i : _i + 1, ...] /= _map + output_image_list[ss] = output_image_list[ss].to(compute_dtype) # remove padding if image_size smaller than roi_size for ss, output_i in enumerate(output_image_list): - if torch.isnan(output_i).any() or torch.isinf(output_i).any(): - warnings.warn("Sliding window inference results contain NaN or Inf.") - zoom_scale = [ seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size) ] - final_slicing: List[slice] = [] + final_slicing: list[slice] = [] for sp in range(num_spatial_dims): slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2]) slice_dim = slice( @@ -295,7 +305,7 @@ def sliding_window_inference( def _get_scan_interval( image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float -) -> Tuple[int, ...]: +) -> tuple[int, ...]: """ Compute scan interval according to the image size, roi size and overlap. Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0, diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index a3c4bf1c5c..9e09b0b123 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .contrastive import ContrastiveLoss from .deform import BendingEnergyLoss from .dice import ( diff --git a/monai/losses/contrastive.py b/monai/losses/contrastive.py index e0ed4ceab0..6213091bf6 100644 --- a/monai/losses/contrastive.py +++ b/monai/losses/contrastive.py @@ -9,14 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from warnings import warn import torch from torch.nn import functional as F from torch.nn.modules.loss import _Loss -from monai.utils import deprecated_arg - class ContrastiveLoss(_Loss): """ @@ -30,8 +30,7 @@ class ContrastiveLoss(_Loss): """ - @deprecated_arg(name="reduction", since="0.8", msg_suffix="`reduction` is no longer supported.") - def __init__(self, temperature: float = 0.5, batch_size: int = -1, reduction="sum") -> None: + def __init__(self, temperature: float = 0.5, batch_size: int = -1) -> None: """ Args: temperature: Can be scaled between 0 and 1 for learning from negative samples, ideally set to 0.5. diff --git a/monai/losses/deform.py b/monai/losses/deform.py index 0f5e263a53..dd03a8eb3d 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from __future__ import annotations import torch from torch.nn.modules.loss import _Loss @@ -52,7 +52,7 @@ class BendingEnergyLoss(_Loss): DeepReg (https://github.com/DeepRegNet/DeepReg) """ - def __init__(self, normalize: bool = False, reduction: Union[LossReduction, str] = LossReduction.MEAN) -> None: + def __init__(self, normalize: bool = False, reduction: LossReduction | str = LossReduction.MEAN) -> None: """ Args: normalize: diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 1af4c519b3..5d3b3e5476 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -9,8 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import Callable, List, Optional, Sequence, Union +from collections.abc import Callable, Sequence +from typing import Any import numpy as np import torch @@ -47,10 +50,10 @@ def __init__( to_onehot_y: bool = False, sigmoid: bool = False, softmax: bool = False, - other_act: Optional[Callable] = None, + other_act: Callable | None = None, squared_pred: bool = False, jaccard: bool = False, - reduction: Union[LossReduction, str] = LossReduction.MEAN, + reduction: LossReduction | str = LossReduction.MEAN, smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, @@ -157,7 +160,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") # reducing only spatial dimensions (not batch nor channels) - reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist() + reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis @@ -203,21 +206,21 @@ class MaskedDiceLoss(DiceLoss): """ - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: """ Args follow :py:class:`monai.losses.DiceLoss`. """ super().__init__(*args, **kwargs) self.spatial_weighted = MaskedLoss(loss=super().forward) - def forward(self, input: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None): + def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: """ Args: input: the shape should be BNH[WD]. target: the shape should be BNH[WD]. mask: the shape should B1H[WD] or 11H[WD]. """ - return self.spatial_weighted(input=input, target=target, mask=mask) + return self.spatial_weighted(input=input, target=target, mask=mask) # type: ignore[no-any-return] class GeneralizedDiceLoss(_Loss): @@ -237,9 +240,9 @@ def __init__( to_onehot_y: bool = False, sigmoid: bool = False, softmax: bool = False, - other_act: Optional[Callable] = None, - w_type: Union[Weight, str] = Weight.SQUARE, - reduction: Union[LossReduction, str] = LossReduction.MEAN, + other_act: Callable | None = None, + w_type: Weight | str = Weight.SQUARE, + reduction: LossReduction | str = LossReduction.MEAN, smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, @@ -337,7 +340,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") # reducing only spatial dimensions (not batch nor channels) - reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist() + reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: reduce_axis = [0] + reduce_axis intersection = torch.sum(target * input, reduce_axis) @@ -395,9 +398,9 @@ class GeneralizedWassersteinDiceLoss(_Loss): def __init__( self, - dist_matrix: Union[np.ndarray, torch.Tensor], + dist_matrix: np.ndarray | torch.Tensor, weighting_mode: str = "default", - reduction: Union[LossReduction, str] = LossReduction.MEAN, + reduction: LossReduction | str = LossReduction.MEAN, smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, ) -> None: @@ -621,14 +624,14 @@ def __init__( to_onehot_y: bool = False, sigmoid: bool = False, softmax: bool = False, - other_act: Optional[Callable] = None, + other_act: Callable | None = None, squared_pred: bool = False, jaccard: bool = False, reduction: str = "mean", smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, - ce_weight: Optional[torch.Tensor] = None, + ce_weight: torch.Tensor | None = None, lambda_dice: float = 1.0, lambda_ce: float = 1.0, ) -> None: @@ -693,7 +696,7 @@ def __init__( self.lambda_ce = lambda_ce self.old_pt_ver = not pytorch_after(1, 10) - def ce(self, input: torch.Tensor, target: torch.Tensor): + def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Compute CrossEntropy loss for the input and target. Will remove the channel dim according to PyTorch CrossEntropyLoss: @@ -713,7 +716,7 @@ def ce(self, input: torch.Tensor, target: torch.Tensor): elif not torch.is_floating_point(target): target = target.to(dtype=input.dtype) - return self.cross_entropy(input, target) + return self.cross_entropy(input, target) # type: ignore[no-any-return] def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -757,7 +760,7 @@ def __init__( to_onehot_y: bool = False, sigmoid: bool = False, softmax: bool = False, - other_act: Optional[Callable] = None, + other_act: Callable | None = None, squared_pred: bool = False, jaccard: bool = False, reduction: str = "mean", @@ -765,7 +768,7 @@ def __init__( smooth_dr: float = 1e-5, batch: bool = False, gamma: float = 2.0, - focal_weight: Optional[Union[Sequence[float], float, int, torch.Tensor]] = None, + focal_weight: Sequence[float] | float | int | torch.Tensor | None = None, lambda_dice: float = 1.0, lambda_focal: float = 1.0, ) -> None: @@ -907,14 +910,14 @@ def __init__( to_onehot_y: bool = False, sigmoid: bool = False, softmax: bool = False, - other_act: Optional[Callable] = None, - w_type: Union[Weight, str] = Weight.SQUARE, - reduction: Union[LossReduction, str] = LossReduction.MEAN, + other_act: Callable | None = None, + w_type: Weight | str = Weight.SQUARE, + reduction: LossReduction | str = LossReduction.MEAN, smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, gamma: float = 2.0, - focal_weight: Optional[Union[Sequence[float], float, int, torch.Tensor]] = None, + focal_weight: Sequence[float] | float | int | torch.Tensor | None = None, lambda_gdl: float = 1.0, lambda_focal: float = 1.0, ) -> None: diff --git a/monai/losses/ds_loss.py b/monai/losses/ds_loss.py index c0425ffdde..d92cbc1ccd 100644 --- a/monai/losses/ds_loss.py +++ b/monai/losses/ds_loss.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Union +from __future__ import annotations import torch import torch.nn.functional as F @@ -24,7 +24,7 @@ class DeepSupervisionLoss(_Loss): supervised networks. The final loss is computed as the sum of weighted losses for each of deep supervision levels. """ - def __init__(self, loss: _Loss, weight_mode: str = "exp", weights: Optional[List[float]] = None) -> None: + def __init__(self, loss: _Loss, weight_mode: str = "exp", weights: list[float] | None = None) -> None: """ Args: loss: main loss instance, e.g DiceLoss(). @@ -42,7 +42,7 @@ def __init__(self, loss: _Loss, weight_mode: str = "exp", weights: Optional[List self.weights = weights self.interp_mode = "nearest-exact" if pytorch_after(1, 11) else "nearest" - def get_weights(self, levels: int = 1) -> List[float]: + def get_weights(self, levels: int = 1) -> list[float]: """ Calculates weights for a given number of scale levels """ @@ -60,7 +60,7 @@ def get_weights(self, levels: int = 1) -> List[float]: return weights - def get_loss(self, input: torch.Tensor, target: torch.Tensor): + def get_loss(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Calculates a loss output accounting for differences in shapes, and downsizing targets if necessary (using nearest neighbor interpolation) @@ -68,10 +68,9 @@ def get_loss(self, input: torch.Tensor, target: torch.Tensor): """ if input.shape[2:] != target.shape[2:]: target = F.interpolate(target, size=input.shape[2:], mode=self.interp_mode) - return self.loss(input, target) - - def forward(self, input: Union[torch.Tensor, List[torch.Tensor]], target: torch.Tensor): + return self.loss(input, target) # type: ignore[no-any-return] + def forward(self, input: torch.Tensor | list[torch.Tensor], target: torch.Tensor) -> torch.Tensor: if isinstance(input, (list, tuple)): weights = self.get_weights(levels=len(input)) loss = torch.tensor(0, dtype=torch.float, device=target.device) @@ -79,7 +78,7 @@ def forward(self, input: Union[torch.Tensor, List[torch.Tensor]], target: torch. loss += weights[l] * self.get_loss(input[l].float(), target) return loss - return self.loss(input.float(), target) + return self.loss(input.float(), target) # type: ignore[no-any-return] ds_loss = DeepSupervisionLoss diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index bf31682748..80c01c7b7f 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import Optional, Sequence, Union +from collections.abc import Sequence import torch import torch.nn.functional as F @@ -67,8 +69,8 @@ def __init__( include_background: bool = True, to_onehot_y: bool = False, gamma: float = 2.0, - weight: Optional[Union[Sequence[float], float, int, torch.Tensor]] = None, - reduction: Union[LossReduction, str] = LossReduction.MEAN, + weight: Sequence[float] | float | int | torch.Tensor | None = None, + reduction: LossReduction | str = LossReduction.MEAN, ) -> None: """ Args: @@ -100,7 +102,7 @@ def __init__( self.include_background = include_background self.to_onehot_y = to_onehot_y self.gamma = gamma - self.weight: Optional[Union[Sequence[float], float, int, torch.Tensor]] = weight + self.weight: Sequence[float] | float | int | torch.Tensor | None = weight def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -152,7 +154,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ce = i - i * t + max_val + ((-max_val).exp() + (-i - max_val).exp()).log() if self.weight is not None: - class_weight: Optional[torch.Tensor] = None + class_weight: torch.Tensor | None = None if isinstance(self.weight, (float, int)): class_weight = torch.as_tensor([self.weight] * i.size(1)) else: diff --git a/monai/losses/giou_loss.py b/monai/losses/giou_loss.py index 623e55921b..7940660a08 100644 --- a/monai/losses/giou_loss.py +++ b/monai/losses/giou_loss.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from __future__ import annotations import torch from torch.nn.modules.loss import _Loss @@ -33,7 +33,7 @@ class BoxGIoULoss(_Loss): - ``"sum"``: the output will be summed. """ - def __init__(self, reduction: Union[LossReduction, str] = LossReduction.MEAN) -> None: + def __init__(self, reduction: LossReduction | str = LossReduction.MEAN) -> None: super().__init__(reduction=LossReduction(reduction).value) def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 4351199aee..39219e059a 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -8,7 +8,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union + +from __future__ import annotations import torch from torch.nn import functional as F @@ -65,7 +66,7 @@ def __init__( spatial_dims: int = 3, kernel_size: int = 3, kernel_type: str = "rectangular", - reduction: Union[LossReduction, str] = LossReduction.MEAN, + reduction: LossReduction | str = LossReduction.MEAN, smooth_nr: float = 0.0, smooth_dr: float = 1e-5, ) -> None: @@ -175,7 +176,7 @@ def __init__( kernel_type: str = "gaussian", num_bins: int = 23, sigma_ratio: float = 0.5, - reduction: Union[LossReduction, str] = LossReduction.MEAN, + reduction: LossReduction | str = LossReduction.MEAN, smooth_nr: float = 1e-7, smooth_dr: float = 1e-7, ) -> None: @@ -221,7 +222,7 @@ def __init__( def parzen_windowing( self, pred: torch.Tensor, target: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: if self.kernel_type == "gaussian": pred_weight, pred_probability = self.parzen_windowing_gaussian(pred) target_weight, target_probability = self.parzen_windowing_gaussian(target) @@ -234,7 +235,7 @@ def parzen_windowing( raise ValueError return pred_weight, pred_probability, target_weight, target_probability - def parzen_windowing_b_spline(self, img: torch.Tensor, order: int) -> Tuple[torch.Tensor, torch.Tensor]: + def parzen_windowing_b_spline(self, img: torch.Tensor, order: int) -> tuple[torch.Tensor, torch.Tensor]: """ Parzen windowing with b-spline kernel (adapted from ITK) @@ -287,7 +288,7 @@ def parzen_windowing_b_spline(self, img: torch.Tensor, order: int) -> Tuple[torc probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bins) return weight, probability - def parzen_windowing_gaussian(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def parzen_windowing_gaussian(self, img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Parzen windowing with gaussian kernel (adapted from DeepReg implementation) Note: the input is expected to range between 0 and 1 diff --git a/monai/losses/multi_scale.py b/monai/losses/multi_scale.py index bef1ae1a5d..7119f51042 100644 --- a/monai/losses/multi_scale.py +++ b/monai/losses/multi_scale.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Union +from __future__ import annotations import torch from torch.nn.modules.loss import _Loss @@ -49,9 +49,9 @@ class MultiScaleLoss(_Loss): def __init__( self, loss: _Loss, - scales: Optional[List] = None, + scales: list | None = None, kernel: str = "gaussian", - reduction: Union[LossReduction, str] = LossReduction.MEAN, + reduction: LossReduction | str = LossReduction.MEAN, ) -> None: """ Args: diff --git a/monai/losses/spatial_mask.py b/monai/losses/spatial_mask.py index aa232f882e..0f823410dd 100644 --- a/monai/losses/spatial_mask.py +++ b/monai/losses/spatial_mask.py @@ -9,9 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import inspect import warnings -from typing import Callable, Optional, Union +from collections.abc import Callable +from typing import Any import torch from torch.nn.modules.loss import _Loss @@ -28,7 +31,9 @@ class MaskedLoss(_Loss): - :py:class:`monai.losses.MaskedDiceLoss` """ - def __init__(self, loss: Union[Callable, _Loss], *loss_args, **loss_kwargs) -> None: + def __init__( + self, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | _Loss, *loss_args: Any, **loss_kwargs: Any + ) -> None: """ Args: loss: loss function to be wrapped, this could be a loss class or an instance of a loss class. @@ -36,11 +41,13 @@ def __init__(self, loss: Union[Callable, _Loss], *loss_args, **loss_kwargs) -> N loss_kwargs: keyword arguments to the loss function's constructor if `loss` is a class. """ super().__init__() - self.loss = loss(*loss_args, **loss_kwargs) if inspect.isclass(loss) else loss + self.loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = ( + loss(*loss_args, **loss_kwargs) if inspect.isclass(loss) else loss + ) if not callable(self.loss): raise ValueError("The loss function is not callable.") - def forward(self, input: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None): + def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: """ Args: input: the shape should be BNH[WD]. diff --git a/monai/losses/ssim_loss.py b/monai/losses/ssim_loss.py index 2dc0289a29..e8e5d0c2ba 100644 --- a/monai/losses/ssim_loss.py +++ b/monai/losses/ssim_loss.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import torch from torch.nn.modules.loss import _Loss @@ -85,7 +87,6 @@ def forward(self, x: torch.Tensor, y: torch.Tensor, data_range: torch.Tensor) -> data_range, self.win_size, self.k1, self.k2, self.spatial_dims )._compute_tensor(x, y) elif x.shape[0] > 1: - for i in range(x.shape[0]): ssim_val: torch.Tensor = SSIMMetric( data_range, self.win_size, self.k1, self.k2, self.spatial_dims @@ -93,7 +94,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor, data_range: torch.Tensor) -> if i == 0: ssim_value = ssim_val else: - ssim_value = torch.cat((ssim_value.view(1), ssim_val.view(1)), dim=0) + ssim_value = torch.cat((ssim_value.view(i), ssim_val.view(1)), dim=0) else: raise ValueError("Batch size is not nonnegative integer value") diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index a0735c24e0..4f22bf84b4 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import Callable, List, Optional, Union +from collections.abc import Callable import torch from torch.nn.modules.loss import _Loss @@ -37,10 +39,10 @@ def __init__( to_onehot_y: bool = False, sigmoid: bool = False, softmax: bool = False, - other_act: Optional[Callable] = None, + other_act: Callable | None = None, alpha: float = 0.5, beta: float = 0.5, - reduction: Union[LossReduction, str] = LossReduction.MEAN, + reduction: LossReduction | str = LossReduction.MEAN, smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, @@ -138,7 +140,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: g1 = 1 - g0 # reducing only spatial dimensions (not batch nor channels) - reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist() + reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 1e2bdae725..8484eb67ed 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import Union import torch from torch.nn.modules.loss import _Loss @@ -37,7 +38,7 @@ def __init__( delta: float = 0.7, gamma: float = 0.75, epsilon: float = 1e-7, - reduction: Union[LossReduction, str] = LossReduction.MEAN, + reduction: LossReduction | str = LossReduction.MEAN, ) -> None: """ Args: @@ -101,7 +102,7 @@ def __init__( delta: float = 0.7, gamma: float = 2, epsilon: float = 1e-7, - reduction: Union[LossReduction, str] = LossReduction.MEAN, + reduction: LossReduction | str = LossReduction.MEAN, ): """ Args: @@ -160,7 +161,7 @@ def __init__( weight: float = 0.5, gamma: float = 0.5, delta: float = 0.7, - reduction: Union[LossReduction, str] = LossReduction.MEAN, + reduction: LossReduction | str = LossReduction.MEAN, ): """ Args: diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index d04e8461c4..3c878f9040 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix from .cumulative_average import CumulativeAverage @@ -26,3 +28,4 @@ from .surface_dice import SurfaceDiceMetric, compute_surface_dice from .surface_distance import SurfaceDistanceMetric, compute_average_surface_distance from .utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background, is_binary_tensor +from .wrapper import MetricsReloadedBinary, MetricsReloadedCategorical diff --git a/monai/metrics/active_learning_metrics.py b/monai/metrics/active_learning_metrics.py index eddc82e87a..7a1654191e 100644 --- a/monai/metrics/active_learning_metrics.py +++ b/monai/metrics/active_learning_metrics.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings from typing import Any @@ -88,7 +90,7 @@ def __init__(self, include_background: bool = True, scalar_reduction: str = "sum self.include_background = include_background self.scalar_reduction = scalar_reduction - def __call__(self, y_pred: Any, y: Any): + def __call__(self, y_pred: Any, y: Any) -> torch.Tensor | None: """ Args: y_pred: Predicted segmentation, typically segmentation model output. @@ -109,7 +111,7 @@ def compute_variance( spatial_map: bool = False, scalar_reduction: str = "mean", threshold: float = 0.0005, -): +) -> torch.Tensor | None: """ Args: y_pred: [N, C, H, W, D] or [N, C, H, W] or [N, C, H] where N is repeats, C is channels and H, W, D stand for @@ -162,7 +164,7 @@ def compute_variance( def label_quality_score( y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, scalar_reduction: str = "mean" -): +) -> torch.Tensor | None: """ The assumption is that the DL model makes better predictions than the provided label quality, hence the difference can be treated as a label quality score diff --git a/monai/metrics/confusion_matrix.py b/monai/metrics/confusion_matrix.py index 872819f1a9..35a5dd9764 100644 --- a/monai/metrics/confusion_matrix.py +++ b/monai/metrics/confusion_matrix.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import Sequence, Union +from collections.abc import Sequence import torch @@ -63,9 +65,9 @@ class ConfusionMatrixMetric(CumulativeIterationMetric): def __init__( self, include_background: bool = True, - metric_name: Union[Sequence[str], str] = "hit_rate", + metric_name: Sequence[str] | str = "hit_rate", compute_sample: bool = False, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, ) -> None: super().__init__() @@ -75,7 +77,7 @@ def __init__( self.reduction = reduction self.get_not_nans = get_not_nans - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ Args: y_pred: input data to compute. It must be one-hot format and first dim is batch. @@ -100,7 +102,9 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor return get_confusion_matrix(y_pred=y_pred, y=y, include_background=self.include_background) - def aggregate(self, compute_sample: bool = False, reduction: Union[MetricReduction, str, None] = None): + def aggregate( + self, compute_sample: bool = False, reduction: MetricReduction | str | None = None + ) -> list[torch.Tensor | tuple[torch.Tensor, torch.Tensor]]: """ Execute reduction for the confusion matrix values. @@ -116,7 +120,7 @@ def aggregate(self, compute_sample: bool = False, reduction: Union[MetricReducti if not isinstance(data, torch.Tensor): raise ValueError("the data to aggregate must be PyTorch Tensor.") - results = [] + results: list[torch.Tensor | tuple[torch.Tensor, torch.Tensor]] = [] for metric_name in self.metric_name: if compute_sample or self.compute_sample: sub_confusion_matrix = compute_confusion_matrix_metric(metric_name, data) @@ -131,7 +135,7 @@ def aggregate(self, compute_sample: bool = False, reduction: Union[MetricReducti return results -def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True): +def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True) -> torch.Tensor: """ Compute confusion matrix. A tensor with the shape [BC4] will be returned. Where, the third dimension represents the number of true positive, false positive, true negative and false negative values for @@ -179,7 +183,7 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou return torch.stack([tp, fp, tn, fn], dim=-1) -def compute_confusion_matrix_metric(metric_name: str, confusion_matrix: torch.Tensor): +def compute_confusion_matrix_metric(metric_name: str, confusion_matrix: torch.Tensor) -> torch.Tensor: """ This function is used to compute confusion matrix related metric. @@ -215,7 +219,7 @@ def compute_confusion_matrix_metric(metric_name: str, confusion_matrix: torch.Te n = fp + tn # calculate metric numerator: torch.Tensor - denominator: Union[torch.Tensor, float] + denominator: torch.Tensor | float nan_tensor = torch.tensor(float("nan"), device=confusion_matrix.device) if metric == "tpr": numerator, denominator = tp, p @@ -274,7 +278,7 @@ def compute_confusion_matrix_metric(metric_name: str, confusion_matrix: torch.Te return numerator / denominator -def check_confusion_matrix_metric_name(metric_name: str): +def check_confusion_matrix_metric_name(metric_name: str) -> str: """ There are many metrics related to confusion matrix, and some of the metrics have more than one names. In addition, some of the names are very long. diff --git a/monai/metrics/cumulative_average.py b/monai/metrics/cumulative_average.py index b099cdc2a4..e55e7b8576 100644 --- a/monai/metrics/cumulative_average.py +++ b/monai/metrics/cumulative_average.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import Any, Optional +from typing import Any import torch import torch.distributed as dist @@ -100,7 +102,7 @@ def aggregate(self, to_numpy: bool = True) -> NdarrayOrTensor: val = val.cpu().numpy() return val - def append(self, val: Any, count: Optional[Any] = 1) -> None: + def append(self, val: Any, count: Any | None = 1) -> None: """ Append with a new value, and an optional count. Any data type is supported that is convertable with torch.as_tensor() e.g. number, list, numpy array, or Tensor. diff --git a/monai/metrics/f_beta_score.py b/monai/metrics/f_beta_score.py index 067c09e0d2..8b595fe86b 100644 --- a/monai/metrics/f_beta_score.py +++ b/monai/metrics/f_beta_score.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from __future__ import annotations + +from collections.abc import Sequence import torch @@ -24,7 +26,7 @@ def __init__( self, beta: float = 1.0, include_background: bool = True, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, ) -> None: super().__init__() @@ -33,7 +35,7 @@ def __init__( self.reduction = reduction self.get_not_nans = get_not_nans - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] is_binary_tensor(y_pred, "y_pred") is_binary_tensor(y, "y") @@ -42,12 +44,14 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor return get_f_beta_score(y_pred=y_pred, y=y, include_background=self.include_background) - def aggregate(self, compute_sample: bool = False, reduction: Union[MetricReduction, str, None] = None): + def aggregate( + self, compute_sample: bool = False, reduction: MetricReduction | str | None = None + ) -> Sequence[torch.Tensor | tuple[torch.Tensor, torch.Tensor]]: data = self.get_buffer() if not isinstance(data, torch.Tensor): raise ValueError("the data to aggregate must be PyTorch Tensor.") - results = [] + results: list[torch.Tensor | tuple[torch.Tensor, torch.Tensor]] = [] f, not_nans = do_metric_reduction(data, reduction or self.reduction) f = compute_f_beta_score(f, self.beta) if self.get_not_nans: @@ -58,7 +62,7 @@ def aggregate(self, compute_sample: bool = False, reduction: Union[MetricReducti return results -def get_f_beta_score(y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True): +def get_f_beta_score(y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True) -> torch.Tensor: if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) @@ -88,7 +92,7 @@ def get_f_beta_score(y_pred: torch.Tensor, y: torch.Tensor, include_background: return torch.stack([tp, fp, tn, fn], dim=-1) -def compute_f_beta_score(confusion_matrix: torch.Tensor, beta: float): +def compute_f_beta_score(confusion_matrix: torch.Tensor, beta: float) -> torch.Tensor: input_dim = confusion_matrix.ndimension() if input_dim == 1: confusion_matrix = confusion_matrix.unsqueeze(dim=0) diff --git a/monai/metrics/froc.py b/monai/metrics/froc.py index 56e0755b99..6fd367d1e4 100644 --- a/monai/metrics/froc.py +++ b/monai/metrics/froc.py @@ -9,20 +9,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from __future__ import annotations + +from typing import Any, cast import numpy as np import torch +from monai.config import NdarrayOrTensor + def compute_fp_tp_probs( - probs: Union[np.ndarray, torch.Tensor], - y_coord: Union[np.ndarray, torch.Tensor], - x_coord: Union[np.ndarray, torch.Tensor], - evaluation_mask: Union[np.ndarray, torch.Tensor], - labels_to_exclude: Optional[List] = None, + probs: NdarrayOrTensor, + y_coord: NdarrayOrTensor, + x_coord: NdarrayOrTensor, + evaluation_mask: NdarrayOrTensor, + labels_to_exclude: list | None = None, resolution_level: int = 0, -): +) -> tuple[NdarrayOrTensor, NdarrayOrTensor, int]: """ This function is modified from the official evaluation code of `CAMELYON 16 Challenge `_, and used to distinguish @@ -74,15 +78,12 @@ def compute_fp_tp_probs( tp_probs[i - 1] = probs[np.where(hittedlabel == i)].max() num_targets = max_label - len(labels_to_exclude) - return fp_probs, tp_probs, num_targets + return fp_probs, tp_probs, cast(int, num_targets) def compute_froc_curve_data( - fp_probs: Union[np.ndarray, torch.Tensor], - tp_probs: Union[np.ndarray, torch.Tensor], - num_targets: int, - num_images: int, -): + fp_probs: np.ndarray | torch.Tensor, tp_probs: np.ndarray | torch.Tensor, num_targets: int, num_images: int +) -> tuple[np.ndarray, np.ndarray]: """ This function is modified from the official evaluation code of `CAMELYON 16 Challenge `_, and used to compute @@ -117,8 +118,8 @@ def compute_froc_curve_data( def compute_froc_score( - fps_per_image: np.ndarray, total_sensitivity: np.ndarray, eval_thresholds: Tuple = (0.25, 0.5, 1, 2, 4, 8) -): + fps_per_image: np.ndarray, total_sensitivity: np.ndarray, eval_thresholds: tuple = (0.25, 0.5, 1, 2, 4, 8) +) -> Any: """ This function is modified from the official evaluation code of `CAMELYON 16 Challenge `_, and used to compute diff --git a/monai/metrics/generalized_dice.py b/monai/metrics/generalized_dice.py index 7fdfc61a14..ff3b2529ca 100644 --- a/monai/metrics/generalized_dice.py +++ b/monai/metrics/generalized_dice.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from __future__ import annotations import torch @@ -46,8 +46,8 @@ class GeneralizedDiceScore(CumulativeIterationMetric): def __init__( self, include_background: bool = True, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN_BATCH, - weight_type: Union[Weight, str] = Weight.SQUARE, + reduction: MetricReduction | str = MetricReduction.MEAN_BATCH, + weight_type: Weight | str = Weight.SQUARE, ) -> None: super().__init__() self.include_background = include_background @@ -64,7 +64,7 @@ def __init__( raise ValueError(f"reduction must be one of {reduction_options}") self.weight_type = look_up_option(weight_type, Weight) - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] """Computes the Generalized Dice Score and returns a tensor with its per image values. Args: @@ -80,7 +80,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor y_pred=y_pred, y=y, include_background=self.include_background, weight_type=self.weight_type ) - def aggregate(self, reduction: Union[MetricReduction, str, None] = None): + def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Tensor: """ Execute reduction logic for the output of `compute_generalized_dice`. @@ -106,10 +106,7 @@ def aggregate(self, reduction: Union[MetricReduction, str, None] = None): def compute_generalized_dice( - y_pred: torch.Tensor, - y: torch.Tensor, - include_background: bool = True, - weight_type: Union[Weight, str] = Weight.SQUARE, + y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, weight_type: Weight | str = Weight.SQUARE ) -> torch.Tensor: """Computes the Generalized Dice Score and returns a tensor with its per image values. @@ -179,7 +176,9 @@ def compute_generalized_dice( y_pred_o = y_pred_o.sum(dim=-1) denom_zeros = denom == 0 generalized_dice_score[denom_zeros] = torch.where( - (y_pred_o == 0)[denom_zeros], torch.tensor(1.0), torch.tensor(0.0) + (y_pred_o == 0)[denom_zeros], + torch.tensor(1.0, device=generalized_dice_score.device), + torch.tensor(0.0, device=generalized_dice_score.device), ) return generalized_dice_score diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 54de8b1d4d..bba9301dd7 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import Optional, Union import numpy as np import torch @@ -22,7 +23,7 @@ ignore_background, is_binary_tensor, ) -from monai.utils import MetricReduction +from monai.utils import MetricReduction, convert_data_type from .metric import CumulativeIterationMetric @@ -62,9 +63,9 @@ def __init__( self, include_background: bool = False, distance_metric: str = "euclidean", - percentile: Optional[float] = None, + percentile: float | None = None, directed: bool = False, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, ) -> None: super().__init__() @@ -75,7 +76,7 @@ def __init__( self.reduction = reduction self.get_not_nans = get_not_nans - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ Args: y_pred: input data to compute, typical segmentation model output. @@ -104,7 +105,9 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor directed=self.directed, ) - def aggregate(self, reduction: Union[MetricReduction, str, None] = None): + def aggregate( + self, reduction: MetricReduction | str | None = None + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ Execute reduction logic for the output of `compute_hausdorff_distance`. @@ -124,13 +127,13 @@ def aggregate(self, reduction: Union[MetricReduction, str, None] = None): def compute_hausdorff_distance( - y_pred: Union[np.ndarray, torch.Tensor], - y: Union[np.ndarray, torch.Tensor], + y_pred: np.ndarray | torch.Tensor, + y: np.ndarray | torch.Tensor, include_background: bool = False, distance_metric: str = "euclidean", - percentile: Optional[float] = None, + percentile: float | None = None, directed: bool = False, -): +) -> torch.Tensor: """ Compute the Hausdorff distance. @@ -152,10 +155,8 @@ def compute_hausdorff_distance( if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) - if isinstance(y, torch.Tensor): - y = y.float() - if isinstance(y_pred, torch.Tensor): - y_pred = y_pred.float() + y_pred = convert_data_type(y_pred, output_type=torch.Tensor, dtype=torch.float)[0] + y = convert_data_type(y, output_type=torch.Tensor, dtype=torch.float)[0] if y.shape != y_pred.shape: raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") @@ -175,12 +176,12 @@ def compute_hausdorff_distance( else: distance_2 = compute_percent_hausdorff_distance(edges_gt, edges_pred, distance_metric, percentile) hd[b, c] = max(distance_1, distance_2) - return torch.from_numpy(hd) + return convert_data_type(hd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0] def compute_percent_hausdorff_distance( - edges_pred: np.ndarray, edges_gt: np.ndarray, distance_metric: str = "euclidean", percentile: Optional[float] = None -): + edges_pred: np.ndarray, edges_gt: np.ndarray, distance_metric: str = "euclidean", percentile: float | None = None +) -> float: """ This function is used to compute the directed Hausdorff distance. """ @@ -192,8 +193,8 @@ def compute_percent_hausdorff_distance( return np.nan if not percentile: - return surface_distance.max() + return surface_distance.max() # type: ignore[no-any-return] if 0 <= percentile <= 100: - return np.percentile(surface_distance, percentile) + return np.percentile(surface_distance, percentile) # type: ignore[no-any-return] raise ValueError(f"percentile should be a value between 0 and 100, get {percentile}.") diff --git a/monai/metrics/loss_metric.py b/monai/metrics/loss_metric.py index 01db1f7af1..2cc9755e36 100644 --- a/monai/metrics/loss_metric.py +++ b/monai/metrics/loss_metric.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from __future__ import annotations import torch from torch.nn.modules.loss import _Loss @@ -17,6 +17,7 @@ from monai.metrics.utils import do_metric_reduction from monai.utils import MetricReduction +from ..config import TensorOrList from .metric import CumulativeIterationMetric @@ -67,14 +68,16 @@ class LossMetric(CumulativeIterationMetric): """ def __init__( - self, loss_fn: _Loss, reduction: Union[MetricReduction, str] = MetricReduction.MEAN, get_not_nans: bool = False + self, loss_fn: _Loss, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False ) -> None: super().__init__() self.loss_fn = loss_fn self.reduction = reduction self.get_not_nans = get_not_nans - def aggregate(self, reduction: Union[MetricReduction, str, None] = None): + def aggregate( + self, reduction: MetricReduction | str | None = None + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ Returns the aggregated loss value across multiple iterations. @@ -89,7 +92,7 @@ def aggregate(self, reduction: Union[MetricReduction, str, None] = None): f, not_nans = do_metric_reduction(data, reduction or self.reduction) return (f, not_nans) if self.get_not_nans else f - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor = None): # type: ignore + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor | None = None) -> TensorOrList: """ Input `y_pred` is compared with ground truth `y`. Both `y_pred` and `y` are expected to be a batch-first Tensor (BC[HWD]). @@ -97,7 +100,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor = None): # type Returns: a tensor with shape (BC[HWD]), or a list of tensors, each tensor with shape (C[HWD]). """ - iter_loss = self.loss_fn(y_pred) if y is None else self.loss_fn(y_pred, y) + iter_loss: TensorOrList = self.loss_fn(y_pred) if y is None else self.loss_fn(y_pred, y) if isinstance(iter_loss, torch.Tensor): while iter_loss.dim() < 2: iter_loss = iter_loss[None] diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index a9d4e7182a..aa83edc9cb 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from __future__ import annotations import torch @@ -50,7 +50,7 @@ class DiceMetric(CumulativeIterationMetric): def __init__( self, include_background: bool = True, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, ignore_empty: bool = True, ) -> None: @@ -60,7 +60,7 @@ def __init__( self.get_not_nans = get_not_nans self.ignore_empty = ignore_empty - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ Args: y_pred: input data to compute, typical segmentation model output. @@ -84,7 +84,9 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor y_pred=y_pred, y=y, include_background=self.include_background, ignore_empty=self.ignore_empty ) - def aggregate(self, reduction: Union[MetricReduction, str, None] = None): + def aggregate( + self, reduction: MetricReduction | str | None = None + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ Execute reduction logic for the output of `compute_meandice`. diff --git a/monai/metrics/meaniou.py b/monai/metrics/meaniou.py index 55fa73e1ff..bd9e11b323 100644 --- a/monai/metrics/meaniou.py +++ b/monai/metrics/meaniou.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from __future__ import annotations import torch @@ -51,7 +51,7 @@ class MeanIoU(CumulativeIterationMetric): def __init__( self, include_background: bool = True, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, ignore_empty: bool = True, ) -> None: @@ -61,7 +61,7 @@ def __init__( self.get_not_nans = get_not_nans self.ignore_empty = ignore_empty - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ Args: y_pred: input data to compute, typical segmentation model output. @@ -85,7 +85,9 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor y_pred=y_pred, y=y, include_background=self.include_background, ignore_empty=self.ignore_empty ) - def aggregate(self, reduction: Union[MetricReduction, str, None] = None): + def aggregate( + self, reduction: MetricReduction | str | None = None + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ Execute reduction logic for the output of `compute_meaniou`. diff --git a/monai/metrics/metric.py b/monai/metrics/metric.py index 99aff701c1..608d914808 100644 --- a/monai/metrics/metric.py +++ b/monai/metrics/metric.py @@ -9,8 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Any, List, Optional +from collections.abc import Sequence +from typing import Any import torch @@ -28,7 +31,7 @@ class Metric(ABC): """ @abstractmethod - def __call__(self, *args: Any, **kwargs: Any): + def __call__(self, *args: Any, **kwargs: Any) -> Any: """ This method should take raw model outputs as inputs, and return values that measure the models' quality. """ @@ -45,7 +48,9 @@ class IterationMetric(Metric): Subclasses typically implement the `_compute_tensor` function for the actual tensor computation logic. """ - def __call__(self, y_pred: TensorOrList, y: Optional[TensorOrList] = None): + def __call__( + self, y_pred: TensorOrList, y: TensorOrList | None = None + ) -> torch.Tensor | Sequence[torch.Tensor | Sequence[torch.Tensor]]: """ Execute basic computation for model prediction `y_pred` and ground truth `y` (optional). It supports inputs of a list of "channel-first" Tensor and a "batch-first" Tensor. @@ -71,7 +76,9 @@ def __call__(self, y_pred: TensorOrList, y: Optional[TensorOrList] = None): return self._compute_tensor(y_pred.detach(), y_) raise ValueError("y_pred or y must be a list/tuple of `channel-first` Tensors or a `batch-first` Tensor.") - def _compute_list(self, y_pred: TensorOrList, y: Optional[TensorOrList] = None): + def _compute_list( + self, y_pred: TensorOrList, y: TensorOrList | None = None + ) -> torch.Tensor | list[torch.Tensor | Sequence[torch.Tensor]]: """ Execute the metric computation for `y_pred` and `y` in a list of "channel-first" tensors. @@ -92,14 +99,14 @@ def _compute_list(self, y_pred: TensorOrList, y: Optional[TensorOrList] = None): # concat the list of results (e.g. a batch of evaluation scores) if isinstance(ret[0], torch.Tensor): - return torch.cat(ret, dim=0) + return torch.cat(ret, dim=0) # type: ignore[arg-type] # the result is a list of sequence of tensors (e.g. a batch of multi-class results) if isinstance(ret[0], (list, tuple)) and all(isinstance(i, torch.Tensor) for i in ret[0]): return [torch.cat(batch_i, dim=0) for batch_i in zip(*ret)] return ret @abstractmethod - def _compute_tensor(self, y_pred: torch.Tensor, y: Optional[torch.Tensor] = None): + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor | None = None) -> TensorOrList: """ Computation logic for `y_pred` and `y` of an iteration, the data should be "batch-first" Tensors. A subclass should implement its own computation logic. @@ -172,8 +179,8 @@ def __init__(self) -> None: `self._buffers` are local buffers, they are not usually used directly. `self._sync_buffers` are the buffers with all the results across all the nodes. """ - self._buffers: Optional[List[List[torch.Tensor]]] = None - self._synced_tensors: Optional[List[Optional[torch.Tensor]]] = None + self._buffers: list[list[torch.Tensor]] | None = None + self._synced_tensors: list[torch.Tensor | None] | None = None self._synced: bool = False self.reset() @@ -186,7 +193,7 @@ def reset(self): self._synced_tensors = None self._synced = False - def extend(self, *data) -> None: + def extend(self, *data: Any) -> None: """ Extend the local buffers with new ("batch-first") data. A buffer will be allocated for each `data` item. @@ -210,7 +217,7 @@ def extend(self, *data) -> None: ) from e self._synced = False - def append(self, *data) -> None: + def append(self, *data: Any) -> None: """ Add samples to the local cumulative buffers. A buffer will be allocated for each `data` item. @@ -231,7 +238,7 @@ def append(self, *data) -> None: self._synced = False @abstractmethod - def aggregate(self, *args: Any, **kwargs: Any): + def aggregate(self, *args: Any, **kwargs: Any) -> Any: """ Aggregate final results based on the gathered buffers. This method is expected to use `get_buffer` to gather the local buffer contents. @@ -310,7 +317,9 @@ class CumulativeIterationMetric(Cumulative, IterationMetric): """ - def __call__(self, y_pred: TensorOrList, y: Optional[TensorOrList] = None): + def __call__( + self, y_pred: TensorOrList, y: TensorOrList | None = None + ) -> torch.Tensor | Sequence[torch.Tensor | Sequence[torch.Tensor]]: """ Execute basic computation for model prediction and ground truth. It can support both `list of channel-first Tensor` and `batch-first Tensor`. diff --git a/monai/metrics/panoptic_quality.py b/monai/metrics/panoptic_quality.py index dc8cfb84b2..05175ba0fb 100644 --- a/monai/metrics/panoptic_quality.py +++ b/monai/metrics/panoptic_quality.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Sequence, Union +from __future__ import annotations + +from collections.abc import Sequence import torch @@ -59,8 +61,8 @@ class PanopticQualityMetric(CumulativeIterationMetric): def __init__( self, num_classes: int, - metric_name: Union[Sequence[str], str] = "pq", - reduction: Union[MetricReduction, str] = MetricReduction.MEAN_BATCH, + metric_name: Sequence[str] | str = "pq", + reduction: MetricReduction | str = MetricReduction.MEAN_BATCH, match_iou_threshold: float = 0.5, smooth_numerator: float = 1e-6, ) -> None: @@ -71,7 +73,7 @@ def __init__( self.smooth_numerator = smooth_numerator self.metric_name = ensure_tuple(metric_name) - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ Args: y_pred: Predictions. It must be in the form of B2HW and have integer type. The first channel and the @@ -120,7 +122,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor return outputs - def aggregate(self, reduction: Union[MetricReduction, str, None] = None): + def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Tensor | list[torch.Tensor]: """ Execute reduction logic for the output of `compute_panoptic_quality`. @@ -158,7 +160,7 @@ def compute_panoptic_quality( match_iou_threshold: float = 0.5, smooth_numerator: float = 1e-6, output_confusion_matrix: bool = False, -): +) -> torch.Tensor: """Computes Panoptic Quality (PQ). If specifying `metric_name` to "SQ" or "RQ", Segmentation Quality (SQ) or Recognition Quality (RQ) will be returned instead. @@ -217,7 +219,7 @@ def compute_panoptic_quality( return torch.as_tensor(iou_sum / (tp + 0.5 * fp + 0.5 * fn + smooth_numerator), device=pred.device) -def _get_id_list(gt: torch.Tensor): +def _get_id_list(gt: torch.Tensor) -> list[torch.Tensor]: id_list = list(gt.unique()) # ensure id 0 is included if 0 not in id_list: @@ -226,13 +228,15 @@ def _get_id_list(gt: torch.Tensor): return id_list -def _get_pairwise_iou(pred: torch.Tensor, gt: torch.Tensor, device: Union[str, torch.device] = "cpu"): +def _get_pairwise_iou( + pred: torch.Tensor, gt: torch.Tensor, device: str | torch.device = "cpu" +) -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]: pred_id_list = _get_id_list(pred) true_id_list = _get_id_list(gt) pairwise_iou = torch.zeros([len(true_id_list) - 1, len(pred_id_list) - 1], dtype=torch.float, device=device) - true_masks: List[torch.Tensor] = [] - pred_masks: List[torch.Tensor] = [] + true_masks: list[torch.Tensor] = [] + pred_masks: list[torch.Tensor] = [] for t in true_id_list[1:]: t_mask = torch.as_tensor(gt == t, device=device).int() @@ -259,8 +263,8 @@ def _get_pairwise_iou(pred: torch.Tensor, gt: torch.Tensor, device: Union[str, t def _get_paired_iou( - pairwise_iou: torch.Tensor, match_iou_threshold: float = 0.5, device: Union[str, torch.device] = "cpu" -): + pairwise_iou: torch.Tensor, match_iou_threshold: float = 0.5, device: str | torch.device = "cpu" +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if match_iou_threshold >= 0.5: pairwise_iou[pairwise_iou <= match_iou_threshold] = 0.0 paired_true, paired_pred = torch.nonzero(pairwise_iou)[:, 0], torch.nonzero(pairwise_iou)[:, 1] @@ -280,7 +284,7 @@ def _get_paired_iou( return paired_iou, paired_true, paired_pred -def _check_panoptic_metric_name(metric_name: str): +def _check_panoptic_metric_name(metric_name: str) -> str: metric_name = metric_name.replace(" ", "_") metric_name = metric_name.lower() if metric_name in ["panoptic_quality", "pq"]: diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py index b714fede86..92c36da715 100644 --- a/monai/metrics/regression.py +++ b/monai/metrics/regression.py @@ -9,10 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import math from abc import abstractmethod +from collections.abc import Callable from functools import partial -from typing import Any, Tuple, Union +from typing import Any import torch import torch.nn.functional as F @@ -42,14 +45,14 @@ class RegressionMetric(CumulativeIterationMetric): """ - def __init__( - self, reduction: Union[MetricReduction, str] = MetricReduction.MEAN, get_not_nans: bool = False - ) -> None: + def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False) -> None: super().__init__() self.reduction = reduction self.get_not_nans = get_not_nans - def aggregate(self, reduction: Union[MetricReduction, str, None] = None): + def aggregate( + self, reduction: MetricReduction | str | None = None + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ Args: reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, @@ -75,7 +78,7 @@ def _check_shape(self, y_pred: torch.Tensor, y: torch.Tensor) -> None: def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor): raise ValueError("y_pred and y must be PyTorch Tensor.") self._check_shape(y_pred, y) @@ -103,9 +106,7 @@ class MSEMetric(RegressionMetric): """ - def __init__( - self, reduction: Union[MetricReduction, str] = MetricReduction.MEAN, get_not_nans: bool = False - ) -> None: + def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False) -> None: super().__init__(reduction=reduction, get_not_nans=get_not_nans) self.sq_func = partial(torch.pow, exponent=2.0) @@ -137,9 +138,7 @@ class MAEMetric(RegressionMetric): """ - def __init__( - self, reduction: Union[MetricReduction, str] = MetricReduction.MEAN, get_not_nans: bool = False - ) -> None: + def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False) -> None: super().__init__(reduction=reduction, get_not_nans=get_not_nans) self.abs_func = torch.abs @@ -172,9 +171,7 @@ class RMSEMetric(RegressionMetric): """ - def __init__( - self, reduction: Union[MetricReduction, str] = MetricReduction.MEAN, get_not_nans: bool = False - ) -> None: + def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False) -> None: super().__init__(reduction=reduction, get_not_nans=get_not_nans) self.sq_func = partial(torch.pow, exponent=2.0) @@ -214,10 +211,7 @@ class PSNRMetric(RegressionMetric): """ def __init__( - self, - max_val: Union[int, float], - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, - get_not_nans: bool = False, + self, max_val: int | float, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False ) -> None: super().__init__(reduction=reduction, get_not_nans=get_not_nans) self.max_val = max_val @@ -231,7 +225,7 @@ def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> Any: return 20 * math.log10(self.max_val) - 10 * torch.log10(mse_out) -def compute_mean_error_metrics(y_pred: torch.Tensor, y: torch.Tensor, func) -> torch.Tensor: +def compute_mean_error_metrics(y_pred: torch.Tensor, y: torch.Tensor, func: Callable) -> torch.Tensor: # reducing in only channel + spatial dimensions (not batch) # reduction of batch handled inside __call__() using do_metric_reduction() in respective calling class flt = partial(torch.flatten, start_dim=1) @@ -275,7 +269,7 @@ def __init__( k1: float = 0.01, k2: float = 0.03, spatial_dims: int = 2, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, ): super().__init__(reduction=reduction, get_not_nans=get_not_nans) @@ -286,8 +280,7 @@ def __init__( self.cov_norm = (win_size**2) / (win_size**2 - 1) self.w = torch.ones([1, 1] + [win_size for _ in range(spatial_dims)]) / win_size**spatial_dims - def _compute_intermediate_statistics(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, ...]: - + def _compute_intermediate_statistics(self, x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, ...]: data_range = self.data_range[(None,) * (self.spatial_dims + 2)] # determine whether to work with 2D convolution or 3D conv = getattr(F, f"conv{self.spatial_dims}d") @@ -335,7 +328,7 @@ def _compute_metric(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ssim = torch.stack( [ - SSIMMetric(self.data_range, self.win_size, self.k1, self.k2, self.spatial_dims)( + SSIMMetric(self.data_range, self.win_size, self.k1, self.k2, self.spatial_dims)( # type: ignore[misc] x[:, i, ...].unsqueeze(1), y[:, i, ...].unsqueeze(1) ) for i in range(x.shape[1]) @@ -354,7 +347,7 @@ def _compute_metric(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return ssim_per_batch - def _compute_metric_and_contrast(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def _compute_metric_and_contrast(self, x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Args: x: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D. diff --git a/monai/metrics/rocauc.py b/monai/metrics/rocauc.py index 2bb8dc2b32..56d9faa9dd 100644 --- a/monai/metrics/rocauc.py +++ b/monai/metrics/rocauc.py @@ -9,10 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import Union, cast +from typing import TYPE_CHECKING, cast import numpy as np + +if TYPE_CHECKING: + import numpy.typing as npt + import torch from monai.utils import Average, look_up_option @@ -44,14 +50,14 @@ class ROCAUCMetric(CumulativeIterationMetric): """ - def __init__(self, average: Union[Average, str] = Average.MACRO) -> None: + def __init__(self, average: Average | str = Average.MACRO) -> None: super().__init__() self.average = average - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override] return y_pred, y - def aggregate(self, average: Union[Average, str, None] = None): + def aggregate(self, average: Average | str | None = None) -> np.ndarray | float | npt.ArrayLike: """ Typically `y_pred` and `y` are stored in the cumulative buffers at each iteration, This function reads the buffers and computes the area under the ROC. @@ -106,7 +112,9 @@ def _calculate(y_pred: torch.Tensor, y: torch.Tensor) -> float: return auc / (nneg * (n - nneg)) -def compute_roc_auc(y_pred: torch.Tensor, y: torch.Tensor, average: Union[Average, str] = Average.MACRO): +def compute_roc_auc( + y_pred: torch.Tensor, y: torch.Tensor, average: Average | str = Average.MACRO +) -> np.ndarray | float | npt.ArrayLike: """Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to: `sklearn.metrics.roc_auc_score `_. @@ -170,5 +178,5 @@ def compute_roc_auc(y_pred: torch.Tensor, y: torch.Tensor, average: Union[Averag return np.mean(auc_values) if average == Average.WEIGHTED: weights = [sum(y_) for y_ in y] - return np.average(auc_values, weights=weights) + return np.average(auc_values, weights=weights) # type: ignore[no-any-return] raise ValueError(f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].') diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py index 80869ce583..12c47dec8d 100644 --- a/monai/metrics/surface_dice.py +++ b/monai/metrics/surface_dice.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import List, Union import numpy as np import torch @@ -53,10 +54,10 @@ class SurfaceDiceMetric(CumulativeIterationMetric): def __init__( self, - class_thresholds: List[float], + class_thresholds: list[float], include_background: bool = False, distance_metric: str = "euclidean", - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, ) -> None: super().__init__() @@ -66,7 +67,7 @@ def __init__( self.reduction = reduction self.get_not_nans = get_not_nans - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] r""" Args: y_pred: Predicted segmentation, typically segmentation model output. @@ -86,7 +87,9 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor distance_metric=self.distance_metric, ) - def aggregate(self, reduction: Union[MetricReduction, str, None] = None): + def aggregate( + self, reduction: MetricReduction | str | None = None + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: r""" Aggregates the output of `_compute_tensor`. @@ -111,10 +114,10 @@ def aggregate(self, reduction: Union[MetricReduction, str, None] = None): def compute_surface_dice( y_pred: torch.Tensor, y: torch.Tensor, - class_thresholds: List[float], + class_thresholds: list[float], include_background: bool = False, distance_metric: str = "euclidean", -): +) -> torch.Tensor: r""" This function computes the (Normalized) Surface Dice (NSD) between the two tensors `y_pred` (referred to as :math:`\hat{Y}`) and `y` (referred to as :math:`Y`). This metric determines which fraction of a segmentation @@ -237,4 +240,4 @@ def compute_surface_dice( else: nsd[b, c] = boundary_correct / boundary_complete - return convert_data_type(nsd, torch.Tensor)[0] + return convert_data_type(nsd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0] diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 8bb688b4e0..f1b4979466 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import Union import numpy as np import torch @@ -58,7 +59,7 @@ def __init__( include_background: bool = False, symmetric: bool = False, distance_metric: str = "euclidean", - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, ) -> None: super().__init__() @@ -68,7 +69,7 @@ def __init__( self.reduction = reduction self.get_not_nans = get_not_nans - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ Args: y_pred: input data to compute, typical segmentation model output. @@ -94,7 +95,9 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor distance_metric=self.distance_metric, ) - def aggregate(self, reduction: Union[MetricReduction, str, None] = None): + def aggregate( + self, reduction: MetricReduction | str | None = None + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ Execute reduction logic for the output of `compute_average_surface_distance`. @@ -114,12 +117,12 @@ def aggregate(self, reduction: Union[MetricReduction, str, None] = None): def compute_average_surface_distance( - y_pred: Union[np.ndarray, torch.Tensor], - y: Union[np.ndarray, torch.Tensor], + y_pred: np.ndarray | torch.Tensor, + y: np.ndarray | torch.Tensor, include_background: bool = False, symmetric: bool = False, distance_metric: str = "euclidean", -): +) -> torch.Tensor: """ This function is used to compute the Average Surface Distance from `y_pred` to `y` under the default setting. @@ -144,10 +147,8 @@ def compute_average_surface_distance( if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) - if isinstance(y, torch.Tensor): - y = y.float() - if isinstance(y_pred, torch.Tensor): - y_pred = y_pred.float() + y_pred = convert_data_type(y_pred, output_type=torch.Tensor, dtype=torch.float)[0] + y = convert_data_type(y, output_type=torch.Tensor, dtype=torch.float)[0] if y.shape != y_pred.shape: raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") @@ -167,4 +168,4 @@ def compute_average_surface_distance( surface_distance = np.concatenate([surface_distance, surface_distance_2]) asd[b, c] = np.nan if surface_distance.shape == (0,) else surface_distance.mean() - return convert_data_type(asd, torch.Tensor)[0] + return convert_data_type(asd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0] diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 6219663756..d0b5c28744 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -9,12 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import Tuple, Union +from typing import Any import numpy as np import torch +from monai.config import NdarrayOrTensor, NdarrayTensor from monai.transforms.croppad.array import SpatialCrop from monai.transforms.utils import generate_spatial_bounding_box from monai.utils import MetricReduction, convert_data_type, look_up_option, optional_import @@ -26,7 +29,7 @@ __all__ = ["ignore_background", "do_metric_reduction", "get_mask_edges", "get_surface_distance", "is_binary_tensor"] -def ignore_background(y_pred: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor]): +def ignore_background(y_pred: NdarrayTensor, y: NdarrayTensor) -> tuple[NdarrayTensor, NdarrayTensor]: """ This function is used to remove background (the first channel) for `y_pred` and `y`. @@ -38,12 +41,14 @@ def ignore_background(y_pred: Union[np.ndarray, torch.Tensor], y: Union[np.ndarr """ - y = y[:, 1:] if y.shape[1] > 1 else y - y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred + y = y[:, 1:] if y.shape[1] > 1 else y # type: ignore[assignment] + y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred # type: ignore[assignment] return y_pred, y -def do_metric_reduction(f: torch.Tensor, reduction: Union[MetricReduction, str] = MetricReduction.MEAN): +def do_metric_reduction( + f: torch.Tensor, reduction: MetricReduction | str = MetricReduction.MEAN +) -> tuple[torch.Tensor | Any, torch.Tensor]: """ This function is to do the metric reduction for calculated `not-nan` metrics of each sample's each class. The function also returns `not_nans`, which counts the number of not nans for the metric. @@ -103,7 +108,9 @@ def do_metric_reduction(f: torch.Tensor, reduction: Union[MetricReduction, str] return f, not_nans -def get_mask_edges(seg_pred, seg_gt, label_idx: int = 1, crop: bool = True) -> Tuple[np.ndarray, np.ndarray]: +def get_mask_edges( + seg_pred: NdarrayOrTensor, seg_gt: NdarrayOrTensor, label_idx: int = 1, crop: bool = True +) -> tuple[np.ndarray, np.ndarray]: """ Do binary erosion and use XOR for input to get the edges. This function is helpful to further calculate metrics such as Average Surface @@ -155,8 +162,8 @@ def get_mask_edges(seg_pred, seg_gt, label_idx: int = 1, crop: bool = True) -> T seg_pred, seg_gt = np.expand_dims(seg_pred, axis=channel_dim), np.expand_dims(seg_gt, axis=channel_dim) box_start, box_end = generate_spatial_bounding_box(np.asarray(seg_pred | seg_gt)) cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) - seg_pred = convert_data_type(np.squeeze(cropper(seg_pred), axis=channel_dim), np.ndarray)[0] - seg_gt = convert_data_type(np.squeeze(cropper(seg_gt), axis=channel_dim), np.ndarray)[0] + seg_pred = convert_data_type(np.squeeze(cropper(seg_pred), axis=channel_dim), np.ndarray)[0] # type: ignore[arg-type] + seg_gt = convert_data_type(np.squeeze(cropper(seg_gt), axis=channel_dim), np.ndarray)[0] # type: ignore[arg-type] # Do binary erosion and use XOR to get edges edges_pred = binary_erosion(seg_pred) ^ seg_pred @@ -200,7 +207,7 @@ def get_surface_distance(seg_pred: np.ndarray, seg_gt: np.ndarray, distance_metr return np.asarray(dis[seg_pred]) -def is_binary_tensor(input: torch.Tensor, name: str): +def is_binary_tensor(input: torch.Tensor, name: str) -> None: """Determines whether the input tensor is torch binary tensor or not. Args: @@ -210,8 +217,8 @@ def is_binary_tensor(input: torch.Tensor, name: str): Raises: ValueError: if `input` is not a PyTorch Tensor. - Returns: - Union[str, None]: warning message, if the tensor is not binary. Otherwise, None. + Note: + A warning message is printed, if the tensor is not binary. """ if not isinstance(input, torch.Tensor): raise ValueError(f"{name} must be of type PyTorch Tensor.") @@ -219,7 +226,7 @@ def is_binary_tensor(input: torch.Tensor, name: str): warnings.warn(f"{name} should be a binarized tensor.") -def remap_instance_id(pred: torch.Tensor, by_size: bool = False): +def remap_instance_id(pred: torch.Tensor, by_size: bool = False) -> torch.Tensor: """ This function is used to rename all instance id of `pred`, so that the id is contiguous. diff --git a/monai/metrics/wrapper.py b/monai/metrics/wrapper.py new file mode 100644 index 0000000000..46a9d5fb73 --- /dev/null +++ b/monai/metrics/wrapper.py @@ -0,0 +1,306 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import cast + +import torch + +from monai.metrics.utils import do_metric_reduction, ignore_background, is_binary_tensor +from monai.utils import MetricReduction, convert_to_numpy, convert_to_tensor, optional_import + +from .metric import CumulativeIterationMetric + +BinaryPairwiseMeasures, _ = optional_import("MetricsReloaded.metrics.pairwise_measures", name="BinaryPairwiseMeasures") +MultiClassPairwiseMeasures, _ = optional_import( + "MetricsReloaded.metrics.pairwise_measures", name="MultiClassPairwiseMeasures" +) + +__all__ = ["MetricsReloadedBinary", "MetricsReloadedCategorical"] + + +class MetricsReloadedWrapper(CumulativeIterationMetric): + """Base class for defining MetricsReloaded metrics as a CumulativeIterationMetric. + + Args: + metric_name: Name of a metric from the MetricsReloaded package. + include_background: whether to skip computation on the first channel of + the predicted output. Defaults to ``True``. + reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. + get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). + Here `not_nans` count the number of not nans for the metric, + thus its shape equals to the shape of the metric. + + """ + + def __init__( + self, + metric_name: str, + include_background: bool = True, + reduction: MetricReduction | str = MetricReduction.MEAN, + get_not_nans: bool = False, + ) -> None: + super().__init__() + self.metric_name = metric_name + self.include_background = include_background + self.reduction = reduction + self.get_not_nans = get_not_nans + + def aggregate( + self, reduction: MetricReduction | str | None = None + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + data = self.get_buffer() + if not isinstance(data, torch.Tensor): + raise ValueError("the data to aggregate must be PyTorch Tensor.") + # do metric reduction + f, not_nans = do_metric_reduction(data, reduction or self.reduction) + return (f, not_nans) if self.get_not_nans else f + + def prepare_onehot(self, y_pred, y): + """Prepares onehot encoded input for metric call.""" + is_binary_tensor(y_pred, "y_pred") + is_binary_tensor(y, "y") + y = y.float() + y_pred = y_pred.float() + if not self.include_background: + y_pred, y = ignore_background(y_pred=y_pred, y=y) + return y_pred, y, y_pred.device + + +class MetricsReloadedBinary(MetricsReloadedWrapper): + """ + Wraps the binary pairwise metrics of MetricsReloaded. + + Args: + metric_name: Name of a binary metric from the MetricsReloaded package. + include_background: whether to skip computation on the first channel of + the predicted output. Defaults to ``True``. + reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. + get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). + Here `not_nans` count the number of not nans for the metric, + thus its shape equals to the shape of the metric. + + Example: + + .. code-block:: python + + import torch + from monai.metrics import MetricsReloadedBinary + + metric_name = "Cohens Kappa" + metric = MetricsReloadedBinary(metric_name=metric_name) + + # first iteration + # shape [batch=1, channel=1, 2, 2] + y_pred = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]) + y = torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]) + print(metric(y_pred, y)) + + # second iteration + # shape [batch=1, channel=1, 2, 2] + y_pred = torch.tensor([[[[1.0, 0.0], [0.0, 0.0]]]]) + y = torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]) + print(metric(y_pred, y)) + + # aggregate + # shape ([batch=2, channel=1]) + print(metric.aggregate(reduction="none")) # tensor([[0.5], [0.2]]) + + # reset + metric.reset() + + """ + + def __init__( + self, + metric_name: str, + include_background: bool = True, + reduction: MetricReduction | str = MetricReduction.MEAN, + get_not_nans: bool = False, + ) -> None: + super().__init__( + metric_name=metric_name, + include_background=include_background, + reduction=reduction, + get_not_nans=get_not_nans, + ) + + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] + """Computes a binary (single-class) MetricsReloaded metric from a batch of + predictions and references. + + Args: + y_pred: Prediction with dimensions (batch, channel, *spatial), where channel=1. + The values should be binarized. + y: Ground-truth with dimensions (batch, channel, *spatial), where channel=1. + The values should be binarized. + + Raises: + ValueError: when `y` or `y_pred` is not a binarized tensor. + ValueError: when `y_pred` has less than three dimensions. + ValueError: when second dimension ~= 1 + + """ + # Preprocess + y_pred, y, device = self.prepare_onehot(y_pred, y) + + # Sanity check + dims = y_pred.ndimension() + if dims < 3: + raise ValueError(f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.") + if y_pred.shape[1] != 1 or y.shape[1] != 1: + raise ValueError(f"y_pred.shape[1]={y_pred.shape[1]} and y.shape[1]={y.shape[1]} should be one.") + + # To numpy array + y_pred = convert_to_numpy(y_pred) + y = convert_to_numpy(y) + + # Create binary pairwise metric object + bpm = BinaryPairwiseMeasures(y_pred, y, axis=tuple(range(2, dims)), smooth_dr=1e-5) + + # Is requested metric available? + if self.metric_name not in bpm.metrics: + raise ValueError(f"Unsupported metric: {self.metric_name}") + + # Compute metric + metric = bpm.metrics[self.metric_name]() + + # Return metric as tensor + return convert_to_tensor(metric, device=device) # type: ignore[no-any-return] + + +class MetricsReloadedCategorical(MetricsReloadedWrapper): + """ + Wraps the categorical pairwise metrics of MetricsReloaded. + + + Args: + metric_name: Name of a categorical metric from the MetricsReloaded package. + include_background: whether to skip computation on the first channel of + the predicted output. Defaults to ``True``. + reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. + get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). + Here `not_nans` count the number of not nans for the metric, + thus its shape equals to the shape of the metric. + smooth_dr: a small constant added to the denominator to avoid nan. OBS: should be greater than zero. + + Example: + + .. code-block:: python + + import torch + from monai.metrics import MetricsReloadedCategorical + + metric_name = "Weighted Cohens Kappa" + metric = MetricsReloadedCategorical(metric_name=metric_name) + + # first iteration + # shape [bach=1, channel=3, 2, 2] + y_pred = torch.tensor([[[[0, 0], [0, 1]], [[0, 0], [0, 0]], [[1, 1], [1, 0]]]]) + y = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]]) + print(metric(y_pred, y)) + + # second iteration + # shape [batch=1, channel=3, 2, 2] + y_pred = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [1, 0]], [[0, 0], [0, 0]]]]) + y = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]]) + print(metric(y_pred, y)) + + # aggregate + # shape ([batch=2, channel=1]) + print(metric.aggregate(reduction="none")) # tensor([[0.2727], [0.6000]]) + + # reset + metric.reset() + + """ + + def __init__( + self, + metric_name: str, + include_background: bool = True, + reduction: MetricReduction | str = MetricReduction.MEAN, + get_not_nans: bool = False, + smooth_dr: float = 1e-5, + ) -> None: + super().__init__( + metric_name=metric_name, + include_background=include_background, + reduction=reduction, + get_not_nans=get_not_nans, + ) + self.smooth_dr = smooth_dr + + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] + """Computes a categorical (multi-class) MetricsReloaded metric from a batch of + predictions and references. + + Args: + y_pred: Prediction with dimensions (batch, channel, *spatial). The values should be + one-hot encoded and binarized. + y: Ground-truth with dimensions (batch, channel, *spatial). The values should be 1 + one-hot encoded and binarized. + + Raises: + ValueError: when `y` or `y_pred` is not a binarized tensor. + ValueError: when `y_pred` has less than three dimensions. + + """ + # Preprocess + y_pred, y, device = self.prepare_onehot(y_pred, y) + + # Sanity check + dims = y_pred.ndimension() + if dims < 3: + raise ValueError(f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.") + + num_classes = y_pred.shape[1] + + # Reshape and permute for compatible dimension with MetricsReloaded + y_pred = y_pred.reshape(y_pred.shape[0], y_pred.shape[1], -1) + y_pred = y_pred.permute((0, 2, 1)) + y = y.reshape(y.shape[0], y.shape[1], -1) + y = y.permute((0, 2, 1)) + dims = y_pred.ndimension() + + # To numpy array + y_pred = convert_to_numpy(y_pred) + y = convert_to_numpy(y) + + # Create categorical pairwise metric object + bpm = MultiClassPairwiseMeasures( + y_pred, + y, + axis=tuple(range(1, dims)), + smooth_dr=self.smooth_dr, + list_values=list(range(num_classes)), + is_onehot=True, + ) + + # Is requested metric available? + if self.metric_name not in bpm.metrics: + raise ValueError(f"Unsupported metric: {self.metric_name}") + + # Compute metric + metric = bpm.metrics[self.metric_name]() + + # Put back singleton channel dimension + metric = metric[..., None] + + # Return metric as tensor + return cast(torch.Tensor, convert_to_tensor(metric, device=device)) diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index b2dc907c18..f109e97fd0 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .utils import ( convert_to_torchscript, copy_model_state, @@ -25,7 +27,6 @@ replace_modules_temp, save_state, set_named_module, - slice_channels, to_norm_affine, train_mode, ) diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index 61d5fcf8f5..e67cb3376f 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .acti_norm import ADN -from .activation import MemoryEfficientSwish, Mish, Swish +from .activation import GEGLU, MemoryEfficientSwish, Mish, Swish from .aspp import SimpleASPP from .backbone_fpn_utils import BackboneWithFPN from .convolutions import Convolution, ResidualUnit diff --git a/monai/networks/blocks/acti_norm.py b/monai/networks/blocks/acti_norm.py index 6aeaa7d275..37bbd32f95 100644 --- a/monai/networks/blocks/acti_norm.py +++ b/monai/networks/blocks/acti_norm.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from __future__ import annotations import torch.nn as nn @@ -69,12 +69,12 @@ class ADN(nn.Sequential): def __init__( self, ordering: str = "NDA", - in_channels: Optional[int] = None, - act: Optional[Union[Tuple, str]] = "RELU", - norm: Optional[Union[Tuple, str]] = None, - norm_dim: Optional[int] = None, - dropout: Optional[Union[Tuple, str, float]] = None, - dropout_dim: Optional[int] = None, + in_channels: int | None = None, + act: tuple | str | None = "RELU", + norm: tuple | str | None = None, + norm_dim: int | None = None, + dropout: tuple | str | float | None = None, + dropout_dim: int | None = None, ) -> None: super().__init__() diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index 1526b37056..1e5e979dff 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import torch from torch import nn @@ -160,3 +162,23 @@ def __init__(self, inplace: bool = False): def forward(self, input: torch.Tensor): return monai_mish(input, self.inplace) + + +class GEGLU(nn.Module): + r"""Applies the element-wise function: + + .. math:: + \text{GEGLU}(x) = x_1 * \text{Sigmoid}(x_2) + + where :math:`x_1` and :math:`x_2` are split from the input tensor along the last dimension. + + Citation: GLU Variants Improve Transformer, Noam Shazeer, 2020, https://arxiv.org/abs/2002.05202. + + Shape: + - Input: :math:`(N, *, 2 * D)` + - Output: :math:`(N, *, D)`, where `*` means, any number of additional dimensions + """ + + def forward(self, input: torch.Tensor): + x, gate = input.chunk(2, dim=-1) + return x * nn.functional.gelu(gate) diff --git a/monai/networks/blocks/aspp.py b/monai/networks/blocks/aspp.py index 8d43530fa7..1f6c76c3af 100644 --- a/monai/networks/blocks/aspp.py +++ b/monai/networks/blocks/aspp.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple, Union +from __future__ import annotations + +from collections.abc import Sequence import torch import torch.nn as nn @@ -37,8 +39,8 @@ def __init__( conv_out_channels: int, kernel_sizes: Sequence[int] = (1, 3, 3, 3), dilations: Sequence[int] = (1, 2, 4, 6), - norm_type: Optional[Union[Tuple, str]] = "BATCH", - acti_type: Optional[Union[Tuple, str]] = "LEAKYRELU", + norm_type: tuple | str | None = "BATCH", + acti_type: tuple | str | None = "LEAKYRELU", bias: bool = False, ) -> None: """ diff --git a/monai/networks/blocks/backbone_fpn_utils.py b/monai/networks/blocks/backbone_fpn_utils.py index 145a4ac2e1..824b31a83b 100644 --- a/monai/networks/blocks/backbone_fpn_utils.py +++ b/monai/networks/blocks/backbone_fpn_utils.py @@ -50,7 +50,7 @@ https://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/backbone_utils.py """ -from typing import Dict, List, Optional, Union +from __future__ import annotations from torch import Tensor, nn @@ -89,11 +89,11 @@ class BackboneWithFPN(nn.Module): def __init__( self, backbone: nn.Module, - return_layers: Dict[str, str], - in_channels_list: List[int], + return_layers: dict[str, str], + in_channels_list: list[int], out_channels: int, - spatial_dims: Union[int, None] = None, - extra_blocks: Optional[ExtraFPNBlock] = None, + spatial_dims: int | None = None, + extra_blocks: ExtraFPNBlock | None = None, ) -> None: super().__init__() @@ -120,7 +120,7 @@ def __init__( ) self.out_channels = out_channels - def forward(self, x: Tensor) -> Dict[str, Tensor]: + def forward(self, x: Tensor) -> dict[str, Tensor]: """ Computes the resulted feature maps of the network. @@ -131,7 +131,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: feature maps after FPN layers. They are ordered from highest resolution first. """ x = self.body(x) # backbone - y: Dict[str, Tensor] = self.fpn(x) # FPN + y: dict[str, Tensor] = self.fpn(x) # FPN return y @@ -139,8 +139,8 @@ def _resnet_fpn_extractor( backbone: resnet.ResNet, spatial_dims: int, trainable_layers: int = 5, - returned_layers: Optional[List[int]] = None, - extra_blocks: Optional[ExtraFPNBlock] = None, + returned_layers: list[int] | None = None, + extra_blocks: ExtraFPNBlock | None = None, ) -> BackboneWithFPN: """ Same code as https://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/backbone_utils.py diff --git a/monai/networks/blocks/convolutions.py b/monai/networks/blocks/convolutions.py index 55735e2f58..8b18614364 100644 --- a/monai/networks/blocks/convolutions.py +++ b/monai/networks/blocks/convolutions.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple, Union +from __future__ import annotations + +from collections.abc import Sequence import numpy as np import torch @@ -98,20 +100,20 @@ def __init__( spatial_dims: int, in_channels: int, out_channels: int, - strides: Union[Sequence[int], int] = 1, - kernel_size: Union[Sequence[int], int] = 3, + strides: Sequence[int] | int = 1, + kernel_size: Sequence[int] | int = 3, adn_ordering: str = "NDA", - act: Optional[Union[Tuple, str]] = "PRELU", - norm: Optional[Union[Tuple, str]] = "INSTANCE", - dropout: Optional[Union[Tuple, str, float]] = None, - dropout_dim: Optional[int] = 1, - dilation: Union[Sequence[int], int] = 1, + act: tuple | str | None = "PRELU", + norm: tuple | str | None = "INSTANCE", + dropout: tuple | str | float | None = None, + dropout_dim: int | None = 1, + dilation: Sequence[int] | int = 1, groups: int = 1, bias: bool = True, conv_only: bool = False, is_transposed: bool = False, - padding: Optional[Union[Sequence[int], int]] = None, - output_padding: Optional[Union[Sequence[int], int]] = None, + padding: Sequence[int] | int | None = None, + output_padding: Sequence[int] | int | None = None, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -248,18 +250,18 @@ def __init__( spatial_dims: int, in_channels: int, out_channels: int, - strides: Union[Sequence[int], int] = 1, - kernel_size: Union[Sequence[int], int] = 3, + strides: Sequence[int] | int = 1, + kernel_size: Sequence[int] | int = 3, subunits: int = 2, adn_ordering: str = "NDA", - act: Optional[Union[Tuple, str]] = "PRELU", - norm: Optional[Union[Tuple, str]] = "INSTANCE", - dropout: Optional[Union[Tuple, str, float]] = None, - dropout_dim: Optional[int] = 1, - dilation: Union[Sequence[int], int] = 1, + act: tuple | str | None = "PRELU", + norm: tuple | str | None = "INSTANCE", + dropout: tuple | str | float | None = None, + dropout_dim: int | None = 1, + dilation: Sequence[int] | int = 1, bias: bool = True, last_conv_only: bool = False, - padding: Optional[Union[Sequence[int], int]] = None, + padding: Sequence[int] | int | None = None, ) -> None: super().__init__() self.spatial_dims = spatial_dims diff --git a/monai/networks/blocks/crf.py b/monai/networks/blocks/crf.py index b6382adf5f..398b89882a 100644 --- a/monai/networks/blocks/crf.py +++ b/monai/networks/blocks/crf.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from __future__ import annotations import torch from torch.nn.functional import softmax @@ -44,7 +44,7 @@ def __init__( bilateral_color_sigma: float = 0.5, gaussian_spatial_sigma: float = 5.0, update_factor: float = 3.0, - compatibility_matrix: Optional[torch.Tensor] = None, + compatibility_matrix: torch.Tensor | None = None, ): """ Args: @@ -92,7 +92,6 @@ def forward(self, input_tensor: torch.Tensor, reference_tensor: torch.Tensor): # mean field loop for _ in range(self.iterations): - # message passing step for both kernels bilateral_output = PHLFilter.apply(output_tensor, bilateral_features) gaussian_output = PHLFilter.apply(output_tensor, gaussian_features) diff --git a/monai/networks/blocks/denseblock.py b/monai/networks/blocks/denseblock.py index dafd8d03a6..afd3183581 100644 --- a/monai/networks/blocks/denseblock.py +++ b/monai/networks/blocks/denseblock.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple, Union +from __future__ import annotations + +from collections.abc import Sequence import torch import torch.nn as nn @@ -66,16 +68,15 @@ def __init__( spatial_dims: int, in_channels: int, channels: Sequence[int], - dilations: Optional[Sequence[int]] = None, - kernel_size: Union[Sequence[int], int] = 3, + dilations: Sequence[int] | None = None, + kernel_size: Sequence[int] | int = 3, num_res_units: int = 0, adn_ordering: str = "NDA", - act: Optional[Union[Tuple, str]] = Act.PRELU, - norm: Optional[Union[Tuple, str]] = Norm.INSTANCE, - dropout: Optional[int] = None, + act: tuple | str | None = Act.PRELU, + norm: tuple | str | None = Norm.INSTANCE, + dropout: int | None = None, bias: bool = True, ): - self.spatial_dims = spatial_dims self.kernel_size = kernel_size self.num_res_units = num_res_units diff --git a/monai/networks/blocks/dints_block.py b/monai/networks/blocks/dints_block.py index 1823845adf..d3ac7cf04b 100644 --- a/monai/networks/blocks/dints_block.py +++ b/monai/networks/blocks/dints_block.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from __future__ import annotations import torch @@ -29,8 +29,8 @@ def __init__( in_channel: int, out_channel: int, spatial_dims: int = 3, - act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + act_name: tuple | str = "RELU", + norm_name: tuple | str = ("INSTANCE", {"affine": True}), ): """ Args: @@ -80,8 +80,8 @@ def __init__( in_channel: int, out_channel: int, spatial_dims: int = 3, - act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + act_name: tuple | str = "RELU", + norm_name: tuple | str = ("INSTANCE", {"affine": True}), ): """ Args: @@ -148,8 +148,8 @@ def __init__( kernel_size: int, padding: int, mode: int = 0, - act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + act_name: tuple | str = "RELU", + norm_name: tuple | str = ("INSTANCE", {"affine": True}), ): """ Args: @@ -233,8 +233,8 @@ def __init__( kernel_size: int = 3, padding: int = 1, spatial_dims: int = 3, - act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + act_name: tuple | str = "RELU", + norm_name: tuple | str = ("INSTANCE", {"affine": True}), ): """ Args: diff --git a/monai/networks/blocks/downsample.py b/monai/networks/blocks/downsample.py index 9b0d5dd4b9..2a6a60ff8a 100644 --- a/monai/networks/blocks/downsample.py +++ b/monai/networks/blocks/downsample.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Union +from __future__ import annotations + +from collections.abc import Sequence import torch import torch.nn as nn @@ -27,9 +29,9 @@ class MaxAvgPool(nn.Module): def __init__( self, spatial_dims: int, - kernel_size: Union[Sequence[int], int], - stride: Optional[Union[Sequence[int], int]] = None, - padding: Union[Sequence[int], int] = 0, + kernel_size: Sequence[int] | int, + stride: Sequence[int] | int | None = None, + padding: Sequence[int] | int = 0, ceil_mode: bool = False, ) -> None: """ diff --git a/monai/networks/blocks/dynunet_block.py b/monai/networks/blocks/dynunet_block.py index 894686b50a..12afab3464 100644 --- a/monai/networks/blocks/dynunet_block.py +++ b/monai/networks/blocks/dynunet_block.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple, Union +from __future__ import annotations + +from collections.abc import Sequence import numpy as np import torch @@ -43,11 +45,11 @@ def __init__( spatial_dims: int, in_channels: int, out_channels: int, - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - norm_name: Union[Tuple, str], - act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), - dropout: Optional[Union[Tuple, str, float]] = None, + kernel_size: Sequence[int] | int, + stride: Sequence[int] | int, + norm_name: tuple | str, + act_name: tuple | str = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), + dropout: tuple | str | float | None = None, ): super().__init__() self.conv1 = get_conv_layer( @@ -132,11 +134,11 @@ def __init__( spatial_dims: int, in_channels: int, out_channels: int, - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - norm_name: Union[Tuple, str], - act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), - dropout: Optional[Union[Tuple, str, float]] = None, + kernel_size: Sequence[int] | int, + stride: Sequence[int] | int, + norm_name: tuple | str, + act_name: tuple | str = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), + dropout: tuple | str | float | None = None, ): super().__init__() self.conv1 = get_conv_layer( @@ -200,12 +202,12 @@ def __init__( spatial_dims: int, in_channels: int, out_channels: int, - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - upsample_kernel_size: Union[Sequence[int], int], - norm_name: Union[Tuple, str], - act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), - dropout: Optional[Union[Tuple, str, float]] = None, + kernel_size: Sequence[int] | int, + stride: Sequence[int] | int, + upsample_kernel_size: Sequence[int] | int, + norm_name: tuple | str, + act_name: tuple | str = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), + dropout: tuple | str | float | None = None, trans_bias: bool = False, ): super().__init__() @@ -244,7 +246,7 @@ def forward(self, inp, skip): class UnetOutBlock(nn.Module): def __init__( - self, spatial_dims: int, in_channels: int, out_channels: int, dropout: Optional[Union[Tuple, str, float]] = None + self, spatial_dims: int, in_channels: int, out_channels: int, dropout: tuple | str | float | None = None ): super().__init__() self.conv = get_conv_layer( @@ -268,11 +270,11 @@ def get_conv_layer( spatial_dims: int, in_channels: int, out_channels: int, - kernel_size: Union[Sequence[int], int] = 3, - stride: Union[Sequence[int], int] = 1, - act: Optional[Union[Tuple, str]] = Act.PRELU, - norm: Optional[Union[Tuple, str]] = Norm.INSTANCE, - dropout: Optional[Union[Tuple, str, float]] = None, + kernel_size: Sequence[int] | int = 3, + stride: Sequence[int] | int = 1, + act: tuple | str | None = Act.PRELU, + norm: tuple | str | None = Norm.INSTANCE, + dropout: tuple | str | float | None = None, bias: bool = False, conv_only: bool = True, is_transposed: bool = False, @@ -298,10 +300,7 @@ def get_conv_layer( ) -def get_padding( - kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int] -) -> Union[Tuple[int, ...], int]: - +def get_padding(kernel_size: Sequence[int] | int, stride: Sequence[int] | int) -> tuple[int, ...] | int: kernel_size_np = np.atleast_1d(kernel_size) stride_np = np.atleast_1d(stride) padding_np = (kernel_size_np - stride_np + 1) / 2 @@ -313,8 +312,8 @@ def get_padding( def get_output_padding( - kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], padding: Union[Sequence[int], int] -) -> Union[Tuple[int, ...], int]: + kernel_size: Sequence[int] | int, stride: Sequence[int] | int, padding: Sequence[int] | int +) -> tuple[int, ...] | int: kernel_size_np = np.atleast_1d(kernel_size) stride_np = np.atleast_1d(stride) padding_np = np.atleast_1d(padding) diff --git a/monai/networks/blocks/encoder.py b/monai/networks/blocks/encoder.py index 9d4dac8f57..419afec838 100644 --- a/monai/networks/blocks/encoder.py +++ b/monai/networks/blocks/encoder.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from abc import ABCMeta, abstractmethod -from typing import Dict, List, Tuple __all__ = ["BaseEncoder"] @@ -28,7 +29,7 @@ class BaseEncoder(metaclass=ABCMeta): @classmethod @abstractmethod - def get_encoder_parameters(cls) -> List[Dict]: + def get_encoder_parameters(cls) -> list[dict]: """ Get parameter list to initialize encoder networks. Each parameter dict must have `spatial_dims`, `in_channels` @@ -42,7 +43,7 @@ def get_encoder_parameters(cls) -> List[Dict]: @classmethod @abstractmethod - def num_channels_per_output(cls) -> List[Tuple[int, ...]]: + def num_channels_per_output(cls) -> list[tuple[int, ...]]: """ Get number of output features' channels. The reason that this function should return a list is that a @@ -56,7 +57,7 @@ def num_channels_per_output(cls) -> List[Tuple[int, ...]]: @classmethod @abstractmethod - def num_outputs(cls) -> List[int]: + def num_outputs(cls) -> list[int]: """ Get number of outputs of encoder. The reason that this function should return a list is that a @@ -70,7 +71,7 @@ def num_outputs(cls) -> List[int]: @classmethod @abstractmethod - def get_encoder_names(cls) -> List[str]: + def get_encoder_names(cls) -> list[str]: """ Get the name string of encoders which will be used to initialize flexible unet. diff --git a/monai/networks/blocks/fcn.py b/monai/networks/blocks/fcn.py index 5833d4a262..b44ea5f99a 100644 --- a/monai/networks/blocks/fcn.py +++ b/monai/networks/blocks/fcn.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Type +from __future__ import annotations import torch import torch.nn as nn @@ -38,7 +38,7 @@ def __init__(self, inplanes: int, planes: int, ks: int = 7): """ super().__init__() - conv2d_type: Type[nn.Conv2d] = Conv[Conv.CONV, 2] + conv2d_type: type[nn.Conv2d] = Conv[Conv.CONV, 2] self.conv_l1 = conv2d_type(in_channels=inplanes, out_channels=planes, kernel_size=(ks, 1), padding=(ks // 2, 0)) self.conv_l2 = conv2d_type(in_channels=planes, out_channels=planes, kernel_size=(1, ks), padding=(0, ks // 2)) self.conv_r1 = conv2d_type(in_channels=inplanes, out_channels=planes, kernel_size=(1, ks), padding=(0, ks // 2)) @@ -69,9 +69,9 @@ def __init__(self, planes: int): """ super().__init__() - relu_type: Type[nn.ReLU] = Act[Act.RELU] - conv2d_type: Type[nn.Conv2d] = Conv[Conv.CONV, 2] - norm2d_type: Type[nn.BatchNorm2d] = Norm[Norm.BATCH, 2] + relu_type: type[nn.ReLU] = Act[Act.RELU] + conv2d_type: type[nn.Conv2d] = Conv[Conv.CONV, 2] + norm2d_type: type[nn.BatchNorm2d] = Norm[Norm.BATCH, 2] self.bn = norm2d_type(num_features=planes) self.relu = relu_type(inplace=True) @@ -118,7 +118,7 @@ def __init__( ): super().__init__() - conv2d_type: Type[nn.Conv2d] = Conv[Conv.CONV, 2] + conv2d_type: type[nn.Conv2d] = Conv[Conv.CONV, 2] self.upsample_mode = upsample_mode self.conv2d_type = conv2d_type diff --git a/monai/networks/blocks/feature_pyramid_network.py b/monai/networks/blocks/feature_pyramid_network.py index adca3df6b4..cca7342078 100644 --- a/monai/networks/blocks/feature_pyramid_network.py +++ b/monai/networks/blocks/feature_pyramid_network.py @@ -50,8 +50,10 @@ https://github.com/pytorch/vision/blob/release/0.12/torchvision/ops/feature_pyramid_network.py """ +from __future__ import annotations + from collections import OrderedDict -from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from collections.abc import Callable import torch.nn.functional as F from torch import Tensor, nn @@ -68,7 +70,7 @@ class ExtraFPNBlock(nn.Module): Same code as https://github.com/pytorch/vision/blob/release/0.12/torchvision/ops/feature_pyramid_network.py """ - def forward(self, results: List[Tensor], x: List[Tensor], names: List[str]): + def forward(self, results: list[Tensor], x: list[Tensor], names: list[str]): """ Compute extended set of results of the FPN and their names. @@ -92,10 +94,10 @@ class LastLevelMaxPool(ExtraFPNBlock): def __init__(self, spatial_dims: int): super().__init__() - pool_type: Type[Union[nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d]] = Pool[Pool.MAX, spatial_dims] + pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims] self.maxpool = pool_type(kernel_size=1, stride=2, padding=0) - def forward(self, results: List[Tensor], x: List[Tensor], names: List[str]) -> Tuple[List[Tensor], List[str]]: + def forward(self, results: list[Tensor], x: list[Tensor], names: list[str]) -> tuple[list[Tensor], list[str]]: names.append("pool") results.append(self.maxpool(results[-1])) return results, names @@ -118,7 +120,7 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int): nn.init.constant_(module.bias, 0) self.use_P5 = in_channels == out_channels - def forward(self, results: List[Tensor], x: List[Tensor], names: List[str]) -> Tuple[List[Tensor], List[str]]: + def forward(self, results: list[Tensor], x: list[Tensor], names: list[str]) -> tuple[list[Tensor], list[str]]: p5, c5 = results[-1], x[-1] x5 = p5 if self.use_P5 else c5 p6 = self.p6(x5) @@ -170,9 +172,9 @@ class FeaturePyramidNetwork(nn.Module): def __init__( self, spatial_dims: int, - in_channels_list: List[int], + in_channels_list: list[int], out_channels: int, - extra_blocks: Optional[ExtraFPNBlock] = None, + extra_blocks: ExtraFPNBlock | None = None, ): super().__init__() @@ -189,7 +191,7 @@ def __init__( self.layer_blocks.append(layer_block_module) # initialize parameters now to avoid modifying the initialization of top_blocks - conv_type_: Type[nn.Module] = Conv[Conv.CONV, spatial_dims] + conv_type_: type[nn.Module] = Conv[Conv.CONV, spatial_dims] for m in self.modules(): if isinstance(m, conv_type_): nn.init.kaiming_uniform_(m.weight, a=1) # type: ignore @@ -228,7 +230,7 @@ def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor: out = module(x) return out - def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]: + def forward(self, x: dict[str, Tensor]) -> dict[str, Tensor]: """ Computes the FPN for a set of feature maps. @@ -240,7 +242,7 @@ def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]: """ # unpack OrderedDict into two lists for easier handling names = list(x.keys()) - x_values: List[Tensor] = list(x.values()) + x_values: list[Tensor] = list(x.values()) last_inner = self.get_result_from_inner_blocks(x_values[-1], -1) results = [] diff --git a/monai/networks/blocks/fft_utils_t.py b/monai/networks/blocks/fft_utils_t.py index 1283f05c6b..2aa4054f64 100644 --- a/monai/networks/blocks/fft_utils_t.py +++ b/monai/networks/blocks/fft_utils_t.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from __future__ import annotations import torch from torch import Tensor @@ -42,7 +42,7 @@ def roll_1d(x: Tensor, shift: int, shift_dim: int) -> Tensor: return torch.cat((right, left), dim=shift_dim) -def roll(x: Tensor, shift: List[int], shift_dims: List[int]) -> Tensor: +def roll(x: Tensor, shift: list[int], shift_dims: list[int]) -> Tensor: """ Similar to np.roll but applies to PyTorch Tensors @@ -66,7 +66,7 @@ def roll(x: Tensor, shift: List[int], shift_dims: List[int]) -> Tensor: return x -def fftshift(x: Tensor, shift_dims: List[int]) -> Tensor: +def fftshift(x: Tensor, shift_dims: list[int]) -> Tensor: """ Similar to np.fft.fftshift but applies to PyTorch Tensors @@ -88,7 +88,7 @@ def fftshift(x: Tensor, shift_dims: List[int]) -> Tensor: return roll(x, shift, shift_dims) -def ifftshift(x: Tensor, shift_dims: List[int]) -> Tensor: +def ifftshift(x: Tensor, shift_dims: list[int]) -> Tensor: """ Similar to np.fft.ifftshift but applies to PyTorch Tensors diff --git a/monai/networks/blocks/localnet_block.py b/monai/networks/blocks/localnet_block.py index 41b76c7d4c..11808eabf7 100644 --- a/monai/networks/blocks/localnet_block.py +++ b/monai/networks/blocks/localnet_block.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple, Type, Union +from __future__ import annotations + +from collections.abc import Sequence import torch from torch import nn @@ -24,9 +26,9 @@ def get_conv_block( spatial_dims: int, in_channels: int, out_channels: int, - kernel_size: Union[Sequence[int], int] = 3, - act: Optional[Union[Tuple, str]] = "RELU", - norm: Optional[Union[Tuple, str]] = "BATCH", + kernel_size: Sequence[int] | int = 3, + act: tuple | str | None = "RELU", + norm: tuple | str | None = "BATCH", ) -> nn.Module: padding = same_padding(kernel_size) mod: nn.Module = Convolution( @@ -44,7 +46,7 @@ def get_conv_block( def get_conv_layer( - spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int] = 3 + spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Sequence[int] | int = 3 ) -> nn.Module: padding = same_padding(kernel_size) mod: nn.Module = Convolution( @@ -71,7 +73,7 @@ def get_deconv_block(spatial_dims: int, in_channels: int, out_channels: int) -> class ResidualBlock(nn.Module): def __init__( - self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int] + self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Sequence[int] | int ) -> None: super().__init__() if in_channels != out_channels: @@ -121,7 +123,7 @@ class LocalNetDownSampleBlock(nn.Module): """ def __init__( - self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int] + self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Sequence[int] | int ) -> None: """ Args: @@ -141,7 +143,7 @@ def __init__( ) self.max_pool = Pool[Pool.MAX, spatial_dims](kernel_size=2) - def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, x) -> tuple[torch.Tensor, torch.Tensor]: """ Halves the spatial dimensions. A tuple of (x, mid) is returned: @@ -166,7 +168,7 @@ def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: class LocalNetUpSampleBlock(nn.Module): """ - A up-sample module that can be used for LocalNet, based on: + An up-sample module that can be used for LocalNet, based on: `Weakly-supervised convolutional neural networks for multimodal image registration `_. `Label-driven weakly-supervised learning for multimodal deformable image registration @@ -176,12 +178,21 @@ class LocalNetUpSampleBlock(nn.Module): DeepReg (https://github.com/DeepRegNet/DeepReg) """ - def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> None: + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + mode: str = "nearest", + align_corners: bool | None = None, + ) -> None: """ Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. + mode: interpolation mode of the additive upsampling, default to 'nearest'. + align_corners: whether to align corners for the additive upsampling, default to None. Raises: ValueError: when ``in_channels != 2 * out_channels`` """ @@ -199,9 +210,11 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> No f"got in_channels={in_channels}, out_channels={out_channels}" ) self.out_channels = out_channels + self.mode = mode + self.align_corners = align_corners - def addictive_upsampling(self, x, mid) -> torch.Tensor: - x = F.interpolate(x, mid.shape[2:]) + def additive_upsampling(self, x, mid) -> torch.Tensor: + x = F.interpolate(x, mid.shape[2:], mode=self.mode, align_corners=self.align_corners) # [(batch, out_channels, ...), (batch, out_channels, ...)] x = x.split(split_size=int(self.out_channels), dim=1) # (batch, out_channels, ...) @@ -226,7 +239,7 @@ def forward(self, x, mid) -> torch.Tensor: "expecting mid spatial dimensions be exactly the double of x spatial dimensions, " f"got x of shape {x.shape}, mid of shape {mid.shape}" ) - h0 = self.deconv_block(x) + self.addictive_upsampling(x, mid) + h0 = self.deconv_block(x) + self.additive_upsampling(x, mid) r1 = h0 + mid r2 = self.conv_block(h0) out: torch.Tensor = self.residual_block(r2, r1) @@ -250,7 +263,7 @@ def __init__( spatial_dims: int, in_channels: int, out_channels: int, - act: Optional[Union[Tuple, str]] = "RELU", + act: tuple | str | None = "RELU", initializer: str = "kaiming_uniform", ) -> None: """ @@ -265,7 +278,7 @@ def __init__( self.conv_block = get_conv_block( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, act=act, norm=None ) - conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] + conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims] for m in self.conv_block.modules(): if isinstance(m, conv_type): if initializer == "kaiming_uniform": diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py index 0feeb044f3..e3ab94b32a 100644 --- a/monai/networks/blocks/mlp.py +++ b/monai/networks/blocks/mlp.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from __future__ import annotations import torch.nn as nn @@ -26,19 +26,14 @@ class MLPBlock(nn.Module): """ def __init__( - self, - hidden_size: int, - mlp_dim: int, - dropout_rate: float = 0.0, - act: Union[Tuple, str] = "GELU", - dropout_mode="vit", + self, hidden_size: int, mlp_dim: int, dropout_rate: float = 0.0, act: tuple | str = "GELU", dropout_mode="vit" ) -> None: """ Args: hidden_size: dimension of hidden layer. mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used. dropout_rate: faction of the input units to drop. - act: activation type and arguments. Defaults to GELU. + act: activation type and arguments. Defaults to GELU. Also supports "GEGLU" and others. dropout_mode: dropout mode, can be "vit" or "swin". "vit" mode uses two dropout instances as implemented in https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87 @@ -53,7 +48,7 @@ def __init__( if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") mlp_dim = mlp_dim or hidden_size - self.linear1 = nn.Linear(hidden_size, mlp_dim) + self.linear1 = nn.Linear(hidden_size, mlp_dim) if act != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2) self.linear2 = nn.Linear(mlp_dim, hidden_size) self.fn = get_act_layer(act) self.drop1 = nn.Dropout(dropout_rate) diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index a611a30b15..a0699eb108 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Type, Union +from __future__ import annotations + +from collections.abc import Sequence import numpy as np import torch @@ -40,8 +42,8 @@ class PatchEmbeddingBlock(nn.Module): def __init__( self, in_channels: int, - img_size: Union[Sequence[int], int], - patch_size: Union[Sequence[int], int], + img_size: Sequence[int] | int, + patch_size: Sequence[int] | int, hidden_size: int, num_heads: int, pos_embed: str, @@ -137,10 +139,10 @@ class PatchEmbed(nn.Module): def __init__( self, - patch_size: Union[Sequence[int], int] = 2, + patch_size: Sequence[int] | int = 2, in_chans: int = 1, embed_dim: int = 48, - norm_layer: Type[LayerNorm] = nn.LayerNorm, + norm_layer: type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3, ) -> None: """ diff --git a/monai/networks/blocks/regunet_block.py b/monai/networks/blocks/regunet_block.py index 78e2598b4b..562364373b 100644 --- a/monai/networks/blocks/regunet_block.py +++ b/monai/networks/blocks/regunet_block.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Sequence, Tuple, Type, Union +from __future__ import annotations + +from collections.abc import Sequence import torch from torch import nn @@ -23,12 +25,12 @@ def get_conv_block( spatial_dims: int, in_channels: int, out_channels: int, - kernel_size: Union[Sequence[int], int] = 3, + kernel_size: Sequence[int] | int = 3, strides: int = 1, - padding: Optional[Union[Tuple[int, ...], int]] = None, - act: Optional[Union[Tuple, str]] = "RELU", - norm: Optional[Union[Tuple, str]] = "BATCH", - initializer: Optional[str] = "kaiming_uniform", + padding: tuple[int, ...] | int | None = None, + act: tuple | str | None = "RELU", + norm: tuple | str | None = "BATCH", + initializer: str | None = "kaiming_uniform", ) -> nn.Module: if padding is None: padding = same_padding(kernel_size) @@ -44,7 +46,7 @@ def get_conv_block( conv_only=False, padding=padding, ) - conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] + conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims] for m in conv_block.modules(): if isinstance(m, conv_type): if initializer == "kaiming_uniform": @@ -59,7 +61,7 @@ def get_conv_block( def get_conv_layer( - spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int] = 3 + spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Sequence[int] | int = 3 ) -> nn.Module: padding = same_padding(kernel_size) mod: nn.Module = Convolution( @@ -195,11 +197,13 @@ class RegistrationExtractionBlock(nn.Module): def __init__( self, spatial_dims: int, - extract_levels: Tuple[int], - num_channels: Union[Tuple[int], List[int]], + extract_levels: tuple[int], + num_channels: tuple[int] | list[int], out_channels: int, - kernel_initializer: Optional[str] = "kaiming_uniform", - activation: Optional[str] = None, + kernel_initializer: str | None = "kaiming_uniform", + activation: str | None = None, + mode: str = "nearest", + align_corners: bool | None = None, ): """ @@ -211,6 +215,8 @@ def __init__( out_channels: number of output channels kernel_initializer: kernel initializer activation: kernel activation function + mode: feature map interpolation mode, default to "nearest". + align_corners: whether to align corners for feature map interpolation. """ super().__init__() self.extract_levels = extract_levels @@ -228,8 +234,10 @@ def __init__( for d in extract_levels ] ) + self.mode = mode + self.align_corners = align_corners - def forward(self, x: List[torch.Tensor], image_size: List[int]) -> torch.Tensor: + def forward(self, x: list[torch.Tensor], image_size: list[int]) -> torch.Tensor: """ Args: @@ -240,7 +248,9 @@ def forward(self, x: List[torch.Tensor], image_size: List[int]) -> torch.Tensor: Tensor of shape (batch, `out_channels`, size1, size2, size3), where (size1, size2, size3) = ``image_size`` """ feature_list = [ - F.interpolate(layer(x[self.max_level - level]), size=image_size) + F.interpolate( + layer(x[self.max_level - level]), size=image_size, mode=self.mode, align_corners=self.align_corners + ) for layer, level in zip(self.layers, self.extract_levels) ] out: torch.Tensor = torch.mean(torch.stack(feature_list, dim=0), dim=0) diff --git a/monai/networks/blocks/segresnet_block.py b/monai/networks/blocks/segresnet_block.py index ded270ab52..3337f50043 100644 --- a/monai/networks/blocks/segresnet_block.py +++ b/monai/networks/blocks/segresnet_block.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from __future__ import annotations import torch.nn as nn @@ -22,14 +22,13 @@ def get_conv_layer( spatial_dims: int, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, bias: bool = False ): - return Convolution( spatial_dims, in_channels, out_channels, strides=stride, kernel_size=kernel_size, bias=bias, conv_only=True ) def get_upsample_layer( - spatial_dims: int, in_channels: int, upsample_mode: Union[UpsampleMode, str] = "nontrainable", scale_factor: int = 2 + spatial_dims: int, in_channels: int, upsample_mode: UpsampleMode | str = "nontrainable", scale_factor: int = 2 ): return UpSample( spatial_dims=spatial_dims, @@ -53,9 +52,9 @@ def __init__( self, spatial_dims: int, in_channels: int, - norm: Union[Tuple, str], + norm: tuple | str, kernel_size: int = 3, - act: Union[Tuple, str] = ("RELU", {"inplace": True}), + act: tuple | str = ("RELU", {"inplace": True}), ) -> None: """ Args: @@ -78,7 +77,6 @@ def __init__( self.conv2 = get_conv_layer(spatial_dims, in_channels=in_channels, out_channels=in_channels) def forward(self, x): - identity = x x = self.norm1(x) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index d0b87fda6b..519c8c7728 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import torch import torch.nn as nn diff --git a/monai/networks/blocks/squeeze_and_excitation.py b/monai/networks/blocks/squeeze_and_excitation.py index a9ac57aa4f..665e9020ff 100644 --- a/monai/networks/blocks/squeeze_and_excitation.py +++ b/monai/networks/blocks/squeeze_and_excitation.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import math -from typing import Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -30,8 +31,8 @@ def __init__( spatial_dims: int, in_channels: int, r: int = 2, - acti_type_1: Union[Tuple[str, Dict], str] = ("relu", {"inplace": True}), - acti_type_2: Union[Tuple[str, Dict], str] = "sigmoid", + acti_type_1: tuple[str, dict] | str = ("relu", {"inplace": True}), + acti_type_2: tuple[str, dict] | str = "sigmoid", add_residual: bool = False, ) -> None: """ @@ -102,8 +103,8 @@ def __init__( spatial_dims: int, in_channels: int, r: int = 2, - acti_type_1: Union[Tuple[str, Dict], str] = "leakyrelu", - acti_type_2: Union[Tuple[str, Dict], str] = "relu", + acti_type_1: tuple[str, dict] | str = "leakyrelu", + acti_type_2: tuple[str, dict] | str = "relu", ) -> None: """ Args: @@ -145,14 +146,14 @@ def __init__( n_chns_1: int, n_chns_2: int, n_chns_3: int, - conv_param_1: Optional[Dict] = None, - conv_param_2: Optional[Dict] = None, - conv_param_3: Optional[Dict] = None, - project: Optional[Convolution] = None, + conv_param_1: dict | None = None, + conv_param_2: dict | None = None, + conv_param_3: dict | None = None, + project: Convolution | None = None, r: int = 2, - acti_type_1: Union[Tuple[str, Dict], str] = ("relu", {"inplace": True}), - acti_type_2: Union[Tuple[str, Dict], str] = "sigmoid", - acti_type_final: Optional[Union[Tuple[str, Dict], str]] = ("relu", {"inplace": True}), + acti_type_1: tuple[str, dict] | str = ("relu", {"inplace": True}), + acti_type_2: tuple[str, dict] | str = "sigmoid", + acti_type_final: tuple[str, dict] | str | None = ("relu", {"inplace": True}), ): """ Args: @@ -244,9 +245,8 @@ def __init__( groups: int, reduction: int, stride: int = 1, - downsample: Optional[Convolution] = None, + downsample: Convolution | None = None, ) -> None: - conv_param_1 = { "strides": 1, "kernel_size": 1, @@ -295,9 +295,8 @@ def __init__( groups: int, reduction: int, stride: int = 1, - downsample: Optional[Convolution] = None, + downsample: Convolution | None = None, ) -> None: - conv_param_1 = { "strides": stride, "kernel_size": 1, @@ -344,10 +343,9 @@ def __init__( groups: int, reduction: int, stride: int = 1, - downsample: Optional[Convolution] = None, + downsample: Convolution | None = None, base_width: int = 4, ) -> None: - conv_param_1 = { "strides": 1, "kernel_size": 1, diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 88b33acb09..3a4b507d69 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import torch.nn as nn from monai.networks.blocks.mlp import MLPBlock diff --git a/monai/networks/blocks/unetr_block.py b/monai/networks/blocks/unetr_block.py index 452a535a2a..13ea0ad4c8 100644 --- a/monai/networks/blocks/unetr_block.py +++ b/monai/networks/blocks/unetr_block.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Tuple, Union +from __future__ import annotations + +from collections.abc import Sequence import torch import torch.nn as nn @@ -28,9 +30,9 @@ def __init__( spatial_dims: int, in_channels: int, out_channels: int, - kernel_size: Union[Sequence[int], int], - upsample_kernel_size: Union[Sequence[int], int], - norm_name: Union[Tuple, str], + kernel_size: Sequence[int] | int, + upsample_kernel_size: Sequence[int] | int, + norm_name: tuple | str, res_block: bool = False, ) -> None: """ @@ -96,10 +98,10 @@ def __init__( in_channels: int, out_channels: int, num_layer: int, - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - upsample_kernel_size: Union[Sequence[int], int], - norm_name: Union[Tuple, str], + kernel_size: Sequence[int] | int, + stride: Sequence[int] | int, + upsample_kernel_size: Sequence[int] | int, + norm_name: tuple | str, conv_block: bool = False, res_block: bool = False, ) -> None: @@ -215,9 +217,9 @@ def __init__( spatial_dims: int, in_channels: int, out_channels: int, - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - norm_name: Union[Tuple, str], + kernel_size: Sequence[int] | int, + stride: Sequence[int] | int, + norm_name: tuple | str, res_block: bool = False, ) -> None: """ diff --git a/monai/networks/blocks/upsample.py b/monai/networks/blocks/upsample.py index ee03aa4e67..dee9966919 100644 --- a/monai/networks/blocks/upsample.py +++ b/monai/networks/blocks/upsample.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple, Union +from __future__ import annotations + +from collections.abc import Sequence import torch import torch.nn as nn @@ -31,6 +33,9 @@ class UpSample(nn.Sequential): - "nontrainable": uses :py:class:`torch.nn.Upsample`. - "pixelshuffle": uses :py:class:`monai.networks.blocks.SubpixelUpsample`. + This operation will cause non-deterministic when ``mode`` is ``UpsampleMode.NONTRAINABLE``. + Please check the link below for more details: + https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms This module can optionally take a pre-convolution (often used to map the number of features from `in_channels` to `out_channels`). """ @@ -38,15 +43,15 @@ class UpSample(nn.Sequential): def __init__( self, spatial_dims: int, - in_channels: Optional[int] = None, - out_channels: Optional[int] = None, - scale_factor: Union[Sequence[float], float] = 2, - kernel_size: Optional[Union[Sequence[float], float]] = None, - size: Optional[Union[Tuple[int], int]] = None, - mode: Union[UpsampleMode, str] = UpsampleMode.DECONV, - pre_conv: Optional[Union[nn.Module, str]] = "default", + in_channels: int | None = None, + out_channels: int | None = None, + scale_factor: Sequence[float] | float = 2, + kernel_size: Sequence[float] | float | None = None, + size: tuple[int] | int | None = None, + mode: UpsampleMode | str = UpsampleMode.DECONV, + pre_conv: nn.Module | str | None = "default", interp_mode: str = InterpolateMode.LINEAR, - align_corners: Optional[bool] = True, + align_corners: bool | None = True, bias: bool = True, apply_pad_pool: bool = True, ) -> None: @@ -203,10 +208,10 @@ class SubpixelUpsample(nn.Module): def __init__( self, spatial_dims: int, - in_channels: Optional[int], - out_channels: Optional[int] = None, + in_channels: int | None, + out_channels: int | None = None, scale_factor: int = 2, - conv_block: Optional[Union[nn.Module, str]] = "default", + conv_block: nn.Module | str | None = "default", apply_pad_pool: bool = True, bias: bool = True, ) -> None: diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 018d694407..10a115eff8 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import List import torch from torch import nn @@ -122,7 +123,7 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor): if not USE_COMPILED: # pytorch native grid_sample for i, dim in enumerate(grid.shape[1:-1]): grid[..., i] = grid[..., i] * 2 / (dim - 1) - 1 - index_ordering: List[int] = list(range(spatial_dims - 1, -1, -1)) + index_ordering: list[int] = list(range(spatial_dims - 1, -1, -1)) grid = grid[..., index_ordering] # z, y, x -> x, y, z return F.grid_sample( image, grid, mode=self._interp_mode, padding_mode=f"{self._padding_mode}", align_corners=True diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index 31bd36dd8f..d61ed57f7f 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -9,20 +9,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding from .drop_path import DropPath from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args -from .filtering import BilateralFilter, PHLFilter +from .filtering import BilateralFilter, PHLFilter, TrainableBilateralFilter, TrainableJointBilateralFilter from .gmm import GaussianMixtureModel from .simplelayers import ( LLTM, + ApplyFilter, ChannelPad, + EllipticalFilter, Flatten, GaussianFilter, HilbertTransform, + LaplaceFilter, + MeanFilter, MedianFilter, Reshape, SavitzkyGolayFilter, + SharpenFilter, SkipConnection, apply_filter, median_filter, diff --git a/monai/networks/layers/convutils.py b/monai/networks/layers/convutils.py index fe688b24ff..fc8ea03809 100644 --- a/monai/networks/layers/convutils.py +++ b/monai/networks/layers/convutils.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from collections.abc import Sequence import numpy as np import torch @@ -17,9 +19,7 @@ __all__ = ["same_padding", "stride_minus_kernel_padding", "calculate_out_shape", "gaussian_1d", "polyval"] -def same_padding( - kernel_size: Union[Sequence[int], int], dilation: Union[Sequence[int], int] = 1 -) -> Union[Tuple[int, ...], int]: +def same_padding(kernel_size: Sequence[int] | int, dilation: Sequence[int] | int = 1) -> tuple[int, ...] | int: """ Return the padding value needed to ensure a convolution using the given kernel size produces an output of the same shape as the input for a stride of 1, otherwise ensure a shape of the input divided by the stride rounded down. @@ -43,9 +43,7 @@ def same_padding( return padding if len(padding) > 1 else padding[0] -def stride_minus_kernel_padding( - kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int] -) -> Union[Tuple[int, ...], int]: +def stride_minus_kernel_padding(kernel_size: Sequence[int] | int, stride: Sequence[int] | int) -> tuple[int, ...] | int: kernel_size_np = np.atleast_1d(kernel_size) stride_np = np.atleast_1d(stride) @@ -56,11 +54,11 @@ def stride_minus_kernel_padding( def calculate_out_shape( - in_shape: Union[Sequence[int], int, np.ndarray], - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - padding: Union[Sequence[int], int], -) -> Union[Tuple[int, ...], int]: + in_shape: Sequence[int] | int | np.ndarray, + kernel_size: Sequence[int] | int, + stride: Sequence[int] | int, + padding: Sequence[int] | int, +) -> tuple[int, ...] | int: """ Calculate the output tensor shape when applying a convolution to a tensor of shape `inShape` with kernel size `kernel_size`, stride value `stride`, and input padding value `padding`. All arguments can be scalars or multiple @@ -120,7 +118,7 @@ def gaussian_1d( out = out / (2.5066282 * sigma) elif approx.lower() == "scalespace": sigma2 = sigma * sigma - out_pos: List[Optional[torch.Tensor]] = [None] * (tail + 1) + out_pos: list[torch.Tensor | None] = [None] * (tail + 1) out_pos[0] = _modified_bessel_0(sigma2) out_pos[1] = _modified_bessel_1(sigma2) for k in range(2, len(out_pos)): diff --git a/monai/networks/layers/drop_path.py b/monai/networks/layers/drop_path.py index 7bb209ed25..6073ca4402 100644 --- a/monai/networks/layers/drop_path.py +++ b/monai/networks/layers/drop_path.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import torch.nn as nn diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index b6c64c6c19..37536a7cf6 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -59,8 +59,11 @@ def use_factory(fact_args): layer = use_factory( (fact.TEST, kwargs) ) """ +from __future__ import annotations + import warnings -from typing import Any, Callable, Dict, Tuple, Type, Union +from collections.abc import Callable +from typing import Any import torch import torch.nn as nn @@ -79,10 +82,10 @@ class LayerFactory: """ def __init__(self) -> None: - self.factories: Dict[str, Callable] = {} + self.factories: dict[str, Callable] = {} @property - def names(self) -> Tuple[str, ...]: + def names(self) -> tuple[str, ...]: """ Produces all factory names. """ @@ -203,7 +206,7 @@ def split_args(args): @Dropout.factory_function("dropout") -def dropout_factory(dim: int) -> Type[Union[nn.Dropout, nn.Dropout2d, nn.Dropout3d]]: +def dropout_factory(dim: int) -> type[nn.Dropout | nn.Dropout2d | nn.Dropout3d]: types = (nn.Dropout, nn.Dropout2d, nn.Dropout3d) return types[dim - 1] @@ -214,34 +217,34 @@ def alpha_dropout_factory(_dim): @Norm.factory_function("instance") -def instance_factory(dim: int) -> Type[Union[nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d]]: +def instance_factory(dim: int) -> type[nn.InstanceNorm1d | nn.InstanceNorm2d | nn.InstanceNorm3d]: types = (nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d) return types[dim - 1] @Norm.factory_function("batch") -def batch_factory(dim: int) -> Type[Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]]: +def batch_factory(dim: int) -> type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d]: types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) return types[dim - 1] @Norm.factory_function("group") -def group_factory(_dim) -> Type[nn.GroupNorm]: +def group_factory(_dim) -> type[nn.GroupNorm]: return nn.GroupNorm @Norm.factory_function("layer") -def layer_factory(_dim) -> Type[nn.LayerNorm]: +def layer_factory(_dim) -> type[nn.LayerNorm]: return nn.LayerNorm @Norm.factory_function("localresponse") -def local_response_factory(_dim) -> Type[nn.LocalResponseNorm]: +def local_response_factory(_dim) -> type[nn.LocalResponseNorm]: return nn.LocalResponseNorm @Norm.factory_function("syncbatch") -def sync_batch_factory(_dim) -> Type[nn.SyncBatchNorm]: +def sync_batch_factory(_dim) -> type[nn.SyncBatchNorm]: return nn.SyncBatchNorm @@ -317,53 +320,56 @@ def mish_factory(): return Mish +@Act.factory_function("geglu") +def geglu_factory(): + from monai.networks.blocks.activation import GEGLU + + return GEGLU + + @Conv.factory_function("conv") -def conv_factory(dim: int) -> Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]]: +def conv_factory(dim: int) -> type[nn.Conv1d | nn.Conv2d | nn.Conv3d]: types = (nn.Conv1d, nn.Conv2d, nn.Conv3d) return types[dim - 1] @Conv.factory_function("convtrans") -def convtrans_factory(dim: int) -> Type[Union[nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d]]: +def convtrans_factory(dim: int) -> type[nn.ConvTranspose1d | nn.ConvTranspose2d | nn.ConvTranspose3d]: types = (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d) return types[dim - 1] @Pool.factory_function("max") -def maxpooling_factory(dim: int) -> Type[Union[nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d]]: +def maxpooling_factory(dim: int) -> type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d]: types = (nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d) return types[dim - 1] @Pool.factory_function("adaptivemax") -def adaptive_maxpooling_factory( - dim: int, -) -> Type[Union[nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d]]: +def adaptive_maxpooling_factory(dim: int) -> type[nn.AdaptiveMaxPool1d | nn.AdaptiveMaxPool2d | nn.AdaptiveMaxPool3d]: types = (nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d) return types[dim - 1] @Pool.factory_function("avg") -def avgpooling_factory(dim: int) -> Type[Union[nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d]]: +def avgpooling_factory(dim: int) -> type[nn.AvgPool1d | nn.AvgPool2d | nn.AvgPool3d]: types = (nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d) return types[dim - 1] @Pool.factory_function("adaptiveavg") -def adaptive_avgpooling_factory( - dim: int, -) -> Type[Union[nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d]]: +def adaptive_avgpooling_factory(dim: int) -> type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d]: types = (nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d) return types[dim - 1] @Pad.factory_function("replicationpad") -def replication_pad_factory(dim: int) -> Type[Union[nn.ReplicationPad1d, nn.ReplicationPad2d, nn.ReplicationPad3d]]: +def replication_pad_factory(dim: int) -> type[nn.ReplicationPad1d | nn.ReplicationPad2d | nn.ReplicationPad3d]: types = (nn.ReplicationPad1d, nn.ReplicationPad2d, nn.ReplicationPad3d) return types[dim - 1] @Pad.factory_function("constantpad") -def constant_pad_factory(dim: int) -> Type[Union[nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d]]: +def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d | nn.ConstantPad3d]: types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d) return types[dim - 1] diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index bbf925eba9..ae43eb3f73 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -9,13 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import torch from monai.utils.module import optional_import _C, _ = optional_import("monai._C") -__all__ = ["BilateralFilter", "PHLFilter"] +__all__ = ["BilateralFilter", "PHLFilter", "TrainableBilateralFilter", "TrainableJointBilateralFilter"] class BilateralFilter(torch.autograd.Function): @@ -31,10 +33,10 @@ class BilateralFilter(torch.autograd.Function): Args: input: input tensor. - spatial sigma: the standard deviation of the spatial blur. Higher values can + spatial_sigma: the standard deviation of the spatial blur. Higher values can hurt performance when not using the approximate method (see fast approx). - color sigma: the standard deviation of the color blur. Lower values preserve + color_sigma: the standard deviation of the color blur. Lower values preserve edges better whilst higher values tend to a simple gaussian spatial blur. fast approx: This flag chooses between two implementations. The approximate method may @@ -47,6 +49,7 @@ class BilateralFilter(torch.autograd.Function): @staticmethod def forward(ctx, input, spatial_sigma=5, color_sigma=0.5, fast_approx=True): + """autograd forward""" ctx.ss = spatial_sigma ctx.cs = color_sigma ctx.fa = fast_approx @@ -55,6 +58,7 @@ def forward(ctx, input, spatial_sigma=5, color_sigma=0.5, fast_approx=True): @staticmethod def backward(ctx, grad_output): + """autograd backward""" spatial_sigma, color_sigma, fast_approx = ctx.ss, ctx.cs, ctx.fa grad_input = _C.bilateral_filter(grad_output, spatial_sigma, color_sigma, fast_approx) return grad_input, None, None, None @@ -83,7 +87,6 @@ class PHLFilter(torch.autograd.Function): @staticmethod def forward(ctx, input, features, sigmas=None): - scaled_features = features if sigmas is not None: for i in range(features.size(1)): @@ -99,3 +102,351 @@ def backward(ctx, grad_output): # scaled_features, = ctx.saved_variables # grad_input = _C.phl_filter(grad_output, scaled_features) # return grad_input + + +class TrainableBilateralFilterFunction(torch.autograd.Function): + """ + torch.autograd.Function for the TrainableBilateralFilter layer. + + See: + F. Wagner, et al., Ultralow-parameter denoising: Trainable bilateral filter layers in + computed tomography, Medical Physics (2022), https://doi.org/10.1002/mp.15718 + + Args: + input: input tensor to be filtered. + + sigma x: trainable standard deviation of the spatial filter kernel in x direction. + + sigma y: trainable standard deviation of the spatial filter kernel in y direction. + + sigma z: trainable standard deviation of the spatial filter kernel in z direction. + + color sigma: trainable standard deviation of the intensity range kernel. This filter + parameter determines the degree of edge preservation. + + Returns: + output (torch.Tensor): filtered tensor. + """ + + @staticmethod + def forward(ctx, input_img, sigma_x, sigma_y, sigma_z, color_sigma): + output_tensor, output_weights_tensor, do_dx_ki, do_dsig_r, do_dsig_x, do_dsig_y, do_dsig_z = _C.tbf_forward( + input_img, sigma_x, sigma_y, sigma_z, color_sigma + ) + + ctx.save_for_backward( + input_img, + sigma_x, + sigma_y, + sigma_z, + color_sigma, + output_tensor, + output_weights_tensor, + do_dx_ki, + do_dsig_r, + do_dsig_x, + do_dsig_y, + do_dsig_z, + ) + + return output_tensor + + @staticmethod + def backward(ctx, grad_output): + input_img = ctx.saved_tensors[0] # input image + sigma_x = ctx.saved_tensors[1] + sigma_y = ctx.saved_tensors[2] + sigma_z = ctx.saved_tensors[3] + color_sigma = ctx.saved_tensors[4] + output_tensor = ctx.saved_tensors[5] # filtered image + output_weights_tensor = ctx.saved_tensors[6] # weights + do_dx_ki = ctx.saved_tensors[7] # derivative of output with respect to input, while k==i + do_dsig_r = ctx.saved_tensors[8] # derivative of output with respect to range sigma + do_dsig_x = ctx.saved_tensors[9] # derivative of output with respect to sigma x + do_dsig_y = ctx.saved_tensors[10] # derivative of output with respect to sigma y + do_dsig_z = ctx.saved_tensors[11] # derivative of output with respect to sigma z + + # calculate gradient with respect to the sigmas + grad_color_sigma = torch.sum(grad_output * do_dsig_r) + grad_sig_x = torch.sum(grad_output * do_dsig_x) + grad_sig_y = torch.sum(grad_output * do_dsig_y) + grad_sig_z = torch.sum(grad_output * do_dsig_z) + + grad_output_tensor = _C.tbf_backward( + grad_output, + input_img, + output_tensor, + output_weights_tensor, + do_dx_ki, + sigma_x, + sigma_y, + sigma_z, + color_sigma, + ) + + return grad_output_tensor, grad_sig_x, grad_sig_y, grad_sig_z, grad_color_sigma + + +class TrainableBilateralFilter(torch.nn.Module): + """ + Implementation of a trainable bilateral filter layer as proposed in the corresponding publication. + All filter parameters can be trained data-driven. The spatial filter kernels x, y, and z determine + image smoothing whereas the color parameter specifies the amount of edge preservation. + Can run on 1D, 2D, or 3D tensors (on top of Batch and Channel dimensions). + + See: + F. Wagner, et al., Ultralow-parameter denoising: Trainable bilateral filter layers in + computed tomography, Medical Physics (2022), https://doi.org/10.1002/mp.15718 + + Args: + input: input tensor to be filtered. + + spatial_sigma: tuple (sigma_x, sigma_y, sigma_z) initializing the trainable standard + deviations of the spatial filter kernels. Tuple length must equal the number of + spatial input dimensions. + + color_sigma: trainable standard deviation of the intensity range kernel. This filter + parameter determines the degree of edge preservation. + + Returns: + output (torch.Tensor): filtered tensor. + """ + + def __init__(self, spatial_sigma, color_sigma): + super().__init__() + + if isinstance(spatial_sigma, float): + spatial_sigma = [spatial_sigma, spatial_sigma, spatial_sigma] + self.len_spatial_sigma = 3 + elif len(spatial_sigma) == 1: + spatial_sigma = [spatial_sigma[0], 0.01, 0.01] + self.len_spatial_sigma = 1 + elif len(spatial_sigma) == 2: + spatial_sigma = [spatial_sigma[0], spatial_sigma[1], 0.01] + self.len_spatial_sigma = 2 + elif len(spatial_sigma) == 3: + spatial_sigma = [spatial_sigma[0], spatial_sigma[1], spatial_sigma[2]] + self.len_spatial_sigma = 3 + else: + raise ValueError( + f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}." + ) + + # Register sigmas as trainable parameters. + self.sigma_x = torch.nn.Parameter(torch.tensor(spatial_sigma[0])) + self.sigma_y = torch.nn.Parameter(torch.tensor(spatial_sigma[1])) + self.sigma_z = torch.nn.Parameter(torch.tensor(spatial_sigma[2])) + self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma)) + + def forward(self, input_tensor): + if input_tensor.shape[1] != 1: + raise ValueError( + f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. " + "Please use multiple parallel filter layers if you want " + "to filter multiple channels." + ) + + len_input = len(input_tensor.shape) + + # C++ extension so far only supports 5-dim inputs. + if len_input == 3: + input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) + elif len_input == 4: + input_tensor = input_tensor.unsqueeze(4) + + if self.len_spatial_sigma != len_input: + raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).") + + prediction = TrainableBilateralFilterFunction.apply( + input_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color + ) + + # Make sure to return tensor of the same shape as the input. + if len_input == 3: + prediction = prediction.squeeze(4).squeeze(3) + elif len_input == 4: + prediction = prediction.squeeze(4) + + return prediction + + +class TrainableJointBilateralFilterFunction(torch.autograd.Function): + """ + torch.autograd.Function for the TrainableJointBilateralFilter layer. + + See: + F. Wagner, et al., Trainable joint bilateral filters for enhanced prediction stability in + low-dose CT, Scientific Reports (2022), https://doi.org/10.1038/s41598-022-22530-4 + + Args: + input: input tensor to be filtered. + + guide: guidance image tensor to be used during filtering. + + sigma x: trainable standard deviation of the spatial filter kernel in x direction. + + sigma y: trainable standard deviation of the spatial filter kernel in y direction. + + sigma z: trainable standard deviation of the spatial filter kernel in z direction. + + color sigma: trainable standard deviation of the intensity range kernel. This filter + parameter determines the degree of edge preservation. + + Returns: + output (torch.Tensor): filtered tensor. + """ + + @staticmethod + def forward(ctx, input_img, guidance_img, sigma_x, sigma_y, sigma_z, color_sigma): + output_tensor, output_weights_tensor, do_dx_ki, do_dsig_r, do_dsig_x, do_dsig_y, do_dsig_z = _C.tjbf_forward( + input_img, guidance_img, sigma_x, sigma_y, sigma_z, color_sigma + ) + + ctx.save_for_backward( + input_img, + sigma_x, + sigma_y, + sigma_z, + color_sigma, + output_tensor, + output_weights_tensor, + do_dx_ki, + do_dsig_r, + do_dsig_x, + do_dsig_y, + do_dsig_z, + guidance_img, + ) + + return output_tensor + + @staticmethod + def backward(ctx, grad_output): + input_img = ctx.saved_tensors[0] # input image + sigma_x = ctx.saved_tensors[1] + sigma_y = ctx.saved_tensors[2] + sigma_z = ctx.saved_tensors[3] + color_sigma = ctx.saved_tensors[4] + output_tensor = ctx.saved_tensors[5] # filtered image + output_weights_tensor = ctx.saved_tensors[6] # weights + do_dx_ki = ctx.saved_tensors[7] # derivative of output with respect to input, while k==i + do_dsig_r = ctx.saved_tensors[8] # derivative of output with respect to range sigma + do_dsig_x = ctx.saved_tensors[9] # derivative of output with respect to sigma x + do_dsig_y = ctx.saved_tensors[10] # derivative of output with respect to sigma y + do_dsig_z = ctx.saved_tensors[11] # derivative of output with respect to sigma z + guidance_img = ctx.saved_tensors[12] # guidance image + + # calculate gradient with respect to the sigmas + grad_color_sigma = torch.sum(grad_output * do_dsig_r) + grad_sig_x = torch.sum(grad_output * do_dsig_x) + grad_sig_y = torch.sum(grad_output * do_dsig_y) + grad_sig_z = torch.sum(grad_output * do_dsig_z) + + grad_output_tensor, grad_guidance_tensor = _C.tjbf_backward( + grad_output, + input_img, + guidance_img, + output_tensor, + output_weights_tensor, + do_dx_ki, + sigma_x, + sigma_y, + sigma_z, + color_sigma, + ) + + return grad_output_tensor, grad_guidance_tensor, grad_sig_x, grad_sig_y, grad_sig_z, grad_color_sigma + + +class TrainableJointBilateralFilter(torch.nn.Module): + """ + Implementation of a trainable joint bilateral filter layer as proposed in the corresponding publication. + The guidance image is used as additional (edge) information during filtering. All filter parameters and the + guidance image can be trained data-driven. The spatial filter kernels x, y, and z determine + image smoothing whereas the color parameter specifies the amount of edge preservation. + Can run on 1D, 2D, or 3D tensors (on top of Batch and Channel dimensions). Input tensor shape must match + guidance tensor shape. + + See: + F. Wagner, et al., Trainable joint bilateral filters for enhanced prediction stability in + low-dose CT, Scientific Reports (2022), https://doi.org/10.1038/s41598-022-22530-4 + + Args: + input: input tensor to be filtered. + + guide: guidance image tensor to be used during filtering. + + spatial_sigma: tuple (sigma_x, sigma_y, sigma_z) initializing the trainable standard + deviations of the spatial filter kernels. Tuple length must equal the number of + spatial input dimensions. + + color_sigma: trainable standard deviation of the intensity range kernel. This filter + parameter determines the degree of edge preservation. + + Returns: + output (torch.Tensor): filtered tensor. + """ + + def __init__(self, spatial_sigma, color_sigma): + super().__init__() + + if isinstance(spatial_sigma, float): + spatial_sigma = [spatial_sigma, spatial_sigma, spatial_sigma] + self.len_spatial_sigma = 3 + elif len(spatial_sigma) == 1: + spatial_sigma = [spatial_sigma[0], 0.01, 0.01] + self.len_spatial_sigma = 1 + elif len(spatial_sigma) == 2: + spatial_sigma = [spatial_sigma[0], spatial_sigma[1], 0.01] + self.len_spatial_sigma = 2 + elif len(spatial_sigma) == 3: + spatial_sigma = [spatial_sigma[0], spatial_sigma[1], spatial_sigma[2]] + self.len_spatial_sigma = 3 + else: + raise ValueError( + f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}." + ) + + # Register sigmas as trainable parameters. + self.sigma_x = torch.nn.Parameter(torch.tensor(spatial_sigma[0])) + self.sigma_y = torch.nn.Parameter(torch.tensor(spatial_sigma[1])) + self.sigma_z = torch.nn.Parameter(torch.tensor(spatial_sigma[2])) + self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma)) + + def forward(self, input_tensor, guidance_tensor): + if input_tensor.shape[1] != 1: + raise ValueError( + f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. " + "Please use multiple parallel filter layers if you want " + "to filter multiple channels." + ) + if input_tensor.shape != guidance_tensor.shape: + raise ValueError( + "Shape of input image must equal shape of guidance image." + f"Got {input_tensor.shape} and {guidance_tensor.shape}." + ) + + len_input = len(input_tensor.shape) + + # C++ extension so far only supports 5-dim inputs. + if len_input == 3: + input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) + guidance_tensor = guidance_tensor.unsqueeze(3).unsqueeze(4) + elif len_input == 4: + input_tensor = input_tensor.unsqueeze(4) + guidance_tensor = guidance_tensor.unsqueeze(4) + + if self.len_spatial_sigma != len_input: + raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).") + + prediction = TrainableJointBilateralFilterFunction.apply( + input_tensor, guidance_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color + ) + + # Make sure to return tensor of the same shape as the input. + if len_input == 3: + prediction = prediction.squeeze(4).squeeze(3) + elif len_input == 4: + prediction = prediction.squeeze(4) + + return prediction diff --git a/monai/networks/layers/gmm.py b/monai/networks/layers/gmm.py index eb9a3f91e4..94d619bb7a 100644 --- a/monai/networks/layers/gmm.py +++ b/monai/networks/layers/gmm.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import torch from monai._extensions.loader import load_module diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index bc72f8d1d3..a1122ceaa2 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -9,15 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import math from copy import deepcopy -from typing import List, Optional, Sequence, Union +from typing import Sequence import torch import torch.nn.functional as F from torch import nn from torch.autograd import Function +from monai.config.type_definitions import NdarrayOrTensor from monai.networks.layers.convutils import gaussian_1d from monai.networks.layers.factories import Conv from monai.utils import ( @@ -57,11 +60,7 @@ class ChannelPad(nn.Module): """ def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - mode: Union[ChannelMatching, str] = ChannelMatching.PAD, + self, spatial_dims: int, in_channels: int, out_channels: int, mode: ChannelMatching | str = ChannelMatching.PAD ): """ @@ -112,7 +111,7 @@ class SkipConnection(nn.Module): The available modes are ``"cat"``, ``"add"``, ``"mul"``. """ - def __init__(self, submodule, dim: int = 1, mode: Union[str, SkipMode] = "cat") -> None: + def __init__(self, submodule, dim: int = 1, mode: str | SkipMode = "cat") -> None: """ Args: @@ -171,11 +170,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def _separable_filtering_conv( input_: torch.Tensor, - kernels: List[torch.Tensor], + kernels: list[torch.Tensor], pad_mode: str, d: int, spatial_dims: int, - paddings: List[int], + paddings: list[int], num_channels: int, ) -> torch.Tensor: if d < 0: @@ -195,8 +194,8 @@ def _separable_filtering_conv( conv_type = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1] # translate padding for input to torch.nn.functional.pad - _reversed_padding_repeated_twice: List[List[int]] = [[p, p] for p in reversed(_padding)] - _sum_reversed_padding_repeated_twice: List[int] = sum(_reversed_padding_repeated_twice, []) + _reversed_padding_repeated_twice: list[list[int]] = [[p, p] for p in reversed(_padding)] + _sum_reversed_padding_repeated_twice: list[int] = sum(_reversed_padding_repeated_twice, []) padded_input = F.pad(input_, _sum_reversed_padding_repeated_twice, mode=pad_mode) return conv_type( @@ -206,7 +205,7 @@ def _separable_filtering_conv( ) -def separable_filtering(x: torch.Tensor, kernels: List[torch.Tensor], mode: str = "zeros") -> torch.Tensor: +def separable_filtering(x: torch.Tensor, kernels: list[torch.Tensor], mode: str = "zeros") -> torch.Tensor: """ Apply 1-D convolutions along each spatial dimension of `x`. @@ -322,7 +321,6 @@ class SavitzkyGolayFilter(nn.Module): """ def __init__(self, window_length: int, order: int, axis: int = 2, mode: str = "zeros"): - super().__init__() if order >= window_length: raise ValueError("order must be less than window_length.") @@ -366,7 +364,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @staticmethod def _make_coeffs(window_length, order): - half_length, rem = divmod(window_length, 2) if rem == 0: raise ValueError("window_length must be odd.") @@ -391,8 +388,7 @@ class HilbertTransform(nn.Module): n: Number of Fourier components (i.e. FFT size). Default: ``x.shape[axis]``. """ - def __init__(self, axis: int = 2, n: Union[int, None] = None) -> None: - + def __init__(self, axis: int = 2, n: int | None = None) -> None: super().__init__() self.axis = axis self.n = n @@ -458,7 +454,7 @@ def median_filter( in_tensor: torch.Tensor, kernel_size: Sequence[int] = (3, 3, 3), spatial_dims: int = 3, - kernel: Optional[torch.Tensor] = None, + kernel: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: """ @@ -536,7 +532,7 @@ class MedianFilter(nn.Module): """ - def __init__(self, radius: Union[Sequence[int], int], spatial_dims: int = 3, device="cpu") -> None: + def __init__(self, radius: Sequence[int] | int, spatial_dims: int = 3, device="cpu") -> None: super().__init__() self.spatial_dims = spatial_dims self.radius: Sequence[int] = ensure_tuple_rep(radius, spatial_dims) @@ -559,7 +555,7 @@ class GaussianFilter(nn.Module): def __init__( self, spatial_dims: int, - sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor], + sigma: Sequence[float] | float | Sequence[torch.Tensor] | torch.Tensor, truncated: float = 4.0, approx: str = "erf", requires_grad: bool = False, @@ -658,3 +654,90 @@ def reset_parameters(self): def forward(self, input, state): return LLTMFunction.apply(input, self.weights, self.bias, *state) + + +class ApplyFilter(nn.Module): + "Wrapper class to apply a filter to an image." + + def __init__(self, filter: NdarrayOrTensor) -> None: + super().__init__() + + self.filter = convert_to_tensor(filter, dtype=torch.float32) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return apply_filter(x, self.filter) + + +class MeanFilter(ApplyFilter): + """ + Mean filtering can smooth edges and remove aliasing artifacts in an segmentation image. + The mean filter used, is a `torch.Tensor` of all ones. + """ + + def __init__(self, spatial_dims: int, size: int) -> None: + """ + Args: + spatial_dims: `int` of either 2 for 2D images and 3 for 3D images + size: edge length of the filter + """ + filter = torch.ones([size] * spatial_dims) + filter = filter + super().__init__(filter=filter) + + +class LaplaceFilter(ApplyFilter): + """ + Laplacian filtering for outline detection in images. Can be used to transform labels to contours. + The laplace filter used, is a `torch.Tensor` where all values are -1, except the center value + which is `size` ** `spatial_dims` + """ + + def __init__(self, spatial_dims: int, size: int) -> None: + """ + Args: + spatial_dims: `int` of either 2 for 2D images and 3 for 3D images + size: edge length of the filter + """ + filter = torch.zeros([size] * spatial_dims).float() - 1 # make all -1 + center_point = tuple([size // 2] * spatial_dims) + filter[center_point] = (size**spatial_dims) - 1 + super().__init__(filter=filter) + + +class EllipticalFilter(ApplyFilter): + """ + Elliptical filter, can be used to dilate labels or label-contours. + The elliptical filter used here, is a `torch.Tensor` with shape (size, ) * ndim containing a circle/sphere of `1` + """ + + def __init__(self, spatial_dims: int, size: int) -> None: + """ + Args: + spatial_dims: `int` of either 2 for 2D images and 3 for 3D images + size: edge length of the filter + """ + radius = size // 2 + grid = torch.meshgrid(*[torch.arange(0, size) for _ in range(spatial_dims)]) + squared_distances = torch.stack([(axis - radius) ** 2 for axis in grid], 0).sum(0) + filter = squared_distances <= radius**2 + super().__init__(filter=filter) + + +class SharpenFilter(EllipticalFilter): + """ + Convolutional filter to sharpen a 2D or 3D image. + The filter used contains a circle/sphere of `-1`, with the center value being + the absolute sum of all non-zero elements in the kernel + """ + + def __init__(self, spatial_dims: int, size: int) -> None: + """ + Args: + spatial_dims: `int` of either 2 for 2D images and 3 for 3D images + size: edge length of the filter + """ + super().__init__(spatial_dims=spatial_dims, size=size) + center_point = tuple([size // 2] * spatial_dims) + center_value = self.filter.sum() + self.filter *= -1 + self.filter[center_point] = center_value diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py index d1a0fed021..53f35e63f2 100644 --- a/monai/networks/layers/spatial_transforms.py +++ b/monai/networks/layers/spatial_transforms.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Union +from __future__ import annotations + +from collections.abc import Sequence import torch import torch.nn as nn @@ -433,13 +435,13 @@ def grid_grad(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", b class AffineTransform(nn.Module): def __init__( self, - spatial_size: Optional[Union[Sequence[int], int]] = None, + spatial_size: Sequence[int] | int | None = None, normalized: bool = False, mode: str = GridSampleMode.BILINEAR, padding_mode: str = GridSamplePadMode.ZEROS, - align_corners: bool = False, + align_corners: bool = True, reverse_indexing: bool = True, - zero_centered: Optional[bool] = None, + zero_centered: bool | None = None, ) -> None: """ Apply affine transformations with a batch of affine matrices. @@ -493,7 +495,7 @@ def __init__( self.zero_centered = zero_centered if zero_centered is not None else False def forward( - self, src: torch.Tensor, theta: torch.Tensor, spatial_size: Optional[Union[Sequence[int], int]] = None + self, src: torch.Tensor, theta: torch.Tensor, spatial_size: Sequence[int] | int | None = None ) -> torch.Tensor: """ ``theta`` must be an affine transformation matrix with shape @@ -535,6 +537,8 @@ def forward( theta = torch.cat([theta, pad_affine], dim=1) if tuple(theta.shape[1:]) not in ((3, 3), (4, 4)): raise ValueError(f"theta must be Nx3x3 or Nx4x4, got {theta.shape}.") + if not torch.is_floating_point(theta): + raise ValueError(f"theta must be floating point data, got {theta.dtype}") # validate `src` if not isinstance(src, torch.Tensor): @@ -557,7 +561,7 @@ def forward( affine=theta, src_size=src_size[2:], dst_size=dst_size[2:], - align_corners=self.align_corners, + align_corners=False, zero_centered=self.zero_centered, ) if self.reverse_indexing: diff --git a/monai/networks/layers/utils.py b/monai/networks/layers/utils.py index a630a5edc7..ace1af27b6 100644 --- a/monai/networks/layers/utils.py +++ b/monai/networks/layers/utils.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from __future__ import annotations import torch.nn @@ -19,7 +19,7 @@ __all__ = ["get_norm_layer", "get_act_layer", "get_dropout_layer", "get_pool_layer"] -def get_norm_layer(name: Union[Tuple, str], spatial_dims: Optional[int] = 1, channels: Optional[int] = 1): +def get_norm_layer(name: tuple | str, spatial_dims: int | None = 1, channels: int | None = 1): """ Create a normalization layer instance. @@ -50,7 +50,7 @@ def get_norm_layer(name: Union[Tuple, str], spatial_dims: Optional[int] = 1, cha return norm_type(**kw_args) -def get_act_layer(name: Union[Tuple, str]): +def get_act_layer(name: tuple | str): """ Create an activation layer instance. @@ -73,7 +73,7 @@ def get_act_layer(name: Union[Tuple, str]): return act_type(**act_args) -def get_dropout_layer(name: Union[Tuple, str, float, int], dropout_dim: Optional[int] = 1): +def get_dropout_layer(name: tuple | str | float | int, dropout_dim: int | None = 1): """ Create a dropout layer instance. @@ -102,7 +102,7 @@ def get_dropout_layer(name: Union[Tuple, str, float, int], dropout_dim: Optional return drop_type(**drop_args) -def get_pool_layer(name: Union[Tuple, str], spatial_dims: Optional[int] = 1): +def get_pool_layer(name: tuple | str, spatial_dims: int | None = 1): """ Create a pooling layer instance. diff --git a/monai/networks/layers/weight_init.py b/monai/networks/layers/weight_init.py index b0c6fae2c2..f413e19ca3 100644 --- a/monai/networks/layers/weight_init.py +++ b/monai/networks/layers/weight_init.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import math import torch diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 18a85d802a..95ddad7842 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .ahnet import AHnet, Ahnet, AHNet from .attentionunet import AttentionUnet from .autoencoder import AutoEncoder diff --git a/monai/networks/nets/ahnet.py b/monai/networks/nets/ahnet.py index 65d85a2054..ae248c0cd1 100644 --- a/monai/networks/nets/ahnet.py +++ b/monai/networks/nets/ahnet.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import math -from typing import Optional, Sequence, Type, Union +from collections.abc import Sequence import torch import torch.nn as nn @@ -23,7 +25,6 @@ class Bottleneck3x3x1(nn.Module): - expansion = 4 def __init__( @@ -31,16 +32,15 @@ def __init__( spatial_dims: int, inplanes: int, planes: int, - stride: Union[Sequence[int], int] = 1, - downsample: Optional[nn.Sequential] = None, + stride: Sequence[int] | int = 1, + downsample: nn.Sequential | None = None, ) -> None: - super().__init__() conv_type = Conv[Conv.CONV, spatial_dims] - norm_type: Type[Union[nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] - pool_type: Type[Union[nn.MaxPool2d, nn.MaxPool3d]] = Pool[Pool.MAX, spatial_dims] - relu_type: Type[nn.ReLU] = Act[Act.RELU] + norm_type: type[nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims] + pool_type: type[nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims] + relu_type: type[nn.ReLU] = Act[Act.RELU] self.conv1 = conv_type(inplanes, planes, kernel_size=1, bias=False) self.bn1 = norm_type(planes) @@ -90,8 +90,8 @@ def __init__(self, spatial_dims: int, num_input_features: int, num_output_featur super().__init__() conv_type = Conv[Conv.CONV, spatial_dims] - norm_type: Type[Union[nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] - relu_type: Type[nn.ReLU] = Act[Act.RELU] + norm_type: type[nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims] + relu_type: type[nn.ReLU] = Act[Act.RELU] self.add_module("norm", norm_type(num_input_features)) self.add_module("relu", relu_type(inplace=True)) @@ -123,8 +123,8 @@ def __init__( super().__init__() conv_type = Conv[Conv.CONV, spatial_dims] - norm_type: Type[Union[nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] - relu_type: Type[nn.ReLU] = Act[Act.RELU] + norm_type: type[nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims] + relu_type: type[nn.ReLU] = Act[Act.RELU] self.add_module("norm", norm_type(num_input_features)) self.add_module("relu", relu_type(inplace=True)) @@ -135,7 +135,7 @@ def __init__( "up", conv_trans_type(num_output_features, num_output_features, kernel_size=2, stride=2, bias=False) ) else: - align_corners: Optional[bool] = None + align_corners: bool | None = None if upsample_mode in ["trilinear", "bilinear"]: align_corners = True self.add_module("up", nn.Upsample(scale_factor=2, mode=upsample_mode, align_corners=align_corners)) @@ -148,8 +148,8 @@ def __init__( super().__init__() conv_type = Conv[Conv.CONV, spatial_dims] - norm_type: Type[Union[nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] - relu_type: Type[nn.ReLU] = Act[Act.RELU] + norm_type: type[nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims] + relu_type: type[nn.ReLU] = Act[Act.RELU] self.add_module("norm", norm_type(num_input_features)) self.add_module("relu", relu_type(inplace=True)) @@ -170,7 +170,7 @@ def __init__( "up", conv_trans_type(num_output_features, num_output_features, kernel_size=2, stride=2, bias=False) ) else: - align_corners: Optional[bool] = None + align_corners: bool | None = None if upsample_mode in ["trilinear", "bilinear"]: align_corners = True self.add_module("up", nn.Upsample(scale_factor=2, mode=upsample_mode, align_corners=align_corners)) @@ -182,8 +182,8 @@ def __init__(self, spatial_dims: int, num_input_features: int, growth_rate: int, # 1x1x1 conv_type = Conv[Conv.CONV, spatial_dims] - norm_type: Type[Union[nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] - relu_type: Type[nn.ReLU] = Act[Act.RELU] + norm_type: type[nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims] + relu_type: type[nn.ReLU] = Act[Act.RELU] self.bn1 = norm_type(num_input_features) self.relu1 = relu_type(inplace=True) @@ -247,7 +247,7 @@ def __init__(self, spatial_dims: int, psp_block_num: int, in_ch: int, upsample_m super().__init__() self.up_modules = nn.ModuleList() conv_type = Conv[Conv.CONV, spatial_dims] - pool_type: Type[Union[nn.MaxPool2d, nn.MaxPool3d]] = Pool[Pool.MAX, spatial_dims] + pool_type: type[nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims] self.pool_modules = nn.ModuleList() self.project_modules = nn.ModuleList() @@ -273,15 +273,13 @@ def __init__(self, spatial_dims: int, psp_block_num: int, in_ch: int, upsample_m def forward(self, x: torch.Tensor) -> torch.Tensor: outputs = [] if self.upsample_mode == "transpose": - for (project_module, pool_module, up_module) in zip( - self.project_modules, self.pool_modules, self.up_modules - ): + for project_module, pool_module, up_module in zip(self.project_modules, self.pool_modules, self.up_modules): output = up_module(project_module(pool_module(x))) outputs.append(output) else: - for (project_module, pool_module) in zip(self.project_modules, self.pool_modules): + for project_module, pool_module in zip(self.project_modules, self.pool_modules): interpolate_size = x.shape[2:] - align_corners: Optional[bool] = None + align_corners: bool | None = None if self.upsample_mode in ["trilinear", "bilinear"]: align_corners = True output = F.interpolate( @@ -347,10 +345,10 @@ def __init__( conv_type = Conv[Conv.CONV, spatial_dims] conv_trans_type = Conv[Conv.CONVTRANS, spatial_dims] norm_type = Norm[Norm.BATCH, spatial_dims] - pool_type: Type[Union[nn.MaxPool2d, nn.MaxPool3d]] = Pool[Pool.MAX, spatial_dims] - relu_type: Type[nn.ReLU] = Act[Act.RELU] - conv2d_type: Type[nn.Conv2d] = Conv[Conv.CONV, 2] - norm2d_type: Type[nn.BatchNorm2d] = Norm[Norm.BATCH, 2] + pool_type: type[nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims] + relu_type: type[nn.ReLU] = Act[Act.RELU] + conv2d_type: type[nn.Conv2d] = Conv[Conv.CONV, 2] + norm2d_type: type[nn.BatchNorm2d] = Norm[Norm.BATCH, 2] self.conv2d_type = conv2d_type self.norm2d_type = norm2d_type @@ -437,7 +435,7 @@ def __init__( net2d = FCN(pretrained=True, progress=progress) self.copy_from(net2d) - def _make_layer(self, block: Type[Bottleneck3x3x1], planes: int, blocks: int, stride: int = 1) -> nn.Sequential: + def _make_layer(self, block: type[Bottleneck3x3x1], planes: int, blocks: int, stride: int = 1) -> nn.Sequential: downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( diff --git a/monai/networks/nets/attentionunet.py b/monai/networks/nets/attentionunet.py index a57b57425e..362d63d636 100644 --- a/monai/networks/nets/attentionunet.py +++ b/monai/networks/nets/attentionunet.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Union +from __future__ import annotations + +from collections.abc import Sequence import torch import torch.nn as nn @@ -200,8 +202,8 @@ def __init__( out_channels: int, channels: Sequence[int], strides: Sequence[int], - kernel_size: Union[Sequence[int], int] = 3, - up_kernel_size: Union[Sequence[int], int] = 3, + kernel_size: Sequence[int] | int = 3, + up_kernel_size: Sequence[int] | int = 3, dropout: float = 0.0, ): super().__init__() diff --git a/monai/networks/nets/autoencoder.py b/monai/networks/nets/autoencoder.py index a88d1861ad..8f093bcc1d 100644 --- a/monai/networks/nets/autoencoder.py +++ b/monai/networks/nets/autoencoder.py @@ -9,7 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any import torch import torch.nn as nn @@ -91,18 +94,17 @@ def __init__( out_channels: int, channels: Sequence[int], strides: Sequence[int], - kernel_size: Union[Sequence[int], int] = 3, - up_kernel_size: Union[Sequence[int], int] = 3, + kernel_size: Sequence[int] | int = 3, + up_kernel_size: Sequence[int] | int = 3, num_res_units: int = 0, - inter_channels: Optional[list] = None, - inter_dilations: Optional[list] = None, + inter_channels: list | None = None, + inter_dilations: list | None = None, num_inter_units: int = 2, - act: Optional[Union[Tuple, str]] = Act.PRELU, - norm: Union[Tuple, str] = Norm.INSTANCE, - dropout: Optional[Union[Tuple, str, float]] = None, + act: tuple | str | None = Act.PRELU, + norm: tuple | str = Norm.INSTANCE, + dropout: tuple | str | float | None = None, bias: bool = True, ) -> None: - super().__init__() self.dimensions = spatial_dims self.in_channels = in_channels @@ -133,7 +135,7 @@ def __init__( def _get_encode_module( self, in_channels: int, channels: Sequence[int], strides: Sequence[int] - ) -> Tuple[nn.Sequential, int]: + ) -> tuple[nn.Sequential, int]: """ Returns the encode part of the network by building up a sequence of layers returned by `_get_encode_layer`. """ @@ -147,7 +149,7 @@ def _get_encode_module( return encode, layer_channels - def _get_intermediate_module(self, in_channels: int, num_inter_units: int) -> Tuple[nn.Module, int]: + def _get_intermediate_module(self, in_channels: int, num_inter_units: int) -> tuple[nn.Module, int]: """ Returns the intermediate block of the network which accepts input from the encoder and whose output goes to the decoder. @@ -198,7 +200,7 @@ def _get_intermediate_module(self, in_channels: int, num_inter_units: int) -> Tu def _get_decode_module( self, in_channels: int, channels: Sequence[int], strides: Sequence[int] - ) -> Tuple[nn.Sequential, int]: + ) -> tuple[nn.Sequential, int]: """ Returns the decode part of the network by building up a sequence of layers returned by `_get_decode_layer`. """ diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py index 6fe77038fe..b26fdcb622 100644 --- a/monai/networks/nets/basic_unet.py +++ b/monai/networks/nets/basic_unet.py @@ -9,14 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Union +from __future__ import annotations + +from collections.abc import Sequence import torch import torch.nn as nn from monai.networks.blocks import Convolution, UpSample from monai.networks.layers.factories import Conv, Pool -from monai.utils import deprecated_arg, ensure_tuple_rep +from monai.utils import ensure_tuple_rep __all__ = ["BasicUnet", "Basicunet", "basicunet", "BasicUNet"] @@ -29,10 +31,10 @@ def __init__( spatial_dims: int, in_chns: int, out_chns: int, - act: Union[str, tuple], - norm: Union[str, tuple], + act: str | tuple, + norm: str | tuple, bias: bool, - dropout: Union[float, tuple] = 0.0, + dropout: float | tuple = 0.0, ): """ Args: @@ -63,10 +65,10 @@ def __init__( spatial_dims: int, in_chns: int, out_chns: int, - act: Union[str, tuple], - norm: Union[str, tuple], + act: str | tuple, + norm: str | tuple, bias: bool, - dropout: Union[float, tuple] = 0.0, + dropout: float | tuple = 0.0, ): """ Args: @@ -95,14 +97,14 @@ def __init__( in_chns: int, cat_chns: int, out_chns: int, - act: Union[str, tuple], - norm: Union[str, tuple], + act: str | tuple, + norm: str | tuple, bias: bool, - dropout: Union[float, tuple] = 0.0, + dropout: float | tuple = 0.0, upsample: str = "deconv", - pre_conv: Optional[Union[nn.Module, str]] = "default", + pre_conv: nn.Module | str | None = "default", interp_mode: str = "linear", - align_corners: Optional[bool] = True, + align_corners: bool | None = True, halves: bool = True, is_pad: bool = True, ): @@ -147,7 +149,7 @@ def __init__( self.convs = TwoConv(spatial_dims, cat_chns + up_chns, out_chns, act, norm, bias, dropout) self.is_pad = is_pad - def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]): + def forward(self, x: torch.Tensor, x_e: torch.Tensor | None): """ Args: @@ -173,21 +175,17 @@ def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]): class BasicUNet(nn.Module): - @deprecated_arg( - name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." - ) def __init__( self, spatial_dims: int = 3, in_channels: int = 1, out_channels: int = 2, features: Sequence[int] = (32, 32, 64, 128, 256, 32), - act: Union[str, tuple] = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}), - norm: Union[str, tuple] = ("instance", {"affine": True}), + act: str | tuple = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}), + norm: str | tuple = ("instance", {"affine": True}), bias: bool = True, - dropout: Union[float, tuple] = 0.0, + dropout: float | tuple = 0.0, upsample: str = "deconv", - dimensions: Optional[int] = None, ): """ A UNet implementation with 1D/2D/3D supports. @@ -217,9 +215,6 @@ def __init__( upsample: upsampling mode, available options are ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. - .. deprecated:: 0.6.0 - ``dimensions`` is deprecated, use ``spatial_dims`` instead. - Examples:: # for spatial 2D @@ -238,8 +233,6 @@ def __init__( """ super().__init__() - if dimensions is not None: - spatial_dims = dimensions fea = ensure_tuple_rep(features, 6) print(f"BasicUNet features: {fea}.") diff --git a/monai/networks/nets/basic_unetplusplus.py b/monai/networks/nets/basic_unetplusplus.py index 4f7d319aaa..28d4b4668a 100644 --- a/monai/networks/nets/basic_unetplusplus.py +++ b/monai/networks/nets/basic_unetplusplus.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Union +from __future__ import annotations + +from collections.abc import Sequence import torch import torch.nn as nn @@ -29,10 +31,10 @@ def __init__( out_channels: int = 2, features: Sequence[int] = (32, 32, 64, 128, 256, 32), deep_supervision: bool = False, - act: Union[str, tuple] = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}), - norm: Union[str, tuple] = ("instance", {"affine": True}), + act: str | tuple = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}), + norm: str | tuple = ("instance", {"affine": True}), bias: bool = True, - dropout: Union[float, tuple] = 0.0, + dropout: float | tuple = 0.0, upsample: str = "deconv", ): """ diff --git a/monai/networks/nets/classifier.py b/monai/networks/nets/classifier.py index 7f4e43eedb..dcedaf3e70 100644 --- a/monai/networks/nets/classifier.py +++ b/monai/networks/nets/classifier.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Union +from __future__ import annotations + +from collections.abc import Sequence import torch import torch.nn as nn @@ -46,13 +48,13 @@ def __init__( classes: int, channels: Sequence[int], strides: Sequence[int], - kernel_size: Union[Sequence[int], int] = 3, + kernel_size: Sequence[int] | int = 3, num_res_units: int = 2, act=Act.PRELU, norm=Norm.INSTANCE, - dropout: Optional[float] = None, + dropout: float | None = None, bias: bool = True, - last_act: Optional[str] = None, + last_act: str | None = None, ) -> None: super().__init__(in_shape, (classes,), channels, strides, kernel_size, num_res_units, act, norm, dropout, bias) @@ -86,11 +88,11 @@ def __init__( in_shape: Sequence[int], channels: Sequence[int], strides: Sequence[int], - kernel_size: Union[Sequence[int], int] = 3, + kernel_size: Sequence[int] | int = 3, num_res_units: int = 2, act=Act.PRELU, norm=Norm.INSTANCE, - dropout: Optional[float] = 0.25, + dropout: float | None = 0.25, bias: bool = True, last_act=Act.SIGMOID, ) -> None: @@ -120,11 +122,11 @@ def __init__( in_shape: Sequence[int], channels: Sequence[int], strides: Sequence[int], - kernel_size: Union[Sequence[int], int] = 3, + kernel_size: Sequence[int] | int = 3, num_res_units: int = 2, act=Act.PRELU, norm=Norm.INSTANCE, - dropout: Optional[float] = 0.25, + dropout: float | None = 0.25, bias: bool = True, ) -> None: super().__init__(in_shape, 1, channels, strides, kernel_size, num_res_units, act, norm, dropout, bias, None) diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py index 2f02ecf395..2100272d91 100644 --- a/monai/networks/nets/densenet.py +++ b/monai/networks/nets/densenet.py @@ -9,9 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import re from collections import OrderedDict -from typing import Callable, Sequence, Type, Union +from collections.abc import Callable, Sequence import torch import torch.nn as nn @@ -47,8 +49,8 @@ def __init__( growth_rate: int, bn_size: int, dropout_prob: float, - act: Union[str, tuple] = ("relu", {"inplace": True}), - norm: Union[str, tuple] = "batch", + act: str | tuple = ("relu", {"inplace": True}), + norm: str | tuple = "batch", ) -> None: """ Args: @@ -94,8 +96,8 @@ def __init__( bn_size: int, growth_rate: int, dropout_prob: float, - act: Union[str, tuple] = ("relu", {"inplace": True}), - norm: Union[str, tuple] = "batch", + act: str | tuple = ("relu", {"inplace": True}), + norm: str | tuple = "batch", ) -> None: """ Args: @@ -122,8 +124,8 @@ def __init__( spatial_dims: int, in_channels: int, out_channels: int, - act: Union[str, tuple] = ("relu", {"inplace": True}), - norm: Union[str, tuple] = "batch", + act: str | tuple = ("relu", {"inplace": True}), + norm: str | tuple = "batch", ) -> None: """ Args: @@ -175,16 +177,15 @@ def __init__( growth_rate: int = 32, block_config: Sequence[int] = (6, 12, 24, 16), bn_size: int = 4, - act: Union[str, tuple] = ("relu", {"inplace": True}), - norm: Union[str, tuple] = "batch", + act: str | tuple = ("relu", {"inplace": True}), + norm: str | tuple = "batch", dropout_prob: float = 0.0, ) -> None: - super().__init__() - conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] - pool_type: Type[Union[nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d]] = Pool[Pool.MAX, spatial_dims] - avg_pool_type: Type[Union[nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d]] = Pool[ + conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims] + pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims] + avg_pool_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[ Pool.ADAPTIVEAVG, spatial_dims ] @@ -294,6 +295,9 @@ class DenseNet121(DenseNet): def __init__( self, + spatial_dims: int, + in_channels: int, + out_channels: int, init_features: int = 64, growth_rate: int = 32, block_config: Sequence[int] = (6, 12, 24, 16), @@ -301,9 +305,17 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs) + super().__init__( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + init_features=init_features, + growth_rate=growth_rate, + block_config=block_config, + **kwargs, + ) if pretrained: - if kwargs["spatial_dims"] > 2: + if spatial_dims > 2: raise NotImplementedError( "Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not" "provide pretrained models for more than two spatial dimensions." @@ -316,6 +328,9 @@ class DenseNet169(DenseNet): def __init__( self, + spatial_dims: int, + in_channels: int, + out_channels: int, init_features: int = 64, growth_rate: int = 32, block_config: Sequence[int] = (6, 12, 32, 32), @@ -323,9 +338,17 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs) + super().__init__( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + init_features=init_features, + growth_rate=growth_rate, + block_config=block_config, + **kwargs, + ) if pretrained: - if kwargs["spatial_dims"] > 2: + if spatial_dims > 2: raise NotImplementedError( "Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not" "provide pretrained models for more than two spatial dimensions." @@ -338,6 +361,9 @@ class DenseNet201(DenseNet): def __init__( self, + spatial_dims: int, + in_channels: int, + out_channels: int, init_features: int = 64, growth_rate: int = 32, block_config: Sequence[int] = (6, 12, 48, 32), @@ -345,9 +371,17 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs) + super().__init__( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + init_features=init_features, + growth_rate=growth_rate, + block_config=block_config, + **kwargs, + ) if pretrained: - if kwargs["spatial_dims"] > 2: + if spatial_dims > 2: raise NotImplementedError( "Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not" "provide pretrained models for more than two spatial dimensions." @@ -360,6 +394,9 @@ class DenseNet264(DenseNet): def __init__( self, + spatial_dims: int, + in_channels: int, + out_channels: int, init_features: int = 64, growth_rate: int = 32, block_config: Sequence[int] = (6, 12, 64, 48), @@ -367,7 +404,15 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs) + super().__init__( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + init_features=init_features, + growth_rate=growth_rate, + block_config=block_config, + **kwargs, + ) if pretrained: raise NotImplementedError("Currently PyTorch Hub does not provide densenet264 pretrained models.") diff --git a/monai/networks/nets/dints.py b/monai/networks/nets/dints.py index b62fe2b747..437789ef0c 100644 --- a/monai/networks/nets/dints.py +++ b/monai/networks/nets/dints.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import datetime import warnings -from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -38,7 +40,7 @@ class CellInterface(torch.nn.Module): """interface for torchscriptable Cell""" - def forward(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: # type: ignore + def forward(self, x: torch.Tensor, weight) -> torch.Tensor: # type: ignore pass @@ -75,15 +77,6 @@ def __init__(self, *args, **kwargs): self.ram_cost = 0 -class _CloseWithRAMCost(nn.Module): - def __init__(self): - super().__init__() - self.ram_cost = 0 - - def forward(self, x): - return torch.tensor(0.0, requires_grad=False).to(x) - - class _ActiConvNormBlockWithRAMCost(ActiConvNormBlock): """The class wraps monai layers with ram estimation. The ram_cost = total_ram/output_size is estimated. Here is the estimation: @@ -103,8 +96,8 @@ def __init__( kernel_size: int, padding: int, spatial_dims: int = 3, - act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + act_name: tuple | str = "RELU", + norm_name: tuple | str = ("INSTANCE", {"affine": True}), ): super().__init__(in_channel, out_channel, kernel_size, padding, spatial_dims, act_name, norm_name) self.ram_cost = 1 + in_channel / out_channel * 2 @@ -118,8 +111,8 @@ def __init__( kernel_size: int, padding: int, p3dmode: int = 0, - act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + act_name: tuple | str = "RELU", + norm_name: tuple | str = ("INSTANCE", {"affine": True}), ): super().__init__(in_channel, out_channel, kernel_size, padding, p3dmode, act_name, norm_name) # 1 in_channel (activation) + 1 in_channel (convolution) + @@ -133,8 +126,8 @@ def __init__( in_channel: int, out_channel: int, spatial_dims: int = 3, - act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + act_name: tuple | str = "RELU", + norm_name: tuple | str = ("INSTANCE", {"affine": True}), ): super().__init__(in_channel, out_channel, spatial_dims, act_name, norm_name) # s0 is upsampled 2x from s1, representing feature sizes at two resolutions. @@ -149,8 +142,8 @@ def __init__( in_channel: int, out_channel: int, spatial_dims: int = 3, - act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + act_name: tuple | str = "RELU", + norm_name: tuple | str = ("INSTANCE", {"affine": True}), ): super().__init__(in_channel, out_channel, spatial_dims, act_name, norm_name) # s0 is upsampled 2x from s1, representing feature sizes at two resolutions. @@ -174,9 +167,10 @@ def __init__(self, c: int, ops: dict, arch_code_c=None): arch_code_c = np.ones(len(ops)) self.ops = nn.ModuleList() for arch_c, op_name in zip(arch_code_c, ops): - self.ops.append(_CloseWithRAMCost() if arch_c == 0 else ops[op_name](c)) + if arch_c > 0: + self.ops.append(ops[op_name](c)) - def forward(self, x: torch.Tensor, weight: torch.Tensor): + def forward(self, x: torch.Tensor, weight: torch.Tensor | None = None): """ Args: x: input tensor. @@ -185,9 +179,10 @@ def forward(self, x: torch.Tensor, weight: torch.Tensor): out: weighted average of the operation results. """ out = 0.0 - weight = weight.to(x) + if weight is not None: + weight = weight.to(x) for idx, _op in enumerate(self.ops): - out = out + _op(x) * weight[idx] + out = (out + _op(x)) if weight is None else out + _op(x) * weight[idx] return out @@ -244,8 +239,8 @@ def __init__( rate: int, arch_code_c=None, spatial_dims: int = 3, - act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + act_name: tuple | str = "RELU", + norm_name: tuple | str = ("INSTANCE", {"affine": True}), ): super().__init__() self._spatial_dims = spatial_dims @@ -303,7 +298,7 @@ def __init__( self.op = MixedOp(c, self.OPS, arch_code_c) - def forward(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, weight: torch.Tensor | None) -> torch.Tensor: """ Args: x: input tensor @@ -357,8 +352,8 @@ def __init__( dints_space, in_channels: int, num_classes: int, - act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + act_name: tuple | str = "RELU", + norm_name: tuple | str = ("INSTANCE", {"affine": True}), spatial_dims: int = 3, use_downsample: bool = True, node_a=None, @@ -555,23 +550,30 @@ class TopologyConstruction(nn.Module): def __init__( self, - arch_code: Optional[list] = None, + arch_code: list | None = None, channel_mul: float = 1.0, cell=Cell, num_blocks: int = 6, num_depths: int = 3, spatial_dims: int = 3, - act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + act_name: tuple | str = "RELU", + norm_name: tuple | str = ("INSTANCE", {"affine": True}), use_downsample: bool = True, device: str = "cpu", ): - super().__init__() - self.filter_nums = [int(n_feat * channel_mul) for n_feat in (32, 64, 128, 256, 512)] + n_feats = tuple([32 * (2**_i) for _i in range(num_depths + 1)]) + self.filter_nums = [int(n_feat * channel_mul) for n_feat in n_feats] + self.num_blocks = num_blocks self.num_depths = num_depths + print( + "{} - Length of input patch is recommended to be a multiple of {:d}.".format( + datetime.datetime.now(), 2 ** (num_depths + int(use_downsample)) + ) + ) + self._spatial_dims = spatial_dims self._act_name = act_name self._norm_name = norm_name @@ -638,8 +640,8 @@ def __init__( num_blocks: int = 6, num_depths: int = 3, spatial_dims: int = 3, - act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + act_name: tuple | str = "RELU", + norm_name: tuple | str = ("INSTANCE", {"affine": True}), use_downsample: bool = True, device: str = "cpu", ): @@ -662,21 +664,19 @@ def __init__( device=device, ) - def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: + def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor]: """ Args: x: input tensor. """ # generate path activation probability - inputs, outputs = x, [torch.tensor(0.0).to(x[0])] * self.num_depths + inputs = x for blk_idx in range(self.num_blocks): - outputs = [torch.tensor(0.0).to(x[0])] * self.num_depths + outputs = [torch.tensor(0.0, dtype=x[0].dtype, device=x[0].device)] * self.num_depths for res_idx, activation in enumerate(self.arch_code_a[blk_idx].data): if activation: mod: CellInterface = self.cell_tree[str((blk_idx, res_idx))] - _out = mod.forward( - x=inputs[self.arch_code2in[res_idx]], weight=torch.ones_like(self.arch_code_c[blk_idx, res_idx]) - ) + _out = mod.forward(x=inputs[self.arch_code2in[res_idx]], weight=None) outputs[self.arch_code2out[res_idx]] = outputs[self.arch_code2out[res_idx]] + _out inputs = outputs @@ -729,19 +729,19 @@ class TopologySearch(TopologyConstruction): The return value will exclude path activation of all 0. """ - node2out: List[List] - node2in: List[List] + node2out: list[list] + node2in: list[list] def __init__( self, channel_mul: float = 1.0, cell=Cell, - arch_code: Optional[list] = None, + arch_code: list | None = None, num_blocks: int = 6, num_depths: int = 3, spatial_dims: int = 3, - act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + act_name: tuple | str = "RELU", + norm_name: tuple | str = ("INSTANCE", {"affine": True}), use_downsample: bool = True, device: str = "cpu", ): @@ -884,13 +884,13 @@ def get_ram_cost_usage(self, in_size, full: bool = False): sizes = [] for res_idx in range(self.num_depths): sizes.append(batch_size * self.filter_nums[res_idx] * (image_size // (2**res_idx)).prod()) - sizes = torch.tensor(sizes).to(torch.float32).to(self.device) / (2 ** (int(self.use_downsample))) + sizes = torch.tensor(sizes, dtype=torch.float32, device=self.device) / (2 ** (int(self.use_downsample))) probs_a, arch_code_prob_a = self.get_prob_a(child=False) cell_prob = F.softmax(self.log_alpha_c, dim=-1) if full: arch_code_prob_a = arch_code_prob_a.detach() arch_code_prob_a.fill_(1) - ram_cost = torch.from_numpy(self.ram_cost).to(torch.float32).to(self.device) + ram_cost = torch.from_numpy(self.ram_cost).to(dtype=torch.float32, device=self.device) usage = 0.0 for blk_idx in range(self.num_blocks): # node activation for input diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index 4eaba4bc42..a761f5993a 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Sequence, Tuple, Type, Union +from __future__ import annotations + +from collections.abc import Sequence import torch import torch.nn as nn @@ -30,7 +32,7 @@ class DynUNetSkipLayer(nn.Module): forward passes of the network. """ - heads: Optional[List[torch.Tensor]] + heads: list[torch.Tensor] | None def __init__(self, index, downsample, upsample, next_layer, heads=None, super_head=None): super().__init__() @@ -130,13 +132,13 @@ def __init__( spatial_dims: int, in_channels: int, out_channels: int, - kernel_size: Sequence[Union[Sequence[int], int]], - strides: Sequence[Union[Sequence[int], int]], - upsample_kernel_size: Sequence[Union[Sequence[int], int]], - filters: Optional[Sequence[int]] = None, - dropout: Optional[Union[Tuple, str, float]] = None, - norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), - act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), + kernel_size: Sequence[Sequence[int] | int], + strides: Sequence[Sequence[int] | int], + upsample_kernel_size: Sequence[Sequence[int] | int], + filters: Sequence[int] | None = None, + dropout: tuple | str | float | None = None, + norm_name: tuple | str = ("INSTANCE", {"affine": True}), + act_name: tuple | str = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), deep_supervision: bool = False, deep_supr_num: int = 1, res_block: bool = False, @@ -167,7 +169,7 @@ def __init__( self.deep_supervision = deep_supervision self.deep_supr_num = deep_supr_num # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on - self.heads: List[torch.Tensor] = [torch.rand(1)] * self.deep_supr_num + self.heads: list[torch.Tensor] = [torch.rand(1)] * self.deep_supr_num if self.deep_supervision: self.deep_supervision_heads = self.get_deep_supervision_heads() self.check_deep_supr_num() @@ -317,10 +319,10 @@ def get_module_list( self, in_channels: Sequence[int], out_channels: Sequence[int], - kernel_size: Sequence[Union[Sequence[int], int]], - strides: Sequence[Union[Sequence[int], int]], - conv_block: Type[nn.Module], - upsample_kernel_size: Optional[Sequence[Union[Sequence[int], int]]] = None, + kernel_size: Sequence[Sequence[int] | int], + strides: Sequence[Sequence[int] | int], + conv_block: type[nn.Module], + upsample_kernel_size: Sequence[Sequence[int] | int] | None = None, trans_bias: bool = False, ): layers = [] diff --git a/monai/networks/nets/efficientnet.py b/monai/networks/nets/efficientnet.py index eef3d68090..7c2a507fea 100644 --- a/monai/networks/nets/efficientnet.py +++ b/monai/networks/nets/efficientnet.py @@ -9,11 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import math import operator import re from functools import reduce -from typing import Dict, List, NamedTuple, Optional, Tuple, Type, Union +from typing import NamedTuple import torch from torch import nn @@ -78,12 +80,12 @@ def __init__( out_channels: int, kernel_size: int, stride: int, - image_size: List[int], + image_size: list[int], expand_ratio: int, - se_ratio: Optional[float], - id_skip: Optional[bool] = True, - norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.01}), - drop_connect_rate: Optional[float] = 0.2, + se_ratio: float | None, + id_skip: bool | None = True, + norm: str | tuple = ("batch", {"eps": 1e-3, "momentum": 0.01}), + drop_connect_rate: float | None = 0.2, ) -> None: """ Mobile Inverted Residual Bottleneck Block. @@ -227,7 +229,7 @@ def set_swish(self, memory_efficient: bool = True) -> None: class EfficientNet(nn.Module): def __init__( self, - blocks_args_str: List[str], + blocks_args_str: list[str], spatial_dims: int = 2, in_channels: int = 3, num_classes: int = 1000, @@ -235,7 +237,7 @@ def __init__( depth_coefficient: float = 1.0, dropout_rate: float = 0.2, image_size: int = 224, - norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.01}), + norm: str | tuple = ("batch", {"eps": 1e-3, "momentum": 0.01}), drop_connect_rate: float = 0.2, depth_divisor: int = 8, ) -> None: @@ -264,8 +266,8 @@ def __init__( # select the type of N-Dimensional layers to use # these are based on spatial dims and selected from MONAI factories - conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv["conv", spatial_dims] - adaptivepool_type: Type[Union[nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d]] = Pool[ + conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv["conv", spatial_dims] + adaptivepool_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[ "adaptiveavg", spatial_dims ] @@ -478,7 +480,7 @@ def __init__( spatial_dims: int = 2, in_channels: int = 3, num_classes: int = 1000, - norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.01}), + norm: str | tuple = ("batch", {"eps": 1e-3, "momentum": 0.01}), adv_prop: bool = False, ) -> None: """ @@ -564,7 +566,7 @@ def __init__( spatial_dims: int = 2, in_channels: int = 3, num_classes: int = 1000, - norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.01}), + norm: str | tuple = ("batch", {"eps": 1e-3, "momentum": 0.01}), adv_prop: bool = False, ) -> None: """ @@ -653,7 +655,7 @@ class EfficientNetEncoder(EfficientNetBNFeatures, BaseEncoder): ] @classmethod - def get_encoder_parameters(cls) -> List[Dict]: + def get_encoder_parameters(cls) -> list[dict]: """ Get the initialization parameter for efficientnet backbones. """ @@ -674,7 +676,7 @@ def get_encoder_parameters(cls) -> List[Dict]: return parameter_list @classmethod - def num_channels_per_output(cls) -> List[Tuple[int, ...]]: + def num_channels_per_output(cls) -> list[tuple[int, ...]]: """ Get number of efficientnet backbone output feature maps' channel. """ @@ -692,7 +694,7 @@ def num_channels_per_output(cls) -> List[Tuple[int, ...]]: ] @classmethod - def num_outputs(cls) -> List[int]: + def num_outputs(cls) -> list[int]: """ Get number of efficientnet backbone output feature maps. Since every backbone contains the same 5 output feature maps, @@ -701,7 +703,7 @@ def num_outputs(cls) -> List[int]: return [5] * 10 @classmethod - def get_encoder_names(cls) -> List[str]: + def get_encoder_names(cls) -> list[str]: """ Get names of efficient backbone. """ @@ -762,7 +764,7 @@ def drop_connect(inputs: torch.Tensor, p: float, training: bool) -> torch.Tensor num_dims: int = len(inputs.shape) - 2 # build dimensions for random tensor, use num_dims to populate appropriate spatial dims - random_tensor_shape: List[int] = [batch_size, 1] + [1] * num_dims + random_tensor_shape: list[int] = [batch_size, 1] + [1] * num_dims # generate binary_tensor mask according to probability (p for 0, 1-p for 1) random_tensor: torch.Tensor = torch.rand(random_tensor_shape, dtype=inputs.dtype, device=inputs.device) @@ -798,8 +800,8 @@ def _load_state_dict(model: nn.Module, arch: str, progress: bool, adv_prop: bool def _get_same_padding_conv_nd( - image_size: List[int], kernel_size: Tuple[int, ...], dilation: Tuple[int, ...], stride: Tuple[int, ...] -) -> List[int]: + image_size: list[int], kernel_size: tuple[int, ...], dilation: tuple[int, ...], stride: tuple[int, ...] +) -> list[int]: """ Helper for getting padding (nn.ConstantPadNd) to be used to get SAME padding conv operations similar to Tensorflow's SAME padding. @@ -826,20 +828,20 @@ def _get_same_padding_conv_nd( stride = stride * num_dims # equation to calculate (pad^+ + pad^-) size - _pad_size: List[int] = [ + _pad_size: list[int] = [ max((math.ceil(_i_s / _s) - 1) * _s + (_k_s - 1) * _d + 1 - _i_s, 0) for _i_s, _k_s, _d, _s in zip(image_size, kernel_size, dilation, stride) ] # distribute paddings into pad^+ and pad^- following Tensorflow's same padding strategy - _paddings: List[Tuple[int, int]] = [(_p // 2, _p - _p // 2) for _p in _pad_size] + _paddings: list[tuple[int, int]] = [(_p // 2, _p - _p // 2) for _p in _pad_size] # unroll list of tuples to tuples, and then to list # reversed as nn.ConstantPadNd expects paddings starting with last dimension - _paddings_ret: List[int] = [outer for inner in reversed(_paddings) for outer in inner] + _paddings_ret: list[int] = [outer for inner in reversed(_paddings) for outer in inner] return _paddings_ret -def _make_same_padder(conv_op: Union[nn.Conv1d, nn.Conv2d, nn.Conv3d], image_size: List[int]): +def _make_same_padder(conv_op: nn.Conv1d | nn.Conv2d | nn.Conv3d, image_size: list[int]): """ Helper for initializing ConstantPadNd with SAME padding similar to Tensorflow. Uses output of _get_same_padding_conv_nd() to get the padding size. @@ -854,7 +856,7 @@ def _make_same_padder(conv_op: Union[nn.Conv1d, nn.Conv2d, nn.Conv3d], image_siz If padding required then nn.ConstandNd() padder initialized to paddings otherwise nn.Identity() """ # calculate padding required - padding: List[int] = _get_same_padding_conv_nd(image_size, conv_op.kernel_size, conv_op.dilation, conv_op.stride) + padding: list[int] = _get_same_padding_conv_nd(image_size, conv_op.kernel_size, conv_op.dilation, conv_op.stride) # initialize and return padder padder = Pad["constantpad", len(padding) // 2] @@ -863,7 +865,7 @@ def _make_same_padder(conv_op: Union[nn.Conv1d, nn.Conv2d, nn.Conv3d], image_siz return nn.Identity() -def _round_filters(filters: int, width_coefficient: Optional[float], depth_divisor: float) -> int: +def _round_filters(filters: int, width_coefficient: float | None, depth_divisor: float) -> int: """ Calculate and round number of filters based on width coefficient multiplier and depth divisor. @@ -890,7 +892,7 @@ def _round_filters(filters: int, width_coefficient: Optional[float], depth_divis return int(new_filters) -def _round_repeats(repeats: int, depth_coefficient: Optional[float]) -> int: +def _round_repeats(repeats: int, depth_coefficient: float | None) -> int: """ Re-calculate module's repeat number of a block based on depth coefficient multiplier. @@ -908,7 +910,7 @@ def _round_repeats(repeats: int, depth_coefficient: Optional[float]) -> int: return int(math.ceil(depth_coefficient * repeats)) -def _calculate_output_image_size(input_image_size: List[int], stride: Union[int, Tuple[int]]): +def _calculate_output_image_size(input_image_size: list[int], stride: int | tuple[int]): """ Calculates the output image size when using _make_same_padder with a stride. Required for static padding. @@ -946,7 +948,7 @@ class BlockArgs(NamedTuple): input_filters: int output_filters: int id_skip: bool - se_ratio: Optional[float] = None + se_ratio: float | None = None @staticmethod def from_string(block_string: str): diff --git a/monai/networks/nets/flexible_unet.py b/monai/networks/nets/flexible_unet.py index 28e0cedaa0..a880cafdc3 100644 --- a/monai/networks/nets/flexible_unet.py +++ b/monai/networks/nets/flexible_unet.py @@ -9,9 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings +from collections.abc import Sequence from pydoc import locate -from typing import List, Optional, Sequence, Tuple, Type, Union +from typing import Any import torch from torch import nn @@ -38,7 +41,7 @@ class FlexUNetEncoderRegister: def __init__(self): self.register_dict = {} - def register_class(self, name: Union[Type, str]): + def register_class(self, name: type[Any] | str): """ Register a given class to the encoder dict. Please notice that input class must be a subclass of BaseEncoder. @@ -110,17 +113,16 @@ def __init__( spatial_dims: int, encoder_channels: Sequence[int], decoder_channels: Sequence[int], - act: Union[str, tuple], - norm: Union[str, tuple], - dropout: Union[float, tuple], + act: str | tuple, + norm: str | tuple, + dropout: float | tuple, bias: bool, upsample: str, - pre_conv: Optional[str], + pre_conv: str | None, interp_mode: str, - align_corners: Optional[bool], + align_corners: bool | None, is_pad: bool, ): - super().__init__() if len(encoder_channels) < 2: raise ValueError("the length of `encoder_channels` should be no less than 2.") @@ -153,7 +155,7 @@ def __init__( ) self.blocks = nn.ModuleList(blocks) - def forward(self, features: List[torch.Tensor], skip_connect: int = 4): + def forward(self, features: list[torch.Tensor], skip_connect: int = 4): skips = features[:-1][::-1] features = features[1:][::-1] @@ -190,10 +192,9 @@ def __init__( in_channels: int, out_channels: int, kernel_size: int = 3, - act: Optional[Union[Tuple, str]] = None, + act: tuple | str | None = None, scale_factor: float = 1.0, ): - conv_layer = Conv[Conv.CONV, spatial_dims]( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=kernel_size // 2 ) @@ -224,11 +225,11 @@ def __init__( out_channels: int, backbone: str, pretrained: bool = False, - decoder_channels: Tuple = (256, 128, 64, 32, 16), + decoder_channels: tuple = (256, 128, 64, 32, 16), spatial_dims: int = 2, - norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.1}), - act: Union[str, tuple] = ("relu", {"inplace": True}), - dropout: Union[float, tuple] = 0.0, + norm: str | tuple = ("batch", {"eps": 1e-3, "momentum": 0.1}), + act: str | tuple = ("relu", {"inplace": True}), + dropout: float | tuple = 0.0, decoder_bias: bool = False, upsample: str = "nontrainable", interp_mode: str = "nearest", @@ -308,7 +309,7 @@ def __init__( bias=decoder_bias, upsample=upsample, interp_mode=interp_mode, - pre_conv=None, + pre_conv="default", align_corners=None, is_pad=is_pad, ) diff --git a/monai/networks/nets/fullyconnectednet.py b/monai/networks/nets/fullyconnectednet.py index 810c07431b..fe9a8a0cd6 100644 --- a/monai/networks/nets/fullyconnectednet.py +++ b/monai/networks/nets/fullyconnectednet.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple, Union +from __future__ import annotations + +from collections.abc import Sequence import torch import torch.nn as nn @@ -20,9 +22,7 @@ __all__ = ["FullyConnectedNet", "VarFullyConnectedNet"] -def _get_adn_layer( - act: Optional[Union[Tuple, str]], dropout: Optional[Union[Tuple, str, float]], ordering: Optional[str] -) -> ADN: +def _get_adn_layer(act: tuple | str | None, dropout: tuple | str | float | None, ordering: str | None) -> ADN: if ordering: return ADN(act=act, dropout=dropout, dropout_dim=1, ordering=ordering) return ADN(act=act, dropout=dropout, dropout_dim=1) @@ -55,10 +55,10 @@ def __init__( in_channels: int, out_channels: int, hidden_channels: Sequence[int], - dropout: Optional[Union[Tuple, str, float]] = None, - act: Optional[Union[Tuple, str]] = Act.PRELU, + dropout: tuple | str | float | None = None, + act: tuple | str | None = Act.PRELU, bias: bool = True, - adn_ordering: Optional[str] = None, + adn_ordering: str | None = None, ) -> None: """ Defines a network accept input with `in_channels` channels, output of `out_channels` channels, and hidden layers @@ -118,10 +118,10 @@ def __init__( latent_size: int, encode_channels: Sequence[int], decode_channels: Sequence[int], - dropout: Optional[Union[Tuple, str, float]] = None, - act: Optional[Union[Tuple, str]] = Act.PRELU, + dropout: tuple | str | float | None = None, + act: tuple | str | None = Act.PRELU, bias: bool = True, - adn_ordering: Optional[str] = None, + adn_ordering: str | None = None, ) -> None: super().__init__() self.in_channels = in_channels @@ -154,7 +154,7 @@ def _get_layer(self, in_channels: int, out_channels: int, bias: bool) -> nn.Sequ seq.add_module("ADN", self.adn_layer) return seq - def encode_forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def encode_forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: x = self.encode(x) x = self.flatten(x) mu = self.mu(x) @@ -179,7 +179,7 @@ def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor return std.add_(mu) - def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: mu, logvar = self.encode_forward(x) z = self.reparameterize(mu, logvar) return self.decode_forward(z), mu, logvar, z diff --git a/monai/networks/nets/generator.py b/monai/networks/nets/generator.py index a69cae4d7b..32428b2696 100644 --- a/monai/networks/nets/generator.py +++ b/monai/networks/nets/generator.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Union +from __future__ import annotations + +from collections.abc import Sequence import numpy as np import torch @@ -62,11 +64,11 @@ def __init__( start_shape: Sequence[int], channels: Sequence[int], strides: Sequence[int], - kernel_size: Union[Sequence[int], int] = 3, + kernel_size: Sequence[int] | int = 3, num_res_units: int = 2, act=Act.PRELU, norm=Norm.INSTANCE, - dropout: Optional[float] = None, + dropout: float | None = None, bias: bool = True, ) -> None: super().__init__() @@ -100,14 +102,14 @@ def __init__( def _get_layer( self, in_channels: int, out_channels: int, strides: int, is_last: bool - ) -> Union[Convolution, nn.Sequential]: + ) -> Convolution | nn.Sequential: """ Returns a layer accepting inputs with `in_channels` number of channels and producing outputs of `out_channels` number of channels. The `strides` indicates upsampling factor, ie. transpose convolutional stride. If `is_last` is True this is the final layer and is not expected to include activation and normalization layers. """ - layer: Union[Convolution, nn.Sequential] + layer: Convolution | nn.Sequential layer = Convolution( in_channels=in_channels, diff --git a/monai/networks/nets/highresnet.py b/monai/networks/nets/highresnet.py index 891a65e67b..e71f8d193d 100644 --- a/monai/networks/nets/highresnet.py +++ b/monai/networks/nets/highresnet.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from collections.abc import Sequence import torch import torch.nn as nn @@ -40,11 +42,11 @@ def __init__( in_channels: int, out_channels: int, kernels: Sequence[int] = (3, 3), - dilation: Union[Sequence[int], int] = 1, - norm_type: Union[Tuple, str] = ("batch", {"affine": True}), - acti_type: Union[Tuple, str] = ("relu", {"inplace": True}), + dilation: Sequence[int] | int = 1, + norm_type: tuple | str = ("batch", {"affine": True}), + acti_type: tuple | str = ("relu", {"inplace": True}), bias: bool = False, - channel_matching: Union[ChannelMatching, str] = ChannelMatching.PAD, + channel_matching: ChannelMatching | str = ChannelMatching.PAD, ) -> None: """ Args: @@ -138,14 +140,13 @@ def __init__( spatial_dims: int = 3, in_channels: int = 1, out_channels: int = 1, - norm_type: Union[str, tuple] = ("batch", {"affine": True}), - acti_type: Union[str, tuple] = ("relu", {"inplace": True}), - dropout_prob: Optional[Union[Tuple, str, float]] = 0.0, + norm_type: str | tuple = ("batch", {"affine": True}), + acti_type: str | tuple = ("relu", {"inplace": True}), + dropout_prob: tuple | str | float | None = 0.0, bias: bool = False, - layer_params: Sequence[Dict] = DEFAULT_LAYER_PARAMS_3D, - channel_matching: Union[ChannelMatching, str] = ChannelMatching.PAD, + layer_params: Sequence[dict] = DEFAULT_LAYER_PARAMS_3D, + channel_matching: ChannelMatching | str = ChannelMatching.PAD, ) -> None: - super().__init__() blocks = nn.ModuleList() @@ -166,7 +167,7 @@ def __init__( ) # residual blocks - for (idx, params) in enumerate(layer_params[1:-2]): # res blocks except the 1st and last two conv layers. + for idx, params in enumerate(layer_params[1:-2]): # res blocks except the 1st and last two conv layers. _in_chns, _out_chns = _out_chns, params["n_features"] _dilation = 2**idx for _ in range(params["repeat"]): diff --git a/monai/networks/nets/hovernet.py b/monai/networks/nets/hovernet.py index 109d73a8cf..323e107fd7 100644 --- a/monai/networks/nets/hovernet.py +++ b/monai/networks/nets/hovernet.py @@ -27,11 +27,13 @@ # } # ========================================================================= +from __future__ import annotations + import os import re import warnings from collections import OrderedDict -from typing import Callable, Dict, List, Optional, Sequence, Type, Union +from collections.abc import Callable, Sequence import torch import torch.nn as nn @@ -53,8 +55,8 @@ def __init__( in_channels: int, out_channels: int, dropout_prob: float = 0.0, - act: Union[str, tuple] = ("relu", {"inplace": True}), - norm: Union[str, tuple] = "batch", + act: str | tuple = ("relu", {"inplace": True}), + norm: str | tuple = "batch", kernel_size: int = 3, padding: int = 0, ) -> None: @@ -90,7 +92,6 @@ def __init__( self.layers.add_module("dropout", dropout_type(dropout_prob)) def forward(self, x: torch.Tensor) -> torch.Tensor: - x1 = self.layers(x) if x1.shape[-1] != x.shape[-1]: trim = (x.shape[-1] - x1.shape[-1]) // 2 @@ -109,8 +110,8 @@ def __init__( in_channels: int, out_channels: int, dropout_prob: float = 0.0, - act: Union[str, tuple] = ("relu", {"inplace": True}), - norm: Union[str, tuple] = "batch", + act: str | tuple = ("relu", {"inplace": True}), + norm: str | tuple = "batch", kernel_size: int = 3, same_padding: bool = False, ) -> None: @@ -164,8 +165,8 @@ def __init__( in_channels: int, out_channels: int, dropout_prob: float = 0.0, - act: Union[str, tuple] = ("relu", {"inplace": True}), - norm: Union[str, tuple] = "batch", + act: str | tuple = ("relu", {"inplace": True}), + norm: str | tuple = "batch", drop_first_norm_relu: int = 0, kernel_size: int = 3, ) -> None: @@ -219,7 +220,7 @@ def __init__( class _Transition(nn.Sequential): def __init__( - self, in_channels: int, act: Union[str, tuple] = ("relu", {"inplace": True}), norm: Union[str, tuple] = "batch" + self, in_channels: int, act: str | tuple = ("relu", {"inplace": True}), norm: str | tuple = "batch" ) -> None: """ Args: @@ -241,8 +242,8 @@ def __init__( in_channels: int, out_channels: int, dropout_prob: float = 0.0, - act: Union[str, tuple] = ("relu", {"inplace": True}), - norm: Union[str, tuple] = "batch", + act: str | tuple = ("relu", {"inplace": True}), + norm: str | tuple = "batch", freeze_dense_layer: bool = False, freeze_block: bool = False, ) -> None: @@ -292,7 +293,6 @@ def __init__( self.requires_grad_(False) def forward(self, x: torch.Tensor) -> torch.Tensor: - sc = self.shortcut(x) if self.shortcut.stride == (2, 2): @@ -315,8 +315,8 @@ class _DecoderBranch(nn.ModuleList): def __init__( self, decode_config: Sequence[int] = (8, 4), - act: Union[str, tuple] = ("relu", {"inplace": True}), - norm: Union[str, tuple] = "batch", + act: str | tuple = ("relu", {"inplace": True}), + norm: str | tuple = "batch", dropout_prob: float = 0.0, out_channels: int = 2, kernel_size: int = 3, @@ -385,8 +385,7 @@ def __init__( 2, scale_factor=2, mode=UpsampleMode.NONTRAINABLE, interp_mode=InterpolateMode.BILINEAR, bias=False ) - def forward(self, xin: torch.Tensor, short_cuts: List[torch.Tensor]) -> torch.Tensor: - + def forward(self, xin: torch.Tensor, short_cuts: list[torch.Tensor]) -> torch.Tensor: block_number = len(short_cuts) - 1 x = xin + short_cuts[block_number] @@ -416,6 +415,10 @@ class HoVerNet(nn.Module): https://github.com/vqdang/hover_net https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html + This network is non-deterministic since it uses `torch.nn.Upsample` with ``UpsampleMode.NONTRAINABLE`` mode which + is implemented with torch.nn.functional.interpolate(). Please check the link below for more details: + https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms + Args: mode: use original implementation (`HoVerNetMODE.ORIGINAL` or "original") or a faster implementation (`HoVerNetMODE.FAST` or "fast"). Defaults to `HoVerNetMODE.FAST`. @@ -448,19 +451,18 @@ class HoVerNet(nn.Module): def __init__( self, - mode: Union[HoVerNetMode, str] = HoVerNetMode.FAST, + mode: HoVerNetMode | str = HoVerNetMode.FAST, in_channels: int = 3, np_out_channels: int = 2, out_classes: int = 0, - act: Union[str, tuple] = ("relu", {"inplace": True}), - norm: Union[str, tuple] = "batch", + act: str | tuple = ("relu", {"inplace": True}), + norm: str | tuple = "batch", decoder_padding: bool = False, dropout_prob: float = 0.0, - pretrained_url: Optional[str] = None, + pretrained_url: str | None = None, adapt_standard_resnet: bool = False, freeze_encoder: bool = False, ) -> None: - super().__init__() if isinstance(mode, str): @@ -492,7 +494,7 @@ def __init__( _ksize = 5 _pad = 0 - conv_type: Type[nn.Conv2d] = Conv[Conv.CONV, 2] + conv_type: type[nn.Conv2d] = Conv[Conv.CONV, 2] self.conv0 = nn.Sequential( OrderedDict( @@ -549,7 +551,7 @@ def __init__( kernel_size=_ksize, same_padding=decoder_padding, out_channels=np_out_channels ) self.horizontal_vertical = _DecoderBranch(kernel_size=_ksize, same_padding=decoder_padding) - self.type_prediction: Optional[_DecoderBranch] = ( + self.type_prediction: _DecoderBranch | None = ( _DecoderBranch(out_channels=out_classes, kernel_size=_ksize, same_padding=decoder_padding) if out_classes > 0 else None @@ -569,8 +571,7 @@ def __init__( weights = _remap_preact_resnet_model(pretrained_url) _load_pretrained_encoder(self, weights) - def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: - + def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: if self.mode == HoVerNetMode.ORIGINAL.value: if x.shape[-1] != 270 or x.shape[-2] != 270: raise ValueError("Input size should be 270 x 270 when using HoVerNetMode.ORIGINAL") @@ -600,8 +601,7 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: return output -def _load_pretrained_encoder(model: nn.Module, state_dict: Union[OrderedDict, Dict]): - +def _load_pretrained_encoder(model: nn.Module, state_dict: OrderedDict | dict): model_dict = model.state_dict() state_dict = { k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) @@ -612,7 +612,6 @@ def _load_pretrained_encoder(model: nn.Module, state_dict: Union[OrderedDict, Di def _remap_preact_resnet_model(model_url: str): - pattern_conv0 = re.compile(r"^(conv0\.\/)(.+)$") pattern_block = re.compile(r"^(d\d+)\.(.+)$") pattern_layer = re.compile(r"^(.+\.d\d+)\.units\.(\d+)(.+)$") @@ -641,7 +640,6 @@ def _remap_preact_resnet_model(model_url: str): def _remap_standard_resnet_model(model_url: str): - pattern_conv0 = re.compile(r"^conv1\.(.+)$") pattern_bn1 = re.compile(r"^bn1\.(.+)$") pattern_block = re.compile(r"^layer(\d+)\.(\d+)\.(.+)$") diff --git a/monai/networks/nets/milmodel.py b/monai/networks/nets/milmodel.py index 88468d39a6..0a25b7feec 100644 --- a/monai/networks/nets/milmodel.py +++ b/monai/networks/nets/milmodel.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Union, cast +from __future__ import annotations + +from typing import cast import torch import torch.nn as nn @@ -54,12 +56,11 @@ def __init__( num_classes: int, mil_mode: str = "att", pretrained: bool = True, - backbone: Optional[Union[str, nn.Module]] = None, - backbone_num_features: Optional[int] = None, + backbone: str | nn.Module | None = None, + backbone_num_features: int | None = None, trans_blocks: int = 4, trans_dropout: float = 0.0, ) -> None: - super().__init__() if num_classes <= 0: @@ -70,15 +71,14 @@ def __init__( self.mil_mode = mil_mode.lower() self.attention = nn.Sequential() - self.transformer = None # type: Optional[nn.Module] + self.transformer: nn.Module | None = None if backbone is None: - net = models.resnet50(pretrained=pretrained) nfc = net.fc.in_features # save the number of final features net.fc = torch.nn.Identity() # remove final linear layer - self.extra_outputs: Dict[str, torch.Tensor] = {} + self.extra_outputs: dict[str, torch.Tensor] = {} if mil_mode == "att_trans_pyramid": # register hooks to capture outputs of intermediate layers @@ -94,7 +94,6 @@ def hook(module, input, output): net.layer4.register_forward_hook(forward_hook("layer4")) elif isinstance(backbone, str): - # assume torchvision model string is provided torch_model = getattr(models, backbone, None) if torch_model is None: @@ -135,7 +134,6 @@ def hook(module, input, output): self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) elif self.mil_mode == "att_trans_pyramid": - transformer_list = nn.ModuleList( [ nn.TransformerEncoder( @@ -172,7 +170,6 @@ def hook(module, input, output): self.net = net def calc_head(self, x: torch.Tensor) -> torch.Tensor: - sh = x.shape if self.mil_mode == "mean": @@ -184,7 +181,6 @@ def calc_head(self, x: torch.Tensor) -> torch.Tensor: x, _ = torch.max(x, dim=1) elif self.mil_mode == "att": - a = self.attention(x) a = torch.softmax(a, dim=1) x = torch.sum(x * a, dim=1) @@ -192,7 +188,6 @@ def calc_head(self, x: torch.Tensor) -> torch.Tensor: x = self.myfc(x) elif self.mil_mode == "att_trans" and self.transformer is not None: - x = x.permute(1, 0, 2) x = self.transformer(x) x = x.permute(1, 0, 2) @@ -204,7 +199,6 @@ def calc_head(self, x: torch.Tensor) -> torch.Tensor: x = self.myfc(x) elif self.mil_mode == "att_trans_pyramid" and self.transformer is not None: - l1 = torch.mean(self.extra_outputs["layer1"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) l2 = torch.mean(self.extra_outputs["layer2"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) l3 = torch.mean(self.extra_outputs["layer3"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) @@ -231,7 +225,6 @@ def calc_head(self, x: torch.Tensor) -> torch.Tensor: return x def forward(self, x: torch.Tensor, no_head: bool = False) -> torch.Tensor: - sh = x.shape x = x.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4]) diff --git a/monai/networks/nets/netadapter.py b/monai/networks/nets/netadapter.py index 39112e4d54..452c31be37 100644 --- a/monai/networks/nets/netadapter.py +++ b/monai/networks/nets/netadapter.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple, Union +from __future__ import annotations + +from typing import Any, Dict import torch @@ -52,9 +54,9 @@ def __init__( model: torch.nn.Module, num_classes: int = 1, dim: int = 2, - in_channels: Optional[int] = None, + in_channels: int | None = None, use_conv: bool = False, - pool: Optional[Tuple[str, Dict[str, Any]]] = ("avg", {"kernel_size": 7, "stride": 1}), + pool: tuple[str, dict[str, Any]] | None = ("avg", {"kernel_size": 7, "stride": 1}), bias: bool = True, fc_name: str = "fc", node_name: str = "", @@ -94,7 +96,7 @@ def __init__( self.pool = get_pool_layer(name=pool, spatial_dims=dim) # create new fully connected layer or kernel size 1 convolutional layer - self.fc: Union[torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv3d] + self.fc: torch.nn.Linear | torch.nn.Conv2d | torch.nn.Conv3d if use_conv: self.fc = Conv[Conv.CONV, dim](in_channels=in_channels_, out_channels=num_classes, kernel_size=1, bias=bias) else: diff --git a/monai/networks/nets/regressor.py b/monai/networks/nets/regressor.py index 0a1e6258a9..a54b926bd0 100644 --- a/monai/networks/nets/regressor.py +++ b/monai/networks/nets/regressor.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Union +from __future__ import annotations + +from collections.abc import Sequence import numpy as np import torch @@ -61,11 +63,11 @@ def __init__( out_shape: Sequence[int], channels: Sequence[int], strides: Sequence[int], - kernel_size: Union[Sequence[int], int] = 3, + kernel_size: Sequence[int] | int = 3, num_res_units: int = 2, act=Act.PRELU, norm=Norm.INSTANCE, - dropout: Optional[float] = None, + dropout: float | None = None, bias: bool = True, ) -> None: super().__init__() @@ -101,14 +103,14 @@ def __init__( def _get_layer( self, in_channels: int, out_channels: int, strides: int, is_last: bool - ) -> Union[ResidualUnit, Convolution]: + ) -> ResidualUnit | Convolution: """ Returns a layer accepting inputs with `in_channels` number of channels and producing outputs of `out_channels` number of channels. The `strides` indicates downsampling factor, ie. convolutional stride. If `is_last` is True this is the final layer and is not expected to include activation and normalization layers. """ - layer: Union[ResidualUnit, Convolution] + layer: ResidualUnit | Convolution if self.num_res_units > 0: layer = ResidualUnit( diff --git a/monai/networks/nets/regunet.py b/monai/networks/nets/regunet.py index ee78342459..1764c480f3 100644 --- a/monai/networks/nets/regunet.py +++ b/monai/networks/nets/regunet.py @@ -8,7 +8,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union + +from __future__ import annotations import torch from torch import nn @@ -46,13 +47,13 @@ def __init__( in_channels: int, num_channel_initial: int, depth: int, - out_kernel_initializer: Optional[str] = "kaiming_uniform", - out_activation: Optional[str] = None, + out_kernel_initializer: str | None = "kaiming_uniform", + out_activation: str | None = None, out_channels: int = 3, - extract_levels: Optional[Tuple[int]] = None, + extract_levels: tuple[int] | None = None, pooling: bool = True, concat_skip: bool = False, - encode_kernel_sizes: Union[int, List[int]] = 3, + encode_kernel_sizes: int | list[int] = 3, ): """ Args: @@ -90,7 +91,7 @@ def __init__( encode_kernel_sizes = [encode_kernel_sizes] * (self.depth + 1) if len(encode_kernel_sizes) != self.depth + 1: raise AssertionError - self.encode_kernel_sizes: List[int] = encode_kernel_sizes + self.encode_kernel_sizes: list[int] = encode_kernel_sizes self.num_channels = [self.num_channel_initial * (2**d) for d in range(self.depth + 1)] self.min_extract_level = min(self.extract_levels) @@ -233,7 +234,7 @@ def forward(self, x): class AffineHead(nn.Module): - def __init__(self, spatial_dims: int, image_size: List[int], decode_size: List[int], in_channels: int): + def __init__(self, spatial_dims: int, image_size: list[int], decode_size: list[int], in_channels: int): super().__init__() self.spatial_dims = spatial_dims if spatial_dims == 2: @@ -255,7 +256,7 @@ def __init__(self, spatial_dims: int, image_size: List[int], decode_size: List[i self.fc.bias.data.copy_(out_init) @staticmethod - def get_reference_grid(image_size: Union[Tuple[int], List[int]]) -> torch.Tensor: + def get_reference_grid(image_size: tuple[int] | list[int]) -> torch.Tensor: mesh_points = [torch.arange(0, dim) for dim in image_size] grid = torch.stack(meshgrid_ij(*mesh_points), dim=0) # (spatial_dims, ...) return grid.to(dtype=torch.float) @@ -273,7 +274,7 @@ def affine_transform(self, theta: torch.Tensor): raise ValueError(f"do not support spatial_dims={self.spatial_dims}") return grid_warped - def forward(self, x: List[torch.Tensor], image_size: List[int]) -> torch.Tensor: + def forward(self, x: list[torch.Tensor], image_size: list[int]) -> torch.Tensor: f = x[0] self.grid = self.grid.to(device=f.device) theta = self.fc(f.reshape(f.shape[0], -1)) @@ -294,16 +295,16 @@ class GlobalNet(RegUNet): def __init__( self, - image_size: List[int], + image_size: list[int], spatial_dims: int, in_channels: int, num_channel_initial: int, depth: int, - out_kernel_initializer: Optional[str] = "kaiming_uniform", - out_activation: Optional[str] = None, + out_kernel_initializer: str | None = "kaiming_uniform", + out_activation: str | None = None, pooling: bool = True, concat_skip: bool = False, - encode_kernel_sizes: Union[int, List[int]] = 3, + encode_kernel_sizes: int | list[int] = 3, ): for size in image_size: if size % (2**depth) != 0: @@ -337,14 +338,23 @@ def build_output_block(self): class AdditiveUpSampleBlock(nn.Module): - def __init__(self, spatial_dims: int, in_channels: int, out_channels: int): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + mode: str = "nearest", + align_corners: bool | None = None, + ): super().__init__() self.deconv = get_deconv_block(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels) + self.mode = mode + self.align_corners = align_corners def forward(self, x: torch.Tensor) -> torch.Tensor: output_size = [size * 2 for size in x.shape[2:]] deconved = self.deconv(x) - resized = F.interpolate(x, output_size) + resized = F.interpolate(x, output_size, mode=self.mode, align_corners=self.align_corners) resized = torch.sum(torch.stack(resized.split(split_size=resized.shape[1] // 2, dim=1), dim=-1), dim=-1) out: torch.Tensor = deconved + resized return out @@ -367,13 +377,15 @@ def __init__( spatial_dims: int, in_channels: int, num_channel_initial: int, - extract_levels: Tuple[int], - out_kernel_initializer: Optional[str] = "kaiming_uniform", - out_activation: Optional[str] = None, + extract_levels: tuple[int], + out_kernel_initializer: str | None = "kaiming_uniform", + out_activation: str | None = None, out_channels: int = 3, pooling: bool = True, - use_addictive_sampling: bool = True, + use_additive_sampling: bool = True, concat_skip: bool = False, + mode: str = "nearest", + align_corners: bool | None = None, ): """ Args: @@ -385,10 +397,14 @@ def __init__( out_channels: number of channels for the output extract_levels: list, which levels from net to extract. The maximum level must equal to ``depth`` pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv3d - use_addictive_sampling: whether use additive up-sampling layer for decoding. + use_additive_sampling: whether use additive up-sampling layer for decoding. concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition + mode: mode for interpolation when use_additive_sampling, default is "nearest". + align_corners: align_corners for interpolation when use_additive_sampling, default is None. """ - self.use_additive_upsampling = use_addictive_sampling + self.use_additive_upsampling = use_additive_sampling + self.mode = mode + self.align_corners = align_corners super().__init__( spatial_dims=spatial_dims, in_channels=in_channels, @@ -412,7 +428,11 @@ def build_bottom_block(self, in_channels: int, out_channels: int): def build_up_sampling_block(self, in_channels: int, out_channels: int) -> nn.Module: if self.use_additive_upsampling: return AdditiveUpSampleBlock( - spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels + spatial_dims=self.spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + mode=self.mode, + align_corners=self.align_corners, ) return get_deconv_block(spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index fca975d40e..02869d415f 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -9,8 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from collections.abc import Callable from functools import partial -from typing import Any, Callable, List, Tuple, Type, Union +from typing import Any import torch import torch.nn as nn @@ -51,7 +54,7 @@ def __init__( planes: int, spatial_dims: int = 3, stride: int = 1, - downsample: Union[nn.Module, partial, None] = None, + downsample: nn.Module | partial | None = None, ) -> None: """ Args: @@ -102,7 +105,7 @@ def __init__( planes: int, spatial_dims: int = 3, stride: int = 1, - downsample: Union[nn.Module, partial, None] = None, + downsample: nn.Module | partial | None = None, ) -> None: """ Args: @@ -181,13 +184,13 @@ class ResNet(nn.Module): def __init__( self, - block: Union[Type[Union[ResNetBlock, ResNetBottleneck]], str], - layers: List[int], - block_inplanes: List[int], + block: type[ResNetBlock | ResNetBottleneck] | str, + layers: list[int], + block_inplanes: list[int], spatial_dims: int = 3, n_input_channels: int = 3, - conv1_t_size: Union[Tuple[int], int] = 7, - conv1_t_stride: Union[Tuple[int], int] = 1, + conv1_t_size: tuple[int] | int = 7, + conv1_t_stride: tuple[int] | int = 1, no_max_pool: bool = False, shortcut_type: str = "B", widen_factor: float = 1.0, @@ -195,7 +198,6 @@ def __init__( feed_forward: bool = True, bias_downsample: bool = True, # for backwards compatibility (also see PR #5477) ) -> None: - super().__init__() if isinstance(block, str): @@ -206,10 +208,10 @@ def __init__( else: raise ValueError("Unknown block '%s', use basic or bottleneck" % block) - conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] - norm_type: Type[Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] - pool_type: Type[Union[nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d]] = Pool[Pool.MAX, spatial_dims] - avgp_type: Type[Union[nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d]] = Pool[ + conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims] + norm_type: type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims] + pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims] + avgp_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[ Pool.ADAPTIVEAVG, spatial_dims ] @@ -258,18 +260,17 @@ def _downsample_basic_block(self, x: torch.Tensor, planes: int, stride: int, spa def _make_layer( self, - block: Type[Union[ResNetBlock, ResNetBottleneck]], + block: type[ResNetBlock | ResNetBottleneck], planes: int, blocks: int, spatial_dims: int, shortcut_type: str, stride: int = 1, ) -> nn.Sequential: - conv_type: Callable = Conv[Conv.CONV, spatial_dims] norm_type: Callable = Norm[Norm.BATCH, spatial_dims] - downsample: Union[nn.Module, partial, None] = None + downsample: nn.Module | partial | None = None if stride != 1 or self.in_planes != planes * block.expansion: if look_up_option(shortcut_type, {"A", "B"}) == "A": downsample = partial( @@ -325,9 +326,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def _resnet( arch: str, - block: Type[Union[ResNetBlock, ResNetBottleneck]], - layers: List[int], - block_inplanes: List[int], + block: type[ResNetBlock | ResNetBottleneck], + layers: list[int], + block_inplanes: list[int], pretrained: bool, progress: bool, **kwargs: Any, diff --git a/monai/networks/nets/segresnet.py b/monai/networks/nets/segresnet.py index cc908f1640..0c1c85f04a 100644 --- a/monai/networks/nets/segresnet.py +++ b/monai/networks/nets/segresnet.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from collections.abc import Sequence import numpy as np import torch @@ -60,15 +62,15 @@ def __init__( init_filters: int = 8, in_channels: int = 1, out_channels: int = 2, - dropout_prob: Optional[float] = None, - act: Union[Tuple, str] = ("RELU", {"inplace": True}), - norm: Union[Tuple, str] = ("GROUP", {"num_groups": 8}), + dropout_prob: float | None = None, + act: tuple | str = ("RELU", {"inplace": True}), + norm: tuple | str = ("GROUP", {"num_groups": 8}), norm_name: str = "", num_groups: int = 8, use_conv_final: bool = True, blocks_down: tuple = (1, 2, 2, 4), blocks_up: tuple = (1, 1, 1), - upsample_mode: Union[UpsampleMode, str] = UpsampleMode.NONTRAINABLE, + upsample_mode: UpsampleMode | str = UpsampleMode.NONTRAINABLE, ): super().__init__() @@ -151,7 +153,7 @@ def _make_final_conv(self, out_channels: int): get_conv_layer(self.spatial_dims, self.init_filters, out_channels, kernel_size=1, bias=True), ) - def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]: x = self.convInit(x) if self.dropout_prob is not None: x = self.dropout(x) @@ -164,7 +166,7 @@ def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: return x, down_x - def decode(self, x: torch.Tensor, down_x: List[torch.Tensor]) -> torch.Tensor: + def decode(self, x: torch.Tensor, down_x: list[torch.Tensor]) -> torch.Tensor: for i, (up, upl) in enumerate(zip(self.up_samples, self.up_layers)): x = up(x) + down_x[i + 1] x = upl(x) @@ -225,13 +227,13 @@ def __init__( init_filters: int = 8, in_channels: int = 1, out_channels: int = 2, - dropout_prob: Optional[float] = None, - act: Union[str, tuple] = ("RELU", {"inplace": True}), - norm: Union[Tuple, str] = ("GROUP", {"num_groups": 8}), + dropout_prob: float | None = None, + act: str | tuple = ("RELU", {"inplace": True}), + norm: tuple | str = ("GROUP", {"num_groups": 8}), use_conv_final: bool = True, blocks_down: tuple = (1, 2, 2, 4), blocks_up: tuple = (1, 1, 1), - upsample_mode: Union[UpsampleMode, str] = UpsampleMode.NONTRAINABLE, + upsample_mode: UpsampleMode | str = UpsampleMode.NONTRAINABLE, ): super().__init__( spatial_dims=spatial_dims, diff --git a/monai/networks/nets/segresnet_ds.py b/monai/networks/nets/segresnet_ds.py index 15129a7996..07f3824b51 100644 --- a/monai/networks/nets/segresnet_ds.py +++ b/monai/networks/nets/segresnet_ds.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Optional, Tuple, Union +from __future__ import annotations + +from collections.abc import Callable import numpy as np import torch @@ -23,7 +25,7 @@ __all__ = ["SegResNetDS"] -def scales_for_resolution(resolution: Union[Tuple, List], n_stages: Optional[int] = None): +def scales_for_resolution(resolution: tuple | list, n_stages: int | None = None): """ A helper function to compute a schedule of scale at different downsampling levels, given the input resolution. @@ -51,7 +53,7 @@ def scales_for_resolution(resolution: Union[Tuple, List], n_stages: Optional[int return scales -def aniso_kernel(scale: Union[Tuple, List]): +def aniso_kernel(scale: tuple | list): """ A helper function to compute kernel_size, padding and stride for the given scale @@ -73,9 +75,9 @@ def __init__( self, spatial_dims: int, in_channels: int, - norm: Union[Tuple, str], - kernel_size: Union[Tuple, int] = 3, - act: Union[Tuple, str] = "relu", + norm: tuple | str, + kernel_size: tuple | int = 3, + act: tuple | str = "relu", ) -> None: """ Args: @@ -144,13 +146,12 @@ def __init__( spatial_dims: int = 3, init_filters: int = 32, in_channels: int = 1, - act: Union[Tuple, str] = "relu", - norm: Union[Tuple, str] = "batch", + act: tuple | str = "relu", + norm: tuple | str = "batch", blocks_down: tuple = (1, 2, 2, 4), - head_module: Optional[nn.Module] = None, - anisotropic_scales: Optional[Tuple] = None, + head_module: nn.Module | None = None, + anisotropic_scales: tuple | None = None, ): - super().__init__() if spatial_dims not in (1, 2, 3): @@ -212,8 +213,7 @@ def __init__( self.act = act self.spatial_dims = spatial_dims - def _forward(self, x: torch.Tensor) -> List[torch.Tensor]: - + def _forward(self, x: torch.Tensor) -> list[torch.Tensor]: outputs = [] x = self.conv_init(x) @@ -227,7 +227,7 @@ def _forward(self, x: torch.Tensor) -> List[torch.Tensor]: return outputs - def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: return self._forward(x) @@ -262,16 +262,15 @@ def __init__( init_filters: int = 32, in_channels: int = 1, out_channels: int = 2, - act: Union[Tuple, str] = "relu", - norm: Union[Tuple, str] = "batch", + act: tuple | str = "relu", + norm: tuple | str = "batch", blocks_down: tuple = (1, 2, 2, 4), - blocks_up: Optional[Tuple] = None, + blocks_up: tuple | None = None, dsdepth: int = 1, - preprocess: Optional[Union[nn.Module, Callable]] = None, - upsample_mode: Union[UpsampleMode, str] = "deconv", - resolution: Optional[Tuple] = None, + preprocess: nn.Module | Callable | None = None, + upsample_mode: UpsampleMode | str = "deconv", + resolution: tuple | None = None, ): - super().__init__() if spatial_dims not in (1, 2, 3): @@ -328,7 +327,6 @@ def __init__( self.up_layers = nn.ModuleList() for i in range(n_up): - filters = filters // 2 kernel_size, _, stride = ( aniso_kernel(anisotropic_scales[len(blocks_up) - i - 1]) if anisotropic_scales else (3, 1, 2) @@ -389,8 +387,7 @@ def is_valid_shape(self, x): a = [i % j == 0 for i, j in zip(x.shape[2:], self.shape_factor())] return all(a) - def _forward(self, x: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]: - + def _forward(self, x: torch.Tensor) -> torch.Tensor | list[torch.Tensor]: if self.preprocess is not None: x = self.preprocess(x) @@ -405,7 +402,7 @@ def _forward(self, x: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]: if len(x_down) == 0: x_down = [torch.zeros(1, device=x.device, dtype=x.dtype)] - outputs: List[torch.Tensor] = [] + outputs: list[torch.Tensor] = [] i = 0 for level in self.up_layers: @@ -426,5 +423,5 @@ def _forward(self, x: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]: # return a list of DS outputs return outputs - def forward(self, x: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]: + def forward(self, x: torch.Tensor) -> torch.Tensor | list[torch.Tensor]: return self._forward(x) diff --git a/monai/networks/nets/senet.py b/monai/networks/nets/senet.py index b4c024b1f2..51435a9ea2 100644 --- a/monai/networks/nets/senet.py +++ b/monai/networks/nets/senet.py @@ -9,9 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import re from collections import OrderedDict -from typing import Any, List, Optional, Sequence, Tuple, Type, Union +from collections.abc import Sequence +from typing import Any import torch import torch.nn as nn @@ -94,18 +97,17 @@ def __init__( self, spatial_dims: int, in_channels: int, - block: Union[Type[Union[SEBottleneck, SEResNetBottleneck, SEResNeXtBottleneck]], str], + block: type[SEBottleneck | SEResNetBottleneck | SEResNeXtBottleneck] | str, layers: Sequence[int], groups: int, reduction: int, - dropout_prob: Optional[float] = 0.2, + dropout_prob: float | None = 0.2, dropout_dim: int = 1, inplanes: int = 128, downsample_kernel_size: int = 3, input_3x3: bool = True, num_classes: int = 1000, ) -> None: - super().__init__() if isinstance(block, str): @@ -120,19 +122,19 @@ def __init__( "Unknown block '%s', use se_bottleneck, se_resnet_bottleneck or se_resnetxt_bottleneck" % block ) - relu_type: Type[nn.ReLU] = Act[Act.RELU] - conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] - pool_type: Type[Union[nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d]] = Pool[Pool.MAX, spatial_dims] - norm_type: Type[Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] - dropout_type: Type[Union[nn.Dropout, nn.Dropout2d, nn.Dropout3d]] = Dropout[Dropout.DROPOUT, dropout_dim] - avg_pool_type: Type[Union[nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d]] = Pool[ + relu_type: type[nn.ReLU] = Act[Act.RELU] + conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims] + pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims] + norm_type: type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims] + dropout_type: type[nn.Dropout | nn.Dropout2d | nn.Dropout3d] = Dropout[Dropout.DROPOUT, dropout_dim] + avg_pool_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[ Pool.ADAPTIVEAVG, spatial_dims ] self.inplanes = inplanes self.spatial_dims = spatial_dims - layer0_modules: List[Tuple[str, Any]] + layer0_modules: list[tuple[str, Any]] if input_3x3: layer0_modules = [ @@ -211,7 +213,7 @@ def __init__( def _make_layer( self, - block: Type[Union[SEBottleneck, SEResNetBottleneck, SEResNeXtBottleneck]], + block: type[SEBottleneck | SEResNetBottleneck | SEResNeXtBottleneck], planes: int, blocks: int, groups: int, @@ -219,7 +221,6 @@ def _make_layer( stride: int = 1, downsample_kernel_size: int = 1, ) -> nn.Sequential: - downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = Convolution( @@ -358,7 +359,7 @@ def __init__( layers: Sequence[int] = (3, 4, 6, 3), groups: int = 1, reduction: int = 16, - dropout_prob: Optional[float] = None, + dropout_prob: float | None = None, inplanes: int = 64, downsample_kernel_size: int = 1, input_3x3: bool = False, @@ -456,7 +457,7 @@ def __init__( layers: Sequence[int] = (3, 4, 6, 3), groups: int = 32, reduction: int = 16, - dropout_prob: Optional[float] = None, + dropout_prob: float | None = None, inplanes: int = 64, downsample_kernel_size: int = 1, input_3x3: bool = False, @@ -490,7 +491,7 @@ def __init__( layers: Sequence[int] = (3, 4, 23, 3), groups: int = 32, reduction: int = 16, - dropout_prob: Optional[float] = None, + dropout_prob: float | None = None, inplanes: int = 64, downsample_kernel_size: int = 1, input_3x3: bool = False, diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 3f18ac4d40..9f8204968f 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import itertools -from typing import Optional, Sequence, Tuple, Type, Union +from collections.abc import Sequence import numpy as np import torch @@ -49,13 +51,13 @@ class SwinUNETR(nn.Module): def __init__( self, - img_size: Union[Sequence[int], int], + img_size: Sequence[int] | int, in_channels: int, out_channels: int, depths: Sequence[int] = (2, 2, 2, 2), num_heads: Sequence[int] = (3, 6, 12, 24), feature_size: int = 24, - norm_name: Union[Tuple, str] = "instance", + norm_name: tuple | str = "instance", drop_rate: float = 0.0, attn_drop_rate: float = 0.0, dropout_path_rate: float = 0.0, @@ -244,7 +246,6 @@ def __init__( self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels) def load_from(self, weights): - with torch.no_grad(): self.swinViT.patch_embed.proj.weight.copy_(weights["state_dict"]["module.patch_embed.proj.weight"]) self.swinViT.patch_embed.proj.bias.copy_(weights["state_dict"]["module.patch_embed.proj.bias"]) @@ -530,7 +531,7 @@ def __init__( attn_drop: float = 0.0, drop_path: float = 0.0, act_layer: str = "GELU", - norm_layer: Type[LayerNorm] = nn.LayerNorm, + norm_layer: type[LayerNorm] = nn.LayerNorm, use_checkpoint: bool = False, ) -> None: """ @@ -684,7 +685,7 @@ class PatchMergingV2(nn.Module): https://github.com/microsoft/Swin-Transformer """ - def __init__(self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3) -> None: + def __init__(self, dim: int, norm_layer: type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3) -> None: """ Args: dim: number of feature channels. @@ -702,7 +703,6 @@ def __init__(self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm, spatial self.norm = norm_layer(4 * dim) def forward(self, x): - x_shape = x.size() if len(x_shape) == 5: b, d, h, w, c = x_shape @@ -814,8 +814,8 @@ def __init__( qkv_bias: bool = False, drop: float = 0.0, attn_drop: float = 0.0, - norm_layer: Type[LayerNorm] = nn.LayerNorm, - downsample: Optional[nn.Module] = None, + norm_layer: type[LayerNorm] = nn.LayerNorm, + downsample: nn.Module | None = None, use_checkpoint: bool = False, ) -> None: """ @@ -916,7 +916,7 @@ def __init__( drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, - norm_layer: Type[LayerNorm] = nn.LayerNorm, + norm_layer: type[LayerNorm] = nn.LayerNorm, patch_norm: bool = False, use_checkpoint: bool = False, spatial_dims: int = 3, diff --git a/monai/networks/nets/torchvision_fc.py b/monai/networks/nets/torchvision_fc.py index 85103a2c04..6f63a34951 100644 --- a/monai/networks/nets/torchvision_fc.py +++ b/monai/networks/nets/torchvision_fc.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple +from __future__ import annotations + +from typing import Any from monai.networks.nets import NetAdapter from monai.utils import optional_import @@ -100,9 +102,9 @@ def __init__( model_name: str = "resnet18", num_classes: int = 1, dim: int = 2, - in_channels: Optional[int] = None, + in_channels: int | None = None, use_conv: bool = False, - pool: Optional[Tuple[str, Dict[str, Any]]] = ("avg", {"kernel_size": 7, "stride": 1}), + pool: tuple[str, dict[str, Any]] | None = ("avg", {"kernel_size": 7, "stride": 1}), bias: bool = True, pretrained: bool = False, fc_name: str = "fc", diff --git a/monai/networks/nets/transchex.py b/monai/networks/nets/transchex.py index f138a36d09..31e27ffbf2 100644 --- a/monai/networks/nets/transchex.py +++ b/monai/networks/nets/transchex.py @@ -9,12 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import math import os import shutil import tarfile import tempfile -from typing import List, Sequence, Tuple, Union +from collections.abc import Sequence import torch from torch import nn @@ -74,7 +76,6 @@ def from_pretrained( with tarfile.open(resolved_archive_file, "r:gz") as archive: def is_within_directory(directory, target): - abs_directory = os.path.abspath(directory) abs_target = os.path.abspath(target) @@ -83,7 +84,6 @@ def is_within_directory(directory, target): return prefix == abs_directory def safe_extract(tar, path=".", members=None, *, numeric_owner=False): - for member in tar.getmembers(): member_path = os.path.join(path, member.name) if not is_within_directory(path, member_path): @@ -115,9 +115,9 @@ def safe_extract(tar, path=".", members=None, *, numeric_owner=False): new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) - missing_keys: List = [] - unexpected_keys: List = [] - error_msgs: List = [] + missing_keys: list = [] + unexpected_keys: list = [] + error_msgs: list = [] metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: @@ -276,8 +276,8 @@ class Transchex(torch.nn.Module): def __init__( self, in_channels: int, - img_size: Union[Sequence[int], int], - patch_size: Union[int, Tuple[int, int]], + img_size: Sequence[int] | int, + patch_size: int | tuple[int, int], num_classes: int, num_language_layers: int, num_vision_layers: int, diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 6db41ef8fb..a48aabf915 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import Optional, Sequence, Tuple, Union +from collections.abc import Sequence import torch import torch.nn as nn @@ -18,7 +20,7 @@ from monai.networks.blocks.convolutions import Convolution, ResidualUnit from monai.networks.layers.factories import Act, Norm from monai.networks.layers.simplelayers import SkipConnection -from monai.utils import alias, deprecated_arg, export +from monai.utils import alias, export __all__ = ["UNet", "Unet"] @@ -107,9 +109,6 @@ class UNet(nn.Module): """ - @deprecated_arg( - name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." - ) def __init__( self, spatial_dims: int, @@ -117,17 +116,15 @@ def __init__( out_channels: int, channels: Sequence[int], strides: Sequence[int], - kernel_size: Union[Sequence[int], int] = 3, - up_kernel_size: Union[Sequence[int], int] = 3, + kernel_size: Sequence[int] | int = 3, + up_kernel_size: Sequence[int] | int = 3, num_res_units: int = 0, - act: Union[Tuple, str] = Act.PRELU, - norm: Union[Tuple, str] = Norm.INSTANCE, + act: tuple | str = Act.PRELU, + norm: tuple | str = Norm.INSTANCE, dropout: float = 0.0, bias: bool = True, adn_ordering: str = "NDA", - dimensions: Optional[int] = None, ) -> None: - super().__init__() if len(channels) < 2: @@ -137,14 +134,10 @@ def __init__( raise ValueError("the length of `strides` should equal to `len(channels) - 1`.") if delta > 0: warnings.warn(f"`len(strides) > len(channels) - 1`, the last {delta} values of strides will not be used.") - if dimensions is not None: - spatial_dims = dimensions - if isinstance(kernel_size, Sequence): - if len(kernel_size) != spatial_dims: - raise ValueError("the length of `kernel_size` should equal to `dimensions`.") - if isinstance(up_kernel_size, Sequence): - if len(up_kernel_size) != spatial_dims: - raise ValueError("the length of `up_kernel_size` should equal to `dimensions`.") + if isinstance(kernel_size, Sequence) and len(kernel_size) != spatial_dims: + raise ValueError("the length of `kernel_size` should equal to `dimensions`.") + if isinstance(up_kernel_size, Sequence) and len(up_kernel_size) != spatial_dims: + raise ValueError("the length of `up_kernel_size` should equal to `dimensions`.") self.dimensions = spatial_dims self.in_channels = in_channels @@ -221,7 +214,6 @@ def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_ """ mod: nn.Module if self.num_res_units > 0: - mod = ResidualUnit( self.dimensions, in_channels, @@ -271,7 +263,7 @@ def _get_up_layer(self, in_channels: int, out_channels: int, strides: int, is_to strides: convolution stride. is_top: True if this is the top block. """ - conv: Union[Convolution, nn.Sequential] + conv: Convolution | nn.Sequential conv = Convolution( self.dimensions, diff --git a/monai/networks/nets/unetr.py b/monai/networks/nets/unetr.py index f95428693a..7ad12daa89 100644 --- a/monai/networks/nets/unetr.py +++ b/monai/networks/nets/unetr.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Tuple, Union +from __future__ import annotations + +from collections.abc import Sequence import torch.nn as nn @@ -29,13 +31,13 @@ def __init__( self, in_channels: int, out_channels: int, - img_size: Union[Sequence[int], int], + img_size: Sequence[int] | int, feature_size: int = 16, hidden_size: int = 768, mlp_dim: int = 3072, num_heads: int = 12, pos_embed: str = "conv", - norm_name: Union[Tuple, str] = "instance", + norm_name: tuple | str = "instance", conv_block: bool = True, res_block: bool = True, dropout_rate: float = 0.0, diff --git a/monai/networks/nets/varautoencoder.py b/monai/networks/nets/varautoencoder.py index 31c2a5cfe6..6cb8d6e40b 100644 --- a/monai/networks/nets/varautoencoder.py +++ b/monai/networks/nets/varautoencoder.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple, Union +from __future__ import annotations + +from collections.abc import Sequence import numpy as np import torch @@ -56,7 +58,7 @@ class VarAutoEncoder(AutoEncoder): # 3 layer network accepting images with dimensions (1, 32, 32) and using a latent vector with 2 values model = VarAutoEncoder( - dimensions=2, + spatial_dims=2, in_shape=(32, 32), # image spatial shape out_channels=1, latent_size=2, @@ -77,19 +79,18 @@ def __init__( latent_size: int, channels: Sequence[int], strides: Sequence[int], - kernel_size: Union[Sequence[int], int] = 3, - up_kernel_size: Union[Sequence[int], int] = 3, + kernel_size: Sequence[int] | int = 3, + up_kernel_size: Sequence[int] | int = 3, num_res_units: int = 0, - inter_channels: Optional[list] = None, - inter_dilations: Optional[list] = None, + inter_channels: list | None = None, + inter_dilations: list | None = None, num_inter_units: int = 2, - act: Optional[Union[Tuple, str]] = Act.PRELU, - norm: Union[Tuple, str] = Norm.INSTANCE, - dropout: Optional[Union[Tuple, str, float]] = None, + act: tuple | str | None = Act.PRELU, + norm: tuple | str = Norm.INSTANCE, + dropout: tuple | str | float | None = None, bias: bool = True, use_sigmoid: bool = True, ) -> None: - self.in_channels, *self.in_shape = in_shape self.use_sigmoid = use_sigmoid @@ -124,7 +125,7 @@ def __init__( self.logvar = nn.Linear(linear_size, self.latent_size) self.decodeL = nn.Linear(self.latent_size, linear_size) - def encode_forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def encode_forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: x = self.encode(x) x = self.intermediate(x) x = x.view(x.shape[0], -1) @@ -148,7 +149,7 @@ def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor return std.add_(mu) - def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: mu, logvar = self.encode_forward(x) z = self.reparameterize(mu, logvar) return self.decode_forward(z, self.use_sigmoid), mu, logvar, z diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index e4166c78b6..f3896d76c4 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Union +from __future__ import annotations + +from collections.abc import Sequence import torch import torch.nn as nn @@ -31,8 +33,8 @@ class ViT(nn.Module): def __init__( self, in_channels: int, - img_size: Union[Sequence[int], int], - patch_size: Union[Sequence[int], int], + img_size: Sequence[int] | int, + patch_size: Sequence[int] | int, hidden_size: int = 768, mlp_dim: int = 3072, num_layers: int = 12, diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index 6197f6bd99..ff6f637118 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Union +from __future__ import annotations + +from collections.abc import Sequence import torch import torch.nn as nn @@ -33,8 +35,8 @@ class ViTAutoEnc(nn.Module): def __init__( self, in_channels: int, - img_size: Union[Sequence[int], int], - patch_size: Union[Sequence[int], int], + img_size: Sequence[int] | int, + patch_size: Sequence[int] | int, out_channels: int = 1, deconv_chns: int = 16, hidden_size: int = 768, diff --git a/monai/networks/nets/vnet.py b/monai/networks/nets/vnet.py index 9abd2bc5e2..697547093a 100644 --- a/monai/networks/nets/vnet.py +++ b/monai/networks/nets/vnet.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Tuple, Type, Union +from __future__ import annotations import torch import torch.nn as nn @@ -20,7 +20,7 @@ __all__ = ["VNet"] -def get_acti_layer(act: Union[Tuple[str, Dict], str], nchan: int = 0): +def get_acti_layer(act: tuple[str, dict] | str, nchan: int = 0): if act == "prelu": act = ("prelu", {"num_parameters": nchan}) act_name, act_args = split_args(act) @@ -29,7 +29,7 @@ def get_acti_layer(act: Union[Tuple[str, Dict], str], nchan: int = 0): class LUConv(nn.Module): - def __init__(self, spatial_dims: int, nchan: int, act: Union[Tuple[str, Dict], str], bias: bool = False): + def __init__(self, spatial_dims: int, nchan: int, act: tuple[str, dict] | str, bias: bool = False): super().__init__() self.act_function = get_acti_layer(act, nchan) @@ -49,7 +49,7 @@ def forward(self, x): return out -def _make_nconv(spatial_dims: int, nchan: int, depth: int, act: Union[Tuple[str, Dict], str], bias: bool = False): +def _make_nconv(spatial_dims: int, nchan: int, depth: int, act: tuple[str, dict] | str, bias: bool = False): layers = [] for _ in range(depth): layers.append(LUConv(spatial_dims, nchan, act, bias)) @@ -58,12 +58,7 @@ def _make_nconv(spatial_dims: int, nchan: int, depth: int, act: Union[Tuple[str, class InputTransition(nn.Module): def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - act: Union[Tuple[str, Dict], str], - bias: bool = False, + self, spatial_dims: int, in_channels: int, out_channels: int, act: tuple[str, dict] | str, bias: bool = False ): super().__init__() @@ -100,16 +95,16 @@ def __init__( spatial_dims: int, in_channels: int, nconvs: int, - act: Union[Tuple[str, Dict], str], - dropout_prob: Optional[float] = None, + act: tuple[str, dict] | str, + dropout_prob: float | None = None, dropout_dim: int = 3, bias: bool = False, ): super().__init__() - conv_type: Type[Union[nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] - norm_type: Type[Union[nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] - dropout_type: Type[Union[nn.Dropout, nn.Dropout2d, nn.Dropout3d]] = Dropout[Dropout.DROPOUT, dropout_dim] + conv_type: type[nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims] + norm_type: type[nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims] + dropout_type: type[nn.Dropout | nn.Dropout2d | nn.Dropout3d] = Dropout[Dropout.DROPOUT, dropout_dim] out_channels = 2 * in_channels self.down_conv = conv_type(in_channels, out_channels, kernel_size=2, stride=2, bias=bias) @@ -137,15 +132,15 @@ def __init__( in_channels: int, out_channels: int, nconvs: int, - act: Union[Tuple[str, Dict], str], - dropout_prob: Optional[float] = None, + act: tuple[str, dict] | str, + dropout_prob: float | None = None, dropout_dim: int = 3, ): super().__init__() - conv_trans_type: Type[Union[nn.ConvTranspose2d, nn.ConvTranspose3d]] = Conv[Conv.CONVTRANS, spatial_dims] - norm_type: Type[Union[nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] - dropout_type: Type[Union[nn.Dropout, nn.Dropout2d, nn.Dropout3d]] = Dropout[Dropout.DROPOUT, dropout_dim] + conv_trans_type: type[nn.ConvTranspose2d | nn.ConvTranspose3d] = Conv[Conv.CONVTRANS, spatial_dims] + norm_type: type[nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims] + dropout_type: type[nn.Dropout | nn.Dropout2d | nn.Dropout3d] = Dropout[Dropout.DROPOUT, dropout_dim] self.up_conv = conv_trans_type(in_channels, out_channels // 2, kernel_size=2, stride=2) self.bn1 = norm_type(out_channels // 2) @@ -170,16 +165,11 @@ def forward(self, x, skipx): class OutputTransition(nn.Module): def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - act: Union[Tuple[str, Dict], str], - bias: bool = False, + self, spatial_dims: int, in_channels: int, out_channels: int, act: tuple[str, dict] | str, bias: bool = False ): super().__init__() - conv_type: Type[Union[nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] + conv_type: type[nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims] self.act_function1 = get_acti_layer(act, out_channels) self.conv_block = Convolution( @@ -233,7 +223,7 @@ def __init__( spatial_dims: int = 3, in_channels: int = 1, out_channels: int = 1, - act: Union[Tuple[str, Dict], str] = ("elu", {"inplace": True}), + act: tuple[str, dict] | str = ("elu", {"inplace": True}), dropout_prob: float = 0.5, dropout_dim: int = 3, bias: bool = False, diff --git a/monai/networks/utils.py b/monai/networks/utils.py index b58a41cc0c..0c8fbe8c28 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -11,26 +11,29 @@ """ Utilities and types for defining networks, these depend on PyTorch. """ + +from __future__ import annotations + import re import warnings from collections import OrderedDict +from collections.abc import Callable, Mapping, Sequence from contextlib import contextmanager from copy import deepcopy -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any +import numpy as np import torch import torch.nn as nn from monai.apps.utils import get_logger from monai.config import PathLike -from monai.utils.deprecate_utils import deprecated from monai.utils.misc import ensure_tuple, save_obj, set_determinism from monai.utils.module import look_up_option, pytorch_after -from monai.utils.type_conversion import convert_to_tensor +from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor __all__ = [ "one_hot", - "slice_channels", "predict_segmentation", "normalize_transform", "to_norm_affine", @@ -161,19 +164,6 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f return labels -@deprecated(since="0.8.0", msg_suffix="use `monai.utils.misc.sample_slices` instead.") -def slice_channels(tensor: torch.Tensor, *slicevals: Optional[int]) -> torch.Tensor: - """ - .. deprecated:: 0.8.0 - Use `monai.utils.misc.sample_slices` instead. - - """ - slices = [slice(None)] * len(tensor.shape) - slices[1] = slice(*slicevals) - - return tensor[slices] - - def predict_segmentation(logits: torch.Tensor, mutually_exclusive: bool = False, threshold: float = 0.0) -> Any: """ Given the logits from a network, computing the segmentation by thresholding all values above 0 @@ -196,8 +186,8 @@ def predict_segmentation(logits: torch.Tensor, mutually_exclusive: bool = False, def normalize_transform( shape, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, align_corners: bool = False, zero_centered: bool = False, ) -> torch.Tensor: @@ -208,8 +198,8 @@ def normalize_transform( - `align_corners=False`, `zero_centered=False`, normalizing from ``[-0.5, d-0.5]``. - `align_corners=True`, `zero_centered=False`, normalizing from ``[0, d-1]``. - - `align_corners=False`, `zero_centered=True`, normalizing from ``[-(d+1)/2, (d-1)/2]``. - - `align_corners=True`, `zero_centered=True`, normalizing from ``[-(d-1)/2, (d-1)/2]``. + - `align_corners=False`, `zero_centered=True`, normalizing from ``[-(d-1)/2, (d-1)/2]``. + - `align_corners=True`, `zero_centered=True`, normalizing from ``[-d/2, d/2]``. Args: shape: input spatial shape, a sequence of integers. @@ -225,15 +215,16 @@ def normalize_transform( norm = shape.clone().detach().to(dtype=torch.float64, device=device) # no in-place change if align_corners: norm[norm <= 1.0] = 2.0 - norm = 2.0 / (norm - 1.0) + norm = 2.0 / (norm if zero_centered else norm - 1.0) norm = torch.diag(torch.cat((norm, torch.ones((1,), dtype=torch.float64, device=device)))) if not zero_centered: # else shift is 0 norm[:-1, -1] = -1.0 else: norm[norm <= 0.0] = 2.0 - norm = 2.0 / norm + norm = 2.0 / (norm - 1.0 if zero_centered else norm) norm = torch.diag(torch.cat((norm, torch.ones((1,), dtype=torch.float64, device=device)))) - norm[:-1, -1] = 1.0 / shape - (0.0 if zero_centered else 1.0) + if not zero_centered: + norm[:-1, -1] = 1.0 / shape - 1.0 norm = norm.unsqueeze(0).to(dtype=dtype) norm.requires_grad = False return norm # type: ignore @@ -275,8 +266,8 @@ def to_norm_affine( raise ValueError(f"affine suggests {sr}D, got src={len(src_size)}D, dst={len(dst_size)}D.") src_xform = normalize_transform(src_size, affine.device, affine.dtype, align_corners, zero_centered) - dst_xform = normalize_transform(dst_size, affine.device, affine.dtype, align_corners, zero_centered) - return src_xform @ affine @ torch.inverse(dst_xform) + dst_xform = normalize_transform(dst_size, "cpu", affine.dtype, align_corners, zero_centered) + return src_xform @ affine @ convert_to_dst_type(np.linalg.inv(dst_xform.numpy()), dst=affine)[0] # monai#5983 def normal_init( @@ -385,17 +376,19 @@ def eval_mode(*nets: nn.Module): print(p(t).sum().backward()) # will correctly raise an exception as gradients are calculated """ - # Get original state of network(s) - training = [n for n in nets if n.training] + # Get original state of network(s). + # Check the training attribute in case it's TensorRT based models which don't have this attribute. + training = [n for n in nets if hasattr(n, "training") and n.training] try: # set to eval mode with torch.no_grad(): - yield [n.eval() for n in nets] + yield [n.eval() if hasattr(n, "eval") else n for n in nets] finally: # Return required networks to training for n in training: - n.train() + if hasattr(n, "train"): + n.train() @contextmanager @@ -420,19 +413,21 @@ def train_mode(*nets: nn.Module): """ # Get original state of network(s) - eval_list = [n for n in nets if not n.training] + # Check the training attribute in case it's TensorRT based models which don't have this attribute. + eval_list = [n for n in nets if hasattr(n, "training") and (not n.training)] try: # set to train mode with torch.set_grad_enabled(True): - yield [n.train() for n in nets] + yield [n.train() if hasattr(n, "train") else n for n in nets] finally: # Return required networks to eval_list for n in eval_list: - n.eval() + if hasattr(n, "eval"): + n.eval() -def get_state_dict(obj: Union[torch.nn.Module, Mapping]): +def get_state_dict(obj: torch.nn.Module | Mapping): """ Get the state dict of input object if has `state_dict`, otherwise, return object directly. For data parallel model, automatically convert it to regular model first. @@ -447,8 +442,8 @@ def get_state_dict(obj: Union[torch.nn.Module, Mapping]): def copy_model_state( - dst: Union[torch.nn.Module, Mapping], - src: Union[torch.nn.Module, Mapping], + dst: torch.nn.Module | Mapping, + src: torch.nn.Module | Mapping, dst_prefix="", mapping=None, exclude_vars=None, @@ -522,7 +517,7 @@ def copy_model_state( return dst_dict, updated_keys, unchanged_keys -def save_state(src: Union[torch.nn.Module, Dict], path: PathLike, **kwargs): +def save_state(src: torch.nn.Module | dict, path: PathLike, **kwargs): """ Save the state dict of input source data with PyTorch `save`. It can save `nn.Module`, `state_dict`, a dictionary of `nn.Module` or `state_dict`. @@ -546,7 +541,7 @@ def save_state(src: Union[torch.nn.Module, Dict], path: PathLike, **kwargs): """ - ckpt: Dict = {} + ckpt: dict = {} if isinstance(src, dict): for k, v in src.items(): ckpt[k] = get_state_dict(v) @@ -558,11 +553,11 @@ def save_state(src: Union[torch.nn.Module, Dict], path: PathLike, **kwargs): def convert_to_torchscript( model: nn.Module, - filename_or_obj: Optional[Any] = None, - extra_files: Optional[Dict] = None, + filename_or_obj: Any | None = None, + extra_files: dict | None = None, verify: bool = False, - inputs: Optional[Sequence[Any]] = None, - device: Optional[torch.device] = None, + inputs: Sequence[Any] | None = None, + device: torch.device | None = None, rtol: float = 1e-4, atol: float = 0.0, **kwargs, @@ -615,7 +610,7 @@ def convert_to_torchscript( for r1, r2 in zip(torch_out, torchscript_out): if isinstance(r1, torch.Tensor) or isinstance(r2, torch.Tensor): assert_fn = torch.testing.assert_close if pytorch_after(1, 11) else torch.testing.assert_allclose - assert_fn(r1, r2, rtol=rtol, atol=atol) + assert_fn(r1, r2, rtol=rtol, atol=atol) # type: ignore return script_module @@ -638,7 +633,7 @@ def _replace_modules( parent: torch.nn.Module, name: str, new_module: torch.nn.Module, - out: List[Tuple[str, torch.nn.Module]], + out: list[tuple[str, torch.nn.Module]], strict_match: bool = True, match_device: bool = True, ) -> None: @@ -656,7 +651,7 @@ def _replace_modules( parent_name = name[:idx] parent = getattr(parent, parent_name) name = name[idx + 1 :] - _out: List[Tuple[str, torch.nn.Module]] = [] + _out: list[tuple[str, torch.nn.Module]] = [] _replace_modules(parent, name, new_module, _out) # prepend the parent name out += [(f"{parent_name}.{r[0]}", r[1]) for r in _out] @@ -678,7 +673,7 @@ def replace_modules( new_module: torch.nn.Module, strict_match: bool = True, match_device: bool = True, -) -> List[Tuple[str, torch.nn.Module]]: +) -> list[tuple[str, torch.nn.Module]]: """ Replace sub-module(s) in a parent module. @@ -704,7 +699,7 @@ def replace_modules( Raises: AttributeError: if `strict_match` is `True` and `name` is not a named module in `parent`. """ - out: List[Tuple[str, torch.nn.Module]] = [] + out: list[tuple[str, torch.nn.Module]] = [] _replace_modules(parent, name, new_module, out, strict_match, match_device) return out @@ -722,7 +717,7 @@ def replace_modules_temp( See :py:class:`monai.networks.utils.replace_modules`. """ - replaced: List[Tuple[str, torch.nn.Module]] = [] + replaced: list[tuple[str, torch.nn.Module]] = [] try: # replace _replace_modules(parent, name, new_module, replaced, strict_match, match_device) diff --git a/monai/optimizers/__init__.py b/monai/optimizers/__init__.py index 8ce5d3f925..f0a3858ced 100644 --- a/monai/optimizers/__init__.py +++ b/monai/optimizers/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .lr_finder import LearningRateFinder from .lr_scheduler import ExponentialLR, LinearLR, WarmupCosineSchedule from .novograd import Novograd diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index ce092d33ab..3e7776c72f 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -9,10 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import pickle +import types import warnings from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Callable import numpy as np import torch @@ -181,11 +184,11 @@ def __init__( model: nn.Module, optimizer: Optimizer, criterion: torch.nn.Module, - device: Optional[Union[str, torch.device]] = None, + device: str | torch.device | None = None, memory_cache: bool = True, - cache_dir: Optional[str] = None, + cache_dir: str | None = None, amp: bool = False, - pickle_module=pickle, + pickle_module: types.ModuleType = pickle, pickle_protocol: int = DEFAULT_PROTOCOL, verbose: bool = True, ) -> None: @@ -222,7 +225,7 @@ def __init__( self.model = model self.criterion = criterion - self.history: Dict[str, list] = {"lr": [], "loss": []} + self.history: dict[str, list] = {"lr": [], "loss": []} self.memory_cache = memory_cache self.cache_dir = cache_dir self.amp = amp @@ -250,10 +253,10 @@ def reset(self) -> None: def range_test( self, train_loader: DataLoader, - val_loader: Optional[DataLoader] = None, + val_loader: DataLoader | None = None, image_extractor: Callable = default_image_extractor, label_extractor: Callable = default_label_extractor, - start_lr: Optional[float] = None, + start_lr: float | None = None, end_lr: int = 10, num_iter: int = 100, step_mode: str = "exp", @@ -311,7 +314,7 @@ def range_test( raise ValueError("`num_iter` must be larger than 1") # Initialize the proper learning rate policy - lr_schedule: Union[ExponentialLR, LinearLR] + lr_schedule: ExponentialLR | LinearLR if step_mode.lower() == "exp": lr_schedule = ExponentialLR(self.optimizer, end_lr, num_iter) elif step_mode.lower() == "linear": @@ -327,7 +330,7 @@ def range_test( if val_loader: val_iter = ValDataLoaderIter(val_loader, image_extractor, label_extractor) - trange: Union[partial[tqdm.trange], Type[range]] + trange: partial[tqdm.trange] | type[range] if self.verbose and has_tqdm: trange = partial(tqdm.trange, desc="Computing optimal learning rate") tprint = tqdm.tqdm.write @@ -369,7 +372,7 @@ def range_test( print("Resetting model and optimizer") self.reset() - def _set_learning_rate(self, new_lrs: Union[float, list]) -> None: + def _set_learning_rate(self, new_lrs: float | list) -> None: """Set learning rate(s) for optimizer.""" if not isinstance(new_lrs, list): new_lrs = [new_lrs] * len(self.optimizer.param_groups) @@ -387,7 +390,9 @@ def _check_for_scheduler(self): if "initial_lr" in param_group: raise RuntimeError("Optimizer already has a scheduler attached to it") - def _train_batch(self, train_iter, accumulation_steps: int, non_blocking_transfer: bool = True) -> float: + def _train_batch( + self, train_iter: TrainDataLoaderIter, accumulation_steps: int, non_blocking_transfer: bool = True + ) -> float: self.model.train() total_loss = 0 @@ -437,7 +442,7 @@ def _validate(self, val_iter: ValDataLoaderIter, non_blocking_transfer: bool = T return running_loss / len(val_iter.dataset) - def get_lrs_and_losses(self, skip_start: int = 0, skip_end: int = 0) -> Tuple[list, list]: + def get_lrs_and_losses(self, skip_start: int = 0, skip_end: int = 0) -> tuple[list, list]: """Get learning rates and their corresponding losses Args: @@ -457,9 +462,7 @@ def get_lrs_and_losses(self, skip_start: int = 0, skip_end: int = 0) -> Tuple[li return lrs, losses - def get_steepest_gradient( - self, skip_start: int = 0, skip_end: int = 0 - ) -> Union[Tuple[float, float], Tuple[None, None]]: + def get_steepest_gradient(self, skip_start: int = 0, skip_end: int = 0) -> tuple[float, float] | tuple[None, None]: """Get learning rate which has steepest gradient and its corresponding loss Args: @@ -478,7 +481,14 @@ def get_steepest_gradient( print("Failed to compute the gradients, there might not be enough points.") return None, None - def plot(self, skip_start: int = 0, skip_end: int = 0, log_lr: bool = True, ax=None, steepest_lr: bool = True): + def plot( + self, + skip_start: int = 0, + skip_end: int = 0, + log_lr: bool = True, + ax: Any | None = None, + steepest_lr: bool = True, + ) -> Any | None: """Plots the learning rate range test. Args: diff --git a/monai/optimizers/lr_scheduler.py b/monai/optimizers/lr_scheduler.py index dc76a3dda1..b056e06a01 100644 --- a/monai/optimizers/lr_scheduler.py +++ b/monai/optimizers/lr_scheduler.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import math from torch.optim import Optimizer diff --git a/monai/optimizers/novograd.py b/monai/optimizers/novograd.py index 07a6aff90a..6675f6ef85 100644 --- a/monai/optimizers/novograd.py +++ b/monai/optimizers/novograd.py @@ -9,11 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Iterable, Optional, Tuple +from __future__ import annotations + +from collections.abc import Callable, Iterable +from typing import TypeVar import torch from torch.optim import Optimizer +T = TypeVar("T") + class Novograd(Optimizer): """ @@ -38,7 +43,7 @@ def __init__( self, params: Iterable, lr: float = 1e-3, - betas: Tuple[float, float] = (0.9, 0.98), + betas: tuple[float, float] = (0.9, 0.98), eps: float = 1e-8, weight_decay: float = 0, grad_averaging: bool = False, @@ -65,7 +70,7 @@ def __setstate__(self, state): for group in self.param_groups: group.setdefault("amsgrad", False) - def step(self, closure: Optional[Callable] = None): + def step(self, closure: Callable[[], T] | None = None) -> T | None: """Performs a single optimization step. Arguments: diff --git a/monai/optimizers/utils.py b/monai/optimizers/utils.py index 5444aca191..7e566abb46 100644 --- a/monai/optimizers/utils.py +++ b/monai/optimizers/utils.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Sequence +from __future__ import annotations + +from collections.abc import Callable, Sequence import torch @@ -24,7 +26,7 @@ def generate_param_groups( match_types: Sequence[str], lr_values: Sequence[float], include_others: bool = True, -): +) -> list[dict]: """ Utility function to generate parameter groups with different LR values for optimizer. The output parameter groups have the same order as `layer_match` functions. diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index ba8702ebd9..940485cbe0 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .adaptors import FunctionSignature, adaptor, apply_alias, to_kwargs from .compose import Compose, OneOf, RandomOrder from .croppad.array import ( @@ -87,6 +89,7 @@ SpatialPadD, SpatialPadDict, ) +from .croppad.functional import crop_func, crop_or_pad_nd, pad_func, pad_nd from .intensity.array import ( AdjustContrast, ComputeHoVerMaps, @@ -451,18 +454,9 @@ ZoomD, ZoomDict, ) -from .transform import ( - LazyTrait, - LazyTransform, - MapTransform, - MultiSampleTrait, - Randomizable, - RandomizableTrait, - RandomizableTransform, - ThreadUnsafe, - Transform, - apply_transform, -) +from .spatial.functional import spatial_resample +from .traits import LazyTrait, MultiSampleTrait, RandomizableTrait, ThreadUnsafe +from .transform import LazyTransform, MapTransform, Randomizable, RandomizableTransform, Transform, apply_transform from .utility.array import ( AddChannel, AddCoordinateChannels, @@ -478,11 +472,14 @@ EnsureType, FgBgToIndices, Identity, + ImageFilter, IntensityStats, LabelToMask, Lambda, MapLabelValue, RandCuCIM, + RandIdentity, + RandImageFilter, RandLambda, RemoveRepeatedChannel, RepeatChannel, @@ -553,6 +550,9 @@ Identityd, IdentityD, IdentityDict, + ImageFilterd, + ImageFilterD, + ImageFilterDict, IntensityStatsd, IntensityStatsD, IntensityStatsDict, @@ -568,6 +568,9 @@ RandCuCIMd, RandCuCIMD, RandCuCIMDict, + RandImageFilterd, + RandImageFilterD, + RandImageFilterDict, RandLambdad, RandLambdaD, RandLambdaDict, diff --git a/monai/transforms/adaptors.py b/monai/transforms/adaptors.py index 1edbcc63e2..5729740690 100644 --- a/monai/transforms/adaptors.py +++ b/monai/transforms/adaptors.py @@ -121,6 +121,8 @@ def __call__(self, img, seg): """ +from __future__ import annotations + from typing import Callable from monai.utils import export as _monai_export @@ -146,7 +148,6 @@ def map_only_names(ditems, input_map): return {v: ditems[k] for k, v in input_map.items()} def _inner(ditems): - sig = FunctionSignature(function) if sig.found_kwargs: @@ -216,7 +217,6 @@ def _inner(ditems): @_monai_export("monai.transforms") def apply_alias(fn, name_map): def _inner(data): - # map names pre_call = dict(data) for _from, _to in name_map.items(): diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 0ce4433218..45e706e143 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -12,8 +12,11 @@ A collection of generic interfaces for MONAI transforms. """ +from __future__ import annotations + import warnings -from typing import Any, Callable, Mapping, Optional, Sequence, Union +from collections.abc import Callable, Mapping, Sequence +from typing import Any import numpy as np @@ -116,7 +119,7 @@ class Compose(Randomizable, InvertibleTransform): def __init__( self, - transforms: Optional[Union[Sequence[Callable], Callable]] = None, + transforms: Sequence[Callable] | Callable | None = None, map_items: bool = True, unpack_items: bool = False, log_stats: bool = False, @@ -129,7 +132,7 @@ def __init__( self.log_stats = log_stats self.set_random_state(seed=get_seed()) - def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> "Compose": + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Compose: super().set_random_state(seed=seed, state=state) for _transform in self.transforms: if not isinstance(_transform, Randomizable): @@ -137,7 +140,7 @@ def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random _transform.set_random_state(seed=self.R.randint(MAX_SEED, dtype="uint32")) return self - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, data: Any | None = None) -> None: for _transform in self.transforms: if not isinstance(_transform, Randomizable): continue @@ -206,8 +209,8 @@ class OneOf(Compose): def __init__( self, - transforms: Optional[Union[Sequence[Callable], Callable]] = None, - weights: Optional[Union[Sequence[float], float]] = None, + transforms: Sequence[Callable] | Callable | None = None, + weights: Sequence[float] | float | None = None, map_items: bool = True, unpack_items: bool = False, log_stats: bool = False, @@ -218,7 +221,10 @@ def __init__( elif weights is None or isinstance(weights, float): weights = [1.0 / len(self.transforms)] * len(self.transforms) if len(weights) != len(self.transforms): - raise AssertionError("transforms and weights should be same size if both specified as sequences.") + raise ValueError( + "transforms and weights should be same size if both specified as sequences, " + f"got {len(weights)} and {len(self.transforms)}." + ) self.weights = ensure_tuple(self._normalize_probabilities(weights)) def _normalize_probabilities(self, weights): @@ -226,9 +232,9 @@ def _normalize_probabilities(self, weights): return weights weights = np.array(weights) if np.any(weights < 0): - raise AssertionError("Probabilities must be greater than or equal to zero.") + raise ValueError(f"Probabilities must be greater than or equal to zero, got {weights}.") if np.all(weights == 0): - raise AssertionError("At least one probability must be greater than zero.") + raise ValueError(f"At least one probability must be greater than zero, got {weights}.") weights = weights / weights.sum() return list(weights) @@ -275,7 +281,9 @@ def inverse(self, data): if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data: index = self.pop_transform(data, key)[TraceKeys.EXTRA_INFO]["index"] else: - raise RuntimeError("Inverse only implemented for Mapping (dictionary) or MetaTensor data.") + raise RuntimeError( + f"Inverse only implemented for Mapping (dictionary) or MetaTensor data, got type {type(data)}." + ) if index is None: # no invertible transforms have been applied return data @@ -303,7 +311,7 @@ class RandomOrder(Compose): def __init__( self, - transforms: Optional[Union[Sequence[Callable], Callable]] = None, + transforms: Sequence[Callable] | Callable | None = None, map_items: bool = True, unpack_items: bool = False, log_stats: bool = False, @@ -339,7 +347,9 @@ def inverse(self, data): if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data: applied_order = self.pop_transform(data, key)[TraceKeys.EXTRA_INFO]["applied_order"] else: - raise RuntimeError("Inverse only implemented for Mapping (dictionary) or MetaTensor data.") + raise RuntimeError( + f"Inverse only implemented for Mapping (dictionary) or MetaTensor data, got type {type(data)}." + ) if applied_order is None: # no invertible transforms have been applied return data diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index a2ee02c553..aa13d54c51 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -13,25 +13,28 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ +from __future__ import annotations + +import warnings +from collections.abc import Callable, Sequence from itertools import chain from math import ceil -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from typing import Any import numpy as np import torch -from torch.nn.functional import pad as pad_pt from monai.config import IndexSelection from monai.config.type_definitions import NdarrayOrTensor from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor from monai.data.utils import get_random_patch, get_valid_patch_size +from monai.transforms.croppad.functional import crop_func, pad_func from monai.transforms.inverse import InvertibleTransform, TraceableTransform -from monai.transforms.transform import Randomizable, Transform +from monai.transforms.traits import MultiSampleTrait +from monai.transforms.transform import LazyTransform, Randomizable, Transform from monai.transforms.utils import ( compute_divisible_spatial_size, - convert_pad_mode, - create_translate, generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, @@ -42,13 +45,14 @@ ) from monai.utils import ImageMetaKey as Key from monai.utils import ( + LazyAttr, Method, PytorchPadMode, TraceKeys, TransformBackends, convert_data_type, - convert_to_dst_type, convert_to_tensor, + deprecated_arg_default, ensure_tuple, ensure_tuple_rep, fall_back_tuple, @@ -77,7 +81,7 @@ ] -class Pad(InvertibleTransform): +class Pad(InvertibleTransform, LazyTransform): """ Perform padding for a given an amount of padding in each dimension. @@ -102,13 +106,13 @@ class Pad(InvertibleTransform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( - self, to_pad: Optional[List[Tuple[int, int]]] = None, mode: str = PytorchPadMode.CONSTANT, **kwargs + self, to_pad: tuple[tuple[int, int]] | None = None, mode: str = PytorchPadMode.CONSTANT, **kwargs ) -> None: self.to_pad = to_pad self.mode = mode self.kwargs = kwargs - def compute_pad_width(self, spatial_shape: Sequence[int]) -> List[Tuple[int, int]]: + def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]: """ dynamically compute the pad width according to the spatial shape. the output is the amount of padding for all dimensions including the channel. @@ -119,26 +123,8 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> List[Tuple[int, int """ raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") - @staticmethod - def _np_pad(img: torch.Tensor, pad_width, mode, **kwargs) -> torch.Tensor: - img_np = img.detach().cpu().numpy() if isinstance(img, torch.Tensor) else img - mode = convert_pad_mode(dst=img_np, mode=mode).value - if mode == "constant" and "value" in kwargs: - val = kwargs.pop("value") - kwargs["constant_values"] = val - out = torch.as_tensor(np.pad(img, pad_width, mode=mode, **kwargs)) - if isinstance(img, MetaTensor): - out = convert_to_dst_type(out, dst=img)[0] - return out - - @staticmethod - def _pt_pad(img: torch.Tensor, pad_width, mode, **kwargs) -> torch.Tensor: - pt_pad_width = [val for sublist in pad_width[1:] for val in sublist[::-1]][::-1] - # torch.pad expects `[B, C, H, W, [D]]` shape - return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **kwargs).squeeze(0) - - def __call__( # type: ignore - self, img: torch.Tensor, to_pad: Optional[List[Tuple[int, int]]] = None, mode: Optional[str] = None, **kwargs + def __call__( # type: ignore[override] + self, img: torch.Tensor, to_pad: tuple[tuple[int, int]] | None = None, mode: str | None = None, **kwargs ) -> torch.Tensor: """ Args: @@ -157,52 +143,14 @@ def __call__( # type: ignore """ to_pad_ = self.to_pad if to_pad is None else to_pad if to_pad_ is None: - to_pad_ = self.compute_pad_width(img.shape[1:]) + spatial_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] + to_pad_ = self.compute_pad_width(spatial_shape) mode_ = self.mode if mode is None else mode kwargs_ = dict(self.kwargs) kwargs_.update(kwargs) img_t = convert_to_tensor(data=img, track_meta=get_track_meta()) - _orig_size = img_t.shape[1:] - - # all zeros, skip padding - if np.asarray(to_pad_).any(): - to_pad_ = list(to_pad_) - if len(to_pad_) < len(img_t.shape): - to_pad_ = list(to_pad_) + [(0, 0)] * (len(img_t.shape) - len(to_pad_)) - if mode_ in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"}: - out = self._np_pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_) - else: - mode_ = convert_pad_mode(dst=img_t, mode=mode_).value - try: - _pad = ( - self._pt_pad - if mode_ in {"reflect", "replicate"} - and img_t.dtype not in {torch.int16, torch.int64, torch.bool, torch.uint8} - else self._np_pad - ) - out = _pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_) - except (ValueError, TypeError, RuntimeError) as err: - if isinstance(err, NotImplementedError) or any( - k in str(err) for k in ("supported", "unexpected keyword", "implemented") - ): - out = self._np_pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_) - else: - raise ValueError( - f"{img_t.shape} {to_pad_} {mode_} {kwargs_} {img_t.dtype} {img_t.device}" - ) from err - else: - out = img_t - if get_track_meta(): - self.update_meta(tensor=out, to_pad=to_pad_) # type: ignore - self.push_transform(out, orig_size=_orig_size, extra_info={"padded": to_pad_}) - return out - - def update_meta(self, tensor: MetaTensor, to_pad: List[Tuple[int, int]]): - spatial_rank = max(len(tensor.affine) - 1, 1) - to_shift = [-s[0] for s in to_pad[1:]] # skipping the channel pad - mat = create_translate(spatial_rank, to_shift) - tensor.affine = tensor.affine @ convert_to_dst_type(mat, tensor.affine)[0] + return pad_func(img_t, to_pad_, mode_, self.get_transform_info(), kwargs_) def inverse(self, data: MetaTensor) -> MetaTensor: transform = self.pop_transform(data) @@ -243,7 +191,7 @@ class SpatialPad(Pad): def __init__( self, - spatial_size: Union[Sequence[int], int, Tuple[Union[Tuple[int, ...], int], ...]], + spatial_size: Sequence[int] | int | tuple[tuple[int, ...] | int, ...], method: str = Method.SYMMETRIC, mode: str = PytorchPadMode.CONSTANT, **kwargs, @@ -252,7 +200,7 @@ def __init__( self.method: Method = look_up_option(method, Method) super().__init__(mode=mode, **kwargs) - def compute_pad_width(self, spatial_shape: Sequence[int]) -> List[Tuple[int, int]]: + def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]: """ dynamically compute the pad width according to the spatial shape. @@ -265,10 +213,10 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> List[Tuple[int, int pad_width = [] for i, sp_i in enumerate(spatial_size): width = max(sp_i - spatial_shape[i], 0) - pad_width.append((width // 2, width - (width // 2))) + pad_width.append((int(width // 2), int(width - (width // 2)))) else: - pad_width = [(0, max(sp_i - spatial_shape[i], 0)) for i, sp_i in enumerate(spatial_size)] - return [(0, 0)] + pad_width + pad_width = [(0, int(max(sp_i - spatial_shape[i], 0))) for i, sp_i in enumerate(spatial_size)] + return tuple([(0, 0)] + pad_width) # type: ignore class BorderPad(Pad): @@ -297,30 +245,30 @@ class BorderPad(Pad): """ - def __init__( - self, spatial_border: Union[Sequence[int], int], mode: str = PytorchPadMode.CONSTANT, **kwargs - ) -> None: + def __init__(self, spatial_border: Sequence[int] | int, mode: str = PytorchPadMode.CONSTANT, **kwargs) -> None: self.spatial_border = spatial_border super().__init__(mode=mode, **kwargs) - def compute_pad_width(self, spatial_shape: Sequence[int]) -> List[Tuple[int, int]]: + def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]: spatial_border = ensure_tuple(self.spatial_border) if not all(isinstance(b, int) for b in spatial_border): raise ValueError(f"self.spatial_border must contain only ints, got {spatial_border}.") spatial_border = tuple(max(0, b) for b in spatial_border) if len(spatial_border) == 1: - data_pad_width = [(spatial_border[0], spatial_border[0]) for _ in spatial_shape] + data_pad_width = [(int(spatial_border[0]), int(spatial_border[0])) for _ in spatial_shape] elif len(spatial_border) == len(spatial_shape): - data_pad_width = [(sp, sp) for sp in spatial_border[: len(spatial_shape)]] + data_pad_width = [(int(sp), int(sp)) for sp in spatial_border[: len(spatial_shape)]] elif len(spatial_border) == len(spatial_shape) * 2: - data_pad_width = [(spatial_border[2 * i], spatial_border[2 * i + 1]) for i in range(len(spatial_shape))] + data_pad_width = [ + (int(spatial_border[2 * i]), int(spatial_border[2 * i + 1])) for i in range(len(spatial_shape)) + ] else: raise ValueError( f"Unsupported spatial_border length: {len(spatial_border)}, available options are " f"[1, len(spatial_shape)={len(spatial_shape)}, 2*len(spatial_shape)={2*len(spatial_shape)}]." ) - return [(0, 0)] + data_pad_width + return tuple([(0, 0)] + data_pad_width) # type: ignore class DivisiblePad(Pad): @@ -331,11 +279,7 @@ class DivisiblePad(Pad): backend = SpatialPad.backend def __init__( - self, - k: Union[Sequence[int], int], - mode: str = PytorchPadMode.CONSTANT, - method: str = Method.SYMMETRIC, - **kwargs, + self, k: Sequence[int] | int, mode: str = PytorchPadMode.CONSTANT, method: str = Method.SYMMETRIC, **kwargs ) -> None: """ Args: @@ -359,13 +303,13 @@ def __init__( self.method: Method = Method(method) super().__init__(mode=mode, **kwargs) - def compute_pad_width(self, spatial_shape: Sequence[int]) -> List[Tuple[int, int]]: + def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]: new_size = compute_divisible_spatial_size(spatial_shape=spatial_shape, k=self.k) spatial_pad = SpatialPad(spatial_size=new_size, method=self.method) return spatial_pad.compute_pad_width(spatial_shape) -class Crop(InvertibleTransform): +class Crop(InvertibleTransform, LazyTransform): """ Perform crop operations on the input image. @@ -375,12 +319,12 @@ class Crop(InvertibleTransform): @staticmethod def compute_slices( - roi_center: Union[Sequence[int], NdarrayOrTensor, None] = None, - roi_size: Union[Sequence[int], NdarrayOrTensor, None] = None, - roi_start: Union[Sequence[int], NdarrayOrTensor, None] = None, - roi_end: Union[Sequence[int], NdarrayOrTensor, None] = None, - roi_slices: Optional[Sequence[slice]] = None, - ): + roi_center: Sequence[int] | NdarrayOrTensor | None = None, + roi_size: Sequence[int] | NdarrayOrTensor | None = None, + roi_start: Sequence[int] | NdarrayOrTensor | None = None, + roi_end: Sequence[int] | NdarrayOrTensor | None = None, + roi_slices: Sequence[slice] | None = None, + ) -> tuple[slice]: """ Compute the crop slices based on specified `center & size` or `start & end` or `slices`. @@ -398,8 +342,8 @@ def compute_slices( if roi_slices: if not all(s.step is None or s.step == 1 for s in roi_slices): - raise ValueError("only slice steps of 1/None are currently supported") - return list(roi_slices) + raise ValueError(f"only slice steps of 1/None are currently supported, got {roi_slices}.") + return ensure_tuple(roi_slices) # type: ignore else: if roi_center is not None and roi_size is not None: roi_center_t = convert_to_tensor(data=roi_center, dtype=torch.int16, wrap_sequence=True, device="cpu") @@ -421,40 +365,26 @@ def compute_slices( roi_end_t = torch.maximum(roi_end_t, roi_start_t) # convert to slices (accounting for 1d) if roi_start_t.numel() == 1: - return [slice(int(roi_start_t.item()), int(roi_end_t.item()))] - else: - return [slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())] + return ensure_tuple([slice(int(roi_start_t.item()), int(roi_end_t.item()))]) # type: ignore + return ensure_tuple( # type: ignore + [slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())] + ) - def __call__(self, img: torch.Tensor, slices: Tuple[slice, ...]) -> torch.Tensor: # type: ignore + def __call__(self, img: torch.Tensor, slices: tuple[slice, ...]) -> torch.Tensor: # type: ignore[override] """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ - orig_size = img.shape[1:] slices_ = list(slices) - sd = len(img.shape[1:]) # spatial dims + sd = len(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]) # spatial dims if len(slices_) < sd: slices_ += [slice(None)] * (sd - len(slices_)) # Add in the channel (no cropping) - slices = tuple([slice(None)] + slices_[:sd]) + slices_ = list([slice(None)] + slices_[:sd]) img_t: MetaTensor = convert_to_tensor(data=img, track_meta=get_track_meta()) - _orig_size = img_t.shape[1:] - img_t = img_t[slices] # type: ignore - if get_track_meta(): - self.update_meta(tensor=img_t, slices=slices) - cropped_from_start = np.asarray([s.indices(o)[0] for s, o in zip(slices[1:], orig_size)]) - cropped_from_end = np.asarray(orig_size) - img_t.shape[1:] - cropped_from_start - cropped = list(chain(*zip(cropped_from_start.tolist(), cropped_from_end.tolist()))) - self.push_transform(img_t, orig_size=_orig_size, extra_info={"cropped": cropped}) - return img_t - - def update_meta(self, tensor: MetaTensor, slices: Tuple[slice, ...]): - spatial_rank = max(len(tensor.affine) - 1, 1) - to_shift = [s.start if s.start is not None else 0 for s in ensure_tuple(slices)[1:]] - mat = create_translate(spatial_rank, to_shift) - tensor.affine = tensor.affine @ convert_to_dst_type(mat, tensor.affine)[0] + return crop_func(img_t, tuple(slices_), self.get_transform_info()) def inverse(self, img: MetaTensor) -> MetaTensor: transform = self.pop_transform(img) @@ -482,11 +412,11 @@ class SpatialCrop(Crop): def __init__( self, - roi_center: Union[Sequence[int], NdarrayOrTensor, None] = None, - roi_size: Union[Sequence[int], NdarrayOrTensor, None] = None, - roi_start: Union[Sequence[int], NdarrayOrTensor, None] = None, - roi_end: Union[Sequence[int], NdarrayOrTensor, None] = None, - roi_slices: Optional[Sequence[slice]] = None, + roi_center: Sequence[int] | NdarrayOrTensor | None = None, + roi_size: Sequence[int] | NdarrayOrTensor | None = None, + roi_start: Sequence[int] | NdarrayOrTensor | None = None, + roi_end: Sequence[int] | NdarrayOrTensor | None = None, + roi_slices: Sequence[slice] | None = None, ) -> None: """ Args: @@ -502,13 +432,13 @@ def __init__( roi_center=roi_center, roi_size=roi_size, roi_start=roi_start, roi_end=roi_end, roi_slices=roi_slices ) - def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore + def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ - return super().__call__(img=img, slices=self.slices) + return super().__call__(img=img, slices=ensure_tuple(self.slices)) class CenterSpatialCrop(Crop): @@ -526,21 +456,24 @@ class CenterSpatialCrop(Crop): the spatial size of output data will be [32, 40, 40]. """ - def __init__(self, roi_size: Union[Sequence[int], int]) -> None: + def __init__(self, roi_size: Sequence[int] | int) -> None: self.roi_size = roi_size - def compute_slices(self, spatial_size: Sequence[int]): # type: ignore + def compute_slices(self, spatial_size: Sequence[int]) -> tuple[slice]: # type: ignore[override] roi_size = fall_back_tuple(self.roi_size, spatial_size) roi_center = [i // 2 for i in spatial_size] return super().compute_slices(roi_center=roi_center, roi_size=roi_size) - def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore + def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ - return super().__call__(img=img, slices=self.compute_slices(img.shape[1:])) + return super().__call__( + img=img, + slices=self.compute_slices(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]), + ) class CenterScaleCrop(Crop): @@ -553,15 +486,15 @@ class CenterScaleCrop(Crop): """ - def __init__(self, roi_scale: Union[Sequence[float], float]): + def __init__(self, roi_scale: Sequence[float] | float): self.roi_scale = roi_scale - def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore - img_size = img.shape[1:] + def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore[override] + img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] ndim = len(img_size) roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] cropper = CenterSpatialCrop(roi_size=roi_size) - return super().__call__(img=img, slices=cropper.compute_slices(img.shape[1:])) + return super().__call__(img=img, slices=cropper.compute_slices(img_size)) class RandSpatialCrop(Randomizable, Crop): @@ -588,10 +521,11 @@ class RandSpatialCrop(Randomizable, Crop): if True, the actual size is sampled from `randint(roi_size, max_roi_size + 1)`. """ + @deprecated_arg_default("random_size", True, False, since="1.1", replaced="1.3") def __init__( self, - roi_size: Union[Sequence[int], int], - max_roi_size: Optional[Union[Sequence[int], int]] = None, + roi_size: Sequence[int] | int, + max_roi_size: Sequence[int] | int | None = None, random_center: bool = True, random_size: bool = True, ) -> None: @@ -599,8 +533,8 @@ def __init__( self.max_roi_size = max_roi_size self.random_center = random_center self.random_size = random_size - self._size: Optional[Sequence[int]] = None - self._slices: Tuple[slice, ...] + self._size: Sequence[int] | None = None + self._slices: tuple[slice, ...] def randomize(self, img_size: Sequence[int]) -> None: self._size = fall_back_tuple(self.roi_size, img_size) @@ -619,14 +553,15 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: slicing doesn't apply to the channel dim. """ + img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] if randomize: - self.randomize(img.shape[1:]) + self.randomize(img_size) if self._size is None: raise RuntimeError("self._size not specified.") if self.random_center: return super().__call__(img=img, slices=self._slices) cropper = CenterSpatialCrop(self._size) - return super().__call__(img=img, slices=cropper.compute_slices(img.shape[1:])) + return super().__call__(img=img, slices=cropper.compute_slices(img_size)) class RandScaleCrop(RandSpatialCrop): @@ -650,10 +585,11 @@ class RandScaleCrop(RandSpatialCrop): `randint(roi_scale * image spatial size, max_roi_scale * image spatial size + 1)`. """ + @deprecated_arg_default("random_size", True, False, since="1.1", replaced="1.3") def __init__( self, - roi_scale: Union[Sequence[float], float], - max_roi_scale: Optional[Union[Sequence[float], float]] = None, + roi_scale: Sequence[float] | float, + max_roi_scale: Sequence[float] | float | None = None, random_center: bool = True, random_size: bool = True, ) -> None: @@ -679,11 +615,11 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: slicing doesn't apply to the channel dim. """ - self.get_max_roi_size(img.shape[1:]) + self.get_max_roi_size(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]) return super().__call__(img=img, randomize=randomize) -class RandSpatialCropSamples(Randomizable, TraceableTransform): +class RandSpatialCropSamples(Randomizable, TraceableTransform, LazyTransform, MultiSampleTrait): """ Crop image with random size or specific size ROI to generate a list of N samples. It can crop at a random position as center or at the image center. And allows to set @@ -716,11 +652,12 @@ class RandSpatialCropSamples(Randomizable, TraceableTransform): backend = RandSpatialCrop.backend + @deprecated_arg_default("random_size", True, False, since="1.1", replaced="1.3") def __init__( self, - roi_size: Union[Sequence[int], int], + roi_size: Sequence[int] | int, num_samples: int, - max_roi_size: Optional[Union[Sequence[int], int]] = None, + max_roi_size: Sequence[int] | int | None = None, random_center: bool = True, random_size: bool = True, ) -> None: @@ -730,27 +667,31 @@ def __init__( self.cropper = RandSpatialCrop(roi_size, max_roi_size, random_center, random_size) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandSpatialCropSamples": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandSpatialCropSamples: super().set_random_state(seed, state) self.cropper.set_random_state(seed, state) return self - def randomize(self, data: Optional[Any] = None) -> None: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, value: bool) -> None: + self._lazy_evaluation = value + self.cropper.lazy_evaluation = value + + def randomize(self, data: Any | None = None) -> None: pass - def __call__(self, img: torch.Tensor) -> List[torch.Tensor]: + def __call__(self, img: torch.Tensor) -> list[torch.Tensor]: """ Apply the transform to `img`, assuming `img` is channel-first and cropping doesn't change the channel dim. """ ret = [] - orig_size = img.shape[1:] for i in range(self.num_samples): cropped = self.cropper(img) if get_track_meta(): cropped.meta[Key.PATCH_INDEX] = i # type: ignore - self.push_transform(cropped, orig_size=orig_size, extra_info=self.pop_transform(cropped, check=False)) + self.push_transform(cropped, replace=True) # track as this class instead of RandSpatialCrop ret.append(cropped) return ret @@ -790,11 +731,11 @@ def threshold_at_one(x): def __init__( self, select_fn: Callable = is_positive, - channel_indices: Optional[IndexSelection] = None, - margin: Union[Sequence[int], int] = 0, + channel_indices: IndexSelection | None = None, + margin: Sequence[int] | int = 0, allow_smaller: bool = True, return_coords: bool = False, - k_divisible: Union[Sequence[int], int] = 1, + k_divisible: Sequence[int] | int = 1, mode: str = PytorchPadMode.CONSTANT, **pad_kwargs, ) -> None: @@ -828,12 +769,19 @@ def __init__( self.k_divisible = k_divisible self.padder = Pad(mode=mode, **pad_kwargs) - def compute_bounding_box(self, img: torch.Tensor): + @Crop.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, _val: bool): + self._lazy_evaluation = _val + self.padder.lazy_evaluation = _val + + def compute_bounding_box(self, img: torch.Tensor) -> tuple[np.ndarray, np.ndarray]: """ Compute the start points and end points of bounding box to crop. And adjust bounding box coords to be divisible by `k`. """ + if isinstance(img, MetaTensor) and img.pending_operations: + warnings.warn("foreground computation may not be accurate if the image has pending operations.") box_start, box_end = generate_spatial_bounding_box( img, self.select_fn, self.channel_indices, self.margin, self.allow_smaller ) @@ -848,8 +796,8 @@ def compute_bounding_box(self, img: torch.Tensor): return box_start_, box_end_ def crop_pad( - self, img: torch.Tensor, box_start: np.ndarray, box_end: np.ndarray, mode: Optional[str] = None, **pad_kwargs - ): + self, img: torch.Tensor, box_start: np.ndarray, box_end: np.ndarray, mode: str | None = None, **pad_kwargs + ) -> torch.Tensor: """ Crop and pad based on the bounding box. @@ -857,19 +805,24 @@ def crop_pad( slices = self.compute_slices(roi_start=box_start, roi_end=box_end) cropped = super().__call__(img=img, slices=slices) pad_to_start = np.maximum(-box_start, 0) - pad_to_end = np.maximum(box_end - np.asarray(img.shape[1:]), 0) + pad_to_end = np.maximum( + box_end - np.asarray(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]), 0 + ) pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - pad_width = BorderPad(spatial_border=pad).compute_pad_width(cropped.shape[1:]) + pad_width = BorderPad(spatial_border=pad).compute_pad_width( + cropped.peek_pending_shape() if isinstance(cropped, MetaTensor) else cropped.shape[1:] + ) ret = self.padder.__call__(img=cropped, to_pad=pad_width, mode=mode, **pad_kwargs) # combine the traced cropping and padding into one transformation # by taking the padded info and placing it in a key inside the crop info. - if get_track_meta(): - ret_: MetaTensor = ret # type: ignore - app_op = ret_.applied_operations.pop(-1) - ret_.applied_operations[-1][TraceKeys.EXTRA_INFO]["pad_info"] = app_op + if get_track_meta() and isinstance(ret, MetaTensor): + if not self.lazy_evaluation: + ret.applied_operations[-1][TraceKeys.EXTRA_INFO]["pad_info"] = ret.applied_operations.pop() return ret - def __call__(self, img: torch.Tensor, mode: Optional[str] = None, **pad_kwargs): # type: ignore + def __call__( # type: ignore[override] + self, img: torch.Tensor, mode: str | None = None, **pad_kwargs + ) -> torch.Tensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't change the channel dim. @@ -878,7 +831,7 @@ def __call__(self, img: torch.Tensor, mode: Optional[str] = None, **pad_kwargs): cropped = self.crop_pad(img, box_start, box_end, mode, **pad_kwargs) if self.return_coords: - return cropped, box_start, box_end + return cropped, box_start, box_end # type: ignore[return-value] return cropped def inverse(self, img: MetaTensor) -> MetaTensor: @@ -892,7 +845,7 @@ def inverse(self, img: MetaTensor) -> MetaTensor: return super().inverse(inv) -class RandWeightedCrop(Randomizable, TraceableTransform): +class RandWeightedCrop(Randomizable, TraceableTransform, LazyTransform, MultiSampleTrait): """ Samples a list of `num_samples` image patches according to the provided `weight_map`. @@ -908,24 +861,27 @@ class RandWeightedCrop(Randomizable, TraceableTransform): backend = SpatialCrop.backend def __init__( - self, - spatial_size: Union[Sequence[int], int], - num_samples: int = 1, - weight_map: Optional[NdarrayOrTensor] = None, + self, spatial_size: Sequence[int] | int, num_samples: int = 1, weight_map: NdarrayOrTensor | None = None ): self.spatial_size = ensure_tuple(spatial_size) self.num_samples = int(num_samples) self.weight_map = weight_map - self.centers: List[np.ndarray] = [] + self.centers: list[np.ndarray] = [] def randomize(self, weight_map: NdarrayOrTensor) -> None: + if isinstance(weight_map, MetaTensor) and weight_map.pending_operations: + warnings.warn("weight map has pending operations, the sampling may not be correct.") self.centers = weighted_patch_samples( spatial_size=self.spatial_size, w=weight_map[0], n_samples=self.num_samples, r_state=self.R ) # using only the first channel as weight map + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, _val: bool): + self._lazy_evaluation = _val + def __call__( - self, img: torch.Tensor, weight_map: Optional[NdarrayOrTensor] = None, randomize: bool = True - ) -> List[torch.Tensor]: + self, img: torch.Tensor, weight_map: NdarrayOrTensor | None = None, randomize: bool = True + ) -> list[torch.Tensor]: """ Args: img: input image to sample patches from. assuming `img` is a channel-first array. @@ -937,30 +893,34 @@ def __call__( Returns: A list of image patches """ - if weight_map is None: - weight_map = self.weight_map - if weight_map is None: - raise ValueError("weight map must be provided for weighted patch sampling.") - if img.shape[1:] != weight_map.shape[1:]: - raise ValueError(f"image and weight map spatial shape mismatch: {img.shape[1:]} vs {weight_map.shape[1:]}.") + img_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] if randomize: + if weight_map is None: + weight_map = self.weight_map + if weight_map is None: + raise ValueError("weight map must be provided for weighted patch sampling.") + w_shape = weight_map.peek_pending_shape() if isinstance(weight_map, MetaTensor) else weight_map.shape[1:] + if img_shape != w_shape: + warnings.warn(f"image and weight map spatial shape mismatch: {img_shape} vs {w_shape}.") self.randomize(weight_map) - _spatial_size = fall_back_tuple(self.spatial_size, weight_map.shape[1:]) - results: List[torch.Tensor] = [] - orig_size = img.shape[1:] + + _spatial_size = fall_back_tuple(self.spatial_size, img_shape) + results: list[torch.Tensor] = [] for i, center in enumerate(self.centers): - cropped = SpatialCrop(roi_center=center, roi_size=_spatial_size)(img) + cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size) + cropper.lazy_evaluation = self.lazy_evaluation + cropped = cropper(img) if get_track_meta(): ret_: MetaTensor = cropped # type: ignore ret_.meta[Key.PATCH_INDEX] = i ret_.meta["crop_center"] = center - self.push_transform(ret_, orig_size=orig_size, extra_info=self.pop_transform(ret_, check=False)) + self.push_transform(ret_, replace=True) results.append(cropped) return results -class RandCropByPosNegLabel(Randomizable, TraceableTransform): +class RandCropByPosNegLabel(Randomizable, TraceableTransform, LazyTransform, MultiSampleTrait): """ Crop random fixed sized regions with the center being a foreground or background voxel based on the Pos Neg Ratio. @@ -1019,15 +979,15 @@ class RandCropByPosNegLabel(Randomizable, TraceableTransform): def __init__( self, - spatial_size: Union[Sequence[int], int], - label: Optional[torch.Tensor] = None, + spatial_size: Sequence[int] | int, + label: torch.Tensor | None = None, pos: float = 1.0, neg: float = 1.0, num_samples: int = 1, - image: Optional[torch.Tensor] = None, + image: torch.Tensor | None = None, image_threshold: float = 0.0, - fg_indices: Optional[NdarrayOrTensor] = None, - bg_indices: Optional[NdarrayOrTensor] = None, + fg_indices: NdarrayOrTensor | None = None, + bg_indices: NdarrayOrTensor | None = None, allow_smaller: bool = False, ) -> None: self.spatial_size = spatial_size @@ -1040,47 +1000,59 @@ def __init__( self.num_samples = num_samples self.image = image self.image_threshold = image_threshold - self.centers: Optional[List[List[int]]] = None + self.centers: tuple[tuple] | None = None self.fg_indices = fg_indices self.bg_indices = bg_indices self.allow_smaller = allow_smaller def randomize( self, - label: torch.Tensor, - fg_indices: Optional[NdarrayOrTensor] = None, - bg_indices: Optional[NdarrayOrTensor] = None, - image: Optional[torch.Tensor] = None, + label: torch.Tensor | None = None, + fg_indices: NdarrayOrTensor | None = None, + bg_indices: NdarrayOrTensor | None = None, + image: torch.Tensor | None = None, ) -> None: - if fg_indices is None or bg_indices is None: - if self.fg_indices is not None and self.bg_indices is not None: - fg_indices_ = self.fg_indices - bg_indices_ = self.bg_indices - else: - fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) - else: - fg_indices_ = fg_indices - bg_indices_ = bg_indices + fg_indices_ = self.fg_indices if fg_indices is None else fg_indices + bg_indices_ = self.bg_indices if bg_indices is None else bg_indices + if fg_indices_ is None or bg_indices_ is None: + if isinstance(label, MetaTensor) and label.pending_operations: + warnings.warn("label has pending operations, the fg/bg indices may be incorrect.") + if isinstance(image, MetaTensor) and image.pending_operations: + warnings.warn("image has pending operations, the fg/bg indices may be incorrect.") + if label is None: + raise ValueError("label must be provided.") + fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) + _shape = None + if label is not None: + _shape = label.peek_pending_shape() if isinstance(label, MetaTensor) else label.shape[1:] + elif image is not None: + _shape = image.peek_pending_shape() if isinstance(image, MetaTensor) else image.shape[1:] + if _shape is None: + raise ValueError("label or image must be provided to get the spatial shape.") self.centers = generate_pos_neg_label_crop_centers( self.spatial_size, self.num_samples, self.pos_ratio, - label.shape[1:], + _shape, fg_indices_, bg_indices_, self.R, self.allow_smaller, ) + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, _val: bool): + self._lazy_evaluation = _val + def __call__( self, img: torch.Tensor, - label: Optional[torch.Tensor] = None, - image: Optional[torch.Tensor] = None, - fg_indices: Optional[NdarrayOrTensor] = None, - bg_indices: Optional[NdarrayOrTensor] = None, + label: torch.Tensor | None = None, + image: torch.Tensor | None = None, + fg_indices: NdarrayOrTensor | None = None, + bg_indices: NdarrayOrTensor | None = None, randomize: bool = True, - ) -> List[torch.Tensor]: + ) -> list[torch.Tensor]: """ Args: img: input data to crop samples from based on the pos/neg ratio of `label` and `image`. @@ -1096,31 +1068,30 @@ def __call__( randomize: whether to execute the random operations, default to `True`. """ - if label is None: - label = self.label - if label is None: - raise ValueError("label should be provided.") if image is None: image = self.image - if randomize: + if label is None: + label = self.label self.randomize(label, fg_indices, bg_indices, image) - results: List[torch.Tensor] = [] - orig_size = img.shape[1:] + results: list[torch.Tensor] = [] if self.centers is not None: + img_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] + roi_size = fall_back_tuple(self.spatial_size, default=img_shape) for i, center in enumerate(self.centers): - roi_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) - cropped = SpatialCrop(roi_center=center, roi_size=roi_size)(img) + cropper = SpatialCrop(roi_center=center, roi_size=roi_size) + cropper.lazy_evaluation = self.lazy_evaluation + cropped = cropper(img) if get_track_meta(): ret_: MetaTensor = cropped # type: ignore ret_.meta[Key.PATCH_INDEX] = i ret_.meta["crop_center"] = center - self.push_transform(ret_, orig_size=orig_size, extra_info=self.pop_transform(ret_, check=False)) + self.push_transform(ret_, replace=True) results.append(cropped) return results -class RandCropByLabelClasses(Randomizable, TraceableTransform): +class RandCropByLabelClasses(Randomizable, TraceableTransform, LazyTransform, MultiSampleTrait): """ Crop random fixed sized regions with the center being a class based on the specified ratios of every class. The label data can be One-Hot format array or Argmax data. And will return a list of arrays for all the @@ -1167,7 +1138,7 @@ class RandCropByLabelClasses(Randomizable, TraceableTransform): the spatial size of output data will be [32, 40, 40]. ratios: specified ratios of every class in the label to generate crop centers, including background class. if None, every class will have the same ratio to generate crop centers. - label: the label image that is used for finding every classes, if None, must set at `self.__call__`. + label: the label image that is used for finding every class, if None, must set at `self.__call__`. num_classes: number of classes for argmax label, not necessary for One-Hot label. num_samples: number of samples (crop regions) to take in each list. image: if image is not None, only return the indices of every class that are within the valid @@ -1181,6 +1152,7 @@ class RandCropByLabelClasses(Randomizable, TraceableTransform): allow_smaller: if `False`, an exception will be raised if the image is smaller than the requested ROI in any dimension. If `True`, any smaller dimensions will remain unchanged. + warn: if `True` prints a warning if a class is not present in the label. """ @@ -1188,15 +1160,16 @@ class RandCropByLabelClasses(Randomizable, TraceableTransform): def __init__( self, - spatial_size: Union[Sequence[int], int], - ratios: Optional[List[Union[float, int]]] = None, - label: Optional[torch.Tensor] = None, - num_classes: Optional[int] = None, + spatial_size: Sequence[int] | int, + ratios: list[float | int] | None = None, + label: torch.Tensor | None = None, + num_classes: int | None = None, num_samples: int = 1, - image: Optional[torch.Tensor] = None, + image: torch.Tensor | None = None, image_threshold: float = 0.0, - indices: Optional[List[NdarrayOrTensor]] = None, + indices: list[NdarrayOrTensor] | None = None, allow_smaller: bool = False, + warn: bool = True, ) -> None: self.spatial_size = spatial_size self.ratios = ratios @@ -1205,33 +1178,49 @@ def __init__( self.num_samples = num_samples self.image = image self.image_threshold = image_threshold - self.centers: Optional[List[List[int]]] = None + self.centers: tuple[tuple] | None = None self.indices = indices self.allow_smaller = allow_smaller + self.warn = warn def randomize( - self, label: torch.Tensor, indices: Optional[List[NdarrayOrTensor]] = None, image: Optional[torch.Tensor] = None + self, + label: torch.Tensor | None = None, + indices: list[NdarrayOrTensor] | None = None, + image: torch.Tensor | None = None, ) -> None: - indices_: Sequence[NdarrayOrTensor] - if indices is None: - if self.indices is not None: - indices_ = self.indices - else: - indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) - else: - indices_ = indices + indices_ = self.indices if indices is None else indices + if indices_ is None: + if isinstance(label, MetaTensor) and label.pending_operations: + warnings.warn("label has pending operations, the fg/bg indices may be incorrect.") + if isinstance(image, MetaTensor) and image.pending_operations: + warnings.warn("image has pending operations, the fg/bg indices may be incorrect.") + if label is None: + raise ValueError("label must not be None.") + indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) + _shape = None + if label is not None: + _shape = label.peek_pending_shape() if isinstance(label, MetaTensor) else label.shape[1:] + elif image is not None: + _shape = image.peek_pending_shape() if isinstance(image, MetaTensor) else image.shape[1:] + if _shape is None: + raise ValueError("label or image must be provided to infer the output spatial shape.") self.centers = generate_label_classes_crop_centers( - self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R, self.allow_smaller + self.spatial_size, self.num_samples, _shape, indices_, self.ratios, self.R, self.allow_smaller, self.warn ) + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, _val: bool): + self._lazy_evaluation = _val + def __call__( self, img: torch.Tensor, - label: Optional[torch.Tensor] = None, - image: Optional[torch.Tensor] = None, - indices: Optional[List[NdarrayOrTensor]] = None, + label: torch.Tensor | None = None, + image: torch.Tensor | None = None, + indices: list[NdarrayOrTensor] | None = None, randomize: bool = True, - ) -> List[torch.Tensor]: + ) -> list[torch.Tensor]: """ Args: img: input data to crop samples from based on the ratios of every class, assumes `img` is a @@ -1243,32 +1232,31 @@ def __call__( randomize: whether to execute the random operations, default to `True`. """ - if label is None: - label = self.label - if label is None: - raise ValueError("label should be provided.") if image is None: image = self.image - if randomize: + if label is None: + label = self.label self.randomize(label, indices, image) - results: List[torch.Tensor] = [] - orig_size = img.shape[1:] + results: list[torch.Tensor] = [] if self.centers is not None: + img_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] + roi_size = fall_back_tuple(self.spatial_size, default=img_shape) for i, center in enumerate(self.centers): - roi_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) - cropped = SpatialCrop(roi_center=tuple(center), roi_size=roi_size)(img) + cropper = SpatialCrop(roi_center=tuple(center), roi_size=roi_size) + cropper.lazy_evaluation = self.lazy_evaluation + cropped = cropper(img) if get_track_meta(): ret_: MetaTensor = cropped # type: ignore ret_.meta[Key.PATCH_INDEX] = i ret_.meta["crop_center"] = center - self.push_transform(ret_, orig_size=orig_size, extra_info=self.pop_transform(ret_, check=False)) + self.push_transform(ret_, replace=True) results.append(cropped) return results -class ResizeWithPadOrCrop(InvertibleTransform): +class ResizeWithPadOrCrop(InvertibleTransform, LazyTransform): """ Resize an image to a target spatial size by either centrally cropping the image or padding it evenly with a user-specified mode. @@ -1295,7 +1283,7 @@ class ResizeWithPadOrCrop(InvertibleTransform): def __init__( self, - spatial_size: Union[Sequence[int], int], + spatial_size: Sequence[int] | int, method: str = Method.SYMMETRIC, mode: str = PytorchPadMode.CONSTANT, **pad_kwargs, @@ -1303,7 +1291,13 @@ def __init__( self.padder = SpatialPad(spatial_size=spatial_size, method=method, mode=mode, **pad_kwargs) self.cropper = CenterSpatialCrop(roi_size=spatial_size) - def __call__(self, img: torch.Tensor, mode: Optional[str] = None, **pad_kwargs) -> torch.Tensor: # type: ignore + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool): + self.padder.lazy_evaluation = val + self.cropper.lazy_evaluation = val + self._lazy_evaluation = val + + def __call__(self, img: torch.Tensor, mode: str | None = None, **pad_kwargs) -> torch.Tensor: # type: ignore """ Args: img: data to pad or crop, assuming `img` is channel-first and @@ -1318,14 +1312,29 @@ def __call__(self, img: torch.Tensor, mode: Optional[str] = None, **pad_kwargs) note that `np.pad` treats channel dimension as the first dimension. """ - orig_size = img.shape[1:] ret = self.padder(self.cropper(img), mode=mode, **pad_kwargs) # remove the individual info and combine if get_track_meta(): ret_: MetaTensor = ret # type: ignore - pad_info = ret_.applied_operations.pop(-1) - crop_info = ret_.applied_operations.pop(-1) - self.push_transform(ret_, orig_size=orig_size, extra_info={"pad_info": pad_info, "crop_info": crop_info}) + if not self.lazy_evaluation: + pad_info = ret_.applied_operations.pop() + crop_info = ret_.applied_operations.pop() + orig_size = crop_info.get(TraceKeys.ORIG_SIZE) + self.push_transform( + ret_, orig_size=orig_size, extra_info={"pad_info": pad_info, "crop_info": crop_info} + ) + else: + pad_info = ret_.pending_operations.pop() + crop_info = ret_.pending_operations.pop() + orig_size = crop_info.get(TraceKeys.ORIG_SIZE) + self.push_transform( + ret_, + orig_size=orig_size, + sp_size=pad_info[LazyAttr.SHAPE], + affine=crop_info[LazyAttr.AFFINE] @ pad_info[LazyAttr.AFFINE], + extra_info={"pad_info": pad_info, "crop_info": crop_info}, + ) + return ret def inverse(self, img: MetaTensor) -> MetaTensor: diff --git a/monai/transforms/croppad/batch.py b/monai/transforms/croppad/batch.py index a3cb15144b..bd8e7bf50f 100644 --- a/monai/transforms/croppad/batch.py +++ b/monai/transforms/croppad/batch.py @@ -13,7 +13,10 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ -from typing import Any, Dict, Hashable, Mapping +from __future__ import annotations + +from collections.abc import Hashable, Mapping +from typing import Any import numpy as np import torch @@ -112,9 +115,9 @@ def __call__(self, batch: Any): return list_data_collate(batch) @staticmethod - def inverse(data: dict) -> Dict[Hashable, np.ndarray]: + def inverse(data: dict) -> dict[Hashable, np.ndarray]: if not isinstance(data, Mapping): - raise RuntimeError("Inverse can only currently be applied on dictionaries.") + raise RuntimeError(f"Inverse can only currently be applied on dictionaries, got type {type(data)}.") d = dict(data) for key in d: @@ -125,7 +128,7 @@ def inverse(data: dict) -> Dict[Hashable, np.ndarray]: transform_key = InvertibleTransform.trace_key(key) if transform_key in d: transforms = d[transform_key] - if not transforms or not isinstance(transforms[-1], Dict): + if not transforms or not isinstance(transforms[-1], dict): continue if transforms[-1].get(TraceKeys.CLASS_NAME) == PadListDataCollate.__name__: xform = transforms.pop() diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index bae6705c22..ab4ce28941 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -15,8 +15,11 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ +from __future__ import annotations + +from collections.abc import Callable, Hashable, Mapping, Sequence from copy import deepcopy -from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Union +from typing import Any import numpy as np import torch @@ -44,10 +47,10 @@ SpatialPad, ) from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import MapTransform, Randomizable +from monai.transforms.traits import MultiSampleTrait +from monai.transforms.transform import LazyTransform, MapTransform, Randomizable from monai.transforms.utils import is_positive -from monai.utils import MAX_SEED, Method, PytorchPadMode, ensure_tuple_rep -from monai.utils.deprecate_utils import deprecated_arg +from monai.utils import MAX_SEED, Method, PytorchPadMode, deprecated_arg_default, ensure_tuple_rep __all__ = [ "Padd", @@ -107,7 +110,7 @@ ] -class Padd(MapTransform, InvertibleTransform): +class Padd(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Pad`. @@ -141,13 +144,19 @@ def __init__( self.padder = padder self.mode = ensure_tuple_rep(mode, len(self.keys)) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, value: bool) -> None: + self._lazy_evaluation = value + if isinstance(self.padder, LazyTransform): + self.padder.lazy_evaluation = value + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key, m in self.key_iterator(d, self.mode): d[key] = self.padder(d[key], mode=m) return d - def inverse(self, data: Mapping[Hashable, MetaTensor]) -> Dict[Hashable, MetaTensor]: + def inverse(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.padder.inverse(d[key]) @@ -164,7 +173,7 @@ class SpatialPadd(Padd): def __init__( self, keys: KeysCollection, - spatial_size: Union[Sequence[int], int], + spatial_size: Sequence[int] | int, method: str = Method.SYMMETRIC, mode: SequenceStr = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False, @@ -208,7 +217,7 @@ class BorderPadd(Padd): def __init__( self, keys: KeysCollection, - spatial_border: Union[Sequence[int], int], + spatial_border: Sequence[int] | int, mode: SequenceStr = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False, **kwargs, @@ -255,7 +264,7 @@ class DivisiblePadd(Padd): def __init__( self, keys: KeysCollection, - k: Union[Sequence[int], int], + k: Sequence[int] | int, mode: SequenceStr = PytorchPadMode.CONSTANT, method: str = Method.SYMMETRIC, allow_missing_keys: bool = False, @@ -288,7 +297,7 @@ def __init__( super().__init__(keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) -class Cropd(MapTransform, InvertibleTransform): +class Cropd(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of abstract class :py:class:`monai.transforms.Crop`. @@ -306,13 +315,19 @@ def __init__(self, keys: KeysCollection, cropper: Crop, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.cropper = cropper - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, value: bool) -> None: + self._lazy_evaluation = value + if isinstance(self.cropper, LazyTransform): + self.cropper.lazy_evaluation = value + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.cropper(d[key]) # type: ignore return d - def inverse(self, data: Mapping[Hashable, MetaTensor]) -> Dict[Hashable, MetaTensor]: + def inverse(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.cropper.inverse(d[key]) @@ -336,9 +351,7 @@ class RandCropd(Cropd, Randomizable): def __init__(self, keys: KeysCollection, cropper: Crop, allow_missing_keys: bool = False): super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandCropd": + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandCropd: super().set_random_state(seed, state) if isinstance(self.cropper, Randomizable): self.cropper.set_random_state(seed, state) @@ -348,10 +361,11 @@ def randomize(self, img_size: Sequence[int]) -> None: if isinstance(self.cropper, Randomizable): self.cropper.randomize(img_size) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) # the first key must exist to execute random operations - self.randomize(d[self.first_key(d)].shape[1:]) + first_item = d[self.first_key(d)] + self.randomize(first_item.peek_pending_shape() if isinstance(first_item, MetaTensor) else first_item.shape[1:]) for key in self.key_iterator(d): kwargs = {"randomize": False} if isinstance(self.cropper, Randomizable) else {} d[key] = self.cropper(d[key], **kwargs) # type: ignore @@ -376,11 +390,11 @@ class SpatialCropd(Cropd): def __init__( self, keys: KeysCollection, - roi_center: Optional[Sequence[int]] = None, - roi_size: Optional[Sequence[int]] = None, - roi_start: Optional[Sequence[int]] = None, - roi_end: Optional[Sequence[int]] = None, - roi_slices: Optional[Sequence[slice]] = None, + roi_center: Sequence[int] | None = None, + roi_size: Sequence[int] | None = None, + roi_start: Sequence[int] | None = None, + roi_end: Sequence[int] | None = None, + roi_slices: Sequence[slice] | None = None, allow_missing_keys: bool = False, ) -> None: """ @@ -419,9 +433,7 @@ class CenterSpatialCropd(Cropd): allow_missing_keys: don't raise exception if key is missing. """ - def __init__( - self, keys: KeysCollection, roi_size: Union[Sequence[int], int], allow_missing_keys: bool = False - ) -> None: + def __init__(self, keys: KeysCollection, roi_size: Sequence[int] | int, allow_missing_keys: bool = False) -> None: cropper = CenterSpatialCrop(roi_size) super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) @@ -441,7 +453,7 @@ class CenterScaleCropd(Cropd): """ def __init__( - self, keys: KeysCollection, roi_scale: Union[Sequence[float], float], allow_missing_keys: bool = False + self, keys: KeysCollection, roi_scale: Sequence[float] | float, allow_missing_keys: bool = False ) -> None: cropper = CenterScaleCrop(roi_scale) super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) @@ -477,11 +489,12 @@ class RandSpatialCropd(RandCropd): allow_missing_keys: don't raise exception if key is missing. """ + @deprecated_arg_default("random_size", True, False, since="1.1", replaced="1.3") def __init__( self, keys: KeysCollection, - roi_size: Union[Sequence[int], int], - max_roi_size: Optional[Union[Sequence[int], int]] = None, + roi_size: Sequence[int] | int, + max_roi_size: Sequence[int] | int | None = None, random_center: bool = True, random_size: bool = True, allow_missing_keys: bool = False, @@ -515,11 +528,12 @@ class RandScaleCropd(RandCropd): allow_missing_keys: don't raise exception if key is missing. """ + @deprecated_arg_default("random_size", True, False, since="1.1", replaced="1.3") def __init__( self, keys: KeysCollection, - roi_scale: Union[Sequence[float], float], - max_roi_scale: Optional[Union[Sequence[float], float]] = None, + roi_scale: Sequence[float] | float, + max_roi_scale: Sequence[float] | float | None = None, random_center: bool = True, random_size: bool = True, allow_missing_keys: bool = False, @@ -528,7 +542,7 @@ def __init__( super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) -class RandSpatialCropSamplesd(Randomizable, MapTransform): +class RandSpatialCropSamplesd(Randomizable, MapTransform, LazyTransform, MultiSampleTrait): """ Dictionary-based version :py:class:`monai.transforms.RandSpatialCropSamples`. Crop image with random size or specific size ROI to generate a list of N samples. @@ -566,28 +580,30 @@ class RandSpatialCropSamplesd(Randomizable, MapTransform): backend = RandSpatialCropSamples.backend - @deprecated_arg(name="meta_keys", since="0.9") - @deprecated_arg(name="meta_key_postfix", since="0.9") + @deprecated_arg_default("random_size", True, False, since="1.1", replaced="1.3") def __init__( self, keys: KeysCollection, - roi_size: Union[Sequence[int], int], + roi_size: Sequence[int] | int, num_samples: int, - max_roi_size: Optional[Union[Sequence[int], int]] = None, + max_roi_size: Sequence[int] | int | None = None, random_center: bool = True, random_size: bool = True, - meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) self.cropper = RandSpatialCropSamples(roi_size, num_samples, max_roi_size, random_center, random_size) - def randomize(self, data: Optional[Any] = None) -> None: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, value: bool) -> None: + self._lazy_evaluation = value + self.cropper.lazy_evaluation = value + + def randomize(self, data: Any | None = None) -> None: self.sub_seed = self.R.randint(MAX_SEED, dtype="uint32") - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> List[Dict[Hashable, torch.Tensor]]: - ret: List[Dict[Hashable, torch.Tensor]] = [dict(data) for _ in range(self.cropper.num_samples)] + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> list[dict[Hashable, torch.Tensor]]: + ret: list[dict[Hashable, torch.Tensor]] = [dict(data) for _ in range(self.cropper.num_samples)] # deep copy all the unmodified data for i in range(self.cropper.num_samples): for key in set(data.keys()).difference(set(self.keys)): @@ -620,10 +636,10 @@ def __init__( keys: KeysCollection, source_key: str, select_fn: Callable = is_positive, - channel_indices: Optional[IndexSelection] = None, - margin: Union[Sequence[int], int] = 0, + channel_indices: IndexSelection | None = None, + margin: Sequence[int] | int = 0, allow_smaller: bool = True, - k_divisible: Union[Sequence[int], int] = 1, + k_divisible: Sequence[int] | int = 1, mode: SequenceStr = PytorchPadMode.CONSTANT, start_coord_key: str = "foreground_start_coord", end_coord_key: str = "foreground_end_coord", @@ -672,20 +688,25 @@ def __init__( super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, value: bool) -> None: + self._lazy_evaluation = value + self.cropper.lazy_evaluation = value + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) self.cropper: CropForeground box_start, box_end = self.cropper.compute_bounding_box(img=d[self.source_key]) if self.start_coord_key is not None: - d[self.start_coord_key] = box_start + d[self.start_coord_key] = box_start # type: ignore if self.end_coord_key is not None: - d[self.end_coord_key] = box_end + d[self.end_coord_key] = box_end # type: ignore for key, m in self.key_iterator(d, self.mode): d[key] = self.cropper.crop_pad(img=d[key], box_start=box_start, box_end=box_end, mode=m) return d -class RandWeightedCropd(Randomizable, MapTransform): +class RandWeightedCropd(Randomizable, MapTransform, LazyTransform, MultiSampleTrait): """ Samples a list of `num_samples` image patches according to the provided `weight_map`. @@ -705,18 +726,12 @@ class RandWeightedCropd(Randomizable, MapTransform): backend = SpatialCrop.backend - @deprecated_arg(name="meta_keys", since="0.9") - @deprecated_arg(name="meta_key_postfix", since="0.9") - @deprecated_arg(name="center_coord_key", since="0.9", msg_suffix="coords stored in img.meta['crop_center']") def __init__( self, keys: KeysCollection, w_key: str, - spatial_size: Union[Sequence[int], int], + spatial_size: Sequence[int] | int, num_samples: int = 1, - center_coord_key: Optional[str] = None, - meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", allow_missing_keys: bool = False, ): MapTransform.__init__(self, keys, allow_missing_keys) @@ -724,8 +739,8 @@ def __init__( self.cropper = RandWeightedCrop(spatial_size, num_samples) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandWeightedCropd": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandWeightedCropd: super().set_random_state(seed, state) self.cropper.set_random_state(seed, state) return self @@ -733,9 +748,14 @@ def set_random_state( def randomize(self, weight_map: NdarrayOrTensor) -> None: self.cropper.randomize(weight_map) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> List[Dict[Hashable, torch.Tensor]]: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, value: bool) -> None: + self._lazy_evaluation = value + self.cropper.lazy_evaluation = value + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> list[dict[Hashable, torch.Tensor]]: # output starts as empty list of dictionaries - ret: List = [dict(data) for _ in range(self.cropper.num_samples)] + ret: list = [dict(data) for _ in range(self.cropper.num_samples)] # deep copy all the unmodified data for i in range(self.cropper.num_samples): for key in set(data.keys()).difference(set(self.keys)): @@ -743,12 +763,12 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> List[Dict[Hashable, self.randomize(weight_map=data[self.w_key]) for key in self.key_iterator(data): - for i, im in enumerate(self.cropper(data[key], weight_map=data[self.w_key], randomize=False)): + for i, im in enumerate(self.cropper(data[key], randomize=False)): ret[i][key] = im return ret -class RandCropByPosNegLabeld(Randomizable, MapTransform): +class RandCropByPosNegLabeld(Randomizable, MapTransform, LazyTransform, MultiSampleTrait): """ Dictionary-based version :py:class:`monai.transforms.RandCropByPosNegLabel`. Crop random fixed sized regions with the center being a foreground or background voxel @@ -802,22 +822,18 @@ class RandCropByPosNegLabeld(Randomizable, MapTransform): backend = RandCropByPosNegLabel.backend - @deprecated_arg(name="meta_keys", since="0.9") - @deprecated_arg(name="meta_key_postfix", since="0.9") def __init__( self, keys: KeysCollection, label_key: str, - spatial_size: Union[Sequence[int], int], + spatial_size: Sequence[int] | int, pos: float = 1.0, neg: float = 1.0, num_samples: int = 1, - image_key: Optional[str] = None, + image_key: str | None = None, image_threshold: float = 0.0, - fg_indices_key: Optional[str] = None, - bg_indices_key: Optional[str] = None, - meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", + fg_indices_key: str | None = None, + bg_indices_key: str | None = None, allow_smaller: bool = False, allow_missing_keys: bool = False, ) -> None: @@ -836,44 +852,47 @@ def __init__( ) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandCropByPosNegLabeld": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandCropByPosNegLabeld: super().set_random_state(seed, state) self.cropper.set_random_state(seed, state) return self def randomize( self, - label: torch.Tensor, - fg_indices: Optional[NdarrayOrTensor] = None, - bg_indices: Optional[NdarrayOrTensor] = None, - image: Optional[torch.Tensor] = None, + label: torch.Tensor | None = None, + fg_indices: NdarrayOrTensor | None = None, + bg_indices: NdarrayOrTensor | None = None, + image: torch.Tensor | None = None, ) -> None: self.cropper.randomize(label=label, fg_indices=fg_indices, bg_indices=bg_indices, image=image) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> List[Dict[Hashable, torch.Tensor]]: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, value: bool) -> None: + self._lazy_evaluation = value + self.cropper.lazy_evaluation = value + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> list[dict[Hashable, torch.Tensor]]: d = dict(data) - label = d[self.label_key] - image = d[self.image_key] if self.image_key else None - fg_indices = d.pop(self.fg_indices_key, None) if self.fg_indices_key is not None else None - bg_indices = d.pop(self.bg_indices_key, None) if self.bg_indices_key is not None else None + fg_indices = d.pop(self.fg_indices_key, None) + bg_indices = d.pop(self.bg_indices_key, None) - self.randomize(label, fg_indices, bg_indices, image) + self.randomize(d.get(self.label_key), fg_indices, bg_indices, d.get(self.image_key)) # initialize returned list with shallow copy to preserve key ordering - ret: List = [dict(d) for _ in range(self.cropper.num_samples)] + ret: list = [dict(d) for _ in range(self.cropper.num_samples)] # deep copy all the unmodified data for i in range(self.cropper.num_samples): for key in set(d.keys()).difference(set(self.keys)): ret[i][key] = deepcopy(d[key]) for key in self.key_iterator(d): - for i, im in enumerate(self.cropper(d[key], label=label, randomize=False)): + for i, im in enumerate(self.cropper(d[key], randomize=False)): ret[i][key] = im return ret -class RandCropByLabelClassesd(Randomizable, MapTransform): +class RandCropByLabelClassesd(Randomizable, MapTransform, LazyTransform, MultiSampleTrait): """ Dictionary-based version :py:class:`monai.transforms.RandCropByLabelClasses`. Crop random fixed sized regions with the center being a class based on the specified ratios of every class. @@ -942,28 +961,27 @@ class RandCropByLabelClassesd(Randomizable, MapTransform): the requested ROI in any dimension. If `True`, any smaller dimensions will remain unchanged. allow_missing_keys: don't raise exception if key is missing. + warn: if `True` prints a warning if a class is not present in the label. + """ backend = RandCropByLabelClasses.backend - @deprecated_arg(name="meta_keys", since="0.9") - @deprecated_arg(name="meta_key_postfix", since="0.9") def __init__( self, keys: KeysCollection, label_key: str, - spatial_size: Union[Sequence[int], int], - ratios: Optional[List[Union[float, int]]] = None, - num_classes: Optional[int] = None, + spatial_size: Sequence[int] | int, + ratios: list[float | int] | None = None, + num_classes: int | None = None, num_samples: int = 1, - image_key: Optional[str] = None, + image_key: str | None = None, image_threshold: float = 0.0, - indices_key: Optional[str] = None, - meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", + indices_key: str | None = None, allow_smaller: bool = False, allow_missing_keys: bool = False, + warn: bool = True, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) self.label_key = label_key @@ -976,37 +994,39 @@ def __init__( num_samples=num_samples, image_threshold=image_threshold, allow_smaller=allow_smaller, + warn=warn, ) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandCropByLabelClassesd": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandCropByLabelClassesd: super().set_random_state(seed, state) self.cropper.set_random_state(seed, state) return self def randomize( - self, label: torch.Tensor, indices: Optional[List[NdarrayOrTensor]] = None, image: Optional[torch.Tensor] = None + self, label: torch.Tensor, indices: list[NdarrayOrTensor] | None = None, image: torch.Tensor | None = None ) -> None: self.cropper.randomize(label=label, indices=indices, image=image) - def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, torch.Tensor]]: - d = dict(data) - label = d[self.label_key] - image = d[self.image_key] if self.image_key else None - indices = d.pop(self.indices_key, None) if self.indices_key is not None else None + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, value: bool) -> None: + self._lazy_evaluation = value + self.cropper.lazy_evaluation = value - self.randomize(label, indices, image) + def __call__(self, data: Mapping[Hashable, Any]) -> list[dict[Hashable, torch.Tensor]]: + d = dict(data) + self.randomize(d.get(self.label_key), d.pop(self.indices_key, None), d.get(self.image_key)) # type: ignore # initialize returned list with shallow copy to preserve key ordering - ret: List = [dict(d) for _ in range(self.cropper.num_samples)] + ret: list = [dict(d) for _ in range(self.cropper.num_samples)] # deep copy all the unmodified data for i in range(self.cropper.num_samples): for key in set(d.keys()).difference(set(self.keys)): ret[i][key] = deepcopy(d[key]) for key in self.key_iterator(d): - for i, im in enumerate(self.cropper(d[key], label=label, randomize=False)): + for i, im in enumerate(self.cropper(d[key], randomize=False)): ret[i][key] = im return ret @@ -1038,7 +1058,7 @@ class ResizeWithPadOrCropd(Padd): def __init__( self, keys: KeysCollection, - spatial_size: Union[Sequence[int], int], + spatial_size: Sequence[int] | int, mode: SequenceStr = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False, method: str = Method.SYMMETRIC, @@ -1074,7 +1094,7 @@ def __init__( self.bbox = BoundingRect(select_fn=select_fn) self.bbox_key_postfix = bbox_key_postfix - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: """ See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`. """ diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py new file mode 100644 index 0000000000..fa95958bd5 --- /dev/null +++ b/monai/transforms/croppad/functional.py @@ -0,0 +1,236 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A collection of "functional" transforms for spatial operations +https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design +""" + +from __future__ import annotations + +import warnings + +import numpy as np +import torch +from torch.nn.functional import pad as pad_pt + +from monai.data.meta_obj import get_track_meta +from monai.data.meta_tensor import MetaTensor +from monai.data.utils import to_affine_nd +from monai.transforms.inverse import TraceableTransform +from monai.transforms.utils import convert_pad_mode, create_translate +from monai.utils import ( + PytorchPadMode, + TraceKeys, + convert_to_dst_type, + convert_to_numpy, + convert_to_tensor, + ensure_tuple, +) + +__all__ = ["pad_nd", "pad_func", "crop_func", "crop_or_pad_nd"] + + +def _convert_pt_pad_mode(padding_mode): + """get the most similar mode of `pad` from ``padding_mode`` of the spatial resampling.""" + if padding_mode is None or padding_mode in ("zeros", "constant", "grid-constant"): + return PytorchPadMode.CONSTANT + elif padding_mode in ("reflection", "reflect", "mirror", "grid-mirror"): + return PytorchPadMode.REFLECT + elif padding_mode in ("wrap", "grid-wrap"): + return PytorchPadMode.CIRCULAR + return PytorchPadMode.REPLICATE # "nearest", "border", and others + + +def _np_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> torch.Tensor: + if isinstance(img, torch.Tensor): + if img.is_cuda: + warnings.warn(f"Padding: moving img {img.shape} from cuda to cpu for dtype={img.dtype} mode={mode}.") + img_np = img.detach().cpu().numpy() + else: + img_np = img + mode = convert_pad_mode(dst=img_np, mode=mode).value + if mode == "constant" and "value" in kwargs: + kwargs["constant_values"] = kwargs.pop("value") + out = torch.as_tensor(np.pad(img, pad_width, mode=mode, **kwargs)) # type: ignore + if isinstance(img, MetaTensor): + out = convert_to_dst_type(out, dst=img)[0] + return out + + +def _pt_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> torch.Tensor: + mode = convert_pad_mode(dst=img, mode=mode).value + if mode == "constant" and "constant_values" in kwargs: + _kwargs = kwargs.copy() + _kwargs["value"] = _kwargs.pop("constant_values") + else: + _kwargs = kwargs + pt_pad_width = [val for sublist in pad_width[1:] for val in sublist[::-1]][::-1] + # torch.pad expects `[B, C, H, W, [D]]` shape + return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **_kwargs).squeeze(0) + + +def pad_nd(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, **kwargs): + """ + PyTorch/Numpy pad ``img`` with integers ``to_pad`` amounts. Depending on the ``mode`` and input dtype, + a suitable backend will be used automatically. + + Args: + img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. + to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...]. + default to `self.to_pad`. + mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. + """ + if mode in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"}: + return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs) + mode = convert_pad_mode(dst=img, mode=mode).value + try: + _pad = ( + _np_pad + if mode in {"reflect", "replicate"} and img.dtype in {torch.int16, torch.int64, torch.bool, torch.uint8} + else _pt_pad + ) + return _pad(img, pad_width=to_pad, mode=mode, **kwargs) + except (ValueError, TypeError, RuntimeError) as err: + if isinstance(err, NotImplementedError) or any( + k in str(err) for k in ("supported", "unexpected keyword", "implemented", "value") + ): + return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs) + raise ValueError(f"{img.shape} {to_pad} {mode} {kwargs} {img.dtype} {img.device}") from err + + +def crop_or_pad_nd(img: torch.Tensor, translation_mat, spatial_size: tuple[int, ...], mode: str, **kwargs): + """ + Crop or pad using the translation matrix and spatial size. The translation coefficients are rounded + to the nearest integers. For a more generic implementation, please see :py:class:`monai.transforms.SpatialResample`. + + Args: + img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. + translation_mat: the translation matrix to be applied to the image. A translation matrix generated by, + for example, :py:func:`monai.transforms.utils.create_translate`. The translation coefficients are rounded + to the nearest integers. + spatial_size: the spatial size of the output image. + mode: the padding mode. + kwargs: other arguments for the `np.pad` or `torch.pad` function. + """ + ndim = len(img.shape) - 1 + matrix_np = np.round(to_affine_nd(ndim, convert_to_numpy(translation_mat, wrap_sequence=True).copy())) + matrix_np = to_affine_nd(len(spatial_size), matrix_np) + cc = np.asarray(np.meshgrid(*[[0.5, x - 0.5] for x in spatial_size], indexing="ij")) + cc = cc.reshape((len(spatial_size), -1)) + src_cc = np.floor(matrix_np @ np.concatenate((cc, np.ones_like(cc[:1])))) + src_start, src_end = src_cc.min(axis=1), src_cc.max(axis=1) + to_pad, to_crop, do_pad, do_crop = [(0, 0)], [slice(None)], False, False + for s, e, sp in zip(src_start, src_end, img.shape[1:]): + do_pad, do_crop = do_pad or s < 0 or e > sp - 1, do_crop or s > 0 or e < sp - 1 + to_pad += [(0 if s >= 0 else int(-s), 0 if e < sp - 1 else int(e - sp + 1))] + to_crop += [slice(int(max(s, 0)), int(e + 1 + to_pad[-1][0]))] + if do_pad: + _mode = _convert_pt_pad_mode(mode) + img = pad_nd(img, to_pad, mode=_mode, **kwargs) + if do_crop: + img = img[to_crop] + return img + + +def pad_func( + img: torch.Tensor, to_pad: tuple[tuple[int, int]], mode: str, transform_info: dict, kwargs +) -> torch.Tensor: + """ + Functional implementation of padding a MetaTensor. This function operates eagerly or lazily according + to ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + + Args: + img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. + to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...]. + note that it including channel dimension. + mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + transform_info: a dictionary with the relevant information pertaining to an applied transform. + kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. + """ + extra_info = {"padded": to_pad, "mode": f"{mode}"} + img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] + spatial_rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else 3 + do_pad = np.asarray(to_pad).any() + if do_pad: + to_pad_list = [(int(p[0]), int(p[1])) for p in to_pad] + if len(to_pad_list) < len(img.shape): + to_pad_list += [(0, 0)] * (len(img.shape) - len(to_pad_list)) + to_shift = [-s[0] for s in to_pad_list[1:]] # skipping the channel pad + xform = create_translate(spatial_rank, to_shift) + shape = [d + s + e for d, (s, e) in zip(img_size, to_pad_list[1:])] + else: + shape = img_size + xform = torch.eye(int(spatial_rank) + 1, device=torch.device("cpu"), dtype=torch.float64) + meta_info = TraceableTransform.track_transform_meta( + img, + sp_size=shape, + affine=xform, + extra_info=extra_info, + orig_size=img_size, + transform_info=transform_info, + lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + ) + out = convert_to_tensor(img.as_tensor() if isinstance(img, MetaTensor) else img, track_meta=get_track_meta()) + if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore + out = pad_nd(out, to_pad_list, mode, **kwargs) if do_pad else out + out = convert_to_tensor(out, track_meta=get_track_meta()) + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore + + +def crop_func(img: torch.Tensor, slices: tuple[slice, ...], transform_info: dict) -> torch.Tensor: + """ + Functional implementation of cropping a MetaTensor. This function operates eagerly or lazily according + to ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + + Args: + img: data to be transformed, assuming `img` is channel-first and cropping doesn't apply to the channel dim. + slices: the crop slices computed based on specified `center & size` or `start & end` or `slices`. + transform_info: a dictionary with the relevant information pertaining to an applied transform. + """ + img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] + spatial_rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else 3 + cropped = np.asarray([[s.indices(o)[0], o - s.indices(o)[1]] for s, o in zip(slices[1:], img_size)]) + extra_info = {"cropped": cropped.flatten().tolist()} + to_shift = [] + for i, s in enumerate(ensure_tuple(slices)[1:]): + if s.start is not None: + to_shift.append(img_size[i] + s.start if s.start < 0 else s.start) + else: + to_shift.append(0) + shape = [s.indices(o)[1] - s.indices(o)[0] for s, o in zip(slices[1:], img_size)] + meta_info = TraceableTransform.track_transform_meta( + img, + sp_size=shape, + affine=create_translate(spatial_rank, to_shift), + extra_info=extra_info, + orig_size=img_size, + transform_info=transform_info, + lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + ) + out = convert_to_tensor(img.as_tensor() if isinstance(img, MetaTensor) else img, track_meta=get_track_meta()) + if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore + out = out[slices] + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index c136b3aa09..588913f579 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -13,10 +13,12 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ +from __future__ import annotations + from abc import abstractmethod -from collections.abc import Iterable +from collections.abc import Callable, Iterable, Sequence from functools import partial -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any from warnings import warn import numpy as np @@ -96,9 +98,9 @@ def __init__(self, prob: float = 0.1, mean: float = 0.0, std: float = 0.1, dtype self.mean = mean self.std = std self.dtype = dtype - self.noise: Optional[np.ndarray] = None + self.noise: np.ndarray | None = None - def randomize(self, img: NdarrayOrTensor, mean: Optional[float] = None) -> None: + def randomize(self, img: NdarrayOrTensor, mean: float | None = None) -> None: super().randomize(None) if not self._do_transform: return None @@ -107,7 +109,7 @@ def randomize(self, img: NdarrayOrTensor, mean: Optional[float] = None) -> None: # noise is float64 array, convert to the output dtype to save memory self.noise, *_ = convert_data_type(noise, dtype=self.dtype) - def __call__(self, img: NdarrayOrTensor, mean: Optional[float] = None, randomize: bool = True) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, mean: float | None = None, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`. """ @@ -155,8 +157,8 @@ class RandRicianNoise(RandomizableTransform): def __init__( self, prob: float = 0.1, - mean: Union[Sequence[float], float] = 0.0, - std: Union[Sequence[float], float] = 1.0, + mean: Sequence[float] | float = 0.0, + std: Sequence[float] | float = 1.0, channel_wise: bool = False, relative: bool = False, sample_std: bool = True, @@ -204,12 +206,12 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen img[i] = self._add_noise(d, mean=_mean[i], std=_std[i] * d.std() if self.relative else _std[i]) else: if not isinstance(self.mean, (int, float)): - raise RuntimeError("If channel_wise is False, mean must be a float or int number.") + raise RuntimeError(f"If channel_wise is False, mean must be a float or int, got {type(self.mean)}.") if not isinstance(self.std, (int, float)): - raise RuntimeError("If channel_wise is False, std must be a float or int number.") + raise RuntimeError(f"If channel_wise is False, std must be a float or int, got {type(self.std)}.") std = self.std * img.std().item() if self.relative else self.std if not isinstance(std, (int, float)): - raise RuntimeError("std must be a float or int number.") + raise RuntimeError(f"std must be a float or int number, got {type(std)}.") img = self._add_noise(img, mean=self.mean, std=std) return img @@ -230,7 +232,7 @@ def __init__(self, offset: float, safe: bool = False) -> None: self.offset = offset self.safe = safe - def __call__(self, img: NdarrayOrTensor, offset: Optional[float] = None) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, offset: float | None = None) -> NdarrayOrTensor: """ Apply the transform to `img`. """ @@ -250,7 +252,7 @@ class RandShiftIntensity(RandomizableTransform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, offsets: Union[Tuple[float, float], float], safe: bool = False, prob: float = 0.1) -> None: + def __init__(self, offsets: tuple[float, float] | float, safe: bool = False, prob: float = 0.1) -> None: """ Args: offsets: offset range to randomly shift. @@ -263,19 +265,19 @@ def __init__(self, offsets: Union[Tuple[float, float], float], safe: bool = Fals if isinstance(offsets, (int, float)): self.offsets = (min(-offsets, offsets), max(-offsets, offsets)) elif len(offsets) != 2: - raise ValueError("offsets should be a number or pair of numbers.") + raise ValueError(f"offsets should be a number or pair of numbers, got {offsets}.") else: self.offsets = (min(offsets), max(offsets)) self._offset = self.offsets[0] self._shifter = ShiftIntensity(self._offset, safe) - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, data: Any | None = None) -> None: super().randomize(None) if not self._do_transform: return None self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) - def __call__(self, img: NdarrayOrTensor, factor: Optional[float] = None, randomize: bool = True) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, factor: float | None = None, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`. @@ -359,7 +361,7 @@ class RandStdShiftIntensity(RandomizableTransform): def __init__( self, - factors: Union[Tuple[float, float], float], + factors: tuple[float, float] | float, prob: float = 0.1, nonzero: bool = False, channel_wise: bool = False, @@ -379,7 +381,7 @@ def __init__( if isinstance(factors, (int, float)): self.factors = (min(-factors, factors), max(-factors, factors)) elif len(factors) != 2: - raise ValueError("factors should be a number or pair of numbers.") + raise ValueError(f"factors should be a number or pair of numbers, got {factors}.") else: self.factors = (min(factors), max(factors)) self.factor = self.factors[0] @@ -387,7 +389,7 @@ def __init__( self.channel_wise = channel_wise self.dtype = dtype - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, data: Any | None = None) -> None: super().randomize(None) if not self._do_transform: return None @@ -420,9 +422,9 @@ class ScaleIntensity(Transform): def __init__( self, - minv: Optional[float] = 0.0, - maxv: Optional[float] = 1.0, - factor: Optional[float] = None, + minv: float | None = 0.0, + maxv: float | None = 1.0, + factor: float | None = None, channel_wise: bool = False, dtype: DtypeLike = np.float32, ) -> None: @@ -473,9 +475,7 @@ class RandScaleIntensity(RandomizableTransform): backend = ScaleIntensity.backend - def __init__( - self, factors: Union[Tuple[float, float], float], prob: float = 0.1, dtype: DtypeLike = np.float32 - ) -> None: + def __init__(self, factors: tuple[float, float] | float, prob: float = 0.1, dtype: DtypeLike = np.float32) -> None: """ Args: factors: factor range to randomly scale by ``v = v * (1 + factor)``. @@ -488,13 +488,13 @@ def __init__( if isinstance(factors, (int, float)): self.factors = (min(-factors, factors), max(-factors, factors)) elif len(factors) != 2: - raise ValueError("factors should be a number or pair of numbers.") + raise ValueError(f"factors should be a number or pair of numbers, got {factors}.") else: self.factors = (min(factors), max(factors)) self.factor = self.factors[0] self.dtype = dtype - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, data: Any | None = None) -> None: super().randomize(None) if not self._do_transform: return None @@ -539,13 +539,13 @@ class RandBiasField(RandomizableTransform): def __init__( self, degree: int = 3, - coeff_range: Tuple[float, float] = (0.0, 0.1), + coeff_range: tuple[float, float] = (0.0, 0.1), dtype: DtypeLike = np.float32, prob: float = 0.1, ) -> None: RandomizableTransform.__init__(self, prob) if degree < 1: - raise ValueError("degree should be no less than 1.") + raise ValueError(f"degree should be no less than 1, got {degree}.") self.degree = degree self.coeff_range = coeff_range self.dtype = dtype @@ -563,7 +563,7 @@ def _generate_random_field(self, spatial_shape: Sequence[int], degree: int, coef coeff_mat[np.tril_indices(degree + 1)] = coeff return np.polynomial.legendre.leggrid2d(coords[0], coords[1], coeff_mat) if rank == 3: - pts: List[List[int]] = [[0, 0, 0]] + pts: list[list[int]] = [[0, 0, 0]] for i in range(degree + 1): for j in range(degree + 1 - i): for k in range(degree + 1 - i - j): @@ -629,8 +629,8 @@ class NormalizeIntensity(Transform): def __init__( self, - subtrahend: Union[Sequence, NdarrayOrTensor, None] = None, - divisor: Union[Sequence, NdarrayOrTensor, None] = None, + subtrahend: Sequence | NdarrayOrTensor | None = None, + divisor: Sequence | NdarrayOrTensor | None = None, nonzero: bool = False, channel_wise: bool = False, dtype: DtypeLike = np.float32, @@ -725,7 +725,7 @@ class ThresholdIntensity(Transform): def __init__(self, threshold: float, above: bool = True, cval: float = 0.0) -> None: if not isinstance(threshold, (int, float)): - raise ValueError("threshold must be a float or int number.") + raise ValueError(f"threshold must be a float or int number, got {type(threshold)} {threshold}.") self.threshold = threshold self.above = above self.cval = cval @@ -764,8 +764,8 @@ def __init__( self, a_min: float, a_max: float, - b_min: Optional[float] = None, - b_max: Optional[float] = None, + b_min: float | None = None, + b_max: float | None = None, clip: bool = False, dtype: DtypeLike = np.float32, ) -> None: @@ -812,7 +812,7 @@ class AdjustContrast(Transform): def __init__(self, gamma: float) -> None: if not isinstance(gamma, (int, float)): - raise ValueError("gamma must be a float or int number.") + raise ValueError(f"gamma must be a float or int number, got {type(gamma)} {gamma}.") self.gamma = gamma def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: @@ -841,13 +841,13 @@ class RandAdjustContrast(RandomizableTransform): backend = AdjustContrast.backend - def __init__(self, prob: float = 0.1, gamma: Union[Sequence[float], float] = (0.5, 4.5)) -> None: + def __init__(self, prob: float = 0.1, gamma: Sequence[float] | float = (0.5, 4.5)) -> None: RandomizableTransform.__init__(self, prob) if isinstance(gamma, (int, float)): if gamma <= 0.5: raise ValueError( - "if gamma is single number, must greater than 0.5 and value is picked from (0.5, gamma)" + f"if gamma is a number, must greater than 0.5 and value is picked from (0.5, gamma), got {gamma}" ) self.gamma = (0.5, gamma) elif len(gamma) != 2: @@ -855,9 +855,9 @@ def __init__(self, prob: float = 0.1, gamma: Union[Sequence[float], float] = (0. else: self.gamma = (min(gamma), max(gamma)) - self.gamma_value: Optional[float] = None + self.gamma_value: float | None = None - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, data: Any | None = None) -> None: super().randomize(None) if not self._do_transform: return None @@ -947,8 +947,8 @@ def __init__( self, lower: float, upper: float, - b_min: Optional[float], - b_max: Optional[float], + b_min: float | None, + b_max: float | None, clip: bool = False, relative: bool = False, channel_wise: bool = False, @@ -1020,11 +1020,11 @@ class MaskIntensity(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, mask_data: Optional[NdarrayOrTensor] = None, select_fn: Callable = is_positive) -> None: + def __init__(self, mask_data: NdarrayOrTensor | None = None, select_fn: Callable = is_positive) -> None: self.mask_data = mask_data self.select_fn = select_fn - def __call__(self, img: NdarrayOrTensor, mask_data: Optional[NdarrayOrTensor] = None) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor | None = None) -> NdarrayOrTensor: """ Args: mask_data: if mask data is single channel, apply to every channel @@ -1069,7 +1069,6 @@ class SavitzkyGolaySmooth(Transform): backend = [TransformBackends.TORCH] def __init__(self, window_length: int, order: int, axis: int = 1, mode: str = "zeros"): - if axis < 0: raise ValueError("axis must be zero or positive.") @@ -1113,8 +1112,7 @@ class DetectEnvelope(Transform): backend = [TransformBackends.TORCH] - def __init__(self, axis: int = 1, n: Union[int, None] = None) -> None: - + def __init__(self, axis: int = 1, n: int | None = None) -> None: if axis < 0: raise ValueError("axis must be zero or positive.") @@ -1157,7 +1155,7 @@ class MedianSmooth(Transform): backend = [TransformBackends.TORCH] - def __init__(self, radius: Union[Sequence[int], int] = 1) -> None: + def __init__(self, radius: Sequence[int] | int = 1) -> None: self.radius = radius def __call__(self, img: NdarrayTensor) -> NdarrayTensor: @@ -1187,14 +1185,14 @@ class GaussianSmooth(Transform): backend = [TransformBackends.TORCH] - def __init__(self, sigma: Union[Sequence[float], float] = 1.0, approx: str = "erf") -> None: + def __init__(self, sigma: Sequence[float] | float = 1.0, approx: str = "erf") -> None: self.sigma = sigma self.approx = approx def __call__(self, img: NdarrayTensor) -> NdarrayTensor: img = convert_to_tensor(img, track_meta=get_track_meta()) img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) - sigma: Union[Sequence[torch.Tensor], torch.Tensor] + sigma: Sequence[torch.Tensor] | torch.Tensor if isinstance(self.sigma, Sequence): sigma = [torch.as_tensor(s, device=img_t.device) for s in self.sigma] else: @@ -1224,9 +1222,9 @@ class RandGaussianSmooth(RandomizableTransform): def __init__( self, - sigma_x: Tuple[float, float] = (0.25, 1.5), - sigma_y: Tuple[float, float] = (0.25, 1.5), - sigma_z: Tuple[float, float] = (0.25, 1.5), + sigma_x: tuple[float, float] = (0.25, 1.5), + sigma_y: tuple[float, float] = (0.25, 1.5), + sigma_z: tuple[float, float] = (0.25, 1.5), prob: float = 0.1, approx: str = "erf", ) -> None: @@ -1240,7 +1238,7 @@ def __init__( self.y = self.sigma_y[0] self.z = self.sigma_z[0] - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, data: Any | None = None) -> None: super().randomize(None) if not self._do_transform: return None @@ -1256,7 +1254,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen if not self._do_transform: return img - sigma = ensure_tuple_size(tup=(self.x, self.y, self.z), dim=img.ndim - 1) + sigma = ensure_tuple_size(vals=(self.x, self.y, self.z), dim=img.ndim - 1) return GaussianSmooth(sigma=sigma, approx=self.approx)(img) @@ -1291,8 +1289,8 @@ class GaussianSharpen(Transform): def __init__( self, - sigma1: Union[Sequence[float], float] = 3.0, - sigma2: Union[Sequence[float], float] = 1.0, + sigma1: Sequence[float] | float = 3.0, + sigma2: Sequence[float] | float = 1.0, alpha: float = 30.0, approx: str = "erf", ) -> None: @@ -1342,13 +1340,13 @@ class RandGaussianSharpen(RandomizableTransform): def __init__( self, - sigma1_x: Tuple[float, float] = (0.5, 1.0), - sigma1_y: Tuple[float, float] = (0.5, 1.0), - sigma1_z: Tuple[float, float] = (0.5, 1.0), - sigma2_x: Union[Tuple[float, float], float] = 0.5, - sigma2_y: Union[Tuple[float, float], float] = 0.5, - sigma2_z: Union[Tuple[float, float], float] = 0.5, - alpha: Tuple[float, float] = (10.0, 30.0), + sigma1_x: tuple[float, float] = (0.5, 1.0), + sigma1_y: tuple[float, float] = (0.5, 1.0), + sigma1_z: tuple[float, float] = (0.5, 1.0), + sigma2_x: tuple[float, float] | float = 0.5, + sigma2_y: tuple[float, float] | float = 0.5, + sigma2_z: tuple[float, float] | float = 0.5, + alpha: tuple[float, float] = (10.0, 30.0), approx: str = "erf", prob: float = 0.1, ) -> None: @@ -1361,15 +1359,15 @@ def __init__( self.sigma2_z = sigma2_z self.alpha = alpha self.approx = approx - self.x1: Optional[float] = None - self.y1: Optional[float] = None - self.z1: Optional[float] = None - self.x2: Optional[float] = None - self.y2: Optional[float] = None - self.z2: Optional[float] = None - self.a: Optional[float] = None - - def randomize(self, data: Optional[Any] = None) -> None: + self.x1: float | None = None + self.y1: float | None = None + self.z1: float | None = None + self.x2: float | None = None + self.y2: float | None = None + self.z2: float | None = None + self.a: float | None = None + + def randomize(self, data: Any | None = None) -> None: super().randomize(None) if not self._do_transform: return None @@ -1394,8 +1392,8 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen if self.x2 is None or self.y2 is None or self.z2 is None or self.a is None: raise RuntimeError("please call the `randomize()` function first.") - sigma1 = ensure_tuple_size(tup=(self.x1, self.y1, self.z1), dim=img.ndim - 1) - sigma2 = ensure_tuple_size(tup=(self.x2, self.y2, self.z2), dim=img.ndim - 1) + sigma1 = ensure_tuple_size(vals=(self.x1, self.y1, self.z1), dim=img.ndim - 1) + sigma2 = ensure_tuple_size(vals=(self.x2, self.y2, self.z2), dim=img.ndim - 1) return GaussianSharpen(sigma1=sigma1, sigma2=sigma2, alpha=self.a, approx=self.approx)(img) @@ -1412,7 +1410,7 @@ class RandHistogramShift(RandomizableTransform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, num_control_points: Union[Tuple[int, int], int] = 10, prob: float = 0.1) -> None: + def __init__(self, num_control_points: tuple[int, int] | int = 10, prob: float = 0.1) -> None: RandomizableTransform.__init__(self, prob) if isinstance(num_control_points, int): @@ -1445,7 +1443,7 @@ def interp(self, x: NdarrayOrTensor, xp: NdarrayOrTensor, fp: NdarrayOrTensor) - f[x > xp[-1]] = fp[-1] return f - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, data: Any | None = None) -> None: super().randomize(None) if not self._do_transform: return None @@ -1468,9 +1466,15 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen if self.reference_control_points is None or self.floating_control_points is None: raise RuntimeError("please call the `randomize()` function first.") img_t = convert_to_tensor(img, track_meta=False) + img_min, img_max = img_t.min(), img_t.max() + if img_min == img_max: + warn( + f"The image's intensity is a single value {img_min}. " + "The original image is simply returned, no histogram shift is done." + ) + return img xp, *_ = convert_to_dst_type(self.reference_control_points, dst=img_t) yp, *_ = convert_to_dst_type(self.floating_control_points, dst=img_t) - img_min, img_max = img_t.min(), img_t.max() reference_control_points_scaled = xp * (img_max - img_min) + img_min floating_control_points_scaled = yp * (img_max - img_min) + img_min img_t = self.interp(img_t, reference_control_points_scaled, floating_control_points_scaled) @@ -1500,7 +1504,6 @@ class GibbsNoise(Transform, Fourier): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__(self, alpha: float = 0.1) -> None: - if alpha > 1 or alpha < 0: raise ValueError("alpha must take values in the interval [0, 1].") self.alpha = alpha @@ -1649,8 +1652,7 @@ class KSpaceSpikeNoise(Transform, Fourier): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, loc: Union[Tuple, Sequence[Tuple]], k_intensity: Optional[Union[Sequence[float], float]] = None): - + def __init__(self, loc: tuple | Sequence[tuple], k_intensity: Sequence[float] | float | None = None): self.loc = ensure_tuple(loc) self.k_intensity = k_intensity @@ -1724,7 +1726,7 @@ def _check_indices(self, img) -> None: f"The index value at position {i} of one of the tuples in loc = {self.loc} is out of bounds for current image." ) - def _set_spike(self, k: NdarrayOrTensor, idx: Tuple, val: Union[Sequence[float], float]): + def _set_spike(self, k: NdarrayOrTensor, idx: tuple, val: Sequence[float] | float): """ Helper function to introduce a given intensity at given location. @@ -1781,14 +1783,13 @@ class RandKSpaceSpikeNoise(RandomizableTransform, Fourier): def __init__( self, prob: float = 0.1, - intensity_range: Optional[Sequence[Union[Sequence[float], float]]] = None, + intensity_range: Sequence[Sequence[float] | float] | None = None, channel_wise: bool = True, ): - self.intensity_range = intensity_range self.channel_wise = channel_wise - self.sampled_k_intensity: List = [] - self.sampled_locs: List[Tuple] = [] + self.sampled_k_intensity: list = [] + self.sampled_locs: list[tuple] = [] if intensity_range is not None and isinstance(intensity_range[0], Sequence) and not channel_wise: raise ValueError("When channel_wise = False, intensity_range should be a 2-tuple (low, high) or None.") @@ -1909,9 +1910,9 @@ class RandCoarseTransform(RandomizableTransform): def __init__( self, holes: int, - spatial_size: Union[Sequence[int], int], - max_holes: Optional[int] = None, - max_spatial_size: Optional[Union[Sequence[int], int]] = None, + spatial_size: Sequence[int] | int, + max_holes: int | None = None, + max_spatial_size: Sequence[int] | int | None = None, prob: float = 0.1, ) -> None: RandomizableTransform.__init__(self, prob) @@ -1921,7 +1922,7 @@ def __init__( self.spatial_size = spatial_size self.max_holes = max_holes self.max_spatial_size = max_spatial_size - self.hole_coords: List = [] + self.hole_coords: list = [] def randomize(self, img_size: Sequence[int]) -> None: super().randomize(None) @@ -1993,11 +1994,11 @@ class RandCoarseDropout(RandCoarseTransform): def __init__( self, holes: int, - spatial_size: Union[Sequence[int], int], + spatial_size: Sequence[int] | int, dropout_holes: bool = True, - fill_value: Optional[Union[Tuple[float, float], float]] = None, - max_holes: Optional[int] = None, - max_spatial_size: Optional[Union[Sequence[int], int]] = None, + fill_value: tuple[float, float] | float | None = None, + max_holes: int | None = None, + max_spatial_size: Sequence[int] | int | None = None, prob: float = 0.1, ) -> None: super().__init__( @@ -2098,7 +2099,7 @@ def __init__( num_bins: int = 256, min: int = 0, max: int = 255, - mask: Optional[NdarrayOrTensor] = None, + mask: NdarrayOrTensor | None = None, dtype: DtypeLike = np.float32, ) -> None: self.num_bins = num_bins @@ -2107,11 +2108,11 @@ def __init__( self.mask = mask self.dtype = dtype - def __call__(self, img: NdarrayOrTensor, mask: Optional[NdarrayOrTensor] = None) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, mask: NdarrayOrTensor | None = None) -> NdarrayOrTensor: img = convert_to_tensor(img, track_meta=get_track_meta()) img_np, *_ = convert_data_type(img, np.ndarray) mask = mask if mask is not None else self.mask - mask_np: Optional[np.ndarray] = None + mask_np: np.ndarray | None = None if mask is not None: mask_np, *_ = convert_data_type(mask, np.ndarray) @@ -2146,7 +2147,6 @@ class IntensityRemap(RandomizableTransform): """ def __init__(self, kernel_size: int = 30, slope: float = 0.7): - super().__init__() self.kernel_size = kernel_size @@ -2201,7 +2201,6 @@ class RandIntensityRemap(RandomizableTransform): """ def __init__(self, prob: float = 0.1, kernel_size: int = 30, slope: float = 0.7, channel_wise: bool = True): - RandomizableTransform.__init__(self, prob=prob) self.kernel_size = kernel_size self.slope = slope @@ -2252,11 +2251,11 @@ class ForegroundMask(Transform): def __init__( self, - threshold: Union[Dict, Callable, str, float, int] = "otsu", - hsv_threshold: Optional[Union[Dict, Callable, str, float, int]] = None, + threshold: dict | Callable | str | float | int = "otsu", + hsv_threshold: dict | Callable | str | float | int | None = None, invert: bool = False, ) -> None: - self.thresholds: Dict[str, Union[Callable, float]] = {} + self.thresholds: dict[str, Callable | float] = {} if threshold is not None: if isinstance(threshold, dict): for mode, th in threshold.items(): diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 2e087652a2..790cb38671 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -15,7 +15,9 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ -from typing import Callable, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from typing import Callable, Hashable, Mapping, Sequence import numpy as np @@ -195,13 +197,13 @@ def __init__( self.rand_gaussian_noise = RandGaussianNoise(mean=mean, std=std, prob=1.0, dtype=dtype) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandGaussianNoised": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandGaussianNoised: super().set_random_state(seed, state) self.rand_gaussian_noise.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) self.randomize(None) if not self._do_transform: @@ -253,8 +255,8 @@ def __init__( self, keys: KeysCollection, prob: float = 0.1, - mean: Union[Sequence[float], float] = 0.0, - std: Union[Sequence[float], float] = 1.0, + mean: Sequence[float] | float = 0.0, + std: Sequence[float] | float = 1.0, channel_wise: bool = False, relative: bool = False, sample_std: bool = True, @@ -273,14 +275,12 @@ def __init__( dtype=dtype, ) - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandRicianNoised": + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandRicianNoised: super().set_random_state(seed, state) self.rand_rician_noise.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) self.randomize(None) if not self._do_transform: @@ -305,8 +305,8 @@ def __init__( keys: KeysCollection, offset: float, safe: bool = False, - factor_key: Optional[str] = None, - meta_keys: Optional[KeysCollection] = None, + factor_key: str | None = None, + meta_keys: KeysCollection | None = None, meta_key_postfix: str = DEFAULT_POST_FIX, allow_missing_keys: bool = False, ) -> None: @@ -341,13 +341,13 @@ def __init__( self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) self.shifter = ShiftIntensity(offset, safe) - def __call__(self, data) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, factor_key, meta_key, meta_key_postfix in self.key_iterator( d, self.factor_key, self.meta_keys, self.meta_key_postfix ): meta_key = meta_key or f"{key}_{meta_key_postfix}" - factor: Optional[float] = d[meta_key].get(factor_key) if meta_key in d else None + factor: float | None = d[meta_key].get(factor_key) if meta_key in d else None offset = None if factor is None else self.shifter.offset * factor d[key] = self.shifter(d[key], offset=offset) return d @@ -363,10 +363,10 @@ class RandShiftIntensityd(RandomizableTransform, MapTransform): def __init__( self, keys: KeysCollection, - offsets: Union[Tuple[float, float], float], + offsets: tuple[float, float] | float, safe: bool = False, - factor_key: Optional[str] = None, - meta_keys: Optional[KeysCollection] = None, + factor_key: str | None = None, + meta_keys: KeysCollection | None = None, meta_key_postfix: str = DEFAULT_POST_FIX, prob: float = 0.1, allow_missing_keys: bool = False, @@ -408,13 +408,13 @@ def __init__( self.shifter = RandShiftIntensity(offsets=offsets, safe=safe, prob=1.0) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandShiftIntensityd": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandShiftIntensityd: super().set_random_state(seed, state) self.shifter.set_random_state(seed, state) return self - def __call__(self, data) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) self.randomize(None) if not self._do_transform: @@ -428,7 +428,7 @@ def __call__(self, data) -> Dict[Hashable, NdarrayOrTensor]: d, self.factor_key, self.meta_keys, self.meta_key_postfix ): meta_key = meta_key or f"{key}_{meta_key_postfix}" - factor: Optional[float] = d[meta_key].get(factor_key) if meta_key in d else None + factor: float | None = d[meta_key].get(factor_key) if meta_key in d else None d[key] = self.shifter(d[key], factor=factor, randomize=False) return d @@ -463,7 +463,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.shifter = StdShiftIntensity(factor, nonzero, channel_wise, dtype) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.shifter(d[key]) @@ -480,7 +480,7 @@ class RandStdShiftIntensityd(RandomizableTransform, MapTransform): def __init__( self, keys: KeysCollection, - factors: Union[Tuple[float, float], float], + factors: tuple[float, float] | float, prob: float = 0.1, nonzero: bool = False, channel_wise: bool = False, @@ -506,13 +506,13 @@ def __init__( ) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandStdShiftIntensityd": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandStdShiftIntensityd: super().set_random_state(seed, state) self.shifter.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) self.randomize(None) if not self._do_transform: @@ -539,9 +539,9 @@ class ScaleIntensityd(MapTransform): def __init__( self, keys: KeysCollection, - minv: Optional[float] = 0.0, - maxv: Optional[float] = 1.0, - factor: Optional[float] = None, + minv: float | None = 0.0, + maxv: float | None = 1.0, + factor: float | None = None, channel_wise: bool = False, dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, @@ -563,7 +563,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.scaler = ScaleIntensity(minv, maxv, factor, channel_wise, dtype) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.scaler(d[key]) @@ -580,7 +580,7 @@ class RandScaleIntensityd(RandomizableTransform, MapTransform): def __init__( self, keys: KeysCollection, - factors: Union[Tuple[float, float], float], + factors: tuple[float, float] | float, prob: float = 0.1, dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, @@ -602,13 +602,13 @@ def __init__( self.scaler = RandScaleIntensity(factors=factors, dtype=dtype, prob=1.0) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandScaleIntensityd": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandScaleIntensityd: super().set_random_state(seed, state) self.scaler.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) self.randomize(None) if not self._do_transform: @@ -634,7 +634,7 @@ def __init__( self, keys: KeysCollection, degree: int = 3, - coeff_range: Tuple[float, float] = (0.0, 0.1), + coeff_range: tuple[float, float] = (0.0, 0.1), dtype: DtypeLike = np.float32, prob: float = 0.1, allow_missing_keys: bool = False, @@ -656,14 +656,12 @@ def __init__( self.rand_bias_field = RandBiasField(degree=degree, coeff_range=coeff_range, dtype=dtype, prob=1.0) - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandBiasFieldd": + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandBiasFieldd: super().set_random_state(seed, state) self.rand_bias_field.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) self.randomize(None) if not self._do_transform: @@ -708,8 +706,8 @@ class NormalizeIntensityd(MapTransform): def __init__( self, keys: KeysCollection, - subtrahend: Optional[NdarrayOrTensor] = None, - divisor: Optional[NdarrayOrTensor] = None, + subtrahend: NdarrayOrTensor | None = None, + divisor: NdarrayOrTensor | None = None, nonzero: bool = False, channel_wise: bool = False, dtype: DtypeLike = np.float32, @@ -718,7 +716,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.normalizer = NormalizeIntensity(subtrahend, divisor, nonzero, channel_wise, dtype) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.normalizer(d[key]) @@ -751,7 +749,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.filter = ThresholdIntensity(threshold, above, cval) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.filter(d[key]) @@ -781,8 +779,8 @@ def __init__( keys: KeysCollection, a_min: float, a_max: float, - b_min: Optional[float] = None, - b_max: Optional[float] = None, + b_min: float | None = None, + b_max: float | None = None, clip: bool = False, dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, @@ -790,7 +788,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.scaler = ScaleIntensityRange(a_min, a_max, b_min, b_max, clip, dtype) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.scaler(d[key]) @@ -817,7 +815,7 @@ def __init__(self, keys: KeysCollection, gamma: float, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.adjuster = AdjustContrast(gamma) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.adjuster(d[key]) @@ -846,7 +844,7 @@ def __init__( self, keys: KeysCollection, prob: float = 0.1, - gamma: Union[Tuple[float, float], float] = (0.5, 4.5), + gamma: tuple[float, float] | float = (0.5, 4.5), allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) @@ -854,13 +852,13 @@ def __init__( self.adjuster = RandAdjustContrast(gamma=gamma, prob=1.0) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandAdjustContrastd": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandAdjustContrastd: super().set_random_state(seed, state) self.adjuster.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) self.randomize(None) if not self._do_transform: @@ -901,8 +899,8 @@ def __init__( keys: KeysCollection, lower: float, upper: float, - b_min: Optional[float], - b_max: Optional[float], + b_min: float | None, + b_max: float | None, clip: bool = False, relative: bool = False, channel_wise: bool = False, @@ -912,7 +910,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative, channel_wise, dtype) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.scaler(d[key]) @@ -945,8 +943,8 @@ class MaskIntensityd(MapTransform): def __init__( self, keys: KeysCollection, - mask_data: Optional[NdarrayOrTensor] = None, - mask_key: Optional[str] = None, + mask_data: NdarrayOrTensor | None = None, + mask_key: str | None = None, select_fn: Callable = is_positive, allow_missing_keys: bool = False, ) -> None: @@ -954,7 +952,7 @@ def __init__( self.converter = MaskIntensity(mask_data=mask_data, select_fn=select_fn) self.mask_key = mask_key if mask_data is None else None - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key], d[self.mask_key]) if self.mask_key is not None else self.converter(d[key]) @@ -991,7 +989,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = SavitzkyGolaySmooth(window_length=window_length, order=order, axis=axis, mode=mode) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -1014,13 +1012,11 @@ class MedianSmoothd(MapTransform): backend = MedianSmooth.backend - def __init__( - self, keys: KeysCollection, radius: Union[Sequence[int], int], allow_missing_keys: bool = False - ) -> None: + def __init__(self, keys: KeysCollection, radius: Sequence[int] | int, allow_missing_keys: bool = False) -> None: super().__init__(keys, allow_missing_keys) self.converter = MedianSmooth(radius) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -1048,14 +1044,14 @@ class GaussianSmoothd(MapTransform): def __init__( self, keys: KeysCollection, - sigma: Union[Sequence[float], float], + sigma: Sequence[float] | float, approx: str = "erf", allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) self.converter = GaussianSmooth(sigma, approx=approx) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -1084,9 +1080,9 @@ class RandGaussianSmoothd(RandomizableTransform, MapTransform): def __init__( self, keys: KeysCollection, - sigma_x: Tuple[float, float] = (0.25, 1.5), - sigma_y: Tuple[float, float] = (0.25, 1.5), - sigma_z: Tuple[float, float] = (0.25, 1.5), + sigma_x: tuple[float, float] = (0.25, 1.5), + sigma_y: tuple[float, float] = (0.25, 1.5), + sigma_z: tuple[float, float] = (0.25, 1.5), approx: str = "erf", prob: float = 0.1, allow_missing_keys: bool = False, @@ -1098,13 +1094,13 @@ def __init__( ) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandGaussianSmoothd": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandGaussianSmoothd: super().set_random_state(seed, state) self.rand_smooth.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) self.randomize(None) if not self._do_transform: @@ -1144,8 +1140,8 @@ class GaussianSharpend(MapTransform): def __init__( self, keys: KeysCollection, - sigma1: Union[Sequence[float], float] = 3.0, - sigma2: Union[Sequence[float], float] = 1.0, + sigma1: Sequence[float] | float = 3.0, + sigma2: Sequence[float] | float = 1.0, alpha: float = 30.0, approx: str = "erf", allow_missing_keys: bool = False, @@ -1153,7 +1149,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = GaussianSharpen(sigma1, sigma2, alpha, approx=approx) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -1189,13 +1185,13 @@ class RandGaussianSharpend(RandomizableTransform, MapTransform): def __init__( self, keys: KeysCollection, - sigma1_x: Tuple[float, float] = (0.5, 1.0), - sigma1_y: Tuple[float, float] = (0.5, 1.0), - sigma1_z: Tuple[float, float] = (0.5, 1.0), - sigma2_x: Union[Tuple[float, float], float] = 0.5, - sigma2_y: Union[Tuple[float, float], float] = 0.5, - sigma2_z: Union[Tuple[float, float], float] = 0.5, - alpha: Tuple[float, float] = (10.0, 30.0), + sigma1_x: tuple[float, float] = (0.5, 1.0), + sigma1_y: tuple[float, float] = (0.5, 1.0), + sigma1_z: tuple[float, float] = (0.5, 1.0), + sigma2_x: tuple[float, float] | float = 0.5, + sigma2_y: tuple[float, float] | float = 0.5, + sigma2_z: tuple[float, float] | float = 0.5, + alpha: tuple[float, float] = (10.0, 30.0), approx: str = "erf", prob: float = 0.1, allow_missing_keys: bool = False, @@ -1215,13 +1211,13 @@ def __init__( ) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandGaussianSharpend": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandGaussianSharpend: super().set_random_state(seed, state) self.rand_sharpen.set_random_state(seed, state) return self - def __call__(self, data: Dict[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: dict[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) self.randomize(None) if not self._do_transform: @@ -1256,7 +1252,7 @@ class RandHistogramShiftd(RandomizableTransform, MapTransform): def __init__( self, keys: KeysCollection, - num_control_points: Union[Tuple[int, int], int] = 10, + num_control_points: tuple[int, int] | int = 10, prob: float = 0.1, allow_missing_keys: bool = False, ) -> None: @@ -1265,13 +1261,13 @@ def __init__( self.shifter = RandHistogramShift(num_control_points=num_control_points, prob=1.0) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandHistogramShiftd": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandHistogramShiftd: super().set_random_state(seed, state) self.shifter.set_random_state(seed, state) return self - def __call__(self, data: Dict[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: dict[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) self.randomize(None) if not self._do_transform: @@ -1320,19 +1316,16 @@ def __init__( alpha: Sequence[float] = (0.0, 1.0), allow_missing_keys: bool = False, ) -> None: - MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob=prob) self.rand_gibbs_noise = RandGibbsNoise(alpha=alpha, prob=1.0) - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandGibbsNoised": + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandGibbsNoised: super().set_random_state(seed, state) self.rand_gibbs_noise.set_random_state(seed, state) return self - def __call__(self, data: Dict[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: dict[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) self.randomize(None) if not self._do_transform: @@ -1369,12 +1362,10 @@ class GibbsNoised(MapTransform): backend = GibbsNoise.backend def __init__(self, keys: KeysCollection, alpha: float = 0.5, allow_missing_keys: bool = False) -> None: - MapTransform.__init__(self, keys, allow_missing_keys) self.transform = GibbsNoise(alpha) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.transform(d[key]) @@ -1427,15 +1418,14 @@ class KSpaceSpikeNoised(MapTransform): def __init__( self, keys: KeysCollection, - loc: Union[Tuple, Sequence[Tuple]], - k_intensity: Optional[Union[Sequence[float], float]] = None, + loc: tuple | Sequence[tuple], + k_intensity: Sequence[float] | float | None = None, allow_missing_keys: bool = False, ) -> None: - super().__init__(keys, allow_missing_keys) self.transform = KSpaceSpikeNoise(loc, k_intensity) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: """ Args: data: Expects image/label to have dimensions (C, H, W) or @@ -1490,7 +1480,7 @@ def __init__( self, keys: KeysCollection, prob: float = 0.1, - intensity_range: Optional[Sequence[Union[Sequence[float], float]]] = None, + intensity_range: Sequence[Sequence[float] | float] | None = None, channel_wise: bool = True, allow_missing_keys: bool = False, ): @@ -1499,13 +1489,13 @@ def __init__( self.rand_noise = RandKSpaceSpikeNoise(prob=1.0, intensity_range=intensity_range, channel_wise=channel_wise) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandKSpaceSpikeNoised": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandKSpaceSpikeNoised: super().set_random_state(seed, state) self.rand_noise.set_random_state(seed, state) return self - def __call__(self, data: Dict[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: dict[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) self.randomize(None) if not self._do_transform: @@ -1556,11 +1546,11 @@ def __init__( self, keys: KeysCollection, holes: int, - spatial_size: Union[Sequence[int], int], + spatial_size: Sequence[int] | int, dropout_holes: bool = True, - fill_value: Optional[Union[Tuple[float, float], float]] = None, - max_holes: Optional[int] = None, - max_spatial_size: Optional[Union[Sequence[int], int]] = None, + fill_value: tuple[float, float] | float | None = None, + max_holes: int | None = None, + max_spatial_size: Sequence[int] | int | None = None, prob: float = 0.1, allow_missing_keys: bool = False, ): @@ -1577,13 +1567,13 @@ def __init__( ) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandCoarseDropoutd": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandCoarseDropoutd: super().set_random_state(seed, state) self.dropper.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) self.randomize(None) if not self._do_transform: @@ -1637,9 +1627,9 @@ def __init__( self, keys: KeysCollection, holes: int, - spatial_size: Union[Sequence[int], int], - max_holes: Optional[int] = None, - max_spatial_size: Optional[Union[Sequence[int], int]] = None, + spatial_size: Sequence[int] | int, + max_holes: int | None = None, + max_spatial_size: Sequence[int] | int | None = None, prob: float = 0.1, allow_missing_keys: bool = False, ): @@ -1650,13 +1640,13 @@ def __init__( ) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandCoarseShuffled": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandCoarseShuffled: super().set_random_state(seed, state) self.shuffle.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) self.randomize(None) if not self._do_transform: @@ -1706,8 +1696,8 @@ def __init__( num_bins: int = 256, min: int = 0, max: int = 255, - mask: Optional[NdarrayOrTensor] = None, - mask_key: Optional[str] = None, + mask: NdarrayOrTensor | None = None, + mask_key: str | None = None, dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: @@ -1715,7 +1705,7 @@ def __init__( self.transform = HistogramNormalize(num_bins=num_bins, min=min, max=max, mask=mask, dtype=dtype) self.mask_key = mask_key if mask is None else None - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.transform(d[key], d[self.mask_key]) if self.mask_key is not None else self.transform(d[key]) @@ -1750,17 +1740,17 @@ class ForegroundMaskd(MapTransform): def __init__( self, keys: KeysCollection, - threshold: Union[Dict, Callable, str, float] = "otsu", - hsv_threshold: Optional[Union[Dict, Callable, str, float, int]] = None, + threshold: dict | Callable | str | float = "otsu", + hsv_threshold: dict | Callable | str | float | int | None = None, invert: bool = False, - new_key_prefix: Optional[str] = None, + new_key_prefix: str | None = None, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) self.transform = ForegroundMask(threshold=threshold, hsv_threshold=hsv_threshold, invert=invert) self.new_key_prefix = new_key_prefix - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): new_key = key if self.new_key_prefix is None else self.new_key_prefix + key @@ -1793,7 +1783,7 @@ def __init__( self.transform = ComputeHoVerMaps(dtype=dtype) self.new_key_prefix = new_key_prefix - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): new_key = key if self.new_key_prefix is None else self.new_key_prefix + key diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index db4653ce93..f2a88be481 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -9,17 +9,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import warnings +from collections.abc import Hashable, Mapping from contextlib import contextmanager -from typing import Any, Hashable, Mapping, Optional, Tuple +from typing import Any import torch from monai import transforms +from monai.data.meta_obj import MetaObj, get_track_meta from monai.data.meta_tensor import MetaTensor -from monai.transforms.transform import Transform -from monai.utils.enums import TraceKeys +from monai.data.utils import to_affine_nd +from monai.transforms.transform import LazyTransform, Transform +from monai.utils import LazyAttr, MetaKeys, TraceKeys, convert_to_dst_type, convert_to_numpy, convert_to_tensor __all__ = ["TraceableTransform", "InvertibleTransform"] @@ -69,76 +74,166 @@ def trace_key(key: Hashable = None): return f"{TraceKeys.KEY_SUFFIX}" return f"{key}{TraceKeys.KEY_SUFFIX}" - def get_transform_info( - self, data, key: Hashable = None, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None - ) -> dict: + @staticmethod + def transform_info_keys(): + """The keys to store necessary info of an applied transform.""" + return ( + TraceKeys.CLASS_NAME, + TraceKeys.ID, + TraceKeys.TRACING, + TraceKeys.LAZY_EVALUATION, + TraceKeys.DO_TRANSFORM, + ) + + def get_transform_info(self) -> dict: """ Return a dictionary with the relevant information pertaining to an applied transform. + """ + vals = ( + self.__class__.__name__, + id(self), + self.tracing, + self.lazy_evaluation if isinstance(self, LazyTransform) else False, + self._do_transform if hasattr(self, "_do_transform") else True, + ) + return dict(zip(self.transform_info_keys(), vals)) - Args: - data: input data. Can be dictionary or MetaTensor. We can use `shape` to - determine the original size of the object (unless that has been given - explicitly, see `orig_size`). - key: if data is a dictionary, data[key] will be modified. - extra_info: if desired, any extra information pertaining to the applied - transform can be stored in this dictionary. These are often needed for - computing the inverse transformation. - orig_size: sometimes during the inverse it is useful to know what the size - of the original image was, in which case it can be supplied here. + def push_transform(self, data, *args, **kwargs): + """ + Push to a stack of applied transforms of ``data``. - Returns: - Dictionary of data pertaining to the applied transformation. + Args: + data: dictionary of data or `MetaTensor`. + args: additional positional arguments to track_transform_meta. + kwargs: additional keyword arguments to track_transform_meta, + set ``replace=True`` (default False) to rewrite the last transform infor in + applied_operation/pending_operation based on ``self.get_transform_info()``. """ - info = {TraceKeys.CLASS_NAME: self.__class__.__name__, TraceKeys.ID: id(self)} - if orig_size is not None: - info[TraceKeys.ORIG_SIZE] = orig_size - elif isinstance(data, Mapping) and key in data and hasattr(data[key], "shape"): - info[TraceKeys.ORIG_SIZE] = data[key].shape[1:] - elif hasattr(data, "shape"): - info[TraceKeys.ORIG_SIZE] = data.shape[1:] - if extra_info is not None: - info[TraceKeys.EXTRA_INFO] = extra_info - # If class is randomizable transform, store whether the transform was actually performed (based on `prob`) - if hasattr(self, "_do_transform"): # RandomizableTransform - info[TraceKeys.DO_TRANSFORM] = self._do_transform - return info - - def push_transform( - self, data, key: Hashable = None, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None - ) -> None: + transform_info = self.get_transform_info() + lazy_eval = transform_info.get(TraceKeys.LAZY_EVALUATION, False) + do_transform = transform_info.get(TraceKeys.DO_TRANSFORM, True) + kwargs = kwargs or {} + replace = kwargs.pop("replace", False) # whether to rewrite the most recently pushed transform info + if replace and get_track_meta() and isinstance(data, MetaTensor): + if not lazy_eval: + xform = self.pop_transform(data, check=False) if do_transform else {} + meta_obj = self.push_transform(data, orig_size=xform.get(TraceKeys.ORIG_SIZE), extra_info=xform) + return data.copy_meta_from(meta_obj) + if do_transform: + xform = data.pending_operations.pop() + extra = xform.copy() + xform.update(transform_info) + meta_obj = self.push_transform(data, transform_info=xform, lazy_evaluation=lazy_eval, extra_info=extra) + return data.copy_meta_from(meta_obj) + return data + kwargs["lazy_evaluation"] = lazy_eval + if "transform_info" in kwargs and isinstance(kwargs["transform_info"], dict): + kwargs["transform_info"].update(transform_info) + else: + kwargs["transform_info"] = transform_info + meta_obj = TraceableTransform.track_transform_meta(data, *args, **kwargs) + return data.copy_meta_from(meta_obj) if isinstance(data, MetaTensor) else data + + @classmethod + def track_transform_meta( + cls, + data, + key: Hashable = None, + sp_size=None, + affine=None, + extra_info: dict | None = None, + orig_size: tuple | None = None, + transform_info=None, + lazy_evaluation=False, + ): """ - Push to a stack of applied transforms. + Update a stack of applied/pending transforms metadata of ``data``. Args: data: dictionary of data or `MetaTensor`. key: if data is a dictionary, data[key] will be modified. + sp_size: the expected output spatial size when the transform is applied. + it can be tensor or numpy, but will be converted to a list of integers. + affine: the affine representation of the (spatial) transform in the image space. + When the transform is applied, meta_tensor.affine will be updated to ``meta_tensor.affine @ affine``. extra_info: if desired, any extra information pertaining to the applied transform can be stored in this dictionary. These are often needed for computing the inverse transformation. orig_size: sometimes during the inverse it is useful to know what the size of the original image was, in which case it can be supplied here. + transform_info: info from self.get_transform_info(). + lazy_evaluation: whether to push the transform to pending_operations or applied_operations. Returns: - None, but data has been updated to store the applied transformation. + + For backward compatibility, if ``data`` is a dictionary, it returns the dictionary with + updated ``data[key]``. Otherwise, this function returns a MetaObj with updated transform metadata. """ - if not self.tracing: - return - info = self.get_transform_info(data, key, extra_info, orig_size) + data_t = data[key] if key is not None else data # compatible with the dict data representation + out_obj = MetaObj() + # after deprecating metadict, we should always convert data_t to metatensor here + if isinstance(data_t, MetaTensor): + out_obj.copy_meta_from(data_t, keys=out_obj.__dict__.keys()) + + if lazy_evaluation and (not get_track_meta()): + warnings.warn("metadata is not tracked, please call 'set_track_meta(True)' if doing lazy evaluation.") + + if not lazy_evaluation and affine is not None and isinstance(data_t, MetaTensor): + # not lazy evaluation, directly update the metatensor affine (don't push to the stack) + orig_affine = data_t.peek_pending_affine() + orig_affine = convert_to_dst_type(orig_affine, affine, dtype=torch.float64)[0] + affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=torch.float64) + out_obj.meta[MetaKeys.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu"), dtype=torch.float64) + + if not (get_track_meta() and transform_info and transform_info.get(TraceKeys.TRACING)): + if isinstance(data, Mapping): + if not isinstance(data, dict): + data = dict(data) + data[key] = data_t.copy_meta_from(out_obj) if isinstance(data_t, MetaTensor) else data_t + return data + return out_obj # return with data_t as tensor if get_track_meta() is False + + info = transform_info.copy() + # track the current spatial shape + if orig_size is not None: + info[TraceKeys.ORIG_SIZE] = orig_size + elif isinstance(data_t, MetaTensor): + info[TraceKeys.ORIG_SIZE] = data_t.peek_pending_shape() + elif hasattr(data_t, "shape"): + info[TraceKeys.ORIG_SIZE] = data_t.shape[1:] + # include extra_info + if extra_info is not None: + extra_info.pop(LazyAttr.SHAPE, None) + extra_info.pop(LazyAttr.AFFINE, None) + info[TraceKeys.EXTRA_INFO] = extra_info - if isinstance(data, MetaTensor): - data.push_applied_operation(info) - elif isinstance(data, Mapping): - if key in data and isinstance(data[key], MetaTensor): - data[key].push_applied_operation(info) + # push the transform info to the applied_operation or pending_operation stack + if lazy_evaluation: + if sp_size is None: + if LazyAttr.SHAPE not in info: + warnings.warn("spatial size is None in push transform.") + else: + info[LazyAttr.SHAPE] = tuple(convert_to_numpy(sp_size, wrap_sequence=True).tolist()) + if affine is None: + if LazyAttr.AFFINE not in info: + warnings.warn("affine is None in push transform.") else: - # If this is the first, create list - if self.trace_key(key) not in data: - if not isinstance(data, dict): - data = dict(data) - data[self.trace_key(key)] = [] - data[self.trace_key(key)].append(info) + info[LazyAttr.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu")) + out_obj.push_pending_operation(info) else: - warnings.warn(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}. {info} not tracked.") + out_obj.push_applied_operation(info) + if isinstance(data, Mapping): + if not isinstance(data, dict): + data = dict(data) + if isinstance(data_t, MetaTensor): + data[key] = data_t.copy_meta_from(out_obj) + else: + x_k = TraceableTransform.trace_key(key) + if x_k not in data: + data[x_k] = [] # If this is the first, create list + data[x_k].append(info) + return data + return out_obj def check_transforms_match(self, transform: Mapping) -> None: """Check transforms are of same instance.""" diff --git a/monai/transforms/inverse_batch_transform.py b/monai/transforms/inverse_batch_transform.py index 3bfbb961a8..73149f1be5 100644 --- a/monai/transforms/inverse_batch_transform.py +++ b/monai/transforms/inverse_batch_transform.py @@ -9,8 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from collections.abc import Callable, Sequence +from typing import Any from torch.utils.data import Dataset from torch.utils.data.dataloader import DataLoader as TorchDataLoader @@ -57,8 +60,8 @@ def __init__( self, transform: InvertibleTransform, loader: TorchDataLoader, - collate_fn: Optional[Callable] = no_collation, - num_workers: Optional[int] = 0, + collate_fn: Callable | None = no_collation, + num_workers: int | None = 0, detach: bool = True, pad_batch: bool = True, fill_value=None, @@ -92,7 +95,7 @@ def __init__( loader.collate_fn, PadListDataCollate ) - def __call__(self, data: Dict[str, Any]) -> Any: + def __call__(self, data: dict[str, Any]) -> Any: decollated_data = decollate_batch(data, detach=self.detach, pad=self.pad_batch, fill_value=self.fill_value) inv_ds = _BatchInverseDataset(decollated_data, self.transform, self.pad_collation_used) inv_loader = DataLoader( @@ -129,7 +132,7 @@ class Decollated(MapTransform): def __init__( self, - keys: Optional[KeysCollection] = None, + keys: KeysCollection | None = None, detach: bool = True, pad_batch: bool = True, fill_value=None, @@ -140,8 +143,8 @@ def __init__( self.pad_batch = pad_batch self.fill_value = fill_value - def __call__(self, data: Union[Dict, List]): - d: Union[Dict, List] + def __call__(self, data: dict | list): + d: dict | list if len(self.keys) == 1 and self.keys[0] is None: # it doesn't support `None` as the key d = data diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 5115ace4f0..56d284d46f 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -13,14 +13,16 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ +from __future__ import annotations + import inspect import logging import sys import traceback import warnings +from collections.abc import Sequence from pathlib import Path from pydoc import locate -from typing import Dict, List, Optional, Sequence, Type, Union import numpy as np import torch @@ -42,7 +44,14 @@ from monai.transforms.utility.array import EnsureChannelFirst from monai.utils import GridSamplePadMode from monai.utils import ImageMetaKey as Key -from monai.utils import OptionalImportError, convert_to_dst_type, ensure_tuple, look_up_option, optional_import +from monai.utils import ( + OptionalImportError, + convert_to_dst_type, + deprecated_arg_default, + ensure_tuple, + look_up_option, + optional_import, +) nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") @@ -107,9 +116,9 @@ class LoadImage(Transform): - Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader), (npz, npy -> NumpyReader), (nrrd -> NrrdReader), (DICOM file -> ITKReader). - Please note that for png, jpg, bmp, and other 2D formats, readers often swap axis 0 and 1 after - loading the array because the `HW` definition for non-medical specific file formats is different - from other common medical packages. + Please note that for png, jpg, bmp, and other 2D formats, readers by default swap axis 0 and 1 after + loading the array with ``reverse_indexing`` set to ``True`` because the spatial axes definition + for non-medical specific file formats is different from other common medical packages. See also: @@ -117,15 +126,17 @@ class LoadImage(Transform): """ + @deprecated_arg_default("image_only", False, True, since="1.1", replaced="1.3") def __init__( self, reader=None, image_only: bool = False, - dtype: Optional[DtypeLike] = np.float32, + dtype: DtypeLike | None = np.float32, ensure_channel_first: bool = False, simple_keys: bool = False, - prune_meta_pattern: Optional[str] = None, + prune_meta_pattern: str | None = None, prune_meta_sep: str = ".", + expanduser: bool = True, *args, **kwargs, ) -> None: @@ -148,6 +159,7 @@ def __init__( prune_meta_sep: combined with `prune_meta_pattern`, used to match and prune keys in the metadata (nested dictionary). default is ".", see also :py:class:`monai.transforms.DeleteItemsd`. e.g. ``prune_meta_pattern=".*_code$", prune_meta_sep=" "`` removes meta keys that ends with ``"_code"``. + expanduser: if True cast filename to Path and call .expanduser on it, otherwise keep filename as is. args: additional parameters for reader if providing a reader name. kwargs: additional parameters for reader if providing a reader name. @@ -169,8 +181,9 @@ def __init__( self.simple_keys = simple_keys self.pattern = prune_meta_pattern self.sep = prune_meta_sep + self.expanduser = expanduser - self.readers: List[ImageReader] = [] + self.readers: list[ImageReader] = [] for r in SUPPORTED_READERS: # set predefined readers as default try: self.register(SUPPORTED_READERS[r](*args, **kwargs)) @@ -220,7 +233,7 @@ def register(self, reader: ImageReader): warnings.warn(f"Preferably the reader should inherit ImageReader, but got {type(reader)}.") self.readers.append(reader) - def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Optional[ImageReader] = None): + def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader | None = None): """ Load image file and metadata from the given filename(s). If `reader` is not specified, this class automatically chooses readers based on the @@ -236,7 +249,9 @@ def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Option reader: runtime reader to load image file and metadata. """ - filename = tuple(f"{Path(s).expanduser()}" for s in ensure_tuple(filename)) # allow Path objects + filename = tuple( + f"{Path(s).expanduser()}" if self.expanduser else s for s in ensure_tuple(filename) # allow Path objects + ) img, err = None, [] if reader is not None: img = reader.read(filename) # runtime specified reader @@ -274,7 +289,7 @@ def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Option img_array, meta_data = reader.get_data(img) img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0] if not isinstance(meta_data, dict): - raise ValueError("`meta_data` must be a dict.") + raise ValueError(f"`meta_data` must be a dict, got type {type(meta_data)}.") # make sure all elements in metadata are little endian meta_data = switch_endianness(meta_data, "<") @@ -352,24 +367,25 @@ class SaveImage(Transform): see also: :py:func:`monai.data.folder_layout.default_name_formatter`. """ + @deprecated_arg_default("resample", True, False, since="1.1", replaced="1.3") def __init__( self, output_dir: PathLike = "./", output_postfix: str = "trans", output_ext: str = ".nii.gz", - output_dtype: Optional[DtypeLike] = np.float32, + output_dtype: DtypeLike | None = np.float32, resample: bool = True, mode: str = "nearest", padding_mode: str = GridSamplePadMode.BORDER, - scale: Optional[int] = None, + scale: int | None = None, dtype: DtypeLike = np.float64, squeeze_end_dims: bool = True, data_root_dir: PathLike = "", separate_folder: bool = True, print_log: bool = True, output_format: str = "", - writer: Union[Type[image_writer.ImageWriter], str, None] = None, - channel_dim: Optional[int] = 0, + writer: type[image_writer.ImageWriter] | str | None = None, + channel_dim: int | None = 0, output_name_formatter=None, ) -> None: self.folder_layout = FolderLayout( @@ -425,7 +441,7 @@ def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, writ if write_kwargs is not None: self.write_kwargs.update(write_kwargs) - def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None): + def __call__(self, img: torch.Tensor | np.ndarray, meta_data: dict | None = None): """ Args: img: target data content that save into file. The image should be channel-first, shape: `[C,H,W,[D]]`. diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 9817a92807..ec2e5e403d 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -15,8 +15,9 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ +from __future__ import annotations + from pathlib import Path -from typing import Optional, Type, Union import numpy as np @@ -26,6 +27,7 @@ from monai.transforms.io.array import LoadImage, SaveImage from monai.transforms.transform import MapTransform from monai.utils import GridSamplePadMode, ensure_tuple, ensure_tuple_rep +from monai.utils.deprecate_utils import deprecated_arg_default from monai.utils.enums import PostFix __all__ = ["LoadImaged", "LoadImageD", "LoadImageDict", "SaveImaged", "SaveImageD", "SaveImageDict"] @@ -51,9 +53,9 @@ class LoadImaged(MapTransform): - Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader), (npz, npy -> NumpyReader), (dcm, DICOM series and others -> ITKReader). - Please note that for png, jpg, bmp, and other 2D formats, readers often swap axis 0 and 1 after - loading the array because the `HW` definition for non-medical specific file formats is different - from other common medical packages. + Please note that for png, jpg, bmp, and other 2D formats, readers by default swap axis 0 and 1 after + loading the array with ``reverse_indexing`` set to ``True`` because the spatial axes definition + for non-medical specific file formats is different from other common medical packages. Note: @@ -68,20 +70,22 @@ class LoadImaged(MapTransform): """ + @deprecated_arg_default("image_only", False, True, since="1.1", replaced="1.3") def __init__( self, keys: KeysCollection, - reader: Union[Type[ImageReader], str, None] = None, + reader: type[ImageReader] | str | None = None, dtype: DtypeLike = np.float32, - meta_keys: Optional[KeysCollection] = None, + meta_keys: KeysCollection | None = None, meta_key_postfix: str = DEFAULT_POST_FIX, overwriting: bool = False, image_only: bool = False, ensure_channel_first: bool = False, simple_keys: bool = False, - prune_meta_pattern: Optional[str] = None, + prune_meta_pattern: str | None = None, prune_meta_sep: str = ".", allow_missing_keys: bool = False, + expanduser: bool = True, *args, **kwargs, ) -> None: @@ -117,6 +121,7 @@ def __init__( in the metadata (nested dictionary). default is ".", see also :py:class:`monai.transforms.DeleteItemsd`. e.g. ``prune_meta_pattern=".*_code$", prune_meta_sep=" "`` removes meta keys that ends with ``"_code"``. allow_missing_keys: don't raise exception if key is missing. + expanduser: if True cast filename to Path and call .expanduser on it, otherwise keep filename as is. args: additional parameters for reader if providing a reader name. kwargs: additional parameters for reader if providing a reader name. """ @@ -129,6 +134,7 @@ def __init__( simple_keys, prune_meta_pattern, prune_meta_sep, + expanduser, *args, **kwargs, ) @@ -136,14 +142,16 @@ def __init__( raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.") self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) if len(self.keys) != len(self.meta_keys): - raise ValueError("meta_keys should have the same length as keys.") + raise ValueError( + f"meta_keys should have the same length as keys, got {len(self.keys)} and {len(self.meta_keys)}." + ) self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) self.overwriting = overwriting def register(self, reader: ImageReader): self._loader.register(reader) - def __call__(self, data, reader: Optional[ImageReader] = None): + def __call__(self, data, reader: ImageReader | None = None): """ Raises: KeyError: When not ``self.overwriting`` and key already exists in ``data``. @@ -156,10 +164,12 @@ def __call__(self, data, reader: Optional[ImageReader] = None): d[key] = data else: if not isinstance(data, (tuple, list)): - raise ValueError("loader must return a tuple or list (because image_only=False was used).") + raise ValueError( + f"loader must return a tuple or list (because image_only=False was used), got {type(data)}." + ) d[key] = data[0] if not isinstance(data[1], dict): - raise ValueError("metadata must be a dict.") + raise ValueError(f"metadata must be a dict, got {type(data[1])}.") meta_key = meta_key or f"{key}_{meta_key_postfix}" if meta_key in d and not self.overwriting: raise KeyError(f"Metadata with key {meta_key} already exists and overwriting=False.") @@ -238,27 +248,28 @@ class SaveImaged(MapTransform): """ + @deprecated_arg_default("resample", True, False, since="1.1", replaced="1.3") def __init__( self, keys: KeysCollection, - meta_keys: Optional[KeysCollection] = None, + meta_keys: KeysCollection | None = None, meta_key_postfix: str = DEFAULT_POST_FIX, - output_dir: Union[Path, str] = "./", + output_dir: Path | str = "./", output_postfix: str = "trans", output_ext: str = ".nii.gz", resample: bool = True, mode: str = "nearest", padding_mode: str = GridSamplePadMode.BORDER, - scale: Optional[int] = None, + scale: int | None = None, dtype: DtypeLike = np.float64, - output_dtype: Optional[DtypeLike] = np.float32, + output_dtype: DtypeLike | None = np.float32, allow_missing_keys: bool = False, squeeze_end_dims: bool = True, data_root_dir: str = "", separate_folder: bool = True, print_log: bool = True, output_format: str = "", - writer: Union[Type[image_writer.ImageWriter], str, None] = None, + writer: type[image_writer.ImageWriter] | str | None = None, output_name_formatter=None, ) -> None: super().__init__(keys, allow_missing_keys) diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 5f79f7954e..0a2517cf87 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -9,12 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from __future__ import annotations +from typing import Any + +import numpy as np import torch from monai.data.meta_tensor import MetaTensor -from monai.data.utils import to_affine_nd from monai.transforms.lazy.utils import ( affine_from_pending, combine_transforms, @@ -22,40 +24,83 @@ kwargs_from_pending, resample, ) +from monai.utils import LazyAttr __all__ = ["apply_transforms"] -def apply_transforms(data: Union[torch.Tensor, MetaTensor], pending: Optional[list] = None): +def apply_transforms( + data: torch.Tensor | MetaTensor, + pending: list | None = None, + mode: str | int | None = None, + padding_mode: str | None = None, + dtype=np.float64, + align_corners: bool = False, + resample_mode: str | None = None, +): """ This method applies pending transforms to `data` tensors. + Currently, only 2d and 3d input are supported. Args: data: A torch Tensor or a monai MetaTensor. pending: pending transforms. This must be set if data is a Tensor, but is optional if data is a MetaTensor. + mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers). + Interpolation mode to calculate output values. Defaults to None. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used + and the value represents the order of the spline interpolation. + See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to None. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `mode` is an integer, using numpy/cupy backends, this argument accepts + {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}. + See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html + dtype: data type for resampling computation. Defaults to ``float64``. + If ``None``, use the data type of input data`. + align_corners: Geometrically, we consider the pixels of the input as squares rather than points, when using + the PyTorch resampling backend. Defaults to ``False``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + resample_mode: the mode of resampling, currently support ``"auto"``. Setting to other values will use the + `monai.transforms.SpatialResample` for resampling (instead of potentially crop/pad). """ if isinstance(data, MetaTensor) and pending is None: - pending = data.pending_operations + pending = data.pending_operations.copy() + data.clear_pending_operations() pending = [] if pending is None else pending if not pending: - return data + return data, [] cumulative_xform = affine_from_pending(pending[0]) cur_kwargs = kwargs_from_pending(pending[0]) + override_kwargs: dict[str, Any] = {} + if mode is not None: + override_kwargs[LazyAttr.INTERP_MODE] = mode + if padding_mode is not None: + override_kwargs[LazyAttr.PADDING_MODE] = padding_mode + if align_corners is not None: + override_kwargs[LazyAttr.ALIGN_CORNERS] = align_corners + if resample_mode is not None: + override_kwargs["resample_mode"] = resample_mode + override_kwargs[LazyAttr.DTYPE] = data.dtype if dtype is None else dtype for p in pending[1:]: new_kwargs = kwargs_from_pending(p) if not is_compatible_apply_kwargs(cur_kwargs, new_kwargs): # carry out an intermediate resample here due to incompatibility between arguments - data = resample(data, cumulative_xform, cur_kwargs) + _cur_kwargs = cur_kwargs.copy() + _cur_kwargs.update(override_kwargs) + sp_size = _cur_kwargs.pop(LazyAttr.SHAPE, None) + data = resample(data, cumulative_xform, sp_size, _cur_kwargs) next_matrix = affine_from_pending(p) cumulative_xform = combine_transforms(cumulative_xform, next_matrix) cur_kwargs.update(new_kwargs) - data = resample(data, cumulative_xform, cur_kwargs) + cur_kwargs.update(override_kwargs) + sp_size = cur_kwargs.pop(LazyAttr.SHAPE, None) + data = resample(data, cumulative_xform, sp_size, cur_kwargs) if isinstance(data, MetaTensor): - data.clear_pending_operations() - data.affine = data.affine @ to_affine_nd(3, cumulative_xform) for p in pending: data.push_applied_operation(p) diff --git a/monai/transforms/lazy/utils.py b/monai/transforms/lazy/utils.py index 4e37e78833..1cdd406635 100644 --- a/monai/transforms/lazy/utils.py +++ b/monai/transforms/lazy/utils.py @@ -9,14 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from __future__ import annotations + +import warnings import numpy as np import torch import monai from monai.config import NdarrayOrTensor -from monai.utils import LazyAttr, convert_to_tensor +from monai.data.utils import AFFINE_TOL +from monai.transforms.utils_pytorch_numpy_unification import allclose +from monai.utils import LazyAttr, convert_to_numpy, convert_to_tensor __all__ = ["resample", "combine_transforms"] @@ -38,7 +42,7 @@ def is_affine_shaped(data): return False if not hasattr(data, "shape") or len(data.shape) < 2: return False - return data.shape[-1] in (3, 4) and data.shape[-2] in (3, 4) and data.shape[-1] == data.shape[-2] + return data.shape[-1] in (3, 4) and data.shape[-1] == data.shape[-2] class DisplacementField: @@ -97,7 +101,7 @@ def kwargs_from_pending(pending_item): ret[LazyAttr.SHAPE] = pending_item[LazyAttr.SHAPE] if LazyAttr.DTYPE in pending_item: ret[LazyAttr.DTYPE] = pending_item[LazyAttr.DTYPE] - return ret + return ret # adding support of pending_item['extra_info']?? def is_compatible_apply_kwargs(kwargs_1, kwargs_2): @@ -105,21 +109,113 @@ def is_compatible_apply_kwargs(kwargs_1, kwargs_2): return True -def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: Optional[dict] = None): +def requires_interp(matrix, atol=AFFINE_TOL): + """ + Check whether the transformation matrix suggests voxel-wise interpolation. + + Returns None if the affine matrix suggests interpolation. + Otherwise, the matrix suggests that the resampling could be achieved by simple array operations + such as flip/permute/pad_nd/slice; in this case this function returns axes information about simple axes + operations. + + Args: + matrix: the affine matrix to check. + atol: absolute tolerance for checking if the matrix is close to an integer. """ - This is a minimal implementation of resample that always uses Affine. + matrix = convert_to_numpy(matrix, wrap_sequence=True) + s = matrix[:, -1] + if not np.allclose(s, np.round(s), atol=atol): + return None + + ndim = len(matrix) - 1 + ox, oy = [], [0] + for x, r in enumerate(matrix[:ndim, :ndim]): + for y, c in enumerate(r): + if np.isclose(c, -1, atol=atol) or np.isclose(c, 1, atol=atol): + y_channel = y + 1 # the returned axis index starting with channel dim + if x in ox or y_channel in oy: + return None + else: + ox.append(x) + oy.append(y_channel) + elif not np.isclose(c, 0.0, atol=atol): + return None + return oy + + +def resample(data: torch.Tensor, matrix: NdarrayOrTensor, spatial_size, kwargs: dict | None = None): + """ + Resample `data` using the affine transformation defined by ``matrix`` and output spatial size ``spatial_size``. + + Args: + data: input data to be resampled. + matrix: affine transformation matrix. + spatial_size: output spatial size. + kwargs: currently supports (see also: ``monai.utils.enums.LazyAttr``) + - "lazy_dtype" + - "lazy_padding_mode" + - "lazy_interpolation_mode" (this option might be ignored when ``mode="auto"``.) + - "lazy_align_corners" + - "atol" for tolerance for matrix floating point comparison. + - "resample_mode" for resampling backend, default to `"auto"`. Setting to other values will use the + `monai.transforms.SpatialResample` for resampling. + + See Also: + :py:class:`monai.transforms.SpatialResample` """ if not Affine.is_affine_shaped(matrix): - raise NotImplementedError("calling dense grid resample API not implemented") + raise NotImplementedError(f"Calling the dense grid resample API directly not implemented, {matrix.shape}.") + if isinstance(data, monai.data.MetaTensor) and data.pending_operations: + warnings.warn("data.pending_operations is not empty, the resampling output may be incorrect.") kwargs = {} if kwargs is None else kwargs + atol = kwargs.pop("atol", AFFINE_TOL) + mode = kwargs.pop("resample_mode", "auto") + init_kwargs = { - "spatial_size": kwargs.pop(LazyAttr.SHAPE, data.shape)[1:], "dtype": kwargs.pop(LazyAttr.DTYPE, data.dtype), + "align_corners": kwargs.pop(LazyAttr.ALIGN_CORNERS, False), } + ndim = len(matrix) - 1 + img = convert_to_tensor(data=data, track_meta=monai.data.get_track_meta()) + init_affine = monai.data.to_affine_nd(ndim, img.affine) + out_spatial_size = img.peek_pending_shape() if spatial_size is None else spatial_size + out_spatial_size = convert_to_numpy(out_spatial_size, wrap_sequence=True) call_kwargs = { + "spatial_size": out_spatial_size, + "dst_affine": init_affine @ monai.utils.convert_to_dst_type(matrix, init_affine)[0], "mode": kwargs.pop(LazyAttr.INTERP_MODE, None), "padding_mode": kwargs.pop(LazyAttr.PADDING_MODE, None), } - resampler = monai.transforms.Affine(affine=matrix, image_only=True, **init_kwargs) - with resampler.trace_transform(False): # don't track this transform in `data` - return resampler(img=data, **call_kwargs) + + axes = requires_interp(matrix, atol=atol) + if axes is not None and mode == "auto" and not init_kwargs["align_corners"]: + matrix_np = np.round(convert_to_numpy(matrix, wrap_sequence=True)) + full_transpose = np.argsort(axes).tolist() + if not np.allclose(full_transpose, np.arange(len(full_transpose))): + img = img.permute(full_transpose[: len(img.shape)]) + in_shape = img.shape[1 : ndim + 1] # requires that ``img`` has empty pending operations + matrix_np[:ndim] = matrix_np[[x - 1 for x in full_transpose[1:]]] + flip = [idx + 1 for idx, val in enumerate(matrix_np[:ndim]) if val[idx] == -1] + if flip: + img = torch.flip(img, dims=flip) # todo: if on cpu, using the np.flip is faster? + for f in flip: + ind_f = f - 1 + matrix_np[ind_f, ind_f] = 1 + matrix_np[ind_f, -1] = in_shape[ind_f] - 1 - matrix_np[ind_f, -1] + if not np.all(out_spatial_size > 0): + raise ValueError(f"Resampling out_spatial_size should be positive, got {out_spatial_size}.") + if ( + allclose(matrix_np, np.eye(len(matrix_np)), atol=atol) + and len(in_shape) == len(out_spatial_size) + and allclose(convert_to_numpy(in_shape, wrap_sequence=True), out_spatial_size) + ): + img.affine = call_kwargs["dst_affine"] + return img + img = monai.transforms.crop_or_pad_nd(img, matrix_np, out_spatial_size, mode=call_kwargs["padding_mode"]) + img.affine = call_kwargs["dst_affine"] + return img + + resampler = monai.transforms.SpatialResample(**init_kwargs) + resampler.lazy_evaluation = False # resampler is a lazytransform + with resampler.trace_transform(False): # don't track this transform in `img` + return resampler(img=img, **call_kwargs) diff --git a/monai/transforms/meta_utility/dictionary.py b/monai/transforms/meta_utility/dictionary.py index bef228f423..ed752bb2d7 100644 --- a/monai/transforms/meta_utility/dictionary.py +++ b/monai/transforms/meta_utility/dictionary.py @@ -15,7 +15,9 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ -from typing import Dict, Hashable, Mapping, Sequence, Union +from __future__ import annotations + +from collections.abc import Hashable, Mapping, Sequence import numpy as np import torch @@ -48,7 +50,7 @@ class FromMetaTensord(MapTransform, InvertibleTransform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY, TransformBackends.CUPY] def __init__( - self, keys: KeysCollection, data_type: Union[Sequence[str], str] = "tensor", allow_missing_keys: bool = False + self, keys: KeysCollection, data_type: Sequence[str] | str = "tensor", allow_missing_keys: bool = False ): """ Args: @@ -60,7 +62,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.as_tensor_output = tuple(d == "tensor" for d in ensure_tuple_rep(data_type, len(self.keys))) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, t in self.key_iterator(d, self.as_tensor_output): im: MetaTensor = d[key] # type: ignore @@ -68,7 +70,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N self.push_transform(d, key) return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): # check transform @@ -94,7 +96,7 @@ class ToMetaTensord(MapTransform, InvertibleTransform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY, TransformBackends.CUPY] - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): self.push_transform(d, key) @@ -105,7 +107,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N d[key] = im return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): # check transform diff --git a/monai/transforms/nvtx.py b/monai/transforms/nvtx.py index b6e0b0d465..69b38e588b 100644 --- a/monai/transforms/nvtx.py +++ b/monai/transforms/nvtx.py @@ -12,7 +12,10 @@ Wrapper around NVIDIA Tools Extension for profiling MONAI transformations """ -from monai.transforms.transform import RandomizableTrait, Transform +from __future__ import annotations + +from monai.transforms.traits import RandomizableTrait +from monai.transforms.transform import Transform from monai.utils import optional_import _nvtx, _ = optional_import("torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?") diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index b05f0b9d61..8e0c642d8b 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -13,8 +13,10 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ +from __future__ import annotations + import warnings -from typing import Callable, Iterable, Optional, Sequence, Tuple, Union +from collections.abc import Callable, Iterable, Sequence import numpy as np import torch @@ -76,9 +78,7 @@ class Activations(Transform): backend = [TransformBackends.TORCH] - def __init__( - self, sigmoid: bool = False, softmax: bool = False, other: Optional[Callable] = None, **kwargs - ) -> None: + def __init__(self, sigmoid: bool = False, softmax: bool = False, other: Callable | None = None, **kwargs) -> None: self.sigmoid = sigmoid self.softmax = softmax self.kwargs = kwargs @@ -89,9 +89,9 @@ def __init__( def __call__( self, img: NdarrayOrTensor, - sigmoid: Optional[bool] = None, - softmax: Optional[bool] = None, - other: Optional[Callable] = None, + sigmoid: bool | None = None, + softmax: bool | None = None, + other: Callable | None = None, ) -> NdarrayOrTensor: """ Args: @@ -171,9 +171,9 @@ class AsDiscrete(Transform): def __init__( self, argmax: bool = False, - to_onehot: Optional[int] = None, - threshold: Optional[float] = None, - rounding: Optional[str] = None, + to_onehot: int | None = None, + threshold: float | None = None, + rounding: str | None = None, **kwargs, ) -> None: self.argmax = argmax @@ -187,10 +187,10 @@ def __init__( def __call__( self, img: NdarrayOrTensor, - argmax: Optional[bool] = None, - to_onehot: Optional[int] = None, - threshold: Optional[float] = None, - rounding: Optional[str] = None, + argmax: bool | None = None, + to_onehot: int | None = None, + threshold: float | None = None, + rounding: str | None = None, ) -> NdarrayOrTensor: """ Args: @@ -216,7 +216,7 @@ def __call__( to_onehot = self.to_onehot if to_onehot is None else to_onehot if to_onehot is not None: if not isinstance(to_onehot, int): - raise ValueError("the number of classes for One-Hot must be an integer.") + raise ValueError(f"the number of classes for One-Hot must be an integer, got {type(to_onehot)}.") img_t = one_hot( img_t, num_classes=to_onehot, dim=self.kwargs.get("dim", 0), dtype=self.kwargs.get("dtype", torch.float) ) @@ -282,10 +282,10 @@ class KeepLargestConnectedComponent(Transform): def __init__( self, - applied_labels: Optional[Union[Sequence[int], int]] = None, - is_onehot: Optional[bool] = None, + applied_labels: Sequence[int] | int | None = None, + is_onehot: bool | None = None, independent: bool = True, - connectivity: Optional[int] = None, + connectivity: int | None = None, num_components: int = 1, ) -> None: """ @@ -409,7 +409,7 @@ class LabelFilter(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, applied_labels: Union[Iterable[int], int]) -> None: + def __init__(self, applied_labels: Iterable[int] | int) -> None: """ Initialize the LabelFilter class with the labels to filter on. @@ -489,9 +489,7 @@ class FillHoles(Transform): backend = [TransformBackends.NUMPY] - def __init__( - self, applied_labels: Optional[Union[Iterable[int], int]] = None, connectivity: Optional[int] = None - ) -> None: + def __init__(self, applied_labels: Iterable[int] | int | None = None, connectivity: int | None = None) -> None: """ Initialize the connectivity and limit the labels for which holes are filled. @@ -582,7 +580,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: class Ensemble: @staticmethod - def get_stacked_torch(img: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor]) -> torch.Tensor: + def get_stacked_torch(img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> torch.Tensor: """Get either a sequence or single instance of np.ndarray/torch.Tensor. Return single torch.Tensor.""" if isinstance(img, Sequence) and isinstance(img[0], np.ndarray): img = [torch.as_tensor(i) for i in img] @@ -592,7 +590,7 @@ def get_stacked_torch(img: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor]) -> return out @staticmethod - def post_convert(img: torch.Tensor, orig_img: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor]) -> NdarrayOrTensor: + def post_convert(img: torch.Tensor, orig_img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayOrTensor: orig_img_ = orig_img[0] if isinstance(orig_img, Sequence) else orig_img out, *_ = convert_to_dst_type(img, orig_img_) return out @@ -623,10 +621,10 @@ class MeanEnsemble(Ensemble, Transform): backend = [TransformBackends.TORCH] - def __init__(self, weights: Optional[Union[Sequence[float], NdarrayOrTensor]] = None) -> None: + def __init__(self, weights: Sequence[float] | NdarrayOrTensor | None = None) -> None: self.weights = torch.as_tensor(weights, dtype=torch.float) if weights is not None else None - def __call__(self, img: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor]) -> NdarrayOrTensor: + def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayOrTensor: img_ = self.get_stacked_torch(img) if self.weights is not None: self.weights = self.weights.to(img_.device) @@ -663,10 +661,10 @@ class VoteEnsemble(Ensemble, Transform): backend = [TransformBackends.TORCH] - def __init__(self, num_classes: Optional[int] = None) -> None: + def __init__(self, num_classes: int | None = None) -> None: self.num_classes = num_classes - def __call__(self, img: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor]) -> NdarrayOrTensor: + def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayOrTensor: img_ = self.get_stacked_torch(img) if self.num_classes is not None: @@ -726,9 +724,9 @@ class ProbNMS(Transform): def __init__( self, spatial_dims: int = 2, - sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 0.0, + sigma: Sequence[float] | float | Sequence[torch.Tensor] | torch.Tensor = 0.0, prob_threshold: float = 0.5, - box_size: Union[int, Sequence[int]] = 48, + box_size: int | Sequence[int] = 48, ) -> None: self.sigma = sigma self.spatial_dims = spatial_dims @@ -787,11 +785,11 @@ class Invert(Transform): def __init__( self, - transform: Optional[InvertibleTransform] = None, - nearest_interp: Union[bool, Sequence[bool]] = True, - device: Union[str, torch.device, None] = None, - post_func: Optional[Callable] = None, - to_tensor: Union[bool, Sequence[bool]] = True, + transform: InvertibleTransform | None = None, + nearest_interp: bool | Sequence[bool] = True, + device: str | torch.device | None = None, + post_func: Callable | None = None, + to_tensor: bool | Sequence[bool] = True, ) -> None: """ Args: @@ -852,7 +850,7 @@ class SobelGradients(Transform): def __init__( self, kernel_size: int = 3, - spatial_axes: Optional[Union[Sequence[int], int]] = None, + spatial_axes: Sequence[int] | int | None = None, normalize_kernels: bool = True, normalize_gradients: bool = False, padding_mode: str = "reflect", @@ -865,7 +863,7 @@ def __init__( self.normalize_gradients = normalize_gradients self.kernel_diff, self.kernel_smooth = self._get_kernel(kernel_size, dtype) - def _get_kernel(self, size, dtype) -> Tuple[torch.Tensor, torch.Tensor]: + def _get_kernel(self, size, dtype) -> tuple[torch.Tensor, torch.Tensor]: if size < 3: raise ValueError(f"Sobel kernel size should be at least three. {size} was given.") if size % 2 == 0: diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 9c554321ba..3fbfe46118 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -15,9 +15,12 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ +from __future__ import annotations + import warnings +from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence from copy import deepcopy -from typing import Any, Callable, Dict, Hashable, Iterable, List, Mapping, Optional, Sequence, Union +from typing import Any import numpy as np import torch @@ -104,9 +107,9 @@ class Activationsd(MapTransform): def __init__( self, keys: KeysCollection, - sigmoid: Union[Sequence[bool], bool] = False, - softmax: Union[Sequence[bool], bool] = False, - other: Optional[Union[Sequence[Callable], Callable]] = None, + sigmoid: Sequence[bool] | bool = False, + softmax: Sequence[bool] | bool = False, + other: Sequence[Callable] | Callable | None = None, allow_missing_keys: bool = False, **kwargs, ) -> None: @@ -133,7 +136,7 @@ def __init__( self.converter = Activations() self.converter.kwargs = kwargs - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, sigmoid, softmax, other in self.key_iterator(d, self.sigmoid, self.softmax, self.other): d[key] = self.converter(d[key], sigmoid, softmax, other) @@ -150,10 +153,10 @@ class AsDiscreted(MapTransform): def __init__( self, keys: KeysCollection, - argmax: Union[Sequence[bool], bool] = False, - to_onehot: Union[Sequence[Optional[int]], Optional[int]] = None, - threshold: Union[Sequence[Optional[float]], Optional[float]] = None, - rounding: Union[Sequence[Optional[str]], Optional[str]] = None, + argmax: Sequence[bool] | bool = False, + to_onehot: Sequence[int | None] | int | None = None, + threshold: Sequence[float | None] | float | None = None, + rounding: Sequence[str | None] | str | None = None, allow_missing_keys: bool = False, **kwargs, ) -> None: @@ -194,7 +197,7 @@ def __init__( self.converter = AsDiscrete() self.converter.kwargs = kwargs - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, argmax, to_onehot, threshold, rounding in self.key_iterator( d, self.argmax, self.to_onehot, self.threshold, self.rounding @@ -213,10 +216,10 @@ class KeepLargestConnectedComponentd(MapTransform): def __init__( self, keys: KeysCollection, - applied_labels: Optional[Union[Sequence[int], int]] = None, - is_onehot: Optional[bool] = None, + applied_labels: Sequence[int] | int | None = None, + is_onehot: bool | None = None, independent: bool = True, - connectivity: Optional[int] = None, + connectivity: int | None = None, num_components: int = 1, allow_missing_keys: bool = False, ) -> None: @@ -251,7 +254,7 @@ def __init__( num_components=num_components, ) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -286,7 +289,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = RemoveSmallObjects(min_size, connectivity, independent_channels) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -301,7 +304,7 @@ class LabelFilterd(MapTransform): backend = LabelFilter.backend def __init__( - self, keys: KeysCollection, applied_labels: Union[Sequence[int], int], allow_missing_keys: bool = False + self, keys: KeysCollection, applied_labels: Sequence[int] | int, allow_missing_keys: bool = False ) -> None: """ Args: @@ -314,7 +317,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = LabelFilter(applied_labels) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -331,8 +334,8 @@ class FillHolesd(MapTransform): def __init__( self, keys: KeysCollection, - applied_labels: Optional[Union[Iterable[int], int]] = None, - connectivity: Optional[int] = None, + applied_labels: Iterable[int] | int | None = None, + connectivity: int | None = None, allow_missing_keys: bool = False, ) -> None: """ @@ -351,7 +354,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = FillHoles(applied_labels=applied_labels, connectivity=connectivity) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -377,7 +380,7 @@ def __init__(self, keys: KeysCollection, kernel_type: str = "Laplace", allow_mis super().__init__(keys, allow_missing_keys) self.converter = LabelToContour(kernel_type=kernel_type) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -395,8 +398,8 @@ class Ensembled(MapTransform): def __init__( self, keys: KeysCollection, - ensemble: Callable[[Union[Sequence[NdarrayOrTensor], NdarrayOrTensor]], NdarrayOrTensor], - output_key: Optional[str] = None, + ensemble: Callable[[Sequence[NdarrayOrTensor] | NdarrayOrTensor], NdarrayOrTensor], + output_key: str | None = None, allow_missing_keys: bool = False, ) -> None: """ @@ -421,9 +424,9 @@ def __init__( raise ValueError("Incompatible values: len(self.keys) > 1 and output_key=None.") self.output_key = output_key if output_key is not None else self.keys[0] - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) - items: Union[List[NdarrayOrTensor], NdarrayOrTensor] + items: list[NdarrayOrTensor] | NdarrayOrTensor if len(self.keys) == 1 and self.keys[0] in d: items = d[self.keys[0]] else: @@ -445,8 +448,8 @@ class MeanEnsembled(Ensembled): def __init__( self, keys: KeysCollection, - output_key: Optional[str] = None, - weights: Optional[Union[Sequence[float], NdarrayOrTensor]] = None, + output_key: str | None = None, + weights: Sequence[float] | NdarrayOrTensor | None = None, ) -> None: """ Args: @@ -477,9 +480,7 @@ class VoteEnsembled(Ensembled): backend = VoteEnsemble.backend - def __init__( - self, keys: KeysCollection, output_key: Optional[str] = None, num_classes: Optional[int] = None - ) -> None: + def __init__(self, keys: KeysCollection, output_key: str | None = None, num_classes: int | None = None) -> None: """ Args: keys: keys of the corresponding items to be stack and execute ensemble. @@ -531,9 +532,9 @@ def __init__( self, keys: KeysCollection, spatial_dims: int = 2, - sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 0.0, + sigma: Sequence[float] | float | Sequence[torch.Tensor] | torch.Tensor = 0.0, prob_threshold: float = 0.5, - box_size: Union[int, Sequence[int]] = 48, + box_size: int | Sequence[int] = 48, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) @@ -581,14 +582,14 @@ def __init__( self, keys: KeysCollection, transform: InvertibleTransform, - orig_keys: Optional[KeysCollection] = None, - meta_keys: Optional[KeysCollection] = None, - orig_meta_keys: Optional[KeysCollection] = None, + orig_keys: KeysCollection | None = None, + meta_keys: KeysCollection | None = None, + orig_meta_keys: KeysCollection | None = None, meta_key_postfix: str = DEFAULT_POST_FIX, - nearest_interp: Union[bool, Sequence[bool]] = True, - to_tensor: Union[bool, Sequence[bool]] = True, - device: Union[Union[str, torch.device], Sequence[Union[str, torch.device]], None] = None, - post_func: Union[Callable, Sequence[Callable], None] = None, + nearest_interp: bool | Sequence[bool] = True, + to_tensor: bool | Sequence[bool] = True, + device: str | torch.device | Sequence[str | torch.device] | None = None, + post_func: Callable | Sequence[Callable] | None = None, allow_missing_keys: bool = False, ) -> None: """ @@ -638,7 +639,7 @@ def __init__( self.post_func = ensure_tuple_rep(post_func, len(self.keys)) self._totensor = ToTensor() - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]: d = dict(data) for ( key, @@ -728,9 +729,9 @@ class SaveClassificationd(MapTransform): def __init__( self, keys: KeysCollection, - meta_keys: Optional[KeysCollection] = None, + meta_keys: KeysCollection | None = None, meta_key_postfix: str = DEFAULT_POST_FIX, - saver: Optional[CSVSaver] = None, + saver: CSVSaver | None = None, output_dir: PathLike = "./", filename: str = "predictions.csv", delimiter: str = ",", @@ -824,12 +825,12 @@ def __init__( self, keys: KeysCollection, kernel_size: int = 3, - spatial_axes: Optional[Union[Sequence[int], int]] = None, + spatial_axes: Sequence[int] | int | None = None, normalize_kernels: bool = True, normalize_gradients: bool = False, padding_mode: str = "reflect", dtype: torch.dtype = torch.float32, - new_key_prefix: Optional[str] = None, + new_key_prefix: str | None = None, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) @@ -845,7 +846,7 @@ def __init__( self.kernel_diff = self.transform.kernel_diff self.kernel_smooth = self.transform.kernel_smooth - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): new_key = key if self.new_key_prefix is None else self.new_key_prefix + key diff --git a/monai/transforms/signal/array.py b/monai/transforms/signal/array.py index 7b619a4b39..59b267c151 100644 --- a/monai/transforms/signal/array.py +++ b/monai/transforms/signal/array.py @@ -13,8 +13,11 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ +from __future__ import annotations + import warnings -from typing import Any, Optional, Sequence +from collections.abc import Sequence +from typing import Any import numpy as np import torch @@ -57,7 +60,7 @@ class SignalRandShift(RandomizableTransform): backend = [TransformBackends.NUMPY, TransformBackends.TORCH] def __init__( - self, mode: Optional[str] = "wrap", filling: Optional[float] = 0.0, boundaries: Sequence[float] = (-1.0, 1.0) + self, mode: str | None = "wrap", filling: float | None = 0.0, boundaries: Sequence[float] = (-1.0, 1.0) ) -> None: """ Args: @@ -390,10 +393,7 @@ class SignalRemoveFrequency(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( - self, - frequency: Optional[float] = None, - quality_factor: Optional[float] = None, - sampling_freq: Optional[float] = None, + self, frequency: float | None = None, quality_factor: float | None = None, sampling_freq: float | None = None ) -> None: """ Args: diff --git a/monai/transforms/smooth_field/array.py b/monai/transforms/smooth_field/array.py index 13507339e1..c9df5f1dbb 100644 --- a/monai/transforms/smooth_field/array.py +++ b/monai/transforms/smooth_field/array.py @@ -10,7 +10,10 @@ # limitations under the License. """Transforms using a smooth spatial field generated by interpolating from smaller randomized fields.""" -from typing import Any, Optional, Sequence, Union +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any import numpy as np import torch @@ -61,10 +64,10 @@ def __init__( low: float = -1.0, high: float = 1.0, channels: int = 1, - spatial_size: Optional[Sequence[int]] = None, + spatial_size: Sequence[int] | None = None, mode: str = InterpolateMode.AREA, - align_corners: Optional[bool] = None, - device: Optional[torch.device] = None, + align_corners: bool | None = None, + device: torch.device | None = None, ): self.rand_size = tuple(rand_size) self.pad = pad @@ -75,8 +78,8 @@ def __init__( self.align_corners = align_corners self.device = device - self.spatial_size: Optional[Sequence[int]] = None - self.spatial_zoom: Optional[Sequence[float]] = None + self.spatial_size: Sequence[int] | None = None + self.spatial_zoom: Sequence[float] | None = None if low >= high: raise ValueError("Value for `low` must be less than `high` otherwise field will be zeros") @@ -92,10 +95,10 @@ def __init__( self.set_spatial_size(spatial_size) - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, data: Any | None = None) -> None: self.field[self.rand_slices] = torch.from_numpy(self.R.uniform(self.low, self.high, self.crand_size)) - def set_spatial_size(self, spatial_size: Optional[Sequence[int]]) -> None: + def set_spatial_size(self, spatial_size: Sequence[int] | None) -> None: """ Set the `spatial_size` and `spatial_zoom` attributes used for interpolating the field to the given dimension, or not interpolate at all if None. @@ -169,10 +172,10 @@ def __init__( rand_size: Sequence[int], pad: int = 0, mode: str = InterpolateMode.AREA, - align_corners: Optional[bool] = None, + align_corners: bool | None = None, prob: float = 0.1, - gamma: Union[Sequence[float], float] = (0.5, 4.5), - device: Optional[torch.device] = None, + gamma: Sequence[float] | float = (0.5, 4.5), + device: torch.device | None = None, ): super().__init__(prob) @@ -198,13 +201,13 @@ def __init__( ) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandSmoothFieldAdjustContrast": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandSmoothFieldAdjustContrast: super().set_random_state(seed, state) self.sfield.set_random_state(seed, state) return self - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, data: Any | None = None) -> None: super().randomize(None) if self._do_transform: @@ -270,10 +273,10 @@ def __init__( rand_size: Sequence[int], pad: int = 0, mode: str = InterpolateMode.AREA, - align_corners: Optional[bool] = None, + align_corners: bool | None = None, prob: float = 0.1, - gamma: Union[Sequence[float], float] = (0.1, 1.0), - device: Optional[torch.device] = None, + gamma: Sequence[float] | float = (0.1, 1.0), + device: torch.device | None = None, ): super().__init__(prob) @@ -299,13 +302,13 @@ def __init__( ) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandSmoothFieldAdjustIntensity": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandSmoothFieldAdjustIntensity: super().set_random_state(seed, state) self.sfield.set_random_state(seed, state) return self - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, data: Any | None = None) -> None: super().randomize(None) if self._do_transform: @@ -367,14 +370,14 @@ def __init__( rand_size: Sequence[int], pad: int = 0, field_mode: str = InterpolateMode.AREA, - align_corners: Optional[bool] = None, + align_corners: bool | None = None, prob: float = 0.1, - def_range: Union[Sequence[float], float] = 1.0, + def_range: Sequence[float] | float = 1.0, grid_dtype=torch.float32, grid_mode: str = GridSampleMode.NEAREST, grid_padding_mode: str = GridSamplePadMode.BORDER, - grid_align_corners: Optional[bool] = False, - device: Optional[torch.device] = None, + grid_align_corners: bool | None = False, + device: torch.device | None = None, ): super().__init__(prob) @@ -412,14 +415,12 @@ def __init__( self.grid = torch.stack(grid).unsqueeze(0).to(self.device, self.grid_dtype) - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "Randomizable": + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Randomizable: super().set_random_state(seed, state) self.sfield.set_random_state(seed, state) return self - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, data: Any | None = None) -> None: super().randomize(None) if self._do_transform: @@ -432,7 +433,7 @@ def set_grid_mode(self, mode: str) -> None: self.grid_mode = mode def __call__( - self, img: NdarrayOrTensor, randomize: bool = True, device: Optional[torch.device] = None + self, img: NdarrayOrTensor, randomize: bool = True, device: torch.device | None = None ) -> NdarrayOrTensor: img = convert_to_tensor(img, track_meta=get_track_meta()) if randomize: diff --git a/monai/transforms/smooth_field/dictionary.py b/monai/transforms/smooth_field/dictionary.py index 08fb71edb4..99d19064f8 100644 --- a/monai/transforms/smooth_field/dictionary.py +++ b/monai/transforms/smooth_field/dictionary.py @@ -9,7 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Hashable, Mapping, Optional, Sequence, Union +from __future__ import annotations + +from collections.abc import Hashable, Mapping, Sequence +from typing import Any import numpy as np import torch @@ -67,10 +70,10 @@ def __init__( rand_size: Sequence[int], pad: int = 0, mode: SequenceStr = InterpolateMode.AREA, - align_corners: Optional[bool] = None, + align_corners: bool | None = None, prob: float = 0.1, - gamma: Union[Sequence[float], float] = (0.5, 4.5), - device: Optional[torch.device] = None, + gamma: Sequence[float] | float = (0.5, 4.5), + device: torch.device | None = None, ): RandomizableTransform.__init__(self, prob) MapTransform.__init__(self, keys) @@ -89,13 +92,13 @@ def __init__( ) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandSmoothFieldAdjustContrastd": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandSmoothFieldAdjustContrastd: super().set_random_state(seed, state) self.trans.set_random_state(seed, state) return self - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, data: Any | None = None) -> None: super().randomize(None) if self._do_transform: @@ -145,10 +148,10 @@ def __init__( rand_size: Sequence[int], pad: int = 0, mode: SequenceStr = InterpolateMode.AREA, - align_corners: Optional[bool] = None, + align_corners: bool | None = None, prob: float = 0.1, - gamma: Union[Sequence[float], float] = (0.1, 1.0), - device: Optional[torch.device] = None, + gamma: Sequence[float] | float = (0.1, 1.0), + device: torch.device | None = None, ): RandomizableTransform.__init__(self, prob) MapTransform.__init__(self, keys) @@ -167,13 +170,13 @@ def __init__( ) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandSmoothFieldAdjustIntensityd": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandSmoothFieldAdjustIntensityd: super().set_random_state(seed, state) self.trans.set_random_state(seed, state) return self - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, data: Any | None = None) -> None: super().randomize(None) self.trans.randomize() @@ -226,14 +229,14 @@ def __init__( rand_size: Sequence[int], pad: int = 0, field_mode: SequenceStr = InterpolateMode.AREA, - align_corners: Optional[bool] = None, + align_corners: bool | None = None, prob: float = 0.1, - def_range: Union[Sequence[float], float] = 1.0, + def_range: Sequence[float] | float = 1.0, grid_dtype=torch.float32, grid_mode: SequenceStr = GridSampleMode.NEAREST, grid_padding_mode: str = GridSamplePadMode.BORDER, - grid_align_corners: Optional[bool] = False, - device: Optional[torch.device] = None, + grid_align_corners: bool | None = False, + device: torch.device | None = None, ): RandomizableTransform.__init__(self, prob) MapTransform.__init__(self, keys) @@ -257,13 +260,13 @@ def __init__( ) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandSmoothDeformd": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandSmoothDeformd: super().set_random_state(seed, state) self.trans.set_random_state(seed, state) return self - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, data: Any | None = None) -> None: super().randomize(None) self.trans.randomize() diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 455b1c62ae..69816def97 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -12,11 +12,15 @@ A collection of "vanilla" transforms for spatial operations https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ + +from __future__ import annotations + +import functools import warnings +from collections.abc import Callable from copy import deepcopy -from enum import Enum from itertools import zip_longest -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Sequence, Tuple, Union, cast import numpy as np import torch @@ -27,11 +31,21 @@ from monai.data.meta_tensor import MetaTensor from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull -from monai.networks.utils import meshgrid_ij, normalize_transform +from monai.networks.utils import meshgrid_ij from monai.transforms.croppad.array import CenterSpatialCrop, ResizeWithPadOrCrop -from monai.transforms.intensity.array import GaussianSmooth from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import Randomizable, RandomizableTransform, Transform +from monai.transforms.spatial.functional import ( + affine_func, + flip, + orientation, + resize, + rotate, + rotate90, + spatial_resample, + zoom, +) +from monai.transforms.traits import MultiSampleTrait +from monai.transforms.transform import LazyTransform, Randomizable, RandomizableTransform, Transform from monai.transforms.utils import ( convert_pad_mode, create_control_grid, @@ -43,7 +57,7 @@ map_spatial_axes, scale_affine, ) -from monai.transforms.utils_pytorch_numpy_unification import allclose, linalg_inv, moveaxis, where +from monai.transforms.utils_pytorch_numpy_unification import linalg_inv, moveaxis from monai.utils import ( GridSampleMode, GridSamplePadMode, @@ -61,10 +75,9 @@ fall_back_tuple, issequenceiterable, optional_import, - pytorch_after, ) from monai.utils.deprecate_utils import deprecated_arg -from monai.utils.enums import GridPatchSort, PytorchPadMode, TraceKeys, TransformBackends, WSIPatchKeys +from monai.utils.enums import GridPatchSort, PatchKeys, PytorchPadMode, TraceKeys, TransformBackends from monai.utils.misc import ImageMetaKey as Key from monai.utils.module import look_up_option from monai.utils.type_conversion import convert_data_type, get_equivalent_dtype, get_torch_dtype_from_string @@ -107,7 +120,7 @@ RandRange = Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] -class SpatialResample(InvertibleTransform): +class SpatialResample(InvertibleTransform, LazyTransform): """ Resample input image from the orientation/spacing defined by ``src_affine`` affine matrix into the ones specified by ``dst_affine`` affine matrix. @@ -120,7 +133,7 @@ class SpatialResample(InvertibleTransform): def __init__( self, - mode: Union[str, int] = GridSampleMode.BILINEAR, + mode: str | int = GridSampleMode.BILINEAR, padding_mode: str = GridSamplePadMode.BORDER, align_corners: bool = False, dtype: DtypeLike = np.float64, @@ -148,55 +161,14 @@ def __init__( self.align_corners = align_corners self.dtype = dtype - def _post_process( - self, - img: torch.Tensor, - src_affine: torch.Tensor, - dst_affine: torch.Tensor, - mode, - padding_mode, - align_corners, - original_spatial_shape, - ) -> torch.Tensor: - """ - Small fn to simplify returning data. If `MetaTensor`, update affine. Elif - tracking metadata is desired, create `MetaTensor` with affine. Else, return - image as `torch.Tensor`. Output type is always `float32`. - - Also append the transform to the stack. - """ - dtype = img.dtype - img = convert_to_tensor(img, track_meta=get_track_meta(), dtype=torch.float32) - if get_track_meta(): - self.update_meta(img, dst_affine) - self.push_transform( - img, - extra_info={ - "dtype": str(dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32 - "mode": mode.value if isinstance(mode, Enum) else mode, - "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, - "src_affine": src_affine, - }, - orig_size=original_spatial_shape, - ) - return img - - def update_meta(self, img, dst_affine): - img.affine = dst_affine - - @deprecated_arg( - name="src_affine", since="0.9", msg_suffix="img should be `MetaTensor`, so affine can be extracted directly." - ) def __call__( self, img: torch.Tensor, - src_affine: Optional[NdarrayOrTensor] = None, - dst_affine: Optional[torch.Tensor] = None, - spatial_size: Optional[Union[Sequence[int], torch.Tensor, int]] = None, - mode: Union[str, int, None] = None, - padding_mode: Optional[str] = None, - align_corners: Optional[bool] = None, + dst_affine: torch.Tensor | None = None, + spatial_size: Sequence[int] | torch.Tensor | int | None = None, + mode: str | int | None = None, + padding_mode: str | None = None, + align_corners: bool | None = None, dtype: DtypeLike = None, ) -> torch.Tensor: """ @@ -237,88 +209,12 @@ def __call__( Set `dst_affine` and `spatial_size` to `None` to turn off the resampling step. """ # get dtype as torch (e.g., torch.float64) - _dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor) - align_corners = self.align_corners if align_corners is None else align_corners + dtype_pt = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor) + align_corners = align_corners if align_corners is not None else self.align_corners mode = mode if mode is not None else self.mode padding_mode = padding_mode if padding_mode is not None else self.padding_mode - original_spatial_shape = img.shape[1:] - - src_affine_: torch.Tensor = img.affine if isinstance(img, MetaTensor) else torch.eye(4) - img = convert_to_tensor(data=img, track_meta=get_track_meta(), dtype=_dtype) - spatial_rank = min(len(img.shape) - 1, src_affine_.shape[0] - 1, 3) - if (not isinstance(spatial_size, int) or spatial_size != -1) and spatial_size is not None: - spatial_rank = min(len(ensure_tuple(spatial_size)), 3) # infer spatial rank based on spatial_size - src_affine_ = to_affine_nd(spatial_rank, src_affine_).to(_dtype) - dst_affine = to_affine_nd(spatial_rank, dst_affine) if dst_affine is not None else src_affine_ - dst_affine = convert_to_dst_type(dst_affine, src_affine_)[0] - if not isinstance(dst_affine, torch.Tensor): - raise ValueError(f"dst_affine should be a torch.Tensor, got {type(dst_affine)}") - - in_spatial_size = torch.tensor(img.shape[1 : spatial_rank + 1]) - if isinstance(spatial_size, int) and (spatial_size == -1): # using the input spatial size - spatial_size = in_spatial_size - elif spatial_size is None and spatial_rank > 1: # auto spatial size - spatial_size, _ = compute_shape_offset(in_spatial_size, src_affine_, dst_affine) # type: ignore - spatial_size = torch.tensor(fall_back_tuple(ensure_tuple(spatial_size)[:spatial_rank], in_spatial_size)) - - if ( - allclose(src_affine_, dst_affine, atol=AFFINE_TOL) - and allclose(spatial_size, in_spatial_size) - or spatial_rank == 1 - ): - # no significant change, return original image - return self._post_process( - img, src_affine_, src_affine_, mode, padding_mode, align_corners, original_spatial_shape - ) - - try: - _s = convert_to_tensor(src_affine_, track_meta=False, device=torch.device("cpu")) - _d = convert_to_tensor(dst_affine, track_meta=False, device=torch.device("cpu")) - xform = ( - torch.linalg.solve(_s, _d) if pytorch_after(1, 8, 0) else torch.solve(_d, _s).solution # type: ignore - ) - except (np.linalg.LinAlgError, RuntimeError) as e: - raise ValueError("src affine is not invertible.") from e - xform = to_affine_nd(spatial_rank, xform).to(device=img.device, dtype=_dtype) - # no resampling if it's identity transform - if allclose(xform, torch.eye(len(xform)), atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size): - return self._post_process( - img, src_affine_, src_affine_, mode, padding_mode, align_corners, original_spatial_shape - ) - - in_spatial_size = in_spatial_size.tolist() # type: ignore - chns, additional_dims = img.shape[0], img.shape[spatial_rank + 1 :] # beyond three spatial dims - - if additional_dims: - xform_shape = [-1] + in_spatial_size - img = img.reshape(xform_shape) # type: ignore - if isinstance(mode, int): - dst_xform_1 = normalize_transform(spatial_size, xform.device, xform.dtype, True, True)[0] # to (-1, 1) - if not align_corners: - norm = create_scale(spatial_rank, [(max(d, 2) - 1) / d for d in spatial_size], xform.device, "torch") - dst_xform_1 = norm.to(xform.dtype) @ dst_xform_1 # type: ignore # scaling (num_step - 1) / num_step - dst_xform_d = normalize_transform(spatial_size, xform.device, xform.dtype, align_corners, False)[0] - xform = xform @ torch.inverse(dst_xform_d) @ dst_xform_1 - affine_xform = Affine( - affine=xform, spatial_size=spatial_size, normalized=True, image_only=True, dtype=_dtype - ) - with affine_xform.trace_transform(False): - img = affine_xform(img, mode=mode, padding_mode=padding_mode) - else: - affine_xform = AffineTransform( - normalized=False, - mode=mode, - padding_mode=padding_mode, - align_corners=align_corners, - reverse_indexing=True, - ) - img = affine_xform(img.unsqueeze(0), theta=xform, spatial_size=spatial_size).squeeze(0) - if additional_dims: - full_shape = (chns, *spatial_size, *additional_dims) - img = img.reshape(full_shape) - - return self._post_process( - img, src_affine_, dst_affine, mode, padding_mode, align_corners, original_spatial_shape + return spatial_resample( + img, dst_affine, spatial_size, mode, padding_mode, align_corners, dtype_pt, self.get_transform_info() ) def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -342,43 +238,19 @@ class ResampleToMatch(SpatialResample): """Resample an image to match given metadata. The affine matrix will be aligned, and the size of the output image will match.""" - def update_meta(self, img: torch.Tensor, dst_affine=None, img_dst=None): - if dst_affine is not None: - super().update_meta(img, dst_affine) - if isinstance(img_dst, MetaTensor) and isinstance(img, MetaTensor): - original_fname = img.meta[Key.FILENAME_OR_OBJ] - img.meta = deepcopy(img_dst.meta) - img.meta[Key.FILENAME_OR_OBJ] = original_fname # keep the original name, the others are overwritten - - @deprecated_arg( - name="src_meta", since="0.9", msg_suffix="img should be `MetaTensor`, so affine can be extracted directly." - ) - @deprecated_arg( - name="dst_meta", since="0.9", msg_suffix="img_dst should be `MetaTensor`, so affine can be extracted directly." - ) - def __call__( + def __call__( # type: ignore self, img: torch.Tensor, img_dst: torch.Tensor, - src_meta: Optional[Dict] = None, - dst_meta: Optional[Dict] = None, - mode: Union[str, int, None] = None, - padding_mode: Optional[str] = None, - align_corners: Optional[bool] = None, + mode: str | int | None = None, + padding_mode: str | None = None, + align_corners: bool | None = None, dtype: DtypeLike = None, ) -> torch.Tensor: """ Args: - img: input image to be resampled to match ``dst_meta``. It currently supports channel-first arrays with + img: input image to be resampled to match ``img_dst``. It currently supports channel-first arrays with at most three spatial dimensions. - src_meta: Dictionary containing the source affine matrix in the form ``{'affine':src_affine}``. - If ``affine`` is not specified, an identity matrix is assumed. Defaults to ``None``. - See also: https://docs.monai.io/en/stable/transforms.html#spatialresample - dst_meta: Dictionary containing the target affine matrix and target spatial shape in the form - ``{'affine':src_affine, 'spatial_shape':spatial_size}``. If ``affine`` is not - specified, ``src_affine`` is assumed. If ``spatial_shape`` is not specified, spatial size is - automatically computed, containing the previous field of view. Defaults to ``None``. - See also: https://docs.monai.io/en/stable/transforms.html#spatialresample mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers). Interpolation mode to calculate output values. Defaults to ``"bilinear"``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html @@ -398,49 +270,59 @@ def __call__( ``np.float64`` (for best precision). If ``None``, use the data type of input data. To be compatible with other modules, the output data type is always `float32`. Raises: - RuntimeError: When ``src_meta`` is missing. - RuntimeError: When ``dst_meta`` is missing. ValueError: When the affine matrix of the source image is not invertible. Returns: Resampled input tensor or MetaTensor. """ if img_dst is None: raise RuntimeError("`img_dst` is missing.") - dst_affine = img_dst.affine if isinstance(img_dst, MetaTensor) else torch.eye(4) + dst_affine = img_dst.peek_pending_affine() if isinstance(img_dst, MetaTensor) else torch.eye(4) img = super().__call__( img=img, dst_affine=dst_affine, - spatial_size=img_dst.shape[1:], # skip channel + spatial_size=img_dst.peek_pending_shape() if isinstance(img_dst, MetaTensor) else img_dst.shape[1:], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, ) - self.update_meta(img, dst_affine=dst_affine, img_dst=img_dst) + if not self.lazy_evaluation: + if isinstance(img, MetaTensor): + img.affine = dst_affine + if isinstance(img_dst, MetaTensor): + original_fname = img.meta.get(Key.FILENAME_OR_OBJ, "resample_to_match_source") + img.meta = deepcopy(img_dst.meta) + img.meta[Key.FILENAME_OR_OBJ] = original_fname # keep the original name, the others are overwritten + else: + if isinstance(img, MetaTensor) and isinstance(img_dst, MetaTensor): + original_fname = img.meta.get(Key.FILENAME_OR_OBJ, "resample_to_match_source") + meta_dict = deepcopy(img_dst.meta) + for k in ("affine", "spatial_shape"): # keys that don't copy from img_dst in lazy evaluation + meta_dict.pop(k, None) + img.meta.update(meta_dict) + img.meta[Key.FILENAME_OR_OBJ] = original_fname # keep the original name, the others are overwritten return img -class Spacing(InvertibleTransform): +class Spacing(InvertibleTransform, LazyTransform): """ Resample input image into the specified `pixdim`. """ backend = SpatialResample.backend - @deprecated_arg(name="image_only", since="0.9") def __init__( self, - pixdim: Union[Sequence[float], float, np.ndarray], + pixdim: Sequence[float] | float | np.ndarray, diagonal: bool = False, - mode: Union[str, int] = GridSampleMode.BILINEAR, + mode: str | int = GridSampleMode.BILINEAR, padding_mode: str = GridSamplePadMode.BORDER, align_corners: bool = False, dtype: DtypeLike = np.float64, scale_extent: bool = False, recompute_affine: bool = False, - min_pixdim: Union[Sequence[float], float, np.ndarray, None] = None, - max_pixdim: Union[Sequence[float], float, np.ndarray, None] = None, - image_only: bool = False, + min_pixdim: Sequence[float] | float | np.ndarray | None = None, + max_pixdim: Sequence[float] | float | np.ndarray | None = None, ) -> None: """ Args: @@ -510,17 +392,22 @@ def __init__( mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype ) + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool) -> None: + self._lazy_evaluation = val + self.sp_resample.lazy_evaluation = val + @deprecated_arg(name="affine", since="0.9", msg_suffix="Not needed, input should be `MetaTensor`.") def __call__( self, data_array: torch.Tensor, - affine: Optional[NdarrayOrTensor] = None, - mode: Union[str, int, None] = None, - padding_mode: Optional[str] = None, - align_corners: Optional[bool] = None, + affine: NdarrayOrTensor | None = None, + mode: str | int | None = None, + padding_mode: str | None = None, + align_corners: bool | None = None, dtype: DtypeLike = None, - scale_extent: Optional[bool] = None, - output_spatial_shape: Optional[Union[Sequence[int], np.ndarray, int]] = None, + scale_extent: bool | None = None, + output_spatial_shape: Sequence[int] | np.ndarray | int | None = None, ) -> torch.Tensor: """ Args: @@ -559,21 +446,23 @@ def __call__( data tensor or MetaTensor (resampled into `self.pixdim`). """ - original_spatial_shape = data_array.shape[1:] + original_spatial_shape = ( + data_array.peek_pending_shape() if isinstance(data_array, MetaTensor) else data_array.shape[1:] + ) sr = len(original_spatial_shape) if sr <= 0: - raise ValueError("data_array must have at least one spatial dimension.") + raise ValueError(f"data_array must have at least one spatial dimension, got {original_spatial_shape}.") affine_: np.ndarray if affine is not None: warnings.warn("arg `affine` is deprecated, the affine of MetaTensor in data_array has higher priority.") - input_affine = data_array.affine if isinstance(data_array, MetaTensor) else affine + input_affine = data_array.peek_pending_affine() if isinstance(data_array, MetaTensor) else affine if input_affine is None: warnings.warn("`data_array` is not of type MetaTensor, assuming affine to be identity.") # default to identity input_affine = np.eye(sr + 1, dtype=np.float64) affine_ = to_affine_nd(sr, convert_data_type(input_affine, np.ndarray)[0]) - out_d = self.pixdim[:sr] + out_d = self.pixdim[:sr].copy() if out_d.size < sr: out_d = np.append(out_d, [out_d[-1]] * (sr - out_d.size)) @@ -594,46 +483,42 @@ def __call__( # compute output affine, shape and offset new_affine = zoom_affine(affine_, out_d, diagonal=self.diagonal) scale_extent = self.scale_extent if scale_extent is None else scale_extent - output_shape, offset = compute_shape_offset(data_array.shape[1:], affine_, new_affine, scale_extent) + output_shape, offset = compute_shape_offset(original_spatial_shape, affine_, new_affine, scale_extent) new_affine[:sr, -1] = offset[:sr] - # convert to MetaTensor if necessary - data_array = convert_to_tensor(data_array, track_meta=get_track_meta()) - if isinstance(data_array, MetaTensor): - data_array.affine = torch.as_tensor(affine_) - # we don't want to track the nested transform otherwise two will be appended actual_shape = list(output_shape) if output_spatial_shape is None else output_spatial_shape data_array = self.sp_resample( data_array, dst_affine=torch.as_tensor(new_affine), - spatial_size=actual_shape, + spatial_size=actual_shape, # type: ignore mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, ) if self.recompute_affine and isinstance(data_array, MetaTensor): - data_array.affine = scale_affine(affine_, original_spatial_shape, actual_shape) + if self.lazy_evaluation: + raise NotImplementedError("recompute_affine is not supported with lazy evaluation.") + a = scale_affine(original_spatial_shape, actual_shape) + data_array.affine = convert_to_dst_type(a, affine_)[0] # type: ignore return data_array def inverse(self, data: torch.Tensor) -> torch.Tensor: return self.sp_resample.inverse(data) -class Orientation(InvertibleTransform): +class Orientation(InvertibleTransform, LazyTransform): """ Change the input image's orientation into the specified based on `axcodes`. """ backend = [TransformBackends.NUMPY, TransformBackends.TORCH] - @deprecated_arg(name="image_only", since="0.9") def __init__( self, - axcodes: Optional[str] = None, + axcodes: str | None = None, as_closest_canonical: bool = False, - labels: Optional[Sequence[Tuple[str, str]]] = (("L", "R"), ("P", "A"), ("I", "S")), - image_only: bool = False, + labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")), ) -> None: """ Args: @@ -679,14 +564,14 @@ def __call__(self, data_array: torch.Tensor) -> torch.Tensor: `torch.Tensor`. """ - spatial_shape = data_array.shape[1:] + spatial_shape = data_array.peek_pending_shape() if isinstance(data_array, MetaTensor) else data_array.shape[1:] sr = len(spatial_shape) if sr <= 0: - raise ValueError("data_array must have at least one spatial dimension.") + raise ValueError(f"data_array must have at least one spatial dimension, got {spatial_shape}.") affine_: np.ndarray affine_np: np.ndarray if isinstance(data_array, MetaTensor): - affine_np, *_ = convert_data_type(data_array.affine, np.ndarray) + affine_np, *_ = convert_data_type(data_array.peek_pending_affine(), np.ndarray) affine_ = to_affine_nd(sr, affine_np) else: warnings.warn("`data_array` is not of type `MetaTensor, assuming affine to be identity.") @@ -702,8 +587,8 @@ def __call__(self, data_array: torch.Tensor) -> torch.Tensor: raise ValueError("Incompatible values: axcodes=None and as_closest_canonical=True.") if sr < len(self.axcodes): warnings.warn( - f"axcodes ('{self.axcodes}') length is smaller than the number of input spatial dimensions D={sr}.\n" - f"{self.__class__.__name__}: input spatial shape is {spatial_shape}, num. channels is {data_array.shape[0]}," + f"axcodes ('{self.axcodes}') length is smaller than number of input spatial dimensions D={sr}.\n" + f"{self.__class__.__name__}: spatial shape = {spatial_shape}, channels = {data_array.shape[0]}," "please make sure the input is in the channel-first format." ) dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=self.labels) @@ -712,31 +597,7 @@ def __call__(self, data_array: torch.Tensor) -> torch.Tensor: f"axcodes must match data_array spatially, got axcodes={len(self.axcodes)}D data_array={sr}D" ) spatial_ornt = nib.orientations.ornt_transform(src, dst) - new_affine = affine_ @ nib.orientations.inv_ornt_aff(spatial_ornt, spatial_shape) - - # convert to MetaTensor if necessary - data_array = convert_to_tensor(data_array, track_meta=get_track_meta()) - - spatial_ornt[:, 0] += 1 # skip channel dim - spatial_ornt = np.concatenate([np.array([[0, 1]]), spatial_ornt]) - axes = [ax for ax, flip in enumerate(spatial_ornt[:, 1]) if flip == -1] - if axes: - data_array = torch.flip(data_array, dims=axes) - full_transpose = np.arange(len(data_array.shape)) - full_transpose[: len(spatial_ornt)] = np.argsort(spatial_ornt[:, 0]) - if not np.all(full_transpose == np.arange(len(data_array.shape))): - data_array = data_array.permute(full_transpose.tolist()) - - new_affine = to_affine_nd(affine_np, new_affine) - new_affine, *_ = convert_data_type(new_affine, torch.Tensor, dtype=torch.float32, device=data_array.device) - - if get_track_meta(): - self.update_meta(data_array, new_affine) - self.push_transform(data_array, extra_info={"original_affine": affine_np}) - return data_array - - def update_meta(self, img, new_affine): - img.affine = new_affine + return orientation(data_array, affine_np, spatial_ornt, self.get_transform_info()) # type: ignore def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) @@ -751,7 +612,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return data -class Flip(InvertibleTransform): +class Flip(InvertibleTransform, LazyTransform): """ Reverses the order of elements along the given spatial axis. Preserves shape. See `torch.flip` documentation for additional details: @@ -768,33 +629,16 @@ class Flip(InvertibleTransform): backend = [TransformBackends.TORCH] - def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: + def __init__(self, spatial_axis: Sequence[int] | int | None = None) -> None: self.spatial_axis = spatial_axis - def update_meta(self, img, shape, axes): - # shape and axes include the channel dim - affine = img.affine - mat = convert_to_dst_type(torch.eye(len(affine)), affine)[0] - for axis in axes: - sp = axis - 1 - mat[sp, sp], mat[sp, -1] = mat[sp, sp] * -1, shape[axis] - 1 - img.affine = affine @ mat - - def forward_image(self, img, axes) -> torch.Tensor: - return torch.flip(img, axes) - def __call__(self, img: torch.Tensor) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]) """ img = convert_to_tensor(img, track_meta=get_track_meta()) - axes = map_spatial_axes(img.ndim, self.spatial_axis) - out = self.forward_image(img, axes) - if get_track_meta(): - self.update_meta(out, out.shape, axes) - self.push_transform(out) - return out + return flip(img, self.spatial_axis, transform_info=self.get_transform_info()) # type: ignore def inverse(self, data: torch.Tensor) -> torch.Tensor: self.pop_transform(data) @@ -803,7 +647,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return flipper(data) -class Resize(InvertibleTransform): +class Resize(InvertibleTransform, LazyTransform): """ Resize the input image to given spatial size (with scaling, not cropping/padding). Implemented using :py:class:`torch.nn.functional.interpolate`. @@ -833,18 +677,21 @@ class Resize(InvertibleTransform): By default, this value is chosen as (s - 1) / 2 where s is the downsampling factor, where s > 1. For the up-size case, s < 1, no anti-aliasing is performed prior to rescaling. + dtype: data type for resampling computation. Defaults to ``float32``. + If None, use the data type of input data. """ backend = [TransformBackends.TORCH] def __init__( self, - spatial_size: Union[Sequence[int], int], + spatial_size: Sequence[int] | int, size_mode: str = "all", mode: str = InterpolateMode.AREA, - align_corners: Optional[bool] = None, + align_corners: bool | None = None, anti_aliasing: bool = False, - anti_aliasing_sigma: Union[Sequence[float], float, None] = None, + anti_aliasing_sigma: Sequence[float] | float | None = None, + dtype: DtypeLike | torch.dtype = torch.float32, ) -> None: self.size_mode = look_up_option(size_mode, ["all", "longest"]) self.spatial_size = spatial_size @@ -852,14 +699,16 @@ def __init__( self.align_corners = align_corners self.anti_aliasing = anti_aliasing self.anti_aliasing_sigma = anti_aliasing_sigma + self.dtype = dtype def __call__( self, img: torch.Tensor, - mode: Optional[str] = None, - align_corners: Optional[bool] = None, - anti_aliasing: Optional[bool] = None, - anti_aliasing_sigma: Union[Sequence[float], float, None] = None, + mode: str | None = None, + align_corners: bool | None = None, + anti_aliasing: bool | None = None, + anti_aliasing_sigma: Sequence[float] | float | None = None, + dtype: DtypeLike | torch.dtype = None, ) -> torch.Tensor: """ Args: @@ -880,6 +729,8 @@ def __call__( By default, this value is chosen as (s - 1) / 2 where s is the downsampling factor, where s > 1. For the up-size case, s < 1, no anti-aliasing is performed prior to rescaling. + dtype: data type for resampling computation. Defaults to ``self.dtype``. + If None, use the data type of input data. Raises: ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. @@ -899,60 +750,29 @@ def __call__( "len(spatial_size) must be greater or equal to img spatial dimensions, " f"got spatial_size={output_ndim} img={input_ndim}." ) - spatial_size_ = fall_back_tuple(self.spatial_size, img.shape[1:]) + _sp = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] + sp_size = fall_back_tuple(self.spatial_size, _sp) else: # for the "longest" mode - img_size = img.shape[1:] + img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] if not isinstance(self.spatial_size, int): raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") scale = self.spatial_size / max(img_size) - spatial_size_ = tuple(int(round(s * scale)) for s in img_size) + sp_size = tuple(int(round(s * scale)) for s in img_size) - original_sp_size = img.shape[1:] _mode = look_up_option(self.mode if mode is None else mode, InterpolateMode) _align_corners = self.align_corners if align_corners is None else align_corners - if tuple(img.shape[1:]) == spatial_size_: # spatial shape is already the desired - img = convert_to_tensor(img, track_meta=get_track_meta()) - - return self._post_process(img, original_sp_size, spatial_size_, _mode, _align_corners, input_ndim) - img_ = convert_to_tensor(img, dtype=torch.float, track_meta=False) - - if anti_aliasing and any(x < y for x, y in zip(spatial_size_, img_.shape[1:])): - factors = torch.div(torch.Tensor(list(img_.shape[1:])), torch.Tensor(spatial_size_)) - if anti_aliasing_sigma is None: - # if sigma is not given, use the default sigma in skimage.transform.resize - anti_aliasing_sigma = torch.maximum(torch.zeros(factors.shape), (factors - 1) / 2).tolist() - else: - # if sigma is given, use the given value for downsampling axis - anti_aliasing_sigma = list(ensure_tuple_rep(anti_aliasing_sigma, len(spatial_size_))) - for axis in range(len(spatial_size_)): - anti_aliasing_sigma[axis] = anti_aliasing_sigma[axis] * int(factors[axis] > 1) - anti_aliasing_filter = GaussianSmooth(sigma=anti_aliasing_sigma) - img_ = convert_to_tensor(anti_aliasing_filter(img_), track_meta=False) - - img = convert_to_tensor(img, track_meta=get_track_meta()) - resized = torch.nn.functional.interpolate( - input=img_.unsqueeze(0), size=spatial_size_, mode=_mode, align_corners=_align_corners + _dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor) + return resize( # type: ignore + img, + sp_size, + _mode, + _align_corners, + _dtype, + input_ndim, + anti_aliasing, + anti_aliasing_sigma, + self.get_transform_info(), ) - out, *_ = convert_to_dst_type(resized.squeeze(0), img) - return self._post_process(out, original_sp_size, spatial_size_, _mode, _align_corners, input_ndim) - - def _post_process(self, img: torch.Tensor, orig_size, sp_size, mode, align_corners, ndim) -> torch.Tensor: - if get_track_meta(): - self.update_meta(img, orig_size, sp_size) - self.push_transform( - img, - orig_size=orig_size, - extra_info={ - "mode": mode, - "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, - "new_dim": len(orig_size) - ndim, # additional dims appended - }, - ) - return img - - def update_meta(self, img, spatial_size, new_spatial_size): - affine = convert_to_tensor(img.affine, track_meta=False) - img.affine = scale_affine(affine, spatial_size, new_spatial_size) def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) @@ -962,8 +782,12 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: orig_size = transform[TraceKeys.ORIG_SIZE] mode = transform[TraceKeys.EXTRA_INFO]["mode"] align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] + dtype = transform[TraceKeys.EXTRA_INFO]["dtype"] xform = Resize( - spatial_size=orig_size, mode=mode, align_corners=None if align_corners == TraceKeys.NONE else align_corners + spatial_size=orig_size, + mode=mode, + align_corners=None if align_corners == TraceKeys.NONE else align_corners, + dtype=dtype, ) with xform.trace_transform(False): data = xform(data) @@ -972,7 +796,7 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: return data -class Rotate(InvertibleTransform): +class Rotate(InvertibleTransform, LazyTransform): """ Rotates an input image by given angle using :py:class:`monai.networks.layers.AffineTransform`. @@ -998,12 +822,12 @@ class Rotate(InvertibleTransform): def __init__( self, - angle: Union[Sequence[float], float], + angle: Sequence[float] | float, keep_size: bool = True, mode: str = GridSampleMode.BILINEAR, padding_mode: str = GridSamplePadMode.BORDER, align_corners: bool = False, - dtype: Union[DtypeLike, torch.dtype] = torch.float32, + dtype: DtypeLike | torch.dtype = torch.float32, ) -> None: self.angle = angle self.keep_size = keep_size @@ -1015,10 +839,10 @@ def __init__( def __call__( self, img: torch.Tensor, - mode: Optional[str] = None, - padding_mode: Optional[str] = None, - align_corners: Optional[bool] = None, - dtype: Union[DtypeLike, torch.dtype] = None, + mode: str | None = None, + padding_mode: str | None = None, + align_corners: bool | None = None, + dtype: DtypeLike | torch.dtype = None, ) -> torch.Tensor: """ Args: @@ -1043,58 +867,14 @@ def __call__( """ img = convert_to_tensor(img, track_meta=get_track_meta()) _dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor) - - im_shape = np.asarray(img.shape[1:]) # spatial dimensions - input_ndim = len(im_shape) - if input_ndim not in (2, 3): - raise ValueError(f"Unsupported image dimension: {input_ndim}, available options are [2, 3].") - _angle = ensure_tuple_rep(self.angle, 1 if input_ndim == 2 else 3) - transform = create_rotate(input_ndim, _angle) - shift = create_translate(input_ndim, ((im_shape - 1) / 2).tolist()) - if self.keep_size: - output_shape = im_shape - else: - corners = np.asarray(np.meshgrid(*[(0, dim) for dim in im_shape], indexing="ij")).reshape( - (len(im_shape), -1) - ) - corners = transform[:-1, :-1] @ corners # type: ignore - output_shape = np.asarray(corners.ptp(axis=1) + 0.5, dtype=int) - shift_1 = create_translate(input_ndim, (-(output_shape - 1) / 2).tolist()) - transform = shift @ transform @ shift_1 - - img_t = img.to(_dtype) - transform_t, *_ = convert_to_dst_type(transform, img_t) _mode = look_up_option(mode or self.mode, GridSampleMode) _padding_mode = look_up_option(padding_mode or self.padding_mode, GridSamplePadMode) _align_corners = self.align_corners if align_corners is None else align_corners - xform = AffineTransform( - normalized=False, - mode=_mode, - padding_mode=_padding_mode, - align_corners=_align_corners, - reverse_indexing=True, + im_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] + output_shape = im_shape if self.keep_size else None + return rotate( # type: ignore + img, self.angle, output_shape, _mode, _padding_mode, _align_corners, _dtype, self.get_transform_info() ) - output: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=output_shape).float().squeeze(0) - out, *_ = convert_to_dst_type(output, dst=img, dtype=output.dtype) - if get_track_meta(): - self.update_meta(out, transform_t) - self.push_transform( - out, - orig_size=img_t.shape[1:], - extra_info={ - "rot_mat": transform, - "mode": _mode, - "padding_mode": _padding_mode, - "align_corners": _align_corners if _align_corners is not None else TraceKeys.NONE, - "dtype": str(_dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32 - }, - ) - return out - - def update_meta(self, img, rotate_mat): - affine = convert_to_tensor(img.affine, track_meta=False) - mat = to_affine_nd(len(affine) - 1, rotate_mat) - img.affine = affine @ convert_to_dst_type(mat, affine)[0] def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) @@ -1106,7 +886,7 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] dtype = transform[TraceKeys.EXTRA_INFO]["dtype"] - inv_rot_mat = linalg_inv(fwd_rot_mat) + inv_rot_mat = linalg_inv(convert_to_numpy(fwd_rot_mat)) xform = AffineTransform( normalized=False, @@ -1120,12 +900,14 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: sp_size = transform[TraceKeys.ORIG_SIZE] out: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=sp_size).float().squeeze(0) out = convert_to_dst_type(out, dst=data, dtype=out.dtype)[0] - if isinstance(data, MetaTensor): - self.update_meta(out, transform_t) + if isinstance(out, MetaTensor): + affine = convert_to_tensor(out.peek_pending_affine(), track_meta=False) + mat = to_affine_nd(len(affine) - 1, transform_t) + out.affine @= convert_to_dst_type(mat, affine)[0] return out -class Zoom(InvertibleTransform): +class Zoom(InvertibleTransform, LazyTransform): """ Zooms an ND image using :py:class:`torch.nn.functional.interpolate`. For details, please see https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html. @@ -1150,6 +932,8 @@ class Zoom(InvertibleTransform): align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + dtype: data type for resampling computation. Defaults to ``float32``. + If None, use the data type of input data. keep_size: Should keep original size (padding/slicing if needed), default is True. kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. @@ -1160,10 +944,11 @@ class Zoom(InvertibleTransform): def __init__( self, - zoom: Union[Sequence[float], float], + zoom: Sequence[float] | float, mode: str = InterpolateMode.AREA, padding_mode: str = NumpyPadMode.EDGE, - align_corners: Optional[bool] = None, + align_corners: bool | None = None, + dtype: DtypeLike | torch.dtype = torch.float32, keep_size: bool = True, **kwargs, ) -> None: @@ -1171,15 +956,17 @@ def __init__( self.mode: InterpolateMode = InterpolateMode(mode) self.padding_mode = padding_mode self.align_corners = align_corners + self.dtype = dtype self.keep_size = keep_size self.kwargs = kwargs def __call__( self, img: torch.Tensor, - mode: Optional[str] = None, - padding_mode: Optional[str] = None, - align_corners: Optional[bool] = None, + mode: str | None = None, + padding_mode: str | None = None, + align_corners: bool | None = None, + dtype: DtypeLike | torch.dtype = None, ) -> torch.Tensor: """ Args: @@ -1198,50 +985,19 @@ def __call__( align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + dtype: data type for resampling computation. Defaults to ``self.dtype``. + If None, use the data type of input data. """ img = convert_to_tensor(img, track_meta=get_track_meta()) - img_t = img.to(torch.float32) - _zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim _mode = look_up_option(self.mode if mode is None else mode, InterpolateMode).value - _align_corners = self.align_corners if align_corners is None else align_corners _padding_mode = padding_mode or self.padding_mode - - zoomed: NdarrayOrTensor = torch.nn.functional.interpolate( - recompute_scale_factor=True, - input=img_t.unsqueeze(0), - scale_factor=list(_zoom), - mode=_mode, - align_corners=_align_corners, + _align_corners = self.align_corners if align_corners is None else align_corners + _dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor) + return zoom( # type: ignore + img, _zoom, self.keep_size, _mode, _padding_mode, _align_corners, _dtype, self.get_transform_info() ) - zoomed = zoomed.squeeze(0) - orig_size, z_size = img_t.shape, zoomed.shape - - out, *_ = convert_to_dst_type(zoomed, dst=img) - if get_track_meta(): - self.update_meta(out, orig_size[1:], z_size[1:]) - do_pad_crop = self.keep_size and not np.allclose(orig_size, z_size) - if do_pad_crop: - _pad_crop = ResizeWithPadOrCrop(spatial_size=img_t.shape[1:], mode=_padding_mode) - out = _pad_crop(out) - if get_track_meta(): - padcrop_xform = self.pop_transform(out, check=False) if do_pad_crop else {} - self.push_transform( - out, - orig_size=orig_size[1:], - extra_info={ - "mode": _mode, - "align_corners": _align_corners if _align_corners is not None else TraceKeys.NONE, - "do_padcrop": do_pad_crop, - "padcrop": padcrop_xform, - }, - ) - return out - - def update_meta(self, img, spatial_size, new_spatial_size): - affine = convert_to_tensor(img.affine, track_meta=False) - img.affine = scale_affine(affine, spatial_size, new_spatial_size) def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) @@ -1259,16 +1015,17 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: # Create inverse transform mode = transform[TraceKeys.EXTRA_INFO]["mode"] align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] + dtype = transform[TraceKeys.EXTRA_INFO]["dtype"] inverse_transform = Resize(spatial_size=transform[TraceKeys.ORIG_SIZE]) # Apply inverse with inverse_transform.trace_transform(False): out = inverse_transform( - data, mode=mode, align_corners=None if align_corners == TraceKeys.NONE else align_corners + data, mode=mode, align_corners=None if align_corners == TraceKeys.NONE else align_corners, dtype=dtype ) return out -class Rotate90(InvertibleTransform): +class Rotate90(InvertibleTransform, LazyTransform): """ Rotate an array by 90 degrees in the plane specified by `axes`. See `torch.rot90` for additional details: @@ -1278,7 +1035,7 @@ class Rotate90(InvertibleTransform): backend = [TransformBackends.TORCH] - def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: + def __init__(self, k: int = 1, spatial_axes: tuple[int, int] = (0, 1)) -> None: """ Args: k: number of times to rotate by 90 degrees. @@ -1286,10 +1043,10 @@ def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: Default: (0, 1), this is the first two axis in spatial dimensions. If axis is negative it counts from the last to the first axis. """ - self.k = k - spatial_axes_: Tuple[int, int] = ensure_tuple(spatial_axes) # type: ignore + self.k = (4 + (k % 4)) % 4 # 0, 1, 2, 3 + spatial_axes_: tuple[int, int] = ensure_tuple(spatial_axes) # type: ignore if len(spatial_axes_) != 2: - raise ValueError("spatial_axes must be 2 int numbers to indicate the axes to rotate 90 degrees.") + raise ValueError(f"spatial_axes must be 2 numbers to define the plane to rotate, got {spatial_axes_}.") self.spatial_axes = spatial_axes_ def __call__(self, img: torch.Tensor) -> torch.Tensor: @@ -1299,30 +1056,7 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: """ img = convert_to_tensor(img, track_meta=get_track_meta()) axes = map_spatial_axes(img.ndim, self.spatial_axes) - ori_shape = img.shape[1:] - out: NdarrayOrTensor = torch.rot90(img, self.k, axes) - out = convert_to_dst_type(out, img)[0] - if get_track_meta(): - self.update_meta(out, ori_shape, out.shape[1:], axes, self.k) - self.push_transform(out, extra_info={"axes": [d - 1 for d in axes], "k": self.k}) # compensate spatial dim - return out - - def update_meta(self, img, spatial_size, new_spatial_size, axes, k): - affine = convert_data_type(img.affine, torch.Tensor)[0] - r, sp_r = len(affine) - 1, len(spatial_size) - mat = to_affine_nd(r, create_translate(sp_r, [-float(d - 1) / 2 for d in new_spatial_size])) - s = -1.0 if int(axes[0]) - int(axes[1]) in (-1, 2) else 1.0 - if sp_r == 2: - rot90 = to_affine_nd(r, create_rotate(sp_r, [s * np.pi / 2])) - else: - idx = {1, 2, 3} - set(axes) - angle: List[float] = [0, 0, 0] - angle[idx.pop() - 1] = s * np.pi / 2 - rot90 = to_affine_nd(r, create_rotate(sp_r, angle)) - for _ in range(k): - mat = rot90 @ mat - mat = to_affine_nd(r, create_translate(sp_r, [float(d - 1) / 2 for d in spatial_size])) @ mat - img.affine = affine @ convert_to_dst_type(mat, affine)[0] + return rotate90(img, axes, self.k, self.get_transform_info()) # type: ignore def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) @@ -1337,7 +1071,7 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: return xform(data) -class RandRotate90(RandomizableTransform, InvertibleTransform): +class RandRotate90(RandomizableTransform, InvertibleTransform, LazyTransform): """ With probability `prob`, input arrays are rotated by 90 degrees in the plane specified by `spatial_axes`. @@ -1345,7 +1079,7 @@ class RandRotate90(RandomizableTransform, InvertibleTransform): backend = Rotate90.backend - def __init__(self, prob: float = 0.1, max_k: int = 3, spatial_axes: Tuple[int, int] = (0, 1)) -> None: + def __init__(self, prob: float = 0.1, max_k: int = 3, spatial_axes: tuple[int, int] = (0, 1)) -> None: """ Args: prob: probability of rotating. @@ -1360,7 +1094,7 @@ def __init__(self, prob: float = 0.1, max_k: int = 3, spatial_axes: Tuple[int, i self._rand_k = 0 - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, data: Any | None = None) -> None: super().randomize(None) if not self._do_transform: return None @@ -1376,13 +1110,13 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: self.randomize() if self._do_transform: - out = Rotate90(self._rand_k, self.spatial_axes)(img) + xform = Rotate90(self._rand_k, self.spatial_axes) + xform.lazy_evaluation = self.lazy_evaluation + out = xform(img) else: out = convert_to_tensor(img, track_meta=get_track_meta()) - if get_track_meta(): - maybe_rot90_info = self.pop_transform(out, check=False) if self._do_transform else {} - self.push_transform(out, extra_info=maybe_rot90_info) + self.push_transform(out, replace=True) return out def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -1393,7 +1127,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return Rotate90().inverse_transform(data, rotate_xform) -class RandRotate(RandomizableTransform, InvertibleTransform): +class RandRotate(RandomizableTransform, InvertibleTransform, LazyTransform): """ Randomly rotate the input arrays. @@ -1425,15 +1159,15 @@ class RandRotate(RandomizableTransform, InvertibleTransform): def __init__( self, - range_x: Union[Tuple[float, float], float] = 0.0, - range_y: Union[Tuple[float, float], float] = 0.0, - range_z: Union[Tuple[float, float], float] = 0.0, + range_x: tuple[float, float] | float = 0.0, + range_y: tuple[float, float] | float = 0.0, + range_z: tuple[float, float] | float = 0.0, prob: float = 0.1, keep_size: bool = True, mode: str = GridSampleMode.BILINEAR, padding_mode: str = GridSamplePadMode.BORDER, align_corners: bool = False, - dtype: Union[DtypeLike, torch.dtype] = np.float32, + dtype: DtypeLike | torch.dtype = np.float32, ) -> None: RandomizableTransform.__init__(self, prob) self.range_x = ensure_tuple(range_x) @@ -1456,7 +1190,7 @@ def __init__( self.y = 0.0 self.z = 0.0 - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, data: Any | None = None) -> None: super().randomize(None) if not self._do_transform: return None @@ -1464,16 +1198,14 @@ def randomize(self, data: Optional[Any] = None) -> None: self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) - @deprecated_arg(name="get_matrix", since="0.9", msg_suffix="please use `img.meta` instead.") def __call__( self, img: torch.Tensor, - mode: Optional[str] = None, - padding_mode: Optional[str] = None, - align_corners: Optional[bool] = None, - dtype: Union[DtypeLike, torch.dtype] = None, + mode: str | None = None, + padding_mode: str | None = None, + align_corners: bool | None = None, + dtype: DtypeLike | torch.dtype = None, randomize: bool = True, - get_matrix: bool = False, ): """ Args: @@ -1495,20 +1227,20 @@ def __call__( self.randomize() if self._do_transform: + ndim = len(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]) rotator = Rotate( - angle=self.x if img.ndim == 3 else (self.x, self.y, self.z), + angle=self.x if ndim == 2 else (self.x, self.y, self.z), keep_size=self.keep_size, mode=look_up_option(mode or self.mode, GridSampleMode), padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode), align_corners=self.align_corners if align_corners is None else align_corners, dtype=dtype or self.dtype or img.dtype, ) + rotator.lazy_evaluation = self.lazy_evaluation out = rotator(img) else: out = convert_to_tensor(img, track_meta=get_track_meta(), dtype=torch.float32) - if get_track_meta(): - rot_info = self.pop_transform(out, check=False) if self._do_transform else {} - self.push_transform(out, extra_info=rot_info) + self.push_transform(out, replace=True) return out def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -1518,7 +1250,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return Rotate(0).inverse_transform(data, xform_info[TraceKeys.EXTRA_INFO]) -class RandFlip(RandomizableTransform, InvertibleTransform): +class RandFlip(RandomizableTransform, InvertibleTransform, LazyTransform): """ Randomly flips the image along axes. Preserves shape. See numpy.flip for additional details. @@ -1531,10 +1263,15 @@ class RandFlip(RandomizableTransform, InvertibleTransform): backend = Flip.backend - def __init__(self, prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: + def __init__(self, prob: float = 0.1, spatial_axis: Sequence[int] | int | None = None) -> None: RandomizableTransform.__init__(self, prob) self.flipper = Flip(spatial_axis=spatial_axis) + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool): + self.flipper.lazy_evaluation = val + self._lazy_evaluation = val + def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: """ Args: @@ -1545,9 +1282,7 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: self.randomize(None) out = self.flipper(img) if self._do_transform else img out = convert_to_tensor(out, track_meta=get_track_meta()) - if get_track_meta(): - xform_info = self.pop_transform(out, check=False) if self._do_transform else {} - self.push_transform(out, extra_info=xform_info) + self.push_transform(out, replace=True) return out def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -1558,7 +1293,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return self.flipper.inverse(data) -class RandAxisFlip(RandomizableTransform, InvertibleTransform): +class RandAxisFlip(RandomizableTransform, InvertibleTransform, LazyTransform): """ Randomly select a spatial axis and flip along it. See numpy.flip for additional details. @@ -1573,9 +1308,14 @@ class RandAxisFlip(RandomizableTransform, InvertibleTransform): def __init__(self, prob: float = 0.1) -> None: RandomizableTransform.__init__(self, prob) - self._axis: Optional[int] = None + self._axis: int | None = None self.flipper = Flip(spatial_axis=self._axis) + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool): + self.flipper.lazy_evaluation = val + self._lazy_evaluation = val + def randomize(self, data: NdarrayOrTensor) -> None: super().randomize(None) if not self._do_transform: @@ -1596,22 +1336,19 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: out = self.flipper(img) else: out = convert_to_tensor(img, track_meta=get_track_meta()) - if get_track_meta(): - xform = self.pop_transform(out, check=False) if self._do_transform else {} - xform["axes"] = self._axis - self.push_transform(out, extra_info=xform) + self.push_transform(out, replace=True) return out def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) if not transform[TraceKeys.DO_TRANSFORM]: return data - flipper = Flip(spatial_axis=transform[TraceKeys.EXTRA_INFO]["axes"]) + flipper = Flip(spatial_axis=transform[TraceKeys.EXTRA_INFO][TraceKeys.EXTRA_INFO]["axes"]) with flipper.trace_transform(False): return flipper(data) -class RandZoom(RandomizableTransform, InvertibleTransform): +class RandZoom(RandomizableTransform, InvertibleTransform, LazyTransform): """ Randomly zooms input arrays with given probability within given zoom range. @@ -1640,6 +1377,8 @@ class RandZoom(RandomizableTransform, InvertibleTransform): align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + dtype: data type for resampling computation. Defaults to ``float32``. + If None, use the data type of input data. keep_size: Should keep original size (pad if needed), default is True. kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. @@ -1651,11 +1390,12 @@ class RandZoom(RandomizableTransform, InvertibleTransform): def __init__( self, prob: float = 0.1, - min_zoom: Union[Sequence[float], float] = 0.9, - max_zoom: Union[Sequence[float], float] = 1.1, + min_zoom: Sequence[float] | float = 0.9, + max_zoom: Sequence[float] | float = 1.1, mode: str = InterpolateMode.AREA, padding_mode: str = NumpyPadMode.EDGE, - align_corners: Optional[bool] = None, + align_corners: bool | None = None, + dtype: DtypeLike | torch.dtype = torch.float32, keep_size: bool = True, **kwargs, ) -> None: @@ -1663,10 +1403,13 @@ def __init__( self.min_zoom = ensure_tuple(min_zoom) self.max_zoom = ensure_tuple(max_zoom) if len(self.min_zoom) != len(self.max_zoom): - raise AssertionError("min_zoom and max_zoom must have same length.") + raise ValueError( + f"min_zoom and max_zoom must have same length, got {len(self.min_zoom)} and {len(self.max_zoom)}." + ) self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) self.padding_mode = padding_mode self.align_corners = align_corners + self.dtype = dtype self.keep_size = keep_size self.kwargs = kwargs @@ -1687,9 +1430,10 @@ def randomize(self, img: NdarrayOrTensor) -> None: def __call__( self, img: torch.Tensor, - mode: Optional[str] = None, - padding_mode: Optional[str] = None, - align_corners: Optional[bool] = None, + mode: str | None = None, + padding_mode: str | None = None, + align_corners: bool | None = None, + dtype: DtypeLike | torch.dtype = None, randomize: bool = True, ) -> torch.Tensor: """ @@ -1708,6 +1452,8 @@ def __call__( align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + dtype: data type for resampling computation. Defaults to ``self.dtype``. + If None, use the data type of input data. randomize: whether to execute `randomize()` function first, default to True. """ @@ -1718,17 +1464,18 @@ def __call__( if not self._do_transform: out = convert_to_tensor(img, track_meta=get_track_meta(), dtype=torch.float32) else: - out = Zoom( + xform = Zoom( self._zoom, keep_size=self.keep_size, mode=look_up_option(mode or self.mode, InterpolateMode), padding_mode=padding_mode or self.padding_mode, align_corners=self.align_corners if align_corners is None else align_corners, + dtype=dtype or self.dtype, **self.kwargs, - )(img) - if get_track_meta(): - z_info = self.pop_transform(out, check=False) if self._do_transform else {} - self.push_transform(out, extra_info=z_info) + ) + xform.lazy_evaluation = self.lazy_evaluation + out = xform(img) + self.push_transform(out, replace=True) return out # type: ignore def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -1738,7 +1485,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return Zoom(self._zoom).inverse_transform(data, xform_info[TraceKeys.EXTRA_INFO]) -class AffineGrid(Transform): +class AffineGrid(LazyTransform): """ Affine transforms on the coordinates. @@ -1762,6 +1509,8 @@ class AffineGrid(Transform): dtype: data type for the grid computation. Defaults to ``float32``. If ``None``, use the data type of input data (if `grid` is provided). device: device on which the tensor will be allocated, if a new grid is generated. + align_corners: Defaults to False. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html affine: If applied, ignore the params (`rotate_params`, etc.) and use the supplied matrix. Should be square with each side = num of image spatial dimensions + 1. @@ -1772,25 +1521,28 @@ class AffineGrid(Transform): def __init__( self, - rotate_params: Optional[Union[Sequence[float], float]] = None, - shear_params: Optional[Union[Sequence[float], float]] = None, - translate_params: Optional[Union[Sequence[float], float]] = None, - scale_params: Optional[Union[Sequence[float], float]] = None, - device: Optional[torch.device] = None, + rotate_params: Sequence[float] | float | None = None, + shear_params: Sequence[float] | float | None = None, + translate_params: Sequence[float] | float | None = None, + scale_params: Sequence[float] | float | None = None, + device: torch.device | None = None, dtype: DtypeLike = np.float32, - affine: Optional[NdarrayOrTensor] = None, + align_corners: bool = False, + affine: NdarrayOrTensor | None = None, ) -> None: self.rotate_params = rotate_params self.shear_params = shear_params self.translate_params = translate_params self.scale_params = scale_params self.device = device - self.dtype = dtype + _dtype = get_equivalent_dtype(dtype, torch.Tensor) + self.dtype = _dtype if _dtype in (torch.float16, torch.float64, None) else torch.float32 + self.align_corners = align_corners self.affine = affine def __call__( - self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: + self, spatial_size: Sequence[int] | None = None, grid: torch.Tensor | None = None + ) -> tuple[torch.Tensor | None, torch.Tensor]: """ The grid can be initialized with a `spatial_size` parameter, or provided directly as `grid`. Therefore, either `spatial_size` or `grid` must be provided. @@ -1804,38 +1556,49 @@ def __call__( ValueError: When ``grid=None`` and ``spatial_size=None``. Incompatible values. """ - if grid is None: # create grid from spatial_size - if spatial_size is None: - raise ValueError("Incompatible values: grid=None and spatial_size=None.") - grid_ = create_grid(spatial_size, device=self.device, backend="torch", dtype=self.dtype) + if not self.lazy_evaluation: + if grid is None: # create grid from spatial_size + if spatial_size is None: + raise ValueError("Incompatible values: grid=None and spatial_size=None.") + grid_ = create_grid(spatial_size, device=self.device, backend="torch", dtype=self.dtype) + else: + grid_ = grid + _dtype = self.dtype or grid_.dtype + grid_: torch.Tensor = convert_to_tensor(grid_, dtype=_dtype, track_meta=get_track_meta()) # type: ignore + _device = grid_.device # type: ignore + spatial_dims = len(grid_.shape) - 1 else: - grid_ = grid - _dtype = self.dtype or grid_.dtype - grid_: torch.Tensor = convert_to_tensor(grid_, dtype=_dtype, track_meta=get_track_meta()) # type: ignore + _device = self.device + spatial_dims = len(spatial_size) # type: ignore _b = TransformBackends.TORCH - _device = grid_.device # type: ignore - affine: NdarrayOrTensor + affine: torch.Tensor if self.affine is None: - spatial_dims = len(grid_.shape) - 1 affine = torch.eye(spatial_dims + 1, device=_device) if self.rotate_params: - affine = affine @ create_rotate(spatial_dims, self.rotate_params, device=_device, backend=_b) + affine @= create_rotate(spatial_dims, self.rotate_params, device=_device, backend=_b) if self.shear_params: - affine = affine @ create_shear(spatial_dims, self.shear_params, device=_device, backend=_b) + affine @= create_shear(spatial_dims, self.shear_params, device=_device, backend=_b) if self.translate_params: - affine = affine @ create_translate(spatial_dims, self.translate_params, device=_device, backend=_b) + affine @= create_translate(spatial_dims, self.translate_params, device=_device, backend=_b) if self.scale_params: - affine = affine @ create_scale(spatial_dims, self.scale_params, device=_device, backend=_b) + affine @= create_scale(spatial_dims, self.scale_params, device=_device, backend=_b) else: - affine = self.affine + affine = self.affine # type: ignore + affine = to_affine_nd(spatial_dims, affine) + if self.lazy_evaluation: + return None, affine - affine = to_affine_nd(len(grid_) - 1, affine) affine = convert_to_tensor(affine, device=grid_.device, dtype=grid_.dtype, track_meta=False) # type: ignore - grid_ = (affine @ grid_.reshape((grid_.shape[0], -1))).reshape([-1] + list(grid_.shape[1:])) - return grid_, affine # type: ignore + if self.align_corners: + sc = create_scale(spatial_dims, [d / (d - 1) for d in grid_.shape[1:]], device=_device, backend=_b) + sc = convert_to_dst_type(sc, affine)[0] + grid_ = ((affine @ sc) @ grid_.view((grid_.shape[0], -1))).view([-1] + list(grid_.shape[1:])) + else: + grid_ = (affine @ grid_.view((grid_.shape[0], -1))).view([-1] + list(grid_.shape[1:])) + return grid_, affine -class RandAffineGrid(Randomizable, Transform): +class RandAffineGrid(Randomizable, LazyTransform): """ Generate randomised affine grid. @@ -1849,7 +1612,8 @@ def __init__( shear_range: RandRange = None, translate_range: RandRange = None, scale_range: RandRange = None, - device: Optional[torch.device] = None, + device: torch.device | None = None, + dtype: DtypeLike = np.float32, ) -> None: """ Args: @@ -1876,6 +1640,8 @@ def __init__( the scale factor to translate for every spatial dims. A value of 1.0 is added to the result. This allows 0 to correspond to no change (i.e., a scaling of 1.0). device: device to store the output grid data. + dtype: data type for the grid computation. Defaults to ``np.float32``. + If ``None``, use the data type of input data (if `grid` is provided). See also: - :py:meth:`monai.transforms.utils.create_rotate` @@ -1889,36 +1655,34 @@ def __init__( self.translate_range = ensure_tuple(translate_range) self.scale_range = ensure_tuple(scale_range) - self.rotate_params: Optional[List[float]] = None - self.shear_params: Optional[List[float]] = None - self.translate_params: Optional[List[float]] = None - self.scale_params: Optional[List[float]] = None + self.rotate_params: list[float] | None = None + self.shear_params: list[float] | None = None + self.translate_params: list[float] | None = None + self.scale_params: list[float] | None = None self.device = device - self.affine: Optional[torch.Tensor] = torch.eye(4, dtype=torch.float64) + self.dtype = dtype + self.affine: torch.Tensor | None = torch.eye(4, dtype=torch.float64) def _get_rand_param(self, param_range, add_scalar: float = 0.0): out_param = [] for f in param_range: if issequenceiterable(f): if len(f) != 2: - raise ValueError("If giving range as [min,max], should only have two elements per dim.") + raise ValueError(f"If giving range as [min,max], should have 2 elements per dim, got {f}.") out_param.append(self.R.uniform(f[0], f[1]) + add_scalar) elif f is not None: out_param.append(self.R.uniform(-f, f) + add_scalar) return out_param - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, data: Any | None = None) -> None: self.rotate_params = self._get_rand_param(self.rotate_range) self.shear_params = self._get_rand_param(self.shear_range) self.translate_params = self._get_rand_param(self.translate_range) self.scale_params = self._get_rand_param(self.scale_range, 1.0) def __call__( - self, - spatial_size: Optional[Sequence[int]] = None, - grid: Optional[NdarrayOrTensor] = None, - randomize: bool = True, + self, spatial_size: Sequence[int] | None = None, grid: NdarrayOrTensor | None = None, randomize: bool = True ) -> torch.Tensor: """ Args: @@ -1937,12 +1701,17 @@ def __call__( translate_params=self.translate_params, scale_params=self.scale_params, device=self.device, + dtype=self.dtype, ) + affine_grid.lazy_evaluation = self.lazy_evaluation + if self.lazy_evaluation: # return the affine only, don't construct the grid + self.affine = affine_grid(spatial_size, grid)[1] # type: ignore + return None # type: ignore _grid: torch.Tensor _grid, self.affine = affine_grid(spatial_size, grid) # type: ignore return _grid - def get_transformation_matrix(self) -> Optional[torch.Tensor]: + def get_transformation_matrix(self) -> torch.Tensor | None: """Get the most recently applied transformation matrix""" return self.affine @@ -1954,13 +1723,8 @@ class RandDeformGrid(Randomizable, Transform): backend = [TransformBackends.TORCH] - @deprecated_arg(name="as_tensor_output", since="0.8") def __init__( - self, - spacing: Union[Sequence[float], float], - magnitude_range: Tuple[float, float], - as_tensor_output: bool = True, - device: Optional[torch.device] = None, + self, spacing: Sequence[float] | float, magnitude_range: tuple[float, float], device: torch.device | None = None ) -> None: """ Args: @@ -1997,15 +1761,15 @@ def __call__(self, spatial_size: Sequence[int]) -> torch.Tensor: class Resample(Transform): - backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( self, - mode: Union[str, int] = GridSampleMode.BILINEAR, + mode: str | int = GridSampleMode.BILINEAR, padding_mode: str = GridSamplePadMode.BORDER, norm_coords: bool = True, - device: Optional[torch.device] = None, + device: torch.device | None = None, + align_corners: bool = False, dtype: DtypeLike = np.float64, ) -> None: """ @@ -2035,6 +1799,8 @@ def __init__( `[-1, 1]` (for torch ``grid_sample`` implementation) to be compatible with the underlying resampling API. device: device on which the tensor will be allocated. + align_corners: Defaults to False. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html dtype: data type for resampling computation. Defaults to ``float64`` for best precision. If ``None``, use the data type of input data. To be compatible with other modules, the output data type is always `float32`. @@ -2044,15 +1810,46 @@ def __init__( self.padding_mode = padding_mode self.norm_coords = norm_coords self.device = device + self.align_corners = align_corners self.dtype = dtype + @staticmethod + @functools.lru_cache(None) + def resolve_modes(interp_mode, padding_mode): + """compute the backend and the corresponding mode for the given interpolation mode and padding mode.""" + _interp_mode = None + _padding_mode = None + if look_up_option(str(interp_mode), SplineMode, default=None) is not None: + backend = TransformBackends.NUMPY + else: + backend = TransformBackends.TORCH + + if (not USE_COMPILED) and (backend == TransformBackends.TORCH): + if str(interp_mode).lower().endswith("linear"): + _interp_mode = GridSampleMode("bilinear") + _interp_mode = GridSampleMode(interp_mode) + _padding_mode = GridSamplePadMode(padding_mode) + elif USE_COMPILED and backend == TransformBackends.TORCH: # compiled is using torch backend param name + _padding_mode = 1 if padding_mode == "reflection" else padding_mode # type: ignore + if interp_mode == "bicubic": + _interp_mode = 3 # type: ignore + elif interp_mode == "bilinear": + _interp_mode = 1 # type: ignore + else: + _interp_mode = GridSampleMode(interp_mode) + else: # TransformBackends.NUMPY + _interp_mode = int(interp_mode) # type: ignore + _padding_mode = look_up_option(padding_mode, NdimageMode) + return backend, _interp_mode, _padding_mode + def __call__( self, img: torch.Tensor, - grid: Optional[torch.Tensor] = None, - mode: Union[str, int, None] = None, - padding_mode: Optional[str] = None, + grid: torch.Tensor | None = None, + mode: str | int | None = None, + padding_mode: str | None = None, dtype: DtypeLike = None, + align_corners: bool | None = None, ) -> torch.Tensor: """ Args: @@ -2080,6 +1877,8 @@ def __call__( See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html dtype: data type for resampling computation. Defaults to ``self.dtype``. To be compatible with other modules, the output data type is always `float32`. + align_corners: Defaults to ``self.align_corners``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html See also: :py:const:`monai.config.USE_COMPILED` @@ -2087,74 +1886,65 @@ def __call__( img = convert_to_tensor(img, track_meta=get_track_meta()) if grid is None: return img + _device = img.device if isinstance(img, torch.Tensor) else self.device _dtype = dtype or self.dtype or img.dtype + _align_corners = self.align_corners if align_corners is None else align_corners img_t, *_ = convert_data_type(img, torch.Tensor, dtype=_dtype, device=_device) - grid_t, *_ = convert_to_dst_type(grid, img_t, dtype=grid.dtype, wrap_sequence=True) - grid_t = grid_t.clone(memory_format=torch.contiguous_format) - - if self.norm_coords: - grid_t[-1] = where(grid_t[-1] != 0, grid_t[-1], 1.0) # type: ignore - sr = min(len(img_t.shape[1:]), 3) - - _interp_mode = self.mode if mode is None else mode - _padding_mode = self.padding_mode if padding_mode is None else padding_mode - if look_up_option(str(_interp_mode), SplineMode, default=None) is not None: - self._backend = TransformBackends.NUMPY - else: - self._backend = TransformBackends.TORCH + sr = min(len(img_t.peek_pending_shape() if isinstance(img_t, MetaTensor) else img_t.shape[1:]), 3) + backend, _interp_mode, _padding_mode = Resample.resolve_modes( + self.mode if mode is None else mode, self.padding_mode if padding_mode is None else padding_mode + ) - if USE_COMPILED or self._backend == TransformBackends.NUMPY: - if self.norm_coords: - for i, dim in enumerate(img_t.shape[1 : 1 + sr]): - grid_t[i] = (max(dim, 2) / 2.0 - 0.5 + grid_t[i]) / grid_t[-1:] - grid_t = grid_t[:sr] - if USE_COMPILED and self._backend == TransformBackends.TORCH: # compiled is using torch backend param name + if USE_COMPILED or backend == TransformBackends.NUMPY: + grid_t, *_ = convert_to_dst_type(grid[:sr], img_t, dtype=grid.dtype, wrap_sequence=True) + if isinstance(grid, torch.Tensor) and grid_t.data_ptr() == grid.data_ptr(): + grid_t = grid_t.clone(memory_format=torch.contiguous_format) + for i, dim in enumerate(img_t.shape[1 : 1 + sr]): + _dim = max(2, dim) + t = (_dim - 1) / 2.0 + if self.norm_coords: + grid_t[i] = ((_dim - 1) / _dim) * grid_t[i] + t if _align_corners else grid_t[i] + t + elif _align_corners: + grid_t[i] = ((_dim - 1) / _dim) * (grid_t[i] + 0.5) + if USE_COMPILED and backend == TransformBackends.TORCH: # compiled is using torch backend param name grid_t = moveaxis(grid_t, 0, -1) # type: ignore - bound = 1 if _padding_mode == "reflection" else _padding_mode - if _interp_mode == "bicubic": - interp = 3 - elif _interp_mode == "bilinear": - interp = 1 - else: - interp = GridSampleMode(_interp_mode) # type: ignore out = grid_pull( img_t.unsqueeze(0), grid_t.unsqueeze(0).to(img_t), - bound=bound, + bound=_padding_mode, extrapolate=True, - interpolation=interp, + interpolation=_interp_mode, )[0] - elif self._backend == TransformBackends.NUMPY: + elif backend == TransformBackends.NUMPY: is_cuda = img_t.is_cuda img_np = (convert_to_cupy if is_cuda else convert_to_numpy)(img_t, wrap_sequence=True) - grid_np, *_ = convert_to_dst_type(grid_t, img_np, wrap_sequence=True) + grid_np, *_ = convert_to_dst_type(grid_t, img_np, dtype=grid_t.dtype, wrap_sequence=True) _map_coord = (cupy_ndi if is_cuda else np_ndi).map_coordinates out = (cupy if is_cuda else np).stack( - [ - _map_coord(c, grid_np, order=int(_interp_mode), mode=look_up_option(_padding_mode, NdimageMode)) - for c in img_np - ] + [_map_coord(c, grid_np, order=_interp_mode, mode=_padding_mode) for c in img_np] ) out = convert_to_dst_type(out, img_t)[0] else: + grid_t = moveaxis(grid[list(range(sr - 1, -1, -1))], 0, -1) # type: ignore + grid_t = convert_to_dst_type(grid_t, img_t, wrap_sequence=True)[0].unsqueeze(0) + if isinstance(grid, torch.Tensor) and grid_t.data_ptr() == grid.data_ptr(): + grid_t = grid_t.clone(memory_format=torch.contiguous_format) if self.norm_coords: - for i, dim in enumerate(img_t.shape[1 : 1 + sr]): - grid_t[i] = 2.0 / (max(2, dim) - 1.0) * grid_t[i] / grid_t[-1:] - index_ordering: List[int] = list(range(sr - 1, -1, -1)) - grid_t = moveaxis(grid_t[index_ordering], 0, -1) # type: ignore + for i, dim in enumerate(img_t.shape[sr + 1 : 0 : -1]): + grid_t[0, ..., i] *= 2.0 / max(2, dim) out = torch.nn.functional.grid_sample( img_t.unsqueeze(0), - grid_t.unsqueeze(0).to(img_t), - mode=GridSampleMode(_interp_mode), - padding_mode=GridSamplePadMode(_padding_mode), - align_corners=True, + grid_t, + mode=_interp_mode, + padding_mode=_padding_mode, + align_corners=None if _align_corners == TraceKeys.NONE else _align_corners, # type: ignore )[0] out_val, *_ = convert_to_dst_type(out, dst=img, dtype=np.float32) return out_val -class Affine(InvertibleTransform): +class Affine(InvertibleTransform, LazyTransform): """ Transform ``img`` given the affine parameters. A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/0.6.0/modules/transforms_demo_2d.ipynb. @@ -2163,21 +1953,20 @@ class Affine(InvertibleTransform): backend = list(set(AffineGrid.backend) & set(Resample.backend)) - @deprecated_arg(name="norm_coords", since="0.8") def __init__( self, - rotate_params: Optional[Union[Sequence[float], float]] = None, - shear_params: Optional[Union[Sequence[float], float]] = None, - translate_params: Optional[Union[Sequence[float], float]] = None, - scale_params: Optional[Union[Sequence[float], float]] = None, - affine: Optional[NdarrayOrTensor] = None, - spatial_size: Optional[Union[Sequence[int], int]] = None, - mode: Union[str, int] = GridSampleMode.BILINEAR, + rotate_params: Sequence[float] | float | None = None, + shear_params: Sequence[float] | float | None = None, + translate_params: Sequence[float] | float | None = None, + scale_params: Sequence[float] | float | None = None, + affine: NdarrayOrTensor | None = None, + spatial_size: Sequence[int] | int | None = None, + mode: str | int = GridSampleMode.BILINEAR, padding_mode: str = GridSamplePadMode.REFLECTION, normalized: bool = False, - norm_coords: bool = True, - device: Optional[torch.device] = None, + device: torch.device | None = None, dtype: DtypeLike = np.float32, + align_corners: bool = False, image_only: bool = False, ) -> None: """ @@ -2230,12 +2019,10 @@ def __init__( dtype: data type for resampling computation. Defaults to ``float32``. If ``None``, use the data type of input data. To be compatible with other modules, the output data type is always `float32`. + align_corners: Defaults to False. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html image_only: if True return only the image volume, otherwise return (image, affine). - .. deprecated:: 0.8.1 - ``norm_coords`` is deprecated, please use ``normalized`` instead - (the new flag is a negation, i.e., ``norm_coords == not normalized``). - """ self.affine_grid = AffineGrid( rotate_params=rotate_params, @@ -2244,22 +2031,28 @@ def __init__( scale_params=scale_params, affine=affine, dtype=dtype, + align_corners=align_corners, device=device, ) self.image_only = image_only self.norm_coord = not normalized - self.resampler = Resample(norm_coords=self.norm_coord, device=device, dtype=dtype) + self.resampler = Resample(norm_coords=self.norm_coord, device=device, dtype=dtype, align_corners=align_corners) self.spatial_size = spatial_size self.mode = mode self.padding_mode: str = padding_mode + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool) -> None: + self.affine_grid.lazy_evaluation = val + self._lazy_evaluation = val + def __call__( self, img: torch.Tensor, - spatial_size: Optional[Union[Sequence[int], int]] = None, - mode: Union[str, int, None] = None, - padding_mode: Optional[str] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, NdarrayOrTensor]]: + spatial_size: Sequence[int] | int | None = None, + mode: str | int | None = None, + padding_mode: str | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, NdarrayOrTensor]: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -2282,34 +2075,33 @@ def __call__( See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html """ img = convert_to_tensor(img, track_meta=get_track_meta()) - img_size = img.shape[1:] + img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] sp_size = fall_back_tuple(self.spatial_size if spatial_size is None else spatial_size, img_size) _mode = mode if mode is not None else self.mode _padding_mode = padding_mode if padding_mode is not None else self.padding_mode grid, affine = self.affine_grid(spatial_size=sp_size) - out = self.resampler(img, grid=grid, mode=_mode, padding_mode=_padding_mode) - if not isinstance(out, MetaTensor): - return out if self.image_only else (out, affine) - if get_track_meta(): - out.meta = img.meta # type: ignore - self.update_meta(out, affine, img_size, sp_size) - self.push_transform( - out, orig_size=img_size, extra_info={"affine": affine, "mode": _mode, "padding_mode": _padding_mode} - ) - return out if self.image_only else (out, affine) + + return affine_func( # type: ignore + img, + affine, + grid, + self.resampler, + sp_size, + _mode, + _padding_mode, + True, + self.image_only, + self.get_transform_info(), + ) @classmethod - def compute_w_affine(cls, affine, mat, img_size, sp_size): - r = len(affine) - 1 + def compute_w_affine(cls, spatial_rank, mat, img_size, sp_size): + r = int(spatial_rank) mat = to_affine_nd(r, mat) shift_1 = create_translate(r, [float(d - 1) / 2 for d in img_size[:r]]) shift_2 = create_translate(r, [-float(d - 1) / 2 for d in sp_size[:r]]) mat = shift_1 @ convert_data_type(mat, np.ndarray)[0] @ shift_2 - return affine @ convert_to_dst_type(mat, affine)[0] - - def update_meta(self, img, mat, img_size, sp_size): - affine = convert_data_type(img.affine, torch.Tensor)[0] - img.affine = Affine.compute_w_affine(affine, mat, img_size, sp_size) + return mat def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) @@ -2318,21 +2110,26 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: fwd_affine = transform[TraceKeys.EXTRA_INFO]["affine"] mode = transform[TraceKeys.EXTRA_INFO]["mode"] padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] - inv_affine = linalg_inv(fwd_affine) + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] + inv_affine = linalg_inv(convert_to_numpy(fwd_affine)) inv_affine = convert_to_dst_type(inv_affine, data, dtype=inv_affine.dtype)[0] - affine_grid = AffineGrid(affine=inv_affine) + affine_grid = AffineGrid(affine=inv_affine, align_corners=align_corners) grid, _ = affine_grid(orig_size) # Apply inverse transform - out = self.resampler(data, grid, mode, padding_mode) + out = self.resampler(data, grid, mode, padding_mode, align_corners=align_corners) if not isinstance(out, MetaTensor): out = MetaTensor(out) out.meta = data.meta # type: ignore - self.update_meta(out, inv_affine, data.shape[1:], orig_size) + affine = convert_data_type(out.peek_pending_affine(), torch.Tensor)[0] + xform, *_ = convert_to_dst_type( + Affine.compute_w_affine(len(affine) - 1, inv_affine, data.shape[1:], orig_size), affine + ) + out.affine @= xform return out -class RandAffine(RandomizableTransform, InvertibleTransform): +class RandAffine(RandomizableTransform, InvertibleTransform, LazyTransform): """ Random affine transform. A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/0.6.0/modules/transforms_demo_2d.ipynb. @@ -2348,11 +2145,11 @@ def __init__( shear_range: RandRange = None, translate_range: RandRange = None, scale_range: RandRange = None, - spatial_size: Optional[Union[Sequence[int], int]] = None, - mode: Union[str, int] = GridSampleMode.BILINEAR, + spatial_size: Sequence[int] | int | None = None, + mode: str | int = GridSampleMode.BILINEAR, padding_mode: str = GridSamplePadMode.REFLECTION, cache_grid: bool = False, - device: Optional[torch.device] = None, + device: torch.device | None = None, ) -> None: """ Args: @@ -2425,10 +2222,17 @@ def __init__( self.mode = mode self.padding_mode: str = padding_mode + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool) -> None: + self._lazy_evaluation = val + self.rand_affine_grid.lazy_evaluation = val + def _init_identity_cache(self): """ Create cache of the identity grid if cache_grid=True and spatial_size is known. """ + if self.lazy_evaluation: + return None if self.spatial_size is None: if self.cache_grid: warnings.warn( @@ -2454,6 +2258,8 @@ def get_identity_grid(self, spatial_size: Sequence[int]): Args: spatial_size: non-dynamic spatial size """ + if self.lazy_evaluation: + return None ndim = len(spatial_size) if spatial_size != fall_back_tuple(spatial_size, [1] * ndim) or spatial_size != fall_back_tuple( spatial_size, [2] * ndim @@ -2465,14 +2271,12 @@ def get_identity_grid(self, spatial_size: Sequence[int]): else self._cached_grid ) - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandAffine": + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandAffine: self.rand_affine_grid.set_random_state(seed, state) super().set_random_state(seed, state) return self - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, data: Any | None = None) -> None: super().randomize(None) if not self._do_transform: return None @@ -2481,9 +2285,9 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__( self, img: torch.Tensor, - spatial_size: Optional[Union[Sequence[int], int]] = None, - mode: Union[str, int, None] = None, - padding_mode: Optional[str] = None, + spatial_size: Sequence[int] | int | None = None, + mode: str | int | None = None, + padding_mode: str | None = None, randomize: bool = True, grid=None, ) -> torch.Tensor: @@ -2515,38 +2319,35 @@ def __call__( self.randomize() # if not doing transform and spatial size doesn't change, nothing to do # except convert to float and device - sp_size = fall_back_tuple(self.spatial_size if spatial_size is None else spatial_size, img.shape[1:]) - do_resampling = self._do_transform or (sp_size != ensure_tuple(img.shape[1:])) + ori_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] + sp_size = fall_back_tuple(self.spatial_size if spatial_size is None else spatial_size, ori_size) + do_resampling = self._do_transform or (sp_size != ensure_tuple(ori_size)) _mode = mode if mode is not None else self.mode _padding_mode = padding_mode if padding_mode is not None else self.padding_mode img = convert_to_tensor(img, track_meta=get_track_meta()) - if not do_resampling: - out: torch.Tensor = convert_data_type(img, dtype=torch.float32, device=self.resampler.device)[0] + if self.lazy_evaluation: + if self._do_transform: + affine = self.rand_affine_grid.get_transformation_matrix() + else: + affine = convert_to_dst_type(torch.eye(len(sp_size) + 1), img, dtype=self.rand_affine_grid.dtype)[0] else: if grid is None: grid = self.get_identity_grid(sp_size) if self._do_transform: grid = self.rand_affine_grid(grid=grid, randomize=randomize) - out = self.resampler(img=img, grid=grid, mode=_mode, padding_mode=_padding_mode) - mat = self.rand_affine_grid.get_transformation_matrix() - out = convert_to_tensor(out, track_meta=get_track_meta()) - if get_track_meta(): - self.push_transform( - out, - orig_size=img.shape[1:], - extra_info={ - "affine": mat, - "mode": _mode, - "padding_mode": _padding_mode, - "do_resampling": do_resampling, - }, - ) - self.update_meta(out, mat, img.shape[1:], sp_size) - return out - - def update_meta(self, img, mat, img_size, sp_size): - affine = convert_data_type(img.affine, torch.Tensor)[0] - img.affine = Affine.compute_w_affine(affine, mat, img_size, sp_size) + affine = self.rand_affine_grid.get_transformation_matrix() + return affine_func( # type: ignore + img, + affine, + grid, + self.resampler, + sp_size, + _mode, + _padding_mode, + do_resampling, + True, + self.get_transform_info(), + ) def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) @@ -2559,7 +2360,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: fwd_affine = transform[TraceKeys.EXTRA_INFO]["affine"] mode = transform[TraceKeys.EXTRA_INFO]["mode"] padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] - inv_affine = linalg_inv(fwd_affine) + inv_affine = linalg_inv(convert_to_numpy(fwd_affine)) inv_affine = convert_to_dst_type(inv_affine, data, dtype=inv_affine.dtype)[0] affine_grid = AffineGrid(affine=inv_affine) grid, _ = affine_grid(orig_size) @@ -2569,7 +2370,11 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: if not isinstance(out, MetaTensor): out = MetaTensor(out) out.meta = data.meta # type: ignore - self.update_meta(out, inv_affine, data.shape[1:], orig_size) + affine = convert_data_type(out.peek_pending_affine(), torch.Tensor)[0] + xform, *_ = convert_to_dst_type( + Affine.compute_w_affine(len(affine) - 1, inv_affine, data.shape[1:], orig_size), affine + ) + out.affine @= xform return out @@ -2584,17 +2389,17 @@ class Rand2DElastic(RandomizableTransform): def __init__( self, - spacing: Union[Tuple[float, float], float], - magnitude_range: Tuple[float, float], + spacing: tuple[float, float] | float, + magnitude_range: tuple[float, float], prob: float = 0.1, rotate_range: RandRange = None, shear_range: RandRange = None, translate_range: RandRange = None, scale_range: RandRange = None, - spatial_size: Optional[Union[Tuple[int, int], int]] = None, - mode: Union[str, int] = GridSampleMode.BILINEAR, + spatial_size: tuple[int, int] | int | None = None, + mode: str | int = GridSampleMode.BILINEAR, padding_mode: str = GridSamplePadMode.REFLECTION, - device: Optional[torch.device] = None, + device: torch.device | None = None, ) -> None: """ Args: @@ -2664,9 +2469,7 @@ def __init__( self.mode = mode self.padding_mode: str = padding_mode - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "Rand2DElastic": + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Rand2DElastic: self.deform_grid.set_random_state(seed, state) self.rand_affine_grid.set_random_state(seed, state) super().set_random_state(seed, state) @@ -2688,9 +2491,9 @@ def randomize(self, spatial_size: Sequence[int]) -> None: def __call__( self, img: torch.Tensor, - spatial_size: Optional[Union[Tuple[int, int], int]] = None, - mode: Union[str, int, None] = None, - padding_mode: Optional[str] = None, + spatial_size: tuple[int, int] | int | None = None, + mode: str | int | None = None, + padding_mode: str | None = None, randomize: bool = True, ) -> torch.Tensor: """ @@ -2730,7 +2533,7 @@ def __call__( grid = CenterSpatialCrop(roi_size=sp_size)(grid[0]) else: _device = img.device if isinstance(img, torch.Tensor) else self.device - grid = create_grid(spatial_size=sp_size, device=_device, backend="torch") + grid = cast(torch.Tensor, create_grid(spatial_size=sp_size, device=_device, backend="torch")) out: torch.Tensor = self.resampler( img, grid, @@ -2751,17 +2554,17 @@ class Rand3DElastic(RandomizableTransform): def __init__( self, - sigma_range: Tuple[float, float], - magnitude_range: Tuple[float, float], + sigma_range: tuple[float, float], + magnitude_range: tuple[float, float], prob: float = 0.1, rotate_range: RandRange = None, shear_range: RandRange = None, translate_range: RandRange = None, scale_range: RandRange = None, - spatial_size: Optional[Union[Tuple[int, int, int], int]] = None, - mode: Union[str, int] = GridSampleMode.BILINEAR, + spatial_size: tuple[int, int, int] | int | None = None, + mode: str | int = GridSampleMode.BILINEAR, padding_mode: str = GridSamplePadMode.REFLECTION, - device: Optional[torch.device] = None, + device: torch.device | None = None, ) -> None: """ Args: @@ -2839,9 +2642,7 @@ def __init__( self.magnitude = 1.0 self.sigma = 1.0 - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "Rand3DElastic": + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Rand3DElastic: self.rand_affine_grid.set_random_state(seed, state) super().set_random_state(seed, state) return self @@ -2863,9 +2664,9 @@ def randomize(self, grid_size: Sequence[int]) -> None: def __call__( self, img: torch.Tensor, - spatial_size: Optional[Union[Tuple[int, int, int], int]] = None, - mode: Union[str, int, None] = None, - padding_mode: Optional[str] = None, + spatial_size: tuple[int, int, int] | int | None = None, + mode: str | int | None = None, + padding_mode: str | None = None, randomize: bool = True, ) -> torch.Tensor: """ @@ -2911,16 +2712,15 @@ def __call__( class GridDistortion(Transform): - backend = [TransformBackends.TORCH] def __init__( self, - num_cells: Union[Tuple[int], int], + num_cells: tuple[int] | int, distort_steps: Sequence[Sequence[float]], - mode: Union[str, int] = GridSampleMode.BILINEAR, + mode: str | int = GridSampleMode.BILINEAR, padding_mode: str = GridSamplePadMode.BORDER, - device: Optional[torch.device] = None, + device: torch.device | None = None, ) -> None: """ Grid distortion transform. Refer to: @@ -2954,9 +2754,9 @@ def __init__( def __call__( self, img: torch.Tensor, - distort_steps: Optional[Sequence[Sequence]] = None, - mode: Optional[str] = None, - padding_mode: Optional[str] = None, + distort_steps: Sequence[Sequence] | None = None, + mode: str | None = None, + padding_mode: str | None = None, ) -> torch.Tensor: """ Args: @@ -2984,6 +2784,8 @@ def __call__( all_ranges = [] num_cells = ensure_tuple_rep(self.num_cells, len(img.shape) - 1) + if isinstance(img, MetaTensor) and img.pending_operations: + warnings.warn("MetaTensor img has pending operations, transform may return incorrect results.") for dim_idx, dim_size in enumerate(img.shape[1:]): dim_distort_steps = distort_steps[dim_idx] ranges = torch.zeros(dim_size, dtype=torch.float32) @@ -3009,17 +2811,16 @@ def __call__( class RandGridDistortion(RandomizableTransform): - backend = [TransformBackends.TORCH] def __init__( self, - num_cells: Union[Tuple[int], int] = 5, + num_cells: tuple[int] | int = 5, prob: float = 0.1, - distort_limit: Union[Tuple[float, float], float] = (-0.03, 0.03), - mode: Union[str, int] = GridSampleMode.BILINEAR, + distort_limit: tuple[float, float] | float = (-0.03, 0.03), + mode: str | int = GridSampleMode.BILINEAR, padding_mode: str = GridSamplePadMode.BORDER, - device: Optional[torch.device] = None, + device: torch.device | None = None, ) -> None: """ Random grid distortion transform. Refer to: @@ -3067,7 +2868,7 @@ def randomize(self, spatial_shape: Sequence[int]) -> None: ) def __call__( - self, img: torch.Tensor, mode: Optional[str] = None, padding_mode: Optional[str] = None, randomize: bool = True + self, img: torch.Tensor, mode: str | None = None, padding_mode: str | None = None, randomize: bool = True ) -> torch.Tensor: """ Args: @@ -3087,13 +2888,15 @@ def __call__( randomize: whether to shuffle the random factors using `randomize()`, default to True. """ if randomize: + if isinstance(img, MetaTensor) and img.pending_operations: + warnings.warn("MetaTensor img has pending operations, transform may return incorrect results.") self.randomize(img.shape[1:]) if not self._do_transform: return convert_to_tensor(img, track_meta=get_track_meta()) # type: ignore return self.grid_distortion(img, distort_steps=self.distort_steps, mode=mode, padding_mode=padding_mode) -class GridSplit(Transform): +class GridSplit(Transform, MultiSampleTrait): """ Split the image into patches based on the provided grid in 2D. @@ -3113,7 +2916,7 @@ class GridSplit(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, grid: Tuple[int, int] = (2, 2), size: Optional[Union[int, Tuple[int, int]]] = None): + def __init__(self, grid: tuple[int, int] = (2, 2), size: int | tuple[int, int] | None = None): # Grid size self.grid = grid @@ -3121,15 +2924,16 @@ def __init__(self, grid: Tuple[int, int] = (2, 2), size: Optional[Union[int, Tup self.size = None if size is None else ensure_tuple_rep(size, len(self.grid)) def __call__( - self, image: NdarrayOrTensor, size: Optional[Union[int, Tuple[int, int], np.ndarray]] = None - ) -> List[NdarrayOrTensor]: + self, image: NdarrayOrTensor, size: int | tuple[int, int] | np.ndarray | None = None + ) -> list[NdarrayOrTensor]: input_size = self.size if size is None else ensure_tuple_rep(size, len(self.grid)) if self.grid == (1, 1) and input_size is None: return [image] - + if isinstance(image, MetaTensor) and image.pending_operations: + warnings.warn("MetaTensor img has pending operations, transform may return incorrect results.") split_size, steps = self._get_params(image.shape[1:], input_size) - patches: List[NdarrayOrTensor] + patches: list[NdarrayOrTensor] as_strided_func: Callable if isinstance(image, torch.Tensor): as_strided_func = torch.as_strided @@ -3157,9 +2961,7 @@ def __call__( return patches - def _get_params( - self, image_size: Union[Sequence[int], np.ndarray], size: Optional[Union[Sequence[int], np.ndarray]] = None - ): + def _get_params(self, image_size: Sequence[int] | np.ndarray, size: Sequence[int] | np.ndarray | None = None): """ Calculate the size and step required for splitting the image Args: @@ -3180,7 +2982,7 @@ def _get_params( return size, steps -class GridPatch(Transform): +class GridPatch(Transform, MultiSampleTrait): """ Extract all the patches sweeping the entire image in a row-major sliding-window manner with possible overlaps. It can sort the patches and return all or a subset of them. @@ -3188,19 +2990,32 @@ class GridPatch(Transform): Args: patch_size: size of patches to generate slices for, 0 or None selects whole dimension offset: offset of starting position in the array, default is 0 for each dimension. - num_patches: number of patches to return. Defaults to None, which returns all the available patches. - If the required patches are more than the available patches, padding will be applied. + num_patches: number of patches (or maximum number of patches) to return. + If the requested number of patches is greater than the number of available patches, + padding will be applied to provide exactly `num_patches` patches unless `threshold` is set. + When `threshold` is set, this value is treated as the maximum number of patches. + Defaults to None, which does not limit number of the patches. overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`), lowest values (`"min"`), or in their default order (`None`). Default to None. threshold: a value to keep only the patches whose sum of intensities are less than the threshold. Defaults to no filtering. - pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``. + pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries. + Defaults to None, which means no padding will be applied. + Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, + ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}. + See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. Returns: - MetaTensor: A MetaTensor consisting of a batch of all the patches with associated metadata + MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension), + with following metadata: + + - `PatchKeys.LOCATION`: the starting location of the patch in the image, + - `PatchKeys.COUNT`: total number of patches in the image, + - "spatial_shape": spatial size of the extracted patch, and + - "offset": the amount of offset for the patches in the image (starting position of the first patch) """ @@ -3209,17 +3024,17 @@ class GridPatch(Transform): def __init__( self, patch_size: Sequence[int], - offset: Optional[Sequence[int]] = None, - num_patches: Optional[int] = None, - overlap: Union[Sequence[float], float] = 0.0, - sort_fn: Optional[str] = None, - threshold: Optional[float] = None, - pad_mode: str = PytorchPadMode.CONSTANT, + offset: Sequence[int] | None = None, + num_patches: int | None = None, + overlap: Sequence[float] | float = 0.0, + sort_fn: str | None = None, + threshold: float | None = None, + pad_mode: str | None = None, **pad_kwargs, ): self.patch_size = ensure_tuple(patch_size) self.offset = ensure_tuple(offset) if offset else (0,) * len(self.patch_size) - self.pad_mode: Optional[NumpyPadMode] = convert_pad_mode(dst=np.zeros(1), mode=pad_mode) if pad_mode else None + self.pad_mode: NumpyPadMode | None = convert_pad_mode(dst=np.zeros(1), mode=pad_mode) if pad_mode else None self.pad_kwargs = pad_kwargs self.overlap = overlap self.num_patches = num_patches @@ -3228,24 +3043,26 @@ def __init__( def filter_threshold(self, image_np: np.ndarray, locations: np.ndarray): """ - Filter the patches and their locations according to a threshold + Filter the patches and their locations according to a threshold. + Args: - image_np: a numpy.ndarray representing a stack of patches - locations: a numpy.ndarray representing the stack of location of each patch + image_np: a numpy.ndarray representing a stack of patches. + locations: a numpy.ndarray representing the stack of location of each patch. + + Returns: + tuple[numpy.ndarray, numpy.ndarray]: tuple of filtered patches and locations. """ - if self.threshold is not None: - n_dims = len(image_np.shape) - idx = np.argwhere(image_np.sum(axis=tuple(range(1, n_dims))) < self.threshold).reshape(-1) - image_np = image_np[idx] - locations = locations[idx] - return image_np, locations + n_dims = len(image_np.shape) + idx = np.argwhere(image_np.sum(axis=tuple(range(1, n_dims))) < self.threshold).reshape(-1) + return image_np[idx], locations[idx] def filter_count(self, image_np: np.ndarray, locations: np.ndarray): """ Sort the patches based on the sum of their intensity, and just keep `self.num_patches` of them. + Args: - image_np: a numpy.ndarray representing a stack of patches - locations: a numpy.ndarray representing the stack of location of each patch + image_np: a numpy.ndarray representing a stack of patches. + locations: a numpy.ndarray representing the stack of location of each patch. """ if self.sort_fn is None: image_np = image_np[: self.num_patches] @@ -3263,7 +3080,17 @@ def filter_count(self, image_np: np.ndarray, locations: np.ndarray): locations = locations[idx] return image_np, locations - def __call__(self, array: NdarrayOrTensor): + def __call__(self, array: NdarrayOrTensor) -> MetaTensor: + """ + Extract the patches (sweeping the entire image in a row-major sliding-window manner with possible overlaps). + + Args: + array: a input image as `numpy.ndarray` or `torch.Tensor` + + Return: + MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension), + with defined `PatchKeys.LOCATION` and `PatchKeys.COUNT` metadata. + """ # create the patch iterator which sweeps the image row-by-row array_np, *_ = convert_data_type(array, np.ndarray) patch_iterator = iter_patch( @@ -3279,35 +3106,38 @@ def __call__(self, array: NdarrayOrTensor): patched_image = np.array(patches[0]) locations = np.array(patches[1])[:, 1:, 0] # only keep the starting location - # Filter patches - if self.num_patches: - patched_image, locations = self.filter_count(patched_image, locations) - elif self.threshold: + # Apply threshold filtering + if self.threshold is not None: patched_image, locations = self.filter_threshold(patched_image, locations) - # Pad the patch list to have the requested number of patches + # Apply count filtering if self.num_patches: - padding = self.num_patches - len(patched_image) - if padding > 0: - patched_image = np.pad( - patched_image, - [[0, padding], [0, 0]] + [[0, 0]] * len(self.patch_size), - constant_values=self.pad_kwargs.get("constant_values", 0), - ) - locations = np.pad(locations, [[0, padding], [0, 0]], constant_values=0) + # Limit number of patches + patched_image, locations = self.filter_count(patched_image, locations) + # Pad the patch list to have the requested number of patches + if self.threshold is None: + padding = self.num_patches - len(patched_image) + if padding > 0: + patched_image = np.pad( + patched_image, + [[0, padding], [0, 0]] + [[0, 0]] * len(self.patch_size), + constant_values=self.pad_kwargs.get("constant_values", 0), + ) + locations = np.pad(locations, [[0, padding], [0, 0]], constant_values=0) # Convert to MetaTensor metadata = array.meta if isinstance(array, MetaTensor) else MetaTensor.get_default_meta() - metadata[WSIPatchKeys.LOCATION] = locations.T - metadata[WSIPatchKeys.COUNT] = len(locations) + metadata[PatchKeys.LOCATION] = locations.T + metadata[PatchKeys.COUNT] = len(locations) metadata["spatial_shape"] = np.tile(np.array(self.patch_size), (len(locations), 1)).T + metadata["offset"] = self.offset output = MetaTensor(x=patched_image, meta=metadata) output.is_batch = True return output -class RandGridPatch(GridPatch, RandomizableTransform): +class RandGridPatch(GridPatch, RandomizableTransform, MultiSampleTrait): """ Extract all the patches sweeping the entire image in a row-major sliding-window manner with possible overlaps, and with random offset for the minimal corner of the image, (0,0) for 2D and (0,0,0) for 3D. @@ -3318,18 +3148,32 @@ class RandGridPatch(GridPatch, RandomizableTransform): min_offset: the minimum range of offset to be selected randomly. Defaults to 0. max_offset: the maximum range of offset to be selected randomly. Defaults to image size modulo patch size. - num_patches: number of patches to return. Defaults to None, which returns all the available patches. + num_patches: number of patches (or maximum number of patches) to return. + If the requested number of patches is greater than the number of available patches, + padding will be applied to provide exactly `num_patches` patches unless `threshold` is set. + When `threshold` is set, this value is treated as the maximum number of patches. + Defaults to None, which does not limit number of the patches. overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`), lowest values (`"min"`), or in their default order (`None`). Default to None. threshold: a value to keep only the patches whose sum of intensities are less than the threshold. Defaults to no filtering. - pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``. + pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries. + Defaults to None, which means no padding will be applied. + Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, + ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}. + See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. Returns: - MetaTensor: A MetaTensor consisting of a batch of all the patches with associated metadata + MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension), + with following metadata: + + - `PatchKeys.LOCATION`: the starting location of the patch in the image, + - `PatchKeys.COUNT`: total number of patches in the image, + - "spatial_shape": spatial size of the extracted patch, and + - "offset": the amount of offset for the patches in the image (starting position of the first patch) """ @@ -3338,12 +3182,12 @@ class RandGridPatch(GridPatch, RandomizableTransform): def __init__( self, patch_size: Sequence[int], - min_offset: Optional[Union[Sequence[int], int]] = None, - max_offset: Optional[Union[Sequence[int], int]] = None, - num_patches: Optional[int] = None, - overlap: Union[Sequence[float], float] = 0.0, - sort_fn: Optional[str] = None, - threshold: Optional[float] = None, + min_offset: Sequence[int] | int | None = None, + max_offset: Sequence[int] | int | None = None, + num_patches: int | None = None, + overlap: Sequence[float] | float = 0.0, + sort_fn: str | None = None, + threshold: float | None = None, pad_mode: str = PytorchPadMode.CONSTANT, **pad_kwargs, ): diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 706e8d7f8b..36e86da903 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -15,7 +15,11 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ -from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union +from __future__ import annotations + +import warnings +from collections.abc import Hashable, Mapping, Sequence +from typing import Any, cast import numpy as np import torch @@ -50,7 +54,8 @@ SpatialResample, Zoom, ) -from monai.transforms.transform import MapTransform, RandomizableTransform +from monai.transforms.traits import MultiSampleTrait +from monai.transforms.transform import LazyTransform, MapTransform, RandomizableTransform from monai.transforms.utils import create_grid from monai.utils import ( GridSampleMode, @@ -62,7 +67,6 @@ ensure_tuple_rep, fall_back_tuple, ) -from monai.utils.deprecate_utils import deprecated_arg from monai.utils.enums import PytorchPadMode, TraceKeys from monai.utils.module import optional_import @@ -139,7 +143,7 @@ ] -class SpatialResampled(MapTransform, InvertibleTransform): +class SpatialResampled(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.SpatialResample`. @@ -156,20 +160,14 @@ class SpatialResampled(MapTransform, InvertibleTransform): backend = SpatialResample.backend - @deprecated_arg(name="meta_keys", since="0.9") - @deprecated_arg(name="meta_key_postfix", since="0.9") - @deprecated_arg(name="meta_src_keys", since="0.9") def __init__( self, keys: KeysCollection, mode: SequenceStr = GridSampleMode.BILINEAR, padding_mode: SequenceStr = GridSamplePadMode.BORDER, - align_corners: Union[Sequence[bool], bool] = False, - dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, - meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", - meta_src_keys: Optional[KeysCollection] = "src_affine", - dst_keys: Optional[KeysCollection] = "dst_affine", + align_corners: Sequence[bool] | bool = False, + dtype: Sequence[DtypeLike] | DtypeLike = np.float64, + dst_keys: KeysCollection | None = "dst_affine", allow_missing_keys: bool = False, ) -> None: """ @@ -207,9 +205,14 @@ def __init__( self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.dst_keys = ensure_tuple_rep(dst_keys, len(self.keys)) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: - d: Dict = dict(data) - for (key, mode, padding_mode, align_corners, dtype, dst_key) in self.key_iterator( + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool) -> None: + self._lazy_evaluation = val + self.sp_transform.lazy_evaluation = val + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d: dict = dict(data) + for key, mode, padding_mode, align_corners, dtype, dst_key in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype, self.dst_keys ): d[key] = self.sp_transform( @@ -223,28 +226,26 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc ) return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.sp_transform.inverse(d[key]) return d -class ResampleToMatchd(MapTransform, InvertibleTransform): +class ResampleToMatchd(MapTransform, InvertibleTransform, LazyTransform): """Dictionary-based wrapper of :py:class:`monai.transforms.ResampleToMatch`.""" backend = ResampleToMatch.backend - @deprecated_arg(name="template_key", since="0.9") def __init__( self, keys: KeysCollection, key_dst: str, - template_key: Optional[str] = None, mode: SequenceStr = GridSampleMode.BILINEAR, padding_mode: SequenceStr = GridSamplePadMode.BORDER, - align_corners: Union[Sequence[bool], bool] = False, - dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, + align_corners: Sequence[bool] | bool = False, + dtype: Sequence[DtypeLike] | DtypeLike = np.float64, allow_missing_keys: bool = False, ): """ @@ -282,9 +283,14 @@ def __init__( self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.resampler = ResampleToMatch() - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool) -> None: + self._lazy_evaluation = val + self.resampler.lazy_evaluation = val + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) - for (key, mode, padding_mode, align_corners, dtype) in self.key_iterator( + for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype ): d[key] = self.resampler( @@ -297,14 +303,14 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc ) return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.resampler.inverse(d[key]) return d -class Spacingd(MapTransform, InvertibleTransform): +class Spacingd(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Spacing`. @@ -320,23 +326,20 @@ class Spacingd(MapTransform, InvertibleTransform): backend = Spacing.backend - @deprecated_arg(name="meta_keys", since="0.9") - @deprecated_arg(name="meta_key_postfix", since="0.9") def __init__( self, keys: KeysCollection, - pixdim: Union[Sequence[float], float], + pixdim: Sequence[float] | float, diagonal: bool = False, mode: SequenceStr = GridSampleMode.BILINEAR, padding_mode: SequenceStr = GridSamplePadMode.BORDER, - align_corners: Union[Sequence[bool], bool] = False, - dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, + align_corners: Sequence[bool] | bool = False, + dtype: Sequence[DtypeLike] | DtypeLike = np.float64, scale_extent: bool = False, recompute_affine: bool = False, - meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", - min_pixdim: Union[Sequence[float], float, None] = None, - max_pixdim: Union[Sequence[float], float, None] = None, + min_pixdim: Sequence[float] | float | None = None, + max_pixdim: Sequence[float] | float | None = None, + ensure_same_shape: bool = True, allow_missing_keys: bool = False, ) -> None: """ @@ -394,6 +397,8 @@ def __init__( max_pixdim: maximal input spacing to be resampled. If provided, input image with a smaller spacing than this value will be kept in its original spacing (not be resampled to `pixdim`). Set it to `None` to use the value of `pixdim`. Default to `None`. + ensure_same_shape: when the inputs have the same spatial shape, and almost the same pixdim, + whether to ensure exactly the same output spatial shape. Default to True. allow_missing_keys: don't raise exception if key is missing. """ @@ -406,13 +411,29 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.scale_extent = ensure_tuple_rep(scale_extent, len(self.keys)) + self.ensure_same_shape = ensure_same_shape + + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool) -> None: + self._lazy_evaluation = val + self.spacing_transform.lazy_evaluation = val + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d: dict = dict(data) + + _init_shape, _pixdim, should_match = None, None, False + output_shape_k = None # tracking output shape - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: - d: Dict = dict(data) for key, mode, padding_mode, align_corners, dtype, scale_extent in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype, self.scale_extent ): - # resample array of each corresponding key + if self.ensure_same_shape and isinstance(d[key], MetaTensor): + if _init_shape is None and _pixdim is None: + _init_shape, _pixdim = d[key].peek_pending_shape(), d[key].pixdim + else: + should_match = np.allclose(_init_shape, d[key].peek_pending_shape()) and np.allclose( + _pixdim, d[key].pixdim, atol=1e-3 + ) d[key] = self.spacing_transform( data_array=d[key], mode=mode, @@ -420,17 +441,19 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc align_corners=align_corners, dtype=dtype, scale_extent=scale_extent, + output_spatial_shape=output_shape_k if should_match else None, ) + output_shape_k = d[key].peek_pending_shape() if isinstance(d[key], MetaTensor) else d[key].shape[1:] return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): - d[key] = self.spacing_transform.inverse(d[key]) + d[key] = self.spacing_transform.inverse(cast(torch.Tensor, d[key])) return d -class Orientationd(MapTransform, InvertibleTransform): +class Orientationd(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Orientation`. @@ -441,16 +464,12 @@ class Orientationd(MapTransform, InvertibleTransform): backend = Orientation.backend - @deprecated_arg(name="meta_keys", since="0.9") - @deprecated_arg(name="meta_key_postfix", since="0.9") def __init__( self, keys: KeysCollection, - axcodes: Optional[str] = None, + axcodes: str | None = None, as_closest_canonical: bool = False, - labels: Optional[Sequence[Tuple[str, str]]] = (("L", "R"), ("P", "A"), ("I", "S")), - meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", + labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")), allow_missing_keys: bool = False, ) -> None: """ @@ -473,20 +492,25 @@ def __init__( super().__init__(keys, allow_missing_keys) self.ornt_transform = Orientation(axcodes=axcodes, as_closest_canonical=as_closest_canonical, labels=labels) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: - d: Dict = dict(data) + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool) -> None: + self._lazy_evaluation = val + self.ornt_transform.lazy_evaluation = val + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d: dict = dict(data) for key in self.key_iterator(d): d[key] = self.ornt_transform(d[key]) return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.ornt_transform.inverse(d[key]) return d -class Rotate90d(MapTransform, InvertibleTransform): +class Rotate90d(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rotate90`. """ @@ -494,7 +518,7 @@ class Rotate90d(MapTransform, InvertibleTransform): backend = Rotate90.backend def __init__( - self, keys: KeysCollection, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1), allow_missing_keys: bool = False + self, keys: KeysCollection, k: int = 1, spatial_axes: tuple[int, int] = (0, 1), allow_missing_keys: bool = False ) -> None: """ Args: @@ -506,20 +530,25 @@ def __init__( super().__init__(keys, allow_missing_keys) self.rotator = Rotate90(k, spatial_axes) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool) -> None: + self._lazy_evaluation = val + self.rotator.lazy_evaluation = val + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.rotator(d[key]) return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.rotator.inverse(d[key]) return d -class RandRotate90d(RandomizableTransform, MapTransform, InvertibleTransform): +class RandRotate90d(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based version :py:class:`monai.transforms.RandRotate90`. With probability `prob`, input arrays are rotated by 90 degrees @@ -533,7 +562,7 @@ def __init__( keys: KeysCollection, prob: float = 0.1, max_k: int = 3, - spatial_axes: Tuple[int, int] = (0, 1), + spatial_axes: tuple[int, int] = (0, 1), allow_missing_keys: bool = False, ) -> None: """ @@ -556,7 +585,7 @@ def __init__( self._rand_k = 0 - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, data: Any | None = None) -> None: self._rand_k = self.R.randint(self.max_k) + 1 super().randomize(None) @@ -567,14 +596,13 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, t # FIXME: here we didn't use array version `RandRotate90` transform as others, because we need # to be compatible with the random status of some previous integration tests rotator = Rotate90(self._rand_k, self.spatial_axes) + rotator.lazy_evaluation = self.lazy_evaluation for key in self.key_iterator(d): d[key] = rotator(d[key]) if self._do_transform else convert_to_tensor(d[key], track_meta=get_track_meta()) - if get_track_meta(): - xform = self.pop_transform(d[key], check=False) if self._do_transform else {} - self.push_transform(d[key], extra_info=xform) + self.push_transform(d[key], replace=True) return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): if not isinstance(d[key], MetaTensor): @@ -585,7 +613,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class Resized(MapTransform, InvertibleTransform): +class Resized(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Resize`. @@ -618,6 +646,8 @@ class Resized(MapTransform, InvertibleTransform): By default, this value is chosen as (s - 1) / 2 where s is the downsampling factor, where s > 1. For the up-size case, s < 1, no anti-aliasing is performed prior to rescaling. + dtype: data type for resampling computation. Defaults to ``float32``. + If None, use the data type of input data. allow_missing_keys: don't raise exception if key is missing. """ @@ -626,25 +656,32 @@ class Resized(MapTransform, InvertibleTransform): def __init__( self, keys: KeysCollection, - spatial_size: Union[Sequence[int], int], + spatial_size: Sequence[int] | int, size_mode: str = "all", mode: SequenceStr = InterpolateMode.AREA, - align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, - anti_aliasing: Union[Sequence[bool], bool] = False, - anti_aliasing_sigma: Union[Sequence[Union[Sequence[float], float, None]], Sequence[float], float, None] = None, + align_corners: Sequence[bool | None] | bool | None = None, + anti_aliasing: Sequence[bool] | bool = False, + anti_aliasing_sigma: Sequence[Sequence[float] | float | None] | Sequence[float] | float | None = None, + dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) + self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.anti_aliasing = ensure_tuple_rep(anti_aliasing, len(self.keys)) self.anti_aliasing_sigma = ensure_tuple_rep(anti_aliasing_sigma, len(self.keys)) self.resizer = Resize(spatial_size=spatial_size, size_mode=size_mode) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool) -> None: + self._lazy_evaluation = val + self.resizer.lazy_evaluation = val + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) - for key, mode, align_corners, anti_aliasing, anti_aliasing_sigma in self.key_iterator( - d, self.mode, self.align_corners, self.anti_aliasing, self.anti_aliasing_sigma + for key, mode, align_corners, anti_aliasing, anti_aliasing_sigma, dtype in self.key_iterator( + d, self.mode, self.align_corners, self.anti_aliasing, self.anti_aliasing_sigma, self.dtype ): d[key] = self.resizer( d[key], @@ -652,17 +689,18 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc align_corners=align_corners, anti_aliasing=anti_aliasing, anti_aliasing_sigma=anti_aliasing_sigma, + dtype=dtype, ) return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.resizer.inverse(d[key]) return d -class Affined(MapTransform, InvertibleTransform): +class Affined(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Affine`. """ @@ -672,16 +710,17 @@ class Affined(MapTransform, InvertibleTransform): def __init__( self, keys: KeysCollection, - rotate_params: Optional[Union[Sequence[float], float]] = None, - shear_params: Optional[Union[Sequence[float], float]] = None, - translate_params: Optional[Union[Sequence[float], float]] = None, - scale_params: Optional[Union[Sequence[float], float]] = None, - affine: Optional[NdarrayOrTensor] = None, - spatial_size: Optional[Union[Sequence[int], int]] = None, + rotate_params: Sequence[float] | float | None = None, + shear_params: Sequence[float] | float | None = None, + translate_params: Sequence[float] | float | None = None, + scale_params: Sequence[float] | float | None = None, + affine: NdarrayOrTensor | None = None, + spatial_size: Sequence[int] | int | None = None, mode: SequenceStr = GridSampleMode.BILINEAR, padding_mode: SequenceStr = GridSamplePadMode.REFLECTION, - device: Optional[torch.device] = None, - dtype: Union[DtypeLike, torch.dtype] = np.float32, + device: torch.device | None = None, + dtype: DtypeLike | torch.dtype = np.float32, + align_corners: bool = False, allow_missing_keys: bool = False, ) -> None: """ @@ -730,6 +769,8 @@ def __init__( dtype: data type for resampling computation. Defaults to ``float32``. If ``None``, use the data type of input data. To be compatible with other modules, the output data type is always `float32`. + align_corners: Defaults to False. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html allow_missing_keys: don't raise exception if key is missing. See also: @@ -746,25 +787,31 @@ def __init__( affine=affine, spatial_size=spatial_size, device=device, - dtype=dtype, + dtype=dtype, # type: ignore + align_corners=align_corners, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool) -> None: + self._lazy_evaluation = val + self.affine.lazy_evaluation = val + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): d[key], _ = self.affine(d[key], mode=mode, padding_mode=padding_mode) return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.affine.inverse(d[key]) return d -class RandAffined(RandomizableTransform, MapTransform, InvertibleTransform): +class RandAffined(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.RandAffine`. """ @@ -774,16 +821,16 @@ class RandAffined(RandomizableTransform, MapTransform, InvertibleTransform): def __init__( self, keys: KeysCollection, - spatial_size: Optional[Union[Sequence[int], int]] = None, + spatial_size: Sequence[int] | int | None = None, prob: float = 0.1, - rotate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, - shear_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, - translate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, - scale_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + rotate_range: Sequence[tuple[float, float] | float] | float | None = None, + shear_range: Sequence[tuple[float, float] | float] | float | None = None, + translate_range: Sequence[tuple[float, float] | float] | float | None = None, + scale_range: Sequence[tuple[float, float] | float] | float | None = None, mode: SequenceStr = GridSampleMode.BILINEAR, padding_mode: SequenceStr = GridSamplePadMode.REFLECTION, cache_grid: bool = False, - device: Optional[torch.device] = None, + device: torch.device | None = None, allow_missing_keys: bool = False, ) -> None: """ @@ -859,25 +906,29 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandAffined": + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool) -> None: + self._lazy_evaluation = val + self.rand_affine.lazy_evaluation = val + + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandAffined: self.rand_affine.set_random_state(seed, state) super().set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) first_key: Hashable = self.first_key(d) if first_key == (): - out: Dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta()) + out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta()) return out self.randomize(None) # all the keys share the same random Affine factor self.rand_affine.randomize() - spatial_size = d[first_key].shape[1:] + item = d[first_key] + spatial_size = item.peek_pending_shape() if isinstance(item, MetaTensor) else item.shape[1:] sp_size = fall_back_tuple(self.rand_affine.spatial_size, spatial_size) # change image size or do random transform @@ -887,26 +938,27 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N if do_resampling: # need to prepare grid grid = self.rand_affine.get_identity_grid(sp_size) if self._do_transform: # add some random factors - grid = self.rand_affine.rand_affine_grid(grid=grid) + grid = self.rand_affine.rand_affine_grid(sp_size, grid=grid) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): # do the transform if do_resampling: - d[key] = self.rand_affine(d[key], mode=mode, padding_mode=padding_mode, grid=grid) # type: ignore + d[key] = self.rand_affine(d[key], None, mode, padding_mode, True, grid) # type: ignore else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32) - if get_track_meta(): - xform = self.pop_transform(d[key], check=False) if do_resampling else {} - self.push_transform(d[key], extra_info={"do_resampling": do_resampling, "rand_affine_info": xform}) + self._do_transform = do_resampling # TODO: unify self._do_transform and do_resampling + self.push_transform(d[key], replace=True) return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): tr = self.pop_transform(d[key]) - do_resampling = tr[TraceKeys.EXTRA_INFO]["do_resampling"] + if TraceKeys.EXTRA_INFO not in tr[TraceKeys.EXTRA_INFO]: + continue + do_resampling = tr[TraceKeys.EXTRA_INFO][TraceKeys.EXTRA_INFO]["do_resampling"] if do_resampling: - d[key].applied_operations.append(tr[TraceKeys.EXTRA_INFO]["rand_affine_info"]) # type: ignore + d[key].applied_operations.append(tr[TraceKeys.EXTRA_INFO]) # type: ignore d[key] = self.rand_affine.inverse(d[key]) # type: ignore return d @@ -922,17 +974,17 @@ class Rand2DElasticd(RandomizableTransform, MapTransform): def __init__( self, keys: KeysCollection, - spacing: Union[Tuple[float, float], float], - magnitude_range: Tuple[float, float], - spatial_size: Optional[Union[Tuple[int, int], int]] = None, + spacing: tuple[float, float] | float, + magnitude_range: tuple[float, float], + spatial_size: tuple[int, int] | int | None = None, prob: float = 0.1, - rotate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, - shear_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, - translate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, - scale_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + rotate_range: Sequence[tuple[float, float] | float] | float | None = None, + shear_range: Sequence[tuple[float, float] | float] | float | None = None, + translate_range: Sequence[tuple[float, float] | float] | float | None = None, + scale_range: Sequence[tuple[float, float] | float] | float | None = None, mode: SequenceStr = GridSampleMode.BILINEAR, padding_mode: SequenceStr = GridSamplePadMode.REFLECTION, - device: Optional[torch.device] = None, + device: torch.device | None = None, allow_missing_keys: bool = False, ) -> None: """ @@ -1008,19 +1060,17 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "Rand2DElasticd": + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Rand2DElasticd: self.rand_2d_elastic.set_random_state(seed, state) super().set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) first_key: Hashable = self.first_key(d) if first_key == (): - out: Dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta()) + out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta()) return out self.randomize(None) @@ -1028,6 +1078,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N if device is None and isinstance(d[first_key], torch.Tensor): device = d[first_key].device # type: ignore self.rand_2d_elastic.set_device(device) + if isinstance(d[first_key], MetaTensor) and d[first_key].pending_operations: # type: ignore + warnings.warn(f"data['{first_key}'] has pending operations, transform may return incorrect results.") sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, d[first_key].shape[1:]) # all the keys share the same random elastic factor @@ -1045,7 +1097,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N ) grid = CenterSpatialCrop(roi_size=sp_size)(grid[0]) else: - grid = create_grid(spatial_size=sp_size, device=device, backend="torch") + grid = cast(torch.Tensor, create_grid(spatial_size=sp_size, device=device, backend="torch")) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): d[key] = self.rand_2d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) # type: ignore @@ -1062,17 +1114,17 @@ class Rand3DElasticd(RandomizableTransform, MapTransform): def __init__( self, keys: KeysCollection, - sigma_range: Tuple[float, float], - magnitude_range: Tuple[float, float], - spatial_size: Optional[Union[Tuple[int, int, int], int]] = None, + sigma_range: tuple[float, float], + magnitude_range: tuple[float, float], + spatial_size: tuple[int, int, int] | int | None = None, prob: float = 0.1, - rotate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, - shear_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, - translate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, - scale_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + rotate_range: Sequence[tuple[float, float] | float] | float | None = None, + shear_range: Sequence[tuple[float, float] | float] | float | None = None, + translate_range: Sequence[tuple[float, float] | float] | float | None = None, + scale_range: Sequence[tuple[float, float] | float] | float | None = None, mode: SequenceStr = GridSampleMode.BILINEAR, padding_mode: SequenceStr = GridSamplePadMode.REFLECTION, - device: Optional[torch.device] = None, + device: torch.device | None = None, allow_missing_keys: bool = False, ) -> None: """ @@ -1150,23 +1202,22 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "Rand3DElasticd": + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Rand3DElasticd: self.rand_3d_elastic.set_random_state(seed, state) super().set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) first_key: Hashable = self.first_key(d) if first_key == (): - out: Dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta()) + out: dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta()) return out self.randomize(None) - + if isinstance(d[first_key], MetaTensor) and d[first_key].pending_operations: # type: ignore + warnings.warn(f"data['{first_key}'] has pending operations, transform may return incorrect results.") sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, d[first_key].shape[1:]) # all the keys share the same random elastic factor @@ -1188,7 +1239,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc return d -class Flipd(MapTransform, InvertibleTransform): +class Flipd(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Flip`. @@ -1204,28 +1255,30 @@ class Flipd(MapTransform, InvertibleTransform): backend = Flip.backend def __init__( - self, - keys: KeysCollection, - spatial_axis: Optional[Union[Sequence[int], int]] = None, - allow_missing_keys: bool = False, + self, keys: KeysCollection, spatial_axis: Sequence[int] | int | None = None, allow_missing_keys: bool = False ) -> None: super().__init__(keys, allow_missing_keys) self.flipper = Flip(spatial_axis=spatial_axis) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool): + self.flipper.lazy_evaluation = val + self._lazy_evaluation = val + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.flipper(d[key]) return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.flipper.inverse(d[key]) return d -class RandFlipd(RandomizableTransform, MapTransform, InvertibleTransform): +class RandFlipd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based version :py:class:`monai.transforms.RandFlip`. @@ -1245,20 +1298,23 @@ def __init__( self, keys: KeysCollection, prob: float = 0.1, - spatial_axis: Optional[Union[Sequence[int], int]] = None, + spatial_axis: Sequence[int] | int | None = None, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.flipper = Flip(spatial_axis=spatial_axis) - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandFlipd": + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool): + self.flipper.lazy_evaluation = val + self._lazy_evaluation = val + + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandFlipd: super().set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) self.randomize(None) @@ -1267,12 +1323,10 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc d[key] = self.flipper(d[key]) else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) - if get_track_meta(): - xform_info = self.pop_transform(d[key], check=False) if self._do_transform else {} - self.push_transform(d[key], extra_info=xform_info) + self.push_transform(d[key], replace=True) return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): xform = self.pop_transform(d[key]) @@ -1283,7 +1337,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform): +class RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based version :py:class:`monai.transforms.RandAxisFlip`. @@ -1304,14 +1358,17 @@ def __init__(self, keys: KeysCollection, prob: float = 0.1, allow_missing_keys: RandomizableTransform.__init__(self, prob) self.flipper = RandAxisFlip(prob=1.0) - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandAxisFlipd": + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool): + self.flipper.lazy_evaluation = val + self._lazy_evaluation = val + + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandAxisFlipd: super().set_random_state(seed, state) self.flipper.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) first_key: Hashable = self.first_key(d) if first_key == (): @@ -1327,12 +1384,10 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc d[key] = self.flipper(d[key], randomize=False) else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) - if get_track_meta(): - xform = self.pop_transform(d[key], check=False) if self._do_transform else {} - self.push_transform(d[key], extra_info=xform) + self.push_transform(d[key], replace=True) return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): xform = self.pop_transform(d[key]) @@ -1342,7 +1397,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class Rotated(MapTransform, InvertibleTransform): +class Rotated(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rotate`. @@ -1375,12 +1430,12 @@ class Rotated(MapTransform, InvertibleTransform): def __init__( self, keys: KeysCollection, - angle: Union[Sequence[float], float], + angle: Sequence[float] | float, keep_size: bool = True, mode: SequenceStr = GridSampleMode.BILINEAR, padding_mode: SequenceStr = GridSamplePadMode.BORDER, - align_corners: Union[Sequence[bool], bool] = False, - dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], DtypeLike, torch.dtype] = np.float32, + align_corners: Sequence[bool] | bool = False, + dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) @@ -1391,7 +1446,12 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool): + self.rotator.lazy_evaluation = val + self._lazy_evaluation = val + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype @@ -1401,14 +1461,14 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc ) return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.rotator.inverse(d[key]) return d -class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform): +class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based version :py:class:`monai.transforms.RandRotate` Randomly rotates the input arrays. @@ -1448,15 +1508,15 @@ class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform): def __init__( self, keys: KeysCollection, - range_x: Union[Tuple[float, float], float] = 0.0, - range_y: Union[Tuple[float, float], float] = 0.0, - range_z: Union[Tuple[float, float], float] = 0.0, + range_x: tuple[float, float] | float = 0.0, + range_y: tuple[float, float] | float = 0.0, + range_z: tuple[float, float] | float = 0.0, prob: float = 0.1, keep_size: bool = True, mode: SequenceStr = GridSampleMode.BILINEAR, padding_mode: SequenceStr = GridSamplePadMode.BORDER, - align_corners: Union[Sequence[bool], bool] = False, - dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], DtypeLike, torch.dtype] = np.float32, + align_corners: Sequence[bool] | bool = False, + dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) @@ -1467,14 +1527,17 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandRotated": + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool): + self.rand_rotate.lazy_evaluation = val + self._lazy_evaluation = val + + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandRotated: super().set_random_state(seed, state) self.rand_rotate.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) self.randomize(None) @@ -1494,12 +1557,10 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc ) else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32) - if get_track_meta(): - rot_info = self.pop_transform(d[key], check=False) if self._do_transform else {} - self.push_transform(d[key], extra_info=rot_info) + self.push_transform(d[key], replace=True) return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): xform = self.pop_transform(d[key]) @@ -1509,7 +1570,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class Zoomd(MapTransform, InvertibleTransform): +class Zoomd(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Zoom`. @@ -1533,6 +1594,8 @@ class Zoomd(MapTransform, InvertibleTransform): 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. + dtype: data type for resampling computation. Defaults to ``float32``. + If None, use the data type of input data. keep_size: Should keep original size (pad if needed), default is True. allow_missing_keys: don't raise exception if key is missing. kwargs: other arguments for the `np.pad` or `torch.pad` function. @@ -1545,10 +1608,11 @@ class Zoomd(MapTransform, InvertibleTransform): def __init__( self, keys: KeysCollection, - zoom: Union[Sequence[float], float], + zoom: Sequence[float] | float, mode: SequenceStr = InterpolateMode.AREA, padding_mode: SequenceStr = NumpyPadMode.EDGE, - align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, + align_corners: Sequence[bool | None] | bool | None = None, + dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, keep_size: bool = True, allow_missing_keys: bool = False, **kwargs, @@ -1557,24 +1621,30 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) + self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **kwargs) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool): + self.zoomer.lazy_evaluation = val + self._lazy_evaluation = val + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) - for key, mode, padding_mode, align_corners in self.key_iterator( - d, self.mode, self.padding_mode, self.align_corners + for key, mode, padding_mode, align_corners, dtype in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners, self.dtype ): - d[key] = self.zoomer(d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners) + d[key] = self.zoomer(d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype) return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.zoomer.inverse(d[key]) return d -class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform): +class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): """ Dict-based version :py:class:`monai.transforms.RandZoom`. @@ -1606,6 +1676,8 @@ class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform): 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. + dtype: data type for resampling computation. Defaults to ``float32``. + If None, use the data type of input data. keep_size: Should keep original size (pad if needed), default is True. allow_missing_keys: don't raise exception if key is missing. kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. @@ -1619,11 +1691,12 @@ def __init__( self, keys: KeysCollection, prob: float = 0.1, - min_zoom: Union[Sequence[float], float] = 0.9, - max_zoom: Union[Sequence[float], float] = 1.1, + min_zoom: Sequence[float] | float = 0.9, + max_zoom: Sequence[float] | float = 1.1, mode: SequenceStr = InterpolateMode.AREA, padding_mode: SequenceStr = NumpyPadMode.EDGE, - align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, + align_corners: Sequence[bool | None] | bool | None = None, + dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, keep_size: bool = True, allow_missing_keys: bool = False, **kwargs, @@ -1634,19 +1707,23 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) + self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandZoomd": + @LazyTransform.lazy_evaluation.setter # type: ignore + def lazy_evaluation(self, val: bool): + self.rand_zoom.lazy_evaluation = val + self._lazy_evaluation = val + + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandZoomd: super().set_random_state(seed, state) self.rand_zoom.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) first_key: Hashable = self.first_key(d) if first_key == (): - out: Dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta()) + out: dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta()) return out self.randomize(None) @@ -1654,21 +1731,24 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc # all the keys share the same random zoom factor self.rand_zoom.randomize(d[first_key]) - for key, mode, padding_mode, align_corners in self.key_iterator( - d, self.mode, self.padding_mode, self.align_corners + for key, mode, padding_mode, align_corners, dtype in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners, self.dtype ): if self._do_transform: d[key] = self.rand_zoom( - d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, randomize=False + d[key], + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + dtype=dtype, + randomize=False, ) else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32) - if get_track_meta(): - xform = self.pop_transform(d[key], check=False) if self._do_transform else {} - self.push_transform(d[key], extra_info=xform) + self.push_transform(d[key], replace=True) return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): xform = self.pop_transform(d[key]) @@ -1688,11 +1768,11 @@ class GridDistortiond(MapTransform): def __init__( self, keys: KeysCollection, - num_cells: Union[Tuple[int], int], - distort_steps: List[Tuple], + num_cells: tuple[int] | int, + distort_steps: list[tuple], mode: str = GridSampleMode.BILINEAR, padding_mode: str = GridSamplePadMode.BORDER, - device: Optional[torch.device] = None, + device: torch.device | None = None, allow_missing_keys: bool = False, ) -> None: """ @@ -1725,7 +1805,7 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): d[key] = self.grid_distortion(d[key], mode=mode, padding_mode=padding_mode) @@ -1742,12 +1822,12 @@ class RandGridDistortiond(RandomizableTransform, MapTransform): def __init__( self, keys: KeysCollection, - num_cells: Union[Tuple[int], int] = 5, + num_cells: tuple[int] | int = 5, prob: float = 0.1, - distort_limit: Union[Tuple[float, float], float] = (-0.03, 0.03), + distort_limit: tuple[float, float] | float = (-0.03, 0.03), mode: str = GridSampleMode.BILINEAR, padding_mode: str = GridSamplePadMode.BORDER, - device: Optional[torch.device] = None, + device: torch.device | None = None, allow_missing_keys: bool = False, ) -> None: """ @@ -1785,24 +1865,25 @@ def __init__( self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandGridDistortiond": + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandGridDistortiond: super().set_random_state(seed, state) self.rand_grid_distortion.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) self.randomize(None) if not self._do_transform: - out: Dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta()) + out: dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta()) return out first_key: Hashable = self.first_key(d) if first_key == (): out = convert_to_tensor(d, track_meta=get_track_meta()) return out - + if isinstance(d[first_key], MetaTensor) and d[first_key].pending_operations: # type: ignore + warnings.warn(f"data['{first_key}'] has pending operations, transform may return incorrect results.") self.rand_grid_distortion.randomize(d[first_key].shape[1:]) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): @@ -1810,7 +1891,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc return d -class GridSplitd(MapTransform): +class GridSplitd(MapTransform, MultiSampleTrait): """ Split the image into patches based on the provided grid in 2D. @@ -1831,8 +1912,8 @@ class GridSplitd(MapTransform): def __init__( self, keys: KeysCollection, - grid: Tuple[int, int] = (2, 2), - size: Optional[Union[int, Tuple[int, int], Dict[Hashable, Union[int, Tuple[int, int], None]]]] = None, + grid: tuple[int, int] = (2, 2), + size: int | tuple[int, int] | dict[Hashable, int | tuple[int, int] | None] | None = None, allow_missing_keys: bool = False, ): super().__init__(keys, allow_missing_keys) @@ -1840,10 +1921,10 @@ def __init__( self.size = size if isinstance(size, dict) else {key: size for key in self.keys} self.splitter = GridSplit(grid=grid) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> list[dict[Hashable, NdarrayOrTensor]]: d = dict(data) n_outputs = np.prod(self.grid) - output: List[Dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(n_outputs)] + output: list[dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(n_outputs)] for key in self.key_iterator(d): result = self.splitter(d[key], self.size[key]) for i in range(n_outputs): @@ -1851,7 +1932,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashab return output -class GridPatchd(MapTransform): +class GridPatchd(MapTransform, MultiSampleTrait): """ Extract all the patches sweeping the entire image in a row-major sliding-window manner with possible overlaps. It can sort the patches and return all or a subset of them. @@ -1861,24 +1942,32 @@ class GridPatchd(MapTransform): patch_size: size of patches to generate slices for, 0 or None selects whole dimension offset: starting position in the array, default is 0 for each dimension. np.random.randint(0, patch_size, 2) creates random start between 0 and `patch_size` for a 2D image. - num_patches: number of patches to return. Defaults to None, which returns all the available patches. + num_patches: number of patches (or maximum number of patches) to return. + If the requested number of patches is greater than the number of available patches, + padding will be applied to provide exactly `num_patches` patches unless `threshold` is set. + When `threshold` is set, this value is treated as the maximum number of patches. + Defaults to None, which does not limit number of the patches. overlap: amount of overlap between patches in each dimension. Default to 0.0. sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`), lowest values (`"min"`), or in their default order (`None`). Default to None. threshold: a value to keep only the patches whose sum of intensities are less than the threshold. Defaults to no filtering. - pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``. + pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries. + Defaults to None, which means no padding will be applied. + Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, + ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}. + See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html allow_missing_keys: don't raise exception if key is missing. pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. Returns: - a list of dictionaries, each of which contains the all the original key/value with the values for `keys` - replaced by the patches. It also add the following new keys: + dictionary, contains the all the original key/value with the values for `keys` + replaced by the patches, a MetaTensor with following metadata: - "patch_location": the starting location of the patch in the image, - "patch_size": size of the extracted patch - "num_patches": total number of patches in the image - "offset": the amount of offset for the patches in the image (starting position of upper left patch) + - `PatchKeys.LOCATION`: the starting location of the patch in the image, + - `PatchKeys.COUNT`: total number of patches in the image, + - "spatial_shape": spatial size of the extracted patch, and + - "offset": the amount of offset for the patches in the image (starting position of the first patch) """ backend = GridPatch.backend @@ -1887,11 +1976,11 @@ def __init__( self, keys: KeysCollection, patch_size: Sequence[int], - offset: Optional[Sequence[int]] = None, - num_patches: Optional[int] = None, + offset: Sequence[int] | None = None, + num_patches: int | None = None, overlap: float = 0.0, - sort_fn: Optional[str] = None, - threshold: Optional[float] = None, + sort_fn: str | None = None, + threshold: float | None = None, pad_mode: str = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False, **pad_kwargs, @@ -1908,14 +1997,14 @@ def __init__( **pad_kwargs, ) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.patcher(d[key]) return d -class RandGridPatchd(RandomizableTransform, MapTransform): +class RandGridPatchd(RandomizableTransform, MapTransform, MultiSampleTrait): """ Extract all the patches sweeping the entire image in a row-major sliding-window manner with possible overlaps, and with random offset for the minimal corner of the image, (0,0) for 2D and (0,0,0) for 3D. @@ -1927,25 +2016,33 @@ class RandGridPatchd(RandomizableTransform, MapTransform): min_offset: the minimum range of starting position to be selected randomly. Defaults to 0. max_offset: the maximum range of starting position to be selected randomly. Defaults to image size modulo patch size. - num_patches: number of patches to return. Defaults to None, which returns all the available patches. + num_patches: number of patches (or maximum number of patches) to return. + If the requested number of patches is greater than the number of available patches, + padding will be applied to provide exactly `num_patches` patches unless `threshold` is set. + When `threshold` is set, this value is treated as the maximum number of patches. + Defaults to None, which does not limit number of the patches. overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`), lowest values (`"min"`), or in their default order (`None`). Default to None. threshold: a value to keep only the patches whose sum of intensities are less than the threshold. Defaults to no filtering. - pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``. + pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries. + Defaults to None, which means no padding will be applied. + Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, + ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}. + See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html allow_missing_keys: don't raise exception if key is missing. pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. Returns: - a list of dictionaries, each of which contains the all the original key/value with the values for `keys` - replaced by the patches. It also add the following new keys: + dictionary, contains the all the original key/value with the values for `keys` + replaced by the patches, a MetaTensor with following metadata: - "patch_location": the starting location of the patch in the image, - "patch_size": size of the extracted patch - "num_patches": total number of patches in the image - "offset": the amount of offset for the patches in the image (starting position of the first patch) + - `PatchKeys.LOCATION`: the starting location of the patch in the image, + - `PatchKeys.COUNT`: total number of patches in the image, + - "spatial_shape": spatial size of the extracted patch, and + - "offset": the amount of offset for the patches in the image (starting position of the first patch) """ @@ -1955,12 +2052,12 @@ def __init__( self, keys: KeysCollection, patch_size: Sequence[int], - min_offset: Optional[Union[Sequence[int], int]] = None, - max_offset: Optional[Union[Sequence[int], int]] = None, - num_patches: Optional[int] = None, + min_offset: Sequence[int] | int | None = None, + max_offset: Sequence[int] | int | None = None, + num_patches: int | None = None, overlap: float = 0.0, - sort_fn: Optional[str] = None, - threshold: Optional[float] = None, + sort_fn: str | None = None, + threshold: float | None = None, pad_mode: str = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False, **pad_kwargs, @@ -1978,14 +2075,12 @@ def __init__( **pad_kwargs, ) - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandGridPatchd": + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandGridPatchd: super().set_random_state(seed, state) self.patcher.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) # All the keys share the same random noise for key in self.key_iterator(d): diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py new file mode 100644 index 0000000000..96ec6be5d0 --- /dev/null +++ b/monai/transforms/spatial/functional.py @@ -0,0 +1,595 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A collection of "functional" transforms for spatial operations +https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design +""" + +from __future__ import annotations + +import math +import warnings +from enum import Enum + +import numpy as np +import torch + +import monai +from monai.config.type_definitions import NdarrayOrTensor +from monai.data.meta_obj import get_track_meta +from monai.data.meta_tensor import MetaTensor +from monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd +from monai.networks.layers import AffineTransform +from monai.transforms.croppad.array import ResizeWithPadOrCrop +from monai.transforms.intensity.array import GaussianSmooth +from monai.transforms.inverse import TraceableTransform +from monai.transforms.utils import create_rotate, create_translate, scale_affine +from monai.transforms.utils_pytorch_numpy_unification import allclose +from monai.utils import ( + TraceKeys, + convert_to_dst_type, + convert_to_numpy, + convert_to_tensor, + ensure_tuple, + ensure_tuple_rep, + fall_back_tuple, + optional_import, +) + +nib, has_nib = optional_import("nibabel") +cupy, _ = optional_import("cupy") +cupy_ndi, _ = optional_import("cupyx.scipy.ndimage") +np_ndi, _ = optional_import("scipy.ndimage") + +__all__ = ["spatial_resample", "orientation", "flip", "resize", "rotate", "zoom", "rotate90", "affine_func"] + + +def _maybe_new_metatensor(img, dtype=None, device=None): + """create a metatensor with fresh metadata if track_meta is True otherwise convert img into a torch tensor""" + return convert_to_tensor( + img.as_tensor() if isinstance(img, MetaTensor) else img, + dtype=dtype, + device=device, + track_meta=get_track_meta(), + wrap_sequence=True, + ) + + +def spatial_resample( + img, dst_affine, spatial_size, mode, padding_mode, align_corners, dtype_pt, transform_info +) -> torch.Tensor: + """ + Functional implementation of resampling the input image to the specified ``dst_affine`` matrix and ``spatial_size``. + This function operates eagerly or lazily according to + ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + + Args: + img: data to be resampled, assuming `img` is channel-first. + dst_affine: target affine matrix, if None, use the input affine matrix, effectively no resampling. + spatial_size: output spatial size, if the component is ``-1``, use the corresponding input spatial size. + mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers). + Interpolation mode to calculate output values. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used + and the value represents the order of the spline interpolation. + See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `mode` is an integer, using numpy/cupy backends, this argument accepts + {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}. + See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html + align_corners: Geometrically, we consider the pixels of the input as squares rather than points. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + dtype_pt: data `dtype` for resampling computation. + transform_info: a dictionary with the relevant information pertaining to an applied transform. + """ + original_spatial_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] + src_affine: torch.Tensor = img.peek_pending_affine() if isinstance(img, MetaTensor) else torch.eye(4) + img = convert_to_tensor(data=img, track_meta=get_track_meta()) + # ensure spatial rank is <= 3 + spatial_rank = min(len(img.shape) - 1, src_affine.shape[0] - 1, 3) + if (not isinstance(spatial_size, int) or spatial_size != -1) and spatial_size is not None: + spatial_rank = min(len(ensure_tuple(spatial_size)), 3) # infer spatial rank based on spatial_size + src_affine = to_affine_nd(spatial_rank, src_affine).to(torch.float64) + dst_affine = to_affine_nd(spatial_rank, dst_affine) if dst_affine is not None else src_affine + dst_affine = convert_to_dst_type(dst_affine, src_affine)[0] + if not isinstance(dst_affine, torch.Tensor): + raise ValueError(f"dst_affine should be a torch.Tensor, got {type(dst_affine)}") + + in_spatial_size = torch.tensor(original_spatial_shape[:spatial_rank]) + if isinstance(spatial_size, int) and (spatial_size == -1): # using the input spatial size + spatial_size = in_spatial_size + elif spatial_size is None and spatial_rank > 1: # auto spatial size + spatial_size, _ = compute_shape_offset(in_spatial_size, src_affine, dst_affine) # type: ignore + spatial_size = torch.tensor( + fall_back_tuple(ensure_tuple(spatial_size)[:spatial_rank], in_spatial_size, lambda x: x >= 0) + ) + extra_info = { + "dtype": str(dtype_pt)[6:], # remove "torch": torch.float32 -> float32 + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, + "src_affine": src_affine, + } + try: + _s = convert_to_numpy(src_affine) + _d = convert_to_numpy(dst_affine) + xform = np.eye(spatial_rank + 1) if spatial_rank < 2 else np.linalg.solve(_s, _d) + except (np.linalg.LinAlgError, RuntimeError) as e: + raise ValueError(f"src affine is not invertible {_s}, {_d}.") from e + xform = convert_to_tensor(to_affine_nd(spatial_rank, xform)).to(device=img.device, dtype=torch.float64) + affine_unchanged = ( + allclose(src_affine, dst_affine, atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size) + ) or (allclose(xform, np.eye(len(xform)), atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size)) + lazy_evaluation = transform_info.get(TraceKeys.LAZY_EVALUATION, False) + meta_info = TraceableTransform.track_transform_meta( + img, + sp_size=spatial_size, + affine=None if affine_unchanged and not lazy_evaluation else xform, + extra_info=extra_info, + orig_size=original_spatial_shape, + transform_info=transform_info, + lazy_evaluation=lazy_evaluation, + ) + if lazy_evaluation: + out = _maybe_new_metatensor(img) + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore + if affine_unchanged: + # no significant change or lazy change, return original image + out = _maybe_new_metatensor(img, dtype=torch.float32) + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore + # drop current meta first + img = img.as_tensor() if isinstance(img, MetaTensor) else img + im_size = list(img.shape) + chns, in_sp_size, additional_dims = im_size[0], im_size[1 : spatial_rank + 1], im_size[spatial_rank + 1 :] + + if additional_dims: + xform_shape = [-1] + in_sp_size + img = img.reshape(xform_shape) + img = img.to(dtype_pt) + if isinstance(mode, int): + dst_xform = create_translate(spatial_rank, [float(d - 1) / 2 for d in spatial_size]) + xform = xform @ convert_to_dst_type(dst_xform, xform)[0] + affine_xform = monai.transforms.Affine( + affine=xform, + spatial_size=spatial_size, + normalized=True, + image_only=True, + dtype=dtype_pt, + align_corners=align_corners, + ) + with affine_xform.trace_transform(False): + img = affine_xform(img, mode=mode, padding_mode=padding_mode) + else: + affine_xform = AffineTransform( # type: ignore + normalized=False, mode=mode, padding_mode=padding_mode, align_corners=align_corners, reverse_indexing=True + ) + img = affine_xform(img.unsqueeze(0), theta=xform.to(img), spatial_size=spatial_size).squeeze(0) # type: ignore + if additional_dims: + full_shape = (chns, *spatial_size, *additional_dims) + img = img.reshape(full_shape) + out = _maybe_new_metatensor(img, dtype=torch.float32) + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore + + +def orientation(img, original_affine, spatial_ornt, transform_info): + """ + Functional implementation of changing the input image's orientation into the specified based on `spatial_ornt`. + This function operates eagerly or lazily according to + ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + + Args: + img: data to be changed, assuming `img` is channel-first. + original_affine: original affine of the input image. + spatial_ornt: orientations of the spatial axes, + see also https://nipy.org/nibabel/reference/nibabel.orientations.html + transform_info: a dictionary with the relevant information pertaining to an applied transform. + """ + spatial_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] + xform = nib.orientations.inv_ornt_aff(spatial_ornt, spatial_shape) + img = convert_to_tensor(img, track_meta=get_track_meta()) + + spatial_ornt[:, 0] += 1 # skip channel dim + spatial_ornt = np.concatenate([np.array([[0, 1]]), spatial_ornt]) + axes = [ax for ax, flip in enumerate(spatial_ornt[:, 1]) if flip == -1] + full_transpose = np.arange(len(spatial_shape) + 1) # channel-first array + full_transpose[: len(spatial_ornt)] = np.argsort(spatial_ornt[:, 0]) + extra_info = {"original_affine": original_affine} + + shape_np = convert_to_numpy(spatial_shape, wrap_sequence=True) + shape_np = shape_np[[i - 1 for i in full_transpose if i > 0]] + meta_info = TraceableTransform.track_transform_meta( + img, + sp_size=shape_np, + affine=xform, + extra_info=extra_info, + orig_size=spatial_shape, + transform_info=transform_info, + lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + ) + out = _maybe_new_metatensor(img) + if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info + if axes: + out = torch.flip(out, dims=axes) + if not np.all(full_transpose == np.arange(len(out.shape))): + out = out.permute(full_transpose.tolist()) + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + + +def flip(img, sp_axes, transform_info): + """ + Functional implementation of flip. + This function operates eagerly or lazily according to + ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + + Args: + img: data to be changed, assuming `img` is channel-first. + sp_axes: spatial axes along which to flip over. + If None, will flip over all of the axes of the input array. + If axis is negative it counts from the last to the first axis. + If axis is a tuple of ints, flipping is performed on all of the axes + specified in the tuple. + transform_info: a dictionary with the relevant information pertaining to an applied transform. + """ + sp_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] + sp_size = convert_to_numpy(sp_size, wrap_sequence=True).tolist() + extra_info = {"axes": sp_axes} # track the spatial axes + axes = monai.transforms.utils.map_spatial_axes(img.ndim, sp_axes) # use the axes with channel dim + rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else torch.tensor(3.0, dtype=torch.double) + # axes include the channel dim + xform = torch.eye(int(rank) + 1, dtype=torch.double) + for axis in axes: + sp = axis - 1 + xform[sp, sp], xform[sp, -1] = xform[sp, sp] * -1, sp_size[sp] - 1 + meta_info = TraceableTransform.track_transform_meta( + img, + sp_size=sp_size, + affine=xform, + extra_info=extra_info, + transform_info=transform_info, + lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + ) + out = _maybe_new_metatensor(img) + if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info + out = torch.flip(out, axes) + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + + +def resize(img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, transform_info): + """ + Functional implementation of resize. + This function operates eagerly or lazily according to + ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + + Args: + img: data to be changed, assuming `img` is channel-first. + out_size: expected shape of spatial dimensions after resize operation. + mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, + ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} + The interpolation mode. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + align_corners: This only has an effect when mode is + 'linear', 'bilinear', 'bicubic' or 'trilinear'. + dtype: data type for resampling computation. If None, use the data type of input data. + input_ndim: number of spatial dimensions. + anti_aliasing: whether to apply a Gaussian filter to smooth the image prior + to downsampling. It is crucial to filter when downsampling + the image to avoid aliasing artifacts. See also ``skimage.transform.resize`` + anti_aliasing_sigma: {float, tuple of floats}, optional + Standard deviation for Gaussian filtering used when anti-aliasing. + transform_info: a dictionary with the relevant information pertaining to an applied transform. + """ + img = convert_to_tensor(img, track_meta=get_track_meta()) + orig_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] + extra_info = { + "mode": mode, + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, + "dtype": str(dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32 + "new_dim": len(orig_size) - input_ndim, + } + meta_info = TraceableTransform.track_transform_meta( + img, + sp_size=out_size, + affine=scale_affine(orig_size, out_size), + extra_info=extra_info, + orig_size=orig_size, + transform_info=transform_info, + lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + ) + if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + if anti_aliasing and transform_info.get(TraceKeys.LAZY_EVALUATION, False): + warnings.warn("anti-aliasing is not compatible with lazy evaluation.") + out = _maybe_new_metatensor(img) + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info + if tuple(convert_to_numpy(orig_size)) == out_size: + out = _maybe_new_metatensor(img, dtype=torch.float32) + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + out = _maybe_new_metatensor(img) + img_ = convert_to_tensor(out, dtype=dtype, track_meta=False) # convert to a regular tensor + if anti_aliasing and any(x < y for x, y in zip(out_size, img_.shape[1:])): + factors = torch.div(torch.Tensor(list(img_.shape[1:])), torch.Tensor(out_size)) + if anti_aliasing_sigma is None: + # if sigma is not given, use the default sigma in skimage.transform.resize + anti_aliasing_sigma = torch.maximum(torch.zeros(factors.shape), (factors - 1) / 2).tolist() + else: + # if sigma is given, use the given value for downsampling axis + anti_aliasing_sigma = list(ensure_tuple_rep(anti_aliasing_sigma, len(out_size))) + for axis in range(len(out_size)): + anti_aliasing_sigma[axis] = anti_aliasing_sigma[axis] * int(factors[axis] > 1) + anti_aliasing_filter = GaussianSmooth(sigma=anti_aliasing_sigma) + img_ = convert_to_tensor(anti_aliasing_filter(img_), track_meta=False) + resized = torch.nn.functional.interpolate( + input=img_.unsqueeze(0), size=out_size, mode=mode, align_corners=align_corners + ) + out, *_ = convert_to_dst_type(resized.squeeze(0), out, dtype=torch.float32) + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + + +def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, transform_info): + """ + Functional implementation of rotate. + This function operates eagerly or lazily according to + ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + + Args: + img: data to be changed, assuming `img` is channel-first. + angle: Rotation angle(s) in radians. should a float for 2D, three floats for 3D. + output_shape: output shape of the rotated data. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + align_corners: See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + dtype: data type for resampling computation. + If None, use the data type of input data. To be compatible with other modules, + the output data type is always ``float32``. + transform_info: a dictionary with the relevant information pertaining to an applied transform. + + """ + + im_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] + input_ndim = len(im_shape) + if input_ndim not in (2, 3): + raise ValueError(f"Unsupported image dimension: {input_ndim}, available options are [2, 3].") + _angle = ensure_tuple_rep(angle, 1 if input_ndim == 2 else 3) + transform = create_rotate(input_ndim, _angle) + if output_shape is None: + corners = np.asarray(np.meshgrid(*[(0, dim) for dim in im_shape], indexing="ij")).reshape((len(im_shape), -1)) + corners = transform[:-1, :-1] @ corners # type: ignore + output_shape = np.asarray(corners.ptp(axis=1) + 0.5, dtype=int) + else: + output_shape = np.asarray(output_shape, dtype=int) + shift = create_translate(input_ndim, ((np.array(im_shape) - 1) / 2).tolist()) + shift_1 = create_translate(input_ndim, (-(np.asarray(output_shape, dtype=int) - 1) / 2).tolist()) + transform = shift @ transform @ shift_1 + extra_info = { + "rot_mat": transform, + "mode": mode, + "padding_mode": padding_mode, + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, + "dtype": str(dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32 + } + meta_info = TraceableTransform.track_transform_meta( + img, + sp_size=output_shape, + affine=transform, + extra_info=extra_info, + orig_size=im_shape, + transform_info=transform_info, + lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + ) + out = _maybe_new_metatensor(img) + if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info + xform = AffineTransform( + normalized=False, mode=mode, padding_mode=padding_mode, align_corners=align_corners, reverse_indexing=True + ) + img_t = out.to(dtype) + transform_t, *_ = convert_to_dst_type(transform, img_t) + output: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=tuple(int(i) for i in output_shape)) + output = output.float().squeeze(0) + out, *_ = convert_to_dst_type(output, dst=out, dtype=torch.float32) + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + + +def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, transform_info): + """ + Functional implementation of zoom. + This function operates eagerly or lazily according to + ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + + Args: + img: data to be changed, assuming `img` is channel-first. + scale_factor: The zoom factor along the spatial axes. + If a float, zoom is the same for each spatial axis. + If a sequence, zoom should contain one value for each spatial axis. + keep_size: Whether keep original size (padding/slicing if needed). + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + align_corners: See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + dtype: data type for resampling computation. + If None, use the data type of input data. To be compatible with other modules, + the output data type is always ``float32``. + transform_info: a dictionary with the relevant information pertaining to an applied transform. + + """ + im_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] + output_size = [ + int(math.floor(float(i) * z)) + for i, z in zip(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:], scale_factor) + ] + xform = scale_affine(im_shape, output_size) + extra_info = { + "mode": mode, + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, + "dtype": str(dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32 + "do_padcrop": False, + "padcrop": {}, + } + if keep_size: + if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + raise NotImplementedError("keep_size=True is not supported for lazy evaluation.") + output_size = [int(i) for i in img.shape[1:]] + meta_info = TraceableTransform.track_transform_meta( + img, + sp_size=output_size, + affine=xform, + extra_info=extra_info, + orig_size=im_shape, + transform_info=transform_info, + lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + ) + out = _maybe_new_metatensor(img) + if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info + img_t = out.to(dtype) + zoomed: NdarrayOrTensor = torch.nn.functional.interpolate( + recompute_scale_factor=True, + input=img_t.unsqueeze(0), + scale_factor=list(scale_factor), + mode=mode, + align_corners=align_corners, + ).squeeze(0) + out, *_ = convert_to_dst_type(zoomed, dst=out, dtype=torch.float32) + if isinstance(out, MetaTensor): + out = out.copy_meta_from(meta_info) + do_pad_crop = not np.allclose(output_size, zoomed.shape[1:]) + if do_pad_crop: + _pad_crop = ResizeWithPadOrCrop(spatial_size=img_t.shape[1:], mode=padding_mode) + out = _pad_crop(out) + if get_track_meta() and do_pad_crop: + padcrop_xform = out.applied_operations.pop() + out.applied_operations[-1]["extra_info"]["do_padcrop"] = True + out.applied_operations[-1]["extra_info"]["padcrop"] = padcrop_xform + return out + + +def rotate90(img, axes, k, transform_info): + """ + Functional implementation of rotate90. + This function operates eagerly or lazily according to + ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + + Args: + img: data to be changed, assuming `img` is channel-first. + axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. + If axis is negative it counts from the last to the first axis. + k: number of times to rotate by 90 degrees. + transform_info: a dictionary with the relevant information pertaining to an applied transform. + """ + extra_info = {"axes": [d - 1 for d in axes], "k": k} + ori_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] + sp_shape = list(ori_shape) + if k in (1, 3): + a_0, a_1 = axes[0] - 1, axes[1] - 1 + sp_shape[a_0], sp_shape[a_1] = ori_shape[a_1], ori_shape[a_0] + rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else torch.tensor(3.0, dtype=torch.double) + r, sp_r = int(rank), len(ori_shape) + xform = to_affine_nd(r, create_translate(sp_r, [-float(d - 1) / 2 for d in sp_shape])) + s = -1.0 if int(axes[0]) - int(axes[1]) in (-1, 2) else 1.0 + if sp_r == 2: + rot90 = to_affine_nd(r, create_rotate(sp_r, [s * np.pi / 2])) + else: + idx = {1, 2, 3} - set(axes) + angle: list[float] = [0, 0, 0] + angle[idx.pop() - 1] = s * np.pi / 2 + rot90 = to_affine_nd(r, create_rotate(sp_r, angle)) + for _ in range(k): + xform = rot90 @ xform + xform = to_affine_nd(r, create_translate(sp_r, [float(d - 1) / 2 for d in ori_shape])) @ xform + meta_info = TraceableTransform.track_transform_meta( + img, + sp_size=sp_shape, + affine=xform, + extra_info=extra_info, + orig_size=ori_shape, + transform_info=transform_info, + lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + ) + out = _maybe_new_metatensor(img) + if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info + out = torch.rot90(out, k, axes) + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + + +def affine_func(img, affine, grid, resampler, sp_size, mode, padding_mode, do_resampling, image_only, transform_info): + """ + Functional implementation of affine. + This function operates eagerly or lazily according to + ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + + Args: + img: data to be changed, assuming `img` is channel-first. + affine: the affine transformation to be applied, it can be a 3x3 or 4x4 matrix. This should be defined + for the voxel space spatial centers (``float(size - 1)/2``). + grid: used in non-lazy mode to pre-compute the grid to do the resampling. + resampler: the resampler function, see also: :py:class:`monai.transforms.Resample`. + sp_size: output image spatial size. + mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers). + Interpolation mode to calculate output values. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used + and the value represents the order of the spline interpolation. + See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `mode` is an integer, using numpy/cupy backends, this argument accepts + {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}. + See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html + do_resampling: whether to do the resampling, this is a flag for the use case of updating metadata but + skipping the actual (potentially heavy) resampling operation. + image_only: if True return only the image volume, otherwise return (image, affine). + transform_info: a dictionary with the relevant information pertaining to an applied transform. + + """ + + # resampler should carry the align_corners and type info + img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] + rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else torch.tensor(3.0, dtype=torch.double) + extra_info = { + "affine": affine, + "mode": mode, + "padding_mode": padding_mode, + "do_resampling": do_resampling, + "align_corners": resampler.align_corners, + } + affine = monai.transforms.Affine.compute_w_affine(rank, affine, img_size, sp_size) + meta_info = TraceableTransform.track_transform_meta( + img, + sp_size=sp_size, + affine=affine, + extra_info=extra_info, + orig_size=img_size, + transform_info=transform_info, + lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + ) + if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + out = _maybe_new_metatensor(img) + out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info + return out if image_only else (out, affine) + if do_resampling: + out = resampler(img=img, grid=grid, mode=mode, padding_mode=padding_mode) + out = _maybe_new_metatensor(out) + else: + out = _maybe_new_metatensor(img, dtype=torch.float32, device=resampler.device) + out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + return out if image_only else (out, affine) diff --git a/monai/transforms/traits.py b/monai/transforms/traits.py new file mode 100644 index 0000000000..0193065562 --- /dev/null +++ b/monai/transforms/traits.py @@ -0,0 +1,80 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A collection of generic traits for MONAI transforms. +""" + +from __future__ import annotations + +__all__ = ["LazyTrait", "RandomizableTrait", "MultiSampleTrait", "ThreadUnsafe"] + + +class LazyTrait: + """ + An interface to indicate that the transform has the capability to execute using + MONAI's lazy resampling feature. In order to do this, the implementing class needs + to be able to describe its operation as an affine matrix or grid with accompanying metadata. + This interface can be extended from by people adapting transforms to the MONAI framework as + well as by implementors of MONAI transforms. + """ + + @property + def lazy_evaluation(self): + """ + Get whether lazy_evaluation is enabled for this transform instance. + Returns: + True if the transform is operating in a lazy fashion, False if not. + """ + raise NotImplementedError() + + @lazy_evaluation.setter + def lazy_evaluation(self, enabled: bool): + """ + Set whether lazy_evaluation is enabled for this transform instance. + Args: + enabled: True if the transform should operate in a lazy fashion, False if not. + """ + raise NotImplementedError() + + +class RandomizableTrait: + """ + An interface to indicate that the transform has the capability to perform + randomized transforms to the data that it is called upon. This interface + can be extended from by people adapting transforms to the MONAI framework as well as by + implementors of MONAI transforms. + """ + + pass + + +class MultiSampleTrait: + """ + An interface to indicate that the transform has the capability to return multiple samples + given an input, such as when performing random crops of a sample. This interface can be + extended from by people adapting transforms to the MONAI framework as well as by implementors + of MONAI transforms. + """ + + pass + + +class ThreadUnsafe: + """ + A class to denote that the transform will mutate its member variables, + when being applied. Transforms inheriting this class should be used + cautiously in a multi-thread context. + + This type is typically used by :py:class:`monai.data.CacheDataset` and + its extensions, where the transform cache is built with multiple threads. + """ + + pass diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index c5ab29133b..3e66431bbc 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -12,9 +12,12 @@ A collection of generic interfaces for MONAI transforms. """ +from __future__ import annotations + import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Mapping, Optional, Tuple, TypeVar, Union +from collections.abc import Callable, Generator, Hashable, Iterable, Mapping +from typing import Any, TypeVar import numpy as np import torch @@ -22,6 +25,7 @@ from monai import config, transforms from monai.config import KeysCollection from monai.data.meta_tensor import MetaTensor +from monai.transforms.traits import LazyTrait, RandomizableTrait, ThreadUnsafe from monai.utils import MAX_SEED, ensure_tuple, first from monai.utils.enums import TransformBackends from monai.utils.misc import MONAIEnvVars @@ -29,9 +33,6 @@ __all__ = [ "ThreadUnsafe", "apply_transform", - "LazyTrait", - "RandomizableTrait", - "MultiSampleTrait", "Randomizable", "LazyTransform", "RandomizableTransform", @@ -72,7 +73,7 @@ def apply_transform( map_items: bool = True, unpack_items: bool = False, log_stats: bool = False, -) -> Union[List[ReturnType], ReturnType]: +) -> list[ReturnType] | ReturnType: """ Transform `data` with `transform`. @@ -113,7 +114,7 @@ def apply_transform( if isinstance(data, (list, tuple)): data = data[0] - def _log_stats(data, prefix: Optional[str] = "Data"): + def _log_stats(data, prefix: str | None = "Data"): if isinstance(data, (np.ndarray, torch.Tensor)): # log data type, shape, range for array datastats(img=data, data_shape=True, value_range=True, prefix=prefix) @@ -129,69 +130,6 @@ def _log_stats(data, prefix: Optional[str] = "Data"): raise RuntimeError(f"applying transform {transform}") from e -class LazyTrait: - """ - An interface to indicate that the transform has the capability to execute using - MONAI's lazy resampling feature. In order to do this, the implementing class needs - to be able to describe its operation as an affine matrix or grid with accompanying metadata. - This interface can be extended from by people adapting transforms to the MONAI framework as - well as by implementors of MONAI transforms. - """ - - @property - def lazy_evaluation(self): - """ - Get whether lazy_evaluation is enabled for this transform instance. - Returns: - True if the transform is operating in a lazy fashion, False if not. - """ - raise NotImplementedError() - - @lazy_evaluation.setter - def lazy_evaluation(self, enabled: bool): - """ - Set whether lazy_evaluation is enabled for this transform instance. - Args: - enabled: True if the transform should operate in a lazy fashion, False if not. - """ - raise NotImplementedError() - - -class RandomizableTrait: - """ - An interface to indicate that the transform has the capability to perform - randomized transforms to the data that it is called upon. This interface - can be extended from by people adapting transforms to the MONAI framework as well as by - implementors of MONAI transforms. - """ - - pass - - -class MultiSampleTrait: - """ - An interface to indicate that the transform has the capability to return multiple samples - given an input, such as when performing random crops of a sample. This interface can be - extended from by people adapting transforms to the MONAI framework as well as by implementors - of MONAI transforms. - """ - - pass - - -class ThreadUnsafe: - """ - A class to denote that the transform will mutate its member variables, - when being applied. Transforms inheriting this class should be used - cautiously in a multi-thread context. - - This type is typically used by :py:class:`monai.data.CacheDataset` and - its extensions, where the transform cache is built with multiple threads. - """ - - pass - - class Randomizable(ThreadUnsafe, RandomizableTrait): """ An interface for handling random state locally, currently based on a class @@ -206,9 +144,7 @@ class Randomizable(ThreadUnsafe, RandomizableTrait): R: np.random.RandomState = np.random.RandomState() - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "Randomizable": + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Randomizable: """ Set the random state locally, to control the randomness, the derived classes should use :py:attr:`self.R` instead of `np.random` to introduce random @@ -283,7 +219,7 @@ class Transform(ABC): # to other data types during the transformation. Note that not all `dtype` (such as float32, uint8) are supported # by all the data types, the `dtype` during the conversion is determined automatically by each transform, # please refer to the transform's docstring. - backend: List[TransformBackends] = [] + backend: list[TransformBackends] = [] @abstractmethod def __call__(self, data: Any): @@ -318,18 +254,17 @@ class LazyTransform(Transform, LazyTrait): dictionary transforms to simplify implementation of new lazy transforms. """ - def __init__(self, lazy_evaluation: Optional[bool] = True): - self.lazy_evaluation = lazy_evaluation + _lazy_evaluation: bool = False @property def lazy_evaluation(self): - return self.lazy_evaluation + return self._lazy_evaluation @lazy_evaluation.setter def lazy_evaluation(self, lazy_evaluation: bool): if not isinstance(lazy_evaluation, bool): - raise TypeError("'lazy_evaluation must be a bool but is of " f"type {type(lazy_evaluation)}'") - self.lazy_evaluation = lazy_evaluation + raise TypeError(f"lazy_evaluation must be a bool but is of type {type(lazy_evaluation)}") + self._lazy_evaluation = lazy_evaluation class RandomizableTransform(Randomizable, Transform): @@ -412,7 +347,7 @@ def __new__(cls, *args, **kwargs): return Transform.__new__(cls) def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: - self.keys: Tuple[Hashable, ...] = ensure_tuple(keys) + self.keys: tuple[Hashable, ...] = ensure_tuple(keys) self.allow_missing_keys = allow_missing_keys if not self.keys: raise ValueError("keys must be non empty.") @@ -469,7 +404,7 @@ def __call__(self, data): """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Optional[Iterable]) -> Generator: + def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Iterable | None) -> Generator: """ Iterate across keys and optionally extra iterables. If key is missing, exception is raised if `allow_missing_keys==False` (default). If `allow_missing_keys==True`, key is skipped. @@ -482,7 +417,7 @@ def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Optional[ ex_iters = extra_iterables or [[None] * len(self.keys)] # loop over keys and any extra iterables - _ex_iters: List[Any] + _ex_iters: list[Any] for key, *_ex_iters in zip(self.keys, *ex_iters): # all normal, yield (what we yield depends on whether extra iterables were given) if key in data: @@ -493,7 +428,7 @@ def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Optional[ " and allow_missing_keys==False." ) - def first_key(self, data: Dict[Hashable, Any]): + def first_key(self, data: dict[Hashable, Any]): """ Get the first available key of `self.keys` in the input `data` dictionary. If no available key, return an empty tuple `()`. diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 27314ab91c..3cbaab1430 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -13,22 +13,38 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ +from __future__ import annotations + import logging import sys import time import warnings +from collections.abc import Mapping, Sequence from copy import deepcopy -from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from functools import partial +from typing import Any, Callable import numpy as np import torch +import torch.nn as nn from monai.config import DtypeLike from monai.config.type_definitions import NdarrayOrTensor from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor -from monai.data.utils import no_collation +from monai.data.utils import is_no_channel, no_collation +from monai.networks.layers.simplelayers import ( + ApplyFilter, + EllipticalFilter, + GaussianFilter, + LaplaceFilter, + MeanFilter, + SavitzkyGolayFilter, + SharpenFilter, + median_filter, +) from monai.transforms.inverse import InvertibleTransform +from monai.transforms.traits import MultiSampleTrait from monai.transforms.transform import Randomizable, RandomizableTrait, RandomizableTransform, Transform from monai.transforms.utils import ( extreme_points_to_image, @@ -38,13 +54,13 @@ ) from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, moveaxis, unravel_indices from monai.utils import ( + MetaKeys, TraceKeys, convert_data_type, convert_to_cupy, convert_to_numpy, convert_to_tensor, deprecated, - deprecated_arg, ensure_tuple, look_up_option, min_version, @@ -60,6 +76,7 @@ __all__ = [ "Identity", + "RandIdentity", "AsChannelFirst", "AsChannelLast", "AddChannel", @@ -92,6 +109,8 @@ "CuCIM", "RandCuCIM", "ToCupy", + "ImageFilter", + "RandImageFilter", ] @@ -111,6 +130,18 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: return img +class RandIdentity(RandomizableTrait): + """ + Do nothing to the data. This transform is random, so can be used to stop the caching of any + subsequent transforms. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, data: Any) -> Any: + return data + + @deprecated(since="0.8", msg_suffix="please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.") class AsChannelFirst(Transform): """ @@ -132,7 +163,7 @@ class AsChannelFirst(Transform): def __init__(self, channel_dim: int = -1) -> None: if not (isinstance(channel_dim, int) and channel_dim >= -1): - raise AssertionError("invalid channel dimension.") + raise ValueError(f"invalid channel dimension ({channel_dim}).") self.channel_dim = channel_dim def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: @@ -162,7 +193,7 @@ class AsChannelLast(Transform): def __init__(self, channel_dim: int = 0) -> None: if not (isinstance(channel_dim, int) and channel_dim >= -1): - raise AssertionError("invalid channel dimension.") + raise ValueError(f"invalid channel dimension ({channel_dim}).") self.channel_dim = channel_dim def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: @@ -216,11 +247,11 @@ class EnsureChannelFirst(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, strict_check: bool = True, channel_dim: Union[None, str, int] = None): + def __init__(self, strict_check: bool = True, channel_dim: None | str | int = None): self.strict_check = strict_check self.input_channel_dim = channel_dim - def __call__(self, img: torch.Tensor, meta_dict: Optional[Mapping] = None) -> torch.Tensor: + def __call__(self, img: torch.Tensor, meta_dict: Mapping | None = None) -> torch.Tensor: """ Apply the transform to `img`. """ @@ -237,9 +268,9 @@ def __call__(self, img: torch.Tensor, meta_dict: Optional[Mapping] = None) -> to if isinstance(img, MetaTensor): meta_dict = img.meta - channel_dim = meta_dict.get("original_channel_dim", None) if isinstance(meta_dict, Mapping) else None + channel_dim = meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None) if isinstance(meta_dict, Mapping) else None if self.input_channel_dim is not None: - channel_dim = self.input_channel_dim + channel_dim = float("nan") if self.input_channel_dim == "no_channel" else self.input_channel_dim if channel_dim is None: msg = "Unknown original_channel_dim in the MetaTensor meta dict or `meta_dict` or `channel_dim`." @@ -250,12 +281,12 @@ def __call__(self, img: torch.Tensor, meta_dict: Optional[Mapping] = None) -> to # track the original channel dim if isinstance(meta_dict, dict): - meta_dict["original_channel_dim"] = channel_dim + meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = channel_dim - if channel_dim == "no_channel": + if is_no_channel(channel_dim): result = img[None] else: - result = moveaxis(img, channel_dim, 0) # type: ignore + result = moveaxis(img, int(channel_dim), 0) # type: ignore return convert_to_tensor(result, track_meta=get_track_meta()) # type: ignore @@ -274,7 +305,7 @@ class RepeatChannel(Transform): def __init__(self, repeats: int) -> None: if repeats <= 0: - raise AssertionError("repeats count must be greater than 0.") + raise ValueError(f"repeats count must be greater than 0, got {repeats}.") self.repeats = repeats def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: @@ -299,7 +330,7 @@ class RemoveRepeatedChannel(Transform): def __init__(self, repeats: int) -> None: if repeats <= 0: - raise AssertionError("repeats count must be greater than 0.") + raise ValueError(f"repeats count must be greater than 0, got {repeats}.") self.repeats = repeats @@ -308,13 +339,13 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: Apply the transform to `img`, assuming `img` is a "channel-first" array. """ if img.shape[0] < 2: - raise AssertionError("Image must have more than one channel") + raise ValueError(f"Image must have more than one channel, got {img.shape[0]} channels.") out: NdarrayOrTensor = convert_to_tensor(img[:: self.repeats, :], track_meta=get_track_meta()) return out -class SplitDim(Transform): +class SplitDim(Transform, MultiSampleTrait): """ Given an image of size X along a certain dimension, return a list of length X containing images. Useful for converting 3D images into a stack of 2D images, splitting multichannel inputs into @@ -336,13 +367,11 @@ def __init__(self, dim: int = -1, keepdim: bool = True, update_meta=True) -> Non self.keepdim = keepdim self.update_meta = update_meta - def __call__(self, img: torch.Tensor) -> List[torch.Tensor]: + def __call__(self, img: torch.Tensor) -> list[torch.Tensor]: """ Apply the transform to `img`. """ n_out = img.shape[self.dim] - if n_out <= 1: - raise RuntimeError(f"Input image is singleton along dimension to be split, got shape {img.shape}.") if isinstance(img, torch.Tensor): outputs = list(torch.split(img, 1, self.dim)) else: @@ -394,7 +423,7 @@ def __init__(self, dtype=np.float32) -> None: """ self.dtype = dtype - def __call__(self, img: NdarrayOrTensor, dtype: Union[DtypeLike, torch.dtype] = None) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, dtype: DtypeLike | torch.dtype = None) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is a numpy array or PyTorch Tensor. @@ -429,10 +458,10 @@ class ToTensor(Transform): def __init__( self, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, wrap_sequence: bool = True, - track_meta: Optional[bool] = None, + track_meta: bool | None = None, ) -> None: super().__init__() self.dtype = dtype @@ -474,10 +503,10 @@ class EnsureType(Transform): def __init__( self, data_type: str = "tensor", - dtype: Optional[Union[DtypeLike, torch.dtype]] = None, - device: Optional[torch.device] = None, + dtype: DtypeLike | torch.dtype | None = None, + device: torch.device | None = None, wrap_sequence: bool = True, - track_meta: Optional[bool] = None, + track_meta: bool | None = None, ) -> None: self.data_type = look_up_option(data_type.lower(), {"tensor", "numpy"}) self.dtype = dtype @@ -549,7 +578,7 @@ class ToCupy(Transform): backend = [TransformBackends.CUPY] - def __init__(self, dtype: Optional[np.dtype] = None, wrap_sequence: bool = True) -> None: + def __init__(self, dtype: np.dtype | None = None, wrap_sequence: bool = True) -> None: super().__init__() self.dtype = dtype self.wrap_sequence = wrap_sequence @@ -586,7 +615,7 @@ class Transpose(Transform): backend = [TransformBackends.TORCH] - def __init__(self, indices: Optional[Sequence[int]]) -> None: + def __init__(self, indices: Sequence[int] | None) -> None: self.indices = None if indices is None else tuple(indices) def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: @@ -604,7 +633,7 @@ class SqueezeDim(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, dim: Optional[int] = 0, update_meta=True) -> None: + def __init__(self, dim: int | None = 0, update_meta=True) -> None: """ Args: dim: dimension to be squeezed. Default = 0 @@ -670,7 +699,7 @@ def __init__( data_shape: bool = True, value_range: bool = True, data_value: bool = False, - additional_info: Optional[Callable] = None, + additional_info: Callable | None = None, name: str = "DataStats", ) -> None: """ @@ -689,7 +718,7 @@ def __init__( """ if not isinstance(prefix, str): - raise AssertionError("prefix must be a string.") + raise ValueError(f"prefix must be a string, got {type(prefix)}.") self.prefix = prefix self.data_type = data_type self.data_shape = data_shape @@ -716,12 +745,12 @@ def __init__( def __call__( self, img: NdarrayOrTensor, - prefix: Optional[str] = None, - data_type: Optional[bool] = None, - data_shape: Optional[bool] = None, - value_range: Optional[bool] = None, - data_value: Optional[bool] = None, - additional_info: Optional[Callable] = None, + prefix: str | None = None, + data_type: bool | None = None, + data_shape: bool | None = None, + value_range: bool | None = None, + data_value: bool | None = None, + additional_info: Callable | None = None, ) -> NdarrayOrTensor: """ Apply the transform to `img`, optionally take arguments similar to the class constructor. @@ -773,7 +802,7 @@ def __init__(self, delay_time: float = 0.0) -> None: super().__init__() self.delay_time: float = delay_time - def __call__(self, img: NdarrayOrTensor, delay_time: Optional[float] = None) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, delay_time: float | None = None) -> NdarrayOrTensor: """ Args: img: data remain unchanged throughout this transform. @@ -809,13 +838,13 @@ class Lambda(InvertibleTransform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, func: Optional[Callable] = None, inv_func: Callable = no_collation) -> None: + def __init__(self, func: Callable | None = None, inv_func: Callable = no_collation) -> None: if func is not None and not callable(func): raise TypeError(f"func must be None or callable but is {type(func).__name__}.") self.func = func self.inv_func = inv_func - def __call__(self, img: NdarrayOrTensor, func: Optional[Callable] = None): + def __call__(self, img: NdarrayOrTensor, func: Callable | None = None): """ Apply `self.func` to `img`. @@ -858,11 +887,11 @@ class RandLambda(Lambda, RandomizableTransform): backend = Lambda.backend - def __init__(self, func: Optional[Callable] = None, prob: float = 1.0, inv_func: Callable = no_collation) -> None: + def __init__(self, func: Callable | None = None, prob: float = 1.0, inv_func: Callable = no_collation) -> None: Lambda.__init__(self=self, func=func, inv_func=inv_func) RandomizableTransform.__init__(self=self, prob=prob) - def __call__(self, img: NdarrayOrTensor, func: Optional[Callable] = None): + def __call__(self, img: NdarrayOrTensor, func: Callable | None = None): self.randomize(img) out = deepcopy(super().__call__(img, func) if self._do_transform else img) # convert to MetaTensor if necessary @@ -904,16 +933,13 @@ class LabelToMask(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( # pytype: disable=annotation-type-mismatch - self, select_labels: Union[Sequence[int], int], merge_channels: bool = False + self, select_labels: Sequence[int] | int, merge_channels: bool = False ) -> None: # pytype: disable=annotation-type-mismatch self.select_labels = ensure_tuple(select_labels) self.merge_channels = merge_channels def __call__( - self, - img: NdarrayOrTensor, - select_labels: Optional[Union[Sequence[int], int]] = None, - merge_channels: bool = False, + self, img: NdarrayOrTensor, select_labels: Sequence[int] | int | None = None, merge_channels: bool = False ) -> NdarrayOrTensor: """ Args: @@ -950,7 +976,7 @@ def __call__( return data -class FgBgToIndices(Transform): +class FgBgToIndices(Transform, MultiSampleTrait): """ Compute foreground and background of the input label data, return the indices. If no output_shape specified, output data will be 1 dim indices after flattening. @@ -967,16 +993,13 @@ class FgBgToIndices(Transform): backend = [TransformBackends.NUMPY, TransformBackends.TORCH] - def __init__(self, image_threshold: float = 0.0, output_shape: Optional[Sequence[int]] = None) -> None: + def __init__(self, image_threshold: float = 0.0, output_shape: Sequence[int] | None = None) -> None: self.image_threshold = image_threshold self.output_shape = output_shape def __call__( - self, - label: NdarrayOrTensor, - image: Optional[NdarrayOrTensor] = None, - output_shape: Optional[Sequence[int]] = None, - ) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: + self, label: NdarrayOrTensor, image: NdarrayOrTensor | None = None, output_shape: Sequence[int] | None = None + ) -> tuple[NdarrayOrTensor, NdarrayOrTensor]: """ Args: label: input data to compute foreground and background indices. @@ -994,15 +1017,11 @@ def __call__( return fg_indices, bg_indices -class ClassesToIndices(Transform): - +class ClassesToIndices(Transform, MultiSampleTrait): backend = [TransformBackends.NUMPY, TransformBackends.TORCH] def __init__( - self, - num_classes: Optional[int] = None, - image_threshold: float = 0.0, - output_shape: Optional[Sequence[int]] = None, + self, num_classes: int | None = None, image_threshold: float = 0.0, output_shape: Sequence[int] | None = None ) -> None: """ Compute indices of every class of the input label data, return a list of indices. @@ -1023,11 +1042,8 @@ def __init__( self.output_shape = output_shape def __call__( - self, - label: NdarrayOrTensor, - image: Optional[NdarrayOrTensor] = None, - output_shape: Optional[Sequence[int]] = None, - ) -> List[NdarrayOrTensor]: + self, label: NdarrayOrTensor, image: NdarrayOrTensor | None = None, output_shape: Sequence[int] | None = None + ) -> list[NdarrayOrTensor]: """ Args: label: input data to compute the indices of every class. @@ -1039,7 +1055,7 @@ def __call__( if output_shape is None: output_shape = self.output_shape - indices: List[NdarrayOrTensor] + indices: list[NdarrayOrTensor] indices = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) if output_shape is not None: indices = [unravel_indices(cls_indices, output_shape) for cls_indices in indices] @@ -1095,7 +1111,7 @@ class AddExtremePointsChannel(Randomizable, Transform): def __init__(self, background: int = 0, pert: float = 0.0) -> None: self._background = background self._pert = pert - self._points: List[Tuple[int, ...]] = [] + self._points: list[tuple[int, ...]] = [] def randomize(self, label: NdarrayOrTensor) -> None: self._points = get_extreme_points(label, rand_state=self.R, background=self._background, pert=self._pert) @@ -1103,8 +1119,8 @@ def randomize(self, label: NdarrayOrTensor) -> None: def __call__( self, img: NdarrayOrTensor, - label: Optional[NdarrayOrTensor] = None, - sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 3.0, + label: NdarrayOrTensor | None = None, + sigma: Sequence[float] | float | Sequence[torch.Tensor] | torch.Tensor = 3.0, rescale_min: float = -1.0, rescale_max: float = 1.0, ) -> NdarrayOrTensor: @@ -1238,14 +1254,14 @@ class IntensityStats(Transform): backend = [TransformBackends.NUMPY] - def __init__(self, ops: Sequence[Union[str, Callable]], key_prefix: str, channel_wise: bool = False) -> None: + def __init__(self, ops: Sequence[str | Callable], key_prefix: str, channel_wise: bool = False) -> None: self.ops = ensure_tuple(ops) self.key_prefix = key_prefix self.channel_wise = channel_wise def __call__( - self, img: NdarrayOrTensor, meta_data: Optional[Dict] = None, mask: Optional[np.ndarray] = None - ) -> Tuple[NdarrayOrTensor, Dict]: + self, img: NdarrayOrTensor, meta_data: dict | None = None, mask: np.ndarray | None = None + ) -> tuple[NdarrayOrTensor, dict]: """ Compute statistics for the intensity of input image. @@ -1309,7 +1325,7 @@ class ToDevice(Transform): backend = [TransformBackends.TORCH] - def __init__(self, device: Union[torch.device, str], **kwargs) -> None: + def __init__(self, device: torch.device | str, **kwargs) -> None: """ Args: device: target device to move the Tensor, for example: "cuda:1". @@ -1403,9 +1419,6 @@ class AddCoordinateChannels(Transform): backend = [TransformBackends.NUMPY] - @deprecated_arg( - name="spatial_channels", new_name="spatial_dims", since="0.8", msg_suffix="please use `spatial_dims` instead." - ) def __init__(self, spatial_dims: Sequence[int]) -> None: self.spatial_dims = spatial_dims @@ -1422,3 +1435,282 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: coord_channels, *_ = convert_to_dst_type(coord_channels, img) # type: ignore coord_channels = coord_channels[list(self.spatial_dims)] return concatenate((img, coord_channels), axis=0) + + +class ImageFilter(Transform): + """ + Applies a convolution filter to the input image. + + Args: + filter: + A string specifying the filter, a custom filter as ``torch.Tenor`` or ``np.ndarray`` or a ``nn.Module``. + Available options for string are: ``mean``, ``laplace``, ``elliptical``, ``sobel``, ``sharpen``, ``median``, ``gauss`` + See below for short explanations on every filter. + filter_size: + A single integer value specifying the size of the quadratic or cubic filter. + Computational complexity scales to the power of 2 (2D filter) or 3 (3D filter), which + should be considered when choosing filter size. + kwargs: + Additional arguments passed to filter function, required by ``sobel`` and ``gauss``. + See below for details. + + Raises: + ValueError: When ``filter_size`` is not an uneven integer + ValueError: When ``filter`` is an array and ``ndim`` is not in [1,2,3] + ValueError: When ``filter`` is an array and any dimension has an even shape + NotImplementedError: When ``filter`` is a string and not in ``self.supported_filters`` + KeyError: When necessary ``kwargs`` are not passed to a filter that requires additional arguments. + + + **Mean Filtering:** ``filter='mean'`` + + Mean filtering can smooth edges and remove aliasing artifacts in an segmentation image. + See also py:func:`monai.networks.layers.simplelayers.MeanFilter` + Example 2D filter (5 x 5):: + + [[1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1]] + + If smoothing labels with this filter, ensure they are in one-hot format. + + **Outline Detection:** ``filter='laplace'`` + + Laplacian filtering for outline detection in images. Can be used to transform labels to contours. + See also py:func:`monai.networks.layers.simplelayers.LaplaceFilter` + + Example 2D filter (5x5):: + + [[-1., -1., -1., -1., -1.], + [-1., -1., -1., -1., -1.], + [-1., -1., 24., -1., -1.], + [-1., -1., -1., -1., -1.], + [-1., -1., -1., -1., -1.]] + + + **Dilation:** ``filter='elliptical'`` + + An elliptical filter can be used to dilate labels or label-contours. + Example 2D filter (5x5):: + + [[0., 0., 1., 0., 0.], + [1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [0., 0., 1., 0., 0.]] + + + **Edge Detection:** ``filter='sobel'`` + + This filter allows for additional arguments passed as ``kwargs`` during initialization. + See also py:func:`monai.transforms.post.SobelGradients` + + *kwargs* + + * ``spatial_axes``: the axes that define the direction of the gradient to be calculated. + It calculates the gradient along each of the provide axis. + By default it calculate the gradient for all spatial axes. + * ``normalize_kernels``: if normalize the Sobel kernel to provide proper gradients. Defaults to True. + * ``normalize_gradients``: if normalize the output gradient to 0 and 1. Defaults to False. + * ``padding_mode``: the padding mode of the image when convolving with Sobel kernels. Defaults to ``"reflect"``. + Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + See ``torch.nn.Conv1d()`` for more information. + * ``dtype``: kernel data type (torch.dtype). Defaults to ``torch.float32``. + + + **Sharpening:** ``filter='sharpen'`` + + Sharpen an image with a 2D or 3D filter. + Example 2D filter (5x5):: + + [[ 0., 0., -1., 0., 0.], + [-1., -1., -1., -1., -1.], + [-1., -1., 17., -1., -1.], + [-1., -1., -1., -1., -1.], + [ 0., 0., -1., 0., 0.]] + + + **Gaussian Smooth:** ``filter='gauss'`` + + Blur/smooth an image with 2D or 3D gaussian filter. + This filter requires additional arguments passed as ``kwargs`` during initialization. + See also py:func:`monai.networks.layers.simplelayers.GaussianFilter` + + *kwargs* + + * ``sigma``: std. could be a single value, or spatial_dims number of values. + * ``truncated``: spreads how many stds. + * ``approx``: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace". + + + **Median Filter:** ``filter='median'`` + + Blur an image with 2D or 3D median filter to remove noise. + Useful in image preprocessing to improve results of later processing. + See also py:func:`monai.networks.layers.simplelayers.MedianFilter` + + + **Savitzky Golay Filter:** ``filter = 'savitzky_golay'`` + + Convolve a Tensor along a particular axis with a Savitzky-Golay kernel. + This filter requires additional arguments passed as ``kwargs`` during initialization. + See also py:func:`monai.networks.layers.simplelayers.SavitzkyGolayFilter` + + *kwargs* + + * ``order``: Order of the polynomial to fit to each window, must be less than ``window_length``. + * ``axis``: (optional): Axis along which to apply the filter kernel. Default 2 (first spatial dimension). + * ``mode``: (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or + ``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information. + + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + supported_filters = sorted( + ["mean", "laplace", "elliptical", "sobel", "sharpen", "median", "gauss", "savitzky_golay"] + ) + + def __init__(self, filter: str | NdarrayOrTensor | nn.Module, filter_size: int | None = None, **kwargs) -> None: + self._check_filter_format(filter, filter_size) + self._check_kwargs_are_present(filter, **kwargs) + self.filter = filter + self.filter_size = filter_size + self.additional_args_for_filter = kwargs + + def __call__(self, img: NdarrayOrTensor, meta_dict: dict | None = None) -> NdarrayOrTensor: + """ + Args: + img: torch tensor data to apply filter to with shape: [channels, height, width[, depth]] + meta_dict: An optional dictionary with metadata + + Returns: + A MetaTensor with the same shape as `img` and identical metadata + """ + if isinstance(img, MetaTensor): + meta_dict = img.meta + img_, prev_type, device = convert_data_type(img, torch.Tensor) + ndim = img_.ndim - 1 # assumes channel first format + + if isinstance(self.filter, str): + self.filter = self._get_filter_from_string(self.filter, self.filter_size, ndim) # type: ignore + elif isinstance(self.filter, (torch.Tensor, np.ndarray)): + self.filter = ApplyFilter(self.filter) + + img_ = self._apply_filter(img_) + if meta_dict: + img_ = MetaTensor(img_, meta=meta_dict) + else: + img_, *_ = convert_data_type(img_, prev_type, device) + return img_ + + def _check_all_values_uneven(self, x: tuple) -> None: + for value in x: + if value % 2 == 0: + raise ValueError(f"Only uneven filters are supported, but filter size is {x}") + + def _check_filter_format(self, filter: str | NdarrayOrTensor | nn.Module, filter_size: int | None = None) -> None: + if isinstance(filter, str): + if not filter_size: + raise ValueError("`filter_size` must be specified when specifying filters by string.") + if filter_size % 2 == 0: + raise ValueError("`filter_size` should be a single uneven integer.") + if filter not in self.supported_filters: + raise NotImplementedError(f"{filter}. Supported filters are {self.supported_filters}.") + elif isinstance(filter, torch.Tensor) or isinstance(filter, np.ndarray): + if filter.ndim not in [1, 2, 3]: + raise ValueError("Only 1D, 2D, and 3D filters are supported.") + self._check_all_values_uneven(filter.shape) + elif isinstance(filter, (nn.Module, Transform)): + pass + else: + raise TypeError( + f"{type(filter)} is not supported." + "Supported types are `class 'str'`, `class 'torch.Tensor'`, `class 'np.ndarray'`, " + "`class 'torch.nn.modules.module.Module'`, `class 'monai.transforms.Transform'`" + ) + + def _check_kwargs_are_present(self, filter, **kwargs): + if filter == "gauss" and "sigma" not in kwargs.keys(): + raise KeyError("`filter='gauss', requires the additonal keyword argument `sigma`") + if filter == "savitzky_golay" and "order" not in kwargs.keys(): + raise KeyError("`filter='savitzky_golay', requires the additonal keyword argument `order`") + + def _get_filter_from_string(self, filter: str, size: int, ndim: int) -> nn.Module | Callable: + if filter == "mean": + return MeanFilter(ndim, size) + elif filter == "laplace": + return LaplaceFilter(ndim, size) + elif filter == "elliptical": + return EllipticalFilter(ndim, size) + elif filter == "sobel": + from monai.transforms.post.array import SobelGradients # cannot import on top because of circular imports + + allowed_keys = SobelGradients.__init__.__annotations__.keys() + kwargs = {k: v for k, v in self.additional_args_for_filter.items() if k in allowed_keys} + return SobelGradients(size, **kwargs) + elif filter == "sharpen": + return SharpenFilter(ndim, size) + elif filter == "gauss": + allowed_keys = GaussianFilter.__init__.__annotations__.keys() + kwargs = {k: v for k, v in self.additional_args_for_filter.items() if k in allowed_keys} + return GaussianFilter(ndim, **kwargs) + elif filter == "median": + return partial(median_filter, kernel_size=size, spatial_dims=ndim) + elif filter == "savitzky_golay": + allowed_keys = SavitzkyGolayFilter.__init__.__annotations__.keys() + kwargs = {k: v for k, v in self.additional_args_for_filter.items() if k in allowed_keys} + return SavitzkyGolayFilter(size, **kwargs) + else: + raise NotImplementedError(f"Filter {filter} not implemented") + + def _apply_filter(self, img: torch.Tensor) -> torch.Tensor: + if isinstance(self.filter, Transform): + img = self.filter(img) + else: + img = self.filter(img.unsqueeze(0)) # type: ignore + img = img[0] # add and remove batch dim + return img + + +class RandImageFilter(RandomizableTransform): + """ + Randomly apply a convolutional filter to the input data. + + Args: + filter: + A string specifying the filter or a custom filter as `torch.Tenor` or `np.ndarray`. + Available options are: `mean`, `laplace`, `elliptical`, `gaussian`` + See below for short explanations on every filter. + filter_size: + A single integer value specifying the size of the quadratic or cubic filter. + Computational complexity scales to the power of 2 (2D filter) or 3 (3D filter), which + should be considered when choosing filter size. + prob: + Probability the transform is applied to the data + """ + + backend = ImageFilter.backend + + def __init__( + self, filter: str | NdarrayOrTensor, filter_size: int | None = None, prob: float = 0.1, **kwargs + ) -> None: + super().__init__(prob) + self.filter = ImageFilter(filter, filter_size, **kwargs) + + def __call__(self, img: NdarrayOrTensor, meta_dict: Mapping | None = None) -> NdarrayOrTensor: + """ + Args: + img: torch tensor data to apply filter to with shape: [channels, height, width[, depth]] + meta_dict: An optional dictionary with metadata + kwargs: optional arguments required by specific filters. E.g. `sigma`if filter is `gauss`. + see py:func:`monai.transforms.utility.array.ImageFilter` for more details + + Returns: + A MetaTensor with the same shape as `img` and identical metadata + """ + self.randomize(None) + if self._do_transform: + img = self.filter(img) + return img diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index d52fdbe251..344f75ddc8 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -15,9 +15,12 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ +from __future__ import annotations + import re +from collections.abc import Callable, Hashable, Mapping from copy import deepcopy -from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union, cast +from typing import Any, Sequence, cast import numpy as np import torch @@ -27,7 +30,8 @@ from monai.data.meta_tensor import MetaObj, MetaTensor from monai.data.utils import no_collation from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import MapTransform, Randomizable, RandomizableTrait, RandomizableTransform +from monai.transforms.traits import MultiSampleTrait, RandomizableTrait +from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform from monai.transforms.utility.array import ( AddChannel, AddCoordinateChannels, @@ -43,6 +47,7 @@ EnsureType, FgBgToIndices, Identity, + ImageFilter, IntensityStats, LabelToMask, Lambda, @@ -118,6 +123,7 @@ "IntensityStatsd", "IntensityStatsD", "IntensityStatsDict", + "ImageFilterd", "LabelToMaskD", "LabelToMaskDict", "LabelToMaskd", @@ -133,6 +139,7 @@ "RandCuCIMd", "RandCuCIMD", "RandCuCIMDict", + "RandImageFilterd", "RandLambdaD", "RandLambdaDict", "RandLambdad", @@ -207,7 +214,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.identity = Identity() - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.identity(d[key]) @@ -232,7 +239,7 @@ def __init__(self, keys: KeysCollection, channel_dim: int = -1, allow_missing_ke super().__init__(keys, allow_missing_keys) self.converter = AsChannelFirst(channel_dim=channel_dim) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -257,7 +264,7 @@ def __init__(self, keys: KeysCollection, channel_dim: int = 0, allow_missing_key super().__init__(keys, allow_missing_keys) self.converter = AsChannelLast(channel_dim=channel_dim) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -281,7 +288,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.adder = AddChannel() - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.adder(d[key]) @@ -300,7 +307,7 @@ class EnsureChannelFirstd(MapTransform): def __init__( self, keys: KeysCollection, - meta_keys: Optional[KeysCollection] = None, + meta_keys: KeysCollection | None = None, meta_key_postfix: str = DEFAULT_POST_FIX, strict_check: bool = True, allow_missing_keys: bool = False, @@ -322,7 +329,7 @@ def __init__( self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): d[key] = self.adjuster(d[key], d.get(meta_key or f"{key}_{meta_key_postfix}")) # type: ignore @@ -347,7 +354,7 @@ def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.repeater = RepeatChannel(repeats) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.repeater(d[key]) @@ -372,21 +379,20 @@ def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.repeater = RemoveRepeatedChannel(repeats) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.repeater(d[key]) return d -class SplitDimd(MapTransform): - +class SplitDimd(MapTransform, MultiSampleTrait): backend = SplitDim.backend def __init__( self, keys: KeysCollection, - output_postfixes: Optional[Sequence[str]] = None, + output_postfixes: Sequence[str] | None = None, dim: int = 0, keepdim: bool = True, update_meta: bool = True, @@ -418,7 +424,7 @@ def __init__( def __call__( self, data: Mapping[Hashable, torch.Tensor] - ) -> Union[Dict[Hashable, torch.Tensor], List[Dict[Hashable, torch.Tensor]]]: + ) -> dict[Hashable, torch.Tensor] | list[dict[Hashable, torch.Tensor]]: d = dict(data) all_keys = list(set(self.key_iterator(d))) @@ -456,7 +462,7 @@ class SplitChanneld(SplitDimd): def __init__( self, keys: KeysCollection, - output_postfixes: Optional[Sequence[str]] = None, + output_postfixes: Sequence[str] | None = None, channel_dim: int = 0, allow_missing_keys: bool = False, ) -> None: @@ -479,7 +485,7 @@ class CastToTyped(MapTransform): def __init__( self, keys: KeysCollection, - dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], DtypeLike, torch.dtype] = np.float32, + dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, allow_missing_keys: bool = False, ) -> None: """ @@ -496,7 +502,7 @@ def __init__( self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.converter = CastToType() - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, dtype in self.key_iterator(d, self.dtype): d[key] = self.converter(d[key], dtype=dtype) @@ -514,10 +520,10 @@ class ToTensord(MapTransform, InvertibleTransform): def __init__( self, keys: KeysCollection, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, wrap_sequence: bool = True, - track_meta: Optional[bool] = None, + track_meta: bool | None = None, allow_missing_keys: bool = False, ) -> None: """ @@ -536,14 +542,14 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = ToTensor(dtype=dtype, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) self.push_transform(d, key) return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): # Remove the applied transform @@ -574,10 +580,10 @@ def __init__( self, keys: KeysCollection, data_type: str = "tensor", - dtype: Union[DtypeLike, torch.dtype] = None, - device: Optional[torch.device] = None, + dtype: DtypeLike | torch.dtype = None, + device: torch.device | None = None, wrap_sequence: bool = True, - track_meta: Optional[bool] = None, + track_meta: bool | None = None, allow_missing_keys: bool = False, ) -> None: """ @@ -598,7 +604,7 @@ def __init__( data_type=data_type, dtype=dtype, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta ) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -631,7 +637,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = ToNumpy(dtype=dtype, wrap_sequence=wrap_sequence) - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -658,14 +664,14 @@ class ToCupyd(MapTransform): def __init__( self, keys: KeysCollection, - dtype: Optional[np.dtype] = None, + dtype: np.dtype | None = None, wrap_sequence: bool = True, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) self.converter = ToCupy(dtype=dtype, wrap_sequence=wrap_sequence) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -689,7 +695,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.converter = ToPIL() - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -703,13 +709,11 @@ class Transposed(MapTransform, InvertibleTransform): backend = Transpose.backend - def __init__( - self, keys: KeysCollection, indices: Optional[Sequence[int]], allow_missing_keys: bool = False - ) -> None: + def __init__(self, keys: KeysCollection, indices: Sequence[int] | None, allow_missing_keys: bool = False) -> None: super().__init__(keys, allow_missing_keys) self.transform = Transpose(indices) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.transform(d[key]) @@ -718,7 +722,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N self.push_transform(d, key, extra_info={"indices": indices}) return d - def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def inverse(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]: d = dict(data) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -741,7 +745,7 @@ class DeleteItemsd(MapTransform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, keys: KeysCollection, sep: str = ".", use_re: Union[Sequence[bool], bool] = False) -> None: + def __init__(self, keys: KeysCollection, sep: str = ".", use_re: Sequence[bool] | bool = False) -> None: """ Args: keys: keys of the corresponding items to delete, can be "A{sep}B{sep}C" @@ -800,9 +804,9 @@ class FlattenSubKeysd(MapTransform): def __init__( self, keys: KeysCollection, - sub_keys: Optional[KeysCollection] = None, + sub_keys: KeysCollection | None = None, delete_keys: bool = True, - prefix: Optional[str] = None, + prefix: str | None = None, ) -> None: super().__init__(keys) self.sub_keys = sub_keys @@ -852,7 +856,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = SqueezeDim(dim=dim, update_meta=update_meta) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -869,12 +873,12 @@ class DataStatsd(MapTransform): def __init__( self, keys: KeysCollection, - prefix: Union[Sequence[str], str] = "Data", - data_type: Union[Sequence[bool], bool] = True, - data_shape: Union[Sequence[bool], bool] = True, - value_range: Union[Sequence[bool], bool] = True, - data_value: Union[Sequence[bool], bool] = False, - additional_info: Optional[Union[Sequence[Callable], Callable]] = None, + prefix: Sequence[str] | str = "Data", + data_type: Sequence[bool] | bool = True, + data_shape: Sequence[bool] | bool = True, + value_range: Sequence[bool] | bool = True, + data_value: Sequence[bool] | bool = False, + additional_info: Sequence[Callable] | Callable | None = None, name: str = "DataStats", allow_missing_keys: bool = False, ) -> None: @@ -909,7 +913,7 @@ def __init__( self.additional_info = ensure_tuple_rep(additional_info, len(self.keys)) self.printer = DataStats(name=name) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, prefix, data_type, data_shape, value_range, data_value, additional_info in self.key_iterator( d, self.prefix, self.data_type, self.data_shape, self.value_range, self.data_value, self.additional_info @@ -926,7 +930,7 @@ class SimulateDelayd(MapTransform): backend = SimulateDelay.backend def __init__( - self, keys: KeysCollection, delay_time: Union[Sequence[float], float] = 0.0, allow_missing_keys: bool = False + self, keys: KeysCollection, delay_time: Sequence[float] | float = 0.0, allow_missing_keys: bool = False ) -> None: """ Args: @@ -941,7 +945,7 @@ def __init__( self.delay_time = ensure_tuple_rep(delay_time, len(self.keys)) self.delayer = SimulateDelay() - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, delay_time in self.key_iterator(d, self.delay_time): d[key] = self.delayer(d[key], delay_time=delay_time) @@ -960,7 +964,7 @@ def __init__( self, keys: KeysCollection, times: int = 1, - names: Optional[KeysCollection] = None, + names: KeysCollection | None = None, allow_missing_keys: bool = False, ) -> None: """ @@ -992,7 +996,7 @@ def __init__( ) self.names = names - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: """ Raises: KeyError: When a key in ``self.names`` already exists in ``data``. @@ -1031,7 +1035,7 @@ def __init__(self, keys: KeysCollection, name: str, dim: int = 0, allow_missing_ self.name = name self.dim = dim - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: """ Raises: TypeError: When items in ``data`` differ in type. @@ -1100,9 +1104,9 @@ class Lambdad(MapTransform, InvertibleTransform): def __init__( self, keys: KeysCollection, - func: Union[Sequence[Callable], Callable], - inv_func: Union[Sequence[Callable], Callable] = no_collation, - overwrite: Union[Sequence[bool], bool, Sequence[str], str] = True, + func: Sequence[Callable] | Callable, + inv_func: Sequence[Callable] | Callable = no_collation, + overwrite: Sequence[bool] | bool | Sequence[str] | str = True, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) @@ -1111,7 +1115,7 @@ def __init__( self.overwrite = ensure_tuple_rep(overwrite, len(self.keys)) self._lambd = Lambda() - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key, func, overwrite in self.key_iterator(d, self.func, self.overwrite): ret = self._lambd(img=d[key], func=func) @@ -1159,9 +1163,9 @@ class RandLambdad(Lambdad, RandomizableTransform): def __init__( self, keys: KeysCollection, - func: Union[Sequence[Callable], Callable], - inv_func: Union[Sequence[Callable], Callable] = no_collation, - overwrite: Union[Sequence[bool], bool] = True, + func: Sequence[Callable] | Callable, + inv_func: Sequence[Callable] | Callable = no_collation, + overwrite: Sequence[bool] | bool = True, prob: float = 1.0, allow_missing_keys: bool = False, ) -> None: @@ -1191,7 +1195,7 @@ def __call__(self, data): d[key] = ret return d - def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key, overwrite in self.key_iterator(d, self.overwrite): if isinstance(d[key], MetaTensor): @@ -1225,14 +1229,14 @@ class LabelToMaskd(MapTransform): def __init__( # pytype: disable=annotation-type-mismatch self, keys: KeysCollection, - select_labels: Union[Sequence[int], int], + select_labels: Sequence[int] | int, merge_channels: bool = False, allow_missing_keys: bool = False, ) -> None: # pytype: disable=annotation-type-mismatch super().__init__(keys, allow_missing_keys) self.converter = LabelToMask(select_labels=select_labels, merge_channels=merge_channels) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -1240,7 +1244,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d -class FgBgToIndicesd(MapTransform): +class FgBgToIndicesd(MapTransform, MultiSampleTrait): """ Dictionary-based wrapper of :py:class:`monai.transforms.FgBgToIndices`. @@ -1267,9 +1271,9 @@ def __init__( keys: KeysCollection, fg_postfix: str = "_fg_indices", bg_postfix: str = "_bg_indices", - image_key: Optional[str] = None, + image_key: str | None = None, image_threshold: float = 0.0, - output_shape: Optional[Sequence[int]] = None, + output_shape: Sequence[int] | None = None, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) @@ -1278,7 +1282,7 @@ def __init__( self.image_key = image_key self.converter = FgBgToIndices(image_threshold, output_shape) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) image = d[self.image_key] if self.image_key else None for key in self.key_iterator(d): @@ -1287,7 +1291,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d -class ClassesToIndicesd(MapTransform): +class ClassesToIndicesd(MapTransform, MultiSampleTrait): """ Dictionary-based wrapper of :py:class:`monai.transforms.ClassesToIndices`. @@ -1312,10 +1316,10 @@ def __init__( self, keys: KeysCollection, indices_postfix: str = "_cls_indices", - num_classes: Optional[int] = None, - image_key: Optional[str] = None, + num_classes: int | None = None, + image_key: str | None = None, image_threshold: float = 0.0, - output_shape: Optional[Sequence[int]] = None, + output_shape: Sequence[int] | None = None, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) @@ -1349,7 +1353,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False): super().__init__(keys, allow_missing_keys) self.converter = ConvertToMultiChannelBasedOnBratsClasses() - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -1383,7 +1387,7 @@ def __init__( label_key: str, background: int = 0, pert: float = 0.0, - sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 3.0, + sigma: Sequence[float] | float | Sequence[torch.Tensor] | torch.Tensor = 3.0, rescale_min: float = -1.0, rescale_max: float = 1.0, allow_missing_keys: bool = False, @@ -1391,7 +1395,7 @@ def __init__( MapTransform.__init__(self, keys, allow_missing_keys) self.background = background self.pert = pert - self.points: List[Tuple[int, ...]] = [] + self.points: list[tuple[int, ...]] = [] self.label_key = label_key self.sigma = sigma self.rescale_min = rescale_min @@ -1400,7 +1404,7 @@ def __init__( def randomize(self, label: NdarrayOrTensor) -> None: self.points = get_extreme_points(label, rand_state=self.R, background=self.background, pert=self.pert) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) label = d[self.label_key] if label.shape[0] != 1: @@ -1450,7 +1454,7 @@ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = F self.name = name self.trans = TorchVision(name, *args, **kwargs) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.trans(d[key]) @@ -1487,7 +1491,7 @@ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = F self.name = name self.trans = TorchVision(name, *args, **kwargs) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.trans(d[key]) @@ -1522,7 +1526,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.mapper = MapLabelValue(orig_labels=orig_labels, target_labels=target_labels, dtype=dtype) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.mapper(d[key]) @@ -1569,11 +1573,11 @@ class IntensityStatsd(MapTransform): def __init__( self, keys: KeysCollection, - ops: Sequence[Union[str, Callable]], + ops: Sequence[str | Callable], key_prefix: str, - mask_keys: Optional[KeysCollection] = None, + mask_keys: KeysCollection | None = None, channel_wise: bool = False, - meta_keys: Optional[KeysCollection] = None, + meta_keys: KeysCollection | None = None, meta_key_postfix: str = DEFAULT_POST_FIX, allow_missing_keys: bool = False, ) -> None: @@ -1585,7 +1589,7 @@ def __init__( raise ValueError("meta_keys should have the same length as keys.") self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - def __call__(self, data) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, mask_key, meta_key, meta_key_postfix in self.key_iterator( d, self.mask_keys, self.meta_keys, self.meta_key_postfix @@ -1605,7 +1609,7 @@ class ToDeviced(MapTransform): backend = ToDevice.backend def __init__( - self, keys: KeysCollection, device: Union[torch.device, str], allow_missing_keys: bool = False, **kwargs + self, keys: KeysCollection, device: torch.device | str, allow_missing_keys: bool = False, **kwargs ) -> None: """ Args: @@ -1619,7 +1623,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = ToDevice(device=device, **kwargs) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -1724,20 +1728,101 @@ class AddCoordinateChannelsd(MapTransform): backend = AddCoordinateChannels.backend - @deprecated_arg( - name="spatial_channels", new_name="spatial_dims", since="0.8", msg_suffix="please use `spatial_dims` instead." - ) def __init__(self, keys: KeysCollection, spatial_dims: Sequence[int], allow_missing_keys: bool = False) -> None: super().__init__(keys, allow_missing_keys) self.add_coordinate_channels = AddCoordinateChannels(spatial_dims=spatial_dims) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.add_coordinate_channels(d[key]) return d +class ImageFilterd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ImageFilter`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: monai.transforms.MapTransform + kernel: + A string specifying the kernel or a custom kernel as `torch.Tenor` or `np.ndarray`. + Available options are: `mean`, `laplacian`, `elliptical`, `sobel_{w,h,d}`` + kernel_size: + A single integer value specifying the size of the quadratic or cubic kernel. + Computational complexity increases exponentially with kernel_size, which + should be considered when choosing the kernel size. + allow_missing_keys: + Don't raise exception if key is missing. + """ + + backend = ImageFilter.backend + + def __init__( + self, + keys: KeysCollection, + kernel: str | NdarrayOrTensor, + kernel_size: int | None = None, + allow_missing_keys: bool = False, + **kwargs, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.filter = ImageFilter(kernel, kernel_size, **kwargs) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.filter(d[key]) + return d + + +class RandImageFilterd(MapTransform, RandomizableTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.RandomFilterKernel`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: monai.transforms.MapTransform + kernel: + A string specifying the kernel or a custom kernel as `torch.Tenor` or `np.ndarray`. + Available options are: `mean`, `laplacian`, `elliptical`, `sobel_{w,h,d}`` + kernel_size: + A single integer value specifying the size of the quadratic or cubic kernel. + Computational complexity increases exponentially with kernel_size, which + should be considered when choosing the kernel size. + prob: + Probability the transform is applied to the data + allow_missing_keys: + Don't raise exception if key is missing. + """ + + backend = ImageFilter.backend + + def __init__( + self, + keys: KeysCollection, + kernel: str | NdarrayOrTensor, + kernel_size: int | None = None, + prob: float = 0.1, + allow_missing_keys: bool = False, + **kwargs, + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob) + self.filter = ImageFilter(kernel, kernel_size, **kwargs) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + self.randomize(None) + if self._do_transform: + for key in self.key_iterator(d): + d[key] = self.filter(d[key]) + return d + + +RandImageFilterD = RandImageFilterDict = RandImageFilterd +ImageFilterD = ImageFilterDict = ImageFilterd IdentityD = IdentityDict = Identityd AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 8c7f475972..86f57d41aa 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -9,13 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import itertools import random import warnings +from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence from contextlib import contextmanager from functools import wraps from inspect import getmembers, isclass -from typing import Any, Callable, Hashable, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Union +from typing import Any import numpy as np import torch @@ -64,7 +67,6 @@ ndimage, _ = optional_import("scipy.ndimage") cp, has_cp = optional_import("cupy") cp_ndarray, _ = optional_import("cupy", name="ndarray") -cucim, has_cucim = optional_import("cucim") exposure, has_skimage = optional_import("skimage.exposure") __all__ = [ @@ -139,7 +141,7 @@ def in_bounds(x: float, y: float, margin: float, maxx: float, maxy: float) -> bo return bool(margin <= x < (maxx - margin) and margin <= y < (maxy - margin)) -def is_empty(img: Union[np.ndarray, torch.Tensor]) -> bool: +def is_empty(img: np.ndarray | torch.Tensor) -> bool: """ Returns True if `img` is empty, that is its maximum value is not greater than its minimum. """ @@ -165,9 +167,9 @@ def zero_margins(img: np.ndarray, margin: int) -> bool: def rescale_array( arr: NdarrayOrTensor, - minv: Optional[float] = 0.0, - maxv: Optional[float] = 1.0, - dtype: Union[DtypeLike, torch.dtype] = np.float32, + minv: float | None = 0.0, + maxv: float | None = 1.0, + dtype: DtypeLike | torch.dtype = np.float32, ) -> NdarrayOrTensor: """ Rescale the values of numpy array `arr` to be from `minv` to `maxv`. @@ -195,7 +197,7 @@ def rescale_array( def rescale_instance_array( - arr: np.ndarray, minv: Optional[float] = 0.0, maxv: Optional[float] = 1.0, dtype: DtypeLike = np.float32 + arr: np.ndarray, minv: float | None = 0.0, maxv: float | None = 1.0, dtype: DtypeLike = np.float32 ) -> np.ndarray: """ Rescale each array slice along the first dimension of `arr` independently. @@ -216,8 +218,8 @@ def rescale_array_int_max(arr: np.ndarray, dtype: DtypeLike = np.uint16) -> np.n def copypaste_arrays( - src_shape, dest_shape, srccenter: Sequence[int], destcenter: Sequence[int], dims: Sequence[Optional[int]] -) -> Tuple[Tuple[slice, ...], Tuple[slice, ...]]: + src_shape, dest_shape, srccenter: Sequence[int], destcenter: Sequence[int], dims: Sequence[int | None] +) -> tuple[tuple[slice, ...], tuple[slice, ...]]: """ Calculate the slices to copy a sliced area of array in `src_shape` into array in `dest_shape`. @@ -271,7 +273,7 @@ def copypaste_arrays( return tuple(srcslices), tuple(destslices) -def resize_center(img: np.ndarray, *resize_dims: Optional[int], fill_value: float = 0.0, inplace: bool = True): +def resize_center(img: np.ndarray, *resize_dims: int | None, fill_value: float = 0.0, inplace: bool = True): """ Resize `img` by cropping or expanding the image from the center. The `resize_dims` values are the output dimensions (or None to use original dimension of `img`). If a dimension is smaller than that of `img` then the result will be @@ -293,8 +295,8 @@ def resize_center(img: np.ndarray, *resize_dims: Optional[int], fill_value: floa def map_binary_to_indices( - label: NdarrayOrTensor, image: Optional[NdarrayOrTensor] = None, image_threshold: float = 0.0 -) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: + label: NdarrayOrTensor, image: NdarrayOrTensor | None = None, image_threshold: float = 0.0 +) -> tuple[NdarrayOrTensor, NdarrayOrTensor]: """ Compute the foreground and background of input label data, return the indices after fattening. For example: @@ -329,10 +331,10 @@ def map_binary_to_indices( def map_classes_to_indices( label: NdarrayOrTensor, - num_classes: Optional[int] = None, - image: Optional[NdarrayOrTensor] = None, + num_classes: int | None = None, + image: NdarrayOrTensor | None = None, image_threshold: float = 0.0, -) -> List[NdarrayOrTensor]: +) -> list[NdarrayOrTensor]: """ Filter out indices of every class of the input label data, return the indices after fattening. It can handle both One-Hot format label and Argmax format label, must provide `num_classes` for @@ -352,11 +354,11 @@ def map_classes_to_indices( determine the valid image content area and select class indices only in this area. """ - img_flat: Optional[NdarrayOrTensor] = None + img_flat: NdarrayOrTensor | None = None if image is not None: img_flat = ravel((image > image_threshold).any(0)) - indices: List[NdarrayOrTensor] = [] + indices: list[NdarrayOrTensor] = [] # assuming the first dimension is channel channels = len(label) @@ -377,11 +379,11 @@ def map_classes_to_indices( def weighted_patch_samples( - spatial_size: Union[int, Sequence[int]], + spatial_size: int | Sequence[int], w: NdarrayOrTensor, n_samples: int = 1, - r_state: Optional[np.random.RandomState] = None, -) -> List: + r_state: np.random.RandomState | None = None, +) -> list: """ Computes `n_samples` of random patch sampling locations, given the sampling weight map `w` and patch `spatial_size`. @@ -398,7 +400,7 @@ def weighted_patch_samples( """ if w is None: - raise ValueError("w must be an ND array.") + raise ValueError("w must be an ND array, got None.") if r_state is None: r_state = np.random.RandomState() img_size = np.asarray(w.shape, dtype=int) @@ -407,7 +409,7 @@ def weighted_patch_samples( s = tuple(slice(w // 2, m - w + w // 2) if m > w else slice(m // 2, m // 2 + 1) for w, m in zip(win_size, img_size)) v = w[s] # weight map in the 'valid' mode v_size = v.shape - v = ravel(v) + v = ravel(v) # always copy if (v < 0).any(): v -= v.min() # shifting to non-negative v = cumsum(v) @@ -424,11 +426,11 @@ def weighted_patch_samples( def correct_crop_centers( - centers: List[int], - spatial_size: Union[Sequence[int], int], + centers: list[int], + spatial_size: Sequence[int] | int, label_spatial_shape: Sequence[int], allow_smaller: bool = False, -): +) -> tuple[Any]: """ Utility to correct the crop center if the crop size and centers are not compatible with the image size. @@ -444,7 +446,10 @@ def correct_crop_centers( spatial_size = fall_back_tuple(spatial_size, default=label_spatial_shape) if any(np.subtract(label_spatial_shape, spatial_size) < 0): if not allow_smaller: - raise ValueError("The size of the proposed random crop ROI is larger than the image size.") + raise ValueError( + "The size of the proposed random crop ROI is larger than the image size, " + f"got ROI size {spatial_size} and label image size {label_spatial_shape} respectively." + ) spatial_size = tuple(min(l, s) for l, s in zip(label_spatial_shape, spatial_size)) # Select subregion to assure valid roi @@ -461,19 +466,19 @@ def correct_crop_centers( for c, v_s, v_e in zip(centers, valid_start, valid_end): center_i = min(max(c, v_s), v_e - 1) valid_centers.append(int(center_i)) - return valid_centers + return ensure_tuple(valid_centers) # type: ignore def generate_pos_neg_label_crop_centers( - spatial_size: Union[Sequence[int], int], + spatial_size: Sequence[int] | int, num_samples: int, pos_ratio: float, label_spatial_shape: Sequence[int], fg_indices: NdarrayOrTensor, bg_indices: NdarrayOrTensor, - rand_state: Optional[np.random.RandomState] = None, + rand_state: np.random.RandomState | None = None, allow_smaller: bool = False, -) -> List[List[int]]: +) -> tuple[tuple]: """ Generate valid sample locations based on the label with option for specifying foreground ratio Valid: samples sitting entirely within image, expected input shape: [C, H, W, D] or [C, H, W] @@ -519,18 +524,19 @@ def generate_pos_neg_label_crop_centers( # shift center to range of valid centers centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape, allow_smaller)) - return centers + return ensure_tuple(centers) # type: ignore def generate_label_classes_crop_centers( - spatial_size: Union[Sequence[int], int], + spatial_size: Sequence[int] | int, num_samples: int, label_spatial_shape: Sequence[int], indices: Sequence[NdarrayOrTensor], - ratios: Optional[List[Union[float, int]]] = None, - rand_state: Optional[np.random.RandomState] = None, + ratios: list[float | int] | None = None, + rand_state: np.random.RandomState | None = None, allow_smaller: bool = False, -) -> List[List[int]]: + warn: bool = True, +) -> tuple[tuple]: """ Generate valid sample locations based on the specified ratios of label classes. Valid: samples sitting entirely within image, expected input shape: [C, H, W, D] or [C, H, W] @@ -546,23 +552,27 @@ def generate_label_classes_crop_centers( allow_smaller: if `False`, an exception will be raised if the image is smaller than the requested ROI in any dimension. If `True`, any smaller dimensions will be set to match the cropped size (i.e., no cropping in that dimension). + warn: if `True` prints a warning if a class is not present in the label. """ if rand_state is None: rand_state = np.random.random.__self__ # type: ignore if num_samples < 1: - raise ValueError("num_samples must be an int number and greater than 0.") - ratios_: List[Union[float, int]] = ([1] * len(indices)) if ratios is None else ratios + raise ValueError(f"num_samples must be an int number and greater than 0, got {num_samples}.") + ratios_: list[float | int] = list(ensure_tuple([1] * len(indices) if ratios is None else ratios)) if len(ratios_) != len(indices): - raise ValueError("random crop ratios must match the number of indices of classes.") + raise ValueError( + f"random crop ratios must match the number of indices of classes, got {len(ratios_)} and {len(indices)}." + ) if any(i < 0 for i in ratios_): - raise ValueError("ratios should not contain negative number.") + raise ValueError(f"ratios should not contain negative number, got {ratios_}.") for i, array in enumerate(indices): if len(array) == 0: - warnings.warn(f"no available indices of class {i} to crop, set the crop ratio of this class to zero.") ratios_[i] = 0 + if warn: + warnings.warn(f"no available indices of class {i} to crop, set the crop ratio of this class to zero.") centers = [] classes = rand_state.choice(len(ratios_), size=num_samples, p=np.asarray(ratios_) / np.sum(ratios_)) @@ -574,15 +584,15 @@ def generate_label_classes_crop_centers( # shift center to range of valid centers centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape, allow_smaller)) - return centers + return ensure_tuple(centers) # type: ignore def create_grid( spatial_size: Sequence[int], - spacing: Optional[Sequence[float]] = None, + spacing: Sequence[float] | None = None, homogeneous: bool = True, - dtype: Union[DtypeLike, torch.dtype] = float, - device: Optional[torch.device] = None, + dtype: DtypeLike | torch.dtype = float, + device: torch.device | None = None, backend=TransformBackends.NUMPY, ) -> NdarrayOrTensor: """ @@ -611,9 +621,9 @@ def create_grid( def _create_grid_numpy( spatial_size: Sequence[int], - spacing: Optional[Sequence[float]] = None, + spacing: Sequence[float] | None = None, homogeneous: bool = True, - dtype: Union[DtypeLike, torch.dtype] = float, + dtype: DtypeLike | torch.dtype = float, ): """ compute a `spatial_size` mesh with the numpy API. @@ -628,10 +638,10 @@ def _create_grid_numpy( def _create_grid_torch( spatial_size: Sequence[int], - spacing: Optional[Sequence[float]] = None, + spacing: Sequence[float] | None = None, homogeneous: bool = True, dtype=torch.float32, - device: Optional[torch.device] = None, + device: torch.device | None = None, ): """ compute a `spatial_size` mesh with the torch API. @@ -658,7 +668,7 @@ def create_control_grid( spacing: Sequence[float], homogeneous: bool = True, dtype: DtypeLike = float, - device: Optional[torch.device] = None, + device: torch.device | None = None, backend=TransformBackends.NUMPY, ): """ @@ -680,8 +690,8 @@ def create_control_grid( def create_rotate( spatial_dims: int, - radians: Union[Sequence[float], float], - device: Optional[torch.device] = None, + radians: Sequence[float] | float, + device: torch.device | None = None, backend: str = TransformBackends.NUMPY, ) -> NdarrayOrTensor: """ @@ -718,7 +728,7 @@ def create_rotate( def _create_rotate( spatial_dims: int, - radians: Union[Sequence[float], float], + radians: Sequence[float] | float, sin_func: Callable = np.sin, cos_func: Callable = np.cos, eye_func: Callable = np.eye, @@ -765,8 +775,8 @@ def _create_rotate( def create_shear( spatial_dims: int, - coefs: Union[Sequence[float], float], - device: Optional[torch.device] = None, + coefs: Sequence[float] | float, + device: torch.device | None = None, backend=TransformBackends.NUMPY, ) -> NdarrayOrTensor: """ @@ -801,7 +811,7 @@ def create_shear( raise ValueError(f"backend {backend} is not supported") -def _create_shear(spatial_dims: int, coefs: Union[Sequence[float], float], eye_func=np.eye) -> NdarrayOrTensor: +def _create_shear(spatial_dims: int, coefs: Sequence[float] | float, eye_func=np.eye) -> NdarrayOrTensor: if spatial_dims == 2: coefs = ensure_tuple_size(coefs, dim=2, pad_val=0.0) out = eye_func(3) @@ -819,8 +829,8 @@ def _create_shear(spatial_dims: int, coefs: Union[Sequence[float], float], eye_f def create_scale( spatial_dims: int, - scaling_factor: Union[Sequence[float], float], - device: Optional[torch.device] = None, + scaling_factor: Sequence[float] | float, + device: torch.device | str | None = None, backend=TransformBackends.NUMPY, ) -> NdarrayOrTensor: """ @@ -844,17 +854,15 @@ def create_scale( raise ValueError(f"backend {backend} is not supported") -def _create_scale( - spatial_dims: int, scaling_factor: Union[Sequence[float], float], array_func=np.diag -) -> NdarrayOrTensor: +def _create_scale(spatial_dims: int, scaling_factor: Sequence[float] | float, array_func=np.diag) -> NdarrayOrTensor: scaling_factor = ensure_tuple_size(scaling_factor, dim=spatial_dims, pad_val=1.0) return array_func(scaling_factor[:spatial_dims] + (1.0,)) # type: ignore def create_translate( spatial_dims: int, - shift: Union[Sequence[float], float], - device: Optional[torch.device] = None, + shift: Sequence[float] | float, + device: torch.device | None = None, backend=TransformBackends.NUMPY, ) -> NdarrayOrTensor: """ @@ -867,6 +875,7 @@ def create_translate( backend: APIs to use, ``numpy`` or ``torch``. """ _backend = look_up_option(backend, TransformBackends) + spatial_dims = int(spatial_dims) if _backend == TransformBackends.NUMPY: return _create_translate(spatial_dims=spatial_dims, shift=shift, eye_func=np.eye, array_func=np.asarray) if _backend == TransformBackends.TORCH: @@ -880,7 +889,7 @@ def create_translate( def _create_translate( - spatial_dims: int, shift: Union[Sequence[float], float], eye_func=np.eye, array_func=np.asarray + spatial_dims: int, shift: Sequence[float] | float, eye_func=np.eye, array_func=np.asarray ) -> NdarrayOrTensor: shift = ensure_tuple(shift) affine = eye_func(spatial_dims + 1) @@ -892,10 +901,10 @@ def _create_translate( def generate_spatial_bounding_box( img: NdarrayOrTensor, select_fn: Callable = is_positive, - channel_indices: Optional[IndexSelection] = None, - margin: Union[Sequence[int], int] = 0, + channel_indices: IndexSelection | None = None, + margin: Sequence[int] | int = 0, allow_smaller: bool = True, -) -> Tuple[List[int], List[int]]: +) -> tuple[list[int], list[int]]: """ Generate the spatial bounding box of foreground in the image with start-end positions (inclusive). Users can define arbitrary function to select expected foreground from the whole image or specified channels. @@ -924,7 +933,7 @@ def generate_spatial_bounding_box( margin = ensure_tuple_rep(margin, ndim) for m in margin: if m < 0: - raise ValueError("margin value should not be negative number.") + raise ValueError(f"margin value should not be negative number, got {margin}.") box_start = [0] * ndim box_end = [0] * ndim @@ -952,7 +961,7 @@ def generate_spatial_bounding_box( def get_largest_connected_component_mask( - img: NdarrayTensor, connectivity: Optional[int] = None, num_components: int = 1 + img: NdarrayTensor, connectivity: int | None = None, num_components: int = 1 ) -> NdarrayTensor: """ Gets the largest connected component mask of an image. @@ -967,6 +976,7 @@ def get_largest_connected_component_mask( """ # use skimage/cucim.skimage and np/cp depending on whether packages are # available and input is non-cpu torch.tensor + cucim, has_cucim = optional_import("cucim") use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device != torch.device("cpu") if use_cp: img_ = convert_to_cupy(img.short()) # type: ignore @@ -1046,9 +1056,7 @@ def remove_small_objects( return out -def get_unique_labels( - img: NdarrayOrTensor, is_onehot: bool, discard: Optional[Union[int, Iterable[int]]] = None -) -> Set[int]: +def get_unique_labels(img: NdarrayOrTensor, is_onehot: bool, discard: int | Iterable[int] | None = None) -> set[int]: """Get list of non-background labels in an image. Args: @@ -1060,13 +1068,13 @@ def get_unique_labels( Returns: Set of labels """ - applied_labels: Set[int] + applied_labels: set[int] n_channels = img.shape[0] if is_onehot: applied_labels = {i for i, s in enumerate(img) if s.sum() > 0} else: if n_channels != 1: - raise ValueError("If input not one-hotted, should only be 1 channel.") + raise ValueError(f"If input not one-hotted, should only be 1 channel, got {n_channels}.") applied_labels = set(unique(img).tolist()) if discard is not None: for i in ensure_tuple(discard): @@ -1075,7 +1083,7 @@ def get_unique_labels( def fill_holes( - img_arr: np.ndarray, applied_labels: Optional[Iterable[int]] = None, connectivity: Optional[int] = None + img_arr: np.ndarray, applied_labels: Iterable[int] | None = None, connectivity: int | None = None ) -> np.ndarray: """ Fill the holes in the provided image. @@ -1134,8 +1142,8 @@ def fill_holes( def get_extreme_points( - img: NdarrayOrTensor, rand_state: Optional[np.random.RandomState] = None, background: int = 0, pert: float = 0.0 -) -> List[Tuple[int, ...]]: + img: NdarrayOrTensor, rand_state: np.random.RandomState | None = None, background: int = 0, pert: float = 0.0 +) -> list[tuple[int, ...]]: """ Generate extreme points from an image. These are used to generate initial segmentation for annotation models. An optional perturbation can be passed to simulate user clicks. @@ -1191,9 +1199,9 @@ def _get_point(val, dim): def extreme_points_to_image( - points: List[Tuple[int, ...]], + points: list[tuple[int, ...]], label: NdarrayOrTensor, - sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 0.0, + sigma: Sequence[float] | float | Sequence[torch.Tensor] | torch.Tensor = 0.0, rescale_min: float = -1.0, rescale_max: float = 1.0, ) -> torch.Tensor: @@ -1237,8 +1245,8 @@ def extreme_points_to_image( def map_spatial_axes( - img_ndim: int, spatial_axes: Optional[Union[Sequence[int], int]] = None, channel_first: bool = True -) -> List[int]: + img_ndim: int, spatial_axes: Sequence[int] | int | None = None, channel_first: bool = True +) -> list[int]: """ Utility to map the spatial axes to real axes in channel first/last shape. For example: @@ -1272,7 +1280,7 @@ def map_spatial_axes( @contextmanager -def allow_missing_keys_mode(transform: Union[MapTransform, Compose, Tuple[MapTransform], Tuple[Compose]]): +def allow_missing_keys_mode(transform: MapTransform | Compose | tuple[MapTransform] | tuple[Compose]): """Temporarily set all MapTransforms to not throw an error if keys are missing. After, revert to original states. Args: @@ -1321,7 +1329,7 @@ def allow_missing_keys_mode(transform: Union[MapTransform, Compose, Tuple[MapTra _interp_modes = list(InterpolateMode) + list(GridSampleMode) -def convert_applied_interp_mode(trans_info, mode: str = "nearest", align_corners: Optional[bool] = None): +def convert_applied_interp_mode(trans_info, mode: str = "nearest", align_corners: bool | None = None): """ Recursively change the interpolation mode in the applied operation stacks, default to "nearest". @@ -1372,7 +1380,7 @@ def reset_ops_id(data): return {k: reset_ops_id(v) for k, v in data.items()} -def compute_divisible_spatial_size(spatial_shape: Sequence[int], k: Union[Sequence[int], int]): +def compute_divisible_spatial_size(spatial_shape: Sequence[int], k: Sequence[int] | int): """ Compute the target spatial size which should be divisible by `k`. @@ -1389,11 +1397,11 @@ def compute_divisible_spatial_size(spatial_shape: Sequence[int], k: Union[Sequen new_dim = int(np.ceil(dim / k_d) * k_d) if k_d > 0 else dim new_size.append(new_dim) - return new_size + return tuple(new_size) def equalize_hist( - img: np.ndarray, mask: Optional[np.ndarray] = None, num_bins: int = 256, min: int = 0, max: int = 255 + img: np.ndarray, mask: np.ndarray | None = None, num_bins: int = 256, min: int = 0, max: int = 255 ) -> np.ndarray: """ Utility to equalize input image based on the histogram. @@ -1459,7 +1467,7 @@ def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor: return k @staticmethod - def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: Optional[int] = None) -> NdarrayOrTensor: + def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None = None) -> NdarrayOrTensor: """ Applies inverse shift and fourier transform. Only the spatial dimensions are transformed. @@ -1484,7 +1492,7 @@ def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: Optional[in return out -def get_number_image_type_conversions(transform: Compose, test_data: Any, key: Optional[Hashable] = None) -> int: +def get_number_image_type_conversions(transform: Compose, test_data: Any, key: Hashable | None = None) -> int: """ Get the number of times that the data need to be converted (e.g., numpy to torch). Conversions between different devices are also counted (e.g., CPU to GPU). @@ -1545,6 +1553,8 @@ def get_transform_backends(): not in [ "BatchInverseTransform", "Compose", + "CuCIM", + "CuCIMD", "Decollated", "InvertD", "InvertibleTransform", @@ -1552,6 +1562,8 @@ def get_transform_backends(): "LambdaD", "MapTransform", "OneOf", + "RandCuCIM", + "RandCuCIMD", "RandomOrder", "PadListDataCollate", "RandLambda", @@ -1607,7 +1619,7 @@ def print_table_column(name, torch, numpy, color=Colors.none): print_color(f"Number of uncategorised: {n_uncategorized}", Colors.red) -def convert_pad_mode(dst: NdarrayOrTensor, mode: Optional[str]): +def convert_pad_mode(dst: NdarrayOrTensor, mode: str | None): """ Utility to convert padding mode between numpy array and PyTorch Tensor. @@ -1619,21 +1631,21 @@ def convert_pad_mode(dst: NdarrayOrTensor, mode: Optional[str]): if isinstance(dst, torch.Tensor): if mode == "wrap": mode = "circular" - if mode == "edge": + elif mode == "edge": mode = "replicate" return look_up_option(mode, PytorchPadMode) if isinstance(dst, np.ndarray): if mode == "circular": mode = "wrap" - if mode == "replicate": + elif mode == "replicate": mode = "edge" return look_up_option(mode, NumpyPadMode) raise ValueError(f"unsupported data type: {type(dst)}.") def convert_to_contiguous( - data: Union[NdarrayOrTensor, str, bytes, Mapping, Sequence[Any]], **kwargs -) -> Union[NdarrayOrTensor, Mapping, Sequence[Any]]: + data: NdarrayOrTensor | str | bytes | Mapping | Sequence[Any], **kwargs +) -> NdarrayOrTensor | Mapping | Sequence[Any]: """ Check and ensure the numpy array or PyTorch Tensor in data to be contiguous in memory. @@ -1653,29 +1665,27 @@ def convert_to_contiguous( return data -def scale_affine(affine, spatial_size, new_spatial_size, centered: bool = True): +def scale_affine(spatial_size, new_spatial_size, centered: bool = True): """ - Scale the affine matrix according to the new spatial size. + Compute the scaling matrix according to the new spatial size Args: - affine: affine matrix to scale. spatial_size: original spatial size. new_spatial_size: new spatial size. - centered: whether the scaling is with respect to - the image center (True, default) or corner (False). + centered: whether the scaling is with respect to the image center (True, default) or corner (False). Returns: - Scaled affine matrix. + the scaling matrix. """ + r = max(len(new_spatial_size), len(spatial_size)) if spatial_size == new_spatial_size: - return affine - r = len(affine) - 1 - s = np.array([float(o) / float(max(n, 1)) for o, n in zip(spatial_size, new_spatial_size)]) + return np.eye(r + 1) + s = np.array([float(o) / float(max(n, 1)) for o, n in zip(spatial_size, new_spatial_size)], dtype=float) scale = create_scale(r, s.tolist()) if centered: - scale[:r, -1] = (np.diag(scale)[:r] - 1) / 2 # type: ignore - return affine @ convert_to_dst_type(scale, affine)[0] + scale[:r, -1] = (np.diag(scale)[:r] - 1) / 2.0 # type: ignore + return scale def attach_hook(func, hook, mode="pre"): diff --git a/monai/transforms/utils_create_transform_ims.py b/monai/transforms/utils_create_transform_ims.py index 12236e1bb1..a98cdfe936 100644 --- a/monai/transforms/utils_create_transform_ims.py +++ b/monai/transforms/utils_create_transform_ims.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import pathlib import tempfile @@ -464,7 +466,6 @@ def create_transform_im( if __name__ == "__main__": - keys = [CommonKeys.IMAGE, CommonKeys.LABEL] data = get_data(keys) create_transform_im(RandFlip, dict(prob=1, spatial_axis=1), data) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index d3f7855649..cad15df181 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -9,7 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple, TypeVar, Union +from __future__ import annotations + +from collections.abc import Sequence +from typing import TypeVar import numpy as np import torch @@ -52,13 +55,13 @@ def allclose(a: NdarrayTensor, b: NdarrayOrTensor, rtol=1e-5, atol=1e-8, equal_nan=False) -> bool: """`np.allclose` with equivalent implementation for torch.""" - b, *_ = convert_to_dst_type(b, a) + b, *_ = convert_to_dst_type(b, a, wrap_sequence=True) if isinstance(a, np.ndarray): return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) # type: ignore -def moveaxis(x: NdarrayOrTensor, src: Union[int, Sequence[int]], dst: Union[int, Sequence[int]]) -> NdarrayOrTensor: +def moveaxis(x: NdarrayOrTensor, src: int | Sequence[int], dst: int | Sequence[int]) -> NdarrayOrTensor: """`moveaxis` for pytorch and numpy""" if isinstance(x, torch.Tensor): return torch.movedim(x, src, dst) # type: ignore @@ -83,8 +86,8 @@ def clip(a: NdarrayOrTensor, a_min, a_max) -> NdarrayOrTensor: def percentile( - x: NdarrayOrTensor, q, dim: Optional[int] = None, keepdim: bool = False, **kwargs -) -> Union[NdarrayOrTensor, float, int]: + x: NdarrayOrTensor, q, dim: int | None = None, keepdim: bool = False, **kwargs +) -> NdarrayOrTensor | float | int: """`np.percentile` with equivalent implementation for torch. Pytorch uses `quantile`. For more details please refer to: @@ -106,7 +109,7 @@ def percentile( q_np = convert_data_type(q, output_type=np.ndarray, wrap_sequence=True)[0] if ((q_np < 0) | (q_np > 100)).any(): raise ValueError(f"q values must be in [0, 100], got values: {q}.") - result: Union[NdarrayOrTensor, float, int] + result: NdarrayOrTensor | float | int if isinstance(x, np.ndarray) or (isinstance(x, torch.Tensor) and torch.numel(x) > 1_000_000): # pytorch#64947 _x = convert_data_type(x, output_type=np.ndarray)[0] result = np.percentile(_x, q_np, axis=dim, keepdims=keepdim, **kwargs) @@ -220,7 +223,7 @@ def ravel(x: NdarrayOrTensor) -> NdarrayOrTensor: return np.ravel(x) -def any_np_pt(x: NdarrayOrTensor, axis: Union[int, Sequence[int]]) -> NdarrayOrTensor: +def any_np_pt(x: NdarrayOrTensor, axis: int | Sequence[int]) -> NdarrayOrTensor: """`np.any` with equivalent implementation for torch. For pytorch, convert to boolean for compatibility with older versions. @@ -313,7 +316,7 @@ def searchsorted(a: NdarrayTensor, v: NdarrayOrTensor, right=False, sorter=None, return torch.searchsorted(a, v, right=right, **kwargs) # type: ignore -def repeat(a: NdarrayOrTensor, repeats: int, axis: Optional[int] = None, **kwargs) -> NdarrayOrTensor: +def repeat(a: NdarrayOrTensor, repeats: int, axis: int | None = None, **kwargs) -> NdarrayOrTensor: """ `np.repeat` with equivalent implementation for torch (`repeat_interleave`). @@ -345,7 +348,7 @@ def isnan(x: NdarrayOrTensor) -> NdarrayOrTensor: T = TypeVar("T") -def ascontiguousarray(x: Union[NdarrayTensor, T], **kwargs) -> Union[NdarrayOrTensor, T]: +def ascontiguousarray(x: NdarrayTensor | T, **kwargs) -> NdarrayOrTensor | T: """`np.ascontiguousarray` with equivalent implementation for torch (`contiguous`). Args: @@ -410,7 +413,7 @@ def linalg_inv(x: NdarrayTensor) -> NdarrayTensor: return torch.linalg.inv(x) if isinstance(x, torch.Tensor) else np.linalg.inv(x) # type: ignore -def max(x: NdarrayTensor, dim: Optional[Union[int, Tuple]] = None, **kwargs) -> NdarrayTensor: +def max(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor: """`torch.max` with equivalent implementation for numpy Args: @@ -433,7 +436,7 @@ def max(x: NdarrayTensor, dim: Optional[Union[int, Tuple]] = None, **kwargs) -> return ret -def mean(x: NdarrayTensor, dim: Optional[Union[int, Tuple]] = None, **kwargs) -> NdarrayTensor: +def mean(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor: """`torch.mean` with equivalent implementation for numpy Args: @@ -455,7 +458,7 @@ def mean(x: NdarrayTensor, dim: Optional[Union[int, Tuple]] = None, **kwargs) -> return ret -def median(x: NdarrayTensor, dim: Optional[Union[int, Tuple]] = None, **kwargs) -> NdarrayTensor: +def median(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor: """`torch.median` with equivalent implementation for numpy Args: @@ -477,7 +480,7 @@ def median(x: NdarrayTensor, dim: Optional[Union[int, Tuple]] = None, **kwargs) return ret -def min(x: NdarrayTensor, dim: Optional[Union[int, Tuple]] = None, **kwargs) -> NdarrayTensor: +def min(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor: """`torch.min` with equivalent implementation for numpy Args: @@ -499,7 +502,7 @@ def min(x: NdarrayTensor, dim: Optional[Union[int, Tuple]] = None, **kwargs) -> return ret -def std(x: NdarrayTensor, dim: Optional[Union[int, Tuple]] = None, unbiased: bool = False) -> NdarrayTensor: +def std(x: NdarrayTensor, dim: int | tuple | None = None, unbiased: bool = False) -> NdarrayTensor: """`torch.std` with equivalent implementation for numpy Args: @@ -521,7 +524,7 @@ def std(x: NdarrayTensor, dim: Optional[Union[int, Tuple]] = None, unbiased: boo return ret -def sum(x: NdarrayTensor, dim: Optional[Union[int, Tuple]] = None, **kwargs) -> NdarrayTensor: +def sum(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor: """`torch.sum` with equivalent implementation for numpy Args: diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 6dc12a0254..8210ec924c 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -9,18 +9,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + # have to explicitly bring these in here to resolve circular import issues from .aliases import alias, resolve_name from .decorators import MethodReplacer, RestartGenerator -from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg +from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default from .dist import evenly_divisible_all_gather, get_dist_device, string_list_all_gather from .enums import ( Average, BlendMode, BoxModeName, + BundleProperty, + BundlePropertyConfig, ChannelMatching, ColorOrder, CommonKeys, + CompInitMode, DiceCEReduction, EngineStatsKeys, FastMRIKeys, @@ -32,7 +37,6 @@ HoVerNetBranch, HoVerNetMode, InterpolateMode, - InverseKeys, JITMetadataKeys, LazyAttr, LossReduction, @@ -41,6 +45,7 @@ MetricReduction, NdimageMode, NumpyPadMode, + PatchKeys, PostFix, ProbMapKeys, PytorchPadMode, @@ -68,12 +73,14 @@ first, get_seed, has_option, + is_immutable, is_module_ver_at_least, is_scalar, is_scalar_tensor, issequenceiterable, list_to_dict, path_to_uri, + pprint_edges, progress_bar, sample_slices, save_obj, diff --git a/monai/utils/aliases.py b/monai/utils/aliases.py index 967f1c5670..38a3fa6d31 100644 --- a/monai/utils/aliases.py +++ b/monai/utils/aliases.py @@ -12,6 +12,8 @@ This module is written for configurable workflow, not currently in use. """ +from __future__ import annotations + import importlib import inspect import sys diff --git a/monai/utils/decorators.py b/monai/utils/decorators.py index 0856c0fc1a..1c064468e8 100644 --- a/monai/utils/decorators.py +++ b/monai/utils/decorators.py @@ -9,10 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from functools import wraps __all__ = ["RestartGenerator", "MethodReplacer"] +from typing import Callable, Generator + class RestartGenerator: """ @@ -20,10 +24,10 @@ class RestartGenerator: used to create an iterator which can start iteration over the given generator multiple times. """ - def __init__(self, create_gen) -> None: + def __init__(self, create_gen: Callable[[], Generator]) -> None: self.create_gen = create_gen - def __iter__(self): + def __iter__(self) -> Generator: return self.create_gen() @@ -34,7 +38,7 @@ class MethodReplacer: replace_list_name = "__replacemethods__" - def __init__(self, meth) -> None: + def __init__(self, meth: Callable) -> None: self.meth = meth def replace_method(self, meth): diff --git a/monai/utils/deprecate_utils.py b/monai/utils/deprecate_utils.py index 68f2d6e46d..d4f239cd23 100644 --- a/monai/utils/deprecate_utils.py +++ b/monai/utils/deprecate_utils.py @@ -9,18 +9,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import inspect import sys import warnings +from collections.abc import Callable from functools import wraps from types import FunctionType -from typing import Optional +from typing import Any, TypeVar from monai.utils.module import version_leq from .. import __version__ -__all__ = ["deprecated", "deprecated_arg", "DeprecatedError"] +__all__ = ["deprecated", "deprecated_arg", "DeprecatedError", "deprecated_arg_default"] +T = TypeVar("T", type, Callable) class DeprecatedError(Exception): @@ -35,12 +39,12 @@ def warn_deprecated(obj, msg, warning_category=FutureWarning): def deprecated( - since: Optional[str] = None, - removed: Optional[str] = None, + since: str | None = None, + removed: str | None = None, msg_suffix: str = "", version_val: str = __version__, - warning_category=FutureWarning, -): + warning_category: type[FutureWarning] = FutureWarning, +) -> Callable[[T], T]: """ Marks a function or class as deprecated. If `since` is given this should be a version at or earlier than the current version and states at what version of the definition was marked as deprecated. If `removed` is given @@ -117,14 +121,14 @@ def _wrapper(*args, **kwargs): def deprecated_arg( - name, - since: Optional[str] = None, - removed: Optional[str] = None, + name: str, + since: str | None = None, + removed: str | None = None, msg_suffix: str = "", version_val: str = __version__, - new_name: Optional[str] = None, - warning_category=FutureWarning, -): + new_name: str | None = None, + warning_category: type[FutureWarning] = FutureWarning, +) -> Callable[[T], T]: """ Marks a particular named argument of a callable as deprecated. The same conditions for `since` and `removed` as described in the `deprecated` decorator. @@ -138,8 +142,6 @@ def deprecated_arg( using the Sphinx directives such as `.. versionchanged:: version` and `.. deprecated:: version`. https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded - In the current implementation type annotations are not preserved. - Args: name: name of position or keyword argument to mark as deprecated. @@ -172,7 +174,7 @@ def deprecated_arg( else: # compare the numbers is_deprecated = since is not None and version_leq(since, version_val) - is_removed = removed is not None and version_leq(removed, version_val) + is_removed = removed is not None and version_val != f"{sys.maxsize}" and version_leq(removed, version_val) def _decorator(func): argname = f"{func.__module__} {func.__qualname__}:{name}" @@ -223,3 +225,103 @@ def _wrapper(*args, **kwargs): return _wrapper return _decorator + + +def deprecated_arg_default( + name: str, + old_default: Any, + new_default: Any, + since: str | None = None, + replaced: str | None = None, + msg_suffix: str = "", + version_val: str = __version__, + warning_category: type[FutureWarning] = FutureWarning, +) -> Callable[[T], T]: + """ + Marks a particular arguments default of a callable as deprecated. It is changed from `old_default` to `new_default` + in version `changed`. + + When the decorated definition is called, a `warning_category` is issued if `since` is given, + the default is not explicitly set by the caller and the current version is at or later than that given. + Another warning with the same category is issued if `changed` is given and the current version is at or later. + + The relevant docstring of the deprecating function should also be updated accordingly, + using the Sphinx directives such as `.. versionchanged:: version` and `.. deprecated:: version`. + https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded + + + Args: + name: name of position or keyword argument where the default is deprecated/changed. + old_default: name of the old default. This is only for the warning message, it will not be validated. + new_default: name of the new default. + It is validated that this value is not present as the default before version `replaced`. + This means, that you can also use this if the actual default value is `None` and set later in the function. + You can also set this to any string representation, e.g. `"calculate_default_value()"` + if the default is calculated from another function. + since: version at which the argument default was marked deprecated but not replaced. + replaced: version at which the argument default was/will be replaced. + msg_suffix: message appended to warning/exception detailing reasons for deprecation. + version_val: (used for testing) version to compare since and removed against, default is MONAI version. + warning_category: a warning category class, defaults to `FutureWarning`. + + Returns: + Decorated callable which warns when deprecated default argument is not explicitly specified. + """ + + if version_val.startswith("0+") or not f"{version_val}".strip()[0].isdigit(): + # version unknown, set version_val to a large value (assuming the latest version) + version_val = f"{sys.maxsize}" + if since is not None and replaced is not None and not version_leq(since, replaced): + raise ValueError(f"since must be less or equal to replaced, got since={since}, replaced={replaced}.") + is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since) + if is_not_yet_deprecated: + # smaller than `since`, do nothing + return lambda obj: obj + if since is None and replaced is None: + # raise a DeprecatedError directly + is_replaced = True + is_deprecated = True + else: + # compare the numbers + is_deprecated = since is not None and version_leq(since, version_val) + is_replaced = replaced is not None and version_val != f"{sys.maxsize}" and version_leq(replaced, version_val) + + def _decorator(func): + argname = f"{func.__module__} {func.__qualname__}:{name}" + + msg_prefix = f" Current default value of argument `{name}={old_default}`" + + if is_replaced: + msg_infix = f"was changed in version {replaced} from `{name}={old_default}` to `{name}={new_default}`." + elif is_deprecated: + msg_infix = f"has been deprecated since version {since}." + if replaced is not None: + msg_infix += f" It will be changed to `{name}={new_default}` in version {replaced}." + else: + msg_infix = f"has been deprecated from `{name}={old_default}` to `{name}={new_default}`." + + msg = f"{msg_prefix} {msg_infix} {msg_suffix}".strip() + + sig = inspect.signature(func) + if name not in sig.parameters: + raise ValueError(f"Argument `{name}` not found in signature of {func.__qualname__}.") + param = sig.parameters[name] + if param.default is inspect.Parameter.empty: + raise ValueError(f"Argument `{name}` has no default value.") + + if param.default == new_default and not is_replaced: + raise ValueError( + f"Argument `{name}` was replaced to the new default value `{new_default}` before the specified version {replaced}." + ) + + @wraps(func) + def _wrapper(*args, **kwargs): + if name not in sig.bind(*args, **kwargs).arguments and is_deprecated: + # arg was not found so the default value is used + warn_deprecated(argname, msg, warning_category) + + return func(*args, **kwargs) + + return _wrapper + + return _decorator diff --git a/monai/utils/dist.py b/monai/utils/dist.py index 37536bfe83..546058c93e 100644 --- a/monai/utils/dist.py +++ b/monai/utils/dist.py @@ -9,7 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from __future__ import annotations + +import sys + +if sys.version_info >= (3, 8): + from typing import Literal + +from typing import overload import torch import torch.distributed as dist @@ -39,7 +46,22 @@ def get_dist_device(): return None -def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = True): +@overload +def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[True]) -> torch.Tensor: + ... + + +@overload +def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[False]) -> list[torch.Tensor]: + ... + + +@overload +def evenly_divisible_all_gather(data: torch.Tensor, concat: bool) -> torch.Tensor | list[torch.Tensor]: + ... + + +def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = True) -> torch.Tensor | list[torch.Tensor]: """ Utility function for distributed data parallel to pad at first dim to make it evenly divisible and all_gather. The input data of every rank should have the same number of dimensions, only the first dim can be different. @@ -62,7 +84,7 @@ def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = True): ndims = data.ndimension() length: int = data.shape[0] if ndims > 0 else 1 - def _torch_all_gather(data: torch.Tensor) -> List[torch.Tensor]: + def _torch_all_gather(data: torch.Tensor) -> list[torch.Tensor]: """ Implementation based on native PyTorch distributed data parallel APIs. @@ -76,7 +98,7 @@ def _torch_all_gather(data: torch.Tensor) -> List[torch.Tensor]: length_tensor = torch.as_tensor([length], device=device) all_lens = [torch.zeros_like(length_tensor) for _ in range(dist.get_world_size())] dist.all_gather(all_lens, length_tensor) - all_lens_: List[int] = [int(i.item()) for i in all_lens] + all_lens_: list[int] = [int(i.item()) for i in all_lens] max_len: int = max(all_lens_) if length < max_len: @@ -88,14 +110,14 @@ def _torch_all_gather(data: torch.Tensor) -> List[torch.Tensor]: # remove the padding items, if all the input data doesn't have batch dim, squeeze the first dim return [(o.squeeze(0) if ndims == 0 else o[:l, ...]).to(orig_device) for o, l in zip(output, all_lens_)] - def _ignite_all_gather(data: torch.Tensor) -> List[torch.Tensor]: + def _ignite_all_gather(data: torch.Tensor) -> list[torch.Tensor]: """ Implementation based on PyTorch ignite package, it can support more kinds of backends. """ data = data.unsqueeze(0) if ndims == 0 else data # make sure the data is evenly-divisible on multi-GPUs - all_lens: List[int] = idist.all_gather(length) + all_lens: list[int] = idist.all_gather(length) max_len: int = max(all_lens) if length < max_len: size = [max_len - length] + list(data.shape[1:]) @@ -108,7 +130,7 @@ def _ignite_all_gather(data: torch.Tensor) -> List[torch.Tensor]: return list(torch.unbind(output, dim=0)) return [output[i * max_len : i * max_len + l, ...] for i, l in enumerate(all_lens)] - output: List[torch.Tensor] + output: list[torch.Tensor] if has_ignite: if idist.get_world_size() <= 1: return data @@ -123,7 +145,7 @@ def _ignite_all_gather(data: torch.Tensor) -> List[torch.Tensor]: return torch.cat(output, dim=0) if concat else output -def string_list_all_gather(strings: List[str], delimiter: str = "\t") -> List[str]: +def string_list_all_gather(strings: list[str], delimiter: str = "\t") -> list[str]: """ Utility function for distributed data parallel to all gather a list of strings. Refer to the idea of ignite `all_gather(string)`: @@ -149,6 +171,6 @@ def string_list_all_gather(strings: List[str], delimiter: str = "\t") -> List[st joined = delimiter.join(strings) gathered = evenly_divisible_all_gather(torch.tensor(bytearray(joined, "utf-8"), dtype=torch.long), concat=False) - gathered = [bytearray(g.tolist()).decode("utf-8").split(delimiter) for g in gathered] + _gathered = [bytearray(g.tolist()).decode("utf-8").split(delimiter) for g in gathered] - return [i for k in gathered for i in k] + return [i for k in _gathered for i in k] diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 4352d83473..8fd79a24da 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -9,11 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import random from enum import Enum -from typing import Optional - -from monai.utils.deprecate_utils import deprecated __all__ = [ "StrEnum", @@ -35,12 +34,12 @@ "SkipMode", "Method", "TraceKeys", - "InverseKeys", "CommonKeys", "GanKeys", "PostFix", "ForwardMode", "TransformBackends", + "CompInitMode", "BoxModeName", "GridPatchSort", "FastMRIKeys", @@ -55,6 +54,8 @@ "HoVerNetMode", "HoVerNetBranch", "LazyAttr", + "BundleProperty", + "BundlePropertyConfig", ] @@ -311,25 +312,8 @@ class TraceKeys(StrEnum): DO_TRANSFORM: str = "do_transforms" KEY_SUFFIX: str = "_transforms" NONE: str = "none" - - -@deprecated(since="0.8.0", msg_suffix="use monai.utils.enums.TraceKeys instead.") -class InverseKeys: - """ - Extra metadata keys used for inverse transforms. - - .. deprecated:: 0.8.0 - Use :class:`monai.utils.enums.TraceKeys` instead. - - """ - - CLASS_NAME = "class" - ID = "id" - ORIG_SIZE = "orig_size" - EXTRA_INFO = "extra_info" - DO_TRANSFORM = "do_transforms" - KEY_SUFFIX = "_transforms" - NONE = "none" + TRACING: str = "tracing" + LAZY_EVALUATION: str = "lazy_evaluation" class CommonKeys(StrEnum): @@ -367,19 +351,19 @@ class PostFix(StrEnum): """Post-fixes.""" @staticmethod - def _get_str(prefix, suffix): + def _get_str(prefix: str | None, suffix: str) -> str: return suffix if prefix is None else f"{prefix}_{suffix}" @staticmethod - def meta(key: Optional[str] = None): + def meta(key: str | None = None) -> str: return PostFix._get_str(key, "meta_dict") @staticmethod - def orig_meta(key: Optional[str] = None): + def orig_meta(key: str | None = None) -> str: return PostFix._get_str(key, "orig_meta_dict") @staticmethod - def transforms(key: Optional[str] = None): + def transforms(key: str | None = None) -> str: return PostFix._get_str(key, TraceKeys.KEY_SUFFIX[1:]) @@ -397,6 +381,18 @@ class TransformBackends(StrEnum): CUPY = "cupy" +class CompInitMode(StrEnum): + """ + Mode names for instantiating a class or calling a callable. + + See also: :py:func:`monai.utils.module.instantiate` + """ + + DEFAULT = "default" + PARTIAL = "partial" + DEBUG = "debug" + + class JITMetadataKeys(StrEnum): """ Keys stored in the metadata file for saved Torchscript models. Some of these are generated by the routines @@ -468,15 +464,25 @@ def get_sort_fn(sort_fn): ) -class WSIPatchKeys(StrEnum): +class PatchKeys(StrEnum): """ - The keys to be used for metadata of patches extracted from whole slide images + The keys to be used for metadata of patches extracted from any kind of image """ LOCATION = "location" - LEVEL = "level" SIZE = "size" COUNT = "count" + + +class WSIPatchKeys(StrEnum): + """ + The keys to be used for metadata of patches extracted from whole slide images + """ + + LOCATION = PatchKeys.LOCATION + SIZE = PatchKeys.SIZE + COUNT = PatchKeys.COUNT + LEVEL = "level" PATH = "path" @@ -514,7 +520,7 @@ class MetaKeys(StrEnum): ORIGINAL_AFFINE = "original_affine" # the affine after image loading before any data processing SPATIAL_SHAPE = "spatial_shape" # optional key for the length in each spatial dimension SPACE = "space" # possible values of space type are defined in `SpaceKeys` - ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or "no_channel" + ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or float("nan") class ColorOrder(StrEnum): @@ -567,6 +573,7 @@ class ImageStatsKeys(StrEnum): CHANNELS = "channels" CROPPED_SHAPE = "cropped_shape" SPACING = "spacing" + SIZEMM = "sizemm" INTENSITY = "intensity" HISTOGRAM = "histogram" @@ -631,3 +638,27 @@ class LazyAttr(StrEnum): PADDING_MODE = "lazy_padding_mode" INTERP_MODE = "lazy_interpolation_mode" DTYPE = "lazy_dtype" + ALIGN_CORNERS = "lazy_align_corners" + + +class BundleProperty(StrEnum): + """ + Bundle property fields: + `DESC` is the description of the property. + `REQUIRED` is flag to indicate whether the property is required or optional. + """ + + DESC = "description" + REQUIRED = "required" + + +class BundlePropertyConfig(StrEnum): + """ + additional bundle property fields for config based bundle workflow: + `ID` is the config item ID of the property. + `REF_ID` is the ID of config item which is supposed to refer to this property. + this field is only useful to check the optional property ID. + """ + + ID = "id" + REF_ID = "refer_id" diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py index 38b8e0a7e7..876f3e5d48 100644 --- a/monai/utils/jupyter_utils.py +++ b/monai/utils/jupyter_utils.py @@ -13,10 +13,13 @@ Matplotlib produce common plots for metrics and images. """ +from __future__ import annotations + import copy +from collections.abc import Callable, Mapping from enum import Enum from threading import RLock, Thread -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any import numpy as np import torch @@ -41,13 +44,13 @@ def plot_metric_graph( - ax, + ax: plt.Axes, title: str, - graphmap: Dict[str, Union[List[float], Tuple[List[float], List[float]]]], + graphmap: Mapping[str, list[float] | tuple[list[float], list[float]]], yscale: str = "log", - avg_keys: Tuple[str] = (LOSS_NAME,), + avg_keys: tuple[str] = (LOSS_NAME,), window_fraction: int = 20, -): +) -> None: """ Plot metrics on a single graph with running averages plotted for selected keys. The values in `graphmap` should be lists of (timepoint, value) pairs as stored in MetricLogger objects. @@ -88,14 +91,14 @@ def plot_metric_graph( def plot_metric_images( - fig, + fig: plt.Figure, title: str, - graphmap: Dict[str, Union[List[float], Tuple[List[float], List[float]]]], - imagemap: Dict[str, np.ndarray], + graphmap: Mapping[str, list[float] | tuple[list[float], list[float]]], + imagemap: dict[str, np.ndarray], yscale: str = "log", - avg_keys: Tuple[str] = (LOSS_NAME,), + avg_keys: tuple[str] = (LOSS_NAME,), window_fraction: int = 20, -) -> List: +) -> list: """ Plot metric graph data with images below into figure `fig`. The intended use is for the graph data to be metrics from a training run and the images to be the batch and output from the last iteration. This uses @@ -135,7 +138,7 @@ def plot_metric_images( return axes -def tensor_to_images(name: str, tensor: torch.Tensor): +def tensor_to_images(name: str, tensor: torch.Tensor) -> np.ndarray | None: """ Return an tuple of images derived from the given tensor. The `name` value indices which key from the output or batch value the tensor was stored as, or is "Batch" or "Output" if these were single tensors @@ -144,25 +147,25 @@ def tensor_to_images(name: str, tensor: torch.Tensor): each channel separately. """ if tensor.ndim == 3 and tensor.shape[1] > 2 and tensor.shape[2] > 2: - return tensor.cpu().data.numpy() + return tensor.cpu().data.numpy() # type: ignore[no-any-return] if tensor.ndim == 4 and tensor.shape[2] > 2 and tensor.shape[3] > 2: dmid = tensor.shape[1] // 2 - return tensor[:, dmid].cpu().data.numpy() + return tensor[:, dmid].cpu().data.numpy() # type: ignore[no-any-return] return None def plot_engine_status( engine: Engine, - logger, + logger: Any, title: str = "Training Log", yscale: str = "log", - avg_keys: Tuple[str] = (LOSS_NAME,), + avg_keys: tuple[str] = (LOSS_NAME,), window_fraction: int = 20, - image_fn: Optional[Callable] = tensor_to_images, - fig=None, + image_fn: Callable[[str, torch.Tensor], Any] | None = tensor_to_images, + fig: plt.Figure = None, selected_inst: int = 0, -) -> Tuple: +) -> tuple[plt.Figure, list]: """ Plot the status of the given Engine with its logger. The plot will consist of a graph of loss values and metrics taken from the logger, and images taken from the `output` and `batch` members of `engine.state`. The images are @@ -188,10 +191,10 @@ def plot_engine_status( else: fig = plt.Figure(figsize=(20, 10), tight_layout=True, facecolor="white") - graphmap = {LOSS_NAME: logger.loss} + graphmap: dict[str, list[float]] = {LOSS_NAME: logger.loss} graphmap.update(logger.metrics) - imagemap: Dict = {} + imagemap: dict = {} if image_fn is not None and engine.state is not None and engine.state.batch is not None: for src in (engine.state.batch, engine.state.output): label = "Batch" if src is engine.state.batch else "Output" @@ -230,10 +233,12 @@ def plot_engine_status( return fig, axes -def _get_loss_from_output(output: Union[Dict[str, torch.Tensor], torch.Tensor]): +def _get_loss_from_output( + output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor +) -> torch.Tensor: """Returns a single value from the network output, which is a dict or tensor.""" - def _get_loss(data): + def _get_loss(data: torch.Tensor | dict[str, torch.Tensor]) -> torch.Tensor: if isinstance(data, dict): return data["loss"] return data @@ -280,10 +285,10 @@ def __init__( super().__init__() self.lock = RLock() self.engine = engine - self._status_dict: Dict[str, Any] = {} + self._status_dict: dict[str, Any] = {} self.loss_transform = loss_transform self.metric_transform = metric_transform - self.fig = None + self.fig: plt.Figure | None = None self.status_format = status_format self.engine.add_event_handler(Events.ITERATION_COMPLETED, self._update_status) @@ -301,7 +306,7 @@ def _update_status(self): """Called as an event, updates the internal status dict at the end of iterations.""" with self.lock: state = self.engine.state - stats: Dict[str, Any] = { + stats: dict[str, Any] = { StatusMembers.EPOCHS.value: 0, StatusMembers.ITERS.value: 0, StatusMembers.LOSS.value: float("nan"), @@ -331,7 +336,7 @@ def _update_status(self): self._status_dict.update(stats) @property - def status_dict(self) -> Dict[str, str]: + def status_dict(self) -> dict[str, str]: """A dictionary containing status information, current loss, and current metric values.""" with self.lock: stats = {StatusMembers.STATUS.value: "Running" if self.is_alive() else "Stopped"} @@ -354,7 +359,7 @@ def status(self) -> str: return ", ".join(msgs) - def plot_status(self, logger, plot_func: Callable = plot_engine_status): + def plot_status(self, logger: Any, plot_func: Callable = plot_engine_status) -> plt.Figure: """ Generate a plot of the current status of the contained engine whose loss and metrics were tracked by `logger`. The function `plot_func` must accept arguments `title`, `engine`, `logger`, and `fig` which are the plot title, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index c751ad3b49..a729688209 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -9,19 +9,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import inspect import itertools import os +import pprint import random import shutil import tempfile import types import warnings from ast import literal_eval -from collections.abc import Iterable +from collections.abc import Callable, Iterable, Sequence from distutils.util import strtobool from pathlib import Path -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, TypeVar, cast, overload import numpy as np import torch @@ -34,6 +37,7 @@ "star_zip_with", "first", "issequenceiterable", + "is_immutable", "ensure_tuple", "ensure_tuple_size", "ensure_tuple_rep", @@ -57,6 +61,7 @@ "save_obj", "label_union", "path_to_uri", + "pprint_edges", ] _seed = None @@ -80,7 +85,20 @@ def star_zip_with(op, *vals): return zip_with(op, *vals, mapfunc=itertools.starmap) -def first(iterable, default=None): +T = TypeVar("T") + + +@overload +def first(iterable: Iterable[T], default: T) -> T: + ... + + +@overload +def first(iterable: Iterable[T]) -> T | None: + ... + + +def first(iterable: Iterable[T], default: T | None = None) -> T | None: """ Returns the first item in the given iterable or `default` if empty, meaningful mostly with 'for' expressions. """ @@ -101,7 +119,16 @@ def issequenceiterable(obj: Any) -> bool: return isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)) -def ensure_tuple(vals: Any, wrap_array: bool = False) -> Tuple[Any, ...]: +def is_immutable(obj: Any) -> bool: + """ + Determine if the object is an immutable object. + + see also https://github.com/python/cpython/blob/3.11/Lib/copy.py#L109 + """ + return isinstance(obj, (type(None), int, float, bool, complex, str, tuple, bytes, type, range, slice)) + + +def ensure_tuple(vals: Any, wrap_array: bool = False) -> tuple: """ Returns a tuple of `vals`. @@ -116,15 +143,20 @@ def ensure_tuple(vals: Any, wrap_array: bool = False) -> Tuple[Any, ...]: return tuple(vals) if issequenceiterable(vals) else (vals,) -def ensure_tuple_size(tup: Any, dim: int, pad_val: Any = 0) -> Tuple[Any, ...]: +def ensure_tuple_size(vals: Any, dim: int, pad_val: Any = 0, pad_from_start: bool = False) -> tuple: """ Returns a copy of `tup` with `dim` values by either shortened or padded with `pad_val` as necessary. """ - new_tup = ensure_tuple(tup) + (pad_val,) * dim - return new_tup[:dim] + tup = ensure_tuple(vals) + pad_dim = dim - len(tup) + if pad_dim <= 0: + return tup[:dim] + if pad_from_start: + return (pad_val,) * pad_dim + tup + return tup + (pad_val,) * pad_dim -def ensure_tuple_rep(tup: Any, dim: int) -> Tuple[Any, ...]: +def ensure_tuple_rep(tup: Any, dim: int) -> tuple[Any, ...]: """ Returns a copy of `tup` with `dim` values by either shortened or duplicated input. @@ -160,8 +192,8 @@ def ensure_tuple_rep(tup: Any, dim: int) -> Tuple[Any, ...]: def fall_back_tuple( - user_provided: Any, default: Union[Sequence, NdarrayTensor], func: Callable = lambda x: x and x > 0 -) -> Tuple[Any, ...]: + user_provided: Any, default: Sequence | NdarrayTensor, func: Callable = lambda x: x and x > 0 +) -> tuple[Any, ...]: """ Refine `user_provided` according to the `default`, and returns as a validated tuple. @@ -215,7 +247,7 @@ def is_scalar(val: Any) -> bool: return bool(np.isscalar(val)) -def progress_bar(index: int, count: int, desc: Optional[str] = None, bar_len: int = 30, newline: bool = False) -> None: +def progress_bar(index: int, count: int, desc: str | None = None, bar_len: int = 30, newline: bool = False) -> None: """print a progress bar to track some time consuming task. Args: @@ -234,14 +266,14 @@ def progress_bar(index: int, count: int, desc: Optional[str] = None, bar_len: in print("") -def get_seed() -> Optional[int]: +def get_seed() -> int | None: return _seed def set_determinism( - seed: Optional[int] = NP_MAX, - use_deterministic_algorithms: Optional[bool] = None, - additional_settings: Optional[Union[Sequence[Callable[[int], Any]], Callable[[int], Any]]] = None, + seed: int | None = NP_MAX, + use_deterministic_algorithms: bool | None = None, + additional_settings: Sequence[Callable[[int], Any]] | Callable[[int], Any] | None = None, ) -> None: """ Set random seed for modules to enable or disable deterministic training. @@ -332,7 +364,7 @@ def _parse_var(s): def copy_to_device( - obj: Any, device: Optional[Union[str, torch.device]], non_blocking: bool = True, verbose: bool = False + obj: Any, device: str | torch.device | None, non_blocking: bool = True, verbose: bool = False ) -> Any: """ Copy object or tuple/list/dictionary of objects to ``device``. @@ -365,7 +397,7 @@ def copy_to_device( return obj -def str2bool(value: Union[str, bool], default: bool = False, raise_exc: bool = True) -> bool: +def str2bool(value: str | bool, default: bool = False, raise_exc: bool = True) -> bool: """ Convert a string to a boolean. Case insensitive. True: yes, true, t, y, 1. False: no, false, f, n, 0. @@ -400,7 +432,7 @@ def str2bool(value: Union[str, bool], default: bool = False, raise_exc: bool = T return default -def str2list(value: Optional[Union[str, list]], raise_exc: bool = True) -> Optional[list]: +def str2list(value: str | list | None, raise_exc: bool = True) -> list | None: """ Convert a string to a list. Useful with argparse commandline arguments: parser.add_argument("--blocks", default=[1,2,3], type=str2list) @@ -438,7 +470,7 @@ class MONAIEnvVars: """ @staticmethod - def data_dir() -> Optional[str]: + def data_dir() -> str | None: return os.environ.get("MONAI_DATA_DIRECTORY") @staticmethod @@ -447,7 +479,7 @@ def debug() -> bool: return val if isinstance(val, bool) else str2bool(val) @staticmethod - def doc_images() -> Optional[str]: + def doc_images() -> str | None: return os.environ.get("MONAI_DOC_IMAGES") @@ -461,7 +493,7 @@ class ImageMetaKey: SPATIAL_SHAPE = "spatial_shape" -def has_option(obj, keywords: Union[str, Sequence[str]]) -> bool: +def has_option(obj: Callable, keywords: str | Sequence[str]) -> bool: """ Return a boolean indicating whether the given callable `obj` has the `keywords` in its signature. """ @@ -502,7 +534,7 @@ def sample_slices(data: NdarrayOrTensor, dim: int = 1, as_indices: bool = True, return data[tuple(slices)] -def check_parent_dir(path: PathLike, create_dir: bool = True): +def check_parent_dir(path: PathLike, create_dir: bool = True) -> None: """ Utility to check whether the parent directory of the `path` exists. @@ -522,8 +554,13 @@ def check_parent_dir(path: PathLike, create_dir: bool = True): def save_obj( - obj, path: PathLike, create_dir: bool = True, atomic: bool = True, func: Optional[Callable] = None, **kwargs -): + obj: object, + path: PathLike, + create_dir: bool = True, + atomic: bool = True, + func: Callable | None = None, + **kwargs: Any, +) -> None: """ Save an object to file with specified path. Support to serialize to a temporary file first, then move to final destination, @@ -564,7 +601,7 @@ def save_obj( pass -def label_union(x: List) -> List: +def label_union(x: list | np.ndarray) -> list: """ Compute the union of class IDs in label and generate a list to include all class IDs Args: @@ -576,7 +613,7 @@ def label_union(x: List) -> List: return list(set.union(set(np.array(x).tolist()))) -def prob2class(x, sigmoid: bool = False, threshold: float = 0.5, **kwargs): +def prob2class(x: torch.Tensor, sigmoid: bool = False, threshold: float = 0.5, **kwargs: Any) -> torch.Tensor: """ Compute the lab from the probability of predicted feature maps @@ -596,3 +633,17 @@ def path_to_uri(path: PathLike) -> str: """ return Path(path).absolute().as_uri() + + +def pprint_edges(val: Any, n_lines: int = 20) -> str: + """ + Pretty print the head and tail ``n_lines`` of ``val``, and omit the middle part if the part has more than 3 lines. + + Returns: the formatted string. + """ + val_str = pprint.pformat(val).splitlines(True) + n_lines = max(n_lines, 1) + if len(val_str) > n_lines * 2 + 3: + hidden_n = len(val_str) - n_lines * 2 + val_str = val_str[:n_lines] + [f"\n ... omitted {hidden_n} line(s)\n\n"] + val_str[-n_lines:] + return "".join(val_str) diff --git a/monai/utils/module.py b/monai/utils/module.py index 435b07fcac..b72e3ff139 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -9,19 +9,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import enum import os +import pdb import re import sys import warnings +from collections.abc import Callable, Collection, Hashable, Mapping from functools import partial, wraps from importlib import import_module -from inspect import isclass, isfunction, ismethod from pkgutil import walk_packages from pydoc import locate from re import match -from types import FunctionType -from typing import Any, Callable, Collection, Hashable, Iterable, List, Mapping, Tuple, Union, cast +from types import FunctionType, ModuleType +from typing import Any, cast import torch @@ -55,7 +58,12 @@ ] -def look_up_option(opt_str, supported: Union[Collection, enum.EnumMeta], default="no_default", print_all_options=True): +def look_up_option( + opt_str: Hashable, + supported: Collection | enum.EnumMeta, + default: Any = "no_default", + print_all_options: bool = True, +) -> Any: """ Look up the option in the supported collection and return the matched item. Raise a value error possibly with a guess of the closest match. @@ -92,7 +100,7 @@ class Color(Enum): if isinstance(opt_str, str): opt_str = opt_str.strip() if isinstance(supported, enum.EnumMeta): - if isinstance(opt_str, str) and opt_str in {item.value for item in cast(Iterable[enum.Enum], supported)}: + if isinstance(opt_str, str) and opt_str in {item.value for item in supported}: # type: ignore # such as: "example" in MyEnum return supported(opt_str) if isinstance(opt_str, enum.Enum) and opt_str in supported: @@ -110,7 +118,7 @@ class Color(Enum): # find a close match set_to_check: set if isinstance(supported, enum.EnumMeta): - set_to_check = {item.value for item in cast(Iterable[enum.Enum], supported)} + set_to_check = {item.value for item in supported} # type: ignore else: set_to_check = set(supported) if supported is not None else set() if not set_to_check: @@ -133,7 +141,7 @@ class Color(Enum): raise ValueError(f"Unsupported option '{opt_str}', " + supported_msg) -def damerau_levenshtein_distance(s1: str, s2: str): +def damerau_levenshtein_distance(s1: str, s2: str) -> int: """ Calculates the Damerau–Levenshtein distance between two strings for spelling correction. https://en.wikipedia.org/wiki/Damerau–Levenshtein_distance @@ -184,13 +192,15 @@ def _inner(obj): return _inner -def load_submodules(basemod, load_all: bool = True, exclude_pattern: str = "(.*[tT]est.*)|(_.*)"): +def load_submodules( + basemod: ModuleType, load_all: bool = True, exclude_pattern: str = "(.*[tT]est.*)|(_.*)" +) -> tuple[list[ModuleType], list[str]]: """ Traverse the source of the module structure starting with module `basemod`, loading all packages plus all files if `load_all` is True, excluding anything whose name matches `exclude_pattern`. """ submodules = [] - err_mod: List[str] = [] + err_mod: list[str] = [] for importer, name, is_pkg in walk_packages( basemod.__path__, prefix=basemod.__name__ + ".", onerror=err_mod.append ): @@ -211,40 +221,57 @@ def load_submodules(basemod, load_all: bool = True, exclude_pattern: str = "(.*[ return submodules, err_mod -def instantiate(path: str, **kwargs): +def instantiate(__path: str, __mode: str, **kwargs: Any) -> Any: """ - Create an object instance or partial function from a class or function represented by string. + Create an object instance or call a callable object from a class or function represented by ``_path``. `kwargs` will be part of the input arguments to the class constructor or function. The target component must be a class or a function, if not, return the component directly. Args: - path: full path of the target class or function component. - kwargs: arguments to initialize the class instance or set default args - for `partial` function. + __path: if a string is provided, it's interpreted as the full path of the target class or function component. + If a callable is provided, ``__path(**kwargs)`` or ``functools.partial(__path, **kwargs)`` will be returned. + __mode: the operating mode for invoking the (callable) ``component`` represented by ``__path``: + + - ``"default"``: returns ``component(**kwargs)`` + - ``"partial"``: returns ``functools.partial(component, **kwargs)`` + - ``"debug"``: returns ``pdb.runcall(component, **kwargs)`` + + kwargs: keyword arguments to the callable represented by ``__path``. """ + from monai.utils.enums import CompInitMode - component = locate(path) + component = locate(__path) if isinstance(__path, str) else __path if component is None: - raise ModuleNotFoundError(f"Cannot locate class or function path: '{path}'.") + raise ModuleNotFoundError(f"Cannot locate class or function path: '{__path}'.") + m = look_up_option(__mode, CompInitMode) try: if kwargs.pop("_debug_", False) or run_debug: warnings.warn( - f"\n\npdb: instantiating component={component}\n" + f"\n\npdb: instantiating component={component}, mode={m}\n" f"See also Debugger commands documentation: https://docs.python.org/3/library/pdb.html\n" ) - import pdb - - pdb.set_trace() - if isclass(component): + breakpoint() + if not callable(component): + warnings.warn(f"Component {component} is not callable when mode={m}.") + return component + if m == CompInitMode.DEFAULT: return component(**kwargs) - # support regular function, static method and class method - if isfunction(component) or (ismethod(component) and isclass(getattr(component, "__self__", None))): + if m == CompInitMode.PARTIAL: return partial(component, **kwargs) + if m == CompInitMode.DEBUG: + warnings.warn( + f"\n\npdb: instantiating component={component}, mode={m}\n" + f"See also Debugger commands documentation: https://docs.python.org/3/library/pdb.html\n" + ) + return pdb.runcall(component, **kwargs) except Exception as e: - raise RuntimeError(f"Failed to instantiate '{path}' with kwargs: {kwargs}") from e + raise RuntimeError( + f"Failed to instantiate component '{__path}' with kwargs: {kwargs}" + f"\n set '_mode_={CompInitMode.DEBUG}' to enter the debugging mode." + ) from e - warnings.warn(f"Component to instantiate must represent a valid class or function, but got {path}.") + warnings.warn(f"Component to instantiate must represent a valid class or function, but got {__path}.") return component @@ -259,7 +286,7 @@ def get_full_type_name(typeobj): return module + "." + typeobj.__name__ -def min_version(the_module, min_version_str: str = "", *_args) -> bool: +def min_version(the_module: Any, min_version_str: str = "", *_args: Any) -> bool: """ Convert version strings into tuples of int and compare them. @@ -274,7 +301,7 @@ def min_version(the_module, min_version_str: str = "", *_args) -> bool: return mod_version >= required -def exact_version(the_module, version_str: str = "", *_args) -> bool: +def exact_version(the_module: Any, version_str: str = "", *_args: Any) -> bool: """ Returns True if the module's __version__ matches version_str """ @@ -307,10 +334,10 @@ def optional_import( version_checker: Callable[..., bool] = min_version, name: str = "", descriptor: str = OPTIONAL_IMPORT_MSG_FMT, - version_args=None, + version_args: Any = None, allow_namespace_pkg: bool = False, as_type: str = "default", -) -> Tuple[Any, bool]: +) -> tuple[Any, bool]: """ Imports an optional module specified by `module` string. Any importing related exceptions will be stored, and exceptions raise lazily @@ -434,7 +461,7 @@ def __init__(self, *_args, **kwargs): def require_pkg( pkg_name: str, version: str = "", version_checker: Callable[..., bool] = min_version, raise_error: bool = True -): +) -> Callable: """ Decorator function to check the required package installation. @@ -450,10 +477,10 @@ def require_pkg( def _decorator(obj): is_func = isinstance(obj, FunctionType) call_obj = obj if is_func else obj.__init__ - _, has = optional_import(module=pkg_name, version=version, version_checker=version_checker) @wraps(call_obj) def _wrapper(*args, **kwargs): + _, has = optional_import(module=pkg_name, version=version, version_checker=version_checker) if not has: err_msg = f"required package `{pkg_name}` is not installed or the version doesn't match requirement." if raise_error: @@ -489,7 +516,7 @@ def get_torch_version_tuple(): return tuple(int(x) for x in torch.__version__.split(".")[:2]) -def version_leq(lhs: str, rhs: str): +def version_leq(lhs: str, rhs: str) -> bool: """ Returns True if version `lhs` is earlier or equal to `rhs`. @@ -503,11 +530,11 @@ def version_leq(lhs: str, rhs: str): pkging, has_ver = optional_import("pkg_resources", name="packaging") if has_ver: try: - return pkging.version.Version(lhs) <= pkging.version.Version(rhs) + return cast(bool, pkging.version.Version(lhs) <= pkging.version.Version(rhs)) except pkging.version.InvalidVersion: return True - def _try_cast(val: str): + def _try_cast(val: str) -> int | str: val = val.strip() try: m = match("(\\d+)(.*)", val) @@ -535,7 +562,7 @@ def _try_cast(val: str): return True -def pytorch_after(major, minor, patch=0, current_ver_string=None) -> bool: +def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: str | None = None) -> bool: """ Compute whether the current pytorch version is after or equal to the specified version. The current system pytorch version is determined by `torch.__version__` or @@ -579,7 +606,7 @@ def pytorch_after(major, minor, patch=0, current_ver_string=None) -> bool: is_prerelease = True patch = int(patch) if c_p != patch: - return c_p > patch # type: ignore + return c_p > patch if is_prerelease: return False return True diff --git a/monai/utils/nvtx.py b/monai/utils/nvtx.py index 7b16d1e2b3..2cea9aabd1 100644 --- a/monai/utils/nvtx.py +++ b/monai/utils/nvtx.py @@ -12,9 +12,11 @@ Decorators and context managers for NVIDIA Tools Extension to profile MONAI components """ +from __future__ import annotations + from collections import defaultdict from functools import wraps -from typing import Any, Optional, Tuple, Union +from typing import Any from torch.autograd import Function from torch.nn import Module @@ -52,9 +54,9 @@ class Range: def __init__( self, - name: Optional[str] = None, - methods: Optional[Union[str, Tuple[str, ...]]] = None, - append_method_name: Optional[bool] = None, + name: str | None = None, + methods: str | tuple[str, ...] | None = None, + append_method_name: bool | None = None, recursive: bool = False, ) -> None: self.name = name @@ -62,7 +64,7 @@ def __init__( self.append_method_name = append_method_name self.recursive = recursive - def __call__(self, obj: Any): + def __call__(self, obj: Any) -> Any: if self.recursive is True: if isinstance(obj, (list, tuple)): return type(obj)(Range(recursive=True)(t) for t in obj) diff --git a/monai/utils/profiling.py b/monai/utils/profiling.py index 291dfbbe16..da5c0ac05c 100644 --- a/monai/utils/profiling.py +++ b/monai/utils/profiling.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import csv import datetime import multiprocessing @@ -21,13 +23,18 @@ from inspect import getframeinfo, stack from queue import Empty from time import perf_counter, perf_counter_ns -from typing import Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast import numpy as np import torch from monai.utils import optional_import +if TYPE_CHECKING: + from ignite.engine import Events +else: + Events = optional_import("ignite.engine", name="Events") + pd, has_pandas = optional_import("pandas") __all__ = [ @@ -50,7 +57,6 @@ def torch_profiler_full(func): @wraps(func) def wrapper(*args, **kwargs): - with torch.autograd.profiler.profile(use_cuda=True) as prof: result = func(*args, **kwargs) @@ -70,7 +76,6 @@ def torch_profiler_time_cpu_gpu(func): @wraps(func) def wrapper(*args, **kwargs): - with torch.autograd.profiler.profile(use_cuda=True) as prof: result = func(*args, **kwargs) @@ -96,7 +101,6 @@ def torch_profiler_time_end_to_end(func): @wraps(func) def wrapper(*args, **kwargs): - torch.cuda.synchronize() start = perf_counter() @@ -123,7 +127,7 @@ class PerfContext: def __init__(self): self.total_time: float = 0 - self.start_time: Optional[float] = None + self.start_time: float | None = None def __enter__(self): self.start_time = perf_counter() @@ -198,7 +202,7 @@ def foo(): pass def __init__(self, call_selector=select_transform_call): self.results = defaultdict(list) self.parent_pid = os.getpid() - self.read_thread: Optional[threading.Thread] = None + self.read_thread: threading.Thread | None = None self.lock = threading.RLock() self.queue: multiprocessing.SimpleQueue = multiprocessing.SimpleQueue() self.queue_timeout = 0.1 @@ -290,7 +294,7 @@ def __exit__(self, exc_type, exc_value, traceback): threading.settrace(None) # type: ignore sys.settrace(None) - def add_result(self, result: ProfileResult): + def add_result(self, result: ProfileResult) -> None: """Add a result in a thread-safe manner to the internal results dictionary.""" with self.lock: self.results[result.name].append(result) @@ -406,7 +410,7 @@ class ProfileHandler: end_event: item in `ignite.engine.Events` stating event at which to stop timing """ - def __init__(self, name: str, profiler: WorkflowProfiler, start_event, end_event): + def __init__(self, name: str, profiler: WorkflowProfiler, start_event: Events, end_event: Events): self.name = name self.profiler = profiler self.start_event = start_event diff --git a/monai/utils/state_cacher.py b/monai/utils/state_cacher.py index 3e392ab979..d37e7abde4 100644 --- a/monai/utils/state_cacher.py +++ b/monai/utils/state_cacher.py @@ -9,11 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import copy import os import pickle import tempfile -from typing import Dict, Optional +from types import ModuleType +from typing import Any, Hashable import torch from torch.serialization import DEFAULT_PROTOCOL @@ -41,9 +44,9 @@ class StateCacher: def __init__( self, in_memory: bool, - cache_dir: Optional[PathLike] = None, + cache_dir: PathLike | None = None, allow_overwrite: bool = True, - pickle_module=pickle, + pickle_module: ModuleType = pickle, pickle_protocol: int = DEFAULT_PROTOCOL, ) -> None: """Constructor. @@ -73,9 +76,11 @@ def __init__( self.allow_overwrite = allow_overwrite self.pickle_module = pickle_module self.pickle_protocol = pickle_protocol - self.cached: Dict = {} + self.cached: dict = {} - def store(self, key, data_obj, pickle_module=None, pickle_protocol: Optional[int] = None): + def store( + self, key: Hashable, data_obj: Any, pickle_module: ModuleType | None = None, pickle_protocol: int | None = None + ) -> None: """ Store a given object with the given key name. @@ -107,7 +112,7 @@ def store(self, key, data_obj, pickle_module=None, pickle_protocol: Optional[int if hasattr(data_obj, "device"): self.cached[key]["device"] = data_obj.device - def retrieve(self, key): + def retrieve(self, key: Hashable) -> Any: """Retrieve the object stored under a given key name.""" if key not in self.cached: raise KeyError(f"Target {key} was not cached.") diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 2cb9bfd8c4..c5dd3a797c 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -9,8 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import re -from typing import Any, Optional, Sequence, Tuple, Type, Union +from collections.abc import Sequence +from typing import Any import numpy as np import torch @@ -85,13 +88,13 @@ def get_equivalent_dtype(dtype, data_type): return dtype_torch_to_numpy(dtype) -def get_dtype(data: Any): +def get_dtype(data: Any) -> DtypeLike | torch.dtype: """Get the dtype of an image, or if there is a sequence, recursively call the method on the 0th element. This therefore assumes that in a `Sequence`, all types are the same. """ if hasattr(data, "dtype"): - return data.dtype + return data.dtype # type: ignore # need recursion if isinstance(data, Sequence): return get_dtype(data[0]) @@ -100,13 +103,13 @@ def get_dtype(data: Any): def convert_to_tensor( - data, - dtype: Union[DtypeLike, torch.dtype] = None, - device: Union[None, str, torch.device] = None, + data: Any, + dtype: DtypeLike | torch.dtype = None, + device: None | str | torch.device = None, wrap_sequence: bool = False, track_meta: bool = False, safe: bool = False, -): +) -> Any: """ Utility to convert the input data to a PyTorch Tensor, if `track_meta` is True, the output will be a `MetaTensor`, otherwise, the output will be a regular torch Tensor. @@ -128,7 +131,7 @@ def convert_to_tensor( """ - def _convert_tensor(tensor, **kwargs): + def _convert_tensor(tensor: Any, **kwargs: Any) -> Any: if not isinstance(tensor, torch.Tensor): # certain numpy types are not supported as being directly convertible to Pytorch tensors if isinstance(tensor, np.ndarray) and tensor.dtype in UNSUPPORTED_TYPES: @@ -170,7 +173,7 @@ def _convert_tensor(tensor, **kwargs): return data -def convert_to_numpy(data, dtype: DtypeLike = None, wrap_sequence: bool = False, safe: bool = False): +def convert_to_numpy(data: Any, dtype: DtypeLike = None, wrap_sequence: bool = False, safe: bool = False) -> Any: """ Utility to convert the input data to a numpy array. If passing a dictionary, list or tuple, recursively check every item and convert it to numpy array. @@ -215,7 +218,7 @@ def convert_to_numpy(data, dtype: DtypeLike = None, wrap_sequence: bool = False, return data -def convert_to_cupy(data, dtype: Optional[np.dtype] = None, wrap_sequence: bool = False, safe: bool = False): +def convert_to_cupy(data: Any, dtype: np.dtype | None = None, wrap_sequence: bool = False, safe: bool = False) -> Any: """ Utility to convert the input data to a cupy array. If passing a dictionary, list or tuple, recursively check every item and convert it to cupy array. @@ -255,12 +258,12 @@ def convert_to_cupy(data, dtype: Optional[np.dtype] = None, wrap_sequence: bool def convert_data_type( data: Any, - output_type: Optional[Type[NdarrayTensor]] = None, - device: Union[None, str, torch.device] = None, - dtype: Union[DtypeLike, torch.dtype] = None, + output_type: type[NdarrayTensor] | None = None, + device: None | str | torch.device = None, + dtype: DtypeLike | torch.dtype = None, wrap_sequence: bool = False, safe: bool = False, -) -> Tuple[NdarrayTensor, type, Optional[torch.device]]: +) -> tuple[NdarrayTensor, type, torch.device | None]: """ Convert to `MetaTensor`, `torch.Tensor` or `np.ndarray` from `MetaTensor`, `torch.Tensor`, `np.ndarray`, `float`, `int`, etc. @@ -325,11 +328,11 @@ def convert_data_type( def convert_to_dst_type( src: Any, dst: NdarrayTensor, - dtype: Union[DtypeLike, torch.dtype, None] = None, + dtype: DtypeLike | torch.dtype | None = None, wrap_sequence: bool = False, - device: Union[None, str, torch.device] = None, + device: None | str | torch.device = None, safe: bool = False, -) -> Tuple[NdarrayTensor, type, Optional[torch.device]]: +) -> tuple[NdarrayTensor, type, torch.device | None]: """ Convert source data to the same data type and device as the destination data. If `dst` is an instance of `torch.Tensor` or its subclass, convert `src` to `torch.Tensor` with the same data type as `dst`, @@ -352,7 +355,7 @@ def convert_to_dst_type( device = dst.device if device is None and isinstance(dst, torch.Tensor) else device if dtype is None: - dtype = dst.dtype + dtype = getattr(dst, "dtype", None) # sequence has no dtype copy_meta = False output_type: Any @@ -375,7 +378,7 @@ def convert_to_dst_type( return output, _type, _device -def convert_to_list(data: Union[Sequence, torch.Tensor, np.ndarray]) -> list: +def convert_to_list(data: Sequence | torch.Tensor | np.ndarray) -> list: """ Convert to list from `torch.Tensor`/`np.ndarray`/`list`/`tuple` etc. Args: @@ -387,7 +390,7 @@ def convert_to_list(data: Union[Sequence, torch.Tensor, np.ndarray]) -> list: return data.tolist() if isinstance(data, (torch.Tensor, np.ndarray)) else list(data) -def get_dtype_bound_value(dtype: Union[DtypeLike, torch.dtype]): +def get_dtype_bound_value(dtype: DtypeLike | torch.dtype) -> tuple[float, float]: """ Get dtype bound value Args: @@ -406,7 +409,7 @@ def get_dtype_bound_value(dtype: Union[DtypeLike, torch.dtype]): return (np.iinfo(dtype).min, np.iinfo(dtype).max) -def safe_dtype_range(data: Any, dtype: Union[DtypeLike, torch.dtype] = None): +def safe_dtype_range(data: Any, dtype: DtypeLike | torch.dtype = None) -> Any: """ Utility to safely convert the input data to target dtype. diff --git a/monai/visualize/__init__.py b/monai/visualize/__init__.py index 49a628b66f..1f2bb7d024 100644 --- a/monai/visualize/__init__.py +++ b/monai/visualize/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .class_activation_maps import CAM, GradCAM, GradCAMpp, ModelWithHooks, default_normalizer from .gradient_based import GuidedBackpropGrad, GuidedBackpropSmoothGrad, SmoothGrad, VanillaGrad from .img2tensorboard import add_animated_gif, make_animated_gif_summary, plot_2d_or_3d_image diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index 8654fbcc71..81d0bb32c4 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -9,8 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import Callable, Dict, Optional, Sequence, Union, cast +from collections.abc import Callable, Sequence +from typing import cast import numpy as np import torch @@ -51,8 +54,8 @@ class ModelWithHooks: def __init__( self, - nn_module, - target_layer_names: Union[str, Sequence[str]], + nn_module: nn.Module, + target_layer_names: str | Sequence[str], register_forward: bool = False, register_backward: bool = False, ): @@ -67,10 +70,10 @@ def __init__( self.model = nn_module self.target_layers = ensure_tuple(target_layer_names) - self.gradients: Dict[str, torch.Tensor] = {} - self.activations: Dict[str, torch.Tensor] = {} - self.score: Optional[torch.Tensor] = None - self.class_idx: Optional[int] = None + self.gradients: dict[str, torch.Tensor] = {} + self.activations: dict[str, torch.Tensor] = {} + self.score: torch.Tensor | None = None + self.class_idx: int | None = None self.register_backward = register_backward self.register_forward = register_forward @@ -104,7 +107,7 @@ def _hook(_module, _input, output): return _hook - def get_layer(self, layer_id: Union[str, Callable]): + def get_layer(self, layer_id: str | Callable[[nn.Module], nn.Module]) -> nn.Module: """ Args: @@ -119,7 +122,7 @@ def get_layer(self, layer_id: Union[str, Callable]): if isinstance(layer_id, str): for name, mod in self.model.named_modules(): if name == layer_id: - return mod + return cast(nn.Module, mod) raise NotImplementedError(f"Could not find {layer_id}.") def class_score(self, logits: torch.Tensor, class_idx: int) -> torch.Tensor: @@ -259,7 +262,7 @@ def __init__( self, nn_module: nn.Module, target_layers: str, - fc_layers: Union[str, Callable] = "fc", + fc_layers: str | Callable = "fc", upsampler: Callable = default_upsampler, postprocessing: Callable = default_normalizer, ) -> None: diff --git a/monai/visualize/gradient_based.py b/monai/visualize/gradient_based.py index 7ab6ef260d..3b427c0dee 100644 --- a/monai/visualize/gradient_based.py +++ b/monai/visualize/gradient_based.py @@ -12,7 +12,7 @@ from __future__ import annotations from functools import partial -from typing import Callable +from typing import Any, Callable import torch @@ -84,7 +84,9 @@ def model(self, m): else: self._model = m # replace the ModelWithHooks - def get_grad(self, x: torch.Tensor, index: torch.Tensor | int | None, retain_graph=True, **kwargs) -> torch.Tensor: + def get_grad( + self, x: torch.Tensor, index: torch.Tensor | int | None, retain_graph: bool = True, **kwargs: Any + ) -> torch.Tensor: if x.shape[0] != 1: raise ValueError("expect batch size of 1") x.requires_grad = True @@ -93,7 +95,7 @@ def get_grad(self, x: torch.Tensor, index: torch.Tensor | int | None, retain_gra grad: torch.Tensor = x.grad.detach() # type: ignore return grad - def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor: + def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs: Any) -> torch.Tensor: return self.get_grad(x, index, **kwargs) @@ -125,7 +127,7 @@ def __init__( else: self.range = range - def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor: + def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs: Any) -> torch.Tensor: stdev = (self.stdev_spread * (x.max() - x.min())).item() total_gradients = torch.zeros_like(x) for _ in self.range(self.n_samples): @@ -156,7 +158,7 @@ class GuidedBackpropGrad(VanillaGrad): (https://arxiv.org/abs/1412.6806) """ - def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor: + def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs: Any) -> torch.Tensor: with replace_modules_temp(self.model, "relu", _GradReLU(), strict_match=False): return super().__call__(x, index, **kwargs) @@ -166,6 +168,6 @@ class GuidedBackpropSmoothGrad(SmoothGrad): Compute gradient-based saliency maps based on both ``GuidedBackpropGrad`` and ``SmoothGrad``. """ - def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor: + def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs: Any) -> torch.Tensor: with replace_modules_temp(self.model, "relu", _GradReLU(), strict_match=False): return super().__call__(x, index, **kwargs) diff --git a/monai/visualize/img2tensorboard.py b/monai/visualize/img2tensorboard.py index 6dd7abcbf1..e7884e9b1f 100644 --- a/monai/visualize/img2tensorboard.py +++ b/monai/visualize/img2tensorboard.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Optional, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import numpy as np import torch @@ -20,21 +22,30 @@ PIL, _ = optional_import("PIL") GifImage, _ = optional_import("PIL.GifImagePlugin", name="Image") -SummaryX, _ = optional_import("tensorboardX.proto.summary_pb2", name="Summary") -SummaryWriter, _ = optional_import("torch.utils.tensorboard", name="SummaryWriter") -SummaryWriterX, has_tensorboardx = optional_import("tensorboardX", name="SummaryWriter") if TYPE_CHECKING: from tensorboard.compat.proto.summary_pb2 import Summary + from tensorboardX import SummaryWriter as SummaryWriterX + from tensorboardX.proto.summary_pb2 import Summary as SummaryX + from torch.utils.tensorboard import SummaryWriter + + has_tensorboardx = True else: Summary, _ = optional_import("tensorboard.compat.proto.summary_pb2", name="Summary") + SummaryX, _ = optional_import("tensorboardX.proto.summary_pb2", name="Summary") + SummaryWriter, _ = optional_import("torch.utils.tensorboard", name="SummaryWriter") + SummaryWriterX, has_tensorboardx = optional_import("tensorboardX", name="SummaryWriter") __all__ = ["make_animated_gif_summary", "add_animated_gif", "plot_2d_or_3d_image"] def _image3_animated_gif( - tag: str, image: Union[np.ndarray, torch.Tensor], writer, frame_dim: int = 0, scale_factor: float = 1.0 -): + tag: str, + image: np.ndarray | torch.Tensor, + writer: SummaryWriter | SummaryWriterX | None, + frame_dim: int = 0, + scale_factor: float = 1.0, +) -> Any: """Function to actually create the animated gif. Args: @@ -68,8 +79,8 @@ def _image3_animated_gif( def make_animated_gif_summary( tag: str, - image: Union[np.ndarray, torch.Tensor], - writer=None, + image: np.ndarray | torch.Tensor, + writer: SummaryWriter | SummaryWriterX | None = None, max_out: int = 3, frame_dim: int = -3, scale_factor: float = 1.0, @@ -93,7 +104,7 @@ def make_animated_gif_summary( summary_op = [] for it_i in range(min(max_out, list(image.shape)[0])): - one_channel_img: Union[torch.Tensor, np.ndarray] = ( + one_channel_img: torch.Tensor | np.ndarray = ( image[it_i, :, :, :].squeeze(dim=0) if isinstance(image, torch.Tensor) else image[it_i, :, :, :] ) summary_op.append( @@ -103,13 +114,13 @@ def make_animated_gif_summary( def add_animated_gif( - writer, + writer: SummaryWriter | SummaryWriterX, tag: str, - image_tensor: Union[np.ndarray, torch.Tensor], + image_tensor: np.ndarray | torch.Tensor, max_out: int = 3, frame_dim: int = -3, scale_factor: float = 1.0, - global_step: Optional[int] = None, + global_step: int | None = None, ) -> None: """Creates an animated gif out of an image tensor in 'CHWD' format and writes it with SummaryWriter. @@ -133,9 +144,9 @@ def add_animated_gif( def plot_2d_or_3d_image( - data: Union[NdarrayTensor, List[NdarrayTensor]], + data: NdarrayTensor | list[NdarrayTensor], step: int, - writer, + writer: SummaryWriter | SummaryWriterX, index: int = 0, max_channels: int = 1, frame_dim: int = -3, diff --git a/monai/visualize/occlusion_sensitivity.py b/monai/visualize/occlusion_sensitivity.py index 61413c038c..b61a132147 100644 --- a/monai/visualize/occlusion_sensitivity.py +++ b/monai/visualize/occlusion_sensitivity.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Sequence -from typing import Callable, Optional, Tuple, Union +from __future__ import annotations + +from collections.abc import Callable, Mapping, Sequence +from typing import Any import numpy as np import torch @@ -84,16 +86,16 @@ class OcclusionSensitivity: def __init__( self, nn_module: nn.Module, - pad_val: Optional[float] = None, - mask_size: Union[int, Sequence] = 16, + pad_val: float | None = None, + mask_size: int | Sequence = 16, n_batch: int = 16, - stride: Union[int, Sequence] = 1, + stride: int | Sequence = 1, per_channel: bool = True, - upsampler: Optional[Callable] = default_upsampler, + upsampler: Callable | None = default_upsampler, verbose: bool = True, - mode: Union[str, float, Callable] = "gaussian", + mode: str | float | Callable = "gaussian", overlap: float = 0.25, - activate: Union[bool, Callable] = True, + activate: bool | Callable = True, ) -> None: """ Occlusion sensitivity constructor. @@ -132,13 +134,13 @@ def __init__( self.mode = mode @staticmethod - def constant_occlusion(x: torch.Tensor, val: float, mask_size: Sequence) -> Tuple[float, torch.Tensor]: + def constant_occlusion(x: torch.Tensor, val: float, mask_size: Sequence) -> tuple[float, torch.Tensor]: """Occlude with a constant occlusion. Multiplicative is zero, additive is constant value.""" ones = torch.ones((*x.shape[:2], *mask_size), device=x.device, dtype=x.dtype) return 0, ones * val @staticmethod - def gaussian_occlusion(x: torch.Tensor, mask_size, sigma=0.25) -> Tuple[torch.Tensor, float]: + def gaussian_occlusion(x: torch.Tensor, mask_size: Sequence, sigma: float = 0.25) -> tuple[torch.Tensor, float]: """ For Gaussian occlusion, Multiplicative is 1-Gaussian, additive is zero. Default sigma of 0.25 empirically shown to give reasonable kernel, see here: @@ -165,12 +167,12 @@ def predictor( cropped_grid: torch.Tensor, nn_module: nn.Module, x: torch.Tensor, - mul: Union[torch.Tensor, float], - add: Union[torch.Tensor, float], + mul: torch.Tensor | float, + add: torch.Tensor | float, mask_size: Sequence, occ_mode: str, - activate: Union[bool, Callable], - module_kwargs, + activate: bool | Callable, + module_kwargs: Mapping[str, Any], ) -> torch.Tensor: """ Predictor function to be passed to the sliding window inferer. Takes a cropped meshgrid, @@ -239,7 +241,7 @@ def predictor( @staticmethod def crop_meshgrid( grid: MetaTensor, b_box: Sequence, mask_size: Sequence - ) -> Tuple[MetaTensor, SpatialCrop, Sequence]: + ) -> tuple[MetaTensor, SpatialCrop, Sequence]: """Crop the meshgrid so we only perform occlusion sensitivity on a subsection of the image.""" # distance from center of mask to edge is -1 // 2. mask_edge = [(m - 1) // 2 for m in mask_size] @@ -264,8 +266,8 @@ def crop_meshgrid( return cropped, cropper, mask_size def __call__( - self, x: torch.Tensor, b_box: Optional[Sequence] = None, **kwargs - ) -> Tuple[torch.Tensor, torch.Tensor]: + self, x: torch.Tensor, b_box: Sequence | None = None, **kwargs: Any + ) -> tuple[torch.Tensor, torch.Tensor]: """ Args: x: Image to use for inference. Should be a tensor consisting of 1 batch. @@ -311,8 +313,8 @@ def __call__( raise ValueError(f"Image (spatial shape) {grid.shape[2:]} should be bigger than mask {mask_size}.") # get additive and multiplicative factors if they are unchanged for all patches (i.e., not mean_patch) - add: Optional[Union[float, torch.Tensor]] - mul: Optional[Union[float, torch.Tensor]] + add: float | torch.Tensor | None + mul: float | torch.Tensor | None # multiply by 0, add value if isinstance(self.mode, float): mul, add = self.constant_occlusion(x, self.mode, mask_size) diff --git a/monai/visualize/utils.py b/monai/visualize/utils.py index e722a1f0c5..f6718fe7a5 100644 --- a/monai/visualize/utils.py +++ b/monai/visualize/utils.py @@ -9,42 +9,48 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import numpy as np import torch -from monai.config.type_definitions import NdarrayOrTensor +from monai.config.type_definitions import DtypeLike, NdarrayOrTensor from monai.transforms.croppad.array import SpatialPad from monai.transforms.utils import rescale_array from monai.transforms.utils_pytorch_numpy_unification import repeat from monai.utils.module import optional_import from monai.utils.type_conversion import convert_data_type, convert_to_dst_type -plt, _ = optional_import("matplotlib", name="pyplot") -cm, _ = optional_import("matplotlib", name="cm") +if TYPE_CHECKING: + from matplotlib import cm + from matplotlib import pyplot as plt +else: + plt, _ = optional_import("matplotlib", name="pyplot") + cm, _ = optional_import("matplotlib", name="cm") __all__ = ["matshow3d", "blend_images"] def matshow3d( - volume, - fig=None, - title: Optional[str] = None, - figsize=(10, 10), - frames_per_row: Optional[int] = None, + volume: NdarrayOrTensor, + fig: Any = None, + title: str | None = None, + figsize: tuple[int, int] = (10, 10), + frames_per_row: int | None = None, frame_dim: int = -3, - channel_dim: Optional[int] = None, - vmin=None, - vmax=None, + channel_dim: int | None = None, + vmin: float | None = None, + vmax: float | None = None, every_n: int = 1, interpolation: str = "none", - show=False, - fill_value=np.nan, + show: bool = False, + fill_value: Any = np.nan, margin: int = 1, - dtype=np.float32, - **kwargs, -): + dtype: DtypeLike = np.float32, + **kwargs: Any, +) -> tuple[Any, np.ndarray]: """ Create a 3D volume figure as a grid of images. @@ -160,11 +166,11 @@ def matshow3d( def blend_images( image: NdarrayOrTensor, label: NdarrayOrTensor, - alpha: Union[float, NdarrayOrTensor] = 0.5, + alpha: float | NdarrayOrTensor = 0.5, cmap: str = "hsv", rescale_arrays: bool = True, transparent_background: bool = True, -): +) -> NdarrayOrTensor: """ Blend an image and a label. Both should have the shape CHW[D]. The image may have C==1 or 3 channels (greyscale or RGB). @@ -193,7 +199,6 @@ def blend_images( raise ValueError("image and label should have matching spatial sizes.") if isinstance(alpha, (np.ndarray, torch.Tensor)): if image.shape[1:] != alpha.shape[1:]: # pytype: disable=attribute-error,invalid-directive - raise ValueError("if alpha is image, size should match input image and label.") # rescale arrays to [0, 1] if desired @@ -204,7 +209,7 @@ def blend_images( if image.shape[0] == 1: image = repeat(image, 3, axis=0) - def get_label_rgb(cmap: str, label: NdarrayOrTensor): + def get_label_rgb(cmap: str, label: NdarrayOrTensor) -> NdarrayOrTensor: _cmap = cm.get_cmap(cmap) label_np, *_ = convert_data_type(label, np.ndarray) label_rgb_np = _cmap(label_np[0]) diff --git a/monai/visualize/visualizer.py b/monai/visualize/visualizer.py index 05ebb2e280..e7f5d9bbbe 100644 --- a/monai/visualize/visualizer.py +++ b/monai/visualize/visualizer.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable +from __future__ import annotations + +from collections.abc import Callable, Sized import torch import torch.nn.functional as F @@ -19,7 +21,7 @@ __all__ = ["default_upsampler"] -def default_upsampler(spatial_size, align_corners=False) -> Callable[[torch.Tensor], torch.Tensor]: +def default_upsampler(spatial_size: Sized, align_corners: bool = False) -> Callable[[torch.Tensor], torch.Tensor]: """ A linear interpolation method for upsampling the feature map. The output of this function is a callable `func`, @@ -27,7 +29,6 @@ def default_upsampler(spatial_size, align_corners=False) -> Callable[[torch.Tens """ def up(x): - linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR] interp_mode = linear_mode[len(spatial_size) - 1] return F.interpolate(x, size=spatial_size, mode=str(interp_mode.value), align_corners=align_corners) diff --git a/pyproject.toml b/pyproject.toml index 4a28aa10a0..d71613fb43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,3 +35,35 @@ exclude = ''' [tool.pycln] all = true exclude = "monai/bundle/__main__.py" + +[tool.pytype] +# Space-separated list of files or directories to exclude. +exclude = ["versioneer.py", "_version.py"] +# Space-separated list of files or directories to process. +inputs = ["monai"] +# Keep going past errors to analyze as many files as possible. +keep_going = true +# Run N jobs in parallel. +jobs = 8 +# All pytype output goes here. +output = ".pytype" +# Paths to source code directories, separated by ':'. +pythonpath = "." +# Check attribute values against their annotations. +check_attribute_types = true +# Check container mutations against their annotations. +check_container_types = true +# Check parameter defaults and assignments against their annotations. +check_parameter_types = true +# Check variable values against their annotations. +check_variable_types = true +# Comma or space separated list of error names to ignore. +disable = ["pyi-error"] +# Report errors. +report_errors = true +# Experimental: Infer precise return types even for invalid function calls. +precise_return = true +# Experimental: solve unknown types to label with structural types. +protocols = true +# Experimental: Only load submodules that are explicitly imported. +strict_import = false diff --git a/requirements-dev.txt b/requirements-dev.txt index 84655dc828..8b4433b39d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,12 +1,12 @@ # Full requirements for developments -r requirements-min.txt -pytorch-ignite==0.4.10 +pytorch-ignite==0.4.11 gdown>=4.4.0 scipy itk>=5.2 nibabel pillow!=8.3.0 # https://github.com/python-pillow/Pillow/issues/5571 -tensorboard +tensorboard>=2.6 # https://github.com/Project-MONAI/MONAI/issues/5776 scikit-image>=0.19.0 tqdm>=4.47.0 lmdb @@ -27,10 +27,6 @@ mypy>=0.790 ninja torchvision psutil -Sphinx==3.5.3 -recommonmark==0.6.0 -sphinx-autodoc-typehints==1.11.1 -sphinx-rtd-theme==0.5.2 cucim==22.8.1; platform_system == "Linux" openslide-python==1.1.2 imagecodecs; platform_system == "Linux" or platform_system == "Darwin" @@ -39,7 +35,8 @@ pandas requests einops transformers<4.22 # https://github.com/Project-MONAI/MONAI/issues/5157 -mlflow +mlflow>=1.28.0 +clearml>=1.10.0rc0 matplotlib!=3.5.0 tensorboardX types-PyYAML @@ -52,3 +49,4 @@ pydicom h5py nni optuna +git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded diff --git a/requirements-min.txt b/requirements-min.txt index 63906b4a94..ad0bb1ef20 100644 --- a/requirements-min.txt +++ b/requirements-min.txt @@ -1,5 +1,5 @@ # Requirements for minimal tests -r requirements.txt -setuptools>=50.3.0,!=60.0.0,!=60.6.0 +setuptools>=50.3.0,<66.0.0,!=60.6.0 coverage>=5.5 parameterized diff --git a/requirements.txt b/requirements.txt index ba7d7be6d7..5a704330de 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ torch>=1.8 -numpy>=1.17 +numpy>=1.20 diff --git a/runtests.sh b/runtests.sh index a752d16d6d..64d78db682 100755 --- a/runtests.sh +++ b/runtests.sh @@ -159,9 +159,16 @@ function clang_format { while read i; do $clang_format_tool -style=file -i $i; done } +function is_pip_installed() { + return $(${PY_EXE} -c "import sys, pkgutil; sys.exit(0 if pkgutil.find_loader(sys.argv[1]) else 1)" $1) +} + function clean_py { - # remove coverage history - ${cmdPrefix}${PY_EXE} -m coverage erase + if is_pip_installed coverage + then + # remove coverage history + ${cmdPrefix}${PY_EXE} -m coverage erase + fi # uninstall the development package echo "Uninstalling MONAI development files..." @@ -200,10 +207,6 @@ function print_style_fail_msg() { echo "Please run auto style fixes: ${green}./runtests.sh --autofix${noColor}" } -function is_pip_installed() { - return $(${PY_EXE} -c "import sys, pkgutil; sys.exit(0 if pkgutil.find_loader(sys.argv[1]) else 1)" $1) -} - function list_unittests() { ${PY_EXE} - << END import unittest @@ -255,8 +258,6 @@ do doIsortFormat=true doFlake8Format=true doPylintFormat=true - doPytypeFormat=true - doMypyFormat=true doCopyRight=true ;; --disttests) @@ -567,7 +568,7 @@ then else ${cmdPrefix}${PY_EXE} -m pytype --version - ${cmdPrefix}${PY_EXE} -m pytype -j ${NUM_PARALLEL} --python-version="$(${PY_EXE} -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")" + ${cmdPrefix}${PY_EXE} -m pytype -j ${NUM_PARALLEL} --python-version="$(${PY_EXE} -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")" "$(pwd)" pytype_status=$? if [ ${pytype_status} -ne 0 ] @@ -629,6 +630,11 @@ fi if [ $doCoverage = true ] then echo "${separator}${blue}coverage${noColor}" + # ensure that the necessary packages for code format testing are installed + if ! is_pip_installed coverage + then + install_deps + fi cmd="${PY_EXE} -m coverage run --append" fi @@ -680,6 +686,11 @@ fi if [ $doCoverage = true ] then echo "${separator}${blue}coverage${noColor}" + # ensure that the necessary packages for code format testing are installed + if ! is_pip_installed coverage + then + install_deps + fi ${cmdPrefix}${PY_EXE} -m coverage combine --append .coverage/ ${cmdPrefix}${PY_EXE} -m coverage report --ignore-errors fi diff --git a/setup.cfg b/setup.cfg index 30352d50db..379a031b89 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,7 +43,7 @@ setup_requires = ninja install_requires = torch>=1.8 - numpy>=1.17 + numpy>=1.20 [options.extras_require] all = @@ -53,7 +53,7 @@ all = pillow tensorboard gdown>=4.4.0 - pytorch-ignite==0.4.10 + pytorch-ignite==0.4.11 torchvision itk>=5.2 tqdm>=4.47.0 @@ -66,7 +66,8 @@ all = pandas einops transformers<4.22 - mlflow + mlflow>=1.28.0 + clearml>=1.10.0rc0 matplotlib tensorboardX pyyaml @@ -90,7 +91,7 @@ tensorboard = gdown = gdown>=4.4.0 ignite = - pytorch-ignite==0.4.10 + pytorch-ignite==0.4.11 torchvision = torchvision itk = @@ -133,6 +134,9 @@ pydicom = pydicom h5py = h5py +# # workaround https://github.com/Project-MONAI/MONAI/issues/5882 +# MetricsReloaded = +# MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded [flake8] select = B,C,E,F,N,P,T4,W,B9 @@ -141,6 +145,8 @@ max_line_length = 120 # E501 is not flexible enough, we're using B950 instead # N812 lowercase 'torch.nn.functional' imported as non lowercase 'F' # B023 https://github.com/Project-MONAI/MONAI/issues/4627 +# B028 https://github.com/Project-MONAI/MONAI/issues/5855 +# B907 https://github.com/Project-MONAI/MONAI/issues/5868 ignore = E203 E501 @@ -151,6 +157,8 @@ ignore = N812 B023 B905 + B028 + B907 per_file_ignores = __init__.py: F401, __main__.py: F401 exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv,.venv,_version.py @@ -160,6 +168,8 @@ profile = black line_length = 120 skip = .git, .eggs, venv, .venv, versioneer.py, _version.py, conf.py, monai/__init__.py skip_glob = *.pyi +add_imports = from __future__ import annotations +append_only = true [versioneer] VCS = git @@ -176,8 +186,8 @@ ignore_missing_imports = True no_implicit_optional = True # Warns about casting an expression to its inferred type. warn_redundant_casts = True -# Warns about unneeded # type: ignore comments. -# warn_unused_ignores = True +# No error on unneeded # type: ignore comments. +warn_unused_ignores = False # Shows a warning when returning a value with type Any from a function declared with a non-Any return type. warn_return_any = True # Prohibit equality checks, identity checks, and container checks between non-overlapping types. @@ -192,6 +202,9 @@ pretty = False warn_unused_configs = True # Make arguments prepended via Concatenate be truly positional-only. strict_concatenate = True +# Allows variables to be redefined with an arbitrary type, +# as long as the redefinition is in the same block and nesting level as the original definition. +# allow_redefinition = True exclude = venv/ @@ -208,39 +221,13 @@ ignore_errors = True ignore_errors = True [mypy-monai.*] +# Also check the body of functions with no types in their type signature. check_untyped_defs = True +# Warns about usage of untyped decorators. +disallow_untyped_decorators = True -[pytype] -# Space-separated list of files or directories to exclude. -exclude = versioneer.py _version.py -# Space-separated list of files or directories to process. -inputs = monai -# Keep going past errors to analyze as many files as possible. -keep_going = True -# Run N jobs in parallel. -jobs = 8 -# All pytype output goes here. -output = .pytype -# Paths to source code directories, separated by ':'. -pythonpath = . -# Check attribute values against their annotations. -check_attribute_types = True -# Check container mutations against their annotations. -check_container_types = True -# Check parameter defaults and assignments against their annotations. -check_parameter_types = True -# Check variable values against their annotations. -check_variable_types = True -# Comma or space separated list of error names to ignore. -disable = pyi-error -# Report errors. -report_errors = True -# Experimental: Infer precise return types even for invalid function calls. -precise_return = True -# Experimental: solve unknown types to label with structural types. -protocols = True -# Experimental: Only load submodules that are explicitly imported. -strict_import = False +[mypy-monai.visualize.*,monai.utils.*,monai.optimizers.*,monai.losses.*,monai.inferers.*,monai.config.*,monai._extensions.*,monai.fl.*,monai.engines.*,monai.handlers.*,monai.auto3dseg.*,monai.bundle.*,monai.metrics.*,monai.apps.*] +disallow_incomplete_defs = True [coverage:run] concurrency = multiprocessing diff --git a/setup.py b/setup.py index faf63a9246..b90d9d0976 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import glob import os import re diff --git a/tests/__init__.py b/tests/__init__.py index 0d6e28a679..58422f803e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import sys import unittest import warnings diff --git a/tests/clang_format_utils.py b/tests/clang_format_utils.py index 1b13ce0ac3..11483e957d 100644 --- a/tests/clang_format_utils.py +++ b/tests/clang_format_utils.py @@ -12,6 +12,8 @@ # this file is adapted from # github/pytorch/pytorch/blob/63d62d3e44a0a4ec09d94f30381d49b78cc5b095/tools/clang_format_utils.py +from __future__ import annotations + import os import platform import stat diff --git a/tests/croppers.py b/tests/croppers.py index 8f78249d90..6b5933458e 100644 --- a/tests/croppers.py +++ b/tests/croppers.py @@ -9,12 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy import numpy as np from monai.data.meta_tensor import MetaTensor +from monai.transforms import Randomizable +from monai.transforms.lazy.functional import apply_transforms from monai.transforms.transform import MapTransform from tests.utils import TEST_NDARRAYS_ALL, assert_allclose @@ -101,3 +105,61 @@ def multi_inverse(self, input_shape, init_params): # there should be as many zeros as elements missing from uniques missing = input_data.size - len(uniques) self.assertEqual((inv_np == 0).sum(), missing) + + def crop_test_pending_ops(self, input_param, input_shape, align_corners=False): + crop_fn = self.Cropper(**input_param) + data = self.get_arr(input_shape) + is_map = isinstance(crop_fn, MapTransform) + im = MetaTensor(data, meta={"a": "b", "affine": np.eye(len(input_shape))}) + input_data = {"img": im} if is_map else im + # non-lazy + result_non_lazy = crop_fn(input_data) + expected = result_non_lazy["img"] if is_map else result_non_lazy + self.assertIsInstance(expected, MetaTensor) + # lazy + crop_fn.lazy_evaluation = True + pending_result = crop_fn(input_data) + pending_result = pending_result["img"] if is_map else pending_result + self.assertIsInstance(pending_result, MetaTensor) + assert_allclose(pending_result.peek_pending_affine(), expected.affine) + assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) + # only support nearest + result = apply_transforms(pending_result, mode="nearest", align_corners=align_corners)[0] + # compare + assert_allclose(result, expected, rtol=1e-5) + + def crop_test_combine_ops(self, funcs, input_shape): + _funcs = [] + for func in funcs: + for _func, _params in func.items(): + _funcs.append(_func(**_params)) + is_map = isinstance(_funcs[0], MapTransform) + data = self.get_arr(input_shape) + im = MetaTensor(data, meta={"a": "b", "affine": np.eye(len(input_shape))}) + input_data = {"img": im} if is_map else im + + # non-lazy + non_lazy_result = input_data + for _func in _funcs: + if isinstance(_func, Randomizable): + _func.set_random_state(seed=123) + non_lazy_result = _func(non_lazy_result) + expected = non_lazy_result["img"] if is_map else non_lazy_result + self.assertIsInstance(expected, MetaTensor) + + # lazy + pending_result = input_data + for _func in _funcs: + _func.lazy_evaluation = True + if isinstance(_func, Randomizable): + _func.set_random_state(seed=123) + pending_result = _func(pending_result) + pending_result = pending_result["img"] if is_map else pending_result + self.assertIsInstance(pending_result, MetaTensor) + assert_allclose(pending_result.peek_pending_affine(), expected.affine) + assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) + # TODO: mode="bilinear" may report error + result = apply_transforms(pending_result, mode="nearest", align_corners=False)[0] + + # compare + assert_allclose(result, expected, rtol=1e-5) diff --git a/tests/hvd_evenly_divisible_all_gather.py b/tests/hvd_evenly_divisible_all_gather.py index a79038f4be..c7baac2bc9 100644 --- a/tests/hvd_evenly_divisible_all_gather.py +++ b/tests/hvd_evenly_divisible_all_gather.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import torch from monai.utils import evenly_divisible_all_gather diff --git a/tests/lazy_transforms_utils.py b/tests/lazy_transforms_utils.py new file mode 100644 index 0000000000..012b39dceb --- /dev/null +++ b/tests/lazy_transforms_utils.py @@ -0,0 +1,77 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from monai.data import set_track_meta +from monai.transforms import Randomizable +from monai.transforms.lazy.functional import apply_transforms +from tests.utils import assert_allclose + +apply_transforms_kwargs = ("pending", "mode", "padding_mode", "dtype", "align_corners") + + +def get_apply_param(init_param=None, call_param=None, params=apply_transforms_kwargs): + apply_param = {} + for key in apply_transforms_kwargs: + if init_param: + if key in init_param.keys(): + apply_param[key] = init_param[key] + if call_param: + if key in call_param.keys(): + apply_param[key] = call_param[key] + return apply_param + + +def test_resampler_lazy( + resampler, + expected_output, + init_param=None, + call_param=None, + output_key=None, + output_idx=None, + rtol=1e-5, + atol=1e-7, + skip_shape_check=False, + seed=None, +): + """ + This test function is used to test the consistency between non-lazy and lazy transforms. + Args: + resampler: instance of a resampling transform. + expected_output: output of non-lazy transform. + init_param: parameters that are used to initialize the transform. + call_param: parameters that are used when calling the transform. + output_key: key to get the output of the transform. This argument is used for dictionary based transforms. + output_idx: index to get the expected output from multiple outputs of the transform. + rtol: relative tolerance. This argument is only used to compare the output. + atol: absolute tolerance. This argument is only used to compare the output. + skip_shape_check: skip the check of shapes. + seed: set the random state with an integer seed. This argument is used for randomizable transforms. + + """ + if isinstance(resampler, Randomizable): + resampler.set_random_state(seed=seed) + set_track_meta(True) + resampler.lazy_evaluation = True + pending_output = resampler(**call_param) + if output_idx is not None: + expected_output, pending_output = expected_output[output_idx], pending_output[output_idx] + if output_key is not None: + non_lazy_out, lazy_out = expected_output[output_key], pending_output[output_key] + else: + non_lazy_out, lazy_out = expected_output, pending_output + assert_allclose(lazy_out.peek_pending_affine(), non_lazy_out.affine) + if not skip_shape_check: + assert_allclose(lazy_out.peek_pending_shape(), non_lazy_out.shape[1:4]) + apply_param = get_apply_param(init_param, call_param) + lazy_out = apply_transforms(lazy_out, **apply_param)[0] + assert_allclose(lazy_out, non_lazy_out, rtol=rtol, atol=atol) diff --git a/tests/min_tests.py b/tests/min_tests.py index e5d5dac41d..1b1f4f450a 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import glob import os import sys @@ -65,6 +67,7 @@ def run_testsuit(): "test_global_mutual_information_loss", "test_grid_patch", "test_gmm", + "test_handler_metrics_reloaded", "test_handler_checkpoint_loader", "test_handler_checkpoint_saver", "test_handler_classification_saver", @@ -129,7 +132,6 @@ def run_testsuit(): "test_mlp", "test_nifti_header_revise", "test_nifti_rw", - "test_nifti_saver", "test_nuclick_transforms", "test_nrrd_reader", "test_occlusion_sensitivity", @@ -141,7 +143,6 @@ def run_testsuit(): "test_pil_reader", "test_plot_2d_or_3d_image", "test_png_rw", - "test_png_saver", "test_prepare_batch_default", "test_prepare_batch_extra_input", "test_prepare_batch_hovernet", @@ -190,6 +191,9 @@ def run_testsuit(): "test_bundle_utils", "test_bundle_init_bundle", "test_fastmri_reader", + "test_metrics_reloaded", + "test_spatial_combine_transforms", + "test_bundle_workflow", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" @@ -209,7 +213,6 @@ def run_testsuit(): if __name__ == "__main__": - # testing import submodules from monai.utils.module import load_submodules diff --git a/tests/ngc_bundle_download.py b/tests/ngc_bundle_download.py index 2b376c3c2d..f380626d73 100644 --- a/tests/ngc_bundle_download.py +++ b/tests/ngc_bundle_download.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import sys import tempfile diff --git a/tests/padders.py b/tests/padders.py index 367d2059b9..ded427e5a1 100644 --- a/tests/padders.py +++ b/tests/padders.py @@ -9,20 +9,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import List import numpy as np import torch from monai.data.meta_tensor import MetaTensor +from monai.transforms import Compose +from monai.transforms.lazy.functional import apply_transforms from monai.transforms.transform import MapTransform from monai.utils.enums import NumpyPadMode, PytorchPadMode from tests.utils import TEST_NDARRAYS_ALL, assert_allclose MODES = [] # Test modes -NP_MODES: List = [ +NP_MODES: list = [ "constant", "edge", # `reflect` mode is not supported in some PyTorch versions, skip the test @@ -44,6 +47,8 @@ MODES += PT_MODES MODES += [PytorchPadMode(i) for i in PT_MODES] +TESTS_PENDING_MODE = [["constant", "zeros"], ["edge", "border"]] + class PadTest(unittest.TestCase): @staticmethod @@ -108,3 +113,56 @@ def pad_test_kwargs(self, unchanged_slices, **input_param): inv = padder.inverse(result) assert_allclose(im, inv, type_test=False) self.assertEqual(inv.applied_operations, []) + + def pad_test_pending_ops(self, input_param, input_shape): + for mode in TESTS_PENDING_MODE: + # TODO: One of the dim in the input data contains 1 report error. + pad_fn = self.Padder(mode=mode[0], **input_param) + data = self.get_arr(input_shape) + is_map = isinstance(pad_fn, MapTransform) + im = MetaTensor(data, meta={"a": "b", "affine": np.eye(len(input_shape))}) + input_data = {"img": im} if is_map else im + # non-lazy + result_non_lazy = pad_fn(input_data) + expected = result_non_lazy["img"] if is_map else result_non_lazy + self.assertIsInstance(expected, MetaTensor) + # lazy + pad_fn.lazy_evaluation = True + pending_result = pad_fn(input_data) + pending_result = pending_result["img"] if is_map else pending_result + self.assertIsInstance(pending_result, MetaTensor) + assert_allclose(pending_result.peek_pending_affine(), expected.affine) + assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) + # TODO: mode="bilinear" may report error + result = apply_transforms(pending_result, mode="nearest", padding_mode=mode[1], align_corners=False)[0] + # compare + assert_allclose(result, expected, rtol=1e-5) + + def pad_test_combine_ops(self, funcs, input_shape, expected_shape): + for mode in TESTS_PENDING_MODE: + # non-lazy + _funcs = [] + for func in funcs: + for _func, _params in func.items(): + _funcs.append(_func(mode=mode[0], **_params)) + trans = Compose(_funcs) + data = self.get_arr(input_shape) + is_map = isinstance(_funcs[0], MapTransform) + im = MetaTensor(data, meta={"a": "b", "affine": np.eye(len(input_shape))}) + input_data = {"img": im} if is_map else im + result_non_lazy = trans(input_data) + expected = result_non_lazy["img"] if is_map else result_non_lazy + self.assertIsInstance(expected, MetaTensor) + # lazy + pending_result = input_data + for _func in _funcs: + _func.lazy_evaluation = True + pending_result = _func(pending_result) + pending_result = pending_result["img"] if is_map else pending_result + self.assertIsInstance(pending_result, MetaTensor) + assert_allclose(pending_result.peek_pending_affine(), expected.affine) + assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) + # TODO: mode="bilinear" may report error + result = apply_transforms(pending_result, mode="nearest", padding_mode=mode[1], align_corners=False)[0] + # compare + assert_allclose(result, expected, rtol=1e-5) diff --git a/tests/profile_subclass/cprofile_profiling.py b/tests/profile_subclass/cprofile_profiling.py index 0befa0f450..a40b54def6 100644 --- a/tests/profile_subclass/cprofile_profiling.py +++ b/tests/profile_subclass/cprofile_profiling.py @@ -12,6 +12,8 @@ Profiling MetaTensor """ +from __future__ import annotations + import cProfile import torch diff --git a/tests/profile_subclass/min_classes.py b/tests/profile_subclass/min_classes.py index 702ba73e21..7104ffcd59 100644 --- a/tests/profile_subclass/min_classes.py +++ b/tests/profile_subclass/min_classes.py @@ -13,6 +13,8 @@ Adapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark """ +from __future__ import annotations + import torch __all__ = ["SubTensor", "SubWithTorchFunc"] diff --git a/tests/profile_subclass/profiling.py b/tests/profile_subclass/profiling.py index 46047b619c..ffa6a8b17d 100644 --- a/tests/profile_subclass/profiling.py +++ b/tests/profile_subclass/profiling.py @@ -12,6 +12,8 @@ Comparing torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor Adapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark """ +from __future__ import annotations + import argparse import torch diff --git a/tests/profile_subclass/pyspy_profiling.py b/tests/profile_subclass/pyspy_profiling.py index 1caeee69e7..c1b0963ba9 100644 --- a/tests/profile_subclass/pyspy_profiling.py +++ b/tests/profile_subclass/pyspy_profiling.py @@ -12,6 +12,8 @@ To be used with py-spy, comparing torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor Adapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark """ +from __future__ import annotations + import argparse import torch diff --git a/tests/runner.py b/tests/runner.py index 7356581365..7a7cc9f28f 100644 --- a/tests/runner.py +++ b/tests/runner.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import argparse import glob import inspect @@ -112,7 +114,6 @@ def get_default_pattern(loader): if __name__ == "__main__": - # Parse input arguments args = parse_args() diff --git a/tests/test_acn_block.py b/tests/test_acn_block.py index 4c12155fd8..2f3783cbb8 100644 --- a/tests/test_acn_block.py +++ b/tests/test_acn_block.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_activations.py b/tests/test_activations.py index 503ca0a350..0e83c73304 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch @@ -83,6 +85,13 @@ (1, 2, 5), ] +TEST_CASE_7 = [ + "geglu", + torch.tensor([[[-10, -8, -6, -4, -2, 0], [0, 2, 4, 6, 8, 10]]], dtype=torch.float32), + torch.tensor([[[1.27e-03, 3.64e-01, 0.00e00], [0.00e00, 1.60e01, 4.00e01]]]), + (1, 2, 3), +] + class TestActivations(unittest.TestCase): @parameterized.expand(TEST_CASES) @@ -99,7 +108,7 @@ def _compare(ret, out, shape): else: _compare(result, out, expected_shape) - @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) def test_monai_activations_value_shape(self, input_param, img, out, expected_shape): act = Act[input_param]() result = act(img) diff --git a/tests/test_activationsd.py b/tests/test_activationsd.py index a8f8f600a4..22a275997c 100644 --- a/tests/test_activationsd.py +++ b/tests/test_activationsd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_adaptors.py b/tests/test_adaptors.py index f59bdaa15e..257c4346ad 100644 --- a/tests/test_adaptors.py +++ b/tests/test_adaptors.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import itertools import unittest diff --git a/tests/test_add_channeld.py b/tests/test_add_channeld.py index 9dc984aff3..0c0a1ffa49 100644 --- a/tests/test_add_channeld.py +++ b/tests/test_add_channeld.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_add_coordinate_channels.py b/tests/test_add_coordinate_channels.py index 5a483e25b9..cd33f98fd5 100644 --- a/tests/test_add_coordinate_channels.py +++ b/tests/test_add_coordinate_channels.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_add_coordinate_channelsd.py b/tests/test_add_coordinate_channelsd.py index c14ff0ba64..f5784928fd 100644 --- a/tests/test_add_coordinate_channelsd.py +++ b/tests/test_add_coordinate_channelsd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_add_extreme_points_channel.py b/tests/test_add_extreme_points_channel.py index 116f96126f..140caa34ba 100644 --- a/tests/test_add_extreme_points_channel.py +++ b/tests/test_add_extreme_points_channel.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_add_extreme_points_channeld.py b/tests/test_add_extreme_points_channeld.py index f9837e9ef4..5640e696fc 100644 --- a/tests/test_add_extreme_points_channeld.py +++ b/tests/test_add_extreme_points_channeld.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_adjust_contrast.py b/tests/test_adjust_contrast.py index 1c38d0edf3..c239f43346 100644 --- a/tests/test_adjust_contrast.py +++ b/tests/test_adjust_contrast.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_adjust_contrastd.py b/tests/test_adjust_contrastd.py index 2d674c6003..6de2658a5b 100644 --- a/tests/test_adjust_contrastd.py +++ b/tests/test_adjust_contrastd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_adn.py b/tests/test_adn.py index ffd28f8fc8..27e23a08d3 100644 --- a/tests/test_adn.py +++ b/tests/test_adn.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_affine.py b/tests/test_affine.py index 019a8f59a4..e8f7f33b17 100644 --- a/tests/test_affine.py +++ b/tests/test_affine.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy @@ -17,7 +19,10 @@ from parameterized import parameterized from monai.data import MetaTensor, set_track_meta -from monai.transforms import Affine +from monai.transforms import Affine, Resize +from monai.transforms.lazy.functional import apply_transforms +from monai.utils import optional_import +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_NDARRAYS_ALL, assert_allclose, test_local_inversion TESTS = [] @@ -58,6 +63,13 @@ p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), ] ) + TESTS.append( + [ + dict(rotate_params=[np.pi / 2], padding_mode="zeros", device=device, align_corners=False), + {"img": p(np.arange(4).reshape((1, 2, 2))), "spatial_size": (4, 4)}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) TESTS.append( [ dict( @@ -160,8 +172,11 @@ def test_affine(self, input_param, input_data, expected_val): input_copy = deepcopy(input_data["img"]) g = Affine(**input_param) result = g(**input_data) + output_idx = None if isinstance(result, tuple): - result = result[0] + output_idx = 0 + result = result[output_idx] + test_local_inversion(g, result, input_copy) assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4, type_test=False) @@ -173,6 +188,56 @@ def test_affine(self, input_param, input_data, expected_val): self.assertIsInstance(result, torch.Tensor) set_track_meta(True) + # test lazy + lazy_input_param = input_param.copy() + for align_corners in [True, False]: + lazy_input_param["align_corners"] = align_corners + resampler = Affine(**lazy_input_param) + non_lazy_result = resampler(**input_data) + test_resampler_lazy(resampler, non_lazy_result, lazy_input_param, input_data, output_idx=output_idx) + + +@unittest.skipUnless(optional_import("scipy")[1], "Requires scipy library.") +class TestAffineConsistency(unittest.TestCase): + @parameterized.expand([[7], [8], [9]]) + def test_affine_resize(self, s): + """s""" + im = np.arange(4).reshape(1, 2, 2).astype(float) + mat = np.array([[1 / s, 0, 0], [0, 1 / s, 0], [0, 0, 1]]) + sp_size = 2 * s + + def method_0(im, ac): + xform = Affine(align_corners=ac, affine=mat, image_only=True, spatial_size=sp_size) + xform.lazy_evaluation = True + out = xform(im) + out = apply_transforms(out, padding_mode="border", align_corners=ac)[0] + return out + + def method_1(im, ac): + xform = Affine(align_corners=ac, affine=mat, image_only=True, spatial_size=sp_size) + xform.lazy_evaluation = True + out = xform(im) + out = apply_transforms(out, mode=1, padding_mode="nearest", align_corners=ac)[0] + return out + + def method_2(im, ac): + xform = Affine(align_corners=ac, affine=mat, padding_mode="border", image_only=True, spatial_size=sp_size) + out = xform(im) + return out + + def method_3(im, ac): + xform = Affine( + align_corners=ac, affine=mat, mode=1, padding_mode="nearest", image_only=True, spatial_size=sp_size + ) + out = xform(im) + return out + + for call in (method_0, method_1, method_2, method_3): + for ac in (False, True): + out = call(im, ac) + ref = Resize(align_corners=ac, spatial_size=(sp_size, sp_size), mode="bilinear")(im) + assert_allclose(out, ref, rtol=1e-4, atol=1e-4, type_test=False) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_affine_grid.py b/tests/test_affine_grid.py index 23651c8b6b..f3febbe0f3 100644 --- a/tests/test_affine_grid.py +++ b/tests/test_affine_grid.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_affine_transform.py b/tests/test_affine_transform.py index 902437350a..39dc609167 100644 --- a/tests/test_affine_transform.py +++ b/tests/test_affine_transform.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -23,14 +25,14 @@ TEST_NORM_CASES = [ [(4, 5), True, [[[0.666667, 0, -1], [0, 0.5, -1], [0, 0, 1]]]], - [(4, 5), True, [[[0.666667, 0, 0], [0, 0.5, 0], [0, 0, 1]]], True], + [(4, 5), True, [[[0.5, 0, 0], [0, 0.4, 0], [0, 0, 1]]], True], [ (2, 4, 5), True, [[[2.0, 0.0, 0.0, -1.0], [0.0, 0.6666667, 0.0, -1.0], [0.0, 0.0, 0.5, -1.0], [0.0, 0.0, 0.0, 1.0]]], ], [(4, 5), False, [[[0.5, 0.0, -0.75], [0.0, 0.4, -0.8], [0.0, 0.0, 1.0]]]], - [(4, 5), False, [[[0.5, 0.0, 0.25], [0.0, 0.4, 0.2], [0.0, 0.0, 1.0]]], True], + [(4, 5), False, [[[0.6666667, 0.0, 0.0], [0.0, 0.5, 0.0], [0.0, 0.0, 1.0]]], True], [(2, 4, 5), False, [[[1.0, 0.0, 0.0, -0.5], [0.0, 0.5, 0.0, -0.75], [0.0, 0.0, 0.4, -0.8], [0.0, 0.0, 0.0, 1.0]]]], ] @@ -68,7 +70,7 @@ (2, 4, 6), (3, 5, 3), False, - [[[1.5, 0.0, 0.0, 0.0], [0.0, 1.25, 0.0, 0.0], [0.0, 0.0, 0.5, 0.0], [0.0, 0.0, 0.0, 1.0]]], + [[[2.0, 0.0, 0.0, 0.0], [0.0, 1.3333334, 0.0, 0.0], [0.0, 0.0, 0.4, 0.0], [0.0, 0.0, 0.0, 1.0]]], True, ], ] @@ -131,7 +133,7 @@ class TestAffineTransform(unittest.TestCase): def test_affine_shift(self): affine = torch.as_tensor([[1.0, 0.0, 0.0], [0.0, 1.0, -1.0]]) image = torch.as_tensor([[[[4.0, 1.0, 3.0, 2.0], [7.0, 6.0, 8.0, 5.0], [3.0, 5.0, 3.0, 6.0]]]]) - out = AffineTransform()(image, affine) + out = AffineTransform(align_corners=False)(image, affine) out = out.detach().cpu().numpy() expected = [[[[0, 4, 1, 3], [0, 7, 6, 8], [0, 3, 5, 3]]]] np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol) @@ -139,7 +141,7 @@ def test_affine_shift(self): def test_affine_shift_1(self): affine = torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, -1.0]]) image = torch.as_tensor([[[[4.0, 1.0, 3.0, 2.0], [7.0, 6.0, 8.0, 5.0], [3.0, 5.0, 3.0, 6.0]]]]) - out = AffineTransform()(image, affine) + out = AffineTransform(align_corners=False)(image, affine) out = out.detach().cpu().numpy() expected = [[[[0, 0, 0, 0], [0, 4, 1, 3], [0, 7, 6, 8]]]] np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol) @@ -147,7 +149,7 @@ def test_affine_shift_1(self): def test_affine_shift_2(self): affine = torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, 0.0]]) image = torch.as_tensor([[[[4.0, 1.0, 3.0, 2.0], [7.0, 6.0, 8.0, 5.0], [3.0, 5.0, 3.0, 6.0]]]]) - out = AffineTransform()(image, affine) + out = AffineTransform(align_corners=False)(image, affine) out = out.detach().cpu().numpy() expected = [[[[0, 0, 0, 0], [4, 1, 3, 2], [7, 6, 8, 5]]]] np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol) @@ -155,7 +157,7 @@ def test_affine_shift_2(self): def test_zoom(self): affine = torch.as_tensor([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]]) image = torch.arange(1.0, 13.0).view(1, 1, 3, 4).to(device=torch.device("cpu:0")) - out = AffineTransform((3, 2))(image, affine) + out = AffineTransform((3, 2), align_corners=False)(image, affine) expected = [[[[1, 3], [5, 7], [9, 11]]]] np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol) @@ -163,21 +165,21 @@ def test_zoom_1(self): affine = torch.as_tensor([[2.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) image = torch.arange(1.0, 13.0).view(1, 1, 3, 4).to(device=torch.device("cpu:0")) out = AffineTransform()(image, affine, (1, 4)) - expected = [[[[1, 2, 3, 4]]]] + expected = [[[[2.333333, 3.333333, 4.333333, 5.333333]]]] np.testing.assert_allclose(out, expected, atol=_rtol) def test_zoom_2(self): affine = torch.as_tensor([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0]], dtype=torch.float32) image = torch.arange(1.0, 13.0).view(1, 1, 3, 4).to(device=torch.device("cpu:0")) out = AffineTransform((1, 2))(image, affine) - expected = [[[[1, 3]]]] + expected = [[[[1.458333, 4.958333]]]] np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol) def test_zoom_zero_center(self): affine = torch.as_tensor([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0]], dtype=torch.float32) image = torch.arange(1.0, 13.0).view(1, 1, 3, 4).to(device=torch.device("cpu:0")) out = AffineTransform((1, 2), zero_centered=True)(image, affine) - expected = [[[[3, 5]]]] + expected = [[[[5.5, 7.5]]]] np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol) def test_affine_transform_minimum(self): @@ -185,7 +187,7 @@ def test_affine_transform_minimum(self): affine = [[np.cos(t), -np.sin(t), 0], [np.sin(t), np.cos(t), 0], [0, 0, 1]] affine = torch.as_tensor(affine, device=torch.device("cpu:0"), dtype=torch.float32) image = torch.arange(24.0).view(1, 1, 4, 6).to(device=torch.device("cpu:0")) - out = AffineTransform()(image, affine) + out = AffineTransform(align_corners=False)(image, affine) out = out.detach().cpu().numpy() expected = [ [ @@ -204,7 +206,7 @@ def test_affine_transform_2d(self): affine = [[np.cos(t), -np.sin(t), 0], [np.sin(t), np.cos(t), 0], [0, 0, 1]] affine = torch.as_tensor(affine, device=torch.device("cpu:0"), dtype=torch.float32) image = torch.arange(24.0).view(1, 1, 4, 6).to(device=torch.device("cpu:0")) - xform = AffineTransform((3, 4), padding_mode="border", align_corners=True, mode="bilinear") + xform = AffineTransform((3, 4), padding_mode="border", align_corners=False, mode="bilinear") out = xform(image, affine) out = out.detach().cpu().numpy() expected = [ @@ -221,7 +223,7 @@ def test_affine_transform_2d(self): if torch.cuda.is_available(): affine = torch.as_tensor(affine, device=torch.device("cuda:0"), dtype=torch.float32) image = torch.arange(24.0).view(1, 1, 4, 6).to(device=torch.device("cuda:0")) - xform = AffineTransform(padding_mode="border", align_corners=True, mode="bilinear") + xform = AffineTransform(padding_mode="border", align_corners=False, mode="bilinear") out = xform(image, affine, (3, 4)) out = out.detach().cpu().numpy() expected = [ @@ -348,19 +350,19 @@ def test_forward_2d(self): expected = torch.nn.functional.grid_sample(x, grid, align_corners=False) expected = expected.detach().cpu().numpy() - actual = AffineTransform(normalized=True, reverse_indexing=False)(x, theta) + actual = AffineTransform(normalized=True, reverse_indexing=False, align_corners=False)(x, theta) actual = actual.detach().cpu().numpy() np.testing.assert_allclose(actual, expected) np.testing.assert_allclose(list(theta.shape), [2, 2, 3]) theta = torch.Tensor([[0, -1, 0], [1, 0, 0]]) - actual = AffineTransform(normalized=True, reverse_indexing=False)(x, theta) + actual = AffineTransform(normalized=True, reverse_indexing=False, align_corners=False)(x, theta) actual = actual.detach().cpu().numpy() np.testing.assert_allclose(actual, expected) np.testing.assert_allclose(list(theta.shape), [2, 3]) theta = torch.Tensor([[[0, -1, 0], [1, 0, 0]]]) - actual = AffineTransform(normalized=True, reverse_indexing=False)(x, theta) + actual = AffineTransform(normalized=True, reverse_indexing=False, align_corners=False)(x, theta) actual = actual.detach().cpu().numpy() np.testing.assert_allclose(actual, expected) np.testing.assert_allclose(list(theta.shape), [1, 2, 3]) @@ -372,19 +374,19 @@ def test_forward_3d(self): expected = torch.nn.functional.grid_sample(x, grid, align_corners=False) expected = expected.detach().cpu().numpy() - actual = AffineTransform(normalized=True, reverse_indexing=False)(x, theta) + actual = AffineTransform(normalized=True, reverse_indexing=False, align_corners=False)(x, theta) actual = actual.detach().cpu().numpy() np.testing.assert_allclose(actual, expected) np.testing.assert_allclose(list(theta.shape), [2, 3, 4]) theta = torch.Tensor([[0, 0, -1, 0], [1, 0, 0, 0], [0, 0, 1, 0]]) - actual = AffineTransform(normalized=True, reverse_indexing=False)(x, theta) + actual = AffineTransform(normalized=True, reverse_indexing=False, align_corners=False)(x, theta) actual = actual.detach().cpu().numpy() np.testing.assert_allclose(actual, expected) np.testing.assert_allclose(list(theta.shape), [3, 4]) theta = torch.Tensor([[[0, 0, -1, 0], [1, 0, 0, 0], [0, 0, 1, 0]]]) - actual = AffineTransform(normalized=True, reverse_indexing=False)(x, theta) + actual = AffineTransform(normalized=True, reverse_indexing=False, align_corners=False)(x, theta) actual = actual.detach().cpu().numpy() np.testing.assert_allclose(actual, expected) np.testing.assert_allclose(list(theta.shape), [1, 3, 4]) diff --git a/tests/test_affined.py b/tests/test_affined.py index b922d80fb5..a35b35758a 100644 --- a/tests/test_affined.py +++ b/tests/test_affined.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy @@ -17,6 +19,7 @@ from parameterized import parameterized from monai.transforms import Affined +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_NDARRAYS_ALL, assert_allclose, test_local_inversion TESTS = [] @@ -77,6 +80,13 @@ p(np.arange(27).reshape(1, 3, 3, 3)), ] ) + TESTS.append( + [ + dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0, 0), device=device, align_corners=False), + {"img": p(np.arange(27).reshape((1, 3, 3, 3)))}, + p(np.arange(27).reshape(1, 3, 3, 3)), + ] + ) TESTS.append( [ dict(keys="img", padding_mode="zeros", spatial_size=(4, 4, 4), device=device), @@ -166,6 +176,15 @@ def test_affine(self, input_param, input_data, expected_val): test_local_inversion(g, result, input_copy, dict_key="img") assert_allclose(result["img"], expected_val, rtol=1e-4, atol=1e-4, type_test="tensor") + # test lazy + lazy_input_param = input_param.copy() + for align_corners in [True, False]: + lazy_input_param["align_corners"] = align_corners + resampler = Affined(**lazy_input_param) + call_param = {"data": input_data} + non_lazy_result = resampler(**call_param) + test_resampler_lazy(resampler, non_lazy_result, lazy_input_param, call_param, output_key="img") + if __name__ == "__main__": unittest.main() diff --git a/tests/test_ahnet.py b/tests/test_ahnet.py index dba8eaf72b..5707cf0452 100644 --- a/tests/test_ahnet.py +++ b/tests/test_ahnet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_alias.py b/tests/test_alias.py index 49f9fa56fe..e2dd8bcf26 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import glob import inspect import os diff --git a/tests/test_anchor_box.py b/tests/test_anchor_box.py index a6abbc0200..c29296e8ae 100644 --- a/tests/test_anchor_box.py +++ b/tests/test_anchor_box.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch @@ -42,7 +44,6 @@ class TestAnchorGenerator(unittest.TestCase): @parameterized.expand(TEST_CASES_2D) def test_anchor_2d(self, input_param, image_shape, feature_maps_shapes): - torch_anchor_utils, _ = optional_import("torchvision.models.detection.anchor_utils") image_list, _ = optional_import("torchvision.models.detection.image_list") diff --git a/tests/test_apply.py b/tests/test_apply.py index f9e8a4a1eb..cf74721267 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -30,7 +32,7 @@ def single_2d_transform_cases(): (torch.as_tensor(get_arange_img((32, 32))), [create_rotate(2, np.pi / 2)], (1, 32, 32)), ( torch.as_tensor(get_arange_img((16, 16))), - [{LazyAttr.AFFINE: create_rotate(2, np.pi / 2), LazyAttr.SHAPE: (1, 45, 45)}], + [{LazyAttr.AFFINE: create_rotate(2, np.pi / 2), LazyAttr.SHAPE: (45, 45)}], (1, 45, 45), ), ] @@ -49,6 +51,8 @@ def _test_apply_metatensor_impl(self, tensor, pending_transforms, expected_shape else: for p in pending_transforms: tensor_.push_pending_operation(p) + if not isinstance(p, dict): + return result, transforms = apply_transforms(tensor_) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_apply_filter.py b/tests/test_apply_filter.py index 62372516a5..0de77bfb4d 100644 --- a/tests/test_apply_filter.py +++ b/tests/test_apply_filter.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_arraydataset.py b/tests/test_arraydataset.py index 380a6372b2..72d62008c4 100644 --- a/tests/test_arraydataset.py +++ b/tests/test_arraydataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import sys import tempfile diff --git a/tests/test_as_channel_first.py b/tests/test_as_channel_first.py index 732c559a1a..3bf4e877ab 100644 --- a/tests/test_as_channel_first.py +++ b/tests/test_as_channel_first.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_as_channel_firstd.py b/tests/test_as_channel_firstd.py index 91086f9299..bd50a96f5d 100644 --- a/tests/test_as_channel_firstd.py +++ b/tests/test_as_channel_firstd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_as_channel_last.py b/tests/test_as_channel_last.py index e6446ab7a6..8f88fb2928 100644 --- a/tests/test_as_channel_last.py +++ b/tests/test_as_channel_last.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_as_channel_lastd.py b/tests/test_as_channel_lastd.py index a6d94d216a..16086b769c 100644 --- a/tests/test_as_channel_lastd.py +++ b/tests/test_as_channel_lastd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index 014e439fe1..2802c7d9ff 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_as_discreted.py b/tests/test_as_discreted.py index dc96a4218b..ec394fc3af 100644 --- a/tests/test_as_discreted.py +++ b/tests/test_as_discreted.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_atss_box_matcher.py b/tests/test_atss_box_matcher.py index 093641bb2f..a614497bc9 100644 --- a/tests/test_atss_box_matcher.py +++ b/tests/test_atss_box_matcher.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_attentionunet.py b/tests/test_attentionunet.py index 1ecd75c166..d5c67cee38 100644 --- a/tests/test_attentionunet.py +++ b/tests/test_attentionunet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_auto3dseg.py b/tests/test_auto3dseg.py index 74c45f6fec..deb54b7a4b 100644 --- a/tests/test_auto3dseg.py +++ b/tests/test_auto3dseg.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest @@ -148,7 +150,6 @@ class TestImageAnalyzer(Analyzer): """ def __init__(self, image_key="image", stats_name="test_image"): - self.image_key = image_key report_format = {"test_stats": None} @@ -169,12 +170,11 @@ def setUp(self): work_dir = self.test_dir.name self.dataroot_dir = os.path.join(work_dir, "sim_dataroot") self.datalist_file = os.path.join(work_dir, "sim_datalist.json") - self.datastat_file = os.path.join(work_dir, "data_stats.yaml") + self.datastat_file = os.path.join(work_dir, "datastats.yaml") ConfigParser.export_config_file(sim_datalist, self.datalist_file) @parameterized.expand(SIM_CPU_TEST_CASES) def test_data_analyzer_cpu(self, input_params): - sim_dim = input_params["sim_dim"] label_key = input_params["label_key"] image_only = not bool(label_key) diff --git a/tests/test_auto3dseg_ensemble.py b/tests/test_auto3dseg_ensemble.py index 24cf37201e..979ebf744b 100644 --- a/tests/test_auto3dseg_ensemble.py +++ b/tests/test_auto3dseg_ensemble.py @@ -9,10 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest -from typing import Dict, List import nibabel as nib import numpy as np @@ -21,13 +22,19 @@ from monai.apps.auto3dseg import AlgoEnsembleBestByFold, AlgoEnsembleBestN, AlgoEnsembleBuilder, BundleGen, DataAnalyzer from monai.bundle.config_parser import ConfigParser from monai.data import create_test_image_3d -from monai.utils import optional_import +from monai.utils import optional_import, set_determinism from monai.utils.enums import AlgoEnsembleKeys -from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick +from tests.utils import ( + SkipIfBeforePyTorchVersion, + get_testing_algo_template_path, + skip_if_downloading_fails, + skip_if_no_cuda, + skip_if_quick, +) _, has_tb = optional_import("torch.utils.tensorboard", name="SummaryWriter") -fake_datalist: Dict[str, List[Dict]] = { +fake_datalist: dict[str, list[dict]] = { "testing": [{"image": "val_001.fake.nii.gz"}, {"image": "val_002.fake.nii.gz"}], "training": [ {"fold": 0, "image": "tr_image_001.fake.nii.gz", "label": "tr_label_001.fake.nii.gz"}, @@ -45,16 +52,15 @@ ], } -num_gpus = 4 if torch.cuda.device_count() > 4 else torch.cuda.device_count() train_param = ( { - "CUDA_VISIBLE_DEVICES": list(range(num_gpus)), "num_images_per_batch": 2, "num_epochs": 2, "num_epochs_per_validation": 1, "num_warmup_epochs": 1, "use_pretrain": False, "pretrained_path": "", + "determ": True, } if torch.cuda.is_available() else {} @@ -64,10 +70,11 @@ @skip_if_quick -@SkipIfBeforePyTorchVersion((1, 9, 1)) +@SkipIfBeforePyTorchVersion((1, 10, 0)) @unittest.skipIf(not has_tb, "no tensorboard summary writer") class TestEnsembleBuilder(unittest.TestCase): def setUp(self) -> None: + set_determinism(0) self.test_dir = tempfile.TemporaryDirectory() @skip_if_no_cuda @@ -119,19 +126,33 @@ def test_ensemble(self) -> None: with skip_if_downloading_fails(): bundle_generator = BundleGen( - algo_path=work_dir, data_stats_filename=da_output_yaml, data_src_cfg_name=data_src_cfg + algo_path=work_dir, + data_stats_filename=da_output_yaml, + data_src_cfg_name=data_src_cfg, + templates_path_or_url=get_testing_algo_template_path(), ) bundle_generator.generate(work_dir, num_fold=1) history = bundle_generator.get_history() for h in history: self.assertEqual(len(h.keys()), 1, "each record should have one model") - for _, algo in h.items(): - algo.train(train_param) + for name, algo in h.items(): + _train_param = train_param.copy() + if name.startswith("segresnet"): + _train_param["network#init_filters"] = 8 + _train_param["pretrained_ckpt_name"] = "" + elif name.startswith("swinunetr"): + _train_param["network#feature_size"] = 12 + algo.train(_train_param) builder = AlgoEnsembleBuilder(history, data_src_cfg) - builder.set_ensemble_method(AlgoEnsembleBestN(n_best=2)) + builder.set_ensemble_method(AlgoEnsembleBestN(n_best=1)) ensemble = builder.get_ensemble() + name = ensemble.get_algo_ensemble()[0][AlgoEnsembleKeys.ID] + if name.startswith("segresnet"): + pred_param["network#init_filters"] = 8 + elif name.startswith("swinunetr"): + pred_param["network#feature_size"] = 12 preds = ensemble(pred_param) self.assertTupleEqual(preds[0].shape, (2, 24, 24, 24)) @@ -141,6 +162,7 @@ def test_ensemble(self) -> None: print(algo[AlgoEnsembleKeys.ID]) def tearDown(self) -> None: + set_determinism(None) self.test_dir.cleanup() diff --git a/tests/test_auto3dseg_hpo.py b/tests/test_auto3dseg_hpo.py index 30c2361ef9..6501bb9363 100644 --- a/tests/test_auto3dseg_hpo.py +++ b/tests/test_auto3dseg_hpo.py @@ -9,11 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest from functools import partial -from typing import Dict, List import nibabel as nib import numpy as np @@ -23,16 +24,18 @@ from monai.bundle.config_parser import ConfigParser from monai.data import create_test_image_3d from monai.utils import optional_import -from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_no_cuda +from tests.utils import ( + SkipIfBeforePyTorchVersion, + get_testing_algo_template_path, + skip_if_downloading_fails, + skip_if_no_cuda, +) _, has_tb = optional_import("torch.utils.tensorboard", name="SummaryWriter") optuna, has_optuna = optional_import("optuna") -num_gpus = 4 if torch.cuda.device_count() > 4 else torch.cuda.device_count() - override_param = ( { - "CUDA_VISIBLE_DEVICES": list(range(num_gpus)), "num_images_per_batch": 2, "num_epochs": 2, "num_epochs_per_validation": 1, @@ -52,7 +55,7 @@ def skip_if_no_optuna(obj): return unittest.skipUnless(has_optuna, "Skipping optuna tests")(obj) -fake_datalist: Dict[str, List[Dict]] = { +fake_datalist: dict[str, list[dict]] = { "testing": [{"image": "val_001.fake.nii.gz"}, {"image": "val_002.fake.nii.gz"}], "training": [ {"fold": 0, "image": "tr_image_001.fake.nii.gz", "label": "tr_label_001.fake.nii.gz"}, @@ -122,7 +125,10 @@ def setUp(self) -> None: ConfigParser.export_config_file(data_src, data_src_cfg) with skip_if_downloading_fails(): bundle_generator = BundleGen( - algo_path=work_dir, data_stats_filename=da_output_yaml, data_src_cfg_name=data_src_cfg + algo_path=work_dir, + data_stats_filename=da_output_yaml, + data_src_cfg_name=data_src_cfg, + templates_path_or_url=get_testing_algo_template_path(), ) bundle_generator.generate(work_dir, num_fold=1) @@ -132,7 +138,6 @@ def setUp(self) -> None: @skip_if_no_cuda def test_run_algo(self) -> None: - algo_dict = self.history[0] algo_name = list(algo_dict.keys())[0] algo = algo_dict[algo_name] diff --git a/tests/test_autoencoder.py b/tests/test_autoencoder.py index bed5a198ff..485049c2d1 100644 --- a/tests/test_autoencoder.py +++ b/tests/test_autoencoder.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_avg_merger.py b/tests/test_avg_merger.py new file mode 100644 index 0000000000..0b0a461faf --- /dev/null +++ b/tests/test_avg_merger.py @@ -0,0 +1,165 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.inferers import AvgMerger +from tests.utils import assert_allclose + +TENSOR_4x4 = torch.randint(low=0, high=255, size=(2, 3, 4, 4), dtype=torch.float32) +TENSOR_4x4_WITH_NAN = TENSOR_4x4.clone() +TENSOR_4x4_WITH_NAN[..., 2:, 2:] = float("nan") + +# no-overlapping 2x2 +TEST_CASE_0_DEFAULT_DTYPE = [ + dict(output_shape=TENSOR_4x4.shape), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + +# overlapping 2x2 +TEST_CASE_1_DEFAULT_DTYPE = [ + dict(output_shape=TENSOR_4x4.shape), + [ + (TENSOR_4x4[..., 0:2, 0:2], (0, 0)), + (TENSOR_4x4[..., 0:2, 1:3], (0, 1)), + (TENSOR_4x4[..., 0:2, 2:4], (0, 2)), + (TENSOR_4x4[..., 1:3, 0:2], (1, 0)), + (TENSOR_4x4[..., 1:3, 1:3], (1, 1)), + (TENSOR_4x4[..., 1:3, 2:4], (1, 2)), + (TENSOR_4x4[..., 2:4, 0:2], (2, 0)), + (TENSOR_4x4[..., 2:4, 1:3], (2, 1)), + (TENSOR_4x4[..., 2:4, 2:4], (2, 2)), + ], + TENSOR_4x4, +] + +# overlapping 3x3 (non-divisible) +TEST_CASE_2_DEFAULT_DTYPE = [ + dict(output_shape=TENSOR_4x4.shape), + [ + (TENSOR_4x4[..., :3, :3], (0, 0)), + (TENSOR_4x4[..., :3, 1:], (0, 1)), + (TENSOR_4x4[..., 1:, :3], (1, 0)), + (TENSOR_4x4[..., 1:, 1:], (1, 1)), + ], + TENSOR_4x4, +] + +# overlapping 2x2 with NaN values +TEST_CASE_3_DEFAULT_DTYPE = [ + dict(output_shape=TENSOR_4x4_WITH_NAN.shape), + [ + (TENSOR_4x4_WITH_NAN[..., 0:2, 0:2], (0, 0)), + (TENSOR_4x4_WITH_NAN[..., 0:2, 1:3], (0, 1)), + (TENSOR_4x4_WITH_NAN[..., 0:2, 2:4], (0, 2)), + (TENSOR_4x4_WITH_NAN[..., 1:3, 0:2], (1, 0)), + (TENSOR_4x4_WITH_NAN[..., 1:3, 1:3], (1, 1)), + (TENSOR_4x4_WITH_NAN[..., 1:3, 2:4], (1, 2)), + (TENSOR_4x4_WITH_NAN[..., 2:4, 0:2], (2, 0)), + (TENSOR_4x4_WITH_NAN[..., 2:4, 1:3], (2, 1)), + (TENSOR_4x4_WITH_NAN[..., 2:4, 2:4], (2, 2)), + ], + TENSOR_4x4_WITH_NAN, +] + +# non-overlapping 2x2 with missing patch +TEST_CASE_4_DEFAULT_DTYPE = [ + dict(output_shape=TENSOR_4x4.shape), + [(TENSOR_4x4[..., :2, :2], (0, 0)), (TENSOR_4x4[..., :2, 2:], (0, 2)), (TENSOR_4x4[..., 2:, :2], (2, 0))], + TENSOR_4x4_WITH_NAN, +] + +# with value_dtype set to half precision +TEST_CASE_5_VALUE_DTYPE = [ + dict(output_shape=TENSOR_4x4.shape, value_dtype=torch.float16), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] +# with count_dtype set to int32 +TEST_CASE_6_COUNT_DTYPE = [ + dict(output_shape=TENSOR_4x4.shape, count_dtype=torch.int32), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] +# with both value_dtype, count_dtype set to double precision +TEST_CASE_7_COUNT_VALUE_DTYPE = [ + dict(output_shape=TENSOR_4x4.shape, value_dtype=torch.float64, count_dtype=torch.float64), + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + TENSOR_4x4, +] + + +class AvgMergerTests(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_0_DEFAULT_DTYPE, + TEST_CASE_1_DEFAULT_DTYPE, + TEST_CASE_2_DEFAULT_DTYPE, + TEST_CASE_3_DEFAULT_DTYPE, + TEST_CASE_4_DEFAULT_DTYPE, + TEST_CASE_5_VALUE_DTYPE, + TEST_CASE_6_COUNT_DTYPE, + TEST_CASE_7_COUNT_VALUE_DTYPE, + ] + ) + def test_avg_merger_patches(self, arguments, patch_locations, expected): + merger = AvgMerger(**arguments) + for pl in patch_locations: + merger.aggregate(pl[0], pl[1]) + output = merger.finalize() + if "value_dtype" in arguments: + self.assertTrue(merger.get_values().dtype, arguments["value_dtype"]) + if "count_dtype" in arguments: + self.assertTrue(merger.get_counts().dtype, arguments["count_dtype"]) + # check for multiple call of finalize + self.assertIs(output, merger.finalize()) + # check if the result is matching the expectation + assert_allclose(output, expected) + + def test_avg_merger_finalized_error(self): + with self.assertRaises(ValueError): + merger = AvgMerger(output_shape=(1, 3, 2, 3)) + merger.finalize() + merger.aggregate(torch.zeros(1, 3, 2, 2), (3, 3)) + + def test_avg_merge_none_output_shape_error(self): + with self.assertRaises(ValueError): + AvgMerger(output_shape=None) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_basic_unet.py b/tests/test_basic_unet.py index a4f88367dd..23e19dd536 100644 --- a/tests/test_basic_unet.py +++ b/tests/test_basic_unet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_basic_unetplusplus.py b/tests/test_basic_unetplusplus.py index 3bca65676a..19ed5977fd 100644 --- a/tests/test_basic_unetplusplus.py +++ b/tests/test_basic_unetplusplus.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_bending_energy.py b/tests/test_bending_energy.py index 18f88ba759..f29d4f256b 100644 --- a/tests/test_bending_energy.py +++ b/tests/test_bending_energy.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_bilateral_approx_cpu.py b/tests/test_bilateral_approx_cpu.py index b3f4f5c3be..da30d5d7de 100644 --- a/tests/test_bilateral_approx_cpu.py +++ b/tests/test_bilateral_approx_cpu.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -365,7 +367,6 @@ class BilateralFilterTestCaseCpuApprox(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_cpu_approx(self, test_case_description, sigmas, input, expected): - # Params to determine the implementation to test device = torch.device("cpu") fast_approx = True @@ -379,7 +380,6 @@ def test_cpu_approx(self, test_case_description, sigmas, input, expected): @parameterized.expand(TEST_CASES) def test_cpu_approx_backwards(self, test_case_description, sigmas, input, expected): - # Params to determine the implementation to test device = torch.device("cpu") fast_approx = True diff --git a/tests/test_bilateral_approx_cuda.py b/tests/test_bilateral_approx_cuda.py index db34b0ff71..924ff3253e 100644 --- a/tests/test_bilateral_approx_cuda.py +++ b/tests/test_bilateral_approx_cuda.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -366,7 +368,6 @@ class BilateralFilterTestCaseCudaApprox(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_cuda_approx(self, test_case_description, sigmas, input, expected): - # Skip this test if not torch.cuda.is_available(): return @@ -384,7 +385,6 @@ def test_cuda_approx(self, test_case_description, sigmas, input, expected): @parameterized.expand(TEST_CASES) def test_cpu_approx_backwards(self, test_case_description, sigmas, input, expected): - # Params to determine the implementation to test device = torch.device("cuda") fast_approx = True diff --git a/tests/test_bilateral_precise.py b/tests/test_bilateral_precise.py index b19369d758..7fc7e06726 100644 --- a/tests/test_bilateral_precise.py +++ b/tests/test_bilateral_precise.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -365,7 +367,6 @@ class BilateralFilterTestCaseCpuPrecise(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_cpu_precise(self, test_case_description, sigmas, input, expected): - # Params to determine the implementation to test device = torch.device("cpu") fast_approx = False @@ -379,7 +380,6 @@ def test_cpu_precise(self, test_case_description, sigmas, input, expected): @parameterized.expand(TEST_CASES) def test_cpu_precise_backwards(self, test_case_description, sigmas, input, expected): - # Params to determine the implementation to test device = torch.device("cpu") fast_approx = False @@ -400,7 +400,6 @@ def test_cpu_precise_backwards(self, test_case_description, sigmas, input, expec class BilateralFilterTestCaseCudaPrecise(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_cuda_precise(self, test_case_description, sigmas, input, expected): - # Skip this test if not torch.cuda.is_available(): return @@ -418,7 +417,6 @@ def test_cuda_precise(self, test_case_description, sigmas, input, expected): @parameterized.expand(TEST_CASES) def test_cuda_precise_backwards(self, test_case_description, sigmas, input, expected): - # Params to determine the implementation to test device = torch.device("cuda") fast_approx = False diff --git a/tests/test_blend_images.py b/tests/test_blend_images.py index 341b79b949..9814a5a3f8 100644 --- a/tests/test_blend_images.py +++ b/tests/test_blend_images.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from unittest.case import skipUnless diff --git a/tests/test_border_pad.py b/tests/test_border_pad.py index 1194ae49a6..d0ea112d3a 100644 --- a/tests/test_border_pad.py +++ b/tests/test_border_pad.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized @@ -38,6 +40,10 @@ def test_pad_kwargs(self): unchanged_slices = [slice(None), slice(2, -2), slice(2, -2)] self.pad_test_kwargs(unchanged_slices, **kwargs) + @parameterized.expand(TESTS) + def test_pending_ops(self, input_param, input_shape, _): + self.pad_test_pending_ops(input_param, input_shape) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_border_padd.py b/tests/test_border_padd.py index ca55c8b09d..c7eb3da762 100644 --- a/tests/test_border_padd.py +++ b/tests/test_border_padd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized @@ -34,6 +36,10 @@ def test_pad(self, input_param, input_shape, expected_shape): modes = ["constant", NumpyPadMode.CONSTANT, PytorchPadMode.CONSTANT, "edge", NumpyPadMode.EDGE] self.pad_test(input_param, input_shape, expected_shape, modes) + @parameterized.expand(TESTS) + def test_pending_ops(self, input_param, input_shape, _): + self.pad_test_pending_ops(input_param, input_shape) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_bounding_rect.py b/tests/test_bounding_rect.py index a7c2648f1e..b9c232e2d2 100644 --- a/tests/test_bounding_rect.py +++ b/tests/test_bounding_rect.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_bounding_rectd.py b/tests/test_bounding_rectd.py index 47ed854263..248a0a8e47 100644 --- a/tests/test_bounding_rectd.py +++ b/tests/test_bounding_rectd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_box_coder.py b/tests/test_box_coder.py index 86ca7a98c2..5835341139 100644 --- a/tests/test_box_coder.py +++ b/tests/test_box_coder.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_box_transform.py b/tests/test_box_transform.py index d597ec76bb..94bd6ade52 100644 --- a/tests/test_box_transform.py +++ b/tests/test_box_transform.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy diff --git a/tests/test_box_utils.py b/tests/test_box_utils.py index 8c56783c3b..c4fefb5a98 100644 --- a/tests/test_box_utils.py +++ b/tests/test_box_utils.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_bundle_ckpt_export.py b/tests/test_bundle_ckpt_export.py index e5847c57ab..a3130cefbd 100644 --- a/tests/test_bundle_ckpt_export.py +++ b/tests/test_bundle_ckpt_export.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import json import os import tempfile diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py index 09cd0128f9..52aa515111 100644 --- a/tests/test_bundle_download.py +++ b/tests/test_bundle_download.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import json import os import tempfile @@ -26,7 +28,6 @@ command_line_tests, skip_if_downloading_fails, skip_if_quick, - skip_if_windows, ) TEST_CASE_1 = ["test_bundle", None] @@ -58,7 +59,6 @@ ] -@skip_if_windows class TestDownload(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) @skip_if_quick diff --git a/tests/test_bundle_get_data.py b/tests/test_bundle_get_data.py index c36409f724..a560f3945f 100644 --- a/tests/test_bundle_get_data.py +++ b/tests/test_bundle_get_data.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_bundle_init_bundle.py b/tests/test_bundle_init_bundle.py index f702401481..08f921da01 100644 --- a/tests/test_bundle_init_bundle.py +++ b/tests/test_bundle_init_bundle.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_bundle_utils.py b/tests/test_bundle_utils.py index 0fbfae5094..d92f6e517f 100644 --- a/tests/test_bundle_utils.py +++ b/tests/test_bundle_utils.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import shutil import tempfile @@ -18,6 +20,7 @@ from monai.bundle.utils import load_bundle_config from monai.networks.nets import UNet +from monai.utils import pprint_edges from tests.utils import command_line_tests, skip_if_windows metadata = """ @@ -115,5 +118,16 @@ def test_load_config_ts(self): self.assertEqual(p["test_dict"]["b"], "c") +class TestPPrintEdges(unittest.TestCase): + def test_str(self): + self.assertEqual(pprint_edges("", 0), "''") + self.assertEqual(pprint_edges({"a": 1, "b": 2}, 0), "{'a': 1, 'b': 2}") + self.assertEqual( + pprint_edges([{"a": 1, "b": 2}] * 20, 1), + "[{'a': 1, 'b': 2},\n\n ... omitted 18 line(s)\n\n {'a': 1, 'b': 2}]", + ) + self.assertEqual(pprint_edges([{"a": 1, "b": 2}] * 8, 4), pprint_edges([{"a": 1, "b": 2}] * 8, 3)) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_bundle_verify_metadata.py b/tests/test_bundle_verify_metadata.py index 773cb40888..0701e905b9 100644 --- a/tests/test_bundle_verify_metadata.py +++ b/tests/test_bundle_verify_metadata.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import json import os import tempfile diff --git a/tests/test_bundle_verify_net.py b/tests/test_bundle_verify_net.py index a978f825b9..bd611323ac 100644 --- a/tests/test_bundle_verify_net.py +++ b/tests/test_bundle_verify_net.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_bundle_workflow.py b/tests/test_bundle_workflow.py new file mode 100644 index 0000000000..948c351a1c --- /dev/null +++ b/tests/test_bundle_workflow.py @@ -0,0 +1,251 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import shutil +import tempfile +import unittest +from copy import deepcopy + +import nibabel as nib +import numpy as np +import torch +from parameterized import parameterized + +from monai.bundle import BundleWorkflow, ConfigWorkflow +from monai.data import DataLoader, Dataset +from monai.engines import SupervisedEvaluator +from monai.inferers import SimpleInferer, SlidingWindowInferer +from monai.networks.nets import UNet +from monai.transforms import ( + Activationsd, + AsDiscreted, + Compose, + EnsureChannelFirstd, + LoadImage, + LoadImaged, + SaveImaged, + ScaleIntensityd, +) +from monai.utils import BundleProperty, set_determinism + +TEST_CASE_1 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.json")] + +TEST_CASE_2 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.yaml")] + +TEST_CASE_3 = [os.path.join(os.path.dirname(__file__), "testing_data", "config_fl_train.json")] + + +class NonConfigWorkflow(BundleWorkflow): + """ + Test class simulates the bundle workflow defined by Python script directly. + + """ + + def __init__(self, filename, output_dir): + super().__init__(workflow="inference") + self.filename = filename + self.output_dir = output_dir + self._bundle_root = "will override" + self._device = torch.device("cpu") + self._network_def = None + self._inferer = None + self._preprocessing = None + self._postprocessing = None + self._evaluator = None + + def initialize(self): + set_determinism(0) + if self._preprocessing is None: + self._preprocessing = Compose( + [LoadImaged(keys="image"), EnsureChannelFirstd(keys="image"), ScaleIntensityd(keys="image")] + ) + dataset = Dataset(data=[{"image": self.filename}], transform=self._preprocessing) + dataloader = DataLoader(dataset, batch_size=1, num_workers=4) + + if self._network_def is None: + self._network_def = UNet( + spatial_dims=3, + in_channels=1, + out_channels=2, + channels=[2, 2, 4, 8, 4], + strides=[2, 2, 2, 2], + num_res_units=2, + norm="batch", + ) + if self._inferer is None: + self._inferer = SlidingWindowInferer(roi_size=(64, 64, 32), sw_batch_size=4, overlap=0.25) + + if self._postprocessing is None: + self._postprocessing = Compose( + [ + Activationsd(keys="pred", softmax=True), + AsDiscreted(keys="pred", argmax=True), + SaveImaged(keys="pred", output_dir=self.output_dir, output_postfix="seg"), + ] + ) + + self._evaluator = SupervisedEvaluator( + device=self._device, + val_data_loader=dataloader, + network=self._network_def.to(self._device), + inferer=self._inferer, + postprocessing=self._postprocessing, + amp=False, + ) + + def run(self): + self._evaluator.run() + + def finalize(self): + return True + + def _get_property(self, name, property): + if name == "bundle_root": + return self._bundle_root + if name == "device": + return self._device + if name == "network_def": + return self._network_def + if name == "inferer": + return self._inferer + if name == "preprocessing": + return self._preprocessing + if name == "postprocessing": + return self._postprocessing + if property[BundleProperty.REQUIRED]: + raise ValueError(f"unsupported property '{name}' is required in the bundle properties.") + + def _set_property(self, name, property, value): + if name == "bundle_root": + self._bundle_root = value + elif name == "device": + self._device = value + elif name == "network_def": + self._network_def = value + elif name == "inferer": + self._inferer = value + elif name == "preprocessing": + self._preprocessing = value + elif name == "postprocessing": + self._postprocessing = value + elif property[BundleProperty.REQUIRED]: + raise ValueError(f"unsupported property '{name}' is required in the bundle properties.") + + +class TestBundleWorkflow(unittest.TestCase): + def setUp(self): + self.data_dir = tempfile.mkdtemp() + self.expected_shape = (128, 128, 128) + test_image = np.random.rand(*self.expected_shape) + self.filename = os.path.join(self.data_dir, "image.nii") + nib.save(nib.Nifti1Image(test_image, np.eye(4)), self.filename) + + def tearDown(self): + shutil.rmtree(self.data_dir) + + def _test_inferer(self, inferer): + # should initialize before parsing any bundle content + inferer.initialize() + # test required and optional properties + self.assertListEqual(inferer.check_properties(), []) + # test read / write the properties, note that we don't assume it as JSON or YAML config here + self.assertEqual(inferer.bundle_root, "will override") + self.assertEqual(inferer.device, torch.device("cpu")) + net = inferer.network_def + self.assertTrue(isinstance(net, UNet)) + sliding_window = inferer.inferer + self.assertTrue(isinstance(sliding_window, SlidingWindowInferer)) + preprocessing = inferer.preprocessing + self.assertTrue(isinstance(preprocessing, Compose)) + postprocessing = inferer.postprocessing + self.assertTrue(isinstance(postprocessing, Compose)) + # test optional properties get + self.assertTrue(inferer.key_metric is None) + inferer.bundle_root = "/workspace/data/spleen_ct_segmentation" + inferer.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + inferer.network_def = deepcopy(net) + inferer.inferer = deepcopy(sliding_window) + inferer.preprocessing = deepcopy(preprocessing) + inferer.postprocessing = deepcopy(postprocessing) + # test optional properties set + inferer.key_metric = "set optional properties" + + # should initialize and parse again as changed the bundle content + inferer.initialize() + inferer.run() + inferer.finalize() + # verify inference output + loader = LoadImage(image_only=True) + pred_file = os.path.join(self.data_dir, "image", "image_seg.nii.gz") + self.assertTupleEqual(loader(pred_file).shape, self.expected_shape) + os.remove(pred_file) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_inference_config(self, config_file): + override = { + "network": "$@network_def.to(@device)", + "dataset#_target_": "Dataset", + "dataset#data": [{"image": self.filename}], + "postprocessing#transforms#2#output_postfix": "seg", + "output_dir": self.data_dir, + } + # test standard MONAI model-zoo config workflow + inferer = ConfigWorkflow( + workflow="infer", + config_file=config_file, + logging_file=os.path.join(os.path.dirname(__file__), "testing_data", "logging.conf"), + **override, + ) + self._test_inferer(inferer) + + @parameterized.expand([TEST_CASE_3]) + def test_train_config(self, config_file): + # test standard MONAI model-zoo config workflow + trainer = ConfigWorkflow( + workflow="train", + config_file=config_file, + logging_file=os.path.join(os.path.dirname(__file__), "testing_data", "logging.conf"), + init_id="initialize", + run_id="run", + final_id="finalize", + ) + # should initialize before parsing any bundle content + trainer.initialize() + # test required and optional properties + self.assertListEqual(trainer.check_properties(), []) + # test read / write the properties + dataset = trainer.train_dataset + self.assertTrue(isinstance(dataset, Dataset)) + inferer = trainer.train_inferer + self.assertTrue(isinstance(inferer, SimpleInferer)) + # test optional properties get + self.assertTrue(trainer.train_key_metric is None) + trainer.train_dataset = deepcopy(dataset) + trainer.train_inferer = deepcopy(inferer) + # test optional properties set + trainer.train_key_metric = "set optional properties" + + # should initialize and parse again as changed the bundle content + trainer.initialize() + trainer.run() + trainer.finalize() + + def test_non_config(self): + # test user defined python style workflow + inferer = NonConfigWorkflow(self.filename, self.data_dir) + self._test_inferer(inferer) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_cachedataset.py b/tests/test_cachedataset.py index 86ebced9f3..dcae5fdce1 100644 --- a/tests/test_cachedataset.py +++ b/tests/test_cachedataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import sys import tempfile @@ -39,7 +41,7 @@ class TestCacheDataset(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, transform, expected_shape): - test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) with tempfile.TemporaryDirectory() as tempdir: test_data = [] for i in ["1", "2"]: @@ -190,7 +192,7 @@ def test_thread_safe(self, persistent_workers, cache_workers, loader_workers): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_hash_as_key(self, transform, expected_shape): - test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) with tempfile.TemporaryDirectory() as tempdir: test_data = [] for i in ["1", "2", "2", "3", "3"]: diff --git a/tests/test_cachedataset_parallel.py b/tests/test_cachedataset_parallel.py index f409e17787..c3fc2cc362 100644 --- a/tests/test_cachedataset_parallel.py +++ b/tests/test_cachedataset_parallel.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest @@ -30,7 +32,7 @@ class TestCacheDatasetParallel(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_shape(self, num_workers, dataset_size, transform): - test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[8, 8, 8]), np.eye(4)) + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[8, 8, 8]).astype(float), np.eye(4)) with tempfile.TemporaryDirectory() as tempdir: nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) diff --git a/tests/test_cachedataset_persistent_workers.py b/tests/test_cachedataset_persistent_workers.py index 7f241899eb..e60862238d 100644 --- a/tests/test_cachedataset_persistent_workers.py +++ b/tests/test_cachedataset_persistent_workers.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from monai.data import CacheDataset, DataLoader, create_test_image_2d diff --git a/tests/test_cachentransdataset.py b/tests/test_cachentransdataset.py index 99ca0e0c3d..d50fe4f8dd 100644 --- a/tests/test_cachentransdataset.py +++ b/tests/test_cachentransdataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest @@ -34,7 +36,7 @@ class TestCacheNTransDataset(unittest.TestCase): @parameterized.expand([TEST_CASE_1]) def test_n_trans(self, transform, expected_shape): - data_array = np.random.randint(0, 2, size=[128, 128, 128]) + data_array = np.random.randint(0, 2, size=[128, 128, 128]).astype(float) test_image = nib.Nifti1Image(data_array, np.eye(4)) with tempfile.TemporaryDirectory() as tempdir: nib.save(test_image, os.path.join(tempdir, "test_image.nii.gz")) diff --git a/tests/test_call_dist.py b/tests/test_call_dist.py index bed8289506..0621824b65 100644 --- a/tests/test_call_dist.py +++ b/tests/test_call_dist.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from tests.utils import DistCall, DistTestCase diff --git a/tests/test_cast_to_type.py b/tests/test_cast_to_type.py index 82daabc4e7..6dd994120c 100644 --- a/tests/test_cast_to_type.py +++ b/tests/test_cast_to_type.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -37,7 +39,6 @@ class TestCastToType(unittest.TestCase): @parameterized.expand(TESTS) def test_type(self, out_dtype, input_data, expected_type): - result = CastToType(dtype=out_dtype)(input_data) self.assertEqual(result.dtype, get_equivalent_dtype(expected_type, type(result))) diff --git a/tests/test_cast_to_typed.py b/tests/test_cast_to_typed.py index 1ac23314a5..687deeda1d 100644 --- a/tests/test_cast_to_typed.py +++ b/tests/test_cast_to_typed.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_center_scale_crop.py b/tests/test_center_scale_crop.py index 3fe7a453d3..4a4efb3b76 100644 --- a/tests/test_center_scale_crop.py +++ b/tests/test_center_scale_crop.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -18,9 +20,10 @@ from tests.croppers import CropTest TEST_SHAPES = [ - [{"roi_scale": [0.6, 0.3, -1]}, (3, 3, 3, 3), (3, 2, 1, 3)], - [{"roi_scale": 0.6}, (3, 3, 3, 3), (3, 2, 2, 2)], - [{"roi_scale": 0.5}, (3, 3, 3, 3), (3, 2, 2, 2)], + [{"roi_scale": [0.6, 0.3, -1]}, (3, 3, 3, 3), (3, 2, 1, 3), False], + [{"roi_scale": [0.6, 0.3, -1]}, (3, 3, 4, 3), (3, 2, 2, 3), True], + [{"roi_scale": 0.6}, (3, 3, 3, 3), (3, 2, 2, 2), True], + [{"roi_scale": 0.5}, (3, 3, 3, 3), (3, 2, 2, 2), True], ] TEST_VALUES = [ @@ -32,17 +35,21 @@ ] -class TestCenterSpatialCrop(CropTest): +class TestCenterScaleCrop(CropTest): Cropper = CenterScaleCrop @parameterized.expand(TEST_SHAPES) - def test_shape(self, input_param, input_shape, expected_shape): + def test_shape(self, input_param, input_shape, expected_shape, _): self.crop_test(input_param, input_shape, expected_shape) @parameterized.expand(TEST_VALUES) def test_value(self, input_param, input_arr, expected_arr): self.crop_test_value(input_param, input_arr, expected_arr) + @parameterized.expand(TEST_SHAPES) + def test_pending_ops(self, input_param, input_shape, _, align_corners): + self.crop_test_pending_ops(input_param, input_shape, align_corners) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_center_scale_cropd.py b/tests/test_center_scale_cropd.py index 088c1c70e7..b53f56b93e 100644 --- a/tests/test_center_scale_cropd.py +++ b/tests/test_center_scale_cropd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -18,9 +20,11 @@ from tests.croppers import CropTest TESTS = [ - [{"keys": "img", "roi_scale": [0.6, 0.3, -1]}, (3, 3, 3, 3), (3, 2, 1, 3)], - [{"keys": "img", "roi_scale": 0.6}, (3, 3, 3, 3), (3, 2, 2, 2)], - [{"keys": "img", "roi_scale": 0.5}, (3, 3, 3, 3), (3, 2, 2, 2)], + [{"keys": "img", "roi_scale": [0.6, 0.3, -1]}, (3, 3, 3, 3), (3, 2, 1, 3), False], + [{"keys": "img", "roi_scale": [0.6, 0.3, -1]}, (3, 3, 1, 3), (3, 2, 1, 3), True], + [{"keys": "img", "roi_scale": [0.6, 0.3, -1]}, (3, 3, 4, 3), (3, 2, 2, 3), True], + [{"keys": "img", "roi_scale": 0.6}, (3, 3, 3, 3), (3, 2, 2, 2), True], + [{"keys": "img", "roi_scale": 0.5}, (3, 3, 3, 3), (3, 2, 2, 2), True], ] TEST_VALUES = [ @@ -36,13 +40,17 @@ class TestCenterScaleCropd(CropTest): Cropper = CenterScaleCropd @parameterized.expand(TESTS) - def test_shape(self, input_param, input_shape, expected_shape): + def test_shape(self, input_param, input_shape, expected_shape, _): self.crop_test(input_param, input_shape, expected_shape) @parameterized.expand(TEST_VALUES) def test_value(self, input_param, input_arr, expected_arr): self.crop_test_value(input_param, input_arr, expected_arr) + @parameterized.expand(TESTS) + def test_pending_ops(self, input_param, input_shape, _, align_corners): + self.crop_test_pending_ops(input_param, input_shape, align_corners) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_center_spatial_crop.py b/tests/test_center_spatial_crop.py index 7b5b19107d..c0da043ecb 100644 --- a/tests/test_center_spatial_crop.py +++ b/tests/test_center_spatial_crop.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -18,9 +20,11 @@ from tests.croppers import CropTest TEST_SHAPES = [ - [{"roi_size": [2, 2, -1]}, (3, 3, 3, 3), (3, 2, 2, 3)], - [{"roi_size": [2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2)], - [{"roi_size": [2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2)], + [{"roi_size": [2, 2, -1]}, (3, 3, 3, 3), (3, 2, 2, 3), True], + [{"roi_size": [2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2), True], + [{"roi_size": [2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2), True], + [{"roi_size": [2, 1, 2]}, (3, 3, 3, 3), (3, 2, 1, 2), False], + [{"roi_size": [2, 1, 3]}, (3, 3, 1, 3), (3, 2, 1, 3), True], ] TEST_VALUES = [ @@ -36,13 +40,17 @@ class TestCenterSpatialCrop(CropTest): Cropper = CenterSpatialCrop @parameterized.expand(TEST_SHAPES) - def test_shape(self, input_param, input_shape, expected_shape): + def test_shape(self, input_param, input_shape, expected_shape, _): self.crop_test(input_param, input_shape, expected_shape) @parameterized.expand(TEST_VALUES) def test_value(self, input_param, input_arr, expected_arr): self.crop_test_value(input_param, input_arr, expected_arr) + @parameterized.expand(TEST_SHAPES) + def test_pending_ops(self, input_param, input_shape, _, align_corners): + self.crop_test_pending_ops(input_param, input_shape, align_corners) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_center_spatial_cropd.py b/tests/test_center_spatial_cropd.py index fa7bc8c8fa..0c11a7828a 100644 --- a/tests/test_center_spatial_cropd.py +++ b/tests/test_center_spatial_cropd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -52,6 +54,10 @@ def test_shape(self, input_param, input_shape, expected_shape, same_area): def test_value(self, input_param, input_data, expected_value): self.crop_test_value(input_param, input_data, expected_value) + @parameterized.expand(TEST_SHAPES) + def test_pending_ops(self, input_param, input_shape, _expected_shape, _same_area): + self.crop_test_pending_ops(input_param, input_shape) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_channel_pad.py b/tests/test_channel_pad.py index bde0f18d83..2d8c57fd68 100644 --- a/tests/test_channel_pad.py +++ b/tests/test_channel_pad.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_check_hash.py b/tests/test_check_hash.py index 5297021540..bb3d0ff12e 100644 --- a/tests/test_check_hash.py +++ b/tests/test_check_hash.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_check_missing_files.py b/tests/test_check_missing_files.py index ecd0b52a63..efbe5a95fb 100644 --- a/tests/test_check_missing_files.py +++ b/tests/test_check_missing_files.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest @@ -22,7 +24,7 @@ class TestCheckMissingFiles(unittest.TestCase): def test_content(self): - test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) with tempfile.TemporaryDirectory() as tempdir: nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) diff --git a/tests/test_classes_to_indices.py b/tests/test_classes_to_indices.py index 1f39e0f480..e7dd7abfe5 100644 --- a/tests/test_classes_to_indices.py +++ b/tests/test_classes_to_indices.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_classes_to_indicesd.py b/tests/test_classes_to_indicesd.py index 398620d304..7da321c7d9 100644 --- a/tests/test_classes_to_indicesd.py +++ b/tests/test_classes_to_indicesd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_complex_utils.py b/tests/test_complex_utils.py index 8ba45c294c..77eaa924a2 100644 --- a/tests/test_complex_utils.py +++ b/tests/test_complex_utils.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_component_locator.py b/tests/test_component_locator.py index ebb2cca7b3..3b54a13706 100644 --- a/tests/test_component_locator.py +++ b/tests/test_component_locator.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from pydoc import locate diff --git a/tests/test_compose.py b/tests/test_compose.py index e322e216ad..ddb7ce25d8 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import sys import unittest diff --git a/tests/test_compose_get_number_conversions.py b/tests/test_compose_get_number_conversions.py index fca5bc727d..664558d9cd 100644 --- a/tests/test_compose_get_number_conversions.py +++ b/tests/test_compose_get_number_conversions.py @@ -9,9 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy -from typing import List, Tuple import numpy as np import torch @@ -69,7 +70,7 @@ def __call__(self, x): return _apply(x, lambda x: convert_to_tensor(x).cuda()) -TESTS: List[Tuple] = [] +TESTS: list[tuple] = [] for is_dict in (False, True): # same type depends on input TESTS.append(((N(), N()), is_dict, NP_ARR, 0)) diff --git a/tests/test_compute_confusion_matrix.py b/tests/test_compute_confusion_matrix.py index 56d619a289..a886d8b7e4 100644 --- a/tests/test_compute_confusion_matrix.py +++ b/tests/test_compute_confusion_matrix.py @@ -9,9 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import Any, Dict, List +from typing import Any +import numpy as np import torch from parameterized import parameterized @@ -25,7 +28,7 @@ _device = "cuda:0" if torch.cuda.is_available() else "cpu" # input data -data: Dict[Any, Any] = { +data: dict[Any, Any] = { "y_pred": torch.tensor( [ [[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]]], @@ -42,7 +45,7 @@ ), } -data_nan: Dict[Any, Any] = { +data_nan: dict[Any, Any] = { # confusion matrix:[[[0,1,2,1],[1,1,1,1],[0,1,2,1]], # [[0,0,0,4],[0,0,4,0],[0,4,0,0]], # [[0,0,2,2],[0,0,2,2],[0,4,0,0]]] @@ -62,7 +65,7 @@ ), } -data_clf: Dict[Any, Any] = { +data_clf: dict[Any, Any] = { "y_pred": torch.tensor([[1, 0, 0], [0, 0, 1]]), "y": torch.tensor([[1, 0, 0], [0, 1, 0]]), "compute_sample": False, @@ -146,7 +149,7 @@ result: Any = None for idx, item in enumerate(metric_names): for reduction in ["mean", "mean_batch"]: - TEST_CASE: List[Any] = [data.copy()] + TEST_CASE: list[Any] = [data.copy()] TEST_CASE[0]["compute_sample"] = True TEST_CASE[0]["include_background"] = True TEST_CASE[0]["metric_name"] = item @@ -161,7 +164,7 @@ # one input to compute multiple metrics for reduction in ["mean", "mean_batch"]: - TEST_CASE_MULTIPLE: List[Any] = [data.copy()] + TEST_CASE_MULTIPLE: list[Any] = [data.copy()] TEST_CASE_MULTIPLE[0]["compute_sample"] = True TEST_CASE_MULTIPLE[0]["include_background"] = True TEST_CASE_MULTIPLE[0]["metric_name"] = metric_names @@ -218,6 +221,7 @@ def test_value(self, input_data, expected_value): input_data["include_background"] = False result = get_confusion_matrix(**input_data) assert_allclose(result, expected_value[:, 1:, :], atol=1e-4, rtol=1e-4) + np.testing.assert_equal(result.device, input_data["y_pred"].device) @parameterized.expand(TEST_CASES_COMPUTE_SAMPLE) def test_compute_sample(self, input_data, expected_value): diff --git a/tests/test_compute_f_beta.py b/tests/test_compute_f_beta.py index 62641a52f5..c8ed5aa887 100644 --- a/tests/test_compute_f_beta.py +++ b/tests/test_compute_f_beta.py @@ -9,22 +9,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -import numpy +import numpy as np import torch from monai.metrics import FBetaScore from tests.utils import assert_allclose +_device = "cuda:0" if torch.cuda.is_available() else "cpu" + class TestFBetaScore(unittest.TestCase): - def test_expecting_success(self): + def test_expecting_success_and_device(self): metric = FBetaScore() - metric( - y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]) - ) - assert_allclose(metric.aggregate()[0], torch.Tensor([0.714286]), atol=1e-6, rtol=1e-6) + y_pred = torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]], device=_device) + y = torch.tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]], device=_device) + metric(y_pred=y_pred, y=y) + result = metric.aggregate()[0] + assert_allclose(result, torch.Tensor([0.714286]), atol=1e-6, rtol=1e-6) + np.testing.assert_equal(result.device, y_pred.device) def test_expecting_success2(self): metric = FBetaScore(beta=0.5) @@ -56,7 +62,7 @@ def test_with_nan_values(self): metric = FBetaScore(get_not_nans=True) metric( y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), - y=torch.Tensor([[1, 0, 1], [numpy.NaN, numpy.NaN, numpy.NaN], [1, 0, 1]]), + y=torch.Tensor([[1, 0, 1], [np.NaN, np.NaN, np.NaN], [1, 0, 1]]), ) assert_allclose(metric.aggregate()[0][0], torch.Tensor([0.727273]), atol=1e-6, rtol=1e-6) diff --git a/tests/test_compute_froc.py b/tests/test_compute_froc.py index 91c1ea1977..1724c469d5 100644 --- a/tests/test_compute_froc.py +++ b/tests/test_compute_froc.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_compute_generalized_dice.py b/tests/test_compute_generalized_dice.py index 38f6e57d32..ab3d012c97 100644 --- a/tests/test_compute_generalized_dice.py +++ b/tests/test_compute_generalized_dice.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -17,11 +19,13 @@ from monai.metrics import GeneralizedDiceScore, compute_generalized_dice +_device = "cuda:0" if torch.cuda.is_available() else "cpu" + # keep background TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1) { - "y_pred": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]), - "y": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), + "y_pred": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=_device), + "y": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=_device), "include_background": True, }, [0.8], @@ -114,7 +118,12 @@ TEST_CASE_9 = [{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, [1.0000, 1.0000]] -class TestComputeMeanDice(unittest.TestCase): +class TestComputeGeneralizedDiceScore(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) + def test_device(self, input_data, _expected_value): + result = compute_generalized_dice(**input_data) + np.testing.assert_equal(result.device, input_data["y_pred"].device) + # Functional part tests @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9]) def test_value(self, input_data, expected_value): @@ -130,7 +139,6 @@ def test_nans(self, input_data, expected_value): # Samplewise tests @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_value_class(self, input_data, expected_value): - # same test as for compute_meandice vals = {} vals["y_pred"] = input_data.pop("y_pred") @@ -143,7 +151,6 @@ def test_value_class(self, input_data, expected_value): # Aggregation tests @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) def test_nans_class(self, params, input_data, expected_value): - generalized_dice_score = GeneralizedDiceScore(**params) generalized_dice_score(**input_data) result = generalized_dice_score.aggregate() diff --git a/tests/test_compute_ho_ver_maps.py b/tests/test_compute_ho_ver_maps.py index 5c4674dd04..50598cb57b 100644 --- a/tests/test_compute_ho_ver_maps.py +++ b/tests/test_compute_ho_ver_maps.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_compute_ho_ver_maps_d.py b/tests/test_compute_ho_ver_maps_d.py index 475e50bc70..27bb57988c 100644 --- a/tests/test_compute_ho_ver_maps_d.py +++ b/tests/test_compute_ho_ver_maps_d.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_compute_meandice.py b/tests/test_compute_meandice.py index 4dd5d77c4f..3526a5b413 100644 --- a/tests/test_compute_meandice.py +++ b/tests/test_compute_meandice.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -189,6 +191,7 @@ class TestComputeMeanDice(unittest.TestCase): def test_value(self, input_data, expected_value): result = compute_dice(**input_data) np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + np.testing.assert_equal(result.device, input_data["y_pred"].device) @parameterized.expand([TEST_CASE_3]) def test_nans(self, input_data, expected_value): @@ -198,7 +201,6 @@ def test_nans(self, input_data, expected_value): # DiceMetric class tests @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_10]) def test_value_class(self, input_data, expected_value): - # same test as for compute_meandice vals = {} vals["y_pred"] = input_data.pop("y_pred") @@ -210,7 +212,6 @@ def test_value_class(self, input_data, expected_value): @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]) def test_nans_class(self, params, input_data, expected_value): - dice_metric = DiceMetric(**params) dice_metric(**input_data) result, _ = dice_metric.aggregate() diff --git a/tests/test_compute_meaniou.py b/tests/test_compute_meaniou.py index 52a0223a2d..be6f1a039f 100644 --- a/tests/test_compute_meaniou.py +++ b/tests/test_compute_meaniou.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -189,6 +191,7 @@ class TestComputeMeanIoU(unittest.TestCase): def test_value(self, input_data, expected_value): result = compute_meaniou(**input_data) np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + np.testing.assert_equal(result.device, input_data["y_pred"].device) @parameterized.expand([TEST_CASE_3]) def test_nans(self, input_data, expected_value): @@ -198,7 +201,6 @@ def test_nans(self, input_data, expected_value): # MeanIoU class tests @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_10]) def test_value_class(self, input_data, expected_value): - # same test as for compute_meaniou vals = {} vals["y_pred"] = input_data.pop("y_pred") @@ -210,7 +212,6 @@ def test_value_class(self, input_data, expected_value): @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]) def test_nans_class(self, params, input_data, expected_value): - iou_metric = MeanIoU(**params) iou_metric(**input_data) result, _ = iou_metric.aggregate() diff --git a/tests/test_compute_panoptic_quality.py b/tests/test_compute_panoptic_quality.py index cf5d0deb2a..a5858e91d1 100644 --- a/tests/test_compute_panoptic_quality.py +++ b/tests/test_compute_panoptic_quality.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from typing import List @@ -94,6 +96,7 @@ class TestPanopticQualityMetric(unittest.TestCase): def test_value(self, input_params, expected_value): result = compute_panoptic_quality(**input_params) np.testing.assert_allclose(result.cpu().detach().item(), expected_value, atol=1e-4) + np.testing.assert_equal(result.device, input_params["pred"].device) @parameterized.expand([TEST_CLS_CASE_1, TEST_CLS_CASE_2, TEST_CLS_CASE_3, TEST_CLS_CASE_4, TEST_CLS_CASE_5]) def test_value_class(self, input_params, y_pred, y_gt, expected_value): diff --git a/tests/test_compute_regression_metrics.py b/tests/test_compute_regression_metrics.py index cab1184812..b0fde3afe9 100644 --- a/tests/test_compute_regression_metrics.py +++ b/tests/test_compute_regression_metrics.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from functools import partial @@ -59,7 +61,6 @@ def test_shape_reduction(self): for batch in batch_dims: for spatial in spatial_dims: for base in base_dims: - # create random tensors in_tensor = torch.rand((batch,) + (base,) * (spatial - 1)).to(device) @@ -102,7 +103,6 @@ def test_compare_numpy(self): for batch in batch_dims: for spatial in spatial_dims: for base in base_dims: - # create random tensors in_tensor_a = torch.rand((batch,) + (base,) * (spatial - 1)).to(device) in_tensor_b = torch.rand((batch,) + (base,) * (spatial - 1)).to(device) @@ -152,7 +152,6 @@ def test_same_input(self): for batch in batch_dims: for spatial in spatial_dims: for base in base_dims: - # create random tensors in_tensor = torch.rand((batch,) + (base,) * (spatial - 1)).to(device) @@ -178,7 +177,6 @@ def test_diff_input(self): for batch in batch_dims: for spatial in spatial_dims: for base in base_dims: - # create random tensors in_tensor_a = torch.zeros((batch,) + (base,) * (spatial - 1)).to(device) in_tensor_b = torch.ones((batch,) + (base,) * (spatial - 1)).to(device) diff --git a/tests/test_compute_roc_auc.py b/tests/test_compute_roc_auc.py index 0e57f1fe4a..2f080c76cb 100644 --- a/tests/test_compute_roc_auc.py +++ b/tests/test_compute_roc_auc.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_compute_variance.py b/tests/test_compute_variance.py index 2743fcdc79..8eaac10a6c 100644 --- a/tests/test_compute_variance.py +++ b/tests/test_compute_variance.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -111,6 +113,7 @@ class TestComputeVariance(unittest.TestCase): def test_value(self, input_data, expected_value): result = compute_variance(**input_data) np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + np.testing.assert_equal(result.device, input_data["y_pred"].device) @parameterized.expand([TEST_CASE_5, TEST_CASE_6]) def test_spatial_case(self, input_data, expected_value): diff --git a/tests/test_concat_itemsd.py b/tests/test_concat_itemsd.py index c0a058a0dd..322a95d7df 100644 --- a/tests/test_concat_itemsd.py +++ b/tests/test_concat_itemsd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_config_item.py b/tests/test_config_item.py index 817175e1e3..3d1272719a 100644 --- a/tests/test_config_item.py +++ b/tests/test_config_item.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from functools import partial from typing import Callable @@ -35,7 +37,7 @@ TEST_CASE_5 = [{"_target_": "LoadImaged", "_disabled_": "true", "keys": ["image"]}, dict] # test non-monai modules and excludes TEST_CASE_6 = [{"_target_": "torch.optim.Adam", "params": torch.nn.PReLU().parameters(), "lr": 1e-4}, torch.optim.Adam] -TEST_CASE_7 = [{"_target_": "decollate_batch", "detach": True, "pad": True}, partial] +TEST_CASE_7 = [{"_target_": "decollate_batch", "detach": True, "pad": True, "_mode_": "partial"}, partial] # test args contains "name" field TEST_CASE_8 = [ {"_target_": "RandTorchVisiond", "keys": "image", "name": "ColorJitter", "brightness": 0.25}, diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index d02a05c914..c4b50daed1 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest @@ -34,6 +36,13 @@ def case_pdb(sarg=None): parser.get_parsed_content() +@TimedCall(seconds=100, force_quit=True) +def case_pdb_inst(sarg=None): + config = {"transform": {"_target_": "Compose", "transforms": [], "_mode_": "debug"}} + parser = ConfigParser(config=config) + return parser.transform + + # test the resolved and parsed instances TEST_CASE_1 = [ { @@ -98,6 +107,8 @@ def __call__(self, a, b): TEST_CASE_4 = [{"A": 1, "B": "@A", "C": "@D", "E": "$'test' + '@F'"}] +TEST_CASE_5 = [{"training": {"A": 1, "A_B": 2}, "total": "$@training#A + @training#A_B + 1"}, 4] + class TestConfigParser(unittest.TestCase): def test_config_content(self): @@ -153,6 +164,8 @@ def test_parse(self, config, expected_ids, output_types): def test_function(self, config): parser = ConfigParser(config=config, globals={"TestClass": TestClass}) for id in config: + if id in ("compute", "cls_compute"): + parser[f"{id}#_mode_"] = "partial" func = parser.get_parsed_content(id=id) self.assertTrue(id in parser.ref_resolver.resolved_content) if id == "error_func": @@ -246,6 +259,14 @@ def test_lambda_reference(self): result = trans(np.ones(64)) self.assertTupleEqual(result.shape, (1, 8, 8)) + def test_non_str_target(self): + configs = { + "fwd": {"_target_": "$@model.forward", "x": "$torch.rand(1, 3, 256, 256)", "_mode_": "partial"}, + "model": {"_target_": "monai.networks.nets.resnet.resnet18", "pretrained": False, "spatial_dims": 2}, + } + self.assertTrue(callable(ConfigParser(config=configs).fwd)) + self.assertTupleEqual(tuple(ConfigParser(config=configs).fwd().shape), (1, 400)) + def test_error_instance(self): config = {"transform": {"_target_": "Compose", "transforms_wrong_key": []}} parser = ConfigParser(config=config) @@ -255,6 +276,32 @@ def test_error_instance(self): def test_pdb(self): with self.assertRaisesRegex(RuntimeError, ".*bdb.BdbQuit.*"): case_pdb() + self.assertEqual(case_pdb_inst(), None) # pdb.runcall without input is None + + def test_get_via_attributes(self): + config = { + "A": {"B": {"C": 1}}, + "my_dims": 2, + "dims_1": "$@my_dims + 1", + "patch_size": [8, 8], + "transform": {"_target_": "Lambda", "func": "$lambda x: x.reshape((1, *@patch_size))"}, + } + parser = ConfigParser(config=config) + self.assertEqual(parser.A, {"B": {"C": 1}}) + self.assertEqual(parser.dims_1, 3) + + trans = parser.transform + result = trans(np.ones(64)) + self.assertTupleEqual(result.shape, (1, 8, 8)) + + def test_builtin(self): + config = {"import statements": "$import math", "calc": {"_target_": "math.isclose", "a": 0.001, "b": 0.001}} + self.assertEqual(ConfigParser(config).calc, True) + + @parameterized.expand([TEST_CASE_5]) + def test_substring_reference(self, config, expected): + parser = ConfigParser(config=config) + self.assertEqual(parser.get_parsed_content("total"), expected) if __name__ == "__main__": diff --git a/tests/test_contrastive_loss.py b/tests/test_contrastive_loss.py index d0eb7d86f2..4cafa0a905 100644 --- a/tests/test_contrastive_loss.py +++ b/tests/test_contrastive_loss.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_convert_data_type.py b/tests/test_convert_data_type.py index 139d2e3f87..c3e4490ffe 100644 --- a/tests/test_convert_data_type.py +++ b/tests/test_convert_data_type.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import List, Tuple import numpy as np import torch @@ -20,14 +21,14 @@ from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, get_equivalent_dtype from tests.utils import TEST_NDARRAYS_ALL, assert_allclose -TESTS: List[Tuple] = [] +TESTS: list[tuple] = [] for in_type in TEST_NDARRAYS_ALL + (int, float): for out_type in TEST_NDARRAYS_ALL: TESTS.append((in_type(np.array(1.0)), out_type(np.array(1.0)), None, False)) # type: ignore if in_type is not float: TESTS.append((in_type(np.array(256)), out_type(np.array(255)), np.uint8, True)) # type: ignore -TESTS_LIST: List[Tuple] = [] +TESTS_LIST: list[tuple] = [] for in_type in TEST_NDARRAYS_ALL + (int, float): for out_type in TEST_NDARRAYS_ALL: TESTS_LIST.append( diff --git a/tests/test_convert_to_multi_channel.py b/tests/test_convert_to_multi_channel.py index b606fee04f..78c3c90688 100644 --- a/tests/test_convert_to_multi_channel.py +++ b/tests/test_convert_to_multi_channel.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_convert_to_multi_channeld.py b/tests/test_convert_to_multi_channeld.py index 7525f8d7e2..351adddb13 100644 --- a/tests/test_convert_to_multi_channeld.py +++ b/tests/test_convert_to_multi_channeld.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_convert_to_torchscript.py b/tests/test_convert_to_torchscript.py index a1c1471463..0b8e9a8141 100644 --- a/tests/test_convert_to_torchscript.py +++ b/tests/test_convert_to_torchscript.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_convolutions.py b/tests/test_convolutions.py index dc018248c1..1311401f1d 100644 --- a/tests/test_convolutions.py +++ b/tests/test_convolutions.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from monai.networks.blocks import Convolution, ResidualUnit diff --git a/tests/test_copy_itemsd.py b/tests/test_copy_itemsd.py index 8354f45bb5..ff4799a094 100644 --- a/tests/test_copy_itemsd.py +++ b/tests/test_copy_itemsd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_copy_model_state.py b/tests/test_copy_model_state.py index bc7b116e1f..2e7513b234 100644 --- a/tests/test_copy_model_state.py +++ b/tests/test_copy_model_state.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_correct_crop_centers.py b/tests/test_correct_crop_centers.py index 50478c7d5d..d2a95bf684 100644 --- a/tests/test_correct_crop_centers.py +++ b/tests/test_correct_crop_centers.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_create_cross_validation_datalist.py b/tests/test_create_cross_validation_datalist.py index 3a3e8481ea..d05a94f59e 100644 --- a/tests/test_create_cross_validation_datalist.py +++ b/tests/test_create_cross_validation_datalist.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_create_grid_and_affine.py b/tests/test_create_grid_and_affine.py index d70db45468..2b5890a777 100644 --- a/tests/test_create_grid_and_affine.py +++ b/tests/test_create_grid_and_affine.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_crf_cpu.py b/tests/test_crf_cpu.py index 46da3298bc..e29a4d69eb 100644 --- a/tests/test_crf_cpu.py +++ b/tests/test_crf_cpu.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -495,7 +497,6 @@ class CRFTestCaseCpu(unittest.TestCase): @parameterized.expand(TEST_CASES) def test(self, test_case_description, params, input, features, expected): - # Create input tensors input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=torch.device("cpu")) feature_tensor = torch.from_numpy(np.array(features)).to(dtype=torch.float, device=torch.device("cpu")) diff --git a/tests/test_crf_cuda.py b/tests/test_crf_cuda.py index ca25fe2de9..8529e2e6de 100644 --- a/tests/test_crf_cuda.py +++ b/tests/test_crf_cuda.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -496,7 +498,6 @@ class CRFTestCaseCuda(unittest.TestCase): @parameterized.expand(TEST_CASES) def test(self, test_case_description, params, input, features, expected): - # Create input tensors input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=torch.device("cuda")) feature_tensor = torch.from_numpy(np.array(features)).to(dtype=torch.float, device=torch.device("cuda")) diff --git a/tests/test_crop_foreground.py b/tests/test_crop_foreground.py index 2f46eba67c..b8a10722a4 100644 --- a/tests/test_crop_foreground.py +++ b/tests/test_crop_foreground.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -16,9 +18,10 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import CropForeground +from monai.transforms.lazy.functional import apply_transforms from tests.utils import TEST_NDARRAYS_ALL, assert_allclose -TEST_COORDS, TESTS = [], [] +TEST_COORDS, TESTS, TEST_LAZY_ERROR = [], [], [] for p in TEST_NDARRAYS_ALL: TEST_COORDS.append( @@ -26,6 +29,7 @@ {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0}, p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), p([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), + True, ] ) @@ -34,6 +38,7 @@ {"select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, p([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]), p([[[3]]]), + False, ] ) @@ -42,6 +47,7 @@ {"select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), p([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), + True, ] ) @@ -50,6 +56,7 @@ {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), + True, ] ) @@ -58,6 +65,7 @@ {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1], "allow_smaller": True}, p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + True, ] ) @@ -66,6 +74,7 @@ {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1], "allow_smaller": False}, p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), p([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + True, ] ) @@ -74,21 +83,23 @@ {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 4}, p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), p([[[1, 2, 1, 0], [2, 3, 2, 0], [1, 2, 1, 0], [0, 0, 0, 0]]]), + True, ] ) - TESTS.append( + TEST_LAZY_ERROR.append( [ {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 10}, p([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), p(np.zeros((1, 0, 0), dtype=np.int64)), + True, ] ) class TestCropForeground(unittest.TestCase): @parameterized.expand(TEST_COORDS + TESTS) - def test_value(self, arguments, image, expected_data): + def test_value(self, arguments, image, expected_data, _): cropper = CropForeground(**arguments) result = cropper(image) assert_allclose(result, expected_data, type_test=False) @@ -100,13 +111,39 @@ def test_value(self, arguments, image, expected_data): self.assertTupleEqual(inv.shape, image.shape) @parameterized.expand(TEST_COORDS) - def test_return_coords(self, arguments, image, _): + def test_return_coords(self, arguments, image, _expected_data, _align_corners): arguments["return_coords"] = True _, start_coord, end_coord = CropForeground(**arguments)(image) arguments["return_coords"] = False np.testing.assert_allclose(start_coord, np.asarray([1, 1])) np.testing.assert_allclose(end_coord, np.asarray([4, 4])) + @parameterized.expand(TEST_COORDS + TESTS) + def test_pending_ops(self, input_param, image, _expected_data, align_corners): + crop_fn = CropForeground(**input_param) + # non-lazy + expected = crop_fn(image) + self.assertIsInstance(expected, MetaTensor) + # lazy + crop_fn.lazy_evaluation = True + pending_result = crop_fn(image) + self.assertIsInstance(pending_result, MetaTensor) + assert_allclose(pending_result.peek_pending_affine(), expected.affine) + assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) + # only support nearest + result = apply_transforms(pending_result, mode="nearest", align_corners=align_corners)[0] + # compare + assert_allclose(result, expected, rtol=1e-5) + + @parameterized.expand(TEST_LAZY_ERROR) + def test_lazy_error(self, input_param, image, _expected_data, align_corners): + with self.assertRaises(ValueError): + crop_fn = CropForeground(**input_param) + # lazy + crop_fn.lazy_evaluation = True + pending_result = crop_fn(image) + return apply_transforms(pending_result, mode="nearest", align_corners=align_corners)[0] + if __name__ == "__main__": unittest.main() diff --git a/tests/test_crop_foregroundd.py b/tests/test_crop_foregroundd.py index fa36d2e065..d2604ef9cf 100644 --- a/tests/test_crop_foregroundd.py +++ b/tests/test_crop_foregroundd.py @@ -9,17 +9,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms import CropForegroundd +from monai.transforms.lazy.functional import apply_transforms from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TEST_POSITION, TESTS = [], [] for p in TEST_NDARRAYS_ALL: - TEST_POSITION.append( [ { @@ -38,6 +41,7 @@ ), }, p(np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]])), + True, ] ) TESTS.append( @@ -49,6 +53,7 @@ ) }, p(np.array([[[3]]])), + False, ] ) TESTS.append( @@ -60,6 +65,7 @@ ) }, p(np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]])), + True, ] ) TESTS.append( @@ -71,6 +77,7 @@ ) }, p(np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]])), + True, ] ) TESTS.append( @@ -89,6 +96,7 @@ ) }, p(np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])), + True, ] ) TESTS.append( @@ -121,6 +129,7 @@ ] ) ), + True, ] ) TESTS.append( @@ -132,7 +141,7 @@ "channel_indices": 0, "margin": 0, "k_divisible": [4, 6], - "mode": "edge", + "mode": "constant", }, { "img": p( @@ -142,14 +151,15 @@ ) ) }, - p(np.array([[[0, 2, 1, 2, 0, 0], [1, 1, 2, 1, 1, 1], [2, 2, 3, 2, 2, 2], [1, 1, 2, 1, 1, 1]]])), + p(np.array([[[0, 2, 1, 2, 0, 0], [1, 1, 2, 1, 1, 0], [2, 2, 3, 2, 2, 0], [1, 1, 2, 1, 1, 0]]])), + False, ] ) class TestCropForegroundd(unittest.TestCase): @parameterized.expand(TEST_POSITION + TESTS) - def test_value(self, arguments, input_data, expected_data): + def test_value(self, arguments, input_data, expected_data, _): cropper = CropForegroundd(**arguments) result = cropper(input_data) assert_allclose(result["img"], expected_data, type_test="tensor") @@ -161,7 +171,7 @@ def test_value(self, arguments, input_data, expected_data): self.assertTupleEqual(inv["label"].shape, input_data["label"].shape) @parameterized.expand(TEST_POSITION) - def test_foreground_position(self, arguments, input_data, _): + def test_foreground_position(self, arguments, input_data, _expected_data, _align_corners): result = CropForegroundd(**arguments)(input_data) np.testing.assert_allclose(result["foreground_start_coord"], np.array([1, 1])) np.testing.assert_allclose(result["foreground_end_coord"], np.array([4, 4])) @@ -172,6 +182,23 @@ def test_foreground_position(self, arguments, input_data, _): np.testing.assert_allclose(result["test_start_coord"], np.array([1, 1])) np.testing.assert_allclose(result["test_end_coord"], np.array([4, 4])) + @parameterized.expand(TEST_POSITION + TESTS) + def test_pending_ops(self, input_param, image, _expected_data, align_corners): + crop_fn = CropForegroundd(**input_param) + # non-lazy + expected = crop_fn(image)["img"] + self.assertIsInstance(expected, MetaTensor) + # lazy + crop_fn.lazy_evaluation = True + pending_result = crop_fn(image)["img"] + self.assertIsInstance(pending_result, MetaTensor) + assert_allclose(pending_result.peek_pending_affine(), expected.affine) + assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) + # only support nearest + result = apply_transforms(pending_result, mode="nearest", align_corners=align_corners)[0] + # compare + assert_allclose(result, expected, rtol=1e-5) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_cross_validation.py b/tests/test_cross_validation.py index 811dcea026..33d060560c 100644 --- a/tests/test_cross_validation.py +++ b/tests/test_cross_validation.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest diff --git a/tests/test_csv_dataset.py b/tests/test_csv_dataset.py index f288ac4b95..82a0f7afbd 100644 --- a/tests/test_csv_dataset.py +++ b/tests/test_csv_dataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_csv_iterable_dataset.py b/tests/test_csv_iterable_dataset.py index d6b84074ba..65a0a420a5 100644 --- a/tests/test_csv_iterable_dataset.py +++ b/tests/test_csv_iterable_dataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import sys import tempfile diff --git a/tests/test_csv_saver.py b/tests/test_csv_saver.py index 01796da00c..833d1134cf 100644 --- a/tests/test_csv_saver.py +++ b/tests/test_csv_saver.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import csv import os import tempfile diff --git a/tests/test_cucim_dict_transform.py b/tests/test_cucim_dict_transform.py index 4a6d2f9d51..6ebfd8bac7 100644 --- a/tests/test_cucim_dict_transform.py +++ b/tests/test_cucim_dict_transform.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_cucim_transform.py b/tests/test_cucim_transform.py index dd73ad94c0..5884358a74 100644 --- a/tests/test_cucim_transform.py +++ b/tests/test_cucim_transform.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_cumulative.py b/tests/test_cumulative.py index 16f5c1d1f5..3377fa815c 100644 --- a/tests/test_cumulative.py +++ b/tests/test_cumulative.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_cumulative_average.py b/tests/test_cumulative_average.py index 543433a6d3..d815d9be77 100644 --- a/tests/test_cumulative_average.py +++ b/tests/test_cumulative_average.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -32,7 +34,6 @@ class TestAverageMeter(unittest.TestCase): @parameterized.expand(TEST_CASE_1) def test_value_all(self, data): - # test orig self.run_test(data) diff --git a/tests/test_cumulative_average_dist.py b/tests/test_cumulative_average_dist.py index a5ee2fed15..17f4164838 100644 --- a/tests/test_cumulative_average_dist.py +++ b/tests/test_cumulative_average_dist.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -23,7 +25,6 @@ class DistributedCumulativeAverage(DistTestCase): @DistCall(nnodes=1, nproc_per_node=2) def test_value(self): - rank = dist.get_rank() nprocs = dist.get_world_size() is_cuda = dist.get_backend() == dist.Backend.NCCL diff --git a/tests/test_cv2_dist.py b/tests/test_cv2_dist.py index 552c26443a..edd2e1ec42 100644 --- a/tests/test_cv2_dist.py +++ b/tests/test_cv2_dist.py @@ -9,8 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest +import numpy as np import torch import torch.distributed as dist from torch.cuda.amp import autocast @@ -20,8 +23,8 @@ from tests.utils import skip_if_no_cuda -def main_worker(rank, ngpus_per_node): - dist.init_process_group(backend="nccl", init_method="tcp://127.0.0.1:12345", world_size=ngpus_per_node, rank=rank) +def main_worker(rank, ngpus_per_node, port): + dist.init_process_group(backend="nccl", init_method=f"tcp://127.0.0.1:{port}", world_size=ngpus_per_node, rank=rank) # `benchmark = True` is not compatible with openCV in PyTorch 22.09 docker for multi-gpu training torch.backends.cudnn.benchmark = True @@ -42,7 +45,8 @@ class TestCV2Dist(unittest.TestCase): def test_cv2_cuda_ops(self): print_config() ngpus_per_node = torch.cuda.device_count() - torch.multiprocessing.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node,)) + port = np.random.randint(10000, 20000) + torch.multiprocessing.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, port)) if __name__ == "__main__": diff --git a/tests/test_data_stats.py b/tests/test_data_stats.py index 2b652f8f62..6ef51bef92 100644 --- a/tests/test_data_stats.py +++ b/tests/test_data_stats.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import os import sys diff --git a/tests/test_data_statsd.py b/tests/test_data_statsd.py index 9c878addf5..374bc815ac 100644 --- a/tests/test_data_statsd.py +++ b/tests/test_data_statsd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import os import sys diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index b75c2a4ed8..2ee69687a6 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import sys import unittest diff --git a/tests/test_dataset.py b/tests/test_dataset.py index f8d4ed2104..667595caa4 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest @@ -26,7 +28,7 @@ class TestDataset(unittest.TestCase): @parameterized.expand([TEST_CASE_1]) def test_shape(self, expected_shape): - test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) with tempfile.TemporaryDirectory() as tempdir: nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) diff --git a/tests/test_dataset_func.py b/tests/test_dataset_func.py index b5871d7de1..afccd129fe 100644 --- a/tests/test_dataset_func.py +++ b/tests/test_dataset_func.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import json import os import tempfile diff --git a/tests/test_dataset_summary.py b/tests/test_dataset_summary.py index a5b5eee28f..b1cc578f32 100644 --- a/tests/test_dataset_summary.py +++ b/tests/test_dataset_summary.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import glob import os import tempfile @@ -37,7 +39,6 @@ class TestDatasetSummary(unittest.TestCase): def test_spacing_intensity(self): set_determinism(seed=0) with tempfile.TemporaryDirectory() as tempdir: - for i in range(5): im, seg = create_test_image_3d(32, 32, 32, num_seg_classes=1, num_objs=3, rad_max=6, channel_dim=0) n = nib.Nifti1Image(im, np.eye(4)) @@ -73,7 +74,6 @@ def test_spacing_intensity(self): def test_anisotropic_spacing(self): with tempfile.TemporaryDirectory() as tempdir: - pixdims = [[1.0, 1.0, 5.0], [1.0, 1.0, 4.0], [1.0, 1.0, 4.5], [1.0, 1.0, 2.0], [1.0, 1.0, 1.0]] for i in range(5): im, seg = create_test_image_3d(32, 32, 32, num_seg_classes=1, num_objs=3, rad_max=6, channel_dim=0) diff --git a/tests/test_decathlondataset.py b/tests/test_decathlondataset.py index b3844484ce..2e3a3cf541 100644 --- a/tests/test_decathlondataset.py +++ b/tests/test_decathlondataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import shutil import unittest diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 538eb38311..ba7d74eb6c 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -9,10 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import sys import unittest from enum import Enum -from typing import List, Tuple import numpy as np import torch @@ -45,12 +46,12 @@ KEYS = ["image"] -TESTS_DICT: List[Tuple] = [] +TESTS_DICT: list[tuple] = [] TESTS_DICT.append((SpatialPadd(KEYS, 150), RandFlipd(KEYS, prob=1.0, spatial_axis=1))) TESTS_DICT.append((RandRotate90d(KEYS, prob=0.0, max_k=1),)) TESTS_DICT.append((RandAffined(KEYS, prob=0.0, translate_range=10),)) -TESTS_LIST: List[Tuple] = [] +TESTS_LIST: list[tuple] = [] TESTS_LIST.append((SpatialPad(150), RandFlip(prob=1.0, spatial_axis=1))) TESTS_LIST.append((RandRotate90(prob=0.0, max_k=1),)) TESTS_LIST.append((RandAffine(prob=0.0, translate_range=10),)) diff --git a/tests/test_deepedit_interaction.py b/tests/test_deepedit_interaction.py index 42fd87607d..5dcc6205f7 100644 --- a/tests/test_deepedit_interaction.py +++ b/tests/test_deepedit_interaction.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_deepedit_transforms.py b/tests/test_deepedit_transforms.py index f608a4342f..7f4d4eee1e 100644 --- a/tests/test_deepedit_transforms.py +++ b/tests/test_deepedit_transforms.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_deepgrow_dataset.py b/tests/test_deepgrow_dataset.py index 5b3e40b1ee..d8a412ade9 100644 --- a/tests/test_deepgrow_dataset.py +++ b/tests/test_deepgrow_dataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import shutil import tempfile @@ -62,7 +64,7 @@ def _create_data(self, length=1, image_channel=1, with_label=True): else: image = np.random.randint(0, 2, size=(128, 128, 40, image_channel)) image_file = os.path.join(self.tempdir, f"image{i}.nii.gz") - nib.save(nib.Nifti1Image(image, affine), image_file) + nib.save(nib.Nifti1Image(image.astype(float), affine), image_file) if with_label: # 3 slices has label @@ -72,7 +74,7 @@ def _create_data(self, length=1, image_channel=1, with_label=True): label[0][0][2] = 1 label[0][1][2] = 1 label_file = os.path.join(self.tempdir, f"label{i}.nii.gz") - nib.save(nib.Nifti1Image(label, affine), label_file) + nib.save(nib.Nifti1Image(label.astype(float), affine), label_file) datalist.append({"image": image_file, "label": label_file}) else: datalist.append({"image": image_file}) diff --git a/tests/test_deepgrow_interaction.py b/tests/test_deepgrow_interaction.py index b040348b62..7cdbeed9f9 100644 --- a/tests/test_deepgrow_interaction.py +++ b/tests/test_deepgrow_interaction.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_deepgrow_transforms.py b/tests/test_deepgrow_transforms.py index bd20b45b6d..1328e13439 100644 --- a/tests/test_deepgrow_transforms.py +++ b/tests/test_deepgrow_transforms.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_delete_itemsd.py b/tests/test_delete_itemsd.py index 99d05fe787..1ec77f29fd 100644 --- a/tests/test_delete_itemsd.py +++ b/tests/test_delete_itemsd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import sys import time import unittest diff --git a/tests/test_denseblock.py b/tests/test_denseblock.py index 3a3f61860e..c14ca2ae7a 100644 --- a/tests/test_denseblock.py +++ b/tests/test_denseblock.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch.nn as nn diff --git a/tests/test_densenet.py b/tests/test_densenet.py index 66f27cba51..8354237a25 100644 --- a/tests/test_densenet.py +++ b/tests/test_densenet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from typing import TYPE_CHECKING from unittest import skipUnless diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index c94c300175..5d511f3821 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -9,10 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import warnings -from monai.utils import DeprecatedError, deprecated, deprecated_arg +from monai.utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default class TestDeprecatedRC(unittest.TestCase): @@ -232,7 +234,7 @@ def test_arg_except2_unknown(self): def afoo4(a, b=None): pass - self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2)) + afoo4(1, b=2) def test_arg_except3_unknown(self): """ @@ -244,8 +246,8 @@ def test_arg_except3_unknown(self): def afoo4(a, b=None, **kwargs): pass - self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2)) - self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2, c=3)) + afoo4(1, b=2) + afoo4(1, b=2, c=3) def test_replacement_arg(self): """ @@ -287,6 +289,159 @@ def afoo4(a, b=None, **kwargs): self.assertEqual(afoo4(a=1, b=2, c=3), (1, {"c": 3})) # prefers the new arg self.assertEqual(afoo4(1, 2, c=3), (1, {"c": 3})) # prefers the new positional arg + def test_deprecated_arg_default_explicit_default(self): + """ + Test deprecated arg default, where the default is explicitly set (no warning). + """ + + @deprecated_arg_default( + "b", old_default="a", new_default="b", since=self.prev_version, version_val=self.test_version + ) + def foo(a, b="a"): + return a, b + + with self.assertWarns(FutureWarning) as aw: + self.assertEqual(foo("a", "a"), ("a", "a")) + self.assertEqual(foo("a", "b"), ("a", "b")) + self.assertEqual(foo("a", "c"), ("a", "c")) + warnings.warn("fake warning", FutureWarning) + + self.assertEqual(aw.warning.args[0], "fake warning") + + def test_deprecated_arg_default_version_less_than_since(self): + """ + Test deprecated arg default, where the current version is less than `since` (no warning). + """ + + @deprecated_arg_default( + "b", old_default="a", new_default="b", since=self.test_version, version_val=self.prev_version + ) + def foo(a, b="a"): + return a, b + + with self.assertWarns(FutureWarning) as aw: + self.assertEqual(foo("a"), ("a", "a")) + self.assertEqual(foo("a", "a"), ("a", "a")) + warnings.warn("fake warning", FutureWarning) + + self.assertEqual(aw.warning.args[0], "fake warning") + + def test_deprecated_arg_default_warning_deprecated(self): + """ + Test deprecated arg default, where the default is used. + """ + + @deprecated_arg_default( + "b", old_default="a", new_default="b", since=self.prev_version, version_val=self.test_version + ) + def foo(a, b="a"): + return a, b + + self.assertWarns(FutureWarning, lambda: foo("a")) + + def test_deprecated_arg_default_warning_replaced(self): + """ + Test deprecated arg default, where the default is used. + """ + + @deprecated_arg_default( + "b", + old_default="a", + new_default="b", + since=self.prev_version, + replaced=self.prev_version, + version_val=self.test_version, + ) + def foo(a, b="a"): + return a, b + + self.assertWarns(FutureWarning, lambda: foo("a")) + + def test_deprecated_arg_default_warning_with_none_as_placeholder(self): + """ + Test deprecated arg default, where the default is used. + """ + + @deprecated_arg_default( + "b", old_default="a", new_default="b", since=self.prev_version, version_val=self.test_version + ) + def foo(a, b=None): + if b is None: + b = "a" + return a, b + + self.assertWarns(FutureWarning, lambda: foo("a")) + + @deprecated_arg_default( + "b", old_default="a", new_default="b", since=self.prev_version, version_val=self.test_version + ) + def foo2(a, b=None): + if b is None: + b = "b" + return a, b + + self.assertWarns(FutureWarning, lambda: foo2("a")) + + def test_deprecated_arg_default_errors(self): + """ + Test deprecated arg default, where the decorator is wrongly used. + """ + + # since > replaced + def since_grater_than_replaced(): + @deprecated_arg_default( + "b", + old_default="a", + new_default="b", + since=self.test_version, + replaced=self.prev_version, + version_val=self.test_version, + ) + def foo(a, b=None): + return a, b + + self.assertRaises(ValueError, since_grater_than_replaced) + + # argname doesnt exist + def argname_doesnt_exist(): + @deprecated_arg_default( + "other", old_default="a", new_default="b", since=self.test_version, version_val=self.test_version + ) + def foo(a, b=None): + return a, b + + self.assertRaises(ValueError, argname_doesnt_exist) + + # argname has no default + def argname_has_no_default(): + @deprecated_arg_default( + "a", + old_default="a", + new_default="b", + since=self.prev_version, + replaced=self.test_version, + version_val=self.test_version, + ) + def foo(a): + return a + + self.assertRaises(ValueError, argname_has_no_default) + + # new default is used but version < replaced + def argname_was_replaced_before_specified_version(): + @deprecated_arg_default( + "a", + old_default="a", + new_default="b", + since=self.prev_version, + replaced=self.next_version, + version_val=self.test_version, + ) + def foo(a, b="b"): + return a, b + + self.assertRaises(ValueError, argname_was_replaced_before_specified_version) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_detect_envelope.py b/tests/test_detect_envelope.py index 5eea0c8653..105d3a4ace 100644 --- a/tests/test_detect_envelope.py +++ b/tests/test_detect_envelope.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_detection_coco_metrics.py b/tests/test_detection_coco_metrics.py index b139377511..780031ee0c 100644 --- a/tests/test_detection_coco_metrics.py +++ b/tests/test_detection_coco_metrics.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import random import unittest diff --git a/tests/test_detector_boxselector.py b/tests/test_detector_boxselector.py index 6e22a7833a..8cc9b15911 100644 --- a/tests/test_detector_boxselector.py +++ b/tests/test_detector_boxselector.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_detector_utils.py b/tests/test_detector_utils.py index b8ae390016..41716934b5 100644 --- a/tests/test_detector_utils.py +++ b/tests/test_detector_utils.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import random import unittest diff --git a/tests/test_dev_collate.py b/tests/test_dev_collate.py index 83dbd71d28..97028f2597 100644 --- a/tests/test_dev_collate.py +++ b/tests/test_dev_collate.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import unittest diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py index 1f43dd8c9a..13b4952ab3 100644 --- a/tests/test_dice_ce_loss.py +++ b/tests/test_dice_ce_loss.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_dice_focal_loss.py b/tests/test_dice_focal_loss.py index af3e868654..ee5b49f456 100644 --- a/tests/test_dice_focal_loss.py +++ b/tests/test_dice_focal_loss.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index 223b09e624..e7f64ccfb3 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_dints_cell.py b/tests/test_dints_cell.py index a5da39bae9..21cef39d68 100644 --- a/tests/test_dints_cell.py +++ b/tests/test_dints_cell.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_dints_mixop.py b/tests/test_dints_mixop.py index b686069173..09d2e7a423 100644 --- a/tests/test_dints_mixop.py +++ b/tests/test_dints_mixop.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_dints_network.py b/tests/test_dints_network.py index 08e75fab98..09059144be 100644 --- a/tests/test_dints_network.py +++ b/tests/test_dints_network.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_discriminator.py b/tests/test_discriminator.py index aa9b9720c4..62635e286e 100644 --- a/tests/test_discriminator.py +++ b/tests/test_discriminator.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_divisible_pad.py b/tests/test_divisible_pad.py index df610c4939..892ffc4cd2 100644 --- a/tests/test_divisible_pad.py +++ b/tests/test_divisible_pad.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized @@ -38,6 +40,10 @@ def test_pad_kwargs(self): unchanged_slices = [slice(None), slice(None, 8), slice(None, 4)] self.pad_test_kwargs(unchanged_slices, **kwargs) + @parameterized.expand(TESTS) + def test_pending_ops(self, input_param, input_shape, _): + self.pad_test_pending_ops(input_param, input_shape) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_divisible_padd.py b/tests/test_divisible_padd.py index 93e5a879f0..409dd9cb17 100644 --- a/tests/test_divisible_padd.py +++ b/tests/test_divisible_padd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized @@ -31,6 +33,10 @@ def test_pad(self, input_param, input_shape, expected_shape): modes = ["constant", NumpyPadMode.CONSTANT, PytorchPadMode.CONSTANT, "edge", NumpyPadMode.EDGE] self.pad_test(input_param, input_shape, expected_shape, modes) + @parameterized.expand(TESTS) + def test_pending_ops(self, input_param, input_shape, _): + self.pad_test_pending_ops(input_param, input_shape) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_download_and_extract.py b/tests/test_download_and_extract.py index e6045cada9..696bcfc78f 100644 --- a/tests/test_download_and_extract.py +++ b/tests/test_download_and_extract.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_downsample_block.py b/tests/test_downsample_block.py index ac2acb0845..cd40be4306 100644 --- a/tests/test_downsample_block.py +++ b/tests/test_downsample_block.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_drop_path.py b/tests/test_drop_path.py index f8ea454228..ab2150e548 100644 --- a/tests/test_drop_path.py +++ b/tests/test_drop_path.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_ds_loss.py b/tests/test_ds_loss.py index dc67b651a3..51200d9584 100644 --- a/tests/test_ds_loss.py +++ b/tests/test_ds_loss.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_dvf2ddf.py b/tests/test_dvf2ddf.py index d061cca7ff..f18b5b7297 100644 --- a/tests/test_dvf2ddf.py +++ b/tests/test_dvf2ddf.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index e14d427640..68694d8983 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -9,19 +9,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import Any, Sequence, Union +from typing import Any, Sequence import torch from parameterized import parameterized from monai.networks import eval_mode from monai.networks.nets import DynUNet -from tests.utils import assert_allclose, skip_if_no_cuda, skip_if_windows, test_script_save +from tests.utils import assert_allclose, test_script_save device = "cuda" if torch.cuda.is_available() else "cpu" -strides: Sequence[Union[Sequence[int], int]] +strides: Sequence[Sequence[int] | int] kernel_size: Sequence[Any] expected_shape: Sequence[Any] @@ -120,8 +122,9 @@ def test_script(self): test_script_save(net, test_data) -@skip_if_no_cuda -@skip_if_windows +# @skip_if_no_cuda +# @skip_if_windows +@unittest.skip("temporary skip for 22.12") class TestDynUNetWithInstanceNorm3dNVFuser(unittest.TestCase): @parameterized.expand([TEST_CASE_DYNUNET_3D[0]]) def test_consistency(self, input_param, input_shape, _): diff --git a/tests/test_dynunet_block.py b/tests/test_dynunet_block.py index 1c83552766..b34ccb31ba 100644 --- a/tests/test_dynunet_block.py +++ b/tests/test_dynunet_block.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_efficientnet.py b/tests/test_efficientnet.py index d56f901af7..5bdad5a568 100644 --- a/tests/test_efficientnet.py +++ b/tests/test_efficientnet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest from typing import TYPE_CHECKING @@ -319,7 +321,6 @@ def test_drop_connect_layer(self): # testing 1D, 2D and 3D shape for rand_tensor_shape in [(512, 16, 4), (384, 16, 4, 4), (256, 16, 4, 4, 4)]: - # test validation mode, out tensor == in tensor training = False for p in p_list: diff --git a/tests/test_ensemble_evaluator.py b/tests/test_ensemble_evaluator.py index 15ea6a0952..40a4a72dd5 100644 --- a/tests/test_ensemble_evaluator.py +++ b/tests/test_ensemble_evaluator.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_ensure_channel_first.py b/tests/test_ensure_channel_first.py index d8dba562bb..027b18b7dd 100644 --- a/tests/test_ensure_channel_first.py +++ b/tests/test_ensure_channel_first.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_ensure_channel_firstd.py b/tests/test_ensure_channel_firstd.py index 44bb7e40f4..08e2709641 100644 --- a/tests/test_ensure_channel_firstd.py +++ b/tests/test_ensure_channel_firstd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_ensure_tuple.py b/tests/test_ensure_tuple.py index ea580871da..dc6649ec4c 100644 --- a/tests/test_ensure_tuple.py +++ b/tests/test_ensure_tuple.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -48,5 +50,4 @@ def test_value(self, input, expected_value, wrap_array=False): if __name__ == "__main__": - unittest.main() diff --git a/tests/test_ensure_type.py b/tests/test_ensure_type.py index 9325e0b601..7d6b7ca586 100644 --- a/tests/test_ensure_type.py +++ b/tests/test_ensure_type.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_ensure_typed.py b/tests/test_ensure_typed.py index 789afd1a46..98a41b5430 100644 --- a/tests/test_ensure_typed.py +++ b/tests/test_ensure_typed.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_enum_bound_interp.py b/tests/test_enum_bound_interp.py index 7607619e7a..5a63fc05af 100644 --- a/tests/test_enum_bound_interp.py +++ b/tests/test_enum_bound_interp.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from monai.utils import optional_import diff --git a/tests/test_eval_mode.py b/tests/test_eval_mode.py index bc9c97d238..8458753e1f 100644 --- a/tests/test_eval_mode.py +++ b/tests/test_eval_mode.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_evenly_divisible_all_gather_dist.py b/tests/test_evenly_divisible_all_gather_dist.py index 5e4b0b3b5d..f338944daa 100644 --- a/tests/test_evenly_divisible_all_gather_dist.py +++ b/tests/test_evenly_divisible_all_gather_dist.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_factorized_increase.py b/tests/test_factorized_increase.py index a86f5a2db9..f7642ff357 100644 --- a/tests/test_factorized_increase.py +++ b/tests/test_factorized_increase.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_factorized_reduce.py b/tests/test_factorized_reduce.py index d14418233e..224a0cb351 100644 --- a/tests/test_factorized_reduce.py +++ b/tests/test_factorized_reduce.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_fastmri_reader.py b/tests/test_fastmri_reader.py index 30393ffc56..b15bd4b6a2 100644 --- a/tests/test_fastmri_reader.py +++ b/tests/test_fastmri_reader.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_fft_utils.py b/tests/test_fft_utils.py index d5e3a22eaa..971df2b411 100644 --- a/tests/test_fft_utils.py +++ b/tests/test_fft_utils.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_fg_bg_to_indices.py b/tests/test_fg_bg_to_indices.py index 03eb770d6d..7d88bb7ee9 100644 --- a/tests/test_fg_bg_to_indices.py +++ b/tests/test_fg_bg_to_indices.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_fg_bg_to_indicesd.py b/tests/test_fg_bg_to_indicesd.py index 3be795919f..d0d1ae5fb6 100644 --- a/tests/test_fg_bg_to_indicesd.py +++ b/tests/test_fg_bg_to_indicesd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized @@ -18,7 +20,6 @@ TEST_CASES = [] for p in TEST_NDARRAYS: - TEST_CASES.append( [ {"keys": "label", "image_key": None, "image_threshold": 0.0, "output_shape": None}, diff --git a/tests/test_file_basename.py b/tests/test_file_basename.py index dc8b1316a2..93e2027575 100644 --- a/tests/test_file_basename.py +++ b/tests/test_file_basename.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_fill_holes.py b/tests/test_fill_holes.py index 688c65005e..65c59d49eb 100644 --- a/tests/test_fill_holes.py +++ b/tests/test_fill_holes.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_fill_holesd.py b/tests/test_fill_holesd.py index 7711df36b3..3f98dab1bf 100644 --- a/tests/test_fill_holesd.py +++ b/tests/test_fill_holesd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_fl_exchange_object.py b/tests/test_fl_exchange_object.py index bb2d0372db..293f9d518b 100644 --- a/tests/test_fl_exchange_object.py +++ b/tests/test_fl_exchange_object.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_fl_monai_algo.py b/tests/test_fl_monai_algo.py index cce01a169f..c4c5da00bb 100644 --- a/tests/test_fl_monai_algo.py +++ b/tests/test_fl_monai_algo.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import shutil import tempfile diff --git a/tests/test_fl_monai_algo_dist.py b/tests/test_fl_monai_algo_dist.py index 11f64ea318..36c2f419b3 100644 --- a/tests/test_fl_monai_algo_dist.py +++ b/tests/test_fl_monai_algo_dist.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest from os.path import join as pathjoin diff --git a/tests/test_fl_monai_algo_stats.py b/tests/test_fl_monai_algo_stats.py index fd2b73ea85..1955c35b36 100644 --- a/tests/test_fl_monai_algo_stats.py +++ b/tests/test_fl_monai_algo_stats.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest diff --git a/tests/test_flatten_sub_keysd.py b/tests/test_flatten_sub_keysd.py index 336d0c296e..997f203870 100644 --- a/tests/test_flatten_sub_keysd.py +++ b/tests/test_flatten_sub_keysd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_flexible_unet.py b/tests/test_flexible_unet.py index 123d494e9a..1218ce6e85 100644 --- a/tests/test_flexible_unet.py +++ b/tests/test_flexible_unet.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import Dict, List, Type, Union import torch from parameterized import parameterized @@ -46,12 +47,10 @@ def get_encoder_parameters(cls): @classmethod def num_channels_per_output(cls): - return [(32, 64, 128, 256, 512, 1024), (32, 64, 128, 256), (32, 64, 128, 256), (32, 64, 128, 256)] @classmethod def num_outputs(cls): - return [6, 4, 4, 4] @classmethod @@ -88,14 +87,14 @@ def get_inplanes(): return [64, 128, 256, 512] @classmethod - def get_encoder_parameters(cls) -> List[Dict]: + def get_encoder_parameters(cls) -> list[dict]: """ Get parameter list to initialize encoder networks. Each parameter dict must have `spatial_dims`, `in_channels` and `pretrained` parameters. """ parameter_list = [] - res_type: Union[Type[ResNetBlock], Type[ResNetBottleneck]] + res_type: type[ResNetBlock] | type[ResNetBottleneck] for backbone in range(len(cls.backbone_names)): if backbone < 3: res_type = ResNetBlock @@ -174,29 +173,32 @@ def make_shape_cases( num_classes=10, input_shape=64, norm=("batch", {"eps": 1e-3, "momentum": 0.01}), + upsample=("nontrainable", "deconv", "pixelshuffle"), ): ret_tests = [] for spatial_dim in spatial_dims: # selected spatial_dims for batch in batches: # check single batch as well as multiple batch input for model in models: # selected models for is_pretrained in pretrained: # pretrained or not pretrained - if ("resnet" in model) and is_pretrained: - continue - kwargs = { - "in_channels": in_channels, - "out_channels": num_classes, - "backbone": model, - "pretrained": is_pretrained, - "spatial_dims": spatial_dim, - "norm": norm, - } - ret_tests.append( - [ - kwargs, - (batch, in_channels) + (input_shape,) * spatial_dim, - (batch, num_classes) + (input_shape,) * spatial_dim, - ] - ) + for upsample_method in upsample: + if ("resnet" in model) and is_pretrained: + continue + kwargs = { + "in_channels": in_channels, + "out_channels": num_classes, + "backbone": model, + "pretrained": is_pretrained, + "spatial_dims": spatial_dim, + "norm": norm, + "upsample": upsample_method, + } + ret_tests.append( + [ + kwargs, + (batch, in_channels) + (input_shape,) * spatial_dim, + (batch, num_classes) + (input_shape,) * spatial_dim, + ] + ) return ret_tests diff --git a/tests/test_flip.py b/tests/test_flip.py index c5a281b127..287852c2c1 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -18,6 +20,7 @@ from monai.data.meta_obj import set_track_meta from monai.data.meta_tensor import MetaTensor from monai.transforms import Flip +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_DEVICES, TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)] @@ -41,21 +44,27 @@ def test_invalid_inputs(self, _, spatial_axis, raises): def test_correct_results(self, _, spatial_axis): for p in TEST_NDARRAYS_ALL: im = p(self.imt[0]) - flip = Flip(spatial_axis=spatial_axis) + init_param = {"spatial_axis": spatial_axis} + flip = Flip(**init_param) expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) - result = flip(im) + call_param = {"img": im} + result = flip(**call_param) + test_resampler_lazy(flip, result, init_param, call_param) assert_allclose(result, p(expected), type_test="tensor") test_local_inversion(flip, result, im) @parameterized.expand(TORCH_CASES) - def test_torch(self, init_param, img: torch.Tensor, track_meta: bool, device): + def test_torch(self, spatial_axis, img: torch.Tensor, track_meta: bool, device): set_track_meta(track_meta) img = img.to(device) - xform = Flip(init_param) - res = xform(img) + init_param = {"spatial_axis": spatial_axis} + xform = Flip(**init_param) + call_param = {"img": img} + res = xform(**call_param) self.assertEqual(img.shape, res.shape) if track_meta: + test_resampler_lazy(xform, res, init_param, call_param) self.assertIsInstance(res, MetaTensor) else: self.assertNotIsInstance(res, MetaTensor) diff --git a/tests/test_flipd.py b/tests/test_flipd.py index b0e656b83f..19f9ed0882 100644 --- a/tests/test_flipd.py +++ b/tests/test_flipd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -19,6 +21,7 @@ from monai.data.meta_obj import set_track_meta from monai.data.meta_tensor import MetaTensor from monai.transforms import Flipd +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_DEVICES, TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)] @@ -41,22 +44,28 @@ def test_invalid_cases(self, _, spatial_axis, raises): @parameterized.expand(VALID_CASES) def test_correct_results(self, _, spatial_axis): for p in TEST_NDARRAYS_ALL: - flip = Flipd(keys="img", spatial_axis=spatial_axis) + init_param = {"keys": "img", "spatial_axis": spatial_axis} + flip = Flipd(**init_param) expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) im = p(self.imt[0]) - result = flip({"img": im})["img"] - assert_allclose(result, p(expected), type_test="tensor") - test_local_inversion(flip, {"img": result}, {"img": im}, "img") + call_param = {"data": {"img": im}} + result = flip(**call_param) + test_resampler_lazy(flip, result, init_param, call_param, output_key="img") + assert_allclose(result["img"], p(expected), type_test="tensor") + test_local_inversion(flip, {"img": result["img"]}, {"img": im}, "img") @parameterized.expand(TORCH_CASES) - def test_torch(self, init_param, img: torch.Tensor, track_meta: bool, device): + def test_torch(self, spatial_axis, img: torch.Tensor, track_meta: bool, device): set_track_meta(track_meta) img = img.to(device) - xform = Flipd("image", init_param) - res = xform({"image": img}) + init_param = {"keys": "image", "spatial_axis": spatial_axis} + xform = Flipd(**init_param) + call_param = {"data": {"image": img}} + res = xform(**call_param) # type: ignore self.assertEqual(img.shape, res["image"].shape) if track_meta: + test_resampler_lazy(xform, res, init_param, call_param, output_key="image") self.assertIsInstance(res["image"], MetaTensor) else: self.assertNotIsInstance(res["image"], MetaTensor) diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index 6ac23fef36..5f30b7b07d 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_folder_layout.py b/tests/test_folder_layout.py index f7291933a3..d6d4bdf679 100644 --- a/tests/test_folder_layout.py +++ b/tests/test_folder_layout.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_foreground_mask.py b/tests/test_foreground_mask.py index 160db5bae3..eb59ae2db6 100644 --- a/tests/test_foreground_mask.py +++ b/tests/test_foreground_mask.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_foreground_maskd.py b/tests/test_foreground_maskd.py index 3c8aa08d7f..24cb233c30 100644 --- a/tests/test_foreground_maskd.py +++ b/tests/test_foreground_maskd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_fourier.py b/tests/test_fourier.py index b500f266d7..3613db989f 100644 --- a/tests/test_fourier.py +++ b/tests/test_fourier.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_fpn_block.py b/tests/test_fpn_block.py index a86cd22a19..c6121c5b98 100644 --- a/tests/test_fpn_block.py +++ b/tests/test_fpn_block.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from collections import OrderedDict diff --git a/tests/test_from_engine_hovernet.py b/tests/test_from_engine_hovernet.py index 201fc356cd..227fa66baa 100644 --- a/tests/test_from_engine_hovernet.py +++ b/tests/test_from_engine_hovernet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_fullyconnectednet.py b/tests/test_fullyconnectednet.py index 6378ec9718..94fc4caa6e 100644 --- a/tests/test_fullyconnectednet.py +++ b/tests/test_fullyconnectednet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_gaussian.py b/tests/test_gaussian.py index 461b11d076..b98507b793 100644 --- a/tests/test_gaussian.py +++ b/tests/test_gaussian.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_gaussian_filter.py b/tests/test_gaussian_filter.py index c4ffe56896..1beee579e8 100644 --- a/tests/test_gaussian_filter.py +++ b/tests/test_gaussian_filter.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_gaussian_sharpen.py b/tests/test_gaussian_sharpen.py index 248e3df4d5..2509a4fc26 100644 --- a/tests/test_gaussian_sharpen.py +++ b/tests/test_gaussian_sharpen.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_gaussian_sharpend.py b/tests/test_gaussian_sharpend.py index 0478007809..75ea915d96 100644 --- a/tests/test_gaussian_sharpend.py +++ b/tests/test_gaussian_sharpend.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_gaussian_smooth.py b/tests/test_gaussian_smooth.py index d5e0875f05..38b29bbd17 100644 --- a/tests/test_gaussian_smooth.py +++ b/tests/test_gaussian_smooth.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_gaussian_smoothd.py b/tests/test_gaussian_smoothd.py index 2e968461e8..8702c073c8 100644 --- a/tests/test_gaussian_smoothd.py +++ b/tests/test_gaussian_smoothd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_generalized_dice_focal_loss.py b/tests/test_generalized_dice_focal_loss.py index ef8661c88d..8905da8106 100644 --- a/tests/test_generalized_dice_focal_loss.py +++ b/tests/test_generalized_dice_focal_loss.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index 619814037b..b8256a41a9 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_generalized_wasserstein_dice_loss.py b/tests/test_generalized_wasserstein_dice_loss.py index 49a5aa0556..7b85fdc5b6 100644 --- a/tests/test_generalized_wasserstein_dice_loss.py +++ b/tests/test_generalized_wasserstein_dice_loss.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_generate_distance_map.py b/tests/test_generate_distance_map.py index ffd432af8c..724a335e1a 100644 --- a/tests/test_generate_distance_map.py +++ b/tests/test_generate_distance_map.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_generate_distance_mapd.py b/tests/test_generate_distance_mapd.py index 98bd5007a9..17c5aa782b 100644 --- a/tests/test_generate_distance_mapd.py +++ b/tests/test_generate_distance_mapd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_generate_instance_border.py b/tests/test_generate_instance_border.py index ceff4a915e..8634bb7d77 100644 --- a/tests/test_generate_instance_border.py +++ b/tests/test_generate_instance_border.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_generate_instance_borderd.py b/tests/test_generate_instance_borderd.py index 800ed92ed2..fc81e8f87c 100644 --- a/tests/test_generate_instance_borderd.py +++ b/tests/test_generate_instance_borderd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_generate_instance_centroid.py b/tests/test_generate_instance_centroid.py index b7293df0ee..f9fdc602a9 100644 --- a/tests/test_generate_instance_centroid.py +++ b/tests/test_generate_instance_centroid.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_generate_instance_centroidd.py b/tests/test_generate_instance_centroidd.py index f989de5ff2..92e45cdf84 100644 --- a/tests/test_generate_instance_centroidd.py +++ b/tests/test_generate_instance_centroidd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_generate_instance_contour.py b/tests/test_generate_instance_contour.py index 8c43bf5bc5..9058855e62 100644 --- a/tests/test_generate_instance_contour.py +++ b/tests/test_generate_instance_contour.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -46,7 +48,6 @@ class TestGenerateInstanceContour(unittest.TestCase): @parameterized.expand(TEST_CASE) def test_shape(self, in_type, test_data, min_num_points, offset, expected): - inst_bbox = get_bbox(test_data[None]) inst_map = test_data[inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]] result = GenerateInstanceContour(min_num_points=min_num_points)(in_type(inst_map[None]), offset=offset) diff --git a/tests/test_generate_instance_contourd.py b/tests/test_generate_instance_contourd.py index e92020c6bc..22e3669850 100644 --- a/tests/test_generate_instance_contourd.py +++ b/tests/test_generate_instance_contourd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_generate_instance_type.py b/tests/test_generate_instance_type.py index 8a083d19b7..354f8640ae 100644 --- a/tests/test_generate_instance_type.py +++ b/tests/test_generate_instance_type.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_generate_instance_typed.py b/tests/test_generate_instance_typed.py index 08d9f550a9..84a5344503 100644 --- a/tests/test_generate_instance_typed.py +++ b/tests/test_generate_instance_typed.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_generate_label_classes_crop_centers.py b/tests/test_generate_label_classes_crop_centers.py index 4f64aadc26..c276171bd5 100644 --- a/tests/test_generate_label_classes_crop_centers.py +++ b/tests/test_generate_label_classes_crop_centers.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy @@ -26,7 +28,7 @@ "label_spatial_shape": [3, 3, 3], "indices": [[3, 12, 21], [1, 9, 18]], }, - list, + tuple, 2, 3, ] @@ -39,7 +41,7 @@ "label_spatial_shape": [3, 3, 3], "indices": [[3, 12, 21], [1, 9, 18]], }, - list, + tuple, 1, 3, ] diff --git a/tests/test_generate_param_groups.py b/tests/test_generate_param_groups.py index 7ae42b8ec6..8301e40188 100644 --- a/tests/test_generate_param_groups.py +++ b/tests/test_generate_param_groups.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_generate_pos_neg_label_crop_centers.py b/tests/test_generate_pos_neg_label_crop_centers.py index d1a208770f..13b7b728b4 100644 --- a/tests/test_generate_pos_neg_label_crop_centers.py +++ b/tests/test_generate_pos_neg_label_crop_centers.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy @@ -28,7 +30,7 @@ "fg_indices": [1, 9, 18], "bg_indices": [3, 12, 21], }, - list, + tuple, 2, 3, ], @@ -41,7 +43,7 @@ "fg_indices": [], "bg_indices": [3, 12, 21], }, - list, + tuple, 2, 3, ], diff --git a/tests/test_generate_spatial_bounding_box.py b/tests/test_generate_spatial_bounding_box.py index d27d5a570f..a67e7d0175 100644 --- a/tests/test_generate_spatial_bounding_box.py +++ b/tests/test_generate_spatial_bounding_box.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_generate_succinct_contour.py b/tests/test_generate_succinct_contour.py index 478c23b522..1c60e99546 100644 --- a/tests/test_generate_succinct_contour.py +++ b/tests/test_generate_succinct_contour.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_generate_succinct_contourd.py b/tests/test_generate_succinct_contourd.py index b34142ec0d..e94a02fed5 100644 --- a/tests/test_generate_succinct_contourd.py +++ b/tests/test_generate_succinct_contourd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_generate_watershed_markers.py b/tests/test_generate_watershed_markers.py index 92b48eeef1..a763361913 100644 --- a/tests/test_generate_watershed_markers.py +++ b/tests/test_generate_watershed_markers.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_generate_watershed_markersd.py b/tests/test_generate_watershed_markersd.py index 22ce7fae0a..76d4ec1ae6 100644 --- a/tests/test_generate_watershed_markersd.py +++ b/tests/test_generate_watershed_markersd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_generate_watershed_mask.py b/tests/test_generate_watershed_mask.py index 0fdb4fb428..1cc35dca5c 100644 --- a/tests/test_generate_watershed_mask.py +++ b/tests/test_generate_watershed_mask.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_generate_watershed_maskd.py b/tests/test_generate_watershed_maskd.py index 65b9353fcc..aa6d5bf03a 100644 --- a/tests/test_generate_watershed_maskd.py +++ b/tests/test_generate_watershed_maskd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_generator.py b/tests/test_generator.py index 617655f86e..c336acf7ef 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_get_equivalent_dtype.py b/tests/test_get_equivalent_dtype.py index a4df3ac2ac..01f8adca73 100644 --- a/tests/test_get_equivalent_dtype.py +++ b/tests/test_get_equivalent_dtype.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_get_extreme_points.py b/tests/test_get_extreme_points.py index 457351b98c..1338ba0e2c 100644 --- a/tests/test_get_extreme_points.py +++ b/tests/test_get_extreme_points.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_get_layers.py b/tests/test_get_layers.py index 6109052d1f..ad0be1a5c4 100644 --- a/tests/test_get_layers.py +++ b/tests/test_get_layers.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_get_package_version.py b/tests/test_get_package_version.py index c4e15c9d09..1881d79602 100644 --- a/tests/test_get_package_version.py +++ b/tests/test_get_package_version.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from monai.utils.module import get_package_version diff --git a/tests/test_get_unique_labels.py b/tests/test_get_unique_labels.py index 67953a3205..e550882243 100644 --- a/tests/test_get_unique_labels.py +++ b/tests/test_get_unique_labels.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_gibbs_noise.py b/tests/test_gibbs_noise.py index e40eda38db..aad5d6fea6 100644 --- a/tests/test_gibbs_noise.py +++ b/tests/test_gibbs_noise.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy diff --git a/tests/test_gibbs_noised.py b/tests/test_gibbs_noised.py index 6662e9e17c..3aa69b7280 100644 --- a/tests/test_gibbs_noised.py +++ b/tests/test_gibbs_noised.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy diff --git a/tests/test_giou_loss.py b/tests/test_giou_loss.py index 25cc258054..e794ddab30 100644 --- a/tests/test_giou_loss.py +++ b/tests/test_giou_loss.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_global_mutual_information_loss.py b/tests/test_global_mutual_information_loss.py index d53d6c9711..b67ed71725 100644 --- a/tests/test_global_mutual_information_loss.py +++ b/tests/test_global_mutual_information_loss.py @@ -8,6 +8,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest @@ -24,28 +26,28 @@ EXPECTED_VALUE = { "xyz_translation": [ - -1.5860259532928467, - -0.5957175493240356, - -0.3855515122413635, - -0.28728482127189636, - -0.23416118323802948, - -0.19534644484519958, - -0.17001715302467346, - -0.15043553709983826, - -0.1366637945175171, - -0.12534910440444946, + -1.5860257, + -0.62433463, + -0.38217825, + -0.2905613, + -0.23233329, + -0.1961407, + -0.16905619, + -0.15100679, + -0.13666219, + -0.12635908, ], "xyz_rotation": [ - -1.5860259532928467, - -0.29977330565452576, - -0.18411292135715485, - -0.1582011878490448, - -0.16107326745986938, - -0.165723517537117, - -0.1970357596874237, - -0.1755618453025818, - -0.17100191116333008, - -0.17264796793460846, + -1.5860257, + -0.30265224, + -0.18666176, + -0.15887907, + -0.1625064, + -0.16603896, + -0.19222091, + -0.18158069, + -0.167644, + -0.16698098, ], } @@ -82,9 +84,13 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0. numpy array of shape HWD """ transform_list = [ - transforms.LoadImaged(keys="img"), + transforms.LoadImaged(keys="img", image_only=True), transforms.Affined( - keys="img", translate_params=translate_params, rotate_params=rotate_params, device=None + keys="img", + translate_params=translate_params, + rotate_params=rotate_params, + device=None, + padding_mode="border", ), transforms.NormalizeIntensityd(keys=["img"]), ] @@ -92,7 +98,7 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0. return transformation({"img": FILE_PATH})["img"] a1 = transformation() - a1 = torch.tensor(a1).unsqueeze(0).unsqueeze(0).to(device) + a1 = a1.clone().unsqueeze(0).unsqueeze(0).to(device) for mode in transform_params_dict: transform_params_list = transform_params_dict[mode] @@ -102,9 +108,9 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0. translate_params=transform_params if "translation" in mode else (0.0, 0.0, 0.0), rotate_params=transform_params if "rotation" in mode else (0.0, 0.0, 0.0), ) - a2 = torch.tensor(a2).unsqueeze(0).unsqueeze(0).to(device) + a2 = a2.clone().unsqueeze(0).unsqueeze(0).to(device) result = loss_fn(a2, a1).detach().cpu().numpy() - np.testing.assert_allclose(result, expected_value, rtol=1e-3, atol=5e-3) + np.testing.assert_allclose(result, expected_value, rtol=0.08, atol=0.05) class TestGlobalMutualInformationLossIll(unittest.TestCase): diff --git a/tests/test_globalnet.py b/tests/test_globalnet.py index 4a3e9c124c..d4496858f2 100644 --- a/tests/test_globalnet.py +++ b/tests/test_globalnet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_gmm.py b/tests/test_gmm.py index ad5e383a6a..4ed3b956ff 100644 --- a/tests/test_gmm.py +++ b/tests/test_gmm.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import shutil import tempfile @@ -273,7 +275,6 @@ def tearDown(self) -> None: @parameterized.expand(TEST_CASES) @skip_if_no_cuda def test_cuda(self, test_case_description, mixture_count, class_count, features, labels, expected): - # Device to run on device = torch.device("cuda") diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py index 9863b50df5..937dda344b 100644 --- a/tests/test_grid_dataset.py +++ b/tests/test_grid_dataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import sys import unittest diff --git a/tests/test_grid_distortion.py b/tests/test_grid_distortion.py index d71642aae8..d776d49f4d 100644 --- a/tests/test_grid_distortion.py +++ b/tests/test_grid_distortion.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -61,16 +63,16 @@ [2.25, 2.25, 2.25, 2.25, 2.25, 2.25], [4.5, 4.5, 4.5, 4.5, 4.5, 4.5], [4.5, 4.5, 4.5, 4.5, 4.5, 4.5], - [3.25, 3.25, 3.25, 3.25, 3.25, 3.25], - [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [4.2500, 4.2500, 4.2500, 4.2500, 4.2500, 4.2500], + [2.0, 2.0, 2.0, 2.0, 2.0, 2.0], ], [ - [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], - [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], - [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], - [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], - [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], - [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], + [0.0, 1.5, 3.0, 3.0, 4.5, 5.0], + [0.0, 1.5, 3.0, 3.0, 4.5, 5.0], + [0.0, 1.5, 3.0, 3.0, 4.5, 5.0], + [0.0, 1.5, 3.0, 3.0, 4.5, 5.0], + [0.0, 1.5, 3.0, 3.0, 4.5, 5.0], + [0.0, 1.5, 3.0, 3.0, 4.5, 5.0], ], ] ).astype(np.float32) @@ -101,7 +103,10 @@ class TestGridDistortion(unittest.TestCase): def test_grid_distortion(self, input_param, input_data, expected_val): g = GridDistortion(**input_param) result = g(input_data) - assert_allclose(result, expected_val, type_test=False, rtol=1e-4, atol=1e-4) + if input_param["padding_mode"] != "reflection": + assert_allclose(result, expected_val, type_test=False, rtol=1e-4, atol=1e-4) + else: + assert_allclose(result.shape, expected_val.shape, type_test=False, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_grid_distortiond.py b/tests/test_grid_distortiond.py index 2cf8bc7ff9..a645eb4f87 100644 --- a/tests/test_grid_distortiond.py +++ b/tests/test_grid_distortiond.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -40,16 +42,16 @@ [2.25, 2.25, 2.25, 2.25, 2.25, 2.25], [4.5, 4.5, 4.5, 4.5, 4.5, 4.5], [4.5, 4.5, 4.5, 4.5, 4.5, 4.5], - [3.25, 3.25, 3.25, 3.25, 3.25, 3.25], - [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [4.2500, 4.2500, 4.2500, 4.2500, 4.2500, 4.2500], + [2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000], ], [ - [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], - [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], - [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], - [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], - [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], - [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], + [0.0000, 2.2500, 4.5000, 4.5000, 4.2500, 2.0000], + [0.0000, 2.2500, 4.5000, 4.5000, 4.2500, 2.0000], + [0.0000, 2.2500, 4.5000, 4.5000, 4.2500, 2.0000], + [0.0000, 2.2500, 4.5000, 4.5000, 4.2500, 2.0000], + [0.0000, 2.2500, 4.5000, 4.5000, 4.2500, 2.0000], + [0.0000, 2.2500, 4.5000, 4.5000, 4.2500, 2.0000], ], ] ).astype(np.float32) @@ -77,8 +79,8 @@ class TestGridDistortiond(unittest.TestCase): def test_grid_distortiond(self, input_param, input_data, expected_val_img, expected_val_mask): g = GridDistortiond(**input_param) result = g(input_data) - assert_allclose(result["img"], expected_val_img, type_test=False, rtol=1e-4, atol=1e-4) assert_allclose(result["mask"], expected_val_mask, type_test=False, rtol=1e-4, atol=1e-4) + assert_allclose(result["img"].shape, expected_val_img.shape, type_test=False, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_grid_patch.py b/tests/test_grid_patch.py index 03b33147dd..766b37cf31 100644 --- a/tests/test_grid_patch.py +++ b/tests/test_grid_patch.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -36,25 +38,31 @@ TEST_CASE_9 = [{"patch_size": (2, 2), "num_patches": 4, "sort_fn": "min"}, A, [A11, A12, A21, A22]] TEST_CASE_10 = [{"patch_size": (2, 2), "overlap": 0.5, "num_patches": 3}, A, [A11, A[:, :2, 1:3], A12]] TEST_CASE_11 = [ - {"patch_size": (3, 3), "num_patches": 2, "constant_values": 255}, + {"patch_size": (3, 3), "num_patches": 2, "constant_values": 255, "pad_mode": "constant"}, A, [A[:, :3, :3], np.pad(A[:, :3, 3:], ((0, 0), (0, 0), (0, 2)), mode="constant", constant_values=255)], ] TEST_CASE_12 = [ - {"patch_size": (3, 3), "offset": (-2, -2), "num_patches": 2}, + {"patch_size": (3, 3), "offset": (-2, -2), "num_patches": 2, "pad_mode": "constant"}, A, [np.zeros((3, 3, 3)), np.pad(A[:, :1, 1:4], ((0, 0), (2, 0), (0, 0)), mode="constant")], ] +# Only threshold filtering TEST_CASE_13 = [{"patch_size": (2, 2), "threshold": 50.0}, A, [A11]] +TEST_CASE_14 = [{"patch_size": (2, 2), "threshold": 150.0}, A, [A11, A12, A21]] +# threshold filtering with num_patches more than available patches (no effect) +TEST_CASE_15 = [{"patch_size": (2, 2), "num_patches": 3, "threshold": 50.0}, A, [A11]] +# threshold filtering with num_patches less than available patches (count filtering) +TEST_CASE_16 = [{"patch_size": (2, 2), "num_patches": 2, "threshold": 150.0}, A, [A11, A12]] -TEST_CASE_MEAT_0 = [ +TEST_CASE_META_0 = [ {"patch_size": (2, 2)}, A, [A11, A12, A21, A22], [{"location": [0, 0]}, {"location": [0, 2]}, {"location": [2, 0]}, {"location": [2, 2]}], ] -TEST_CASE_MEAT_1 = [ +TEST_CASE_META_1 = [ {"patch_size": (2, 2)}, MetaTensor(x=A, meta={"path": "path/to/file"}), [A11, A12, A21, A22], @@ -82,6 +90,9 @@ TEST_CASES.append([p, *TEST_CASE_11]) TEST_CASES.append([p, *TEST_CASE_12]) TEST_CASES.append([p, *TEST_CASE_13]) + TEST_CASES.append([p, *TEST_CASE_14]) + TEST_CASES.append([p, *TEST_CASE_15]) + TEST_CASES.append([p, *TEST_CASE_16]) class TestGridPatch(unittest.TestCase): @@ -94,7 +105,7 @@ def test_grid_patch(self, in_type, input_parameters, image, expected): for output_patch, expected_patch in zip(output, expected): assert_allclose(output_patch, expected_patch, type_test=False) - @parameterized.expand([TEST_CASE_MEAT_0, TEST_CASE_MEAT_1]) + @parameterized.expand([TEST_CASE_META_0, TEST_CASE_META_1]) @SkipIfBeforePyTorchVersion((1, 9, 1)) def test_grid_patch_meta(self, input_parameters, image, expected, expected_meta): set_track_meta(True) diff --git a/tests/test_grid_patchd.py b/tests/test_grid_patchd.py index 0f1bea5f8a..46928150cd 100644 --- a/tests/test_grid_patchd.py +++ b/tests/test_grid_patchd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -35,16 +37,22 @@ TEST_CASE_9 = [{"patch_size": (2, 2), "num_patches": 4, "sort_fn": "min"}, {"image": A}, [A11, A12, A21, A22]] TEST_CASE_10 = [{"patch_size": (2, 2), "overlap": 0.5, "num_patches": 3}, {"image": A}, [A11, A[:, :2, 1:3], A12]] TEST_CASE_11 = [ - {"patch_size": (3, 3), "num_patches": 2, "constant_values": 255}, + {"patch_size": (3, 3), "num_patches": 2, "constant_values": 255, "pad_mode": "constant"}, {"image": A}, [A[:, :3, :3], np.pad(A[:, :3, 3:], ((0, 0), (0, 0), (0, 2)), mode="constant", constant_values=255)], ] TEST_CASE_12 = [ - {"patch_size": (3, 3), "offset": (-2, -2), "num_patches": 2}, + {"patch_size": (3, 3), "offset": (-2, -2), "num_patches": 2, "pad_mode": "constant"}, {"image": A}, [np.zeros((3, 3, 3)), np.pad(A[:, :1, 1:4], ((0, 0), (2, 0), (0, 0)), mode="constant")], ] +# Only threshold filtering TEST_CASE_13 = [{"patch_size": (2, 2), "threshold": 50.0}, {"image": A}, [A11]] +TEST_CASE_14 = [{"patch_size": (2, 2), "threshold": 150.0}, {"image": A}, [A11, A12, A21]] +# threshold filtering with num_patches more than available patches (no effect) +TEST_CASE_15 = [{"patch_size": (2, 2), "threshold": 50.0, "num_patches": 3}, {"image": A}, [A11]] +# threshold filtering with num_patches less than available patches (count filtering) +TEST_CASE_16 = [{"patch_size": (2, 2), "threshold": 150.0, "num_patches": 2}, {"image": A}, [A11, A12]] TEST_SINGLE = [] for p in TEST_NDARRAYS: @@ -62,6 +70,9 @@ TEST_SINGLE.append([p, *TEST_CASE_11]) TEST_SINGLE.append([p, *TEST_CASE_12]) TEST_SINGLE.append([p, *TEST_CASE_13]) + TEST_SINGLE.append([p, *TEST_CASE_14]) + TEST_SINGLE.append([p, *TEST_CASE_15]) + TEST_SINGLE.append([p, *TEST_CASE_16]) class TestGridPatchd(unittest.TestCase): diff --git a/tests/test_grid_pull.py b/tests/test_grid_pull.py index 561b231498..8877b0c121 100644 --- a/tests/test_grid_pull.py +++ b/tests/test_grid_pull.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_grid_split.py b/tests/test_grid_split.py index 82734ffd93..3ccf6e75a8 100644 --- a/tests/test_grid_split.py +++ b/tests/test_grid_split.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_grid_splitd.py b/tests/test_grid_splitd.py index 086dd2691d..d8519b2121 100644 --- a/tests/test_grid_splitd.py +++ b/tests/test_grid_splitd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py index bdd4499687..7dfb802bba 100644 --- a/tests/test_handler_checkpoint_loader.py +++ b/tests/test_handler_checkpoint_loader.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import tempfile import unittest diff --git a/tests/test_handler_checkpoint_saver.py b/tests/test_handler_checkpoint_saver.py index c87866490c..70810e018f 100644 --- a/tests/test_handler_checkpoint_saver.py +++ b/tests/test_handler_checkpoint_saver.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_handler_classification_saver.py b/tests/test_handler_classification_saver.py index 313842b443..905e326a66 100644 --- a/tests/test_handler_classification_saver.py +++ b/tests/test_handler_classification_saver.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import csv import os import tempfile @@ -26,7 +28,6 @@ class TestHandlerClassificationSaver(unittest.TestCase): def test_saved_content(self): with tempfile.TemporaryDirectory() as tempdir: - # set up engine def _train_func(engine, batch): engine.state.batch = decollate_batch(batch) diff --git a/tests/test_handler_classification_saver_dist.py b/tests/test_handler_classification_saver_dist.py index e92009d37f..ef06b69683 100644 --- a/tests/test_handler_classification_saver_dist.py +++ b/tests/test_handler_classification_saver_dist.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import csv import os import tempfile diff --git a/tests/test_handler_clearml_image.py b/tests/test_handler_clearml_image.py new file mode 100644 index 0000000000..781931327f --- /dev/null +++ b/tests/test_handler_clearml_image.py @@ -0,0 +1,49 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from monai.handlers import ClearMLImageHandler +from monai.utils import optional_import + +Task, has_clearml = optional_import("clearml", name="Task") +_, has_tb = optional_import("torch.utils.tensorboard", name="SummaryWriter") + + +@unittest.skipUnless(has_clearml, "Requires 'clearml' installation") +@unittest.skipUnless(has_tb, "Requires SummaryWriter installation") +@unittest.skip("temp mute clearml tests https://github.com/Project-MONAI/MONAI/issues/6148") +class TestHandlerClearMLImageHandler(unittest.TestCase): + def test_task_init(self): + Task.set_offline(offline_mode=True) + try: + ClearMLImageHandler( + project_name="MONAI", + task_name="monai_experiment", + output_uri=True, + tags=None, + reuse_last_task_id=True, + continue_last_task=False, + auto_connect_frameworks=True, + auto_connect_arg_parser=False, + ) + except Exception as exc: + self.fail(exc) + self.assertEqual(Task.current_task().name, "monai_experiment") + self.assertEqual(Task.current_task()._project_name[1], "MONAI") + # Close ClearML Task + Task.current_task().close() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_clearml_stats.py b/tests/test_handler_clearml_stats.py new file mode 100644 index 0000000000..fad847ca1d --- /dev/null +++ b/tests/test_handler_clearml_stats.py @@ -0,0 +1,49 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from monai.handlers import ClearMLStatsHandler +from monai.utils import optional_import + +Task, has_clearml = optional_import("clearml", name="Task") +_, has_tb = optional_import("torch.utils.tensorboard", name="SummaryWriter") + + +@unittest.skipUnless(has_clearml, "Requires 'clearml' installation") +@unittest.skipUnless(has_tb, "Requires SummaryWriter installation") +@unittest.skip("temp mute clearml tests https://github.com/Project-MONAI/MONAI/issues/6148") +class TestHandlerClearMLStatsHandler(unittest.TestCase): + def test_task_init(self): + Task.set_offline(offline_mode=True) + try: + ClearMLStatsHandler( + project_name="MONAI", + task_name="monai_experiment", + output_uri=True, + tags=None, + reuse_last_task_id=True, + continue_last_task=False, + auto_connect_frameworks=True, + auto_connect_arg_parser=False, + ) + except Exception as exc: + self.fail(exc) + self.assertEqual(Task.current_task().name, "monai_experiment") + self.assertEqual(Task.current_task()._project_name[1], "MONAI") + # Close ClearML Task + Task.current_task().close() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_confusion_matrix.py b/tests/test_handler_confusion_matrix.py index ee6f3cd681..5f3ee3ae97 100644 --- a/tests/test_handler_confusion_matrix.py +++ b/tests/test_handler_confusion_matrix.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import Any, Dict +from typing import Any import torch from ignite.engine import Engine @@ -24,7 +26,7 @@ TEST_CASE_3 = [{"save_details": False, "metric_name": "f1", "reduction": "mean_batch"}, torch.tensor([0.6667, 0.8000])] TEST_CASE_SEG_1 = [{"include_background": True, "metric_name": "tpr"}, 0.7] -data_1: Dict[Any, Any] = { +data_1: dict[Any, Any] = { "y_pred": torch.tensor( [ [[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]]], @@ -39,7 +41,7 @@ ), } -data_2: Dict[Any, Any] = { +data_2: dict[Any, Any] = { "y_pred": torch.tensor([[[[0.0, 1.0], [1.0, 0.0]], [[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [0.0, 0.0]]]]), "y": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]), } diff --git a/tests/test_handler_confusion_matrix_dist.py b/tests/test_handler_confusion_matrix_dist.py index 511e84d22a..b74b7e57c4 100644 --- a/tests/test_handler_confusion_matrix_dist.py +++ b/tests/test_handler_confusion_matrix_dist.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_handler_decollate_batch.py b/tests/test_handler_decollate_batch.py index 757708ea2b..5bc5584515 100644 --- a/tests/test_handler_decollate_batch.py +++ b/tests/test_handler_decollate_batch.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_handler_early_stop.py b/tests/test_handler_early_stop.py index 36604e5735..675a804472 100644 --- a/tests/test_handler_early_stop.py +++ b/tests/test_handler_early_stop.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from ignite.engine import Engine, Events diff --git a/tests/test_handler_garbage_collector.py b/tests/test_handler_garbage_collector.py index e3bc3411b9..f64039b6fb 100644 --- a/tests/test_handler_garbage_collector.py +++ b/tests/test_handler_garbage_collector.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import gc import unittest from unittest import skipUnless diff --git a/tests/test_handler_hausdorff_distance.py b/tests/test_handler_hausdorff_distance.py index 1fe9b2e4a3..906db86d62 100644 --- a/tests/test_handler_hausdorff_distance.py +++ b/tests/test_handler_hausdorff_distance.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import Tuple import numpy as np import torch @@ -21,7 +22,7 @@ def create_spherical_seg_3d( - radius: float = 20.0, centre: Tuple[int, int, int] = (49, 49, 49), im_shape: Tuple[int, int, int] = (99, 99, 99) + radius: float = 20.0, centre: tuple[int, int, int] = (49, 49, 49), im_shape: tuple[int, int, int] = (99, 99, 99) ) -> np.ndarray: """ Return a 3D image with a sphere inside. Voxel values will be diff --git a/tests/test_handler_logfile.py b/tests/test_handler_logfile.py index b67bf63a5b..f09876ab0a 100644 --- a/tests/test_handler_logfile.py +++ b/tests/test_handler_logfile.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest @@ -58,7 +60,6 @@ def test_filename(self): filename = "something_else.txt" with tempfile.TemporaryDirectory() as tempdir: - handler = LogfileHandler(output_dir=tempdir, filename=filename) handler.attach(self.engine) diff --git a/tests/test_handler_lr_scheduler.py b/tests/test_handler_lr_scheduler.py index 15401fe1b2..f1d3f45f06 100644 --- a/tests/test_handler_lr_scheduler.py +++ b/tests/test_handler_lr_scheduler.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import os import re diff --git a/tests/test_handler_mean_dice.py b/tests/test_handler_mean_dice.py index 88eb4fbdcd..10cf981f02 100644 --- a/tests/test_handler_mean_dice.py +++ b/tests/test_handler_mean_dice.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_handler_mean_iou.py b/tests/test_handler_mean_iou.py index fdd4a5d04d..89dae3af58 100644 --- a/tests/test_handler_mean_iou.py +++ b/tests/test_handler_mean_iou.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_handler_metric_logger.py b/tests/test_handler_metric_logger.py index c3de866c5c..016af1e8b5 100644 --- a/tests/test_handler_metric_logger.py +++ b/tests/test_handler_metric_logger.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_handler_metrics_reloaded.py b/tests/test_handler_metrics_reloaded.py new file mode 100644 index 0000000000..e080204d6f --- /dev/null +++ b/tests/test_handler_metrics_reloaded.py @@ -0,0 +1,149 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from ignite.engine import Engine, Events +from parameterized import parameterized + +from monai.handlers import MetricsReloadedBinaryHandler, MetricsReloadedCategoricalHandler, from_engine +from monai.utils import optional_import +from tests.utils import assert_allclose + +_, has_metrics = optional_import("MetricsReloaded") + +TEST_CASE_BIN_1 = [ + {"metric_name": "Volume Difference"}, + [torch.tensor([[[1.0, 0.0], [0.0, 1.0]]]), torch.tensor([[[1.0, 0.0], [0.0, 1.0]]])], + [torch.tensor([[[1.0, 0.0], [1.0, 1.0]]]), torch.tensor([[[1.0, 0.0], [1.0, 1.0]]])], + 0.3333, +] + +TEST_CASE_BIN_2 = [ + {"metric_name": "Boundary IoU"}, + [torch.tensor([[[1.0, 0.0], [0.0, 1.0]]]), torch.tensor([[[1.0, 0.0], [0.0, 1.0]]])], + [torch.tensor([[[1.0, 0.0], [1.0, 1.0]]]), torch.tensor([[[1.0, 0.0], [1.0, 1.0]]])], + 0.6667, +] + +TEST_CASE_BIN_3 = [ + {"metric_name": "xTh Percentile Hausdorff Distance"}, + [torch.tensor([[[1.0, 0.0], [0.0, 1.0]]]), torch.tensor([[[1.0, 0.0], [0.0, 1.0]]])], + [torch.tensor([[[1.0, 0.0], [1.0, 1.0]]]), torch.tensor([[[1.0, 0.0], [1.0, 1.0]]])], + 0.9, +] + +TEST_CASE_CAT_1 = [ + {"metric_name": "Weighted Cohens Kappa"}, + [ + torch.tensor([[[0, 0], [0, 1]], [[0, 0], [0, 0]], [[1, 1], [1, 0]]]), + torch.tensor([[[0, 0], [0, 1]], [[0, 0], [0, 0]], [[1, 1], [1, 0]]]), + ], + [ + torch.tensor([[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]), + torch.tensor([[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]), + ], + 0.272727, +] + +TEST_CASE_CAT_2 = [ + {"metric_name": "Matthews Correlation Coefficient"}, + [ + torch.tensor([[[0, 0], [0, 1]], [[0, 0], [0, 0]], [[1, 1], [1, 0]]]), + torch.tensor([[[0, 0], [0, 1]], [[0, 0], [0, 0]], [[1, 1], [1, 0]]]), + ], + [ + torch.tensor([[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]), + torch.tensor([[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]), + ], + 0.387298, +] + + +@unittest.skipIf(not has_metrics, "MetricsReloaded not available.") +class TestHandlerMetricsReloadedBinary(unittest.TestCase): + @parameterized.expand([TEST_CASE_BIN_1, TEST_CASE_BIN_2, TEST_CASE_BIN_3]) + def test_compute(self, input_params, y_pred, y, expected_value): + input_params["output_transform"] = from_engine(["pred", "label"]) + metric = MetricsReloadedBinaryHandler(**input_params) + + # set up engine + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + metric.attach(engine=engine, name=input_params["metric_name"]) + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + + engine.fire_event(Events.EPOCH_COMPLETED) + assert_allclose( + engine.state.metrics[input_params["metric_name"]], expected_value, atol=1e-4, rtol=1e-4, type_test=False + ) + + @parameterized.expand([TEST_CASE_BIN_1, TEST_CASE_BIN_2, TEST_CASE_BIN_3]) + def test_shape_mismatch(self, input_params, _y_pred, _y, _expected_value): + input_params["output_transform"] = from_engine(["pred", "label"]) + metric = MetricsReloadedBinaryHandler(**input_params) + with self.assertRaises((AssertionError, ValueError)): + y_pred = torch.Tensor([[0, 1], [1, 0]]) + y = torch.ones((2, 3)) + metric.update([y_pred, y]) + + with self.assertRaises((AssertionError, ValueError)): + y_pred = [torch.ones((2, 1, 1)), torch.ones((1, 1, 1))] + y = [torch.ones((2, 1, 1)), torch.ones((1, 1, 1))] + metric.update([y_pred, y]) + + +@unittest.skipIf(not has_metrics, "MetricsReloaded not available.") +class TestMetricsReloadedCategorical(unittest.TestCase): + @parameterized.expand([TEST_CASE_CAT_1, TEST_CASE_CAT_2]) + def test_compute(self, input_params, y_pred, y, expected_value): + input_params["output_transform"] = from_engine(["pred", "label"]) + metric = MetricsReloadedCategoricalHandler(**input_params) + + # set up engine + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + metric.attach(engine=engine, name=input_params["metric_name"]) + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + + engine.fire_event(Events.EPOCH_COMPLETED) + assert_allclose( + engine.state.metrics[input_params["metric_name"]], expected_value, atol=1e-4, rtol=1e-4, type_test=False + ) + + @parameterized.expand([TEST_CASE_CAT_1, TEST_CASE_CAT_2]) + def test_shape_mismatch(self, input_params, y_pred, y, _expected_value): + input_params["output_transform"] = from_engine(["pred", "label"]) + metric = MetricsReloadedCategoricalHandler(**input_params) + with self.assertRaises((AssertionError, ValueError)): + y_pred[0] = torch.zeros([3, 2, 1]) + metric.update([y_pred, y]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_metrics_saver.py b/tests/test_handler_metrics_saver.py index 27f107605b..9888a73e5f 100644 --- a/tests/test_handler_metrics_saver.py +++ b/tests/test_handler_metrics_saver.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import csv import os import tempfile diff --git a/tests/test_handler_metrics_saver_dist.py b/tests/test_handler_metrics_saver_dist.py index 426d99c223..11d7db168b 100644 --- a/tests/test_handler_metrics_saver_dist.py +++ b/tests/test_handler_metrics_saver_dist.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import csv import os import tempfile diff --git a/tests/test_handler_mlflow.py b/tests/test_handler_mlflow.py index f41957840f..d9474a9a72 100644 --- a/tests/test_handler_mlflow.py +++ b/tests/test_handler_mlflow.py @@ -9,20 +9,33 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import glob import os import shutil import tempfile import unittest from concurrent.futures import ThreadPoolExecutor +from unittest.mock import MagicMock import numpy as np from ignite.engine import Engine, Events +from parameterized import parameterized from monai.handlers import MLFlowHandler from monai.utils import path_to_uri +def get_event_filter(e): + def event_filter(_, event): + if event in e: + return True + return False + + return event_filter + + def dummy_train(tracking_folder): tempdir = tempfile.mkdtemp() @@ -58,7 +71,6 @@ def tearDown(self): def test_metrics_track(self): experiment_param = {"backbone": "efficientnet_b0"} with tempfile.TemporaryDirectory() as tempdir: - # set up engine def _train_func(engine, batch): return [batch + 1.0] @@ -94,6 +106,85 @@ def _update_metric(engine): # check logging output self.assertTrue(len(glob.glob(test_path)) > 0) + @parameterized.expand([[True], [get_event_filter([1, 2])]]) + def test_metrics_track_mock(self, epoch_log): + experiment_param = {"backbone": "efficientnet_b0"} + with tempfile.TemporaryDirectory() as tempdir: + # set up engine + def _train_func(engine, batch): + return [batch + 1.0] + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get("acc", 0.1) + engine.state.metrics["acc"] = current_metric + 0.1 + engine.state.test = current_metric + + # set up testing handler + test_path = os.path.join(tempdir, "mlflow_test") + handler = MLFlowHandler( + iteration_log=False, + epoch_log=epoch_log, + tracking_uri=path_to_uri(test_path), + state_attributes=["test"], + experiment_param=experiment_param, + close_on_complete=True, + ) + handler._default_epoch_log = MagicMock() + handler.attach(engine) + + max_epochs = 4 + engine.run(range(3), max_epochs=max_epochs) + handler.close() + # check logging output + if epoch_log is True: + self.assertEqual(handler._default_epoch_log.call_count, max_epochs) + else: + self.assertEqual(handler._default_epoch_log.call_count, 2) # 2 = len([1, 2]) from event_filter + + @parameterized.expand([[True], [get_event_filter([1, 3])]]) + def test_metrics_track_iters_mock(self, iteration_log): + experiment_param = {"backbone": "efficientnet_b0"} + with tempfile.TemporaryDirectory() as tempdir: + # set up engine + def _train_func(engine, batch): + return [batch + 1.0] + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get("acc", 0.1) + engine.state.metrics["acc"] = current_metric + 0.1 + engine.state.test = current_metric + + # set up testing handler + test_path = os.path.join(tempdir, "mlflow_test") + handler = MLFlowHandler( + iteration_log=iteration_log, + epoch_log=False, + tracking_uri=path_to_uri(test_path), + state_attributes=["test"], + experiment_param=experiment_param, + close_on_complete=True, + ) + handler._default_iteration_log = MagicMock() + handler.attach(engine) + + num_iters = 3 + max_epochs = 2 + engine.run(range(num_iters), max_epochs=max_epochs) + handler.close() + # check logging output + if iteration_log is True: + self.assertEqual(handler._default_iteration_log.call_count, num_iters * max_epochs) + else: + self.assertEqual(handler._default_iteration_log.call_count, 2) # 2 = len([1, 3]) from event_filter + def test_multi_thread(self): test_uri_list = ["monai_mlflow_test1", "monai_mlflow_test2"] with ThreadPoolExecutor(2, "Training") as executor: diff --git a/tests/test_handler_nvtx.py b/tests/test_handler_nvtx.py index 29a1b8e4fb..75cc5bc1f4 100644 --- a/tests/test_handler_nvtx.py +++ b/tests/test_handler_nvtx.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_handler_panoptic_quality.py b/tests/test_handler_panoptic_quality.py index 9e24c52d9e..1595b5ad2c 100644 --- a/tests/test_handler_panoptic_quality.py +++ b/tests/test_handler_panoptic_quality.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_handler_parameter_scheduler.py b/tests/test_handler_parameter_scheduler.py index 22bf046b83..1e7bbb7588 100644 --- a/tests/test_handler_parameter_scheduler.py +++ b/tests/test_handler_parameter_scheduler.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from ignite.engine import Engine, Events diff --git a/tests/test_handler_post_processing.py b/tests/test_handler_post_processing.py index 10c7ac4a8b..c449665c1e 100644 --- a/tests/test_handler_post_processing.py +++ b/tests/test_handler_post_processing.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_handler_prob_map_producer.py b/tests/test_handler_prob_map_producer.py index a968e7dea0..153a00b1ac 100644 --- a/tests/test_handler_prob_map_producer.py +++ b/tests/test_handler_prob_map_producer.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest @@ -49,7 +51,6 @@ def __init__(self, name, size): ] def __getitem__(self, index): - image = np.ones((3, 2, 2)) * index metadata = { ProbMapKeys.COUNT.value: self.data[index][ProbMapKeys.COUNT.value], diff --git a/tests/test_handler_regression_metrics.py b/tests/test_handler_regression_metrics.py index c6af76a4db..a06452c54d 100644 --- a/tests/test_handler_regression_metrics.py +++ b/tests/test_handler_regression_metrics.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from functools import partial @@ -64,7 +66,6 @@ def test_compute(self): # iterate over all variations and check shapes for different reduction functions for mt_fn, mt_fn_np in zip(metrics, metrics_np): - for batch in batch_dims: for spatial in spatial_dims: for base in base_dims: diff --git a/tests/test_handler_regression_metrics_dist.py b/tests/test_handler_regression_metrics_dist.py index a8b644d550..a2e96b97d9 100644 --- a/tests/test_handler_regression_metrics_dist.py +++ b/tests/test_handler_regression_metrics_dist.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_handler_rocauc.py b/tests/test_handler_rocauc.py index 6e2d6be27e..ce2351a9f5 100644 --- a/tests/test_handler_rocauc.py +++ b/tests/test_handler_rocauc.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_handler_rocauc_dist.py b/tests/test_handler_rocauc_dist.py index 994cbe139b..5b6ea045c7 100644 --- a/tests/test_handler_rocauc_dist.py +++ b/tests/test_handler_rocauc_dist.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_handler_smartcache.py b/tests/test_handler_smartcache.py index 7bc9011c2d..c3b4d72cb4 100644 --- a/tests/test_handler_smartcache.py +++ b/tests/test_handler_smartcache.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import sys import unittest diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index 7fe07d974b..1842e08635 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import os import re @@ -18,12 +20,23 @@ import torch from ignite.engine import Engine, Events +from parameterized import parameterized from monai.handlers import StatsHandler +def get_event_filter(e): + def event_filter(_, event): + if event in e: + return True + return False + + return event_filter + + class TestHandlerStats(unittest.TestCase): - def test_metrics_print(self): + @parameterized.expand([[True], [get_event_filter([1, 2])]]) + def test_metrics_print(self, epoch_log): log_stream = StringIO() log_handler = logging.StreamHandler(log_stream) log_handler.setLevel(logging.INFO) @@ -46,10 +59,11 @@ def _update_metric(engine): logger = logging.getLogger(key_to_handler) logger.setLevel(logging.INFO) logger.addHandler(log_handler) - stats_handler = StatsHandler(iteration_log=False, epoch_log=True, name=key_to_handler) + stats_handler = StatsHandler(iteration_log=False, epoch_log=epoch_log, name=key_to_handler) stats_handler.attach(engine) - engine.run(range(3), max_epochs=2) + max_epochs = 4 + engine.run(range(3), max_epochs=max_epochs) # check logging output output_str = log_stream.getvalue() @@ -59,9 +73,13 @@ def _update_metric(engine): for line in output_str.split("\n"): if has_key_word.match(line): content_count += 1 - self.assertTrue(content_count > 0) + if epoch_log is True: + self.assertTrue(content_count == max_epochs) + else: + self.assertTrue(content_count == 2) # 2 = len([1, 2]) from event_filter - def test_loss_print(self): + @parameterized.expand([[True], [get_event_filter([1, 3])]]) + def test_loss_print(self, iteration_log): log_stream = StringIO() log_handler = logging.StreamHandler(log_stream) log_handler.setLevel(logging.INFO) @@ -78,10 +96,14 @@ def _train_func(engine, batch): logger = logging.getLogger(key_to_handler) logger.setLevel(logging.INFO) logger.addHandler(log_handler) - stats_handler = StatsHandler(iteration_log=True, epoch_log=False, name=key_to_handler, tag_name=key_to_print) + stats_handler = StatsHandler( + iteration_log=iteration_log, epoch_log=False, name=key_to_handler, tag_name=key_to_print + ) stats_handler.attach(engine) - engine.run(range(3), max_epochs=2) + num_iters = 3 + max_epochs = 2 + engine.run(range(num_iters), max_epochs=max_epochs) # check logging output output_str = log_stream.getvalue() @@ -91,7 +113,10 @@ def _train_func(engine, batch): for line in output_str.split("\n"): if has_key_word.match(line): content_count += 1 - self.assertTrue(content_count > 0) + if iteration_log is True: + self.assertTrue(content_count == num_iters * max_epochs) + else: + self.assertTrue(content_count == 2) # 2 = len([1, 3]) from event_filter def test_loss_dict(self): log_stream = StringIO() @@ -229,7 +254,9 @@ def _train_func(engine, batch): # set up testing handler stats_handler = StatsHandler(name=None, tag_name=key_to_print) - stats_handler.attach(engine) + engine.logger.setLevel(logging.WARNING) + with self.assertWarns(Warning): # engine logging level warn + stats_handler.attach(engine) # leverage `engine.logger` to print info engine.logger.setLevel(logging.INFO) level = logging.root.getEffectiveLevel() diff --git a/tests/test_handler_surface_distance.py b/tests/test_handler_surface_distance.py index 6d245693ac..736f7e251a 100644 --- a/tests/test_handler_surface_distance.py +++ b/tests/test_handler_surface_distance.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import Tuple import numpy as np import torch @@ -21,7 +22,7 @@ def create_spherical_seg_3d( - radius: float = 20.0, centre: Tuple[int, int, int] = (49, 49, 49), im_shape: Tuple[int, int, int] = (99, 99, 99) + radius: float = 20.0, centre: tuple[int, int, int] = (49, 49, 49), im_shape: tuple[int, int, int] = (99, 99, 99) ) -> np.ndarray: """ Return a 3D image with a sphere inside. Voxel values will be diff --git a/tests/test_handler_tb_image.py b/tests/test_handler_tb_image.py index 749480e279..8657e552f1 100644 --- a/tests/test_handler_tb_image.py +++ b/tests/test_handler_tb_image.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import glob import tempfile import unittest @@ -32,7 +34,6 @@ class TestHandlerTBImage(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_tb_image_shape(self, shape): with tempfile.TemporaryDirectory() as tempdir: - # set up engine def _train_func(engine, batch): engine.state.batch = decollate_batch(list(batch)) diff --git a/tests/test_handler_tb_stats.py b/tests/test_handler_tb_stats.py index eef77e5e2b..f4ccfc9780 100644 --- a/tests/test_handler_tb_stats.py +++ b/tests/test_handler_tb_stats.py @@ -9,11 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import glob import tempfile import unittest +from unittest.mock import MagicMock from ignite.engine import Engine, Events +from parameterized import parameterized from monai.handlers import TensorBoardStatsHandler from monai.utils import optional_import @@ -21,11 +25,27 @@ SummaryWriter, has_tb = optional_import("torch.utils.tensorboard", name="SummaryWriter") +def get_event_filter(e): + def event_filter(_, event): + if event in e: + return True + return False + + return event_filter + + @unittest.skipUnless(has_tb, "Requires SummaryWriter installation") class TestHandlerTBStats(unittest.TestCase): + def test_args_validation(self): + with self.assertWarns(FutureWarning): + with self.assertRaisesRegex(ValueError, expected_regex="iteration_interval should be 1"): + TensorBoardStatsHandler(log_dir=".", iteration_log=get_event_filter([1, 2]), iteration_interval=2) + + with self.assertRaisesRegex(ValueError, expected_regex="epoch_interval should be 1"): + TensorBoardStatsHandler(log_dir=".", epoch_log=get_event_filter([1, 2]), epoch_interval=2) + def test_metrics_print(self): with tempfile.TemporaryDirectory() as tempdir: - # set up engine def _train_func(engine, batch): return [batch + 1.0] @@ -46,9 +66,37 @@ def _update_metric(engine): # check logging output self.assertTrue(len(glob.glob(tempdir)) > 0) - def test_metrics_writer(self): + @parameterized.expand([[True], [get_event_filter([1, 2])]]) + def test_metrics_print_mock(self, epoch_log): with tempfile.TemporaryDirectory() as tempdir: + # set up engine + def _train_func(engine, batch): + return [batch + 1.0] + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get("acc", 0.1) + engine.state.metrics["acc"] = current_metric + 0.1 + + # set up testing handler + stats_handler = TensorBoardStatsHandler(log_dir=tempdir, iteration_log=False, epoch_log=epoch_log) + stats_handler._default_epoch_writer = MagicMock() + stats_handler.attach(engine) + + max_epochs = 4 + engine.run(range(3), max_epochs=max_epochs) + stats_handler.close() + # check logging output + if epoch_log is True: + self.assertEqual(stats_handler._default_epoch_writer.call_count, max_epochs) + else: + self.assertEqual(stats_handler._default_epoch_writer.call_count, 2) # 2 = len([1, 2]) from event_filter + def test_metrics_writer(self): + with tempfile.TemporaryDirectory() as tempdir: # set up engine def _train_func(engine, batch): return [batch + 1.0] @@ -78,6 +126,47 @@ def _update_metric(engine): # check logging output self.assertTrue(len(glob.glob(tempdir)) > 0) + @parameterized.expand([[True], [get_event_filter([1, 3])]]) + def test_metrics_writer_mock(self, iteration_log): + with tempfile.TemporaryDirectory() as tempdir: + # set up engine + def _train_func(engine, batch): + return [batch + 1.0] + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get("acc", 0.1) + engine.state.metrics["acc"] = current_metric + 0.1 + engine.state.test = current_metric + + # set up testing handler + writer = SummaryWriter(log_dir=tempdir) + stats_handler = TensorBoardStatsHandler( + summary_writer=writer, + iteration_log=iteration_log, + epoch_log=False, + output_transform=lambda x: {"loss": x[0] * 2.0}, + global_epoch_transform=lambda x: x * 3.0, + state_attributes=["test"], + ) + stats_handler._default_iteration_writer = MagicMock() + stats_handler.attach(engine) + + num_iters = 3 + max_epochs = 2 + engine.run(range(num_iters), max_epochs=max_epochs) + writer.close() + + if iteration_log is True: + self.assertEqual(stats_handler._default_iteration_writer.call_count, num_iters * max_epochs) + else: + self.assertEqual( + stats_handler._default_iteration_writer.call_count, 2 + ) # 2 = len([1, 3]) from event_filter + if __name__ == "__main__": unittest.main() diff --git a/tests/test_handler_validation.py b/tests/test_handler_validation.py index 42ffc8b9eb..deb762917d 100644 --- a/tests/test_handler_validation.py +++ b/tests/test_handler_validation.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_hardnegsampler.py b/tests/test_hardnegsampler.py index f4eff81810..b33cea1537 100644 --- a/tests/test_hardnegsampler.py +++ b/tests/test_hardnegsampler.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_hashing.py b/tests/test_hashing.py index 5a1265bd48..093de47cf9 100644 --- a/tests/test_hashing.py +++ b/tests/test_hashing.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_hausdorff_distance.py b/tests/test_hausdorff_distance.py index 44c011fe13..40f5b187d0 100644 --- a/tests/test_hausdorff_distance.py +++ b/tests/test_hausdorff_distance.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import Tuple import numpy as np import torch @@ -18,9 +19,11 @@ from monai.metrics import HausdorffDistanceMetric +_device = "cuda:0" if torch.cuda.is_available() else "cpu" + def create_spherical_seg_3d( - radius: float = 20.0, centre: Tuple[int, int, int] = (49, 49, 49), im_shape: Tuple[int, int, int] = (99, 99, 99) + radius: float = 20.0, centre: tuple[int, int, int] = (49, 49, 49), im_shape: tuple[int, int, int] = (99, 99, 99) ) -> np.ndarray: """ Return a 3D image with a sphere inside. Voxel values will be @@ -115,8 +118,8 @@ def test_value(self, input_data, expected_value): else: [seg_1, seg_2] = input_data ct = 0 - seg_1 = torch.tensor(seg_1) - seg_2 = torch.tensor(seg_2) + seg_1 = torch.tensor(seg_1, device=_device) + seg_2 = torch.tensor(seg_2, device=_device) for metric in ["euclidean", "chessboard", "taxicab"]: for directed in [True, False]: hd_metric = HausdorffDistanceMetric( @@ -129,7 +132,8 @@ def test_value(self, input_data, expected_value): hd_metric(batch_seg_1, batch_seg_2) result = hd_metric.aggregate(reduction="mean") expected_value_curr = expected_value[ct] - np.testing.assert_allclose(expected_value_curr, result, rtol=1e-7) + np.testing.assert_allclose(expected_value_curr, result.cpu(), rtol=1e-7) + np.testing.assert_equal(result.device, seg_1.device) ct += 1 @parameterized.expand(TEST_CASES_NANS) diff --git a/tests/test_header_correct.py b/tests/test_header_correct.py index aa0a4dde08..71fed1e35d 100644 --- a/tests/test_header_correct.py +++ b/tests/test_header_correct.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import nibabel as nib diff --git a/tests/test_highresnet.py b/tests/test_highresnet.py index cb3a923f14..04520419b7 100644 --- a/tests/test_highresnet.py +++ b/tests/test_highresnet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_hilbert_transform.py b/tests/test_hilbert_transform.py index f7954d6b24..4c49aecd8b 100644 --- a/tests/test_hilbert_transform.py +++ b/tests/test_hilbert_transform.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -21,7 +23,6 @@ def create_expected_numpy_output(input_datum, **kwargs): - x = np.fft.fft(input_datum.cpu().numpy() if input_datum.device.type == "cuda" else input_datum.numpy(), **kwargs) f = np.fft.fftfreq(x.shape[kwargs["axis"]]) u = np.heaviside(f, 0.5) diff --git a/tests/test_histogram_normalize.py b/tests/test_histogram_normalize.py index 586f7a59d3..3a340db52a 100644 --- a/tests/test_histogram_normalize.py +++ b/tests/test_histogram_normalize.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_histogram_normalized.py b/tests/test_histogram_normalized.py index 64d5d514a6..24f27d225e 100644 --- a/tests/test_histogram_normalized.py +++ b/tests/test_histogram_normalized.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_hovernet.py b/tests/test_hovernet.py index 61bf85c50b..d768895bdc 100644 --- a/tests/test_hovernet.py +++ b/tests/test_hovernet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_hovernet_instance_map_post_processing.py b/tests/test_hovernet_instance_map_post_processing.py index db87332c13..990e2d9a10 100644 --- a/tests/test_hovernet_instance_map_post_processing.py +++ b/tests/test_hovernet_instance_map_post_processing.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_hovernet_instance_map_post_processingd.py b/tests/test_hovernet_instance_map_post_processingd.py index 2aa9f7ae64..69e42d3495 100644 --- a/tests/test_hovernet_instance_map_post_processingd.py +++ b/tests/test_hovernet_instance_map_post_processingd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_hovernet_loss.py b/tests/test_hovernet_loss.py index 653b8f8c88..10db4518fa 100644 --- a/tests/test_hovernet_loss.py +++ b/tests/test_hovernet_loss.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import random import unittest diff --git a/tests/test_hovernet_nuclear_type_post_processing.py b/tests/test_hovernet_nuclear_type_post_processing.py index ae2b6dcdeb..f2b33c96ae 100644 --- a/tests/test_hovernet_nuclear_type_post_processing.py +++ b/tests/test_hovernet_nuclear_type_post_processing.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_hovernet_nuclear_type_post_processingd.py b/tests/test_hovernet_nuclear_type_post_processingd.py index 7f3462dddb..01478b7961 100644 --- a/tests/test_hovernet_nuclear_type_post_processingd.py +++ b/tests/test_hovernet_nuclear_type_post_processingd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_identity.py b/tests/test_identity.py index 60134c24a4..19116cbb8f 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from monai.transforms.utility.array import Identity diff --git a/tests/test_identityd.py b/tests/test_identityd.py index f1d27d61d4..98499def01 100644 --- a/tests/test_identityd.py +++ b/tests/test_identityd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from monai.transforms.utility.dictionary import Identityd diff --git a/tests/test_image_dataset.py b/tests/test_image_dataset.py index a89759323d..ddafce9eac 100644 --- a/tests/test_image_dataset.py +++ b/tests/test_image_dataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py new file mode 100644 index 0000000000..841a5d5cd5 --- /dev/null +++ b/tests/test_image_filter.py @@ -0,0 +1,231 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers.simplelayers import GaussianFilter +from monai.transforms import ImageFilter, ImageFilterd, RandImageFilter, RandImageFilterd + +EXPECTED_FILTERS = { + "mean": torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]).float(), + "laplace": torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]).float(), + "elliptical": torch.tensor([[0, 1, 0], [1, 1, 1], [0, 1, 0]]).float(), + "sharpen": torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]).float(), +} + +SUPPORTED_FILTERS = ["mean", "laplace", "elliptical", "sobel", "sharpen", "median", "gauss", "savitzky_golay"] +SAMPLE_IMAGE_2D = torch.randn(1, 10, 10) +SAMPLE_IMAGE_3D = torch.randn(1, 10, 10, 10) +SAMPLE_DICT = {"image_2d": SAMPLE_IMAGE_2D, "image_3d": SAMPLE_IMAGE_3D} + +# Sobel filter uses reflect pad as default which is not implemented for 3d in torch 1.8.1 or 1.9.1 +ADDITIONAL_ARGUMENTS = {"order": 1, "sigma": 1, "padding_mode": "zeros"} + + +class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + 1 + + +class TestNotAModuleOrTransform: + pass + + +class TestImageFilter(unittest.TestCase): + @parameterized.expand(SUPPORTED_FILTERS) + def test_init_from_string(self, filter_name): + "Test init from string" + _ = ImageFilter(filter_name, 3, **ADDITIONAL_ARGUMENTS) + + def test_init_raises(self): + with self.assertRaises(Exception) as context: + _ = ImageFilter("mean") + self.assertTrue("`filter_size` must be specified when specifying filters by string." in str(context.output)) + with self.assertRaises(Exception) as context: + _ = ImageFilter("mean") + self.assertTrue("`filter_size` should be a single uneven integer." in str(context.output)) + with self.assertRaises(Exception) as context: + _ = ImageFilter("gauss", 3) + self.assertTrue("`filter='gauss', requires the additonal keyword argument `sigma`" in str(context.output)) + with self.assertRaises(Exception) as context: + _ = ImageFilter("savitzky_golay", 3) + self.assertTrue( + "`filter='savitzky_golay', requires the additonal keyword argument `order`" in str(context.output) + ) + + def test_init_from_array(self): + "Test init with custom filter and assert wrong filter shape throws an error" + _ = ImageFilter(torch.ones(3, 3)) + _ = ImageFilter(torch.ones(3, 3, 3)) + _ = ImageFilter(np.ones((3, 3))) + _ = ImageFilter(np.ones((3, 3, 3))) + + with self.assertRaises(Exception) as context: + _ = ImageFilter(torch.ones(3, 3, 3, 3)) + self.assertTrue("Only 1D, 2D, and 3D filters are supported." in str(context.output)) + + def test_init_from_module(self): + filter = ImageFilter(TestModule()) + out = filter(torch.zeros(1, 3, 3, 3)) + torch.testing.assert_allclose(torch.ones(1, 3, 3, 3), out) + + def test_init_from_transform(self): + _ = ImageFilter(GaussianFilter(3, sigma=2)) + + def test_init_from_wrong_type_fails(self): + with self.assertRaises(Exception) as context: + _ = ImageFilter(TestNotAModuleOrTransform()) + self.assertTrue(" is not supported." in str(context.output)) + + @parameterized.expand(EXPECTED_FILTERS.keys()) + def test_2d_filter_correctness(self, filter_name): + "Test correctness of filters (2d only)" + tfm = ImageFilter(filter_name, 3, **ADDITIONAL_ARGUMENTS) + filter = tfm._get_filter_from_string(filter_name, size=3, ndim=2).filter.squeeze() + torch.testing.assert_allclose(filter, EXPECTED_FILTERS[filter_name]) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_2d(self, filter_name): + "Text function `__call__` for 2d images" + filter = ImageFilter(filter_name, 3, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_IMAGE_2D) + self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_2D.shape[1:]) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_3d(self, filter_name): + "Text function `__call__` for 3d images" + filter = ImageFilter(filter_name, 3, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_IMAGE_3D) + self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_3D.shape[1:]) + + +class TestImageFilterDict(unittest.TestCase): + @parameterized.expand(SUPPORTED_FILTERS) + def test_init_from_string_dict(self, filter_name): + "Test init from string and assert an error is thrown if no size is passed" + _ = ImageFilterd("image", filter_name, 3, **ADDITIONAL_ARGUMENTS) + with self.assertRaises(Exception) as _: + _ = ImageFilterd(self.image_key, filter_name) + + def test_init_from_array_dict(self): + "Test init with custom filter and assert wrong filter shape throws an error" + _ = ImageFilterd("image", torch.ones(3, 3)) + with self.assertRaises(Exception) as _: + _ = ImageFilterd(self.image_key, torch.ones(3, 3, 3, 3)) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_2d(self, filter_name): + "Text function `__call__` for 2d images" + filter = ImageFilterd("image_2d", filter_name, 3, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_DICT) + self.assertEqual(out_tensor["image_2d"].shape[1:], SAMPLE_IMAGE_2D.shape[1:]) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_3d(self, filter_name): + "Text function `__call__` for 3d images" + filter = ImageFilterd("image_3d", filter_name, 3, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_DICT) + self.assertEqual(out_tensor["image_3d"].shape[1:], SAMPLE_IMAGE_3D.shape[1:]) + + +class TestRandImageFilter(unittest.TestCase): + @parameterized.expand(SUPPORTED_FILTERS) + def test_init_from_string(self, filter_name): + "Test init from string and assert an error is thrown if no size is passed" + _ = RandImageFilter(filter_name, 3, **ADDITIONAL_ARGUMENTS) + with self.assertRaises(Exception) as _: + _ = RandImageFilter(filter_name) + + def test_init_from_array(self): + "Test init with custom filter and assert wrong filter shape throws an error" + _ = RandImageFilter(torch.ones(3, 3)) + with self.assertRaises(Exception) as _: + _ = RandImageFilter(torch.ones(3, 3, 3, 3)) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_2d_prob_1(self, filter_name): + "Text function `__call__` for 2d images" + filter = RandImageFilter(filter_name, 3, 1, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_IMAGE_2D) + self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_2D.shape[1:]) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_3d_prob_1(self, filter_name): + "Text function `__call__` for 3d images" + filter = RandImageFilter(filter_name, 3, 1, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_IMAGE_3D) + self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_3D.shape[1:]) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_2d_prob_0(self, filter_name): + "Text function `__call__` for 2d images" + filter = RandImageFilter(filter_name, 3, 0, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_IMAGE_2D) + torch.testing.assert_allclose(out_tensor, SAMPLE_IMAGE_2D) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_3d_prob_0(self, filter_name): + "Text function `__call__` for 3d images" + filter = RandImageFilter(filter_name, 3, 0, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_IMAGE_3D) + torch.testing.assert_allclose(out_tensor, SAMPLE_IMAGE_3D) + + +class TestRandImageFilterDict(unittest.TestCase): + @parameterized.expand(SUPPORTED_FILTERS) + def test_init_from_string_dict(self, filter_name): + "Test init from string and assert an error is thrown if no size is passed" + _ = RandImageFilterd("image", filter_name, 3, **ADDITIONAL_ARGUMENTS) + with self.assertRaises(Exception) as _: + _ = RandImageFilterd("image", filter_name) + + def test_init_from_array_dict(self): + "Test init with custom filter and assert wrong filter shape throws an error" + _ = RandImageFilterd("image", torch.ones(3, 3)) + with self.assertRaises(Exception) as _: + _ = RandImageFilterd("image", torch.ones(3, 3, 3, 3)) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_2d_prob_1(self, filter_name): + filter = RandImageFilterd("image_2d", filter_name, 3, 1.0, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_DICT) + self.assertEqual(out_tensor["image_2d"].shape[1:], SAMPLE_IMAGE_2D.shape[1:]) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_3d_prob_1(self, filter_name): + filter = RandImageFilterd("image_3d", filter_name, 3, 1.0, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_DICT) + self.assertEqual(out_tensor["image_3d"].shape[1:], SAMPLE_IMAGE_3D.shape[1:]) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_2d_prob_0(self, filter_name): + filter = RandImageFilterd("image_2d", filter_name, 3, 0.0, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_DICT) + torch.testing.assert_allclose(out_tensor["image_2d"].shape[1:], SAMPLE_IMAGE_2D.shape[1:]) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_3d_prob_0(self, filter_name): + filter = RandImageFilterd("image_3d", filter_name, 3, 0.0, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_DICT) + torch.testing.assert_allclose(out_tensor["image_3d"].shape[1:], SAMPLE_IMAGE_3D.shape[1:]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_image_rw.py b/tests/test_image_rw.py index a120607224..42f9b0c4e7 100644 --- a/tests/test_image_rw.py +++ b/tests/test_image_rw.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import itertools import os import shutil diff --git a/tests/test_img2tensorboard.py b/tests/test_img2tensorboard.py index 58c4d3cfab..7825f9b4d7 100644 --- a/tests/test_img2tensorboard.py +++ b/tests/test_img2tensorboard.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_init_reader.py b/tests/test_init_reader.py index e9a8ec96f7..1350146220 100644 --- a/tests/test_init_reader.py +++ b/tests/test_init_reader.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from monai.data import ITKReader, NibabelReader, NrrdReader, NumpyReader, PILReader, PydicomReader diff --git a/tests/test_integration_autorunner.py b/tests/test_integration_autorunner.py index 237045fda7..84583852ef 100644 --- a/tests/test_integration_autorunner.py +++ b/tests/test_integration_autorunner.py @@ -9,10 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest -from typing import Dict, List import nibabel as nib import numpy as np @@ -22,12 +23,18 @@ from monai.bundle.config_parser import ConfigParser from monai.data import create_test_image_3d from monai.utils import optional_import -from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick +from tests.utils import ( + SkipIfBeforePyTorchVersion, + get_testing_algo_template_path, + skip_if_downloading_fails, + skip_if_no_cuda, + skip_if_quick, +) _, has_tb = optional_import("torch.utils.tensorboard", name="SummaryWriter") _, has_nni = optional_import("nni") -sim_datalist: Dict[str, List[Dict]] = { +sim_datalist: dict[str, list[dict]] = { "testing": [{"image": "val_001.fake.nii.gz"}, {"image": "val_002.fake.nii.gz"}], "training": [ {"fold": 0, "image": "tr_image_001.fake.nii.gz", "label": "tr_label_001.fake.nii.gz"}, @@ -45,10 +52,8 @@ ], } -num_gpus = 4 if torch.cuda.device_count() > 4 else torch.cuda.device_count() train_param = ( { - "CUDA_VISIBLE_DEVICES": list(range(num_gpus)), "num_images_per_batch": 2, "num_epochs": 2, "num_epochs_per_validation": 1, @@ -108,7 +113,9 @@ def setUp(self) -> None: @skip_if_no_cuda def test_autorunner(self) -> None: work_dir = os.path.join(self.test_path, "work_dir") - runner = AutoRunner(work_dir=work_dir, input=self.data_src_cfg) + runner = AutoRunner( + work_dir=work_dir, input=self.data_src_cfg, templates_path_or_url=get_testing_algo_template_path() + ) runner.set_training_params(train_param) # 2 epochs runner.set_num_fold(1) with skip_if_downloading_fails(): @@ -117,7 +124,9 @@ def test_autorunner(self) -> None: @skip_if_no_cuda def test_autorunner_ensemble(self) -> None: work_dir = os.path.join(self.test_path, "work_dir") - runner = AutoRunner(work_dir=work_dir, input=self.data_src_cfg) + runner = AutoRunner( + work_dir=work_dir, input=self.data_src_cfg, templates_path_or_url=get_testing_algo_template_path() + ) runner.set_training_params(train_param) # 2 epochs runner.set_ensemble_method("AlgoEnsembleBestByFold") runner.set_num_fold(1) @@ -127,7 +136,9 @@ def test_autorunner_ensemble(self) -> None: @skip_if_no_cuda def test_autorunner_gpu_customization(self) -> None: work_dir = os.path.join(self.test_path, "work_dir") - runner = AutoRunner(work_dir=work_dir, input=self.data_src_cfg) + runner = AutoRunner( + work_dir=work_dir, input=self.data_src_cfg, templates_path_or_url=get_testing_algo_template_path() + ) gpu_customization_specs = { "universal": {"num_trials": 1, "range_num_images_per_batch": [1, 2], "range_num_sw_batch_size": [1, 2]} } @@ -141,9 +152,14 @@ def test_autorunner_gpu_customization(self) -> None: @unittest.skipIf(not has_nni, "nni required") def test_autorunner_hpo(self) -> None: work_dir = os.path.join(self.test_path, "work_dir") - runner = AutoRunner(work_dir=work_dir, input=self.data_src_cfg, hpo=True, ensemble=False) + runner = AutoRunner( + work_dir=work_dir, + input=self.data_src_cfg, + hpo=True, + ensemble=False, + templates_path_or_url=get_testing_algo_template_path(), + ) hpo_param = { - "CUDA_VISIBLE_DEVICES": train_param["CUDA_VISIBLE_DEVICES"], "num_epochs_per_validation": train_param["num_epochs_per_validation"], "num_images_per_batch": train_param["num_images_per_batch"], "num_epochs": train_param["num_epochs"], diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py index 1b3583a911..24c0286133 100644 --- a/tests/test_integration_bundle_run.py +++ b/tests/test_integration_bundle_run.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import json import os import shutil @@ -19,6 +21,7 @@ import nibabel as nib import numpy as np +import torch from parameterized import parameterized from monai.bundle import ConfigParser @@ -54,6 +57,7 @@ def test_tiny(self): json.dump( { "trainer": {"_target_": "tests.test_integration_bundle_run._Runnable42", "val": 42}, + # keep this test case to cover the "runner_id" arg "training": "$@trainer.run()", }, f, @@ -109,9 +113,10 @@ def test_shape(self, config_file, expected_shape): override = "--network $@network_def.to(@device) --dataset#_target_ Dataset" else: override = f"--network %{overridefile1}#move_net --dataset#_target_ %{overridefile2}" + device = "$torch.device('cuda:0')" if torch.cuda.is_available() else "$torch.device('cpu')" # test with `monai.bundle` as CLI entry directly - cmd = "-m monai.bundle run evaluating --postprocessing#transforms#2#output_postfix seg" - cmd += f" {override} --no_epoch False --output_dir {tempdir}" + cmd = "-m monai.bundle run --postprocessing#transforms#2#output_postfix seg" + cmd += f" {override} --no_epoch False --output_dir {tempdir} --device {device}" la = ["coverage", "run"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] test_env = os.environ.copy() print(f"CUDA_VISIBLE_DEVICES in {__file__}", test_env.get("CUDA_VISIBLE_DEVICES")) @@ -122,8 +127,8 @@ def test_shape(self, config_file, expected_shape): tracking_uri = path_to_uri(tempdir) + "/mlflow_override2" # test override experiment management configs # here test the script with `google fire` tool as CLI - cmd = "-m fire monai.bundle.scripts run --runner_id evaluating --tracking mlflow --evaluator#amp False" - cmd += f" --tracking_uri {tracking_uri} {override} --output_dir {tempdir}" + cmd = "-m fire monai.bundle.scripts run --tracking mlflow --evaluator#amp False" + cmd += f" --tracking_uri {tracking_uri} {override} --output_dir {tempdir} --device {device}" la = ["coverage", "run"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] command_line_tests(la) self.assertTupleEqual(loader(os.path.join(tempdir, "image", "image_trans.nii.gz")).shape, expected_shape) diff --git a/tests/test_integration_classification_2d.py b/tests/test_integration_classification_2d.py index 78f0bf9f36..da883724a0 100644 --- a/tests/test_integration_classification_2d.py +++ b/tests/test_integration_classification_2d.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest import warnings @@ -56,7 +58,6 @@ def __getitem__(self, index): def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0", num_workers=10): - monai.config.print_config() # define transforms for image and classification train_transforms = Compose( diff --git a/tests/test_integration_determinism.py b/tests/test_integration_determinism.py index 94d2325514..26d95dfb47 100644 --- a/tests/test_integration_determinism.py +++ b/tests/test_integration_determinism.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_integration_fast_train.py b/tests/test_integration_fast_train.py index bb50ddf7b6..497fe22dab 100644 --- a/tests/test_integration_fast_train.py +++ b/tests/test_integration_fast_train.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import math import os import shutil diff --git a/tests/test_integration_gpu_customization.py b/tests/test_integration_gpu_customization.py index 787c222c9f..4c8772ea50 100644 --- a/tests/test_integration_gpu_customization.py +++ b/tests/test_integration_gpu_customization.py @@ -9,10 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest -from typing import Dict, List import nibabel as nib import numpy as np @@ -23,11 +24,17 @@ from monai.data import create_test_image_3d from monai.utils import optional_import from monai.utils.enums import AlgoEnsembleKeys -from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick +from tests.utils import ( + SkipIfBeforePyTorchVersion, + get_testing_algo_template_path, + skip_if_downloading_fails, + skip_if_no_cuda, + skip_if_quick, +) _, has_tb = optional_import("torch.utils.tensorboard", name="SummaryWriter") -fake_datalist: Dict[str, List[Dict]] = { +fake_datalist: dict[str, list[dict]] = { "testing": [{"image": "val_001.fake.nii.gz"}, {"image": "val_002.fake.nii.gz"}], "training": [ {"fold": 0, "image": "tr_image_001.fake.nii.gz", "label": "tr_label_001.fake.nii.gz"}, @@ -45,10 +52,8 @@ ], } -num_gpus = 4 if torch.cuda.device_count() > 4 else torch.cuda.device_count() train_param = ( { - "CUDA_VISIBLE_DEVICES": list(range(num_gpus)), "num_images_per_batch": 2, "num_epochs": 2, "num_epochs_per_validation": 1, @@ -119,7 +124,10 @@ def test_ensemble_gpu_customization(self) -> None: with skip_if_downloading_fails(): bundle_generator = BundleGen( - algo_path=work_dir, data_stats_filename=da_output_yaml, data_src_cfg_name=data_src_cfg + algo_path=work_dir, + data_stats_filename=da_output_yaml, + data_src_cfg_name=data_src_cfg, + templates_path_or_url=get_testing_algo_template_path(), ) gpu_customization_specs = { diff --git a/tests/test_integration_segmentation_3d.py b/tests/test_integration_segmentation_3d.py index 27e42b6bf5..2b57ef84d6 100644 --- a/tests/test_integration_segmentation_3d.py +++ b/tests/test_integration_segmentation_3d.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import shutil import tempfile diff --git a/tests/test_integration_sliding_window.py b/tests/test_integration_sliding_window.py index 38cab0c2f0..4ae894956b 100644 --- a/tests/test_integration_sliding_window.py +++ b/tests/test_integration_sliding_window.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_integration_stn.py b/tests/test_integration_stn.py index 5b9b22668a..c858060c31 100644 --- a/tests/test_integration_stn.py +++ b/tests/test_integration_stn.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -45,7 +47,7 @@ def __init__(self, is_ref=True, reverse_indexing=False): self.fc_loc[2].weight.data.zero_() self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) if not self.is_ref: - self.xform = AffineTransform(normalized=True, reverse_indexing=reverse_indexing) + self.xform = AffineTransform(align_corners=False, normalized=True, reverse_indexing=reverse_indexing) # Spatial transformer network forward function def stn_ref(self, x): diff --git a/tests/test_integration_unet_2d.py b/tests/test_integration_unet_2d.py index e60c91968a..90c0098d36 100644 --- a/tests/test_integration_unet_2d.py +++ b/tests/test_integration_unet_2d.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_integration_workers.py b/tests/test_integration_workers.py index 654af5a89c..33c26cedf8 100644 --- a/tests/test_integration_workers.py +++ b/tests/test_integration_workers.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 342e70cc8e..9393f0bb42 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import shutil import tempfile diff --git a/tests/test_integration_workflows_gan.py b/tests/test_integration_workflows_gan.py index 7dd05848bb..db841fca7e 100644 --- a/tests/test_integration_workflows_gan.py +++ b/tests/test_integration_workflows_gan.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import shutil import tempfile diff --git a/tests/test_intensity_stats.py b/tests/test_intensity_stats.py index 3479306180..243fcd0dd4 100644 --- a/tests/test_intensity_stats.py +++ b/tests/test_intensity_stats.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_intensity_statsd.py b/tests/test_intensity_statsd.py index d6aeac61b0..3fe82b1df7 100644 --- a/tests/test_intensity_statsd.py +++ b/tests/test_intensity_statsd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import sys import unittest diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 192a8d345e..0423c80d6b 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -9,12 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import random import sys import unittest from copy import deepcopy from functools import partial -from typing import TYPE_CHECKING, List, Tuple +from typing import TYPE_CHECKING from unittest.case import skipUnless import numpy as np @@ -70,14 +72,13 @@ from tests.utils import make_nifti_image, make_rand_affine if TYPE_CHECKING: - has_nib = True else: _, has_nib = optional_import("nibabel") KEYS = ["image", "label"] -TESTS: List[Tuple] = [] +TESTS: list[tuple] = [] # For pad, start with odd/even images and add odd/even amounts for name in ("1D even", "1D odd"): @@ -452,7 +453,6 @@ def test_inverse(self, _, data_name, acceptable_diff, is_meta, *transforms): # skip this test if multiprocessing uses 'spawn', as the check is only basic anyway @skipUnless(torch.multiprocessing.get_start_method() == "spawn", "requires spawn") def test_fail(self): - t1 = SpatialPadd("image", [10, 5]) data = t1(self.all_data["2D"]) @@ -463,7 +463,6 @@ def test_fail(self): @parameterized.expand(N_SAMPLES_TESTS) def test_inverse_inferred_seg(self, extra_transform): - test_data = [] for _ in range(20): image, label = create_test_image_2d(100, 101) diff --git a/tests/test_inverse_array.py b/tests/test_inverse_array.py index c86bf92bef..ea021f2566 100644 --- a/tests/test_inverse_array.py +++ b/tests/test_inverse_array.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_inverse_collation.py b/tests/test_inverse_collation.py index 4614432808..05e296e6b9 100644 --- a/tests/test_inverse_collation.py +++ b/tests/test_inverse_collation.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import sys import unittest from typing import TYPE_CHECKING @@ -44,7 +46,6 @@ from tests.utils import make_nifti_image if TYPE_CHECKING: - has_nib = True else: _, has_nib = optional_import("nibabel") diff --git a/tests/test_invert.py b/tests/test_invert.py index 5d4926ecd2..c260193f75 100644 --- a/tests/test_invert.py +++ b/tests/test_invert.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import sys import unittest from copy import deepcopy diff --git a/tests/test_invertd.py b/tests/test_invertd.py index f3b504b39d..5ad8735fdc 100644 --- a/tests/test_invertd.py +++ b/tests/test_invertd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import sys import unittest diff --git a/tests/test_is_supported_format.py b/tests/test_is_supported_format.py index 6cb887c125..591772bb3a 100644 --- a/tests/test_is_supported_format.py +++ b/tests/test_is_supported_format.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 2c47a2181e..38be9ec30c 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import sys import tempfile @@ -32,7 +34,7 @@ def __iter__(self): class TestIterableDataset(unittest.TestCase): def test_shape(self): expected_shape = (128, 128, 128) - test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) test_data = [] with tempfile.TemporaryDirectory() as tempdir: for i in range(6): diff --git a/tests/test_itk_torch_bridge.py b/tests/test_itk_torch_bridge.py new file mode 100644 index 0000000000..c08db89198 --- /dev/null +++ b/tests/test_itk_torch_bridge.py @@ -0,0 +1,486 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.apps import download_url +from monai.data import ITKReader +from monai.data.itk_torch_bridge import ( + get_itk_image_center, + itk_image_to_metatensor, + itk_to_monai_affine, + metatensor_to_itk_image, + monai_to_itk_affine, + monai_to_itk_ddf, +) +from monai.networks.blocks import Warp +from monai.transforms import Affine +from monai.utils import optional_import, set_determinism +from tests.utils import skip_if_downloading_fails, skip_if_quick, test_is_quick, testing_data_config + +itk, has_itk = optional_import("itk") + +TESTS = ["CT_2D_head_fixed.mha", "CT_2D_head_moving.mha"] +if not test_is_quick(): + TESTS += ["copd1_highres_INSP_STD_COPD_img.nii.gz", "copd1_highres_EXP_STD_COPD_img.nii.gz"] + + +@unittest.skipUnless(has_itk, "Requires `itk` package.") +class TestITKTorchAffineMatrixBridge(unittest.TestCase): + def setUp(self): + set_determinism(seed=0) + self.data_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") + self.reader = ITKReader(pixel_type=itk.F) + + for file_name in TESTS: + path = os.path.join(self.data_dir, file_name) + if not os.path.exists(path): + with skip_if_downloading_fails(): + data_spec = testing_data_config("images", f"{file_name.split('.', 1)[0]}") + download_url( + data_spec["url"], path, hash_val=data_spec["hash_val"], hash_type=data_spec["hash_type"] + ) + + def tearDown(self): + set_determinism(seed=None) + + def create_itk_affine_from_parameters( + self, image, translation=None, rotation=None, scale=None, shear=None, center_of_rotation=None + ): + """ + Creates an affine transformation for an ITK image based on the provided parameters. + + Args: + image: The ITK image. + translation: The translation (shift) to apply to the image. + rotation: The rotation to apply to the image, specified as angles in radians around the x, y, and z axes. + scale: The scaling factor to apply to the image. + shear: The shear to apply to the image. + center_of_rotation: The center of rotation for the image. If not specified, + the center of the image is used. + + Returns: + A tuple containing the affine transformation matrix and the translation vector. + """ + itk_transform = itk.AffineTransform[itk.D, image.ndim].New() + + # Set center + if center_of_rotation: + itk_transform.SetCenter(center_of_rotation) + else: + itk_transform.SetCenter(get_itk_image_center(image)) + + # Set parameters + if rotation: + if image.ndim == 2: + itk_transform.Rotate2D(rotation[0]) + else: + for i, angle_in_rads in enumerate(rotation): + if angle_in_rads != 0: + axis = [0, 0, 0] + axis[i] = 1 + itk_transform.Rotate3D(axis, angle_in_rads) + + if scale: + itk_transform.Scale(scale) + + if shear: + itk_transform.Shear(*shear) + + if translation: + itk_transform.Translate(translation) + + matrix = np.asarray(itk_transform.GetMatrix(), dtype=np.float64) + + return matrix, translation + + def itk_affine_resample(self, image, matrix, translation, center_of_rotation=None, reference_image=None): + # Translation transform + itk_transform = itk.AffineTransform[itk.D, image.ndim].New() + + # Set center + if center_of_rotation: + itk_transform.SetCenter(center_of_rotation) + else: + itk_transform.SetCenter(get_itk_image_center(image)) + + # Set matrix and translation + itk_transform.SetMatrix(itk.matrix_from_array(matrix)) + itk_transform.Translate(translation) + + # Interpolator + image = image.astype(itk.D) + interpolator = itk.LinearInterpolateImageFunction.New(image) + + if not reference_image: + reference_image = image + + # Resample with ITK + output_image = itk.resample_image_filter( + image, interpolator=interpolator, transform=itk_transform, output_parameters_from_image=reference_image + ) + + return np.asarray(output_image, dtype=np.float32) + + def monai_affine_resample(self, metatensor, affine_matrix): + affine = Affine( + affine=affine_matrix, padding_mode="zeros", mode="bilinear", dtype=torch.float64, image_only=True + ) + output_tensor = affine(metatensor) + + return output_tensor.squeeze().permute(*torch.arange(output_tensor.ndim - 2, -1, -1)).array + + def remove_border(self, image): + """ + MONAI seems to have different behavior in the borders of the image than ITK. + This helper function sets the border of the ITK image as 0 (padding but keeping + the same image size) in order to allow numerical comparison between the + result from resampling with ITK/Elastix and resampling with MONAI. + To use: image[:] = remove_border(image) + Args: + image: The ITK image to be padded. + + Returns: + The padded array of data. + """ + return np.pad(image[1:-1, 1:-1, 1:-1] if image.ndim == 3 else image[1:-1, 1:-1], pad_width=1) + + def itk_warp(self, image, ddf): + """ + Warping with python itk + Args: + image: itk image of array shape 2D: (H, W) or 3D: (D, H, W) + ddf: numpy array of shape 2D: (2, H, W) or 3D: (3, D, H, W) + Returns: + warped_image: numpy array of shape (H, W) or (D, H, W) + """ + # MONAI -> ITK ddf + displacement_field = monai_to_itk_ddf(image, ddf) + + # Resample using the ddf + interpolator = itk.LinearInterpolateImageFunction.New(image) + warped_image = itk.warp_image_filter( + image, interpolator=interpolator, displacement_field=displacement_field, output_parameters_from_image=image + ) + + return np.asarray(warped_image) + + def monai_warp(self, image_tensor, ddf_tensor): + """ + Warping with MONAI + Args: + image_tensor: torch tensor of shape 2D: (1, 1, H, W) and 3D: (1, 1, D, H, W) + ddf_tensor: torch tensor of shape 2D: (1, 2, H, W) and 3D: (1, 3, D, H, W) + Returns: + warped_image: numpy array of shape (H, W) or (D, H, W) + """ + warp = Warp(mode="bilinear", padding_mode="zeros") + warped_image = warp(image_tensor.to(torch.float64), ddf_tensor.to(torch.float64)) + + return warped_image.to(torch.float32).squeeze().numpy() + + @parameterized.expand(TESTS) + def test_setting_affine_parameters(self, filepath): + # Read image + image = self.reader.read(os.path.join(self.data_dir, filepath)) + image[:] = self.remove_border(image) + ndim = image.ndim + + # Affine parameters + translation = [65.2, -50.2, 33.9][:ndim] + rotation = [0.78539816339, 1.0, -0.66][:ndim] + scale = [2.0, 1.5, 3.2][:ndim] + shear = [0, 1, 1.6] # axis1, axis2, coeff + + # Spacing + spacing = np.array([1.2, 1.5, 2.0])[:ndim] + image.SetSpacing(spacing) + + # ITK + matrix, translation = self.create_itk_affine_from_parameters(image, translation, rotation, scale, shear) + output_array_itk = self.itk_affine_resample(image, matrix=matrix, translation=translation) + + # MONAI + metatensor = itk_image_to_metatensor(image) + affine_matrix_for_monai = itk_to_monai_affine(image, matrix, translation) + output_array_monai = self.monai_affine_resample(metatensor, affine_matrix=affine_matrix_for_monai) + + # Make sure that the array conversion of the inputs is the same + input_array_monai = metatensor.squeeze().permute(*torch.arange(metatensor.ndim - 2, -1, -1)).array + np.testing.assert_array_equal(input_array_monai, np.asarray(image)) + + # Compare outputs + percentage = ( + 100 * np.isclose(output_array_monai, output_array_itk).sum(dtype=np.float64) / output_array_itk.size + ) + self.assertGreaterEqual(percentage, 99.0) + + @parameterized.expand(TESTS) + def test_arbitary_center_of_rotation(self, filepath): + # Read image + image = self.reader.read(os.path.join(self.data_dir, filepath)) + image[:] = self.remove_border(image) + ndim = image.ndim + + # ITK matrix (3x3 affine matrix) + matrix = np.array( + [ + [0.55915995, 0.50344867, 0.43208387], + [0.01133669, 0.82088571, 0.86841365], + [0.30478496, 0.94998986, 0.32742505], + ] + )[:ndim, :ndim] + translation = [54.0, 2.7, -11.9][:ndim] + + # Spatial properties + center_of_rotation = [-32.3, 125.1, 0.7][:ndim] + origin = [1.6, 0.5, 2.0][:ndim] + spacing = np.array([1.2, 1.5, 0.6])[:ndim] + + image.SetSpacing(spacing) + image.SetOrigin(origin) + + # ITK + output_array_itk = self.itk_affine_resample(image, matrix, translation, center_of_rotation) + + # MONAI + metatensor = itk_image_to_metatensor(image) + affine_matrix_for_monai = itk_to_monai_affine(image, matrix, translation, center_of_rotation) + output_array_monai = self.monai_affine_resample(metatensor, affine_matrix=affine_matrix_for_monai) + + # Make sure that the array conversion of the inputs is the same + input_array_monai = metatensor.squeeze().permute(*torch.arange(metatensor.ndim - 2, -1, -1)).array + np.testing.assert_array_equal(input_array_monai, np.asarray(image)) + + # Compare outputs + percentage = ( + 100 * np.isclose(output_array_monai, output_array_itk).sum(dtype=np.float64) / output_array_itk.size + ) + self.assertGreaterEqual(percentage, 99.0) + + @parameterized.expand(TESTS) + def test_monai_to_itk(self, filepath): + # Read image + image = self.reader.read(os.path.join(self.data_dir, filepath)) + image[:] = self.remove_border(image) + ndim = image.ndim + + # MONAI affine matrix + affine_matrix = torch.eye(ndim + 1, dtype=torch.float64) + affine_matrix[:ndim, :ndim] = torch.tensor( + [ + [0.55915995, 0.50344867, 0.43208387], + [0.01133669, 0.82088571, 0.86841365], + [0.30478496, 0.94998986, 0.32742505], + ], + dtype=torch.float64, + )[:ndim, :ndim] + + affine_matrix[:ndim, ndim] = torch.tensor([54.0, 2.7, -11.9], dtype=torch.float64)[:ndim] + + # Spatial properties + center_of_rotation = [-32.3, 125.1, 0.7][:ndim] + origin = [1.6, 0.5, 2.0][:ndim] + spacing = np.array([1.2, 1.5, 0.6])[:ndim] + + image.SetSpacing(spacing) + image.SetOrigin(origin) + + # ITK + matrix, translation = monai_to_itk_affine(image, affine_matrix, center_of_rotation) + output_array_itk = self.itk_affine_resample(image, matrix, translation, center_of_rotation) + + # MONAI + metatensor = itk_image_to_metatensor(image) + output_array_monai = self.monai_affine_resample(metatensor, affine_matrix) + + # Make sure that the array conversion of the inputs is the same + input_array_monai = metatensor.squeeze().permute(*torch.arange(metatensor.ndim - 2, -1, -1)).array + np.testing.assert_array_equal(input_array_monai, np.asarray(image)) + + # Compare outputs + percentage = ( + 100 * np.isclose(output_array_monai, output_array_itk).sum(dtype=np.float64) / output_array_itk.size + ) + self.assertGreaterEqual(percentage, 99.0) + + @parameterized.expand(TESTS) + def test_cyclic_conversion(self, filepath): + image = self.reader.read(os.path.join(self.data_dir, filepath)) + image[:] = self.remove_border(image) + ndim = image.ndim + + # ITK matrix (3x3 affine matrix) + matrix = np.array( + [ + [2.90971094, 1.18297296, 2.60008784], + [0.29416137, 0.10294283, 2.82302616], + [1.70578374, 1.39706003, 2.54652029], + ] + )[:ndim, :ndim] + + translation = [-29.05463245, 35.27116398, 48.58759597][:ndim] + + # Spatial properties + center_of_rotation = [-27.84789587, -60.7871084, 42.73501932][:ndim] + origin = [8.10416794, 5.4831944, 0.49211025][:ndim] + spacing = np.array([0.7, 3.2, 1.3])[:ndim] + + direction = np.array( + [ + [1.02895588, 0.22791448, 0.02429561], + [0.21927512, 1.28632268, -0.14932226], + [0.47455613, 0.38534345, 0.98505633], + ], + dtype=np.float64, + ) + image.SetDirection(direction[:ndim, :ndim]) + + image.SetSpacing(spacing) + image.SetOrigin(origin) + + affine_matrix = itk_to_monai_affine(image, matrix, translation, center_of_rotation) + matrix_result, translation_result = monai_to_itk_affine(image, affine_matrix, center_of_rotation) + + meta_tensor = itk_image_to_metatensor(image) + image_result = metatensor_to_itk_image(meta_tensor) + + np.testing.assert_allclose(matrix, matrix_result) + np.testing.assert_allclose(translation, translation_result) + np.testing.assert_array_equal(image.shape, image_result.shape) + np.testing.assert_array_equal(image, image_result) + + @parameterized.expand([(2,), (3,)]) + def test_random_array(self, ndim): + # Create image/array with random size and pixel intensities + s = torch.randint(low=2, high=20, size=(ndim,)) + img = 100 * torch.rand((1, 1, *s.tolist()), dtype=torch.float32) + + # Pad at the edges because ITK and MONAI have different behavior there + # during resampling + img = torch.nn.functional.pad(img, pad=ndim * (1, 1)) + ddf = 5 * torch.rand((1, ndim, *img.shape[-ndim:]), dtype=torch.float32) - 2.5 + + # Warp with MONAI + img_resampled = self.monai_warp(img, ddf) + + # Create ITK image + itk_img = itk.GetImageFromArray(img.squeeze().numpy()) + + # Set random spacing + spacing = 3 * np.random.rand(ndim) + itk_img.SetSpacing(spacing) + + # Set random direction + direction = 5 * np.random.rand(ndim, ndim) - 5 + direction = itk.matrix_from_array(direction) + itk_img.SetDirection(direction) + + # Set random origin + origin = 100 * np.random.rand(ndim) - 100 + itk_img.SetOrigin(origin) + + # Warp with ITK + itk_img_resampled = self.itk_warp(itk_img, ddf.squeeze().numpy()) + + # Compare + np.testing.assert_allclose(img_resampled, itk_img_resampled, rtol=1e-2, atol=1e-2) + + @parameterized.expand(TESTS) + @skip_if_quick + def test_real_data(self, filepath): + # Read image + image = self.reader.read(os.path.join(self.data_dir, filepath)) + image[:] = self.remove_border(image) + ndim = image.ndim + + # Random ddf + ddf = 10 * torch.rand((1, ndim, *image.shape), dtype=torch.float32) - 10 + + # Warp with MONAI + image_tensor = torch.tensor(itk.GetArrayFromImage(image), dtype=torch.float32).unsqueeze(0).unsqueeze(0) + img_resampled = self.monai_warp(image_tensor, ddf) + + # Warp with ITK + itk_img_resampled = self.itk_warp(image, ddf.squeeze().numpy()) + + # Compare + np.testing.assert_allclose(img_resampled, itk_img_resampled, rtol=1e-3, atol=1e-3) + + @parameterized.expand(zip(TESTS[::2], TESTS[1::2])) + @skip_if_quick + def test_use_reference_space(self, ref_filepath, filepath): + # Read the images + image = self.reader.read(os.path.join(self.data_dir, filepath)) + image[:] = self.remove_border(image) + ndim = image.ndim + + ref_image = self.reader.read(os.path.join(self.data_dir, ref_filepath)) + + # Set arbitary origin, spacing, direction for both of the images + image.SetSpacing([1.2, 2.0, 1.7][:ndim]) + ref_image.SetSpacing([1.9, 1.5, 1.3][:ndim]) + + direction = np.array( + [ + [1.02895588, 0.22791448, 0.02429561], + [0.21927512, 1.28632268, -0.14932226], + [0.47455613, 0.38534345, 0.98505633], + ], + dtype=np.float64, + ) + image.SetDirection(direction[:ndim, :ndim]) + + ref_direction = np.array( + [ + [1.26032417, -0.19243174, 0.54877414], + [0.31958275, 0.9543068, 0.2720827], + [-0.24106769, -0.22344502, 0.9143302], + ], + dtype=np.float64, + ) + ref_image.SetDirection(ref_direction[:ndim, :ndim]) + + image.SetOrigin([57.3, 102.0, -20.9][:ndim]) + ref_image.SetOrigin([23.3, -0.5, 23.7][:ndim]) + + # Set affine parameters + matrix = np.array( + [ + [0.55915995, 0.50344867, 0.43208387], + [0.01133669, 0.82088571, 0.86841365], + [0.30478496, 0.94998986, 0.32742505], + ] + )[:ndim, :ndim] + translation = [54.0, 2.7, -11.9][:ndim] + center_of_rotation = [-32.3, 125.1, 0.7][:ndim] + + # Resample using ITK + output_array_itk = self.itk_affine_resample(image, matrix, translation, center_of_rotation, ref_image) + + # MONAI + metatensor = itk_image_to_metatensor(image) + affine_matrix_for_monai = itk_to_monai_affine(image, matrix, translation, center_of_rotation, ref_image) + output_array_monai = self.monai_affine_resample(metatensor, affine_matrix_for_monai) + + # Compare outputs + np.testing.assert_allclose(output_array_monai, output_array_itk, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_itk_writer.py b/tests/test_itk_writer.py index 163fead76e..869ec7b947 100644 --- a/tests/test_itk_writer.py +++ b/tests/test_itk_writer.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_k_space_spike_noise.py b/tests/test_k_space_spike_noise.py index 537655b3e5..4d820573a6 100644 --- a/tests/test_k_space_spike_noise.py +++ b/tests/test_k_space_spike_noise.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy @@ -45,7 +47,6 @@ def get_data(im_shape, im_type): @parameterized.expand(TESTS) def test_same_result(self, im_shape, im_type, k_intensity): - im = self.get_data(im_shape, im_type) loc = [0, int(im.shape[1] / 2), 0] if len(im_shape) == 2 else [0, int(im.shape[1] / 2), 0, 0] t = KSpaceSpikeNoise(loc, k_intensity) @@ -61,7 +62,6 @@ def test_same_result(self, im_shape, im_type, k_intensity): @parameterized.expand(TESTS) def test_highlighted_kspace_pixel(self, im_shape, as_tensor_input, k_intensity): - im = self.get_data(im_shape, as_tensor_input) loc = [0, int(im.shape[1] / 2), 0] if len(im_shape) == 2 else [0, int(im.shape[1] / 2), 0, 0] t = KSpaceSpikeNoise(loc, k_intensity) diff --git a/tests/test_k_space_spike_noised.py b/tests/test_k_space_spike_noised.py index 03c99d1533..76a79d4b12 100644 --- a/tests/test_k_space_spike_noised.py +++ b/tests/test_k_space_spike_noised.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy @@ -47,7 +49,6 @@ def get_data(im_shape, im_type): @parameterized.expand(TESTS) def test_same_result(self, im_shape, im_type): - data = self.get_data(im_shape, im_type) loc = [0] + [int(im_shape[i] / 2) for i in range(len(im_shape))] k_intensity = 10 @@ -64,7 +65,6 @@ def test_same_result(self, im_shape, im_type): @parameterized.expand(TESTS) def test_highlighted_kspace_pixel(self, im_shape, im_type): - data = self.get_data(im_shape, im_type) loc = [0] + [int(im_shape[i] / 2) for i in range(len(im_shape))] k_intensity = 10 diff --git a/tests/test_keep_largest_connected_component.py b/tests/test_keep_largest_connected_component.py index a0e309f2d7..7da3c4b21f 100644 --- a/tests/test_keep_largest_connected_component.py +++ b/tests/test_keep_largest_connected_component.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy diff --git a/tests/test_keep_largest_connected_componentd.py b/tests/test_keep_largest_connected_componentd.py index 544b7f1773..aac91a2de9 100644 --- a/tests/test_keep_largest_connected_componentd.py +++ b/tests/test_keep_largest_connected_componentd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_kspace_mask.py b/tests/test_kspace_mask.py index 42b90d1675..5d6d9c18ea 100644 --- a/tests/test_kspace_mask.py +++ b/tests/test_kspace_mask.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_label_filter.py b/tests/test_label_filter.py index 42aa419b1d..47a8706491 100644 --- a/tests/test_label_filter.py +++ b/tests/test_label_filter.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_label_filterd.py b/tests/test_label_filterd.py index eea18d0278..f27df08c2a 100644 --- a/tests/test_label_filterd.py +++ b/tests/test_label_filterd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_label_quality_score.py b/tests/test_label_quality_score.py index db31624a95..aa243b4236 100644 --- a/tests/test_label_quality_score.py +++ b/tests/test_label_quality_score.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -101,6 +103,7 @@ class TestLabelQualityScore(unittest.TestCase): def test_value(self, input_data, expected_value): result = label_quality_score(**input_data) np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + np.testing.assert_equal(result.device, input_data["y_pred"].device) @parameterized.expand([TEST_CASE_6, TEST_CASE_7]) def test_spatial_case(self, input_data, expected_value): diff --git a/tests/test_label_to_contour.py b/tests/test_label_to_contour.py index cab116afbe..590fd5d4e4 100644 --- a/tests/test_label_to_contour.py +++ b/tests/test_label_to_contour.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_label_to_contourd.py b/tests/test_label_to_contourd.py index e11b59130a..6fcec72dd8 100644 --- a/tests/test_label_to_contourd.py +++ b/tests/test_label_to_contourd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_label_to_mask.py b/tests/test_label_to_mask.py index e374018836..2eba825cf3 100644 --- a/tests/test_label_to_mask.py +++ b/tests/test_label_to_mask.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_label_to_maskd.py b/tests/test_label_to_maskd.py index 3b6f527c9a..35f54ca5b9 100644 --- a/tests/test_label_to_maskd.py +++ b/tests/test_label_to_maskd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_lambda.py b/tests/test_lambda.py index 3fa080f794..91678c0b81 100644 --- a/tests/test_lambda.py +++ b/tests/test_lambda.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from monai.data.meta_tensor import MetaTensor diff --git a/tests/test_lambdad.py b/tests/test_lambdad.py index 1e333da12b..55df819fa9 100644 --- a/tests/test_lambdad.py +++ b/tests/test_lambdad.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from monai.data.meta_tensor import MetaTensor diff --git a/tests/test_lesion_froc.py b/tests/test_lesion_froc.py index 8c9f751b1e..10682c2bb7 100644 --- a/tests/test_lesion_froc.py +++ b/tests/test_lesion_froc.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest from unittest import skipUnless diff --git a/tests/test_list_data_collate.py b/tests/test_list_data_collate.py index 7b6417820c..226d9e1a55 100644 --- a/tests/test_list_data_collate.py +++ b/tests/test_list_data_collate.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_list_to_dict.py b/tests/test_list_to_dict.py index ec81310c9f..4e6bb8cdf7 100644 --- a/tests/test_list_to_dict.py +++ b/tests/test_list_to_dict.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_lltm.py b/tests/test_lltm.py index 877d211767..6ee716e1ef 100644 --- a/tests/test_lltm.py +++ b/tests/test_lltm.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_lmdbdataset.py b/tests/test_lmdbdataset.py index 33f27ee4bc..080868b0dd 100644 --- a/tests/test_lmdbdataset.py +++ b/tests/test_lmdbdataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import shutil import tempfile @@ -124,7 +126,7 @@ def test_cache(self): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) def test_shape(self, transform, expected_shape, kwargs=None): kwargs = kwargs or {} - test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) with tempfile.TemporaryDirectory() as tempdir: nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) diff --git a/tests/test_lmdbdataset_dist.py b/tests/test_lmdbdataset_dist.py index cad2949dde..0b4c7c35fa 100644 --- a/tests/test_lmdbdataset_dist.py +++ b/tests/test_lmdbdataset_dist.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import shutil import tempfile import unittest diff --git a/tests/test_load_decathlon_datalist.py b/tests/test_load_decathlon_datalist.py index a58ba73ece..b0e390cd73 100644 --- a/tests/test_load_decathlon_datalist.py +++ b/tests/test_load_decathlon_datalist.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import json import os import tempfile diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 1db39a310b..ec748f9951 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import shutil import tempfile diff --git a/tests/test_load_imaged.py b/tests/test_load_imaged.py index 8210d2f0d1..534cbb6618 100644 --- a/tests/test_load_imaged.py +++ b/tests/test_load_imaged.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import shutil import tempfile diff --git a/tests/test_load_spacing_orientation.py b/tests/test_load_spacing_orientation.py index 8257b9965f..e6ff5f8317 100644 --- a/tests/test_load_spacing_orientation.py +++ b/tests/test_load_spacing_orientation.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import time import unittest diff --git a/tests/test_loader_semaphore.py b/tests/test_loader_semaphore.py index 85cf5593f8..859ee1f8d5 100644 --- a/tests/test_loader_semaphore.py +++ b/tests/test_loader_semaphore.py @@ -10,6 +10,8 @@ # limitations under the License. """this test should not generate errors or UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores""" +from __future__ import annotations + import multiprocessing as mp import unittest diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py index e6052824a9..21fe8b973f 100644 --- a/tests/test_local_normalized_cross_correlation_loss.py +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_localnet.py b/tests/test_localnet.py index 9ad50b9be8..f557147960 100644 --- a/tests/test_localnet.py +++ b/tests/test_localnet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch @@ -32,6 +34,8 @@ "extract_levels": (0, 1), "pooling": False, "concat_skip": True, + "mode": "bilinear", + "align_corners": True, }, (1, 2, 16, 16), (1, 2, 16, 16), diff --git a/tests/test_localnet_block.py b/tests/test_localnet_block.py index d85509344e..27ea4cd1a6 100644 --- a/tests/test_localnet_block.py +++ b/tests/test_localnet_block.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch @@ -25,7 +27,17 @@ [{"spatial_dims": spatial_dims, "in_channels": 2, "out_channels": 4, "kernel_size": 3}] for spatial_dims in [2, 3] ] -TEST_CASE_UP_SAMPLE = [[{"spatial_dims": spatial_dims, "in_channels": 4, "out_channels": 2}] for spatial_dims in [2, 3]] +TEST_CASE_UP_SAMPLE = [ + [ + { + "spatial_dims": spatial_dims, + "in_channels": 4, + "out_channels": 2, + "mode": "bilinear" if spatial_dims == 2 else "trilinear", + } + ] + for spatial_dims in [2, 3] +] TEST_CASE_EXTRACT = [ [{"spatial_dims": spatial_dims, "in_channels": 2, "out_channels": 3, "act": act, "initializer": initializer}] diff --git a/tests/test_look_up_option.py b/tests/test_look_up_option.py index 700f7b9691..5f81fb8d43 100644 --- a/tests/test_look_up_option.py +++ b/tests/test_look_up_option.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from enum import Enum diff --git a/tests/test_loss_metric.py b/tests/test_loss_metric.py index 74b6bd14a0..682221f5f5 100644 --- a/tests/test_loss_metric.py +++ b/tests/test_loss_metric.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py index aed7976feb..c10016eeff 100644 --- a/tests/test_lr_finder.py +++ b/tests/test_lr_finder.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import pickle import random @@ -47,7 +49,6 @@ @unittest.skipUnless(has_pil, "requires PIL") class TestLRFinder(unittest.TestCase): def setUp(self): - self.root_dir = MONAIEnvVars.data_dir() if not self.root_dir: self.root_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py index 44f4c50c0f..bcddf7627e 100644 --- a/tests/test_lr_scheduler.py +++ b/tests/test_lr_scheduler.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_make_nifti.py b/tests/test_make_nifti.py index 951f079764..4560507c6c 100644 --- a/tests/test_make_nifti.py +++ b/tests/test_make_nifti.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_map_binary_to_indices.py b/tests/test_map_binary_to_indices.py index bc96231160..1080c2a513 100644 --- a/tests/test_map_binary_to_indices.py +++ b/tests/test_map_binary_to_indices.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_map_classes_to_indices.py b/tests/test_map_classes_to_indices.py index 2f32382f6b..9576dd0f61 100644 --- a/tests/test_map_classes_to_indices.py +++ b/tests/test_map_classes_to_indices.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_map_label_value.py b/tests/test_map_label_value.py index ef08f7eae3..32f5fccdb6 100644 --- a/tests/test_map_label_value.py +++ b/tests/test_map_label_value.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_map_label_valued.py b/tests/test_map_label_valued.py index cf8ca6c8e2..8c91adaa49 100644 --- a/tests/test_map_label_valued.py +++ b/tests/test_map_label_valued.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_map_transform.py b/tests/test_map_transform.py index dd77ccb099..7430cf09c7 100644 --- a/tests/test_map_transform.py +++ b/tests/test_map_transform.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_mask_intensity.py b/tests/test_mask_intensity.py index da4b2a8192..2b831ba415 100644 --- a/tests/test_mask_intensity.py +++ b/tests/test_mask_intensity.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_mask_intensityd.py b/tests/test_mask_intensityd.py index 186ed6f2d8..6a39416de4 100644 --- a/tests/test_mask_intensityd.py +++ b/tests/test_mask_intensityd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_masked_dice_loss.py b/tests/test_masked_dice_loss.py index 317da3a316..b868f4d3a1 100644 --- a/tests/test_masked_dice_loss.py +++ b/tests/test_masked_dice_loss.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_masked_inference_wsi_dataset.py b/tests/test_masked_inference_wsi_dataset.py index c424edd897..bb90f7900b 100644 --- a/tests/test_masked_inference_wsi_dataset.py +++ b/tests/test_masked_inference_wsi_dataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest @@ -38,7 +40,6 @@ def prepare_data(*masks): - mask = np.zeros((HEIGHT // 2, WIDTH // 2)) mask[100, 100] = 1 np.save(masks[0], mask) diff --git a/tests/test_masked_loss.py b/tests/test_masked_loss.py index a00b3ae7e7..a5f507ff97 100644 --- a/tests/test_masked_loss.py +++ b/tests/test_masked_loss.py @@ -8,6 +8,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_masked_patch_wsi_dataset.py b/tests/test_masked_patch_wsi_dataset.py index 9783c2d7cf..730ce97bdb 100644 --- a/tests/test_masked_patch_wsi_dataset.py +++ b/tests/test_masked_patch_wsi_dataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest from unittest import skipUnless diff --git a/tests/test_matshow3d.py b/tests/test_matshow3d.py index 18ed16dd2c..a6781fc72e 100644 --- a/tests/test_matshow3d.py +++ b/tests/test_matshow3d.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_mean_ensemble.py b/tests/test_mean_ensemble.py index 060170e3bf..09b7f94dc4 100644 --- a/tests/test_mean_ensemble.py +++ b/tests/test_mean_ensemble.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_mean_ensembled.py b/tests/test_mean_ensembled.py index f6d6286d35..01123b0729 100644 --- a/tests/test_mean_ensembled.py +++ b/tests/test_mean_ensembled.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_median_filter.py b/tests/test_median_filter.py index cf6286bdfe..9f27adff4c 100644 --- a/tests/test_median_filter.py +++ b/tests/test_median_filter.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_median_smooth.py b/tests/test_median_smooth.py index 314327835b..21cd45f28e 100644 --- a/tests/test_median_smooth.py +++ b/tests/test_median_smooth.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_median_smoothd.py b/tests/test_median_smoothd.py index 811e833a90..b8d3452c86 100644 --- a/tests/test_median_smoothd.py +++ b/tests/test_median_smoothd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py index c46d846903..7011405e4a 100644 --- a/tests/test_mednistdataset.py +++ b/tests/test_mednistdataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import shutil import unittest diff --git a/tests/test_meta_affine.py b/tests/test_meta_affine.py index 269db33ef4..b95ea3f1ac 100644 --- a/tests/test_meta_affine.py +++ b/tests/test_meta_affine.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest from copy import deepcopy diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index eb0bfdb12f..4f2cb9636a 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import io import os import random @@ -18,7 +20,6 @@ import warnings from copy import deepcopy from multiprocessing.reduction import ForkingPickler -from typing import Optional, Union import numpy as np import torch @@ -86,7 +87,7 @@ def check( shape: bool = True, vals: bool = True, ids: bool = True, - device: Optional[Union[str, torch.device]] = None, + device: str | torch.device | None = None, meta: bool = True, check_ids: bool = True, **kwargs, @@ -246,7 +247,7 @@ def test_torchscript(self, device): "your pytorch version if this is important to you." ) im_conv = im_conv.as_tensor() - self.check(out, im_conv, ids=False) + self.check(out, im_conv, ids=False) def test_pickling(self): m, _ = self.get_im() @@ -257,7 +258,7 @@ def test_pickling(self): if not isinstance(m2, MetaTensor) and not pytorch_after(1, 8, 1): warnings.warn("Old version of pytorch. pickling converts `MetaTensor` to `torch.Tensor`.") m = m.as_tensor() - self.check(m2, m, ids=False) + self.check(m2, m, ids=False) @skip_if_no_cuda def test_amp(self): @@ -405,6 +406,13 @@ def test_indexing(self): for _d in d: self.check_meta(_d, data) + def test_slicing(self): + x = MetaTensor(np.zeros((10, 3, 4))) + self.assertEqual(x[slice(4, 1)].shape[0], 0) + x.is_batch = True + with self.assertRaises(ValueError): + x[slice(0, 8)] + @parameterized.expand(DTYPES) @SkipIfBeforePyTorchVersion((1, 8)) def test_decollate(self, dtype): @@ -507,9 +515,11 @@ def test_pending_ops(self): self.assertEqual(m.pending_operations, []) self.assertEqual(m.peek_pending_shape(), (10, 8)) self.assertIsInstance(m.peek_pending_affine(), torch.Tensor) + self.assertTrue(m.peek_pending_rank() >= 1) m.push_pending_operation({}) self.assertEqual(m.peek_pending_shape(), (10, 8)) self.assertIsInstance(m.peek_pending_affine(), torch.Tensor) + self.assertTrue(m.peek_pending_rank() >= 1) @parameterized.expand(TESTS) def test_multiprocessing(self, device=None, dtype=None): diff --git a/tests/test_metatensor_integration.py b/tests/test_metatensor_integration.py index 6e8d5f40a3..6a4c67d160 100644 --- a/tests/test_metatensor_integration.py +++ b/tests/test_metatensor_integration.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_metrics_reloaded.py b/tests/test_metrics_reloaded.py new file mode 100644 index 0000000000..010326b87d --- /dev/null +++ b/tests/test_metrics_reloaded.py @@ -0,0 +1,95 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.metrics import MetricsReloadedBinary, MetricsReloadedCategorical +from monai.utils import optional_import + +_, has_metrics = optional_import("MetricsReloaded") + +# shape: (1, 1, 2, 2) +y_pred = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]) +y = torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]) +TEST_CASES_BINARY = [ + [{"metric_name": "False Positives"}, [y_pred, y], 0.0], + [{"metric_name": "False Negatives"}, [y_pred, y], 1.0], + [{"metric_name": "True Positives"}, [y_pred, y], 2.0], + [{"metric_name": "True Negatives"}, [y_pred, y], 1.0], + [{"metric_name": "Youden Index"}, [y_pred, y], 0.666654], + [{"metric_name": "Sensitivity"}, [y_pred, y], 0.666664], + [{"metric_name": "Specificity"}, [y_pred, y], 0.99999], + [{"metric_name": "Balanced Accuracy"}, [y_pred, y], 0.833327], + [{"metric_name": "Accuracy"}, [y_pred, y], 0.75], + [{"metric_name": "False Positive Rate"}, [y_pred, y], 0.0], + [{"metric_name": "Normalised Expected Cost"}, [y_pred, y], 0.333333], + [{"metric_name": "Matthews Correlation Coefficient"}, [y_pred, y], 0.57735], + [{"metric_name": "Cohens Kappa"}, [y_pred, y], 0.5], + [{"metric_name": "Positive Likelihood Ratio"}, [y_pred, y], 66576.03], + [{"metric_name": "Prediction Overlaps Reference"}, [y_pred, y], 1.0], + [{"metric_name": "Positive Predictive Value"}, [y_pred, y], 0.999995], + [{"metric_name": "Recall"}, [y_pred, y], 0.666664], + [{"metric_name": "FBeta"}, [y_pred, y], 0.799992], + [{"metric_name": "Net Benefit Treated"}, [y_pred, y], 0.5], + [{"metric_name": "Negative Predictive Values"}, [y_pred, y], 0.5], + [{"metric_name": "Dice Score"}, [y_pred, y], 0.799992], + [{"metric_name": "False Positives Per Image"}, [y_pred, y], 0.0], + [{"metric_name": "Intersection Over Reference"}, [y_pred, y], 0.666664], + [{"metric_name": "Intersection Over Union"}, [y_pred, y], 0.666664], + [{"metric_name": "Volume Difference"}, [y_pred, y], 0.333333], + [{"metric_name": "Topology Precision"}, [y_pred, y], 1.0], + [{"metric_name": "Topology Sensitivity"}, [y_pred, y], 1.0], + [{"metric_name": "Centreline Dice Score"}, [y_pred, y], 1.0], + [{"metric_name": "Boundary IoU"}, [y_pred, y], 0.666667], + [{"metric_name": "Normalised Surface Distance"}, [y_pred, y], 1.0], + [{"metric_name": "Average Symmetric Surface Distance"}, [y_pred, y], 0.2], + [{"metric_name": "Mean Average Surfance Distance"}, [y_pred, y], 0.166666], + [{"metric_name": "Hausdorff Distance"}, [y_pred, y], 1.0], + [{"metric_name": "xTh Percentile Hausdorff Distance"}, [y_pred, y], 0.9], +] + +# shape: (1, 3, 2, 2) +y_pred = torch.tensor([[[[0, 0], [0, 1]], [[0, 0], [0, 0]], [[1, 1], [1, 0]]]]) +y = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]]) +TEST_CASES_CATEGORICAL = [ + [{"metric_name": "Balanced Accuracy"}, [y_pred, y], 0.5], + [{"metric_name": "Weighted Cohens Kappa"}, [y_pred, y], 0.272727], + [{"metric_name": "Matthews Correlation Coefficient"}, [y_pred, y], 0.387298], + [{"metric_name": "Expected Cost"}, [y_pred, y], 0.5], + [{"metric_name": "Normalised Expected Cost"}, [y_pred, y], 0.75], +] + + +@unittest.skipIf(not has_metrics, "MetricsReloaded not available.") +class TestMetricsReloaded(unittest.TestCase): + @parameterized.expand(TEST_CASES_BINARY) + def test_binary(self, input_param, input_data, expected_val): + metric = MetricsReloadedBinary(**input_param) + result = metric(*input_data) + np.testing.assert_allclose( + result.detach().cpu().numpy(), expected_val, rtol=1e-5, err_msg=input_param["metric_name"] + ) + + @parameterized.expand(TEST_CASES_CATEGORICAL) + def test_categorical(self, input_param, input_data, expected_val): + metric = MetricsReloadedCategorical(**input_param) + result = metric(*input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_milmodel.py b/tests/test_milmodel.py index 2d58af4a2b..9178e0bccb 100644 --- a/tests/test_milmodel.py +++ b/tests/test_milmodel.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_mlp.py b/tests/test_mlp.py index 737762cfb1..8ad66ebc6e 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -22,7 +24,6 @@ for dropout_rate in np.linspace(0, 1, 4): for hidden_size in [128, 256, 512, 768]: for mlp_dim in [0, 1028, 2048, 3072]: - test_case = [ {"hidden_size": hidden_size, "mlp_dim": mlp_dim, "dropout_rate": dropout_rate}, (2, 512, hidden_size), diff --git a/tests/test_mmar_download.py b/tests/test_mmar_download.py index a87d927ed2..66fca6bb7f 100644 --- a/tests/test_mmar_download.py +++ b/tests/test_mmar_download.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_module_list.py b/tests/test_module_list.py index acd574d463..293da95d5a 100644 --- a/tests/test_module_list.py +++ b/tests/test_module_list.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import glob import inspect import os diff --git a/tests/test_monai_env_vars.py b/tests/test_monai_env_vars.py index 663dcdd98d..b9285b58d8 100644 --- a/tests/test_monai_env_vars.py +++ b/tests/test_monai_env_vars.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest diff --git a/tests/test_mri_utils.py b/tests/test_mri_utils.py index c1e2e788bf..2f67816e2e 100644 --- a/tests/test_mri_utils.py +++ b/tests/test_mri_utils.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_multi_scale.py b/tests/test_multi_scale.py index f348c09512..8b8acb2503 100644 --- a/tests/test_multi_scale.py +++ b/tests/test_multi_scale.py @@ -8,6 +8,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_net_adapter.py b/tests/test_net_adapter.py index c3abc8d142..74a2daab9d 100644 --- a/tests/test_net_adapter.py +++ b/tests/test_net_adapter.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_network_consistency.py b/tests/test_network_consistency.py index 327f0cfbf0..948e4d0615 100644 --- a/tests/test_network_consistency.py +++ b/tests/test_network_consistency.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import json import os import unittest @@ -47,7 +49,6 @@ def tearDown(self): ) @parameterized.expand(TESTS, skip_on_empty=True) def test_network_consistency(self, net_name, data_path, json_path): - print("Net name: " + net_name) print("Data path: " + data_path) print("JSON path: " + json_path) diff --git a/tests/test_nifti_endianness.py b/tests/test_nifti_endianness.py index 7f179d3bde..2539d95fd5 100644 --- a/tests/test_nifti_endianness.py +++ b/tests/test_nifti_endianness.py @@ -9,11 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest from pathlib import Path -from typing import TYPE_CHECKING, List, Tuple +from typing import TYPE_CHECKING from unittest.case import skipUnless import numpy as np @@ -36,7 +38,7 @@ nib, has_nib = optional_import("nibabel") PILImage, has_pil = optional_import("PIL.Image") -TESTS: List[Tuple] = [] +TESTS: list[tuple] = [] for endianness in ["<", ">"]: for use_array in [True, False]: for image_only in [True, False]: @@ -51,7 +53,6 @@ def setUp(self): @parameterized.expand(TESTS) @skipUnless(has_nib, "Requires NiBabel") def test_endianness(self, endianness, use_array, image_only): - hdr = nib.Nifti1Header(endianness=endianness) nii = nib.Nifti1Image(self.im, np.eye(4), header=hdr) nib.save(nii, self.fname) diff --git a/tests/test_nifti_header_revise.py b/tests/test_nifti_header_revise.py index 7f917cb0e9..3d000160e1 100644 --- a/tests/test_nifti_header_revise.py +++ b/tests/test_nifti_header_revise.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import nibabel as nib diff --git a/tests/test_nifti_rw.py b/tests/test_nifti_rw.py index 7da50617d9..cac18cf9e3 100644 --- a/tests/test_nifti_rw.py +++ b/tests/test_nifti_rw.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_nifti_saver.py b/tests/test_nifti_saver.py deleted file mode 100644 index 54489904df..0000000000 --- a/tests/test_nifti_saver.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -import unittest -from pathlib import Path - -import numpy as np -import torch - -from monai.data import NiftiSaver -from monai.transforms import LoadImage - - -class TestNiftiSaver(unittest.TestCase): - def test_saved_content(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = NiftiSaver(output_dir=Path(tempdir), output_postfix="seg", output_ext=".nii.gz") - - meta_data = {"filename_or_obj": ["testfile" + str(i) + ".nii" for i in range(8)]} - saver.save_batch(torch.zeros(8, 1, 2, 2), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - - def test_saved_resize_content(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = NiftiSaver(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32) - - meta_data = { - "filename_or_obj": ["testfile" + str(i) + ".nii" for i in range(8)], - "affine": [np.diag(np.ones(4)) * 5] * 8, - "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, - } - saver.save_batch(torch.randint(0, 255, (8, 8, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - - def test_saved_3d_resize_content(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = NiftiSaver(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32) - - meta_data = { - "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], - "spatial_shape": [(10, 10, 2)] * 8, - "affine": [np.diag(np.ones(4)) * 5] * 8, - "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, - } - saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - - def test_saved_3d_no_resize_content(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = NiftiSaver( - output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32, resample=False - ) - - meta_data = { - "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], - "spatial_shape": [(10, 10, 2)] * 8, - "affine": [np.diag(np.ones(4)) * 5] * 8, - "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, - } - saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join(tempdir, "testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") - img = LoadImage("nibabelreader", image_only=True)(filepath) - self.assertEqual(img.shape, (1, 2, 2, 8)) - - def test_squeeze_end_dims(self): - with tempfile.TemporaryDirectory() as tempdir: - - for squeeze_end_dims in [False, True]: - - saver = NiftiSaver( - output_dir=tempdir, - output_postfix="", - output_ext=".nii.gz", - dtype=np.float32, - squeeze_end_dims=squeeze_end_dims, - ) - - fname = "testfile_squeeze" - meta_data = {"filename_or_obj": fname} - - # 2d image w channel - saver.save(torch.randint(0, 255, (1, 2, 2)), meta_data) - - im = LoadImage(image_only=True)(os.path.join(tempdir, fname, fname + ".nii.gz")) - self.assertTrue(im.ndim == 2 if squeeze_end_dims else 4) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_normalize_intensity.py b/tests/test_normalize_intensity.py index 4d06a80c1d..193b5cc4b2 100644 --- a/tests/test_normalize_intensity.py +++ b/tests/test_normalize_intensity.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_normalize_intensityd.py b/tests/test_normalize_intensityd.py index a8167a1e93..451269b1c4 100644 --- a/tests/test_normalize_intensityd.py +++ b/tests/test_normalize_intensityd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_npzdictitemdataset.py b/tests/test_npzdictitemdataset.py index e24a2cfc1f..4ff4577b72 100644 --- a/tests/test_npzdictitemdataset.py +++ b/tests/test_npzdictitemdataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import tempfile import unittest from io import BytesIO diff --git a/tests/test_nrrd_reader.py b/tests/test_nrrd_reader.py index 03bfbfe156..01fabe65a8 100644 --- a/tests/test_nrrd_reader.py +++ b/tests/test_nrrd_reader.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_nuclick_transforms.py b/tests/test_nuclick_transforms.py index 4071dd2a07..fcdd362b01 100644 --- a/tests/test_nuclick_transforms.py +++ b/tests/test_nuclick_transforms.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_numpy_reader.py b/tests/test_numpy_reader.py index d220a67c92..393f613163 100644 --- a/tests/test_numpy_reader.py +++ b/tests/test_numpy_reader.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import sys import tempfile diff --git a/tests/test_nvtx_decorator.py b/tests/test_nvtx_decorator.py index 318c27f1e4..1e7bea17d4 100644 --- a/tests/test_nvtx_decorator.py +++ b/tests/test_nvtx_decorator.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_nvtx_transform.py b/tests/test_nvtx_transform.py index 5310270bc0..3a5314c35f 100644 --- a/tests/test_nvtx_transform.py +++ b/tests/test_nvtx_transform.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_occlusion_sensitivity.py b/tests/test_occlusion_sensitivity.py index 02e34704f1..c7ac5ef533 100644 --- a/tests/test_occlusion_sensitivity.py +++ b/tests/test_occlusion_sensitivity.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import Any, List +from typing import Any import torch from parameterized import parameterized @@ -40,8 +42,8 @@ def __call__(self, x, adjoint_info): model_3d.eval() model_2d_adjoint.eval() -TESTS: List[Any] = [] -TESTS_FAIL: List[Any] = [] +TESTS: list[Any] = [] +TESTS_FAIL: list[Any] = [] # 2D w/ bounding box with all modes for mode in ("gaussian", "mean_patch", "mean_img"): diff --git a/tests/test_one_of.py b/tests/test_one_of.py index 2ea41c6e50..687ec71aad 100644 --- a/tests/test_one_of.py +++ b/tests/test_one_of.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy diff --git a/tests/test_optim_novograd.py b/tests/test_optim_novograd.py index c1e63182e6..8808db432d 100644 --- a/tests/test_optim_novograd.py +++ b/tests/test_optim_novograd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_optional_import.py b/tests/test_optional_import.py index b87ebf8909..03db7b3fc6 100644 --- a/tests/test_optional_import.py +++ b/tests/test_optional_import.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from monai.utils import OptionalImportError, exact_version, optional_import diff --git a/tests/test_ori_ras_lps.py b/tests/test_ori_ras_lps.py index d0a9b034e4..824793f927 100644 --- a/tests/test_ori_ras_lps.py +++ b/tests/test_ori_ras_lps.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_orientation.py b/tests/test_orientation.py index 979f6ae485..6e89d085d2 100644 --- a/tests/test_orientation.py +++ b/tests/test_orientation.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import nibabel as nib @@ -19,6 +21,7 @@ from monai.data.meta_obj import set_track_meta from monai.data.meta_tensor import MetaTensor from monai.transforms import Orientation, create_rotate, create_translate +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_DEVICES, assert_allclose TESTS = [] @@ -186,9 +189,13 @@ def test_ornt_meta( ): img = MetaTensor(img, affine=affine).to(device) ornt = Orientation(**init_param) - res: MetaTensor = ornt(img) + call_param = {"data_array": img} + res = ornt(**call_param) + if img.ndim in (3, 4): + test_resampler_lazy(ornt, res, init_param, call_param) + assert_allclose(res, expected_data.to(device)) - new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=ornt.labels) + new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=ornt.labels) # type: ignore self.assertEqual("".join(new_code), expected_code) @parameterized.expand(TESTS_TORCH) @@ -204,6 +211,7 @@ def test_ornt_torch(self, init_param, img: torch.Tensor, track_meta: bool, devic assert_allclose(res, expected_data) if track_meta: self.assertIsInstance(res, MetaTensor) + assert isinstance(res, MetaTensor) # for mypy type narrowing new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=ornt.labels) self.assertEqual("".join(new_code), expected_code) else: diff --git a/tests/test_orientationd.py b/tests/test_orientationd.py index 1b4660a60a..ddb5dc3e98 100644 --- a/tests/test_orientationd.py +++ b/tests/test_orientationd.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import Optional import nibabel as nib import numpy as np @@ -20,6 +21,7 @@ from monai.data.meta_obj import set_track_meta from monai.data.meta_tensor import MetaTensor from monai.transforms import Orientationd +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_DEVICES TESTS = [] @@ -65,19 +67,21 @@ class TestOrientationdCase(unittest.TestCase): @parameterized.expand(TESTS) def test_orntd( - self, init_param, img: torch.Tensor, affine: Optional[torch.Tensor], expected_shape, expected_code, device + self, init_param, img: torch.Tensor, affine: torch.Tensor | None, expected_shape, expected_code, device ): ornt = Orientationd(**init_param) if affine is not None: img = MetaTensor(img, affine=affine) img = img.to(device) - data = {k: img.clone() for k in ornt.keys} - res = ornt(data) + call_param = {"data": {k: img.clone() for k in ornt.keys}} + res = ornt(**call_param) for k in ornt.keys: + if img.ndim in (3, 4): + test_resampler_lazy(ornt, res, init_param, call_param, output_key=k) _im = res[k] self.assertIsInstance(_im, MetaTensor) np.testing.assert_allclose(_im.shape, expected_shape) - code = nib.aff2axcodes(_im.affine.cpu(), ornt.ornt_transform.labels) + code = nib.aff2axcodes(_im.affine.cpu(), ornt.ornt_transform.labels) # type: ignore self.assertEqual("".join(code), expected_code) @parameterized.expand(TESTS_TORCH) @@ -87,13 +91,16 @@ def test_orntd_torch(self, init_param, img: torch.Tensor, track_meta: bool, devi img = img.to(device) expected_shape = img.shape expected_code = ornt.ornt_transform.axcodes - data = {k: img.clone() for k in ornt.keys} - res = ornt(data) + call_param = {"data": {k: img.clone() for k in ornt.keys}} + res = ornt(**call_param) for k in ornt.keys: _im = res[k] np.testing.assert_allclose(_im.shape, expected_shape) if track_meta: + if img.ndim in (3, 4): + test_resampler_lazy(ornt, res, init_param, call_param, output_key=k) self.assertIsInstance(_im, MetaTensor) + assert isinstance(_im, MetaTensor) # for mypy type narrowing code = nib.aff2axcodes(_im.affine.cpu(), ornt.ornt_transform.labels) self.assertEqual("".join(code), expected_code) else: diff --git a/tests/test_p3d_block.py b/tests/test_p3d_block.py index 62b7098dcd..db9e9c284d 100644 --- a/tests/test_p3d_block.py +++ b/tests/test_p3d_block.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py index 89a8508967..cd98f29abf 100644 --- a/tests/test_pad_collation.py +++ b/tests/test_pad_collation.py @@ -9,10 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import random import unittest from functools import wraps -from typing import List, Tuple import numpy as np import torch @@ -42,7 +43,7 @@ def _testing_collate(x): return pad_list_data_collate(batch=x, method="end", mode="constant") -TESTS: List[Tuple] = [] +TESTS: list[tuple] = [] for pad_collate in [_testing_collate, PadListDataCollate(method="end", mode="constant")]: TESTS.append((dict, pad_collate, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) @@ -86,7 +87,6 @@ def tearDown(self) -> None: @parameterized.expand(TESTS) def test_pad_collation(self, t_type, collate_method, transform): - if t_type == dict: dataset = CacheDataset(self.dict_data, transform, progress=False) else: diff --git a/tests/test_pad_mode.py b/tests/test_pad_mode.py index ae2e94b60c..722d5b573f 100644 --- a/tests/test_pad_mode.py +++ b/tests/test_pad_mode.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_parallel_execution.py b/tests/test_parallel_execution.py index 6186e73f68..26ec873950 100644 --- a/tests/test_parallel_execution.py +++ b/tests/test_parallel_execution.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import warnings diff --git a/tests/test_parallel_execution_dist.py b/tests/test_parallel_execution_dist.py index f067b71d14..b6f6695be4 100644 --- a/tests/test_parallel_execution_dist.py +++ b/tests/test_parallel_execution_dist.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_partition_dataset.py b/tests/test_partition_dataset.py index 687cf8df34..8640d8cc73 100644 --- a/tests/test_partition_dataset.py +++ b/tests/test_partition_dataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_partition_dataset_classes.py b/tests/test_partition_dataset_classes.py index 4ed283bdd7..c4fa5ed199 100644 --- a/tests/test_partition_dataset_classes.py +++ b/tests/test_partition_dataset_classes.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_patch_dataset.py b/tests/test_patch_dataset.py index 9574afccea..7d66bdccbb 100644 --- a/tests/test_patch_dataset.py +++ b/tests/test_patch_dataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import sys import unittest diff --git a/tests/test_patch_inferer.py b/tests/test_patch_inferer.py new file mode 100644 index 0000000000..c7ae2c6244 --- /dev/null +++ b/tests/test_patch_inferer.py @@ -0,0 +1,237 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized +from torch.nn.functional import avg_pool2d + +from monai.data.meta_tensor import MetaTensor +from monai.inferers import AvgMerger, PatchInferer, SlidingWindowSplitter +from tests.utils import assert_allclose + +TENSOR_4x4 = torch.randint(low=0, high=255, size=(2, 3, 4, 4), dtype=torch.float32) +TENSOR_2x2 = avg_pool2d(TENSOR_4x4, 2, 2) + +# no-overlapping 2x2 patches +TEST_CASE_0_TENSOR = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger), + lambda x: x, + TENSOR_4x4, +] + +# no-overlapping 2x2 patches using all default parameters (except for splitter) +TEST_CASE_1_TENSOR = [TENSOR_4x4, dict(splitter=SlidingWindowSplitter(patch_size=(2, 2))), lambda x: x, TENSOR_4x4] + +# divisible batch_size +TEST_CASE_2_TENSOR = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, batch_size=2), + lambda x: x, + TENSOR_4x4, +] + +# non-divisible batch_size +TEST_CASE_3_TENSOR = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, batch_size=3), + lambda x: x, + TENSOR_4x4, +] + +# patches that are already split (Splitter should be None) +TEST_CASE_4_SPLIT_LIST = [ + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + dict(splitter=None, merger_cls=AvgMerger, output_shape=(2, 3, 4, 4)), + lambda x: x, + TENSOR_4x4, +] + +# using all default parameters (patches are already split) +TEST_CASE_5_SPLIT_LIST = [ + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + dict(merger_cls=AvgMerger, output_shape=(2, 3, 4, 4)), + lambda x: x, + TENSOR_4x4, +] + +# output smaller than input patches +TEST_CASE_6_SMALLER = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger), + lambda x: torch.mean(x, dim=(-1, -2), keepdim=True), + TENSOR_2x2, +] + +# preprocess patches +TEST_CASE_7_PREPROCESS = [ + TENSOR_4x4, + dict( + splitter=SlidingWindowSplitter(patch_size=(2, 2)), + merger_cls=AvgMerger, + preprocessing=lambda x: 2 * x, + postprocessing=None, + ), + lambda x: x, + 2 * TENSOR_4x4, +] + +# preprocess patches +TEST_CASE_8_POSTPROCESS = [ + TENSOR_4x4, + dict( + splitter=SlidingWindowSplitter(patch_size=(2, 2)), + merger_cls=AvgMerger, + preprocessing=None, + postprocessing=lambda x: 4 * x, + ), + lambda x: x, + 4 * TENSOR_4x4, +] + +# str merger as the class name +TEST_CASE_9_STR_MERGER = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls="AvgMerger"), + lambda x: x, + TENSOR_4x4, +] + +# str merger as dotted patch +TEST_CASE_10_STR_MERGER = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls="monai.inferers.merger.AvgMerger"), + lambda x: x, + TENSOR_4x4, +] + +# list of tensor output +TEST_CASE_0_LIST_TENSOR = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger), + lambda x: (x, x), + (TENSOR_4x4, TENSOR_4x4), +] + +# list of tensor output +TEST_CASE_0_DICT = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger), + lambda x: {"model_output": x}, + {"model_output": TENSOR_4x4}, +] + +# ---------------------------------------------------------------------------- +# Error test cases +# ---------------------------------------------------------------------------- +# invalid splitter: not callable +TEST_CASE_ERROR_0 = [None, dict(splitter=1), TypeError] +# invalid merger: non-existent merger +TEST_CASE_ERROR_1 = [None, dict(splitter=lambda x: x, merger_cls="NonExistent"), ValueError] +# invalid merger: callable +TEST_CASE_ERROR_2 = [None, dict(splitter=lambda x: x, merger_cls=lambda x: x), TypeError] +# invalid merger: Merger object +TEST_CASE_ERROR_3 = [None, dict(splitter=lambda x: x, merger_cls=AvgMerger(output_shape=(1, 1))), TypeError] +# invalid merger: list of Merger class +TEST_CASE_ERROR_4 = [None, dict(splitter=lambda x: x, merger_cls=[AvgMerger, AvgMerger]), TypeError] +# invalid preprocessing +TEST_CASE_ERROR_5 = [None, dict(splitter=lambda x: x, preprocessing=1), TypeError] +# invalid postprocessing +TEST_CASE_ERROR_6 = [None, dict(splitter=lambda x: x, postprocessing=1), TypeError] +# provide splitter when data is already split (splitter is not None) +TEST_CASE_ERROR_7 = [ + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + dict(splitter=lambda x: x), + AttributeError, +] +# invalid inputs: split patches tensor without location +TEST_CASE_ERROR_8 = [torch.zeros(2, 2), dict(splitter=None), ValueError] +# invalid inputs: split patches MetaTensor without location metadata +TEST_CASE_ERROR_9 = [MetaTensor(torch.zeros(2, 2)), dict(splitter=None), ValueError] + + +class PatchInfererTests(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_0_TENSOR, + TEST_CASE_1_TENSOR, + TEST_CASE_2_TENSOR, + TEST_CASE_3_TENSOR, + TEST_CASE_4_SPLIT_LIST, + TEST_CASE_5_SPLIT_LIST, + TEST_CASE_6_SMALLER, + TEST_CASE_7_PREPROCESS, + TEST_CASE_8_POSTPROCESS, + TEST_CASE_9_STR_MERGER, + TEST_CASE_10_STR_MERGER, + ] + ) + def test_patch_inferer_tensor(self, inputs, arguments, network, expected): + inferer = PatchInferer(**arguments) + output = inferer(inputs=inputs, network=network) + assert_allclose(output, expected) + + @parameterized.expand([TEST_CASE_0_LIST_TENSOR]) + def test_patch_inferer_list_tensor(self, inputs, arguments, network, expected): + inferer = PatchInferer(**arguments) + output = inferer(inputs=inputs, network=network) + for out, exp in zip(output, expected): + assert_allclose(out, exp) + + @parameterized.expand([TEST_CASE_0_DICT]) + def test_patch_inferer_dict(self, inputs, arguments, network, expected): + inferer = PatchInferer(**arguments) + output = inferer(inputs=inputs, network=network) + for k in expected: + assert_allclose(output[k], expected[k]) + + @parameterized.expand( + [ + TEST_CASE_ERROR_0, + TEST_CASE_ERROR_1, + TEST_CASE_ERROR_2, + TEST_CASE_ERROR_3, + TEST_CASE_ERROR_4, + TEST_CASE_ERROR_5, + TEST_CASE_ERROR_6, + TEST_CASE_ERROR_7, + TEST_CASE_ERROR_8, + TEST_CASE_ERROR_9, + ] + ) + def test_patch_inferer_errors(self, inputs, arguments, expected_error): + with self.assertRaises(expected_error): + PatchInferer(**arguments) + inferer = PatchInferer(**arguments) + inferer(inputs=inputs, network=lambda x: x) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_patch_wsi_dataset.py b/tests/test_patch_wsi_dataset.py index 0ba1a4a649..d2cc139ebc 100644 --- a/tests/test_patch_wsi_dataset.py +++ b/tests/test_patch_wsi_dataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest from unittest import skipUnless diff --git a/tests/test_patchembedding.py b/tests/test_patchembedding.py index 6971eb0463..bfc1fdbbc9 100644 --- a/tests/test_patchembedding.py +++ b/tests/test_patchembedding.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from unittest import skipUnless diff --git a/tests/test_pathology_he_stain.py b/tests/test_pathology_he_stain.py index 48b3c88543..7ddad4ad6f 100644 --- a/tests/test_pathology_he_stain.py +++ b/tests/test_pathology_he_stain.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_pathology_he_stain_dict.py b/tests/test_pathology_he_stain_dict.py index 2115ce9a99..07db1c3e48 100644 --- a/tests/test_pathology_he_stain_dict.py +++ b/tests/test_pathology_he_stain_dict.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_pathology_prob_nms.py b/tests/test_pathology_prob_nms.py index 3399e33afa..0053500437 100644 --- a/tests/test_pathology_prob_nms.py +++ b/tests/test_pathology_prob_nms.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_persistentdataset.py b/tests/test_persistentdataset.py index bc52ed1102..1b8245e318 100644 --- a/tests/test_persistentdataset.py +++ b/tests/test_persistentdataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import pickle import tempfile diff --git a/tests/test_persistentdataset_dist.py b/tests/test_persistentdataset_dist.py index 20dcb2c264..e69c32b1eb 100644 --- a/tests/test_persistentdataset_dist.py +++ b/tests/test_persistentdataset_dist.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import shutil import tempfile diff --git a/tests/test_phl_cpu.py b/tests/test_phl_cpu.py index d479f554b4..98a5018d8e 100644 --- a/tests/test_phl_cpu.py +++ b/tests/test_phl_cpu.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -242,7 +244,6 @@ class PHLFilterTestCaseCpu(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_cpu(self, test_case_description, sigmas, input, features, expected): - # Create input tensors input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=torch.device("cpu")) feature_tensor = torch.from_numpy(np.array(features)).to(dtype=torch.float, device=torch.device("cpu")) diff --git a/tests/test_phl_cuda.py b/tests/test_phl_cuda.py index d49a60ecd9..0ddfd5eaae 100644 --- a/tests/test_phl_cuda.py +++ b/tests/test_phl_cuda.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -150,7 +152,6 @@ class PHLFilterTestCaseCuda(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_cuda(self, test_case_description, sigmas, input, features, expected): - # Create input tensors input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=torch.device("cuda")) feature_tensor = torch.from_numpy(np.array(features)).to(dtype=torch.float, device=torch.device("cuda")) diff --git a/tests/test_pil_reader.py b/tests/test_pil_reader.py index 4f0b891b72..dfa5eb725d 100644 --- a/tests/test_pil_reader.py +++ b/tests/test_pil_reader.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest @@ -23,7 +25,7 @@ TEST_CASE_2 = [(128, 128, 3), ["test_image.png"], (128, 128, 3), (128, 128)] -TEST_CASE_3 = [(128, 128, 4), ["test_image.png"], (128, 128, 4), (128, 128)] +TEST_CASE_3 = [(128, 128, 4), ["test_image.png"], (128, 128, 4), (128, 128), False] TEST_CASE_4 = [(128, 128), ["test_image1.png", "test_image2.png", "test_image3.png"], (3, 128, 128), (128, 128)] @@ -36,20 +38,21 @@ class TestPNGReader(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) - def test_shape_value(self, data_shape, filenames, expected_shape, meta_shape): + def test_shape_value(self, data_shape, filenames, expected_shape, meta_shape, reverse=True): test_image = np.random.randint(0, 256, size=data_shape) with tempfile.TemporaryDirectory() as tempdir: for i, name in enumerate(filenames): filenames[i] = os.path.join(tempdir, name) Image.fromarray(test_image.astype("uint8")).save(filenames[i]) - reader = PILReader(mode="r") + reader = PILReader(mode="r", reverse_indexing=reverse) result = reader.get_data(reader.read(filenames)) # load image by PIL and compare the result test_image = np.asarray(Image.open(filenames[0])) self.assertTupleEqual(tuple(result[1]["spatial_shape"]), meta_shape) self.assertTupleEqual(result[0].shape, expected_shape) - test_image = np.moveaxis(test_image, 0, 1) + if reverse: + test_image = np.moveaxis(test_image, 0, 1) if result[0].shape == test_image.shape: np.testing.assert_allclose(result[0], test_image) else: diff --git a/tests/test_plot_2d_or_3d_image.py b/tests/test_plot_2d_or_3d_image.py index 5325c6a294..5add74c260 100644 --- a/tests/test_plot_2d_or_3d_image.py +++ b/tests/test_plot_2d_or_3d_image.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import glob import tempfile import unittest diff --git a/tests/test_png_rw.py b/tests/test_png_rw.py index 47b5571ac0..0b6e8184ea 100644 --- a/tests/test_png_rw.py +++ b/tests/test_png_rw.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_png_saver.py b/tests/test_png_saver.py deleted file mode 100644 index d832718643..0000000000 --- a/tests/test_png_saver.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -import unittest -from pathlib import Path - -import torch - -from monai.data import PNGSaver - - -class TestPNGSaver(unittest.TestCase): - def test_saved_content(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = PNGSaver(output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255) - - meta_data = {"filename_or_obj": ["testfile" + str(i) + ".jpg" for i in range(8)]} - saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - - def test_saved_content_three_channel(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = PNGSaver(output_dir=Path(tempdir), output_postfix="seg", output_ext=".png", scale=255) - - meta_data = {"filename_or_obj": ["testfile" + str(i) + ".jpg" for i in range(8)]} - saver.save_batch(torch.randint(1, 200, (8, 3, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - - def test_saved_content_spatial_size(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = PNGSaver(output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255) - - meta_data = { - "filename_or_obj": ["testfile" + str(i) + ".jpg" for i in range(8)], - "spatial_shape": [(4, 4) for i in range(8)], - } - saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - - def test_saved_specified_root(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = PNGSaver( - output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255, data_root_dir="test" - ) - - meta_data = { - "filename_or_obj": [os.path.join("test", "testfile" + str(i), "image" + ".jpg") for i in range(8)] - } - saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "image", "image" + "_seg.png") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_polyval.py b/tests/test_polyval.py index db3bcaca53..113c862cb3 100644 --- a/tests/test_polyval.py +++ b/tests/test_polyval.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_prepare_batch_default.py b/tests/test_prepare_batch_default.py index e3836ed86f..e440f5cfe3 100644 --- a/tests/test_prepare_batch_default.py +++ b/tests/test_prepare_batch_default.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_prepare_batch_default_dist.py b/tests/test_prepare_batch_default_dist.py index 3c7532e916..d015cf4b2f 100644 --- a/tests/test_prepare_batch_default_dist.py +++ b/tests/test_prepare_batch_default_dist.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_prepare_batch_extra_input.py b/tests/test_prepare_batch_extra_input.py index 79c9a13679..1769a19e4a 100644 --- a/tests/test_prepare_batch_extra_input.py +++ b/tests/test_prepare_batch_extra_input.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_prepare_batch_hovernet.py b/tests/test_prepare_batch_hovernet.py index 9aed8e94c7..5a7080a225 100644 --- a/tests/test_prepare_batch_hovernet.py +++ b/tests/test_prepare_batch_hovernet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_preset_filters.py b/tests/test_preset_filters.py new file mode 100644 index 0000000000..9bca24cef3 --- /dev/null +++ b/tests/test_preset_filters.py @@ -0,0 +1,130 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.layers import ApplyFilter, EllipticalFilter, LaplaceFilter, MeanFilter, SharpenFilter + +TEST_CASES_MEAN = [(3, 3, torch.ones(3, 3, 3)), (2, 5, torch.ones(5, 5))] + +TEST_CASES_LAPLACE = [ + ( + 3, + 3, + torch.Tensor( + [ + [[-1, -1, -1], [-1, -1, -1], [-1, -1, -1]], + [[-1, -1, -1], [-1, 26, -1], [-1, -1, -1]], + [[-1, -1, -1], [-1, -1, -1], [-1, -1, -1]], + ] + ), + ), + (2, 3, torch.Tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])), +] + +TEST_CASES_ELLIPTICAL = [ + ( + 3, + 3, + torch.Tensor( + [[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [1, 1, 1], [0, 1, 0]], [[0, 0, 0], [0, 1, 0], [0, 0, 0]]] + ), + ), + (2, 3, torch.Tensor([[0, 1, 0], [1, 1, 1], [0, 1, 0]])), +] + +TEST_CASES_SHARPEN = [ + ( + 3, + 3, + torch.Tensor( + [ + [[0, 0, 0], [0, -1, 0], [0, 0, 0]], + [[0, -1, 0], [-1, 7, -1], [0, -1, 0]], + [[0, 0, 0], [0, -1, 0], [0, 0, 0]], + ] + ), + ), + (2, 3, torch.Tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])), +] + + +class _TestFilter: + def test_init(self, spatial_dims, size, expected): + test_filter = self.filter_class(spatial_dims=spatial_dims, size=size) + torch.testing.assert_allclose(expected, test_filter.filter) + self.assertIsInstance(test_filter, torch.nn.Module) + + def test_forward(self): + test_filter = self.filter_class(spatial_dims=2, size=3) + input = torch.ones(1, 1, 5, 5) + _ = test_filter(input) + + +class TestApplyFilter(unittest.TestCase): + def test_init_and_forward_2d(self): + filter_2d = torch.ones(3, 3) + image_2d = torch.ones(1, 3, 3) + apply_filter_2d = ApplyFilter(filter_2d) + out = apply_filter_2d(image_2d) + self.assertEqual(out.shape, image_2d.shape) + + def test_init_and_forward_3d(self): + filter_2d = torch.ones(3, 3, 3) + image_2d = torch.ones(1, 3, 3, 3) + apply_filter_2d = ApplyFilter(filter_2d) + out = apply_filter_2d(image_2d) + self.assertEqual(out.shape, image_2d.shape) + + +class MeanFilterTestCase(_TestFilter, unittest.TestCase): + def setUp(self) -> None: + self.filter_class = MeanFilter + + @parameterized.expand(TEST_CASES_MEAN) + def test_init(self, spatial_dims, size, expected): + super().test_init(spatial_dims, size, expected) + + +class LaplaceFilterTestCase(_TestFilter, unittest.TestCase): + def setUp(self) -> None: + self.filter_class = LaplaceFilter + + @parameterized.expand(TEST_CASES_LAPLACE) + def test_init(self, spatial_dims, size, expected): + super().test_init(spatial_dims, size, expected) + + +class EllipticalTestCase(_TestFilter, unittest.TestCase): + def setUp(self) -> None: + self.filter_class = EllipticalFilter + + @parameterized.expand(TEST_CASES_ELLIPTICAL) + def test_init(self, spatial_dims, size, expected): + super().test_init(spatial_dims, size, expected) + + +class SharpenTestCase(_TestFilter, unittest.TestCase): + def setUp(self) -> None: + self.filter_class = SharpenFilter + + @parameterized.expand(TEST_CASES_SHARPEN) + def test_init(self, spatial_dims, size, expected): + super().test_init(spatial_dims, size, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_print_info.py b/tests/test_print_info.py index 591316884c..bb748c3f7b 100644 --- a/tests/test_print_info.py +++ b/tests/test_print_info.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from monai.config import print_debug_info diff --git a/tests/test_print_transform_backends.py b/tests/test_print_transform_backends.py index e714003769..4cd93c3fb2 100644 --- a/tests/test_print_transform_backends.py +++ b/tests/test_print_transform_backends.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from monai.transforms.utils import get_transform_backends, print_transform_backends diff --git a/tests/test_probnms.py b/tests/test_probnms.py index aab312c1db..8da5396fac 100644 --- a/tests/test_probnms.py +++ b/tests/test_probnms.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_probnmsd.py b/tests/test_probnmsd.py index bb2315487b..1f0288811e 100644 --- a/tests/test_probnmsd.py +++ b/tests/test_probnmsd.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import Any, List +from typing import Any import numpy as np import torch @@ -19,7 +21,7 @@ from monai.transforms.post.dictionary import ProbNMSD from tests.utils import TEST_NDARRAYS -TESTS: List[Any] = [] +TESTS: list[Any] = [] for p in TEST_NDARRAYS: probs_map_1 = p(np.random.rand(100, 100).clip(0, 0.5)) TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": 10}, {"prob_map": probs_map_1}, []]) diff --git a/tests/test_profiling.py b/tests/test_profiling.py index 40522b07c5..2b93fae196 100644 --- a/tests/test_profiling.py +++ b/tests/test_profiling.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import datetime import os import unittest diff --git a/tests/test_pytorch_version_after.py b/tests/test_pytorch_version_after.py index be43e49f82..4c8c032c80 100644 --- a/tests/test_pytorch_version_after.py +++ b/tests/test_pytorch_version_after.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_query_memory.py b/tests/test_query_memory.py index cdb44d3eb1..5e57913acb 100644 --- a/tests/test_query_memory.py +++ b/tests/test_query_memory.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from tests.utils import query_memory diff --git a/tests/test_rand_adjust_contrast.py b/tests/test_rand_adjust_contrast.py index 5dc800793e..bfeedc2fcf 100644 --- a/tests/test_rand_adjust_contrast.py +++ b/tests/test_rand_adjust_contrast.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_adjust_contrastd.py b/tests/test_rand_adjust_contrastd.py index b355ac3e4f..4037266da4 100644 --- a/tests/test_rand_adjust_contrastd.py +++ b/tests/test_rand_adjust_contrastd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_affine.py b/tests/test_rand_affine.py index b5bc67ffb1..83aafe9773 100644 --- a/tests/test_rand_affine.py +++ b/tests/test_rand_affine.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -17,6 +19,7 @@ from monai.transforms import RandAffine from monai.utils.type_conversion import convert_data_type +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_NDARRAYS_ALL, assert_allclose, is_tf32_env _rtol = 1e-3 if is_tf32_env() else 1e-4 @@ -142,6 +145,7 @@ def test_rand_affine(self, input_param, input_data, expected_val): g = RandAffine(**input_param) g.set_random_state(123) result = g(**input_data) + test_resampler_lazy(g, result, input_param, input_data, seed=123) if input_param.get("cache_grid", False): self.assertTrue(g._cached_grid is not None) assert_allclose(result, expected_val, rtol=_rtol, atol=1e-4, type_test="tensor") diff --git a/tests/test_rand_affine_grid.py b/tests/test_rand_affine_grid.py index 6a40d39e4e..113987a85c 100644 --- a/tests/test_rand_affine_grid.py +++ b/tests/test_rand_affine_grid.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_affined.py b/tests/test_rand_affined.py index 566eed68ef..5c1e2359e8 100644 --- a/tests/test_rand_affined.py +++ b/tests/test_rand_affined.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import itertools import unittest @@ -18,7 +20,8 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import RandAffined -from monai.utils import GridSampleMode +from monai.utils import GridSampleMode, ensure_tuple_rep +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import assert_allclose, is_tf32_env _rtol = 1e-3 if is_tf32_env() else 1e-4 @@ -217,7 +220,22 @@ class TestRandAffined(unittest.TestCase): def test_rand_affined(self, input_param, input_data, expected_val, track_meta): set_track_meta(track_meta) g = RandAffined(**input_param).set_random_state(123) - res = g(input_data) + call_param = {"data": input_data} + res = g(**call_param) + # test lazy + if track_meta and input_data["img"].ndim in (3, 4): + if "mode" not in input_param.keys(): + input_param["mode"] = "bilinear" + if not isinstance(input_param["keys"], str): + input_param["mode"] = ensure_tuple_rep(input_param["mode"], len(input_param["keys"])) + lazy_init_param = input_param.copy() + for key, mode in zip(input_param["keys"], input_param["mode"]): + lazy_init_param["keys"], lazy_init_param["mode"] = key, mode + resampler = RandAffined(**lazy_init_param).set_random_state(123) + expected_output = resampler(**call_param) + test_resampler_lazy(resampler, expected_output, lazy_init_param, call_param, seed=123, output_key=key) + resampler.lazy_evaluation = False + if input_param.get("cache_grid", False): self.assertTrue(g.rand_affine._cached_grid is not None) for key in res: @@ -231,16 +249,18 @@ def test_rand_affined(self, input_param, input_data, expected_val, track_meta): assert_allclose(result, expected, rtol=_rtol, atol=1e-3, type_test=False) g.set_random_state(4) - res = g(input_data) + res = g(**call_param) if not track_meta: return # affine should be tensor because the resampler only supports pytorch backend if isinstance(res["img"], MetaTensor) and "extra_info" in res["img"].applied_operations[0]: - if not res["img"].applied_operations[-1]["extra_info"]["do_resampling"]: + if not res["img"].applied_operations[-1]["extra_info"]: + return + if not res["img"].applied_operations[-1]["extra_info"]["extra_info"]["do_resampling"]: return - affine_img = res["img"].applied_operations[0]["extra_info"]["rand_affine_info"]["extra_info"]["affine"] - affine_seg = res["seg"].applied_operations[0]["extra_info"]["rand_affine_info"]["extra_info"]["affine"] + affine_img = res["img"].applied_operations[0]["extra_info"]["extra_info"]["affine"] + affine_seg = res["seg"].applied_operations[0]["extra_info"]["extra_info"]["affine"] assert_allclose(affine_img, affine_seg, rtol=_rtol, atol=1e-3) res_inv = g.inverse(res) diff --git a/tests/test_rand_axis_flip.py b/tests/test_rand_axis_flip.py index 7458b9d6dd..457617fc19 100644 --- a/tests/test_rand_axis_flip.py +++ b/tests/test_rand_axis_flip.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -16,6 +18,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import RandAxisFlip +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion @@ -23,8 +26,15 @@ class TestRandAxisFlip(NumpyImageTestCase2D): def test_correct_results(self): for p in TEST_NDARRAYS_ALL: flip = RandAxisFlip(prob=1.0) + flip.set_random_state(seed=321) im = p(self.imt[0]) - result = flip(im) + call_param = {"img": im} + result = flip(**call_param) + + # test lazy + test_resampler_lazy(flip, result, call_param=call_param, seed=321) + flip.lazy_evaluation = False + expected = [np.flip(channel, flip._axis) for channel in self.imt[0]] assert_allclose(result, p(np.stack(expected)), type_test="tensor") test_local_inversion(flip, result, im) diff --git a/tests/test_rand_axis_flipd.py b/tests/test_rand_axis_flipd.py index a62da88af3..e6fac5637f 100644 --- a/tests/test_rand_axis_flipd.py +++ b/tests/test_rand_axis_flipd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -16,6 +18,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import RandAxisFlipd +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase3D, assert_allclose, test_local_inversion @@ -23,8 +26,15 @@ class TestRandAxisFlip(NumpyImageTestCase3D): def test_correct_results(self): for p in TEST_NDARRAYS_ALL: flip = RandAxisFlipd(keys="img", prob=1.0) + flip.set_random_state(seed=1234) im = p(self.imt[0]) - result = flip({"img": im}) + call_param = {"data": {"img": im}} + result = flip(**call_param) + + # test lazy + test_resampler_lazy(flip, result, call_param=call_param, output_key="img", seed=1234) + flip.lazy_evaluation = False + test_local_inversion(flip, result, {"img": im}, "img") expected = [np.flip(channel, flip.flipper._axis) for channel in self.imt[0]] assert_allclose(result["img"], p(np.stack(expected)), type_test="tensor") diff --git a/tests/test_rand_bias_field.py b/tests/test_rand_bias_field.py index 690c4022eb..16f615146f 100644 --- a/tests/test_rand_bias_field.py +++ b/tests/test_rand_bias_field.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_bias_fieldd.py b/tests/test_rand_bias_fieldd.py index 05a5a1b636..2b8a60289d 100644 --- a/tests/test_rand_bias_fieldd.py +++ b/tests/test_rand_bias_fieldd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_coarse_dropout.py b/tests/test_rand_coarse_dropout.py index cc05edbf02..8c3876f10b 100644 --- a/tests/test_rand_coarse_dropout.py +++ b/tests/test_rand_coarse_dropout.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_coarse_dropoutd.py b/tests/test_rand_coarse_dropoutd.py index e54db130a5..7b16f992b7 100644 --- a/tests/test_rand_coarse_dropoutd.py +++ b/tests/test_rand_coarse_dropoutd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_coarse_shuffle.py b/tests/test_rand_coarse_shuffle.py index fb7311e5a3..adfb722b42 100644 --- a/tests/test_rand_coarse_shuffle.py +++ b/tests/test_rand_coarse_shuffle.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_coarse_shuffled.py b/tests/test_rand_coarse_shuffled.py index fa9c17286d..3b5a1434f4 100644 --- a/tests/test_rand_coarse_shuffled.py +++ b/tests/test_rand_coarse_shuffled.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_crop_by_label_classes.py b/tests/test_rand_crop_by_label_classes.py index b1165e8986..6723dfc4c6 100644 --- a/tests/test_rand_crop_by_label_classes.py +++ b/tests/test_rand_crop_by_label_classes.py @@ -9,13 +9,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms import ClassesToIndices, RandCropByLabelClasses -from tests.utils import TEST_NDARRAYS_ALL +from monai.transforms.lazy.functional import apply_transforms +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TESTS_INDICES, TESTS_SHAPE = [], [] for p in TEST_NDARRAYS_ALL: @@ -105,7 +109,7 @@ "label": None, "num_classes": 2, "spatial_size": [4, 4, 4], - "ratios": [1, 1], + "ratios": (1, 1), # test no assignment "num_samples": 2, "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), "image_threshold": 0, @@ -113,7 +117,7 @@ }, { "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), - "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + "label": p(np.random.randint(0, 1, size=[1, 3, 3, 3])), "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), }, list, @@ -141,6 +145,26 @@ def test_indices(self, input_param, input_data, expected_type, expected_shape): self.assertIsInstance(result, expected_type) self.assertTupleEqual(result[0].shape, expected_shape) + @parameterized.expand(TESTS_INDICES + TESTS_SHAPE) + def test_pending_ops(self, input_param, input_data, _expected_type, _expected_shape): + cropper = RandCropByLabelClasses(**input_param) + # non-lazy + cropper.set_random_state(0) + expected = cropper(**input_data) + self.assertIsInstance(expected[0], MetaTensor) + # lazy + cropper.set_random_state(0) + cropper.lazy_evaluation = True + pending_result = cropper(**input_data) + for i, _pending_result in enumerate(pending_result): + self.assertIsInstance(_pending_result, MetaTensor) + assert_allclose(_pending_result.peek_pending_affine(), expected[i].affine) + assert_allclose(_pending_result.peek_pending_shape(), expected[i].shape[1:]) + # only support nearest + result = apply_transforms(_pending_result, mode="nearest", align_corners=False)[0] + # compare + assert_allclose(result, expected[i], rtol=1e-5) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_crop_by_label_classesd.py b/tests/test_rand_crop_by_label_classesd.py index 24600a41ef..77221e67cd 100644 --- a/tests/test_rand_crop_by_label_classesd.py +++ b/tests/test_rand_crop_by_label_classesd.py @@ -9,13 +9,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms import ClassesToIndicesd, RandCropByLabelClassesd -from tests.utils import TEST_NDARRAYS_ALL +from monai.transforms.lazy.functional import apply_transforms +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TESTS = [] for p in TEST_NDARRAYS_ALL: @@ -73,7 +77,7 @@ "label_key": "label", "num_classes": 2, "spatial_size": [4, 4, 2], - "ratios": [1, 1], + "ratios": (1, 1), # test no assignment "num_samples": 2, "image_key": "image", "image_threshold": 0, @@ -82,7 +86,7 @@ { "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), - "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + "label": p(np.random.randint(0, 1, size=[1, 3, 3, 3])), }, list, (3, 3, 3, 2), @@ -129,6 +133,26 @@ def test_type_shape(self, input_param, input_data, expected_type, expected_shape _len = len(tuple(input_data.keys())) - 1 # except for the indices_key self.assertTupleEqual(tuple(result[0].keys())[:_len], tuple(input_data.keys())[:-1]) + @parameterized.expand(TESTS) + def test_pending_ops(self, input_param, input_data, _expected_type, _expected_shape): + cropper = RandCropByLabelClassesd(**input_param) + # non-lazy + cropper.set_random_state(0) + expected = cropper(input_data) + self.assertIsInstance(expected[0]["img"], MetaTensor) + # lazy + cropper.set_random_state(0) + cropper.lazy_evaluation = True + pending_result = cropper(input_data) + for i, _pending_result in enumerate(pending_result): + self.assertIsInstance(_pending_result["img"], MetaTensor) + assert_allclose(_pending_result["img"].peek_pending_affine(), expected[i]["img"].affine) + assert_allclose(_pending_result["img"].peek_pending_shape(), expected[i]["img"].shape[1:]) + # only support nearest + result = apply_transforms(_pending_result["img"], mode="nearest", align_corners=False)[0] + # compare + assert_allclose(result, expected[i]["img"], rtol=1e-5) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_crop_by_pos_neg_label.py b/tests/test_rand_crop_by_pos_neg_label.py index f6da393ab9..e1c4cdff58 100644 --- a/tests/test_rand_crop_by_pos_neg_label.py +++ b/tests/test_rand_crop_by_pos_neg_label.py @@ -9,14 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy import numpy as np from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms import RandCropByPosNegLabel -from tests.utils import TEST_NDARRAYS_ALL +from monai.transforms.lazy.functional import apply_transforms +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TESTS = [ [ @@ -120,6 +124,29 @@ def test_type_shape(self, input_param, input_data, expected_shape): if len(results) > 1: np.testing.assert_allclose(results[0], results[-1]) + @parameterized.expand(TESTS) + def test_pending_ops(self, input_param, input_data, _expected_shape): + for p in TEST_NDARRAYS_ALL: + input_param_mod = self.convert_data_type(p, input_param) + input_data_mod = self.convert_data_type(p, input_data) + cropper = RandCropByPosNegLabel(**input_param_mod) + # non-lazy + cropper.set_random_state(0) + expected = cropper(**input_data_mod) + self.assertIsInstance(expected[0], MetaTensor) + # lazy + cropper.set_random_state(0) + cropper.lazy_evaluation = True + pending_result = cropper(**input_data_mod) + for i, _pending_result in enumerate(pending_result): + self.assertIsInstance(_pending_result, MetaTensor) + assert_allclose(_pending_result.peek_pending_affine(), expected[i].affine) + assert_allclose(_pending_result.peek_pending_shape(), expected[i].shape[1:]) + # only support nearest + result = apply_transforms(_pending_result, mode="nearest", align_corners=False)[0] + # compare + assert_allclose(result, expected[i], rtol=1e-5) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_crop_by_pos_neg_labeld.py b/tests/test_rand_crop_by_pos_neg_labeld.py index 64673bf4bf..11b7960617 100644 --- a/tests/test_rand_crop_by_pos_neg_labeld.py +++ b/tests/test_rand_crop_by_pos_neg_labeld.py @@ -9,14 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy import numpy as np from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms import RandCropByPosNegLabeld -from tests.utils import TEST_NDARRAYS_ALL +from monai.transforms.lazy.functional import apply_transforms +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TESTS = [ [ @@ -137,6 +141,31 @@ def test_correct_center(self): result = cropper(test_image) np.testing.assert_allclose(result[0]["label"], np.asarray([[[0, 0, 1], [0, 0, 0], [0, 0, 0]]])) + @parameterized.expand(TESTS) + def test_pending_ops(self, input_param, input_data, _expected_shape): + for p in TEST_NDARRAYS_ALL: + input_param_mod = self.convert_data_type(p, input_param) + input_data_mod = self.convert_data_type(p, input_data) + cropper = RandCropByPosNegLabeld(**input_param_mod) + # non-lazy + cropper.set_random_state(0) + expected = cropper(input_data_mod) + self.assertIsInstance(expected[0]["image"], MetaTensor) + # lazy + cropper.set_random_state(0) + cropper.lazy_evaluation = True + pending_result = cropper(input_data_mod) + for i, _pending_result in enumerate(pending_result): + self.assertIsInstance(_pending_result["image"], MetaTensor) + assert_allclose(_pending_result["image"].peek_pending_affine(), expected[i]["image"].affine) + assert_allclose(_pending_result["image"].peek_pending_shape(), expected[i]["image"].shape[1:]) + # only support nearest + result_image = apply_transforms(_pending_result["image"], mode="nearest", align_corners=False)[0] + result_extra = apply_transforms(_pending_result["extra"], mode="nearest", align_corners=False)[0] + # compare + assert_allclose(result_image, expected[i]["image"], rtol=1e-5) + assert_allclose(result_extra, expected[i]["extra"], rtol=1e-5) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_cucim_dict_transform.py b/tests/test_rand_cucim_dict_transform.py index 4a2fdfe77e..33e0667723 100644 --- a/tests/test_rand_cucim_dict_transform.py +++ b/tests/test_rand_cucim_dict_transform.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_cucim_transform.py b/tests/test_rand_cucim_transform.py index 8f81730c59..37d8e29f1d 100644 --- a/tests/test_rand_cucim_transform.py +++ b/tests/test_rand_cucim_transform.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_deform_grid.py b/tests/test_rand_deform_grid.py index 3e59e3207b..58b64ae596 100644 --- a/tests/test_rand_deform_grid.py +++ b/tests/test_rand_deform_grid.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_elastic_2d.py b/tests/test_rand_elastic_2d.py index 125da74528..c59052854f 100644 --- a/tests/test_rand_elastic_2d.py +++ b/tests/test_rand_elastic_2d.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_elastic_3d.py b/tests/test_rand_elastic_3d.py index 76c9e9024d..0ff3ef6129 100644 --- a/tests/test_rand_elastic_3d.py +++ b/tests/test_rand_elastic_3d.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_elasticd_2d.py b/tests/test_rand_elasticd_2d.py index d6f7a0cbba..d0fbd5aa88 100644 --- a/tests/test_rand_elasticd_2d.py +++ b/tests/test_rand_elasticd_2d.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_elasticd_3d.py b/tests/test_rand_elasticd_3d.py index 9db474861e..e058293584 100644 --- a/tests/test_rand_elasticd_3d.py +++ b/tests/test_rand_elasticd_3d.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_flip.py b/tests/test_rand_flip.py index cdd51dd77e..c3b0bfdede 100644 --- a/tests/test_rand_flip.py +++ b/tests/test_rand_flip.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -17,6 +19,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import RandFlip +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)] @@ -35,7 +38,8 @@ def test_invalid_inputs(self, _, spatial_axis, raises): def test_correct_results(self, _, spatial_axis): for p in TEST_NDARRAYS_ALL: im = p(self.imt[0]) - flip = RandFlip(prob=1.0, spatial_axis=spatial_axis) + init_param = {"prob": 1.0, "spatial_axis": spatial_axis} + flip = RandFlip(**init_param) set_track_meta(False) result = flip(im) self.assertNotIsInstance(result, MetaTensor) @@ -43,10 +47,14 @@ def test_correct_results(self, _, spatial_axis): set_track_meta(True) expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) - result = flip(im) + call_param = {"img": im} + result = flip(**call_param) assert_allclose(result, p(expected), type_test="tensor") test_local_inversion(flip, result, im) + # test lazy + test_resampler_lazy(flip, result, init_param, call_param) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_flipd.py b/tests/test_rand_flipd.py index 92b070fd0a..d67b4ca31b 100644 --- a/tests/test_rand_flipd.py +++ b/tests/test_rand_flipd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -17,6 +19,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import RandFlipd +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion VALID_CASES = [("no_axis", None), ("one_axis", 1), ("many_axis", [0, 1])] @@ -26,13 +29,21 @@ class TestRandFlipd(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) def test_correct_results(self, _, spatial_axis): for p in TEST_NDARRAYS_ALL: - flip = RandFlipd(keys="img", prob=1.0, spatial_axis=spatial_axis) + init_param = {"keys": "img", "prob": 1.0, "spatial_axis": spatial_axis} + flip = RandFlipd(**init_param) im = p(self.imt[0]) - result = flip({"img": im})["img"] + call_param = {"data": {"img": im}} + result = flip(**call_param) + + # test lazy + test_resampler_lazy(flip, result, init_param, call_param, output_key="img") + flip.lazy_evaluation = False + expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(result, p(expected), type_test="tensor") - test_local_inversion(flip, {"img": result}, {"img": im}, "img") + assert_allclose(result["img"], p(expected), type_test="tensor") + test_local_inversion(flip, {"img": result["img"]}, {"img": im}, "img") + set_track_meta(False) result = flip({"img": im})["img"] self.assertNotIsInstance(result, MetaTensor) diff --git a/tests/test_rand_gaussian_noise.py b/tests/test_rand_gaussian_noise.py index faa2b143f5..7d4d04ff3f 100644 --- a/tests/test_rand_gaussian_noise.py +++ b/tests/test_rand_gaussian_noise.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_gaussian_noised.py b/tests/test_rand_gaussian_noised.py index a927761186..24fc19f226 100644 --- a/tests/test_rand_gaussian_noised.py +++ b/tests/test_rand_gaussian_noised.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_gaussian_sharpen.py b/tests/test_rand_gaussian_sharpen.py index 01d1a1ecec..8dff69cd4c 100644 --- a/tests/test_rand_gaussian_sharpen.py +++ b/tests/test_rand_gaussian_sharpen.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_rand_gaussian_sharpend.py b/tests/test_rand_gaussian_sharpend.py index 02961fc7ec..4c32880053 100644 --- a/tests/test_rand_gaussian_sharpend.py +++ b/tests/test_rand_gaussian_sharpend.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_gaussian_smooth.py b/tests/test_rand_gaussian_smooth.py index f6d75305a5..9fb91a38a1 100644 --- a/tests/test_rand_gaussian_smooth.py +++ b/tests/test_rand_gaussian_smooth.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_gaussian_smoothd.py b/tests/test_rand_gaussian_smoothd.py index c9be44d3c9..d312494e46 100644 --- a/tests/test_rand_gaussian_smoothd.py +++ b/tests/test_rand_gaussian_smoothd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_gibbs_noise.py b/tests/test_rand_gibbs_noise.py index b87b839eb9..a0d18ae7f3 100644 --- a/tests/test_rand_gibbs_noise.py +++ b/tests/test_rand_gibbs_noise.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy diff --git a/tests/test_rand_gibbs_noised.py b/tests/test_rand_gibbs_noised.py index 23a7dd5fdb..4120f967e2 100644 --- a/tests/test_rand_gibbs_noised.py +++ b/tests/test_rand_gibbs_noised.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy diff --git a/tests/test_rand_grid_distortion.py b/tests/test_rand_grid_distortion.py index 88b4989cd5..8131a2382a 100644 --- a/tests/test_rand_grid_distortion.py +++ b/tests/test_rand_grid_distortion.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -64,15 +66,15 @@ [3.132195, 3.132195, 3.132195, 3.132195, 3.132195, 3.132195], [3.132195, 3.132195, 3.132195, 3.132195, 3.132195, 3.132195], [4.482229, 4.482229, 4.482229, 4.482229, 4.482229, 4.482229], - [4.167737, 4.167737, 4.167737, 4.167737, 4.167737, 4.167737], + [5.0, 5.0, 5.0, 5.0, 5.0, 5.0], ], [ - [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 4.4565434], - [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 4.4565434], - [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 4.4565434], - [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 4.4565434], - [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 4.4565434], - [0.0, 1.3940266, 2.7880538, 2.7880538, 4.1657557, 4.456543], + [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 5.0], + [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 5.0], + [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 5.0], + [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 5.0], + [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 5.0], + [0.0, 1.3940266, 2.7880538, 2.7880538, 4.1657557, 5.0], ], ] ).astype(np.float32) @@ -87,7 +89,10 @@ def test_rand_grid_distortion(self, input_param, seed, input_data, expected_val) g = RandGridDistortion(**input_param) g.set_random_state(seed=seed) result = g(input_data) - assert_allclose(result, expected_val, type_test="tensor", rtol=1e-4, atol=1e-4) + if input_param["padding_mode"] != "reflection": + assert_allclose(result, expected_val, type_test="tensor", rtol=1e-4, atol=1e-4) + else: + assert_allclose(result.shape, expected_val.shape, type_test=False, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_grid_distortiond.py b/tests/test_rand_grid_distortiond.py index a7b64e5980..9f8ed3b9e6 100644 --- a/tests/test_rand_grid_distortiond.py +++ b/tests/test_rand_grid_distortiond.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_grid_patch.py b/tests/test_rand_grid_patch.py index 417915fbab..cb66276a8c 100644 --- a/tests/test_rand_grid_patch.py +++ b/tests/test_rand_grid_patch.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_grid_patchd.py b/tests/test_rand_grid_patchd.py index 4f3ec3bb6a..15f4d5447f 100644 --- a/tests/test_rand_grid_patchd.py +++ b/tests/test_rand_grid_patchd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_histogram_shift.py b/tests/test_rand_histogram_shift.py index 89198549cd..318dad9dfa 100644 --- a/tests/test_rand_histogram_shift.py +++ b/tests/test_rand_histogram_shift.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -42,6 +44,16 @@ ] ) +WARN_TESTS = [] +for p in TEST_NDARRAYS: + WARN_TESTS.append( + [ + {"num_control_points": 5, "prob": 1.0}, + {"img": p(np.zeros(8).reshape((1, 2, 2, 2)))}, + np.zeros(8).reshape((1, 2, 2, 2)), + ] + ) + class TestRandHistogramShift(unittest.TestCase): @parameterized.expand(TESTS) @@ -69,6 +81,12 @@ def test_interp(self): self.assertEqual(yi.shape, (3, 2)) assert_allclose(yi, array_type([[1.0, 5.0], [0.5, -0.5], [4.0, 5.0]])) + @parameterized.expand(WARN_TESTS) + def test_warn(self, input_param, input_data, expected_val): + with self.assertWarns(Warning): + result = RandHistogramShift(**input_param)(**input_data) + assert_allclose(result, expected_val, type_test="tensor") + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_histogram_shiftd.py b/tests/test_rand_histogram_shiftd.py index 7c94379e0e..45e81ab012 100644 --- a/tests/test_rand_histogram_shiftd.py +++ b/tests/test_rand_histogram_shiftd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_k_space_spike_noise.py b/tests/test_rand_k_space_spike_noise.py index 176699ddd1..4e7d59329b 100644 --- a/tests/test_rand_k_space_spike_noise.py +++ b/tests/test_rand_k_space_spike_noise.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy diff --git a/tests/test_rand_k_space_spike_noised.py b/tests/test_rand_k_space_spike_noised.py index 7f493ef276..3e1c11b2d9 100644 --- a/tests/test_rand_k_space_spike_noised.py +++ b/tests/test_rand_k_space_spike_noised.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy @@ -44,7 +46,6 @@ def get_data(im_shape, im_type): @parameterized.expand(TESTS) def test_same_result(self, im_shape, im_type): - data = self.get_data(im_shape, im_type) t = RandKSpaceSpikeNoised(KEYS, prob=1.0, intensity_range=(13, 15), channel_wise=True) diff --git a/tests/test_rand_lambda.py b/tests/test_rand_lambda.py index cb5c57e9e4..1f14499bc0 100644 --- a/tests/test_rand_lambda.py +++ b/tests/test_rand_lambda.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy diff --git a/tests/test_rand_lambdad.py b/tests/test_rand_lambdad.py index 8bd7bbbfc8..6b60a3fe70 100644 --- a/tests/test_rand_lambdad.py +++ b/tests/test_rand_lambdad.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy diff --git a/tests/test_rand_rician_noise.py b/tests/test_rand_rician_noise.py index 9ee1a6ce82..fe7135835e 100644 --- a/tests/test_rand_rician_noise.py +++ b/tests/test_rand_rician_noise.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_rician_noised.py b/tests/test_rand_rician_noised.py index 05707059bc..ae0acab4eb 100644 --- a/tests/test_rand_rician_noised.py +++ b/tests/test_rand_rician_noised.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_rotate.py b/tests/test_rand_rotate.py index 7cba4f99c9..fe0d432fd4 100644 --- a/tests/test_rand_rotate.py +++ b/tests/test_rand_rotate.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import List, Tuple import numpy as np import scipy.ndimage @@ -19,6 +20,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import RandRotate +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import ( TEST_NDARRAYS_ALL, NumpyImageTestCase2D, @@ -27,14 +29,14 @@ test_local_inversion, ) -TEST_CASES_2D: List[Tuple] = [] +TEST_CASES_2D: list[tuple] = [] for p in TEST_NDARRAYS_ALL: TEST_CASES_2D.append((p, np.pi / 2, True, "bilinear", "border", False)) TEST_CASES_2D.append((p, np.pi / 4, True, "nearest", "border", False)) TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", True)) TEST_CASES_2D.append((p, (-np.pi / 4, 0), False, "nearest", "zeros", True)) -TEST_CASES_3D: List[Tuple] = [] +TEST_CASES_3D: list[tuple] = [] for p in TEST_NDARRAYS_ALL: TEST_CASES_3D.append( (p, np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 81, 110, 112)) @@ -71,17 +73,23 @@ class TestRandRotate2D(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES_2D) def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners): - rotate_fn = RandRotate( - range_x=degrees, - prob=1.0, - keep_size=keep_size, - mode=mode, - padding_mode=padding_mode, - align_corners=align_corners, - dtype=np.float64, - ) + init_param = { + "range_x": degrees, + "prob": 1.0, + "keep_size": keep_size, + "mode": mode, + "padding_mode": padding_mode, + "align_corners": align_corners, + "dtype": np.float64, + } + rotate_fn = RandRotate(**init_param) rotate_fn.set_random_state(243) - rotated = rotate_fn(im_type(self.imt[0])) + call_param = {"img": im_type(self.imt[0])} + rotated = rotate_fn(**call_param) + + # test lazy + test_resampler_lazy(rotate_fn, rotated, init_param=init_param, call_param=call_param, seed=243) + rotate_fn.lazy_evaluation = False _order = 0 if mode == "nearest" else 1 if mode == "border": @@ -97,26 +105,33 @@ def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, expected = np.stack(expected).astype(np.float32) rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated good = np.sum(np.isclose(expected, rotated[0], atol=1e-3)) - self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") + self.assertLessEqual(np.abs(good - expected.size), 40, "diff at most 40 pixels") class TestRandRotate3D(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) def test_correct_results(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected): - rotate_fn = RandRotate( - range_x=x, - range_y=y, - range_z=z, - prob=1.0, - keep_size=keep_size, - mode=mode, - padding_mode=padding_mode, - align_corners=align_corners, - dtype=np.float64, - ) + init_param = { + "range_x": x, + "range_y": y, + "range_z": z, + "prob": 1.0, + "keep_size": keep_size, + "mode": mode, + "padding_mode": padding_mode, + "align_corners": align_corners, + "dtype": np.float64, + } + rotate_fn = RandRotate(**init_param) rotate_fn.set_random_state(243) im = im_type(self.imt[0]) - rotated = rotate_fn(im) + call_param = {"img": im} + rotated = rotate_fn(**call_param) + + # test lazy + test_resampler_lazy(rotate_fn, rotated, init_param=init_param, call_param=call_param, seed=243) + rotate_fn.lazy_evaluation = False + assert_allclose(rotated.shape, expected, rtol=1e-7, atol=0) test_local_inversion(rotate_fn, rotated, im) diff --git a/tests/test_rand_rotate90.py b/tests/test_rand_rotate90.py index 30ad906ac2..2504c0f01b 100644 --- a/tests/test_rand_rotate90.py +++ b/tests/test_rand_rotate90.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -16,6 +18,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import RandRotate90 +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion @@ -25,14 +28,20 @@ def test_default(self): for p in TEST_NDARRAYS_ALL: rotate.set_random_state(123) im = p(self.imt[0]) - rotated = rotate(im) + call_param = {"img": im} + rotated = rotate(**call_param) test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test="tensor") + # test lazy + test_resampler_lazy(rotate, rotated, call_param=call_param, seed=123) + rotate.lazy_evaluation = False + def test_k(self): - rotate = RandRotate90(max_k=2) + init_param = {"max_k": 2} + rotate = RandRotate90(**init_param) for p in TEST_NDARRAYS_ALL: im = p(self.imt[0]) set_track_meta(False) @@ -42,18 +51,28 @@ def test_k(self): set_track_meta(True) rotate.set_random_state(123) - rotated = rotate(im) + call_param = {"img": im} + rotated = rotate(**call_param) test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test="tensor") + # test lazy + test_resampler_lazy(rotate, rotated, call_param=call_param, seed=123) + rotate.lazy_evaluation = False + def test_spatial_axes(self): rotate = RandRotate90(spatial_axes=(0, 1), prob=1.0) for p in TEST_NDARRAYS_ALL: rotate.set_random_state(1234) im = p(self.imt[0]) - rotated = rotate(im) + call_param = {"img": im} + rotated = rotate(**call_param) + # test lazy + test_resampler_lazy(rotate, rotated, call_param=call_param, seed=1234) + rotate.lazy_evaluation = False + self.assertEqual(len(rotated.applied_operations), 1) expected = [np.rot90(channel, rotate._rand_k, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) @@ -65,7 +84,12 @@ def test_prob_k_spatial_axes(self): for p in TEST_NDARRAYS_ALL: rotate.set_random_state(234) im = p(self.imt[0]) - rotated = rotate(im) + call_param = {"img": im} + rotated = rotate(**call_param) + # test lazy + test_resampler_lazy(rotate, rotated, call_param=call_param, seed=234) + rotate.lazy_evaluation = False + test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) diff --git a/tests/test_rand_rotate90d.py b/tests/test_rand_rotate90d.py index ec0e5ac92e..f811f1a6a6 100644 --- a/tests/test_rand_rotate90d.py +++ b/tests/test_rand_rotate90d.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -16,17 +18,24 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import RandRotate90d +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion class TestRandRotate90d(NumpyImageTestCase2D): def test_default(self): - key = None + key = "test" rotate = RandRotate90d(keys=key) for p in TEST_NDARRAYS_ALL: rotate.set_random_state(1323) im = {key: p(self.imt[0])} - rotated = rotate(im) + call_param = {"data": im} + rotated = rotate(**call_param) + + # test lazy + test_resampler_lazy(rotate, rotated, call_param=call_param, seed=1323, output_key=key) + rotate.lazy_evaluation = False + test_local_inversion(rotate, rotated, im, key) expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) @@ -44,7 +53,13 @@ def test_k(self): for p in TEST_NDARRAYS_ALL: rotate.set_random_state(234) im = {key: p(self.imt[0])} - rotated = rotate(im) + call_param = {"data": im} + rotated = rotate(**call_param) + + # test lazy + test_resampler_lazy(rotate, rotated, call_param=call_param, seed=234, output_key=key) + rotate.lazy_evaluation = False + test_local_inversion(rotate, rotated, im, key) expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) @@ -56,7 +71,13 @@ def test_spatial_axes(self): for p in TEST_NDARRAYS_ALL: rotate.set_random_state(234) im = {key: p(self.imt[0])} - rotated = rotate(im) + call_param = {"data": im} + rotated = rotate(**call_param) + + # test lazy + test_resampler_lazy(rotate, rotated, call_param=call_param, seed=234, output_key=key) + rotate.lazy_evaluation = False + test_local_inversion(rotate, rotated, im, key) expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) @@ -68,7 +89,13 @@ def test_prob_k_spatial_axes(self): for p in TEST_NDARRAYS_ALL: rotate.set_random_state(234) im = {key: p(self.imt[0])} - rotated = rotate(im) + call_param = {"data": im} + rotated = rotate(**call_param) + + # test lazy + test_resampler_lazy(rotate, rotated, call_param=call_param, seed=234, output_key=key) + rotate.lazy_evaluation = False + expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) assert_allclose(rotated[key], p(expected), type_test="tensor") diff --git a/tests/test_rand_rotated.py b/tests/test_rand_rotated.py index fe57c641a5..e45286f50d 100644 --- a/tests/test_rand_rotated.py +++ b/tests/test_rand_rotated.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import List, Tuple import numpy as np import scipy.ndimage @@ -19,16 +20,17 @@ from monai.transforms import RandRotated from monai.utils import GridSampleMode, GridSamplePadMode +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, test_local_inversion -TEST_CASES_2D: List[Tuple] = [] +TEST_CASES_2D: list[tuple] = [] for p in TEST_NDARRAYS_ALL: TEST_CASES_2D.append((p, np.pi / 2, True, "bilinear", "border", False)) TEST_CASES_2D.append((p, np.pi / 4, True, "nearest", "border", False)) TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", True)) TEST_CASES_2D.append((p, (-np.pi / 4, 0), False, "nearest", "zeros", True)) -TEST_CASES_3D: List[Tuple] = [] +TEST_CASES_3D: list[tuple] = [] for p in TEST_NDARRAYS_ALL: TEST_CASES_3D.append( (p, np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 81, 110, 112)) @@ -107,19 +109,26 @@ class TestRandRotated2D(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES_2D) def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners): - rotate_fn = RandRotated( - "img", - range_x=degrees, - prob=1.0, - keep_size=keep_size, - mode=mode, - padding_mode=padding_mode, - align_corners=align_corners, - dtype=np.float64, - ) + init_param = { + "keys": "img", + "range_x": degrees, + "prob": 1.0, + "keep_size": keep_size, + "mode": mode, + "padding_mode": padding_mode, + "align_corners": align_corners, + "dtype": np.float64, + } + rotate_fn = RandRotated(**init_param) im = im_type(self.imt[0]) rotate_fn.set_random_state(243) - rotated = rotate_fn({"img": im, "seg": im_type(self.segn[0])}) + call_param = {"data": {"img": im, "seg": im_type(self.segn[0])}} + rotated = rotate_fn(**call_param) + + # test lazy + test_resampler_lazy( + rotate_fn, rotated, init_param=init_param, call_param=call_param, seed=243, output_key="img" + ) _order = 0 if mode == "nearest" else 1 if padding_mode == "border": @@ -137,26 +146,33 @@ def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v expected = np.stack(expected).astype(np.float32) good = np.sum(np.isclose(expected, rotated["img"][0], atol=1e-3)) - self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") + self.assertLessEqual(np.abs(good - expected.size), 40, "diff at most 40 pixels") class TestRandRotated3D(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) def test_correct_shapes(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected): - rotate_fn = RandRotated( - ("img", "seg"), - range_x=x, - range_y=y, - range_z=z, - prob=1.0, - keep_size=keep_size, - mode=mode, - padding_mode=padding_mode, - align_corners=align_corners, - dtype=np.float64, - ) + init_param = { + "keys": ("img", "seg"), + "range_x": x, + "range_y": y, + "range_z": z, + "prob": 1.0, + "keep_size": keep_size, + "mode": mode, + "padding_mode": padding_mode, + "align_corners": align_corners, + "dtype": np.float64, + } + rotate_fn = RandRotated(**init_param) rotate_fn.set_random_state(243) - rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) + call_param = {"data": {"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}} + rotated = rotate_fn(**call_param) + + # test lazy + test_resampler_lazy( + rotate_fn, rotated, init_param=init_param, call_param=call_param, seed=243, output_key="img" + ) np.testing.assert_allclose(rotated["img"].shape, expected) rotate_fn.prob = 0.0 diff --git a/tests/test_rand_scale_crop.py b/tests/test_rand_scale_crop.py index a97a77a8e6..bf43273fcf 100644 --- a/tests/test_rand_scale_crop.py +++ b/tests/test_rand_scale_crop.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -67,6 +69,10 @@ def test_random_shape(self, input_param, input_shape, expected_shape): result = cropper(input_data) self.assertTupleEqual(result.shape, expected_shape) + @parameterized.expand(TEST_SHAPES) + def test_pending_ops(self, input_param, input_shape, _): + self.crop_test_pending_ops(input_param, input_shape) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_scale_cropd.py b/tests/test_rand_scale_cropd.py index dd92783766..15a48a55d7 100644 --- a/tests/test_rand_scale_cropd.py +++ b/tests/test_rand_scale_cropd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -87,6 +89,10 @@ def test_random_shape(self, input_param, input_shape, expected_shape): result = cropper(input_data)["img"] self.assertTupleEqual(result.shape, expected_shape) + @parameterized.expand(TEST_SHAPES) + def test_pending_ops(self, input_param, input_shape, _): + self.crop_test_pending_ops(input_param, input_shape) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_scale_intensity.py b/tests/test_rand_scale_intensity.py index b0999a82a5..5f5ca076a8 100644 --- a/tests/test_rand_scale_intensity.py +++ b/tests/test_rand_scale_intensity.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_scale_intensityd.py b/tests/test_rand_scale_intensityd.py index d548ee34d6..6b5a04a8f3 100644 --- a/tests/test_rand_scale_intensityd.py +++ b/tests/test_rand_scale_intensityd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_shift_intensity.py b/tests/test_rand_shift_intensity.py index d5ad083d33..12b7ccf526 100644 --- a/tests/test_rand_shift_intensity.py +++ b/tests/test_rand_shift_intensity.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_shift_intensityd.py b/tests/test_rand_shift_intensityd.py index 1a8356c2c9..92bc39dd20 100644 --- a/tests/test_rand_shift_intensityd.py +++ b/tests/test_rand_shift_intensityd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_spatial_crop.py b/tests/test_rand_spatial_crop.py index 383ea8a1cb..a0d56bcaf3 100644 --- a/tests/test_rand_spatial_crop.py +++ b/tests/test_rand_spatial_crop.py @@ -9,12 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np from parameterized import parameterized -from monai.transforms import RandSpatialCrop +from monai.data.meta_tensor import MetaTensor +from monai.transforms import RandScaleCrop, RandSpatialCrop +from monai.transforms.lazy.functional import apply_transforms from tests.croppers import CropTest from tests.utils import TEST_NDARRAYS_ALL, assert_allclose @@ -22,6 +26,7 @@ [{"roi_size": [3, 3, -1], "random_center": True}, (3, 3, 3, 4), (3, 3, 3, 4)], [{"roi_size": [3, 3, 3], "random_center": True}, (3, 3, 3, 3), (3, 3, 3, 3)], [{"roi_size": [3, 3, 3], "random_center": False}, (3, 3, 3, 3), (3, 3, 3, 3)], + [{"roi_size": [3, 3, 2], "random_center": False, "random_size": False}, (3, 3, 3, 3), (3, 3, 3, 2)], ] TEST_VALUES = [ @@ -40,6 +45,15 @@ [{"roi_size": 3, "max_roi_size": 4, "random_center": True, "random_size": True}, (1, 4, 5, 6), (1, 3, 4, 3)], ] +func1 = {RandSpatialCrop: {"roi_size": [8, 7, -1], "random_center": True, "random_size": False}} +func2 = {RandScaleCrop: {"roi_scale": [0.5, 0.6, -1.0], "random_center": True, "random_size": True}} +func3 = {RandScaleCrop: {"roi_scale": [1.0, 0.5, -1.0], "random_center": False, "random_size": False}} + +TESTS_COMBINE = [] +TESTS_COMBINE.append([[func1, func2, func3], (3, 10, 10, 8)]) +TESTS_COMBINE.append([[func1, func2], (3, 8, 8, 4)]) +TESTS_COMBINE.append([[func2, func2], (3, 8, 8, 4)]) + class TestRandSpatialCrop(CropTest): Cropper = RandSpatialCrop @@ -64,8 +78,29 @@ def test_random_shape(self, input_param, input_shape, expected_shape): cropper = RandSpatialCrop(**input_param) cropper.set_random_state(seed=123) input_data = im_type(np.random.randint(0, 2, input_shape)) - result = cropper(input_data) - self.assertTupleEqual(result.shape, expected_shape) + expected = cropper(input_data) + self.assertTupleEqual(expected.shape, expected_shape) + + # lazy + # reset random seed to ensure the same results + cropper.set_random_state(seed=123) + cropper.lazy_evaluation = True + pending_result = cropper(input_data) + self.assertIsInstance(pending_result, MetaTensor) + assert_allclose(pending_result.peek_pending_affine(), expected.affine) + assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) + # only support nearest + result = apply_transforms(pending_result, mode="nearest", align_corners=False)[0] + # compare + assert_allclose(result, expected, rtol=1e-5) + + @parameterized.expand(TEST_SHAPES) + def test_pending_ops(self, input_param, input_shape, _): + self.crop_test_pending_ops(input_param, input_shape) + + @parameterized.expand(TESTS_COMBINE) + def test_combine_ops(self, funcs, input_shape): + self.crop_test_combine_ops(funcs, input_shape) if __name__ == "__main__": diff --git a/tests/test_rand_spatial_crop_samples.py b/tests/test_rand_spatial_crop_samples.py index fd905a6dae..69d2e5af5d 100644 --- a/tests/test_rand_spatial_crop_samples.py +++ b/tests/test_rand_spatial_crop_samples.py @@ -9,12 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms import RandSpatialCropSamples +from monai.transforms.lazy.functional import apply_transforms from tests.croppers import CropTest from tests.utils import TEST_NDARRAYS_ALL, assert_allclose @@ -95,6 +99,30 @@ def test_shape(self, input_param, input_shape, expected_shape, expected_last_ite self.assertEqual(item.meta["patch_index"], i) assert_allclose(result[-1], expected_last_item, type_test="tensor") + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_pending_ops(self, input_param, input_shape, _expected_shape, _expected_last_item): + input_data = np.arange(192).reshape(*input_shape) + + for p in TEST_NDARRAYS_ALL: + xform = RandSpatialCropSamples(**input_param) + image = p(input_data) + # non-lazy + xform.set_random_state(1234) + expected = xform(image) + self.assertIsInstance(expected[0], MetaTensor) + # lazy + xform.set_random_state(1234) + xform.lazy_evaluation = True + pending_result = xform(image) + for i, _pending_result in enumerate(pending_result): + self.assertIsInstance(_pending_result, MetaTensor) + assert_allclose(_pending_result.peek_pending_affine(), expected[i].affine) + assert_allclose(_pending_result.peek_pending_shape(), expected[i].shape[1:]) + # only support nearest + result = apply_transforms(_pending_result, mode="nearest", align_corners=False)[0] + # compare + assert_allclose(result, expected[i], rtol=1e-5) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_spatial_crop_samplesd.py b/tests/test_rand_spatial_crop_samplesd.py index c860cbea4f..fc6e6c8c43 100644 --- a/tests/test_rand_spatial_crop_samplesd.py +++ b/tests/test_rand_spatial_crop_samplesd.py @@ -9,12 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms import Compose, DivisiblePadd, RandSpatialCropSamplesd +from monai.transforms.lazy.functional import apply_transforms from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TEST_CASE_1 = [ @@ -108,6 +112,29 @@ def test_deep_copy(self): for sample in samples: self.assertEqual(len(sample["img"].applied_operations), len(transform)) + @parameterized.expand([TEST_CASE_1, *TEST_CASE_2]) + def test_pending_ops(self, input_param, input_data, _expected_shape, _expected_last): + xform = RandSpatialCropSamplesd(**input_param) + # non-lazy + xform.set_random_state(1234) + expected = xform(input_data) + self.assertIsInstance(expected[0]["img"], MetaTensor) + + # lazy + xform.set_random_state(1234) + xform.lazy_evaluation = True + pending_result = xform(input_data) + for i, _pending_result in enumerate(pending_result): + self.assertIsInstance(_pending_result["img"], MetaTensor) + assert_allclose(_pending_result["img"].peek_pending_affine(), expected[i]["img"].affine) + assert_allclose(_pending_result["img"].peek_pending_shape(), expected[i]["img"].shape[1:]) + # only support nearest + result_img = apply_transforms(_pending_result["img"], mode="nearest", align_corners=False)[0] + result_seg = apply_transforms(_pending_result["seg"], mode="nearest", align_corners=False)[0] + # compare + assert_allclose(result_img, expected[i]["img"], rtol=1e-5) + assert_allclose(result_seg, expected[i]["seg"], rtol=1e-5) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_spatial_cropd.py b/tests/test_rand_spatial_cropd.py index 1b256959c6..5114a45159 100644 --- a/tests/test_rand_spatial_cropd.py +++ b/tests/test_rand_spatial_cropd.py @@ -9,12 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np from parameterized import parameterized -from monai.transforms import RandSpatialCropd +from monai.data.meta_tensor import MetaTensor +from monai.transforms import RandScaleCropd, RandSpatialCropd +from monai.transforms.lazy.functional import apply_transforms from tests.croppers import CropTest from tests.utils import TEST_NDARRAYS_ALL, assert_allclose @@ -22,6 +26,7 @@ [{"keys": "img", "roi_size": [3, 3, -1], "random_center": True}, (3, 3, 3, 5), (3, 3, 3, 5)], [{"keys": "img", "roi_size": [3, 3, 3], "random_center": True}, (3, 3, 3, 3), (3, 3, 3, 3)], [{"keys": "img", "roi_size": [3, 3, 3], "random_center": False}, (3, 3, 3, 3), (3, 3, 3, 3)], + [{"keys": "img", "roi_size": [3, 2, 3], "random_center": False, "random_size": False}, (3, 3, 3, 3), (3, 3, 2, 3)], ] TEST_VALUES = [ @@ -44,6 +49,15 @@ ], ] +func1 = {RandSpatialCropd: {"keys": "img", "roi_size": [8, 7, -1], "random_center": True, "random_size": False}} +func2 = {RandScaleCropd: {"keys": "img", "roi_scale": [0.5, 0.6, -1.0], "random_center": True, "random_size": True}} +func3 = {RandScaleCropd: {"keys": "img", "roi_scale": [1.0, 0.5, -1.0], "random_center": False, "random_size": False}} + +TESTS_COMBINE = [] +TESTS_COMBINE.append([[func1, func2, func3], (3, 10, 10, 8)]) +TESTS_COMBINE.append([[func1, func2], (3, 8, 8, 4)]) +TESTS_COMBINE.append([[func2, func2], (3, 8, 8, 4)]) + class TestRandSpatialCropd(CropTest): Cropper = RandSpatialCropd @@ -69,8 +83,29 @@ def test_random_shape(self, input_param, input_shape, expected_shape): cropper = self.Cropper(**input_param) cropper.set_random_state(seed=123) input_data = {"img": im_type(np.random.randint(0, 2, input_shape))} - result = cropper(input_data)["img"] - self.assertTupleEqual(result.shape, expected_shape) + expected = cropper(input_data)["img"] + self.assertTupleEqual(expected.shape, expected_shape) + + # lazy + # reset random seed to ensure the same results + cropper.set_random_state(seed=123) + cropper.lazy_evaluation = True + pending_result = cropper(input_data)["img"] + self.assertIsInstance(pending_result, MetaTensor) + assert_allclose(pending_result.peek_pending_affine(), expected.affine) + assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) + # only support nearest + result = apply_transforms(pending_result, mode="nearest", align_corners=False)[0] + # compare + assert_allclose(result, expected, rtol=1e-5) + + @parameterized.expand(TEST_SHAPES) + def test_pending_ops(self, input_param, input_shape, _): + self.crop_test_pending_ops(input_param, input_shape) + + @parameterized.expand(TESTS_COMBINE) + def test_combine_ops(self, funcs, input_shape): + self.crop_test_combine_ops(funcs, input_shape) if __name__ == "__main__": diff --git a/tests/test_rand_std_shift_intensity.py b/tests/test_rand_std_shift_intensity.py index a2345dca1d..535fb7cb20 100644 --- a/tests/test_rand_std_shift_intensity.py +++ b/tests/test_rand_std_shift_intensity.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_std_shift_intensityd.py b/tests/test_rand_std_shift_intensityd.py index bbbed053ad..31209ee754 100644 --- a/tests/test_rand_std_shift_intensityd.py +++ b/tests/test_rand_std_shift_intensityd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_rand_weighted_crop.py b/tests/test_rand_weighted_crop.py index 2e5554402d..e279f29f68 100644 --- a/tests/test_rand_weighted_crop.py +++ b/tests/test_rand_weighted_crop.py @@ -9,12 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np from parameterized.parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms.croppad.array import RandWeightedCrop +from monai.transforms.lazy.functional import apply_transforms from tests.croppers import CropTest from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose @@ -165,6 +169,26 @@ def test_rand_weighted_crop(self, _, input_params, img, weight, expected_shape, assert_allclose(res, img, type_test="tensor") self.assertEqual(len(res.applied_operations), 1) + @parameterized.expand(TESTS) + def test_pending_ops(self, _, input_param, img, weight, expected_shape, expected_vals): + crop = RandWeightedCrop(**input_param) + # non-lazy + crop.set_random_state(10) + expected = crop(img, weight) + self.assertIsInstance(expected[0], MetaTensor) + # lazy + crop.set_random_state(10) + crop.lazy_evaluation = True + pending_result = crop(img, weight) + for i, _pending_result in enumerate(pending_result): + self.assertIsInstance(_pending_result, MetaTensor) + assert_allclose(_pending_result.peek_pending_affine(), expected[i].affine) + assert_allclose(_pending_result.peek_pending_shape(), expected[i].shape[1:]) + # only support nearest + result = apply_transforms(_pending_result, mode="nearest", align_corners=False)[0] + # compare + assert_allclose(result, expected[i], rtol=1e-5) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_weighted_cropd.py b/tests/test_rand_weighted_cropd.py index ee5fa5f083..51e1b15c2c 100644 --- a/tests/test_rand_weighted_cropd.py +++ b/tests/test_rand_weighted_cropd.py @@ -9,13 +9,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms.croppad.dictionary import RandWeightedCropd -from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D +from monai.transforms.lazy.functional import apply_transforms +from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose def get_data(ndim): @@ -153,6 +157,26 @@ def test_rand_weighted_cropd(self, _, init_params, input_data, expected_shape, e _len = len(tuple(input_data.keys())) self.assertTupleEqual(tuple(result[0].keys())[:_len], tuple(input_data.keys())) + @parameterized.expand(TESTS) + def test_pending_ops(self, _, input_param, input_data, expected_shape, expected_centers): + crop = RandWeightedCropd(**input_param) + # non-lazy + crop.set_random_state(10) + expected = crop(input_data) + self.assertIsInstance(expected[0]["img"], MetaTensor) + # lazy + crop.set_random_state(10) + crop.lazy_evaluation = True + pending_result = crop(input_data) + for i, _pending_result in enumerate(pending_result): + self.assertIsInstance(_pending_result["img"], MetaTensor) + assert_allclose(_pending_result["img"].peek_pending_affine(), expected[i]["img"].affine) + assert_allclose(_pending_result["img"].peek_pending_shape(), expected[i]["img"].shape[1:]) + # only support nearest + result = apply_transforms(_pending_result["img"], mode="nearest", align_corners=False)[0] + # compare + assert_allclose(result, expected[i]["img"], rtol=1e-5) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_zoom.py b/tests/test_rand_zoom.py index b34d3b0419..76d05da5e3 100644 --- a/tests/test_rand_zoom.py +++ b/tests/test_rand_zoom.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -18,19 +20,41 @@ from monai.transforms import RandZoom from monai.utils import InterpolateMode +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion -VALID_CASES = [(0.8, 1.2, "nearest", False), (0.8, 1.2, InterpolateMode.NEAREST, False)] +VALID_CASES = [ + (0.8, 1.2, "nearest", False), + (0.8, 1.2, InterpolateMode.NEAREST, False), + (0.8, 1.2, InterpolateMode.BILINEAR, False, True), + (0.8, 1.2, InterpolateMode.BILINEAR, False, False), +] class TestRandZoom(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) - def test_correct_results(self, min_zoom, max_zoom, mode, keep_size): + def test_correct_results(self, min_zoom, max_zoom, mode, keep_size, align_corners=None): for p in TEST_NDARRAYS_ALL: - random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode, keep_size=keep_size) + init_param = { + "prob": 1.0, + "min_zoom": min_zoom, + "max_zoom": max_zoom, + "mode": mode, + "keep_size": keep_size, + "dtype": torch.float64, + "align_corners": align_corners, + } + random_zoom = RandZoom(**init_param) random_zoom.set_random_state(1234) im = p(self.imt[0]) - zoomed = random_zoom(im) + call_param = {"img": im} + zoomed = random_zoom(**call_param) + + # test lazy + # TODO: temporarily skip "nearest" test + if mode == InterpolateMode.BILINEAR: + test_resampler_lazy(random_zoom, zoomed, init_param, call_param, seed=1234) + test_local_inversion(random_zoom, zoomed, im) expected = [ zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False) diff --git a/tests/test_rand_zoomd.py b/tests/test_rand_zoomd.py index 3a067e8a90..367c99a3e8 100644 --- a/tests/test_rand_zoomd.py +++ b/tests/test_rand_zoomd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -17,29 +19,44 @@ from scipy.ndimage import zoom as zoom_scipy from monai.transforms import RandZoomd +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion -VALID_CASES = [(0.8, 1.2, "nearest", None, False)] +VALID_CASES = [ + (0.8, 1.2, "nearest", None, False), + (0.8, 1.2, "bilinear", None, False), + (0.8, 1.2, "bilinear", False, False), +] class TestRandZoomd(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) def test_correct_results(self, min_zoom, max_zoom, mode, align_corners, keep_size): key = "img" - random_zoom = RandZoomd( - key, - prob=1.0, - min_zoom=min_zoom, - max_zoom=max_zoom, - mode=mode, - align_corners=align_corners, - keep_size=keep_size, - ) + init_param = { + "keys": key, + "prob": 1.0, + "min_zoom": min_zoom, + "max_zoom": max_zoom, + "mode": mode, + "align_corners": align_corners, + "keep_size": keep_size, + "dtype": torch.float64, + } + random_zoom = RandZoomd(**init_param) for p in TEST_NDARRAYS_ALL: random_zoom.set_random_state(1234) im = p(self.imt[0]) - zoomed = random_zoom({key: im}) + call_param = {"data": {key: im}} + zoomed = random_zoom(**call_param) + + # test lazy + # TODO: temporarily skip "nearest" test + if mode == "bilinear": + test_resampler_lazy(random_zoom, zoomed, init_param, call_param, key, seed=1234) + random_zoom.lazy_evaluation = False + test_local_inversion(random_zoom, zoomed, {key: im}, key) expected = [ zoom_scipy(channel, zoom=random_zoom.rand_zoom._zoom, mode="nearest", order=0, prefilter=False) diff --git a/tests/test_randidentity.py b/tests/test_randidentity.py new file mode 100644 index 0000000000..09dc055b4e --- /dev/null +++ b/tests/test_randidentity.py @@ -0,0 +1,49 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import monai.transforms as mt +from monai.data import CacheDataset +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose + + +class T(mt.Transform): + def __call__(self, x): + return x * 2 + + +class TestIdentity(NumpyImageTestCase2D): + def test_identity(self): + for p in TEST_NDARRAYS: + img = p(self.imt) + identity = mt.RandIdentity() + assert_allclose(img, identity(img)) + + def test_caching(self, init=1, expect=4, expect_pre_cache=2): + # check that we get the correct result (two lots of T so should get 4) + x = init + transforms = mt.Compose([T(), mt.RandIdentity(), T()]) + self.assertEqual(transforms(x), expect) + + # check we get correct result with CacheDataset + x = [init] + ds = CacheDataset(x, transforms) + self.assertEqual(ds[0], expect) + + # check that the cached value is correct + self.assertEqual(ds._cache[0], expect_pre_cache) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_random_order.py b/tests/test_random_order.py index 0ed46fbb0a..a60202dd78 100644 --- a/tests/test_random_order.py +++ b/tests/test_random_order.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_randomizable.py b/tests/test_randomizable.py index 7445287a12..96854a6db8 100644 --- a/tests/test_randomizable.py +++ b/tests/test_randomizable.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_randomizable_transform_type.py b/tests/test_randomizable_transform_type.py index 9f77d2cd5a..3a0995be68 100644 --- a/tests/test_randomizable_transform_type.py +++ b/tests/test_randomizable_transform_type.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from monai.transforms.transform import RandomizableTrait, RandomizableTransform diff --git a/tests/test_randtorchvisiond.py b/tests/test_randtorchvisiond.py index d7d37a7abc..82f9adf473 100644 --- a/tests/test_randtorchvisiond.py +++ b/tests/test_randtorchvisiond.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_recon_net_utils.py b/tests/test_recon_net_utils.py index 6621bf735e..38adb9617b 100644 --- a/tests/test_recon_net_utils.py +++ b/tests/test_recon_net_utils.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_reference_based_normalize_intensity.py b/tests/test_reference_based_normalize_intensity.py index 01811e5907..8d2715f983 100644 --- a/tests/test_reference_based_normalize_intensity.py +++ b/tests/test_reference_based_normalize_intensity.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_reference_based_spatial_cropd.py b/tests/test_reference_based_spatial_cropd.py index ab5573044d..d5777482c0 100644 --- a/tests/test_reference_based_spatial_cropd.py +++ b/tests/test_reference_based_spatial_cropd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_reference_resolver.py b/tests/test_reference_resolver.py index e6b01c05f4..2ec5425258 100644 --- a/tests/test_reference_resolver.py +++ b/tests/test_reference_resolver.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py index 822a056879..6cd973c32e 100644 --- a/tests/test_reg_loss_integration.py +++ b/tests/test_reg_loss_integration.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_regunet.py b/tests/test_regunet.py index 04f971d2eb..04ff60ef30 100644 --- a/tests/test_regunet.py +++ b/tests/test_regunet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_regunet_block.py b/tests/test_regunet_block.py index 3be02ea377..eebe9d8694 100644 --- a/tests/test_regunet_block.py +++ b/tests/test_regunet_block.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch @@ -53,6 +55,7 @@ "out_channels": 1, "kernel_initializer": "zeros", "activation": "sigmoid", + "mode": "trilinear", }, [(1, 3, 2, 2, 2), (1, 2, 4, 4, 4), (1, 1, 8, 8, 8)], (3, 3, 3), diff --git a/tests/test_remove_repeated_channel.py b/tests/test_remove_repeated_channel.py index e4b707ce42..90b1b79b03 100644 --- a/tests/test_remove_repeated_channel.py +++ b/tests/test_remove_repeated_channel.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_remove_repeated_channeld.py b/tests/test_remove_repeated_channeld.py index 9db66a6aa0..6d36d32f6f 100644 --- a/tests/test_remove_repeated_channeld.py +++ b/tests/test_remove_repeated_channeld.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_remove_small_objects.py b/tests/test_remove_small_objects.py index 7130d60739..27dd648e24 100644 --- a/tests/test_remove_small_objects.py +++ b/tests/test_remove_small_objects.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import List, Tuple import numpy as np from parameterized import parameterized @@ -30,7 +31,7 @@ TEST_OUTPUT1 = np.array([[[0, 0, 2, 1, 0], [1, 1, 1, 2, 0], [1, 1, 1, 0, 0]]]) -TESTS: List[Tuple] = [] +TESTS: list[tuple] = [] for dtype in (int, float): for p in TEST_NDARRAYS: TESTS.append((dtype, p, TEST_ZEROS, None)) diff --git a/tests/test_repeat_channel.py b/tests/test_repeat_channel.py index 3d74b6479c..0ae5743836 100644 --- a/tests/test_repeat_channel.py +++ b/tests/test_repeat_channel.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_repeat_channeld.py b/tests/test_repeat_channeld.py index a348f3eea9..9f7872135d 100644 --- a/tests/test_repeat_channeld.py +++ b/tests/test_repeat_channeld.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_replace_module.py b/tests/test_replace_module.py index 4cb4443410..cac3fd39e5 100644 --- a/tests/test_replace_module.py +++ b/tests/test_replace_module.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import Optional, Type import torch from parameterized import parameterized @@ -37,7 +38,7 @@ def setUp(self): self.total = self.get_num_modules() self.assertGreater(self.num_relus, 0) - def get_num_modules(self, mod: Optional[Type[torch.nn.Module]] = None) -> int: + def get_num_modules(self, mod: type[torch.nn.Module] | None = None) -> int: m = [m for _, m in self.net.named_modules()] if mod is not None: m = [_m for _m in m if isinstance(_m, mod)] diff --git a/tests/test_require_pkg.py b/tests/test_require_pkg.py index dbe63fff3f..b1a3d82a17 100644 --- a/tests/test_require_pkg.py +++ b/tests/test_require_pkg.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from monai.utils import OptionalImportError, min_version, require_pkg diff --git a/tests/test_resample.py b/tests/test_resample.py index 0136552334..2df1b7a3ff 100644 --- a/tests/test_resample.py +++ b/tests/test_resample.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch @@ -26,13 +28,13 @@ def rotate_90_2d(): return t -RESAMPLE_FUNCTION_CASES = [(get_arange_img((3, 3)), rotate_90_2d(), [[2, 5, 8], [1, 4, 7], [0, 3, 6]])] +RESAMPLE_FUNCTION_CASES = [(get_arange_img((3, 3)), rotate_90_2d(), [[0, 3, 6], [0, 3, 6], [0, 3, 6]])] class TestResampleFunction(unittest.TestCase): @parameterized.expand(RESAMPLE_FUNCTION_CASES) def test_resample_function_impl(self, img, matrix, expected): - out = resample(convert_to_tensor(img), matrix) + out = resample(convert_to_tensor(img), matrix, img.shape[1:], {"lazy_padding_mode": "border"}) assert_allclose(out[0], expected, type_test=False) diff --git a/tests/test_resample_backends.py b/tests/test_resample_backends.py index 6d231183a9..97ee0731e8 100644 --- a/tests/test_resample_backends.py +++ b/tests/test_resample_backends.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_resample_datalist.py b/tests/test_resample_datalist.py index fa120b261e..ae52492953 100644 --- a/tests/test_resample_datalist.py +++ b/tests/test_resample_datalist.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_resample_to_match.py b/tests/test_resample_to_match.py index 30df565a26..d27897d1a3 100644 --- a/tests/test_resample_to_match.py +++ b/tests/test_resample_to_match.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import itertools import os import random @@ -19,12 +21,15 @@ import nibabel as nib import numpy as np +import torch from parameterized import parameterized +from monai.data import MetaTensor from monai.data.image_reader import ITKReader, NibabelReader from monai.data.image_writer import ITKWriter -from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, ResampleToMatch, SaveImaged +from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, ResampleToMatch, SaveImage, SaveImaged from monai.utils import optional_import +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import assert_allclose, download_url_or_skip_test, testing_data_config _, has_itk = optional_import("itk", allow_namespace_pkg=True) @@ -63,8 +68,14 @@ def tearDownClass(cls): def test_correct(self, reader, writer): loader = Compose([LoadImaged(("im1", "im2"), reader=reader), EnsureChannelFirstd(("im1", "im2"))]) data = loader({"im1": self.fnames[0], "im2": self.fnames[1]}) + tr = ResampleToMatch() + im_mod = tr(data["im2"], data["im1"]) + + # check lazy resample + tr_lazy = ResampleToMatch() + call_param = {"img": data["im2"], "img_dst": data["im1"]} + test_resampler_lazy(tr_lazy, im_mod, init_param={}, call_param=call_param) - im_mod = ResampleToMatch()(data["im2"], data["im1"]) saver = SaveImaged( "im3", output_dir=self.tmpdir, output_postfix="", separate_folder=False, writer=writer, resample=False ) @@ -88,6 +99,13 @@ def test_inverse(self): self.assertLess(((im_mod2.affine - data["im2"].affine) ** 2).sum() ** 0.5, 1e-2) self.assertEqual(im_mod2.applied_operations, []) + def test_no_name(self): + img_1 = MetaTensor(torch.zeros(1, 2, 2, 2)) + img_2 = MetaTensor(torch.zeros(1, 3, 3, 3)) + im_mod = ResampleToMatch()(img_1, img_2) + self.assertEqual(im_mod.meta["filename_or_obj"], "resample_to_match_source") + SaveImage(output_dir=self.tmpdir, output_postfix="", separate_folder=False, resample=False)(im_mod) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_resample_to_matchd.py b/tests/test_resample_to_matchd.py index 566ef4ada9..748e830bdd 100644 --- a/tests/test_resample_to_matchd.py +++ b/tests/test_resample_to_matchd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import shutil import tempfile @@ -24,6 +26,7 @@ ResampleToMatchd, SaveImaged, ) +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import assert_allclose, download_url_or_skip_test, testing_data_config @@ -74,6 +77,17 @@ def test_correct(self): data = Invertd("im3", transforms)(data) assert_allclose(data["im2"].shape, data["im3"].shape) + def test_lazy(self): + pre_transforms = Compose( + [LoadImaged(("im1", "im2")), EnsureChannelFirstd(("im1", "im2")), CopyItemsd(("im2"), names=("im3"))] + ) + data = pre_transforms({"im1": self.fnames[0], "im2": self.fnames[1]}) + init_param = {"keys": "im3", "key_dst": "im1"} + resampler = ResampleToMatchd(**init_param) + call_param = {"data": data} + non_lazy_out = resampler(**call_param) + test_resampler_lazy(resampler, non_lazy_out, init_param, call_param, output_key="im3") + if __name__ == "__main__": unittest.main() diff --git a/tests/test_resampler.py b/tests/test_resampler.py index 5c8ef24c0e..50ea344090 100644 --- a/tests/test_resampler.py +++ b/tests/test_resampler.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -52,17 +54,17 @@ ), ] ) - TESTS.append( - [ - dict(padding_mode="reflection", device=device), - {"grid": p(create_grid((4, 4))), "img": q(np.arange(4).reshape((1, 2, 2))), "mode": "nearest"}, - q( - np.array( - [[[3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0], [3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0]]] - ) - ), - ] - ) + # TESTS.append( # not well defined nearest + reflection resampling + # [ + # dict(padding_mode="reflection", device=device), + # {"grid": p(create_grid((4, 4))), "img": q(np.arange(4).reshape((1, 2, 2))), "mode": "nearest"}, + # q( + # np.array( + # [[[3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0], [3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0]]] + # ) + # ), + # ] + # ) TESTS.append( [ dict(padding_mode="zeros", device=device), diff --git a/tests/test_resize.py b/tests/test_resize.py index b755bb3faf..4925c441de 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -18,6 +20,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import Resize +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, is_tf32_env, pytorch_after TEST_CASE_0 = [{"spatial_size": 15}, (6, 10, 15)] @@ -43,6 +46,15 @@ def test_invalid_inputs(self): resize = Resize(spatial_size=(128,), mode="order") resize(self.imt[0]) + def test_unchange(self): + resize = Resize(spatial_size=(128, 64), mode="bilinear") + set_track_meta(False) + for p in TEST_NDARRAYS_ALL: + im = p(self.imt[0]) + result = resize(im) + assert_allclose(im, result, type_test=False) + set_track_meta(True) + @parameterized.expand( [ ((32, -1), "area", True), @@ -50,12 +62,14 @@ def test_invalid_inputs(self): ((32, 32, 32), "trilinear", True), ((256, 256), "bilinear", False), ((256, 256), "nearest-exact" if pytorch_after(1, 11) else "nearest", False), + ((128, 128), "nearest", False), ((128, 64), "area", True), # already in a good shape ] ) def test_correct_results(self, spatial_size, mode, anti_aliasing): """resize 'spatial_size' and 'mode'""" - resize = Resize(spatial_size, mode=mode, anti_aliasing=anti_aliasing) + init_param = {"spatial_size": spatial_size, "mode": mode, "anti_aliasing": anti_aliasing, "dtype": np.float64} + resize = Resize(**init_param) _order = 0 if mode.endswith("linear"): _order = 1 @@ -72,7 +86,10 @@ def test_correct_results(self, spatial_size, mode, anti_aliasing): expected = np.stack(expected).astype(np.float32) for p in TEST_NDARRAYS_ALL: im = p(self.imt[0]) - out = resize(im) + call_param = {"img": im} + out = resize(**call_param) + if init_param["mode"] in ("bilinear", "nearest") and anti_aliasing is False: + test_resampler_lazy(resize, out, init_param, call_param) if isinstance(im, MetaTensor): im_inv = resize.inverse(out) self.assertTrue(not im_inv.applied_operations) diff --git a/tests/test_resize_with_pad_or_crop.py b/tests/test_resize_with_pad_or_crop.py index 4e097cd3d4..5529ec698a 100644 --- a/tests/test_resize_with_pad_or_crop.py +++ b/tests/test_resize_with_pad_or_crop.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -17,32 +19,37 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import ResizeWithPadOrCrop -from tests.utils import TEST_NDARRAYS_ALL, pytorch_after +from monai.transforms.lazy.functional import apply_transforms +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose, pytorch_after TEST_CASES = [ - [{"spatial_size": [15, 8, 8], "mode": "constant"}, (3, 8, 8, 4), (3, 15, 8, 8)], - [ - {"spatial_size": [15, 4, 8], "mode": "constant", "method": "end", "constant_values": 1}, - (3, 8, 8, 4), - (3, 15, 4, 8), - ], - [{"spatial_size": [15, 4, -1], "mode": "constant"}, (3, 8, 8, 4), (3, 15, 4, 4)], + [{"spatial_size": [15, 8, 8], "mode": "constant"}, (3, 8, 8, 4), (3, 15, 8, 8), True], + [{"spatial_size": [15, 4, -1], "mode": "constant"}, (3, 8, 8, 4), (3, 15, 4, 4), True], [ {"spatial_size": [15, 4, -1], "mode": "reflect" if pytorch_after(1, 11) else "constant"}, (3, 8, 8, 4), (3, 15, 4, 4), + True, ], [ {"spatial_size": [-1, -1, -1], "mode": "reflect" if pytorch_after(1, 11) else "constant"}, (3, 8, 8, 4), (3, 8, 8, 4), + True, + ], + [ + {"spatial_size": [15, 4, 8], "mode": "constant", "method": "end", "constant_values": 1}, + (3, 8, 8, 4), + (3, 15, 4, 8), + True, ], ] +TESTS_PENDING_MODE = {"constant": "zeros", "edge": "border", "reflect": "reflection"} class TestResizeWithPadOrCrop(unittest.TestCase): @parameterized.expand(TEST_CASES) - def test_pad_shape(self, input_param, input_shape, expected_shape): + def test_pad_shape(self, input_param, input_shape, expected_shape, _): for p in TEST_NDARRAYS_ALL: if isinstance(p(0), torch.Tensor) and ( "constant_values" in input_param or input_param["mode"] == "reflect" @@ -60,6 +67,33 @@ def test_pad_shape(self, input_param, input_shape, expected_shape): self.assertIsInstance(inv, MetaTensor) self.assertEqual(inv.applied_operations, []) + @parameterized.expand(TEST_CASES) + def test_pending_ops(self, input_param, input_shape, _expected_data, align_corners): + for p in TEST_NDARRAYS_ALL: + # grid sample only support constant value to be zero + if "constant_values" in input_param and input_param["constant_values"] != 0: + continue + padcropper = ResizeWithPadOrCrop(**input_param) + image = p(np.zeros(input_shape)) + # non-lazy + expected = padcropper(image) + self.assertIsInstance(expected, MetaTensor) + # lazy + padcropper.lazy_evaluation = True + pending_result = padcropper(image) + self.assertIsInstance(pending_result, MetaTensor) + assert_allclose(pending_result.peek_pending_affine(), expected.affine) + assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) + # only support nearest + result = apply_transforms( + pending_result, + mode="nearest", + padding_mode=TESTS_PENDING_MODE[input_param["mode"]], + align_corners=align_corners, + )[0] + # compare + assert_allclose(result, expected, rtol=1e-5) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_resize_with_pad_or_cropd.py b/tests/test_resize_with_pad_or_cropd.py index eb4e5f09cc..a71652375b 100644 --- a/tests/test_resize_with_pad_or_cropd.py +++ b/tests/test_resize_with_pad_or_cropd.py @@ -9,22 +9,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest +from copy import deepcopy import numpy as np import torch from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms import ResizeWithPadOrCropd -from tests.utils import TEST_NDARRAYS_ALL, pytorch_after +from monai.transforms.lazy.functional import apply_transforms +from tests.test_resize_with_pad_or_crop import TESTS_PENDING_MODE +from tests.utils import TEST_NDARRAYS_ALL, assert_allclose, pytorch_after TEST_CASES = [ [{"keys": "img", "spatial_size": [15, 8, 8], "mode": "constant"}, {"img": np.zeros((3, 8, 8, 4))}, (3, 15, 8, 8)], - [ - {"keys": "img", "spatial_size": [15, 4, 8], "mode": "constant", "method": "end", "constant_values": 1}, - {"img": np.zeros((3, 8, 8, 4))}, - (3, 15, 4, 8), - ], [{"keys": "img", "spatial_size": [15, 4, -1], "mode": "constant"}, {"img": np.zeros((3, 8, 8, 4))}, (3, 15, 4, 4)], [ {"keys": "img", "spatial_size": [15, 4, -1], "mode": "reflect" if pytorch_after(1, 11) else "constant"}, @@ -36,6 +37,11 @@ {"img": np.zeros((3, 8, 8, 4))}, (3, 8, 8, 4), ], + [ + {"keys": "img", "spatial_size": [15, 4, 8], "mode": "constant", "method": "end", "constant_values": 1}, + {"img": np.zeros((3, 8, 8, 4))}, + (3, 15, 4, 8), + ], ] @@ -48,12 +54,37 @@ def test_pad_shape(self, input_param, input_data, expected_val): ): continue padcropper = ResizeWithPadOrCropd(**input_param) - input_data["img"] = p(input_data["img"]) - result = padcropper(input_data) + input_data_ = deepcopy(input_data) + input_data_["img"] = p(input_data["img"]) + result = padcropper(input_data_) np.testing.assert_allclose(result["img"].shape, expected_val) inv = padcropper.inverse(result) - for k in input_data: - self.assertTupleEqual(inv[k].shape, input_data[k].shape) + for k in input_data_: + self.assertTupleEqual(inv[k].shape, input_data_[k].shape) + + @parameterized.expand(TEST_CASES) + def test_pending_ops(self, input_param, input_data, _expected_data): + for p in TEST_NDARRAYS_ALL: + # grid sample only support constant value to be zero + if "constant_values" in input_param and input_param["constant_values"] != 0: + continue + padcropper = ResizeWithPadOrCropd(**input_param) + input_data["img"] = p(input_data["img"]) + # non-lazy + expected = padcropper(input_data)["img"] + self.assertIsInstance(expected, MetaTensor) + # lazy + padcropper.lazy_evaluation = True + pending_result = padcropper(input_data)["img"] + self.assertIsInstance(pending_result, MetaTensor) + assert_allclose(pending_result.peek_pending_affine(), expected.affine) + assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) + # only support nearest + result = apply_transforms( + pending_result, mode="nearest", padding_mode=TESTS_PENDING_MODE[input_param["mode"]], align_corners=True + )[0] + # compare + assert_allclose(result, expected, rtol=1e-5) if __name__ == "__main__": diff --git a/tests/test_resized.py b/tests/test_resized.py index a9da604b15..bd711b33d8 100644 --- a/tests/test_resized.py +++ b/tests/test_resized.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -18,6 +20,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import Invertd, Resize, Resized +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion TEST_CASE_0 = [{"keys": "img", "spatial_size": 15}, (6, 10, 15)] @@ -49,6 +52,9 @@ ((64, 64), "area", True), ((32, 32, 32), "area", True), ((256, 256), "bilinear", False), + ((256, 256), "bilinear", True), + ((128, 128), "nearest", False), + ((128, 128), "nearest", True), ] @@ -62,9 +68,25 @@ def test_invalid_inputs(self): resize = Resized(keys="img", spatial_size=(128,), mode="order") resize({"img": self.imt[0]}) + def test_unchange(self): + resize = Resized(keys="img", spatial_size=(128, 64), mode="bilinear") + set_track_meta(False) + for p in TEST_NDARRAYS_ALL: + im = p(self.imt[0]) + result = resize({"img": im})["img"] + assert_allclose(im, result, type_test=False) + set_track_meta(True) + @parameterized.expand(TEST_CORRECT_CASES) def test_correct_results(self, spatial_size, mode, anti_aliasing): - resize = Resized("img", spatial_size, mode=mode, anti_aliasing=anti_aliasing) + init_param = { + "keys": "img", + "spatial_size": spatial_size, + "mode": mode, + "anti_aliasing": anti_aliasing, + "dtype": np.float32, + } + resize = Resized(**init_param) _order = 0 if mode.endswith("linear"): _order = 1 @@ -80,7 +102,11 @@ def test_correct_results(self, spatial_size, mode, anti_aliasing): expected = np.stack(expected).astype(np.float32) for p in TEST_NDARRAYS_ALL: im = p(self.imt[0]) - out = resize({"img": im}) + call_param = {"data": {"img": im}} + out = resize(**call_param) + lazy_resize = Resized(**init_param) + if init_param["mode"] in ("bilinear", "nearest"): + test_resampler_lazy(lazy_resize, out, init_param, call_param, output_key="img", atol=1e-5) test_local_inversion(resize, out, {"img": im}, "img") assert_allclose(out["img"], expected, type_test=False, atol=1.0) diff --git a/tests/test_resnet.py b/tests/test_resnet.py index b09b97a450..cc24106373 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from typing import TYPE_CHECKING diff --git a/tests/test_retinanet.py b/tests/test_retinanet.py index f067e82962..3925cca2f4 100644 --- a/tests/test_retinanet.py +++ b/tests/test_retinanet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_retinanet_detector.py b/tests/test_retinanet_detector.py index 243828432d..a5a4001f5c 100644 --- a/tests/test_retinanet_detector.py +++ b/tests/test_retinanet_detector.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import random import unittest diff --git a/tests/test_retinanet_predict_utils.py b/tests/test_retinanet_predict_utils.py index 5157691696..d97806e91c 100644 --- a/tests/test_retinanet_predict_utils.py +++ b/tests/test_retinanet_predict_utils.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_rotate.py b/tests/test_rotate.py index d039738b21..6ecdfa6182 100644 --- a/tests/test_rotate.py +++ b/tests/test_rotate.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import List, Tuple import numpy as np import scipy.ndimage @@ -19,9 +20,10 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import Rotate +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, test_local_inversion -TEST_CASES_2D: List[Tuple] = [] +TEST_CASES_2D: list[tuple] = [] for p in TEST_NDARRAYS_ALL: TEST_CASES_2D.append((p, np.pi / 6, False, "bilinear", "border", False)) TEST_CASES_2D.append((p, np.pi / 4, True, "bilinear", "border", False)) @@ -29,7 +31,7 @@ TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", False)) TEST_CASES_2D.append((p, -np.pi / 2, False, "bilinear", "zeros", True)) -TEST_CASES_3D: List[Tuple] = [] +TEST_CASES_3D: list[tuple] = [] for p in TEST_NDARRAYS_ALL: TEST_CASES_3D.append((p, -np.pi / 2, True, "nearest", "border", False)) TEST_CASES_3D.append((p, np.pi / 4, True, "bilinear", "border", False)) @@ -37,7 +39,7 @@ TEST_CASES_3D.append((p, np.pi, False, "nearest", "zeros", False)) TEST_CASES_3D.append((p, -np.pi / 2, False, "bilinear", "zeros", False)) -TEST_CASES_SHAPE_3D: List[Tuple] = [] +TEST_CASES_SHAPE_3D: list[tuple] = [] for p in TEST_NDARRAYS_ALL: TEST_CASES_SHAPE_3D.append((p, [-np.pi / 2, 1.0, 2.0], "nearest", "border", False)) TEST_CASES_SHAPE_3D.append((p, [np.pi / 4, 0, 0], "bilinear", "border", False)) @@ -47,8 +49,18 @@ class TestRotate2D(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES_2D) def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): - rotate_fn = Rotate(angle, keep_size, mode, padding_mode, align_corners, dtype=np.float64) - rotated = rotate_fn(im_type(self.imt[0])) + init_param = { + "angle": angle, + "keep_size": keep_size, + "mode": mode, + "padding_mode": padding_mode, + "align_corners": align_corners, + "dtype": np.float64, + } + rotate_fn = Rotate(**init_param) + call_param = {"img": im_type(self.imt[0])} + rotated = rotate_fn(**call_param) + test_resampler_lazy(rotate_fn, rotated, init_param, call_param) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated.shape) _order = 0 if mode == "nearest" else 1 @@ -75,8 +87,18 @@ def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, al class TestRotate3D(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): - rotate_fn = Rotate([angle, 0, 0], keep_size, mode, padding_mode, align_corners, dtype=np.float64) - rotated = rotate_fn(im_type(self.imt[0])) + init_param = { + "angle": [angle, 0, 0], + "keep_size": keep_size, + "mode": mode, + "padding_mode": padding_mode, + "align_corners": align_corners, + "dtype": np.float64, + } + rotate_fn = Rotate(**init_param) + call_param = {"img": im_type(self.imt[0])} + rotated = rotate_fn(**call_param) + test_resampler_lazy(rotate_fn, rotated, init_param, call_param) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated.shape) _order = 0 if mode == "nearest" else 1 diff --git a/tests/test_rotate90.py b/tests/test_rotate90.py index 69414430c2..fd54e7639f 100644 --- a/tests/test_rotate90.py +++ b/tests/test_rotate90.py @@ -9,12 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np +from parameterized import parameterized from monai.data import MetaTensor, set_track_meta -from monai.transforms import Rotate90 +from monai.transforms import Affine, Rotate90 +from monai.transforms.lazy.functional import apply_transforms +from monai.utils import optional_import +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import ( TEST_NDARRAYS_ALL, NumpyImageTestCase2D, @@ -30,7 +36,13 @@ def test_rotate90_default(self): for p in TEST_NDARRAYS_ALL: im = p(self.imt[0]) set_track_meta(True) - rotated = rotate(im) + call_param = {"img": im} + rotated = rotate(**call_param) + + # test lazy + test_resampler_lazy(rotate, rotated, call_param=call_param) + rotate.lazy_evaluation = False + test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) @@ -44,7 +56,13 @@ def test_k(self): rotate = Rotate90(k=2) for p in TEST_NDARRAYS_ALL: im = p(self.imt[0]) - rotated = rotate(im) + call_param = {"img": im} + rotated = rotate(**call_param) + + # test lazy + test_resampler_lazy(rotate, rotated, call_param=call_param) + rotate.lazy_evaluation = False + test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) @@ -54,7 +72,13 @@ def test_spatial_axes(self): rotate = Rotate90(spatial_axes=(0, -1)) for p in TEST_NDARRAYS_ALL: im = p(self.imt[0]) - rotated = rotate(im) + call_param = {"img": im} + rotated = rotate(**call_param) + + # test lazy + test_resampler_lazy(rotate, rotated, call_param=call_param) + rotate.lazy_evaluation = False + test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 1, (0, -1)) for channel in self.imt[0]] expected = np.stack(expected) @@ -64,8 +88,13 @@ def test_prob_k_spatial_axes(self): rotate = Rotate90(k=2, spatial_axes=(0, 1)) for p in TEST_NDARRAYS_ALL: im = p(self.imt[0]) + call_param = {"img": im} + rotated = rotate(**call_param) + + # test lazy + test_resampler_lazy(rotate, rotated, call_param=call_param) + rotate.lazy_evaluation = False - rotated = rotate(im) test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) @@ -77,7 +106,13 @@ def test_rotate90_default(self): rotate = Rotate90() for p in TEST_NDARRAYS_ALL: im = p(self.imt[0]) - rotated = rotate(im) + call_param = {"img": im} + rotated = rotate(**call_param) + + # test lazy + test_resampler_lazy(rotate, rotated, call_param=call_param) + rotate.lazy_evaluation = False + test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) @@ -87,7 +122,13 @@ def test_k(self): rotate = Rotate90(k=2) for p in TEST_NDARRAYS_ALL: im = p(self.imt[0]) - rotated = rotate(im) + call_param = {"img": im} + rotated = rotate(**call_param) + + # test lazy + test_resampler_lazy(rotate, rotated, call_param=call_param) + rotate.lazy_evaluation = False + test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) @@ -97,7 +138,13 @@ def test_spatial_axes(self): rotate = Rotate90(spatial_axes=(0, -1)) for p in TEST_NDARRAYS_ALL: im = p(self.imt[0]) - rotated = rotate(im) + call_param = {"img": im} + rotated = rotate(**call_param) + + # test lazy + test_resampler_lazy(rotate, rotated, call_param=call_param) + rotate.lazy_evaluation = False + test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 1, (0, -1)) for channel in self.imt[0]] expected = np.stack(expected) @@ -107,12 +154,59 @@ def test_prob_k_spatial_axes(self): rotate = Rotate90(k=2, spatial_axes=(0, 1)) for p in TEST_NDARRAYS_ALL: im = p(self.imt[0]) - rotated = rotate(im) + call_param = {"img": im} + rotated = rotate(**call_param) + + # test lazy + test_resampler_lazy(rotate, rotated, call_param=call_param) + rotate.lazy_evaluation = False + test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test="tensor") +@unittest.skipUnless(optional_import("scipy")[1], "Requires scipy library.") +class TestRot90Consistency(unittest.TestCase): + @parameterized.expand([[2], [3], [4]]) + def test_affine_rot90(self, s): + """s""" + im = np.arange(int(s * s)).reshape(1, s, s).astype(float) + mat = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) + + def method_0(im, ac): + xform = Affine(align_corners=ac, affine=mat, image_only=True, spatial_size=s) + xform.lazy_evaluation = True + out = xform(im) + out = apply_transforms(out, padding_mode="border", align_corners=ac)[0] + return out + + def method_1(im, ac): + xform = Affine(align_corners=ac, affine=mat, image_only=True, spatial_size=s) + xform.lazy_evaluation = True + out = xform(im) + out = apply_transforms(out, mode=1, padding_mode="nearest", align_corners=ac)[0] + return out + + def method_2(im, ac): + xform = Affine(align_corners=ac, affine=mat, padding_mode="border", image_only=True, spatial_size=s) + out = xform(im) + return out + + def method_3(im, ac): + xform = Affine( + align_corners=ac, affine=mat, mode=1, padding_mode="nearest", image_only=True, spatial_size=s + ) + out = xform(im) + return out + + for call in (method_0, method_1, method_2, method_3): + for ac in (False, True): + out = call(im, ac) + ref = Rotate90()(im) + assert_allclose(out, ref, rtol=1e-4, atol=1e-4, type_test=False) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rotate90d.py b/tests/test_rotate90d.py index f88e8937e8..95d475d480 100644 --- a/tests/test_rotate90d.py +++ b/tests/test_rotate90d.py @@ -9,12 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np from monai.data import MetaTensor, set_track_meta from monai.transforms import Rotate90d +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion @@ -25,7 +28,13 @@ def test_rotate90_default(self): for p in TEST_NDARRAYS_ALL: im = p(self.imt[0]) set_track_meta(True) - rotated = rotate({key: im}) + call_param = {"data": {key: im}} + rotated = rotate(**call_param) + + # test lazy + test_resampler_lazy(rotate, rotated, call_param=call_param, output_key=key) + rotate.lazy_evaluation = False + test_local_inversion(rotate, rotated, {key: im}, key) expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) @@ -36,11 +45,17 @@ def test_rotate90_default(self): set_track_meta(True) def test_k(self): - key = None + key = "test" rotate = Rotate90d(keys=key, k=2) for p in TEST_NDARRAYS_ALL: im = p(self.imt[0]) - rotated = rotate({key: im}) + call_param = {"data": {key: im}} + rotated = rotate(**call_param) + + # test lazy + test_resampler_lazy(rotate, rotated, call_param=call_param, output_key=key) + rotate.lazy_evaluation = False + test_local_inversion(rotate, rotated, {key: im}, key) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) @@ -51,7 +66,13 @@ def test_spatial_axes(self): rotate = Rotate90d(keys=key, spatial_axes=(0, 1)) for p in TEST_NDARRAYS_ALL: im = p(self.imt[0]) - rotated = rotate({key: im}) + call_param = {"data": {key: im}} + rotated = rotate(**call_param) + + # test lazy + test_resampler_lazy(rotate, rotated, call_param=call_param, output_key=key) + rotate.lazy_evaluation = False + test_local_inversion(rotate, rotated, {key: im}, key) expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) @@ -62,7 +83,13 @@ def test_prob_k_spatial_axes(self): rotate = Rotate90d(keys=key, k=2, spatial_axes=(0, 1)) for p in TEST_NDARRAYS_ALL: im = p(self.imt[0]) - rotated = rotate({key: im}) + call_param = {"data": {key: im}} + rotated = rotate(**call_param) + + # test lazy + test_resampler_lazy(rotate, rotated, call_param=call_param, output_key=key) + rotate.lazy_evaluation = False + test_local_inversion(rotate, rotated, {key: im}, key) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) diff --git a/tests/test_rotated.py b/tests/test_rotated.py index 48b2e8a3c7..5c51594c6c 100644 --- a/tests/test_rotated.py +++ b/tests/test_rotated.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import List, Tuple import numpy as np import scipy.ndimage @@ -19,9 +20,10 @@ from monai.data import MetaTensor from monai.transforms import Rotated +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, test_local_inversion -TEST_CASES_2D: List[Tuple] = [] +TEST_CASES_2D: list[tuple] = [] for p in TEST_NDARRAYS_ALL: TEST_CASES_2D.append((p, -np.pi / 6, False, "bilinear", "border", False)) TEST_CASES_2D.append((p, -np.pi / 4, True, "bilinear", "border", False)) @@ -29,7 +31,7 @@ TEST_CASES_2D.append((p, -np.pi, False, "nearest", "zeros", False)) TEST_CASES_2D.append((p, np.pi / 2, False, "bilinear", "zeros", True)) -TEST_CASES_3D: List[Tuple] = [] +TEST_CASES_3D: list[tuple] = [] for p in TEST_NDARRAYS_ALL: TEST_CASES_3D.append((p, -np.pi / 6, False, "bilinear", "border", False)) TEST_CASES_3D.append((p, -np.pi / 4, True, "bilinear", "border", False)) @@ -41,11 +43,24 @@ class TestRotated2D(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES_2D) def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): - rotate_fn = Rotated( - ("img", "seg"), angle, keep_size, (mode, "nearest"), padding_mode, align_corners, dtype=np.float64 - ) + init_param = { + "keys": ("img", "seg"), + "angle": angle, + "keep_size": keep_size, + "mode": (mode, "nearest"), + "padding_mode": padding_mode, + "align_corners": align_corners, + "dtype": np.float64, + } + rotate_fn = Rotated(**init_param) im = im_type(self.imt[0]) - rotated = rotate_fn({"img": im, "seg": im_type(self.segn[0])}) + call_param = {"data": {"img": im, "seg": im_type(self.segn[0])}} + rotated = rotate_fn(**call_param) + # test lazy + lazy_init_param = init_param.copy() + for k, m in zip(init_param["keys"], init_param["mode"]): + lazy_init_param["keys"], lazy_init_param["mode"] = k, m + test_resampler_lazy(rotate_fn, rotated, lazy_init_param, call_param, output_key=k) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) _order = 0 if mode == "nearest" else 1 @@ -76,10 +91,23 @@ def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, al class TestRotated3D(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): - rotate_fn = Rotated( - ("img", "seg"), [0, angle, 0], keep_size, (mode, "nearest"), padding_mode, align_corners, dtype=np.float64 - ) - rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) + init_param = { + "keys": ("img", "seg"), + "angle": [0, angle, 0], + "keep_size": keep_size, + "mode": (mode, "nearest"), + "padding_mode": padding_mode, + "align_corners": align_corners, + "dtype": np.float64, + } + rotate_fn = Rotated(**init_param) + call_param = {"data": {"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}} + rotated = rotate_fn(**call_param) + # test lazy + lazy_init_param = init_param.copy() + for k, m in zip(init_param["keys"], init_param["mode"]): + lazy_init_param["keys"], lazy_init_param["mode"] = k, m + test_resampler_lazy(rotate_fn, rotated, lazy_init_param, call_param, output_key=k) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) _order = 0 if mode == "nearest" else 1 diff --git a/tests/test_safe_dtype_range.py b/tests/test_safe_dtype_range.py index 2f0b1bcefc..73f9607d7d 100644 --- a/tests/test_safe_dtype_range.py +++ b/tests/test_safe_dtype_range.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import List, Tuple import numpy as np import torch @@ -22,7 +23,7 @@ cp, _ = optional_import("cupy") -TESTS: List[Tuple] = [] +TESTS: list[tuple] = [] for in_type in TEST_NDARRAYS_ALL + (int, float): TESTS.append((in_type(np.array(1.0)), in_type(np.array(1.0)), None)) # type: ignore if in_type is not float: @@ -31,7 +32,7 @@ for in_type in TEST_NDARRAYS_ALL: TESTS.append((in_type(np.array([[256, 255], [-12, 0]])), in_type(np.array([[255, 255], [0, 0]])), np.uint8)) -TESTS_LIST: List[Tuple] = [] +TESTS_LIST: list[tuple] = [] for in_type in TEST_NDARRAYS_ALL + (int, float): TESTS_LIST.append( ( diff --git a/tests/test_saliency_inferer.py b/tests/test_saliency_inferer.py index c97bcb7811..4efe30d7a6 100644 --- a/tests/test_saliency_inferer.py +++ b/tests/test_saliency_inferer.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_sample_slices.py b/tests/test_sample_slices.py index 117d39b486..02b7926392 100644 --- a/tests/test_sample_slices.py +++ b/tests/test_sample_slices.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_sampler_dist.py b/tests/test_sampler_dist.py index c40ede414a..b2f86c54cc 100644 --- a/tests/test_sampler_dist.py +++ b/tests/test_sampler_dist.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_save_classificationd.py b/tests/test_save_classificationd.py index 10c65c2044..dd0b213bd6 100644 --- a/tests/test_save_classificationd.py +++ b/tests/test_save_classificationd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import csv import os import tempfile diff --git a/tests/test_save_image.py b/tests/test_save_image.py index 4f0c7b5a67..ba94ab5087 100644 --- a/tests/test_save_image.py +++ b/tests/test_save_image.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_save_imaged.py b/tests/test_save_imaged.py index 4b079b73fd..676eb74678 100644 --- a/tests/test_save_imaged.py +++ b/tests/test_save_imaged.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_save_state.py b/tests/test_save_state.py index c48b12ebdc..8ab7080700 100644 --- a/tests/test_save_state.py +++ b/tests/test_save_state.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_savitzky_golay_filter.py b/tests/test_savitzky_golay_filter.py index 0e54276533..b7f89cdfde 100644 --- a/tests/test_savitzky_golay_filter.py +++ b/tests/test_savitzky_golay_filter.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_savitzky_golay_smooth.py b/tests/test_savitzky_golay_smooth.py index b296372986..6da4f24c62 100644 --- a/tests/test_savitzky_golay_smooth.py +++ b/tests/test_savitzky_golay_smooth.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_savitzky_golay_smoothd.py b/tests/test_savitzky_golay_smoothd.py index 730fdeeef2..7e7176e2bb 100644 --- a/tests/test_savitzky_golay_smoothd.py +++ b/tests/test_savitzky_golay_smoothd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_scale_intensity.py b/tests/test_scale_intensity.py index 9941172b0e..57a7da1780 100644 --- a/tests/test_scale_intensity.py +++ b/tests/test_scale_intensity.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_scale_intensity_range.py b/tests/test_scale_intensity_range.py index 958881f790..898f4dfb45 100644 --- a/tests/test_scale_intensity_range.py +++ b/tests/test_scale_intensity_range.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_scale_intensity_range_percentiles.py b/tests/test_scale_intensity_range_percentiles.py index 181243b0fe..583dcec07e 100644 --- a/tests/test_scale_intensity_range_percentiles.py +++ b/tests/test_scale_intensity_range_percentiles.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_scale_intensity_range_percentilesd.py b/tests/test_scale_intensity_range_percentilesd.py index b35f937b95..8e2511d9e4 100644 --- a/tests/test_scale_intensity_range_percentilesd.py +++ b/tests/test_scale_intensity_range_percentilesd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_scale_intensity_ranged.py b/tests/test_scale_intensity_ranged.py index 4ac4910e37..724acf1c73 100644 --- a/tests/test_scale_intensity_ranged.py +++ b/tests/test_scale_intensity_ranged.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from monai.transforms import ScaleIntensityRanged diff --git a/tests/test_scale_intensityd.py b/tests/test_scale_intensityd.py index d560523214..6705cfda9d 100644 --- a/tests/test_scale_intensityd.py +++ b/tests/test_scale_intensityd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_se_block.py b/tests/test_se_block.py index 88983a7746..de129f4d55 100644 --- a/tests/test_se_block.py +++ b/tests/test_se_block.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_se_blocks.py b/tests/test_se_blocks.py index 400ee85e7f..c97e459f50 100644 --- a/tests/test_se_blocks.py +++ b/tests/test_se_blocks.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_seg_loss_integration.py b/tests/test_seg_loss_integration.py index f4a0f25267..23bc63fbf6 100644 --- a/tests/test_seg_loss_integration.py +++ b/tests/test_seg_loss_integration.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_segresnet.py b/tests/test_segresnet.py index b7c37f87b9..cb34445efa 100644 --- a/tests/test_segresnet.py +++ b/tests/test_segresnet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_segresnet_block.py b/tests/test_segresnet_block.py index 9bb435ac1d..dd455aec13 100644 --- a/tests/test_segresnet_block.py +++ b/tests/test_segresnet_block.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_segresnet_ds.py b/tests/test_segresnet_ds.py index b9a5d873dc..a5b88f9724 100644 --- a/tests/test_segresnet_ds.py +++ b/tests/test_segresnet_ds.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch @@ -79,7 +81,6 @@ def test_shape(self, input_param, input_shape, expected_shape): @parameterized.expand(TEST_CASE_SEGRESNET_DS2) def test_shape2(self, input_param, input_shape, expected_shape): - dsdepth = input_param.get("dsdepth", 1) net = SegResNetDS(**input_param).to(device) @@ -105,7 +106,6 @@ def test_shape2(self, input_param, input_shape, expected_shape): @parameterized.expand(TEST_CASE_SEGRESNET_DS3) def test_shape3(self, input_param, input_shape, expected_shapes): - dsdepth = input_param.get("dsdepth", 1) net = SegResNetDS(**input_param).to(device) diff --git a/tests/test_select_cross_validation_folds.py b/tests/test_select_cross_validation_folds.py index 7693baca80..3ab6c0a9c5 100644 --- a/tests/test_select_cross_validation_folds.py +++ b/tests/test_select_cross_validation_folds.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_select_itemsd.py b/tests/test_select_itemsd.py index ba75a27cff..5eb4a1c51b 100644 --- a/tests/test_select_itemsd.py +++ b/tests/test_select_itemsd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import sys import time import unittest diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 407bee341c..926ef7da55 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from unittest import skipUnless @@ -26,7 +28,6 @@ for dropout_rate in np.linspace(0, 1, 4): for hidden_size in [360, 480, 600, 768]: for num_heads in [4, 6, 8, 12]: - test_case = [ {"hidden_size": hidden_size, "num_heads": num_heads, "dropout_rate": dropout_rate}, (2, 512, hidden_size), diff --git a/tests/test_senet.py b/tests/test_senet.py index b0d8ac0c0a..92b5f39ace 100644 --- a/tests/test_senet.py +++ b/tests/test_senet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest from typing import TYPE_CHECKING diff --git a/tests/test_separable_filter.py b/tests/test_separable_filter.py index e6838e2f9b..1797a649e0 100644 --- a/tests/test_separable_filter.py +++ b/tests/test_separable_filter.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_set_determinism.py b/tests/test_set_determinism.py index 7d6c54909d..aab7af1079 100644 --- a/tests/test_set_determinism.py +++ b/tests/test_set_determinism.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_set_visible_devices.py b/tests/test_set_visible_devices.py index 75cbd6fb0d..53703e107a 100644 --- a/tests/test_set_visible_devices.py +++ b/tests/test_set_visible_devices.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest diff --git a/tests/test_shift_intensity.py b/tests/test_shift_intensity.py index ecded268ab..f1bc36036e 100644 --- a/tests/test_shift_intensity.py +++ b/tests/test_shift_intensity.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_shift_intensityd.py b/tests/test_shift_intensityd.py index b5a2a3218d..e8d163b34a 100644 --- a/tests/test_shift_intensityd.py +++ b/tests/test_shift_intensityd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_shuffle_buffer.py b/tests/test_shuffle_buffer.py index 306bc636d2..9fcd3a23f6 100644 --- a/tests/test_shuffle_buffer.py +++ b/tests/test_shuffle_buffer.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import sys import unittest diff --git a/tests/test_signal_continuouswavelet.py b/tests/test_signal_continuouswavelet.py index f8f028aec9..4886168a00 100644 --- a/tests/test_signal_continuouswavelet.py +++ b/tests/test_signal_continuouswavelet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest from unittest import skipUnless diff --git a/tests/test_signal_fillempty.py b/tests/test_signal_fillempty.py index 388426bc95..f44e4ba29a 100644 --- a/tests/test_signal_fillempty.py +++ b/tests/test_signal_fillempty.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest diff --git a/tests/test_signal_rand_add_gaussiannoise.py b/tests/test_signal_rand_add_gaussiannoise.py index dbaf716c4b..2090df876f 100644 --- a/tests/test_signal_rand_add_gaussiannoise.py +++ b/tests/test_signal_rand_add_gaussiannoise.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest diff --git a/tests/test_signal_rand_add_sine.py b/tests/test_signal_rand_add_sine.py index 5cb63f1496..ae0684d608 100644 --- a/tests/test_signal_rand_add_sine.py +++ b/tests/test_signal_rand_add_sine.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest diff --git a/tests/test_signal_rand_add_sine_partial.py b/tests/test_signal_rand_add_sine_partial.py index c04e6b138c..109fb006ea 100644 --- a/tests/test_signal_rand_add_sine_partial.py +++ b/tests/test_signal_rand_add_sine_partial.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest diff --git a/tests/test_signal_rand_add_squarepulse.py b/tests/test_signal_rand_add_squarepulse.py index 6c96f69577..efbdc9af09 100644 --- a/tests/test_signal_rand_add_squarepulse.py +++ b/tests/test_signal_rand_add_squarepulse.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest from unittest import skipUnless diff --git a/tests/test_signal_rand_add_squarepulse_partial.py b/tests/test_signal_rand_add_squarepulse_partial.py index dd7aeae793..eee3f5596d 100644 --- a/tests/test_signal_rand_add_squarepulse_partial.py +++ b/tests/test_signal_rand_add_squarepulse_partial.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest from unittest import skipUnless diff --git a/tests/test_signal_rand_drop.py b/tests/test_signal_rand_drop.py index 4235ae6d87..5dcd466481 100644 --- a/tests/test_signal_rand_drop.py +++ b/tests/test_signal_rand_drop.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest diff --git a/tests/test_signal_rand_scale.py b/tests/test_signal_rand_scale.py index 2ac708ef19..126d7cca65 100644 --- a/tests/test_signal_rand_scale.py +++ b/tests/test_signal_rand_scale.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest diff --git a/tests/test_signal_rand_shift.py b/tests/test_signal_rand_shift.py index 402cd433f8..ed25cc8b1f 100644 --- a/tests/test_signal_rand_shift.py +++ b/tests/test_signal_rand_shift.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest from unittest import skipUnless diff --git a/tests/test_signal_remove_frequency.py b/tests/test_signal_remove_frequency.py index fa70c4f795..b18de36c08 100644 --- a/tests/test_signal_remove_frequency.py +++ b/tests/test_signal_remove_frequency.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest from unittest import skipUnless diff --git a/tests/test_simple_aspp.py b/tests/test_simple_aspp.py index 9c952bd791..f18b208e9c 100644 --- a/tests/test_simple_aspp.py +++ b/tests/test_simple_aspp.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_simulatedelay.py b/tests/test_simulatedelay.py index 3a0507dae7..5cf47b245e 100644 --- a/tests/test_simulatedelay.py +++ b/tests/test_simulatedelay.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import time import unittest diff --git a/tests/test_simulatedelayd.py b/tests/test_simulatedelayd.py index cbabb68e0f..827fe69510 100644 --- a/tests/test_simulatedelayd.py +++ b/tests/test_simulatedelayd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import time import unittest diff --git a/tests/test_skip_connection.py b/tests/test_skip_connection.py index f523891084..0ac8ef0d7a 100644 --- a/tests/test_skip_connection.py +++ b/tests/test_skip_connection.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_slice_inferer.py b/tests/test_slice_inferer.py index 0f33385f42..4d7dea026f 100644 --- a/tests/test_slice_inferer.py +++ b/tests/test_slice_inferer.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_sliding_patch_wsi_dataset.py b/tests/test_sliding_patch_wsi_dataset.py index 06395cf26c..e6d11de739 100644 --- a/tests/test_sliding_patch_wsi_dataset.py +++ b/tests/test_sliding_patch_wsi_dataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest from unittest import skipUnless diff --git a/tests/test_sliding_window_hovernet_inference.py b/tests/test_sliding_window_hovernet_inference.py index 0106f1e5b4..0dc2216c22 100644 --- a/tests/test_sliding_window_hovernet_inference.py +++ b/tests/test_sliding_window_hovernet_inference.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py index 519daee6f5..5f07084927 100644 --- a/tests/test_sliding_window_inference.py +++ b/tests/test_sliding_window_inference.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import itertools import unittest @@ -81,10 +83,12 @@ def test_default_device(self, data_type): def compute(data): return data + 1 + inputs.requires_grad = True result = sliding_window_inference(inputs, roi_shape, sw_batch_size, compute) + self.assertTrue(result.requires_grad) np.testing.assert_string_equal(inputs.device.type, result.device.type) expected_val = np.ones((1, 3, 16, 15, 7), dtype=np.float32) + 1 - np.testing.assert_allclose(result.cpu().numpy(), expected_val) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val) @parameterized.expand(list(itertools.product(TEST_TORCH_AND_META_TENSORS, ("cpu", "cuda"), ("cpu", "cuda", None)))) @skip_if_no_cuda diff --git a/tests/test_sliding_window_splitter.py b/tests/test_sliding_window_splitter.py new file mode 100644 index 0000000000..53192a9720 --- /dev/null +++ b/tests/test_sliding_window_splitter.py @@ -0,0 +1,244 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized +from torch.nn.functional import pad + +from monai.inferers import SlidingWindowSplitter +from tests.utils import assert_allclose + +# random int tensor (0, 255) +TENSOR_4x4 = torch.randint(low=0, high=255, size=(2, 3, 4, 4), dtype=torch.float32) + +# random int tensor (0, 255) with artifacts at [..., :2, 2:] +TENSOR_4x4_artifact = TENSOR_4x4.clone() +TENSOR_4x4_artifact[..., :2, 2:] = 512.0 + +# ---------------------------------------------------------------------------- +# Primary use test cases +# ---------------------------------------------------------------------------- +# no-overlapping 2x2 +TEST_CASE_0 = [ + TENSOR_4x4, + {"patch_size": (2, 2), "overlap": 0.0}, + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], +] + +# no-overlapping 3x3 with pad +TEST_CASE_1 = [ + TENSOR_4x4, + {"patch_size": (3, 3), "overlap": 0.0}, + [ + (TENSOR_4x4[..., :3, :3], (0, 0)), + (pad(TENSOR_4x4[..., :3, 3:], (0, 2)), (0, 3)), + (pad(TENSOR_4x4[..., 3:, :3], (0, 0, 0, 2)), (3, 0)), + (pad(TENSOR_4x4[..., 3:, 3:], (0, 2, 0, 2)), (3, 3)), + ], +] + +# overlapping 2x2 +TEST_CASE_2 = [ + TENSOR_4x4, + {"patch_size": (2, 2), "overlap": (0.5, 0.5)}, + [ + (TENSOR_4x4[..., 0:2, 0:2], (0, 0)), + (TENSOR_4x4[..., 0:2, 1:3], (0, 1)), + (TENSOR_4x4[..., 0:2, 2:4], (0, 2)), + (TENSOR_4x4[..., 1:3, 0:2], (1, 0)), + (TENSOR_4x4[..., 1:3, 1:3], (1, 1)), + (TENSOR_4x4[..., 1:3, 2:4], (1, 2)), + (TENSOR_4x4[..., 2:4, 0:2], (2, 0)), + (TENSOR_4x4[..., 2:4, 1:3], (2, 1)), + (TENSOR_4x4[..., 2:4, 2:4], (2, 2)), + ], +] + +# overlapping 3x3 (non-divisible) +TEST_CASE_3 = [ + TENSOR_4x4, + {"patch_size": (3, 3), "overlap": 2.0 / 3.0}, + [ + (TENSOR_4x4[..., :3, :3], (0, 0)), + (TENSOR_4x4[..., :3, 1:], (0, 1)), + (TENSOR_4x4[..., 1:, :3], (1, 0)), + (TENSOR_4x4[..., 1:, 1:], (1, 1)), + ], +] + +# non-overlapping 2x2 with positive offset +TEST_CASE_4 = [ + TENSOR_4x4, + {"patch_size": (2, 2), "offset": 1}, + [ + (TENSOR_4x4[..., 1:3, 1:3], (1, 1)), + (pad(TENSOR_4x4[..., 1:3, 3:], (0, 1)), (1, 3)), + (pad(TENSOR_4x4[..., 3:, 1:3], (0, 0, 0, 1)), (3, 1)), + (pad(TENSOR_4x4[..., 3:, 3:], (0, 1, 0, 1)), (3, 3)), + ], +] + +# non-overlapping 2x2 with negative offset +TEST_CASE_5 = [ + TENSOR_4x4, + {"patch_size": (2, 2), "offset": -1}, + [ + (pad(TENSOR_4x4[..., :1, :1], (1, 0, 1, 0)), (-1, -1)), + (pad(TENSOR_4x4[..., :1, 1:3], (0, 0, 1, 0)), (-1, 1)), + (pad(TENSOR_4x4[..., :1, 3:], (0, 1, 1, 0)), (-1, 3)), + (pad(TENSOR_4x4[..., 1:3, :1], (1, 0)), (1, -1)), + (TENSOR_4x4[..., 1:3, 1:3], (1, 1)), + (pad(TENSOR_4x4[..., 1:3, 3:], (0, 1)), (1, 3)), + (pad(TENSOR_4x4[..., 3:, :1], (1, 0, 0, 1)), (3, -1)), + (pad(TENSOR_4x4[..., 3:, 1:3], (0, 0, 0, 1)), (3, 1)), + (pad(TENSOR_4x4[..., 3:, 3:], (0, 1, 0, 1)), (3, 3)), + ], +] + +# non-overlapping 2x2 with positive offset and no padding +TEST_CASE_6 = [TENSOR_4x4, {"patch_size": (2, 2), "offset": 1, "pad_mode": None}, [(TENSOR_4x4[..., 1:3, 1:3], (1, 1))]] + + +# ---------------------------------------------------------------------------- +# Filtering function test cases +# ---------------------------------------------------------------------------- +def gen_filter(filter_type, value=None): + """ "Generate patch filtering function for testing""" + if filter_type.lower() == "high": + + def my_filter(patch, location): + if torch.any(patch > value): + return True + return False + + elif filter_type.lower() == "low": + + def my_filter(patch, location): + if torch.any(patch < value): + return True + return False + + elif filter_type.lower() == "location": + + def my_filter(patch, location): + if location in value: + return True + return False + + return my_filter + + +TEST_CASE_FILTER_FN_0 = [ + TENSOR_4x4_artifact, + {"patch_size": (2, 2), "filter_fn": gen_filter("low", 256)}, + [ + (TENSOR_4x4_artifact[..., :2, :2], (0, 0)), + (TENSOR_4x4_artifact[..., 2:, :2], (2, 0)), + (TENSOR_4x4_artifact[..., 2:, 2:], (2, 2)), + ], +] + +TEST_CASE_FILTER_FN_1 = [ + TENSOR_4x4_artifact, + {"patch_size": (2, 2), "filter_fn": gen_filter("high", 256)}, + [(TENSOR_4x4_artifact[..., :2, 2:], (0, 2))], +] + +TEST_CASE_FILTER_FN_2 = [ + TENSOR_4x4_artifact, + {"patch_size": (2, 2), "filter_fn": gen_filter("location", [(2, 2), (2, 0)])}, + [(TENSOR_4x4_artifact[..., 2:, :2], (2, 0)), (TENSOR_4x4_artifact[..., 2:, 2:], (2, 2))], +] + + +# ---------------------------------------------------------------------------- +# Error test cases +# ---------------------------------------------------------------------------- +def extra_parameter_filter(patch, location, extra): + return + + +def missing_parameter_filter(patch): + return + + +# invalid overlap: 1.0 +TEST_CASE_ERROR_0 = [TENSOR_4x4, {"patch_size": (2, 2), "overlap": 1.0}, ValueError] +# invalid overlap: negative +TEST_CASE_ERROR_1 = [TENSOR_4x4, {"patch_size": (2, 2), "overlap": -0.1}, ValueError] + +# invalid offset: positive and larger than image size +TEST_CASE_ERROR_2 = [TENSOR_4x4, {"patch_size": (2, 2), "offset": 4}, ValueError] +# invalid offset: negative and larger than patch size (in magnitude) +TEST_CASE_ERROR_3 = [TENSOR_4x4, {"patch_size": (2, 2), "offset": -3}, ValueError] +# invalid offset: negative and no padding +TEST_CASE_ERROR_4 = [TENSOR_4x4, {"patch_size": (2, 2), "offset": -1, "pad_mode": None}, ValueError] + +# invalid filter function: with more than two positional parameters +TEST_CASE_ERROR_5 = [TENSOR_4x4, {"patch_size": (2, 2), "filter_fn": extra_parameter_filter}, ValueError] +# invalid filter function: with less than two positional parameters +TEST_CASE_ERROR_6 = [TENSOR_4x4, {"patch_size": (2, 2), "filter_fn": missing_parameter_filter}, ValueError] +# invalid filter function: non-callable +TEST_CASE_ERROR_7 = [TENSOR_4x4, {"patch_size": (2, 2), "filter_fn": 1}, ValueError] + + +class SlidingWindowSplitterTests(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_0, + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_FILTER_FN_0, + TEST_CASE_FILTER_FN_1, + TEST_CASE_FILTER_FN_2, + ] + ) + def test_split_patches(self, image, arguments, expected): + patches = SlidingWindowSplitter(**arguments)(image) + patches = list(patches) + self.assertEqual(len(patches), len(expected)) + for p, e in zip(patches, expected): + assert_allclose(p[0], e[0]) + self.assertTupleEqual(p[1], e[1]) + + @parameterized.expand( + [ + TEST_CASE_ERROR_0, + TEST_CASE_ERROR_1, + TEST_CASE_ERROR_2, + TEST_CASE_ERROR_3, + TEST_CASE_ERROR_4, + TEST_CASE_ERROR_5, + TEST_CASE_ERROR_6, + TEST_CASE_ERROR_7, + ] + ) + def test_split_patches_errors(self, image, arguments, expected_error): + with self.assertRaises(expected_error): + patches = SlidingWindowSplitter(**arguments)(image) + patches = list(patches) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_smartcache_patch_wsi_dataset.py b/tests/test_smartcache_patch_wsi_dataset.py index 5760264a7b..e51033bb4e 100644 --- a/tests/test_smartcache_patch_wsi_dataset.py +++ b/tests/test_smartcache_patch_wsi_dataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest from unittest import skipUnless diff --git a/tests/test_smartcachedataset.py b/tests/test_smartcachedataset.py index 63dc1534bc..0e2a79fef3 100644 --- a/tests/test_smartcachedataset.py +++ b/tests/test_smartcachedataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import copy import os import sys @@ -38,7 +40,7 @@ class TestSmartCacheDataset(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_shape(self, replace_rate, num_replace_workers, transform): - test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[8, 8, 8]), np.eye(4)) + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[8, 8, 8]).astype(float), np.eye(4)) with tempfile.TemporaryDirectory() as tempdir: nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) diff --git a/tests/test_smooth_field.py b/tests/test_smooth_field.py index b731af36f4..c525311478 100644 --- a/tests/test_smooth_field.py +++ b/tests/test_smooth_field.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from itertools import product diff --git a/tests/test_sobel_gradient.py b/tests/test_sobel_gradient.py index c464062571..3d995a60c9 100644 --- a/tests/test_sobel_gradient.py +++ b/tests/test_sobel_gradient.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_sobel_gradientd.py b/tests/test_sobel_gradientd.py index 30df091ea0..7499a0410b 100644 --- a/tests/test_sobel_gradientd.py +++ b/tests/test_sobel_gradientd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_spacing.py b/tests/test_spacing.py index a1e289d19b..659e1d88da 100644 --- a/tests/test_spacing.py +++ b/tests/test_spacing.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import Dict, List import numpy as np import torch @@ -21,9 +22,10 @@ from monai.data.utils import affine_to_spacing from monai.transforms import Spacing from monai.utils import fall_back_tuple +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_DEVICES, TEST_NDARRAYS_ALL, assert_allclose, skip_if_quick -TESTS: List[List] = [] +TESTS: list[list] = [] for device in TEST_DEVICES: TESTS.append( [ @@ -267,17 +269,22 @@ class TestSpacingCase(unittest.TestCase): @parameterized.expand(TESTS) def test_spacing( self, - init_param: Dict, + init_param: dict, img: torch.Tensor, affine: torch.Tensor, - data_param: Dict, + data_param: dict, expected_output: torch.Tensor, device: torch.device, ): img = MetaTensor(img, affine=affine).to(device) - res: MetaTensor = Spacing(**init_param)(img, **data_param) + tr = Spacing(**init_param) + call_param = data_param.copy() + call_param["data_array"] = img + res: MetaTensor = tr(**call_param) self.assertEqual(img.device, res.device) + test_resampler_lazy(tr, res, init_param=init_param, call_param=call_param) + assert_allclose(res, expected_output, atol=1e-1, rtol=1e-1) sr = min(len(res.shape) - 1, 3) if isinstance(init_param["pixdim"], float): @@ -289,13 +296,17 @@ def test_spacing( @parameterized.expand(TESTS_TORCH) def test_spacing_torch(self, pixdim, img, track_meta: bool): set_track_meta(track_meta) - tr = Spacing(pixdim=pixdim) - res = tr(img) + init_param = {"pixdim": pixdim} + tr = Spacing(**init_param) + call_param = {"data_array": img} + res = tr(**call_param) + if track_meta: self.assertIsInstance(res, MetaTensor) new_spacing = affine_to_spacing(res.affine, 3) assert_allclose(new_spacing, pixdim, type_test=False) self.assertNotEqual(img.shape, res.shape) + test_resampler_lazy(tr, res, init_param=init_param, call_param=call_param) else: self.assertIsInstance(res, torch.Tensor) self.assertNotIsInstance(res, MetaTensor) @@ -350,6 +361,16 @@ def test_inverse_mn_mx(self, device, recompute, align, scale_extent): self.assertEqual(img_out.shape, img_t.shape) self.assertLess(((affine - img_out.affine) ** 2).sum() ** 0.5, 5e-2) + def test_property_no_change(self): + affine = torch.tensor( + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=torch.float32, device="cpu" + ) + affine[:3] *= 1 - 1e-4 # make sure it's not exactly target but close to the target + img = MetaTensor(torch.rand((1, 10, 9, 8), dtype=torch.float32), affine=affine, meta={"fname": "somewhere"}) + tr = Spacing(pixdim=[1.0, 1.0, 1.0]) + tr(img) + assert_allclose(tr.pixdim, [1.0, 1.0, 1.0], type_test=False) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_spacingd.py b/tests/test_spacingd.py index 22729fd1b2..6fe232d813 100644 --- a/tests/test_spacingd.py +++ b/tests/test_spacingd.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import List, Tuple import numpy as np import torch @@ -20,9 +21,11 @@ from monai.data.meta_tensor import MetaTensor from monai.data.utils import affine_to_spacing from monai.transforms import Spacingd +from monai.utils import ensure_tuple_rep +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_DEVICES, assert_allclose -TESTS: List[Tuple] = [] +TESTS: list[tuple] = [] for device in TEST_DEVICES: TESTS.append( ( @@ -50,7 +53,7 @@ {"image": MetaTensor(torch.ones((2, 10, 20)))}, dict(keys="image", pixdim=(1, 2)), (2, 10, 10), - torch.as_tensor(np.diag((1, 2, 1))), + torch.as_tensor(np.diag((1, 2, 1, 1))), *device, ) ) @@ -63,7 +66,7 @@ }, dict(keys=("image", "seg"), mode="nearest", pixdim=(1, 0.2)), (2, 1, 46), - torch.as_tensor(np.diag((1, 0.2, 1))), + torch.as_tensor(np.diag((1, 0.2, 1, 1))), *device, ) ) @@ -76,7 +79,7 @@ }, dict(keys=("image", "seg"), mode=("bilinear", "nearest"), pixdim=(1, 0.2)), (2, 1, 46), - torch.as_tensor(np.diag((1, 0.2, 1))), + torch.as_tensor(np.diag((1, 0.2, 1, 1))), *device, ) ) @@ -91,7 +94,18 @@ class TestSpacingDCase(unittest.TestCase): @parameterized.expand(TESTS) def test_spacingd(self, _, data, kw_args, expected_shape, expected_affine, device): data = {k: v.to(device) for k, v in data.items()} - res = Spacingd(**kw_args)(data) + tr = Spacingd(**kw_args) + call_param = {"data": data} + res = tr(**call_param) + # test lazy + if not isinstance(kw_args["keys"], str): # multiple keys + kw_args["mode"] = ensure_tuple_rep(kw_args["mode"], len(kw_args["keys"])) + init_param = kw_args.copy() + for key, mode in zip(kw_args["keys"], kw_args["mode"]): + init_param["keys"], init_param["mode"] = key, mode + test_resampler_lazy(tr, res, init_param, call_param, output_key=key) + else: + test_resampler_lazy(tr, res, kw_args, call_param, output_key=kw_args["keys"]) in_img = data["image"] out_img = res["image"] self.assertEqual(in_img.device, out_img.device) @@ -104,11 +118,14 @@ def test_spacingd(self, _, data, kw_args, expected_shape, expected_affine, devic def test_orntd_torch(self, init_param, img: torch.Tensor, track_meta: bool, device): set_track_meta(track_meta) tr = Spacingd(**init_param) - data = {"seg": img.to(device)} - res = tr(data)["seg"] + call_param = {"data": {"seg": img.to(device)}} + res_data = tr(**call_param) # type: ignore + res = res_data["seg"] if track_meta: + test_resampler_lazy(tr, res_data, init_param, call_param, output_key="seg") self.assertIsInstance(res, MetaTensor) + assert isinstance(res, MetaTensor) # for mypy type narrowing new_spacing = affine_to_spacing(res.affine, 3) assert_allclose(new_spacing, init_param["pixdim"], type_test=False) self.assertNotEqual(img.shape, res.shape) @@ -117,6 +134,30 @@ def test_orntd_torch(self, init_param, img: torch.Tensor, track_meta: bool, devi self.assertNotIsInstance(res, MetaTensor) self.assertNotEqual(img.shape, res.shape) + def test_space_same_shape(self): + affine_1 = np.array( + [ + [1.499277e00, 2.699563e-02, 3.805804e-02, -1.948635e02], + [-2.685805e-02, 1.499757e00, -2.635604e-12, 4.438188e01], + [-3.805194e-02, -5.999028e-04, 1.499517e00, 4.036536e01], + [0.000000e00, 0.000000e00, 0.000000e00, 1.000000e00], + ] + ) + affine_2 = np.array( + [ + [1.499275e00, 2.692252e-02, 3.805728e-02, -1.948635e02], + [-2.693010e-02, 1.499758e00, -4.260525e-05, 4.438188e01], + [-3.805190e-02, -6.406730e-04, 1.499517e00, 4.036536e01], + [0.000000e00, 0.000000e00, 0.000000e00, 1.000000e00], + ] + ) + img_1 = MetaTensor(np.zeros((1, 238, 145, 315)), affine=affine_1) + img_2 = MetaTensor(np.zeros((1, 238, 145, 315)), affine=affine_2) + out = Spacingd(("img_1", "img_2"), pixdim=1)({"img_1": img_1, "img_2": img_2}) + self.assertEqual(out["img_1"].shape, out["img_2"].shape) # ensure_same_shape True + out = Spacingd(("img_1", "img_2"), pixdim=1, ensure_same_shape=False)({"img_1": img_1, "img_2": img_2}) + self.assertNotEqual(out["img_1"].shape, out["img_2"].shape) # ensure_same_shape False + if __name__ == "__main__": unittest.main() diff --git a/tests/test_spatial_combine_transforms.py b/tests/test_spatial_combine_transforms.py new file mode 100644 index 0000000000..74c03fc4ff --- /dev/null +++ b/tests/test_spatial_combine_transforms.py @@ -0,0 +1,185 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +import monai.transforms as mt +from monai.data import create_test_image_2d, create_test_image_3d +from monai.data.meta_tensor import MetaTensor +from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.transform import MapTransform +from monai.utils import set_determinism +from tests.lazy_transforms_utils import get_apply_param +from tests.utils import assert_allclose + +TEST_2D = [ + [ + (2, 62, 61), + [ + (mt.Spacing, {"pixdim": (1.2, 1.5), "padding_mode": "zeros", "dtype": torch.float32}), + (mt.Orientation, {"axcodes": "RA"}), + (mt.Resize, {"spatial_size": (64, 48), "mode": "bilinear"}), + (mt.RandSpatialCrop, {"roi_size": (32, 32)}), + ( + mt.RandAffine, + { + "prob": 0.9, + "rotate_range": (np.pi / 2,), + "shear_range": [1, 2], + "translate_range": [2, 1], + "mode": "bilinear", + "padding_mode": "reflection", + }, + ), + (mt.RandFlip, {"prob": 0.9}), + (mt.RandRotate, {"prob": 0.9, "range_x": np.pi / 4, "mode": "bilinear", "padding_mode": "reflection"}), + (mt.CenterScaleCrop, {"roi_scale": (0.96, 0.8)}), + (mt.RandZoom, {"prob": 0.9, "mode": "bilinear", "keep_size": False, "align_corners": False}), + ], + ], + [ + (2, 63, 64), + [ + (mt.CenterScaleCropd, {"roi_scale": (0.96, 0.8), "keys": "img"}), + (mt.RandRotated, {"prob": 0.9, "range_x": np.pi / 4, "keys": "img"}), + ( + mt.RandZoomd, + {"prob": 0.9, "mode": "bilinear", "keep_size": False, "keys": "img", "align_corners": False}, + ), + (mt.Spacingd, {"pixdim": (1.2, 1.5), "padding_mode": "zeros", "dtype": torch.float32, "keys": "img"}), + (mt.RandFlipd, {"prob": 0.9, "keys": "img"}), + ( + mt.RandAffined, + { + "prob": 0.9, + "rotate_range": (np.pi / 2,), + "shear_range": [1, 2], + "translate_range": [2, 1], + "mode": "bilinear", + "keys": "img", + }, + ), + (mt.Orientationd, {"axcodes": "RA", "keys": "img"}), + (mt.Resized, {"spatial_size": (48, 48), "mode": "bilinear", "keys": "img"}), + (mt.RandScaleCropd, {"roi_scale": (0.4, 1.5), "random_size": False, "keys": "img"}), + ], + ], +] + +TEST_3D = [ + [ + (2, 83, 100, 67), + [ + (mt.Orientation, {"axcodes": "RAS"}), + (mt.CenterScaleCrop, {"roi_scale": (1.2, 0.8, 1.0)}), + ( + mt.RandAffine, + { + "prob": 0.9, + "rotate_range": (np.pi / 2,), + "shear_range": [1, 2], + "translate_range": [2, 1], + "mode": "bilinear", + }, + ), + (mt.Spacing, {"pixdim": (0.9, 1.2, 1.0), "padding_mode": "zeros", "dtype": torch.float32}), + (mt.RandSpatialCrop, {"roi_size": (36, 36, 38), "random_size": False}), + (mt.RandZoom, {"prob": 0.9, "mode": "nearest", "keep_size": False}), + (mt.Resize, {"spatial_size": (32, 32, 32), "mode": "nearest"}), + (mt.RandFlip, {"prob": 0.9}), + (mt.RandRotate, {"prob": 0.9, "range_x": np.pi / 4}), + ], + ], + [ + (2, 62, 64, 72), + [ + (mt.RandScaleCropd, {"roi_scale": (0.9, 0.7, 1.1), "random_size": False, "keys": "img"}), + (mt.Spacingd, {"pixdim": (1.2, 1.5, 0.9), "padding_mode": "zeros", "dtype": torch.float32, "keys": "img"}), + (mt.Orientationd, {"axcodes": "RAS", "keys": "img"}), + (mt.Resized, {"spatial_size": (32, 32, 32), "mode": "nearest", "keys": "img"}), + (mt.RandFlipd, {"prob": 0.9, "keys": "img"}), + (mt.CenterScaleCropd, {"roi_scale": (0.96, 0.8, 1.25), "keys": "img"}), + (mt.RandZoomd, {"prob": 0.9, "mode": "nearest", "keep_size": False, "keys": "img"}), + ( + mt.RandAffined, + { + "prob": 0.9, + "rotate_range": (np.pi / 2,), + "shear_range": [1, 2], + "translate_range": [2, 1], + "mode": "bilinear", + "keys": "img", + }, + ), + (mt.RandRotated, {"prob": 0.9, "range_x": np.pi / 4, "keys": "img"}), + ], + ], +] + + +class CombineLazyTest(unittest.TestCase): + @parameterized.expand(TEST_2D + TEST_3D) + def test_combine_transforms(self, input_shape, funcs): + for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: + for seed in [10, 100, 1000, 10000]: + set_determinism(seed=seed) + _funcs = [] + for _func, _params in funcs: + _funcs.append(_func(**_params)) + is_map = isinstance(_funcs[0], MapTransform) + chns, sp_size = input_shape[0], input_shape[1:] + imgs = [] + for _ in range(chns): + if len(sp_size) == 2: + imgs.append(create_test_image_2d(sp_size[0], sp_size[1])[0]) + else: + imgs.append(create_test_image_3d(sp_size[0], sp_size[1], sp_size[2])[0]) + data = np.stack(imgs).astype(float) + im = MetaTensor(data, meta={"a": "b", "affine": np.eye(len(input_shape))}).to(device) + input_data = {"img": im} if is_map else im + # non lazy + non_lazy_result = input_data + for _func in _funcs: + if isinstance(_func, mt.Randomizable): + _func.set_random_state(seed=seed) + non_lazy_result = _func(non_lazy_result) + expected = non_lazy_result["img"] if is_map else non_lazy_result + + # lazy + pending_result = input_data + for _func in _funcs: + _func.lazy_evaluation = True + if isinstance(_func, mt.Randomizable): + _func.set_random_state(seed=seed) + pending_result = _func(pending_result) + pending_result = pending_result["img"] if is_map else pending_result + + assert_allclose(pending_result.peek_pending_affine(), expected.affine, atol=1e-7) + assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:4]) + + # test final result + init_param = funcs[-1][1] + call_param = {} + apply_param = get_apply_param(init_param, call_param) + result = apply_transforms(pending_result, **apply_param)[0] + + match_ratio = np.sum(np.isclose(result.array, expected.array, atol=5e-1)) / np.prod(result.shape) + self.assertGreater(match_ratio, 0.5) # at least half of the images are very close + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_spatial_crop.py b/tests/test_spatial_crop.py index 6fdfbd3f70..9211270577 100644 --- a/tests/test_spatial_crop.py +++ b/tests/test_spatial_crop.py @@ -9,11 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized -from monai.transforms import SpatialCrop +from monai.transforms import CenterScaleCrop, CenterSpatialCrop, SpatialCrop from tests.croppers import CropTest TESTS = [ @@ -23,7 +25,6 @@ [{"roi_start": [0, 0], "roi_end": [2, 2]}, (3, 3, 3, 3), (3, 2, 2, 3)], [{"roi_start": [0, 0, 0, 0, 0], "roi_end": [2, 2, 2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2)], [{"roi_start": [0, 0, 0, 0, 0], "roi_end": [8, 8, 8, 2, 2]}, (3, 3, 3, 3), (3, 3, 3, 3)], - [{"roi_start": [1, 0, 0], "roi_end": [1, 8, 8]}, (3, 3, 3, 3), (3, 0, 3, 3)], [ {"roi_slices": [slice(s, e) for s, e in zip([None, None, None], [None, None, None])]}, (3, 11, 12, 15), @@ -39,6 +40,17 @@ TEST_ERRORS = [[{"roi_slices": [slice(s, e, 2) for s, e in zip([-1, -2, 0], [None, None, 2])]}]] +TEST_LAZY_ERRORS = [[{"roi_start": [1, 0, 0], "roi_end": [1, 8, 8]}, (3, 3, 3, 3), (3, 0, 3, 3)]] + +func1 = {CenterSpatialCrop: {"roi_size": [8, 8, 6]}} +func2 = {SpatialCrop: {"roi_center": [1, 1, 1], "roi_size": [3, 4, 3]}} +func3 = {CenterScaleCrop: {"roi_scale": [0.6, 0.3, -1]}} + +TESTS_COMBINE = [] +TESTS_COMBINE.append([[func1, func2, func3], (3, 10, 10, 8)]) +TESTS_COMBINE.append([[func1, func2], (3, 8, 8, 4)]) +TESTS_COMBINE.append([[func1, func3], (3, 8, 8, 4)]) + class TestSpatialCrop(CropTest): Cropper = SpatialCrop @@ -52,6 +64,19 @@ def test_error(self, input_param): with self.assertRaises(ValueError): SpatialCrop(**input_param) + @parameterized.expand(TESTS) + def test_pending_ops(self, input_param, input_shape, _): + self.crop_test_pending_ops(input_param, input_shape) + + @parameterized.expand(TEST_LAZY_ERRORS) + def test_lazy_error(self, input_param, input_shape, _): + with self.assertRaises(ValueError): + return self.crop_test_pending_ops(input_param, input_shape) + + @parameterized.expand(TESTS_COMBINE) + def test_combine_ops(self, funcs, input_shape): + self.crop_test_combine_ops(funcs, input_shape) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_spatial_cropd.py b/tests/test_spatial_cropd.py index 11f6da0811..0705367568 100644 --- a/tests/test_spatial_cropd.py +++ b/tests/test_spatial_cropd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized @@ -63,6 +65,10 @@ class TestSpatialCropd(CropTest): def test_shape(self, input_param, input_shape, expected_shape, same_area): self.crop_test(input_param, input_shape, expected_shape, same_area) + @parameterized.expand(TESTS) + def test_pending_ops(self, input_param, input_shape, _expected_shape, _same_area): + self.crop_test_pending_ops(input_param, input_shape) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_spatial_pad.py b/tests/test_spatial_pad.py index 5a70c10686..978059760e 100644 --- a/tests/test_spatial_pad.py +++ b/tests/test_spatial_pad.py @@ -9,17 +9,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized -from monai.transforms import SpatialPad +from monai.transforms import BorderPad, DivisiblePad, SpatialPad from tests.padders import PadTest TESTS = [] TESTS.append([{"spatial_size": [3, 4], "method": "end"}, (1, 2, 3), (1, 3, 4)]) TESTS.append([{"spatial_size": [15, 4, -1], "method": "symmetric"}, (3, 8, 8, 4), (3, 15, 8, 4)]) +func1 = {SpatialPad: {"spatial_size": [15, 4, -1], "method": "symmetric"}} +func2 = {BorderPad: {"spatial_border": 2}} +func3 = {DivisiblePad: {"k": 5, "method": "end"}} + +TESTS_COMBINE = [] +TESTS_COMBINE.append([[func1, func2, func3], (3, 8, 8, 4), (3, 20, 15, 10)]) +TESTS_COMBINE.append([[func1, func2], (3, 8, 8, 4), (3, 19, 12, 8)]) +TESTS_COMBINE.append([[func2, func2], (3, 8, 8, 4), (3, 16, 16, 12)]) + class TestSpatialPad(PadTest): Padder = SpatialPad @@ -33,6 +44,14 @@ def test_pad_kwargs(self): unchanged_slices = [slice(None), slice(None, 8), slice(None, 4)] self.pad_test_kwargs(unchanged_slices, **kwargs) + @parameterized.expand(TESTS) + def test_pending_ops(self, input_param, input_shape, _): + self.pad_test_pending_ops(input_param, input_shape) + + @parameterized.expand(TESTS_COMBINE) + def test_combine_ops(self, funcs, input_shape, expected_shape): + self.pad_test_combine_ops(funcs, input_shape, expected_shape) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_spatial_padd.py b/tests/test_spatial_padd.py index 656a731de0..10bf958738 100644 --- a/tests/test_spatial_padd.py +++ b/tests/test_spatial_padd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized @@ -17,10 +19,10 @@ from tests.padders import PadTest TESTS = [ - [{"keys": ["img"], "spatial_size": [15, 8, 8], "method": "symmetric"}, (3, 8, 8, 4), (3, 15, 8, 8)], - [{"keys": ["img"], "spatial_size": [15, 8, 8], "method": "end"}, (3, 8, 8, 4), (3, 15, 8, 8)], - [{"keys": ["img"], "spatial_size": [15, 8, 8], "method": "end"}, (3, 8, 8, 4), (3, 15, 8, 8)], - [{"keys": ["img"], "spatial_size": [15, 8, -1], "method": "end"}, (3, 8, 4, 4), (3, 15, 8, 4)], + [{"keys": ["img"], "spatial_size": [15, 8, 8], "method": "symmetric"}, (3, 8, 8, 5), (3, 15, 8, 8)], + [{"keys": ["img"], "spatial_size": [15, 8, 8], "method": "end"}, (3, 8, 8, 5), (3, 15, 8, 8)], + [{"keys": ["img"], "spatial_size": [15, 8, 8], "method": "end"}, (3, 8, 8, 5), (3, 15, 8, 8)], + [{"keys": ["img"], "spatial_size": [15, 8, -1], "method": "end"}, (3, 8, 5, 4), (3, 15, 8, 4)], ] @@ -32,6 +34,10 @@ def test_pad(self, input_param, input_shape, expected_shape): modes = ["constant", {"constant"}] self.pad_test(input_param, input_shape, expected_shape, modes) + @parameterized.expand(TESTS) + def test_pending_ops(self, input_param, input_shape, _): + self.pad_test_pending_ops(input_param, input_shape) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_spatial_resample.py b/tests/test_spatial_resample.py index 30bf33149b..b513bd0f05 100644 --- a/tests/test_spatial_resample.py +++ b/tests/test_spatial_resample.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy @@ -21,6 +23,7 @@ from monai.data.utils import to_affine_nd from monai.transforms import SpatialResample from monai.utils import optional_import +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_DEVICES, TEST_NDARRAYS_ALL, assert_allclose TESTS = [] @@ -138,9 +141,14 @@ def test_flips(self, img, device, data_param, expected_output): img.affine = torch.eye(4) if hasattr(img, "to"): img = img.to(device) - out = SpatialResample()(img=img, **data_param) + resampler = SpatialResample() + call_param = data_param.copy() + call_param["img"] = img + out = resampler(**call_param) assert_allclose(out, expected_output, rtol=1e-2, atol=1e-2) - assert_allclose(out.affine, data_param["dst_affine"]) + assert_allclose(to_affine_nd(len(out.shape) - 1, out.affine), call_param["dst_affine"]) + + test_resampler_lazy(resampler, out, init_param=None, call_param=call_param) @parameterized.expand(TEST_4_5_D) def test_4d_5d(self, new_shape, tile, device, dtype, expected_data): @@ -150,10 +158,15 @@ def test_4d_5d(self, new_shape, tile, device, dtype, expected_data): dst = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, -1.0, 1.5], [0.0, 0.0, 0.0, 1.0]]) dst = dst.to(dtype) - out = SpatialResample(dtype=dtype, align_corners=True)(img=img, dst_affine=dst, align_corners=False) + init_param = {"dtype": dtype, "align_corners": True} + call_param = {"img": img, "dst_affine": dst, "align_corners": False} + resampler = SpatialResample(**init_param) + out = resampler(**call_param) assert_allclose(out, expected_data[None], rtol=1e-2, atol=1e-2) assert_allclose(out.affine, dst.to(torch.float32), rtol=1e-2, atol=1e-2) + test_resampler_lazy(resampler, out, init_param, call_param) + @parameterized.expand(TEST_DEVICES) def test_ill_affine(self, device): img = MetaTensor(torch.arange(12).reshape(1, 2, 2, 3)).to(device) @@ -180,9 +193,14 @@ def test_input_torch(self, new_shape, tile, device, dtype, expected_data, track_ img = torch.as_tensor(np.tile(img, tile)).to(device) dst = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, -1.0, 1.5], [0.0, 0.0, 0.0, 1.0]]) dst = dst.to(dtype).to(device) - - out = SpatialResample(dtype=dtype)(img=img, dst_affine=dst) + init_param = {"dtype": dtype} + call_param = {"img": img, "dst_affine": dst} + resampler = SpatialResample(**init_param) + out = resampler(**call_param) assert_allclose(out, expected_data[None], rtol=1e-2, atol=1e-2) + + test_resampler_lazy(resampler, out, init_param, call_param) + if track_meta: self.assertIsInstance(out, MetaTensor) assert_allclose(out.affine, dst.to(torch.float32), rtol=1e-2, atol=1e-2) @@ -196,7 +214,7 @@ def test_inverse(self, img, device, data_param, expected_output): tr = SpatialResample() out = tr(img=img, **data_param) assert_allclose(out, expected_output, rtol=1e-2, atol=1e-2) - assert_allclose(out.affine, data_param["dst_affine"]) + assert_allclose(to_affine_nd(len(out.shape) - 1, out.affine), data_param["dst_affine"]) # inverse out = tr.inverse(out) @@ -204,6 +222,14 @@ def test_inverse(self, img, device, data_param, expected_output): expected_affine = to_affine_nd(len(out.affine) - 1, torch.eye(4)) assert_allclose(out.affine, expected_affine) + def test_unchange(self): + for i, p in enumerate(TEST_NDARRAYS_ALL): + set_track_meta(i % 2) + img = p(np.arange(12).reshape(1, 3, 4)) + result = SpatialResample()(img) + assert_allclose(result, img, type_test=False) + set_track_meta(True) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_spatial_resampled.py b/tests/test_spatial_resampled.py index 5ace0b3774..ebe3eb6e4f 100644 --- a/tests/test_spatial_resampled.py +++ b/tests/test_spatial_resampled.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -18,6 +20,7 @@ from monai.data.meta_tensor import MetaTensor from monai.data.utils import to_affine_nd from monai.transforms.spatial.dictionary import SpatialResampled +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_DEVICES, assert_allclose TESTS = [] @@ -88,14 +91,21 @@ class TestSpatialResample(unittest.TestCase): def test_flips_inverse(self, img, device, dst_affine, kwargs, expected_output): img = MetaTensor(img, affine=torch.eye(4)).to(device) data = {"img": img, "dst_affine": dst_affine} - - xform = SpatialResampled(keys="img", **kwargs) - output_data = xform(data) + init_param = kwargs.copy() + init_param["keys"] = "img" + call_param = {"data": data} + xform = SpatialResampled(**init_param) + output_data = xform(**call_param) out = output_data["img"] assert_allclose(out, expected_output, rtol=1e-2, atol=1e-2) - assert_allclose(out.affine, dst_affine, rtol=1e-2, atol=1e-2) + assert_allclose(to_affine_nd(len(out.shape) - 1, out.affine), dst_affine, rtol=1e-2, atol=1e-2) + + # check lazy + lazy_xform = SpatialResampled(**init_param) + test_resampler_lazy(lazy_xform, output_data, init_param, call_param, output_key="img") + # check inverse inverted = xform.inverse(output_data)["img"] self.assertEqual(inverted.applied_operations, []) # no further invert after inverting expected_affine = to_affine_nd(len(out.affine) - 1, torch.eye(4)) diff --git a/tests/test_split_channel.py b/tests/test_split_channel.py index 4b41c334e8..12b336431d 100644 --- a/tests/test_split_channel.py +++ b/tests/test_split_channel.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_split_channeld.py b/tests/test_split_channeld.py index 7a34855676..63b4caf7a5 100644 --- a/tests/test_split_channeld.py +++ b/tests/test_split_channeld.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_split_on_grid.py b/tests/test_split_on_grid.py index b1d4cd93c5..1bf7dbf970 100644 --- a/tests/test_split_on_grid.py +++ b/tests/test_split_on_grid.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_split_on_grid_dict.py b/tests/test_split_on_grid_dict.py index 778a38da34..fdc16b3a0a 100644 --- a/tests/test_split_on_grid_dict.py +++ b/tests/test_split_on_grid_dict.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_splitdim.py b/tests/test_splitdim.py index d6ee4fc55e..6c678a6bc2 100644 --- a/tests/test_splitdim.py +++ b/tests/test_splitdim.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -38,13 +40,12 @@ def test_correct_shape(self, shape, keepdim, im_type): arr[0, 0, 0, 0] *= 2 self.assertEqual(arr.flatten()[0], out[0].flatten()[0]) - def test_error(self): - """Should fail because splitting along singleton dimension""" + def test_singleton(self): shape = (2, 1, 8, 7) for p in TEST_NDARRAYS: arr = p(np.random.rand(*shape)) - with self.assertRaises(RuntimeError): - _ = SplitDim(dim=1)(arr) + out = SplitDim(dim=1)(arr) + self.assertEqual(out[0].shape, shape) if __name__ == "__main__": diff --git a/tests/test_splitdimd.py b/tests/test_splitdimd.py index 85184b494a..b01913269d 100644 --- a/tests/test_splitdimd.py +++ b/tests/test_splitdimd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy @@ -38,7 +40,7 @@ def setUpClass(cls) -> None: affine = make_rand_affine() data = {"i": make_nifti_image(arr, affine)} - loader = LoadImaged("i") + loader = LoadImaged("i", image_only=True) cls.data = loader(data) @parameterized.expand(TESTS) @@ -82,13 +84,12 @@ def test_correct(self, keepdim, im_type, update_meta, list_output): arr[0, 0, 0, 0] *= 2 self.assertEqual(arr.flatten()[0], out.flatten()[0]) - def test_error(self): - """Should fail because splitting along singleton dimension""" + def test_singleton(self): shape = (2, 1, 8, 7) for p in TEST_NDARRAYS: arr = p(np.random.rand(*shape)) - with self.assertRaises(RuntimeError): - _ = SplitDimd("i", dim=1)({"i": arr}) + out = SplitDimd("i", dim=1)({"i": arr}) + self.assertEqual(out["i"].shape, shape) if __name__ == "__main__": diff --git a/tests/test_squeezedim.py b/tests/test_squeezedim.py index a2b538d58c..6673fd25c1 100644 --- a/tests/test_squeezedim.py +++ b/tests/test_squeezedim.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -32,7 +34,6 @@ class TestSqueezeDim(unittest.TestCase): @parameterized.expand(TESTS) def test_shape(self, input_param, test_data, expected_shape): - result = SqueezeDim(**input_param)(test_data) self.assertTupleEqual(result.shape, expected_shape) if "dim" in input_param and input_param["dim"] == 2 and isinstance(result, MetaTensor): @@ -40,7 +41,6 @@ def test_shape(self, input_param, test_data, expected_shape): @parameterized.expand(TESTS_FAIL) def test_invalid_inputs(self, exception, input_param, test_data): - with self.assertRaises(exception): SqueezeDim(**input_param)(test_data) diff --git a/tests/test_squeezedimd.py b/tests/test_squeezedimd.py index 5908e7673f..9fa9d84030 100644 --- a/tests/test_squeezedimd.py +++ b/tests/test_squeezedimd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_ssim_loss.py b/tests/test_ssim_loss.py index 3e1c085069..a4ba66300b 100644 --- a/tests/test_ssim_loss.py +++ b/tests/test_ssim_loss.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch @@ -16,31 +18,34 @@ from monai.losses.ssim_loss import SSIMLoss -x = torch.ones([1, 1, 10, 10]) / 2 -y1 = torch.ones([1, 1, 10, 10]) / 2 -y2 = torch.zeros([1, 1, 10, 10]) -data_range = x.max().unsqueeze(0) TESTS2D = [] for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: - TESTS2D.append((x.to(device), y1.to(device), data_range.to(device), torch.tensor(1.0).unsqueeze(0).to(device))) - TESTS2D.append((x.to(device), y2.to(device), data_range.to(device), torch.tensor(0.0).unsqueeze(0).to(device))) + for batch_size in [1, 2, 16]: + x = torch.ones([batch_size, 1, 10, 10]) / 2 + y1 = torch.ones([batch_size, 1, 10, 10]) / 2 + y2 = torch.zeros([batch_size, 1, 10, 10]) + data_range = x.max().unsqueeze(0) + TESTS2D.append((x.to(device), y1.to(device), data_range.to(device), torch.tensor(1.0).unsqueeze(0).to(device))) + TESTS2D.append((x.to(device), y2.to(device), data_range.to(device), torch.tensor(0.0).unsqueeze(0).to(device))) -x = torch.ones([1, 1, 10, 10, 10]) / 2 -y1 = torch.ones([1, 1, 10, 10, 10]) / 2 -y2 = torch.zeros([1, 1, 10, 10, 10]) -data_range = x.max().unsqueeze(0) TESTS3D = [] for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: - TESTS3D.append((x.to(device), y1.to(device), data_range.to(device), torch.tensor(1.0).unsqueeze(0).to(device))) - TESTS3D.append((x.to(device), y2.to(device), data_range.to(device), torch.tensor(0.0).unsqueeze(0).to(device))) + for batch_size in [1, 2, 16]: + x = torch.ones([batch_size, 1, 10, 10, 10]) / 2 + y1 = torch.ones([batch_size, 1, 10, 10, 10]) / 2 + y2 = torch.zeros([batch_size, 1, 10, 10, 10]) + data_range = x.max().unsqueeze(0) + TESTS3D.append((x.to(device), y1.to(device), data_range.to(device), torch.tensor(1.0).unsqueeze(0).to(device))) + TESTS3D.append((x.to(device), y2.to(device), data_range.to(device), torch.tensor(0.0).unsqueeze(0).to(device))) -x = torch.ones([1, 1, 10, 10]) / 2 -y = torch.ones([1, 1, 10, 10]) / 2 -y.requires_grad_(True) -data_range = x.max().unsqueeze(0) TESTS2D_GRAD = [] for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: - TESTS2D_GRAD.append([x.to(device), y.to(device), data_range.to(device)]) + for batch_size in [1, 2, 16]: + x = torch.ones([batch_size, 1, 10, 10]) / 2 + y = torch.ones([batch_size, 1, 10, 10]) / 2 + y.requires_grad_(True) + data_range = x.max().unsqueeze(0) + TESTS2D_GRAD.append([x.to(device), y.to(device), data_range.to(device)]) class TestSSIMLoss(unittest.TestCase): diff --git a/tests/test_ssim_metric.py b/tests/test_ssim_metric.py index 01c48dd793..5505e5b750 100644 --- a/tests/test_ssim_metric.py +++ b/tests/test_ssim_metric.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_state_cacher.py b/tests/test_state_cacher.py index e4164be272..2037dc3951 100644 --- a/tests/test_state_cacher.py +++ b/tests/test_state_cacher.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import pickle import unittest from os.path import exists, join @@ -36,7 +38,6 @@ class TestStateCacher(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_state_cacher(self, data_obj, params): - key = "data_obj" state_cacher = StateCacher(**params) diff --git a/tests/test_std_shift_intensity.py b/tests/test_std_shift_intensity.py index b8306aa09c..af18c18aa2 100644 --- a/tests/test_std_shift_intensity.py +++ b/tests/test_std_shift_intensity.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_std_shift_intensityd.py b/tests/test_std_shift_intensityd.py index b86f6bd5e6..6cb7d416c7 100644 --- a/tests/test_std_shift_intensityd.py +++ b/tests/test_std_shift_intensityd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_str2bool.py b/tests/test_str2bool.py index e1d9ca1ee3..36f99b4064 100644 --- a/tests/test_str2bool.py +++ b/tests/test_str2bool.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from monai.utils.misc import str2bool diff --git a/tests/test_str2list.py b/tests/test_str2list.py index 95a4dcaef0..b442925fb3 100644 --- a/tests/test_str2list.py +++ b/tests/test_str2list.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from monai.utils.misc import str2list diff --git a/tests/test_subpixel_upsample.py b/tests/test_subpixel_upsample.py index bd46aecb97..a6de8dd846 100644 --- a/tests/test_subpixel_upsample.py +++ b/tests/test_subpixel_upsample.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_surface_dice.py b/tests/test_surface_dice.py index ccc6242e1e..3ee54e5903 100644 --- a/tests/test_surface_dice.py +++ b/tests/test_surface_dice.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -17,13 +19,15 @@ from monai.metrics.surface_dice import SurfaceDiceMetric +_device = "cuda:0" if torch.cuda.is_available() else "cpu" + class TestAllSurfaceDiceMetrics(unittest.TestCase): def test_tolerance_euclidean_distance(self): batch_size = 2 n_class = 2 - predictions = torch.zeros((batch_size, 480, 640), dtype=torch.int64) - labels = torch.zeros((batch_size, 480, 640), dtype=torch.int64) + predictions = torch.zeros((batch_size, 480, 640), dtype=torch.int64, device=_device) + labels = torch.zeros((batch_size, 480, 640), dtype=torch.int64, device=_device) predictions[0, :, 50:] = 1 labels[0, :, 60:] = 1 # 10 px shift predictions_hot = F.one_hot(predictions, num_classes=n_class).permute(0, 3, 1, 2) @@ -36,8 +40,10 @@ def test_tolerance_euclidean_distance(self): res0_nans = sd0_nans(predictions_hot, labels_hot) agg0_nans, not_nans = sd0_nans.aggregate() - np.testing.assert_array_equal(res0, res0_nans) - np.testing.assert_array_equal(agg0, agg0_nans) + np.testing.assert_array_equal(res0.cpu(), res0_nans.cpu()) + np.testing.assert_equal(res0.device, predictions.device) + np.testing.assert_array_equal(agg0.cpu(), agg0_nans.cpu()) + np.testing.assert_equal(agg0.device, predictions.device) res1 = SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, labels_hot) res9 = SurfaceDiceMetric(class_thresholds=[9, 9], include_background=True)(predictions_hot, labels_hot) @@ -49,7 +55,7 @@ def test_tolerance_euclidean_distance(self): assert res0[0, 0] < res1[0, 0] < res9[0, 0] < res10[0, 0] assert res0[0, 1] < res1[0, 1] < res9[0, 1] < res10[0, 1] - np.testing.assert_array_equal(res10, res11) + np.testing.assert_array_equal(res10.cpu(), res11.cpu()) expected_res0 = np.zeros((batch_size, n_class)) expected_res0[0, 1] = 1 - (478 + 480 + 9 * 2) / (480 * 4 + 588 * 2 + 578 * 2) @@ -57,9 +63,9 @@ def test_tolerance_euclidean_distance(self): expected_res0[1, 0] = 1 expected_res0[1, 1] = np.nan for b, c in np.ndindex(batch_size, n_class): - np.testing.assert_allclose(expected_res0[b, c], res0[b, c]) - np.testing.assert_array_equal(agg0, np.nanmean(np.nanmean(expected_res0, axis=1), axis=0)) - np.testing.assert_equal(not_nans, torch.tensor(2)) + np.testing.assert_allclose(expected_res0[b, c], res0[b, c].cpu()) + np.testing.assert_array_equal(agg0.cpu(), np.nanmean(np.nanmean(expected_res0, axis=1), axis=0)) + np.testing.assert_equal(not_nans.cpu(), torch.tensor(2)) def test_tolerance_all_distances(self): batch_size = 1 @@ -275,17 +281,17 @@ def test_not_predicted_not_present(self): # test aggregation res_bgr = sur_metric_bgr.aggregate(reduction="mean") - np.testing.assert_equal(res_bgr, torch.tensor([1 / 3], dtype=torch.float64)) + np.testing.assert_equal(res_bgr, torch.tensor([1 / 3], dtype=torch.float)) res = sur_metric.aggregate() - np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float64)) + np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float)) predictions_empty = torch.zeros((2, 3, 1, 1)) sur_metric_nans = SurfaceDiceMetric(class_thresholds=[1, 1, 1], include_background=True, get_not_nans=True) res_classes = sur_metric_nans(predictions_empty, predictions_empty) res, not_nans = sur_metric_nans.aggregate() np.testing.assert_array_equal(res_classes, [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan]]) - np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float64)) - np.testing.assert_equal(not_nans, torch.tensor([0], dtype=torch.float64)) + np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float)) + np.testing.assert_equal(not_nans, torch.tensor([0], dtype=torch.float)) if __name__ == "__main__": diff --git a/tests/test_surface_distance.py b/tests/test_surface_distance.py index 4cd70b43aa..f2e2ea7144 100644 --- a/tests/test_surface_distance.py +++ b/tests/test_surface_distance.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import Tuple import numpy as np import torch @@ -18,9 +19,11 @@ from monai.metrics import SurfaceDistanceMetric +_device = "cuda:0" if torch.cuda.is_available() else "cpu" + def create_spherical_seg_3d( - radius: float = 20.0, centre: Tuple[int, int, int] = (49, 49, 49), im_shape: Tuple[int, int, int] = (99, 99, 99) + radius: float = 20.0, centre: tuple[int, int, int] = (49, 49, 49), im_shape: tuple[int, int, int] = (99, 99, 99) ) -> np.ndarray: """ Return a 3D image with a sphere inside. Voxel values will be @@ -110,8 +113,8 @@ def test_value(self, input_data, expected_value): [seg_1, seg_2] = input_data metric = "euclidean" ct = 0 - seg_1 = torch.tensor(seg_1) - seg_2 = torch.tensor(seg_2) + seg_1 = torch.tensor(seg_1, device=_device) + seg_2 = torch.tensor(seg_2, device=_device) for symmetric in [True, False]: sur_metric = SurfaceDistanceMetric(include_background=False, symmetric=symmetric, distance_metric=metric) # shape of seg_1, seg_2 are: HWD, converts to BNHWD @@ -121,7 +124,8 @@ def test_value(self, input_data, expected_value): sur_metric(batch_seg_1, batch_seg_2) result = sur_metric.aggregate() expected_value_curr = expected_value[ct] - np.testing.assert_allclose(expected_value_curr, result, rtol=1e-5) + np.testing.assert_allclose(expected_value_curr, result.cpu(), rtol=1e-5) + np.testing.assert_equal(result.device, seg_1.device) ct += 1 @parameterized.expand(TEST_CASES_NANS) diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py index 6188d6225a..636fcc9e31 100644 --- a/tests/test_swin_unetr.py +++ b/tests/test_swin_unetr.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from unittest import skipUnless diff --git a/tests/test_synthetic.py b/tests/test_synthetic.py index 0da7054a10..116897e67d 100644 --- a/tests/test_synthetic.py +++ b/tests/test_synthetic.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_tciadataset.py b/tests/test_tciadataset.py index 301f2caac2..69f88927b9 100644 --- a/tests/test_tciadataset.py +++ b/tests/test_tciadataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import shutil import unittest diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index 93a569186d..20831ca294 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from functools import partial from typing import TYPE_CHECKING diff --git a/tests/test_thread_buffer.py b/tests/test_thread_buffer.py index 87da22eab3..ab5dba77be 100644 --- a/tests/test_thread_buffer.py +++ b/tests/test_thread_buffer.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import sys import time import unittest diff --git a/tests/test_threadcontainer.py b/tests/test_threadcontainer.py index 2419b390fd..ca9fb244fc 100644 --- a/tests/test_threadcontainer.py +++ b/tests/test_threadcontainer.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import time diff --git a/tests/test_threshold_intensity.py b/tests/test_threshold_intensity.py index 3c0a2033ee..7fb28d413f 100644 --- a/tests/test_threshold_intensity.py +++ b/tests/test_threshold_intensity.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_threshold_intensityd.py b/tests/test_threshold_intensityd.py index 8aade12322..d5e7e5f517 100644 --- a/tests/test_threshold_intensityd.py +++ b/tests/test_threshold_intensityd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_tile_on_grid.py b/tests/test_tile_on_grid.py index 4aca252350..06b5fe65f0 100644 --- a/tests/test_tile_on_grid.py +++ b/tests/test_tile_on_grid.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import Optional import numpy as np from parameterized import parameterized @@ -72,11 +73,10 @@ def make_image( tile_size: int, step: int = 0, random_offset: bool = False, - filter_mode: Optional[str] = None, + filter_mode: str | None = None, seed=123, **kwargs, ): - tile_count = int(np.sqrt(tile_count)) pad = 0 if random_offset: @@ -116,7 +116,6 @@ def make_image( class TestTileOnGrid(unittest.TestCase): @parameterized.expand(TESTS) def test_tile_patch_single_call(self, in_type, input_parameters): - img, tiles = make_image(**input_parameters) input_img = in_type(img) @@ -126,7 +125,6 @@ def test_tile_patch_single_call(self, in_type, input_parameters): @parameterized.expand(TESTS2) def test_tile_patch_random_call(self, in_type, input_parameters): - img, tiles = make_image(**input_parameters, seed=123) input_img = in_type(img) diff --git a/tests/test_tile_on_grid_dict.py b/tests/test_tile_on_grid_dict.py index fa94dd0a70..bb8689fd3b 100644 --- a/tests/test_tile_on_grid_dict.py +++ b/tests/test_tile_on_grid_dict.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import Optional import numpy as np import torch @@ -81,11 +82,10 @@ def make_image( tile_size: int, step: int = 0, random_offset: bool = False, - filter_mode: Optional[str] = None, + filter_mode: str | None = None, seed=123, **kwargs, ): - tile_count = int(np.sqrt(tile_count)) pad = 0 if random_offset: @@ -125,7 +125,6 @@ def make_image( class TestTileOnGridDict(unittest.TestCase): @parameterized.expand(TESTS) def test_tile_patch_single_call(self, in_type, input_parameters): - key = "image" input_parameters["keys"] = key @@ -148,7 +147,6 @@ def test_tile_patch_single_call(self, in_type, input_parameters): @parameterized.expand(TESTS2) def test_tile_patch_random_call(self, in_type, input_parameters): - key = "image" input_parameters["keys"] = key diff --git a/tests/test_timedcall_dist.py b/tests/test_timedcall_dist.py index a2b3ae585a..af7cf8720f 100644 --- a/tests/test_timedcall_dist.py +++ b/tests/test_timedcall_dist.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import multiprocessing import sys import time diff --git a/tests/test_to_contiguous.py b/tests/test_to_contiguous.py index a9c2a78278..3a57ae6d8b 100644 --- a/tests/test_to_contiguous.py +++ b/tests/test_to_contiguous.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_to_cupy.py b/tests/test_to_cupy.py index 36edf24f3f..12a377181d 100644 --- a/tests/test_to_cupy.py +++ b/tests/test_to_cupy.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from unittest import skipUnless diff --git a/tests/test_to_cupyd.py b/tests/test_to_cupyd.py index 3e778ae269..e9a3488489 100644 --- a/tests/test_to_cupyd.py +++ b/tests/test_to_cupyd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from unittest import skipUnless diff --git a/tests/test_to_device.py b/tests/test_to_device.py index 9f78119326..cad2b65316 100644 --- a/tests/test_to_device.py +++ b/tests/test_to_device.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_to_deviced.py b/tests/test_to_deviced.py index b3ee490566..093c3b0c4d 100644 --- a/tests/test_to_deviced.py +++ b/tests/test_to_deviced.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_to_from_meta_tensord.py b/tests/test_to_from_meta_tensord.py index 43a0e99081..470826313a 100644 --- a/tests/test_to_from_meta_tensord.py +++ b/tests/test_to_from_meta_tensord.py @@ -9,11 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import random import string import unittest from copy import deepcopy -from typing import Optional, Union import numpy as np import torch @@ -67,7 +68,7 @@ def check( shape: bool = True, vals: bool = True, ids: bool = True, - device: Optional[Union[str, torch.device]] = None, + device: str | torch.device | None = None, meta: bool = True, check_ids: bool = False, **kwargs, diff --git a/tests/test_to_numpy.py b/tests/test_to_numpy.py index e1f135a289..0c604fb9d4 100644 --- a/tests/test_to_numpy.py +++ b/tests/test_to_numpy.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from unittest import skipUnless diff --git a/tests/test_to_numpyd.py b/tests/test_to_numpyd.py index ba7cf798ef..d25bdf14a5 100644 --- a/tests/test_to_numpyd.py +++ b/tests/test_to_numpyd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from unittest import skipUnless diff --git a/tests/test_to_onehot.py b/tests/test_to_onehot.py index c08672bfb2..52307900af 100644 --- a/tests/test_to_onehot.py +++ b/tests/test_to_onehot.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_to_pil.py b/tests/test_to_pil.py index 0a1351028c..e4f74f6e1e 100644 --- a/tests/test_to_pil.py +++ b/tests/test_to_pil.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from typing import TYPE_CHECKING from unittest import skipUnless diff --git a/tests/test_to_pild.py b/tests/test_to_pild.py index d00ecf13d4..4eb5999b15 100644 --- a/tests/test_to_pild.py +++ b/tests/test_to_pild.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from typing import TYPE_CHECKING from unittest import skipUnless diff --git a/tests/test_to_tensor.py b/tests/test_to_tensor.py index cd1a814f21..cde845c246 100644 --- a/tests/test_to_tensor.py +++ b/tests/test_to_tensor.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_to_tensord.py b/tests/test_to_tensord.py index 4c1f2172ae..82456786fd 100644 --- a/tests/test_to_tensord.py +++ b/tests/test_to_tensord.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_torchscript_utils.py b/tests/test_torchscript_utils.py index cdf2f19eb3..ec24f388f1 100644 --- a/tests/test_torchscript_utils.py +++ b/tests/test_torchscript_utils.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_torchvision.py b/tests/test_torchvision.py index 68b9413e65..9cd536aa6f 100644 --- a/tests/test_torchvision.py +++ b/tests/test_torchvision.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from parameterized import parameterized diff --git a/tests/test_torchvision_fc_model.py b/tests/test_torchvision_fc_model.py index d7341bc71e..5f92a1f8b4 100644 --- a/tests/test_torchvision_fc_model.py +++ b/tests/test_torchvision_fc_model.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from unittest import skipUnless diff --git a/tests/test_torchvisiond.py b/tests/test_torchvisiond.py index b72c6f86f1..b2a6bcafc5 100644 --- a/tests/test_torchvisiond.py +++ b/tests/test_torchvisiond.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_traceable_transform.py b/tests/test_traceable_transform.py index bc6aad3a62..42906c84d2 100644 --- a/tests/test_traceable_transform.py +++ b/tests/test_traceable_transform.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from monai.transforms.inverse import TraceableTransform @@ -28,6 +30,8 @@ class TestTraceable(unittest.TestCase): def test_default(self): expected_key = "_transforms" a = _TraceTest() + for x in a.transform_info_keys(): + self.assertTrue(x in a.get_transform_info()) self.assertEqual(a.trace_key(), expected_key) data = {"image": "test"} diff --git a/tests/test_train_mode.py b/tests/test_train_mode.py index 231e3854f0..6136e2f7db 100644 --- a/tests/test_train_mode.py +++ b/tests/test_train_mode.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_trainable_bilateral.py b/tests/test_trainable_bilateral.py new file mode 100644 index 0000000000..43b628be80 --- /dev/null +++ b/tests/test_trainable_bilateral.py @@ -0,0 +1,474 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized +from torch.autograd import gradcheck + +from monai.networks.layers.filtering import TrainableBilateralFilterFunction +from tests.utils import skip_if_no_cpp_extension, skip_if_no_cuda + +TEST_CASES = [ + [ + # Case Description + "1 dimension, 1 channel, low spatial sigmas, low color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (1.0, 1.0, 1.0, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.999997, 0.000001, 0.000000, 0.000001, 0.999997] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000001, 0.999995, 0.000001, 0.000000] + ], + ], + ], + [ + # Case Description + "1 dimension, 1 channel, low spatial sigmas, high color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (1, 1, 1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.714200, 0.158126, 0.061890, 0.158126, 0.714200] + ], + # Batch 1 + [ + # Channel 0 + [0.043465, 0.158126, 0.555452, 0.158126, 0.043465] + ], + ], + ], + [ + # Case Description + "1 dimension, 1 channel, high spatial sigmas, low color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (4, 4, 4, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.999994, 0.000002, 0.000002, 0.000002, 0.999994] + ], + # Batch 1 + [ + # Channel 0 + [0.000001, 0.000001, 0.999986, 0.000001, 0.000001] + ], + ], + ], + [ + # Case Description + "1 dimension, 1 channel, high spatial sigmas, high color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (4, 4, 4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.533282, 0.245915, 0.244711, 0.245915, 0.533282] + ], + # Batch 1 + [ + # Channel 0 + [0.125052, 0.126608, 0.333592, 0.126608, 0.125052] + ], + ], + ], + [ + # Case Description + "2 dimensions, 1 channel, high spatial sigmas, high color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (4, 4, 4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]] + ], + # Batch 1 + [ + # Channel 0 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.239789, 0.082990, 0.082630, 0.082990, 0.239789], + [0.082990, 0.081934, 0.081579, 0.081934, 0.082990], + [0.082630, 0.081579, 0.081225, 0.081579, 0.082630], + [0.082990, 0.081934, 0.081579, 0.081934, 0.082990], + [0.239789, 0.082990, 0.082630, 0.082990, 0.239789], + ] + ], + # Batch 1 + [ + # Channel 0 + [ + [0.024155, 0.024432, 0.024525, 0.024432, 0.024155], + [0.024432, 0.024712, 0.024806, 0.024712, 0.024432], + [0.024525, 0.024806, 0.080686, 0.024806, 0.024525], + [0.024432, 0.024712, 0.024806, 0.024712, 0.024432], + [0.024155, 0.024432, 0.024525, 0.024432, 0.024155], + ] + ], + ], + ], + [ + # Case Description + "3 dimensions, 1 channel, high spatial sigmas, high color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (4, 4, 4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + # Frame 1 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 2 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 3 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 4 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + ] + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [ + [0.098142, 0.030317, 0.030191, 0.030316, 0.098142], + [0.030316, 0.029947, 0.029822, 0.029947, 0.030316], + [0.030191, 0.029822, 0.029698, 0.029822, 0.030191], + [0.030316, 0.029947, 0.029822, 0.029947, 0.030316], + [0.098142, 0.030317, 0.030191, 0.030317, 0.098142], + ], + # Frame 1 + [ + [0.030316, 0.029947, 0.029822, 0.029947, 0.030316], + [0.029947, 0.029581, 0.029458, 0.029581, 0.029947], + [0.029822, 0.029458, 0.029336, 0.029458, 0.029822], + [0.029947, 0.029581, 0.029458, 0.029581, 0.029947], + [0.030316, 0.029947, 0.029822, 0.029947, 0.030316], + ], + # Frame 2 + [ + [0.030191, 0.029822, 0.029698, 0.029822, 0.030191], + [0.029822, 0.029458, 0.029336, 0.029458, 0.029822], + [0.029698, 0.029336, 0.029214, 0.029336, 0.029698], + [0.029822, 0.029458, 0.029336, 0.029458, 0.029822], + [0.030191, 0.029822, 0.029698, 0.029822, 0.030191], + ], + # Frame 3 + [ + [0.030316, 0.029947, 0.029822, 0.029947, 0.030317], + [0.029947, 0.029581, 0.029458, 0.029581, 0.029947], + [0.029822, 0.029458, 0.029336, 0.029458, 0.029822], + [0.029947, 0.029581, 0.029458, 0.029581, 0.029947], + [0.030316, 0.029947, 0.029822, 0.029947, 0.030316], + ], + # Frame 4 + [ + [0.098142, 0.030317, 0.030191, 0.030317, 0.098142], + [0.030317, 0.029947, 0.029822, 0.029947, 0.030316], + [0.030191, 0.029822, 0.029698, 0.029822, 0.030191], + [0.030317, 0.029947, 0.029822, 0.029947, 0.030316], + [0.098142, 0.030317, 0.030191, 0.030316, 0.098142], + ], + ] + ] + ], + ], +] + + +@skip_if_no_cpp_extension +class BilateralFilterTestCaseCpuPrecise(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cpu_precise(self, test_case_description, sigmas, input, expected): + # Params to determine the implementation to test + device = torch.device("cpu") + + # Create input tensor and apply filter + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device) + + len_input = len(input_tensor.shape) + # C++ extension so far only supports 5-dim inputs. + if len_input == 3: + input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) + elif len_input == 4: + input_tensor = input_tensor.unsqueeze(4) + + output = TrainableBilateralFilterFunction.apply(input_tensor, *sigmas).cpu().numpy() + + # Make sure to return tensor of the same shape as the input. + if len_input == 3: + output = output.squeeze(4).squeeze(3) + elif len_input == 4: + output = output.squeeze(4) + + # Ensure result are as expected. + np.testing.assert_allclose(output, expected, atol=1e-5) + + @parameterized.expand(TEST_CASES) + def test_cpu_precise_backwards(self, test_case_description, sigmas, input, expected): + # Params to determine the implementation to test + device = torch.device("cpu") + + # Prepare input tensor + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device) + input_tensor.requires_grad = True + + # C++ extension so far only supports 5-dim inputs. + len_input = len(input_tensor.shape) + if len_input == 3: + input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) + elif len_input == 4: + input_tensor = input_tensor.unsqueeze(4) + + # Check gradient toward input. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-6, atol=1e-5, raise_exception=False) + input_tensor = input_tensor.detach() + input_tensor.requires_grad = False + + # Check gradient toward sigma_x. + args = ( + input_tensor, + torch.tensor(sigmas[0], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_y. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_z. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_color. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3], dtype=torch.double, requires_grad=True), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-3, atol=1e-3, raise_exception=False) + + +@skip_if_no_cuda +@skip_if_no_cpp_extension +class BilateralFilterTestCaseCudaPrecise(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cuda_precise(self, test_case_description, sigmas, input, expected): + # Skip this test + if not torch.cuda.is_available(): + return + + # Params to determine the implementation to test + device = torch.device("cuda") + + # Create input tensor and apply filter + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device) + + len_input = len(input_tensor.shape) + # C++ extension so far only supports 5-dim inputs. + if len_input == 3: + input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) + elif len_input == 4: + input_tensor = input_tensor.unsqueeze(4) + + output = TrainableBilateralFilterFunction.apply(input_tensor, *sigmas).cpu().numpy() + + # Make sure to return tensor of the same shape as the input. + if len_input == 3: + output = output.squeeze(4).squeeze(3) + elif len_input == 4: + output = output.squeeze(4) + + # Ensure result are as expected. + np.testing.assert_allclose(output, expected, atol=1e-5) + + @parameterized.expand(TEST_CASES) + def test_cuda_precise_backwards(self, test_case_description, sigmas, input, expected): + # Params to determine the implementation to test + device = torch.device("cuda") + + # Prepare input tensor + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device) + input_tensor.requires_grad = True + + # C++ extension so far only supports 5-dim inputs. + len_input = len(input_tensor.shape) + if len_input == 3: + input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) + elif len_input == 4: + input_tensor = input_tensor.unsqueeze(4) + + # Check gradient toward input. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-6, atol=1e-5, raise_exception=False) + input_tensor = input_tensor.detach() + input_tensor.requires_grad = False + + # Check gradient toward sigma_x. + args = ( + input_tensor, + torch.tensor(sigmas[0], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_y. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_z. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_color. + args = ( + input_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3], dtype=torch.double, requires_grad=True), + ) + gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-3, atol=1e-3, raise_exception=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_trainable_joint_bilateral.py b/tests/test_trainable_joint_bilateral.py new file mode 100644 index 0000000000..a42510b7c6 --- /dev/null +++ b/tests/test_trainable_joint_bilateral.py @@ -0,0 +1,608 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized +from torch.autograd import gradcheck + +from monai.networks.layers.filtering import TrainableJointBilateralFilterFunction +from tests.utils import skip_if_no_cpp_extension, skip_if_no_cuda + +TEST_CASES = [ + [ + # Case Description + "1 dimension, 1 channel, low spatial sigmas, low color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (1.0, 1.0, 1.0, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Guide + [ + # Batch 0 + [ + # Channel 0 + [1, 1, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 1] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.622459, 0.377540, 0.000001, 0.000001, 0.999997] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000001, 0.880793, 0.000002, 0.119203] + ], + ], + ], + [ + # Case Description + "1 dimension, 1 channel, low spatial sigmas, high color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (1, 1, 1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Guide + [ + # Batch 0 + [ + # Channel 0 + [1, 1, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 1] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.595404, 0.302253, 0.070203, 0.163038, 0.714200] + ], + # Batch 1 + [ + # Channel 0 + [0.043465, 0.158126, 0.536864, 0.182809, 0.092537] + ], + ], + ], + [ + # Case Description + "1 dimension, 1 channel, high spatial sigmas, low color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (4, 4, 4, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Guide + [ + # Batch 0 + [ + # Channel 0 + [1, 1, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 1] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.623709, 0.632901, 0.000003, 0.000003, 0.680336] + ], + # Batch 1 + [ + # Channel 0 + [0.000001, 0.000001, 0.531206, 0.000001, 0.468788] + ], + ], + ], + [ + # Case Description + "1 dimension, 1 channel, high spatial sigmas, high color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (4, 4, 4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Guide + [ + # Batch 0 + [ + # Channel 0 + [1, 1, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 1] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.464455, 0.463098, 0.276430, 0.275530, 0.478105] + ], + # Batch 1 + [ + # Channel 0 + [0.134956, 0.138247, 0.293759, 0.141954, 0.281082] + ], + ], + ], + [ + # Case Description + "2 dimensions, 1 channel, high spatial sigmas, high color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (4, 4, 4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]] + ], + # Batch 1 + [ + # Channel 0 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] + ], + ], + # Guide + [ + # Batch 0 + [ + # Channel 0 + [[1, 1, 0, 0, 1], [0, 0, 0, 1, 0], [1, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 1]] + ], + # Batch 1 + [ + # Channel 0 + [[0, 0, 0, 1, 0], [0, 0, 0, 1, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0]] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.186535, 0.187357, 0.105377, 0.103652, 0.198665], + [0.112617, 0.108847, 0.105970, 0.189602, 0.102954], + [0.178338, 0.179829, 0.107473, 0.105256, 0.103963], + [0.117651, 0.113304, 0.109876, 0.107392, 0.105853], + [0.121557, 0.177689, 0.113150, 0.110388, 0.192877], + ] + ], + # Batch 1 + [ + # Channel 0 + [ + [0.047156, 0.047865, 0.048233, 0.038611, 0.047911], + [0.047607, 0.048292, 0.048633, 0.039251, 0.038611], + [0.047715, 0.048369, 0.048678, 0.048633, 0.048233], + [0.047477, 0.048094, 0.048369, 0.048292, 0.047865], + [0.039190, 0.047477, 0.047715, 0.047607, 0.047156], + ] + ], + ], + ], + [ + # Case Description + "3 dimensions, 1 channel, high spatial sigmas, high color sigma", + # (sigma_x, sigma_y, sigma_z, color_sigma) + (4, 4, 4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + # Frame 1 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 2 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 3 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 4 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + ] + ] + ], + # Guide + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [[1, 1, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 1]], + # Frame 1 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 2 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 3 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 4 + [[1, 1, 0, 1, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 1, 1]], + ] + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [ + [0.089316, 0.088903, 0.033707, 0.033461, 0.091881], + [0.035173, 0.034324, 0.033747, 0.033448, 0.033427], + [0.035619, 0.034710, 0.034074, 0.033720, 0.033646], + [0.036364, 0.035387, 0.034688, 0.034275, 0.034148], + [0.037401, 0.085687, 0.035583, 0.035109, 0.089741], + ], + # Frame 1 + [ + [0.034248, 0.033502, 0.033023, 0.032816, 0.032881], + [0.034339, 0.033546, 0.033020, 0.032767, 0.032785], + [0.034721, 0.033876, 0.033298, 0.032994, 0.032965], + [0.035397, 0.034490, 0.033856, 0.033501, 0.033424], + [0.036357, 0.035383, 0.034688, 0.034279, 0.034155], + ], + # Frame 2 + [ + [0.033748, 0.033047, 0.032609, 0.032441, 0.032541], + [0.033782, 0.033041, 0.032562, 0.032353, 0.032410], + [0.034104, 0.033316, 0.032792, 0.032538, 0.032554], + [0.034714, 0.033872, 0.033298, 0.032998, 0.032972], + [0.035604, 0.034702, 0.034074, 0.033727, 0.033660], + ], + # Frame 3 + [ + [0.033533, 0.032871, 0.032471, 0.032340, 0.032476], + [0.033511, 0.032815, 0.032380, 0.032212, 0.032310], + [0.033775, 0.033037, 0.032562, 0.032356, 0.032417], + [0.034324, 0.033539, 0.033020, 0.032774, 0.032799], + [0.035151, 0.034313, 0.033747, 0.033459, 0.033449], + ], + # Frame 4 + [ + [0.091383, 0.090681, 0.032608, 0.091418, 0.092851], + [0.033525, 0.032867, 0.032471, 0.032344, 0.032483], + [0.033733, 0.033039, 0.032609, 0.032448, 0.032555], + [0.034226, 0.033491, 0.033023, 0.032827, 0.032903], + [0.089445, 0.034216, 0.033707, 0.090126, 0.091748], + ], + ] + ] + ], + ], +] + + +@skip_if_no_cpp_extension +class JointBilateralFilterTestCaseCpuPrecise(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cpu_precise(self, test_case_description, sigmas, input, guide, expected): + # Params to determine the implementation to test + device = torch.device("cpu") + + # Create input tensor and apply filter + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device) + guide_tensor = torch.from_numpy(np.array(guide)).to(dtype=torch.double, device=device) + + len_input = len(input_tensor.shape) + # C++ extension so far only supports 5-dim inputs. + if len_input == 3: + input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) + guide_tensor = guide_tensor.unsqueeze(3).unsqueeze(4) + elif len_input == 4: + input_tensor = input_tensor.unsqueeze(4) + guide_tensor = guide_tensor.unsqueeze(4) + + output = TrainableJointBilateralFilterFunction.apply(input_tensor, guide_tensor, *sigmas).cpu().numpy() + + # Make sure to return tensor of the same shape as the input. + if len_input == 3: + output = output.squeeze(4).squeeze(3) + elif len_input == 4: + output = output.squeeze(4) + + # Ensure result are as expected. + np.testing.assert_allclose(output, expected, atol=1e-5) + + @parameterized.expand(TEST_CASES) + def test_cpu_precise_backwards(self, test_case_description, sigmas, input, guide, expected): + # Params to determine the implementation to test + device = torch.device("cpu") + + # Prepare input tensor + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device) + input_tensor.requires_grad = True + guide_tensor = torch.from_numpy(np.array(guide)).to(dtype=torch.double, device=device) + + # C++ extension so far only supports 5-dim inputs. + len_input = len(input_tensor.shape) + if len_input == 3: + input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) + guide_tensor = guide_tensor.unsqueeze(3).unsqueeze(4) + elif len_input == 4: + input_tensor = input_tensor.unsqueeze(4) + guide_tensor = guide_tensor.unsqueeze(4) + + # Check gradient toward input. + args = ( + input_tensor, + guide_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-6, atol=1e-5, raise_exception=False) + input_tensor = input_tensor.detach() + input_tensor.requires_grad = False + + # Check gradient toward guide. + guide_tensor.requires_grad = True + args = ( + input_tensor, + guide_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-6, atol=1e-5, raise_exception=False) + guide_tensor = guide_tensor.detach() + guide_tensor.guide_tensor = False + + # Check gradient toward sigma_x. + args = ( + input_tensor, + guide_tensor, + torch.tensor(sigmas[0], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_y. + args = ( + input_tensor, + guide_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_z. + args = ( + input_tensor, + guide_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_color. + args = ( + input_tensor, + guide_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3], dtype=torch.double, requires_grad=True), + ) + gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-3, atol=1e-3, raise_exception=False) + + +@skip_if_no_cuda +@skip_if_no_cpp_extension +class JointBilateralFilterTestCaseCudaPrecise(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cuda_precise(self, test_case_description, sigmas, input, guide, expected): + # Skip this test + if not torch.cuda.is_available(): + return + + # Params to determine the implementation to test + device = torch.device("cuda") + + # Create input tensor and apply filter + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device) + guide_tensor = torch.from_numpy(np.array(guide)).to(dtype=torch.double, device=device) + + len_input = len(input_tensor.shape) + # C++ extension so far only supports 5-dim inputs. + if len_input == 3: + input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) + guide_tensor = guide_tensor.unsqueeze(3).unsqueeze(4) + elif len_input == 4: + input_tensor = input_tensor.unsqueeze(4) + guide_tensor = guide_tensor.unsqueeze(4) + + output = TrainableJointBilateralFilterFunction.apply(input_tensor, guide_tensor, *sigmas).cpu().numpy() + + # Make sure to return tensor of the same shape as the input. + if len_input == 3: + output = output.squeeze(4).squeeze(3) + elif len_input == 4: + output = output.squeeze(4) + + # Ensure result are as expected. + np.testing.assert_allclose(output, expected, atol=1e-5) + + @parameterized.expand(TEST_CASES) + def test_cuda_precise_backwards(self, test_case_description, sigmas, input, guide, expected): + # Params to determine the implementation to test + device = torch.device("cuda") + + # Prepare input tensor + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device) + input_tensor.requires_grad = True + guide_tensor = torch.from_numpy(np.array(guide)).to(dtype=torch.double, device=device) + + # C++ extension so far only supports 5-dim inputs. + len_input = len(input_tensor.shape) + if len_input == 3: + input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) + guide_tensor = guide_tensor.unsqueeze(3).unsqueeze(4) + elif len_input == 4: + input_tensor = input_tensor.unsqueeze(4) + guide_tensor = guide_tensor.unsqueeze(4) + + # Check gradient toward input. + args = ( + input_tensor, + guide_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-6, atol=1e-5, raise_exception=False) + input_tensor = input_tensor.detach() + input_tensor.requires_grad = False + + # Check gradient toward guide. + guide_tensor.requires_grad = True + args = ( + input_tensor, + guide_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-6, atol=1e-5, raise_exception=False) + guide_tensor = guide_tensor.detach() + guide_tensor.guide_tensor = False + + # Check gradient toward sigma_x. + args = ( + input_tensor, + guide_tensor, + torch.tensor(sigmas[0], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_y. + args = ( + input_tensor, + guide_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_z. + args = ( + input_tensor, + guide_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2], dtype=torch.double, requires_grad=True), + torch.tensor(sigmas[3]), + ) + gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False) + + # Check gradient toward sigma_color. + args = ( + input_tensor, + guide_tensor, + torch.tensor(sigmas[0]), + torch.tensor(sigmas[1]), + torch.tensor(sigmas[2]), + torch.tensor(sigmas[3], dtype=torch.double, requires_grad=True), + ) + gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-3, atol=1e-3, raise_exception=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_transchex.py b/tests/test_transchex.py index 713bc35f56..9ad847cdaa 100644 --- a/tests/test_transchex.py +++ b/tests/test_transchex.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_transform.py b/tests/test_transform.py index a6c5001147..57903ab88c 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest diff --git a/tests/test_transformerblock.py b/tests/test_transformerblock.py index d6131d010c..f1a20b842c 100644 --- a/tests/test_transformerblock.py +++ b/tests/test_transformerblock.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -23,7 +25,6 @@ for hidden_size in [360, 480, 600, 768]: for num_heads in [4, 8, 12]: for mlp_dim in [1024, 3072]: - test_case = [ { "hidden_size": hidden_size, diff --git a/tests/test_transpose.py b/tests/test_transpose.py index 94a5b49c3a..0c9ae1c7e3 100644 --- a/tests/test_transpose.py +++ b/tests/test_transpose.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_transposed.py b/tests/test_transposed.py index 14e62eb9da..ab80520fc9 100644 --- a/tests/test_transposed.py +++ b/tests/test_transposed.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest from copy import deepcopy diff --git a/tests/test_tversky_loss.py b/tests/test_tversky_loss.py index 22f57cc8c6..d1175f40c5 100644 --- a/tests/test_tversky_loss.py +++ b/tests/test_tversky_loss.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_unet.py b/tests/test_unet.py index a90e32230b..9cb4af3379 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_unetr.py b/tests/test_unetr.py index b233c72f24..406d30aa12 100644 --- a/tests/test_unetr.py +++ b/tests/test_unetr.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_unetr_block.py b/tests/test_unetr_block.py index 8a4ee3a163..60004be25e 100644 --- a/tests/test_unetr_block.py +++ b/tests/test_unetr_block.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_unified_focal_loss.py b/tests/test_unified_focal_loss.py index 1a7bb91059..0e7217e2b4 100644 --- a/tests/test_unified_focal_loss.py +++ b/tests/test_unified_focal_loss.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_upsample_block.py b/tests/test_upsample_block.py index 535ad80c11..a82a31b064 100644 --- a/tests/test_upsample_block.py +++ b/tests/test_upsample_block.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py index 7041a09f52..619ae8aee3 100644 --- a/tests/test_utils_pytorch_numpy_unification.py +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_varautoencoder.py b/tests/test_varautoencoder.py index a6315ebc63..b050983d2c 100644 --- a/tests/test_varautoencoder.py +++ b/tests/test_varautoencoder.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_varnet.py b/tests/test_varnet.py index c715e7d37f..3ec6b0f087 100644 --- a/tests/test_varnet.py +++ b/tests/test_varnet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_version_leq.py b/tests/test_version_leq.py index 725c1ee128..ef9e70ad86 100644 --- a/tests/test_version_leq.py +++ b/tests/test_version_leq.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import itertools import unittest diff --git a/tests/test_video_datasets.py b/tests/test_video_datasets.py index 78e015e350..eedbe212eb 100644 --- a/tests/test_video_datasets.py +++ b/tests/test_video_datasets.py @@ -9,9 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest -from typing import Type, Union import torch @@ -31,8 +32,8 @@ class Base: class TestVideoDataset(unittest.TestCase): - video_source: Union[int, str] - ds: Type[VideoDataset] + video_source: int | str + ds: type[VideoDataset] def get_video_source(self): return self.video_source diff --git a/tests/test_vis_cam.py b/tests/test_vis_cam.py index 2137926424..bb3ff7237a 100644 --- a/tests/test_vis_cam.py +++ b/tests/test_vis_cam.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_vis_gradbased.py b/tests/test_vis_gradbased.py index 5af8769872..0fbe328c83 100644 --- a/tests/test_vis_gradbased.py +++ b/tests/test_vis_gradbased.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_vis_gradcam.py b/tests/test_vis_gradcam.py index 8ec5e2c913..f5ba188082 100644 --- a/tests/test_vis_gradcam.py +++ b/tests/test_vis_gradcam.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest -from typing import Any, List +from typing import Any import numpy as np import torch @@ -28,8 +30,8 @@ def __call__(self, x, adjoint_info): return super().__call__(x) -TESTS: List[Any] = [] -TESTS_ILL: List[Any] = [] +TESTS: list[Any] = [] +TESTS_ILL: list[Any] = [] for cam in (GradCAM, GradCAMpp): # 2D diff --git a/tests/test_vit.py b/tests/test_vit.py index 33a7902fad..504c1ccebd 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch @@ -16,7 +18,7 @@ from monai.networks import eval_mode from monai.networks.nets.vit import ViT -from tests.utils import SkipIfBeforePyTorchVersion, test_script_save +from tests.utils import SkipIfBeforePyTorchVersion, skip_if_quick, test_script_save TEST_CASE_Vit = [] for dropout_rate in [0.6]: @@ -57,6 +59,7 @@ TEST_CASE_Vit.append(test_case) +@skip_if_quick class TestViT(unittest.TestCase): @parameterized.expand(TEST_CASE_Vit) def test_shape(self, input_param, input_shape, expected_shape): diff --git a/tests/test_vitautoenc.py b/tests/test_vitautoenc.py index f7685f946c..f9f489b9b9 100644 --- a/tests/test_vitautoenc.py +++ b/tests/test_vitautoenc.py @@ -8,6 +8,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_vnet.py b/tests/test_vnet.py index add0396bd8..633893ce51 100644 --- a/tests/test_vnet.py +++ b/tests/test_vnet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_vote_ensemble.py b/tests/test_vote_ensemble.py index 79868d4706..32ff120c5d 100644 --- a/tests/test_vote_ensemble.py +++ b/tests/test_vote_ensemble.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_vote_ensembled.py b/tests/test_vote_ensembled.py index e42a57f3b7..17f9d54835 100644 --- a/tests/test_vote_ensembled.py +++ b/tests/test_vote_ensembled.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_warp.py b/tests/test_warp.py index 31f3540c9e..e614973f90 100644 --- a/tests/test_warp.py +++ b/tests/test_warp.py @@ -8,6 +8,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest @@ -20,7 +22,13 @@ from monai.networks.blocks.warp import Warp from monai.transforms import LoadImaged from monai.utils import GridSampleMode, GridSamplePadMode -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule, download_url_or_skip_test, testing_data_config +from tests.utils import ( + SkipIfBeforePyTorchVersion, + SkipIfNoModule, + download_url_or_skip_test, + skip_if_quick, + testing_data_config, +) LOW_POWER_TEST_CASES = [ # run with BUILD_MONAI=1 to test csrc/resample, BUILD_MONAI=0 to test native grid_sample [ @@ -96,6 +104,7 @@ TEST_CASES += CPP_TEST_CASES +@skip_if_quick class TestWarp(unittest.TestCase): def setUp(self): config = testing_data_config("images", "Prostate_T2W_AX_1") diff --git a/tests/test_watershed.py b/tests/test_watershed.py index 705ddce817..a5a232ba3c 100644 --- a/tests/test_watershed.py +++ b/tests/test_watershed.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_watershedd.py b/tests/test_watershedd.py index 6e04409e1d..c12f5ad140 100644 --- a/tests/test_watershedd.py +++ b/tests/test_watershedd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_weight_init.py b/tests/test_weight_init.py index c850ff4ce6..376faacc56 100644 --- a/tests/test_weight_init.py +++ b/tests/test_weight_init.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_weighted_random_sampler_dist.py b/tests/test_weighted_random_sampler_dist.py index d5322b6482..d38bab54f0 100644 --- a/tests/test_weighted_random_sampler_dist.py +++ b/tests/test_weighted_random_sampler_dist.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_with_allow_missing_keys.py b/tests/test_with_allow_missing_keys.py index 36d5c0c843..ec55654f07 100644 --- a/tests/test_with_allow_missing_keys.py +++ b/tests/test_with_allow_missing_keys.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_write_metrics_reports.py b/tests/test_write_metrics_reports.py index 769803682a..4f61e43fe1 100644 --- a/tests/test_write_metrics_reports.py +++ b/tests/test_write_metrics_reports.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import csv import os import tempfile diff --git a/tests/test_wsireader.py b/tests/test_wsireader.py index 45f11bb138..e60661b25e 100644 --- a/tests/test_wsireader.py +++ b/tests/test_wsireader.py @@ -9,12 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import unittest -from typing import Any, Tuple +from typing import Any from unittest import skipUnless import numpy as np +import torch from numpy.testing import assert_array_equal from parameterized import parameterized @@ -25,7 +28,7 @@ from monai.transforms import Compose, LoadImaged, ToTensord from monai.utils import deprecated, first, optional_import from monai.utils.enums import PostFix, WSIPatchKeys -from tests.utils import assert_allclose, download_url_or_skip_test, testing_data_config +from tests.utils import assert_allclose, download_url_or_skip_test, skip_if_no_cuda, testing_data_config cucim, has_cucim = optional_import("cucim") has_cucim = has_cucim and hasattr(cucim, "CuImage") @@ -42,10 +45,13 @@ HEIGHT = 32914 WIDTH = 46000 -TEST_CASE_0 = [FILE_PATH, 2, (3, HEIGHT // 4, WIDTH // 4)] +TEST_CASE_WHOLE_0 = [FILE_PATH, 2, (3, HEIGHT // 4, WIDTH // 4)] TEST_CASE_TRANSFORM_0 = [FILE_PATH, 4, (HEIGHT // 16, WIDTH // 16), (1, 3, HEIGHT // 16, WIDTH // 16)] +# ---------------------------------------------------------------------------- +# Test cases for *deprecated* monai.data.image_reader.WSIReader +# ---------------------------------------------------------------------------- TEST_CASE_DEP_1 = [ FILE_PATH, {"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0}, @@ -54,8 +60,8 @@ TEST_CASE_DEP_2 = [ FILE_PATH, - {"location": (0, 0), "size": (2, 1), "level": 2}, - np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), + {"location": (0, 0), "size": (2, 1), "level": 8}, + np.array([[[242], [242]], [[242], [242]], [[242], [242]]]), ] TEST_CASE_DEP_3 = [ @@ -81,48 +87,144 @@ np.array([[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[237, 237], [237, 237]]]), ] +# ---------------------------------------------------------------------------- +# Test cases for monai.data.wsi_reader.WSIReader +# ---------------------------------------------------------------------------- + +TEST_CASE_0 = [ + FILE_PATH, + {"level": 8, "dtype": None}, + {"location": (0, 0), "size": (2, 1)}, + np.array([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=np.float64), +] + TEST_CASE_1 = [ FILE_PATH, {}, {"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0}, - np.array([[[246], [246]], [[246], [246]], [[246], [246]]]), + np.array([[[246], [246]], [[246], [246]], [[246], [246]]], dtype=np.uint8), ] TEST_CASE_2 = [ FILE_PATH, {}, - {"location": (0, 0), "size": (2, 1), "level": 2}, - np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), + {"location": (0, 0), "size": (2, 1), "level": 8}, + np.array([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=np.uint8), ] TEST_CASE_3 = [ FILE_PATH, {"channel_dim": -1}, {"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0}, - np.moveaxis(np.array([[[246], [246]], [[246], [246]], [[246], [246]]]), 0, -1), + np.moveaxis(np.array([[[246], [246]], [[246], [246]], [[246], [246]]], dtype=np.uint8), 0, -1), ] TEST_CASE_4 = [ FILE_PATH, {"channel_dim": 2}, - {"location": (0, 0), "size": (2, 1), "level": 2}, - np.moveaxis(np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), 0, -1), + {"location": (0, 0), "size": (2, 1), "level": 8}, + np.moveaxis(np.array([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=np.uint8), 0, -1), ] TEST_CASE_5 = [ FILE_PATH, - {"level": 2}, + {"level": 8}, + {"location": (0, 0), "size": (2, 1)}, + np.array([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=np.uint8), +] + +TEST_CASE_6 = [ + FILE_PATH, + {"level": 8, "dtype": np.int32}, + {"location": (0, 0), "size": (2, 1)}, + np.array([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=np.int32), +] + +TEST_CASE_7 = [ + FILE_PATH, + {"level": 8, "dtype": np.float32}, + {"location": (0, 0), "size": (2, 1)}, + np.array([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=np.float32), +] + +TEST_CASE_8 = [ + FILE_PATH, + {"level": 8, "dtype": torch.uint8}, + {"location": (0, 0), "size": (2, 1)}, + torch.tensor([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=torch.uint8), +] + +TEST_CASE_9 = [ + FILE_PATH, + {"level": 8, "dtype": torch.float32}, + {"location": (0, 0), "size": (2, 1)}, + torch.tensor([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=torch.float32), +] + +# device tests +TEST_CASE_DEVICE_1 = [ + FILE_PATH, + {"level": 8, "dtype": torch.float32, "device": "cpu"}, + {"location": (0, 0), "size": (2, 1)}, + torch.tensor([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=torch.float32), + "cpu", +] + +TEST_CASE_DEVICE_2 = [ + FILE_PATH, + {"level": 8, "dtype": torch.float32, "device": "cuda"}, + {"location": (0, 0), "size": (2, 1)}, + torch.tensor([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=torch.float32), + "cuda", +] + +TEST_CASE_DEVICE_3 = [ + FILE_PATH, + {"level": 8, "dtype": np.float32, "device": "cpu"}, {"location": (0, 0), "size": (2, 1)}, - np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), + np.array([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=np.float32), + "cpu", +] + +TEST_CASE_DEVICE_4 = [ + FILE_PATH, + {"level": 8, "dtype": np.float32, "device": "cuda"}, + {"location": (0, 0), "size": (2, 1)}, + torch.tensor([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=torch.float32), + "cuda", +] + +TEST_CASE_DEVICE_5 = [ + FILE_PATH, + {"level": 8, "device": "cuda"}, + {"location": (0, 0), "size": (2, 1)}, + torch.tensor([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=torch.uint8), + "cuda", +] + +TEST_CASE_DEVICE_6 = [ + FILE_PATH, + {"level": 8}, + {"location": (0, 0), "size": (2, 1)}, + np.array([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=np.uint8), + "cpu", +] + +TEST_CASE_DEVICE_7 = [ + FILE_PATH, + {"level": 8, "device": None}, + {"location": (0, 0), "size": (2, 1)}, + np.array([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=np.uint8), + "cpu", ] TEST_CASE_MULTI_WSI = [ [FILE_PATH, FILE_PATH], - {"location": (0, 0), "size": (2, 1), "level": 2}, + {"location": (0, 0), "size": (2, 1), "level": 8}, np.concatenate( [ - np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), - np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), + np.array([[[242], [242]], [[242], [242]], [[242], [242]]]), + np.array([[[242], [242]], [[242], [242]], [[242], [242]]]), ], axis=0, ), @@ -184,7 +286,7 @@ class WSIReaderDeprecatedTests: class Tests(unittest.TestCase): backend = None - @parameterized.expand([TEST_CASE_0]) + @parameterized.expand([TEST_CASE_WHOLE_0]) def test_read_whole_image(self, file_path, level, expected_shape): reader = WSIReaderDeprecated(self.backend, level=level) with reader.read(file_path) as img_obj: @@ -255,17 +357,19 @@ def test_read_malformats(self, img_expected): @parameterized.expand([TEST_CASE_TRANSFORM_0]) def test_with_dataloader( - self, file_path: PathLike, level: int, expected_spatial_shape: Any, expected_shape: Tuple[int, ...] + self, file_path: PathLike, level: int, expected_spatial_shape: Any, expected_shape: tuple[int, ...] ): train_transform = Compose( [ - LoadImaged(keys=["image"], reader=WSIReaderDeprecated, backend=self.backend, level=level), + LoadImaged( + keys=["image"], reader=WSIReaderDeprecated, backend=self.backend, level=level, image_only=False + ), ToTensord(keys=["image"]), ] ) dataset = Dataset([{"image": file_path}], transform=train_transform) data_loader = DataLoader(dataset) - data: dict = first(data_loader) + data: dict = first(data_loader, {}) for s in data[PostFix.meta("image")]["spatial_shape"]: assert_allclose(s, expected_spatial_shape, type_test=False) self.assertTupleEqual(data["image"].shape, expected_shape) @@ -275,7 +379,7 @@ class WSIReaderTests: class Tests(unittest.TestCase): backend = None - @parameterized.expand([TEST_CASE_0]) + @parameterized.expand([TEST_CASE_WHOLE_0]) def test_read_whole_image(self, file_path, level, expected_shape): reader = WSIReader(self.backend, level=level) with reader.read(file_path) as img_obj: @@ -284,10 +388,23 @@ def test_read_whole_image(self, file_path, level, expected_shape): self.assertEqual(meta["backend"], self.backend) self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower()) self.assertEqual(meta[WSIPatchKeys.LEVEL], level) - assert_array_equal(meta[WSIPatchKeys.SIZE], expected_shape[1:]) - assert_array_equal(meta[WSIPatchKeys.LOCATION], (0, 0)) - - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + assert_allclose(meta[WSIPatchKeys.SIZE], expected_shape[1:], type_test=False) + assert_allclose(meta[WSIPatchKeys.LOCATION], (0, 0), type_test=False) + + @parameterized.expand( + [ + TEST_CASE_0, + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + ] + ) def test_read_region(self, file_path, kwargs, patch_info, expected_img): reader = WSIReader(self.backend, **kwargs) level = patch_info.get("level", kwargs.get("level")) @@ -298,14 +415,15 @@ def test_read_region(self, file_path, kwargs, patch_info, expected_img): img, meta = reader.get_data(img_obj, **patch_info) img2 = reader.get_data(img_obj, **patch_info)[0] self.assertTupleEqual(img.shape, img2.shape) - self.assertIsNone(assert_array_equal(img, img2)) + assert_allclose(img, img2) self.assertTupleEqual(img.shape, expected_img.shape) - self.assertIsNone(assert_array_equal(img, expected_img)) + assert_allclose(img, expected_img) + self.assertEqual(img.dtype, expected_img.dtype) self.assertEqual(meta["backend"], self.backend) self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower()) self.assertEqual(meta[WSIPatchKeys.LEVEL], level) - assert_array_equal(meta[WSIPatchKeys.SIZE], patch_info["size"]) - assert_array_equal(meta[WSIPatchKeys.LOCATION], patch_info["location"]) + assert_allclose(meta[WSIPatchKeys.SIZE], patch_info["size"], type_test=False) + assert_allclose(meta[WSIPatchKeys.LOCATION], patch_info["location"], type_test=False) @parameterized.expand([TEST_CASE_MULTI_WSI]) def test_read_region_multi_wsi(self, file_path_list, patch_info, expected_img): @@ -318,14 +436,14 @@ def test_read_region_multi_wsi(self, file_path_list, patch_info, expected_img): for img_obj in img_obj_list: img_obj.close() self.assertTupleEqual(img.shape, img2.shape) - self.assertIsNone(assert_array_equal(img, img2)) + assert_allclose(img, img2) self.assertTupleEqual(img.shape, expected_img.shape) - self.assertIsNone(assert_array_equal(img, expected_img)) + assert_allclose(img, expected_img) self.assertEqual(meta["backend"], self.backend) self.assertEqual(meta[WSIPatchKeys.PATH][0].lower(), str(os.path.abspath(file_path_list[0])).lower()) self.assertEqual(meta[WSIPatchKeys.LEVEL][0], patch_info["level"]) - assert_array_equal(meta[WSIPatchKeys.SIZE][0], expected_img.shape[1:]) - assert_array_equal(meta[WSIPatchKeys.LOCATION][0], patch_info["location"]) + assert_allclose(meta[WSIPatchKeys.SIZE][0], expected_img.shape[1:], type_test=False) + assert_allclose(meta[WSIPatchKeys.LOCATION][0], patch_info["location"], type_test=False) @parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1]) @skipUnless(has_tiff, "Requires tifffile.") @@ -344,8 +462,8 @@ def test_read_rgba(self, img_expected): with reader.read(file_path) as img_obj: image[mode], _ = reader.get_data(img_obj) - self.assertIsNone(assert_array_equal(image["RGB"], img_expected)) - self.assertIsNone(assert_array_equal(image["RGBA"], img_expected)) + assert_allclose(image["RGB"], img_expected) + assert_allclose(image["RGBA"], img_expected) @parameterized.expand([TEST_CASE_ERROR_0C, TEST_CASE_ERROR_1C, TEST_CASE_ERROR_2C, TEST_CASE_ERROR_3D]) @skipUnless(has_tiff, "Requires tifffile.") @@ -362,40 +480,40 @@ def test_read_malformats(self, img_expected): @parameterized.expand([TEST_CASE_TRANSFORM_0]) def test_with_dataloader( - self, file_path: PathLike, level: int, expected_spatial_shape: Any, expected_shape: Tuple[int, ...] + self, file_path: PathLike, level: int, expected_spatial_shape: Any, expected_shape: tuple[int, ...] ): train_transform = Compose( [ - LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level), + LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level, image_only=False), ToTensord(keys=["image"]), ] ) dataset = Dataset([{"image": file_path}], transform=train_transform) data_loader = DataLoader(dataset) - data: dict = first(data_loader) + data: dict = first(data_loader, {}) for s in data[PostFix.meta("image")]["spatial_shape"]: assert_allclose(s, expected_spatial_shape, type_test=False) self.assertTupleEqual(data["image"].shape, expected_shape) @parameterized.expand([TEST_CASE_TRANSFORM_0]) def test_with_dataloader_batch( - self, file_path: PathLike, level: int, expected_spatial_shape: Any, expected_shape: Tuple[int, ...] + self, file_path: PathLike, level: int, expected_spatial_shape: Any, expected_shape: tuple[int, ...] ): train_transform = Compose( [ - LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level), + LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level, image_only=False), ToTensord(keys=["image"]), ] ) dataset = Dataset([{"image": file_path}, {"image": file_path}], transform=train_transform) batch_size = 2 data_loader = DataLoader(dataset, batch_size=batch_size) - data: dict = first(data_loader) + data: dict = first(data_loader, {}) for s in data[PostFix.meta("image")]["spatial_shape"]: assert_allclose(s, expected_spatial_shape, type_test=False) self.assertTupleEqual(data["image"].shape, (batch_size, *expected_shape[1:])) - @parameterized.expand([TEST_CASE_0]) + @parameterized.expand([TEST_CASE_WHOLE_0]) def test_read_whole_image_multi_thread(self, file_path, level, expected_shape): if self.backend == "cucim": reader = WSIReader(self.backend, level=level, num_workers=4) @@ -405,8 +523,8 @@ def test_read_whole_image_multi_thread(self, file_path, level, expected_shape): self.assertEqual(meta["backend"], self.backend) self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower()) self.assertEqual(meta[WSIPatchKeys.LEVEL], level) - assert_array_equal(meta[WSIPatchKeys.SIZE], expected_shape[1:]) - assert_array_equal(meta[WSIPatchKeys.LOCATION], (0, 0)) + assert_allclose(meta[WSIPatchKeys.SIZE], expected_shape[1:], type_test=False) + assert_allclose(meta[WSIPatchKeys.LOCATION], (0, 0), type_test=False) @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_read_region_multi_thread(self, file_path, kwargs, patch_info, expected_img): @@ -417,14 +535,14 @@ def test_read_region_multi_thread(self, file_path, kwargs, patch_info, expected_ img, meta = reader.get_data(img_obj, **patch_info) img2 = reader.get_data(img_obj, **patch_info)[0] self.assertTupleEqual(img.shape, img2.shape) - self.assertIsNone(assert_array_equal(img, img2)) + assert_allclose(img, img2) self.assertTupleEqual(img.shape, expected_img.shape) - self.assertIsNone(assert_array_equal(img, expected_img)) + assert_allclose(img, expected_img) self.assertEqual(meta["backend"], self.backend) self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower()) self.assertEqual(meta[WSIPatchKeys.LEVEL], patch_info["level"]) - assert_array_equal(meta[WSIPatchKeys.SIZE], patch_info["size"]) - assert_array_equal(meta[WSIPatchKeys.LOCATION], patch_info["location"]) + assert_allclose(meta[WSIPatchKeys.SIZE], patch_info["size"], type_test=False) + assert_allclose(meta[WSIPatchKeys.LOCATION], patch_info["location"], type_test=False) @parameterized.expand([TEST_CASE_MPP_0]) def test_resolution_mpp(self, file_path, level, expected_mpp): @@ -433,6 +551,42 @@ def test_resolution_mpp(self, file_path, level, expected_mpp): mpp = reader.get_mpp(img_obj, level) self.assertTupleEqual(mpp, expected_mpp) + @parameterized.expand( + [ + TEST_CASE_DEVICE_1, + TEST_CASE_DEVICE_2, + TEST_CASE_DEVICE_3, + TEST_CASE_DEVICE_4, + TEST_CASE_DEVICE_5, + TEST_CASE_DEVICE_6, + TEST_CASE_DEVICE_7, + ] + ) + @skip_if_no_cuda + def test_read_region_device(self, file_path, kwargs, patch_info, expected_img, device): + reader = WSIReader(self.backend, **kwargs) + level = patch_info.get("level", kwargs.get("level")) + if self.backend == "tifffile" and level < 2: + return + with reader.read(file_path) as img_obj: + # Read twice to check multiple calls + img, meta = reader.get_data(img_obj, **patch_info) + img2 = reader.get_data(img_obj, **patch_info)[0] + self.assertTupleEqual(img.shape, img2.shape) + assert_allclose(img, img2) + self.assertTupleEqual(img.shape, expected_img.shape) + assert_allclose(img, expected_img) + self.assertEqual(img.dtype, expected_img.dtype) + if isinstance(img, torch.Tensor): + self.assertEqual(img.device.type, device) + else: + self.assertEqual("cpu", device) + self.assertEqual(meta["backend"], self.backend) + self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower()) + self.assertEqual(meta[WSIPatchKeys.LEVEL], level) + assert_allclose(meta[WSIPatchKeys.SIZE], patch_info["size"], type_test=False) + assert_allclose(meta[WSIPatchKeys.LOCATION], patch_info["location"], type_test=False) + @skipUnless(has_cucim, "Requires cucim") class TestCuCIMDeprecated(WSIReaderDeprecatedTests.Tests): diff --git a/tests/test_zipdataset.py b/tests/test_zipdataset.py index f381e0a453..de8a8e80d6 100644 --- a/tests/test_zipdataset.py +++ b/tests/test_zipdataset.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_zoom.py b/tests/test_zoom.py index 1d0447e319..fe3cd02766 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -9,24 +9,52 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np +import torch from parameterized import parameterized from scipy.ndimage import zoom as zoom_scipy from monai.data import MetaTensor, set_track_meta from monai.transforms import Zoom -from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion +from monai.transforms.lazy.functional import apply_transforms +from tests.utils import ( + DEFAULT_TEST_AFFINE, + TEST_NDARRAYS_ALL, + NumpyImageTestCase2D, + assert_allclose, + test_local_inversion, +) -VALID_CASES = [(1.5, "nearest"), (1.5, "nearest"), (0.8, "bilinear"), (0.8, "area")] +VALID_CASES = [(1.5, "nearest", True), (1.5, "nearest", False), (0.8, "bilinear"), (0.8, "area")] INVALID_CASES = [((None, None), "bilinear", TypeError), ((0.9, 0.9), "s", ValueError)] class TestZoom(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) - def test_correct_results(self, zoom, mode): + def test_pending_ops(self, zoom, mode, align_corners=False): + im = MetaTensor(self.imt[0], meta={"a": "b", "affine": DEFAULT_TEST_AFFINE}) + zoom_fn = Zoom(zoom=zoom, mode="bilinear", keep_size=False, dtype=torch.float64, align_corners=align_corners) + # non-lazy + expected = zoom_fn(im) + self.assertIsInstance(expected, MetaTensor) + # lazy + zoom_fn.lazy_evaluation = True + pending_result = zoom_fn(im) + self.assertIsInstance(pending_result, MetaTensor) + assert_allclose(pending_result.peek_pending_affine(), expected.affine) + assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) + result = apply_transforms(pending_result, mode="bilinear", dtype=np.float64, align_corners=align_corners)[0] + # compare + match_ratio = np.sum(np.isclose(result, expected)) / np.prod(result.shape) + self.assertGreater(match_ratio, 0.95) + + @parameterized.expand(VALID_CASES) + def test_correct_results(self, zoom, mode, *_): for p in TEST_NDARRAYS_ALL: zoom_fn = Zoom(zoom=zoom, mode=mode, keep_size=False) im = p(self.imt[0]) diff --git a/tests/test_zoom_affine.py b/tests/test_zoom_affine.py index 3c4bcd302c..dc39a4f1c2 100644 --- a/tests/test_zoom_affine.py +++ b/tests/test_zoom_affine.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import nibabel as nib diff --git a/tests/test_zoomd.py b/tests/test_zoomd.py index b6ff86e474..5c755c1c4d 100644 --- a/tests/test_zoomd.py +++ b/tests/test_zoomd.py @@ -9,28 +9,53 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np +import torch from parameterized import parameterized from scipy.ndimage import zoom as zoom_scipy from monai.transforms import Zoomd +from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion -VALID_CASES = [(1.5, "nearest", False), (0.3, "bilinear", False), (0.8, "bilinear", False)] +VALID_CASES = [ + (1.5, "nearest", False), + (0.3, "bilinear", False, True), + (0.8, "bilinear", False, False), + (1.3, "bilinear", False), +] INVALID_CASES = [("no_zoom", None, "bilinear", TypeError), ("invalid_order", 0.9, "s", ValueError)] class TestZoomd(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) - def test_correct_results(self, zoom, mode, keep_size): + def test_correct_results(self, zoom, mode, keep_size, align_corners=None): key = "img" - zoom_fn = Zoomd(key, zoom=zoom, mode=mode, keep_size=keep_size) + init_param = { + "keys": key, + "zoom": zoom, + "mode": mode, + "keep_size": keep_size, + "dtype": torch.float64, + "align_corners": align_corners, + } + zoom_fn = Zoomd(**init_param) for p in TEST_NDARRAYS_ALL: im = p(self.imt[0]) - zoomed = zoom_fn({key: im}) + call_param = {"data": {key: im}} + zoomed = zoom_fn(**call_param) + + # test lazy + # TODO: temporarily skip "nearest" test + if mode == "bilinear": + test_resampler_lazy(zoom_fn, zoomed, init_param, call_param, output_key=key) + zoom_fn.lazy_evaluation = False + test_local_inversion(zoom_fn, zoomed, {key: im}, key) _order = 0 if mode.endswith("linear"): diff --git a/tests/testing_data/config_fl_train.json b/tests/testing_data/config_fl_train.json index f53a95bc02..bdb9792fce 100644 --- a/tests/testing_data/config_fl_train.json +++ b/tests/testing_data/config_fl_train.json @@ -1,6 +1,7 @@ { "bundle_root": "tests/testing_data", "dataset_dir": "@bundle_root", + "val_interval": 1, "imports": [ "$import os" ], @@ -66,13 +67,6 @@ "min_zoom": 0.9, "max_zoom": 1.1, "prob": 0.5 - }, - { - "_target_": "ToTensord", - "keys": [ - "image", - "label" - ] } ], "preprocessing": { @@ -104,6 +98,12 @@ "_target_": "SimpleInferer" }, "handlers": [ + { + "_target_": "ValidationHandler", + "validator": "@validate#evaluator", + "epoch_level": true, + "interval": "@val_interval" + }, { "_target_": "StatsHandler", "tag_name": "train_loss", @@ -121,5 +121,82 @@ "inferer": "@train#inferer", "train_handlers": "@train#handlers" } - } + }, + "validate": { + "preprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "LoadImaged", + "keys": [ + "image" + ], + "image_only": true + }, + { + "_target_": "EnsureChannelFirstD", + "keys": [ + "image" + ] + }, + { + "_target_": "ScaleIntensityd", + "keys": [ + "image" + ] + } + ] + }, + "dataset": { + "_target_": "Dataset", + "data": [ + { + "image": "$os.path.join(@dataset_dir, 'image0.jpeg')", + "label": 0 + }, + { + "image": "$os.path.join(@dataset_dir, 'image1.jpeg')", + "label": 1 + } + ], + "transform": "@validate#preprocessing" + }, + "dataloader": { + "_target_": "DataLoader", + "dataset": "@validate#dataset", + "batch_size": 1, + "shuffle": true, + "num_workers": 2 + }, + "inferer": { + "_target_": "SimpleInferer" + }, + "postprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "Activationsd", + "keys": "pred", + "softmax": true + } + ] + }, + "evaluator": { + "_target_": "SupervisedEvaluator", + "device": "@device", + "val_data_loader": "@validate#dataloader", + "network": "@network", + "inferer": "@validate#inferer", + "postprocessing": "@validate#postprocessing" + } + }, + "initialize": [ + "$monai.utils.set_determinism(seed=123)" + ], + "run": [ + "$@train#trainer.run()" + ], + "finalize": [ + "$monai.utils.set_determinism(seed=None)" + ] } diff --git a/tests/testing_data/cpp_resample_answers.py b/tests/testing_data/cpp_resample_answers.py index 93f596619e..266c45a508 100644 --- a/tests/testing_data/cpp_resample_answers.py +++ b/tests/testing_data/cpp_resample_answers.py @@ -9,14 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import csv import os import warnings -from typing import List, Optional -def _read_testing_data_answers(fname: Optional[str] = None, delimiter=",") -> List: - answers: List = [] +def _read_testing_data_answers(fname: str | None = None, delimiter=",") -> list: + answers: list = [] if not fname: return answers # read answers from directory of the current file @@ -37,5 +38,5 @@ def _read_testing_data_answers(fname: Optional[str] = None, delimiter=",") -> Li return answers -Expected_1D_GP_fwd: List = _read_testing_data_answers(fname="1D_BP_fwd.txt") -Expected_1D_GP_bwd: List = _read_testing_data_answers(fname="1D_BP_bwd.txt") +Expected_1D_GP_fwd: list = _read_testing_data_answers(fname="1D_BP_fwd.txt") +Expected_1D_GP_bwd: list = _read_testing_data_answers(fname="1D_BP_bwd.txt") diff --git a/tests/testing_data/data_config.json b/tests/testing_data/data_config.json index 788d664439..c2d2ba9635 100644 --- a/tests/testing_data/data_config.json +++ b/tests/testing_data/data_config.json @@ -54,6 +54,26 @@ "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MNI152_T1_2mm_strucseg.nii.gz", "hash_type": "sha256", "hash_val": "eb4f1e596ca85aadaefc359d409fb9a3e27d733e6def04b996953b7c54bc26d4" + }, + "copd1_highres_INSP_STD_COPD_img": { + "url": "https://data.kitware.com/api/v1/file/62a0f067bddec9d0c4175c5a/download", + "hash_type": "sha512", + "hash_val": "60193cd6ef0cf055c623046446b74f969a2be838444801bd32ad5bedc8a7eeecb343e8a1208769c9c7a711e101c806a3133eccdda7790c551a69a64b9b3701e9" + }, + "copd1_highres_EXP_STD_COPD_img": { + "url": "https://data.kitware.com/api/v1/item/62a0f045bddec9d0c4175c44/download", + "hash_type": "sha512", + "hash_val": "841ef303958541474e66c2d1ccdc8b7ed17ba2f2681101307766b979a07979f2ec818ddf13791c3f1ac5a8ec3258d6ea45b692b4b4a838de9188602618972b6d" + }, + "CT_2D_head_fixed": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/CT_2D_head_fixed.mha", + "hash_type": "sha256", + "hash_val": "06f2ce6fbf6a59f0874c735555fcf71717f631156b1b0697c1752442f7fc1cc5" + }, + "CT_2D_head_moving": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/CT_2D_head_moving.mha", + "hash_type": "sha256", + "hash_val": "a37c5fe388c38b3f4ac564f456277d09d3982eda58c4da05ead8ee2332360f47" } }, "videos": { diff --git a/tests/testing_data/inference.json b/tests/testing_data/inference.json index c222667101..7b1f9c20cf 100644 --- a/tests/testing_data/inference.json +++ b/tests/testing_data/inference.json @@ -1,9 +1,10 @@ { "dataset_dir": "/workspace/data/Task09_Spleen", + "bundle_root": "will override", "output_dir": "need override", "prediction_shape": "prediction shape:", "import_glob": "$import glob", - "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", + "device": "$torch.device('cpu')", "print_test_name": "$print('json_test')", "print_glob_file": "$print(glob.__file__)", "network_def": { @@ -112,8 +113,10 @@ "postprocessing": "@postprocessing", "amp": false }, - "evaluating": [ - "$monai.utils.set_determinism(0)", + "initialize": [ + "$monai.utils.set_determinism(0)" + ], + "run": [ "$@evaluator.run()" ] } diff --git a/tests/testing_data/inference.yaml b/tests/testing_data/inference.yaml index a289b549db..0343ea0bae 100644 --- a/tests/testing_data/inference.yaml +++ b/tests/testing_data/inference.yaml @@ -1,8 +1,9 @@ --- dataset_dir: "/workspace/data/Task09_Spleen" +bundle_root: "will override" output_dir: "need override" prediction_shape: "prediction shape:" -device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')" +device: "$torch.device('cpu')" print_test_name: "$print('yaml_test')" network_def: _target_: UNet @@ -80,6 +81,9 @@ evaluator: inferer: "@inferer" postprocessing: "@postprocessing" amp: false -evaluating: +initialize: - "$monai.utils.set_determinism(0)" +run: - "$@evaluator.run()" +finalize: + - "$print('test finalize section.')" diff --git a/tests/testing_data/integration_answers.py b/tests/testing_data/integration_answers.py index 1907e3c4a5..b1f115e1d3 100644 --- a/tests/testing_data/integration_answers.py +++ b/tests/testing_data/integration_answers.py @@ -9,10 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import numpy as np EXPECTED_ANSWERS = [ - { # test answers for PyTorch 1.6 + { # test answers for PyTorch 1.12.1 "integration_classification_2d": { "losses": [0.776835828070428, 0.1615355300011149, 0.07492854832938523, 0.04591309238865877], "best_metric": 0.9999184380485994, @@ -20,56 +22,56 @@ }, "integration_segmentation_3d": { "losses": [ - 0.5367561340332031, - 0.478084459900856, - 0.4581540793180466, - 0.44623913466930387, - 0.42341493666172025, - 0.42569945752620697, + 0.5428894340991974, + 0.47331981360912323, + 0.4482289582490921, + 0.4452722787857056, + 0.4289989799261093, + 0.4359133839607239, ], - "best_metric": 0.9295084029436111, - "infer_metric": 0.9296411260962486, + "best_metric": 0.933259129524231, + "infer_metric": 0.9332860708236694, "output_sums": [ - 0.14302121377204619, - 0.15321686701244813, - 0.15267064069005093, - 0.1408481434833016, - 0.18862719991649474, - 0.16992848513054068, - 0.1479306037291329, - 0.1691071594535633, - 0.15804366588267224, - 0.18019304183940157, - 0.1635089455927468, - 0.16851606024285842, - 0.1454348651039073, - 0.11584957890961554, - 0.16255468027312903, - 0.20118089432240313, - 0.176187783307603, - 0.1004243279488101, - 0.19385348502657657, - 0.2030768555124136, - 0.196251372926592, - 0.20823046240222043, - 0.1631389353339986, - 0.13299661219478043, - 0.14917081129077908, - 0.14383374638201593, - 0.23050183928776746, - 0.1614747942341212, - 0.14913436515470202, - 0.10443081170610946, - 0.11978674347415241, - 0.13126176432899028, - 0.11570832453348577, - 0.15306806147195887, - 0.163673089782912, - 0.19394971756732426, - 0.22197501007172804, - 0.1812147930033603, - 0.19051659118682873, - 0.0774867922747158, + 0.142167581604417, + 0.15195543400875847, + 0.1512754523215521, + 0.13962938779108452, + 0.18835719348918614, + 0.16943498693483486, + 0.1465709827477569, + 0.16806483607477135, + 0.1568844609697224, + 0.17911090857818554, + 0.16252098157181355, + 0.16806016936625395, + 0.14430124467305516, + 0.11316135548315168, + 0.16183771025615476, + 0.2009426314066978, + 0.1760258010156966, + 0.09700864497950844, + 0.1938495370314683, + 0.20319147575335647, + 0.19629641404249798, + 0.20852344793102826, + 0.16185073630020633, + 0.13184196857669161, + 0.1480959525354053, + 0.14232924377085415, + 0.23177739882790951, + 0.16094610375534632, + 0.14832771888168225, + 0.10259365443625812, + 0.11850632233099603, + 0.1294100326098242, + 0.11364228279017609, + 0.15181947897584674, + 0.16319358155815072, + 0.1940284526521386, + 0.22306137879066443, + 0.18083137638759522, + 0.1903135237574692, + 0.07402317520619131, ], }, "integration_workflows": { @@ -163,7 +165,7 @@ ], }, }, - { # test answers for PyTorch 1.7 + { # test answers for PyTorch 1.8 "integration_classification_2d": { "losses": [0.777176220515731, 0.16019743723664315, 0.07480076164197011, 0.045643698364780966], "best_metric": 0.9999418774120775, @@ -171,56 +173,56 @@ }, "integration_segmentation_3d": { "losses": [ - 0.5427072256803512, - 0.46434969305992124, - 0.45358552038669586, - 0.4363856494426727, - 0.42080804109573366, - 0.42058534920215607, + 0.5326887160539627, + 0.4685510128736496, + 0.46245276033878324, + 0.4411882758140564, + 0.4198471873998642, + 0.43021280467510226, ], - "best_metric": 0.9292903542518616, - "infer_metric": 0.9306288316845894, + "best_metric": 0.931993305683136, + "infer_metric": 0.9326668977737427, "output_sums": [ - 0.14192493409895743, - 0.15182314591386872, - 0.15143080738742032, - 0.13972497034181824, - 0.18790884439406313, - 0.16933812661492562, - 0.14664343345928132, - 0.1678599094806423, - 0.1568852615222309, - 0.17882538307200632, - 0.16226220644853354, - 0.16756325103417588, - 0.1449974856885373, - 0.1160602083671129, - 0.1614830941632057, - 0.20060717335382267, - 0.17543495742507476, - 0.10308107883493946, - 0.19289222718691168, - 0.20225689438356148, - 0.19587806881756237, - 0.20773073456322155, - 0.16193015294299506, - 0.13181961683097554, - 0.14850995284454005, - 0.14238637655756, - 0.2307113922277095, - 0.1608335768948913, - 0.1480752874532259, - 0.1038477413165911, - 0.11880665574424197, - 0.13084873656303445, - 0.1141965805147642, - 0.1531586543003841, - 0.16275008603701097, - 0.19320476187766733, - 0.2217811250932611, - 0.18027048819200148, - 0.18958803602663193, - 0.08653716931250294, + 0.1418775228871769, + 0.15188869120317386, + 0.15140863737688195, + 0.1396146850007127, + 0.18784343811575696, + 0.16909487431163164, + 0.14649608249452073, + 0.1677767130878611, + 0.1568122289811143, + 0.17874181729735056, + 0.16213703658980205, + 0.16754335171970686, + 0.14444824920997243, + 0.11432402622850306, + 0.16143210936221247, + 0.20055289634107482, + 0.17543571757219317, + 0.09920729163334538, + 0.19297325815057875, + 0.2023200127892273, + 0.1956677579845722, + 0.20774045016425718, + 0.16193278944159428, + 0.13174198906539808, + 0.14830508550670007, + 0.14241105864278342, + 0.23090631643085724, + 0.16056153813499532, + 0.1480353269419819, + 0.10318719171632634, + 0.11867462580989198, + 0.12997011485830187, + 0.11401220332210203, + 0.15242746700662088, + 0.1628489107974574, + 0.19327235354175412, + 0.22184902863377548, + 0.18028049625972334, + 0.18958059106892552, + 0.07884601267057013, ], }, "integration_workflows": { @@ -314,67 +316,6 @@ ], }, }, - { # test answers for PyTorch 21.04, cuda 11.3 - "integration_classification_2d": { - "losses": [0.7772567988770782, 0.16357883198815545, 0.0748426011840629, 0.045560025710873545], - "best_metric": 0.9999362036681547, - "infer_prop": [1030, 898, 981, 1033, 960, 1046], - }, - "integration_segmentation_3d": { - "losses": [ - 0.5462346076965332, - 0.4699550330638885, - 0.4407052755355835, - 0.4473582059144974, - 0.4345871120691299, - 0.4268435090780258, - ], - "best_metric": 0.9325245052576066, - "infer_metric": 0.9326683700084686, - "output_sums": [ - 0.14224469870198278, - 0.15221021012369151, - 0.15124158255724182, - 0.13988812880932433, - 0.18869885039284465, - 0.16944664085835437, - 0.14679946398855015, - 0.1681337815374021, - 0.1572538225010156, - 0.179386563044054, - 0.162734465243387, - 0.16831902111202945, - 0.1447043535420074, - 0.11343210557896033, - 0.16199135405262954, - 0.20095180481987404, - 0.17613484080473857, - 0.09717457016552708, - 0.1940439758638305, - 0.2033698355271389, - 0.19628583555443793, - 0.20852096425983455, - 0.16202004771083997, - 0.13206408917949392, - 0.14840973098125526, - 0.14237425379050472, - 0.23165483128059614, - 0.16098621485325398, - 0.14831028015056963, - 0.10317099380415945, - 0.118716576251689, - 0.13002315213569166, - 0.11436407827087304, - 0.1522274707636008, - 0.16314910792851098, - 0.1941135852761834, - 0.22309890968242424, - 0.18111804948625987, - 0.19043976068601465, - 0.07442812452084423, - ], - }, - }, { # test answers for PyTorch 1.9 "integration_workflows": { "output_sums_2": [ @@ -491,158 +432,11 @@ "infer_metric": 0.9316383600234985, }, }, - { # test answers for PyTorch 21.10 - "integration_classification_2d": { - "losses": [0.7806222991199251, 0.16259610306495315, 0.07529311385124353, 0.04640352608529246], - "best_metric": 0.9999369155431564, - "infer_prop": [1030, 898, 981, 1033, 960, 1046], - }, - "integration_segmentation_3d": { - "losses": [ - 0.5462362408638001, - 0.4913381844758987, - 0.4526856362819672, - 0.43404580652713776, - 0.42532919645309447, - 0.4160102754831314, - ], - "best_metric": 0.9357608556747437, - "infer_metric": 0.9359462857246399, - "output_sums": [ - 0.14133183650702907, - 0.15129517085134564, - 0.15039408698301698, - 0.1388800895551786, - 0.18765019147239637, - 0.16847158867677473, - 0.14567945622102715, - 0.16728557092807228, - 0.15601444057659314, - 0.17816339678760573, - 0.1616256801482474, - 0.16733042976922818, - 0.14342795433701588, - 0.1122946416901734, - 0.16105778942392063, - 0.20017543167070598, - 0.17512204704647916, - 0.09592956823274325, - 0.19316383411238341, - 0.2022308530579937, - 0.19527218778022315, - 0.2075871950564991, - 0.16083565516485876, - 0.13111518931029637, - 0.1473909261474288, - 0.14161210629657228, - 0.23102446985179093, - 0.15980667305916593, - 0.14760356792082058, - 0.1018092235719272, - 0.11792260857122504, - 0.1285278390386459, - 0.11275165891441473, - 0.15101653432548032, - 0.16236351926994622, - 0.1932631773335222, - 0.2221395787381994, - 0.18003549292918666, - 0.18940543270178078, - 0.07430261166443994, - ], - }, - "integration_workflows": { - "output_sums": [ - 0.14211511611938477, - 0.1516571044921875, - 0.1381092071533203, - 0.13403034210205078, - 0.18480682373046875, - 0.16382598876953125, - 0.14140796661376953, - 0.1665945053100586, - 0.15700864791870117, - 0.17697620391845703, - 0.16163396835327148, - 0.16488313674926758, - 0.1442713737487793, - 0.11060476303100586, - 0.16111087799072266, - 0.19617986679077148, - 0.1744403839111328, - 0.052786827087402344, - 0.19046974182128906, - 0.19913578033447266, - 0.19527721405029297, - 0.2032318115234375, - 0.16050148010253906, - 0.13228464126586914, - 0.1512293815612793, - 0.1372208595275879, - 0.22692251205444336, - 0.16164922714233398, - 0.14729642868041992, - 0.10398292541503906, - 0.1195836067199707, - 0.13096046447753906, - 0.11221647262573242, - 0.1521167755126953, - 0.1599421501159668, - 0.1898345947265625, - 0.21675777435302734, - 0.1777491569519043, - 0.18526840209960938, - 0.035144805908203125, - ], - "output_sums_2": [ - 0.14200592041015625, - 0.15146303176879883, - 0.13796186447143555, - 0.1339101791381836, - 0.18489742279052734, - 0.1637406349182129, - 0.14113903045654297, - 0.16657161712646484, - 0.15676355361938477, - 0.17683839797973633, - 0.1614980697631836, - 0.16493558883666992, - 0.14408016204833984, - 0.11035394668579102, - 0.1610560417175293, - 0.1962742805480957, - 0.17439842224121094, - 0.05285835266113281, - 0.19057941436767578, - 0.19914865493774414, - 0.19533538818359375, - 0.20333576202392578, - 0.16032838821411133, - 0.13197898864746094, - 0.1510462760925293, - 0.13703680038452148, - 0.2270984649658203, - 0.16144943237304688, - 0.1472611427307129, - 0.10393238067626953, - 0.11940813064575195, - 0.1307811737060547, - 0.11203241348266602, - 0.15186500549316406, - 0.15992307662963867, - 0.18991422653198242, - 0.21689796447753906, - 0.1777033805847168, - 0.18547868728637695, - 0.035192012786865234, - ], - }, - }, ] def test_integration_value(test_name, key, data, rtol=1e-2): - for (idx, expected) in enumerate(EXPECTED_ANSWERS): + for idx, expected in enumerate(EXPECTED_ANSWERS): if test_name not in expected: continue if key not in expected[test_name]: diff --git a/tests/utils.py b/tests/utils.py index bb5d35fc98..ef46678c06 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import copy import datetime import functools @@ -28,7 +30,7 @@ from contextlib import contextmanager from functools import partial, reduce from subprocess import PIPE, Popen -from typing import Callable, Optional, Tuple, Union +from typing import Callable from urllib.error import ContentTooShortError, HTTPError import numpy as np @@ -47,7 +49,7 @@ from monai.utils.type_conversion import convert_data_type nib, _ = optional_import("nibabel") -http_error, has_requests = optional_import("requests", name="HTTPError") +http_error, has_req = optional_import("requests", name="HTTPError") quick_test_var = "QUICKTEST" _tf32_enabled = None @@ -64,6 +66,16 @@ def testing_data_config(*keys): return reduce(operator.getitem, keys, _test_data_config) +def get_testing_algo_template_path(): + """ + a local folder to the testing algorithm template or a url to the compressed template file. + Default to None, which effectively uses bundle_gen's ``default_algo_zip`` path. + + https://github.com/Project-MONAI/MONAI/blob/1.1.0/monai/apps/auto3dseg/bundle_gen.py#L380-L381 + """ + return os.environ.get("MONAI_TESTING_ALGO_TEMPLATE", None) + + def clone(data: NdarrayTensor) -> NdarrayTensor: """ Clone data independent of type. @@ -80,7 +92,7 @@ def clone(data: NdarrayTensor) -> NdarrayTensor: def assert_allclose( actual: NdarrayOrTensor, desired: NdarrayOrTensor, - type_test: Union[bool, str] = True, + type_test: bool | str = True, device_test: bool = False, *args, **kwargs, @@ -124,7 +136,7 @@ def assert_allclose( def skip_if_downloading_fails(): try: yield - except (ContentTooShortError, HTTPError, ConnectionError) + (http_error,) if has_requests else () as e: + except (ContentTooShortError, HTTPError, ConnectionError) + (http_error,) if has_req else () as e: # noqa: B030 raise unittest.SkipTest(f"error while downloading: {e}") from e except ssl.SSLError as ssl_e: if "decryption failed" in str(ssl_e): @@ -307,7 +319,9 @@ def has_cupy(): HAS_CUPY = has_cupy() -def make_nifti_image(array: NdarrayOrTensor, affine=None, dir=None, fname=None, suffix=".nii.gz", verbose=False): +def make_nifti_image( + array: NdarrayOrTensor, affine=None, dir=None, fname=None, suffix=".nii.gz", verbose=False, dtype=float +): """ Create a temporary nifti image on the disk and return the image name. User is responsible for deleting the temporary file when done with it. @@ -318,7 +332,7 @@ def make_nifti_image(array: NdarrayOrTensor, affine=None, dir=None, fname=None, affine, *_ = convert_data_type(affine, np.ndarray) if affine is None: affine = np.eye(4) - test_image = nib.Nifti1Image(array, affine) + test_image = nib.Nifti1Image(array.astype(dtype), affine) # type: ignore # if dir not given, create random. Else, make sure it exists. if dir is None: @@ -339,7 +353,7 @@ def make_nifti_image(array: NdarrayOrTensor, affine=None, dir=None, fname=None, return fname -def make_rand_affine(ndim: int = 3, random_state: Optional[np.random.RandomState] = None): +def make_rand_affine(ndim: int = 3, random_state: np.random.RandomState | None = None): """Create random affine transformation (with values == -1, 0 or 1).""" rs = np.random.random.__self__ if random_state is None else random_state # type: ignore @@ -401,13 +415,13 @@ def __init__( nnodes: int = 1, nproc_per_node: int = 1, master_addr: str = "localhost", - master_port: Optional[int] = None, - node_rank: Optional[int] = None, + master_port: int | None = None, + node_rank: int | None = None, timeout=60, init_method=None, - backend: Optional[str] = None, - daemon: Optional[bool] = None, - method: Optional[str] = "spawn", + backend: str | None = None, + daemon: bool | None = None, + method: str | None = "spawn", verbose: bool = False, ): """ @@ -537,8 +551,8 @@ class TimedCall: def __init__( self, seconds: float = 60.0, - daemon: Optional[bool] = None, - method: Optional[str] = "spawn", + daemon: bool | None = None, + method: str | None = "spawn", force_quit: bool = True, skip_timing=False, ): @@ -571,7 +585,6 @@ def run_process(func, args, kwargs, results): results.put(e) def __call__(self, obj): - if self.skip_timing: return obj @@ -780,7 +793,7 @@ def command_line_tests(cmd, copy_env=True): raise RuntimeError(f"subprocess call error {e.returncode}: {errors}, {output}") from e -TEST_TORCH_TENSORS: Tuple = (torch.as_tensor,) +TEST_TORCH_TENSORS: tuple = (torch.as_tensor,) if torch.cuda.is_available(): gpu_tensor: Callable = partial(torch.as_tensor, device="cuda") TEST_TORCH_TENSORS = TEST_TORCH_TENSORS + (gpu_tensor,) @@ -789,9 +802,9 @@ def command_line_tests(cmd, copy_env=True): [[2.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 0.0, 2.0, 0.0], [0.0, 0.0, 0.0, 1.0]] ) _metatensor_creator = partial(MetaTensor, meta={"a": "b", "affine": DEFAULT_TEST_AFFINE}) -TEST_NDARRAYS_NO_META_TENSOR: Tuple[Callable] = (np.array,) + TEST_TORCH_TENSORS # type: ignore -TEST_NDARRAYS: Tuple[Callable] = TEST_NDARRAYS_NO_META_TENSOR + (_metatensor_creator,) # type: ignore -TEST_TORCH_AND_META_TENSORS: Tuple[Callable] = TEST_TORCH_TENSORS + (_metatensor_creator,) # type: ignore +TEST_NDARRAYS_NO_META_TENSOR: tuple[Callable] = (np.array,) + TEST_TORCH_TENSORS # type: ignore +TEST_NDARRAYS: tuple[Callable] = TEST_NDARRAYS_NO_META_TENSOR + (_metatensor_creator,) # type: ignore +TEST_TORCH_AND_META_TENSORS: tuple[Callable] = TEST_TORCH_TENSORS + (_metatensor_creator,) # type: ignore # alias for branch tests TEST_NDARRAYS_ALL = TEST_NDARRAYS