Skip to content

Add support for batch size to --perplexity #407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 13, 2023
36 changes: 21 additions & 15 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,27 @@ void perplexity(llama_context * ctx, const gpt_params & params) {

int count = 0;
int seq_count = tokens.size() / params.n_ctx;
int n_vocab = llama_n_vocab(ctx);

double nll = 0.0;

fprintf(stderr, "%s : calculating perplexity over %d chunks\n", __func__, seq_count);
fprintf(stderr, "%s : calculating perplexity over %d chunks, batch_size=%d\n", __func__, seq_count, params.n_batch);

for (int i = 0; i < seq_count; ++i) {
int start = i * params.n_ctx;
int end = start + params.n_ctx - 1; // TODO: this is not optimal, e.g. it makes the batch 511 instead of 512
// it is better to always be power of 2 for better performance
std::vector<llama_token> embd(tokens.begin() + start, tokens.begin() + end);
int end = start + params.n_ctx;

std::vector<float> logits;
int num_batches = (params.n_ctx + params.n_batch - 1) / params.n_batch;
auto start_t = std::chrono::high_resolution_clock::now();
if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
for (int j = 0; j < num_batches; ++j) {
int batch_start = start + j * params.n_batch;
int batch_size = std::min(end - batch_start, params.n_batch);
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * params.n_batch, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
}
auto batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
}
auto end_t = std::chrono::high_resolution_clock::now();
if (i == 0) {
Expand All @@ -59,15 +66,12 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
// Example, we have a context window of 512, we will compute perplexity for each of the
// last 256 tokens. Then, we split the input up into context window size chunks to
// process the entire prompt.

auto logits = llama_get_logits(ctx);
for (int j = params.n_ctx / 2; j < params.n_ctx - 1; ++j) {
for (int j = std::min(512, params.n_ctx / 2); j < params.n_ctx - 1; ++j) {
// Calculate probability of next token, given the previous ones.
int n_vocab = llama_n_vocab(ctx);
std::vector<float> tok_logits(
logits + j * n_vocab,
logits + (j + 1) * n_vocab);
const float prob = softmax(tok_logits)[tokens[start + j + 1]];
logits.begin() + j * n_vocab,
logits.begin() + (j + 1) * n_vocab);
float prob = softmax(tok_logits)[tokens[start + j + 1]];
nll += -std::log(prob);
++count;
}
Expand All @@ -82,11 +86,13 @@ int main(int argc, char ** argv) {
gpt_params params;
params.model = "models/llama-7B/ggml-model.bin";

params.n_batch = 512;
if (gpt_params_parse(argc, argv, params) == false) {
return 1;
}

params.perplexity = true;
params.n_batch = std::min(params.n_batch, params.n_ctx);

if (params.n_ctx > 2048) {
fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
Expand Down