Skip to content

Conversation

@aboubezari
Copy link
Contributor

@aboubezari aboubezari commented Sep 9, 2025

What does this PR do?

Type of change: Bug fix

Overview: If there a cast node connecting an input directly to an output, then the output will be totally disconnected due to naming issues. This fix will create specialized cast nodes for such edge cases and avoid removing them in the initial pass.

Usage

Autocast precision converter

Testing

Added a unittest that fails before my change, and passes after my fix.

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: No

Summary by CodeRabbit

  • Bug Fixes

    • Improved graph sanitization to handle cast nodes that map model inputs directly to outputs, avoiding conversion failures in low‑precision workflows.
  • New Features

    • Sanitization pass now runs before precision conversion and inserts safe pass‑through nodes where needed.
    • New configuration options for minimum opset, maximum IR version, and plugin handling.
  • Tests

    • Added unit tests covering casted input→output scenarios, parameterized for low‑precision types and I/O preservation.
  • Chores

    • Duplicate sanitizer logic and duplicate test definitions introduced (cleanup needed).

@aboubezari aboubezari requested a review from a team as a code owner September 9, 2025 00:47
@aboubezari aboubezari requested a review from ajrasane September 9, 2025 00:47
@copy-pr-bot
Copy link

copy-pr-bot bot commented Sep 9, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 9, 2025

Walkthrough

Adds GraphSanitizer logic to isolate Cast nodes that map graph inputs directly to outputs and integrates that sanitization into PrecisionConverter (new ctor params stored and used). Adds tests and fixture for cast-to-output models. Duplicate method and duplicate test/fixture definitions are present.

Changes

