Skip to content

Commit 0f635b9

Browse files
[OVEP] Add a check for type mismatches in QDQ stripping
When rewiring the graph after eliminating QDQ pairs, the runtime now checks whether the type matches before and after the eliminated nodes and inserts a Cast node if there is a mismatch.
1 parent 19ebc1f commit 0f635b9

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -463,11 +463,32 @@ struct CustomGraph {
463463
}
464464

465465
if (!is_prev_input) {
466-
for (const auto& edge : output_edges) {
466+
if (prev.node_ptr->OutputDefs()[0]->Type() != dq_node_ref.OutputDefs()[0]->Type()) {
467+
NodeArg& output = original_graph.GetOrCreateNodeArg(prev.node_name + "_cast_0", dq_node_ref.OutputDefs()[0]->TypeAsProto());
468+
std::string cast_node_name = prev.node_ptr->OutputDefs()[0]->Name() + "_cast";
469+
InlinedVector<NodeArg*> input_args = {(NodeArg*)(prev.node_ptr->OutputDefs()[0])};
470+
InlinedVector<NodeArg*> output_args = {&output};
471+
Node& cast_node = original_graph.AddNode(cast_node_name, "Cast", "", input_args, output_args, nullptr, "");
472+
auto type_dummy = dq_node_ref.OutputDefs()[0]->Type();
473+
auto type_final = type_dummy->find("tensor(float)") != std::string::npos ? onnx::TensorProto_DataType_FLOAT : onnx::TensorProto_DataType_FLOAT16;
474+
cast_node.AddAttribute("to", int64_t(type_final));
467475
original_graph.AddEdge(prev.node_ptr->Index(),
468-
std::get<0>(edge),
476+
cast_node.Index(),
469477
prev_output_index,
470-
std::get<2>(edge));
478+
0);
479+
for (const auto& edge : output_edges) {
480+
original_graph.AddEdge(cast_node.Index(),
481+
std::get<0>(edge),
482+
0,
483+
std::get<2>(edge));
484+
}
485+
} else {
486+
for (const auto& edge : output_edges) {
487+
original_graph.AddEdge(prev.node_ptr->Index(),
488+
std::get<0>(edge),
489+
prev_output_index,
490+
std::get<2>(edge));
491+
}
471492
}
472493
}
473494
}

0 commit comments

Comments
 (0)