Skip to content

Commit 2b02316

Browse files
authored
[ET-VK][ez] Don't copy zeros for cache tensors (#15601)
Currently, cache tensors for SDPA are prepacked even though the mutable buffer data just contains zeros. For fused SDPA, this step can be skipped. Differential Revision: [D86226137](https://our.internmc.facebook.com/intern/diff/D86226137/)
1 parent c66ea27 commit 2b02316

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ ${define_required_extensions(DTYPE)}
1414

1515
layout(std430) buffer;
1616

17+
#define DEBUG_MODE
18+
19+
#extension GL_EXT_debug_printf : enable
20+
1721
#include "common.glslh"
1822

1923
${layout_declare_tensor(B, "w", "t_cache", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)}
@@ -80,13 +84,17 @@ void main() {
8084
const int S = projected_sizes.z;
8185
const int H = projected_sizes.y;
8286

83-
if (d4 >= D4 || s >= S || h >= H) {
87+
const int c = s + input_pos; // idx along max_context_len dim
88+
const int C = cache_sizes.z;
89+
90+
if (d4 >= D4 || c >= C || h >= H) {
8491
return;
8592
}
8693

87-
const int c = s + input_pos; // idx along max_context_len dim
88-
const int C = cache_sizes.y;
94+
IN_VEC4_T in_texel = IN_VEC4_T(0.0);
95+
if (s < S) {
96+
in_texel = read_projected_d4(d4, h, s, D4, H, S);
97+
}
8998

90-
IN_VEC4_T in_texel = read_projected_d4(d4, h, s, D4, H, S);
9199
write_cache_d4(in_texel, d4, c, h, D4, C, H);
92100
}

backends/vulkan/runtime/graph/ops/impl/SDPA.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -575,9 +575,9 @@ void compute_attn_weight_with_kv_cache_impl(
575575

576576
utils::StorageType cache_storage = graph.storage_type_of(q_projected);
577577
const ValueRef k_cache =
578-
prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked);
578+
graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked);
579579
const ValueRef v_cache =
580-
prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked);
580+
graph.add_tensor_like(v_cache_data, cache_storage, utils::kWidthPacked);
581581

582582
update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1});
583583
update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1});

0 commit comments

Comments
 (0)