Cohort / File(s) Summary
Graph sanitization
modelopt/onnx/autocast/graphsanitizer.py
Adds sanitize_io_casts() to detect Cast nodes whose input is a graph input and whose output is a graph output; creates an intermediate output name, rewires the Cast to emit the intermediate, inserts an Identity node that produces the original output, preserves node order, and sanitize() now calls sanitize_io_casts(). Note: sanitize_io_casts is defined twice (duplicate).
Precision conversion
modelopt/onnx/autocast/precisionconverter.py
Imports GraphSanitizer; extends PrecisionConverter.__init__ with min_opset: int = 13, `max_ir_version: int
ONNX autocast tests
tests/unit/onnx/autocast/test_precisionconverter.py
Adds module constant LATEST_IR_VERSION_SUPPORTED_BY_ORT; adds fixture model_with_casted_input_to_output() returning (model, value_info_map, initializer_map, node_to_init_map) and parameterized test test_casted_input_to_output_model(...) (params: low_precision_type "fp16"/"bf16", keep_io_types True/False) that constructs PrecisionConverter using max_ir_version. Fixture/test pair appears duplicated in file.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Test as Unit Test
  participant PC as PrecisionConverter
  participant GS as GraphSanitizer
  participant G as ONNX Graph

  Test->>PC: construct PrecisionConverter(..., min_opset, max_ir_version, trt_plugins)
  PC->>PC: store params (min_opset, max_ir_version, trt_plugins)
  PC->>PC: _sanitize_model()
  PC->>GS: new GraphSanitizer(model, min_opset, max_ir_version, trt_plugins)
  GS->>G: inspect graph nodes/topology
  alt Cast node maps input -> output
    GS->>G: create intermediate output name
    GS->>G: rewire Cast to emit intermediate
    GS->>G: insert Identity consuming intermediate -> original output
    GS->>G: preserve node order
    GS-->>PC: return sanitized model
  else
    GS-->>PC: return model unchanged
  end
  PC->>PC: continue precision-conversion workflow
  PC-->>Test: return converted model
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

I nibbled wires where Casts did meet,
I slipped an Identity in to keep things neat.
fp16 and bf16 took a tiny hop,
Sanitized hops kept the outputs on top.
— a rabbit, cheering on the CI 🐇

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The PR title "[Autocast] Fix edge case casting input directly to output" directly reflects the main objectives and changes in this pull request. The title clearly identifies a bug fix for a specific edge case, which aligns with the PR's purpose of addressing issues where Cast nodes directly connect graph inputs to outputs. The changes demonstrate exactly what the title describes: a new sanitize_io_casts() method in GraphSanitizer that injects identity nodes to handle this edge case, integration into PrecisionConverter's conversion pipeline, and new test coverage for this scenario. The title is concise, specific, and accurately represents the primary change without being vague or misleading.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
modelopt/onnx/autocast/precisionconverter.py (2)

624-634: Preexisting-cast removal misses BF16 targets

Comment says FP16/BF16/FP32 casts are removed, but is_fp_cast only matches to ∈ {FLOAT16, FLOAT}. Include BF16 as a target to honor the contract.

-                is_fp_cast = cast_to_type in [
-                    onnx.TensorProto.FLOAT16,
-                    onnx.TensorProto.FLOAT,
-                ] and cast_from_type in [
+                is_fp_cast = cast_to_type in [
+                    onnx.TensorProto.FLOAT16,
+                    onnx.TensorProto.FLOAT,
+                    onnx.TensorProto.BFLOAT16,
+                ] and cast_from_type in [
                     onnx.TensorProto.FLOAT16,
                     onnx.TensorProto.FLOAT,
                     onnx.TensorProto.BFLOAT16,
                 ]

641-644: Guard for output-producing casts is ineffective

This condition checks if BOTH the cast input and output are network outputs, which never happens. It should keep casts that produce a network output.

-                    # Keep cast nodes that are necessary producers of network outputs
-                    if any(node.input[0] == out.name for out in self.model.graph.output) and any(
-                        node.output[0] == out.name for out in self.model.graph.output
-                    ):
+                    # Keep casts that produce a network output
+                    if node.output[0] in model_output_names:
                         continue
🧹 Nitpick comments (3)
modelopt/onnx/autocast/precisionconverter.py (1)

618-621: Insert duplicate IO-bridge casts deterministically (top of graph)

Appending at the tail can shuffle topo order. Inserting at index 0 is more stable for input-driven casts.

-        for cast in casts_to_add:
-            self.model.graph.node.append(cast)
+        for cast in casts_to_add:
+            self.model.graph.node.insert(0, cast)
tests/unit/onnx/autocast/test_precisionconverter.py (2)

1068-1071: Don’t write artifacts to /tmp in unit tests

The saved model isn’t used. Remove to keep tests hermetic.

-    onnx.save(model, "/tmp/model_with_casted_output.onnx")

1076-1091: Strengthen assertions: verify Y1 connectivity and dtype

Only checking onnx.checker is weak. Assert that Y1 remains produced by a Cast and retains FP32 (keep_io_types=True).

     converted_model = converter.convert(
         high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"]
     )
-    onnx.checker.check_model(converted_model)
+    onnx.checker.check_model(converted_model)
+    # Y1 should remain connected and produced by a Cast
+    y1_producers = utils.get_producer_nodes(converted_model, "Y1")
+    assert len(y1_producers) == 1
+    assert y1_producers[0].op_type == "Cast"
+    # keep_io_types=True -> FP32 I/O preserved
+    y1_vi = next(y for y in converted_model.graph.output if y.name == "Y1")
+    assert y1_vi.type.tensor_type.elem_type == TensorProto.FLOAT
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 512dbb7 and 16d5875.

📒 Files selected for processing (2)
  • modelopt/onnx/autocast/precisionconverter.py (1 hunks)
  • tests/unit/onnx/autocast/test_precisionconverter.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/unit/onnx/autocast/test_precisionconverter.py (2)
modelopt/onnx/utils.py (1)
  • check_model (557-569)
modelopt/onnx/autocast/precisionconverter.py (1)
  • convert (113-202)
🔇 Additional comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)

607-617: Confirm intent: skipping only the duplicated “new_cast”, not the renamed original

casts_to_skip holds the original name (now assigned to new_cast) and will not skip the renamed original (..._io_special_case). If that original gets removed, _bypass_cast_node will reconnect its consumers directly to the model input. Is that intended? If not, add the renamed name to the skip list as well right after renaming.

-                    casts_to_skip.append(node.name)
+                    casts_to_skip.append(node.name)
                     casts_to_add.append(new_cast)
                     # Now adjust the old cast's name, consumers and producers
                     node.name = f"{node.name}_io_special_case"
+                    casts_to_skip.append(node.name)  # keep the internal IO-special-case cast as well

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
modelopt/onnx/autocast/precisionconverter.py (3)

600-606: Also preserve optional Cast attributes (e.g., saturate) to avoid semantic drift

If the model uses newer opsets where Cast may carry optional attributes (like saturate), the duplicate should copy them.

Apply:

-                    new_cast = helper.make_node(
-                        "Cast",
-                        name=node.name,
-                        inputs=[node.input[0]],
-                        outputs=[node.output[0]],
-                        to=utils.get_cast_to_type(node),
-                    )
+                    # Copy optional attributes (e.g., 'saturate' in newer opsets)
+                    saturate = next((a.i for a in node.attribute if a.name == "saturate"), None)
+                    cast_attrs = {"to": utils.get_cast_to_type(node)}
+                    if saturate is not None:
+                        cast_attrs["saturate"] = saturate
+                    new_cast = helper.make_node(
+                        "Cast",
+                        name=node.name,
+                        inputs=[node.input[0]],
+                        outputs=[node.output[0]],
+                        **cast_attrs,
+                    )

618-621: Insert duplicate Cast adjacent to the original for better locality and readability

Appending at the end works but scatters IO nodes. Insert near the renamed source node to keep topology readable.

-        for cast in casts_to_add:
-            self.model.graph.node.append(cast)
+        # Preserve locality: insert duplicates next to their originals
+        for cast in casts_to_add:
+            target_idx = -1
+            for i, n in enumerate(self.model.graph.node):
+                if n.name == f"{cast.name}_io_special_case":
+                    target_idx = i
+                    break
+            if target_idx >= 0:
+                self.model.graph.node.insert(target_idx, cast)
+            else:
+                # Fallback to prepend to avoid end-append reordering
+                self.model.graph.node.insert(0, cast)

592-596: Use a set for casts_to_skip from the start

Minor nit for clarity and O(1) membership checks.

-        casts_to_skip = []
+        casts_to_skip: set[str] = set()
         # Add casts as a separate step to avoid modifying the graph while iterating over it
         casts_to_add = []
@@
-                    casts_to_skip.append(node.name)
+                    casts_to_skip.add(node.name)
@@
-        casts_to_skip = set(casts_to_skip)
+        # already a set

Also applies to: 620-621

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 16d5875 and 01308d6.

📒 Files selected for processing (1)
  • modelopt/onnx/autocast/precisionconverter.py (1 hunks)
🔇 Additional comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)

588-617: Solid IO-cast preservation strategy

Duplicating the IO-facing Cast, renaming the original, and rewiring consumers avoids disconnecting outputs while still enabling generic cast cleanup. This addresses the edge case cleanly.

value_info_map,
initializer_map,
node_to_init_map,
keep_io_types=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test fails when keep_io_types=False due to a graph input becoming a graph output directly, which violates the assertion in ModelOpt that all original input and output names should be maintained in the quantized model.

Copy link
Contributor

@galagam galagam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is indeed an edge case that challenges AutoCast's assumptions.
I'd like to propose an alternative approach: In GraphSanitizer, if an input is cast directly to output, inject an identity node.
@aboubezari @gcunhase

@gcunhase
Copy link
Contributor

This is indeed an edge case that challenges AutoCast's assumptions. I'd like to propose an alternative approach: In GraphSanitizer, if an input is cast directly to output, inject an identity node. @aboubezari @gcunhase

@galagam agree! @aboubezari, please make the suggested modification as a next step for this MR. Thanks!

@aboubezari
Copy link
Contributor Author

This is indeed an edge case that challenges AutoCast's assumptions. I'd like to propose an alternative approach: In GraphSanitizer, if an input is cast directly to output, inject an identity node. @aboubezari @gcunhase

@galagam agree! @aboubezari, please make the suggested modification as a next step for this MR. Thanks!

Will give this a try. Thanks.

@aboubezari
Copy link
Contributor Author

@gcunhase @galagam I've implemented the suggestion. Can you take another look?

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tests/unit/onnx/autocast/test_precisionconverter.py (2)

1078-1097: Recompute mappings after sanitize() and assert outputs preserved

sanitize() mutates the graph (adds Identity and rewires). Recompute mappings to avoid stale references, and assert that both outputs remain present.

Apply this diff:

 def test_casted_input_to_output_model(
     model_with_casted_input_to_output, low_precision_type, keep_io_types
 ):
     model, value_info_map, initializer_map, node_to_init_map = model_with_casted_input_to_output

     min_opset = 22 if low_precision_type == "bf16" else 13
     graph_sanitizer = GraphSanitizer(model, min_opset)
     graph_sanitizer.sanitize()
-    model = graph_sanitizer.model
+    model = graph_sanitizer.model
+    # Recompute mappings after graph mutation by sanitizer
+    value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
     converter = PrecisionConverter(
         model,
         value_info_map,
         initializer_map,
         node_to_init_map,
         keep_io_types=keep_io_types,
         low_precision_type=low_precision_type,
     )
     converted_model = converter.convert(
         high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"]
     )
     onnx.checker.check_model(converted_model)
+    # Ensure both original outputs are preserved
+    assert {o.name for o in converted_model.graph.output} == {"Y1", "Y2"}

Optionally, to validate type expectations on the direct-IO-cast output:

  • If keep_io_types is True: Y1/Y2 should be FLOAT.
  • If False: Y1/Y2 should be low_precision_onnx_type(low_precision_type).

1028-1067: Optional: Expand fixture to cover multi-Cast-from-same-input edge case

To guard against future regressions (multiple Casts from the same input to different outputs), add a companion fixture/test. Example snippet to append to this file:

@pytest.fixture
def model_with_two_casted_outputs():
    x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3])
    y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [2, 3])
    y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [2, 3])
    cast1 = helper.make_node("Cast", ["X"], ["Y1"], name="cast1", to=TensorProto.FLOAT)
    cast2 = helper.make_node("Cast", ["X"], ["Y2"], name="cast2", to=TensorProto.FLOAT)
    graph = helper.make_graph([cast1, cast2], "model_two_casted_outputs", [x], [y1, y2], [])
    model = helper.make_model(graph, producer_name="model_two_casted_outputs")
    model.opset_import[0].version = 20
    model.ir_version = 10
    onnx.checker.check_model(model)
    model = onnx_utils.infer_shapes(model)
    value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
    return model, value_info_map, initializer_map, node_to_init_map

@pytest.mark.parametrize("keep_io_types", [True, False])
@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
def test_two_casts_from_same_input(model_with_two_casted_outputs, keep_io_types, low_precision_type):
    model, value_info_map, initializer_map, node_to_init_map = model_with_two_casted_outputs
    gsani = GraphSanitizer(model, 22 if low_precision_type == "bf16" else 13)
    gsani.sanitize()
    model = gsani.model
    value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
    conv = PrecisionConverter(
        model, value_info_map, initializer_map, node_to_init_map,
        keep_io_types=keep_io_types, low_precision_type=low_precision_type
    )
    converted = conv.convert(high_precision_nodes=["cast1", "cast2"], low_precision_nodes=[])
    onnx.checker.check_model(converted)
    assert {o.name for o in converted.graph.output} == {"Y1", "Y2"}

I can send a PR update with these additions if helpful.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9363b09 and 03529fc.

📒 Files selected for processing (2)
  • modelopt/onnx/autocast/graphsanitizer.py (2 hunks)
  • tests/unit/onnx/autocast/test_precisionconverter.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/unit/onnx/autocast/test_precisionconverter.py (1)
modelopt/onnx/autocast/graphsanitizer.py (2)
  • GraphSanitizer (28-450)
  • sanitize (53-68)
🔇 Additional comments (2)
modelopt/onnx/autocast/graphsanitizer.py (1)

68-69: Good placement in sanitize()

Running sanitize_io_casts() after cleanup and IR versioning keeps the injected Identity from being swept early. LGTM.

tests/unit/onnx/autocast/test_precisionconverter.py (1)

23-24: Import addition looks good

GraphSanitizer is used below; import is correct and scoped for tests.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/unit/onnx/autocast/test_precisionconverter.py (1)

1077-1098: Consider adding explicit assertions for the edge case.

The test validates that the model passes check_model after conversion, which confirms basic correctness. However, it could be strengthened by explicitly asserting that:

  1. The Identity node was inserted by the sanitizer
  2. The original input and output names are preserved
  3. The Cast node is still present

This would make the test's intent clearer and provide better regression coverage for the specific edge case being addressed.

Example assertions to add before line 1098:

# Verify Identity node was inserted for the IO cast
identity_nodes = [n for n in converted_model.graph.node if n.op_type == "Identity"]
assert len(identity_nodes) >= 1, "Expected at least one Identity node for IO cast isolation"

# Verify original I/O names are preserved
assert converted_model.graph.input[0].name == "X"
assert any(o.name == "Y1" for o in converted_model.graph.output)
assert any(o.name == "Y2" for o in converted_model.graph.output)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between caf9d39 and 0e0d11a.

📒 Files selected for processing (2)
  • modelopt/onnx/autocast/precisionconverter.py (5 hunks)
  • tests/unit/onnx/autocast/test_precisionconverter.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/unit/onnx/autocast/test_precisionconverter.py (2)
modelopt/onnx/utils.py (1)
  • check_model (557-569)
modelopt/onnx/autocast/precisionconverter.py (1)
  • convert (120-211)
modelopt/onnx/autocast/precisionconverter.py (1)
modelopt/onnx/autocast/graphsanitizer.py (2)
  • GraphSanitizer (28-460)
  • sanitize (53-68)
🔇 Additional comments (6)
tests/unit/onnx/autocast/test_precisionconverter.py (2)

33-33: LGTM! Constant definition is clear.

The constant LATEST_IR_VERSION_SUPPORTED_BY_ORT is well-named and appropriately scoped for test usage.


1031-1074: LGTM! Fixture correctly models the edge case.

The fixture constructs a graph where a Cast node connects input X directly to output Y1, alongside a separate computation path. This accurately captures the edge case described in the PR objectives.

modelopt/onnx/autocast/precisionconverter.py (4)

35-35: LGTM! Import is correctly placed.

The GraphSanitizer import is appropriately added to support the new sanitization step.


77-79: LGTM! New parameters are well-integrated.

The three new parameters (min_opset, max_ir_version, trt_plugins) are properly added to the constructor with sensible defaults and stored as instance variables for use in the sanitization step.

Also applies to: 116-118


142-142: LGTM! Sanitization is correctly positioned in the workflow.

The _sanitize_model() call is appropriately placed early in the convert() method, after model validation but before unsupported op filtering. This ensures the graph is normalized (e.g., IO casts are isolated with Identity nodes) before precision conversion logic executes.


1043-1051: LGTM! Sanitization method is clean and correct.

The _sanitize_model() method correctly instantiates GraphSanitizer with the model and relevant parameters, invokes sanitization, and updates the model. This cleanly integrates the graph normalization step (including IO cast isolation with Identity nodes) into the precision conversion workflow.

@aboubezari aboubezari force-pushed the fix_auto_cast_io_special_nodes_bug branch 2 times, most recently from 7dd644a to 4028e2a Compare October 1, 2025 15:10
@aboubezari
Copy link
Contributor Author

@galagam @gcunhase I've improved the fix and confirmed it fixes some customer issues we've been having. Please take a look and let me know if there's anything else to change. Thanks!

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)

77-92: Document the new parameters in the docstring.

The three new parameters (min_opset, max_ir_version, trt_plugins) are not documented in the docstring.

Apply this diff:

         Args:
             model: ONNX model to convert.
             value_info_map: Map of tensor names to value info.
             initializer_map: Map of tensor names to initializers.
             node_to_init_map: Map of node names to lists of initializer names.
             keep_io_types: Keep the input and output types of the model, otherwise they will be converted.
             low_precision_type: Precision to convert to.
             init_conversion_max_bytes: Maximum size in bytes for initializer conversion. Larger initializers will be
                                        cast at runtime.
             custom_ops: List of custom ops.
+            min_opset: Minimum opset version to use for sanitization (default: 13).
+            max_ir_version: Maximum IR version supported by ORT (default: None).
+            trt_plugins: List of TensorRT plugin library paths in .so format (default: []).
🧹 Nitpick comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)

1063-1071: Consider matching parameter order to the GraphSanitizer constructor signature.

While using keyword arguments makes order irrelevant, matching the constructor's parameter order improves readability.

Apply this diff:

     def _sanitize_model(self):
         graph_sanitizer = GraphSanitizer(
             self.model,
             self.min_opset,
-            trt_plugins=self.trt_plugins,
             max_ir_version=self.max_ir_version,
+            trt_plugins=self.trt_plugins if self.trt_plugins is not None else [],
         )
         graph_sanitizer.sanitize()
         self.model = graph_sanitizer.model
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0e0d11a and 4028e2a.

📒 Files selected for processing (3)
  • modelopt/onnx/autocast/graphsanitizer.py (2 hunks)
  • modelopt/onnx/autocast/precisionconverter.py (5 hunks)
  • tests/unit/onnx/autocast/test_precisionconverter.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/unit/onnx/autocast/test_precisionconverter.py
  • modelopt/onnx/autocast/graphsanitizer.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/autocast/precisionconverter.py (1)
modelopt/onnx/autocast/graphsanitizer.py (2)
  • GraphSanitizer (28-560)
  • sanitize (53-69)
🔇 Additional comments (2)
modelopt/onnx/autocast/precisionconverter.py (2)

35-35: LGTM!

The import is correctly placed and necessary for the sanitization functionality.


142-143: LGTM!

The sanitization is correctly invoked early in the conversion workflow, after validation but before precision conversion operations.

Comment on lines +77 to +79
min_opset: int = 13,
max_ir_version: int | None = None,
trt_plugins: list[str] | None = [],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Replace mutable default argument with None.

The default value trt_plugins=[] creates a mutable default argument, which can lead to shared state between instances if the list is mutated.

Apply this diff:

-        trt_plugins: list[str] | None = [],
+        trt_plugins: list[str] | None = None,

Then update line 1067 to handle the None case:

         graph_sanitizer = GraphSanitizer(
             self.model,
             self.min_opset,
-            trt_plugins=self.trt_plugins,
+            trt_plugins=self.trt_plugins if self.trt_plugins is not None else [],
             max_ir_version=self.max_ir_version,
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
min_opset: int = 13,
max_ir_version: int | None = None,
trt_plugins: list[str] | None = [],
def __init__(
...,
min_opset: int = 13,
max_ir_version: int | None = None,
trt_plugins: list[str] | None = None,
):
...
Suggested change
min_opset: int = 13,
max_ir_version: int | None = None,
trt_plugins: list[str] | None = [],
graph_sanitizer = GraphSanitizer(
self.model,
self.min_opset,
trt_plugins=self.trt_plugins if self.trt_plugins is not None else [],
max_ir_version=self.max_ir_version,
)
🤖 Prompt for AI Agents
In modelopt/onnx/autocast/precisionconverter.py around lines 77 to 79, change
the function signature to use trt_plugins: list[str] | None = None instead of a
mutable default list, and then at line 1067 update the code to treat a None
value as an empty list (e.g., set local_trt_plugins = trt_plugins or [] before
using it) so any subsequent iterations or mutations operate on a fresh list
rather than a shared default.

@codecov
Copy link

codecov bot commented Oct 10, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.41%. Comparing base (718fd9e) to head (315ef04).
⚠️ Report is 6 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #305      +/-   ##
==========================================
+ Coverage   73.37%   73.41%   +0.04%     
==========================================
  Files         180      180              
  Lines       17934    17957      +23     
==========================================
+ Hits        13159    13184      +25     
+ Misses       4775     4773       -2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tests/unit/onnx/autocast/test_precisionconverter.py (2)

28-28: Consider documenting or relocating the constant.

The constant LATEST_IR_VERSION_SUPPORTED_BY_ORT = 10 is defined at module level but only used in one test. Consider either:

  • Adding a comment explaining why version 10 is the latest supported by ORT, or
  • Moving it to a shared constants file if other tests might need it

1154-1175: Consider adding assertions to verify the sanitization behavior.

The test validates that the model remains valid after conversion but doesn't verify that the sanitization actually occurred as expected. Consider adding assertions to check:

  • That an Identity node was inserted between the Cast and the output
  • That the Cast node's output was rewired to the intermediate name
  • That the output remains correctly connected

This would make the test more robust and easier to debug if the sanitization logic changes in the future.

Example assertions you could add:

# Verify sanitization occurred
identity_nodes = [n for n in converted_model.graph.node if n.op_type == "Identity"]
assert len(identity_nodes) >= 1, "Expected at least one Identity node after sanitization"

# Verify output Y1 is still correctly connected
y1_output = next(o for o in converted_model.graph.output if o.name == "Y1")
assert y1_output.name == "Y1", "Output Y1 should maintain its original name"
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4028e2a and ac76dc7.

📒 Files selected for processing (2)
  • modelopt/onnx/autocast/graphsanitizer.py (2 hunks)
  • tests/unit/onnx/autocast/test_precisionconverter.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/unit/onnx/autocast/test_precisionconverter.py (1)
modelopt/onnx/autocast/precisionconverter.py (1)
  • convert (120-214)
🔇 Additional comments (1)
modelopt/onnx/autocast/graphsanitizer.py (1)

347-379: LGTM! The sanitization logic correctly handles the edge case.

The implementation properly addresses the issue where Cast nodes connecting inputs directly to outputs can cause disconnection problems. The approach of inserting Identity nodes with unique intermediate names is sound, and the offset-based insertion logic correctly maintains graph ordering.

Key strengths:

  • Unique intermediate naming per output (f"{cast_output_name}__io_cast_src") prevents collisions when multiple outputs are cast from the same input
  • Defensive checks on lines 359-360 prevent crashes if nodes lack expected inputs/outputs
  • Insertion list with offset tracking (lines 378-379) correctly handles index shifts as nodes are inserted

The concerns raised in previous reviews about name collisions and insertion ordering have been properly addressed.

@gcunhase gcunhase enabled auto-merge (squash) October 14, 2025 16:16
@gcunhase gcunhase disabled auto-merge October 14, 2025 16:16
@gcunhase gcunhase enabled auto-merge (squash) October 14, 2025 16:17
@galagam galagam force-pushed the fix_auto_cast_io_special_nodes_bug branch from ac76dc7 to d61c4ad Compare October 15, 2025 06:19
@galagam
Copy link
Contributor

galagam commented Oct 15, 2025

@aboubezari The PR is approved. Can you please rebase and sign your commits?
See: https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md#%EF%B8%8F-signing-your-work

Once commits are signed and rebased, PR should be able to auto-merge.

auto-merge was automatically disabled October 15, 2025 19:57

Head branch was pushed to by a user without write access

@aboubezari aboubezari force-pushed the fix_auto_cast_io_special_nodes_bug branch from d61c4ad to 76d4505 Compare October 15, 2025 19:57
@aboubezari
Copy link
Contributor Author

@aboubezari The PR is approved. Can you please rebase and sign your commits? See: https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md#%EF%B8%8F-signing-your-work

Once commits are signed and rebased, PR should be able to auto-merge.

@galagam I can confirm my commits are verified and signed off. I have just fetched upstream main and rebased.

@galagam galagam enabled auto-merge (squash) October 16, 2025 05:41
@galagam galagam force-pushed the fix_auto_cast_io_special_nodes_bug branch 3 times, most recently from f125429 to fac6a17 Compare October 16, 2025 08:24
Update modelopt/onnx/autocast/precisionconverter.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: aboubezari <126983138+aboubezari@users.noreply.github.com>

cleanup

Signed-off-by: Ali Boubezari <aboubezari@nuro.ai>

Inject identity nodes in sanitizer; revert existing logic; update test

Signed-off-by: Ali Boubezari <aboubezari@nuro.ai>

Inject identity nodes in sanitizer; revert existing logic; update test

Signed-off-by: Ali Boubezari <aboubezari@nuro.ai>

move pass

Signed-off-by: Ali Boubezari <aboubezari@nuro.ai>

call sanitizer in precision converter

Signed-off-by: Ali Boubezari <aboubezari@nuro.ai>

address review comments

Signed-off-by: Ali Boubezari <aboubezari@nuro.ai>
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
@galagam galagam force-pushed the fix_auto_cast_io_special_nodes_bug branch from fac6a17 to 315ef04 Compare October 16, 2025 08:27
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)

77-79: Replace mutable default argument with None.

The default value trt_plugins=[] creates a mutable default argument, which can lead to shared state between instances if the list is mutated.

🧹 Nitpick comments (1)
tests/unit/onnx/autocast/test_precisionconverter.py (1)

1154-1175: LGTM! Test exercises the new parameters and edge case.

The test correctly uses the new min_opset, max_ir_version, and trt_plugins parameters, and exercises the edge case of a Cast node connecting an input directly to an output. The test demonstrates that the model validates successfully after the fix.

Optional enhancement: Consider adding more specific assertions to verify the expected behavior of the sanitization step. For example, you could check that an Identity node was inserted or that the graph structure matches expectations. This would make the test more robust and self-documenting.

Example assertions you could add:

# Verify the graph structure after conversion
identity_nodes = [n for n in converted_model.graph.node if n.op_type == "Identity"]
assert len(identity_nodes) > 0, "Expected Identity node(s) to be inserted by sanitization"

# Verify output names are preserved
assert converted_model.graph.output[0].name == "Y1"
assert converted_model.graph.output[1].name == "Y2"
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fac6a17 and 315ef04.

📒 Files selected for processing (3)
  • modelopt/onnx/autocast/graphsanitizer.py (2 hunks)
  • modelopt/onnx/autocast/precisionconverter.py (5 hunks)
  • tests/unit/onnx/autocast/test_precisionconverter.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/onnx/autocast/graphsanitizer.py
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/onnx/autocast/precisionconverter.py (1)
modelopt/onnx/autocast/graphsanitizer.py (2)
  • GraphSanitizer (28-557)
  • sanitize (53-69)
tests/unit/onnx/autocast/test_precisionconverter.py (2)
modelopt/onnx/utils.py (1)
  • check_model (557-569)
modelopt/onnx/autocast/precisionconverter.py (1)
  • convert (120-214)
🔇 Additional comments (3)
modelopt/onnx/autocast/precisionconverter.py (1)

142-143: LGTM! GraphSanitizer integration is clean.

The sanitization step is correctly placed after initial validation and before low-precision processing. The _sanitize_model implementation properly instantiates GraphSanitizer with the new parameters and updates the model.

Also applies to: 1063-1071

tests/unit/onnx/autocast/test_precisionconverter.py (2)

1108-1151: LGTM! Fixture correctly reproduces the edge case.

The fixture creates a model with a Cast node that directly connects input X to output Y1, which is the exact edge case scenario described in the PR objectives. The model structure is valid and appropriate for testing the sanitization behavior.


28-28: No update needed: IR version constant is accurate. IR version 10 is confirmed as the latest supported by ONNX Runtime per the compatibility table and release notes.

@galagam
Copy link
Contributor

galagam commented Oct 16, 2025

/ok to test 315ef04

@galagam galagam merged commit ae78b9f into NVIDIA:main Oct 20, 2025
30 of 32 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants