From 1d7d31d8608ea70e88667869d88b427e308e4dd4 Mon Sep 17 00:00:00 2001 From: Irina Efode Date: Tue, 1 Oct 2024 14:09:04 +0400 Subject: [PATCH] Dirty version --- src/cpp/src/continuous_batching_impl.cpp | 272 ++++++++++++------ src/cpp/src/continuous_batching_impl.hpp | 90 +++++- .../continuous_batching_impl_interface.cpp | 44 +++ .../continuous_batching_impl_interface.hpp | 4 +- .../src/paged_attention_transformations.cpp | 31 +- .../src/paged_attention_transformations.hpp | 13 + src/cpp/src/speculative_decoding_impl.cpp | 174 ++++++++++- src/cpp/src/speculative_decoding_impl.hpp | 12 +- 8 files changed, 519 insertions(+), 121 deletions(-) diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index 21152c717..153460f95 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -14,54 +14,36 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl( const Tokenizer& tokenizer, const SchedulerConfig& scheduler_config, const std::string& device, - const ov::AnyMap& plugin_config, - bool is_validation_mode_enabled) { + const ov::AnyMap& plugin_config) { m_tokenizer = tokenizer; - m_is_validation_mode_enabled = is_validation_mode_enabled; - ov::Core core; + ov::Core core; // The model can be compiled for GPU as well std::shared_ptr model = core.read_model(models_path + "/openvino_model.xml"); - DeviceConfig device_config(core, scheduler_config, device, plugin_config); bool is_need_per_layer_cache_control = scheduler_config.use_cache_eviction; apply_paged_attention_transformations(model, device_config, is_need_per_layer_cache_control); + init(model, scheduler_config, plugin_config, device_config, core); +} - ov::InferRequest infer_request = core.compile_model(model, device_config.get_device(), plugin_config).create_infer_request(); - - // setup KV caches - m_cache_manager = std::make_shared(device_config, core); - for (size_t decoder_layer_id = 0; decoder_layer_id < device_config.get_num_layers(); ++decoder_layer_id) { - infer_request.set_tensor(std::string("key_cache.") + std::to_string(decoder_layer_id), m_cache_manager->get_key_cache(decoder_layer_id)); - infer_request.set_tensor(std::string("value_cache.") + std::to_string(decoder_layer_id), m_cache_manager->get_value_cache(decoder_layer_id)); - } - - SchedulerConfig updated_config = scheduler_config; - // update KV number in scheduler config - if (scheduler_config.num_kv_blocks != device_config.get_num_kv_blocks()) { - updated_config.num_kv_blocks = device_config.get_num_kv_blocks(); - } - - bool can_use_partial_preemption = true; - if (device_config.get_device().find("GPU") != std::string::npos && !updated_config.dynamic_split_fuse) { - // in case of executing a `vLLM-like` pipeline, it's better not to use partial eviction on the GPU, - // as it may lead to performance slowdown - can_use_partial_preemption = false; - } - - m_scheduler = std::make_shared(updated_config, device_config.get_num_layers(), can_use_partial_preemption); - // and finally create model runner - bool is_use_cache_eviction = m_scheduler->get_config().use_cache_eviction; - if (is_use_cache_eviction) { - m_model_runner = std::make_shared(infer_request, updated_config, device_config.get_num_layers(), true); - } else { - m_model_runner = std::make_shared(infer_request, updated_config, device_config.get_num_layers()); - } - m_sampler = std::make_shared(m_tokenizer); - m_sampler->set_seed(m_generation_config.rng_seed); +ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl( + ov::Core& core, + const std::shared_ptr& model, + const Tokenizer& tokenizer, + const DeviceConfig& device_config, + const SchedulerConfig& scheduler_config, + const std::string& device, + const ov::AnyMap& plugin_config, + bool is_validation_mode_enabled) { + m_is_validation_mode_enabled = is_validation_mode_enabled; + init(model, scheduler_config, plugin_config, device_config, core); +} - // read default generation config +void ContinuousBatchingPipeline::ContinuousBatchingImpl::_pull_awaiting_requests() { + std::lock_guard lock{m_awaiting_requests_mutex}; + m_requests.insert(m_requests.end(), m_awaiting_requests.begin(), m_awaiting_requests.end()); + m_awaiting_requests.clear(); } GenerationHandle @@ -107,11 +89,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() { step_timer.start(); // Pull awaiting requests - { - std::lock_guard lock{m_awaiting_requests_mutex}; - m_requests.insert(m_requests.end(), m_awaiting_requests.begin(), m_awaiting_requests.end()); - m_awaiting_requests.clear(); - } + _pull_awaiting_requests(); m_pipeline_metrics.requests = m_requests.size(); Scheduler::Output scheduler_output; @@ -294,49 +272,49 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector -ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector& prompts, - std::vector sampling_params, - const StreamerVariant& streamer) { - std::vector input_ids; - static ManualTimer timer("tokenize"); - if (m_is_chat_conversation) { - OPENVINO_ASSERT(1 == prompts.size(), "Can't chat with multiple prompts"); - m_history.push_back({{"role", "user"}, {"content", prompts.at(0)}}); - constexpr bool add_generation_prompt = true; - std::string history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); - timer.start(); - input_ids.push_back(m_tokenizer.encode(history).input_ids); - timer.end(); - } else { - input_ids.reserve(prompts.size()); - for (const std::string& prompt : prompts) { - timer.start(); - input_ids.push_back(m_tokenizer.encode(prompt).input_ids); - timer.end(); - } - } - std::vector encoded = generate(input_ids, sampling_params, streamer); - std::vector decoded; - decoded.reserve(encoded.size()); - for (EncodedGenerationResult& res : encoded) { - std::vector generated; - generated.reserve(res.m_generation_ids.size()); - for (size_t idx = 0; idx < res.m_generation_ids.size(); ++idx) { - generated.push_back(m_tokenizer.decode(res.m_generation_ids.at(idx))); - if (m_is_chat_conversation && 0 == idx) { - m_history.push_back({{"role", "assistant"}, {"content", generated.back()}}); - } - } - decoded.push_back(GenerationResult{ - res.m_request_id, - std::move(generated), - std::move(res.m_scores), - res.m_status - }); - } - return decoded; -} +// std::vector +// ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector& prompts, +// std::vector sampling_params, +// const StreamerVariant& streamer) { +// std::vector input_ids; +// static ManualTimer timer("tokenize"); +// if (m_is_chat_conversation) { +// OPENVINO_ASSERT(1 == prompts.size(), "Can't chat with multiple prompts"); +// m_history.push_back({{"role", "user"}, {"content", prompts.at(0)}}); +// constexpr bool add_generation_prompt = true; +// std::string history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); +// timer.start(); +// input_ids.push_back(m_tokenizer.encode(history).input_ids); +// timer.end(); +// } else { +// input_ids.reserve(prompts.size()); +// for (const std::string& prompt : prompts) { +// timer.start(); +// input_ids.push_back(m_tokenizer.encode(prompt).input_ids); +// timer.end(); +// } +// } +// std::vector encoded = generate(input_ids, sampling_params, streamer); +// std::vector decoded; +// decoded.reserve(encoded.size()); +// for (EncodedGenerationResult& res : encoded) { +// std::vector generated; +// generated.reserve(res.m_generation_ids.size()); +// for (size_t idx = 0; idx < res.m_generation_ids.size(); ++idx) { +// generated.push_back(m_tokenizer.decode(res.m_generation_ids.at(idx))); +// if (m_is_chat_conversation && 0 == idx) { +// m_history.push_back({{"role", "assistant"}, {"content", generated.back()}}); +// } +// } +// decoded.push_back(GenerationResult{ +// res.m_request_id, +// std::move(generated), +// std::move(res.m_scores), +// res.m_status +// }); +// } +// return decoded; +// } void ContinuousBatchingPipeline::ContinuousBatchingImpl::_free_non_running_requests() { std::vector::iterator requests_iterator = m_requests.begin(); @@ -413,4 +391,126 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::maybe_evict_cache_block seq_group_ptr->register_token_eviction(num_blocks_evicted * sched_config.block_size); } } + +void ContinuousBatchingPipeline::ContinuousBatchingImpl::finish_request(int64_t request_id) { + if (request_id == -1) { + while (!m_requests.empty()) { + const auto& request = *m_requests.rbegin(); + for (const auto& sequence : request->get_sequences()) { + m_scheduler->free_sequence(sequence->get_id()); + } + m_sampler->clear_beam_search_info(request->get_request_id()); + m_requests.pop_back(); + } + } else { + for (size_t i = 0; i < m_requests.size(); ++i) { + auto& request = m_requests[i]; + if (request->get_request_id() != request_id) { + continue; + } + for (const auto& sequence : request->get_sequences()) { + m_scheduler->free_sequence(sequence->get_id()); + } + m_sampler->clear_beam_search_info(request->get_request_id()); + m_requests.erase(m_requests.begin() + i); + break; + } + } +} + +std::vector +ContinuousBatchingPipeline::ContinuousBatchingImpl::get_generated_sequences() { + _pull_awaiting_requests(); + std::vector result; + for (const auto& request : m_requests) { + const auto request_id = request->get_request_id(); + for (const auto& sequence : request->get_sequences()) { + auto generated_ids = sequence->get_generated_ids(); + auto log_probs = sequence->get_generated_log_probs(); + result.emplace_back(request_id, sequence->get_grouped_id(), generated_ids, log_probs); + } + } + return result; +} + +ContinuousBatchingPipeline::ContinuousBatchingImpl::UpdateSeqResult +ContinuousBatchingPipeline::ContinuousBatchingImpl::update_generated_sequence( + const ContinuousBatchingPipeline::ContinuousBatchingImpl::GeneratedSequence& candidate_sequence) { + _pull_awaiting_requests(); + bool is_empty_generated_tokens = false; + for (auto& request : m_requests) { + if (candidate_sequence.request_id == request->get_request_id()) { + bool is_seq_exists = false; + // todo: iefode: multiseq + size_t to_remove_tokens = 0, to_insert_tokens = 0; + for (auto& sequence : request->get_sequences()) { + if (candidate_sequence.sequence_id == sequence->get_grouped_id()) { + is_seq_exists = true; + auto present_ids = sequence->get_generated_ids(); + const auto& candidate_ids = candidate_sequence.token_ids; + + // remove extra tokens from sequence + { + auto token_idx = std::min(present_ids.size(), candidate_ids.size()); + if (token_idx) { + while (token_idx-- > 0) { + if (present_ids[token_idx] == candidate_ids[token_idx]) { + break; + } + } + to_remove_tokens = present_ids.size() - (token_idx + 1); + if (to_remove_tokens > 0) { + const auto gen_ids_before = sequence->get_generated_ids(); + sequence->remove_last_tokens(to_remove_tokens); + present_ids = sequence->get_generated_ids(); + const size_t gen_len_before = gen_ids_before.size(), + gen_len_after = present_ids.size(); + if (gen_len_after == 0) { + is_empty_generated_tokens = true; + } + OPENVINO_ASSERT(gen_len_after < gen_len_before); + for (size_t i = gen_len_after; i < gen_len_before; ++i) { + // todo + // m_sampler->update_logit_processor(request->get_request_id(), gen_ids_before[i]); + } + } + } + } + // insert new tokens to sequence + { + OPENVINO_ASSERT(candidate_ids.size() >= present_ids.size()); + const auto& candidate_log_probs = candidate_sequence.log_probs; + const size_t start_id = std::min(present_ids.size(), candidate_ids.size()), + stop_id = std::max(present_ids.size(), candidate_ids.size()); + to_insert_tokens = stop_id - start_id; + for (size_t i = start_id; i < stop_id; ++i) { + sequence->append_token(candidate_ids[i], i < candidate_log_probs.size() ? candidate_log_probs[i] : 0.f); + } + } + } + break; + } + if (!is_seq_exists) { + Sequence::Ptr new_sequence(new Sequence(candidate_sequence.sequence_id)); + const auto& generated_tokens = candidate_sequence.token_ids; + const auto& generated_log_probs = candidate_sequence.log_probs; + for (size_t i = 0; i < generated_tokens.size(); ++i) { + new_sequence->append_token(generated_tokens[i], generated_log_probs[i]); + } + request->add_sequence(new_sequence); + } + if (!is_empty_generated_tokens) { + if (to_remove_tokens > 0) { + // request->decrease_processed_tokens(to_remove_tokens); + } + // to validate tokens/extend kv-cache before generation + // request->set_validation_len(to_insert_tokens); + } else if (to_remove_tokens > 0) { + request->update_processed_tokens_num(request->get_prompt_len()); + } + return ContinuousBatchingPipeline::ContinuousBatchingImpl::UpdateSeqResult(to_insert_tokens, to_remove_tokens); + } + } + return {0, 0}; +} } diff --git a/src/cpp/src/continuous_batching_impl.hpp b/src/cpp/src/continuous_batching_impl.hpp index d3dccf92c..c488f482e 100644 --- a/src/cpp/src/continuous_batching_impl.hpp +++ b/src/cpp/src/continuous_batching_impl.hpp @@ -40,22 +40,71 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc float _get_current_running_average_cache_usage() const; void maybe_evict_cache_blocks(const SchedulerConfig& sched_config); + + inline void + init(std::shared_ptr model, + const SchedulerConfig& scheduler_config, + const ov::AnyMap& plugin_config, + const DeviceConfig& device_config, + ov::Core& core) { + ov::InferRequest infer_request = core.compile_model(model, device_config.get_device(), plugin_config).create_infer_request(); + + // setup KV caches + m_cache_manager = std::make_shared(device_config, core); + for (size_t decoder_layer_id = 0; decoder_layer_id < device_config.get_num_layers(); ++decoder_layer_id) { + infer_request.set_tensor(std::string("key_cache.") + std::to_string(decoder_layer_id), m_cache_manager->get_key_cache(decoder_layer_id)); + infer_request.set_tensor(std::string("value_cache.") + std::to_string(decoder_layer_id), m_cache_manager->get_value_cache(decoder_layer_id)); + } + + SchedulerConfig updated_config = scheduler_config; + // update KV number in scheduler config + if (scheduler_config.num_kv_blocks != device_config.get_num_kv_blocks()) { + updated_config.num_kv_blocks = device_config.get_num_kv_blocks(); + } + + bool can_use_partial_preemption = true; + if (device_config.get_device().find("GPU") != std::string::npos && !updated_config.dynamic_split_fuse) { + // in case of executing a `vLLM-like` pipeline, it's better not to use partial eviction on the GPU, + // as it may lead to performance slowdown + can_use_partial_preemption = false; + } + + m_scheduler = std::make_shared(updated_config, device_config.get_num_layers(), can_use_partial_preemption); + // and finally create model runner + bool is_use_cache_eviction = m_scheduler->get_config().use_cache_eviction; + if (is_use_cache_eviction) { + m_model_runner = std::make_shared(infer_request, updated_config, device_config.get_num_layers(), true); + } else { + m_model_runner = std::make_shared(infer_request, updated_config, device_config.get_num_layers()); + } + m_sampler = std::make_shared(m_tokenizer); + m_sampler->set_seed(m_generation_config.rng_seed); + }; + + void _pull_awaiting_requests(); + public: ContinuousBatchingImpl(const std::string& models_path, const Tokenizer& tokenizer, const SchedulerConfig& scheduler_config, const std::string& device, - const ov::AnyMap& plugin_config, - bool is_validation_mode_enabled = false); + const ov::AnyMap& plugin_config); ContinuousBatchingImpl(const std::string& models_path, const SchedulerConfig& scheduler_config, const std::string& device, const ov::AnyMap& llm_plugin_config, - const ov::AnyMap& tokenizer_plugin_config, - bool is_validation_mode_enabled = false) - : ContinuousBatchingImpl{models_path, Tokenizer(models_path, tokenizer_plugin_config), scheduler_config, device, llm_plugin_config, is_validation_mode_enabled} {}; + const ov::AnyMap& tokenizer_plugin_config) + : ContinuousBatchingImpl{models_path, Tokenizer(models_path, tokenizer_plugin_config), scheduler_config, device, llm_plugin_config} {} + ContinuousBatchingImpl(ov::Core& core, + const std::shared_ptr& model, + const Tokenizer& tokenizer, + const DeviceConfig& device_config, + const SchedulerConfig& scheduler_config, + const std::string& device, + const ov::AnyMap& plugin_config, + bool is_validation_mode_enabled = false); GenerationHandle add_request(uint64_t request_id, const ov::Tensor& input_ids, @@ -72,9 +121,32 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc generate(const std::vector& input_ids, const std::vector& sampling_params, const StreamerVariant& streamer) override; - std::vector - generate(const std::vector& prompts, - std::vector sampling_params, - const StreamerVariant& streamer) override; + // std::vector + // generate(const std::vector& prompts, + // std::vector sampling_params, + // const StreamerVariant& streamer) override; + + // for speculative decoding + void finish_request(int64_t request_id = -1); + + struct GeneratedSequence { + uint64_t request_id = 0, sequence_id = 0; + std::vector token_ids; + std::vector log_probs; + + GeneratedSequence(uint64_t req_id, uint64_t seq_id, const std::vector& generated_token_ids, const std::vector& generated_log_probs) : + request_id(req_id), + sequence_id(seq_id), + token_ids(generated_token_ids), + log_probs(generated_log_probs) {}; + }; + + struct UpdateSeqResult { + size_t to_insert, to_remove; + UpdateSeqResult(size_t _to_insert = 0, size_t _to_remove = 0) : to_insert(_to_insert), to_remove(_to_remove) {}; + }; + + std::vector get_generated_sequences(); + UpdateSeqResult update_generated_sequence(const GeneratedSequence& new_sequence); }; } \ No newline at end of file diff --git a/src/cpp/src/continuous_batching_impl_interface.cpp b/src/cpp/src/continuous_batching_impl_interface.cpp index 7f7db465f..ecbf01229 100644 --- a/src/cpp/src/continuous_batching_impl_interface.cpp +++ b/src/cpp/src/continuous_batching_impl_interface.cpp @@ -27,4 +27,48 @@ void ContinuousBatchingPipeline::ImplInterface::finish_chat() { m_is_chat_conversation = false; m_history.clear(); }; +std::vector +ContinuousBatchingPipeline::ImplInterface::generate( + const std::vector& prompts, + std::vector sampling_params, + const StreamerVariant& streamer) { + std::vector input_ids; + static ManualTimer timer("tokenize"); + if (m_is_chat_conversation) { + OPENVINO_ASSERT(1 == prompts.size(), "Can't chat with multiple prompts"); + m_history.push_back({{"role", "user"}, {"content", prompts.at(0)}}); + constexpr bool add_generation_prompt = true; + std::string history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); + timer.start(); + input_ids.push_back(m_tokenizer.encode(history).input_ids); + timer.end(); + } else { + input_ids.reserve(prompts.size()); + for (const std::string& prompt : prompts) { + timer.start(); + input_ids.push_back(m_tokenizer.encode(prompt).input_ids); + timer.end(); + } + } + std::vector encoded = generate(input_ids, sampling_params, streamer); + std::vector decoded; + decoded.reserve(encoded.size()); + for (EncodedGenerationResult& res : encoded) { + std::vector generated; + generated.reserve(res.m_generation_ids.size()); + for (size_t idx = 0; idx < res.m_generation_ids.size(); ++idx) { + generated.push_back(m_tokenizer.decode(res.m_generation_ids.at(idx))); + if (m_is_chat_conversation && 0 == idx) { + m_history.push_back({{"role", "assistant"}, {"content", generated.back()}}); + } + } + decoded.push_back(GenerationResult{ + res.m_request_id, + std::move(generated), + std::move(res.m_scores), + res.m_status + }); + } + return decoded; +} } \ No newline at end of file diff --git a/src/cpp/src/continuous_batching_impl_interface.hpp b/src/cpp/src/continuous_batching_impl_interface.hpp index a3615b582..eddb07a8b 100644 --- a/src/cpp/src/continuous_batching_impl_interface.hpp +++ b/src/cpp/src/continuous_batching_impl_interface.hpp @@ -58,10 +58,10 @@ class ContinuousBatchingPipeline::ImplInterface { generate(const std::vector& input_ids, const std::vector& sampling_params, const StreamerVariant& streamer) = 0; - virtual std::vector + std::vector generate(const std::vector& prompts, std::vector sampling_params, - const StreamerVariant& streamer) = 0; + const StreamerVariant& streamer); void start_chat(const std::string& system_message); void finish_chat(); diff --git a/src/cpp/src/paged_attention_transformations.cpp b/src/cpp/src/paged_attention_transformations.cpp index 28dda4dea..9851c8119 100644 --- a/src/cpp/src/paged_attention_transformations.cpp +++ b/src/cpp/src/paged_attention_transformations.cpp @@ -14,25 +14,29 @@ inline ov::PartialShape to_partial_with_dyn_0_dim(const ov::Shape& static_shape) return partial_shape; } -/** Applies transformations to the ov::Model to enable paged attention inference. - * @param model Pointer to the ov::Model representing one of the supported LLM architectures. - * @param device_config Configuration struct for inferencing device specifics. - * @param per_layer_cache_control If true, then the transformations will enable per-layer control of KV cache blocks, allowing to specify - * different sets of KV cache blocks for different attention layers. If false, then the KV cache block structure will be identical across all - * decoder layers. - */ -void apply_paged_attention_transformations(std::shared_ptr model, DeviceConfig& device_config, bool per_layer_cache_control) { +size_t get_kv_cache_size(const std::shared_ptr model) { + const auto& parameters = model->get_parameters(); + // extract num_kv_heads and head_size + size_t kv_caches_inputs_offset = 2; + ov::PartialShape k_shape = parameters[kv_caches_inputs_offset]->get_partial_shape(); + OPENVINO_ASSERT(k_shape.rank().get_length() == 3, "KV cache shape is expected to have rank 3, while shape is ", k_shape); + size_t num_kv_heads = k_shape[1].get_length(), head_size = k_shape[2].get_length(); + return num_kv_heads * head_size; +} + +void apply_paged_attention_transformations(std::shared_ptr model, bool per_layer_cache_control) { const ov::op::util::VariableVector& variables = model->get_variables(); OPENVINO_ASSERT(!variables.empty(), "Model is supposed to be stateful"); bool use_block_indices_inputs = per_layer_cache_control; bool use_score_outputs = per_layer_cache_control; ov::pass::SDPAToPagedAttention(use_block_indices_inputs, use_score_outputs).run_on_model(model); +} +void set_kv_cache_type_and_shape(std::shared_ptr model, DeviceConfig& device_config) { const ov::ParameterVector& parameters = model->get_parameters(); - std::map> key_cache_params; - std::map> value_cache_params; + std::map> key_cache_params, value_cache_params; for (const auto& param_ptr : parameters) { const auto& name = param_ptr->get_friendly_name(); if (name.find("key_cache.") == 0) { @@ -43,8 +47,8 @@ void apply_paged_attention_transformations(std::shared_ptr model, Dev } } - OPENVINO_ASSERT(key_cache_params.size() == value_cache_params.size()); OPENVINO_ASSERT(key_cache_params.size() > 0); + OPENVINO_ASSERT(key_cache_params.size() == value_cache_params.size()); size_t num_layers = key_cache_params.size(); // extract num_kv_heads and head_size @@ -66,4 +70,9 @@ void apply_paged_attention_transformations(std::shared_ptr model, Dev model->validate_nodes_and_infer_types(); } + +void apply_paged_attention_transformations(std::shared_ptr model, DeviceConfig& device_config, bool per_layer_cache_control) { + apply_paged_attention_transformations(model, per_layer_cache_control); + set_kv_cache_type_and_shape(model, device_config); } +} \ No newline at end of file diff --git a/src/cpp/src/paged_attention_transformations.hpp b/src/cpp/src/paged_attention_transformations.hpp index a7bce2375..3e2e3bc49 100644 --- a/src/cpp/src/paged_attention_transformations.hpp +++ b/src/cpp/src/paged_attention_transformations.hpp @@ -7,5 +7,18 @@ #include "device_config.hpp" namespace ov::genai { +/** Applies transformations to the ov::Model to enable paged attention inference. + * @param model Pointer to the ov::Model representing one of the supported LLM architectures. + * @param device_config Configuration struct for inferencing device specifics. + * @param per_layer_cache_control If true, then the transformations will enable per-layer control of KV cache blocks, allowing to specify + * different sets of KV cache blocks for different attention layers. If false, then the KV cache block structure will be identical across all + * decoder layers. + */ void apply_paged_attention_transformations(std::shared_ptr model, DeviceConfig& device_config, bool per_layer_cache_control = false); + +void apply_paged_attention_transformations(std::shared_ptr model, bool per_layer_cache_control = false); + +size_t get_kv_cache_size(const std::shared_ptr model); + +void set_kv_cache_type_and_shape(std::shared_ptr model, DeviceConfig& device_config); } \ No newline at end of file diff --git a/src/cpp/src/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding_impl.cpp index 1f7531514..c6980cf52 100644 --- a/src/cpp/src/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding_impl.cpp @@ -1,9 +1,14 @@ // Copyright (C) 2023-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 +#include "text_callback_streamer.hpp" #include "speculative_decoding_impl.hpp" +#include "paged_attention_transformations.hpp" namespace ov::genai { +template struct overloaded : Ts... {using Ts::operator()...;}; +template overloaded(Ts...) -> overloaded; + ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl( const std::string& main_models_path, const std::string& draft_models_path, @@ -11,14 +16,46 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl( const SchedulerConfig& scheduler_config, const std::string& device, const ov::AnyMap& plugin_config) { - m_main_pipeline = std::make_shared(main_models_path, tokenizer, scheduler_config, device, plugin_config, true); - m_draft_pipeline = std::make_shared(draft_models_path, tokenizer, scheduler_config, device, plugin_config, false); + ov::Core core; + std::shared_ptr main_model = core.read_model(main_models_path + "/openvino_model.xml"), + draft_model = core.read_model(draft_models_path + "/openvino_model.xml"); + + apply_paged_attention_transformations(main_model, scheduler_config.use_cache_eviction); + apply_paged_attention_transformations(draft_model, scheduler_config.use_cache_eviction); + + ov::genai::SchedulerConfig main_scheduler_config = scheduler_config, + draft_scheduler_config = scheduler_config; + + // split KV cache to 2 caches for main and draft models + size_t main_model_cache_size = get_kv_cache_size(main_model), + draft_model_cache_size = get_kv_cache_size(draft_model); + auto k = static_cast(draft_model_cache_size) / (main_model_cache_size + draft_model_cache_size); + + size_t main_cache_size = scheduler_config.cache_size * (1 - k), + draft_cache_size = scheduler_config.cache_size * k; + if (draft_cache_size == 0) { + main_cache_size -= main_cache_size > 1 ? 1 : 0; + draft_cache_size = 1; + } + + main_scheduler_config.cache_size = main_cache_size; + draft_scheduler_config.cache_size = draft_cache_size; + + DeviceConfig main_device_config(core, main_scheduler_config, device, plugin_config), + draft_device_config(core, draft_scheduler_config, device, plugin_config); + + set_kv_cache_type_and_shape(main_model, main_device_config); + set_kv_cache_type_and_shape(draft_model, draft_device_config); + + m_main_pipeline = std::make_shared(core, main_model, tokenizer, main_device_config, main_scheduler_config, device, plugin_config, true); + m_draft_pipeline = std::make_shared(core, draft_model, tokenizer, draft_device_config, draft_scheduler_config, device, plugin_config, false); } GenerationHandle ContinuousBatchingPipeline::SpeculativeDecodingImpl::add_request(uint64_t request_id, const ov::Tensor& input_ids, ov::genai::GenerationConfig sampling_params) { + m_draft_pipeline->add_request(request_id, input_ids, sampling_params); return m_main_pipeline->add_request(request_id, input_ids, sampling_params); }; @@ -26,6 +63,7 @@ GenerationHandle ContinuousBatchingPipeline::SpeculativeDecodingImpl::add_request(uint64_t request_id, const std::string& prompt, ov::genai::GenerationConfig sampling_params) { + m_draft_pipeline->add_request(request_id, prompt, sampling_params); return m_main_pipeline->add_request(request_id, prompt, sampling_params); } @@ -33,21 +71,139 @@ bool ContinuousBatchingPipeline::SpeculativeDecodingImpl::has_non_finished_reque return m_main_pipeline->has_non_finished_requests(); } +void ContinuousBatchingPipeline::SpeculativeDecodingImpl::update_strategy(size_t num_matches) { + // dynamically adjust number of generated candidates based on number of matches + // we want to balance the benefits of getting candidates tokens correct with the + // cost of forecasting incorrect candidates tokens. + if (m_num_candidates == 0) { + return; + } + if (num_matches == m_num_candidates) { + m_num_candidates = std::min(m_num_candidates + 2, m_max_num_candidates); + } else { + m_num_candidates = std::max(int64_t(m_num_candidates) - 1, int64_t(1)); + } +} + void ContinuousBatchingPipeline::SpeculativeDecodingImpl::step() { + std::vector candidate_sequences; + // find minimum(candidates_number, seq_len) to generate candidates + size_t min_candidates_number = m_num_candidates; + for (auto& request : m_to_generate_length) { + if (request.second < min_candidates_number && request.second > 0) { + min_candidates_number = request.second; + } + } + // generate candidates by speculative model + for (size_t i = 0; i < min_candidates_number; ++i) { + m_draft_pipeline->step(); + } + + // put candidates to model KV cache + candidate_sequences = m_draft_pipeline->get_generated_sequences(); + for (const auto& candidate : candidate_sequences) { + m_main_pipeline->update_generated_sequence(candidate); + } + + // validate candidates and generate 1 new token + m_main_pipeline->step(); + + auto checked_sequences = m_main_pipeline->get_generated_sequences(); + size_t max_removed_token_cnt = 0; + for (const auto& checked_sequence : checked_sequences) { + auto update_result = m_draft_pipeline->update_generated_sequence(checked_sequence); + max_removed_token_cnt = std::max(max_removed_token_cnt, update_result.to_remove); + } + OPENVINO_ASSERT(m_max_num_candidates >= max_removed_token_cnt); + auto num_matches = m_max_num_candidates - max_removed_token_cnt; + update_strategy(num_matches); + + // update to generate tokens + for (auto& request : m_to_generate_length) { + if (request.second > num_matches) { + request.second -= (num_matches + 1); + } else { + request.second = 0; + m_draft_pipeline->finish_request(request.first); + } + } } std::vector ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector& input_ids, const std::vector& sampling_params, const StreamerVariant& streamer) { - return m_main_pipeline->generate(input_ids, sampling_params, streamer); -} + OPENVINO_ASSERT(!has_non_finished_requests(), "Generate cannot be called while ContinuousBatchingPipeline is already in running state. Use ContinuousBatchingPipeline::add_request"); + OPENVINO_ASSERT(input_ids.size() == sampling_params.size()); + const std::shared_ptr& streamer_ptr = std::visit(overloaded{ + [](std::monostate) -> std::shared_ptr { + return nullptr; + }, + [](const std::shared_ptr& streamer) { + return streamer; + }, + [this](const std::function& streamer) -> std::shared_ptr { + return std::make_unique(m_tokenizer, streamer); + } + }, streamer); -std::vector -ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector& prompts, - std::vector sampling_params, - const StreamerVariant& streamer) { - return m_main_pipeline->generate(prompts, sampling_params, streamer); + std::vector main_generations, draft_generations; + for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) { + OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); + main_generations.push_back(m_main_pipeline->add_request(request_id, input_ids[request_id], sampling_params[request_id])); + + auto draft_sampling_params = sampling_params[request_id]; + draft_sampling_params.max_new_tokens = SIZE_MAX; + draft_sampling_params.min_new_tokens = 0; + draft_generations.push_back(m_draft_pipeline->add_request(request_id, input_ids[request_id], draft_sampling_params)); + } + + std::vector results; + results.reserve(input_ids.size()); + + bool continue_generation = true; + while (has_non_finished_requests() && continue_generation) { + step(); + if (streamer_ptr) { + std::unordered_map token = main_generations.at(0).get()->back(); + OPENVINO_ASSERT(1 == token.size()); + OPENVINO_ASSERT(1 == token.begin()->second.generated_ids.size()); + continue_generation = !streamer_ptr->put(token.begin()->second.generated_ids.at(0)); + } + } + if (streamer_ptr) { + streamer_ptr->end(); + } + draft_generations.clear(); + + for (size_t generation_idx = 0; generation_idx < main_generations.size(); ++generation_idx) { + const auto& generation = main_generations[generation_idx]; + EncodedGenerationResult result; + result.m_request_id = 1; + std::vector generation_outputs = generation->read_all(); + std::sort(generation_outputs.begin(), generation_outputs.end(), [=] (GenerationOutput& r1, GenerationOutput& r2) { + return r1.score > r2.score; + }); + + auto num_outputs = std::min(sampling_params[generation_idx].num_return_sequences, generation_outputs.size()); + for (size_t generation_output_idx = 0; generation_output_idx < num_outputs; ++generation_output_idx) { + const auto& generation_output = generation_outputs[generation_output_idx]; + result.m_generation_ids.push_back(std::move(generation_output.generated_ids)); + result.m_scores.push_back(generation_output.score); + } + result.m_status = generation->get_status(); + results.push_back(std::move(result)); + } + + OPENVINO_ASSERT(results.size() == input_ids.size()); + return results; } +// std::vector +// ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector& prompts, +// std::vector sampling_params, +// const StreamerVariant& streamer) { +// return m_main_pipeline->generate(prompts, sampling_params, streamer); +// } + } diff --git a/src/cpp/src/speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding_impl.hpp index 9dc15d2f3..d7f1c0d53 100644 --- a/src/cpp/src/speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding_impl.hpp @@ -10,6 +10,10 @@ namespace ov::genai { class ContinuousBatchingPipeline::SpeculativeDecodingImpl : public ContinuousBatchingPipeline::ImplInterface { protected: std::shared_ptr m_main_pipeline, m_draft_pipeline; + size_t m_num_candidates = 5, m_max_num_candidates = 10; + std::map m_to_generate_length; + + void update_strategy(size_t num_matches); public: SpeculativeDecodingImpl(const std::string& main_models_path, @@ -43,9 +47,9 @@ class ContinuousBatchingPipeline::SpeculativeDecodingImpl : public ContinuousBat generate(const std::vector& input_ids, const std::vector& sampling_params, const StreamerVariant& streamer) override; - std::vector - generate(const std::vector& prompts, - std::vector sampling_params, - const StreamerVariant& streamer) override; + // std::vector + // generate(const std::vector& prompts, + // std::vector sampling_params, + // const StreamerVariant& streamer) override; }; } \ No newline at end of file