Skip to content

Commit

Permalink
Fix chat scenario. Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
olpipi committed Jul 1, 2024
1 parent 727e772 commit f95544e
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 4 deletions.
55 changes: 55 additions & 0 deletions .github/workflows/causal_lm_cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,58 @@ jobs:
&& export PYTHONPATH=./build/:$PYTHONPATH
&& timeout 50s samples/python/greedy_causal_lm/greedy_causal_lm.py ./redpajama-3b-chat/ "Alan Turing was a"
| diff ./pred_greedy.txt -
cpp-chat_sample-ubuntu:
runs-on: ubuntu-20.04-16-cores
steps:
- uses: actions/checkout@v4
with:
submodules: recursive
- uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Install OpenVINO
run: |
mkdir ./ov/
curl https://storage.openvinotoolkit.org/repositories/openvino/packages/pre-release/2024.2.0rc1/linux/l_openvino_toolkit_ubuntu20_2024.2.0.dev20240524_x86_64.tgz | tar --directory ./ov/ --strip-components 1 -xz
sudo ./ov/install_dependencies/install_openvino_dependencies.sh
- name: Download, convert and build
run: |
source ./ov/setupvars.sh
python -m pip install --upgrade-strategy eager -r ./samples/requirements.txt --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/pre-release
python -m pip install ./thirdparty/openvino_tokenizers/[transformers] --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/pre-release
optimum-cli export openvino --trust-remote-code --weight-format fp16 --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama-1.1B-Chat-v1.0
cmake -DCMAKE_BUILD_TYPE=Release -S ./ -B ./build/
cmake --build ./build/ --config Release -j
- name: Compare
run: |
source ./ov/setupvars.sh
printf 'What is 2 + 2?\nWhat is the previous answer?\nAdd 1 to it.\nSubtract 5 from it.\nWhy is the sun yellow?\nWhat was my first question?\nStop!\n' > ./input.txt
timeout 30s cat input.txt | ./build/samples/cpp/chat_sample/chat_sample ./TinyLlama-1.1B-Chat-v1.0/ > ./pred.txt
python -c "
from transformers import LlamaTokenizer, AutoModelForCausalLM
with open('pred.txt', 'r') as file:
predictions = file.read()
model_id = 'TinyLlama/TinyLlama-1.1B-Chat-v1.0'
tokenizer = LlamaTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
prompts = ['What is 2 + 2?', 'What is the previous answer?', 'Add 1 to it.', 'Subtract 5 from it.', 'Why is the sun yellow?', 'What was my first question?']
def gen_prompt(prompt):
return {'role': 'user', 'content': prompt}
def gen_answer(answer):
return {'role': 'assistant', 'content': answer}
chat_history = []
chat_prompt = ''
for prompt in prompts:
chat_history.append(gen_prompt(prompt))
chat_prompt = tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)
tokenized = tokenizer(chat_prompt, return_tensors='pt')
answer = model.generate(**tokenized, max_length=1000, do_sample=False)
answer_str = tokenizer.decode(answer[0, tokenized['input_ids'].numel():], skip_special_tokens=True)
chat_history.append(gen_answer(answer_str))
idx = predictions.find(answer_str)
if -1 == idx:
raise RuntimeError(f'Missing "{answer_str=}" from predictions')
predictions = predictions[:idx] + predictions[idx + len(answer_str):]
"
echo "Chat sample?" passed
31 changes: 27 additions & 4 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,24 @@
#include "utils.hpp"
#include "text_callback_streamer.hpp"

namespace {

ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::TokenizedInputs& fisrt, const ov::genai::TokenizedInputs& second){
auto first_size = fisrt.input_ids.get_size();
auto second_size = second.input_ids.get_size();
ov::Shape new_shape{1, first_size - second_size};

ov::Tensor new_input_ids(ov::element::i64, new_shape);
auto data_ptr = fisrt.input_ids.data<int64_t>();
std::copy(data_ptr + second_size, data_ptr + first_size, new_input_ids.data<int64_t>());

ov::Tensor new_attention_mask(ov::element::i64, new_shape);
std::fill_n(new_attention_mask.data<int64_t>(), new_shape[1], 1);

return {new_input_ids, new_attention_mask};
}
}

namespace ov {
namespace genai {

Expand Down Expand Up @@ -101,12 +119,17 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
m_history.push_back({{"role", "user"}, {"content", prompt}});
constexpr bool add_generation_prompt = true;
auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);

prompt = new_templated_chat_history.substr(m_templated_chat_history.size());
auto new_chat_tokens = m_tokenizer.encode(new_templated_chat_history);
if (m_is_cache_empty) {
encoded_input = new_chat_tokens;
} else {
auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history);
encoded_input = subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens);
}
m_templated_chat_history = new_templated_chat_history;
} else {
encoded_input = m_tokenizer.encode(prompt);
}

encoded_input = m_tokenizer.encode(prompt);
}

auto encoded_results = generate(encoded_input, config, streamer);
Expand Down

0 comments on commit f95544e

Please sign in to comment.