Skip to content

[QNN EP] Add Support for If Op using Graph Transformation #24906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include "core/optimizer/group_query_attention_fusion.h"
#include "core/optimizer/identical_children_consolidation.h"
#include "core/optimizer/identity_elimination.h"
#include "core/optimizer/if_to_where_transformer.h"
#include "core/optimizer/label_encoder_fusion.h"
#include "core/optimizer/layer_norm_fusion.h"
#include "core/optimizer/matmul_activation_fusion.h"
Expand Down Expand Up @@ -285,6 +286,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(

const InlinedHashSet<std::string_view> cuda_eps = {onnxruntime::kCudaExecutionProvider};

const InlinedHashSet<std::string_view> qnn_eps = {onnxruntime::kQnnExecutionProvider};

const InlinedHashSet<std::string_view> cuda_rocm_eps = {onnxruntime::kCudaExecutionProvider,
onnxruntime::kRocmExecutionProvider};
const InlinedHashSet<std::string_view> cpu_cuda_rocm_eps = {onnxruntime::kCpuExecutionProvider,
Expand Down Expand Up @@ -345,6 +348,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(

transformers.emplace_back(std::make_unique<ConvActivationFusion>(cpu_rocm_acl_armnn_js_webgpu_eps));

transformers.emplace_back(std::make_unique<IfToWhereTransformer>(qnn_eps));

transformers.emplace_back(std::make_unique<GeluFusion>(cpu_acl_cuda_dml_rocm_eps, level));
transformers.emplace_back(std::make_unique<LayerNormFusion>(cpu_acl_cuda_dml_rocm_eps, level));
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(cpu_cuda_rocm_eps));
Expand Down
226 changes: 226 additions & 0 deletions onnxruntime/core/optimizer/if_to_where_transformer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/optimizer/if_to_where_transformer.h"
#include "core/graph/graph.h"
#include "core/graph/graph_utils.h"
#include "core/graph/constants.h"
#include "core/graph/model.h"

using namespace ONNX_NAMESPACE;

Check warning on line 10 in onnxruntime/core/optimizer/if_to_where_transformer.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/core/optimizer/if_to_where_transformer.cc:10: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
using namespace onnxruntime::common;

Check warning on line 11 in onnxruntime/core/optimizer/if_to_where_transformer.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/core/optimizer/if_to_where_transformer.cc:11: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
namespace onnxruntime {

static Status InlineSubgraph(
Graph& main_graph,
const Graph& subgraph,
std::unordered_map<std::string, NodeArg*>& name_to_nodearg,
std::vector<NodeArg*>& subgraph_outputs,
const logging::Logger& logger) {
for (const auto* input : subgraph.GetInputs()) {
NodeArg* outer_arg = main_graph.GetNodeArg(input->Name());
if (!outer_arg) {
// If missing, try to add initializer from subgraph to main graph
const ONNX_NAMESPACE::TensorProto* initializer = nullptr;
const ONNX_NAMESPACE::TensorProto* initializer_input = nullptr;
bool has_initializer = subgraph.GetInitializedTensor(input->Name(), initializer);
if (has_initializer && initializer != nullptr) {
// If not in main graph, add initializer
if (!main_graph.GetInitializedTensor(input->Name(), initializer_input)) {
LOGS(logger, VERBOSE) << "Adding initializer '" << input->Name() << "' to main graph.";
main_graph.AddInitializedTensor(*initializer);
}
}
// Create NodeArg in main graph for input
outer_arg = &main_graph.GetOrCreateNodeArg(input->Name(), input->TypeAsProto());
}
name_to_nodearg[input->Name()] = outer_arg;
}

// Copy subgraph initializers to main graph
for (const auto& pair : subgraph.GetAllInitializedTensors()) {
const std::string& init_name = pair.first;
const ONNX_NAMESPACE::TensorProto* initializer = nullptr;
bool has_initializer = subgraph.GetInitializedTensor(init_name, initializer);
if (has_initializer && initializer != nullptr) {
const ONNX_NAMESPACE::TensorProto* initializer_input = nullptr;
if (!main_graph.GetInitializedTensor(init_name, initializer_input)) {
LOGS(logger, VERBOSE) << "Creating NodeArg for subgraph initializer '" << init_name << "' in main graph.";
main_graph.AddInitializedTensor(*initializer);
}
main_graph.GetOrCreateNodeArg(init_name, nullptr);
}
}

// Inline nodes
for (const auto& node : subgraph.Nodes()) {
std::vector<NodeArg*> inputs;
for (const auto* input_arg : node.InputDefs()) {
NodeArg* mapped_input = nullptr;
auto it = name_to_nodearg.find(input_arg->Name());
if (it != name_to_nodearg.end()) {
mapped_input = it->second;
} else {
mapped_input = main_graph.GetNodeArg(input_arg->Name());
if (!mapped_input) {
// Add NodeArg if missing
mapped_input = &main_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto());
}
}
if (!mapped_input) {
LOGS(logger, ERROR) << "Node input '" << input_arg->Name() << "' not found in main graph for node '" << node.Name() << "'";
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"Node input '", input_arg->Name(), "' not found in main graph for node '", node.Name(), "'");
}
inputs.push_back(mapped_input);
}

// Outputs
std::vector<NodeArg*> outputs;
for (const auto* output_arg : node.OutputDefs()) {
// Create a unique NodeArg in main graph for each output
NodeArg* mapped_output = &main_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto());
name_to_nodearg[output_arg->Name()] = mapped_output;
outputs.push_back(mapped_output);
}

// Add node to main graph
main_graph.AddNode(
main_graph.GenerateNodeName(node.OpType()),
node.OpType(),
node.Description(),
inputs,
outputs,
&node.GetAttributes(),
node.Domain());
}

// Map subgraph outputs to main graph NodeArgs
for (const auto* output : subgraph.GetOutputs()) {
auto it = name_to_nodearg.find(output->Name());
NodeArg* mapped_output = (it != name_to_nodearg.end()) ? it->second : main_graph.GetNodeArg(output->Name());
if (!mapped_output) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"Missing output: ", output->Name(), " in main graph after inlining.");
}
subgraph_outputs.push_back(mapped_output);
}

