Skip to content

Commit 2584067

Browse files
Expand type transform
1 parent 0f635b9 commit 2584067

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -468,10 +468,33 @@ struct CustomGraph {
468468
std::string cast_node_name = prev.node_ptr->OutputDefs()[0]->Name() + "_cast";
469469
InlinedVector<NodeArg*> input_args = {(NodeArg*)(prev.node_ptr->OutputDefs()[0])};
470470
InlinedVector<NodeArg*> output_args = {&output};
471+
std::unordered_map<std::string, int> type_str_to_tensor_data_type_;
472+
type_str_to_tensor_data_type_["tensor(float)"] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
473+
type_str_to_tensor_data_type_["tensor(float16)"] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
474+
type_str_to_tensor_data_type_["tensor(bfloat16)"] = ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16;
475+
type_str_to_tensor_data_type_["tensor(double)"] = ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
476+
type_str_to_tensor_data_type_["tensor(int8)"] = ONNX_NAMESPACE::TensorProto_DataType_INT8;
477+
type_str_to_tensor_data_type_["tensor(int16)"] = ONNX_NAMESPACE::TensorProto_DataType_INT16;
478+
type_str_to_tensor_data_type_["tensor(int32)"] = ONNX_NAMESPACE::TensorProto_DataType_INT32;
479+
type_str_to_tensor_data_type_["tensor(int64)"] = ONNX_NAMESPACE::TensorProto_DataType_INT64;
480+
type_str_to_tensor_data_type_["tensor(uint8)"] = ONNX_NAMESPACE::TensorProto_DataType_UINT8;
481+
type_str_to_tensor_data_type_["tensor(uint16)"] = ONNX_NAMESPACE::TensorProto_DataType_UINT16;
482+
type_str_to_tensor_data_type_["tensor(uint32)"] = ONNX_NAMESPACE::TensorProto_DataType_UINT32;
483+
type_str_to_tensor_data_type_["tensor(uint64)"] = ONNX_NAMESPACE::TensorProto_DataType_UINT64;
484+
type_str_to_tensor_data_type_["tensor(complex64)"] = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64;
485+
type_str_to_tensor_data_type_["tensor(complex128)"] = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128;
486+
type_str_to_tensor_data_type_["tensor(string)"] = ONNX_NAMESPACE::TensorProto_DataType_STRING;
487+
type_str_to_tensor_data_type_["tensor(bool)"] = ONNX_NAMESPACE::TensorProto_DataType_BOOL;
488+
type_str_to_tensor_data_type_["tensor(float8e4m3fn)"] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN;
489+
type_str_to_tensor_data_type_["tensor(float8e4m3fnuz)"] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ;
490+
type_str_to_tensor_data_type_["tensor(float8e5m2)"] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2;
491+
type_str_to_tensor_data_type_["tensor(float8e5m2fnuz)"] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ;
492+
type_str_to_tensor_data_type_["tensor(uint4)"] = ONNX_NAMESPACE::TensorProto_DataType_UINT4;
493+
type_str_to_tensor_data_type_["tensor(int4)"] = ONNX_NAMESPACE::TensorProto_DataType_INT4;
494+
type_str_to_tensor_data_type_["tensor(float4e2m1)"] = ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1;
471495
Node& cast_node = original_graph.AddNode(cast_node_name, "Cast", "", input_args, output_args, nullptr, "");
472-
auto type_dummy = dq_node_ref.OutputDefs()[0]->Type();
473-
auto type_final = type_dummy->find("tensor(float)") != std::string::npos ? onnx::TensorProto_DataType_FLOAT : onnx::TensorProto_DataType_FLOAT16;
474-
cast_node.AddAttribute("to", int64_t(type_final));
496+
auto type_cast = type_str_to_tensor_data_type_[*dq_node_ref.OutputDefs()[0]->Type()];
497+
cast_node.AddAttribute("to", static_cast<int64_t>(type_cast));
475498
original_graph.AddEdge(prev.node_ptr->Index(),
476499
cast_node.Index(),
477500
prev_output_index,

0 commit comments

Comments
 (0)