Skip to content

Commit 16d5875

Browse files
committed
[Autocast] Fix edge case casting input directly to output
1 parent cf6f1d4 commit 16d5875

File tree

2 files changed

+99
-1
lines changed

2 files changed

+99
-1
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,9 +586,42 @@ def _bypass_cast_node(self, node: onnx.NodeProto) -> None:
586586
consumer.input[i] = input_tensor
587587

588588
def _remove_preexisting_casts(self) -> None:
589-
nodes_to_remove = []
589+
# First check for special case where an input is casted directly to an output
590+
model_input_names = {input.name for input in self.model.graph.input}
591+
model_output_names = {output.name for output in self.model.graph.output}
592+
# Ensure that special casts that we add are not removed by the following logic
593+
casts_to_skip = []
594+
# Add casts as a separate step to avoid modifying the graph while iterating over it
595+
casts_to_add = []
590596
for node in self.model.graph.node:
591597
if node.op_type == "Cast":
598+
if node.input[0] in model_input_names and node.output[0] in model_output_names:
599+
# Create a special cast just for the input-output case.
600+
new_cast = helper.make_node(
601+
"Cast",
602+
name=node.name,
603+
inputs=[node.input[0]],
604+
outputs=[node.output[0]],
605+
to=node.attribute[0].i,
606+
)
607+
casts_to_skip.append(node.name)
608+
casts_to_add.append(new_cast)
609+
# Now adjust the old cast's name, consumers and producers
610+
node.name = f"{node.name}_io_special_case"
611+
node_new_output_name = f"{node.output[0]}_io_special_case"
612+
for consumer in utils.get_consumer_nodes(self.model, node.output[0]):
613+
for i, input_name in enumerate(consumer.input):
614+
if input_name == node.output[0]:
615+
consumer.input[i] = node_new_output_name
616+
node.output[0] = node_new_output_name
617+
618+
for cast in casts_to_add:
619+
self.model.graph.node.append(cast)
620+
casts_to_skip = set(casts_to_skip)
621+
622+
nodes_to_remove = []
623+
for node in self.model.graph.node:
624+
if node.op_type == "Cast" and node.name not in casts_to_skip:
592625
cast_from_type = self._get_tensor_type(node.input[0])
593626
cast_to_type = utils.get_cast_to_type(node)
594627
is_fp_cast = cast_to_type in [

tests/unit/onnx/autocast/test_precisionconverter.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,3 +1023,68 @@ def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_
10231023
assert utils.get_consumer_nodes(converted_model, "const_scalar")[0].op_type == "Add"
10241024
assert len(utils.get_consumer_nodes(converted_model, "const_array")) == 1
10251025
assert utils.get_consumer_nodes(converted_model, "const_array")[0].op_type == "Add"
1026+
1027+
1028+
@pytest.fixture
1029+
def model_with_casted_output():
1030+
"""Create a model with an output produced by a Cast node."""
1031+
# Create input and outputs
1032+
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3])
1033+
y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [2, 3]) # Intermediate output
1034+
y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [2, 3]) # Final output
1035+
1036+
# Create constant value
1037+
const = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
1038+
1039+
# Create constant node
1040+
const_node = helper.make_node(
1041+
"Constant",
1042+
[],
1043+
["const"],
1044+
name="const",
1045+
value=numpy_helper.from_array(const, name="const_value"),
1046+
)
1047+
1048+
# Create computation nodes
1049+
add1 = helper.make_node("Add", ["X", "const"], ["add1_out"], name="add1")
1050+
add2 = helper.make_node("Add", ["add1_out", "const"], ["Y2"], name="add2")
1051+
1052+
# Create cast node that feeds directly from input to output
1053+
cast_input = helper.make_node("Cast", ["X"], ["Y1"], name="cast_input", to=TensorProto.FLOAT)
1054+
1055+
graph = helper.make_graph(
1056+
[const_node, add1, add2, cast_input],
1057+
"model_with_casted_output",
1058+
[x],
1059+
[y1, y2],
1060+
[],
1061+
)
1062+
1063+
model = helper.make_model(graph, producer_name="model_with_casted_output")
1064+
model.opset_import[0].version = 20
1065+
model.ir_version = 10
1066+
onnx.checker.check_model(model)
1067+
1068+
model = onnx_utils.infer_shapes(model)
1069+
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
1070+
onnx.save(model, "/tmp/model_with_casted_output.onnx")
1071+
1072+
return model, value_info_map, initializer_map, node_to_init_map
1073+
1074+
1075+
@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
1076+
def test_casted_output_model(model_with_casted_output, low_precision_type):
1077+
model, value_info_map, initializer_map, node_to_init_map = model_with_casted_output
1078+
1079+
converter = PrecisionConverter(
1080+
model,
1081+
value_info_map,
1082+
initializer_map,
1083+
node_to_init_map,
1084+
keep_io_types=True,
1085+
low_precision_type=low_precision_type,
1086+
)
1087+
converted_model = converter.convert(
1088+
high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"]
1089+
)
1090+
onnx.checker.check_model(converted_model)

0 commit comments

Comments
 (0)