Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 48 additions & 12 deletions onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
ORT_RETURN_IF_NOT(GetShape(*input_defs[3], input_past_k_shape, logger), "Cannot get past_key shape");

NodeAttrHelper helper(node);
const int32_t local_window_size = helper.Get("local_window_size", -1);
const uint32_t kv_num_heads = helper.Get("kv_num_heads", 0);
const uint32_t num_heads = helper.Get("num_heads", 0);

Expand Down Expand Up @@ -290,18 +291,17 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
| |
+-------------------------------> Lesser <---------------------Transpose (1,0)
|
1 ---> Where <--- finfo_min (minimum value of FP32)
1 ---> Where (attn_mask) <--- finfo_min (minimum value of FP32)
|
attention_bias
*/
const std::vector<int32_t> mask_shape_ones_shape(batch_size * num_heads * qkv_sequence_length * past_sequence_length,
1);
std::string mask_shape_ones_shape_name = "webnn_GQA_left_constant_of_scatter_indices_" + std::to_string(batch_size) +
"_" + std::to_string(num_heads) + "_" + std::to_string(qkv_sequence_length) +
"_" + std::to_string(past_sequence_length);
emscripten::val mask_shape_ones_shape_constant = model_builder.CreateOrGetConstant<int32_t>(
ONNX_NAMESPACE::TensorProto_DataType_INT32, mask_shape_ones_shape_name, mask_shape_ones_shape,
std::vector<uint32_t>({batch_size, num_heads, qkv_sequence_length, past_sequence_length}));
emscripten::val value_int_one_constant =
model_builder.CreateOrGetConstant<int>(ONNX_NAMESPACE::TensorProto_DataType_INT32, 1, {1});

std::vector<uint32_t> mask_shape_ones_shape = {batch_size, num_heads, qkv_sequence_length, past_sequence_length};
common_options.set("label", node.Name() + "_/GQA/GQA_mask_shape_ones/expand");
emscripten::val mask_shape_ones_shape_constant = model_builder.GetBuilder().call<emscripten::val>(
"expand", value_int_one_constant, emscripten::val::array(mask_shape_ones_shape), common_options);

emscripten::val cumsum_options = emscripten::val::object();
cumsum_options.set("label", node.Name() + "_range_of_mask_shape");
Expand All @@ -315,7 +315,7 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
std::iota(pre_neq_right_data_range.begin(), pre_neq_right_data_range.end(), 1);

std::string pre_neq_right_data_range_name =
"webnn_GQA_left_constant_of_scatter_indices_" + std::to_string(qkv_sequence_length);
"webnn_GQA_pre_neq_right_data_range_" + std::to_string(qkv_sequence_length);
emscripten::val pre_neq_right_data_range_constant = model_builder.CreateOrGetConstant<int32_t>(
ONNX_NAMESPACE::TensorProto_DataType_INT32, pre_neq_right_data_range_name, pre_neq_right_data_range,
std::vector<uint32_t>({qkv_sequence_length}));
Expand All @@ -333,10 +333,46 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
emscripten::val neq_right =
model_builder.GetBuilder().call<emscripten::val>("transpose", expanded_neq_right, transpose_options);

common_options.set("label", node.Name() + "_/GQA/attn_mask/condition");
emscripten::val condition =
common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_1");
emscripten::val condition_1 =
model_builder.GetBuilder().call<emscripten::val>("lesser", neq_left, neq_right, common_options);

emscripten::val condition = condition_1;
// For local window size not equal to -1, new attention mask pattern for applying sliding window
/*
condition_1 (old attn_mask) ---> CumSum (axis=3, exclusive=true, reversed=true)
| |
| Lesser <--- local_window_size
| |
LogicalAnd <----------------- condition_2
|
new attn_mask
*/
if (local_window_size != -1) {
// Cast condition
common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_2/cast");
emscripten::val casted_condition_1 =
model_builder.GetBuilder().call<emscripten::val>("cast", condition_1, emscripten::val("int32"), common_options);

cumsum_options = emscripten::val::object();
cumsum_options.set("label", node.Name() + "_/GQA/attn_mask/condition_2/cumsum");
cumsum_options.set("exclusive", true);
cumsum_options.set("reversed", true);
emscripten::val neq_left_2 = model_builder.GetBuilder().call<emscripten::val>(
"cumulativeSum", casted_condition_1, gsl::narrow<uint32_t>(3), cumsum_options);

emscripten::val local_window_size_constant =
model_builder.CreateOrGetConstant<int>(ONNX_NAMESPACE::TensorProto_DataType_INT32, local_window_size, {1});

common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_2");
emscripten::val condition_2 =
model_builder.GetBuilder().call<emscripten::val>("lesser", neq_left_2, local_window_size_constant, common_options);

common_options.set("label", node.Name() + "_/GQA/attn_mask/condition/and");
condition = model_builder.GetBuilder().call<emscripten::val>(
"logicalAnd", condition_1, condition_2, common_options);
}

emscripten::val value_one_constant =
model_builder.CreateOrGetConstant<float>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1, {1});

Expand Down
Loading