Skip to content

Commit

Permalink
batch of input prompts, first try: experimental app was added:
Browse files Browse the repository at this point in the history
- it reads maximum number of inputs from command line
- it processes input with llama_eval_batch
  • Loading branch information
Xarbirus committed Sep 27, 2023
1 parent 99115f3 commit ed25b2e
Show file tree
Hide file tree
Showing 7 changed files with 921 additions and 10 deletions.
16 changes: 14 additions & 2 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_batch = std::stoi(argv[i]);
} else if (arg == "--inputs") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.n_inputs = std::stoi(argv[i]);
} else if (arg == "--keep") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -623,6 +629,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --inputs N maximum number of parallel inputs (default: %d)\n", params.n_inputs);
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
Expand Down Expand Up @@ -742,6 +749,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
lparams.embedding = params.embedding;
lparams.rope_freq_base = params.rope_freq_base;
lparams.rope_freq_scale = params.rope_freq_scale;
lparams.n_inputs = params.n_inputs;

return lparams;
}
Expand Down Expand Up @@ -782,8 +790,12 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
{
LOG("warming up the model with an empty run\n");

const std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads);
std::vector<llama_token> tmp;
for (int i = 0; i < params.n_inputs; ++i) {
tmp.push_back(llama_token_bos(lctx));
tmp.push_back(llama_token_eos(lctx));
}
llama_eval_batch(lctx, tmp.data(), std::min(tmp.size()/params.n_inputs, (size_t) params.n_batch), params.n_inputs, 0, params.n_threads);
llama_reset_timings(lctx);
}

Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ struct gpt_params {
int32_t n_beams = 0; // if non-zero then use beam search of given width.
float rope_freq_base = 0.0f; // RoPE base frequency
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
int32_t n_inputs = 1; // number of inputs (if > 1 -> batch of inputs)

// sampling parameters
int32_t top_k = 40; // <= 0 to use vocab size
Expand Down
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ else()
add_subdirectory(embd-input)
add_subdirectory(llama-bench)
add_subdirectory(beam-search)
add_subdirectory(input-batches-experiment)
if (LLAMA_METAL)
add_subdirectory(metal)
endif()
Expand Down
8 changes: 8 additions & 0 deletions examples/input-batches-experiment/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
set(TARGET input-batches-experiment)
add_executable(${TARGET} main.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama)
target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO)
endif()
270 changes: 270 additions & 0 deletions examples/input-batches-experiment/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
#include "common.h"

#include "console.h"
#include "llama.h"
#include "build-info.h"
#include "grammar-parser.h"

#include <cassert>
#include <cinttypes>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>

#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h>
#include <unistd.h>
#elif defined (_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#include <signal.h>
#endif

#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif

