-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add support for optional conditioning in PatchInferer, SliceInferer, and SlidingWindowInferer #8400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
53345d1
to
c4c65ae
Compare
…ndt94@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR looks good to me overall, with appropriate tests that pass. Not sure if it'll be better to derive the condition from an argument defaulted to None instead of kwargs, to be more transparent. If leaving to kwargs, I believe there should still be some mention on the dosctrings that it is possible to pass a "condition". Not sure about this either, but If the shape of both input and output have to match, then checking for this at the beginning of the inferer sampler would be good (+test).
Signed-off-by: FinnBehrendt <finn.behrendt94@gmail.com>
Signed-off-by: FinnBehrendt <finn.behrendt94@gmail.com>
Thanks a lot for the suggestions! I just pushed the updated changes. |
WalkthroughThe code introduces support for an optional Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Inferer (Patch/SlidingWindow/Slice)
participant Network
User->>Inferer: call(inputs, network, condition=cond)
Inferer->>Inferer: Validate shapes/types of inputs and condition
loop For each patch/slice/window
Inferer->>Network: network(input_patch, condition=cond_patch)
Network-->>Inferer: output_patch
end
Inferer-->>User: aggregated_output
Assessment against linked issues
Poem
✨ Finishing Touches
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
monai/inferers/inferer.py (1)
817-835
: Network wrapper correctly handles condition tensor!The implementation properly squeezes the condition tensor along the spatial dimension (matching the input processing) and passes it to the network. Shape validation is appropriately handled in the
__call__
method before reaching this point, addressing the concern raised in the previous review.
🧹 Nitpick comments (1)
monai/inferers/inferer.py (1)
386-407
: Consider refactoring to reduce code duplication.The inference loop has identical logic duplicated between the conditioned (lines 387-398) and unconditioned (lines 400-407) cases. Only the
kwargs["condition"]
assignment differs.Consider refactoring to eliminate duplication:
- if condition is not None: - for (patches, locations, batch_size), (condition_patches, _, _) in zip( - self._batch_sampler(patches_locations), self._batch_sampler(condition_locations) - ): - # add patched condition to kwargs - kwargs["condition"] = condition_patches - # run inference - outputs = self._run_inference(network, patches, *args, **kwargs) - # initialize the mergers - if not mergers: - mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size) - # aggregate outputs - self._aggregate(outputs, locations, batch_size, mergers, ratios) - else: - for patches, locations, batch_size in self._batch_sampler(patches_locations): - # run inference - outputs = self._run_inference(network, patches, *args, **kwargs) - # initialize the mergers - if not mergers: - mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size) - # aggregate outputs - self._aggregate(outputs, locations, batch_size, mergers, ratios) + # Create an iterator for condition patches if condition is provided + condition_iter = self._batch_sampler(condition_locations) if condition is not None else None + + for patches, locations, batch_size in self._batch_sampler(patches_locations): + # Get condition patches if available + if condition_iter is not None: + condition_patches, _, _ = next(condition_iter) + kwargs["condition"] = condition_patches + + # run inference + outputs = self._run_inference(network, patches, *args, **kwargs) + # initialize the mergers + if not mergers: + mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size) + # aggregate outputs + self._aggregate(outputs, locations, batch_size, mergers, ratios)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
monai/inferers/inferer.py
(6 hunks)monai/inferers/utils.py
(3 hunks)tests/inferers/test_patch_inferer.py
(1 hunks)tests/inferers/test_slice_inferer.py
(1 hunks)tests/inferers/test_sliding_window_inference.py
(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
tests/inferers/test_patch_inferer.py (5)
monai/inferers/splitter.py (1)
SlidingWindowSplitter
(94-292)monai/inferers/merger.py (1)
AvgMerger
(103-203)monai/transforms/utils_pytorch_numpy_unification.py (1)
mean
(484-503)monai/inferers/inferer.py (1)
PatchInferer
(100-417)tests/test_utils.py (1)
assert_allclose
(99-139)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (18)
- GitHub Check: build-docs
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: packaging
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.4.1)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.10)
🔇 Additional comments (16)
monai/inferers/utils.py (5)
156-157
: Clean extraction of condition parameter.The condition tensor is properly extracted from kwargs and defaulted to None, ensuring it doesn't interfere with other keyword arguments passed to the predictor.
173-174
: Consistent padding logic for condition tensor.The condition tensor is padded identically to the input tensor when padding is required, ensuring synchronized spatial dimensions throughout the sliding window process.
227-229
: Proper handling of condition patches for batched windows.The condition patches are concatenated and moved to the sliding window device in parallel with input patches, maintaining consistency in batch processing.
232-234
: Consistent condition handling for single windows.The condition patch is extracted and moved to the sliding window device for single window processing, mirroring the input handling logic.
237-239
: Condition properly passed to predictor.The condition tensor is included in kwargs when calling the predictor, enabling conditional inference while maintaining backward compatibility when no condition is provided.
tests/inferers/test_slice_inferer.py (1)
56-88
: Comprehensive test coverage for SliceInferer with condition support.The test class properly validates:
- Model forward method accepts condition parameter
- Condition tensor processed synchronously with input
- Output correctness verified through sum comparison
- Multiple execution stability confirmed
- Consistent test parameterization with existing tests
tests/inferers/test_patch_inferer.py (2)
308-478
: Comprehensive test case variants for condition support.The test cases properly extend existing scenarios with condition support:
- Network functions consistently modified to accept condition parameter
- Expected results correctly adjusted for conditioning logic (e.g.,
TENSOR_4x4 * 2
for additive conditioning)- All major test scenarios covered including preprocessing, postprocessing, padding, and multi-threading
481-532
: Well-structured test class for PatchInferer condition support.The test class provides thorough validation:
- Handles both tensor and list inputs for condition (lines 503-506, 513-516, 524-527)
- Proper condition tensor creation with cloning for tensor inputs
- Comprehensive coverage of tensor, list, and dict outputs
- Consistent test structure with existing test patterns
tests/inferers/test_sliding_window_inference.py (4)
376-409
: Comprehensive basic testing for sliding window inference with condition.The test method properly validates:
- Condition tensor creation and device placement
- Compute function accepts condition parameter
- Expected results correctly calculated for conditioning
- Both
sliding_window_inference
function andSlidingWindowInferer
class tested
430-452
: Thorough device handling validation with condition support.The test verifies:
- Condition tensor device placement consistency with inputs
- Proper device handling for sliding window and compute devices
- Device assertions for both data and condition tensors in compute function
- Gradient propagation maintained with condition support
568-614
: Comprehensive args/kwargs handling with condition support.The test validates:
- Condition parameter properly passed alongside other args/kwargs
- Compute function signature supports condition parameter
- Expected results correctly calculated with multiple parameters
- Both function and class interfaces tested with condition
615-686
: Thorough multi-output testing with condition support.The test comprehensively validates:
- Tuple and dict output formats with condition support
- Condition tensor properly handled for different output resolutions
- Expected results correctly calculated for each output type
- Both
sliding_window_inference
function andSlidingWindowInferer
class testedmonai/inferers/inferer.py (4)
325-327
: Documentation looks good!The new
condition
parameter is well documented with clear requirements about shape matching.
330-353
: Comprehensive validation logic for condition tensor.The shape and type validation is thorough, covering both tensor and list cases with clear error messages.
569-576
: Condition handling looks correct!The implementation properly validates the condition shape and passes it through to
sliding_window_inference
. Usingkwargs.get()
instead ofkwargs.pop()
is appropriate here since the condition needs to remain in kwargs for the downstream function.
784-815
: SliceInferer condition support implemented correctly!The implementation properly validates the condition and passes it through to the network_wrapper. The approach of using the lambda function to inject the condition parameter works well.
Fixes #8220
Description
This PR adds support for optional conditioning in MONAI’s inferers, allowing models to receive auxiliary inputs for conditioning that are processed (patched, sliced) the same way as the inputs. This is particularly relevant for generative models like conditional GANs or DMs.
Example Usage:
Types of changes
PatchInferer
,SliceInferer
, andSlidingWindowInferer
to optionally accept acondition
tensor (passed as a kwarg).condition
can now be:None
(default)inputs
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.Additional extensions such as support for dense vector conditioning (e.g., (1, C, Z), with Z being the conditional dimension) could be explored in a follow-up PR if there’s interest.
Summary by CodeRabbit
New Features
Tests