Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perplexity: Compute scores correlated to HellaSwag #2312

Merged
merged 6 commits into from
Jul 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.antiprompt.push_back(argv[i]);
} else if (arg == "--perplexity") {
params.perplexity = true;
} else if (arg == "--perplexity-lines") {
params.perplexity_lines = true;
} else if (arg == "--ignore-eos") {
params.logit_bias[llama_token_eos()] = -INFINITY;
} else if (arg == "--no-penalize-nl") {
Expand Down Expand Up @@ -512,7 +514,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " not recommended: doubles context memory required and no measurable increase in quality\n");
fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp);
fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " --perplexity compute perplexity over the prompt\n");
fprintf(stderr, " --perplexity compute perplexity over each ctx window of the prompt\n");
fprintf(stderr, " --perplexity-lines compute perplexity over each line of the prompt\n");
fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
fprintf(stderr, " --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
if (llama_mlock_supported()) {
Expand Down
1 change: 1 addition & 0 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ struct gpt_params {
bool instruct = false; // instruction mode (used for Alpaca models)
bool penalize_nl = true; // consider newlines as a repeatable token
bool perplexity = false; // compute perplexity over the prompt
bool perplexity_lines = false; // compute perplexity over each line of the prompt
bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory
bool mem_test = false; // compute maximum memory usage
Expand Down
78 changes: 77 additions & 1 deletion examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <cmath>
#include <ctime>
#include <sstream>

#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
Expand Down Expand Up @@ -120,6 +121,77 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
printf("\n");
}

void perplexity_lines(llama_context * ctx, const gpt_params & params) {
// Calculates perplexity over each line of the prompt

std::vector<std::string> prompt_lines;
std::istringstream strstream(params.prompt);
std::string line;

while (std::getline(strstream,line,'\n')) {
prompt_lines.push_back(line);
}

const int n_vocab = llama_n_vocab(ctx);

int counttotal = 0;
size_t n_lines = prompt_lines.size();

double nll = 0.0;

fprintf(stderr, "%s: calculating perplexity over %lu lines\n", __func__, n_lines);

printf("\nLine\tPPL line\tPPL cumulative\n");

for (size_t i = 0; i < n_lines; ++i) {

// Tokenize and insert BOS at start
std::vector<int> batch_embd = ::llama_tokenize(ctx, prompt_lines[i], true);

size_t batch_size = batch_embd.size();

// Stop if line is too long
if( batch_size > (size_t)params.n_ctx ) {
fprintf(stderr, "%s : tokens in line %lu > n_ctxl\n", __func__, i);
return;
}

if (llama_eval(ctx, batch_embd.data(), batch_size, 0, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
}

const auto batch_logits = llama_get_logits(ctx);
std::vector<float> logits;
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);

double nllline = 0.0;
int countline = 0;

// Perplexity over second half of the line
for (size_t j = batch_size/2; j < batch_size - 1; ++j) {
// Calculate probability of next token, given the previous ones.
const std::vector<float> tok_logits(
logits.begin() + (j + 0) * n_vocab,
logits.begin() + (j + 1) * n_vocab);

const float prob = softmax(tok_logits)[batch_embd[ j + 1]];

nllline += -std::log(prob);
++countline;
}

nll += nllline;
counttotal += countline;

// perplexity is e^(average negative log-likelihood)
printf("%lu\t%.8lf\t%.8lf\n", i + 1, std::exp(nllline/countline), std::exp(nll / counttotal) );
fflush(stdout);
}

printf("\n");
}

int main(int argc, char ** argv) {
gpt_params params;

Expand Down Expand Up @@ -168,7 +240,11 @@ int main(int argc, char ** argv) {
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
}

perplexity(ctx, params);
if (params.perplexity_lines) {
perplexity_lines(ctx, params);
} else {
perplexity(ctx, params);
}

llama_print_timings(ctx);
llama_free(ctx);
Expand Down