Skip to content

Commit

Permalink
feat: Distributed data parallel training support (#2464)
Browse files Browse the repository at this point in the history
* first experiment distributed

* feat: add DistributedUtils (MPI&NCCL working)

* feat: add DistributedUtils (MPI&NCCL working)

* fix: no need for amdgpu now

* chore: cleanup&propose how to use amdgpu

* chore: add preferences for CUDA-awareness

* feat: fix devices for CUDA-awareness

* chore: add tests

* chore: get rid of unnecessary deps

* chore: update NEWS.md

* chore: cleanup env

* chore: update docs

* chore: update docs & cleanup

* chore: update docs & cleanup

* Update docs/src/guide/gpu.md

Co-authored-by: Carlo Lucibello <carlo.lucibello@unibocconi.it>

* Update docs/src/guide/gpu.md

Co-authored-by: Carlo Lucibello <carlo.lucibello@unibocconi.it>

* Update docs/src/guide/gpu.md

Co-authored-by: Carlo Lucibello <carlo.lucibello@unibocconi.it>

* Update docs/src/guide/gpu.md

Co-authored-by: Carlo Lucibello <carlo.lucibello@unibocconi.it>

* Update docs/src/guide/gpu.md

Co-authored-by: Carlo Lucibello <carlo.lucibello@unibocconi.it>

* Update docs/src/guide/gpu.md

Co-authored-by: Carlo Lucibello <carlo.lucibello@unibocconi.it>

* Update docs/src/guide/gpu.md

Co-authored-by: Carlo Lucibello <carlo.lucibello@unibocconi.it>

* Update docs/src/guide/gpu.md

Co-authored-by: Carlo Lucibello <carlo.lucibello@unibocconi.it>

* Update docs/src/guide/gpu.md

Co-authored-by: Carlo Lucibello <carlo.lucibello@unibocconi.it>

* Update docs/src/guide/gpu.md

* Update docs/src/guide/gpu.md

* chore: add PR review suggestions

* chore: fix docs

* fix: add runtests.jl

* chore: small docs update

* chore: remove pkgs from deps

---------

Co-authored-by: CarloLucibello <carlo.lucibello@gmail.com>
Co-authored-by: Carlo Lucibello <carlo.lucibello@unibocconi.it>
  • Loading branch information
3 people authored Aug 19, 2024
1 parent 033f4b2 commit d1ff714
Show file tree
Hide file tree
Showing 14 changed files with 1,042 additions and 0 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release.

## v0.14.18
* Add [support for distributed data parallel training](https://github.com/FluxML/Flux.jl/pull/2446).
* MPI and NCCL backend available with `FluxMPIExt` and `FluxMPINCCLExt` extensions respectively.

## v0.14.17
* Add [support for Enzyme](https://github.com/FluxML/Flux.jl/pull/2446) with `Flux.train!`.

Expand Down
8 changes: 8 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -26,14 +27,18 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[extensions]
FluxAMDGPUExt = "AMDGPU"
FluxCUDAExt = "CUDA"
FluxCUDAcuDNNExt = ["CUDA", "cuDNN"]
FluxEnzymeExt = "Enzyme"
FluxMPIExt = "MPI"
FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
FluxMetalExt = "Metal"

[compat]
Expand All @@ -45,14 +50,17 @@ Compat = "4.10.0"
Enzyme = "0.12"
Functors = "0.4"
MLUtils = "0.4"
MPI = "0.20.19"
MacroTools = "0.5"
Metal = "0.5, 1"
NCCL = "0.1.1"
NNlib = "0.9.22"
OneHotArrays = "0.2.4"
Optimisers = "0.3.3"
Preferences = "1"
ProgressLogging = "0.1"
Reexport = "1.0"
Setfield = "1.1"
SpecialFunctions = "2.1.2"
Statistics = "1"
Zygote = "0.6.67"
Expand Down
117 changes: 117 additions & 0 deletions docs/src/guide/gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,120 @@ Flux.supported_devices
Flux.get_device
Flux.gpu_backend!
```

## Distributed data parallel training

!!! danger "Experimental"

Distributed support is experimental and could change in the future.


Flux supports now distributed data parallel training with `DistributedUtils` module.
If you want to run your code on multiple GPUs, you have to install `MPI.jl` (see [docs](https://juliaparallel.org/MPI.jl/stable/usage/) for more info).

```julia-repl
julia> using MPI
julia> MPI.install_mpiexecjl()
```

Now you can run your code with `mpiexecjl --project=. -n <np> julia <filename>.jl` from CLI.

You can use either the `MPIBackend` or `NCCLBackend`, the latter only if also `NCCL.jl` is loaded. First, initialize a backend with `DistributedUtils.initialize`, e.g.

```julia-repl
julia> using Flux, MPI, NCCL, CUDA
julia> CUDA.allowscalar(false)
julia> DistributedUtils.initialize(NCCLBackend)
julia> backend = DistributedUtils.get_distributed_backend(NCCLBackend)
NCCLBackend{Communicator, MPIBackend{MPI.Comm}}(Communicator(Ptr{NCCL.LibNCCL.ncclComm} @0x000000000607a660), MPIBackend{MPI.Comm}(MPI.Comm(1140850688)))
```

Pass your model, as well as any data to GPU device.
```julia-repl
julia> model = Chain(Dense(1 => 256, tanh), Dense(256 => 1)) |> gpu
Chain(
Dense(1 => 256, tanh), # 512 parameters
Dense(256 => 1), # 257 parameters
) # Total: 4 arrays, 769 parameters, 744 bytes.
julia> x = rand(Float32, 1, 16) |> gpu
1×16 CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}:
0.239324 0.331029 0.924996 0.55593 0.853093 0.874513 0.810269 0.935858 0.477176 0.564591 0.678907 0.729682 0.96809 0.115833 0.66191 0.75822
julia> y = x .^ 3
1×16 CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}:
0.0137076 0.0362744 0.791443 0.171815 0.620854 0.668804 0.53197 0.819654 0.108651 0.179971 0.312918 0.388508 0.907292 0.00155418 0.29 0.435899
```

In this case, we are training on a total of `16 * number of processes` samples. You can also use `DistributedUtils.DistributedDataContainer` to split the data uniformly across processes (or do it manually).

```julia-repl
julia> data = DistributedUtils.DistributedDataContainer(backend, x)
Flux.DistributedUtils.DistributedDataContainer(Float32[0.23932439 0.33102947 … 0.66191036 0.75822026], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16])
```

You have to wrap your model in `DistributedUtils.FluxDistributedModel` and synchronize it (broadcast accross all processes):
```julia-repl
julia> model = DistributedUtils.synchronize!!(backend, DistributedUtils.FluxDistributedModel(model); root=0)
Chain(
Dense(1 => 256, tanh), # 512 parameters
Dense(256 => 1), # 257 parameters
) # Total: 4 arrays, 769 parameters, 744 bytes.
```

Time to set up an optimizer by using `DistributedUtils.DistributedOptimizer` and synchronize it as well.
```julia-repl
julia> using Optimisers
julia> opt = DistributedUtils.DistributedOptimizer(backend, Optimisers.Adam(0.001f0))
DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8))
julia> st_opt = Optimisers.setup(opt, model)
(layers = ((weight = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999))), bias = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999))), σ = ()), (weight = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0], Float32[0.0], (0.9, 0.999))), σ = ())),)
julia> st_opt = DistributedUtils.synchronize!!(backend, st_opt; root=0)
(layers = ((weight = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999))), bias = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999))), σ = ()), (weight = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0], Float32[0.0], (0.9, 0.999))), σ = ())),)
```

Now you can define loss and train the model.
```julia-repl
julia> loss(model) = mean((model(x) .- y).^2)
loss (generic function with 1 method)
julia> for epoch in 1:100
global model, st_opt
l, grad = Zygote.withgradient(loss, model)
println("Epoch $epoch: Loss $l")
st_opt, model = Optimisers.update(st_opt, model, grad[1])
end
Epoch 1: Loss 0.011638729
Epoch 2: Loss 0.0116432225
Epoch 3: Loss 0.012763695
...
```

Remember that in order to run it on multiple GPUs you have to run from CLI `mpiexecjl --project=. -n <np> julia <filename>.jl`,
where `<np>` is the number of processes that you want to use. The number of processes usually corresponds to the number of gpus.

By default `MPI.jl` MPI installation is CUDA-unaware so if you want to run it in CUDA-aware mode, read more [here](https://juliaparallel.org/MPI.jl/stable/usage/#CUDA-aware-MPI-support) on custom installation and rebuilding `MPI.jl`.
Then test if your MPI is CUDA-aware by
```julia-repl
julia> import Pkg
julia> Pkg.test("MPI"; test_args=["--backend=CUDA"])
```

If it is, set your local preference as below
```julia-repl
julia> using Preferences
julia> set_preferences!("Flux", "FluxDistributedMPICUDAAware" => true)
```

!!! warning "Known shortcomings"

We don't run CUDA-aware tests so you're running it at own risk.

183 changes: 183 additions & 0 deletions ext/FluxMPIExt/FluxMPIExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
module FluxMPIExt

using CUDA
using Flux: MPIBackend, NCCLBackend, DistributedUtils,
AbstractDevice, FluxCUDADevice, FluxAMDGPUDevice, cpu, gpu,
get_device, MPI_CUDA_AWARE, MPI_ROCM_AWARE
using MPI: MPI

if Base.find_package("AMDGPU") !== nothing
using AMDGPU
end


function DistributedUtils.__initialize(
::Type{MPIBackend}; cuda_devices=nothing, amdgpu_devices=nothing,
force_cuda::Bool=false, caller::String="", force_amdgpu::Bool=false) # Undocumented internal kwarg
!MPI.Initialized() && MPI.Init()
DistributedUtils.MPI_Initialized[] = true

local_rank = MPI.Comm_rank(MPI.COMM_WORLD)

if cuda_devices !== missing && CUDA.functional()
if cuda_devices === nothing
CUDA.device!((local_rank + 1) % length(CUDA.devices()))
else
CUDA.device!(cuda_devices[local_rank + 1])
end
elseif force_cuda
error(lazy"CUDA devices are not functional and `force_cuda` is set to `true`. This is caused by backend: $(caller).")
end

if Base.find_package("AMDGPU") !== nothing
if amdgpu_devices !== missing && AMDGPU.functional()
if amdgpu_devices === nothing
AMDGPU.device!((local_rank + 1) % length(AMDGPU.devices()))
else
AMDGPU.device!(amdgpu_devices[local_rank + 1])
end
elseif force_amdgpu
error(lazy"AMDGPU devices are not functional (or `LuxAMDGPU.jl` not loaded) and `force_amdgpu` is set to `true`. This is caused by backend: $(caller).")
end
end

return
end

DistributedUtils.__get_distributed_backend(::Type{MPIBackend}) = MPIBackend(MPI.COMM_WORLD)

DistributedUtils.local_rank(backend::MPIBackend) = MPI.Comm_rank(backend.comm)

DistributedUtils.total_workers(backend::MPIBackend) = MPI.Comm_size(backend.comm)

# Broadcast
# Union with Function is because of Flux.cpu istypeof Function
# We need CPU in case of non CUDA-aware implementation
function DistributedUtils.__bcast!(
backend::MPIBackend, sendrecvbuf, dev::Union{AbstractDevice, Function}; root=0)
MPI.Bcast!(sendrecvbuf, backend.comm; root)
return sendrecvbuf
end

function DistributedUtils.__bcast!(
backend::MPIBackend, sendbuf, recvbuf, dev::Union{AbstractDevice, Function}; root=0)
return DistributedUtils.__bcast!(
backend, ifelse(DistributedUtils.local_rank(backend) == root, sendbuf, recvbuf),
dev; root)
end

# if MPI implementation is not CUDA-aware
# we have to move data to CPU first
for (aware, dType) in ((MPI_CUDA_AWARE, FluxCUDADevice), (MPI_ROCM_AWARE, FluxAMDGPUDevice))
if !aware
@eval begin
function DistributedUtils.__bcast!(
backend::MPIBackend, sendrecvbuf, dev::$dType; root=0)
sendrecvbuf_ = sendrecvbuf |> cpu
DistributedUtils.__bcast!(backend, sendrecvbuf_, cpu; root)
sendrecvbuf |> gpu
return sendrecvbuf
end

function DistributedUtils.__bcast!(
backend::MPIBackend, sendbuf, recvbuf, dev::$dType; root=0)
sendbuf_ = sendbuf |> cpu
recvbuf_ = recvbuf |> cpu
DistributedUtils.__bcast!(backend, sendbuf_, recvbuf_, cpu; root)
recvbuf |> gpu
return recvbuf
end
end
end
end


# Allreduce
function DistributedUtils.__allreduce!(
backend::MPIBackend, sendrecvbuf, op::F, dev::Union{AbstractDevice, Function};) where {F}
mpiop = ifelse(op === DistributedUtils.avg, +, op)
MPI.Allreduce!(sendrecvbuf, mpiop, backend.comm)
if op === DistributedUtils.avg
sendrecvbuf ./= DistributedUtils.total_workers(backend)
end
return sendrecvbuf
end

function DistributedUtils.__allreduce!(
backend::MPIBackend, sendbuf, recvbuf, op::F, dev::Union{AbstractDevice, Function};) where {F}
mpiop = ifelse(op === DistributedUtils.avg, +, op)
MPI.Allreduce!(sendbuf, recvbuf, mpiop, backend.comm)
if op === DistributedUtils.avg
recvbuf ./= DistributedUtils.total_workers(backend)
end
return recvbuf
end

for (aware, dType) in ((MPI_CUDA_AWARE, FluxCUDADevice), (MPI_ROCM_AWARE, FluxAMDGPUDevice))
if !aware
@eval begin
function DistributedUtils.__allreduce!(
backend::MPIBackend, sendrecvbuf, op::F, dev::$dType) where {F}
sendrecvbuf_ = sendrecvbuf |> cpu
DistributedUtils.__allreduce!(backend, sendrecvbuf_, op, cpu)
sendrecvbuf |> gpu
return sendrecvbuf
end

function DistributedUtils.__allreduce!(
backend::MPIBackend, sendbuf, recvbuf, op::F, dev::$dType) where {F}
sendbuf_ = sendbuf |> cpu
recvbuf_ = recvbuf |> cpu
DistributedUtils.__allreduce!(backend, sendbuf_, recvbuf_, op, cpu)
recvbuf |> gpu
return recvbuf
end
end
end
end

# Reduce
function DistributedUtils.__reduce!(backend::MPIBackend, sendrecvbuf, op::F,
dev::Union{AbstractDevice, Function}; root::Int) where {F}
mpiop = ifelse(op === DistributedUtils.avg, +, op)
MPI.Reduce!(sendrecvbuf, mpiop, backend.comm; root)
if op === DistributedUtils.avg
sendrecvbuf ./= DistributedUtils.total_workers(backend)
end
return sendrecvbuf
end

function DistributedUtils.__reduce!(backend::MPIBackend, sendbuf, recvbuf, op::F,
dev::Union{AbstractDevice, Function}; root::Int) where {F}
mpiop = ifelse(op === DistributedUtils.avg, +, op)
MPI.Reduce!(sendbuf, recvbuf, mpiop, backend.comm; root)
if op === DistributedUtils.avg
recvbuf ./= DistributedUtils.total_workers(backend)
end
return recvbuf
end

for (aware, dType) in ((MPI_CUDA_AWARE, FluxCUDADevice), (MPI_ROCM_AWARE, FluxAMDGPUDevice))
if !aware
@eval begin
function DistributedUtils.__reduce!(backend::MPIBackend, sendrecvbuf, op::F,
dev::$dType; root::Int) where {F}
sendrecvbuf_ = sendrecvbuf |> cpu
DistributedUtils.__reduce!(backend, sendrecvbuf_, op, cpu; root)
sendrecvbuf |> gpu
return sendrecvbuf
end

function DistributedUtils.__reduce!(backend::MPIBackend, sendbuf, recvbuf,
op::F, dev::$dType; root::Int) where {F}
sendbuf_ = sendbuf |> cpu
recvbuf_ = recvbuf |> cpu
DistributedUtils.__reduce!(backend, sendbuf_, recvbuf_, op, cpu; root)
recvbuf |> gpu
return recvbuf
end
end
end
end

end
Loading

0 comments on commit d1ff714

Please sign in to comment.