From f85785f6504b08bbf973b56b6a423e94f206b4ef Mon Sep 17 00:00:00 2001 From: Evan Miller Date: Mon, 3 Jul 2023 21:51:05 -0400 Subject: [PATCH 01/23] MPI support, first cut --- examples/main/main.cpp | 2 + examples/perplexity/perplexity.cpp | 2 + ggml.c | 95 ++++++++++++++++++++ ggml.h | 13 +++ llama.cpp | 137 ++++++++++++++++++++++------- llama.h | 4 + 6 files changed, 220 insertions(+), 33 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3a171925ba510..154a4113ad657 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -671,5 +671,7 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); + llama_finalize_backend(); + return 0; } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index dd54ed3c4bd6c..5c45237955bab 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -172,5 +172,7 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); + llama_finalize_backend(); + return 0; } diff --git a/ggml.c b/ggml.c index afeb72ff00ccf..2f00428d3ba5c 100644 --- a/ggml.c +++ b/ggml.c @@ -26,6 +26,8 @@ #include #include +#include + #ifdef GGML_USE_METAL #include #endif @@ -4648,6 +4650,35 @@ struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggm return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, NULL); } +struct ggml_tensor * ggml_send_tensor( + struct ggml_context * ctx, + const struct ggml_tensor *src, + int dst_rank) { + + struct ggml_tensor * result = ggml_new_i32(ctx, 0); + + result->op = GGML_OP_SEND; + result->src0 = src; + result->extra = (void *)dst_rank; + + return result; +} + +struct ggml_tensor * ggml_recv_tensor( + struct ggml_context * ctx, + const struct ggml_tensor *parent, + struct ggml_tensor *dst, + int src_rank) { + + struct ggml_tensor * result = dst; + + result->op = GGML_OP_RECV; + result->src0 = parent; // just used for graph computation + result->extra = (void *)src_rank; + + return result; +} + struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { memset(tensor->data, 0, ggml_nbytes(tensor)); return tensor; @@ -8191,6 +8222,52 @@ static void ggml_compute_forward_dup( } } +// ggml_compute_forward_recv + +static void ggml_compute_forward_recv( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#ifdef GGML_USE_MPI + MPI_Status status; + int my_rank; + MPI_Comm_rank(MPI_COMM_WORLD, &my_rank); + // fprintf(stderr, "(%d) Receiving from (%d)\n", my_rank, (int)dst->extra); + int retval = MPI_Recv(dst->data, dst->ne[0] * dst->ne[1], MPI_FLOAT, (int)dst->extra, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + // fprintf(stderr, "(%d) Received from (%d)\n", my_rank, (int)dst->extra); + GGML_ASSERT(retval == MPI_SUCCESS); +#else + GGML_ASSERT(false); +#endif +} + +// ggml_compute_forward_send + +static void ggml_compute_forward_send( + const struct ggml_compute_params * params, + struct ggml_tensor * src, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + GGML_ASSERT(src->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_I32); +#ifdef GGML_USE_MPI + int my_rank; + MPI_Comm_rank(MPI_COMM_WORLD, &my_rank); + // fprintf(stderr, "(%d) Sending to (%d)\n", my_rank, (int)dst->extra); + int retval = MPI_Send(src->data, src->ne[0] * src->ne[1], MPI_FLOAT, (int)dst->extra, 0, MPI_COMM_WORLD); + // fprintf(stderr, "(%d) Sent to (%d)\n", my_rank, (int)dst->extra); + ggml_set_i32(dst, retval); + GGML_ASSERT(retval == MPI_SUCCESS); +#else + GGML_ASSERT(false); +#endif +} + // ggml_compute_forward_add static void ggml_compute_forward_add_f32( @@ -15420,6 +15497,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_dup(params, tensor->src0, tensor); } break; + case GGML_OP_SEND: + { + ggml_compute_forward_send(params, tensor->src0, tensor); + } break; + case GGML_OP_RECV: + { + ggml_compute_forward_recv(params, tensor); + } break; case GGML_OP_ADD: { ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor); @@ -15710,6 +15795,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); } } break; + case GGML_OP_SEND: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_RECV: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_ADD: { if (src0->grad) { @@ -17058,6 +17151,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { node->n_tasks = 1; } break; + case GGML_OP_SEND: + case GGML_OP_RECV: case GGML_OP_SET: case GGML_OP_CONT: case GGML_OP_RESHAPE: diff --git a/ggml.h b/ggml.h index 11b51f8bd656a..aa78f17dd0254 100644 --- a/ggml.h +++ b/ggml.h @@ -353,6 +353,9 @@ extern "C" { GGML_OP_CROSS_ENTROPY_LOSS_BACK, GGML_OP_COUNT, + + GGML_OP_SEND, + GGML_OP_RECV, }; @@ -556,6 +559,16 @@ extern "C" { GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src); GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src); + GGML_API struct ggml_tensor * ggml_send_tensor( + struct ggml_context * ctx, + const struct ggml_tensor *src, + int dst_rank); + GGML_API struct ggml_tensor * ggml_recv_tensor( + struct ggml_context * ctx, + const struct ggml_tensor *parent, + struct ggml_tensor *dst, + int src_rank); + GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name); GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); diff --git a/llama.cpp b/llama.cpp index a869bbac80304..c7de0bc60b67e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -49,6 +49,8 @@ #include #include +#include + #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data #endif @@ -330,6 +332,9 @@ struct llama_context { ggml_metal_context * ctx_metal = NULL; #endif + int mpi_rank; + int mpi_size; + int buf_last = 0; size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 }; @@ -864,6 +869,15 @@ void llama_init_backend(bool numa) { if (numa) { ggml_numa_init(); } +#ifdef GGML_USE_MPI + MPI_Init(NULL, NULL); +#endif +} + +void llama_finalize_backend() { +#ifdef GGML_USE_MPI + MPI_Finalize(); +#endif } int64_t llama_time_us() { @@ -1307,7 +1321,16 @@ static bool llama_eval_internal( struct ggml_tensor * cur; struct ggml_tensor * inpL; - if (tokens) { + if (lctx.mpi_rank > 0) { +#ifdef GGML_USE_MPI + inpL = ggml_recv_tensor(ctx0, NULL, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N), + lctx.mpi_rank-1); + ggml_set_name(inpL, "recv"); +#else + GGML_ASSERT(false); +#endif + } else if (tokens) { struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); ggml_set_name(embd, "embd"); memcpy(embd->data, tokens, N*ggml_element_size(embd)); @@ -1341,7 +1364,9 @@ static bool llama_eval_internal( } #endif // GGML_USE_CUBLAS - for (int il = 0; il < n_layer; ++il) { + // EMM TODO distribute work more evenly - maybe rank=0 gets the smallest amount? + int slice_size = (n_layer + (lctx.mpi_size - 1)) / lctx.mpi_size; + for (int il = lctx.mpi_rank * slice_size; il < n_layer && il < (lctx.mpi_rank + 1) * slice_size; ++il) { offload_func_t offload_func = llama_nop; #ifdef GGML_USE_CUBLAS @@ -1556,25 +1581,36 @@ static bool llama_eval_internal( // used at the end to optionally extract the embeddings struct ggml_tensor * embeddings = NULL; +#ifdef GGML_USE_MPI + cur = ggml_send_tensor(ctx0, cur, (lctx.mpi_rank+1)%lctx.mpi_size); + ggml_set_name(cur, "send"); +#endif + if (lctx.mpi_rank == 0) { +#ifdef GGML_USE_MPI + cur = ggml_recv_tensor(ctx0, cur, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N), + lctx.mpi_size-1); + ggml_set_name(cur, "recv"); +#endif + // norm + { + cur = ggml_rms_norm(ctx0, cur); + offload_func_nr(cur); + ggml_set_name(cur, "rms_norm_2"); - // norm - { - cur = ggml_rms_norm(ctx0, inpL); - offload_func_nr(cur); - ggml_set_name(cur, "rms_norm_2"); - - // cur = cur*norm(broadcasted) - cur = ggml_mul(ctx0, cur, model.norm); - // offload_func_nr(cur); // TODO CPU + GPU mirrored backend - ggml_set_name(cur, "result_norm"); + // cur = cur*norm(broadcasted) + cur = ggml_mul(ctx0, cur, model.norm); + // offload_func_nr(cur); // TODO CPU + GPU mirrored backend + ggml_set_name(cur, "result_norm"); - embeddings = cur; - } + embeddings = cur; + } - // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); - ggml_set_name(cur, "result_output"); + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + ggml_set_name(cur, "result_output"); + } lctx.use_buf(ctx0, -1); @@ -1632,26 +1668,28 @@ static bool llama_eval_internal( // update kv token count lctx.kv_self.n = n_past + N; - // extract logits - { - auto & logits_out = lctx.logits; + if (lctx.mpi_rank == 0) { + // extract logits + { + auto & logits_out = lctx.logits; - if (lctx.logits_all) { - logits_out.resize(n_vocab * N); - memcpy(logits_out.data(), (float *) ggml_get_data(cur), sizeof(float)*n_vocab*N); - } else { - // return result for just the last token - logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float *) ggml_get_data(cur) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + if (lctx.logits_all) { + logits_out.resize(n_vocab * N); + memcpy(logits_out.data(), (float *) ggml_get_data(cur), sizeof(float)*n_vocab*N); + } else { + // return result for just the last token + logits_out.resize(n_vocab); + memcpy(logits_out.data(), (float *) ggml_get_data(cur) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + } } - } - // extract embeddings - if (!lctx.embedding.empty()) { - auto & embedding_out = lctx.embedding; + // extract embeddings + if (!lctx.embedding.empty()) { + auto & embedding_out = lctx.embedding; - embedding_out.resize(n_embd); - memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); + embedding_out.resize(n_embd); + memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); + } } if (mem_per_token == 0) { @@ -2603,6 +2641,14 @@ struct llama_context * llama_new_context_with_model( ctx->rng = std::mt19937(params.seed); ctx->logits_all = params.logits_all; +#ifdef GGML_USE_MPI + MPI_Comm_size(MPI_COMM_WORLD, &ctx->mpi_size); + MPI_Comm_rank(MPI_COMM_WORLD, &ctx->mpi_rank); +#else + ctx->mpi_size = 1; + ctx->mpi_rank = 0; +#endif + ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; // reserve memory for context buffers @@ -2675,6 +2721,16 @@ struct llama_context * llama_new_context_with_model( } #endif + if (ctx->mpi_rank > 0) { + // Enter a blocking eval loop with dummy input, letting rank=0 drive the process + const std::vector tmp = { llama_token_bos(), }; + while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)); +#ifdef GGML_USE_MPI + MPI_Finalize(); +#endif + exit(1); + } + return ctx; } @@ -3351,6 +3407,13 @@ int llama_eval( int n_tokens, int n_past, int n_threads) { +#ifdef GGML_USE_MPI + // Synchronize the worker node parameters with the root node + MPI_Barrier(MPI_COMM_WORLD); + MPI_Bcast(&n_past, 1, MPI_INT, 0, MPI_COMM_WORLD); + MPI_Bcast(&n_tokens, 1, MPI_INT, 0, MPI_COMM_WORLD); + MPI_Bcast(&n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD); +#endif if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; @@ -3434,6 +3497,14 @@ int llama_n_embd(const struct llama_context * ctx) { return ctx->model.hparams.n_embd; } +int llama_mpi_rank(const struct llama_context * ctx) { + return ctx->mpi_rank; +} + +int llama_mpi_size(const struct llama_context * ctx) { + return ctx->mpi_size; +} + int llama_get_vocab( const struct llama_context * ctx, const char * * strings, diff --git a/llama.h b/llama.h index 5bb1964bd390d..1920584c56b7f 100644 --- a/llama.h +++ b/llama.h @@ -145,6 +145,8 @@ extern "C" { // If numa is true, use NUMA optimizations // Call once at the start of the program LLAMA_API void llama_init_backend(bool numa); + // Call once at the end of the program - currently only used for MPI + LLAMA_API void llama_finalize_backend(); LLAMA_API int64_t llama_time_us(); @@ -257,6 +259,8 @@ extern "C" { LLAMA_API int llama_n_vocab(const struct llama_context * ctx); LLAMA_API int llama_n_ctx (const struct llama_context * ctx); LLAMA_API int llama_n_embd (const struct llama_context * ctx); + LLAMA_API int llama_mpi_rank (const struct llama_context * ctx); + LLAMA_API int llama_mpi_size (const struct llama_context * ctx); // Get the vocabulary as output parameters. // Returns number of results. From d05ca74dd81018dff00afbdd226c40f38c9ed2f5 Mon Sep 17 00:00:00 2001 From: Evan Miller Date: Mon, 3 Jul 2023 23:53:43 -0400 Subject: [PATCH 02/23] fix warnings, update README --- Makefile | 5 +++++ README.md | 29 +++++++++++++++++++++++++++++ ggml.c | 5 +++-- ggml.h | 4 ++-- 4 files changed, 39 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index 03f38bdba04ec..f3bda7b9f3dd7 100644 --- a/Makefile +++ b/Makefile @@ -149,6 +149,11 @@ ifndef LLAMA_NO_ACCELERATE endif endif # LLAMA_NO_ACCELERATE +ifdef LLAMA_MPI + CFLAGS += -DGGML_USE_MPI + CXXFLAGS += -DGGML_USE_MPI +endif # LLAMA_MPI + ifdef LLAMA_OPENBLAS CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas -I/usr/include/openblas LDFLAGS += -lopenblas diff --git a/README.md b/README.md index e890dc9c22e75..5f5dee6bfac57 100644 --- a/README.md +++ b/README.md @@ -267,6 +267,35 @@ Any value larger than 0 will offload the computation to the GPU. For example: ./main -m ./models/7B/ggml-model-q4_0.bin -n 128 -ngl 1 ``` +### MPI Build + +MPI lets you distribute the computation over a cluster of machines. Because of the serial nature of LLM prediction, this won't yield any end-to-end speed-ups, but it will let you run larger models than would otherwise fit into RAM on a single machine. + +First, build llama.cpp and download/convert the weights on all of the machines in your cluster. The paths to the weights and programs should be identical on all machines. You will need to build llama.cpp with an MPI-capable compiler, for example, + +```bash +make CC=mpicc CXX=mpicxx LLAMA_MPI=1 +``` + +Once the programs are built and the weights are downloaded on all machines, ensure password-less SSH access to each machine from the primary host. + +Next, create a `hostfile` with a list of the hostnames and their relative "weights" (slots). If you want to use localhost for computation, use its local subnet IP address rather than the loopback address or "localhost". + +Here is an example hostfile: + +``` +192.168.0.1:2 +malvolio.local:1 +``` + +The above will distribute the computation across 2 processes on the first host and 1 process on the second host. Each process will use roughly an equal amount of RAM. Try to keep these numbers small, as inter-process (intra-host) communication is expensive. + +Finally, you're ready to run a computation using `mpirun`: + +```bash +mpirun -hostfile hostfile -n 3 ./main -m ./models/7B/ggml-model-q4_0.bin -n 128 +``` + ### BLAS Build Building the program with BLAS support may lead to some performance improvements in prompt processing using batch sizes higher than 32 (the default is 512). BLAS doesn't affect the normal generation performance. There are currently three different implementations of it: diff --git a/ggml.c b/ggml.c index 2f00428d3ba5c..074e63cc7c1fb 100644 --- a/ggml.c +++ b/ggml.c @@ -4652,7 +4652,7 @@ struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggm struct ggml_tensor * ggml_send_tensor( struct ggml_context * ctx, - const struct ggml_tensor *src, + struct ggml_tensor *src, int dst_rank) { struct ggml_tensor * result = ggml_new_i32(ctx, 0); @@ -4666,9 +4666,10 @@ struct ggml_tensor * ggml_send_tensor( struct ggml_tensor * ggml_recv_tensor( struct ggml_context * ctx, - const struct ggml_tensor *parent, + struct ggml_tensor *parent, struct ggml_tensor *dst, int src_rank) { + UNUSED(ctx); struct ggml_tensor * result = dst; diff --git a/ggml.h b/ggml.h index aa78f17dd0254..de7bd26409bb5 100644 --- a/ggml.h +++ b/ggml.h @@ -561,11 +561,11 @@ extern "C" { GGML_API struct ggml_tensor * ggml_send_tensor( struct ggml_context * ctx, - const struct ggml_tensor *src, + struct ggml_tensor *src, int dst_rank); GGML_API struct ggml_tensor * ggml_recv_tensor( struct ggml_context * ctx, - const struct ggml_tensor *parent, + struct ggml_tensor *parent, struct ggml_tensor *dst, int src_rank); From 668ba5fe0b62d2b628b9ddf132ffbf85992afbcc Mon Sep 17 00:00:00 2001 From: Evan Miller Date: Tue, 4 Jul 2023 00:09:02 -0400 Subject: [PATCH 03/23] fixes --- Makefile | 4 ++-- examples/simple/simple.cpp | 2 ++ llama.cpp | 20 ++++++++++---------- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/Makefile b/Makefile index f3bda7b9f3dd7..1639715259c1b 100644 --- a/Makefile +++ b/Makefile @@ -150,8 +150,8 @@ ifndef LLAMA_NO_ACCELERATE endif # LLAMA_NO_ACCELERATE ifdef LLAMA_MPI - CFLAGS += -DGGML_USE_MPI - CXXFLAGS += -DGGML_USE_MPI + CFLAGS += -DGGML_USE_MPI -Wno-cast-qual -Wno-int-to-void-pointer-cast -Wno-void-pointer-to-int-cast + CXXFLAGS += -DGGML_USE_MPI -Wno-cast-qual endif # LLAMA_MPI ifdef LLAMA_OPENBLAS diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 2d913cebb813a..57a0fb7c5585d 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -173,6 +173,8 @@ int main(int argc, char ** argv) llama_free( ctx ); llama_free_model( model ); + llama_finalize_backend(); + return 0; } diff --git a/llama.cpp b/llama.cpp index c7de0bc60b67e..a4435897e6ff6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1581,17 +1581,17 @@ static bool llama_eval_internal( // used at the end to optionally extract the embeddings struct ggml_tensor * embeddings = NULL; -#ifdef GGML_USE_MPI - cur = ggml_send_tensor(ctx0, cur, (lctx.mpi_rank+1)%lctx.mpi_size); - ggml_set_name(cur, "send"); -#endif + if (lctx.mpi_size > 1) { + cur = ggml_send_tensor(ctx0, cur, (lctx.mpi_rank+1)%lctx.mpi_size); + ggml_set_name(cur, "send"); + } if (lctx.mpi_rank == 0) { -#ifdef GGML_USE_MPI - cur = ggml_recv_tensor(ctx0, cur, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N), - lctx.mpi_size-1); - ggml_set_name(cur, "recv"); -#endif + if (lctx.mpi_size > 1) { + cur = ggml_recv_tensor(ctx0, cur, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N), + lctx.mpi_size-1); + ggml_set_name(cur, "recv"); + } // norm { cur = ggml_rms_norm(ctx0, cur); From 042c5b278fd8ba947811bb3b5d150d4e3601b11a Mon Sep 17 00:00:00 2001 From: Evan Miller Date: Tue, 4 Jul 2023 00:13:20 -0400 Subject: [PATCH 04/23] wrap includes --- ggml.c | 2 ++ llama.cpp | 2 ++ 2 files changed, 4 insertions(+) diff --git a/ggml.c b/ggml.c index 074e63cc7c1fb..71e77d015cfad 100644 --- a/ggml.c +++ b/ggml.c @@ -26,7 +26,9 @@ #include #include +#ifdef GGML_USE_MPI #include +#endif #ifdef GGML_USE_METAL #include diff --git a/llama.cpp b/llama.cpp index a4435897e6ff6..af22bf64bb735 100644 --- a/llama.cpp +++ b/llama.cpp @@ -49,7 +49,9 @@ #include #include +#ifdef GGML_USE_MPI #include +#endif #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data From 06a239343c96c16f7e75e58989ec19a43f1e4d5e Mon Sep 17 00:00:00 2001 From: Evan Miller Date: Thu, 6 Jul 2023 20:18:41 -0400 Subject: [PATCH 05/23] PR comments --- Makefile | 6 +++- ggml-mpi.c | 81 ++++++++++++++++++++++++++++++++++++++++++++ ggml-mpi.h | 22 ++++++++++++ ggml.c | 98 ------------------------------------------------------ ggml.h | 13 -------- llama.cpp | 31 +++++++++-------- llama.h | 2 -- 7 files changed, 125 insertions(+), 128 deletions(-) create mode 100644 ggml-mpi.c create mode 100644 ggml-mpi.h diff --git a/Makefile b/Makefile index 937b195b8170f..b250debd4abfb 100644 --- a/Makefile +++ b/Makefile @@ -148,8 +148,12 @@ ifndef LLAMA_NO_ACCELERATE endif # LLAMA_NO_ACCELERATE ifdef LLAMA_MPI - CFLAGS += -DGGML_USE_MPI -Wno-cast-qual -Wno-int-to-void-pointer-cast -Wno-void-pointer-to-int-cast + CFLAGS += -DGGML_USE_MPI -Wno-cast-qual CXXFLAGS += -DGGML_USE_MPI -Wno-cast-qual + OBJS += ggml-mpi.o + +ggml-mpi.o: ggml-mpi.c ggml-mpi.h + $(CC) $(CFLAGS) -c $< -o $@ endif # LLAMA_MPI ifdef LLAMA_OPENBLAS diff --git a/ggml-mpi.c b/ggml-mpi.c new file mode 100644 index 0000000000000..bf301d08b5aee --- /dev/null +++ b/ggml-mpi.c @@ -0,0 +1,81 @@ +#include "ggml-mpi.h" + +#include "ggml.h" + +#include +#include +#include +#define UNUSED GGML_UNUSED + +struct ggml_mpi_tensor_info { + int rank; +}; + +// ggml_compute_forward_send + +static void ggml_mpi_compute_forward_send( + struct ggml_tensor * src, + const struct ggml_tensor * orig) { + UNUSED(orig); + GGML_ASSERT(src->type == GGML_TYPE_F32); + + int my_rank; + MPI_Comm_rank(MPI_COMM_WORLD, &my_rank); + + int dst_rank = ((struct ggml_mpi_tensor_info *)src->extra)->rank; + // fprintf(stderr, "(%d) Sending to (%d)\n", my_rank, (int)dst->extra); + int retval = MPI_Send(src->data, ggml_nelements(src), MPI_FLOAT, dst_rank, 0, MPI_COMM_WORLD); + // fprintf(stderr, "(%d) Sent to (%d)\n", my_rank, (int)dst->extra); + GGML_ASSERT(retval == MPI_SUCCESS); +} + +// ggml_compute_forward_recv + +static void ggml_mpi_compute_forward_recv( + struct ggml_tensor * dst, + const struct ggml_tensor * orig, + const struct ggml_tensor * parent) { + UNUSED(parent); + UNUSED(orig); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + MPI_Status status; + + int my_rank; + MPI_Comm_rank(MPI_COMM_WORLD, &my_rank); + + int src_rank = ((struct ggml_mpi_tensor_info *)dst->extra)->rank; + // fprintf(stderr, "(%d) Receiving from (%d)\n", my_rank, src_extra); + int retval = MPI_Recv(dst->data, ggml_nelements(dst), MPI_FLOAT, src_rank, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + // fprintf(stderr, "(%d) Received from (%d)\n", my_rank, src_extra); + GGML_ASSERT(retval == MPI_SUCCESS); +} + +struct ggml_tensor * ggml_mpi_send_tensor( + struct ggml_context * ctx, + struct ggml_tensor *src, + int dst_rank) { + + struct ggml_tensor * result = ggml_map_custom1_inplace_f32(ctx, src, ggml_mpi_compute_forward_send); + + // TODO how/when to free this struct? + struct ggml_mpi_tensor_info *info = calloc(1, sizeof(struct ggml_mpi_tensor_info)); + info->rank = dst_rank; + result->extra = info; + + return result; +} + +struct ggml_tensor * ggml_mpi_recv_tensor( + struct ggml_context * ctx, + struct ggml_tensor *parent, + struct ggml_tensor *dst, + int src_rank) { + struct ggml_tensor * result = ggml_map_custom2_inplace_f32(ctx, dst, parent, ggml_mpi_compute_forward_recv); + + // TODO how/when to free this struct? + struct ggml_mpi_tensor_info *info = calloc(1, sizeof(struct ggml_mpi_tensor_info)); + info->rank = src_rank; + result->extra = info; + + return result; +} diff --git a/ggml-mpi.h b/ggml-mpi.h new file mode 100644 index 0000000000000..ef5269dc5c74f --- /dev/null +++ b/ggml-mpi.h @@ -0,0 +1,22 @@ +#pragma once + +struct ggml_context; +struct ggml_tensor; + +#ifdef __cplusplus +extern "C" { +#endif + +struct ggml_tensor * ggml_mpi_send_tensor( + struct ggml_context * ctx, + struct ggml_tensor *src, + int dst_rank); +struct ggml_tensor * ggml_mpi_recv_tensor( + struct ggml_context * ctx, + struct ggml_tensor *parent, + struct ggml_tensor *dst, + int src_rank); + +#ifdef __cplusplus +} +#endif diff --git a/ggml.c b/ggml.c index 99b7b75a81ef4..d257c3d657b34 100644 --- a/ggml.c +++ b/ggml.c @@ -26,10 +26,6 @@ #include #include -#ifdef GGML_USE_MPI -#include -#endif - #ifdef GGML_USE_METAL #include #endif @@ -4688,36 +4684,6 @@ struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggm return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, NULL); } -struct ggml_tensor * ggml_send_tensor( - struct ggml_context * ctx, - struct ggml_tensor *src, - int dst_rank) { - - struct ggml_tensor * result = ggml_new_i32(ctx, 0); - - result->op = GGML_OP_SEND; - result->src0 = src; - result->extra = (void *)dst_rank; - - return result; -} - -struct ggml_tensor * ggml_recv_tensor( - struct ggml_context * ctx, - struct ggml_tensor *parent, - struct ggml_tensor *dst, - int src_rank) { - UNUSED(ctx); - - struct ggml_tensor * result = dst; - - result->op = GGML_OP_RECV; - result->src0 = parent; // just used for graph computation - result->extra = (void *)src_rank; - - return result; -} - struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { memset(tensor->data, 0, ggml_nbytes(tensor)); return tensor; @@ -8323,52 +8289,6 @@ static void ggml_compute_forward_dup( } } -// ggml_compute_forward_recv - -static void ggml_compute_forward_recv( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#ifdef GGML_USE_MPI - MPI_Status status; - int my_rank; - MPI_Comm_rank(MPI_COMM_WORLD, &my_rank); - // fprintf(stderr, "(%d) Receiving from (%d)\n", my_rank, (int)dst->extra); - int retval = MPI_Recv(dst->data, dst->ne[0] * dst->ne[1], MPI_FLOAT, (int)dst->extra, MPI_ANY_TAG, MPI_COMM_WORLD, &status); - // fprintf(stderr, "(%d) Received from (%d)\n", my_rank, (int)dst->extra); - GGML_ASSERT(retval == MPI_SUCCESS); -#else - GGML_ASSERT(false); -#endif -} - -// ggml_compute_forward_send - -static void ggml_compute_forward_send( - const struct ggml_compute_params * params, - struct ggml_tensor * src, - struct ggml_tensor * dst) { - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - GGML_ASSERT(src->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_I32); -#ifdef GGML_USE_MPI - int my_rank; - MPI_Comm_rank(MPI_COMM_WORLD, &my_rank); - // fprintf(stderr, "(%d) Sending to (%d)\n", my_rank, (int)dst->extra); - int retval = MPI_Send(src->data, src->ne[0] * src->ne[1], MPI_FLOAT, (int)dst->extra, 0, MPI_COMM_WORLD); - // fprintf(stderr, "(%d) Sent to (%d)\n", my_rank, (int)dst->extra); - ggml_set_i32(dst, retval); - GGML_ASSERT(retval == MPI_SUCCESS); -#else - GGML_ASSERT(false); -#endif -} - // ggml_compute_forward_add static void ggml_compute_forward_add_f32( @@ -14655,14 +14575,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_dup(params, tensor->src0, tensor); } break; - case GGML_OP_SEND: - { - ggml_compute_forward_send(params, tensor->src0, tensor); - } break; - case GGML_OP_RECV: - { - ggml_compute_forward_recv(params, tensor); - } break; case GGML_OP_ADD: { ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor); @@ -14961,14 +14873,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); } } break; - case GGML_OP_SEND: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_RECV: - { - GGML_ASSERT(false); // TODO: not implemented - } break; case GGML_OP_ADD: { if (src0->grad) { @@ -16307,8 +16211,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { node->n_tasks = 1; } break; - case GGML_OP_SEND: - case GGML_OP_RECV: case GGML_OP_SET: case GGML_OP_CONT: case GGML_OP_RESHAPE: diff --git a/ggml.h b/ggml.h index f204d13420418..d0710c5559170 100644 --- a/ggml.h +++ b/ggml.h @@ -381,9 +381,6 @@ extern "C" { GGML_OP_CROSS_ENTROPY_LOSS_BACK, GGML_OP_COUNT, - - GGML_OP_SEND, - GGML_OP_RECV, }; @@ -587,16 +584,6 @@ extern "C" { GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src); GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src); - GGML_API struct ggml_tensor * ggml_send_tensor( - struct ggml_context * ctx, - struct ggml_tensor *src, - int dst_rank); - GGML_API struct ggml_tensor * ggml_recv_tensor( - struct ggml_context * ctx, - struct ggml_tensor *parent, - struct ggml_tensor *dst, - int src_rank); - GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name); GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); diff --git a/llama.cpp b/llama.cpp index 99abde3482f0b..42b2f6155fa88 100644 --- a/llama.cpp +++ b/llama.cpp @@ -19,6 +19,9 @@ #ifdef GGML_USE_METAL #include "ggml-metal.h" #endif +#ifdef GGML_USE_MPI +#include "ggml-mpi.h" +#endif #ifdef GGML_USE_K_QUANTS #ifndef QK_K #ifdef GGML_QKK_64 @@ -1332,10 +1335,10 @@ static bool llama_eval_internal( if (lctx.mpi_rank > 0) { #ifdef GGML_USE_MPI - inpL = ggml_recv_tensor(ctx0, NULL, + inpL = ggml_mpi_recv_tensor(ctx0, NULL, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N), lctx.mpi_rank-1); - ggml_set_name(inpL, "recv"); + ggml_set_name(inpL, "mpi_recv"); #else GGML_ASSERT(false); #endif @@ -1591,15 +1594,23 @@ static bool llama_eval_internal( struct ggml_tensor * embeddings = NULL; if (lctx.mpi_size > 1) { - cur = ggml_send_tensor(ctx0, cur, (lctx.mpi_rank+1)%lctx.mpi_size); - ggml_set_name(cur, "send"); +#ifdef GGML_USE_MPI + cur = ggml_mpi_send_tensor(ctx0, cur, (lctx.mpi_rank+1)%lctx.mpi_size); + ggml_set_name(cur, "mpi_send"); +#else + GGML_ASSERT(false); +#endif } if (lctx.mpi_rank == 0) { if (lctx.mpi_size > 1) { - cur = ggml_recv_tensor(ctx0, cur, +#ifdef GGML_USE_MPI + cur = ggml_mpi_recv_tensor(ctx0, cur, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N), lctx.mpi_size-1); - ggml_set_name(cur, "recv"); + ggml_set_name(cur, "mpi_recv"); +#else + GGML_ASSERT(false); +#endif } // norm { @@ -3504,14 +3515,6 @@ int llama_n_embd(const struct llama_context * ctx) { return ctx->model.hparams.n_embd; } -int llama_mpi_rank(const struct llama_context * ctx) { - return ctx->mpi_rank; -} - -int llama_mpi_size(const struct llama_context * ctx) { - return ctx->mpi_size; -} - int llama_get_vocab( const struct llama_context * ctx, const char * * strings, diff --git a/llama.h b/llama.h index 14bc432c73a32..b90c523555da8 100644 --- a/llama.h +++ b/llama.h @@ -273,8 +273,6 @@ extern "C" { LLAMA_API int llama_n_vocab(const struct llama_context * ctx); LLAMA_API int llama_n_ctx (const struct llama_context * ctx); LLAMA_API int llama_n_embd (const struct llama_context * ctx); - LLAMA_API int llama_mpi_rank (const struct llama_context * ctx); - LLAMA_API int llama_mpi_size (const struct llama_context * ctx); // Get the vocabulary as output parameters. // Returns number of results. From 1f0a2cfeda0f0e5e4ff4855a68535d2b5d7c0430 Mon Sep 17 00:00:00 2001 From: Evan Miller Date: Thu, 6 Jul 2023 21:25:34 -0400 Subject: [PATCH 06/23] Update CMakeLists.txt --- CMakeLists.txt | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index a2404548f90d4..adc76d94e6465 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -75,6 +75,7 @@ option(LLAMA_CUDA_DMMV_F16 "llama: use 16 bit floats for dmmv set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K") option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_METAL "llama: use Metal" OFF) +option(LLAMA_MPI "llama: use MPI" OFF) option(LLAMA_K_QUANTS "llama: use k-quants" ON) option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF) @@ -305,6 +306,23 @@ if (LLAMA_METAL) ) endif() +if (LLAMA_MPI) + cmake_minimum_required(VERSION 3.10) + find_package(MPI) + if (MPI_C_FOUND) + message(STATUS "MPI found") + set(GGML_SOURCES_MPI ggml-mpi.c ggml-mpi.h) + add_compile_definitions(GGML_USE_MPI) + add_compile_definitions(${MPI_C_COMPILE_DEFINITIONS}) + set(cxx_flags ${cxx_flags} -Wno-cast-qual) + set(c_flags ${c_flags} -Wno-cast-qual) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_C_LIBRARIES}) + set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${MPI_C_INCLUDE_DIRS}) + else() + message(WARNING "MPI not found") + endif() +endif() + if (LLAMA_CLBLAST) find_package(CLBlast) if (CLBlast_FOUND) @@ -473,6 +491,7 @@ add_library(ggml OBJECT ${GGML_SOURCES_CUDA} ${GGML_SOURCES_OPENCL} ${GGML_SOURCES_METAL} + ${GGML_SOURCES_MPI} ${GGML_SOURCES_EXTRA} ) From 55207ba2b8a1cb14bac8a749d7268af266030d2c Mon Sep 17 00:00:00 2001 From: Evan Miller Date: Thu, 6 Jul 2023 21:40:18 -0400 Subject: [PATCH 07/23] Add GH workflow, fix test --- .github/workflows/build.yml | 28 ++++++++++++++++++++++++++++ tests/test-tokenizer-0.cpp | 4 ++++ 2 files changed, 32 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 12481e8be7cf7..e98ef5b5d5ba9 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -101,6 +101,34 @@ jobs: cd build ctest --verbose + ubuntu-latest-cmake-mpi: + runs-on: ubuntu-latest + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v1 + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential mpich + + - name: Build + id: cmake_build + run: | + mkdir build + cd build + cmake -DLLAMA_MPI=ON .. + cmake --build . --config Release + + - name: Test + id: cmake_test + run: | + cd build + ctest --verbose + macOS-latest-make: runs-on: macos-latest diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp index 20abe710018ee..1d8759a27ebb9 100644 --- a/tests/test-tokenizer-0.cpp +++ b/tests/test-tokenizer-0.cpp @@ -31,6 +31,8 @@ int main(int argc, char **argv) { llama_model * model; llama_context * ctx; + llama_init_backend(false); + // load the vocab { auto lparams = llama_context_default_params(); @@ -97,5 +99,7 @@ int main(int argc, char **argv) { llama_free_model(model); llama_free(ctx); + llama_finalize_backend(); + return 0; } From ef61acfbf5b07f5fed76c1a808232633d318c58f Mon Sep 17 00:00:00 2001 From: Evan Miller Date: Fri, 7 Jul 2023 09:02:23 -0400 Subject: [PATCH 08/23] Add info to README --- README.md | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index af0fef66ffb39..e3497033c468b 100644 --- a/README.md +++ b/README.md @@ -272,15 +272,25 @@ Any value larger than 0 will offload the computation to the GPU. For example: MPI lets you distribute the computation over a cluster of machines. Because of the serial nature of LLM prediction, this won't yield any end-to-end speed-ups, but it will let you run larger models than would otherwise fit into RAM on a single machine. -First, build llama.cpp and download/convert the weights on all of the machines in your cluster. The paths to the weights and programs should be identical on all machines. You will need to build llama.cpp with an MPI-capable compiler, for example, +First you will need MPI libraries installed on your system. The two most popular (only?) options are [MPICH](https://www.mpich.org) and [OpenMPI](https://www.open-mpi.org). Either can be installed with a package manager (`apt`, Homebrew, MacPorts, etc). -```bash -make CC=mpicc CXX=mpicxx LLAMA_MPI=1 -``` +Next you will need to build the project with `LLAMA_MPI` set to true on all machines; if you're building with `make`, you will also need to specify an MPI-capable compiler (when building with CMake, this is configured automatically): + +- Using `make`: + + ```bash + make CC=mpicc CXX=mpicxx LLAMA_MPI=1 + ``` + +- Using `CMake`: + + ```bash + cmake -S . -B build -DLLAMA_MPI=ON + ``` -Once the programs are built and the weights are downloaded on all machines, ensure password-less SSH access to each machine from the primary host. +Once the programs are built, download/convert the weights on all of the machines in your cluster. The paths to the weights and programs should be identical on all machines. -Next, create a `hostfile` with a list of the hostnames and their relative "weights" (slots). If you want to use localhost for computation, use its local subnet IP address rather than the loopback address or "localhost". +Next, ensure password-less SSH access to each machine from the primary host, and create a `hostfile` with a list of the hostnames and their relative "weights" (slots). If you want to use localhost for computation, use its local subnet IP address rather than the loopback address or "localhost". Here is an example hostfile: From 3232db628c8faf595f022ba19203acc104efddb0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 9 Jul 2023 14:08:53 +0300 Subject: [PATCH 09/23] mpi : trying to move more MPI stuff into ggml-mpi (WIP) (#2099) --- examples/embd-input/embd-input-lib.cpp | 2 +- examples/embedding/embedding.cpp | 4 +- examples/main/main.cpp | 4 +- examples/perplexity/perplexity.cpp | 4 +- examples/quantize/quantize.cpp | 4 +- examples/server/server.cpp | 4 +- examples/simple/simple.cpp | 4 +- ggml-mpi.c | 70 +++++++++++++++++++++--- ggml-mpi.h | 28 ++++++++-- llama.cpp | 73 +++++++++++--------------- llama.h | 4 +- 11 files changed, 134 insertions(+), 67 deletions(-) diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index 5fa4942be7aaf..26563821a1078 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -34,7 +34,7 @@ struct MyModel* create_mymodel(int argc, char ** argv) { } fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); - llama_init_backend(params.numa); + llama_backend_init(params.numa); llama_model * model; llama_context * ctx; diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 03e801c2a6d4b..5192d6df5c2f8 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -35,7 +35,7 @@ int main(int argc, char ** argv) { params.prompt = gpt_random_prompt(rng); } - llama_init_backend(params.numa); + llama_backend_init(params.numa); llama_model * model; llama_context * ctx; @@ -93,5 +93,7 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); + llama_backend_free(); + return 0; } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index ef57a8982c64a..07d8fc6ac0781 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -105,7 +105,7 @@ int main(int argc, char ** argv) { params.prompt = gpt_random_prompt(rng); } - llama_init_backend(params.numa); + llama_backend_init(params.numa); llama_model * model; llama_context * ctx; @@ -671,7 +671,7 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); - llama_finalize_backend(); + llama_backend_free(); return 0; } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 68f44ba805966..7e120ff12cb42 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -147,7 +147,7 @@ int main(int argc, char ** argv) { params.prompt = gpt_random_prompt(rng); } - llama_init_backend(params.numa); + llama_backend_init(params.numa); llama_model * model; llama_context * ctx; @@ -172,7 +172,7 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); - llama_finalize_backend(); + llama_backend_free(); return 0; } diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 1eb0f75d6dc79..797d2f0c5a279 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -180,7 +180,7 @@ int main(int argc, char ** argv) { usage(argv[0]); } - llama_init_backend(false); + llama_backend_init(false); // parse command line arguments const std::string fname_inp = argv[arg_idx]; @@ -257,5 +257,7 @@ int main(int argc, char ** argv) { printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0); } + llama_backend_free(); + return 0; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2cbfc0018de3a..296c5d6468f16 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1079,7 +1079,7 @@ int main(int argc, char **argv) params.model_alias = params.model; } - llama_init_backend(params.numa); + llama_backend_init(params.numa); LOG_INFO("build info", {{"build", BUILD_NUMBER}, {"commit", BUILD_COMMIT}}); @@ -1309,5 +1309,7 @@ int main(int argc, char **argv) return 1; } + llama_backend_free(); + return 0; } diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 57a0fb7c5585d..aa2c4352df294 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -66,7 +66,7 @@ int main(int argc, char ** argv) // Init LLM : //--------------------------------- - llama_init_backend(params.numa); + llama_backend_init(params.numa); llama_model * model; llama_context * ctx; @@ -173,7 +173,7 @@ int main(int argc, char ** argv) llama_free( ctx ); llama_free_model( model ); - llama_finalize_backend(); + llama_backend_free(); return 0; } diff --git a/ggml-mpi.c b/ggml-mpi.c index bf301d08b5aee..b68e2c42b4432 100644 --- a/ggml-mpi.c +++ b/ggml-mpi.c @@ -2,9 +2,11 @@ #include "ggml.h" +#include + #include #include -#include + #define UNUSED GGML_UNUSED struct ggml_mpi_tensor_info { @@ -52,9 +54,8 @@ static void ggml_mpi_compute_forward_recv( struct ggml_tensor * ggml_mpi_send_tensor( struct ggml_context * ctx, - struct ggml_tensor *src, - int dst_rank) { - + struct ggml_tensor * src, + int dst_rank) { struct ggml_tensor * result = ggml_map_custom1_inplace_f32(ctx, src, ggml_mpi_compute_forward_send); // TODO how/when to free this struct? @@ -67,9 +68,9 @@ struct ggml_tensor * ggml_mpi_send_tensor( struct ggml_tensor * ggml_mpi_recv_tensor( struct ggml_context * ctx, - struct ggml_tensor *parent, - struct ggml_tensor *dst, - int src_rank) { + struct ggml_tensor * parent, + struct ggml_tensor * dst, + int src_rank) { struct ggml_tensor * result = ggml_map_custom2_inplace_f32(ctx, dst, parent, ggml_mpi_compute_forward_recv); // TODO how/when to free this struct? @@ -79,3 +80,58 @@ struct ggml_tensor * ggml_mpi_recv_tensor( return result; } + +struct ggml_mpi_context { + int mpi_rank; + int mpi_size; +}; + +void ggml_mpi_backend_init(void) { + MPI_Init(NULL, NULL); +} + +void ggml_mpi_backend_free(void) { + MPI_Finalize(); +} + +struct ggml_mpi_context * ggml_mpi_init(void) { + struct ggml_mpi_context * ctx = calloc(1, sizeof(struct ggml_mpi_context)); + + MPI_Comm_rank(MPI_COMM_WORLD, &ctx->mpi_rank); + MPI_Comm_size(MPI_COMM_WORLD, &ctx->mpi_size); + + return ctx; +} + +void ggml_mpi_free(struct ggml_mpi_context * ctx) { + free(ctx); +} + +int ggml_mpi_rank(struct ggml_mpi_context * ctx) { + return ctx->mpi_rank; +} + +struct ggml_tensor * ggml_mpi_eval_init( + struct ggml_mpi_context * ctx_mpi, + struct ggml_context * ctx, + int n_embd, + int * n_tokens, + int * n_past, + int * n_threads) { + struct ggml_tensor * res = NULL; + + // synchronize the worker node parameters with the root node + MPI_Barrier(MPI_COMM_WORLD); + + MPI_Bcast(n_tokens, 1, MPI_INT, 0, MPI_COMM_WORLD); + MPI_Bcast(n_past, 1, MPI_INT, 0, MPI_COMM_WORLD); + MPI_Bcast(n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD); + + if (ctx_mpi->mpi_rank > 0) { + res = ggml_mpi_recv_tensor(ctx, NULL, + ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, *n_tokens), ctx_mpi->mpi_rank - 1); + ggml_set_name(res, "mpi_recv"); + } + + return res; +} diff --git a/ggml-mpi.h b/ggml-mpi.h index ef5269dc5c74f..157c6255d4b75 100644 --- a/ggml-mpi.h +++ b/ggml-mpi.h @@ -9,13 +9,31 @@ extern "C" { struct ggml_tensor * ggml_mpi_send_tensor( struct ggml_context * ctx, - struct ggml_tensor *src, - int dst_rank); + struct ggml_tensor * src, + int dst_rank); struct ggml_tensor * ggml_mpi_recv_tensor( struct ggml_context * ctx, - struct ggml_tensor *parent, - struct ggml_tensor *dst, - int src_rank); + struct ggml_tensor * parent, + struct ggml_tensor * dst, + int src_rank); + +struct ggml_mpi_context; + +void ggml_mpi_backend_init(void); +void ggml_mpi_backend_free(void); + +struct ggml_mpi_context * ggml_mpi_init(void); +void ggml_mpi_free(struct ggml_mpi_context * ctx); + +int ggml_mpi_rank(struct ggml_mpi_context * ctx); + +struct ggml_tensor * ggml_mpi_eval_init( + struct ggml_mpi_context * ctx_mpi, + struct ggml_context * ctx, + int n_embd, + int * n_tokens, + int * n_past, + int * n_threads); #ifdef __cplusplus } diff --git a/llama.cpp b/llama.cpp index 42b2f6155fa88..d84e827c33025 100644 --- a/llama.cpp +++ b/llama.cpp @@ -52,10 +52,6 @@ #include #include -#ifdef GGML_USE_MPI -#include -#endif - #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data #endif @@ -337,8 +333,9 @@ struct llama_context { ggml_metal_context * ctx_metal = NULL; #endif - int mpi_rank; - int mpi_size; +#ifdef GGML_USE_MPI + ggml_mpi_context * ctx_mpi = NULL; +#endif int buf_last = 0; size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 }; @@ -859,7 +856,7 @@ bool llama_mlock_supported() { return llama_mlock::SUPPORTED; } -void llama_init_backend(bool numa) { +void llama_backend_init(bool numa) { ggml_time_init(); // needed to initialize f16 tables @@ -872,14 +869,15 @@ void llama_init_backend(bool numa) { if (numa) { ggml_numa_init(); } + #ifdef GGML_USE_MPI - MPI_Init(NULL, NULL); + ggml_mpi_backend_init(); #endif } -void llama_finalize_backend() { +void llama_backend_free() { #ifdef GGML_USE_MPI - MPI_Finalize(); + ggml_mpi_backend_free(); #endif } @@ -1282,9 +1280,9 @@ static bool llama_eval_internal( llama_context & lctx, const llama_token * tokens, const float * embd, - const int n_tokens, - const int n_past, - const int n_threads, + int n_tokens, + int n_past, + int n_threads, const char * cgraph_fname) { LLAMA_ASSERT((!tokens && embd) || (tokens && !embd)); @@ -1333,16 +1331,14 @@ static bool llama_eval_internal( struct ggml_tensor * cur; struct ggml_tensor * inpL; - if (lctx.mpi_rank > 0) { #ifdef GGML_USE_MPI - inpL = ggml_mpi_recv_tensor(ctx0, NULL, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N), - lctx.mpi_rank-1); - ggml_set_name(inpL, "mpi_recv"); -#else - GGML_ASSERT(false); + inpL = ggml_mpi_eval_init(lctx.ctx_mpi, ctx0, n_embd, &n_tokens, &n_past, &n_threads); + + if (inpL) { + // only rank 0 loads uses the input + } else #endif - } else if (tokens) { + if (tokens) { struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); ggml_set_name(embd, "embd"); memcpy(embd->data, tokens, N*ggml_element_size(embd)); @@ -1585,7 +1581,6 @@ static bool llama_eval_internal( // input for next layer inpL = cur; - } lctx.use_buf(ctx0, 0); @@ -1601,6 +1596,7 @@ static bool llama_eval_internal( GGML_ASSERT(false); #endif } + if (lctx.mpi_rank == 0) { if (lctx.mpi_size > 1) { #ifdef GGML_USE_MPI @@ -1688,7 +1684,11 @@ static bool llama_eval_internal( // update kv token count lctx.kv_self.n = n_past + N; - if (lctx.mpi_rank == 0) { +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(lctx.ctx_mpi) == 0) { +#else + { +#endif // extract logits { auto & logits_out = lctx.logits; @@ -2659,14 +2659,6 @@ struct llama_context * llama_new_context_with_model( ctx->rng = std::mt19937(params.seed); ctx->logits_all = params.logits_all; -#ifdef GGML_USE_MPI - MPI_Comm_size(MPI_COMM_WORLD, &ctx->mpi_size); - MPI_Comm_rank(MPI_COMM_WORLD, &ctx->mpi_rank); -#else - ctx->mpi_size = 1; - ctx->mpi_rank = 0; -#endif - ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; // reserve memory for context buffers @@ -2739,15 +2731,17 @@ struct llama_context * llama_new_context_with_model( } #endif - if (ctx->mpi_rank > 0) { +#ifdef GGML_USE_MPI + ctx->ctx_mpi = ggml_mpi_init(); + + if (ggml_mpi_rank(ctx->ctx_mpi) > 0) { // Enter a blocking eval loop with dummy input, letting rank=0 drive the process const std::vector tmp = { llama_token_bos(), }; - while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)); -#ifdef GGML_USE_MPI - MPI_Finalize(); -#endif + while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {}; + llama_backend_free(); exit(1); } +#endif return ctx; } @@ -3425,13 +3419,6 @@ int llama_eval( int n_tokens, int n_past, int n_threads) { -#ifdef GGML_USE_MPI - // Synchronize the worker node parameters with the root node - MPI_Barrier(MPI_COMM_WORLD); - MPI_Bcast(&n_past, 1, MPI_INT, 0, MPI_COMM_WORLD); - MPI_Bcast(&n_tokens, 1, MPI_INT, 0, MPI_COMM_WORLD); - MPI_Bcast(&n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD); -#endif if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; diff --git a/llama.h b/llama.h index b90c523555da8..686463aa25af8 100644 --- a/llama.h +++ b/llama.h @@ -158,9 +158,9 @@ extern "C" { // Initialize the llama + ggml backend // If numa is true, use NUMA optimizations // Call once at the start of the program - LLAMA_API void llama_init_backend(bool numa); + LLAMA_API void llama_backend_init(bool numa); // Call once at the end of the program - currently only used for MPI - LLAMA_API void llama_finalize_backend(); + LLAMA_API void llama_backend_free(); LLAMA_API int64_t llama_time_us(); From e339d355795936ad54bd26f212ae42f90a0128b8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 9 Jul 2023 14:42:36 +0300 Subject: [PATCH 10/23] mpi : add names for layer inputs + prep ggml_mpi_graph_compute() --- ggml-mpi.c | 9 +++++++ ggml-mpi.h | 8 ++++++ llama.cpp | 76 ++++++++++++++++++++---------------------------------- 3 files changed, 45 insertions(+), 48 deletions(-) diff --git a/ggml-mpi.c b/ggml-mpi.c index b68e2c42b4432..16a088e57cb06 100644 --- a/ggml-mpi.c +++ b/ggml-mpi.c @@ -135,3 +135,12 @@ struct ggml_tensor * ggml_mpi_eval_init( return res; } + +void ggml_mpi_graph_compute( + struct ggml_mpi_context * ctx_mpi, + struct ggml_cgraph * gf, + int n_layers, + int n_embd, + int n_tokens) { + +} diff --git a/ggml-mpi.h b/ggml-mpi.h index 157c6255d4b75..fc3d0ce5172ee 100644 --- a/ggml-mpi.h +++ b/ggml-mpi.h @@ -2,6 +2,7 @@ struct ggml_context; struct ggml_tensor; +struct ggml_cgraph; #ifdef __cplusplus extern "C" { @@ -35,6 +36,13 @@ struct ggml_tensor * ggml_mpi_eval_init( int * n_past, int * n_threads); +void ggml_mpi_graph_compute( + struct ggml_mpi_context * ctx_mpi, + struct ggml_cgraph * gf, + int n_layers, + int n_embd, + int n_tokens); + #ifdef __cplusplus } #endif diff --git a/llama.cpp b/llama.cpp index d84e827c33025..88ccd4999a549 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1361,20 +1361,20 @@ static bool llama_eval_internal( offload_func_t offload_func_v = llama_nop; #ifdef GGML_USE_CUBLAS - if (n_gpu_layers > n_layer) { - offload_func_nr = ggml_cuda_assign_buffers; - } - if (n_gpu_layers > n_layer + 1) { - offload_func_v = ggml_cuda_assign_buffers; - } - if (n_gpu_layers > n_layer + 2) { - offload_func_kq = ggml_cuda_assign_buffers; - } + if (n_gpu_layers > n_layer) { + offload_func_nr = ggml_cuda_assign_buffers; + } + if (n_gpu_layers > n_layer + 1) { + offload_func_v = ggml_cuda_assign_buffers; + } + if (n_gpu_layers > n_layer + 2) { + offload_func_kq = ggml_cuda_assign_buffers; + } #endif // GGML_USE_CUBLAS - // EMM TODO distribute work more evenly - maybe rank=0 gets the smallest amount? - int slice_size = (n_layer + (lctx.mpi_size - 1)) / lctx.mpi_size; - for (int il = lctx.mpi_rank * slice_size; il < n_layer && il < (lctx.mpi_rank + 1) * slice_size; ++il) { + for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); + offload_func_t offload_func = llama_nop; #ifdef GGML_USE_CUBLAS @@ -1588,46 +1588,24 @@ static bool llama_eval_internal( // used at the end to optionally extract the embeddings struct ggml_tensor * embeddings = NULL; - if (lctx.mpi_size > 1) { -#ifdef GGML_USE_MPI - cur = ggml_mpi_send_tensor(ctx0, cur, (lctx.mpi_rank+1)%lctx.mpi_size); - ggml_set_name(cur, "mpi_send"); -#else - GGML_ASSERT(false); -#endif - } - - if (lctx.mpi_rank == 0) { - if (lctx.mpi_size > 1) { -#ifdef GGML_USE_MPI - cur = ggml_mpi_recv_tensor(ctx0, cur, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N), - lctx.mpi_size-1); - ggml_set_name(cur, "mpi_recv"); -#else - GGML_ASSERT(false); -#endif - } - // norm - { - cur = ggml_rms_norm(ctx0, cur); - offload_func_nr(cur); - ggml_set_name(cur, "rms_norm_2"); - - // cur = cur*norm(broadcasted) - cur = ggml_mul(ctx0, cur, model.norm); - // offload_func_nr(cur); // TODO CPU + GPU mirrored backend - ggml_set_name(cur, "result_norm"); - - embeddings = cur; - } + // norm + { + cur = ggml_rms_norm(ctx0, inpL); + offload_func_nr(cur); + ggml_set_name(cur, "rms_norm_2"); + // cur = cur*norm(broadcasted) + cur = ggml_mul(ctx0, cur, model.norm); + // offload_func_nr(cur); // TODO CPU + GPU mirrored backend + ggml_set_name(cur, "result_norm"); - // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); - ggml_set_name(cur, "result_output"); + embeddings = cur; } + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + ggml_set_name(cur, "result_output"); + lctx.use_buf(ctx0, -1); // logits -> probs @@ -1659,6 +1637,8 @@ static bool llama_eval_internal( ggml_graph_compute(ctx0, &gf); } +#elif GGML_USE_MPI + ggml_mpi_graph_compute(lctx.ctx_mpi, &gf, n_layer, n_embd, n_tokens); #else ggml_graph_compute(ctx0, &gf); #endif From 01abb3b3b95750db515d6846fcc6240c86fe8ed4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 9 Jul 2023 16:04:27 +0300 Subject: [PATCH 11/23] mpi : move all MPI logic into ggml-mpi Not tested yet --- ggml-mpi.c | 216 +++++++++++++++++++++++++++++++---------------------- ggml-mpi.h | 19 +---- llama.cpp | 14 ++-- 3 files changed, 136 insertions(+), 113 deletions(-) diff --git a/ggml-mpi.c b/ggml-mpi.c index 16a088e57cb06..8bf4468a19510 100644 --- a/ggml-mpi.c +++ b/ggml-mpi.c @@ -6,84 +6,15 @@ #include #include +#include -#define UNUSED GGML_UNUSED - -struct ggml_mpi_tensor_info { - int rank; -}; - -// ggml_compute_forward_send - -static void ggml_mpi_compute_forward_send( - struct ggml_tensor * src, - const struct ggml_tensor * orig) { - UNUSED(orig); - GGML_ASSERT(src->type == GGML_TYPE_F32); - - int my_rank; - MPI_Comm_rank(MPI_COMM_WORLD, &my_rank); - - int dst_rank = ((struct ggml_mpi_tensor_info *)src->extra)->rank; - // fprintf(stderr, "(%d) Sending to (%d)\n", my_rank, (int)dst->extra); - int retval = MPI_Send(src->data, ggml_nelements(src), MPI_FLOAT, dst_rank, 0, MPI_COMM_WORLD); - // fprintf(stderr, "(%d) Sent to (%d)\n", my_rank, (int)dst->extra); - GGML_ASSERT(retval == MPI_SUCCESS); -} - -// ggml_compute_forward_recv - -static void ggml_mpi_compute_forward_recv( - struct ggml_tensor * dst, - const struct ggml_tensor * orig, - const struct ggml_tensor * parent) { - UNUSED(parent); - UNUSED(orig); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - MPI_Status status; - - int my_rank; - MPI_Comm_rank(MPI_COMM_WORLD, &my_rank); - - int src_rank = ((struct ggml_mpi_tensor_info *)dst->extra)->rank; - // fprintf(stderr, "(%d) Receiving from (%d)\n", my_rank, src_extra); - int retval = MPI_Recv(dst->data, ggml_nelements(dst), MPI_FLOAT, src_rank, MPI_ANY_TAG, MPI_COMM_WORLD, &status); - // fprintf(stderr, "(%d) Received from (%d)\n", my_rank, src_extra); - GGML_ASSERT(retval == MPI_SUCCESS); -} +#define MIN(a, b) ((a) < (b) ? (a) : (b)) -struct ggml_tensor * ggml_mpi_send_tensor( - struct ggml_context * ctx, - struct ggml_tensor * src, - int dst_rank) { - struct ggml_tensor * result = ggml_map_custom1_inplace_f32(ctx, src, ggml_mpi_compute_forward_send); - - // TODO how/when to free this struct? - struct ggml_mpi_tensor_info *info = calloc(1, sizeof(struct ggml_mpi_tensor_info)); - info->rank = dst_rank; - result->extra = info; - - return result; -} - -struct ggml_tensor * ggml_mpi_recv_tensor( - struct ggml_context * ctx, - struct ggml_tensor * parent, - struct ggml_tensor * dst, - int src_rank) { - struct ggml_tensor * result = ggml_map_custom2_inplace_f32(ctx, dst, parent, ggml_mpi_compute_forward_recv); - - // TODO how/when to free this struct? - struct ggml_mpi_tensor_info *info = calloc(1, sizeof(struct ggml_mpi_tensor_info)); - info->rank = src_rank; - result->extra = info; - - return result; -} +#define UNUSED GGML_UNUSED struct ggml_mpi_context { - int mpi_rank; - int mpi_size; + int rank; + int size; }; void ggml_mpi_backend_init(void) { @@ -97,8 +28,8 @@ void ggml_mpi_backend_free(void) { struct ggml_mpi_context * ggml_mpi_init(void) { struct ggml_mpi_context * ctx = calloc(1, sizeof(struct ggml_mpi_context)); - MPI_Comm_rank(MPI_COMM_WORLD, &ctx->mpi_rank); - MPI_Comm_size(MPI_COMM_WORLD, &ctx->mpi_size); + MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank); + MPI_Comm_size(MPI_COMM_WORLD, &ctx->size); return ctx; } @@ -108,17 +39,15 @@ void ggml_mpi_free(struct ggml_mpi_context * ctx) { } int ggml_mpi_rank(struct ggml_mpi_context * ctx) { - return ctx->mpi_rank; + return ctx->rank; } -struct ggml_tensor * ggml_mpi_eval_init( +void ggml_mpi_eval_init( struct ggml_mpi_context * ctx_mpi, - struct ggml_context * ctx, - int n_embd, int * n_tokens, int * n_past, int * n_threads) { - struct ggml_tensor * res = NULL; + UNUSED(ctx_mpi); // synchronize the worker node parameters with the root node MPI_Barrier(MPI_COMM_WORLD); @@ -126,21 +55,130 @@ struct ggml_tensor * ggml_mpi_eval_init( MPI_Bcast(n_tokens, 1, MPI_INT, 0, MPI_COMM_WORLD); MPI_Bcast(n_past, 1, MPI_INT, 0, MPI_COMM_WORLD); MPI_Bcast(n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD); +} - if (ctx_mpi->mpi_rank > 0) { - res = ggml_mpi_recv_tensor(ctx, NULL, - ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, *n_tokens), ctx_mpi->mpi_rank - 1); - ggml_set_name(res, "mpi_recv"); +int ggml_graph_get_node_idx( struct ggml_cgraph * gf, const char * name) { + struct ggml_tensor * t = ggml_graph_get_tensor(gf, name); + if (t == NULL) { + fprintf(stderr, "%s: tensor %s not found\n", __func__, name); + return -1; } - return res; + for (int i = 0; i < gf->n_nodes; i++) { + if (gf->nodes[i] == t) { + return i; + } + } + + fprintf(stderr, "%s: tensor %s not found in graph (should not happen)\n", __func__, name); + return -1; } void ggml_mpi_graph_compute( struct ggml_mpi_context * ctx_mpi, + struct ggml_context * ctx, struct ggml_cgraph * gf, - int n_layers, - int n_embd, - int n_tokens) { + int n_layers) { + const int mpi_rank = ctx_mpi->rank; + const int mpi_size = ctx_mpi->size; + + struct ggml_tensor * embd = ggml_graph_get_tensor(gf, "layer_inp_0"); + if (embd == NULL) { + fprintf(stderr, "%s: tensor 'embd' not found\n", __func__); + return; + } + + GGML_ASSERT(embd == gf->nodes[0]); + + // distribute the compute graph into slices across the MPI nodes + // + // the main node (0) processes the last layers + the remainder of the compute graph + // and is responsible to pass the input embeddings to the first node (1) + // + // node 1: [( 0) * n_per_node, ( 1) * n_per_node) + // node 2: [( 1) * n_per_node, ( 2) * n_per_node) + // ... + // node n-1: [(n-2) * n_per_node, (n-1) * n_per_node) + // node 0: [(n-1) * n_per_node, n_nodes) + // + if (mpi_rank > 0) { + // recv input data for each node into the "embd" tensor (i.e. the first node in the compute graph) + { + MPI_Status status; UNUSED(status); + + const int mpi_rank_src = mpi_rank - 1; + + // fprintf(stderr, "(%d) Receiving from (%d)\n", mpi_rank, mpi_rank_src); + const int retval = MPI_Recv(embd, ggml_nelements(embd), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + GGML_ASSERT(retval == MPI_SUCCESS); + // fprintf(stderr, "(%d) Received from (%d)\n", mpi_rank, mpi_rank_src); + } + } else { + // node 0 sends the input data to node 1 + { + const int mpi_rank_dst = mpi_rank + 1; + + const int retval = MPI_Send(embd, ggml_nelements(embd), MPI_FLOAT, mpi_rank_dst, 0, MPI_COMM_WORLD); + GGML_ASSERT(retval == MPI_SUCCESS); + // fprintf(stderr, "(%d) Sent to (%d)\n", mpi_rank, mpi_rank_dst); + } + + // recv the output data from the last node + { + MPI_Status status; UNUSED(status); + + const int mpi_rank_src = mpi_size - 1; + + const int retval = MPI_Recv(embd, ggml_nelements(embd), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + GGML_ASSERT(retval == MPI_SUCCESS); + } + } + + { + const int n_per_node = (n_layers + (mpi_size - 1)) / mpi_size; + + const int mpi_idx = mpi_rank > 0 ? mpi_rank - 1 : mpi_size - 1; + + const int il0 = (mpi_idx + 0) * n_per_node; + const int il1 = MIN(n_layers, (mpi_idx + 1) * n_per_node); + + char name_l0[64]; + char name_l1[64]; + + snprintf(name_l0, sizeof(name_l0), "layer_inp_%d", il0); + snprintf(name_l1, sizeof(name_l1), "layer_inp_%d", il1); + const int idx_l0 = ggml_graph_get_node_idx(gf, name_l0); + const int idx_l1 = mpi_rank > 0 ? ggml_graph_get_node_idx(gf, name_l1) : gf->n_nodes; + + if (idx_l0 < 0 || idx_l1 < 0) { + fprintf(stderr, "%s: layer input nodes not found\n", __func__); + return; + } + + // attach the input data to the first layer for this node + gf->nodes[idx_l0 + 1]->src0 = gf->nodes[1]->src0; + gf->nodes[idx_l0 + 1]->src1 = gf->nodes[1]->src1; + + memcpy(gf->nodes[idx_l0 + 1]->opt, gf->nodes[1]->opt, sizeof(gf->nodes[idx_l0 + 1]->opt)); + + for (int i = 1; i < idx_l1 - idx_l0; i++) { + gf->nodes[i] = gf->nodes[idx_l0 + i]; + gf->grads[i] = gf->grads[idx_l0 + i]; + } + + gf->n_nodes = idx_l1 - idx_l0; + } + + ggml_graph_compute(ctx, gf); + + // send the output data to the next node + if (mpi_rank > 0) { + struct ggml_tensor * output = gf->nodes[gf->n_nodes - 1]; + + const int mpi_rank_dst = (mpi_rank + 1) % mpi_size; + + const int retval = MPI_Send(output, ggml_nelements(output), MPI_FLOAT, mpi_rank_dst, 0, MPI_COMM_WORLD); + GGML_ASSERT(retval == MPI_SUCCESS); + } } diff --git a/ggml-mpi.h b/ggml-mpi.h index fc3d0ce5172ee..02e125cfb624b 100644 --- a/ggml-mpi.h +++ b/ggml-mpi.h @@ -8,16 +8,6 @@ struct ggml_cgraph; extern "C" { #endif -struct ggml_tensor * ggml_mpi_send_tensor( - struct ggml_context * ctx, - struct ggml_tensor * src, - int dst_rank); -struct ggml_tensor * ggml_mpi_recv_tensor( - struct ggml_context * ctx, - struct ggml_tensor * parent, - struct ggml_tensor * dst, - int src_rank); - struct ggml_mpi_context; void ggml_mpi_backend_init(void); @@ -28,20 +18,17 @@ void ggml_mpi_free(struct ggml_mpi_context * ctx); int ggml_mpi_rank(struct ggml_mpi_context * ctx); -struct ggml_tensor * ggml_mpi_eval_init( +void ggml_mpi_eval_init( struct ggml_mpi_context * ctx_mpi, - struct ggml_context * ctx, - int n_embd, int * n_tokens, int * n_past, int * n_threads); void ggml_mpi_graph_compute( struct ggml_mpi_context * ctx_mpi, + struct ggml_context * ctx, struct ggml_cgraph * gf, - int n_layers, - int n_embd, - int n_tokens); + int n_layers); #ifdef __cplusplus } diff --git a/llama.cpp b/llama.cpp index 88ccd4999a549..fa8030c36229c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1332,15 +1332,11 @@ static bool llama_eval_internal( struct ggml_tensor * inpL; #ifdef GGML_USE_MPI - inpL = ggml_mpi_eval_init(lctx.ctx_mpi, ctx0, n_embd, &n_tokens, &n_past, &n_threads); - - if (inpL) { - // only rank 0 loads uses the input - } else + ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); #endif + if (tokens) { struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - ggml_set_name(embd, "embd"); memcpy(embd->data, tokens, N*ggml_element_size(embd)); inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd); } else { @@ -1348,6 +1344,8 @@ static bool llama_eval_internal( memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); } + ggml_set_name(inpL, "embd"); + const int i_gpu_start = n_layer - n_gpu_layers; (void) i_gpu_start; @@ -1638,7 +1636,7 @@ static bool llama_eval_internal( ggml_graph_compute(ctx0, &gf); } #elif GGML_USE_MPI - ggml_mpi_graph_compute(lctx.ctx_mpi, &gf, n_layer, n_embd, n_tokens); + ggml_mpi_graph_compute(lctx.ctx_mpi, ctx0, &gf, n_layer, n_embd, n_tokens); #else ggml_graph_compute(ctx0, &gf); #endif @@ -2716,7 +2714,7 @@ struct llama_context * llama_new_context_with_model( if (ggml_mpi_rank(ctx->ctx_mpi) > 0) { // Enter a blocking eval loop with dummy input, letting rank=0 drive the process - const std::vector tmp = { llama_token_bos(), }; + const std::vector tmp(ctx->model.hparams.n_ctx, llama_token_bos()); while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {}; llama_backend_free(); exit(1); From c717c5185f76a07422fbf6d66b58bfe7b6f0fd9a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 9 Jul 2023 16:40:16 +0300 Subject: [PATCH 12/23] mpi : various fixes - communication now works but results are wrong --- ggml-mpi.c | 21 ++++++++++++++------- llama.cpp | 10 +++++----- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/ggml-mpi.c b/ggml-mpi.c index 8bf4468a19510..e890d24d1316f 100644 --- a/ggml-mpi.c +++ b/ggml-mpi.c @@ -108,19 +108,17 @@ void ggml_mpi_graph_compute( const int mpi_rank_src = mpi_rank - 1; - // fprintf(stderr, "(%d) Receiving from (%d)\n", mpi_rank, mpi_rank_src); - const int retval = MPI_Recv(embd, ggml_nelements(embd), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + //printf("%s: node %d: waiting for %d elements from %d\n", __func__, mpi_rank, (int) ggml_nelements(embd), mpi_rank_src); + const int retval = MPI_Recv(embd->data, ggml_nelements(embd), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); GGML_ASSERT(retval == MPI_SUCCESS); - // fprintf(stderr, "(%d) Received from (%d)\n", mpi_rank, mpi_rank_src); } } else { // node 0 sends the input data to node 1 { const int mpi_rank_dst = mpi_rank + 1; - const int retval = MPI_Send(embd, ggml_nelements(embd), MPI_FLOAT, mpi_rank_dst, 0, MPI_COMM_WORLD); + const int retval = MPI_Send(embd->data, ggml_nelements(embd), MPI_FLOAT, mpi_rank_dst, 0, MPI_COMM_WORLD); GGML_ASSERT(retval == MPI_SUCCESS); - // fprintf(stderr, "(%d) Sent to (%d)\n", mpi_rank, mpi_rank_dst); } // recv the output data from the last node @@ -129,7 +127,8 @@ void ggml_mpi_graph_compute( const int mpi_rank_src = mpi_size - 1; - const int retval = MPI_Recv(embd, ggml_nelements(embd), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + //fprintf(stderr, "%s: node %d: waiting for %d elements from %d\n", __func__, mpi_rank, (int) ggml_nelements(embd), mpi_rank_src); + const int retval = MPI_Recv(embd->data, ggml_nelements(embd), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); GGML_ASSERT(retval == MPI_SUCCESS); } } @@ -165,20 +164,28 @@ void ggml_mpi_graph_compute( for (int i = 1; i < idx_l1 - idx_l0; i++) { gf->nodes[i] = gf->nodes[idx_l0 + i]; gf->grads[i] = gf->grads[idx_l0 + i]; + + //fprintf(stderr, "%s: node %d: %d -> %d\n", __func__, mpi_rank, idx_l0 + i, i); } gf->n_nodes = idx_l1 - idx_l0; + + //fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, mpi_rank, gf->n_nodes, il0, il1); } ggml_graph_compute(ctx, gf); + //fprintf(stderr, "%s: node %d: done\n", __func__, mpi_rank); + // send the output data to the next node if (mpi_rank > 0) { struct ggml_tensor * output = gf->nodes[gf->n_nodes - 1]; const int mpi_rank_dst = (mpi_rank + 1) % mpi_size; - const int retval = MPI_Send(output, ggml_nelements(output), MPI_FLOAT, mpi_rank_dst, 0, MPI_COMM_WORLD); + //fprintf(stderr, "%s: node %d: sending %d elements to node %d\n", __func__, mpi_rank, ggml_nelements(output), mpi_rank_dst); + + const int retval = MPI_Send(output->data, ggml_nelements(output), MPI_FLOAT, mpi_rank_dst, 0, MPI_COMM_WORLD); GGML_ASSERT(retval == MPI_SUCCESS); } } diff --git a/llama.cpp b/llama.cpp index fa8030c36229c..08a5bd2841500 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1287,6 +1287,10 @@ static bool llama_eval_internal( LLAMA_ASSERT((!tokens && embd) || (tokens && !embd)); +#ifdef GGML_USE_MPI + ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); +#endif + // enforce that the first token is BOS if (tokens && n_past == 0 && tokens[0] != llama_token_bos()) { fprintf(stderr, "%s: first token must be BOS\n", __func__); @@ -1331,10 +1335,6 @@ static bool llama_eval_internal( struct ggml_tensor * cur; struct ggml_tensor * inpL; -#ifdef GGML_USE_MPI - ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); -#endif - if (tokens) { struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); memcpy(embd->data, tokens, N*ggml_element_size(embd)); @@ -1636,7 +1636,7 @@ static bool llama_eval_internal( ggml_graph_compute(ctx0, &gf); } #elif GGML_USE_MPI - ggml_mpi_graph_compute(lctx.ctx_mpi, ctx0, &gf, n_layer, n_embd, n_tokens); + ggml_mpi_graph_compute(lctx.ctx_mpi, ctx0, &gf, n_layer); #else ggml_graph_compute(ctx0, &gf); #endif From ef37dd14e744b6323f95196b00b88f29512de697 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 9 Jul 2023 17:01:08 +0300 Subject: [PATCH 13/23] mpi : fix output tensor after MPI compute (still not working) --- ggml-mpi.c | 11 ++++++++--- llama.cpp | 2 ++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/ggml-mpi.c b/ggml-mpi.c index e890d24d1316f..46ee5bacb2bd5 100644 --- a/ggml-mpi.c +++ b/ggml-mpi.c @@ -57,7 +57,7 @@ void ggml_mpi_eval_init( MPI_Bcast(n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD); } -int ggml_graph_get_node_idx( struct ggml_cgraph * gf, const char * name) { +int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) { struct ggml_tensor * t = ggml_graph_get_tensor(gf, name); if (t == NULL) { fprintf(stderr, "%s: tensor %s not found\n", __func__, name); @@ -141,8 +141,8 @@ void ggml_mpi_graph_compute( const int il0 = (mpi_idx + 0) * n_per_node; const int il1 = MIN(n_layers, (mpi_idx + 1) * n_per_node); - char name_l0[64]; - char name_l1[64]; + char name_l0[GGML_MAX_NAME]; + char name_l1[GGML_MAX_NAME]; snprintf(name_l0, sizeof(name_l0), "layer_inp_%d", il0); snprintf(name_l1, sizeof(name_l1), "layer_inp_%d", il1); @@ -175,6 +175,11 @@ void ggml_mpi_graph_compute( ggml_graph_compute(ctx, gf); + //if (mpi_rank == 0) { + // ggml_graph_print(gf); + // ggml_graph_dump_dot(gf, NULL, "llama.dot"); + //} + //fprintf(stderr, "%s: node %d: done\n", __func__, mpi_rank); // send the output data to the next node diff --git a/llama.cpp b/llama.cpp index 08a5bd2841500..4bf1e75d2f787 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1637,6 +1637,8 @@ static bool llama_eval_internal( } #elif GGML_USE_MPI ggml_mpi_graph_compute(lctx.ctx_mpi, ctx0, &gf, n_layer); + + cur = gf.nodes[gf.n_nodes - 1]; #else ggml_graph_compute(ctx0, &gf); #endif From beadbf33809f5dd7761de538a19d2ca3dcff9446 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 9 Jul 2023 18:26:20 +0300 Subject: [PATCH 14/23] mpi : fix inference --- ggml-mpi.c | 66 +++++++++++++++++++++++++++++++++--------------------- llama.cpp | 10 ++++----- 2 files changed, 46 insertions(+), 30 deletions(-) diff --git a/ggml-mpi.c b/ggml-mpi.c index 46ee5bacb2bd5..6dd7e7b76299b 100644 --- a/ggml-mpi.c +++ b/ggml-mpi.c @@ -74,6 +74,7 @@ int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) { return -1; } +// TODO: there are many improvements that can be done to this implementation void ggml_mpi_graph_compute( struct ggml_mpi_context * ctx_mpi, struct ggml_context * ctx, @@ -82,18 +83,24 @@ void ggml_mpi_graph_compute( const int mpi_rank = ctx_mpi->rank; const int mpi_size = ctx_mpi->size; - struct ggml_tensor * embd = ggml_graph_get_tensor(gf, "layer_inp_0"); - if (embd == NULL) { - fprintf(stderr, "%s: tensor 'embd' not found\n", __func__); + struct ggml_tensor * inp_tokens = ggml_graph_get_tensor(gf, "inp_tokens"); + if (inp_tokens == NULL) { + fprintf(stderr, "%s: tensor 'inp_tokens' not found\n", __func__); return; } - GGML_ASSERT(embd == gf->nodes[0]); + struct ggml_tensor * inp0 = ggml_graph_get_tensor(gf, "layer_inp_0"); + if (inp0 == NULL) { + fprintf(stderr, "%s: tensor 'inp0' not found\n", __func__); + return; + } + + GGML_ASSERT(inp0 == gf->nodes[0]); // distribute the compute graph into slices across the MPI nodes // // the main node (0) processes the last layers + the remainder of the compute graph - // and is responsible to pass the input embeddings to the first node (1) + // and is responsible to pass the input tokens to the first node (1) // // node 1: [( 0) * n_per_node, ( 1) * n_per_node) // node 2: [( 1) * n_per_node, ( 2) * n_per_node) @@ -102,22 +109,28 @@ void ggml_mpi_graph_compute( // node 0: [(n-1) * n_per_node, n_nodes) // if (mpi_rank > 0) { - // recv input data for each node into the "embd" tensor (i.e. the first node in the compute graph) - { + if (mpi_rank == 1) { // the first node receives the input tokens from the main node MPI_Status status; UNUSED(status); const int mpi_rank_src = mpi_rank - 1; - //printf("%s: node %d: waiting for %d elements from %d\n", __func__, mpi_rank, (int) ggml_nelements(embd), mpi_rank_src); - const int retval = MPI_Recv(embd->data, ggml_nelements(embd), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + const int retval = MPI_Recv(inp_tokens->data, ggml_nelements(inp_tokens), MPI_INT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + GGML_ASSERT(retval == MPI_SUCCESS); + } else { // recv input data for each node into the "inp0" tensor (i.e. the first node in the compute graph) + MPI_Status status; UNUSED(status); + + const int mpi_rank_src = mpi_rank - 1; + + //printf("%s: node %d: waiting for %d elements from %d\n", __func__, mpi_rank, (int) ggml_nelements(inp0), mpi_rank_src); + const int retval = MPI_Recv(inp0->data, ggml_nelements(inp0), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); GGML_ASSERT(retval == MPI_SUCCESS); } - } else { - // node 0 sends the input data to node 1 + } else if (mpi_size > 1) { + // node 0 sends the input tokens to node 1 { const int mpi_rank_dst = mpi_rank + 1; - const int retval = MPI_Send(embd->data, ggml_nelements(embd), MPI_FLOAT, mpi_rank_dst, 0, MPI_COMM_WORLD); + const int retval = MPI_Send(inp_tokens->data, ggml_nelements(inp_tokens), MPI_INT, mpi_rank_dst, 0, MPI_COMM_WORLD); GGML_ASSERT(retval == MPI_SUCCESS); } @@ -127,8 +140,8 @@ void ggml_mpi_graph_compute( const int mpi_rank_src = mpi_size - 1; - //fprintf(stderr, "%s: node %d: waiting for %d elements from %d\n", __func__, mpi_rank, (int) ggml_nelements(embd), mpi_rank_src); - const int retval = MPI_Recv(embd->data, ggml_nelements(embd), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + //fprintf(stderr, "%s: node %d: waiting for %d elements from %d\n", __func__, mpi_rank, (int) ggml_nelements(inp0), mpi_rank_src); + const int retval = MPI_Recv(inp0->data, ggml_nelements(inp0), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); GGML_ASSERT(retval == MPI_SUCCESS); } } @@ -148,7 +161,7 @@ void ggml_mpi_graph_compute( snprintf(name_l1, sizeof(name_l1), "layer_inp_%d", il1); const int idx_l0 = ggml_graph_get_node_idx(gf, name_l0); - const int idx_l1 = mpi_rank > 0 ? ggml_graph_get_node_idx(gf, name_l1) : gf->n_nodes; + const int idx_l1 = mpi_rank > 0 ? ggml_graph_get_node_idx(gf, name_l1) + 1 : gf->n_nodes; if (idx_l0 < 0 || idx_l1 < 0) { fprintf(stderr, "%s: layer input nodes not found\n", __func__); @@ -156,16 +169,24 @@ void ggml_mpi_graph_compute( } // attach the input data to the first layer for this node - gf->nodes[idx_l0 + 1]->src0 = gf->nodes[1]->src0; - gf->nodes[idx_l0 + 1]->src1 = gf->nodes[1]->src1; - - memcpy(gf->nodes[idx_l0 + 1]->opt, gf->nodes[1]->opt, sizeof(gf->nodes[idx_l0 + 1]->opt)); + for (int i = idx_l0; i < idx_l1; i++) { + if (gf->nodes[i]->src0 == gf->nodes[idx_l0]) { + gf->nodes[i]->src0 = inp0; + } + if (gf->nodes[i]->src1 == gf->nodes[idx_l0]) { + gf->nodes[i]->src1 = inp0; + } + } + // TODO: instead of rearranging the nodes, we should be able to execute a subset of the compute graph for (int i = 1; i < idx_l1 - idx_l0; i++) { gf->nodes[i] = gf->nodes[idx_l0 + i]; gf->grads[i] = gf->grads[idx_l0 + i]; + } - //fprintf(stderr, "%s: node %d: %d -> %d\n", __func__, mpi_rank, idx_l0 + i, i); + // the first node performs the "get_rows" operation, the rest of the nodes get the data from the previous node + if (mpi_idx != 0) { + gf->nodes[0]->op = GGML_OP_NONE; } gf->n_nodes = idx_l1 - idx_l0; @@ -175,11 +196,6 @@ void ggml_mpi_graph_compute( ggml_graph_compute(ctx, gf); - //if (mpi_rank == 0) { - // ggml_graph_print(gf); - // ggml_graph_dump_dot(gf, NULL, "llama.dot"); - //} - //fprintf(stderr, "%s: node %d: done\n", __func__, mpi_rank); // send the output data to the next node diff --git a/llama.cpp b/llama.cpp index 4bf1e75d2f787..b7aad4c6eebcd 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1336,16 +1336,16 @@ static bool llama_eval_internal( struct ggml_tensor * inpL; if (tokens) { - struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - memcpy(embd->data, tokens, N*ggml_element_size(embd)); - inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd); + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); + ggml_set_name(inp_tokens, "inp_tokens"); + + inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); } else { inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); } - ggml_set_name(inpL, "embd"); - const int i_gpu_start = n_layer - n_gpu_layers; (void) i_gpu_start; From 9da9d26c70a1cf7793cfff0b2c6d03eeff1eaa36 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 9 Jul 2023 18:38:32 +0300 Subject: [PATCH 15/23] mpi : minor --- ggml-mpi.c | 4 ++-- llama.cpp | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/ggml-mpi.c b/ggml-mpi.c index 6dd7e7b76299b..4bde418089f8c 100644 --- a/ggml-mpi.c +++ b/ggml-mpi.c @@ -6,7 +6,6 @@ #include #include -#include #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -168,7 +167,8 @@ void ggml_mpi_graph_compute( return; } - // attach the input data to the first layer for this node + // attach the input data to all nodes that need it + // TODO: not great - should be able to do this without modifying the compute graph (see next TODO below) for (int i = idx_l0; i < idx_l1; i++) { if (gf->nodes[i]->src0 == gf->nodes[idx_l0]) { gf->nodes[i]->src0 = inp0; diff --git a/llama.cpp b/llama.cpp index b7aad4c6eebcd..8c2d0ea4b986a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1342,6 +1342,10 @@ static bool llama_eval_internal( inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); } else { +#ifdef GGML_USE_MPI + GGML_ASSERT(false && "not implemented"); +#endif + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); } From 4a9a4748e93abd23cc070735c93fdc7462982273 Mon Sep 17 00:00:00 2001 From: Evan Miller Date: Sun, 9 Jul 2023 15:05:58 -0400 Subject: [PATCH 16/23] Add OpenMPI to GH action --- .github/workflows/build.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d62e93805b0f0..0494fdc9c8eef 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -107,6 +107,10 @@ jobs: ubuntu-latest-cmake-mpi: runs-on: ubuntu-latest + strategy: + matrix: + mpi_library: [mpich, libopenmpi-dev] + steps: - name: Clone id: checkout @@ -116,7 +120,7 @@ jobs: id: depends run: | sudo apt-get update - sudo apt-get install build-essential mpich + sudo apt-get install build-essential ${{ matrix.mpi_library }} - name: Build id: cmake_build From 03cc12be0d08f708df81d52583205628c5f99984 Mon Sep 17 00:00:00 2001 From: Evan Miller Date: Sun, 9 Jul 2023 15:10:43 -0400 Subject: [PATCH 17/23] [mpi] continue-on-error: true --- .github/workflows/build.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0494fdc9c8eef..b6e21b4ec77ca 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -107,6 +107,8 @@ jobs: ubuntu-latest-cmake-mpi: runs-on: ubuntu-latest + continue-on-error: true + strategy: matrix: mpi_library: [mpich, libopenmpi-dev] From 166db36c51e23b984c7209049244d6e9d9c6a6e2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 9 Jul 2023 22:23:04 +0300 Subject: [PATCH 18/23] mpi : fix after master merge --- ggml-mpi.c | 5 +++-- ggml-mpi.h | 3 ++- llama.cpp | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/ggml-mpi.c b/ggml-mpi.c index 4bde418089f8c..70639e0788e39 100644 --- a/ggml-mpi.c +++ b/ggml-mpi.c @@ -78,7 +78,8 @@ void ggml_mpi_graph_compute( struct ggml_mpi_context * ctx_mpi, struct ggml_context * ctx, struct ggml_cgraph * gf, - int n_layers) { + int n_layers, + int n_threads) { const int mpi_rank = ctx_mpi->rank; const int mpi_size = ctx_mpi->size; @@ -194,7 +195,7 @@ void ggml_mpi_graph_compute( //fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, mpi_rank, gf->n_nodes, il0, il1); } - ggml_graph_compute(ctx, gf); + ggml_graph_compute_with_ctx(ctx, gf, n_threads); //fprintf(stderr, "%s: node %d: done\n", __func__, mpi_rank); diff --git a/ggml-mpi.h b/ggml-mpi.h index 02e125cfb624b..2ad0a43864691 100644 --- a/ggml-mpi.h +++ b/ggml-mpi.h @@ -28,7 +28,8 @@ void ggml_mpi_graph_compute( struct ggml_mpi_context * ctx_mpi, struct ggml_context * ctx, struct ggml_cgraph * gf, - int n_layers); + int n_layers, + int n_threads); #ifdef __cplusplus } diff --git a/llama.cpp b/llama.cpp index 325db7d56b4c5..322e37a7de72a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1657,7 +1657,7 @@ static bool llama_eval_internal( ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads); } #elif GGML_USE_MPI - ggml_mpi_graph_compute(lctx.ctx_mpi, ctx0, &gf, n_layer); + ggml_mpi_graph_compute(lctx.ctx_mpi, ctx0, &gf, n_layer, n_threads); cur = gf.nodes[gf.n_nodes - 1]; #else From f085a57d1a29544faf358de9d76141fb6eabf978 Mon Sep 17 00:00:00 2001 From: Evan Miller Date: Sun, 9 Jul 2023 15:31:53 -0400 Subject: [PATCH 19/23] [mpi] Link MPI C++ libraries to fix OpenMPI --- CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2d992c394e1bc..cf6cd34f18ec1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -321,6 +321,11 @@ if (LLAMA_MPI) set(c_flags ${c_flags} -Wno-cast-qual) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_C_LIBRARIES}) set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${MPI_C_INCLUDE_DIRS}) + # Even if you're only using the C header, C++ programs may bring in MPI + # C++ functions, so more linkage is needed + if (MPI_CXX_FOUND) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_CXX_LIBRARIES}) + endif() else() message(WARNING "MPI not found") endif() From 00b8aa1e66472806fa2bc5a342e6d1c05d817065 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 9 Jul 2023 22:31:54 +0300 Subject: [PATCH 20/23] tests : fix new llama_backend API --- tests/test-tokenizer-0.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp index 1d8759a27ebb9..87fde16453d25 100644 --- a/tests/test-tokenizer-0.cpp +++ b/tests/test-tokenizer-0.cpp @@ -31,7 +31,7 @@ int main(int argc, char **argv) { llama_model * model; llama_context * ctx; - llama_init_backend(false); + llama_backend_init(false); // load the vocab { @@ -99,7 +99,7 @@ int main(int argc, char **argv) { llama_free_model(model); llama_free(ctx); - llama_finalize_backend(); + llama_backend_free(); return 0; } From ada1a2aa8bee9edff58d7cb7634ffe5a4bcc6cad Mon Sep 17 00:00:00 2001 From: Evan Miller Date: Sun, 9 Jul 2023 15:37:33 -0400 Subject: [PATCH 21/23] [mpi] use MPI_INT32_T --- ggml-mpi.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml-mpi.c b/ggml-mpi.c index 70639e0788e39..f1bbcabfea510 100644 --- a/ggml-mpi.c +++ b/ggml-mpi.c @@ -114,7 +114,7 @@ void ggml_mpi_graph_compute( const int mpi_rank_src = mpi_rank - 1; - const int retval = MPI_Recv(inp_tokens->data, ggml_nelements(inp_tokens), MPI_INT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + const int retval = MPI_Recv(inp_tokens->data, ggml_nelements(inp_tokens), MPI_INT32_T, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); GGML_ASSERT(retval == MPI_SUCCESS); } else { // recv input data for each node into the "inp0" tensor (i.e. the first node in the compute graph) MPI_Status status; UNUSED(status); @@ -130,7 +130,7 @@ void ggml_mpi_graph_compute( { const int mpi_rank_dst = mpi_rank + 1; - const int retval = MPI_Send(inp_tokens->data, ggml_nelements(inp_tokens), MPI_INT, mpi_rank_dst, 0, MPI_COMM_WORLD); + const int retval = MPI_Send(inp_tokens->data, ggml_nelements(inp_tokens), MPI_INT32_T, mpi_rank_dst, 0, MPI_COMM_WORLD); GGML_ASSERT(retval == MPI_SUCCESS); } From c3c3ef11a60c9e5da51a4a9049e644aa10f39034 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 10 Jul 2023 18:35:38 +0300 Subject: [PATCH 22/23] mpi : factor out recv / send in functions and reuse --- ggml-mpi.c | 77 +++++++++++++++++++++++++++--------------------------- 1 file changed, 38 insertions(+), 39 deletions(-) diff --git a/ggml-mpi.c b/ggml-mpi.c index f1bbcabfea510..6282b7276fb4d 100644 --- a/ggml-mpi.c +++ b/ggml-mpi.c @@ -56,7 +56,7 @@ void ggml_mpi_eval_init( MPI_Bcast(n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD); } -int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) { +static int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) { struct ggml_tensor * t = ggml_graph_get_tensor(gf, name); if (t == NULL) { fprintf(stderr, "%s: tensor %s not found\n", __func__, name); @@ -73,6 +73,34 @@ int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) { return -1; } +static void ggml_mpi_tensor_send(struct ggml_tensor * t, int mpi_rank_dst) { + MPI_Datatype mpi_type; + + switch (t->type) { + case GGML_TYPE_I32: mpi_type = MPI_INT32_T; break; + case GGML_TYPE_F32: mpi_type = MPI_FLOAT; break; + default: GGML_ASSERT(false && "not implemented"); + } + + const int retval = MPI_Send(t->data, ggml_nelements(t), mpi_type, mpi_rank_dst, 0, MPI_COMM_WORLD); + GGML_ASSERT(retval == MPI_SUCCESS); +} + +static void ggml_mpi_tensor_recv(struct ggml_tensor * t, int mpi_rank_src) { + MPI_Datatype mpi_type; + + switch (t->type) { + case GGML_TYPE_I32: mpi_type = MPI_INT32_T; break; + case GGML_TYPE_F32: mpi_type = MPI_FLOAT; break; + default: GGML_ASSERT(false && "not implemented"); + } + + MPI_Status status; UNUSED(status); + + const int retval = MPI_Recv(t->data, ggml_nelements(t), mpi_type, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + GGML_ASSERT(retval == MPI_SUCCESS); +} + // TODO: there are many improvements that can be done to this implementation void ggml_mpi_graph_compute( struct ggml_mpi_context * ctx_mpi, @@ -109,41 +137,19 @@ void ggml_mpi_graph_compute( // node 0: [(n-1) * n_per_node, n_nodes) // if (mpi_rank > 0) { - if (mpi_rank == 1) { // the first node receives the input tokens from the main node - MPI_Status status; UNUSED(status); - - const int mpi_rank_src = mpi_rank - 1; - - const int retval = MPI_Recv(inp_tokens->data, ggml_nelements(inp_tokens), MPI_INT32_T, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); - GGML_ASSERT(retval == MPI_SUCCESS); - } else { // recv input data for each node into the "inp0" tensor (i.e. the first node in the compute graph) - MPI_Status status; UNUSED(status); - - const int mpi_rank_src = mpi_rank - 1; - - //printf("%s: node %d: waiting for %d elements from %d\n", __func__, mpi_rank, (int) ggml_nelements(inp0), mpi_rank_src); - const int retval = MPI_Recv(inp0->data, ggml_nelements(inp0), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); - GGML_ASSERT(retval == MPI_SUCCESS); + if (mpi_rank == 1) { + // the first node (1) receives the input tokens from the main node (0) + ggml_mpi_tensor_recv(inp_tokens, 0); + } else { + // recv input data for each node into the "inp0" tensor (i.e. the first node in the compute graph) + ggml_mpi_tensor_recv(inp0, mpi_rank - 1); } } else if (mpi_size > 1) { // node 0 sends the input tokens to node 1 - { - const int mpi_rank_dst = mpi_rank + 1; - - const int retval = MPI_Send(inp_tokens->data, ggml_nelements(inp_tokens), MPI_INT32_T, mpi_rank_dst, 0, MPI_COMM_WORLD); - GGML_ASSERT(retval == MPI_SUCCESS); - } + ggml_mpi_tensor_send(inp_tokens, 1); // recv the output data from the last node - { - MPI_Status status; UNUSED(status); - - const int mpi_rank_src = mpi_size - 1; - - //fprintf(stderr, "%s: node %d: waiting for %d elements from %d\n", __func__, mpi_rank, (int) ggml_nelements(inp0), mpi_rank_src); - const int retval = MPI_Recv(inp0->data, ggml_nelements(inp0), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); - GGML_ASSERT(retval == MPI_SUCCESS); - } + ggml_mpi_tensor_recv(inp0, mpi_size - 1); } { @@ -201,13 +207,6 @@ void ggml_mpi_graph_compute( // send the output data to the next node if (mpi_rank > 0) { - struct ggml_tensor * output = gf->nodes[gf->n_nodes - 1]; - - const int mpi_rank_dst = (mpi_rank + 1) % mpi_size; - - //fprintf(stderr, "%s: node %d: sending %d elements to node %d\n", __func__, mpi_rank, ggml_nelements(output), mpi_rank_dst); - - const int retval = MPI_Send(output->data, ggml_nelements(output), MPI_FLOAT, mpi_rank_dst, 0, MPI_COMM_WORLD); - GGML_ASSERT(retval == MPI_SUCCESS); + ggml_mpi_tensor_send(gf->nodes[gf->n_nodes - 1], (mpi_rank + 1) % mpi_size); } } From eaef2d0e76d8a64b80154bc0c253408a659928e3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 10 Jul 2023 18:47:24 +0300 Subject: [PATCH 23/23] mpi : extend API to allow usage with outer backends (e.g. Metal) --- .gitignore | 1 + ggml-metal.m | 1 + ggml-mpi.c | 16 ++++++++------ ggml-mpi.h | 11 ++++++---- llama.cpp | 59 +++++++++++++++++++++++++--------------------------- 5 files changed, 47 insertions(+), 41 deletions(-) diff --git a/.gitignore b/.gitignore index 4fccec31b8114..faec869e040b2 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,7 @@ build-static/ build-cublas/ build-opencl/ build-metal/ +build-mpi/ build-no-accel/ build-sanitize-addr/ build-sanitize-thread/ diff --git a/ggml-metal.m b/ggml-metal.m index 3f15f791f9f65..6473644c24204 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -450,6 +450,7 @@ void ggml_metal_graph_compute( //} switch (dst->op) { + case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: case GGML_OP_TRANSPOSE: diff --git a/ggml-mpi.c b/ggml-mpi.c index 6282b7276fb4d..872e808de7700 100644 --- a/ggml-mpi.c +++ b/ggml-mpi.c @@ -102,12 +102,10 @@ static void ggml_mpi_tensor_recv(struct ggml_tensor * t, int mpi_rank_src) { } // TODO: there are many improvements that can be done to this implementation -void ggml_mpi_graph_compute( +void ggml_mpi_graph_compute_pre( struct ggml_mpi_context * ctx_mpi, - struct ggml_context * ctx, struct ggml_cgraph * gf, - int n_layers, - int n_threads) { + int n_layers) { const int mpi_rank = ctx_mpi->rank; const int mpi_size = ctx_mpi->size; @@ -200,10 +198,16 @@ void ggml_mpi_graph_compute( //fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, mpi_rank, gf->n_nodes, il0, il1); } +} - ggml_graph_compute_with_ctx(ctx, gf, n_threads); +void ggml_mpi_graph_compute_post( + struct ggml_mpi_context * ctx_mpi, + struct ggml_cgraph * gf, + int n_layers) { + UNUSED(n_layers); - //fprintf(stderr, "%s: node %d: done\n", __func__, mpi_rank); + const int mpi_rank = ctx_mpi->rank; + const int mpi_size = ctx_mpi->size; // send the output data to the next node if (mpi_rank > 0) { diff --git a/ggml-mpi.h b/ggml-mpi.h index 2ad0a43864691..eda119d449849 100644 --- a/ggml-mpi.h +++ b/ggml-mpi.h @@ -24,12 +24,15 @@ void ggml_mpi_eval_init( int * n_past, int * n_threads); -void ggml_mpi_graph_compute( +void ggml_mpi_graph_compute_pre( struct ggml_mpi_context * ctx_mpi, - struct ggml_context * ctx, struct ggml_cgraph * gf, - int n_layers, - int n_threads); + int n_layers); + +void ggml_mpi_graph_compute_post( + struct ggml_mpi_context * ctx_mpi, + struct ggml_cgraph * gf, + int n_layers); #ifdef __cplusplus } diff --git a/llama.cpp b/llama.cpp index 322e37a7de72a..ad7283faf1f1a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1632,6 +1632,10 @@ static bool llama_eval_internal( // run the computation ggml_build_forward_expand(&gf, cur); +#if GGML_USE_MPI + ggml_mpi_graph_compute_pre(lctx.ctx_mpi, &gf, n_layer); +#endif + #ifdef GGML_USE_METAL if (lctx.ctx_metal && N == 1) { ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); @@ -1656,14 +1660,19 @@ static bool llama_eval_internal( ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads); } -#elif GGML_USE_MPI - ggml_mpi_graph_compute(lctx.ctx_mpi, ctx0, &gf, n_layer, n_threads); - - cur = gf.nodes[gf.n_nodes - 1]; #else ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads); #endif +#if GGML_USE_MPI + ggml_mpi_graph_compute_post(lctx.ctx_mpi, &gf, n_layer); +#endif + + // update kv token count + lctx.kv_self.n = n_past + N; + + struct ggml_tensor * res = gf.nodes[gf.n_nodes - 1]; + if (cgraph_fname) { ggml_graph_export(&gf, cgraph_fname); } @@ -1679,38 +1688,26 @@ static bool llama_eval_internal( // ggml_graph_dump_dot(&gf, NULL, "llama.dot"); //} - //embd_w.resize(n_vocab*N); - //memcpy(embd_w.data(), ggml_get_data(cur), sizeof(float)*n_vocab*N); - - // update kv token count - lctx.kv_self.n = n_past + N; - -#ifdef GGML_USE_MPI - if (ggml_mpi_rank(lctx.ctx_mpi) == 0) { -#else + // extract logits { -#endif - // extract logits - { - auto & logits_out = lctx.logits; + auto & logits_out = lctx.logits; - if (lctx.logits_all) { - logits_out.resize(n_vocab * N); - memcpy(logits_out.data(), (float *) ggml_get_data(cur), sizeof(float)*n_vocab*N); - } else { - // return result for just the last token - logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float *) ggml_get_data(cur) + (n_vocab*(N-1)), sizeof(float)*n_vocab); - } + if (lctx.logits_all) { + logits_out.resize(n_vocab * N); + memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N); + } else { + // return result for just the last token + logits_out.resize(n_vocab); + memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(N-1)), sizeof(float)*n_vocab); } + } - // extract embeddings - if (!lctx.embedding.empty()) { - auto & embedding_out = lctx.embedding; + // extract embeddings + if (!lctx.embedding.empty()) { + auto & embedding_out = lctx.embedding; - embedding_out.resize(n_embd); - memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); - } + embedding_out.resize(n_embd); + memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); } if (mem_per_token == 0) {