Skip to content

Commit

Permalink
Merge branch 'master' into fix-respect-use-bos-token
Browse files Browse the repository at this point in the history
  • Loading branch information
KerfuffleV2 committed Nov 14, 2023
2 parents c7ff2d5 + bd90eca commit eef5ae3
Show file tree
Hide file tree
Showing 32 changed files with 2,755 additions and 2,035 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++

### Hot topics

- ⚠️ **Upcoming change that might break functionality. Help with testing is needed:** https://github.com/ggerganov/llama.cpp/pull/3912
- *No hot topics atm. Open to suggestions about what is hot today*

----

Expand Down Expand Up @@ -424,7 +424,7 @@ Building the program with BLAS support may lead to some performance improvements
```
The environment variable [`HIP_VISIBLE_DEVICES`](https://rocm.docs.amd.com/en/latest/understand/gpu_isolation.html#hip-visible-devices) can be used to specify which GPU(s) will be used.
If your GPU is not officialy supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 or 11.0.0 on RDNA3.
If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 or 11.0.0 on RDNA3.
The following compilation options are also available to tweak performance (yes, they refer to CUDA, not HIP, because it uses the same code as the cuBLAS version above):
| Option | Legal values | Default | Description |
Expand Down
1 change: 1 addition & 0 deletions common/train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct train_state * init_train_state() {
state->opt = new struct ggml_opt_context;
state->opt->ctx = NULL;
state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
state->opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
state->opt->loss_after = 0.0f;

return state;
Expand Down
2 changes: 2 additions & 0 deletions common/train.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "ggml.h"
#include "llama.h"

#define LLAMA_TRAIN_MAX_NODES 16384

typedef std::string mt19937_state;

struct train_state {
Expand Down
2 changes: 1 addition & 1 deletion docs/token_generation_performance_tips.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ llama_model_load_internal: [cublas] total VRAM used: 17223 MB
If you see these lines, then the GPU is being used.

## Verifying that the CPU is not oversaturated
llama accepts a `-t N` (or `--threads N`) parameter. It's extremely important that this parameter is not too large. If your token generation is extremely slow, try setting this number to 1. If this significantly improves your token generation speed, then your CPU is being oversaturated and you need to explicitly set this parameter to the number of the physicial CPU cores on your machine (even if you utilize a GPU). If in doubt, start with 1 and double the amount until you hit a performance bottleneck, then scale the number down.
llama accepts a `-t N` (or `--threads N`) parameter. It's extremely important that this parameter is not too large. If your token generation is extremely slow, try setting this number to 1. If this significantly improves your token generation speed, then your CPU is being oversaturated and you need to explicitly set this parameter to the number of the physical CPU cores on your machine (even if you utilize a GPU). If in doubt, start with 1 and double the amount until you hit a performance bottleneck, then scale the number down.

# Example of runtime flags effect on inference speed benchmark
These runs were tested on the following machine:
Expand Down
21 changes: 12 additions & 9 deletions examples/benchmark/benchmark-matmult.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ int main(int argc, char ** argv) {
struct ggml_tensor * m11xm2 = ggml_mul_mat(ctx, m11, m2);

// printf("Creating compute graph\n");
struct ggml_cgraph gf = ggml_build_forward(m11xm2);
struct ggml_cgraph * gf = ggml_new_graph(ctx);
ggml_build_forward_expand(gf, m11xm2);

printf("n_threads=%i\n", benchmark_params.n_threads);

Expand All @@ -180,9 +181,9 @@ int main(int argc, char ** argv) {

std::vector<uint8_t> work_buffer;

ggml_graph_compute_helper(work_buffer, &gf, benchmark_params.n_threads);
ggml_graph_compute_helper(work_buffer, gf, benchmark_params.n_threads);

TENSOR_DUMP(gf.nodes[0]);
TENSOR_DUMP(gf->nodes[0]);

printf("\n------ Test 2 - Matrix Mult via %s code\n", ggml_type_name(qtype));

Expand All @@ -200,7 +201,8 @@ int main(int argc, char ** argv) {
struct ggml_tensor * q31 = ggml_mul_mat(ctx, q11, m2);

// printf("Creating compute graph\n");
struct ggml_cgraph gf31 = ggml_build_forward(q31);
struct ggml_cgraph * gf31 = ggml_new_graph(ctx);
ggml_build_forward_expand(gf31, q31);

// Set up a second graph computation to make sure we override the CPU cache lines
// printf("Creating new tensor q12 & Running quantize\n");
Expand All @@ -211,7 +213,8 @@ int main(int argc, char ** argv) {
struct ggml_tensor * q32 = ggml_mul_mat(ctx, q12, m2);

//printf("Creating compute graph\n");
struct ggml_cgraph gf32 = ggml_build_forward(q32);
struct ggml_cgraph * gf32 = ggml_new_graph(ctx);
ggml_build_forward_expand(gf32, q32);
printf("n_threads=%i\n", benchmark_params.n_threads);

const int dimx = sizex;
Expand All @@ -223,7 +226,7 @@ int main(int argc, char ** argv) {


// Let's use the F32 result from above as a reference for the quantized multiplication
float sum_of_F32_reference = tensor_sum_elements(gf.nodes[0]);
float sum_of_F32_reference = tensor_sum_elements(gf->nodes[0]);

printf("Iteration;NThreads; SizeX; SizeY; SizeZ; Required_FLOPS; Elapsed_u_Seconds; gigaFLOPS\n");
printf("=====================================================================================\n");
Expand All @@ -233,7 +236,7 @@ int main(int argc, char ** argv) {

long long int start = ggml_time_us();
//printf("Running ggml_graph_compute\n");
ggml_graph_compute_helper(work_buffer, &gf31, benchmark_params.n_threads);
ggml_graph_compute_helper(work_buffer, gf31, benchmark_params.n_threads);

long long int stop = ggml_time_us();
long long int usec = stop-start;
Expand All @@ -251,7 +254,7 @@ int main(int argc, char ** argv) {

// Check that the matrix multiplication result is in the right ballpark
// We cannot use the exact value from the F32 multiplication because the quantizuation will be slightly different
float sum_of_Q4_result = tensor_sum_elements(gf31.nodes[0]);
float sum_of_Q4_result = tensor_sum_elements(gf31->nodes[0]);
float delta = std::abs(sum_of_Q4_result - sum_of_F32_reference);
float allowed_delta = (sum_of_F32_reference) / 1000 / 1000; // Let's accept an epsilon of 10^-6

Expand All @@ -266,7 +269,7 @@ int main(int argc, char ** argv) {
}

// Running a different graph computation to make sure we override the CPU cache lines
ggml_graph_compute_helper(work_buffer, &gf32, benchmark_params.n_threads);
ggml_graph_compute_helper(work_buffer, gf32, benchmark_params.n_threads);
}
printf("\n");
printf("Average%78.2f\n",gflops_sum/((double)benchmark_params.n_iterations));
Expand Down
4 changes: 2 additions & 2 deletions examples/export-lora/export-lora.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ static struct lora_data * load_lora(struct lora_info * info) {
}

struct ggml_init_params params_ggml;
params_ggml.mem_size = ggml_tensor_overhead() * GGML_MAX_NODES;
params_ggml.mem_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE;
params_ggml.mem_buffer = NULL;
params_ggml.no_alloc = true;
result->ctx = ggml_init(params_ggml);
Expand Down Expand Up @@ -334,7 +334,7 @@ static bool apply_lora(struct ggml_tensor * tensor, struct lora_data * lora, int
float scaling = lora->info.scale * (float)lora->lora_alpha / (float)lora->lora_r;

struct ggml_init_params params;
params.mem_size = GGML_OBJECT_SIZE + GGML_GRAPH_SIZE + ggml_tensor_overhead()*4 + GGML_MEM_ALIGN*5;
params.mem_size = GGML_OBJECT_SIZE + ggml_graph_overhead() + ggml_tensor_overhead()*4 + GGML_MEM_ALIGN*5;
params.mem_buffer = NULL;
params.no_alloc = true;
struct ggml_context * ctx = NULL;
Expand Down
23 changes: 11 additions & 12 deletions examples/finetune/finetune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
if (enable_checkpointing) {
ggml_build_backward_gradient_checkpointing(ctx, gf, gb, gb_tmp, checkpoints.data(), (int) checkpoints.size());
} else {
*gb = *gf;
ggml_graph_cpy(gf, gb);
ggml_build_backward_expand(ctx, gf, gb, true);
}

Expand Down Expand Up @@ -1615,6 +1615,7 @@ int main(int argc, char ** argv) {
opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
opt->params.print_forward_graph = false;
opt->params.print_backward_graph = false;
opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
opt->params.n_threads = params.common.n_threads;
opt->params.past = params.common.opt_past;
opt->params.delta = params.common.opt_delta;
Expand Down Expand Up @@ -1741,11 +1742,9 @@ int main(int argc, char ** argv) {
ggml_allocr_free(alloc);

// context for compute tensors without their data
size_t estimated_compute_size_wo_data = (
ggml_tensor_overhead()*GGML_MAX_NODES*2
+ (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*(
params.common.use_checkpointing ? 3 : 2
)
const size_t estimated_compute_size_wo_data = (
2*LLAMA_TRAIN_MAX_NODES*ggml_tensor_overhead() +
(params.common.use_checkpointing ? 3 : 2)*(GGML_OBJECT_SIZE+ggml_graph_overhead_custom(LLAMA_TRAIN_MAX_NODES, true))
);
struct ggml_init_params ctx_compute_params = {
estimated_compute_size_wo_data, // mem_size
Expand All @@ -1768,11 +1767,11 @@ int main(int argc, char ** argv) {
for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
ctx_compute = ggml_init(ctx_compute_params);
alloc = ggml_allocr_new_measure(tensor_alignment);
gf = ggml_new_graph(ctx_compute);
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = (enum ggml_cgraph_eval_order) order;
gb = ggml_new_graph(ctx_compute);
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gb_tmp = params.common.use_checkpointing
? ggml_new_graph(ctx_compute)
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
: NULL;
loss = llama_build_lora_finetune_graphs(
&model, &lora, alloc, ctx_compute,
Expand Down Expand Up @@ -1801,11 +1800,11 @@ int main(int argc, char ** argv) {
mem_compute_data.resize(max_compute_size);
ctx_compute = ggml_init(ctx_compute_params);
alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
gf = ggml_new_graph(ctx_compute);
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = best_order;
gb = ggml_new_graph(ctx_compute);
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gb_tmp = params.common.use_checkpointing
? ggml_new_graph(ctx_compute)
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
: NULL;
loss = llama_build_lora_finetune_graphs(
&model, &lora, alloc, ctx_compute,
Expand Down
4 changes: 2 additions & 2 deletions examples/llava/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
// measure mem requirement and allocate
{
static const size_t tensor_alignment = 32;
new_clip->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
new_clip->buf_compute.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead());
new_clip->alloc = ggml_allocr_new_measure(tensor_alignment);
clip_image_f32_batch batch;
batch.size = 1;
Expand Down Expand Up @@ -761,7 +761,7 @@ bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip
temp->ny = img->ny;
temp->size = img->size;
temp->data = new uint8_t[temp->size]();
*temp->data = *img->data; // copy
memcpy(&temp->data[0], &img->data[0], temp->size); // copy
}

const int nx = temp->nx;
Expand Down
2 changes: 1 addition & 1 deletion examples/main/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ The `--ctx-size` option allows you to set the size of the prompt context used by

### Extended Context Size

Some fine-tuned models have extened the context length by scaling RoPE. For example, if the original pretrained model have a context length (max sequence length) of 4096 (4k) and the fine-tuned model have 32k. That is a scaling factor of 8, and should work by setting the above `--ctx-size` to 32768 (32k) and `--rope-scale` to 8.
Some fine-tuned models have extended the context length by scaling RoPE. For example, if the original pre-trained model have a context length (max sequence length) of 4096 (4k) and the fine-tuned model have 32k. That is a scaling factor of 8, and should work by setting the above `--ctx-size` to 32768 (32k) and `--rope-scale` to 8.

- `--rope-scale N`: Where N is the linear scaling factor used by the fine-tuned model.

Expand Down
10 changes: 5 additions & 5 deletions examples/metal/metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ int main(int argc, char ** argv) {
struct ggml_context * ctx_data = NULL;
struct ggml_context * ctx_eval = NULL;

struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
struct ggml_cgraph * gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);

// this allocates all Metal resources and memory buffers
auto * ctx_metal = ggml_metal_init(1);
Expand All @@ -46,21 +46,21 @@ int main(int argc, char ** argv) {

// main
{
struct ggml_tensor * input = ggml_graph_get_tensor(&gf, "embd");
struct ggml_tensor * input = ggml_graph_get_tensor(gf, "embd");
*(int32_t *) input->data = 1; // BOS

ggml_metal_set_tensor(ctx_metal, input);

// warmup
ggml_metal_graph_compute(ctx_metal, &gf);
ggml_metal_graph_compute(ctx_metal, gf);

const int n_iter = 16;

const int64_t t0 = ggml_time_us();

// the actual inference happens here
for (int i = 0; i < n_iter; ++i) {
ggml_metal_graph_compute(ctx_metal, &gf);
ggml_metal_graph_compute(ctx_metal, gf);
}

const int64_t t1 = ggml_time_us();
Expand All @@ -70,7 +70,7 @@ int main(int argc, char ** argv) {

// debug output
{
struct ggml_tensor * logits = gf.nodes[gf.n_nodes - 1];
struct ggml_tensor * logits = gf->nodes[gf->n_nodes - 1];
ggml_metal_get_tensor(ctx_metal, logits);

float * ptr = (float *) ggml_get_data(logits);
Expand Down
2 changes: 1 addition & 1 deletion examples/parallel/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# llama.cpp/example/parallel

Simplified simluation for serving incoming requests in parallel
Simplified simulation of serving incoming requests in parallel
23 changes: 11 additions & 12 deletions examples/train-text-from-scratch/train-text-from-scratch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ static struct ggml_tensor * llama_build_train_graphs(
if (enable_checkpointing) {
ggml_build_backward_gradient_checkpointing(ctx, gf, gb, gb_tmp, checkpoints.data(), (int) checkpoints.size());
} else {
*gb = *gf;
ggml_graph_cpy(gf, gb);
ggml_build_backward_expand(ctx, gf, gb, true);
}

Expand Down Expand Up @@ -1006,6 +1006,7 @@ int main(int argc, char ** argv) {
opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
opt->params.print_forward_graph = false;
opt->params.print_backward_graph = false;
opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
opt->params.n_threads = params.common.n_threads;
opt->params.past = params.common.opt_past;
opt->params.delta = params.common.opt_delta;
Expand Down Expand Up @@ -1108,11 +1109,9 @@ int main(int argc, char ** argv) {
ggml_allocr_free(alloc);

// context for compute tensors without their data
size_t estimated_compute_size_wo_data = (
ggml_tensor_overhead()*GGML_MAX_NODES*2
+ (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*(
params.common.use_checkpointing ? 3 : 2
)
const size_t estimated_compute_size_wo_data = (
2*LLAMA_TRAIN_MAX_NODES*ggml_tensor_overhead() +
(params.common.use_checkpointing ? 3 : 2)*(GGML_OBJECT_SIZE+ggml_graph_overhead_custom(LLAMA_TRAIN_MAX_NODES, true))
);
struct ggml_init_params ctx_compute_params = {
estimated_compute_size_wo_data, // mem_size
Expand All @@ -1135,11 +1134,11 @@ int main(int argc, char ** argv) {
for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
ctx_compute = ggml_init(ctx_compute_params);
alloc = ggml_allocr_new_measure(tensor_alignment);
gf = ggml_new_graph(ctx_compute);
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = (enum ggml_cgraph_eval_order) order;
gb = ggml_new_graph(ctx_compute);
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gb_tmp = params.common.use_checkpointing
? ggml_new_graph(ctx_compute)
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
: NULL;
loss = llama_build_train_graphs(
&model, alloc, ctx_compute,
Expand Down Expand Up @@ -1168,11 +1167,11 @@ int main(int argc, char ** argv) {
mem_compute_data.resize(max_compute_size);
ctx_compute = ggml_init(ctx_compute_params);
alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
gf = ggml_new_graph(ctx_compute);
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = best_order;
gb = ggml_new_graph(ctx_compute);
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gb_tmp = params.common.use_checkpointing
? ggml_new_graph(ctx_compute)
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
: NULL;
loss = llama_build_train_graphs(
&model, alloc, ctx_compute,
Expand Down
Loading

0 comments on commit eef5ae3

Please sign in to comment.