@@ -468,10 +468,33 @@ struct CustomGraph {
468468 std::string cast_node_name = prev.node_ptr ->OutputDefs ()[0 ]->Name () + " _cast" ;
469469 InlinedVector<NodeArg*> input_args = {(NodeArg*)(prev.node_ptr ->OutputDefs ()[0 ])};
470470 InlinedVector<NodeArg*> output_args = {&output};
471+ std::unordered_map<std::string, int > type_str_to_tensor_data_type_;
472+ type_str_to_tensor_data_type_[" tensor(float)" ] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
473+ type_str_to_tensor_data_type_[" tensor(float16)" ] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
474+ type_str_to_tensor_data_type_[" tensor(bfloat16)" ] = ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16;
475+ type_str_to_tensor_data_type_[" tensor(double)" ] = ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
476+ type_str_to_tensor_data_type_[" tensor(int8)" ] = ONNX_NAMESPACE::TensorProto_DataType_INT8;
477+ type_str_to_tensor_data_type_[" tensor(int16)" ] = ONNX_NAMESPACE::TensorProto_DataType_INT16;
478+ type_str_to_tensor_data_type_[" tensor(int32)" ] = ONNX_NAMESPACE::TensorProto_DataType_INT32;
479+ type_str_to_tensor_data_type_[" tensor(int64)" ] = ONNX_NAMESPACE::TensorProto_DataType_INT64;
480+ type_str_to_tensor_data_type_[" tensor(uint8)" ] = ONNX_NAMESPACE::TensorProto_DataType_UINT8;
481+ type_str_to_tensor_data_type_[" tensor(uint16)" ] = ONNX_NAMESPACE::TensorProto_DataType_UINT16;
482+ type_str_to_tensor_data_type_[" tensor(uint32)" ] = ONNX_NAMESPACE::TensorProto_DataType_UINT32;
483+ type_str_to_tensor_data_type_[" tensor(uint64)" ] = ONNX_NAMESPACE::TensorProto_DataType_UINT64;
484+ type_str_to_tensor_data_type_[" tensor(complex64)" ] = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64;
485+ type_str_to_tensor_data_type_[" tensor(complex128)" ] = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128;
486+ type_str_to_tensor_data_type_[" tensor(string)" ] = ONNX_NAMESPACE::TensorProto_DataType_STRING;
487+ type_str_to_tensor_data_type_[" tensor(bool)" ] = ONNX_NAMESPACE::TensorProto_DataType_BOOL;
488+ type_str_to_tensor_data_type_[" tensor(float8e4m3fn)" ] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN;
489+ type_str_to_tensor_data_type_[" tensor(float8e4m3fnuz)" ] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ;
490+ type_str_to_tensor_data_type_[" tensor(float8e5m2)" ] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2;
491+ type_str_to_tensor_data_type_[" tensor(float8e5m2fnuz)" ] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ;
492+ type_str_to_tensor_data_type_[" tensor(uint4)" ] = ONNX_NAMESPACE::TensorProto_DataType_UINT4;
493+ type_str_to_tensor_data_type_[" tensor(int4)" ] = ONNX_NAMESPACE::TensorProto_DataType_INT4;
494+ type_str_to_tensor_data_type_[" tensor(float4e2m1)" ] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1;
471495 Node& cast_node = original_graph.AddNode (cast_node_name, " Cast" , " " , input_args, output_args, nullptr , " " );
472- auto type_dummy = dq_node_ref.OutputDefs ()[0 ]->Type ();
473- auto type_final = type_dummy->find (" tensor(float)" ) != std::string::npos ? onnx::TensorProto_DataType_FLOAT : onnx::TensorProto_DataType_FLOAT16;
474- cast_node.AddAttribute (" to" , int64_t (type_final));
496+ auto type_cast = type_str_to_tensor_data_type_[*dq_node_ref.OutputDefs ()[0 ]->Type ()];
497+ cast_node.AddAttribute (" to" , static_cast <int64_t >(type_cast));
475498 original_graph.AddEdge (prev.node_ptr ->Index (),
476499 cast_node.Index (),
477500 prev_output_index,
0 commit comments