Skip to content

Commit

Permalink
fix chat_sample build on Win
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-esir committed May 30, 2024
1 parent 7d1d616 commit 9208110
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 62 deletions.
14 changes: 0 additions & 14 deletions src/cpp/src/text_callback_streamer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,6 @@ void TextCallbackStreamer::end() {
on_finalized_text(res.str());
}

// void TextCallbackStreamer::set_tokenizer(Tokenizer tokenizer) {
// this->m_tokenizer = tokenizer;
// }

// void TextCallbackStreamer::set_callback(std::function<void (std::string)> callback) {
// on_decoded_text_callback = callback;
// m_enabled = true;
// }

// void TextCallbackStreamer::set_callback() {
// on_decoded_text_callback = [](std::string words){};
// m_enabled = false;
// }

void TextCallbackStreamer::on_finalized_text(const std::string& subword) {
if (m_enabled) {
on_decoded_text_callback(subword);
Expand Down
7 changes: 1 addition & 6 deletions src/cpp/src/text_callback_streamer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@ class TextCallbackStreamer: public StreamerBase {
void end() override;

TextCallbackStreamer(const Tokenizer& tokenizer, std::function<void(std::string)> callback, bool print_eos_token = false);
// ~TextCallbackStreamer() = default;

// void set_tokenizer(Tokenizer tokenizer);
// void set_callback(std::function<void (std::string)> callback);
// void set_callback();


std::function<void (std::string)> on_decoded_text_callback = [](std::string words){};
bool m_enabled = false;
int64_t m_eos_token;
Expand Down
2 changes: 1 addition & 1 deletion tests/python_tests/list_test_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
def models_list():
model_ids = [
("TinyLlama/TinyLlama-1.1B-Chat-v1.0", "TinyLlama-1.1B-Chat-v1.0"),
("microsoft/phi-1_5", "phi-1_5/"),
# ("microsoft/phi-1_5", "phi-1_5/"),

# ("google/gemma-2b-it", "gemma-2b-it"),
# ("google/gemma-7b-it", "gemma-7b-it"),
Expand Down
49 changes: 8 additions & 41 deletions text_generation/causal_lm/cpp/chat_sample.cpp
Original file line number Diff line number Diff line change
@@ -1,36 +1,7 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include <openvino/openvino.hpp>
#include "openvino/genai/llm_pipeline.hpp"
#include "openvino/genai/streamer_base.hpp"

using namespace std;

class CustomStreamer: public ov::genai::StreamerBase {
public:
void put(int64_t token) {
std::cout << token << std::endl;
/* custom decoding/tokens processing code
tokens_cache.push_back(token);
std::string text = m_tokenizer.decode(tokens_cache);
...
*/
};

void end() {
/* custom finalization */
};
};

std::vector<string> questions = {
"1+1=",
"what was the previous answer?",
"Why is the sky blue?",
"4+10=",
"What is Intel OpenVINO?",
"Can you briefly summarize what I asked you about during this session?",
};

int main(int argc, char* argv[]) try {
std::string prompt;
Expand All @@ -41,23 +12,19 @@ int main(int argc, char* argv[]) try {

ov::genai::GenerationConfig config = pipe.get_generation_config();
config.max_new_tokens = 10000;
std::function<bool(std::string)> streamer = [](std::string word) { std::cout << word << std::flush; return true;};
std::shared_ptr<ov::genai::StreamerBase> custom_streamer = std::make_shared<CustomStreamer>();
std::function<void(std::string)> streamer = [](std::string word) { std::cout << word << std::flush; };

pipe.start_chat();
for (size_t i = 0; i < questions.size(); i++) {
// std::getline(std::cin, prompt);
prompt = questions[i];

for (;;) {
std::cout << "question:\n";
cout << prompt << endl;

std::getline(std::cin, prompt);
if (prompt == "Stop!")
break;

// auto answer_str = pipe(prompt, config, streamer);
auto answer_str = pipe(prompt, ov::genai::generation_config(config), ov::genai::streamer(streamer));
// auto answer_str = pipe.generate(prompt, ov::genai::max_new_tokens(10000), ov::genai::streamer(streamer));
accumulated_str += answer_str;
pipe.generate(prompt, config, streamer);

cout << "\n----------\n";
std::cout << "\n----------\n";
}
pipe.finish_chat();
} catch (const std::exception& error) {
Expand Down

0 comments on commit 9208110

Please sign in to comment.