Skip to content

Commit 58b598f

Browse files
committed
Allow ops to handle ignoring an empty tensor as input
1 parent 85ceffe commit 58b598f

File tree

5 files changed

+17
-10
lines changed

5 files changed

+17
-10
lines changed

onnxruntime/core/providers/webnn/builders/helper.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We
6969
}
7070
}
7171

72-
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) {
72+
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name,
73+
const logging::Logger& logger, bool allow_empty_input) {
7374
const auto& node_arg_name = node_arg.Name();
7475
const auto* shape_proto = node_arg.Shape();
7576
// Optional tensors can be indicated by an empty name, just ignore it.
@@ -89,7 +90,7 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n
8990
<< "use sessionOptions.FreeDimensionOverrides to set a fixed shape: " << node_arg_name;
9091
return false;
9192
}
92-
if (dim.dim_value() == 0) {
93+
if (dim.dim_value() == 0 && !allow_empty_input) {
9394
LOGS(logger, VERBOSE) << "The shape of [" << node_arg_name << "] has 0 dimension which is not supported by WebNN";
9495
return false;
9596
}

onnxruntime/core/providers/webnn/builders/helper.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::s
181181
return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; });
182182
}
183183

184-
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);
184+
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name,
185+
const logging::Logger& logger, bool allow_empty_input = false);
185186

186187
// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
187188
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,

onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node&
2929
bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node,
3030
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
3131
const logging::Logger& logger) const {
32-
if (!HasSupportedInputs(initializers, node, wnn_limits, logger))
32+
if (!HasSupportedInputs(node, wnn_limits, logger))
3333
return false;
3434

3535
if (!HasSupportedOutputs(node, wnn_limits, logger))
@@ -41,12 +41,11 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons
4141
return IsOpSupportedImpl(initializers, node, device_type, logger);
4242
}
4343

44-
bool BaseOpBuilder::HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node,
45-
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
44+
bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits,
45+
const logging::Logger& logger) const {
4646
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
4747
for (const auto* input : node.InputDefs()) {
48-
// ONNX initializers should have shape information, skip the shape check if the input is an initializer.
49-
if (!Contains(initializers, input->Name()) && !IsTensorShapeSupported(*input, node_name, logger)) {
48+
if (!IsTensorShapeSupported(*input, node_name, logger, allow_empty_tensor_as_input_)) {
5049
return false;
5150
}
5251
}

onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ class BaseOpBuilder : public IOpBuilder {
2222
const logging::Logger& logger) const override final ORT_MUST_USE_RESULT;
2323

2424
protected:
25+
explicit BaseOpBuilder(bool allow_empty_tensor_as_input = false)
26+
: allow_empty_tensor_as_input_(allow_empty_tensor_as_input) {
27+
}
2528
virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
2629
const logging::Logger& logger) const ORT_MUST_USE_RESULT = 0;
2730

@@ -53,9 +56,10 @@ class BaseOpBuilder : public IOpBuilder {
5356

5457
private:
5558
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
56-
bool HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits,
57-
const logging::Logger& logger) const;
59+
bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
5860
bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
61+
62+
const bool allow_empty_tensor_as_input_; // Some operators can handle ignoring an empty tensor as input.
5963
};
6064

6165
} // namespace webnn

onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ namespace webnn {
2121
class ResizeOpBuilder : public BaseOpBuilder {
2222
// Add operator related.
2323
public:
24+
// Allow roi and scales potentially being empty inputs that are ignored during processing.
25+
ResizeOpBuilder() : BaseOpBuilder(/*allow empty inputs*/ true) {}
2426
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
2527

2628
private:

0 commit comments

Comments
 (0)