diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index b03959e4f067b..75ca2a6ca091a 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -48,6 +48,7 @@ #include "core/optimizer/group_query_attention_fusion.h" #include "core/optimizer/identical_children_consolidation.h" #include "core/optimizer/identity_elimination.h" +#include "core/optimizer/if_to_where_transformer.h" #include "core/optimizer/label_encoder_fusion.h" #include "core/optimizer/layer_norm_fusion.h" #include "core/optimizer/matmul_activation_fusion.h" @@ -285,6 +286,8 @@ InlinedVector> GenerateTransformers( const InlinedHashSet cuda_eps = {onnxruntime::kCudaExecutionProvider}; + const InlinedHashSet qnn_eps = {onnxruntime::kQnnExecutionProvider}; + const InlinedHashSet cuda_rocm_eps = {onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider}; const InlinedHashSet cpu_cuda_rocm_eps = {onnxruntime::kCpuExecutionProvider, @@ -345,6 +348,8 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_rocm_acl_armnn_js_webgpu_eps)); + transformers.emplace_back(std::make_unique(qnn_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps, level)); transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps, level)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); diff --git a/onnxruntime/core/optimizer/if_to_where_transformer.cc b/onnxruntime/core/optimizer/if_to_where_transformer.cc new file mode 100644 index 0000000000000..523254ddc44b3 --- /dev/null +++ b/onnxruntime/core/optimizer/if_to_where_transformer.cc @@ -0,0 +1,226 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/if_to_where_transformer.h" +#include "core/graph/graph.h" +#include "core/graph/graph_utils.h" +#include "core/graph/constants.h" +#include "core/graph/model.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; +namespace onnxruntime { + +static Status InlineSubgraph( + Graph& main_graph, + const Graph& subgraph, + std::unordered_map& name_to_nodearg, + std::vector& subgraph_outputs, + const logging::Logger& logger) { + for (const auto* input : subgraph.GetInputs()) { + NodeArg* outer_arg = main_graph.GetNodeArg(input->Name()); + if (!outer_arg) { + // If missing, try to add initializer from subgraph to main graph + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + const ONNX_NAMESPACE::TensorProto* initializer_input = nullptr; + bool has_initializer = subgraph.GetInitializedTensor(input->Name(), initializer); + if (has_initializer && initializer != nullptr) { + // If not in main graph, add initializer + if (!main_graph.GetInitializedTensor(input->Name(), initializer_input)) { + LOGS(logger, VERBOSE) << "Adding initializer '" << input->Name() << "' to main graph."; + main_graph.AddInitializedTensor(*initializer); + } + } + // Create NodeArg in main graph for input + outer_arg = &main_graph.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); + } + name_to_nodearg[input->Name()] = outer_arg; + } + + // Copy subgraph initializers to main graph + for (const auto& pair : subgraph.GetAllInitializedTensors()) { + const std::string& init_name = pair.first; + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + bool has_initializer = subgraph.GetInitializedTensor(init_name, initializer); + if (has_initializer && initializer != nullptr) { + const ONNX_NAMESPACE::TensorProto* initializer_input = nullptr; + if (!main_graph.GetInitializedTensor(init_name, initializer_input)) { + LOGS(logger, VERBOSE) << "Creating NodeArg for subgraph initializer '" << init_name << "' in main graph."; + main_graph.AddInitializedTensor(*initializer); + } + main_graph.GetOrCreateNodeArg(init_name, nullptr); + } + } + + // Inline nodes + for (const auto& node : subgraph.Nodes()) { + std::vector inputs; + for (const auto* input_arg : node.InputDefs()) { + NodeArg* mapped_input = nullptr; + auto it = name_to_nodearg.find(input_arg->Name()); + if (it != name_to_nodearg.end()) { + mapped_input = it->second; + } else { + mapped_input = main_graph.GetNodeArg(input_arg->Name()); + if (!mapped_input) { + // Add NodeArg if missing + mapped_input = &main_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); + } + } + if (!mapped_input) { + LOGS(logger, ERROR) << "Node input '" << input_arg->Name() << "' not found in main graph for node '" << node.Name() << "'"; + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Node input '", input_arg->Name(), "' not found in main graph for node '", node.Name(), "'"); + } + inputs.push_back(mapped_input); + } + + // Outputs + std::vector outputs; + for (const auto* output_arg : node.OutputDefs()) { + // Create a unique NodeArg in main graph for each output + NodeArg* mapped_output = &main_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); + name_to_nodearg[output_arg->Name()] = mapped_output; + outputs.push_back(mapped_output); + } + + // Add node to main graph + main_graph.AddNode( + main_graph.GenerateNodeName(node.OpType()), + node.OpType(), + node.Description(), + inputs, + outputs, + &node.GetAttributes(), + node.Domain()); + } + + // Map subgraph outputs to main graph NodeArgs + for (const auto* output : subgraph.GetOutputs()) { + auto it = name_to_nodearg.find(output->Name()); + NodeArg* mapped_output = (it != name_to_nodearg.end()) ? it->second : main_graph.GetNodeArg(output->Name()); + if (!mapped_output) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Missing output: ", output->Name(), " in main graph after inlining."); + } + subgraph_outputs.push_back(mapped_output); + } + + return Status::OK(); +} + +static bool CanTransformIfToWhere(const Node& if_node) { + const auto& outputs = if_node.OutputDefs(); + for (const auto* output : outputs) { + const auto* type_proto = output->TypeAsProto(); + if (!type_proto) return false; + + // If any output is optional, we skip the transformation + if (type_proto->value_case() == onnx::TypeProto::kOptionalType) { + return false; + } + + // Optionally, ensure it's a tensor + if (type_proto->value_case() != onnx::TypeProto::kTensorType) { + return false; + } + } + return true; +} + +Status IfToWhereTransformer::ApplyImpl(Graph& graph, + bool& modified, + int graph_level, + const logging::Logger& logger) const { + modified = false; + (void)graph_level; + for (auto it = graph.Nodes().begin(), end = graph.Nodes().end(); it != end; ++it) { + Node& if_node = *it; + if (if_node.OpType() != "If") + continue; + + if (!graph_utils::IsSupportedProvider(if_node, GetCompatibleExecutionProviders())) { + // If not Qnn Execution Provider + continue; + } + + if (!CanTransformIfToWhere(if_node)) { + LOGS(logger, VERBOSE) << "Skipping IfToWhere transformation for node " << if_node.Name() + << " due to optional or unsupported output types."; + continue; + } + + const NodeArg* cond_arg = if_node.InputDefs()[0]; + if (!cond_arg) + continue; + const onnx::TypeProto* type_proto = cond_arg->TypeAsProto(); + if (type_proto && type_proto->has_tensor_type()) { + if (type_proto->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_BOOL) { + continue; + } + } + if (graph_utils::IsConstantInitializer(graph, cond_arg->Name(), true)) + continue; + + // Extract the two subgraph branches + Graph* then_sub = if_node.GetMutableGraphAttribute("then_branch"); + Graph* else_sub = if_node.GetMutableGraphAttribute("else_branch"); + if (!then_sub || !else_sub) { + LOGS(logger, INFO) << "If node missing subgraphs!"; + continue; + } + + std::unordered_map then_map, else_map; + std::vector then_outputs, else_outputs; + ORT_RETURN_IF_ERROR(InlineSubgraph(graph, *then_sub, then_map, then_outputs, logger)); + ORT_RETURN_IF_ERROR(InlineSubgraph(graph, *else_sub, else_map, else_outputs, logger)); + + // Build one Where node per original If output + const auto& if_outputs = if_node.MutableOutputDefs(); + size_t num_outputs = if_outputs.size(); + + if (then_outputs.size() != else_outputs.size() || then_outputs.size() != if_outputs.size()) { + LOGS(logger, INFO) << "Mismatch in output sizes between then/else branches and If node."; + continue; + } + + for (size_t i = 0; i < num_outputs; ++i) { + NodeArg* out_arg = if_outputs[i]; + graph.AddNode( + graph.GenerateNodeName("IfToWhere"), + "Where", + "Select between then/else outputs", + {const_cast(cond_arg), then_outputs[i], else_outputs[i]}, + {out_arg}); + } + + std::vector> edges_to_remove; + for (auto edge_it = if_node.OutputEdgesBegin(), last = if_node.OutputEdgesEnd(); + edge_it != last; ++edge_it) { + edges_to_remove.emplace_back( + if_node.Index(), // src node index + edge_it->GetNode().Index(), // dst node index + edge_it->GetSrcArgIndex(), // which output slot of 'if_node' + edge_it->GetDstArgIndex()); // which input slot on consumer + } + + for (auto& edge : edges_to_remove) { + graph.RemoveEdge( + std::get<0>(edge), + std::get<1>(edge), + std::get<2>(edge), + std::get<3>(edge)); + } + + graph.RemoveNode(if_node.Index()); + modified = true; + break; + } + + if (modified) { + ORT_RETURN_IF_ERROR(graph.Resolve()); + } + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/if_to_where_transformer.h b/onnxruntime/core/optimizer/if_to_where_transformer.h new file mode 100644 index 0000000000000..26eabdfacbc56 --- /dev/null +++ b/onnxruntime/core/optimizer/if_to_where_transformer.h @@ -0,0 +1,17 @@ +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +// Transformer that lowers dynamic If nodes into Where ops by inlining both branches. +class IfToWhereTransformer final : public GraphTransformer { + public: + IfToWhereTransformer(const InlinedHashSet& compatible_execution_providers = {kQnnExecutionProvider}) noexcept + : GraphTransformer("IfToWhereTransformer", compatible_execution_providers) {} + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const; +}; + +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 8b3f55c7df756..e187e02cba5b4 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -52,6 +52,7 @@ #include "core/optimizer/graph_transformer.h" #include "core/optimizer/identity_elimination.h" #include "core/optimizer/initializer.h" +#include "core/optimizer/if_to_where_transformer.h" #include "core/optimizer/isinf_reducesum_fusion.h" #include "core/optimizer/label_encoder_fusion.h" #include "core/optimizer/matmul_add_fusion.h" @@ -628,6 +629,53 @@ TEST_F(GraphTransformationTests, ConstantFoldingNodesOnDifferentEP) { } } +// Tests that a simple If node with dynamic condition is transformed into a Where node +TEST_F(GraphTransformationTests, IfToWhereTransformer_BasicDynamicCondition) { + const ORTCHAR_T* model_uri = MODEL_FOLDER "if_to_where_test.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + + // Assign all nodes to QNN EP + for (auto& node : graph.Nodes()) { + node.SetExecutionProviderType("QNNExecutionProvider"); + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["If"], 0); + EXPECT_EQ(op_to_count["Where"], 1); + EXPECT_EQ(op_to_count["Add"], 1); // from then_branch + EXPECT_EQ(op_to_count["Identity"], 1); // from else_branch +} + +// Tests that the transformer skips If nodes when assigned to a non-QNN execution provider +TEST_F(GraphTransformationTests, IfToWhereTransformer_NonQnnEP) { + const ORTCHAR_T* model_uri = MODEL_FOLDER "if_to_where_test.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + + for (auto& node : graph.Nodes()) { + node.SetExecutionProviderType("CPUExecutionProvider"); // Not QNN + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["If"], 1); // Should remain + EXPECT_EQ(op_to_count["Where"], 0); // Should not be added +} + TEST_F(GraphTransformationTests, ConstantFoldingUnsupportedFloat16) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "constant_float16_mul.onnx"; std::shared_ptr model; diff --git a/onnxruntime/test/testdata/transform/if_to_where_test.onnx b/onnxruntime/test/testdata/transform/if_to_where_test.onnx new file mode 100644 index 0000000000000..5e8bc9c149197 Binary files /dev/null and b/onnxruntime/test/testdata/transform/if_to_where_test.onnx differ