Skip to content

Conversation

@ytl0623
Copy link
Contributor

@ytl0623 ytl0623 commented Jan 22, 2026

Fixes #8276

Description

  • Added a new argument apply_inverse_to_pred. Defaults to True to preserve backward compatibility. When set to False, it skips the inverse transformation step and aggregates the model predictions directly.
  • Added a new unit test to simulate a classification task with spatial augmentation, verifying that the aggregation works correctly without spatial inversion.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Signed-off-by: ytl0623 <david89062388@gmail.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 22, 2026

📝 Walkthrough

Walkthrough

Adds a boolean parameter apply_inverse_to_pred (default True) to TestTimeAugmentation, stored as self.apply_inverse_to_pred. __call__ now conditionally applies inverse transforms to predictions when apply_inverse_to_pred is True; when False, raw predictions are collected (enabling non-spatial outputs). _check_transforms warning logic was adjusted to consider invertibility relative to apply_inverse_to_pred. Docstrings and constructor signature updated. Tests: new test_non_spatial_output validates behavior with apply_inverse_to_pred=False; a test UNet call changes strides=(2, 2)strides=(2,).

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 2
❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 42.86% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Out of Scope Changes check ❓ Inconclusive UNet stride change (2,2 to 2,) in test file appears unrelated to core objective of supporting non-spatial predictions; warrants clarification on necessity. Clarify why UNet stride modification is required for non-spatial TTA support, or revert to original strides if it's an unintended change.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed Title accurately summarizes the main change: generalizing TestTimeAugmentation to support non-spatial predictions via the new apply_inverse_to_pred parameter.
Description check ✅ Passed Description covers key changes, preserves backward compatibility, documents the new parameter, and confirms new tests were added. All essential sections are present and complete.
Linked Issues check ✅ Passed The implementation directly addresses issue #8276 by adding apply_inverse_to_pred parameter to skip inverse transforms for non-spatial predictions, with test coverage for classification use case.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@ericspod
Copy link
Member

hi @ytl0623 thanks for this change, I think it's fine in principle. The _check_transforms method should be changed to account for when the new argument is False, in which case it doesn't need to check for invertibility of transforms. I noticed other issues with the original version of this method so I'd propose something like the following (which I haven't tested):

def _check_transforms(self):
    """Should be at least 1 random transform, and all random transforms should be invertible."""
    transforms = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms
    warns=[]
    randoms=[]
    for idx, t in transforms:
        if isinstance(t, Randomizable):
            randoms.append(t)
            if self.apply_inverse_to_pred and not isinstance(t, InvertibleTransform):
                warns.append(f"Transform {idx} (type {type(t).__name__}) not invertible.")
                
    if len(randoms)==0:
        warns.append("TTA usually requires at least one `Randomizable` transform in the given transform sequence.")
        
    if len(warns)>0:
        warnings.warn("TTA has encountered issues with the given transforms:"+"\n  ".join(warns))

Please check this logic, it might be that we need to check all transforms for invertibility whether they're random or not, but what I have here is equivalent to the original.

ytl0623 and others added 2 commits January 24, 2026 00:07
Signed-off-by: ytl0623 <david89062388@gmail.com>
@ytl0623
Copy link
Contributor Author

ytl0623 commented Jan 23, 2026

