Skip to content

Commit

Permalink
What if we do something crazy like add layers instead of removing them?
Browse files Browse the repository at this point in the history
  • Loading branch information
KerfuffleV2 committed Oct 20, 2023
1 parent a0c2f5c commit 74eebc6
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 63 deletions.
72 changes: 48 additions & 24 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par

llama_batch batch = llama_batch_get_one(NULL, 0, 0, 0);

const int32_t n_layers = 32; // model layer count
const int test_count = 6; // num perplexity chunks to run for each test
const size_t prune_target = 4; // prune this many of the worst results each pass
// end tunables
// model layer count
const int32_t n_layers = 32;

// num perplexity chunks to run for each test
const int test_count = 4;

// prune this many of the worst results each pass
const size_t prune_target = 2;

// start with all but first/last layers disabled and start adding them back
const bool anti_mode = true;

// **** end tunables ***

// 1 = attn, 2 = mlp, 3 = both
int32_t test_skip_type = 0; // but don't mess with this, it's set automatically.
Expand All @@ -340,11 +349,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
skip_types.resize(n_layers);
std::fill(skip_types.begin(), skip_types.end(), 0);
std::vector<std::tuple<int32_t, int32_t, double>> pass_results;
std::vector<int32_t> worsts;
worsts.resize(n_layers);
std::fill(worsts.begin(), worsts.end(), 0);
std::vector<int32_t> extremes;
extremes.resize(n_layers);
std::fill(extremes.begin(), extremes.end(), 0);
if (anti_mode) {
// No pointing in starting with first/last layer disabled.
skip_types[0] = 15;
skip_types[n_layers - 1] = 15;
skips.push_back(0); skips.push_back(0 + n_layers);
skips.push_back(n_layers - 1); skips.push_back(n_layers - 1 + n_layers);
}
int32_t curr_best_layer = -1, curr_best_type = 0;
double curr_best_ppl = -1, ref_ppl = -1;
const int32_t mask = anti_mode ? 3 : 0;

int count = 0;
double nll = 0.0;
Expand Down Expand Up @@ -372,35 +389,40 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
}
if (skip_layer >= n_layers) {
if (curr_best_layer == -1) break;
if (pass_results.size() >= prune_target * 2) {
if (prune_target > 0 && pass_results.size() >= prune_target * 2) {
std::sort(pass_results.begin(), pass_results.end(),
[](const std::tuple<int32_t, int32_t, double> & a, const std::tuple<int32_t, int32_t, double> & b) {
if (anti_mode) return std::get<2>(b) > std::get<2>(a);
return std::get<2>(a) > std::get<2>(b);
}
);
const size_t num_prune = std::min(pass_results.size(), prune_target);
for (size_t temp = 0; temp < num_prune; temp++) {
for (size_t temp = 0, pruned = 0; temp < pass_results.size(); temp++) {
int32_t lidx = std::get<0>(pass_results[temp]);
if (lidx == curr_best_layer && std::get<1>(pass_results[temp]) == curr_best_type) continue;
worsts[lidx] |= std::get<1>(pass_results[temp]);
printf("\nPrune[%zu]: %d (%d) - %.2f\n", temp, lidx, std::get<1>(pass_results[temp]), std::get<2>(pass_results[temp]));
extremes[lidx] |= std::get<1>(pass_results[temp]);
printf("\nPrune[%zu]: %d (%d) - %.2f\n", pruned + 1, lidx,
std::get<1>(pass_results[temp]), std::get<2>(pass_results[temp]));
if (anti_mode) {
skip_types[lidx] |= std::get<1>(pass_results[temp]);
skips.push_back(std::get<1>(pass_results[temp]) == 1 ? lidx : -lidx);
}
if (++pruned >= num_prune) break;
}
}
pass_results.clear();
printf("\n\nADD SKIP %c%3d - ppl vs ref %.4f",
printf("\n\nADD %c%3d - ppl vs ref %.4f",
int(label[curr_best_type]), curr_best_layer,
curr_best_ppl - ref_ppl);
if (curr_best_ppl > ref_ppl * 1.75) break;
if (!anti_mode && curr_best_ppl > ref_ppl * 1.75) break;
skip_types[curr_best_layer] += curr_best_type;
if (std::find(skips.begin(), skips.end(), curr_best_layer) == skips.end()) {
skips.push_back(curr_best_layer);
}
skips.push_back(curr_best_type == 1 ? curr_best_layer : curr_best_layer + n_layers);
curr_best_layer = -1;
curr_best_ppl = -1;
curr_best_type = 0;
skip_layer = n_layers;
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
skip_types[new_sl] = (skip_types[new_sl] & 3) | (worsts[new_sl] << 2);
skip_types[new_sl] = (skip_types[new_sl] & 3) | (extremes[new_sl] << 2);
}
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
int32_t curr_skipped = (skip_types[new_sl] >> 2) | (skip_types[new_sl] & 3);
Expand All @@ -420,16 +442,18 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
logit_history.clear();
prob_history.clear();

int alive = 0;
for (int32_t i = 0; i < n_layers; i++) {
layers[i] = (skip_types[i] & 3) | (i == skip_layer ? test_skip_type : 0);
layers[i] = mask ^ ((skip_types[i] & 3) | (i == skip_layer ? test_skip_type : 0));
alive += !(layers[i] & 1) + !(layers[i] & 2);
}
layers[n_layers] = -1;
printf("\nTEST %c%3d + [", int(label[test_skip_type]), skip_layer);
for (const auto l : skips) {
printf("%c%d, ", int(label[skip_types[l] & 3]), l);
for (auto l : skips) {
printf("%c%d, ", int(label[skip_types[l % n_layers] & 3]), l % n_layers);
}
printf("] - len: %3zu, best:(%c%3d @ %.3f), last took %.2f sec\n",
skips.size() + 1,
printf("] - live: %3d/%3d, best:(%c%3d @ %.3f), last took %.2f sec\n",
alive, n_layers * 2,
int(label[curr_best_type]), curr_best_layer,
curr_best_ppl != -1 ? curr_best_ppl - ref_ppl : 0,
test_t_total);
Expand Down Expand Up @@ -477,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par

const auto t_end = std::chrono::high_resolution_clock::now();

if (i == 0 && skip_layer < 0 && skips.empty()) {
if (i == 0 && skip_layer < 0 && ref_ppl < 0) {
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total * n_chunk);
Expand Down Expand Up @@ -516,7 +540,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
}
fflush(stdout);
if (skip_layer >= 0 && (i + 1 == test_count || (i > 1 && ppl > ref_ppl * 3))) {
if (skip_layer >= 0 && (i + 1 == test_count || (i > 1 && ppl > ref_ppl * 30))) {
i = test_count - 1;
skip_types[skip_layer] |= test_skip_type << 2;
if (curr_best_layer == -1 || ppl < curr_best_ppl) {
Expand Down
Loading

0 comments on commit 74eebc6

Please sign in to comment.