Skip to content

Commit

Permalink
Update config by model_desc
Browse files Browse the repository at this point in the history
  • Loading branch information
iefode committed Oct 9, 2024
1 parent 93143e5 commit 9724723
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 39 deletions.
23 changes: 23 additions & 0 deletions src/cpp/include/openvino/genai/generation_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "openvino/runtime/compiled_model.hpp"
#include "openvino/runtime/infer_request.hpp"
#include "openvino/genai/tokenizer.hpp"
#include "openvino/genai/scheduler_config.hpp"
#include "lora_adapter.hpp"

namespace ov {
Expand Down Expand Up @@ -149,6 +150,26 @@ class OPENVINO_GENAI_EXPORTS GenerationConfig {
void validate() const;
};

/**
* @brief ModelDesc serves to activate speculative decoding model in continuous batching pipeline.
* Create SpeculativeDecodingImpl and fill it with sutable values.
*/
struct ModelDesc {
std::string model_path;
std::string device;
ov::genai::SchedulerConfig scheduler_config;
ov::AnyMap plugin_config;

ModelDesc(const std::string& model_path,
const std::string& device = "",
const ov::genai::SchedulerConfig& scheduler_config = {},
const ov::AnyMap& plugin_config = {}) :
model_path(model_path),
device(device),
scheduler_config(scheduler_config),
plugin_config(plugin_config) {}
};

/*
* utils that allow to use generate and operator() in the following way:
* pipe.generate(input_ids, ov::genai::max_new_tokens(200), ov::genai::temperature(1.0f),...)
Expand Down Expand Up @@ -184,6 +205,8 @@ static constexpr ov::Property<NumAssistatantTokensScheduleType> num_assistant_to
static constexpr ov::Property<float> assistant_confidence_threshold{"assistant_confidence_threshold"};
static constexpr ov::Property<size_t> num_assistant_tokens{"num_assistant_tokens"};

static constexpr ov::Property<ModelDesc> draft_model{"draft_model"};

// Predefined Configs
OPENVINO_GENAI_EXPORTS GenerationConfig beam_search();
OPENVINO_GENAI_EXPORTS GenerationConfig greedy();
Expand Down
7 changes: 0 additions & 7 deletions src/cpp/include/openvino/genai/llm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,6 @@ using StringInputs = std::variant<std::string, std::vector<std::string>>;
*/
static constexpr ov::Property<SchedulerConfig> scheduler_config{"scheduler_config"};

/**
* @brief draft_model property serves to activate speculative decoding model in continuous batching pipeline.
* Create SchedulerConfig and fill it with sutable values. Copy or move it to plugin_config.
* And create LLMPipeline instance with this config.
*/
static constexpr ov::Property<SchedulerConfig> draft_model{"draft_model"};

/**
* @brief Structure to store resulting batched tokens and scores for each batch sequence.
* The first num_return_sequences elements correspond to the first batch element.
Expand Down
4 changes: 2 additions & 2 deletions src/cpp/src/continuous_batching_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::string& model
if (draft_model.empty()) {
m_impl = std::make_shared<ContinuousBatchingImpl>(models_path, scheduler_config, device, llm_plugin_config, tokenizer_plugin_config);
} else {
m_impl = std::make_shared<SpeculativeDecodingImpl>(models_path, draft_model, scheduler_config, device, llm_plugin_config_without_draft_model, tokenizer_plugin_config);
m_impl = std::make_shared<SpeculativeDecodingImpl>(models_path, scheduler_config, device, llm_plugin_config_without_draft_model, draft_model, tokenizer_plugin_config);
}
}

Expand All @@ -52,7 +52,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline(
if (draft_model.empty()) {
m_impl = std::make_shared<ContinuousBatchingImpl>(model_path, tokenizer, scheduler_config, device, plugin_config);
} else {
m_impl = std::make_shared<SpeculativeDecodingImpl>(model_path, draft_model, scheduler_config, device, plugin_config_without_draft_model);
m_impl = std::make_shared<SpeculativeDecodingImpl>(model_path, scheduler_config, device, plugin_config_without_draft_model, draft_model);
}
}

Expand Down
71 changes: 42 additions & 29 deletions src/cpp/src/speculative_decoding_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,51 +11,65 @@ template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;

ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(
const std::string& main_models_path,
const std::string& draft_models_path,
const SchedulerConfig& scheduler_config,
const std::string& device,
const ov::AnyMap& plugin_config,
const SchedulerConfig& main_scheduler_config,
const std::string& main_device,
const ov::AnyMap& main_plugin_config,
const ov::genai::ModelDesc draft_model_desc,
const ov::AnyMap& tokenizer_plugin_config) {
ov::Core core;
std::shared_ptr<ov::Model> main_model = core.read_model(main_models_path + "/openvino_model.xml"),
draft_model = core.read_model(draft_models_path + "/openvino_model.xml");
std::string openvino_model_name = "/openvino_model.xml",
draft_model_path = draft_model_desc.model_path;

apply_paged_attention_transformations(main_model, scheduler_config.use_cache_eviction);
apply_paged_attention_transformations(draft_model, scheduler_config.use_cache_eviction);
std::shared_ptr<ov::Model> main_model = core.read_model(main_models_path + openvino_model_name),
draft_model = core.read_model(draft_model_path + openvino_model_name);

ov::genai::SchedulerConfig main_scheduler_config = scheduler_config,
draft_scheduler_config = scheduler_config;

// split KV cache to 2 caches for main and draft models
size_t main_model_cache_size = get_kv_cache_size(main_model),
draft_model_cache_size = get_kv_cache_size(draft_model);
auto k = static_cast<float>(draft_model_cache_size) / (main_model_cache_size + draft_model_cache_size);

size_t main_cache_size = scheduler_config.cache_size * (1 - k),
draft_cache_size = scheduler_config.cache_size * k;
if (draft_cache_size == 0) {
main_cache_size -= main_cache_size > 1 ? 1 : 0;
draft_cache_size = 1;
apply_paged_attention_transformations(main_model, main_scheduler_config.use_cache_eviction);
apply_paged_attention_transformations(draft_model, main_scheduler_config.use_cache_eviction);

std::string draft_device = draft_model_desc.device;
bool is_draft_device_undefined = false;
if (draft_device.empty()) {
draft_device = main_device;
is_draft_device_undefined = true;
}

main_scheduler_config.cache_size = main_cache_size;
draft_scheduler_config.cache_size = draft_cache_size;
ov::genai::SchedulerConfig main_scheduler_config_updated = main_scheduler_config,
draft_scheduler_config = is_draft_device_undefined ? main_scheduler_config : draft_model_desc.scheduler_config;
if (is_draft_device_undefined) {
// split KV cache to 2 caches for main and draft models
size_t main_model_cache_size = get_kv_cache_size(main_model),
draft_model_cache_size = get_kv_cache_size(draft_model);
auto k = static_cast<float>(draft_model_cache_size) / (main_model_cache_size + draft_model_cache_size);

size_t main_cache_size = main_scheduler_config.cache_size * (1 - k),
draft_cache_size = main_scheduler_config.cache_size * k;
if (draft_cache_size == 0) {
main_cache_size -= main_cache_size > 1 ? 1 : 0;
draft_cache_size = 1;
}

DeviceConfig main_device_config(core, main_scheduler_config, device, plugin_config),
draft_device_config(core, draft_scheduler_config, device, plugin_config);
main_scheduler_config_updated.cache_size = main_cache_size;
draft_scheduler_config.cache_size = draft_cache_size;
}

ov::AnyMap draft_plugin_config = is_draft_device_undefined ? main_plugin_config : draft_model_desc.plugin_config;

DeviceConfig main_device_config(core, main_scheduler_config, main_device, main_plugin_config),
draft_device_config(core, draft_scheduler_config, draft_device, draft_plugin_config);

set_kv_cache_type_and_shape(main_model, main_device_config);
set_kv_cache_type_and_shape(draft_model, draft_device_config);

// main and draft model can have different tokenizers
// to do: support retokenization: 154103
Tokenizer main_model_tokenizer(main_models_path, tokenizer_plugin_config),
draft_model_tokenizer(draft_models_path, tokenizer_plugin_config);
draft_model_tokenizer(draft_model_path, tokenizer_plugin_config);

m_tokenizer = main_model_tokenizer;

m_main_pipeline = std::make_shared<ContinuousBatchingImpl>(core, main_model, main_model_tokenizer, main_device_config, main_scheduler_config, device, plugin_config, true);
m_draft_pipeline = std::make_shared<ContinuousBatchingImpl>(core, draft_model, draft_model_tokenizer, draft_device_config, draft_scheduler_config, device, plugin_config, false);
// to create `main_pipeline` with enabled validation_mode and `draft_pipeline` with disabled validation mode
m_main_pipeline = std::make_shared<ContinuousBatchingImpl>(core, main_model, main_model_tokenizer, main_device_config, main_scheduler_config, main_device, main_plugin_config, true);
m_draft_pipeline = std::make_shared<ContinuousBatchingImpl>(core, draft_model, draft_model_tokenizer, draft_device_config, draft_scheduler_config, draft_device, draft_plugin_config, false);
}

GenerationHandle
Expand All @@ -82,7 +96,6 @@ void ContinuousBatchingPipeline::SpeculativeDecodingImpl::step() {
// generate candidates by draft model
m_draft_pipeline->step();


// to generate num_matches statistic
std::map<int64_t, ContinuousBatchingPipeline::ContinuousBatchingImpl::UpdateSeqResult> update_sequence_info;
// put candidates to model KV cache
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/speculative_decoding_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ class ContinuousBatchingPipeline::SpeculativeDecodingImpl : public ContinuousBat

public:
SpeculativeDecodingImpl(const std::string& main_models_path,
const std::string& draft_models_path,
const SchedulerConfig& scheduler_config,
const std::string& device,
const ov::AnyMap& plugin_config,
const ov::genai::ModelDesc draft_model_desc,
const ov::AnyMap& tokenizer_config = {});

GenerationHandle add_request(uint64_t request_id,
Expand Down

0 comments on commit 9724723

Please sign in to comment.