Skip to content

Commit 07192cd

Browse files
committed
[webgpu] Throw errors for graph catpure when not implemented
1 parent b39e144 commit 07192cd

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

onnxruntime/contrib_ops/webgpu/bert/attention.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,10 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
522522
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
523523
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* head_sink,
524524
const Tensor* seqlen_k, int local_window_size) {
525+
if (context.IsGraphCaptureEnabled()) {
526+
ORT_NOT_IMPLEMENTED("Graph capture not implemented for non flash attention path");
527+
}
528+
525529
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});
526530
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
527531
const int total_sequence_length =

0 commit comments

Comments
 (0)