int main(int argc, char ** argv) {
gpt_params params;

if (!gpt_params_parse(argc, argv, params)) {
return 1;
}

{
// fixme: hardcoded only for tests
const int32_t batches = 4;

if (params.n_inputs < batches) {
params.n_inputs = batches;
LOG_TEE("%s: warning: maximum number of parallel inputs set to %d\n", __func__, params.n_inputs);
}
}

params.perplexity = true;
params.seed = 42;

// save choice to use color for later
// (note for later: this is a slightly awkward choice)
console::init(params.simple_io, params.use_color);
atexit([]() { console::cleanup(); });

if (params.rope_freq_base != 10000.0) {
LOG_TEE("%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
}

if (params.rope_freq_scale != 1.0) {
LOG_TEE("%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale);
}

LOG_TEE("%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);

LOG_TEE("%s: seed = %u\n", __func__, params.seed);

std::mt19937 rng(params.seed);
if (params.random_prompt) {
params.prompt = gpt_random_prompt(rng);
}

LOG("%s: llama backend init\n", __func__);
llama_backend_init(params.numa);

llama_model * model;
llama_context * ctx;
llama_context * ctx_guidance = NULL;

// load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
std::tie(model, ctx) = llama_init_from_gpt_params(params);

if (params.cfg_scale > 1.f) {
struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
ctx_guidance = llama_new_context_with_model(model, lparams);
}

if (model == NULL) {
LOG_TEE("%s: error: unable to load model\n", __func__);
return 1;
}

const int n_ctx_train = llama_n_ctx_train(ctx);
if (params.n_ctx > n_ctx_train) {
LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, params.n_ctx);
} else if (params.n_ctx < 8) {
LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;
}

// print system information
{
LOG_TEE("\n");
LOG_TEE("system_info: n_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
}

auto print = [](llama_context *ctx_, const std::vector<llama_token>& input, int32_t batch, int32_t seq, int32_t n_past) {
auto n_vocab = llama_n_vocab(ctx_);
auto logits = llama_get_logits(ctx_);

for (int i = 0; i < batch; ++i) {
for (int j = 0; j < seq; ++j) {
LOG_TEE("\tbatch = %d, seq = %d (n_past = %d), token = %d, first logit = %g\n", i, j, n_past, input[i * seq + j], *(logits + n_vocab * (i * seq + j)))
}
if (i != batch - 1) {
LOG_TEE("\t---\n")
}
}
LOG_TEE("\t==========\n")
};

auto run_simple = [print, &params](const char* name, llama_context *ctx_, const std::vector<std::vector<llama_token>> &inputs) {
LOG_TEE("RUN: %s\n", name)

int n_past = 0;
for (const auto &input: inputs) {
int seq = input.size();

llama_eval(ctx_, input.data(), seq, n_past, params.n_threads);
print(ctx_, input, 1, seq, n_past);

n_past += seq;
}
};

auto run_par = [print, &params](const char* name, llama_context *ctx_, int32_t batch, const std::vector<std::vector<llama_token>> &inputs, bool compare_batches = true) {
LOG_TEE("RUN: %s\n", name)

std::vector<std::vector<float>> logits(batch);

int n_past = 0;
for (const auto &input: inputs) {
int seq = input.size() / batch;

llama_eval_batch(ctx_, input.data(), seq, batch, n_past, params.n_threads);
print(ctx_, input, batch, seq, n_past);

n_past += seq;

if (compare_batches && batch > 1) {
auto* logits_begin = llama_get_logits(ctx_);
auto n_vocab = llama_n_vocab(ctx_);
for (int i = 0; i < batch; ++i) {
logits[i].assign(logits_begin + n_vocab * seq * i,
logits_begin + n_vocab * seq * (i + 1));
}
}
}

if (compare_batches && batch > 1) {
auto equal = std::all_of(logits.begin() + 1, logits.end(), [&](const std::vector<float> &batch_logits) {
return logits[0] == batch_logits;
});
LOG_TEE("\tAll parallel input results are equal: %s\n", equal ? "TRUE" : "FALSE")
}
};

const std::vector<llama_token> tokens{2, 3, 5, 7};
const std::vector<int32_t> parallels{std::max(params.n_inputs / 2, 1), params.n_inputs};

{
const std::vector<llama_token> base_input{tokens};

{
const int32_t batch = 1;
run_simple("Single prompt processing (llama_eval)", ctx, {base_input});
run_par("Single prompt processing (llama_eval_batch)", ctx, batch, {base_input});
}

for (auto p: parallels) {
std::vector<llama_token> parallel_input;
for (int i = 0; i < p; ++i) {
parallel_input.insert(parallel_input.end(), base_input.begin(), base_input.end());
}

std::string name = "Parallel prompt processing, " + std::to_string(p) + " parallel input prompts (llama_eval_batch)";
run_par(name.c_str(), ctx, p, {parallel_input});
}
}
{
std::vector<std::vector<llama_token>> base_input;
for (llama_token token : tokens) {
base_input.push_back({token});
}

{
const int32_t batch = 1;
run_simple("Token-by-Token Single prompt processing (llama_eval)",ctx, base_input);
run_par("Token-by-Token Single prompt processing (llama_eval_batch)", ctx, batch, base_input);
}

for (auto p: parallels) {
std::vector<std::vector<llama_token>> parallel_input;
for (const auto& input : base_input) {
parallel_input.emplace_back();
auto& current = parallel_input.back();
for (int i = 0; i < p; ++i) {
current.insert(current.end(), input.begin(), input.end());
}
}

std::string name = "Token-by-Token Parallel prompt processing, " + std::to_string(p) + " parallel input prompts (llama_eval_batch)";
run_par(name.c_str(), ctx, p, parallel_input);
}
}
{
using logits_storage = std::array<std::vector<float>, 2>;
logits_storage base_logits_1;
logits_storage base_logits_2;

auto reserve_storage = [&](logits_storage& storage){
auto reserve_logits = [&](logits_storage::value_type& logits){
logits.reserve(llama_n_vocab(ctx) * tokens.size());
};
reserve_logits(storage[0]);
reserve_logits(storage[1]);
};

reserve_storage(base_logits_1);
reserve_storage(base_logits_2);

std::vector<llama_token> base_input_1{tokens};
std::vector<llama_token> base_input_2{tokens};
std::transform(base_input_2.begin(), base_input_2.end(), base_input_2.begin(), [](llama_token val){return val*2;});
{
run_simple("Another Single prompt processing (llama_eval)", ctx, {base_input_1});
base_logits_1[0].assign(llama_get_logits(ctx), llama_get_logits(ctx) + llama_n_vocab(ctx) * tokens.size());

run_simple("Yet Another Single prompt processing (llama_eval)", ctx, {base_input_2});
base_logits_2[0].assign(llama_get_logits(ctx), llama_get_logits(ctx) + llama_n_vocab(ctx) * tokens.size());
}
std::vector<llama_token> combined_input{base_input_1};
combined_input.insert(combined_input.end(), base_input_2.begin(), base_input_2.end());
run_par("Combined prompts (llama_eval_batch)", ctx, 2, {combined_input}, false);

base_logits_1[1].assign(llama_get_logits(ctx), llama_get_logits(ctx) + llama_n_vocab(ctx) * tokens.size());
base_logits_2[1].assign(llama_get_logits(ctx) + llama_n_vocab(ctx) * tokens.size(), llama_get_logits(ctx) + llama_n_vocab(ctx) * tokens.size() * 2);

bool batch_equal_1 = base_logits_1[0] == base_logits_1[1];
bool batch_equal_2 = base_logits_2[0] == base_logits_2[1];
LOG_TEE("\tFirst batch results are equal: %s\n", batch_equal_1 ? "TRUE" : "FALSE")
LOG_TEE("\tSecond batch results are equal: %s\n", batch_equal_2 ? "TRUE" : "FALSE")
LOG_TEE("\tAll parallel input results are equal: %s\n", batch_equal_1 && batch_equal_2 ? "TRUE" : "FALSE")
}

if (ctx_guidance) { llama_free(ctx_guidance); }
llama_free(ctx);
llama_free_model(model);
llama_backend_free();

return 0;
}
Loading

0 comments on commit ed25b2e

Please sign in to comment.