Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed inference via MPI #2099

Merged
merged 29 commits into from
Jul 10, 2023
Merged

Distributed inference via MPI #2099

merged 29 commits into from
Jul 10, 2023

Conversation

evanmiller
Copy link
Collaborator

@evanmiller evanmiller commented Jul 4, 2023

Model inference is currently limited by the memory on a single node. Using MPI, we can distribute models across a locally networked cluster of machines.

This PR uses a ring pipeline architecture so that the process at rank (index) 0 handles both input and output. The layers are grouped into slices, and each MPI slot (process) handles a slice. The communication during each token prediction happens like

Rank 0 -> Rank 1 -> ... -> Rank N-1 -> Rank 0

Running MPI locally with N=8, you can see the 13B model distributed across 8 processes; each process takes up less than a gigabyte of system memory.

image

Note that this doesn't speed anything up as the processes cannot execute concurrently, but these processes can be distributed to multiple machines to take advantage of more machine RAM. No special code was required to read a subset of weights; selective weight-loading is just a consequence of mmap.

See notes added to the README to try the distributed code for yourself.

Technical changes

The set of changes is somewhat minimal; the additions are:

  • New LLAMA_MPI compile-time option
  • New ggml_mpi_send_tensor and ggml_mpi_recv_tensor functions, possibly to be added to GGML later
  • New llama_finalize_backend() API function (calls MPI_Finalize())
  • New mpi_rank and mpi_size fields in the llama_context object

To take advantage of MPI, binary CLI programs usually need no source code changes except to call llama_finalize_backend(). This is something of a hack – I have modified llama_new_context_with_model to enter an evaluation loop on non-primary processes. This loop blocks at MPI_Barrier, waiting for the driving (rank 0) program to call it. I'm open to other suggestions, but this strategy let me run the example programs more or less out of the box.

The changes to the core token prediction algorithm involve sending or receiving tensors before and after the layer loop. Each process only handles a subset of layers. If the process does not handle the first layer, it receives the input tensor from the preceding process. To close the communication ring, the driving (first) process will receive the layer output from the last process, and use that output tensor to compute logits and embeddings. This ensures that all user I/O occurs within a single process.

I was able to test the cluster code locally on an iMac connected to a (very slow) 12" MacBook over WiFi. It didn't win any speed awards, but it did generate plausible text, so I am confident in the overall algorithm correctness. However, there are likely bugs / oversights when it comes to handling MPI communication errors and shutdown.

Leaving as draft as I presume the GGML changes should be finalized and merged before the llama.cpp changes.

See previous discussion in #946

Makefile Outdated Show resolved Hide resolved
@ggerganov
Copy link
Owner

ggerganov commented Jul 6, 2023

That's actually neat! I'm surprised that the change is so small

I saw the discussion in ggerganov/ggml#340

Let's try to make the following changes and see if we can make the implementation more decoupled from ggml and llama.cpp. Here is the plan:

  • Try to avoid introducing tag into ggml_tensor. Instead, utilize the extra pointer to reference a custom MPI-related struct with that information. See how we do this in the CUDA backend as an example:

    llama.cpp/ggml-cuda.cu

    Lines 236 to 239 in f789f2c

    struct ggml_tensor_extra_gpu {
    void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
    cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
    };
  • Ideally, we want ggml.c to not include mpi.h. To achieve that, lets try to define ggml_send_tensor() and ggml_recv_tensor() as custom operators (see ggml_map_custom1_f32()) in the user code - i.e. llama.cpp. Currently we don't have an example how to use these custom ops, but I hope it is clear from the API. Let me know if not.
    If this change works out, then you will be able to avoid all modifications in ggml and isolate the MPI dependency just in llama.cpp, so the companion PR won't be necessary anymore
  • The next step, would be to try and move as much code as possible into a new backend: ggml-mpi.h / ggml-mpi.cpp. This backend will provide a very simple API - mainly ggml_mpi_graph_compute() + helper init/free functions. See ggml-metal.h for reference. The more code you manage to move from llama.cpp into ggml-mpi.h/.cpp the better. I think you will be able to also move the definitions of the custom send / recv operators from the previous point in the ggml-mpi backend as well. Maybe the mpi_rank and mpi_size can go into a new struct ggml_mpi_context similar to struct ggml_metal_context.
    All this will help to decouple the implementation from the MPI specifics and let us more easily extend in the future

In short, you can follow the GGML_USE_METAL ifdefs in llama.cpp. The goal is to have a very similar pattern for GGML_USE_MPI with as few modifications in llama.cpp as possible. I am not 100% sure that this refactoring plan will completely workout, so if you reach a blocker - let us know.

