@@ -30,6 +30,9 @@ def low_precision_onnx_type(low_precision_type_str):
3030 return TensorProto .FLOAT16 if low_precision_type_str == "fp16" else TensorProto .BFLOAT16
3131
3232
33+ LATEST_IR_VERSION_SUPPORTED_BY_ORT = 10
34+
35+
3336####################################################################################################
3437# Testing with a basic GEMM->Add->Relu graph
3538####################################################################################################
@@ -1023,3 +1026,73 @@ def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_
10231026 assert utils .get_consumer_nodes (converted_model , "const_scalar" )[0 ].op_type == "Add"
10241027 assert len (utils .get_consumer_nodes (converted_model , "const_array" )) == 1
10251028 assert utils .get_consumer_nodes (converted_model , "const_array" )[0 ].op_type == "Add"
1029+
1030+
1031+ @pytest .fixture
1032+ def model_with_casted_input_to_output ():
1033+ """Create a model with an output produced by a Cast node."""
1034+ # Create input and outputs
1035+ x = helper .make_tensor_value_info ("X" , TensorProto .FLOAT , [2 , 3 ])
1036+ y1 = helper .make_tensor_value_info ("Y1" , TensorProto .FLOAT , [2 , 3 ]) # Intermediate output
1037+ y2 = helper .make_tensor_value_info ("Y2" , TensorProto .FLOAT , [2 , 3 ]) # Final output
1038+
1039+ # Create constant value
1040+ const = np .array ([[1.0 , 2.0 , 3.0 ], [4.0 , 5.0 , 6.0 ]], dtype = np .float32 )
1041+
1042+ # Create constant node
1043+ const_node = helper .make_node (
1044+ "Constant" ,
1045+ [],
1046+ ["const" ],
1047+ name = "const" ,
1048+ value = numpy_helper .from_array (const , name = "const_value" ),
1049+ )
1050+
1051+ # Create computation nodes
1052+ add1 = helper .make_node ("Add" , ["X" , "const" ], ["add1_out" ], name = "add1" )
1053+ add2 = helper .make_node ("Add" , ["add1_out" , "const" ], ["Y2" ], name = "add2" )
1054+
1055+ # Create cast node that feeds directly from input to output
1056+ cast_input = helper .make_node ("Cast" , ["X" ], ["Y1" ], name = "cast_input" , to = TensorProto .FLOAT )
1057+
1058+ graph = helper .make_graph (
1059+ [const_node , add1 , add2 , cast_input ],
1060+ "model_with_casted_output" ,
1061+ [x ],
1062+ [y1 , y2 ],
1063+ [],
1064+ )
1065+
1066+ model = helper .make_model (graph , producer_name = "model_with_casted_output" )
1067+ model .opset_import [0 ].version = 20
1068+ model .ir_version = 10
1069+ onnx .checker .check_model (model )
1070+
1071+ model = onnx_utils .infer_shapes (model )
1072+ value_info_map , initializer_map , node_to_init_map = utils .setup_mappings (model )
1073+
1074+ return model , value_info_map , initializer_map , node_to_init_map
1075+
1076+
1077+ @pytest .mark .parametrize ("low_precision_type" , ["fp16" , "bf16" ])
1078+ @pytest .mark .parametrize ("keep_io_types" , [True , False ])
1079+ def test_casted_input_to_output_model (
1080+ model_with_casted_input_to_output , low_precision_type , keep_io_types
1081+ ):
1082+ model , value_info_map , initializer_map , node_to_init_map = model_with_casted_input_to_output
1083+
1084+ converter = PrecisionConverter (
1085+ model ,
1086+ value_info_map ,
1087+ initializer_map ,
1088+ node_to_init_map ,
1089+ keep_io_types = keep_io_types ,
1090+ low_precision_type = low_precision_type ,
1091+ min_opset = 22 if low_precision_type == "bf16" else 13 ,
1092+ max_ir_version = LATEST_IR_VERSION_SUPPORTED_BY_ORT ,
1093+ trt_plugins = [],
1094+ )
1095+ converted_model = converter .convert (
1096+ high_precision_nodes = ["cast_input" ], low_precision_nodes = ["add1" , "add2" ]
1097+ )
1098+ onnx .checker .check_model (converted_model )
0 commit comments