Skip to content

Commit

Permalink
Merge pull request #197 from wangzhaode/feature/add_config
Browse files Browse the repository at this point in the history
Refactor tokenizer, add config.json.
  • Loading branch information
wangzhaode authored Jun 4, 2024
2 parents 0f25187 + 2e9d685 commit 97a976f
Show file tree
Hide file tree
Showing 9 changed files with 694 additions and 990 deletions.
9 changes: 6 additions & 3 deletions demo/cli_demo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,23 @@ void benchmark(Llm* llm, std::string prompt_file) {
if (prompt.substr(0, 1) == "#") {
continue;
}
std::string::size_type pos = 0;
while ((pos = prompt.find("\\n", pos)) != std::string::npos) {
prompt.replace(pos, 2, "\n");
pos += 1;
}
prompts.push_back(prompt);
}
int prompt_len = 0;
int decode_len = 0;
int64_t prefill_time = 0;
int64_t decode_time = 0;
llm->warmup();
for (int i = 0; i < prompts.size(); i++) {
llm->response(prompts[i]);
prompt_len += llm->prompt_len_;
decode_len += llm->gen_seq_len_;
prefill_time += llm->prefill_us_;
decode_time += llm->decode_us_;
llm->reset();
}
float prefill_s = prefill_time / 1e6;
float decode_s = decode_time / 1e6;
Expand All @@ -54,7 +57,7 @@ int main(int argc, const char* argv[]) {
std::string model_dir = argv[1];
std::cout << "model path is " << model_dir << std::endl;
std::unique_ptr<Llm> llm(Llm::createLLM(model_dir));
llm->load(model_dir);
llm->load();
if (argc < 3) {
llm->chat();
}
Expand Down
2 changes: 1 addition & 1 deletion demo/memory_demo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ int main(int argc, const char* argv[]) {
if (argc == 4) {
auto llm_dir = argv[3];
std::shared_ptr<Llm> llm(Llm::createLLM(llm_dir));
llm->load(llm_dir);
llm->load();
chat_memory->summarize(llm);
chat_memory->save(memory_dir);
}
Expand Down
42 changes: 28 additions & 14 deletions demo/tokenizer_demo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,41 @@
//

#include "tokenizer.hpp"
#include <fstream>

int main(int argc, const char* argv[]) {
if (argc < 2) {
std::cout << "Usage: " << argv[0] << " tokenizer.txt" << std::endl;
std::cout << "Usage: " << argv[0] << " tokenizer.txt prompt.txt" << std::endl;
return 0;
}
std::string tokenizer_path = argv[1];
std::unique_ptr<Tokenizer> tokenizer_(new Tiktoken);
tokenizer_->load(tokenizer_path);
const std::string system_str = "Youare a helpful assistant.";
const std::string user_str = "<|endoftext|>";
// const std::string query = "\n<|im_start|>system\n" + system_str + "<|im_end|>\n<|im_start|>\n" + user_str + "<|im_end|>\n<|im_start|>assistant\n";
const std::string query = system_str + "\n" + user_str;
auto tokens = tokenizer_->encode(query);
std::string prompt_file = argv[2];
std::unique_ptr<Tokenizer> tokenizer(Tokenizer::createTokenizer(tokenizer_path));

std::string decode_str;
printf("encode tokens = [ ");
for (auto token : tokens) {
decode_str += tokenizer_->decode(token);
std::ifstream prompt_fs(prompt_file);
std::vector<std::string> prompts;
std::string prompt;
while (std::getline(prompt_fs, prompt)) {
// prompt start with '#' will be ignored
if (prompt.substr(0, 1) == "#") {
continue;
}
std::string::size_type pos = 0;
while ((pos = prompt.find("\\n", pos)) != std::string::npos) {
prompt.replace(pos, 2, "\n");
pos += 1;
}
const std::string query = "\n<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n";
std::cout << query;
auto tokens = tokenizer->encode(query);
std::string decode_str;
printf("encode tokens = [ ");
for (auto token : tokens) {
printf("%d, ", token);
decode_str += tokenizer->decode(token);
}
printf("]\n");
printf("decode str = %s\n", decode_str.c_str());
}
printf("]\n");
printf("decode str = %s\n", decode_str.c_str());
return 0;
}
2 changes: 1 addition & 1 deletion demo/web_demo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ int main(int argc, const char* argv[]) {
std::string web_dir = argv[2];
std::cout << "model path is " << model_dir << std::endl;
std::unique_ptr<Llm> llm(Llm::createLLM(model_dir));
llm->load(model_dir);
llm->load();

std::stringstream ss;
httplib::Server svr;
Expand Down
Loading

0 comments on commit 97a976f

Please sign in to comment.