Skip to content

Commit

Permalink
fix: CUDA package optional for FluxMPIExt (#2488)
Browse files Browse the repository at this point in the history
  • Loading branch information
askorupka authored Oct 4, 2024
1 parent 66ecf7e commit ef19396
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions ext/FluxMPIExt/FluxMPIExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
module FluxMPIExt

using CUDA
if Base.find_package("CUDA") !== nothing
using CUDA
end

using Flux: MPIBackend, NCCLBackend, DistributedUtils,
AbstractDevice, FluxCUDADevice, FluxAMDGPUDevice, cpu, gpu,
get_device, MPI_CUDA_AWARE, MPI_ROCM_AWARE
Expand All @@ -19,14 +22,16 @@ function DistributedUtils.__initialize(

local_rank = MPI.Comm_rank(MPI.COMM_WORLD)

if cuda_devices !== missing && CUDA.functional()
if cuda_devices === nothing
CUDA.device!((local_rank + 1) % length(CUDA.devices()))
else
CUDA.device!(cuda_devices[local_rank + 1])
if Base.find_package("CUDA") !== nothing
if cuda_devices !== missing && CUDA.functional()
if cuda_devices === nothing
CUDA.device!((local_rank + 1) % length(CUDA.devices()))
else
CUDA.device!(cuda_devices[local_rank + 1])
end
elseif force_cuda
error(lazy"CUDA devices are not functional and `force_cuda` is set to `true`. This is caused by backend: $(caller).")
end
elseif force_cuda
error(lazy"CUDA devices are not functional and `force_cuda` is set to `true`. This is caused by backend: $(caller).")
end

if Base.find_package("AMDGPU") !== nothing
Expand Down

0 comments on commit ef19396

Please sign in to comment.