Skip to content

Commit 20de366

Browse files
mdvoretc-intelMayureshV1Copilot
authored
CVS-175447-[OVEP] Add a check for type mismatches in QDQ stripping (#834)
* [OVEP] Add a check for type mismatches in QDQ stripping When rewiring the graph after eliminating QDQ pairs, the runtime now checks whether the type matches before and after the eliminated nodes and inserts a Cast node if there is a mismatch. * Expand type transform * Limit output types to f32/f16, add const_cast * Apply null check suggestion Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent b1f7750 commit 20de366

File tree

1 file changed

+27
-3
lines changed

1 file changed

+27
-3
lines changed

onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -463,11 +463,35 @@ struct CustomGraph {
463463
}
464464

465465
if (!is_prev_input) {
466-
for (const auto& edge : output_edges) {
466+
if (prev.node_ptr->OutputDefs()[0]->Type() != dq_node_ref.OutputDefs()[0]->Type()) {
467+
NodeArg& output = original_graph.GetOrCreateNodeArg(prev.node_name + "_cast_0", dq_node_ref.OutputDefs()[0]->TypeAsProto());
468+
std::string cast_node_name = prev.node_ptr->OutputDefs()[0]->Name() + "_cast";
469+
InlinedVector<NodeArg*> input_args = {const_cast<NodeArg*>(prev.node_ptr->OutputDefs()[0])};
470+
InlinedVector<NodeArg*> output_args = {&output};
471+
Node& cast_node = original_graph.AddNode(cast_node_name, "Cast", "", input_args, output_args, nullptr, "");
472+
auto type_str = dq_node_ref.OutputDefs()[0]->Type();
473+
ORT_ENFORCE(type_str != nullptr, "Type string is null in QDQ scales fix.");
474+
auto type_cast = type_str->find("tensor(float)") != std::string::npos ? onnx::TensorProto_DataType_FLOAT : onnx::TensorProto_DataType_FLOAT16;
475+
ORT_ENFORCE((type_cast == onnx::TensorProto_DataType_FLOAT) || (type_str->find("tensor(float16)") != std::string::npos),
476+
"QDQ type misalignment, expected float32 or float16 output");
477+
cast_node.AddAttribute("to", static_cast<int64_t>(type_cast));
467478
original_graph.AddEdge(prev.node_ptr->Index(),
468-
std::get<0>(edge),
479+
cast_node.Index(),
469480
prev_output_index,
470-
std::get<2>(edge));
481+
0);
482+
for (const auto& edge : output_edges) {
483+
original_graph.AddEdge(cast_node.Index(),
484+
std::get<0>(edge),
485+
0,
486+
std::get<2>(edge));
487+
}
488+
} else {
489+
for (const auto& edge : output_edges) {
490+
original_graph.AddEdge(prev.node_ptr->Index(),
491+
std::get<0>(edge),
492+
prev_output_index,
493+
std::get<2>(edge));
494+
}
471495
}
472496
}
473497
}

0 commit comments

Comments
 (0)