From 6af634668d467472405c4b990ff8f2a5e48dc1d0 Mon Sep 17 00:00:00 2001 From: peishenyan Date: Fri, 8 Aug 2025 14:14:18 +0800 Subject: [PATCH 1/2] temp local window size add log info temp Support local_window_size for WebNN GQA --- .../webnn/builders/impl/gqa_op_builder.cc | 50 +++++++++++++++++-- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc index 0b927075402fe..5e76c6e636081 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc @@ -107,6 +107,7 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b ORT_RETURN_IF_NOT(GetShape(*input_defs[3], input_past_k_shape, logger), "Cannot get past_key shape"); NodeAttrHelper helper(node); + const int32_t local_window_size = helper.Get("local_window_size", -1); const uint32_t kv_num_heads = helper.Get("kv_num_heads", 0); const uint32_t num_heads = helper.Get("num_heads", 0); @@ -290,13 +291,13 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b | | +-------------------------------> Lesser <---------------------Transpose (1,0) | - 1 ---> Where <--- finfo_min (minimum value of FP32) + 1 ---> Where (attn_mask) <--- finfo_min (minimum value of FP32) | attention_bias */ const std::vector mask_shape_ones_shape(batch_size * num_heads * qkv_sequence_length * past_sequence_length, 1); - std::string mask_shape_ones_shape_name = "webnn_GQA_left_constant_of_scatter_indices_" + std::to_string(batch_size) + + std::string mask_shape_ones_shape_name = "webnn_GQA_mask_shape_ones_" + std::to_string(batch_size) + "_" + std::to_string(num_heads) + "_" + std::to_string(qkv_sequence_length) + "_" + std::to_string(past_sequence_length); emscripten::val mask_shape_ones_shape_constant = model_builder.CreateOrGetConstant( @@ -315,7 +316,7 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b std::iota(pre_neq_right_data_range.begin(), pre_neq_right_data_range.end(), 1); std::string pre_neq_right_data_range_name = - "webnn_GQA_left_constant_of_scatter_indices_" + std::to_string(qkv_sequence_length); + "webnn_GQA_pre_neq_right_data_range_" + std::to_string(qkv_sequence_length); emscripten::val pre_neq_right_data_range_constant = model_builder.CreateOrGetConstant( ONNX_NAMESPACE::TensorProto_DataType_INT32, pre_neq_right_data_range_name, pre_neq_right_data_range, std::vector({qkv_sequence_length})); @@ -333,10 +334,49 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b emscripten::val neq_right = model_builder.GetBuilder().call("transpose", expanded_neq_right, transpose_options); - common_options.set("label", node.Name() + "_/GQA/attn_mask/condition"); - emscripten::val condition = + common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_1"); + emscripten::val condition_1 = model_builder.GetBuilder().call("lesser", neq_left, neq_right, common_options); + emscripten::val condition = condition_1; + // For local window size not equal to -1, new attention mask pattern for applying sliding window + /* + condition_1 (old attn_mask) ---> CumSum (axis=3, exclusive=true, reversed=true) + | | + | Lesser <--- local_window_size + | | + LogicalAnd <----------------- condition_2 + | + new attn_mask + */ + if (local_window_size != -1) { + emscripten::val console = emscripten::val::global("console"); + console.call("log", emscripten::val("local window size is not -1.")); + // Cast condition + common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_2/cast"); + emscripten::val casted_condition_1 = + model_builder.GetBuilder().call("cast", condition_1, emscripten::val("int32"), common_options); + + cumsum_options = emscripten::val::object(); + cumsum_options.set("label", node.Name() + "_/GQA/attn_mask/condition_2/cumsum"); + cumsum_options.set("exclusive", true); + cumsum_options.set("reversed", true); + emscripten::val neq_left_2 = model_builder.GetBuilder().call( + "cumulativeSum", casted_condition_1, gsl::narrow(3), cumsum_options); + + emscripten::val local_window_size_constant = + model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_INT32, local_window_size, {1}); + + common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_2"); + emscripten::val condition_2 = + model_builder.GetBuilder().call("lesser", neq_left_2, local_window_size_constant, common_options); + + common_options.set("label", node.Name() + "_/GQA/attn_mask/condition/and"); + condition = model_builder.GetBuilder().call( + "logicalAnd", condition_1, condition_2, common_options); + console.call("log", condition_2); + } + emscripten::val value_one_constant = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1, {1}); From e9ef8ae05e8a9822476965ba2ae4309d4485ac9a Mon Sep 17 00:00:00 2001 From: peishenyan Date: Mon, 3 Nov 2025 16:13:12 +0800 Subject: [PATCH 2/2] decompose large constant --- .../webnn/builders/impl/gqa_op_builder.cc | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc index 5e76c6e636081..a29fbdb91e79f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc @@ -295,14 +295,13 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b | attention_bias */ - const std::vector mask_shape_ones_shape(batch_size * num_heads * qkv_sequence_length * past_sequence_length, - 1); - std::string mask_shape_ones_shape_name = "webnn_GQA_mask_shape_ones_" + std::to_string(batch_size) + - "_" + std::to_string(num_heads) + "_" + std::to_string(qkv_sequence_length) + - "_" + std::to_string(past_sequence_length); - emscripten::val mask_shape_ones_shape_constant = model_builder.CreateOrGetConstant( - ONNX_NAMESPACE::TensorProto_DataType_INT32, mask_shape_ones_shape_name, mask_shape_ones_shape, - std::vector({batch_size, num_heads, qkv_sequence_length, past_sequence_length})); + emscripten::val value_int_one_constant = + model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_INT32, 1, {1}); + + std::vector mask_shape_ones_shape = {batch_size, num_heads, qkv_sequence_length, past_sequence_length}; + common_options.set("label", node.Name() + "_/GQA/GQA_mask_shape_ones/expand"); + emscripten::val mask_shape_ones_shape_constant = model_builder.GetBuilder().call( + "expand", value_int_one_constant, emscripten::val::array(mask_shape_ones_shape), common_options); emscripten::val cumsum_options = emscripten::val::object(); cumsum_options.set("label", node.Name() + "_range_of_mask_shape"); @@ -350,8 +349,6 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b new attn_mask */ if (local_window_size != -1) { - emscripten::val console = emscripten::val::global("console"); - console.call("log", emscripten::val("local window size is not -1.")); // Cast condition common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_2/cast"); emscripten::val casted_condition_1 = @@ -374,7 +371,6 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b common_options.set("label", node.Name() + "_/GQA/attn_mask/condition/and"); condition = model_builder.GetBuilder().call( "logicalAnd", condition_1, condition_2, common_options); - console.call("log", condition_2); } emscripten::val value_one_constant =