-
Notifications
You must be signed in to change notification settings - Fork 192
[Autocast] Fix edge case casting input directly to output #305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Autocast] Fix edge case casting input directly to output #305
Conversation
WalkthroughAdds GraphSanitizer logic to isolate Cast nodes that map graph inputs directly to outputs and integrates that sanitization into PrecisionConverter (new ctor params stored and used). Adds tests and fixture for cast-to-output models. Duplicate method and duplicate test/fixture definitions are present. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Test as Unit Test
participant PC as PrecisionConverter
participant GS as GraphSanitizer
participant G as ONNX Graph
Test->>PC: construct PrecisionConverter(..., min_opset, max_ir_version, trt_plugins)
PC->>PC: store params (min_opset, max_ir_version, trt_plugins)
PC->>PC: _sanitize_model()
PC->>GS: new GraphSanitizer(model, min_opset, max_ir_version, trt_plugins)
GS->>G: inspect graph nodes/topology
alt Cast node maps input -> output
GS->>G: create intermediate output name
GS->>G: rewire Cast to emit intermediate
GS->>G: insert Identity consuming intermediate -> original output
GS->>G: preserve node order
GS-->>PC: return sanitized model
else
GS-->>PC: return model unchanged
end
PC->>PC: continue precision-conversion workflow
PC-->>Test: return converted model
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
modelopt/onnx/autocast/precisionconverter.py (2)
624-634: Preexisting-cast removal misses BF16 targetsComment says FP16/BF16/FP32 casts are removed, but
is_fp_castonly matchesto ∈ {FLOAT16, FLOAT}. Include BF16 as a target to honor the contract.- is_fp_cast = cast_to_type in [ - onnx.TensorProto.FLOAT16, - onnx.TensorProto.FLOAT, - ] and cast_from_type in [ + is_fp_cast = cast_to_type in [ + onnx.TensorProto.FLOAT16, + onnx.TensorProto.FLOAT, + onnx.TensorProto.BFLOAT16, + ] and cast_from_type in [ onnx.TensorProto.FLOAT16, onnx.TensorProto.FLOAT, onnx.TensorProto.BFLOAT16, ]
641-644: Guard for output-producing casts is ineffectiveThis condition checks if BOTH the cast input and output are network outputs, which never happens. It should keep casts that produce a network output.
- # Keep cast nodes that are necessary producers of network outputs - if any(node.input[0] == out.name for out in self.model.graph.output) and any( - node.output[0] == out.name for out in self.model.graph.output - ): + # Keep casts that produce a network output + if node.output[0] in model_output_names: continue
🧹 Nitpick comments (3)
modelopt/onnx/autocast/precisionconverter.py (1)
618-621: Insert duplicate IO-bridge casts deterministically (top of graph)Appending at the tail can shuffle topo order. Inserting at index 0 is more stable for input-driven casts.
- for cast in casts_to_add: - self.model.graph.node.append(cast) + for cast in casts_to_add: + self.model.graph.node.insert(0, cast)tests/unit/onnx/autocast/test_precisionconverter.py (2)
1068-1071: Don’t write artifacts to /tmp in unit testsThe saved model isn’t used. Remove to keep tests hermetic.
- onnx.save(model, "/tmp/model_with_casted_output.onnx")
1076-1091: Strengthen assertions: verify Y1 connectivity and dtypeOnly checking
onnx.checkeris weak. Assert that Y1 remains produced by a Cast and retains FP32 (keep_io_types=True).converted_model = converter.convert( high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"] ) - onnx.checker.check_model(converted_model) + onnx.checker.check_model(converted_model) + # Y1 should remain connected and produced by a Cast + y1_producers = utils.get_producer_nodes(converted_model, "Y1") + assert len(y1_producers) == 1 + assert y1_producers[0].op_type == "Cast" + # keep_io_types=True -> FP32 I/O preserved + y1_vi = next(y for y in converted_model.graph.output if y.name == "Y1") + assert y1_vi.type.tensor_type.elem_type == TensorProto.FLOAT
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/onnx/autocast/precisionconverter.py(1 hunks)tests/unit/onnx/autocast/test_precisionconverter.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/unit/onnx/autocast/test_precisionconverter.py (2)
modelopt/onnx/utils.py (1)
check_model(557-569)modelopt/onnx/autocast/precisionconverter.py (1)
convert(113-202)
🔇 Additional comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)
607-617: Confirm intent: skipping only the duplicated “new_cast”, not the renamed original
casts_to_skipholds the original name (now assigned tonew_cast) and will not skip the renamed original (..._io_special_case). If that original gets removed,_bypass_cast_nodewill reconnect its consumers directly to the model input. Is that intended? If not, add the renamed name to the skip list as well right after renaming.- casts_to_skip.append(node.name) + casts_to_skip.append(node.name) casts_to_add.append(new_cast) # Now adjust the old cast's name, consumers and producers node.name = f"{node.name}_io_special_case" + casts_to_skip.append(node.name) # keep the internal IO-special-case cast as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
modelopt/onnx/autocast/precisionconverter.py (3)
600-606: Also preserve optional Cast attributes (e.g., saturate) to avoid semantic driftIf the model uses newer opsets where Cast may carry optional attributes (like saturate), the duplicate should copy them.
Apply:
- new_cast = helper.make_node( - "Cast", - name=node.name, - inputs=[node.input[0]], - outputs=[node.output[0]], - to=utils.get_cast_to_type(node), - ) + # Copy optional attributes (e.g., 'saturate' in newer opsets) + saturate = next((a.i for a in node.attribute if a.name == "saturate"), None) + cast_attrs = {"to": utils.get_cast_to_type(node)} + if saturate is not None: + cast_attrs["saturate"] = saturate + new_cast = helper.make_node( + "Cast", + name=node.name, + inputs=[node.input[0]], + outputs=[node.output[0]], + **cast_attrs, + )
618-621: Insert duplicate Cast adjacent to the original for better locality and readabilityAppending at the end works but scatters IO nodes. Insert near the renamed source node to keep topology readable.
- for cast in casts_to_add: - self.model.graph.node.append(cast) + # Preserve locality: insert duplicates next to their originals + for cast in casts_to_add: + target_idx = -1 + for i, n in enumerate(self.model.graph.node): + if n.name == f"{cast.name}_io_special_case": + target_idx = i + break + if target_idx >= 0: + self.model.graph.node.insert(target_idx, cast) + else: + # Fallback to prepend to avoid end-append reordering + self.model.graph.node.insert(0, cast)
592-596: Use a set for casts_to_skip from the startMinor nit for clarity and O(1) membership checks.
- casts_to_skip = [] + casts_to_skip: set[str] = set() # Add casts as a separate step to avoid modifying the graph while iterating over it casts_to_add = [] @@ - casts_to_skip.append(node.name) + casts_to_skip.add(node.name) @@ - casts_to_skip = set(casts_to_skip) + # already a setAlso applies to: 620-621
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/onnx/autocast/precisionconverter.py(1 hunks)
🔇 Additional comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)
588-617: Solid IO-cast preservation strategyDuplicating the IO-facing Cast, renaming the original, and rewiring consumers avoids disconnecting outputs while still enabling generic cast cleanup. This addresses the edge case cleanly.
| value_info_map, | ||
| initializer_map, | ||
| node_to_init_map, | ||
| keep_io_types=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test fails when keep_io_types=False due to a graph input becoming a graph output directly, which violates the assertion in ModelOpt that all original input and output names should be maintained in the quantized model.
galagam
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is indeed an edge case that challenges AutoCast's assumptions.
I'd like to propose an alternative approach: In GraphSanitizer, if an input is cast directly to output, inject an identity node.
@aboubezari @gcunhase
@galagam agree! @aboubezari, please make the suggested modification as a next step for this MR. Thanks! |
Will give this a try. Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tests/unit/onnx/autocast/test_precisionconverter.py (2)
1078-1097: Recompute mappings after sanitize() and assert outputs preservedsanitize() mutates the graph (adds Identity and rewires). Recompute mappings to avoid stale references, and assert that both outputs remain present.
Apply this diff:
def test_casted_input_to_output_model( model_with_casted_input_to_output, low_precision_type, keep_io_types ): model, value_info_map, initializer_map, node_to_init_map = model_with_casted_input_to_output min_opset = 22 if low_precision_type == "bf16" else 13 graph_sanitizer = GraphSanitizer(model, min_opset) graph_sanitizer.sanitize() - model = graph_sanitizer.model + model = graph_sanitizer.model + # Recompute mappings after graph mutation by sanitizer + value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) converter = PrecisionConverter( model, value_info_map, initializer_map, node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, ) converted_model = converter.convert( high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"] ) onnx.checker.check_model(converted_model) + # Ensure both original outputs are preserved + assert {o.name for o in converted_model.graph.output} == {"Y1", "Y2"}Optionally, to validate type expectations on the direct-IO-cast output:
- If keep_io_types is True: Y1/Y2 should be FLOAT.
- If False: Y1/Y2 should be low_precision_onnx_type(low_precision_type).
1028-1067: Optional: Expand fixture to cover multi-Cast-from-same-input edge caseTo guard against future regressions (multiple Casts from the same input to different outputs), add a companion fixture/test. Example snippet to append to this file:
@pytest.fixture def model_with_two_casted_outputs(): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3]) y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [2, 3]) y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [2, 3]) cast1 = helper.make_node("Cast", ["X"], ["Y1"], name="cast1", to=TensorProto.FLOAT) cast2 = helper.make_node("Cast", ["X"], ["Y2"], name="cast2", to=TensorProto.FLOAT) graph = helper.make_graph([cast1, cast2], "model_two_casted_outputs", [x], [y1, y2], []) model = helper.make_model(graph, producer_name="model_two_casted_outputs") model.opset_import[0].version = 20 model.ir_version = 10 onnx.checker.check_model(model) model = onnx_utils.infer_shapes(model) value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) def test_two_casts_from_same_input(model_with_two_casted_outputs, keep_io_types, low_precision_type): model, value_info_map, initializer_map, node_to_init_map = model_with_two_casted_outputs gsani = GraphSanitizer(model, 22 if low_precision_type == "bf16" else 13) gsani.sanitize() model = gsani.model value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) conv = PrecisionConverter( model, value_info_map, initializer_map, node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type ) converted = conv.convert(high_precision_nodes=["cast1", "cast2"], low_precision_nodes=[]) onnx.checker.check_model(converted) assert {o.name for o in converted.graph.output} == {"Y1", "Y2"}I can send a PR update with these additions if helpful.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/onnx/autocast/graphsanitizer.py(2 hunks)tests/unit/onnx/autocast/test_precisionconverter.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/unit/onnx/autocast/test_precisionconverter.py (1)
modelopt/onnx/autocast/graphsanitizer.py (2)
GraphSanitizer(28-450)sanitize(53-68)
🔇 Additional comments (2)
modelopt/onnx/autocast/graphsanitizer.py (1)
68-69: Good placement in sanitize()Running sanitize_io_casts() after cleanup and IR versioning keeps the injected Identity from being swept early. LGTM.
tests/unit/onnx/autocast/test_precisionconverter.py (1)
23-24: Import addition looks goodGraphSanitizer is used below; import is correct and scoped for tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/unit/onnx/autocast/test_precisionconverter.py (1)
1077-1098: Consider adding explicit assertions for the edge case.The test validates that the model passes
check_modelafter conversion, which confirms basic correctness. However, it could be strengthened by explicitly asserting that:
- The Identity node was inserted by the sanitizer
- The original input and output names are preserved
- The Cast node is still present
This would make the test's intent clearer and provide better regression coverage for the specific edge case being addressed.
Example assertions to add before line 1098:
# Verify Identity node was inserted for the IO cast identity_nodes = [n for n in converted_model.graph.node if n.op_type == "Identity"] assert len(identity_nodes) >= 1, "Expected at least one Identity node for IO cast isolation" # Verify original I/O names are preserved assert converted_model.graph.input[0].name == "X" assert any(o.name == "Y1" for o in converted_model.graph.output) assert any(o.name == "Y2" for o in converted_model.graph.output)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/onnx/autocast/precisionconverter.py(5 hunks)tests/unit/onnx/autocast/test_precisionconverter.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/unit/onnx/autocast/test_precisionconverter.py (2)
modelopt/onnx/utils.py (1)
check_model(557-569)modelopt/onnx/autocast/precisionconverter.py (1)
convert(120-211)
modelopt/onnx/autocast/precisionconverter.py (1)
modelopt/onnx/autocast/graphsanitizer.py (2)
GraphSanitizer(28-460)sanitize(53-68)
🔇 Additional comments (6)
tests/unit/onnx/autocast/test_precisionconverter.py (2)
33-33: LGTM! Constant definition is clear.The constant
LATEST_IR_VERSION_SUPPORTED_BY_ORTis well-named and appropriately scoped for test usage.
1031-1074: LGTM! Fixture correctly models the edge case.The fixture constructs a graph where a Cast node connects input X directly to output Y1, alongside a separate computation path. This accurately captures the edge case described in the PR objectives.
modelopt/onnx/autocast/precisionconverter.py (4)
35-35: LGTM! Import is correctly placed.The GraphSanitizer import is appropriately added to support the new sanitization step.
77-79: LGTM! New parameters are well-integrated.The three new parameters (
min_opset,max_ir_version,trt_plugins) are properly added to the constructor with sensible defaults and stored as instance variables for use in the sanitization step.Also applies to: 116-118
142-142: LGTM! Sanitization is correctly positioned in the workflow.The
_sanitize_model()call is appropriately placed early in theconvert()method, after model validation but before unsupported op filtering. This ensures the graph is normalized (e.g., IO casts are isolated with Identity nodes) before precision conversion logic executes.
1043-1051: LGTM! Sanitization method is clean and correct.The
_sanitize_model()method correctly instantiatesGraphSanitizerwith the model and relevant parameters, invokes sanitization, and updates the model. This cleanly integrates the graph normalization step (including IO cast isolation with Identity nodes) into the precision conversion workflow.
7dd644a to
4028e2a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)
77-92: Document the new parameters in the docstring.The three new parameters (
min_opset,max_ir_version,trt_plugins) are not documented in the docstring.Apply this diff:
Args: model: ONNX model to convert. value_info_map: Map of tensor names to value info. initializer_map: Map of tensor names to initializers. node_to_init_map: Map of node names to lists of initializer names. keep_io_types: Keep the input and output types of the model, otherwise they will be converted. low_precision_type: Precision to convert to. init_conversion_max_bytes: Maximum size in bytes for initializer conversion. Larger initializers will be cast at runtime. custom_ops: List of custom ops. + min_opset: Minimum opset version to use for sanitization (default: 13). + max_ir_version: Maximum IR version supported by ORT (default: None). + trt_plugins: List of TensorRT plugin library paths in .so format (default: []).
🧹 Nitpick comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)
1063-1071: Consider matching parameter order to the GraphSanitizer constructor signature.While using keyword arguments makes order irrelevant, matching the constructor's parameter order improves readability.
Apply this diff:
def _sanitize_model(self): graph_sanitizer = GraphSanitizer( self.model, self.min_opset, - trt_plugins=self.trt_plugins, max_ir_version=self.max_ir_version, + trt_plugins=self.trt_plugins if self.trt_plugins is not None else [], ) graph_sanitizer.sanitize() self.model = graph_sanitizer.model
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
modelopt/onnx/autocast/graphsanitizer.py(2 hunks)modelopt/onnx/autocast/precisionconverter.py(5 hunks)tests/unit/onnx/autocast/test_precisionconverter.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/unit/onnx/autocast/test_precisionconverter.py
- modelopt/onnx/autocast/graphsanitizer.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/autocast/precisionconverter.py (1)
modelopt/onnx/autocast/graphsanitizer.py (2)
GraphSanitizer(28-560)sanitize(53-69)
🔇 Additional comments (2)
modelopt/onnx/autocast/precisionconverter.py (2)
35-35: LGTM!The import is correctly placed and necessary for the sanitization functionality.
142-143: LGTM!The sanitization is correctly invoked early in the conversion workflow, after validation but before precision conversion operations.
| min_opset: int = 13, | ||
| max_ir_version: int | None = None, | ||
| trt_plugins: list[str] | None = [], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace mutable default argument with None.
The default value trt_plugins=[] creates a mutable default argument, which can lead to shared state between instances if the list is mutated.
Apply this diff:
- trt_plugins: list[str] | None = [],
+ trt_plugins: list[str] | None = None,Then update line 1067 to handle the None case:
graph_sanitizer = GraphSanitizer(
self.model,
self.min_opset,
- trt_plugins=self.trt_plugins,
+ trt_plugins=self.trt_plugins if self.trt_plugins is not None else [],
max_ir_version=self.max_ir_version,
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| min_opset: int = 13, | |
| max_ir_version: int | None = None, | |
| trt_plugins: list[str] | None = [], | |
| def __init__( | |
| ..., | |
| min_opset: int = 13, | |
| max_ir_version: int | None = None, | |
| trt_plugins: list[str] | None = None, | |
| ): | |
| ... |
| min_opset: int = 13, | |
| max_ir_version: int | None = None, | |
| trt_plugins: list[str] | None = [], | |
| graph_sanitizer = GraphSanitizer( | |
| self.model, | |
| self.min_opset, | |
| trt_plugins=self.trt_plugins if self.trt_plugins is not None else [], | |
| max_ir_version=self.max_ir_version, | |
| ) |
🤖 Prompt for AI Agents
In modelopt/onnx/autocast/precisionconverter.py around lines 77 to 79, change
the function signature to use trt_plugins: list[str] | None = None instead of a
mutable default list, and then at line 1067 update the code to treat a None
value as an empty list (e.g., set local_trt_plugins = trt_plugins or [] before
using it) so any subsequent iterations or mutations operate on a fresh list
rather than a shared default.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #305 +/- ##
==========================================
+ Coverage 73.37% 73.41% +0.04%
==========================================
Files 180 180
Lines 17934 17957 +23
==========================================
+ Hits 13159 13184 +25
+ Misses 4775 4773 -2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tests/unit/onnx/autocast/test_precisionconverter.py (2)
28-28: Consider documenting or relocating the constant.The constant
LATEST_IR_VERSION_SUPPORTED_BY_ORT = 10is defined at module level but only used in one test. Consider either:
- Adding a comment explaining why version 10 is the latest supported by ORT, or
- Moving it to a shared constants file if other tests might need it
1154-1175: Consider adding assertions to verify the sanitization behavior.The test validates that the model remains valid after conversion but doesn't verify that the sanitization actually occurred as expected. Consider adding assertions to check:
- That an Identity node was inserted between the Cast and the output
- That the Cast node's output was rewired to the intermediate name
- That the output remains correctly connected
This would make the test more robust and easier to debug if the sanitization logic changes in the future.
Example assertions you could add:
# Verify sanitization occurred identity_nodes = [n for n in converted_model.graph.node if n.op_type == "Identity"] assert len(identity_nodes) >= 1, "Expected at least one Identity node after sanitization" # Verify output Y1 is still correctly connected y1_output = next(o for o in converted_model.graph.output if o.name == "Y1") assert y1_output.name == "Y1", "Output Y1 should maintain its original name"
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/onnx/autocast/graphsanitizer.py(2 hunks)tests/unit/onnx/autocast/test_precisionconverter.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/unit/onnx/autocast/test_precisionconverter.py (1)
modelopt/onnx/autocast/precisionconverter.py (1)
convert(120-214)
🔇 Additional comments (1)
modelopt/onnx/autocast/graphsanitizer.py (1)
347-379: LGTM! The sanitization logic correctly handles the edge case.The implementation properly addresses the issue where Cast nodes connecting inputs directly to outputs can cause disconnection problems. The approach of inserting Identity nodes with unique intermediate names is sound, and the offset-based insertion logic correctly maintains graph ordering.
Key strengths:
- Unique intermediate naming per output (
f"{cast_output_name}__io_cast_src") prevents collisions when multiple outputs are cast from the same input- Defensive checks on lines 359-360 prevent crashes if nodes lack expected inputs/outputs
- Insertion list with offset tracking (lines 378-379) correctly handles index shifts as nodes are inserted
The concerns raised in previous reviews about name collisions and insertion ordering have been properly addressed.
ac76dc7 to
d61c4ad
Compare
|
@aboubezari The PR is approved. Can you please rebase and sign your commits? Once commits are signed and rebased, PR should be able to auto-merge. |
Head branch was pushed to by a user without write access
d61c4ad to
76d4505
Compare
@galagam I can confirm my commits are verified and signed off. I have just fetched upstream main and rebased. |
f125429 to
fac6a17
Compare
Update modelopt/onnx/autocast/precisionconverter.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: aboubezari <126983138+aboubezari@users.noreply.github.com> cleanup Signed-off-by: Ali Boubezari <aboubezari@nuro.ai> Inject identity nodes in sanitizer; revert existing logic; update test Signed-off-by: Ali Boubezari <aboubezari@nuro.ai> Inject identity nodes in sanitizer; revert existing logic; update test Signed-off-by: Ali Boubezari <aboubezari@nuro.ai> move pass Signed-off-by: Ali Boubezari <aboubezari@nuro.ai> call sanitizer in precision converter Signed-off-by: Ali Boubezari <aboubezari@nuro.ai> address review comments Signed-off-by: Ali Boubezari <aboubezari@nuro.ai> Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
fac6a17 to
315ef04
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)
77-79: Replace mutable default argument with None.The default value
trt_plugins=[]creates a mutable default argument, which can lead to shared state between instances if the list is mutated.
🧹 Nitpick comments (1)
tests/unit/onnx/autocast/test_precisionconverter.py (1)
1154-1175: LGTM! Test exercises the new parameters and edge case.The test correctly uses the new
min_opset,max_ir_version, andtrt_pluginsparameters, and exercises the edge case of a Cast node connecting an input directly to an output. The test demonstrates that the model validates successfully after the fix.Optional enhancement: Consider adding more specific assertions to verify the expected behavior of the sanitization step. For example, you could check that an Identity node was inserted or that the graph structure matches expectations. This would make the test more robust and self-documenting.
Example assertions you could add:
# Verify the graph structure after conversion identity_nodes = [n for n in converted_model.graph.node if n.op_type == "Identity"] assert len(identity_nodes) > 0, "Expected Identity node(s) to be inserted by sanitization" # Verify output names are preserved assert converted_model.graph.output[0].name == "Y1" assert converted_model.graph.output[1].name == "Y2"
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
modelopt/onnx/autocast/graphsanitizer.py(2 hunks)modelopt/onnx/autocast/precisionconverter.py(5 hunks)tests/unit/onnx/autocast/test_precisionconverter.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/onnx/autocast/graphsanitizer.py
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/onnx/autocast/precisionconverter.py (1)
modelopt/onnx/autocast/graphsanitizer.py (2)
GraphSanitizer(28-557)sanitize(53-69)
tests/unit/onnx/autocast/test_precisionconverter.py (2)
modelopt/onnx/utils.py (1)
check_model(557-569)modelopt/onnx/autocast/precisionconverter.py (1)
convert(120-214)
🔇 Additional comments (3)
modelopt/onnx/autocast/precisionconverter.py (1)
142-143: LGTM! GraphSanitizer integration is clean.The sanitization step is correctly placed after initial validation and before low-precision processing. The
_sanitize_modelimplementation properly instantiates GraphSanitizer with the new parameters and updates the model.Also applies to: 1063-1071
tests/unit/onnx/autocast/test_precisionconverter.py (2)
1108-1151: LGTM! Fixture correctly reproduces the edge case.The fixture creates a model with a Cast node that directly connects input
Xto outputY1, which is the exact edge case scenario described in the PR objectives. The model structure is valid and appropriate for testing the sanitization behavior.
28-28: No update needed: IR version constant is accurate. IR version 10 is confirmed as the latest supported by ONNX Runtime per the compatibility table and release notes.
|
/ok to test 315ef04 |
What does this PR do?
Type of change: Bug fix
Overview: If there a cast node connecting an input directly to an output, then the output will be totally disconnected due to naming issues. This fix will create specialized cast nodes for such edge cases and avoid removing them in the initial pass.
Usage
Autocast precision converter
Testing
Added a unittest that fails before my change, and passes after my fix.
Before your PR is "Ready for review"
Summary by CodeRabbit
Bug Fixes
New Features
Tests
Chores