Skip to content

[WIP] Add Fill-In-Middle example #2934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ models-mnt
/baby-llama
/beam-search
/save-load-state
/fill-in-middle
build-info.h
arm_neon.h
compile_commands.json
Expand Down
9 changes: 6 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Define the default target now so that it is always the first target
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple save-load-state server embd-input-test gguf llama-bench baby-llama beam-search tests/test-c.o
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple save-load-state server embd-input-test gguf llama-bench baby-llama beam-search fill-in-middle tests/test-c.o

# Binaries only useful for tests
TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1
Expand Down Expand Up @@ -67,8 +67,8 @@ OPT = -Ofast
else
OPT = -O3
endif
CFLAGS = -I. $(OPT) -std=c11 -fPIC
CXXFLAGS = -I. -I./common $(OPT) -std=c++11 -fPIC
CFLAGS = -I. $(OPT) -std=c11 -fPIC -g -fsanitize=address
CXXFLAGS = -I. -I./common $(OPT) -std=c++11 -fPIC -g -fsanitize=address
LDFLAGS =

ifdef LLAMA_DEBUG
Expand Down Expand Up @@ -475,6 +475,9 @@ baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o common.o $(OBJS)
beam-search: examples/beam-search/beam-search.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

fill-in-middle: examples/fill-in-middle/FIM.c ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

ifneq '' '$(or $(filter clean,$(MAKECMDGOALS)),$(LLAMA_METAL))'
BUILD_TARGETS += metal
endif
Expand Down
5 changes: 5 additions & 0 deletions examples/fill-in-middle/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set(TARGET FIM)
add_executable(${TARGET} FIM.c)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
192 changes: 192 additions & 0 deletions examples/fill-in-middle/FIM.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include "../../llama.h"

/*
The FIM (Fill-In-Middle) objective is useful for generating text conditioned on a prefix and a suffix.
For a quick summary of what's going on here, see issue #2818.
*/


static inline struct llama_context*
codellama_create_fim_context(const char* model_path, const char** error_message) {
struct llama_context_params params = llama_context_default_params();
params.use_mlock = 1;
struct llama_model* model = llama_load_model_from_file(model_path, params);
if (!model) {
*error_message = "Failed to load model.";
return NULL;
}

struct llama_context* context = llama_new_context_with_model(model, params);
if (!context) {
*error_message = "Failed to create context.";
llama_free_model(model);
return NULL;
}

return context;
}

static inline char*
codellama_fill_in_middle(struct llama_context* ctx, const char* prefix, const char* suffix, size_t n_max_tokens, int n_threads, bool spm, const char** error_message) {

int num_tokens;
size_t combined_len = strlen(prefix) + strlen(suffix) + 3;
size_t initial_size = sizeof(llama_token) * combined_len;
llama_token* tokens_end = (llama_token*)malloc(initial_size);
llama_token* tokens = tokens_end;
if (!tokens) {
*error_message = "Failed to allocate memory for tokens.";
return NULL;
}

// Append first part of prompt
*tokens_end++ = spm ? llama_token_suffix(ctx) : llama_token_prefix(ctx);
tokens_end += num_tokens = llama_tokenize(ctx, spm ? suffix : prefix, tokens_end, n_max_tokens, 0);
if (num_tokens < 0) {
*error_message = "Failed to tokenize the prompt.";
free(tokens);
return NULL;
}

// Append second part of prompt
*tokens_end++ = spm ? llama_token_prefix(ctx) : llama_token_suffix(ctx);
tokens_end += num_tokens = llama_tokenize(ctx, spm ? prefix : suffix, tokens_end, n_max_tokens, 0);
if (num_tokens < 0) {
*error_message = "Failed to tokenize the prompt.";
free(tokens);
return NULL;
}

// Append middle token
*tokens_end++ = llama_token_middle(ctx);

// Grow to accommodate the prompt and the max amount of generated tokens
size_t prompt_len = (size_t)(tokens_end - tokens);
size_t min_len = (prompt_len + n_max_tokens);
if (min_len > combined_len) {
llama_token* new_tokens = (llama_token*)realloc(tokens, sizeof(llama_token) * min_len);
if (!new_tokens) {
*error_message = "Failed to allocate memory for tokens.";
free(tokens);
return NULL;
}
tokens = new_tokens;
}

// Evaluate the LM on the prompt.
if (llama_eval(ctx, tokens, prompt_len, 0, n_threads)) {
*error_message = "Failed to evaluate the prompt.";
free(tokens);
return NULL;
}

// Generate tokens until n_max_tokens or the <EOT> token is generated.
llama_token* generated_tokens = tokens + prompt_len;
size_t num_generated_tokens = 0;
int vocab_size = llama_n_vocab(ctx);
for (size_t i = 0; i < n_max_tokens; i++) {
// Evaluate the LM for a single token, obtaining the logits and probabilities.
if (llama_eval(ctx, &generated_tokens[num_generated_tokens], 1, (int)num_generated_tokens, n_threads)) {
*error_message = "Failed to evaluate the prompt.";
free(tokens);
break;
}
float* logits = llama_get_logits(ctx);

// From the logits, select the most likely token.
float highest_log_likelihood = -1;
llama_token likeliest_token = -1;
for (llama_token token_id = 0; token_id < vocab_size; token_id++) {
if (logits[token_id] > highest_log_likelihood) {
highest_log_likelihood = logits[token_id];
likeliest_token = token_id;
}
}

// Don't add the token if it's <EOT>.
if (likeliest_token == llama_token_eot(ctx)) {
break;
}

// Append the token, so it's there for subsequent evaluations.
generated_tokens[num_generated_tokens++] = likeliest_token;

// Translate the token to a string.
char cs[20] = {0};
int token_length = llama_token_to_piece(ctx, likeliest_token, cs, 20);
cs[token_length] = '\0';
printf("%s\n", cs);
}

// Allocate memory for the final result
size_t result_length = 0;
size_t result_capacity = 4096;
char* result = (char*)malloc(sizeof(char) * result_capacity);
if (!result) {
*error_message = "Failed to allocate memory for result.";
free(tokens);
return NULL;
}

// Translate tokens to string, growing the allocation if it's too small.
for (size_t i = 0; i < num_generated_tokens; i++) {
int appended = llama_token_to_piece(ctx, generated_tokens[i], result, result_capacity - result_length);
if (appended < 0) {
i--; // retry the token with a larger buffer
size_t new_capacity = result_capacity * 2;
char* new_result = (char*)realloc(result, sizeof(char) * new_capacity);
if (!new_result) {
*error_message = "Failed to allocate memory for result.";
free(tokens);
free(result);
return NULL;
}
result = new_result;
result_capacity = new_capacity;
}

result_length += appended;
}

free(tokens);
*error_message = NULL;
return result;
}

