Skip to content

Add support for optional conditioning in PatchInferer, SliceInferer, and SlidingWindowInferer #8400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: dev
Choose a base branch
from

Conversation

FinnBehrendt
Copy link

@FinnBehrendt FinnBehrendt commented Mar 25, 2025

Fixes #8220

Description

This PR adds support for optional conditioning in MONAI’s inferers, allowing models to receive auxiliary inputs for conditioning that are processed (patched, sliced) the same way as the inputs. This is particularly relevant for generative models like conditional GANs or DMs.

Example Usage:

# Given a conditioned model, inputs of shape (1, C, H, W, D) and condition of shape (1, C, H, W, D)
output = SliceInferer(...)(inputs, model, condition=cond_tensor)

Types of changes

  • Extended PatchInferer, SliceInferer, and SlidingWindowInferer to optionally accept a condition tensor (passed as a kwarg).
  • The condition can now be:
    • None (default)
    • A tensor of the same shape as inputs
  • The inferers now slice/patch the conditions alongside the corresponding inputs and feed them to the network.
  • Updated unit tests for each inferer:
    • Verified with and without conditioning
  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Additional extensions such as support for dense vector conditioning (e.g., (1, C, Z), with Z being the conditional dimension) could be explored in a follow-up PR if there’s interest.

Summary by CodeRabbit

  • New Features

    • Added support for an optional "condition" tensor in patch-based, sliding window, and slice inference, allowing conditional inference workflows.
    • The "condition" tensor is validated for shape and type consistency with inputs and is processed in sync during inference.
  • Tests

    • Introduced extensive new tests for conditional inference across patch, sliding window, and slice inferers to ensure correct behavior and output validation when using the "condition" argument.

Sorry, something went wrong.

@FinnBehrendt FinnBehrendt marked this pull request as draft March 25, 2025 11:02
FinnBehrendt and others added 5 commits March 25, 2025 12:08

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
I, FinnBehrendt <finn.behrendt94@gmail.com>, hereby add my Signed-off-by to this commit: 248d3c3
I, FinnBehrendt <finn.behrendt94@gmail.com>, hereby add my Signed-off-by to this commit: c4c65ae

Signed-off-by: FinnBehrendt <finn.behrendt94@gmail.com>
…ndt94@gmail.com>
I, FinnBehrendt <finn.behrendt94@gmail.com>, hereby add my Signed-off-by to this commit: dbb856a
I, FinnBehrendt <finn.behrendt94@gmail.com>, hereby add my Signed-off-by to this commit: 3c63c1d

Signed-off-by: FinnBehrendt <finn.behrendt94@gmail.com>
@FinnBehrendt FinnBehrendt marked this pull request as ready for review March 25, 2025 16:19
KumoLiu and others added 3 commits April 13, 2025 16:48

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
Copy link
Contributor

@virginiafdez virginiafdez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR looks good to me overall, with appropriate tests that pass. Not sure if it'll be better to derive the condition from an argument defaulted to None instead of kwargs, to be more transparent. If leaving to kwargs, I believe there should still be some mention on the dosctrings that it is possible to pass a "condition". Not sure about this either, but If the shape of both input and output have to match, then checking for this at the beginning of the inferer sampler would be good (+test).

Signed-off-by: FinnBehrendt <finn.behrendt94@gmail.com>
Signed-off-by: FinnBehrendt <finn.behrendt94@gmail.com>
@FinnBehrendt
Copy link
Author

Thanks a lot for the suggestions! I just pushed the updated changes.

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
Copy link

coderabbitai bot commented Jul 11, 2025

Walkthrough

The code introduces support for an optional condition tensor to the PatchInferer, SlidingWindowInferer, and SliceInferer classes, as well as the sliding_window_inference utility. This allows conditional inference by synchronously processing the condition tensor alongside inputs and passing it to the network. Comprehensive tests are added for this new conditional mechanism.

Changes

File(s) Change Summary
monai/inferers/inferer.py Added optional condition argument to PatchInferer, SlidingWindowInferer, and SliceInferer methods; updated control flow to validate and propagate condition.
monai/inferers/utils.py Updated sliding_window_inference to accept and process an optional condition tensor in sync with inputs.
tests/inferers/test_patch_inferer.py Added new test class and cases to verify PatchInferer with condition argument.
tests/inferers/test_slice_inferer.py Added new test class to verify SliceInferer with condition support.
tests/inferers/test_sliding_window_inference.py Added new test class to verify sliding_window_inference and SlidingWindowInferer with condition.

