From 7ab0c519f2ed1e6b8c089bbccfd77e365d4882c1 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 6 Nov 2025 12:05:25 -0800 Subject: [PATCH] [ET-VK][ez] Ensure that attn_weight buffers do not exceed GPU buffer numel limit Title says it all! To give a concrete example, Llama3.2-1B-Instruct will have attn weights with size `{1, 32, max_seq_len, max_context_len}`. Usually `max_seq_len == max_context_len`, and if `max_context_len = 2048` Then the attention weight tensors will have sizes `{1, 32, 2048, 2048}` which will contain 134217728 elements. The `maxStorageBufferRange` for Adreno 750 is also 134217728 (2^27), so using context length of 2048 will produce incorrect results on Adreno 750. In practice, it is unlikely that the prompt sequence length will be equal to the context length, so the solution is to adjust down the `max_seq_len` dim of the attention weight tensors to ensure that the GPU buffer numel limit is not hit. Differential Revision: [D86443407](https://our.internmc.facebook.com/intern/diff/D86443407/) [ghstack-poisoned] --- backends/vulkan/runtime/graph/ComputeGraph.h | 4 +++ .../vulkan/runtime/graph/ops/impl/SDPA.cpp | 35 +++++++++++++++---- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index f7de7e183de..b61bd4a51c0 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -639,6 +639,10 @@ class ComputeGraph final { bool device_name_contains(const char* substr); + int64_t max_buffer_numel() { + return static_cast(context_->adapter_ptr()->max_buffer_numel()); + } + // // Graph Building // diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 4eed8b82834..d28d2c90fcb 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -471,10 +471,31 @@ void sdpa_impl(ComputeGraph& graph, const std::vector& args) { VK_CHECK_COND(graph.val_is_none(attn_mask)); const int64_t num_q_heads = graph.size_at(-2, q_projected); - const int64_t max_seq_len = graph.size_at(-3, q_projected); - + int64_t max_seq_len = graph.size_at(-3, q_projected); const int64_t max_context_len = graph.size_at(-3, k_cache); + const utils::StorageType attn_weights_storage = + graph.storage_type_of(q_projected); + + // If using buffer storage for attn weights, we need to ensure that the buffer + // numel limit is not exceeded. If needed, manually adjust max_seq_len based + // on the buffer numel limit. + if (attn_weights_storage == utils::kBuffer) { + const int64_t max_buffer_numel = graph.max_buffer_numel(); + if (num_q_heads * max_seq_len * max_context_len >= max_buffer_numel) { + // Compute the maximum possible value for max_seq_len that will hit + // the buffer numel limit. + max_seq_len = max_buffer_numel / (num_q_heads * max_context_len); + // Adjust down to the nearest multiple of 4 to make sure the limit is + // not hit. + if (max_seq_len % 4 != 0) { + max_seq_len = (max_seq_len / 4) * 4; + } else { + max_seq_len -= 4; + } + } + } + std::vector attn_weight_full_sizes = { 1, // batch num_q_heads, @@ -485,14 +506,14 @@ void sdpa_impl(ComputeGraph& graph, const std::vector& args) { &graph, attn_weight_full_sizes, graph.dtype_of(q_projected), - graph.storage_type_of(q_projected), + attn_weights_storage, utils::kWidthPacked); TmpTensor attn_weights_softmax( &graph, attn_weight_full_sizes, graph.dtype_of(q_projected), - graph.storage_type_of(q_projected), + attn_weights_storage, utils::kWidthPacked); add_sdpa_compute_attn_weights_node( @@ -528,9 +549,9 @@ void sdpa_with_kv_cache_impl( utils::StorageType cache_storage = graph.storage_type_of(q_projected); const ValueRef k_cache = - prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked); + graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked); const ValueRef v_cache = - prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked); + graph.add_tensor_like(v_cache_data, cache_storage, utils::kWidthPacked); update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); @@ -573,7 +594,7 @@ void compute_attn_weight_with_kv_cache_impl( (void)sequence_len; - utils::StorageType cache_storage = graph.storage_type_of(q_projected); + const utils::StorageType cache_storage = graph.storage_type_of(q_projected); const ValueRef k_cache = graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked); const ValueRef v_cache =