Skip to content

Commit 586dacd

Browse files
committed
Fixed wrong input index for number of heads in K.
1 parent 3d04b1b commit 586dacd

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
206206
auto num_k_heads =
207207
extract_num_kv_heads(k_heads_unsqueeze, sdpa_node->get_input_tensor(1).get_partial_shape()[-3]);
208208
auto num_v_heads =
209-
extract_num_kv_heads(v_heads_unsqueeze, sdpa_node->get_input_tensor(1).get_partial_shape()[-3]);
209+
extract_num_kv_heads(v_heads_unsqueeze, sdpa_node->get_input_tensor(2).get_partial_shape()[-3]);
210210
const ov::element::Type kv_cache_type = real_q.get_element_type();
211211
std::string layer_index_str = std::to_string(layer_index);
212212
auto k_parameter = setName(std::make_shared<v0::Parameter>(kv_cache_type, PartialShape{-1, num_k_heads, E}),

0 commit comments

Comments
 (0)