@@ -1023,3 +1023,68 @@ def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_
10231023 assert utils .get_consumer_nodes (converted_model , "const_scalar" )[0 ].op_type == "Add"
10241024 assert len (utils .get_consumer_nodes (converted_model , "const_array" )) == 1
10251025 assert utils .get_consumer_nodes (converted_model , "const_array" )[0 ].op_type == "Add"
1026+
1027+
1028+ @pytest .fixture
1029+ def model_with_casted_output ():
1030+ """Create a model with an output produced by a Cast node."""
1031+ # Create input and outputs
1032+ x = helper .make_tensor_value_info ("X" , TensorProto .FLOAT , [2 , 3 ])
1033+ y1 = helper .make_tensor_value_info ("Y1" , TensorProto .FLOAT , [2 , 3 ]) # Intermediate output
1034+ y2 = helper .make_tensor_value_info ("Y2" , TensorProto .FLOAT , [2 , 3 ]) # Final output
1035+
1036+ # Create constant value
1037+ const = np .array ([[1.0 , 2.0 , 3.0 ], [4.0 , 5.0 , 6.0 ]], dtype = np .float32 )
1038+
1039+ # Create constant node
1040+ const_node = helper .make_node (
1041+ "Constant" ,
1042+ [],
1043+ ["const" ],
1044+ name = "const" ,
1045+ value = numpy_helper .from_array (const , name = "const_value" ),
1046+ )
1047+
1048+ # Create computation nodes
1049+ add1 = helper .make_node ("Add" , ["X" , "const" ], ["add1_out" ], name = "add1" )
1050+ add2 = helper .make_node ("Add" , ["add1_out" , "const" ], ["Y2" ], name = "add2" )
1051+
1052+ # Create cast node that feeds directly from input to output
1053+ cast_input = helper .make_node ("Cast" , ["X" ], ["Y1" ], name = "cast_input" , to = TensorProto .FLOAT )
1054+
1055+ graph = helper .make_graph (
1056+ [const_node , add1 , add2 , cast_input ],
1057+ "model_with_casted_output" ,
1058+ [x ],
1059+ [y1 , y2 ],
1060+ [],
1061+ )
1062+
1063+ model = helper .make_model (graph , producer_name = "model_with_casted_output" )
1064+ model .opset_import [0 ].version = 20
1065+ model .ir_version = 10
1066+ onnx .checker .check_model (model )
1067+
1068+ model = onnx_utils .infer_shapes (model )
1069+ value_info_map , initializer_map , node_to_init_map = utils .setup_mappings (model )
1070+ onnx .save (model , "/tmp/model_with_casted_output.onnx" )
1071+
1072+ return model , value_info_map , initializer_map , node_to_init_map
1073+
1074+
1075+ @pytest .mark .parametrize ("low_precision_type" , ["fp16" , "bf16" ])
1076+ def test_casted_output_model (model_with_casted_output , low_precision_type ):
1077+ model , value_info_map , initializer_map , node_to_init_map = model_with_casted_output
1078+
1079+ converter = PrecisionConverter (
1080+ model ,
1081+ value_info_map ,
1082+ initializer_map ,
1083+ node_to_init_map ,
1084+ keep_io_types = True ,
1085+ low_precision_type = low_precision_type ,
1086+ )
1087+ converted_model = converter .convert (
1088+ high_precision_nodes = ["cast_input" ], low_precision_nodes = ["add1" , "add2" ]
1089+ )
1090+ onnx .checker .check_model (converted_model )
0 commit comments