@@ -1023,3 +1023,81 @@ def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_
10231023 assert utils .get_consumer_nodes (converted_model , "const_scalar" )[0 ].op_type == "Add"
10241024 assert len (utils .get_consumer_nodes (converted_model , "const_array" )) == 1
10251025 assert utils .get_consumer_nodes (converted_model , "const_array" )[0 ].op_type == "Add"
1026+
1027+
1028+ @pytest .fixture
1029+ def model_with_multiple_output_node_casted_to_output ():
1030+ """Create a model with a Cast node connecting a consumer with multiple outputs to a graph output."""
1031+ # Create inputs and outputs
1032+ x1 = helper .make_tensor_value_info ("X1" , TensorProto .FLOAT , [1 , 2 , 16 , 16 ])
1033+ x2 = helper .make_tensor_value_info ("X2" , TensorProto .FLOAT , [1 , 3 , 16 , 16 ])
1034+ x3 = helper .make_tensor_value_info ("X3" , TensorProto .FLOAT , [1 , 4 , 16 , 16 ])
1035+ y1 = helper .make_tensor_value_info ("Y1" , TensorProto .FLOAT , [1 , 5 , 16 , 16 ])
1036+ y2 = helper .make_tensor_value_info ("Y2" , TensorProto .FLOAT , [1 , 9 , 16 , 16 ])
1037+
1038+ # Create computation nodes
1039+ concat_1_node = helper .make_node (
1040+ "Concat" ,
1041+ ["X1" , "X2" ],
1042+ ["concat_1_out" ],
1043+ name = "concat_1" ,
1044+ axis = 1 ,
1045+ )
1046+ concat_2_node = helper .make_node (
1047+ "Concat" ,
1048+ ["concat_1_out" , "X3" ],
1049+ ["Y2" ],
1050+ name = "concat_2" ,
1051+ axis = 1 ,
1052+ )
1053+
1054+ # Create a Cast node between 'concat_1' and the graph output
1055+ cast_node = helper .make_node (
1056+ "Cast" ,
1057+ ["concat_1_out" ],
1058+ ["Y1" ],
1059+ name = "cast_0" ,
1060+ to = TensorProto .FLOAT ,
1061+ )
1062+
1063+ graph = helper .make_graph (
1064+ [concat_1_node , concat_2_node , cast_node ],
1065+ "model_with_multiple_output_node_casted_to_output" ,
1066+ [x1 , x2 , x3 ],
1067+ [y1 , y2 ],
1068+ [],
1069+ )
1070+
1071+ model = helper .make_model (
1072+ graph , producer_name = "model_with_multiple_output_node_casted_to_output"
1073+ )
1074+ model .opset_import [0 ].version = 20
1075+ model .ir_version = 10
1076+ onnx .checker .check_model (model )
1077+
1078+ model = onnx_utils .infer_shapes (model )
1079+ value_info_map , initializer_map , node_to_init_map = utils .setup_mappings (model )
1080+
1081+ return model , value_info_map , initializer_map , node_to_init_map
1082+
1083+
1084+ @pytest .mark .parametrize ("low_precision_type" , ["fp16" , "bf16" ])
1085+ def test_multiple_output_node_casted_to_output (
1086+ model_with_multiple_output_node_casted_to_output , low_precision_type
1087+ ):
1088+ model , value_info_map , initializer_map , node_to_init_map = (
1089+ model_with_multiple_output_node_casted_to_output
1090+ )
1091+
1092+ converter = PrecisionConverter (
1093+ model ,
1094+ value_info_map ,
1095+ initializer_map ,
1096+ node_to_init_map ,
1097+ keep_io_types = True ,
1098+ low_precision_type = low_precision_type ,
1099+ )
1100+ converted_model = converter .convert (
1101+ high_precision_nodes = [], low_precision_nodes = ["concat_1" , "concat_2" ]
1102+ )
1103+ onnx .checker .check_model (converted_model )
0 commit comments