Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CB naive chat #644

Merged
merged 37 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
d2ccfe0
Add ContinuousBatchingPipeline constructor similar to LLMPipeline
Wovchena Jul 10, 2024
ab0f43c
Use CB as backend
Wovchena Jul 11, 2024
05cf5d3
Update bindings
Wovchena Jul 11, 2024
96fcf77
comma
Wovchena Jul 11, 2024
e5c2631
Merge branch 'add-ContinuousBatchingPipeline-constructor-similar-to-L…
Wovchena Jul 11, 2024
967da3a
Merge branch 'master' into add-ContinuousBatchingPipeline-constructor…
Wovchena Jul 11, 2024
c13d623
Merge branch 'add-ContinuousBatchingPipeline-constructor-similar-to-L…
Wovchena Jul 11, 2024
e278a14
pass
Wovchena Jul 11, 2024
47fa22c
conflict
Wovchena Jul 11, 2024
2094ba6
conflict
Wovchena Jul 11, 2024
78361a9
clean
Wovchena Jul 11, 2024
691fefc
verify status
Wovchena Jul 11, 2024
1825328
conflict
Wovchena Jul 11, 2024
6a3275e
args
Wovchena Jul 11, 2024
a5f2cd6
test
Wovchena Jul 11, 2024
771fc29
tests
Wovchena Jul 11, 2024
5c615bf
tests
Wovchena Jul 11, 2024
3afc16d
test
Wovchena Jul 12, 2024
67f4717
remove caching
Wovchena Jul 12, 2024
d26723f
Clear beam search info.
popovaan Jul 12, 2024
c28a023
Merge remote-tracking branch 'popovaan/clear_beam_info' into use-CB-a…
Wovchena Jul 12, 2024
d223d68
-am cache
Wovchena Jul 15, 2024
238ea8b
updte
Wovchena Jul 15, 2024
5a4c878
Revert "Merge remote-tracking branch 'popovaan/clear_beam_info' into …
Wovchena Jul 15, 2024
cf35f19
revert spelling
Wovchena Jul 15, 2024
12061af
relax abs_tol
Wovchena Jul 15, 2024
cad13dc
Merge branch 'releases/2024/3' into use-CB-as-backend
Wovchena Jul 15, 2024
0ffdd6f
Merge branch 'releases/2024/3' into use-CB-as-backend
Wovchena Jul 15, 2024
6d7a468
lru_cache
Wovchena Jul 15, 2024
036111c
Merge branch 'releases/2024/3' into cb-streaming
Wovchena Jul 16, 2024
c6d345a
Add CB streaming
Wovchena Jul 16, 2024
bc56ca6
use StreamerVariant
Wovchena Jul 17, 2024
c70b909
Add CB naive chat
Wovchena Jul 18, 2024
5468439
correct tests
Wovchena Jul 19, 2024
4f77aa9
correct test_continuous_batching_vs_stateful
Wovchena Jul 19, 2024
b45d038
Merge branch 'releases/2024/3' into cb-naive-chat
Wovchena Jul 23, 2024
6884d04
Merge branch 'releases/2024/3' into cb-naive-chat
Wovchena Jul 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include "openvino/genai/tokenizer.hpp"
#include "openvino/genai/generation_config.hpp"
#include "openvino/genai/generation_handle.hpp"
#include "openvino/genai/llm_pipeline.hpp"
#include "openvino/genai/streamer_base.hpp"
#include "openvino/genai/visibility.hpp"

namespace ov::genai {
Expand Down Expand Up @@ -55,13 +57,27 @@ class OPENVINO_GENAI_EXPORTS ContinuousBatchingPipeline {

PipelineMetrics get_metrics() const;

GenerationHandle add_request(uint64_t request_id, std::string prompt, ov::genai::GenerationConfig sampling_params);
GenerationHandle add_request(uint64_t request_id, const ov::Tensor& input_ids, const ov::genai::GenerationConfig& sampling_params);
GenerationHandle add_request(uint64_t request_id, const std::string& prompt, const ov::genai::GenerationConfig& sampling_params);

void step();

bool has_non_finished_requests();

// more high level interface, which can process multiple prompts in continuous batching manner
std::vector<GenerationResult> generate(const std::vector<std::string>& prompts, std::vector<ov::genai::GenerationConfig> sampling_params);
std::vector<EncodedGenerationResult> generate(const std::vector<ov::Tensor>& input_ids, const std::vector<ov::genai::GenerationConfig>& sampling_params, const ov::genai::StreamerVariant& streamer=std::monostate{});
std::vector<GenerationResult> generate(const std::vector<std::string>& prompts, const std::vector<ov::genai::GenerationConfig>& sampling_params, const ov::genai::StreamerVariant& streamer=std::monostate{});

/**
* @brief start chat with keeping history in kv cache.
*
* @param system_message optional system message.
*/
void start_chat(const std::string& system_message = "");

/**
* @brief finish chat and clear kv cache.
*/
void finish_chat();
};
}
15 changes: 15 additions & 0 deletions src/cpp/include/openvino/genai/generation_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,20 @@ enum class GenerationStatus {
DROPPED_BY_HANDLE = 4 // Status set when generation handle is dropped
};

