forked from ggerganov/llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
batch of input prompts, first try: experimental app was added:
- it reads maximum number of inputs from command line - it processes input with llama_eval_batch
- Loading branch information
Showing
7 changed files
with
921 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
set(TARGET input-batches-experiment) | ||
add_executable(${TARGET} main.cpp) | ||
install(TARGETS ${TARGET} RUNTIME) | ||
target_link_libraries(${TARGET} PRIVATE common llama) | ||
target_compile_features(${TARGET} PRIVATE cxx_std_11) | ||
if(TARGET BUILD_INFO) | ||
add_dependencies(${TARGET} BUILD_INFO) | ||
endif() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,270 @@ | ||
#include "common.h" | ||
|
||
#include "console.h" | ||
#include "llama.h" | ||
#include "build-info.h" | ||
#include "grammar-parser.h" | ||
|
||
#include <cassert> | ||
#include <cinttypes> | ||
#include <cmath> | ||
#include <cstdio> | ||
#include <cstring> | ||
#include <ctime> | ||
#include <fstream> | ||
#include <iostream> | ||
#include <sstream> | ||
#include <string> | ||
#include <vector> | ||
|
||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) | ||
#include <signal.h> | ||
#include <unistd.h> | ||
#elif defined (_WIN32) | ||
#define WIN32_LEAN_AND_MEAN | ||
#ifndef NOMINMAX | ||
#define NOMINMAX | ||
#endif | ||
#include <windows.h> | ||
#include <signal.h> | ||
#endif | ||
|
||
#if defined(_MSC_VER) | ||
#pragma warning(disable: 4244 4267) // possible loss of data | ||
#endif | ||
|
||
int main(int argc, char ** argv) { | ||
gpt_params params; | ||
|
||
if (!gpt_params_parse(argc, argv, params)) { | ||
return 1; | ||
} | ||
|
||
{ | ||
// fixme: hardcoded only for tests | ||
const int32_t batches = 4; | ||
|
||
if (params.n_inputs < batches) { | ||
params.n_inputs = batches; | ||
LOG_TEE("%s: warning: maximum number of parallel inputs set to %d\n", __func__, params.n_inputs); | ||
} | ||
} | ||
|
||
params.perplexity = true; | ||
params.seed = 42; | ||
|
||
// save choice to use color for later | ||
// (note for later: this is a slightly awkward choice) | ||
console::init(params.simple_io, params.use_color); | ||
atexit([]() { console::cleanup(); }); | ||
|
||
if (params.rope_freq_base != 10000.0) { | ||
LOG_TEE("%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base); | ||
} | ||
|
||
if (params.rope_freq_scale != 1.0) { | ||
LOG_TEE("%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale); | ||
} | ||
|
||
LOG_TEE("%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); | ||
|
||
LOG_TEE("%s: seed = %u\n", __func__, params.seed); | ||
|
||
std::mt19937 rng(params.seed); | ||
if (params.random_prompt) { | ||
params.prompt = gpt_random_prompt(rng); | ||
} | ||
|
||
LOG("%s: llama backend init\n", __func__); | ||
llama_backend_init(params.numa); | ||
|
||
llama_model * model; | ||
llama_context * ctx; | ||
llama_context * ctx_guidance = NULL; | ||
|
||
// load the model and apply lora adapter, if any | ||
LOG("%s: load the model and apply lora adapter, if any\n", __func__); | ||
std::tie(model, ctx) = llama_init_from_gpt_params(params); | ||
|
||
if (params.cfg_scale > 1.f) { | ||
struct llama_context_params lparams = llama_context_params_from_gpt_params(params); | ||
ctx_guidance = llama_new_context_with_model(model, lparams); | ||
} | ||
|
||
if (model == NULL) { | ||
LOG_TEE("%s: error: unable to load model\n", __func__); | ||
return 1; | ||
} | ||
|
||
const int n_ctx_train = llama_n_ctx_train(ctx); | ||
if (params.n_ctx > n_ctx_train) { | ||
LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n", | ||
__func__, n_ctx_train, params.n_ctx); | ||
} else if (params.n_ctx < 8) { | ||
LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__); | ||
params.n_ctx = 8; | ||
} | ||
|
||
// print system information | ||
{ | ||
LOG_TEE("\n"); | ||
LOG_TEE("system_info: n_threads = %d / %d | %s\n", | ||
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); | ||
} | ||
|
||
auto print = [](llama_context *ctx_, const std::vector<llama_token>& input, int32_t batch, int32_t seq, int32_t n_past) { | ||
auto n_vocab = llama_n_vocab(ctx_); | ||
auto logits = llama_get_logits(ctx_); | ||
|
||
for (int i = 0; i < batch; ++i) { | ||
for (int j = 0; j < seq; ++j) { | ||
LOG_TEE("\tbatch = %d, seq = %d (n_past = %d), token = %d, first logit = %g\n", i, j, n_past, input[i * seq + j], *(logits + n_vocab * (i * seq + j))) | ||
} | ||
if (i != batch - 1) { | ||
LOG_TEE("\t---\n") | ||
} | ||
} | ||
LOG_TEE("\t==========\n") | ||
}; | ||
|
||
auto run_simple = [print, ¶ms](const char* name, llama_context *ctx_, const std::vector<std::vector<llama_token>> &inputs) { | ||
LOG_TEE("RUN: %s\n", name) | ||
|
||
int n_past = 0; | ||
for (const auto &input: inputs) { | ||
int seq = input.size(); | ||
|
||
llama_eval(ctx_, input.data(), seq, n_past, params.n_threads); | ||
print(ctx_, input, 1, seq, n_past); | ||
|
||
n_past += seq; | ||
} | ||
}; | ||
|
||
auto run_par = [print, ¶ms](const char* name, llama_context *ctx_, int32_t batch, const std::vector<std::vector<llama_token>> &inputs, bool compare_batches = true) { | ||
LOG_TEE("RUN: %s\n", name) | ||
|
||
std::vector<std::vector<float>> logits(batch); | ||
|
||
int n_past = 0; | ||
for (const auto &input: inputs) { | ||
int seq = input.size() / batch; | ||
|
||
llama_eval_batch(ctx_, input.data(), seq, batch, n_past, params.n_threads); | ||
print(ctx_, input, batch, seq, n_past); | ||
|
||
n_past += seq; | ||
|
||
if (compare_batches && batch > 1) { | ||
auto* logits_begin = llama_get_logits(ctx_); | ||
auto n_vocab = llama_n_vocab(ctx_); | ||
for (int i = 0; i < batch; ++i) { | ||
logits[i].assign(logits_begin + n_vocab * seq * i, | ||
logits_begin + n_vocab * seq * (i + 1)); | ||
} | ||
} | ||
} | ||
|
||
if (compare_batches && batch > 1) { | ||
auto equal = std::all_of(logits.begin() + 1, logits.end(), [&](const std::vector<float> &batch_logits) { | ||
return logits[0] == batch_logits; | ||
}); | ||
LOG_TEE("\tAll parallel input results are equal: %s\n", equal ? "TRUE" : "FALSE") | ||
} | ||
}; | ||
|
||
const std::vector<llama_token> tokens{2, 3, 5, 7}; | ||
const std::vector<int32_t> parallels{std::max(params.n_inputs / 2, 1), params.n_inputs}; | ||
|
||
{ | ||
const std::vector<llama_token> base_input{tokens}; | ||
|
||
{ | ||
const int32_t batch = 1; | ||
run_simple("Single prompt processing (llama_eval)", ctx, {base_input}); | ||
run_par("Single prompt processing (llama_eval_batch)", ctx, batch, {base_input}); | ||
} | ||
|
||
for (auto p: parallels) { | ||
std::vector<llama_token> parallel_input; | ||
for (int i = 0; i < p; ++i) { | ||
parallel_input.insert(parallel_input.end(), base_input.begin(), base_input.end()); | ||
} | ||
|
||
std::string name = "Parallel prompt processing, " + std::to_string(p) + " parallel input prompts (llama_eval_batch)"; | ||
run_par(name.c_str(), ctx, p, {parallel_input}); | ||
} | ||
} | ||
{ | ||
std::vector<std::vector<llama_token>> base_input; | ||
for (llama_token token : tokens) { | ||
base_input.push_back({token}); | ||
} | ||
|
||
{ | ||
const int32_t batch = 1; | ||
run_simple("Token-by-Token Single prompt processing (llama_eval)",ctx, base_input); | ||
run_par("Token-by-Token Single prompt processing (llama_eval_batch)", ctx, batch, base_input); | ||
} | ||
|
||
for (auto p: parallels) { | ||
std::vector<std::vector<llama_token>> parallel_input; | ||
for (const auto& input : base_input) { | ||
parallel_input.emplace_back(); | ||
auto& current = parallel_input.back(); | ||
for (int i = 0; i < p; ++i) { | ||
current.insert(current.end(), input.begin(), input.end()); | ||
} | ||
} | ||
|
||
std::string name = "Token-by-Token Parallel prompt processing, " + std::to_string(p) + " parallel input prompts (llama_eval_batch)"; | ||
run_par(name.c_str(), ctx, p, parallel_input); | ||
} | ||
} | ||
{ | ||
using logits_storage = std::array<std::vector<float>, 2>; | ||
logits_storage base_logits_1; | ||
logits_storage base_logits_2; | ||
|
||
auto reserve_storage = [&](logits_storage& storage){ | ||
auto reserve_logits = [&](logits_storage::value_type& logits){ | ||
logits.reserve(llama_n_vocab(ctx) * tokens.size()); | ||
}; | ||
reserve_logits(storage[0]); | ||
reserve_logits(storage[1]); | ||
}; | ||
|
||
reserve_storage(base_logits_1); | ||
reserve_storage(base_logits_2); | ||
|
||
std::vector<llama_token> base_input_1{tokens}; | ||
std::vector<llama_token> base_input_2{tokens}; | ||
std::transform(base_input_2.begin(), base_input_2.end(), base_input_2.begin(), [](llama_token val){return val*2;}); | ||
{ | ||
run_simple("Another Single prompt processing (llama_eval)", ctx, {base_input_1}); | ||
base_logits_1[0].assign(llama_get_logits(ctx), llama_get_logits(ctx) + llama_n_vocab(ctx) * tokens.size()); | ||
|
||
run_simple("Yet Another Single prompt processing (llama_eval)", ctx, {base_input_2}); | ||
base_logits_2[0].assign(llama_get_logits(ctx), llama_get_logits(ctx) + llama_n_vocab(ctx) * tokens.size()); | ||
} | ||
std::vector<llama_token> combined_input{base_input_1}; | ||
combined_input.insert(combined_input.end(), base_input_2.begin(), base_input_2.end()); | ||
run_par("Combined prompts (llama_eval_batch)", ctx, 2, {combined_input}, false); | ||
|
||
base_logits_1[1].assign(llama_get_logits(ctx), llama_get_logits(ctx) + llama_n_vocab(ctx) * tokens.size()); | ||
base_logits_2[1].assign(llama_get_logits(ctx) + llama_n_vocab(ctx) * tokens.size(), llama_get_logits(ctx) + llama_n_vocab(ctx) * tokens.size() * 2); | ||
|
||
bool batch_equal_1 = base_logits_1[0] == base_logits_1[1]; | ||
bool batch_equal_2 = base_logits_2[0] == base_logits_2[1]; | ||
LOG_TEE("\tFirst batch results are equal: %s\n", batch_equal_1 ? "TRUE" : "FALSE") | ||
LOG_TEE("\tSecond batch results are equal: %s\n", batch_equal_2 ? "TRUE" : "FALSE") | ||
LOG_TEE("\tAll parallel input results are equal: %s\n", batch_equal_1 && batch_equal_2 ? "TRUE" : "FALSE") | ||
} | ||
|
||
if (ctx_guidance) { llama_free(ctx_guidance); } | ||
llama_free(ctx); | ||
llama_free_model(model); | ||
llama_backend_free(); | ||
|
||
return 0; | ||
} |
Oops, something went wrong.