@ggerganov ggerganov added the high priority Very important issue label Jul 6, 2023
@evanmiller
Copy link
Collaborator Author

  • Try to avoid introducing tag into ggml_tensor. Instead, utilize the extra pointer to reference a custom MPI-related struct with that information. See how we do this in the CUDA backend as an example:

    llama.cpp/ggml-cuda.cu

    Lines 236 to 239 in f789f2c

    struct ggml_tensor_extra_gpu {
    void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
    cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
    };

Done; but how/when should this custom struct be freed? I could free and NULL it immediately after using the information, but this feels wrong somehow. (It's also only 4 bytes so we could also just not worry about the micro-leak.)

  • Ideally, we want ggml.c to not include mpi.h. To achieve that, lets try to define ggml_send_tensor() and ggml_recv_tensor() as custom operators (see ggml_map_custom1_f32()) in the user code - i.e. llama.cpp. Currently we don't have an example how to use these custom ops, but I hope it is clear from the API. Let me know if not.

Done

If this change works out, then you will be able to avoid all modifications in ggml and isolate the MPI dependency just in llama.cpp, so the companion PR won't be necessary anymore

Yep, with your suggested approach, the core ggml.c changes are no longer necessary. Will close that PR.

  • The next step, would be to try and move as much code as possible into a new backend: ggml-mpi.h / ggml-mpi.cpp. This backend will provide a very simple API - mainly ggml_mpi_graph_compute() + helper init/free functions. See ggml-metal.h for reference. The more code you manage to move from llama.cpp into ggml-mpi.h/.cpp the better. I think you will be able to also move the definitions of the custom send / recv operators from the previous point in the ggml-mpi backend as well. Maybe the mpi_rank and mpi_size can go into a new struct ggml_mpi_context similar to struct ggml_metal_context.

I've moved code into ggml-mpi.c/h in this PR. If this is what you were thinking, I can open a new PR over in GGML.

All this will help to decouple the implementation from the MPI specifics and let us more easily extend in the future

In short, you can follow the GGML_USE_METAL ifdefs in llama.cpp. The goal is to have a very similar pattern for GGML_USE_MPI with as few modifications in llama.cpp as possible. I am not 100% sure that this refactoring plan will completely workout, so if you reach a blocker - let us know.

Tested out the changes with a local MPI ring, and inference still seems to work. Will peel off the Draft label; please let me know if you'd like to see other changes.

@evanmiller
Copy link
Collaborator Author

On another note, this paper outlines parallelization strategies used in Google's PaLM: https://arxiv.org/abs/2211.05102 Not sure if they're applicable to LLaMA, but this would be a good starting point for thinking beyond simple layer-based pipeline parallelism...

@evanmiller evanmiller marked this pull request as ready for review July 7, 2023 00:41
README.md Show resolved Hide resolved
@evanmiller evanmiller mentioned this pull request Jul 7, 2023
4 tasks
@ggerganov
Copy link
Owner

ggerganov commented Jul 9, 2023

@evanmiller

I tried to factor out all the MPI logic into the ggml-mpi backend: evanmiller#1

Want to test if this works, but I don't know how to make the hostfile that you mention in the instructions.
Can you help with some howto for getting this to run on a local machine?

Edit: nvm, I just saw the full instructions that you have provided. Will give it a try now

@evanmiller
Copy link
Collaborator Author

@evanmiller

I tried to factor out all the MPI logic into the ggml-mpi backend: evanmiller#1

Want to test if this works, but I don't know how to make the hostfile that you mention in the instructions. Can you help with some howto for getting this to run on a local machine?

Edit: nvm, I just saw the full instructions that you have provided. Will give it a try now

Great! Note that hostfile is only needed for a networked cluster – just omit the argument to execute locally with shared memory communication.

@ggerganov
Copy link
Owner

@evanmiller

Thanks - it works now. Please take a look at the proposed changes.
The implementation is quite hacky atm, but I am OK to merge it since it is well decoupled from llama.cpp. Later we can improve it from the master branch.

Let me know what you think

@evanmiller
Copy link
Collaborator Author

I've looked over your branch, I agree it's a little hacky but I was able to follow the logic. Overall it makes sense to me. It's great that this will work out of the box with many other models!

My only tentative feedback would be to replace MPI_INT with MPI_INT32_T for the token tensor. However the latter constant does not appear to be supported by OpenMPI, so you'd need an additional #ifdef to make it all work. Let me add a block to the GitHub workflow so we have both OpenMPI and MPICH coverage.

@evanmiller
Copy link
Collaborator Author

The linker is unhappy with OpenMPI on GitHub CI

[ 25%] Building CXX object tests/CMakeFiles/test-sampling.dir/test-sampling.cpp.o
[ 27%] Linking CXX executable ../bin/test-sampling
/usr/bin/ld: ../libllama.a(llama.cpp.o): in function `MPI::Op::Init(void (*)(void const*, void*, int, MPI::Datatype const&), bool)':
llama.cpp:(.text._ZN3MPI2Op4InitEPFvPKvPviRKNS_8DatatypeEEb[_ZN3MPI2Op4InitEPFvPKvPviRKNS_8DatatypeEEb]+0x1d): undefined reference to `ompi_mpi_cxx_op_intercept'
/usr/bin/ld: ../libllama.a(llama.cpp.o): in function `MPI::Intracomm::Clone() const':
llama.cpp:(.text._ZNK3MPI9Intracomm5CloneEv[_ZNK3MPI9Intracomm5CloneEv]+0x40): undefined reference to `MPI::Comm::Comm()'
/usr/bin/ld: ../libllama.a(llama.cpp.o): in function `MPI::Graphcomm::Clone() const':
llama.cpp:(.text._ZNK3MPI9Graphcomm5CloneEv[_ZNK3MPI9Graphcomm5CloneEv]+0x3a): undefined reference to `MPI::Comm::Comm()'
/usr/bin/ld: ../libllama.a(llama.cpp.o): in function `MPI::Cartcomm::Sub(bool const*) const':
llama.cpp:(.text._ZNK3MPI8Cartcomm3SubEPKb[_ZNK3MPI8Cartcomm3SubEPKb]+0x96): undefined reference to `MPI::Comm::Comm()'
/usr/bin/ld: ../libllama.a(llama.cpp.o): in function `MPI::Intracomm::Create_graph(int, int const*, int const*, bool) const':
llama.cpp:(.text._ZNK3MPI9Intracomm12Create_graphEiPKiS2_b[_ZNK3MPI9Intracomm12Create_graphEiPKiS2_b]+0x42): undefined reference to `MPI::Comm::Comm()'
/usr/bin/ld: ../libllama.a(llama.cpp.o): in function `MPI::Cartcomm::Clone() const':
llama.cpp:(.text._ZNK3MPI8Cartcomm5CloneEv[_ZNK3MPI8Cartcomm5CloneEv]+0x3a): undefined reference to `MPI::Comm::Comm()'
/usr/bin/ld: ../libllama.a(llama.cpp.o):llama.cpp:(.text._ZNK3MPI9Intracomm11Create_cartEiPKiPKbb[_ZNK3MPI9Intracomm11Create_cartEiPKiPKbb]+0xab): more undefined references to `MPI::Comm::Comm()' follow
/usr/bin/ld: ../libllama.a(llama.cpp.o):(.data.rel.ro._ZTVN3MPI8DatatypeE[_ZTVN3MPI8DatatypeE]+0x78): undefined reference to `MPI::Datatype::Free()'
/usr/bin/ld: ../libllama.a(llama.cpp.o):(.data.rel.ro._ZTVN3MPI3WinE[_ZTVN3MPI3WinE]+0x48): undefined reference to `MPI::Win::Free()'
collect2: error: ld returned 1 exit status
gmake[2]: *** [tests/CMakeFiles/test-sampling.dir/build.make:99: bin/test-sampling] Error 1
gmake[1]: *** [CMakeFiles/Makefile2:1299: tests/CMakeFiles/test-sampling.dir/all] Error 2
gmake: *** [Makefile:101: all] Error 2
Error: Process completed with exit code 2.

Guessing it just needs a variable added somewhere in the CMakeLists.

@ggerganov
Copy link
Owner

Yup, let's resolve CI and MPI_INT and look to merge this.
I already have some neat ideas how to improve the implementation - will write an issue with details.

I think we should try to utilize this approach to run a 65B LLaMA on Raspberry Pis.
The only question is if the model can be mmaped from a shared network drive. If this is possible, then plugging enough devices should eventually allow you to do the inference.

It would be a fun thing to try and potentially achieve world-first inference of 65B model on a cluster of Raspberries 😄

@evanmiller
Copy link
Collaborator Author

According to a header comment, MPI_INT32_T was added in the MPI standard v2.2 so we should be safe to use it. Builds locally with 4.1.4.

I will be busy the next few hours but feel free to tweak / merge / etc after reviewing CI.

I'm looking forward to seeing 65B models running on clusters of hacked home appliances!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority Very important issue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants