Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Distributed data parallel training support #2464

Merged
merged 31 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
0393894
first experiment distributed
CarloLucibello Jun 7, 2024
76ae025
feat: add DistributedUtils (MPI&NCCL working)
askorupka Jun 16, 2024
181cc9c
feat: add DistributedUtils (MPI&NCCL working)
askorupka Jul 7, 2024
40bf188
fix: no need for amdgpu now
askorupka Jul 7, 2024
450f62c
chore: cleanup&propose how to use amdgpu
askorupka Jul 7, 2024
8fbde8d
chore: add preferences for CUDA-awareness
askorupka Jul 7, 2024
599f506
feat: fix devices for CUDA-awareness
askorupka Jul 21, 2024
3382010
chore: add tests
askorupka Jul 21, 2024
443875e
chore: get rid of unnecessary deps
askorupka Jul 21, 2024
330b20b
chore: update NEWS.md
askorupka Jul 21, 2024
a255ff9
chore: cleanup env
askorupka Jul 24, 2024
3aab47d
chore: update docs
askorupka Jul 24, 2024
2f54c88
chore: update docs & cleanup
askorupka Jul 24, 2024
8a984bb
chore: update docs & cleanup
askorupka Jul 24, 2024
c0aefb7
Update docs/src/guide/gpu.md
askorupka Aug 3, 2024
bd23dd3
Update docs/src/guide/gpu.md
askorupka Aug 3, 2024
cee9150
Update docs/src/guide/gpu.md
askorupka Aug 3, 2024
5c85fe8
Update docs/src/guide/gpu.md
askorupka Aug 3, 2024
e151ead
Update docs/src/guide/gpu.md
askorupka Aug 3, 2024
2797924
Update docs/src/guide/gpu.md
askorupka Aug 3, 2024
f2cedd5
Update docs/src/guide/gpu.md
askorupka Aug 3, 2024
a3b62cb
Update docs/src/guide/gpu.md
askorupka Aug 3, 2024
a144ccf
Update docs/src/guide/gpu.md
askorupka Aug 3, 2024
22b35a0
Update docs/src/guide/gpu.md
askorupka Aug 3, 2024
7d03ef7
Update docs/src/guide/gpu.md
askorupka Aug 3, 2024
6c11e3c
chore: add PR review suggestions
askorupka Aug 17, 2024
6a6951a
Merge branch 'master' into distributed
askorupka Aug 17, 2024
41acd3f
chore: fix docs
askorupka Aug 17, 2024
0e33cfa
fix: add runtests.jl
askorupka Aug 19, 2024
58ae10f
chore: small docs update
askorupka Aug 19, 2024
053dcc7
chore: remove pkgs from deps
askorupka Aug 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release.