return Status::OK();
}

static bool CanTransformIfToWhere(const Node& if_node) {
const auto& outputs = if_node.OutputDefs();
for (const auto* output : outputs) {
const auto* type_proto = output->TypeAsProto();
if (!type_proto) return false;

// If any output is optional, we skip the transformation
if (type_proto->value_case() == onnx::TypeProto::kOptionalType) {
return false;
}

// Optionally, ensure it's a tensor
if (type_proto->value_case() != onnx::TypeProto::kTensorType) {
return false;
}
}
return true;
}

Status IfToWhereTransformer::ApplyImpl(Graph& graph,
bool& modified,
int graph_level,
const logging::Logger& logger) const {
modified = false;
(void)graph_level;
for (auto it = graph.Nodes().begin(), end = graph.Nodes().end(); it != end; ++it) {
Node& if_node = *it;
if (if_node.OpType() != "If")
continue;

if (!graph_utils::IsSupportedProvider(if_node, GetCompatibleExecutionProviders())) {
// If not Qnn Execution Provider
continue;
}

if (!CanTransformIfToWhere(if_node)) {
LOGS(logger, VERBOSE) << "Skipping IfToWhere transformation for node " << if_node.Name()
<< " due to optional or unsupported output types.";
continue;
}

const NodeArg* cond_arg = if_node.InputDefs()[0];
if (!cond_arg)
continue;
const onnx::TypeProto* type_proto = cond_arg->TypeAsProto();
if (type_proto && type_proto->has_tensor_type()) {
if (type_proto->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_BOOL) {
continue;
}
}
if (graph_utils::IsConstantInitializer(graph, cond_arg->Name(), true))
continue;

// Extract the two subgraph branches
Graph* then_sub = if_node.GetMutableGraphAttribute("then_branch");
Graph* else_sub = if_node.GetMutableGraphAttribute("else_branch");
if (!then_sub || !else_sub) {
LOGS(logger, ERROR) << "If node missing subgraphs!";
continue;
}

std::unordered_map<std::string, NodeArg*> then_map, else_map;

Check warning on line 173 in onnxruntime/core/optimizer/if_to_where_transformer.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/if_to_where_transformer.cc:173: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]

Check warning on line 173 in onnxruntime/core/optimizer/if_to_where_transformer.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/if_to_where_transformer.cc:173: Add #include <string> for string [build/include_what_you_use] [4]
std::vector<NodeArg*> then_outputs, else_outputs;
ORT_RETURN_IF_ERROR(InlineSubgraph(graph, *then_sub, then_map, then_outputs, logger));
ORT_RETURN_IF_ERROR(InlineSubgraph(graph, *else_sub, else_map, else_outputs, logger));

// Build one Where node per original If output
const auto& if_outputs = if_node.MutableOutputDefs();
size_t num_outputs = if_outputs.size();

if (then_outputs.size() != else_outputs.size() || then_outputs.size() != if_outputs.size()) {
LOGS(logger, ERROR) << "Mismatch in output sizes between then/else branches and If node.";
continue;
}

for (size_t i = 0; i < num_outputs; ++i) {
NodeArg* out_arg = if_outputs[i];
graph.AddNode(
graph.GenerateNodeName("IfToWhere"),
"Where",
"Select between then/else outputs",
{const_cast<NodeArg*>(cond_arg), then_outputs[i], else_outputs[i]},
{out_arg});
}

std::vector<std::tuple<NodeIndex, NodeIndex, int, int>> edges_to_remove;

Check warning on line 197 in onnxruntime/core/optimizer/if_to_where_transformer.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/if_to_where_transformer.cc:197: Add #include <vector> for vector<> [build/include_what_you_use] [4]
for (auto edge_it = if_node.OutputEdgesBegin(), last = if_node.OutputEdgesEnd();
edge_it != last; ++edge_it) {
edges_to_remove.emplace_back(
if_node.Index(), // src node index
edge_it->GetNode().Index(), // dst node index
edge_it->GetSrcArgIndex(), // which output slot of 'if_node'
edge_it->GetDstArgIndex()); // which input slot on consumer
}

for (auto& edge : edges_to_remove) {
graph.RemoveEdge(
std::get<0>(edge),
std::get<1>(edge),
std::get<2>(edge),
std::get<3>(edge));
}

graph.RemoveNode(if_node.Index());
modified = true;
break;
}