int main(int argc, char** argv) {
if (argc != 6) {
fprintf(stderr, "Usage: %s <model> <prefix> <suffix> <n_max_tokens> <n_threads>\n", argv[0]);
return 1;
}

char* model = argv[1];
char* prefix = argv[2];
char* suffix = argv[3];
size_t n_max_tokens = atoi(argv[4]) > 0 ? atoi(argv[4]) : 64;
int n_threads = atoi(argv[5]);
bool spm = false;
const char* error_message = NULL;

puts("Loading the model. This could take quite a while...");
struct llama_context* ctx = codellama_create_fim_context(model, &error_message);
if (error_message) {
fprintf(stderr, "Error: %s\n", error_message);
return 1;
}

puts("Model loaded. Generating text...");
char* result = codellama_fill_in_middle(ctx, prefix, suffix, n_max_tokens, n_threads, spm, &error_message);
if (error_message) {
fprintf(stderr, "Error: %s\n", error_message);
return 1;
}

puts("Generated text:");
printf("%s%s%s\n", prefix, result, suffix);

free(result);
llama_free(ctx);
}
35 changes: 35 additions & 0 deletions examples/fill-in-middle/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@

# Example

The FIM (Fill-In-Middle) objective is useful for generating text conditioned on a prefix and a suffix.
This example is for use with codellama, for doing exactly that.

For a quick summary of what's going on here, see issue #2818, and/or read [the FIM paper](https://arxiv.org/abs/2207.14255).

```
Usage: ./fill-in-middle <model> <prefix> <suffix> <n_max_tokens> <n_threads>
```
```sh
./fill-in-middle \
CodeLlama-34B-GGUF/codellama-34b.Q4_K_S.gguf \
$'def add(a, b):\n' \
$'\n' \
64 \
4
```

With prefix:
```py
def add(a, b):

```

And a newline as suffix:
```py

```

We can expect it to generate somethng like:
```py
return a + b
```
23 changes: 23 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,13 @@ struct llama_vocab {

id linefeed_id = 13;

// codellama FIM special tokens
// TODO: load these from the vocabulary.
id special_prefix_id = 32007;
id special_middle_id = 32009;
id special_suffix_id = 32008;
id special_eot_id = 32010;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, does Code Llama 34B have these special tokens?
If it does not, then how would FIM work with it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does, yeah. These are new, and I think only in codellama. I don't think they're in llama2. To get the token ids themselves, the codellama people run the tokenizer, and these are the values that came out.

https://github.com/facebookresearch/codellama/blob/cb51c14ec761370ba2e2bc351374a79265d0465e/llama/tokenizer.py#L28-L31

It should work. But I've been busy with my day job, and haven't gotten a chance to test it yet. Definitely not going to suggest merging until I'm certain.


int find_bpe_rank(std::string token_left, std::string token_right) const {
replace_all(token_left, " ", "\u0120");
replace_all(token_left, "\n", "\u010A");
Expand Down Expand Up @@ -6132,6 +6139,22 @@ llama_token llama_token_nl(const struct llama_context * ctx) {
return ctx->model.vocab.linefeed_id;
}

llama_token llama_token_prefix(const struct llama_context * ctx) {
return ctx->model.vocab.special_prefix_id;
}

llama_token llama_token_middle(const struct llama_context * ctx) {
return ctx->model.vocab.special_middle_id;
}

llama_token llama_token_suffix(const struct llama_context * ctx) {
return ctx->model.vocab.special_suffix_id;
}

llama_token llama_token_eot(const struct llama_context * ctx) {
return ctx->model.vocab.special_eot_id;
}

int llama_tokenize(
struct llama_context * ctx,
const char * text,
Expand Down
6 changes: 6 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,12 @@ extern "C" {
LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence
LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line

// codellama FIM tokens
LLAMA_API llama_token llama_token_prefix(const struct llama_context * ctx); // Beginning of FIM prefix
LLAMA_API llama_token llama_token_middle(const struct llama_context * ctx); // Beginning of FIM middle
LLAMA_API llama_token llama_token_suffix(const struct llama_context * ctx); // Beginning of FIM suffix
LLAMA_API llama_token llama_token_eot (const struct llama_context * ctx); // End of FIM middle

//
// Tokenization
//
Expand Down