struct EncodedGenerationResult {
// request ID - obsolete when handle API is approved as handle will connect results with prompts.
uint64_t m_request_id;

// in a generic case we have multiple generation results per initial prompt
// depending on sampling parameters (e.g. beam search or parallel sampling)
std::vector<std::vector<int64_t>> m_generation_ids;
// scores
std::vector<float> m_scores;

// Status of generation
GenerationStatus m_status = GenerationStatus::RUNNING;
};

struct GenerationResult {
// request ID - obsolete when handle API is approved as handle will connect results with prompts.
uint64_t m_request_id;
Expand Down Expand Up @@ -60,6 +74,7 @@ class OPENVINO_GENAI_EXPORTS GenerationHandleImpl {

bool can_read();

GenerationOutputs back();
// Reads result of a generation for single iteration
GenerationOutputs read();
// Reads all generated tokens for all sequences
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/include/openvino/genai/llm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
namespace ov {
namespace genai {

// Return flag corresponds whether generation should be stopped: false means continue generation, true means stop.
// Return flag correspods whether generation should be stopped: false means continue generation, true means stop.
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
using OptionalGenerationConfig = std::optional<GenerationConfig>;
using EncodedInputs = std::variant<ov::Tensor, TokenizedInputs>;
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/include/openvino/genai/scheduler_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct SchedulerConfig {
std::size_t num_kv_blocks = 0;

// total size of KV cache in GB
std::size_t cache_size = 0;
std::size_t cache_size = 1;

// block size for KV cache
std::size_t block_size = 32;
Expand Down
146 changes: 121 additions & 25 deletions src/cpp/src/continuous_batching_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,21 @@
#include <memory>

#include "openvino/genai/continuous_batching_pipeline.hpp"
#include "openvino/genai/generation_handle.hpp"
#include "openvino/genai/tokenizer.hpp"
#include "cache_manager.hpp"
#include "sampler.hpp"
#include "model_runner.hpp"
#include "scheduler.hpp"
#include "text_callback_streamer.hpp"
#include "timer.hpp"
#include "debug_utils.hpp"

using namespace ov::genai;

template<class... Ts> struct overloaded : Ts... {using Ts::operator()...;};
template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;

void apply_paged_attention_transformations(std::shared_ptr<ov::Model> model, DeviceConfig& device_config);

class ContinuousBatchingPipeline::Impl {
Expand Down Expand Up @@ -51,6 +56,8 @@ class ContinuousBatchingPipeline::Impl {
std::vector<SequenceGroup::Ptr> m_awaiting_requests;
// Mutex protecting access to m_awaiting_requests, so add_request and step methods can be called from different threads
std::mutex m_awaiting_requests_mutex;
bool m_is_chat_conversation = false;
ChatHistory m_history;


void _free_non_running_requests() {
Expand Down Expand Up @@ -120,18 +127,9 @@ class ContinuousBatchingPipeline::Impl {
return m_tokenizer;
}

GenerationHandle add_request(uint64_t request_id, std::string prompt, ov::genai::GenerationConfig sampling_params) {
GenerationHandle add_request(uint64_t request_id, const ov::Tensor& input_ids, ov::genai::GenerationConfig sampling_params) {
sampling_params.set_eos_token_id(m_tokenizer.get_eos_token_id());
sampling_params.validate();

ov::Tensor input_ids;
{
static ManualTimer timer("tokenize");
timer.start();
input_ids = m_tokenizer.encode(prompt).input_ids;
timer.end();
}

SequenceGroup::Ptr sequence_group = std::make_shared<SequenceGroup>(request_id, input_ids,
sampling_params, m_scheduler->get_config().block_size);
{
Expand All @@ -141,6 +139,14 @@ class ContinuousBatchingPipeline::Impl {
return std::make_unique<GenerationHandleImpl>(sequence_group->get_generation_stream(), sampling_params);
}

GenerationHandle add_request(uint64_t request_id, const std::string& prompt, ov::genai::GenerationConfig sampling_params) {
static ManualTimer timer("tokenize");
timer.start();
ov::Tensor input_ids = m_tokenizer.encode(prompt).input_ids;
timer.end();
return add_request(request_id, input_ids, sampling_params);
}

void step() {
static ManualTimer step_timer("step()");
step_timer.start();
Expand Down Expand Up @@ -238,25 +244,47 @@ class ContinuousBatchingPipeline::Impl {
return !m_awaiting_requests.empty() || !m_requests.empty();
}

std::vector<GenerationResult> generate(const std::vector<std::string> prompts, std::vector<ov::genai::GenerationConfig> sampling_params) {
std::vector<EncodedGenerationResult> generate(const std::vector<ov::Tensor>& input_ids, const std::vector<GenerationConfig>& sampling_params, const StreamerVariant& streamer) {
OPENVINO_ASSERT(!has_non_finished_requests(), "Generate cannot be called while ContinuousBatchingPipeline is already in running state. Use ContinuousBatchingPipeline::add_request");
OPENVINO_ASSERT(prompts.size() == sampling_params.size());
OPENVINO_ASSERT(input_ids.size() == sampling_params.size());
const std::shared_ptr<StreamerBase>& streamer_ptr = std::visit(overloaded{
[](std::monostate) -> std::shared_ptr<StreamerBase> {
return nullptr;
},
[](const std::shared_ptr<StreamerBase>& streamer) {
return streamer;
},
[this](const std::function<bool(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer);
}
}, streamer);

std::vector<GenerationHandle> generations;
for (size_t request_id = 0; request_id < prompts.size(); ++request_id) {
generations.push_back(add_request(request_id, prompts[request_id], sampling_params[request_id]));
for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) {
OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch.");
generations.push_back(add_request(request_id, input_ids[request_id], sampling_params[request_id]));
}

std::vector<GenerationResult> results;
std::vector<EncodedGenerationResult> results;
results.reserve(m_awaiting_requests.size());

while (has_non_finished_requests()) {
bool continue_generation = true;
while (has_non_finished_requests() && continue_generation) {
step();
if (streamer_ptr) {
std::unordered_map<uint64_t, GenerationOutput> token = generations.at(0).get()->back();
OPENVINO_ASSERT(1 == token.size());
OPENVINO_ASSERT(1 == token.begin()->second.generated_token_ids.size());
continue_generation = !streamer_ptr->put(token.begin()->second.generated_token_ids.at(0));
}
}
if (streamer_ptr) {
streamer_ptr->end();
Comment on lines +281 to +282
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we continue to execute this function when streamer is used? Output has already been read and pushed to streamer in above while loop.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's just API. We document that end() is called in the end. One possible use case to let the streamer to flush the output.

}

for (size_t generation_idx = 0; generation_idx < generations.size(); ++generation_idx) {
const auto& generation = generations[generation_idx];
GenerationResult result;
EncodedGenerationResult result;
result.m_request_id = 1;
std::vector<GenerationOutput> generation_outputs = generation->read_all();
std::sort(generation_outputs.begin(), generation_outputs.end(), [=] (GenerationOutput& r1, GenerationOutput& r2) {
Expand All @@ -266,17 +294,69 @@ class ContinuousBatchingPipeline::Impl {
auto num_outputs = std::min(sampling_params[generation_idx].num_return_sequences, generation_outputs.size());
for (size_t generation_output_idx = 0; generation_output_idx < num_outputs; ++generation_output_idx) {
const auto& generation_output = generation_outputs[generation_output_idx];
std::string output_text = m_tokenizer.decode(generation_output.generated_token_ids);
result.m_generation_ids.push_back(output_text);
result.m_generation_ids.push_back(std::move(generation_output.generated_token_ids));
result.m_scores.push_back(generation_output.score);
}
result.m_status = generation->get_status();
results.push_back(result);
results.push_back(std::move(result));
}

OPENVINO_ASSERT(results.size() == prompts.size());
OPENVINO_ASSERT(results.size() == input_ids.size());
return results;
}

std::vector<GenerationResult> generate(const std::vector<std::string>& prompts, std::vector<ov::genai::GenerationConfig> sampling_params, const StreamerVariant& streamer) {
std::vector<ov::Tensor> input_ids;
static ManualTimer timer("tokenize");
if (m_is_chat_conversation) {
OPENVINO_ASSERT(1 == prompts.size(), "Can't chat with multiple prompts");
m_history.push_back({{"role", "user"}, {"content", prompts.at(0)}});
constexpr bool add_generation_prompt = true;
std::string history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
timer.start();
input_ids.push_back(m_tokenizer.encode(history).input_ids);
timer.end();
} else {
input_ids.reserve(prompts.size());
for (const std::string& prompt : prompts) {
timer.start();
input_ids.push_back(m_tokenizer.encode(prompt).input_ids);
timer.end();
}
}
std::vector<EncodedGenerationResult> encoded = generate(input_ids, sampling_params, streamer);
std::vector<GenerationResult> decoded;
decoded.reserve(encoded.size());
for (EncodedGenerationResult& res : encoded) {
std::vector<std::string> generated;
generated.reserve(res.m_generation_ids.size());
for (size_t idx = 0; idx < res.m_generation_ids.size(); ++idx) {
generated.push_back(m_tokenizer.decode(res.m_generation_ids.at(idx)));
if (m_is_chat_conversation && 0 == idx) {
m_history.push_back({{"role", "assistant"}, {"content", generated.back()}});
}
}
decoded.push_back(GenerationResult{
res.m_request_id,
std::move(generated),
std::move(res.m_scores),
res.m_status
});
}
return decoded;
}

void start_chat(const std::string& system_message) {
if (!system_message.empty()) {
m_history.push_back({{"role", "system"}, {"content", system_message}});
}
m_is_chat_conversation = true;
};

void finish_chat() {
m_is_chat_conversation = false;
m_history.clear();
};
};

ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::string& models_path,
Expand Down Expand Up @@ -306,10 +386,14 @@ PipelineMetrics ContinuousBatchingPipeline::get_metrics() const{
return m_impl->get_metrics();
}

GenerationHandle ContinuousBatchingPipeline::add_request(uint64_t request_id, std::string prompt, ov::genai::GenerationConfig sampling_params) {
GenerationHandle ContinuousBatchingPipeline::add_request(uint64_t request_id, const std::string& prompt, const ov::genai::GenerationConfig& sampling_params) {
return m_impl->add_request(request_id, prompt, sampling_params);
}

GenerationHandle ContinuousBatchingPipeline::add_request(uint64_t request_id, const ov::Tensor& input_ids, const ov::genai::GenerationConfig& sampling_params) {
return m_impl->add_request(request_id, input_ids, sampling_params);
}

void ContinuousBatchingPipeline::step() {
m_impl->step();
}
Expand All @@ -318,6 +402,18 @@ bool ContinuousBatchingPipeline::has_non_finished_requests() {
return m_impl->has_non_finished_requests();
}

std::vector<GenerationResult> ContinuousBatchingPipeline::generate(const std::vector<std::string>& prompts, std::vector<ov::genai::GenerationConfig> sampling_params) {
return m_impl->generate(prompts, sampling_params);
}
std::vector<EncodedGenerationResult> ContinuousBatchingPipeline::generate(const std::vector<ov::Tensor>& input_ids, const std::vector<ov::genai::GenerationConfig>& sampling_params, const StreamerVariant& streamer) {
return m_impl->generate(input_ids, sampling_params, streamer);
}

std::vector<GenerationResult> ContinuousBatchingPipeline::generate(const std::vector<std::string>& prompts, const std::vector<ov::genai::GenerationConfig>& sampling_params, const StreamerVariant& streamer) {
return m_impl->generate(prompts, sampling_params, streamer);
}

void ContinuousBatchingPipeline::start_chat(const std::string& system_message) {
m_impl->start_chat(system_message);
};

void ContinuousBatchingPipeline::finish_chat() {
m_impl->finish_chat();
};
4 changes: 4 additions & 0 deletions src/cpp/src/generation_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ bool GenerationHandleImpl::can_read() {
return m_generation_stream->can_read();
}

std::unordered_map<uint64_t, GenerationOutput> GenerationHandleImpl::back() {
return m_generation_stream->back();
}

std::unordered_map<uint64_t, GenerationOutput> GenerationHandleImpl::read() {
return m_generation_stream->read();
}
Expand Down
3 changes: 3 additions & 0 deletions src/cpp/src/generation_stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ class GenerationStream {
}

// Retriving vector of pairs <sequence_id, token_id> as we can generate multiple outputs for a single prompt
GenerationOutputs back() {
return m_output_queue.back();
}
GenerationOutputs read() {
return m_output_queue.pull();
}
Expand Down
Loading
Loading