@@ -523,8 +523,8 @@ void ov::npuw::LLMInferRequest::update_kvcache_for(
523523
524524void ov::npuw::LLMInferRequest::trim_kvcache_for_speculative_decoding (ov::SoPtr<ov::ITensor> position_ids) {
525525 auto & kvcache_desc = m_npuw_llm_compiled_model->m_kvcache_desc ;
526- // FIXME: It can not work with OmniThinker for now.
527- OPENVINO_ASSERT ((position_ids->get_shape ().size () > = 2 ) && (position_ids->get_shape ().back () >= 1 ));
526+ // FIXME: It won't work with Qwen2.5-VL/Omni for now.
527+ OPENVINO_ASSERT ((position_ids->get_shape ().size () = = 2 ) && (position_ids->get_shape ().back () >= 1 ));
528528 auto position_id = position_ids->data <int64_t >()[0 ];
529529 auto dirty_num = kvcache_desc.num_stored_tokens - static_cast <uint32_t >(position_id);
530530 if (dirty_num > 0 ) {
@@ -926,15 +926,16 @@ void ov::npuw::LLMInferRequest::infer() {
926926 // number of logits.
927927 // The outcome of two items is that prefill and generate stages
928928 // can be safely differentiated by start position id for
929- // both main and draft models.
929+ // both main and draft models for most of LLMs .
930930 if (input_ids->get_shape ()[layer_ids::INPUT_IDS_SEQ_LEN_DIM] > 1 &&
931931 position_ids->data <int64_t >()[0 ] == m_first_position_id) {
932932 infer_prefill (input_ids, attention_mask, position_ids, token_type_ids);
933933 } else {
934- auto & kvcache_desc = m_npuw_llm_compiled_model->m_kvcache_desc ;
935- // Need to reconsider the solution. Some model like Qwen2.5VL, doesn't use speculative decoding,
936- // but it may have repeated position ids, then it will trigger kvcache trim and cause AC issue.
937- if (kvcache_desc.max_generation_token_len > 1 ) {
934+ // FIXME: Need to make the solution smarter.
935+ // Qwen2.5VL uses 3D position_ids but current `trim_kvcache_for_speculative_decoding`
936+ // doesn't take this into account and causes accuracy issues.
937+ // Speculative Decode isn't supposed to work with such position_ids currently.
938+ if (position_ids->get_shape ().size () < 3 ) {
938939 trim_kvcache_for_speculative_decoding (position_ids);
939940 }
940941 infer_generate (input_ids, attention_mask, position_ids, token_type_ids);
0 commit comments