Skip to content

Commit 7be71cf

Browse files
authored
[NPUW][OpenVINO 2025.4] Hot fix for Speculative Decode (#33062)
### Details: - *Sibling of #33061 - *Fixes speculative decode acceptance rate by allowing trim of draft model* ### Tickets: - *EISW-193919*
1 parent 610d5f9 commit 7be71cf

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

src/plugins/intel_npu/src/plugin/npuw/llm_infer_request.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -523,8 +523,8 @@ void ov::npuw::LLMInferRequest::update_kvcache_for(
523523

524524
void ov::npuw::LLMInferRequest::trim_kvcache_for_speculative_decoding(ov::SoPtr<ov::ITensor> position_ids) {
525525
auto& kvcache_desc = m_npuw_llm_compiled_model->m_kvcache_desc;
526-
// FIXME: It can not work with OmniThinker for now.
527-
OPENVINO_ASSERT((position_ids->get_shape().size() >= 2) && (position_ids->get_shape().back() >= 1));
526+
// FIXME: It won't work with Qwen2.5-VL/Omni for now.
527+
OPENVINO_ASSERT((position_ids->get_shape().size() == 2) && (position_ids->get_shape().back() >= 1));
528528
auto position_id = position_ids->data<int64_t>()[0];
529529
auto dirty_num = kvcache_desc.num_stored_tokens - static_cast<uint32_t>(position_id);
530530
if (dirty_num > 0) {
@@ -926,15 +926,16 @@ void ov::npuw::LLMInferRequest::infer() {
926926
// number of logits.
927927
// The outcome of two items is that prefill and generate stages
928928
// can be safely differentiated by start position id for
929-
// both main and draft models.
929+
// both main and draft models for most of LLMs.
930930
if (input_ids->get_shape()[layer_ids::INPUT_IDS_SEQ_LEN_DIM] > 1 &&
931931
position_ids->data<int64_t>()[0] == m_first_position_id) {
932932
infer_prefill(input_ids, attention_mask, position_ids, token_type_ids);
933933
} else {
934-
auto& kvcache_desc = m_npuw_llm_compiled_model->m_kvcache_desc;
935-
// Need to reconsider the solution. Some model like Qwen2.5VL, doesn't use speculative decoding,
936-
// but it may have repeated position ids, then it will trigger kvcache trim and cause AC issue.
937-
if (kvcache_desc.max_generation_token_len > 1) {
934+
// FIXME: Need to make the solution smarter.
935+
// Qwen2.5VL uses 3D position_ids but current `trim_kvcache_for_speculative_decoding`
936+
// doesn't take this into account and causes accuracy issues.
937+
// Speculative Decode isn't supposed to work with such position_ids currently.
938+
if (position_ids->get_shape().size() < 3) {
938939
trim_kvcache_for_speculative_decoding(position_ids);
939940
}
940941
infer_generate(input_ids, attention_mask, position_ids, token_type_ids);

0 commit comments

Comments
 (0)