Hi @ericspod, thanks for the suggestion!

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@monai/data/test_time_augmentation.py`:
- Around line 67-71: The docstring for test-time augmentation (function/class
using parameter names transform, batch_size, and apply_inverse_to_pred)
incorrectly states "All random transforms must be of type InvertibleTransform";
update the docstring (and the transform type hint if present) to reflect that
non-invertible random transforms are allowed when apply_inverse_to_pred=False
and only need to be invertible when apply_inverse_to_pred=True; change the
wording in both occurrences (the block around the transform description and the
later paragraph at lines ~115-118) to describe this conditional requirement and,
if applicable, broaden the transform type hint to accept non-invertible
Randomizable types when apply_inverse_to_pred is False.
- Around line 174-175: The warning message built from the local variable warns
is missing a newline after the colon and does not set a stacklevel, so update
the warnings.warn call to prepend a newline (e.g., "TTA has encountered issues
with the given transforms:\n  " + "\n  ".join(warns)) and pass an appropriate
stacklevel (e.g., stacklevel=2) so user stack traces point to the caller; locate
and modify the warnings.warn(...) invocation that uses the warns list in
test_time_augmentation.py.
- Around line 208-213: The non-inverse branch currently returns raw predictions
and skips all Invertd post-processing (to_tensor, output_device, post_func);
update the branch so decollated items still go through the same inverter
pipeline (or a post-processing-only path) before extracting self._pred_key.
Concretely, in the else branch replace outs.extend([i[self._pred_key] for i in
decollate_batch(b)]) with code that calls self.inverter on each
PadListDataCollate.inverse(i) (or calls an Invertd method/flag that runs only
to_tensor/output_device/post_func but not spatial inverse) and then extracts
[self._pred_key]; ensure the call honors to_tensor, output_device and post_func
parameters so behavior matches the apply_inverse_to_pred=True path.

Comment on lines 67 to 71
Args:
transform: transform (or composed) to be applied to each realization. At least one transform must be of type
`RandomizableTrait` (i.e. `Randomizable`, `RandomizableTransform`, or `RandomizableTrait`).
. All random transforms must be of type `InvertibleTransform`.
All random transforms must be of type `InvertibleTransform`.
batch_size: number of realizations to infer at once.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Docstring still claims all random transforms must be invertible.

With apply_inverse_to_pred=False, non‑invertible random transforms are allowed. Update the docstring (and consider widening the transform type hint) to match behavior. As per coding guidelines, keep docstrings aligned with behavior.

✅ Suggested docstring fix
-        All random transforms must be of type `InvertibleTransform`.
+        When `apply_inverse_to_pred` is True, all random transforms must be of type `InvertibleTransform`.

Also applies to: 115-118

🤖 Prompt for AI Agents
In `@monai/data/test_time_augmentation.py` around lines 67 - 71, The docstring for
test-time augmentation (function/class using parameter names transform,
batch_size, and apply_inverse_to_pred) incorrectly states "All random transforms
must be of type InvertibleTransform"; update the docstring (and the transform
type hint if present) to reflect that non-invertible random transforms are
allowed when apply_inverse_to_pred=False and only need to be invertible when
apply_inverse_to_pred=True; change the wording in both occurrences (the block
around the transform description and the later paragraph at lines ~115-118) to
describe this conditional requirement and, if applicable, broaden the transform
type hint to accept non-invertible Randomizable types when apply_inverse_to_pred
is False.

Comment on lines +174 to +175
if len(warns) > 0:
warnings.warn("TTA has encountered issues with the given transforms:" + "\n ".join(warns))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Improve warning formatting and stacklevel.

Current warning concatenates without a newline and lacks a stacklevel, making diagnostics noisy.

🛠️ Proposed fix
-        warnings.warn("TTA has encountered issues with the given transforms:" + "\n  ".join(warns))
+        warnings.warn(
+            "TTA has encountered issues with the given transforms:\n  " + "\n  ".join(warns),
+            stacklevel=2,
+        )
🧰 Tools
🪛 Ruff (0.14.13)

175-175: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)

🤖 Prompt for AI Agents
In `@monai/data/test_time_augmentation.py` around lines 174 - 175, The warning
message built from the local variable warns is missing a newline after the colon
and does not set a stacklevel, so update the warnings.warn call to prepend a
newline (e.g., "TTA has encountered issues with the given transforms:\n  " + "\n
".join(warns)) and pass an appropriate stacklevel (e.g., stacklevel=2) so user
stack traces point to the caller; locate and modify the warnings.warn(...)
invocation that uses the warns list in test_time_augmentation.py.

Comment on lines 208 to 213
b[self._pred_key] = self.inferrer_fn(b[self.image_key].to(self.device))
outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)])
if self.apply_inverse_to_pred:
outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)])
else:
outs.extend([i[self._pred_key] for i in decollate_batch(b)])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Step 1: Examine Invertd class and its post-processing behavior
rg -n "class Invertd" monai/transforms/post/dictionary.py -A 30

Repository: Project-MONAI/MONAI

Length of output: 1743


🏁 Script executed:

#!/bin/bash
# Step 2: Check the test_time_augmentation.py file to understand context
head -250 monai/data/test_time_augmentation.py | tail -70

Repository: Project-MONAI/MONAI

Length of output: 3385


🏁 Script executed:

#!/bin/bash
# Step 3: Search for post_func and output_device in Invertd
rg -n "post_func|output_device|to_tensor" monai/transforms/post/dictionary.py | head -30

Repository: Project-MONAI/MONAI

Length of output: 1179


🏁 Script executed:

#!/bin/bash
# Search for tests covering apply_inverse_to_pred behavior
rg -n "apply_inverse_to_pred" --type py -B 2 -A 2

Repository: Project-MONAI/MONAI

Length of output: 3372


🏁 Script executed:

#!/bin/bash
# Check how inverter is initialized - what post_func/to_tensor/output_device values are used
rg -n "self.inverter|Invertd" monai/data/test_time_augmentation.py -B 2 -A 2

Repository: Project-MONAI/MONAI

Length of output: 961


🏁 Script executed:

#!/bin/bash
# Look at the full __init__ of TestTimeAugmentation to understand inverter setup
sed -n '1,207p' monai/data/test_time_augmentation.py | tail -100

Repository: Project-MONAI/MONAI

Length of output: 4408


Non-inverse branch ignores to_tensor, output_device, and post_func parameters.

When apply_inverse_to_pred=False, the code skips Invertd entirely (line 212), so these configured post-processing steps are never applied. Users may expect them to work regardless of inversion. Either apply post-processing in both branches, or add a post-processing-only path to Invertd when inversion is not needed.

🤖 Prompt for AI Agents
In `@monai/data/test_time_augmentation.py` around lines 208 - 213, The non-inverse
branch currently returns raw predictions and skips all Invertd post-processing
(to_tensor, output_device, post_func); update the branch so decollated items
still go through the same inverter pipeline (or a post-processing-only path)
before extracting self._pred_key. Concretely, in the else branch replace
outs.extend([i[self._pred_key] for i in decollate_batch(b)]) with code that
calls self.inverter on each PadListDataCollate.inverse(i) (or calls an Invertd
method/flag that runs only to_tensor/output_device/post_func but not spatial
inverse) and then extracts [self._pred_key]; ensure the call honors to_tensor,
output_device and post_func parameters so behavior matches the
apply_inverse_to_pred=True path.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Generalize TestTimeAugmentation to non-spatial predictions

2 participants