Skip to content

Commit

Permalink
Add CB naive chat
Browse files Browse the repository at this point in the history
  • Loading branch information
Wovchena committed Jul 18, 2024
1 parent bc56ca6 commit 532a804
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 26 deletions.
12 changes: 12 additions & 0 deletions src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,17 @@ class OPENVINO_GENAI_EXPORTS ContinuousBatchingPipeline {
// more high level interface, which can process multiple prompts in continuous batching manner
std::vector<EncodedGenerationResult> generate(const std::vector<ov::Tensor>& input_ids, const std::vector<ov::genai::GenerationConfig>& sampling_params, const ov::genai::StreamerVariant& streamer=std::monostate{});
std::vector<GenerationResult> generate(const std::vector<std::string>& prompts, const std::vector<ov::genai::GenerationConfig>& sampling_params, const ov::genai::StreamerVariant& streamer=std::monostate{});

/**
* @brief start chat with keeping history in kv cache.
*
* @param system_message optional system message.
*/
void start_chat(const std::string& system_message = "");

/**
* @brief finish chat and clear kv cache.
*/
void finish_chat();
};
}
47 changes: 41 additions & 6 deletions src/cpp/src/continuous_batching_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class ContinuousBatchingPipeline::Impl {
std::vector<SequenceGroup::Ptr> m_awaiting_requests;
// Mutex protecting access to m_awaiting_requests, so add_request and step methods can be called from different threads
std::mutex m_awaiting_requests_mutex;
bool m_is_chat_conversation = false;
ChatHistory m_history;


void _free_non_running_requests() {
Expand Down Expand Up @@ -305,21 +307,34 @@ class ContinuousBatchingPipeline::Impl {

std::vector<GenerationResult> generate(const std::vector<std::string>& prompts, std::vector<ov::genai::GenerationConfig> sampling_params, const StreamerVariant& streamer) {
std::vector<ov::Tensor> input_ids;
input_ids.reserve(prompts.size());
for (const std::string& prompt : prompts) {
static ManualTimer timer("tokenize");
static ManualTimer timer("tokenize");
if (m_is_chat_conversation) {
OPENVINO_ASSERT(1 == prompts.size(), "Can't chat with multiple prompts");
m_history.push_back({{"role", "user"}, {"content", prompts.at(0)}});
constexpr bool add_generation_prompt = true;
std::string history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
timer.start();
input_ids.push_back(m_tokenizer.encode(prompt).input_ids);
input_ids.push_back(m_tokenizer.encode(history).input_ids);
timer.end();
} else {
input_ids.reserve(prompts.size());
for (const std::string& prompt : prompts) {
timer.start();
input_ids.push_back(m_tokenizer.encode(prompt).input_ids);
timer.end();
}
}
std::vector<EncodedGenerationResult> encoded = generate(input_ids, sampling_params, streamer);
std::vector<GenerationResult> decoded;
decoded.reserve(encoded.size());
for (EncodedGenerationResult& res : encoded) {
std::vector<std::string> generated;
generated.reserve(res.m_generation_ids.size());
for (const std::vector<int64_t>& tokens : res.m_generation_ids) {
generated.push_back(m_tokenizer.decode(tokens));
for (size_t idx = 0; idx < res.m_generation_ids.size(); ++idx) {
generated.push_back(m_tokenizer.decode(res.m_generation_ids.at(idx)));
if (m_is_chat_conversation && 0 == idx) {
m_history.push_back({{"role", "assistant"}, {"content", generated.back()}});
}
}
decoded.push_back(GenerationResult{
res.m_request_id,
Expand All @@ -330,6 +345,18 @@ class ContinuousBatchingPipeline::Impl {
}
return decoded;
}

void start_chat(const std::string& system_message) {
if (!system_message.empty()) {
m_history.push_back({{"role", "system"}, {"content", system_message}});
}
m_is_chat_conversation = true;
};

void finish_chat() {
m_is_chat_conversation = false;
m_history.clear();
};
};

ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::string& models_path,
Expand Down Expand Up @@ -382,3 +409,11 @@ std::vector<EncodedGenerationResult> ContinuousBatchingPipeline::generate(const
std::vector<GenerationResult> ContinuousBatchingPipeline::generate(const std::vector<std::string>& prompts, const std::vector<ov::genai::GenerationConfig>& sampling_params, const StreamerVariant& streamer) {
return m_impl->generate(prompts, sampling_params, streamer);
}

void ContinuousBatchingPipeline::start_chat(const std::string& system_message) {
m_impl->start_chat(system_message);
};

void ContinuousBatchingPipeline::finish_chat() {
m_impl->finish_chat();
};
38 changes: 27 additions & 11 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
EncodedInputs encoded_input;

if (auto input_vector = std::get_if<std::vector<std::string>>(&inputs)) {
OPENVINO_ASSERT(!is_chat_conversation, "Can't chat with multiple prompts");
encoded_input = m_tokenizer.encode(*input_vector);
} else if (auto input_prompt = std::get_if<std::string>(&inputs)) {
std::string& prompt = *input_prompt;
Expand Down Expand Up @@ -386,16 +387,31 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
OptionalGenerationConfig generation_config,
StreamerVariant streamer
) override {
EncodedInputs input_ids_att = std::visit(overloaded{
[this](const std::string& prompt) {
return m_tokenizer.encode(prompt);
std::vector<std::string> prompts = std::visit(overloaded{
[](const std::string& prompt) {
return std::vector{prompt};
},
[this](std::vector<std::string>& prompts) {
return m_tokenizer.encode(prompts);
[](std::vector<std::string>& prompts) {
return prompts;
}
}, inputs);
EncodedResults encoded = generate(input_ids_att, generation_config, streamer);
return {m_tokenizer.decode(encoded.tokens), encoded.scores};
const GenerationConfig& config = generation_config.has_value() ? *generation_config : m_generation_config;
// -1 == config.eos_token_id and config.validate() are handled in m_impl.
std::vector<GenerationResult> generated = m_impl.generate(
prompts,
std::vector<GenerationConfig>{prompts.size(), config},
streamer
);
std::vector<std::string> plain_replies;
std::vector<float> plain_scores;
for (GenerationResult& res : generated) {
if (GenerationStatus::FINISHED != res.m_status) {
OPENVINO_THROW("Got unfinished GenerationStatus");
}
std::move(res.m_generation_ids.begin(), res.m_generation_ids.end(), std::back_inserter(plain_replies));
std::move(res.m_scores.begin(), res.m_scores.end(), std::back_inserter(plain_scores));
}
return {std::move(plain_replies), std::move(plain_scores)};
}

EncodedResults generate(
Expand Down Expand Up @@ -457,12 +473,12 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
}

void start_chat(const std::string& system_message) override {
OPENVINO_THROW("start_chat() isn't implemented.");
}
m_impl.start_chat();
};

void finish_chat() override {
OPENVINO_THROW("finish_chat() isn't implemented.");
}
m_impl.finish_chat();
};
};
}

Expand Down
5 changes: 5 additions & 0 deletions tests/python_tests/ov_genai_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,8 @@ def load_pipe(configs: List[Tuple], temp_path):
with (temp_path / config_name).open('w') as f:
json.dump(config_json, f)
return ov_genai.LLMPipeline(str(temp_path))


@functools.lru_cache(1)
def get_continuous_batching(path):
return ov_genai.LLMPipeline(str(path), ov_genai.Tokenizer(str(path)), 'CB')
20 changes: 19 additions & 1 deletion tests/python_tests/test_chat_generate_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import math
import openvino
import openvino_tokenizers
import openvino_genai as ov_genai
Expand All @@ -12,7 +13,8 @@
read_model,
load_tok,
model_tmp_path,
get_chat_templates
get_chat_templates,
get_continuous_batching,
)


Expand Down Expand Up @@ -163,3 +165,19 @@ def test_apply_chat_template(model_tmp_path, chat_config: Tuple[str, Dict]):
print(f'hf reference: {full_history_str_hf}')
print(f'ov_genai out: {full_history_str}')
assert full_history_str == full_history_str_hf


@pytest.mark.parametrize("generation_config", configs[1:])
@pytest.mark.parametrize("model_descr", get_chat_models_list())
@pytest.mark.precommit
def test_chat_continuous_batching_vs_stateful(model_descr, generation_config: Dict):
model_id, path, tokenizer, model, stateful = read_model(model_descr)
cb = get_continuous_batching(path)
stateful.start_chat()
cb.start_chat()
for question in quenstions:
generated = cb.generate(question, **generation_config)
reference = stateful.generate(question, **generation_config)
assert generated == reference
# Test that finish_chat() doesn't fail just in case.
cb.finish_chat()
12 changes: 4 additions & 8 deletions tests/python_tests/test_generate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import sys
from pathlib import Path
import torch
import functools
import math
from ov_genai_test_utils import (
get_models_list,
Expand All @@ -20,6 +19,7 @@
load_tok,
model_tmp_path,
STOP_CRITERIA_MAP,
get_continuous_batching,
)


Expand Down Expand Up @@ -675,11 +675,6 @@ def test_left_pad():
run_hf_ov_genai_comparison_batched(models, config, prompts)


@functools.lru_cache(1)
def get_continuous_batching(path):
return ov_genai.LLMPipeline(str(path), ov_genai.Tokenizer(str(path)), 'CB')


@pytest.mark.parametrize("generation_config", test_configs)
@pytest.mark.parametrize("prompt", batched_prompts)
@pytest.mark.precommit
Expand All @@ -694,7 +689,7 @@ def test_continuous_batching_vs_stateful(prompt, generation_config):
generated = cb.generate(prompt, **generation_config)
reference = stateful.generate(prompt, **generation_config)
assert generated.texts == reference.texts
if 1 != generation_config.get("num_beams", 1):
if 1 != generation_config.get("num_return_sequences", 1):
# Stateful puts zeroes to generated.scores. Don't compare them.
for gen, ref in zip(generated.scores, reference.scores):
assert math.isclose(gen, ref, abs_tol=0.0003)
Expand All @@ -710,4 +705,5 @@ def test_cb_streamer_vs_return_vs_stateful(prompt):
streamed = []
generated = cb.generate(prompt, max_new_tokens=20, streamer=lambda subword: streamed.append(subword))
reference = stateful.generate(prompt, max_new_tokens=20)
assert generated == "".join(streamed) == reference
assert generated == "".join(streamed)
assert "".join(streamed) == reference

0 comments on commit 532a804

Please sign in to comment.