## v0.14.18
* Add [support for distributed data parallel training](https://github.com/FluxML/Flux.jl/pull/2446).
* MPI and NCCL backend available with `FluxMPIExt` and `FluxMPINCCLExt` extensions respectively.

## v0.14.17
* Add [support for Enzyme](https://github.com/FluxML/Flux.jl/pull/2446) with `Flux.train!`.

Expand Down
8 changes: 8 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -26,14 +27,18 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[extensions]
FluxAMDGPUExt = "AMDGPU"
FluxCUDAExt = "CUDA"
FluxCUDAcuDNNExt = ["CUDA", "cuDNN"]
FluxEnzymeExt = "Enzyme"
FluxMPIExt = "MPI"
FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
FluxMetalExt = "Metal"

[compat]
Expand All @@ -45,14 +50,17 @@ Compat = "4.10.0"
Enzyme = "0.12"
Functors = "0.4"
MLUtils = "0.4"
MPI = "0.20.19"
MacroTools = "0.5"
Metal = "0.5, 1"
NCCL = "0.1.1"
NNlib = "0.9.22"
OneHotArrays = "0.2.4"
Optimisers = "0.3.3"
Preferences = "1"
ProgressLogging = "0.1"
Reexport = "1.0"
Setfield = "1.1"
SpecialFunctions = "2.1.2"
Statistics = "1"
Zygote = "0.6.67"
Expand Down
117 changes: 117 additions & 0 deletions docs/src/guide/gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,120 @@ Flux.supported_devices
Flux.get_device
Flux.gpu_backend!
```

## Distributed data parallel training

askorupka marked this conversation as resolved.
Show resolved Hide resolved
!!! danger "Experimental"

Distributed support is experimental and could change in the future.


Flux supports now distributed data parallel training with `DistributedUtils` module.
If you want to run your code on multiple GPUs, you have to install `MPI.jl` (see [docs](https://juliaparallel.org/MPI.jl/stable/usage/) for more info).

```julia-repl
julia> using MPI

julia> MPI.install_mpiexecjl()
```

Now you can run your code with `mpiexecjl --project=. -n <np> julia <filename>.jl` from CLI.

You can use either the `MPIBackend` or `NCCLBackend`, the latter only if also `NCCL.jl` is loaded. First, initialize a backend with `DistributedUtils.initialize`, e.g.

```julia-repl
julia> using Flux, MPI, NCCL, CUDA

julia> CUDA.allowscalar(false)

julia> DistributedUtils.initialize(NCCLBackend)
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

julia> backend = DistributedUtils.get_distributed_backend(NCCLBackend)
NCCLBackend{Communicator, MPIBackend{MPI.Comm}}(Communicator(Ptr{NCCL.LibNCCL.ncclComm} @0x000000000607a660), MPIBackend{MPI.Comm}(MPI.Comm(1140850688)))
```
askorupka marked this conversation as resolved.
Show resolved Hide resolved

Pass your model, as well as any data to GPU device.
```julia-repl
julia> model = Chain(Dense(1 => 256, tanh), Dense(256 => 1)) |> gpu
Chain(
Dense(1 => 256, tanh), # 512 parameters
Dense(256 => 1), # 257 parameters
) # Total: 4 arrays, 769 parameters, 744 bytes.

julia> x = rand(Float32, 1, 16) |> gpu
1×16 CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}:
0.239324 0.331029 0.924996 0.55593 0.853093 0.874513 0.810269 0.935858 0.477176 0.564591 0.678907 0.729682 0.96809 0.115833 0.66191 0.75822

julia> y = x .^ 3
1×16 CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}:
0.0137076 0.0362744 0.791443 0.171815 0.620854 0.668804 0.53197 0.819654 0.108651 0.179971 0.312918 0.388508 0.907292 0.00155418 0.29 0.435899
```

In this case, we are training on a total of `16 * number of processes` samples. You can also use `DistributedUtils.DistributedDataContainer` to split the data uniformly across processes (or do it manually).

```julia-repl
julia> data = DistributedUtils.DistributedDataContainer(backend, x)
Flux.DistributedUtils.DistributedDataContainer(Float32[0.23932439 0.33102947 … 0.66191036 0.75822026], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16])
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
```

You have to wrap your model in `DistributedUtils.FluxDistributedModel` and synchronize it (broadcast accross all processes):
```julia-repl
julia> model = DistributedUtils.synchronize!!(backend, DistributedUtils.FluxDistributedModel(model); root=0)
Chain(
Dense(1 => 256, tanh), # 512 parameters

Dense(256 => 1), # 257 parameters
) # Total: 4 arrays, 769 parameters, 744 bytes.
```

Time to set up an optimizer by using `DistributedUtils.DistributedOptimizer` and synchronize it as well.
```julia-repl
julia> using Optimisers

julia> opt = DistributedUtils.DistributedOptimizer(backend, Optimisers.Adam(0.001f0))
DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8))

julia> st_opt = Optimisers.setup(opt, model)
(layers = ((weight = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999))), bias = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999))), σ = ()), (weight = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0], Float32[0.0], (0.9, 0.999))), σ = ())),)

julia> st_opt = DistributedUtils.synchronize!!(backend, st_opt; root=0)
(layers = ((weight = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999))), bias = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999))), σ = ()), (weight = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0], Float32[0.0], (0.9, 0.999))), σ = ())),)
```

Now you can define loss and train the model.
```julia-repl
julia> loss(model) = mean((model(x) .- y).^2)
loss (generic function with 1 method)

