Skip to content

Commit 5b789e6

Browse files
committed
Refactor host_flash_attention.
Signed-off-by: intelgaoxiong <xiong.gao@intel.com>
1 parent e138d4f commit 5b789e6

File tree

6 files changed

+350
-382
lines changed

6 files changed

+350
-382
lines changed

src/plugins/intel_npu/src/plugin/npuw/base_sync_infer_request.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -539,9 +539,12 @@ void ov::npuw::IBaseInferRequest::bind_global_params(std::size_t idx, RqPtr requ
539539
if (!is_hfa_attention) {
540540
return false; // Early return
541541
}
542+
// Check if sub_in_idx matches any SDPA parameter in the mapping
543+
// HFA parameters: PAST_KEY, PAST_VALUE, QUERY, PRESENT_KEY, PRESENT_VALUE
542544
auto& hfa_attn = proto_comp_model_desc.host_flash_attention.value()._sdpa_attention_info;
543-
return std::any_of(hfa_attn.params.begin(), hfa_attn.params.end(), [&](const auto& p) -> bool {
544-
return p.idx == sub_in_idx;
545+
const auto& param_map = hfa_attn._sdpa_param_index_map;
546+
return std::any_of(param_map.begin(), param_map.end(), [&](const auto& kv) -> bool {
547+
return kv.second == sub_in_idx;
545548
});
546549
};
547550

src/plugins/intel_npu/src/plugin/npuw/compiled_model.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,6 @@ void ov::npuw::CompiledModel::CompiledModelDesc::serialize(std::ostream& stream,
724724
write(stream, host_flash_attention);
725725
if (host_flash_attention.has_value()) {
726726
write(stream, host_flash_attention.value()._tile_size);
727-
write(stream, host_flash_attention.value()._kv_cache_size);
728727

729728
// Serialize compiled tile model
730729
if (host_flash_attention.value()._compiled_tile_model) {
@@ -843,7 +842,6 @@ void ov::npuw::CompiledModel::CompiledModelDesc::deserialize(std::istream& strea
843842
read(stream, host_flash_attention);
844843
if (host_flash_attention.has_value()) {
845844
read(stream, host_flash_attention.value()._tile_size);
846-
read(stream, host_flash_attention.value()._kv_cache_size);
847845

848846
bool has_compiled_model = false;
849847
read(stream, has_compiled_model);
@@ -1794,8 +1792,8 @@ void ov::npuw::CompiledModel::compile_host_flash_attention_model(std::size_t id,
17941792
hfa.set_compiled_tile_model(std::move(compiled_tile_model));
17951793

17961794
LOG_INFO("Successfully compiled host flash attention regular tile model");
1797-
std::cout << "HostFlashAttention tile model compiled on " << device << " (tile_size=" << hfa._tile_size
1798-
<< ", kv_cache_size=" << hfa._kv_cache_size << ")" << std::endl;
1795+
std::cout << "HostFlashAttention tile model compiled on " << device << " (tile_size=" << hfa._tile_size << ")"
1796+
<< std::endl;
17991797
} catch (const std::exception& ex) {
18001798
LOG_ERROR("Failed to compile host flash attention tile model: " << ex.what());
18011799
OPENVINO_THROW("Host flash attention tile model compilation failed: ", ex.what());

0 commit comments

Comments
 (0)