diff --git a/src/plugins/intel_npu/src/plugin/npuw/llm_infer_request.cpp b/src/plugins/intel_npu/src/plugin/npuw/llm_infer_request.cpp index 3e6abb1dc1c35a..702473ed7e938c 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/llm_infer_request.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/llm_infer_request.cpp @@ -523,8 +523,8 @@ void ov::npuw::LLMInferRequest::update_kvcache_for( void ov::npuw::LLMInferRequest::trim_kvcache_for_speculative_decoding(ov::SoPtr position_ids) { auto& kvcache_desc = m_npuw_llm_compiled_model->m_kvcache_desc; - // FIXME: It can not work with OmniThinker for now. - OPENVINO_ASSERT((position_ids->get_shape().size() >= 2) && (position_ids->get_shape().back() >= 1)); + // FIXME: It won't work with Qwen2.5-VL/Omni for now. + OPENVINO_ASSERT((position_ids->get_shape().size() == 2) && (position_ids->get_shape().back() >= 1)); auto position_id = position_ids->data()[0]; auto dirty_num = kvcache_desc.num_stored_tokens - static_cast(position_id); if (dirty_num > 0) { @@ -926,15 +926,16 @@ void ov::npuw::LLMInferRequest::infer() { // number of logits. // The outcome of two items is that prefill and generate stages // can be safely differentiated by start position id for - // both main and draft models. + // both main and draft models for most of LLMs. if (input_ids->get_shape()[layer_ids::INPUT_IDS_SEQ_LEN_DIM] > 1 && position_ids->data()[0] == m_first_position_id) { infer_prefill(input_ids, attention_mask, position_ids, token_type_ids); } else { - auto& kvcache_desc = m_npuw_llm_compiled_model->m_kvcache_desc; - // Need to reconsider the solution. Some model like Qwen2.5VL, doesn't use speculative decoding, - // but it may have repeated position ids, then it will trigger kvcache trim and cause AC issue. - if (kvcache_desc.max_generation_token_len > 1) { + // FIXME: Need to make the solution smarter. + // Qwen2.5VL uses 3D position_ids but current `trim_kvcache_for_speculative_decoding` + // doesn't take this into account and causes accuracy issues. + // Speculative Decode isn't supposed to work with such position_ids currently. + if (position_ids->get_shape().size() < 3) { trim_kvcache_for_speculative_decoding(position_ids); } infer_generate(input_ids, attention_mask, position_ids, token_type_ids);