Sequence Diagram(s)

Loading
sequenceDiagram
    participant User
    participant Inferer (Patch/SlidingWindow/Slice)
    participant Network

    User->>Inferer: call(inputs, network, condition=cond)
    Inferer->>Inferer: Validate shapes/types of inputs and condition
    loop For each patch/slice/window
        Inferer->>Network: network(input_patch, condition=cond_patch)
        Network-->>Inferer: output_patch
    end
    Inferer-->>User: aggregated_output

Assessment against linked issues

Objective Addressed Explanation
Add support for conditional-based models in SliceInferer (#8220)
Ensure condition tensor is processed slice-by-slice in sync with inputs (#8220)
Validate shape and type matching between condition and inputs for correct inference (#8220)
Add tests verifying SliceInferer and related inferers handle condition argument correctly (#8220)

Poem

A rabbit with code in its paws,
Adds "condition"—and earns applause!
Now slices and patches, in sync they go,
Through inferers, the tensors flow.
With tests that hop and outputs that gleam,
This conditional leap fulfills the dream!
🐇✨

✨ Finishing Touches
  • 📝 Generate Docstrings

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Sorry, something went wrong.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
monai/inferers/inferer.py (1)

817-835: Network wrapper correctly handles condition tensor!

The implementation properly squeezes the condition tensor along the spatial dimension (matching the input processing) and passes it to the network. Shape validation is appropriately handled in the __call__ method before reaching this point, addressing the concern raised in the previous review.

🧹 Nitpick comments (1)
monai/inferers/inferer.py (1)

386-407: Consider refactoring to reduce code duplication.

The inference loop has identical logic duplicated between the conditioned (lines 387-398) and unconditioned (lines 400-407) cases. Only the kwargs["condition"] assignment differs.

Consider refactoring to eliminate duplication:

-        if condition is not None:
-            for (patches, locations, batch_size), (condition_patches, _, _) in zip(
-                self._batch_sampler(patches_locations), self._batch_sampler(condition_locations)
-            ):
-                # add patched condition to kwargs
-                kwargs["condition"] = condition_patches
-                # run inference
-                outputs = self._run_inference(network, patches, *args, **kwargs)
-                # initialize the mergers
-                if not mergers:
-                    mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size)
-                # aggregate outputs
-                self._aggregate(outputs, locations, batch_size, mergers, ratios)
-        else:
-            for patches, locations, batch_size in self._batch_sampler(patches_locations):
-                # run inference
-                outputs = self._run_inference(network, patches, *args, **kwargs)
-                # initialize the mergers
-                if not mergers:
-                    mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size)
-                # aggregate outputs
-                self._aggregate(outputs, locations, batch_size, mergers, ratios)
+        # Create an iterator for condition patches if condition is provided
+        condition_iter = self._batch_sampler(condition_locations) if condition is not None else None
+        
+        for patches, locations, batch_size in self._batch_sampler(patches_locations):
+            # Get condition patches if available
+            if condition_iter is not None:
+                condition_patches, _, _ = next(condition_iter)
+                kwargs["condition"] = condition_patches
+            
+            # run inference
+            outputs = self._run_inference(network, patches, *args, **kwargs)
+            # initialize the mergers
+            if not mergers:
+                mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size)
+            # aggregate outputs
+            self._aggregate(outputs, locations, batch_size, mergers, ratios)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4b69748 and 147a686.

📒 Files selected for processing (5)
  • monai/inferers/inferer.py (6 hunks)
  • monai/inferers/utils.py (3 hunks)
  • tests/inferers/test_patch_inferer.py (1 hunks)
  • tests/inferers/test_slice_inferer.py (1 hunks)
  • tests/inferers/test_sliding_window_inference.py (1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
tests/inferers/test_patch_inferer.py (5)
monai/inferers/splitter.py (1)
  • SlidingWindowSplitter (94-292)
monai/inferers/merger.py (1)
  • AvgMerger (103-203)
monai/transforms/utils_pytorch_numpy_unification.py (1)
  • mean (484-503)
monai/inferers/inferer.py (1)
  • PatchInferer (100-417)
tests/test_utils.py (1)
  • assert_allclose (99-139)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (18)
  • GitHub Check: build-docs
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: packaging
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-pytorch (2.4.1)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.10)
🔇 Additional comments (16)
monai/inferers/utils.py (5)

156-157: Clean extraction of condition parameter.

The condition tensor is properly extracted from kwargs and defaulted to None, ensuring it doesn't interfere with other keyword arguments passed to the predictor.


173-174: Consistent padding logic for condition tensor.

The condition tensor is padded identically to the input tensor when padding is required, ensuring synchronized spatial dimensions throughout the sliding window process.


227-229: Proper handling of condition patches for batched windows.

The condition patches are concatenated and moved to the sliding window device in parallel with input patches, maintaining consistency in batch processing.


232-234: Consistent condition handling for single windows.

The condition patch is extracted and moved to the sliding window device for single window processing, mirroring the input handling logic.


237-239: Condition properly passed to predictor.

The condition tensor is included in kwargs when calling the predictor, enabling conditional inference while maintaining backward compatibility when no condition is provided.

tests/inferers/test_slice_inferer.py (1)

56-88: Comprehensive test coverage for SliceInferer with condition support.

The test class properly validates:

  • Model forward method accepts condition parameter
  • Condition tensor processed synchronously with input
  • Output correctness verified through sum comparison
  • Multiple execution stability confirmed
  • Consistent test parameterization with existing tests
tests/inferers/test_patch_inferer.py (2)

308-478: Comprehensive test case variants for condition support.

The test cases properly extend existing scenarios with condition support:

  • Network functions consistently modified to accept condition parameter
  • Expected results correctly adjusted for conditioning logic (e.g., TENSOR_4x4 * 2 for additive conditioning)
  • All major test scenarios covered including preprocessing, postprocessing, padding, and multi-threading

481-532: Well-structured test class for PatchInferer condition support.

The test class provides thorough validation:

  • Handles both tensor and list inputs for condition (lines 503-506, 513-516, 524-527)
  • Proper condition tensor creation with cloning for tensor inputs
  • Comprehensive coverage of tensor, list, and dict outputs
  • Consistent test structure with existing test patterns
tests/inferers/test_sliding_window_inference.py (4)

376-409: Comprehensive basic testing for sliding window inference with condition.

The test method properly validates:

  • Condition tensor creation and device placement
  • Compute function accepts condition parameter
  • Expected results correctly calculated for conditioning
  • Both sliding_window_inference function and SlidingWindowInferer class tested

430-452: Thorough device handling validation with condition support.

The test verifies:

  • Condition tensor device placement consistency with inputs
  • Proper device handling for sliding window and compute devices
  • Device assertions for both data and condition tensors in compute function
  • Gradient propagation maintained with condition support

568-614: Comprehensive args/kwargs handling with condition support.

The test validates:

  • Condition parameter properly passed alongside other args/kwargs
  • Compute function signature supports condition parameter
  • Expected results correctly calculated with multiple parameters
  • Both function and class interfaces tested with condition

615-686: Thorough multi-output testing with condition support.

The test comprehensively validates:

  • Tuple and dict output formats with condition support
  • Condition tensor properly handled for different output resolutions
  • Expected results correctly calculated for each output type
  • Both sliding_window_inference function and SlidingWindowInferer class tested
monai/inferers/inferer.py (4)

325-327: Documentation looks good!

The new condition parameter is well documented with clear requirements about shape matching.


330-353: Comprehensive validation logic for condition tensor.

The shape and type validation is thorough, covering both tensor and list cases with clear error messages.


569-576: Condition handling looks correct!

The implementation properly validates the condition shape and passes it through to sliding_window_inference. Using kwargs.get() instead of kwargs.pop() is appropriate here since the condition needs to remain in kwargs for the downstream function.


784-815: SliceInferer condition support implemented correctly!

The implementation properly validates the condition and passes it through to the network_wrapper. The approach of using the lambda function to inject the condition parameter works well.

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

SliceInferer may need to handle conditional-based models
4 participants