Skip to content

Commit

Permalink
Add infer request queue for tokenizers and allow for optional plugin_…
Browse files Browse the repository at this point in the history
…config in tokenizer (#651)

This improves performance of CB lib when tested within OVMS.
  • Loading branch information
dkalinowski committed Jul 24, 2024
1 parent 56eeafc commit 98c2e1c
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class OPENVINO_GENAI_EXPORTS ContinuousBatchingPipeline {
ContinuousBatchingPipeline(const std::string& models_path,
const SchedulerConfig& scheduler_config,
const std::string& device = "CPU",
const ov::AnyMap& plugin_config = {});
const ov::AnyMap& llm_plugin_config = {},
const ov::AnyMap& tokenizer_plugin_config = {});

/**
* @brief Constructs a ContinuousBatchingPipeline when ov::genai::Tokenizer is initialized manually using file from the different dirs.
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/include/openvino/genai/tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class OPENVINO_GENAI_EXPORTS Tokenizer {
* @brief ov::genai::Tokenizer constructor.
* @param tokenizer_path openvino_tokenizer.xml and openvino_detokenizer.xml should be located in the tokenizer_path
*/
Tokenizer(const std::string& tokenizer_path);
Tokenizer(const std::string& tokenizer_path, const ov::AnyMap& plugin_config = {});

/**
* @brief encode a single prompt
Expand Down
100 changes: 100 additions & 0 deletions src/cpp/src/circular_buffer_queue.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#pragma once

#include <queue>
#include <mutex>
#include <future>
#include <algorithm>
#include <atomic>

namespace ov::genai {

// From OVMS:
// https://github.com/openvinotoolkit/model_server/blob/d73e85cbb8ac1d761754cb2064a00551a9ffc655/src/queue.hpp#L34
template <typename T>
class CircularBufferQueue
{
int m_front_idx;
std::atomic<int> m_back_idx;
std::vector<int> m_values;
std::queue<std::promise<int>> m_promises;
std::vector<T> m_data;
std::mutex m_front_mut;
std::mutex m_queue_mutex;

public:

CircularBufferQueue(size_t length, const std::function<T()>& create_fn) :
m_values(length),
m_front_idx{0},
m_back_idx{0} {
std::iota(m_values.begin(), m_values.end(), 0);
m_data.reserve(length);
for (size_t i = 0; i < length; i++) {
m_data.emplace_back(std::move(create_fn()));
}
}

CircularBufferQueue(const CircularBufferQueue&) = delete;
CircularBufferQueue(const CircularBufferQueue&&) = delete;
CircularBufferQueue& operator=(const CircularBufferQueue&) = delete;

T& get(int value) {
return m_data[value];
}

std::future<int> get_idle() {
int value;
std::promise<int> idle_promise;
std::future<int> idle_future = idle_promise.get_future();
std::unique_lock<std::mutex> lk(m_front_mut);
if (m_values[m_front_idx] < 0) {
std::unique_lock<std::mutex> queueLock(m_queue_mutex);
m_promises.push(std::move(idle_promise));
} else {
value = m_values[m_front_idx];
m_values[m_front_idx] = -1;
m_front_idx = (m_front_idx + 1) % m_values.size();
lk.unlock();
idle_promise.set_value(value);
}
return idle_future;
}

void return_to(int value) {
std::unique_lock<std::mutex> lk(m_queue_mutex);
if (m_promises.size()) {
std::promise<int> promise = std::move(m_promises.front());
m_promises.pop();
lk.unlock();
promise.set_value(value);
return;
}
int old_back = m_back_idx.load();
while (!m_back_idx.compare_exchange_weak(
old_back,
(old_back + 1) % m_values.size(),
std::memory_order_relaxed)) {
}
m_values[old_back] = value;
}
};

template <typename T>
class CircularBufferQueueElementGuard {
CircularBufferQueue<T>* m_queue;
int m_value;
public:
CircularBufferQueueElementGuard(CircularBufferQueue<T>* queue) : m_queue(queue) {
m_value = m_queue->get_idle().get(); // blocking until we get the element
}

T& get() {
return m_queue->get(m_value);
}

~CircularBufferQueueElementGuard() {
m_queue->return_to(m_value);
}
};

}
9 changes: 5 additions & 4 deletions src/cpp/src/continuous_batching_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ class ContinuousBatchingPipeline::Impl {
// read default generation config
}

Impl(const std::string& models_path, const SchedulerConfig& scheduler_config, const std::string& device, const ov::AnyMap& plugin_config)
: Impl{models_path, Tokenizer(models_path), scheduler_config, device, plugin_config} {}
Impl(const std::string& models_path, const SchedulerConfig& scheduler_config, const std::string& device, const ov::AnyMap& llm_plugin_config, const ov::AnyMap& tokenizer_plugin_config)
: Impl{models_path, Tokenizer(models_path, tokenizer_plugin_config), scheduler_config, device, llm_plugin_config} {}

ov::genai::GenerationConfig get_config() const {
return m_generation_config;
Expand Down Expand Up @@ -282,8 +282,9 @@ class ContinuousBatchingPipeline::Impl {
ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::string& models_path,
const SchedulerConfig& scheduler_config,
const std::string& device,
const ov::AnyMap& plugin_config ) {
m_impl = std::make_shared<Impl>(models_path, scheduler_config, device, plugin_config);
const ov::AnyMap& llm_plugin_config,
const ov::AnyMap& tokenizer_plugin_config) {
m_impl = std::make_shared<Impl>(models_path, scheduler_config, device, llm_plugin_config, tokenizer_plugin_config);
}

ContinuousBatchingPipeline::ContinuousBatchingPipeline(
Expand Down
98 changes: 62 additions & 36 deletions src/cpp/src/tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
#include <jinja2cpp/template.h>
#include <jinja2cpp/template_env.h>
#include "tokenizers_path.hpp"
#include "circular_buffer_queue.hpp"
#include <fstream>
#include <memory>

namespace {

Expand Down Expand Up @@ -55,10 +57,12 @@ namespace genai {

class Tokenizer::TokenizerImpl {
public:
ov::InferRequest m_tokenizer_request;
ov::InferRequest m_detokenizer_request;
std::mutex m_tokenizer_mutex;
std::mutex m_detokenizer_mutex;
ov::CompiledModel m_tokenizer;
ov::CompiledModel m_detokenizer;

std::unique_ptr<CircularBufferQueue<ov::InferRequest>> m_ireq_queue_tokenizer;
std::unique_ptr<CircularBufferQueue<ov::InferRequest>> m_ireq_queue_detokenizer;

int64_t m_pad_token_id = -1;
int64_t m_bos_token_id = -1;
int64_t m_eos_token_id = -1;
Expand All @@ -71,7 +75,7 @@ class Tokenizer::TokenizerImpl {

TokenizerImpl() = default;

TokenizerImpl(std::filesystem::path tokenizer_path)
TokenizerImpl(std::filesystem::path tokenizer_path, const ov::AnyMap& plugin_config)
: m_chat_template{chat_template_from_tokenizer_json_if_exists(tokenizer_path)} {
ov::Core core;

Expand All @@ -92,10 +96,23 @@ class Tokenizer::TokenizerImpl {
read_tokenizer_config_if_necessary(tokenizer_path);

auto device = "CPU"; // currently openvino_tokenizer supports only CPU
m_tokenizer_request = core.compile_model(tokenizer_path / "openvino_tokenizer.xml",
device).create_infer_request();
m_detokenizer_request = core.compile_model(tokenizer_path / "openvino_detokenizer.xml",
device).create_infer_request();
m_tokenizer = core.compile_model(tokenizer_path / "openvino_tokenizer.xml",
device, plugin_config);
m_detokenizer = core.compile_model(tokenizer_path / "openvino_detokenizer.xml",
device, plugin_config);


const size_t INFER_REQUEST_QUEUE_SIZE = m_tokenizer.get_property(ov::optimal_number_of_infer_requests);
m_ireq_queue_tokenizer = std::make_unique<CircularBufferQueue<ov::InferRequest>>(
INFER_REQUEST_QUEUE_SIZE,
[this]() -> ov::InferRequest {
return std::move(this->m_tokenizer.create_infer_request());
});
m_ireq_queue_detokenizer = std::make_unique<CircularBufferQueue<ov::InferRequest>>(
INFER_REQUEST_QUEUE_SIZE,
[this]() -> ov::InferRequest {
return std::move(this->m_detokenizer.create_infer_request());
});

// Get special token ids by inference if they are not defined.
infer_special_tokens_if_necessary();
Expand Down Expand Up @@ -231,29 +248,35 @@ class Tokenizer::TokenizerImpl {
}

TokenizedInputs encode(std::string prompt) {
CircularBufferQueueElementGuard<ov::InferRequest> infer_request_guard(this->m_ireq_queue_tokenizer.get());
size_t batch_size = 1;
std::unique_lock<std::mutex> lock(m_tokenizer_mutex);
m_tokenizer_request.set_input_tensor(ov::Tensor{ov::element::string, {batch_size}, &prompt});
m_tokenizer_request.infer();
return get_copied_results();
infer_request_guard.get().set_input_tensor(ov::Tensor{ov::element::string, {batch_size}, &prompt});
infer_request_guard.get().start_async();
infer_request_guard.get().wait();
return get_copied_results(
infer_request_guard.get().get_tensor("input_ids"),
infer_request_guard.get().get_tensor("attention_mask")
);
}

TokenizedInputs encode(std::vector<std::string>& prompts) {
TokenizedInputs unpadded;
{
std::unique_lock<std::mutex> lock(m_tokenizer_mutex);
m_tokenizer_request.set_input_tensor(ov::Tensor{ov::element::string, {prompts.size()}, prompts.data()});
auto size_ = m_tokenizer_request.get_input_tensor().get_shape();
m_tokenizer_request.infer();

unpadded = get_copied_results();
CircularBufferQueueElementGuard<ov::InferRequest> infer_request_guard(this->m_ireq_queue_tokenizer.get());
infer_request_guard.get().set_input_tensor(ov::Tensor{ov::element::string, {prompts.size()}, prompts.data()});
auto size_ = infer_request_guard.get().get_input_tensor().get_shape();
infer_request_guard.get().start_async();
infer_request_guard.get().wait();

unpadded = get_copied_results(
infer_request_guard.get().get_tensor("input_ids"),
infer_request_guard.get().get_tensor("attention_mask")
);
}
return pad_left(unpadded.input_ids, unpadded.attention_mask);
}

TokenizedInputs get_copied_results() {
auto input_ids = m_tokenizer_request.get_tensor("input_ids");
auto attention_mask = m_tokenizer_request.get_tensor("attention_mask");
TokenizedInputs get_copied_results(ov::Tensor input_ids, ov::Tensor attention_mask) {
ov::Tensor input_ids_ = ov::Tensor(input_ids.get_element_type(), input_ids.get_shape());
ov::Tensor attention_mask_ = ov::Tensor(attention_mask.get_element_type(), attention_mask.get_shape());
input_ids.copy_to(input_ids_);
Expand All @@ -263,22 +286,24 @@ class Tokenizer::TokenizerImpl {
}

std::string decode(std::vector<int64_t> tokens) {
CircularBufferQueueElementGuard<ov::InferRequest> infer_request_guard(this->m_ireq_queue_detokenizer.get());
size_t batch_size = 1;
std::unique_lock<std::mutex> lock(m_detokenizer_mutex);
m_detokenizer_request.set_input_tensor(ov::Tensor{ov::element::i64, {batch_size, tokens.size()}, tokens.data()});
m_detokenizer_request.infer();
return m_detokenizer_request.get_output_tensor().data<std::string>()[0];
infer_request_guard.get().set_input_tensor(ov::Tensor{ov::element::i64, {batch_size, tokens.size()}, tokens.data()});
infer_request_guard.get().start_async();
infer_request_guard.get().wait();
return infer_request_guard.get().get_output_tensor().data<std::string>()[0];
}

std::vector<std::string> decode(ov::Tensor tokens) {
OPENVINO_ASSERT(tokens.get_element_type() == ov::element::i64, "tokens tensor element type should be an i64");
OPENVINO_ASSERT(tokens.get_shape().size() == 2, "tokens tensor should of rank 2 with shape [batch_size, seq_len]");

std::unique_lock<std::mutex> lock(m_detokenizer_mutex);
m_detokenizer_request.set_input_tensor(tokens);
m_detokenizer_request.infer();
CircularBufferQueueElementGuard<ov::InferRequest> infer_request_guard(this->m_ireq_queue_detokenizer.get());
infer_request_guard.get().set_input_tensor(tokens);
infer_request_guard.get().start_async();
infer_request_guard.get().wait();

auto res = m_detokenizer_request.get_output_tensor();
auto res = infer_request_guard.get().get_output_tensor();
auto res_data = res.data<std::string>();
return std::vector<std::string>(res_data, res_data + res.get_shape()[0]);
}
Expand All @@ -299,10 +324,11 @@ class Tokenizer::TokenizerImpl {
std::fill(tokens_data + i * max_len + line_len, tokens_data + (i + 1) * max_len, m_pad_token_id);
}

std::unique_lock<std::mutex> lock(m_detokenizer_mutex);
m_detokenizer_request.set_input_tensor(tokens);
m_detokenizer_request.infer();
auto res = m_detokenizer_request.get_output_tensor();
CircularBufferQueueElementGuard<ov::InferRequest> infer_request_guard(this->m_ireq_queue_detokenizer.get());
infer_request_guard.get().set_input_tensor(tokens);
infer_request_guard.get().start_async();
infer_request_guard.get().wait();
auto res = infer_request_guard.get().get_output_tensor();
auto res_data = res.data<std::string>();
return std::vector<std::string>(res_data, res_data + res.get_shape()[0]);
}
Expand Down Expand Up @@ -411,9 +437,9 @@ class Tokenizer::TokenizerImpl {

};

Tokenizer::Tokenizer(const std::string& tokenizer_path) {
Tokenizer::Tokenizer(const std::string& tokenizer_path, const ov::AnyMap& plugin_config) {
ScopedVar env_manager(tokenizers_relative_to_genai().string());
m_pimpl = std::make_shared<TokenizerImpl>(tokenizer_path);
m_pimpl = std::make_shared<TokenizerImpl>(tokenizer_path, plugin_config);
}

TokenizedInputs Tokenizer::encode(const std::string prompt) {
Expand Down
12 changes: 6 additions & 6 deletions src/python/py_generate_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,10 +436,10 @@ PYBIND11_MODULE(py_generate_pipeline, m) {
R"(openvino_genai.Tokenizer object is used to initialize Tokenizer
if it's located in a different path than the main model.)")

.def(py::init([](const std::string& tokenizer_path) {
.def(py::init([](const std::string& tokenizer_path, const std::map<std::string, py::object>& plugin_config) {
ScopedVar env_manager(ov_tokenizers_module_path());
return std::make_unique<ov::genai::Tokenizer>(tokenizer_path);
}), py::arg("tokenizer_path"))
return std::make_unique<ov::genai::Tokenizer>(tokenizer_path, properties_to_any_map(plugin_config));
}), py::arg("tokenizer_path"), py::arg("plugin_config") = ov::AnyMap({}))

.def("encode", [](Tokenizer& tok, std::vector<std::string>& prompts) { return tok.encode(prompts); },
py::arg("prompts"),
Expand Down Expand Up @@ -596,10 +596,10 @@ PYBIND11_MODULE(py_generate_pipeline, m) {
.def_readwrite("max_num_seqs", &SchedulerConfig::max_num_seqs);

py::class_<ContinuousBatchingPipeline>(m, "ContinuousBatchingPipeline")
.def(py::init([](const std::string& model_path, const SchedulerConfig& scheduler_config, const std::string& device, const std::map<std::string, py::object>& plugin_config) {
.def(py::init([](const std::string& model_path, const SchedulerConfig& scheduler_config, const std::string& device, const std::map<std::string, py::object>& llm_plugin_config, const std::map<std::string, py::object>& tokenizer_plugin_config) {
ScopedVar env_manager(ov_tokenizers_module_path());
return std::make_unique<ContinuousBatchingPipeline>(model_path, scheduler_config, device, properties_to_any_map(plugin_config));
}), py::arg("model_path"), py::arg("scheduler_config"), py::arg("device") = "CPU", py::arg("plugin_config") = ov::AnyMap({}))
return std::make_unique<ContinuousBatchingPipeline>(model_path, scheduler_config, device, properties_to_any_map(llm_plugin_config), properties_to_any_map(tokenizer_plugin_config));
}), py::arg("model_path"), py::arg("scheduler_config"), py::arg("device") = "CPU", py::arg("llm_plugin_config") = ov::AnyMap({}), py::arg("tokenizer_plugin_config") = ov::AnyMap({}))
.def(py::init([](const std::string& model_path, const ov::genai::Tokenizer& tokenizer, const SchedulerConfig& scheduler_config, const std::string& device, const std::map<std::string, py::object>& plugin_config) {
ScopedVar env_manager(ov_tokenizers_module_path());
return std::make_unique<ContinuousBatchingPipeline>(model_path, tokenizer, scheduler_config, device, properties_to_any_map(plugin_config));
Expand Down
2 changes: 1 addition & 1 deletion tests/python_tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def run_continuous_batching(
prompts: List[str],
generation_configs : List[GenerationConfig]
) -> List[GenerationResult]:
pipe = ContinuousBatchingPipeline(model_path.absolute().as_posix(), scheduler_config, "CPU", {})
pipe = ContinuousBatchingPipeline(model_path.absolute().as_posix(), scheduler_config, "CPU", {}, {})
output = pipe.generate(prompts, generation_configs)
del pipe
shutil.rmtree(model_path)
Expand Down
2 changes: 1 addition & 1 deletion tests/python_tests/ov_genai_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def load_tok(configs: List[Tuple], temp_path):
for config_json, config_name in configs:
with (temp_path / config_name).open('w') as f:
json.dump(config_json, f)
return ov_genai.Tokenizer(str(temp_path))
return ov_genai.Tokenizer(str(temp_path), {})


def load_pipe(configs: List[Tuple], temp_path):
Expand Down
2 changes: 1 addition & 1 deletion tests/python_tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def test_post_oom_health(tmp_path):
model_path : Path = tmp_path / model_id
save_ov_model_from_optimum(model, hf_tokenizer, model_path)

pipe = ContinuousBatchingPipeline(model_path.absolute().as_posix(), Tokenizer(model_path.absolute().as_posix()), scheduler_config)
pipe = ContinuousBatchingPipeline(model_path.absolute().as_posix(), Tokenizer(model_path.absolute().as_posix(), {}), scheduler_config, "CPU", {})
# First run should return incomplete response
output = pipe.generate(["What is OpenVINO?"], generation_configs)
assert(len(output))
Expand Down

0 comments on commit 98c2e1c

Please sign in to comment.