julia> for epoch in 1:100
global model, st_opt
l, grad = Zygote.withgradient(loss, model)
println("Epoch $epoch: Loss $l")
st_opt, model = Optimisers.update(st_opt, model, grad[1])
end
Epoch 1: Loss 0.011638729
Epoch 2: Loss 0.0116432225
Epoch 3: Loss 0.012763695
...
```

Remember that in order to run it on multiple GPUs you have to run from CLI `mpiexecjl --project=. -n <np> julia <filename>.jl`,
where `<np>` is the number of processes that you want to use. The number of processes usually corresponds to the number of gpus.

By default `MPI.jl` MPI installation is CUDA-unaware so if you want to run it in CUDA-aware mode, read more [here](https://juliaparallel.org/MPI.jl/stable/usage/#CUDA-aware-MPI-support) on custom installation and rebuilding `MPI.jl`.
Then test if your MPI is CUDA-aware by
```julia-repl
julia> import Pkg
julia> Pkg.test("MPI"; test_args=["--backend=CUDA"])
```

If it is, set your local preference as below
```julia-repl
julia> using Preferences
julia> set_preferences!("Flux", "FluxDistributedMPICUDAAware" => true)
```

!!! warning "Known shortcomings"

We don't run CUDA-aware tests so you're running it at own risk.

183 changes: 183 additions & 0 deletions ext/FluxMPIExt/FluxMPIExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
module FluxMPIExt

using CUDA
using Flux: MPIBackend, NCCLBackend, DistributedUtils,
AbstractDevice, FluxCUDADevice, FluxAMDGPUDevice, cpu, gpu,
get_device, MPI_CUDA_AWARE, MPI_ROCM_AWARE
using MPI: MPI

if Base.find_package("AMDGPU") !== nothing
using AMDGPU
end


function DistributedUtils.__initialize(
::Type{MPIBackend}; cuda_devices=nothing, amdgpu_devices=nothing,
force_cuda::Bool=false, caller::String="", force_amdgpu::Bool=false) # Undocumented internal kwarg
!MPI.Initialized() && MPI.Init()
DistributedUtils.MPI_Initialized[] = true

local_rank = MPI.Comm_rank(MPI.COMM_WORLD)

if cuda_devices !== missing && CUDA.functional()
if cuda_devices === nothing
CUDA.device!((local_rank + 1) % length(CUDA.devices()))
else
CUDA.device!(cuda_devices[local_rank + 1])
end
elseif force_cuda
error(lazy"CUDA devices are not functional and `force_cuda` is set to `true`. This is caused by backend: $(caller).")
end

if Base.find_package("AMDGPU") !== nothing
if amdgpu_devices !== missing && AMDGPU.functional()
if amdgpu_devices === nothing
AMDGPU.device!((local_rank + 1) % length(AMDGPU.devices()))
else
AMDGPU.device!(amdgpu_devices[local_rank + 1])
end
elseif force_amdgpu
error(lazy"AMDGPU devices are not functional (or `LuxAMDGPU.jl` not loaded) and `force_amdgpu` is set to `true`. This is caused by backend: $(caller).")
end
end

return
end

DistributedUtils.__get_distributed_backend(::Type{MPIBackend}) = MPIBackend(MPI.COMM_WORLD)

DistributedUtils.local_rank(backend::MPIBackend) = MPI.Comm_rank(backend.comm)

DistributedUtils.total_workers(backend::MPIBackend) = MPI.Comm_size(backend.comm)

# Broadcast
# Union with Function is because of Flux.cpu istypeof Function
# We need CPU in case of non CUDA-aware implementation
function DistributedUtils.__bcast!(
backend::MPIBackend, sendrecvbuf, dev::Union{AbstractDevice, Function}; root=0)
MPI.Bcast!(sendrecvbuf, backend.comm; root)
return sendrecvbuf
end

function DistributedUtils.__bcast!(
backend::MPIBackend, sendbuf, recvbuf, dev::Union{AbstractDevice, Function}; root=0)
return DistributedUtils.__bcast!(
backend, ifelse(DistributedUtils.local_rank(backend) == root, sendbuf, recvbuf),
dev; root)
end

# if MPI implementation is not CUDA-aware
# we have to move data to CPU first
for (aware, dType) in ((MPI_CUDA_AWARE, FluxCUDADevice), (MPI_ROCM_AWARE, FluxAMDGPUDevice))
if !aware
@eval begin
function DistributedUtils.__bcast!(
backend::MPIBackend, sendrecvbuf, dev::$dType; root=0)
sendrecvbuf_ = sendrecvbuf |> cpu
DistributedUtils.__bcast!(backend, sendrecvbuf_, cpu; root)
sendrecvbuf |> gpu
return sendrecvbuf
end

function DistributedUtils.__bcast!(
backend::MPIBackend, sendbuf, recvbuf, dev::$dType; root=0)
sendbuf_ = sendbuf |> cpu
recvbuf_ = recvbuf |> cpu
DistributedUtils.__bcast!(backend, sendbuf_, recvbuf_, cpu; root)
recvbuf |> gpu
return recvbuf
end
end
end
end


# Allreduce
function DistributedUtils.__allreduce!(
backend::MPIBackend, sendrecvbuf, op::F, dev::Union{AbstractDevice, Function};) where {F}
mpiop = ifelse(op === DistributedUtils.avg, +, op)
MPI.Allreduce!(sendrecvbuf, mpiop, backend.comm)
if op === DistributedUtils.avg
sendrecvbuf ./= DistributedUtils.total_workers(backend)
end
return sendrecvbuf
end

function DistributedUtils.__allreduce!(
backend::MPIBackend, sendbuf, recvbuf, op::F, dev::Union{AbstractDevice, Function};) where {F}
mpiop = ifelse(op === DistributedUtils.avg, +, op)
MPI.Allreduce!(sendbuf, recvbuf, mpiop, backend.comm)
if op === DistributedUtils.avg
recvbuf ./= DistributedUtils.total_workers(backend)
end
return recvbuf
end

for (aware, dType) in ((MPI_CUDA_AWARE, FluxCUDADevice), (MPI_ROCM_AWARE, FluxAMDGPUDevice))
if !aware
@eval begin
function DistributedUtils.__allreduce!(
backend::MPIBackend, sendrecvbuf, op::F, dev::$dType) where {F}
sendrecvbuf_ = sendrecvbuf |> cpu
DistributedUtils.__allreduce!(backend, sendrecvbuf_, op, cpu)
sendrecvbuf |> gpu
return sendrecvbuf
end

function DistributedUtils.__allreduce!(
backend::MPIBackend, sendbuf, recvbuf, op::F, dev::$dType) where {F}
sendbuf_ = sendbuf |> cpu
recvbuf_ = recvbuf |> cpu
DistributedUtils.__allreduce!(backend, sendbuf_, recvbuf_, op, cpu)
recvbuf |> gpu
return recvbuf
end
end
end
end

# Reduce
function DistributedUtils.__reduce!(backend::MPIBackend, sendrecvbuf, op::F,
dev::Union{AbstractDevice, Function}; root::Int) where {F}
mpiop = ifelse(op === DistributedUtils.avg, +, op)
MPI.Reduce!(sendrecvbuf, mpiop, backend.comm; root)
if op === DistributedUtils.avg
sendrecvbuf ./= DistributedUtils.total_workers(backend)
end
return sendrecvbuf
end

function DistributedUtils.__reduce!(backend::MPIBackend, sendbuf, recvbuf, op::F,
dev::Union{AbstractDevice, Function}; root::Int) where {F}
mpiop = ifelse(op === DistributedUtils.avg, +, op)
MPI.Reduce!(sendbuf, recvbuf, mpiop, backend.comm; root)
if op === DistributedUtils.avg
recvbuf ./= DistributedUtils.total_workers(backend)
end
return recvbuf
end

for (aware, dType) in ((MPI_CUDA_AWARE, FluxCUDADevice), (MPI_ROCM_AWARE, FluxAMDGPUDevice))
if !aware
@eval begin
function DistributedUtils.__reduce!(backend::MPIBackend, sendrecvbuf, op::F,
dev::$dType; root::Int) where {F}
sendrecvbuf_ = sendrecvbuf |> cpu
DistributedUtils.__reduce!(backend, sendrecvbuf_, op, cpu; root)
sendrecvbuf |> gpu
return sendrecvbuf
end

function DistributedUtils.__reduce!(backend::MPIBackend, sendbuf, recvbuf,
op::F, dev::$dType; root::Int) where {F}
sendbuf_ = sendbuf |> cpu
recvbuf_ = recvbuf |> cpu
DistributedUtils.__reduce!(backend, sendbuf_, recvbuf_, op, cpu; root)
recvbuf |> gpu
return recvbuf
end
end
end
end

end
Loading
Loading