diff --git a/src/cpp/src/text_callback_streamer.cpp b/src/cpp/src/text_callback_streamer.cpp index 39ef3bbcf..4169967bc 100644 --- a/src/cpp/src/text_callback_streamer.cpp +++ b/src/cpp/src/text_callback_streamer.cpp @@ -49,20 +49,6 @@ void TextCallbackStreamer::end() { on_finalized_text(res.str()); } -// void TextCallbackStreamer::set_tokenizer(Tokenizer tokenizer) { -// this->m_tokenizer = tokenizer; -// } - -// void TextCallbackStreamer::set_callback(std::function callback) { -// on_decoded_text_callback = callback; -// m_enabled = true; -// } - -// void TextCallbackStreamer::set_callback() { -// on_decoded_text_callback = [](std::string words){}; -// m_enabled = false; -// } - void TextCallbackStreamer::on_finalized_text(const std::string& subword) { if (m_enabled) { on_decoded_text_callback(subword); diff --git a/src/cpp/src/text_callback_streamer.hpp b/src/cpp/src/text_callback_streamer.hpp index 766f80cf9..76e8fdb37 100644 --- a/src/cpp/src/text_callback_streamer.hpp +++ b/src/cpp/src/text_callback_streamer.hpp @@ -15,12 +15,7 @@ class TextCallbackStreamer: public StreamerBase { void end() override; TextCallbackStreamer(const Tokenizer& tokenizer, std::function callback, bool print_eos_token = false); - // ~TextCallbackStreamer() = default; - - // void set_tokenizer(Tokenizer tokenizer); - // void set_callback(std::function callback); - // void set_callback(); - + std::function on_decoded_text_callback = [](std::string words){}; bool m_enabled = false; int64_t m_eos_token; diff --git a/tests/python_tests/list_test_models.py b/tests/python_tests/list_test_models.py index a9454fc21..b45844f2a 100644 --- a/tests/python_tests/list_test_models.py +++ b/tests/python_tests/list_test_models.py @@ -1,7 +1,7 @@ def models_list(): model_ids = [ ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", "TinyLlama-1.1B-Chat-v1.0"), - ("microsoft/phi-1_5", "phi-1_5/"), + # ("microsoft/phi-1_5", "phi-1_5/"), # ("google/gemma-2b-it", "gemma-2b-it"), # ("google/gemma-7b-it", "gemma-7b-it"), diff --git a/text_generation/causal_lm/cpp/chat_sample.cpp b/text_generation/causal_lm/cpp/chat_sample.cpp index 08748b3bb..a288f1f23 100644 --- a/text_generation/causal_lm/cpp/chat_sample.cpp +++ b/text_generation/causal_lm/cpp/chat_sample.cpp @@ -1,36 +1,7 @@ // Copyright (C) 2023-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 -#include #include "openvino/genai/llm_pipeline.hpp" -#include "openvino/genai/streamer_base.hpp" - -using namespace std; - -class CustomStreamer: public ov::genai::StreamerBase { -public: - void put(int64_t token) { - std::cout << token << std::endl; - /* custom decoding/tokens processing code - tokens_cache.push_back(token); - std::string text = m_tokenizer.decode(tokens_cache); - ... - */ - }; - - void end() { - /* custom finalization */ - }; -}; - -std::vector questions = { - "1+1=", - "what was the previous answer?", - "Why is the sky blue?", - "4+10=", - "What is Intel OpenVINO?", - "Can you briefly summarize what I asked you about during this session?", -}; int main(int argc, char* argv[]) try { std::string prompt; @@ -41,23 +12,19 @@ int main(int argc, char* argv[]) try { ov::genai::GenerationConfig config = pipe.get_generation_config(); config.max_new_tokens = 10000; - std::function streamer = [](std::string word) { std::cout << word << std::flush; return true;}; - std::shared_ptr custom_streamer = std::make_shared(); + std::function streamer = [](std::string word) { std::cout << word << std::flush; }; pipe.start_chat(); - for (size_t i = 0; i < questions.size(); i++) { - // std::getline(std::cin, prompt); - prompt = questions[i]; - + for (;;) { std::cout << "question:\n"; - cout << prompt << endl; + + std::getline(std::cin, prompt); + if (prompt == "Stop!") + break; - // auto answer_str = pipe(prompt, config, streamer); - auto answer_str = pipe(prompt, ov::genai::generation_config(config), ov::genai::streamer(streamer)); - // auto answer_str = pipe.generate(prompt, ov::genai::max_new_tokens(10000), ov::genai::streamer(streamer)); - accumulated_str += answer_str; + pipe.generate(prompt, config, streamer); - cout << "\n----------\n"; + std::cout << "\n----------\n"; } pipe.finish_chat(); } catch (const std::exception& error) {