Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mpt - Add flags for prompt context size (-c/--ctx_size) #174

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.n_predict = std::stoi(argv[++i]);
} else if (arg == "--top_k") {
params.top_k = std::stoi(argv[++i]);
} else if (arg == "-c" || arg == "--ctx_size") {
params.n_ctx = std::stoi(argv[++i]);
} else if (arg == "--top_p") {
params.top_p = std::stof(argv[++i]);
} else if (arg == "--temp") {
Expand Down Expand Up @@ -76,6 +78,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict);
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " -m FNAME, --model FNAME\n");
Expand Down
4 changes: 3 additions & 1 deletion examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
// CLI argument parsing
//

struct gpt_params {
struct gpt_params { //default values
int32_t seed = -1; // RNG seed
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_predict = 200; // new tokens to predict

int32_t n_ctx = 2048; //default context size

// sampling parameters
int32_t top_k = 40;
float top_p = 0.9f;
Expand Down
24 changes: 14 additions & 10 deletions examples/mpt/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
#include <utility>
#include <vector>

int n_ctx = 4096;

// no defaults for now
struct mpt_hparams {
int32_t d_model = 0;
int32_t max_seq_len = 0;
int32_t n_heads = 0;
int32_t n_layers = 0;
int32_t n_ctx = 4096;
int32_t n_vocab = 0;
float alibi_bias_max = 0;
float clip_qkv = 0;
Expand Down Expand Up @@ -65,7 +64,7 @@ struct mpt_model {
};

// load the model's weights from a file
bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vocab) {
bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vocab, int n_ctx) {
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());

auto fin = std::ifstream(fname, std::ios::binary);
Expand Down Expand Up @@ -97,11 +96,14 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo
fin.read((char *)&hparams.clip_qkv, sizeof(hparams.clip_qkv));
fin.read((char *)&hparams.ftype, sizeof(hparams.ftype));

hparams.n_ctx = n_ctx;

printf("%s: d_model = %d\n", __func__, hparams.d_model);
printf("%s: max_seq_len = %d\n", __func__, hparams.max_seq_len);
printf("%s: n_heads = %d\n", __func__, hparams.n_heads);
printf("%s: n_layers = %d\n", __func__, hparams.n_layers);
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
printf("%s: alibi_bias_max = %f\n", __func__, hparams.alibi_bias_max);
printf("%s: clip_qkv = %f\n", __func__, hparams.clip_qkv);
printf("%s: ftype = %d\n", __func__, hparams.ftype);
Expand Down Expand Up @@ -144,6 +146,7 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo
const size_t n_embd = hparams.d_model;
const size_t n_layer = hparams.n_layers;
const size_t n_vocab = hparams.n_vocab;
const size_t n_ctx = hparams.n_ctx;

ctx_size += n_embd * n_vocab * ggml_type_sizef(wtype); // wte_weight
ctx_size += n_embd * ggml_type_sizef(GGML_TYPE_F32); // norm_f_weight
Expand Down Expand Up @@ -336,12 +339,13 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
const int n_layer = hparams.n_layers;
const int n_head = hparams.n_heads;
const int n_vocab = hparams.n_vocab;
const int n_ctx = hparams.n_ctx;

static size_t buf_size = 256u * 1024 * 1024;
static void * buf = malloc(buf_size);

if (mem_per_token > 0 && mem_per_token * N > buf_size) {
const size_t buf_size_new = 1.1 * (mem_per_token * N); // add 10% to account for ggml object overhead
if (mem_per_token > 0 && mem_per_token * (n_past+N) > buf_size) {
const size_t buf_size_new = 1.1 * (mem_per_token * (n_past+N)); // add 10% to account for ggml object overhead
// printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__,
// buf_size, buf_size_new);

Expand Down Expand Up @@ -520,9 +524,9 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *)ggml_get_data(inpL) + (n_vocab * (N - 1)), sizeof(float) * n_vocab);

if (mem_per_token == 0) {
mem_per_token = ggml_used_mem(ctx0) / N;
}
// Update the memory used per token, so more memory buffer can be increased
// initial estimates are large and then shrink down to more accurate numbers
mem_per_token = ggml_used_mem(ctx0) / (n_past + N);
// printf("used_mem = %zu\n", ggml_used_mem(ctx0));

ggml_free(ctx0);
Expand Down Expand Up @@ -567,7 +571,7 @@ int main(int argc, char ** argv) {
{
const int64_t t_start_us = ggml_time_us();

if (!mpt_model_load(params.model, model, vocab)) {
if (!mpt_model_load(params.model, model, vocab, params.n_ctx)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return 1;
}
Expand All @@ -593,7 +597,7 @@ int main(int argc, char ** argv) {
}
printf("\n");

params.n_predict = std::min(params.n_predict, n_ctx - (int)embd_inp.size());
params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int)embd_inp.size());

std::vector<gpt_vocab::id> embd;

Expand Down