Skip to content

Commit

Permalink
Initialize tokenizer and simplify str_lookup prototype
Browse files Browse the repository at this point in the history
- new method to initialize tokenizer with a given vocab_size
- removed voacb_size from the arguments of build_tokenizer
- applied the changes in run.c, runq.c, test.c
- pass the tokenizer object to str_lookup
- helps to easily follow - "oh! we will be checking for the string in the Tokenizer"
- passes: unit test (test.c)
- passes: integration test (./run stories15M.bin on Mackbook M1>
  • Loading branch information
pagakarthik committed Feb 25, 2024
1 parent b3c4b6c commit bffee1f
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 27 deletions.
34 changes: 19 additions & 15 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -378,27 +378,30 @@ typedef struct {
unsigned char byte_pieces[512]; // stores all single-byte strings
} Tokenizer;

int compare_tokens(const void *a, const void *b) {
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
}

void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
// i should have written the vocab_size into the tokenizer file... sigh
void init_tokenizer(Tokenizer* t, int vocab_size){
// allocate memory based on the specified vocab_size
t->vocab_size = vocab_size;
// malloc space to hold the scores and the strings
t->vocab = (char**)malloc(vocab_size * sizeof(char*));
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
t->vocab = (char**)malloc(t->vocab_size * sizeof(char*));
t->vocab_scores = (float*)malloc(t->vocab_size * sizeof(float));
t->sorted_vocab = NULL; // initialized lazily
for (int i = 0; i < 256; i++) {
t->byte_pieces[i * 2] = (unsigned char)i;
t->byte_pieces[i * 2 + 1] = '\0';
}
}

int compare_tokens(const void *a, const void *b) {
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
}

void build_tokenizer(Tokenizer* t, char* tokenizer_path) {
// read in the file
FILE *file = fopen(tokenizer_path, "rb");
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
int len;
for (int i = 0; i < vocab_size; i++) {
for (int i = 0; i < t->vocab_size; i++) {
if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}
if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
t->vocab[i] = (char *)malloc(len + 1);
Expand Down Expand Up @@ -442,10 +445,10 @@ void safe_printf(char *piece) {
printf("%s", piece);
}

int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
int str_lookup(char *str, Tokenizer* t) {
// efficiently find the perfect match for str in vocab, return its index or -1 if not found
TokenIndex tok = { .str = str }; // acts as the key to search for
TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
TokenIndex *res = bsearch(&tok, t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
return res != NULL ? res->id : -1;
}

Expand Down Expand Up @@ -480,7 +483,7 @@ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *
// TODO: pretty sure this isn't correct in the general case but I don't have the
// energy to read more of the sentencepiece code to figure out what it's doing
if (text[0] != '\0') {
int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);
int dummy_prefix = str_lookup(" ", t);
tokens[(*n_tokens)++] = dummy_prefix;
}

Expand Down Expand Up @@ -517,7 +520,7 @@ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *
}

// ok c+1 is not a continuation byte, so we've read in a full codepoint
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
int id = str_lookup(str_buffer, t);

if (id != -1) {
// we found this codepoint in vocab, add it as a token
Expand All @@ -542,7 +545,7 @@ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *
for (int i=0; i < (*n_tokens-1); i++) {
// check if we can merge the pair (tokens[i], tokens[i+1])
sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
int id = str_lookup(str_buffer, t);
if (id != -1 && t->vocab_scores[id] > best_score) {
// this merge pair exists in vocab! record its score and position
best_score = t->vocab_scores[id];
Expand Down Expand Up @@ -948,7 +951,8 @@ int main(int argc, char *argv[]) {

// build the Tokenizer via the tokenizer .bin file
Tokenizer tokenizer;
build_tokenizer(&tokenizer, tokenizer_path, transformer.config.vocab_size);
init_tokenizer(&tokenizer, transformer.config.vocab_size);
build_tokenizer(&tokenizer, tokenizer_path);

// build the Sampler
Sampler sampler;
Expand Down
26 changes: 15 additions & 11 deletions runq.c
Original file line number Diff line number Diff line change
Expand Up @@ -501,13 +501,16 @@ int compare_tokens(const void *a, const void *b) {
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
}

void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
// i should have written the vocab_size into the tokenizer file... sigh
void init_tokenizer(Tokenizer* t, int vocab_size){
// allocate memory based on the specified vocab_size
t->vocab_size = vocab_size;
// malloc space to hold the scores and the strings
t->vocab = (char**)malloc(vocab_size * sizeof(char*));
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
t->vocab = (char**)malloc(t->vocab_size * sizeof(char*));
t->vocab_scores = (float*)malloc(t->vocab_size * sizeof(float));
t->sorted_vocab = NULL; // initialized lazily
}

void build_tokenizer(Tokenizer* t, char* tokenizer_path) {
for (int i = 0; i < 256; i++) {
t->byte_pieces[i * 2] = (unsigned char)i;
t->byte_pieces[i * 2 + 1] = '\0';
Expand All @@ -517,7 +520,7 @@ void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
int len;
for (int i = 0; i < vocab_size; i++) {
for (int i = 0; i < t->vocab_size; i++) {
if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}
if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
t->vocab[i] = (char *)malloc(len + 1);
Expand Down Expand Up @@ -561,10 +564,10 @@ void safe_printf(char *piece) {
printf("%s", piece);
}

int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
int str_lookup(char *str, Tokenizer* t) {
// efficiently find the perfect match for str in vocab, return its index or -1 if not found
TokenIndex tok = { .str = str }; // acts as the key to search for
TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
TokenIndex *res = bsearch(&tok, t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
return res != NULL ? res->id : -1;
}

Expand Down Expand Up @@ -599,7 +602,7 @@ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *
// TODO: pretty sure this isn't correct in the general case but I don't have the
// energy to read more of the sentencepiece code to figure out what it's doing
if (text[0] != '\0') {
int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);
int dummy_prefix = str_lookup(" ", t);
tokens[(*n_tokens)++] = dummy_prefix;
}

Expand Down Expand Up @@ -636,7 +639,7 @@ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *
}

// ok c+1 is not a continuation byte, so we've read in a full codepoint
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
int id = str_lookup(str_buffer, t);

if (id != -1) {
// we found this codepoint in vocab, add it as a token
Expand All @@ -661,7 +664,7 @@ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *
for (int i=0; i < (*n_tokens-1); i++) {
// check if we can merge the pair (tokens[i], tokens[i+1])
sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
int id = str_lookup(str_buffer, t);
if (id != -1 && t->vocab_scores[id] > best_score) {
// this merge pair exists in vocab! record its score and position
best_score = t->vocab_scores[id];
Expand Down Expand Up @@ -1067,7 +1070,8 @@ int main(int argc, char *argv[]) {

// build the Tokenizer via the tokenizer .bin file
Tokenizer tokenizer;
build_tokenizer(&tokenizer, tokenizer_path, transformer.config.vocab_size);
init_tokenizer(&tokenizer, transformer.config.vocab_size);
build_tokenizer(&tokenizer, tokenizer_path);

// build the Sampler
Sampler sampler;
Expand Down
3 changes: 2 additions & 1 deletion test.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ void test_prompt_encodings() {
char *tokenizer_path = "tokenizer.bin";
int vocab_size = 32000;
Tokenizer tokenizer;
build_tokenizer(&tokenizer, tokenizer_path, vocab_size);
init_tokenizer(&tokenizer, vocab_size);
build_tokenizer(&tokenizer, tokenizer_path);

// test 0 (test the empty string) (I added this as a simple case)
char *prompt0 = "";
Expand Down

0 comments on commit bffee1f

Please sign in to comment.