Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BPE dropout support, use it in training. #2073

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
// initialize to prompt numer of chars, since n_tokens <= n_prompt_chars
std::vector<llama_token> res(text.size() + (int) add_bos);
const int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos);
const int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos, 0.0);
assert(n >= 0);
res.resize(n);

Expand Down
2 changes: 1 addition & 1 deletion examples/save-load-state/save-load-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ int main(int argc, char ** argv) {
return 1;
}
auto tokens = std::vector<llama_token>(params.n_ctx);
auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true);
auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true, 0.0);

if (n_prompt_tokens < 1) {
fprintf(stderr, "%s : failed to tokenize prompt\n", __func__);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2187,7 +2187,7 @@ int tokenize_file(struct llama_context * lctx, const char * filename, std::vecto

out.resize(buf.size());

int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), buf.size(), false);
int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), buf.size(), false, 0.1f);
if (n_tokens >= 0) {
out.resize(n_tokens);
}
Expand Down
29 changes: 24 additions & 5 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include <mutex>
#include <sstream>
#include <numeric>
#include <random>

#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
Expand Down Expand Up @@ -1717,7 +1718,7 @@ struct llama_sp_bigram {
// original implementation:
// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
struct llama_tokenizer {
llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {}
llama_tokenizer(const llama_vocab & vocab, float dropout): vocab_(vocab), dropout_(dropout) {}

void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
// split string into utf8 chars
Expand Down Expand Up @@ -1759,6 +1760,9 @@ struct llama_tokenizer {
right_sym.n = 0;

//printf("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size);
if (skip_merge()) {
continue;
}

// remove the right sym from the chain
left_sym.next = right_sym.next;
Expand Down Expand Up @@ -1814,13 +1818,27 @@ struct llama_tokenizer {
work_queue_.push(bigram);
}

bool skip_merge()
{
std::uniform_real_distribution<> gen(0.0, 1.0);
if (dropout_ <= 0.0) {
return false;
}
if (dropout_ >= 1.0) {
return true;
}
return gen(rng) < dropout_;
}

const llama_vocab & vocab_;
std::vector<llama_sp_symbol> symbols_;
llama_sp_bigram::queue work_queue_;
float dropout_;
std::mt19937 rng;
howard0su marked this conversation as resolved.
Show resolved Hide resolved
};

static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {
llama_tokenizer tokenizer(vocab);
static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos, float dropout) {
llama_tokenizer tokenizer(vocab, dropout);
std::vector<llama_vocab::id> output;

if (text.empty()) {
Expand Down Expand Up @@ -3407,8 +3425,9 @@ int llama_tokenize(
const char * text,
llama_token * tokens,
int n_max_tokens,
bool add_bos) {
auto res = llama_tokenize(ctx->vocab, text, add_bos);
bool add_bos,
float dropout) {
auto res = llama_tokenize(ctx->vocab, text, add_bos, dropout);

if (n_max_tokens < (int) res.size()) {
fprintf(stderr, "%s: too many tokens\n", __func__);
Expand Down
3 changes: 2 additions & 1 deletion llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,8 @@ extern "C" {
const char * text,
llama_token * tokens,
int n_max_tokens,
bool add_bos);
bool add_bos,
float dropout);

LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
Expand Down
2 changes: 1 addition & 1 deletion tests/test-tokenizer-0.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ int main(int argc, char **argv) {

for (const auto & test_kv : k_tests()) {
std::vector<llama_token> res(test_kv.first.size());
const int n = llama_tokenize(ctx, test_kv.first.c_str(), res.data(), int(res.size()), true);
const int n = llama_tokenize(ctx, test_kv.first.c_str(), res.data(), int(res.size()), true, 0.0);
res.resize(n);

bool correct = res.size() == test_kv.second.size();
Expand Down