Skip to content

Commit

Permalink
Fixing the last deviations from sentencepiece indicated by test-token…
Browse files Browse the repository at this point in the history
…izer-1 (ggerganov#3170)

* Fix für ggerganov#2721

* Reenable tokenizer test for LLaMa

* Add `console.cpp` dependency

* Fix dependency to `common`

* Fixing wrong fix.

* Make console usage platform specific

Work on compiler warnings.

* Adapting makefile

* Remove trailing whitespace

* Adapting the other parts of the makefile

* Fix typo.

* Fixing the last deviations from sentencepiece indicated by test-tokenizer-1

* Simplify logic

* Add missing change...

* Fix ugly compiler warning

* llama_tokenize should accept strings containing NUL now

* Adding huichen's test case
  • Loading branch information
goerch authored and pkrmf committed Sep 26, 2023
1 parent 50cf679 commit f46d26f
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 14 deletions.
4 changes: 2 additions & 2 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -801,10 +801,10 @@ std::vector<llama_token> llama_tokenize(
// upper limit for the number of tokens
int n_tokens = text.length() + add_bos;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos);
n_tokens = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos);
int check = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
Expand Down
4 changes: 2 additions & 2 deletions examples/train-text-from-scratch/train-text-from-scratch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -965,10 +965,10 @@ int tokenize_file(struct llama_context * lctx, const char * filename, std::vecto

buf[size] = '\0';

int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), out.size(), false);
int n_tokens = llama_tokenize(lctx, buf.data(), buf.size(), out.data(), out.size(), false);
if (n_tokens < 0) {
out.resize(-n_tokens);
n_tokens = llama_tokenize(lctx, buf.data(), out.data(), out.size(), false);
n_tokens = llama_tokenize(lctx, buf.data(), buf.size(), out.data(), out.size(), false);
}
GGML_ASSERT(n_tokens >= 0);
out.resize(n_tokens);
Expand Down
6 changes: 4 additions & 2 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7032,19 +7032,21 @@ llama_token llama_token_nl(const struct llama_context * ctx) {
int llama_tokenize(
struct llama_context * ctx,
const char * text,
int text_len,
llama_token * tokens,
int n_max_tokens,
bool add_bos) {
return llama_tokenize_with_model(&ctx->model, text, tokens, n_max_tokens, add_bos);
return llama_tokenize_with_model(&ctx->model, text, text_len, tokens, n_max_tokens, add_bos);
}

int llama_tokenize_with_model(
const struct llama_model * model,
const char * text,
int text_len,
llama_token * tokens,
int n_max_tokens,
bool add_bos) {
auto res = llama_tokenize_internal(model->vocab, text, add_bos);
auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos);

if (n_max_tokens < (int) res.size()) {
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
Expand Down
2 changes: 2 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,13 +374,15 @@ extern "C" {
LLAMA_API int llama_tokenize(
struct llama_context * ctx,
const char * text,
int text_len,
llama_token * tokens,
int n_max_tokens,
bool add_bos);

LLAMA_API int llama_tokenize_with_model(
const struct llama_model * model,
const char * text,
int text_len,
llama_token * tokens,
int n_max_tokens,
bool add_bos);
Expand Down
1 change: 1 addition & 0 deletions tests/test-tokenizer-0-llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ static const std::map<std::string, std::vector<llama_token>> & k_tests() {
{ " Hello" , { 1678, 15043, }, },
{ " Hello" , { 268, 15043, }, },
{ " Hello\n Hello" , { 268, 15043, 13, 1678, 15043, }, },
{ " (" , { 29871, 313, }, },
};

return _k_tests;
Expand Down
14 changes: 6 additions & 8 deletions tests/test-tokenizer-1-llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,9 @@ int main(int argc, char **argv) {
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
std::string check = llama_detokenize_spm(ctx, tokens);
if (check != str) {
fprintf(stderr, "%s : error: token %d detokenizes to >%s<(%llu) but tokenization of this detokenizes to >%s<(%llu)\n",
fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n",
__func__, i, str.c_str(), str.length(), check.c_str(), check.length());
if(i != 3)
return 2;
return 2;
}
}

Expand All @@ -99,11 +98,10 @@ int main(int argc, char **argv) {
std::string str = codepoint_to_utf8(cp);
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
std::string check = llama_detokenize_spm(ctx, tokens);
if (str != check) {
fprintf(stderr, "%s : error: codepoint %d detokenizes to >%s<(%llu) instead of >%s<(%llu)\n",
if (cp != 9601 && str != check) {
fprintf(stderr, "%s : error: codepoint %d detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
__func__, cp, check.c_str(), check.length(), str.c_str(), str.length());
if(cp != 0 && cp != 9601)
return 3;
return 3;
}
}
}
Expand All @@ -112,7 +110,7 @@ int main(int argc, char **argv) {
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
std::string check = llama_detokenize_spm(ctx, tokens);
if (str != check) {
fprintf(stderr, "%s : error: codepoint %d detokenizes to >%s<(%llu) instead of >%s<(%llu)\n",
fprintf(stderr, "%s : error: codepoint %d detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
__func__, cp, check.c_str(), check.length(), str.c_str(), str.length());
return 4;
}
Expand Down

0 comments on commit f46d26f

Please sign in to comment.