Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/plugins/intel_npu/src/plugin/npuw/llm_infer_request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,8 @@ void ov::npuw::LLMInferRequest::update_kvcache_for(

void ov::npuw::LLMInferRequest::trim_kvcache_for_speculative_decoding(ov::SoPtr<ov::ITensor> position_ids) {
auto& kvcache_desc = m_npuw_llm_compiled_model->m_kvcache_desc;
// FIXME: It can not work with OmniThinker for now.
OPENVINO_ASSERT((position_ids->get_shape().size() >= 2) && (position_ids->get_shape().back() >= 1));
// FIXME: It won't work with Qwen2.5-VL/Omni for now.
OPENVINO_ASSERT((position_ids->get_shape().size() == 2) && (position_ids->get_shape().back() >= 1));
auto position_id = position_ids->data<int64_t>()[0];
auto dirty_num = kvcache_desc.num_stored_tokens - static_cast<uint32_t>(position_id);
if (dirty_num > 0) {
Expand Down Expand Up @@ -926,15 +926,16 @@ void ov::npuw::LLMInferRequest::infer() {
// number of logits.
// The outcome of two items is that prefill and generate stages
// can be safely differentiated by start position id for
// both main and draft models.
// both main and draft models for most of LLMs.
if (input_ids->get_shape()[layer_ids::INPUT_IDS_SEQ_LEN_DIM] > 1 &&
position_ids->data<int64_t>()[0] == m_first_position_id) {
infer_prefill(input_ids, attention_mask, position_ids, token_type_ids);
} else {
auto& kvcache_desc = m_npuw_llm_compiled_model->m_kvcache_desc;
// Need to reconsider the solution. Some model like Qwen2.5VL, doesn't use speculative decoding,
// but it may have repeated position ids, then it will trigger kvcache trim and cause AC issue.
if (kvcache_desc.max_generation_token_len > 1) {
// FIXME: Need to make the solution smarter.
// Qwen2.5VL uses 3D position_ids but current `trim_kvcache_for_speculative_decoding`
// doesn't take this into account and causes accuracy issues.
// Speculative Decode isn't supposed to work with such position_ids currently.
if (position_ids->get_shape().size() < 3) {
trim_kvcache_for_speculative_decoding(position_ids);
}
infer_generate(input_ids, attention_mask, position_ids, token_type_ids);
Expand Down
Loading