if (modified) {
ORT_RETURN_IF_ERROR(graph.Resolve());
}
return Status::OK();
}

} // namespace onnxruntime
17 changes: 17 additions & 0 deletions onnxruntime/core/optimizer/if_to_where_transformer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

#include "core/optimizer/graph_transformer.h"

namespace onnxruntime {

// Transformer that lowers dynamic If nodes into Where ops by inlining both branches.
class IfToWhereTransformer final : public GraphTransformer {
public:
IfToWhereTransformer(const InlinedHashSet<std::string_view>& compatible_execution_providers = {kQnnExecutionProvider}) noexcept
: GraphTransformer("IfToWhereTransformer", compatible_execution_providers) {}

private:
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const;
};

} // namespace onnxruntime
48 changes: 48 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include "core/optimizer/graph_transformer.h"
#include "core/optimizer/identity_elimination.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/if_to_where_transformer.h"
#include "core/optimizer/isinf_reducesum_fusion.h"
#include "core/optimizer/label_encoder_fusion.h"
#include "core/optimizer/matmul_add_fusion.h"
Expand Down Expand Up @@ -628,6 +629,53 @@ TEST_F(GraphTransformationTests, ConstantFoldingNodesOnDifferentEP) {
}
}

// Tests that a simple If node with dynamic condition is transformed into a Where node
TEST_F(GraphTransformationTests, IfToWhereTransformer_BasicDynamicCondition) {
const ORTCHAR_T* model_uri = MODEL_FOLDER "if_to_where_test.onnx";
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
Graph& graph = model->MainGraph();

// Assign all nodes to QNN EP
for (auto& node : graph.Nodes()) {
node.SetExecutionProviderType("QNNExecutionProvider");
}

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<IfToWhereTransformer>(), TransformerLevel::Level1));

ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["If"], 0);
EXPECT_EQ(op_to_count["Where"], 1);
EXPECT_EQ(op_to_count["Add"], 1); // from then_branch
EXPECT_EQ(op_to_count["Identity"], 1); // from else_branch
}

// Tests that the transformer skips If nodes when assigned to a non-QNN execution provider
TEST_F(GraphTransformationTests, IfToWhereTransformer_NonQnnEP) {
const ORTCHAR_T* model_uri = MODEL_FOLDER "if_to_where_test.onnx";
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
Graph& graph = model->MainGraph();

for (auto& node : graph.Nodes()) {
node.SetExecutionProviderType("CPUExecutionProvider"); // Not QNN
}

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<IfToWhereTransformer>(), TransformerLevel::Level1));

ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["If"], 1); // Should remain
EXPECT_EQ(op_to_count["Where"], 0); // Should not be added
}

TEST_F(GraphTransformationTests, ConstantFoldingUnsupportedFloat16) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "constant_float16_mul.onnx";
std::shared_ptr<Model> model;
Expand Down
Binary file not shown.
Loading