From c26a65e7ef7833e2b0f5111ec9af0e507269676b Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 10 Oct 2024 07:15:41 +0200 Subject: [PATCH 01/21] removed Flux devices --- NEWS.md | 3 + Project.toml | 2 + src/Flux.jl | 13 +++ src/deprecations.jl | 31 +++++ src/devices.jl | 14 +++ src/functor.jl | 279 -------------------------------------------- 6 files changed, 63 insertions(+), 279 deletions(-) create mode 100644 src/devices.jl diff --git a/NEWS.md b/NEWS.md index 0448b74d77..51dde806e9 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,9 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release. +## v0.14.22 +* Data movement between devices is now provided by [MLDataDevice.jl](https://github.com/LuxDL/MLDataDevices.jl). + ## v0.14.18 * Add [support for distributed data parallel training](https://github.com/FluxML/Flux.jl/pull/2446). * MPI and NCCL backend available with `FluxMPIExt` and `FluxMPINCCLExt` extensions respectively. diff --git a/Project.toml b/Project.toml index 49869e3bb1..d805332f20 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" @@ -49,6 +50,7 @@ ChainRulesCore = "1.12" Compat = "4.10.0" Enzyme = "0.12, 0.13" Functors = "0.4" +MLDataDevices = "1.2.0" MLUtils = "0.4" MPI = "0.20.19" MacroTools = "0.5" diff --git a/src/Flux.jl b/src/Flux.jl index 7eac8ee7d6..c47dd12d8f 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -18,6 +18,17 @@ using Zygote: Params, @adjoint, gradient, pullback using Zygote.ForwardDiff: value export gradient +@reexport using MLDataDevices: MLDataDevices, gpu_backend!, supported_gpu_backends, reset_gpu_device!, + default_device_rng, + gpu_device, cpu_device, xla_device, + CPUDevice, + CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, + XLADevice, + # get_device, # we define get_device here for retrocompatibility + get_device_type, + DeviceIterator + + # Pirate error to catch a common mistake. (Internal function `base` because overloading `update!` is more likely to give ambiguities.) Optimisers.base(dx::Zygote.Grads) = error("Optimisers.jl cannot be used with Zygote.jl's implicit gradients, `Params` & `Grads`") @@ -92,6 +103,8 @@ include("deprecations.jl") include("losses/Losses.jl") using .Losses +include("devices.jl") + # Distributed Training include("distributed/backend.jl") include("distributed/public_api.jl") diff --git a/src/deprecations.jl b/src/deprecations.jl index 7e29dcad1b..7ffea8ccb8 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -218,6 +218,37 @@ function loadmodel!(dst::ConvTranspose, src::NamedTuple{(:σ, :weight, :bias, :s loadmodel!(dst, new_src; kw...) end +function get_device(; verbose::Bool=false) + Base.depwarn("get_device() is deprecated. Use `gpu_device()` instead.", :get_device) + return MLDataDevices.gpu_device() +end + +function get_device(backend::String, idx::Int = 0) + Base.depwarn("get_device(backend::String, idx::Int) is deprecated. Use `gpu_device(idx)` instead.", :get_device) + if backend == "AMD" + @warn "\"AMD\" backend is deprecated. Please use \"AMDGPU\" instead." maxlog=1 + backend = "AMDGPU" + end + if backend == "CPU" + return MLDataDevices.CPUDevice() + else + return _get_device(Val(Symbol(backend)), idx) + end +end + +function _get_device(::Val{D}, idx) where D + if D ∈ (:CUDA, :AMDGPU, :Metal) + error(string("Unavailable backend: ", D,". Try importing the corresponding package with `using ", D, "`.")) + else + error(string("Unsupported backend: ", D, ". Supported backends are ", (:CUDA, :AMDGPU, :Metal), ".")) + end +end + +function supported_devices() + Base.depwarn("supported_devices() is deprecated. Use `supported_gpu_backends()` instead.", :supported_devices) + return MLDataDevices.supported_gpu_backends() +end + # v0.15 deprecations # Enable these when 0.15 is released, and delete const ClipGrad = Optimise.ClipValue etc: diff --git a/src/devices.jl b/src/devices.jl new file mode 100644 index 0000000000..cd7a92c3ed --- /dev/null +++ b/src/devices.jl @@ -0,0 +1,14 @@ +# TODO get docstring from MLDataDevices.get_device +get_device(x) = MLDataDevices.get_device(x) + +function (device::MLDataDevices.AbstractDevice)(d::MLUtils.DataLoader) + MLUtils.DataLoader(MLUtils.mapobs(device, d.data), + d.batchsize, + d.buffer, + d.partial, + d.shuffle, + d.parallel, + d.collate, + d.rng, + ) +end diff --git a/src/functor.jl b/src/functor.jl index eeaffab1c3..de7ba1d0f1 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -194,34 +194,6 @@ const GPU_BACKENDS = ("CUDA", "AMDGPU", "Metal", "CPU") const GPU_BACKEND_ORDER = Dict(collect(zip(GPU_BACKENDS, 1:length(GPU_BACKENDS)))) const GPU_BACKEND = @load_preference("gpu_backend", "CUDA") -""" - gpu_backend!(backend::String) - -Set the GPU backend to `backend` in the `LocalPreferences.toml` file in you project directory. -After restarting Julia, the new backend will affect all subsequent calls to [`gpu`](@ref) and [`get_device`](@ref). - -The supported backends are `"CUDA"`, `"AMDGPU"`, `"Metal"` and `"CPU"`. -""" -function gpu_backend!(backend::String) - if backend == GPU_BACKEND - @info """ - GPU backend is already set to: $backend. - No need to do anything else. - """ - return - end - - backend in GPU_BACKENDS || throw(ArgumentError(""" - Unsupported GPU backend: $backend. - Supported backends are: $GPU_BACKENDS. - """)) - - @set_preferences!("gpu_backend" => backend) - @info """ - New GPU backend set: $backend. - Restart your Julia session for this change to take effect! - """ -end """ gpu(m) @@ -478,254 +450,3 @@ function cpu(d::MLUtils.DataLoader) d.rng, ) end - -# Defining device interfaces. -""" - Flux.AbstractDevice <: Function - -An abstract type representing `device` objects for different GPU backends. The currently supported backends are `"CUDA"`, `"AMDGPU"`, `"Metal"` and `"CPU"`; the `"CPU"` backend is the fallback case when no GPU is available. GPU extensions of Flux define subtypes of this type. - -""" -abstract type AbstractDevice <: Function end - -function (device::AbstractDevice)(d::MLUtils.DataLoader) - MLUtils.DataLoader(MLUtils.mapobs(device, d.data), - d.batchsize, - d.buffer, - d.partial, - d.shuffle, - d.parallel, - d.collate, - d.rng, - ) -end - -function _get_device_name(::T)::String where {T <: AbstractDevice} end - -## check device availability; more definitions in corresponding extensions -_isavailable(::Nothing) = false -_isfunctional(::Nothing) = false - -_isavailable(::AbstractDevice) = false -_isfunctional(::AbstractDevice) = false - -""" - Flux.FluxCPUDevice <: Flux.AbstractDevice - -A type representing `device` objects for the `"CPU"` backend for Flux. This is the fallback case when no GPU is available to Flux. -""" -Base.@kwdef struct FluxCPUDevice <: AbstractDevice end - -(::FluxCPUDevice)(x) = cpu(x) -_isavailable(::FluxCPUDevice) = true -_isfunctional(::FluxCPUDevice) = true -_get_device_name(::FluxCPUDevice) = "CPU" - -""" - FluxCUDADevice <: AbstractDevice - -A type representing `device` objects for the `"CUDA"` backend for Flux. -""" -Base.@kwdef struct FluxCUDADevice <: AbstractDevice - deviceID -end - -""" - FluxAMDGPUDevice <: AbstractDevice - -A type representing `device` objects for the `"AMDGPU"` backend for Flux. -""" -Base.@kwdef struct FluxAMDGPUDevice <: AbstractDevice - deviceID -end - -""" - FluxMetalDevice <: AbstractDevice - -A type representing `device` objects for the `"Metal"` backend for Flux. -""" -Base.@kwdef struct FluxMetalDevice <: AbstractDevice - deviceID -end - -## device list. order is important -const DEVICES = Ref{Vector{Union{Nothing, AbstractDevice}}}(Vector{Union{Nothing, AbstractDevice}}(nothing, length(GPU_BACKENDS))) -DEVICES[][GPU_BACKEND_ORDER["CPU"]] = FluxCPUDevice() - -## get device - -""" - Flux.supported_devices() - -Get all supported backends for Flux, in order of preference. - -# Example - -```jldoctest -julia> using Flux; - -julia> Flux.supported_devices() -("CUDA", "AMDGPU", "Metal", "CPU") -``` -""" -supported_devices() = GPU_BACKENDS - -""" - Flux.get_device(; verbose=false)::Flux.AbstractDevice - -Returns a `device` object for the most appropriate backend for the current Julia session. - -First, the function checks whether a backend preference has been set via the [`Flux.gpu_backend!`](@ref) function. If so, an attempt is made to load this backend. If the corresponding trigger package has been loaded and the backend is functional, a `device` corresponding to the given backend is loaded. Otherwise, the backend is chosen automatically. To update the backend preference, use [`Flux.gpu_backend!`](@ref). - -If there is no preference, then for each of the `"CUDA"`, `"AMDGPU"`, `"Metal"` and `"CPU"` backends in the given order, this function checks whether the given backend has been loaded via the corresponding trigger package, and whether the backend is functional. If so, the `device` corresponding to the backend is returned. If no GPU backend is available, a `Flux.FluxCPUDevice` is returned. - -If `verbose` is set to `true`, then the function prints informative log messages. - -# Examples -For the example given below, the backend preference was set to `"AMDGPU"` via the [`gpu_backend!`](@ref) function. - -```julia-repl -julia> using Flux; - -julia> model = Dense(2 => 3) -Dense(2 => 3) # 9 parameters - -julia> device = Flux.get_device(; verbose=true) # this will just load the CPU device -[ Info: Using backend set in preferences: AMDGPU. -┌ Warning: Trying to use backend: AMDGPU but it's trigger package is not loaded. -│ Please load the package and call this function again to respect the preferences backend. -└ @ Flux ~/fluxml/Flux.jl/src/functor.jl:638 -[ Info: Using backend: CPU. -(::Flux.FluxCPUDevice) (generic function with 1 method) - -julia> model = model |> device -Dense(2 => 3) # 9 parameters - -julia> model.weight -3×2 Matrix{Float32}: - -0.304362 -0.700477 - -0.861201 0.67825 - -0.176017 0.234188 -``` - -Here is the same example, but using `"CUDA"`: - -```julia-repl -julia> using Flux, CUDA; - -julia> model = Dense(2 => 3) -Dense(2 => 3) # 9 parameters - -julia> device = Flux.get_device(; verbose=true) -[ Info: Using backend set in preferences: AMDGPU. -┌ Warning: Trying to use backend: AMDGPU but it's trigger package is not loaded. -│ Please load the package and call this function again to respect the preferences backend. -└ @ Flux ~/fluxml/Flux.jl/src/functor.jl:637 -[ Info: Using backend: CUDA. -(::Flux.FluxCUDADevice) (generic function with 1 method) - -julia> model = model |> device -Dense(2 => 3) # 9 parameters - -julia> model.weight -3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}: - 0.820013 0.527131 - -0.915589 0.549048 - 0.290744 -0.0592499 -``` -""" -function get_device(; verbose=false)::AbstractDevice - backend = @load_preference("gpu_backend", nothing) - - if backend !== nothing - allowed_backends = supported_devices() - idx = findfirst(isequal(backend), allowed_backends) - if backend ∉ allowed_backends - @warn """ - `gpu_backend` preference is set to $backend, which is not allowed. - Defaulting to automatic device selection. - """ maxlog=1 - else - verbose && @info "Using backend set in preferences: $backend." - device = DEVICES[][idx] - - if !_isavailable(device) - @warn """ - Trying to use backend: $backend but it's trigger package is not loaded. - Please load the package and call this function again to respect the preferences backend. - """ - else - if _isfunctional(device) - return device - else - @warn "Backend: $backend from the set preferences is not functional. Defaulting to automatic device selection." - end - end - end - end - - for backend in GPU_BACKENDS - device = DEVICES[][GPU_BACKEND_ORDER[backend]] - if _isavailable(device) - if _isfunctional(device) - verbose && @info "Using backend: $backend." - return device - end - end - end -end - -""" - Flux.get_device(backend::String, idx::Int = 0)::Flux.AbstractDevice - -Get a device object for a backend specified by the string `backend` and `idx`. The currently supported values -of `backend` are `"CUDA"`, `"AMDGPU"` and `"CPU"`. `idx` must be an integer value between `0` and the number of available devices. - -# Examples - -```julia-repl -julia> using Flux, CUDA; - -julia> CUDA.devices() -CUDA.DeviceIterator() for 3 devices: -0. GeForce RTX 2080 Ti -1. GeForce RTX 2080 Ti -2. TITAN X (Pascal) - -julia> device0 = Flux.get_device("CUDA", 0) -(::Flux.FluxCUDADevice) (generic function with 1 method) - -julia> device0.deviceID -CuDevice(0): GeForce RTX 2080 Ti - -julia> device1 = Flux.get_device("CUDA", 1) -(::Flux.FluxCUDADevice) (generic function with 1 method) - -julia> device1.deviceID -CuDevice(1): GeForce RTX 2080 Ti - -julia> cpu_device = Flux.get_device("CPU") -(::Flux.FluxCPUDevice) (generic function with 1 method) - -``` -""" -function get_device(backend::String, idx::Int = 0) - if backend == "AMD" - @warn "\"AMD\" backend is deprecated. Please use \"AMDGPU\" instead." maxlog=1 - backend = "AMDGPU" - end - if backend == "CPU" - return FluxCPUDevice() - else - return get_device(Val(Symbol(backend)), idx) - end -end - -# Fallback -function get_device(::Val{D}, idx) where D - if D ∈ (:CUDA, :AMDGPU, :Metal) - error("Unavailable backend: $(D). Try importing the corresponding package with `using $D`.") - else - error("Unsupported backend: $(D). Supported backends are $(GPU_BACKENDS).") - end -end From 3a6bf8678342a77638244372e147780fb8a21683 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 10 Oct 2024 07:48:22 +0200 Subject: [PATCH 02/21] fix gpu extensions --- docs/src/guide/gpu.md | 1 + ext/FluxAMDGPUExt/FluxAMDGPUExt.jl | 11 ----------- ext/FluxAMDGPUExt/functor.jl | 8 ++------ ext/FluxCUDAExt/FluxCUDAExt.jl | 14 -------------- ext/FluxCUDAExt/functor.jl | 8 ++------ ext/FluxMetalExt/FluxMetalExt.jl | 6 ------ ext/FluxMetalExt/functor.jl | 4 ++-- src/deprecations.jl | 2 +- test/ext_amdgpu/get_devices.jl | 6 +++--- test/ext_cuda/get_devices.jl | 6 +++--- test/ext_metal/get_devices.jl | 9 --------- test/functors.jl | 10 ++-------- 12 files changed, 16 insertions(+), 69 deletions(-) diff --git a/docs/src/guide/gpu.md b/docs/src/guide/gpu.md index b9cd6d1f8c..1744bcdbed 100644 --- a/docs/src/guide/gpu.md +++ b/docs/src/guide/gpu.md @@ -232,6 +232,7 @@ More information for conditional use of GPUs in CUDA.jl can be found in its [doc ## Using device objects +///// TODO //// As a more convenient syntax, Flux allows the usage of GPU `device` objects which can be used to easily transfer models to GPUs (and defaulting to using the CPU if no GPU backend is available). This syntax has a few advantages including automatic selection of the GPU backend and type stability of data movement. To do this, the [`Flux.get_device`](@ref) function can be used. `Flux.get_device` first checks for a GPU preference, and if possible returns a device for the preference backend. For instance, consider the following example, where we load the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) package to use an NVIDIA GPU (`"CUDA"` is the default preference): diff --git a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl index 0017295b03..8e8086c1a8 100644 --- a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl +++ b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl @@ -17,16 +17,6 @@ const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat # Set to boolean on the first call to check_use_amdgpu const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing) -function (device::Flux.FluxAMDGPUDevice)(x) - if device.deviceID === nothing - Flux.gpu(Flux.FluxAMDGPUAdaptor(), x) - else - return Flux.gpu(Flux.FluxAMDGPUAdaptor(AMDGPU.device_id(device.deviceID) - 1), x) # subtracting 1, because device_id returns a positive integer - end -end -Flux._get_device_name(::Flux.FluxAMDGPUDevice) = "AMDGPU" -Flux._isavailable(::Flux.FluxAMDGPUDevice) = true -Flux._isfunctional(::Flux.FluxAMDGPUDevice) = AMDGPU.functional() function check_use_amdgpu() if !isnothing(USE_AMDGPU[]) @@ -55,7 +45,6 @@ include("conv.jl") function __init__() Flux.AMDGPU_LOADED[] = true - Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["AMDGPU"]] = AMDGPU.functional() ? Flux.FluxAMDGPUDevice(AMDGPU.device()) : Flux.FluxAMDGPUDevice(nothing) end # TODO diff --git a/ext/FluxAMDGPUExt/functor.jl b/ext/FluxAMDGPUExt/functor.jl index d3a27e54c7..c2b6420ca1 100644 --- a/ext/FluxAMDGPUExt/functor.jl +++ b/ext/FluxAMDGPUExt/functor.jl @@ -108,10 +108,6 @@ function Adapt.adapt_structure(to::FluxCPUAdaptor, m::AMDGPU_CONV) Adapt.adapt(to, m.bias), _other_args(m)...) end -function Flux.get_device(::Val{:AMDGPU}, id::Int) # id should start from 0 - old_id = AMDGPU.device_id(AMDGPU.device()) - 1 # subtracting 1 because ids start from 0 - AMDGPU.device!(AMDGPU.devices()[id + 1]) # adding 1 because ids start from 0 - device = Flux.FluxAMDGPUDevice(AMDGPU.device()) - AMDGPU.device!(AMDGPU.devices()[old_id + 1]) - return device +function Flux._get_device(::Val{:AMDGPU}, id::Int) # id should start from 0 + return MLDataDevices.gpu_device(id+1, force=true) end diff --git a/ext/FluxCUDAExt/FluxCUDAExt.jl b/ext/FluxCUDAExt/FluxCUDAExt.jl index 9948c5f4c0..0a0f67adc7 100644 --- a/ext/FluxCUDAExt/FluxCUDAExt.jl +++ b/ext/FluxCUDAExt/FluxCUDAExt.jl @@ -14,17 +14,6 @@ import Adapt: adapt_storage const USE_CUDA = Ref{Union{Nothing, Bool}}(nothing) -function (device::Flux.FluxCUDADevice)(x) - if device.deviceID === nothing - return Flux.gpu(Flux.FluxCUDAAdaptor(), x) - else - return Flux.gpu(Flux.FluxCUDAAdaptor(device.deviceID.handle), x) - end -end -Flux._get_device_name(::Flux.FluxCUDADevice) = "CUDA" -Flux._isavailable(::Flux.FluxCUDADevice) = true -Flux._isfunctional(::Flux.FluxCUDADevice) = CUDA.functional() - function check_use_cuda() if !isnothing(USE_CUDA[]) return @@ -48,9 +37,6 @@ include("utils.jl") function __init__() Flux.CUDA_LOADED[] = true - ## add device to available devices - Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["CUDA"]] = CUDA.functional() ? Flux.FluxCUDADevice(CUDA.device()) : Flux.FluxCUDADevice(nothing) - try Base.require(Main, :cuDNN) catch diff --git a/ext/FluxCUDAExt/functor.jl b/ext/FluxCUDAExt/functor.jl index dc8649fff0..205f366b24 100644 --- a/ext/FluxCUDAExt/functor.jl +++ b/ext/FluxCUDAExt/functor.jl @@ -56,10 +56,6 @@ function _cuda(id::Union{Nothing, Int}, x) fmap(x -> Adapt.adapt(FluxCUDAAdaptor(id), x), x; exclude=Flux._isleaf) end -function Flux.get_device(::Val{:CUDA}, id::Int) - old_id = CUDA.device().handle - CUDA.device!(id) - device = Flux.FluxCUDADevice(CUDA.device()) - CUDA.device!(old_id) - return device +function Flux._get_device(::Val{:CUDA}, id::Int) + return MLDataUtils.gpu_device(id+1, force=true) end diff --git a/ext/FluxMetalExt/FluxMetalExt.jl b/ext/FluxMetalExt/FluxMetalExt.jl index a11046d244..27316c3b16 100644 --- a/ext/FluxMetalExt/FluxMetalExt.jl +++ b/ext/FluxMetalExt/FluxMetalExt.jl @@ -12,11 +12,6 @@ using Zygote const USE_METAL = Ref{Union{Nothing, Bool}}(nothing) -(::Flux.FluxMetalDevice)(x) = Flux.gpu(Flux.FluxMetalAdaptor(), x) -Flux._get_device_name(::Flux.FluxMetalDevice) = "Metal" -Flux._isavailable(::Flux.FluxMetalDevice) = true -Flux._isfunctional(::Flux.FluxMetalDevice) = Metal.functional() - function check_use_metal() isnothing(USE_METAL[]) || return @@ -35,7 +30,6 @@ include("functor.jl") function __init__() Flux.METAL_LOADED[] = true - Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["Metal"]] = Metal.functional() ? Flux.FluxMetalDevice(Metal.current_device()) : Flux.FluxMetalDevice(nothing) end end diff --git a/ext/FluxMetalExt/functor.jl b/ext/FluxMetalExt/functor.jl index 8b20fdafe5..ad59cc7117 100644 --- a/ext/FluxMetalExt/functor.jl +++ b/ext/FluxMetalExt/functor.jl @@ -33,8 +33,8 @@ function _metal(x) fmap(x -> Adapt.adapt(FluxMetalAdaptor(), x), x; exclude=_isleaf) end -function Flux.get_device(::Val{:Metal}, id::Int) +function Flux._get_device(::Val{:Metal}, id::Int) @assert id == 0 "Metal backend only supports one device at the moment" - return Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["Metal"]] + return MLDataDevices.gpu_device() end diff --git a/src/deprecations.jl b/src/deprecations.jl index 7ffea8ccb8..8dadadfd6d 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -224,7 +224,7 @@ function get_device(; verbose::Bool=false) end function get_device(backend::String, idx::Int = 0) - Base.depwarn("get_device(backend::String, idx::Int) is deprecated. Use `gpu_device(idx)` instead.", :get_device) + Base.depwarn("get_device(backend::String, idx::Int) is deprecated. Use `gpu_device(idx+1)` instead.", :get_device) if backend == "AMD" @warn "\"AMD\" backend is deprecated. Please use \"AMDGPU\" instead." maxlog=1 backend = "AMDGPU" diff --git a/test/ext_amdgpu/get_devices.jl b/test/ext_amdgpu/get_devices.jl index caec773720..f261f80efd 100644 --- a/test/ext_amdgpu/get_devices.jl +++ b/test/ext_amdgpu/get_devices.jl @@ -1,7 +1,7 @@ -amdgpu_device = Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["AMDGPU"]] +amdgpu_device = gpu_device() # should pass, whether or not AMDGPU is functional -@test typeof(amdgpu_device) <: Flux.FluxAMDGPUDevice +@test typeof(amdgpu_device) <: Flux.AMDGPUDevice @test typeof(amdgpu_device.deviceID) <: AMDGPU.HIPDevice @@ -17,7 +17,7 @@ amdgpu_device = Flux.get_device() @test Flux._get_device_name(amdgpu_device) in Flux.supported_devices() # correctness of data transfer -x = randn(5, 5) +x = randn(Float32, 5, 5) cx = x |> amdgpu_device @test cx isa AMDGPU.ROCArray @test AMDGPU.device_id(AMDGPU.device(cx)) == AMDGPU.device_id(amdgpu_device.deviceID) diff --git a/test/ext_cuda/get_devices.jl b/test/ext_cuda/get_devices.jl index 17944a0e8f..db93f1e514 100644 --- a/test/ext_cuda/get_devices.jl +++ b/test/ext_cuda/get_devices.jl @@ -1,7 +1,7 @@ -cuda_device = Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["CUDA"]] +cuda_device = gpu_device() # should pass, whether or not CUDA is functional -@test typeof(cuda_device) <: Flux.FluxCUDADevice +@test typeof(cuda_device) <: Flux.CUDADevice @test typeof(cuda_device.deviceID) <: CUDA.CuDevice @@ -12,7 +12,7 @@ bias = copy(dense_model.bias) # store the bias cuda_device = Flux.get_device() -@test typeof(cuda_device) <: Flux.FluxCUDADevice +@test typeof(cuda_device) <: Flux.CUDADevice @test typeof(cuda_device.deviceID) <: CUDA.CuDevice @test Flux._get_device_name(cuda_device) in Flux.supported_devices() diff --git a/test/ext_metal/get_devices.jl b/test/ext_metal/get_devices.jl index 12302974bc..269f6a4b9c 100644 --- a/test/ext_metal/get_devices.jl +++ b/test/ext_metal/get_devices.jl @@ -1,12 +1,3 @@ -@testset "Flux.DEVICES" begin - metal_device = Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["Metal"]] - - # should pass, whether or not Metal is functional - @test typeof(metal_device) <: Flux.FluxMetalDevice - - @test typeof(metal_device.deviceID) <: Metal.MTLDevice -end - @testset "get_device()" begin metal_device = Flux.get_device() diff --git a/test/functors.jl b/test/functors.jl index 879536a94b..280b76d6f0 100644 --- a/test/functors.jl +++ b/test/functors.jl @@ -3,16 +3,10 @@ if !(Flux.CUDA_LOADED[] || Flux.AMDGPU_LOADED[] || Flux.METAL_LOADED[]) @test x === gpu(x) end -@test typeof(Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["CUDA"]]) <: Nothing -@test typeof(Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["AMDGPU"]]) <: Nothing -@test typeof(Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["Metal"]]) <: Nothing -@test typeof(Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["CPU"]]) <: Flux.FluxCPUDevice - dev = Flux.get_device() -@test typeof(dev) <: Flux.FluxCPUDevice +@test typeof(dev) <: Flux.CPUDevice @test dev(x) == x -@test Flux._get_device_name(dev) in Flux.supported_devices() # specifically getting CPU device dev = Flux.get_device("CPU") -@test typeof(dev) <: Flux.FluxCPUDevice +@test typeof(dev) <: Flux.CPUDevice From 7a60214a6aba173b6388478083ab2d6e19673fe6 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 10 Oct 2024 08:32:22 +0200 Subject: [PATCH 03/21] ported MPI extension --- docs/src/guide/gpu.md | 10 +-- ext/FluxMPIExt/FluxMPIExt.jl | 123 +++++++++++++-------------- ext/FluxMPINCCLExt/FluxMPINCCLExt.jl | 15 ++-- src/distributed/public_api.jl | 42 +++++++-- test/ext_amdgpu/get_devices.jl | 8 +- test/ext_cuda/get_devices.jl | 4 +- 6 files changed, 110 insertions(+), 92 deletions(-) diff --git a/docs/src/guide/gpu.md b/docs/src/guide/gpu.md index 1744bcdbed..e3b05eb6a7 100644 --- a/docs/src/guide/gpu.md +++ b/docs/src/guide/gpu.md @@ -378,11 +378,11 @@ Due to a limitation in `Metal.jl`, currently this kind of data movement across d ```@docs Flux.AbstractDevice -Flux.FluxCPUDevice -Flux.FluxCUDADevice -Flux.FluxAMDGPUDevice -Flux.FluxMetalDevice -Flux.supported_devices +Flux.CPUDevice +Flux.CUDADevice +Flux.AMDGPUDevice +Flux.MetalDevice +Flux.supported_gpu_backends Flux.get_device Flux.gpu_backend! ``` diff --git a/ext/FluxMPIExt/FluxMPIExt.jl b/ext/FluxMPIExt/FluxMPIExt.jl index 7a938e2e6a..f1db1ae3a7 100644 --- a/ext/FluxMPIExt/FluxMPIExt.jl +++ b/ext/FluxMPIExt/FluxMPIExt.jl @@ -1,17 +1,9 @@ module FluxMPIExt -if Base.find_package("CUDA") !== nothing - using CUDA -end - using Flux: MPIBackend, NCCLBackend, DistributedUtils, - AbstractDevice, FluxCUDADevice, FluxAMDGPUDevice, cpu, gpu, - get_device, MPI_CUDA_AWARE, MPI_ROCM_AWARE + MPI_CUDA_AWARE, MPI_ROCM_AWARE using MPI: MPI - -if Base.find_package("AMDGPU") !== nothing - using AMDGPU -end +using MLDataDevices: AbstractDevice, CUDADevice, AMDGPUDevice, functional, set_device! function DistributedUtils.__initialize( @@ -22,28 +14,24 @@ function DistributedUtils.__initialize( local_rank = MPI.Comm_rank(MPI.COMM_WORLD) - if Base.find_package("CUDA") !== nothing - if cuda_devices !== missing && CUDA.functional() - if cuda_devices === nothing - CUDA.device!((local_rank + 1) % length(CUDA.devices())) - else - CUDA.device!(cuda_devices[local_rank + 1]) - end - elseif force_cuda - error(lazy"CUDA devices are not functional and `force_cuda` is set to `true`. This is caused by backend: $(caller).") + if cuda_devices !== missing && functional(CUDADevice) + if cuda_devices === nothing + set_device!(CUDADevice, nothing, local_rank + 1) + else + set_device!(CUDADevice, cuda_devices[local_rank + 1]) end + elseif force_cuda + error(lazy"CUDA devices are not functional and `force_cuda` is set to `true`. This is caused by backend: $(caller).") end - if Base.find_package("AMDGPU") !== nothing - if amdgpu_devices !== missing && AMDGPU.functional() - if amdgpu_devices === nothing - AMDGPU.device!((local_rank + 1) % length(AMDGPU.devices())) - else - AMDGPU.device!(amdgpu_devices[local_rank + 1]) - end - elseif force_amdgpu - error(lazy"AMDGPU devices are not functional (or `LuxAMDGPU.jl` not loaded) and `force_amdgpu` is set to `true`. This is caused by backend: $(caller).") + if amdgpu_devices !== missing && AMDGPU.functional() + if amdgpu_devices === nothing + set_device!(AMDGPUDevice, nothing, local_rank + 1) + else + set_device!(AMDGPUDevice, amdgpu_devices[local_rank + 1]) end + elseif force_amdgpu + error(lazy"AMDGPU devices are not functional (or `LuxAMDGPU.jl` not loaded) and `force_amdgpu` is set to `true`. This is caused by backend: $(caller).") end return @@ -56,16 +44,15 @@ DistributedUtils.local_rank(backend::MPIBackend) = MPI.Comm_rank(backend.comm) DistributedUtils.total_workers(backend::MPIBackend) = MPI.Comm_size(backend.comm) # Broadcast -# Union with Function is because of Flux.cpu istypeof Function # We need CPU in case of non CUDA-aware implementation function DistributedUtils.__bcast!( - backend::MPIBackend, sendrecvbuf, dev::Union{AbstractDevice, Function}; root=0) + backend::MPIBackend, sendrecvbuf, dev::AbstractDevice; root=0) MPI.Bcast!(sendrecvbuf, backend.comm; root) return sendrecvbuf end function DistributedUtils.__bcast!( - backend::MPIBackend, sendbuf, recvbuf, dev::Union{AbstractDevice, Function}; root=0) + backend::MPIBackend, sendbuf, recvbuf, dev::AbstractDevice; root=0) return DistributedUtils.__bcast!( backend, ifelse(DistributedUtils.local_rank(backend) == root, sendbuf, recvbuf), dev; root) @@ -73,24 +60,26 @@ end # if MPI implementation is not CUDA-aware # we have to move data to CPU first -for (aware, dType) in ((MPI_CUDA_AWARE, FluxCUDADevice), (MPI_ROCM_AWARE, FluxAMDGPUDevice)) +for (aware, dType) in ((MPI_CUDA_AWARE, CUDADevice), (MPI_ROCM_AWARE, AMDGPUDevice)) if !aware @eval begin function DistributedUtils.__bcast!( backend::MPIBackend, sendrecvbuf, dev::$dType; root=0) - sendrecvbuf_ = sendrecvbuf |> cpu - DistributedUtils.__bcast!(backend, sendrecvbuf_, cpu; root) - sendrecvbuf |> gpu - return sendrecvbuf + cdev = cpu_device() + sendrecvbuf_ = sendrecvbuf |> cdev + DistributedUtils.__bcast!(backend, sendrecvbuf_, cdev; root) + copyto!(sendrecvbuf, sendrecvbuf_) + return end function DistributedUtils.__bcast!( backend::MPIBackend, sendbuf, recvbuf, dev::$dType; root=0) - sendbuf_ = sendbuf |> cpu - recvbuf_ = recvbuf |> cpu - DistributedUtils.__bcast!(backend, sendbuf_, recvbuf_, cpu; root) - recvbuf |> gpu - return recvbuf + cdev = cpu_device() + sendbuf_ = sendbuf |> cdev + recvbuf_ = recvbuf |> cdev + DistributedUtils.__bcast!(backend, sendbuf_, recvbuf_, cdev; root) + copyto!(recvbuf, recvbuf_) + return end end end @@ -99,7 +88,7 @@ end # Allreduce function DistributedUtils.__allreduce!( - backend::MPIBackend, sendrecvbuf, op::F, dev::Union{AbstractDevice, Function};) where {F} + backend::MPIBackend, sendrecvbuf, op::F, ::AbstractDevice) where {F} mpiop = ifelse(op === DistributedUtils.avg, +, op) MPI.Allreduce!(sendrecvbuf, mpiop, backend.comm) if op === DistributedUtils.avg @@ -109,7 +98,7 @@ function DistributedUtils.__allreduce!( end function DistributedUtils.__allreduce!( - backend::MPIBackend, sendbuf, recvbuf, op::F, dev::Union{AbstractDevice, Function};) where {F} + backend::MPIBackend, sendbuf, recvbuf, op::F, ::AbstractDevice) where {F} mpiop = ifelse(op === DistributedUtils.avg, +, op) MPI.Allreduce!(sendbuf, recvbuf, mpiop, backend.comm) if op === DistributedUtils.avg @@ -118,24 +107,26 @@ function DistributedUtils.__allreduce!( return recvbuf end -for (aware, dType) in ((MPI_CUDA_AWARE, FluxCUDADevice), (MPI_ROCM_AWARE, FluxAMDGPUDevice)) +for (aware, dType) in ((MPI_CUDA_AWARE, CUDADevice), (MPI_ROCM_AWARE, AMDGPUDevice)) if !aware @eval begin function DistributedUtils.__allreduce!( backend::MPIBackend, sendrecvbuf, op::F, dev::$dType) where {F} - sendrecvbuf_ = sendrecvbuf |> cpu - DistributedUtils.__allreduce!(backend, sendrecvbuf_, op, cpu) - sendrecvbuf |> gpu - return sendrecvbuf + cdev = cpu_device() + sendrecvbuf_ = sendrecvbuf |> cdev + DistributedUtils.__allreduce!(backend, sendrecvbuf_, op, cdev) + copyto!(sendrecvbuf, sendrecvbuf_) + return end function DistributedUtils.__allreduce!( backend::MPIBackend, sendbuf, recvbuf, op::F, dev::$dType) where {F} - sendbuf_ = sendbuf |> cpu - recvbuf_ = recvbuf |> cpu - DistributedUtils.__allreduce!(backend, sendbuf_, recvbuf_, op, cpu) - recvbuf |> gpu - return recvbuf + cdev = cpu_device() + sendbuf_ = sendbuf |> cdev + recvbuf_ = recvbuf |> cdev + DistributedUtils.__allreduce!(backend, sendbuf_, recvbuf_, op, cdev) + copyto!(recvbuf, recvbuf_) + return end end end @@ -143,7 +134,7 @@ end # Reduce function DistributedUtils.__reduce!(backend::MPIBackend, sendrecvbuf, op::F, - dev::Union{AbstractDevice, Function}; root::Int) where {F} + dev::AbstractDevice; root::Int) where {F} mpiop = ifelse(op === DistributedUtils.avg, +, op) MPI.Reduce!(sendrecvbuf, mpiop, backend.comm; root) if op === DistributedUtils.avg @@ -153,7 +144,7 @@ function DistributedUtils.__reduce!(backend::MPIBackend, sendrecvbuf, op::F, end function DistributedUtils.__reduce!(backend::MPIBackend, sendbuf, recvbuf, op::F, - dev::Union{AbstractDevice, Function}; root::Int) where {F} + dev::AbstractDevice; root::Int) where {F} mpiop = ifelse(op === DistributedUtils.avg, +, op) MPI.Reduce!(sendbuf, recvbuf, mpiop, backend.comm; root) if op === DistributedUtils.avg @@ -162,24 +153,26 @@ function DistributedUtils.__reduce!(backend::MPIBackend, sendbuf, recvbuf, op::F return recvbuf end -for (aware, dType) in ((MPI_CUDA_AWARE, FluxCUDADevice), (MPI_ROCM_AWARE, FluxAMDGPUDevice)) +for (aware, dType) in ((MPI_CUDA_AWARE, CUDADevice), (MPI_ROCM_AWARE, AMDGPUDevice)) if !aware @eval begin function DistributedUtils.__reduce!(backend::MPIBackend, sendrecvbuf, op::F, dev::$dType; root::Int) where {F} - sendrecvbuf_ = sendrecvbuf |> cpu - DistributedUtils.__reduce!(backend, sendrecvbuf_, op, cpu; root) - sendrecvbuf |> gpu - return sendrecvbuf + cdev = cpu_device() + sendrecvbuf_ = sendrecvbuf |> cdev + DistributedUtils.__reduce!(backend, sendrecvbuf_, op, cdev; root) + copyto!(sendrecvbuf, sendrecvbuf_) + return end function DistributedUtils.__reduce!(backend::MPIBackend, sendbuf, recvbuf, op::F, dev::$dType; root::Int) where {F} - sendbuf_ = sendbuf |> cpu - recvbuf_ = recvbuf |> cpu - DistributedUtils.__reduce!(backend, sendbuf_, recvbuf_, op, cpu; root) - recvbuf |> gpu - return recvbuf + cdev = cpu_device() + sendbuf_ = sendbuf |> cdev + recvbuf_ = recvbuf |> cdev + DistributedUtils.__reduce!(backend, sendbuf_, recvbuf_, op, cdev; root) + copyto!(recvbuf, recvbuf_) + return end end end diff --git a/ext/FluxMPINCCLExt/FluxMPINCCLExt.jl b/ext/FluxMPINCCLExt/FluxMPINCCLExt.jl index 754a6c74c6..bed56d775a 100644 --- a/ext/FluxMPINCCLExt/FluxMPINCCLExt.jl +++ b/ext/FluxMPINCCLExt/FluxMPINCCLExt.jl @@ -1,6 +1,7 @@ module FluxMPINCCLExt -using Flux: MPIBackend, NCCLBackend, DistributedUtils, FluxCUDADevice, FluxAMDGPUDevice, AbstractDevice +using Flux: MPIBackend, NCCLBackend, DistributedUtils +using MLDataDevices: AbstractDevice, CUDADevice, AMDGPUDevice, functional, set_device! using MPI: MPI using NCCL: NCCL using Setfield: @set! @@ -35,7 +36,7 @@ DistributedUtils.total_workers(backend::NCCLBackend) = NCCL.size(backend.comm) # For non-CUDA Arrays, fallback to MPI # Broadcast function DistributedUtils.__bcast!( - backend::NCCLBackend, sendrecvbuf::CuArray, ::FluxCUDADevice; root=0) + backend::NCCLBackend, sendrecvbuf::CuArray, ::CUDADevice; root=0) NCCL.Broadcast!(sendrecvbuf, backend.comm; root) return sendrecvbuf end @@ -46,7 +47,7 @@ function DistributedUtils.__bcast!( end function DistributedUtils.__bcast!( - backend::NCCLBackend, sendbuf, recvbuf, ::FluxCUDADevice; root=0) + backend::NCCLBackend, sendbuf, recvbuf, ::CUDADevice; root=0) NCCL.Broadcast!(sendbuf, recvbuf, backend.comm; root) return recvbuf end @@ -58,7 +59,7 @@ end # Allreduce function DistributedUtils.__allreduce!( - backend::NCCLBackend, sendrecvbuf::CuArray, op::F, dev::FluxCUDADevice) where {F} + backend::NCCLBackend, sendrecvbuf::CuArray, op::F, dev::CUDADevice) where {F} op = ifelse(op === DistributedUtils.avg, NCCL.avg, op) NCCL.Allreduce!(sendrecvbuf, op, backend.comm) return sendrecvbuf @@ -70,7 +71,7 @@ function DistributedUtils.__allreduce!( end function DistributedUtils.__allreduce!( - backend::NCCLBackend, sendbuf, recvbuf, op::F, ::FluxCUDADevice) where {F} + backend::NCCLBackend, sendbuf, recvbuf, op::F, ::CUDADevice) where {F} op = ifelse(op === DistributedUtils.avg, NCCL.avg, op) NCCL.Allreduce!(sendbuf, recvbuf, op, backend.comm) return recvbuf @@ -83,7 +84,7 @@ end # Reduce function DistributedUtils.__reduce!( - backend::NCCLBackend, sendrecvbuf, op::F, ::FluxCUDADevice; root::Int) where {F} + backend::NCCLBackend, sendrecvbuf, op::F, ::CUDADevice; root::Int) where {F} op = ifelse(op === DistributedUtils.avg, NCCL.avg, op) NCCL.Reduce!(sendrecvbuf, op, backend.comm; root) return sendrecvbuf @@ -95,7 +96,7 @@ function DistributedUtils.__reduce!(backend::NCCLBackend, sendrecvbuf, op::F, end function DistributedUtils.__reduce!( - backend::NCCLBackend, sendbuf, recvbuf, op::F, ::FluxCUDADevice; root::Int) where {F} + backend::NCCLBackend, sendbuf, recvbuf, op::F, ::CUDADevice; root::Int) where {F} op = ifelse(op === DistributedUtils.avg, NCCL.avg, op) NCCL.Reduce!(sendbuf, recvbuf, op, backend.comm; root) return recvbuf diff --git a/src/distributed/public_api.jl b/src/distributed/public_api.jl index 38176a2e63..26d321814d 100644 --- a/src/distributed/public_api.jl +++ b/src/distributed/public_api.jl @@ -5,7 +5,8 @@ module DistributedUtils using ChainRulesCore: ChainRulesCore -using ..Flux: AbstractFluxDistributedBackend, MPIBackend, NCCLBackend, AbstractDevice, get_device +using ..Flux: AbstractFluxDistributedBackend, MPIBackend, NCCLBackend +using MLDataDevices: get_device, AbstractDevice using Functors: fmap using MLUtils: MLUtils, numobs using Optimisers: Optimisers, AbstractRule, Leaf @@ -99,12 +100,19 @@ Backend Agnostic API to broadcast the given buffer `sendrecvbuf` or `sendbuf` to workers into `recvbuf`. The value at `root` will be broadcasted to all other workers. """ function bcast!(backend::AbstractFluxDistributedBackend, sendrecvbuf; root::Int=0) - return __bcast!(backend, sendrecvbuf, get_device(); root) + return __bcast!(backend, sendrecvbuf, get_device(sendrecvbuf); root) end function bcast!(backend::AbstractFluxDistributedBackend, sendbuf, recvbuf; root::Int=0) - dev = ifelse(get_device() == FluxCPUDevice, cpu, gpu) - return __bcast!(backend, sendbuf, recvbuf, dev; root) + send_dev = get_device(sendbuf) + recv_dev = get_device(recvbuf) + if send_dev == recv_dev + return __bcast!(backend, sendbuf, recvbuf, send_dev; root) + else + sendbuf_ = sendbuf |> recv_dev + @warn "`sendbuf` and `recvbuf` are on different devices." maxlog=1 + return __bcast!(backend, sendbuf_, recvbuf, recv_dev; root) + end end function __bcast! end @@ -129,8 +137,16 @@ end function allreduce!( backend::AbstractFluxDistributedBackend, sendbuf, recvbuf, op::F) where {F} - dev = ifelse(get_device() == FluxCPUDevice, cpu, gpu) - return __allreduce!(backend, sendbuf, recvbuf, op, dev) + send_dev = get_device(sendbuf) + recv_dev = get_device(recvbuf) + if send_dev == recv_dev + __allreduce!(backend, sendbuf, recvbuf, op, send_dev) + else + sendbuf_ = sendbuf |> recv_dev + @warn "`sendbuf` and `recvbuf` are on different devices." maxlog=1 + __allreduce!(backend, sendbuf_, recvbuf, op, recv_dev) + end + return end function __allreduce! end @@ -149,13 +165,21 @@ workers. """ function reduce!( backend::AbstractFluxDistributedBackend, sendrecvbuf, op::F; root::Int=0) where {F} - return __reduce!(backend, sendrecvbuf, op, get_device(); root) + return __reduce!(backend, sendrecvbuf, op, get_device(sendrecvbuf); root) end function reduce!(backend::AbstractFluxDistributedBackend, sendbuf, recvbuf, op::F; root::Int=0) where {F} - dev = ifelse(get_device() == FluxCPUDevice, cpu, gpu) - return __reduce!(backend, sendbuf, recvbuf, op, dev; root) + send_dev = get_device(sendbuf) + recv_dev = get_device(recvbuf) + if send_dev == recv_dev + __reduce!(backend, sendbuf, recvbuf, op, send_dev; root) + else + sendbuf_ = sendbuf |> recv_dev + @warn "`sendbuf` and `recvbuf` are on different devices." maxlog=1 + __reduce!(backend, sendbuf_, recvbuf, op, recv_dev; root) + end + return end function __reduce! end diff --git a/test/ext_amdgpu/get_devices.jl b/test/ext_amdgpu/get_devices.jl index f261f80efd..ecda4a9759 100644 --- a/test/ext_amdgpu/get_devices.jl +++ b/test/ext_amdgpu/get_devices.jl @@ -10,9 +10,9 @@ dense_model = Dense(2 => 3) # initially lives on CPU weight = copy(dense_model.weight) # store the weight bias = copy(dense_model.bias) # store the bias -amdgpu_device = Flux.get_device() +amdgpu_device = gpu_device() -@test typeof(amdgpu_device) <: Flux.FluxAMDGPUDevice +@test typeof(amdgpu_device) <: Flux.AMDGPUDevice @test typeof(amdgpu_device.deviceID) <: AMDGPU.HIPDevice @test Flux._get_device_name(amdgpu_device) in Flux.supported_devices() @@ -37,7 +37,7 @@ for id in 0:(length(AMDGPU.devices()) - 1) @test isequal(Flux.cpu(dense_model.bias), bias) end # finally move to CPU, and see if things work -cpu_device = Flux.get_device("CPU") -dense_model = cpu_device(dense_model) +cdev = cpu_device() +dense_model = cdev(dense_model) @test dense_model.weight isa Matrix @test dense_model.bias isa Vector diff --git a/test/ext_cuda/get_devices.jl b/test/ext_cuda/get_devices.jl index db93f1e514..4eb4b6905b 100644 --- a/test/ext_cuda/get_devices.jl +++ b/test/ext_cuda/get_devices.jl @@ -37,7 +37,7 @@ for id in 0:(length(CUDA.devices()) - 1) @test isequal(Flux.cpu(dense_model.bias), bias) end # finally move to CPU, and see if things work -cpu_device = Flux.get_device("CPU") -dense_model = cpu_device(dense_model) +cdev = cpu_device() +dense_model = cdev(dense_model) @test dense_model.weight isa Matrix @test dense_model.bias isa Vector From 81659c0cb303e7dda118e03b1d44fb4960c135e9 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 10 Oct 2024 09:04:08 +0200 Subject: [PATCH 04/21] docs --- NEWS.md | 2 +- docs/src/guide/gpu.md | 120 ++++++++++++++++++------------------------ test/runtests.jl | 2 +- 3 files changed, 53 insertions(+), 71 deletions(-) diff --git a/NEWS.md b/NEWS.md index 51dde806e9..654ae70c07 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,7 +3,7 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release. ## v0.14.22 -* Data movement between devices is now provided by [MLDataDevice.jl](https://github.com/LuxDL/MLDataDevices.jl). +* Data movement between devices is now provided by [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl). ## v0.14.18 * Add [support for distributed data parallel training](https://github.com/FluxML/Flux.jl/pull/2446). diff --git a/docs/src/guide/gpu.md b/docs/src/guide/gpu.md index e3b05eb6a7..7a022f9fed 100644 --- a/docs/src/guide/gpu.md +++ b/docs/src/guide/gpu.md @@ -232,20 +232,17 @@ More information for conditional use of GPUs in CUDA.jl can be found in its [doc ## Using device objects -///// TODO //// -As a more convenient syntax, Flux allows the usage of GPU `device` objects which can be used to easily transfer models to GPUs (and defaulting to using the CPU if no GPU backend is available). This syntax has a few advantages including automatic selection of the GPU backend and type stability of data movement. To do this, the [`Flux.get_device`](@ref) function can be used. +As a more convenient syntax, Flux allows the usage of GPU `device` objects which can be used to easily transfer models to GPUs (and defaulting to using the CPU if no GPU backend is available). This syntax has a few advantages including automatic selection of the GPU backend and type stability of data movement. +These features are provided by [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl) package, that Flux's uses internally and re-exports. -`Flux.get_device` first checks for a GPU preference, and if possible returns a device for the preference backend. For instance, consider the following example, where we load the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) package to use an NVIDIA GPU (`"CUDA"` is the default preference): +A `device` object can be created using the [`gpu_device`](@ref MLDataDevices.get_device) function. +`gpu_device` first checks for a GPU preference, and if possible returns a device for the preference backend. For instance, consider the following example, where we load the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) package to use an NVIDIA GPU (`"CUDA"` is the default preference): ```julia-repl julia> using Flux, CUDA; -julia> device = Flux.get_device(; verbose=true) # returns handle to an NVIDIA GPU -[ Info: Using backend set in preferences: CUDA. -(::Flux.FluxCUDADevice) (generic function with 1 method) - -julia> device.deviceID # check the id of the GPU -CuDevice(0): NVIDIA GeForce GTX 1650 +julia> device = gpu_device() # returns handle to an NVIDIA GPU if available +(::CUDADevice{Nothing}) (generic function with 4 methods) julia> model = Dense(2 => 3); @@ -263,77 +260,57 @@ julia> model.weight -0.984794 -0.904345 0.720379 -0.486398 0.851011 -0.586942 - ``` -The device preference can also be set via the [`Flux.gpu_backend!`](@ref) function. For instance, below we first set our device preference to `"CPU"`: +The device preference can also be set via the [`gpu_backend!`](@ref MLDataDevices.gpu_backend!) function. For instance, below we first set our device preference to `"AMDGPU"`: ```julia-repl -julia> using Flux; Flux.gpu_backend!("CPU") -┌ Info: New GPU backend set: CPU. -└ Restart your Julia session for this change to take effect! +julia> gpu_backend!("AMDGPU") +[ Info: GPU backend has been set to AMDGPU. Restart Julia to use the new backend. ``` - -Then, after restarting the Julia session, `Flux.get_device` returns a handle to the `"CPU"`: +If no functional GPU backend is available, the device will default to a CPU device. +You can also explictly request a CPU device by calling the [`cpu_device`](@ref MLDataDevices.cpu_device) function. ```julia-repl -julia> using Flux, CUDA; # even if CUDA is loaded, we'll still get a CPU device - -julia> device = Flux.get_device(; verbose=true) # get a CPU device -[ Info: Using backend set in preferences: CPU. -(::Flux.FluxCPUDevice) (generic function with 1 method) +julia> using Flux, MLDataDevices -julia> model = Dense(2 => 3); - -julia> model = model |> device -Dense(2 => 3) # 9 parameters +julia> cdev = cpu_device() +(::CPUDevice{Nothing}) (generic function with 4 methods) -julia> model.weight # no change; model still lives on CPU -3×2 Matrix{Float32}: - -0.942968 0.856258 - 0.440009 0.714106 - -0.419192 -0.471838 -``` -Clearly, this means that the same code will work for any GPU backend and the CPU. +julia> gdev = gpu_device(force=true) # force GPU device, error if no GPU is available +(::CUDADevice{Nothing}) (generic function with 4 methods) -If the preference backend isn't available or isn't functional, then [`Flux.get_device`](@ref) looks for a CUDA, AMDGPU or Metal backend, and returns a corresponding device (if the backend is available and functional). Otherwise, a CPU device is returned. In the below example, the GPU preference is `"CUDA"`: +julia> model = Dense(2 => 3); # model in CPU memory -```julia-repl -julia> using Flux; # preference is CUDA, but CUDA.jl not loaded +julia> gmodel = model |> gdev; # transfer model to GPU -julia> device = Flux.get_device(; verbose=true) # this will resort to automatic device selection -[ Info: Using backend set in preferences: CUDA. -┌ Warning: Trying to use backend: CUDA but it's trigger package is not loaded. -│ Please load the package and call this function again to respect the preferences backend. -└ @ Flux ~/fluxml/Flux.jl/src/functor.jl:637 -[ Info: Using backend: CPU. -(::Flux.FluxCPUDevice) (generic function with 1 method) +julia> cmodel = gmodel |> cdev; # transfer model back to CPU ``` -For detailed information about how the backend is selected, check the documentation for [`Flux.get_device`](@ref). ## Data movement across GPU devices -Flux also supports getting handles to specific GPU devices, and transferring models from one GPU device to another GPU -device from the same backend. Let's try it out for NVIDIA GPUs. First, we list all the available devices: +Flux also supports getting handles to specific GPU devices, and transferring models from one GPU device to another GPU device from the same backend. Let's try it out for NVIDIA GPUs. First, we list all the available devices: ```julia-repl julia> using Flux, CUDA; julia> CUDA.devices() CUDA.DeviceIterator() for 3 devices: -0. GeForce RTX 2080 Ti -1. GeForce RTX 2080 Ti -2. TITAN X (Pascal) - +0. NVIDIA TITAN RTX +1. NVIDIA TITAN RTX +2. NVIDIA TITAN RTX ``` Then, let's select the device with id `0`: ```julia-repl -julia> device0 = Flux.get_device("CUDA", 0) # the currently supported values for backend are "CUDA" and "AMDGPU" -(::Flux.FluxCUDADevice) (generic function with 1 method) +julia> device0 = gpu_device(1) +(::CUDADevice{CuDevice}) (generic function with 4 methods) +julia> device0.device +CuDevice(0): NVIDIA TITAN RTX ``` +Notice that indexing starts from `0` in the `CUDA.devices()` output, but `gpu_device!` expects the device id starting from `1`. Then, let's move a simple dense layer to the GPU represented by `device0`: @@ -344,27 +321,25 @@ Dense(2 => 3) # 9 parameters julia> dense_model = dense_model |> device0; julia> dense_model.weight -3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}: - 0.695662 0.816299 - -0.204763 -0.10232 - -0.955829 0.538412 +3×2 CuArray{Float32, 2, CUDA.DeviceMemory}: + -0.142062 -0.131455 + -0.828134 -1.06552 + 0.608595 -1.05375 julia> CUDA.device(dense_model.weight) # check the GPU to which dense_model is attached -CuDevice(0): GeForce RTX 2080 Ti - +CuDevice(0): NVIDIA TITAN RTX ``` Next, we'll get a handle to the device with id `1`, and move `dense_model` to that device: ```julia-repl -julia> device1 = Flux.get_device("CUDA", 1) -(::Flux.FluxCUDADevice) (generic function with 1 method) +julia> device1 = gpu_device(2) +(::CUDADevice{CuDevice}) (generic function with 4 methods) julia> dense_model = dense_model |> device1; # don't directly print the model; see warning below julia> CUDA.device(dense_model.weight) -CuDevice(1): GeForce RTX 2080 Ti - +CuDevice(1): NVIDIA TITAN RTX ``` Due to a limitation in `Metal.jl`, currently this kind of data movement across devices is only supported for `CUDA` and `AMDGPU` backends. @@ -377,14 +352,21 @@ Due to a limitation in `Metal.jl`, currently this kind of data movement across d ```@docs -Flux.AbstractDevice -Flux.CPUDevice -Flux.CUDADevice -Flux.AMDGPUDevice -Flux.MetalDevice -Flux.supported_gpu_backends -Flux.get_device -Flux.gpu_backend! +MLDataDevices.cpu_device +MLDataDevices.default_device_rng +MLDataDevices.get_device +MLDataDevices.gpu_device +MLDataDevices.gpu_backend! +MLDataDevices.get_device_type +MLDataDevices.reset_gpu_device! +MLDataDevices.supported_gpu_backends +MLDataDevices.CPUDevice +MLDataDevices.CUDADevice +MLDataDevices.AMDGPUDevice +MLDataDevices.MetalDevice +MLDataDevices.oneAPIDevice +MLDataDevices.XLADevice +MLDataDevices.DeviceIterator ``` ## Distributed data parallel training diff --git a/test/runtests.jl b/test/runtests.jl index c2b9f9e28e..9e06bd3919 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,7 +8,7 @@ using Zygote using Pkg # ENV["FLUX_TEST_AMDGPU"] = "true" -# ENV["FLUX_TEST_CUDA"] = "true" +ENV["FLUX_TEST_CUDA"] = "true" # ENV["FLUX_TEST_METAL"] = "true" # ENV["FLUX_TEST_CPU"] = "false" # ENV["FLUX_TEST_DISTRIBUTED_MPI"] = "true" From f6019f2c9b100325456e0899068921cfee00bbba Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 10 Oct 2024 09:09:41 +0200 Subject: [PATCH 05/21] docs --- test/ext_enzyme/enzyme.jl | 1 - test/runtests.jl | 9 +++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index 25284c5a3f..bae14fd246 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -5,7 +5,6 @@ using Enzyme: Enzyme, make_zero, Active, Duplicated, ReverseWithPrimal using Functors using FiniteDifferences -using CUDA function gradient_fd(f, x...) diff --git a/test/runtests.jl b/test/runtests.jl index 9e06bd3919..e61f9394c0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,12 +7,14 @@ using IterTools: ncycle using Zygote using Pkg +## Uncomment below to change the default test settings # ENV["FLUX_TEST_AMDGPU"] = "true" -ENV["FLUX_TEST_CUDA"] = "true" +# ENV["FLUX_TEST_CUDA"] = "true" # ENV["FLUX_TEST_METAL"] = "true" # ENV["FLUX_TEST_CPU"] = "false" # ENV["FLUX_TEST_DISTRIBUTED_MPI"] = "true" # ENV["FLUX_TEST_DISTRIBUTED_NCCL"] = "true" +ENV["FLUX_TEST_ENZYME"] = "false" include("test_utils.jl") @@ -140,14 +142,13 @@ Random.seed!(0) @info "Skipping Distributed tests, set FLUX_TEST_DISTRIBUTED_MPI or FLUX_TEST_DISTRIBUTED_NCCL=true to run them." end - if get(ENV, "FLUX_TEST_CUDA", "false") == "true" + if get(ENV, "FLUX_TEST_ENZYME", "true") == "true" @testset "Enzyme" begin - Pkg.add(["CUDA", "cuDNN"]) import Enzyme include("ext_enzyme/enzyme.jl") end else - @info "Skipping Enzyme tests, set FLUX_TEST_CUDA=true to run them." + @info "Skipping Enzyme tests, set FLUX_TEST_ENZYME=true to run them." end end From 3680c235bc7ee39fb94dfe9dce7d70352a6a6d56 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 10 Oct 2024 10:38:59 +0200 Subject: [PATCH 06/21] skip enzyme tests --- test/runtests.jl | 2 +- test/train.jl | 118 +++++++++++++++++++++++++---------------------- test/utils.jl | 13 ++++-- 3 files changed, 73 insertions(+), 60 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index e61f9394c0..c48b281c92 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,7 +14,7 @@ using Pkg # ENV["FLUX_TEST_CPU"] = "false" # ENV["FLUX_TEST_DISTRIBUTED_MPI"] = "true" # ENV["FLUX_TEST_DISTRIBUTED_NCCL"] = "true" -ENV["FLUX_TEST_ENZYME"] = "false" +ENV["FLUX_TEST_ENZYME"] = "false" # We temporarily disable Enzyme tests since they are failing include("test_utils.jl") diff --git a/test/train.jl b/test/train.jl index 4c0c12b1b6..92b25bef5d 100644 --- a/test/train.jl +++ b/test/train.jl @@ -11,72 +11,82 @@ function train_enzyme!(fn, model, args...; kwargs...) end for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) -@testset "Explicit Flux.train! with $name" begin - Random.seed!(84) - w = randn(10, 10) - w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset. - @testset for rule in [AdamW(), AdaGrad(0.1), AdaMax(), AdaDelta(0.9), AMSGrad(), - NAdam(), RAdam(), Descent(0.1), Adam(), OAdam(), AdaBelief(), - Nesterov(), RMSProp(), Momentum()] - - loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) - model = (weight=copy(w2), bias=zeros(10), ignore=nothing) - @test loss(model, rand(10, 10)) > 1 - - opt = Flux.setup(rule, model) - trainfn!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) - @test loss(model, rand(10, 10)) < 0.01 - end - # Test direct use of Optimisers.jl rule, only really OK for `Descent`: - # Enzyme doesn't work with un-initialized atm, presumably due to trainmode? - if name != "Enzyme" - @testset "without setup, $opt" for opt in [Descent(0.1), Optimisers.Descent(0.1), Optimisers.Adam()] - loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) - model = (weight=copy(w2), bias=zeros(10), ignore=nothing) - @test loss(model, rand(10, 10)) > 1 - trainfn!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) - @test loss(model, rand(10, 10)) < 0.01 - end + if (name == "Enzyme" && get(ENV, "FLUX_TEST_ENZYME", "true") == "false") + continue end -end -end -for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) -@testset "Explicit Flux.train! features with $name" begin - @testset "Stop on NaN" begin - m1 = Dense(1 => 1) - m1.weight .= 0 - CNT = Ref(0) - @test_throws DomainError trainfn!(m1, tuple.(1:100), Descent(0.1)) do m, i - CNT[] += 1 - (i == 51 ? NaN32 : 1f0) * sum(m([1.0])) + @testset "Explicit Flux.train! with $name" begin + Random.seed!(84) + w = randn(10, 10) + w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset. + @testset for rule in [AdamW(), AdaGrad(0.1), AdaMax(), AdaDelta(0.9), AMSGrad(), + NAdam(), RAdam(), Descent(0.1), Adam(), OAdam(), AdaBelief(), + Nesterov(), RMSProp(), Momentum()] + + loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + model = (weight=copy(w2), bias=zeros(10), ignore=nothing) + @test loss(model, rand(10, 10)) > 1 + + opt = Flux.setup(rule, model) + trainfn!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) + @test loss(model, rand(10, 10)) < 0.01 end - @test CNT[] == 51 # stopped early + + # Test direct use of Optimisers.jl rule, only really OK for `Descent`: + # Enzyme doesn't work with un-initialized atm, presumably due to trainmode? if name != "Enzyme" - @test m1.weight[1] ≈ -5 # did not corrupt weights - else - @test m1.weight[1] ≈ 0.0 # did not corrupt weights + @testset "without setup, $opt" for opt in [Descent(0.1), Optimisers.Descent(0.1), Optimisers.Adam()] + loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + model = (weight=copy(w2), bias=zeros(10), ignore=nothing) + @test loss(model, rand(10, 10)) > 1 + trainfn!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) + @test loss(model, rand(10, 10)) < 0.01 + end end end +end - @testset "non-tuple data" begin - w = randn(10, 10) - w2 = randn(10, 10) - loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) - model = (weight=copy(w2), bias=zeros(10)) - opt = Flux.setup(AdamW(), model) - trainfn!(loss, model, (rand(10) for _ in 1: 10^5), opt) - @test loss(model, rand(10, 10)) < 0.01 +for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) + + if (name == "Enzyme" && get(ENV, "FLUX_TEST_ENZYME", "true") == "false") + continue end - @testset "callbacks give helpful error" begin - m1 = Dense(1 => 1) - cb = () -> println("this should not be printed") - @test_throws ErrorException trainfn!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb) + @testset "Explicit Flux.train! features with $name" begin + @testset "Stop on NaN" begin + m1 = Dense(1 => 1) + m1.weight .= 0 + CNT = Ref(0) + @test_throws DomainError trainfn!(m1, tuple.(1:100), Descent(0.1)) do m, i + CNT[] += 1 + (i == 51 ? NaN32 : 1f0) * sum(m([1.0])) + end + @test CNT[] == 51 # stopped early + if name != "Enzyme" + @test m1.weight[1] ≈ -5 # did not corrupt weights + else + @test m1.weight[1] ≈ 0.0 # did not corrupt weights + end + end + + @testset "non-tuple data" begin + w = randn(10, 10) + w2 = randn(10, 10) + loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + model = (weight=copy(w2), bias=zeros(10)) + opt = Flux.setup(AdamW(), model) + trainfn!(loss, model, (rand(10) for _ in 1: 10^5), opt) + @test loss(model, rand(10, 10)) < 0.01 + end + + @testset "callbacks give helpful error" begin + m1 = Dense(1 => 1) + cb = () -> println("this should not be printed") + @test_throws ErrorException trainfn!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb) + end end end -end @testset "Explicit Flux.update! features" begin m = Chain(Dense(2=>3, tanh), Dense(3=>1), only) diff --git a/test/utils.jl b/test/utils.jl index 1910e6ecd5..e05d5f4562 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -273,11 +273,14 @@ end @testset "params gradient" begin m = (x=[1,2.0], y=[3.0]); - # Explicit -- was broken by #2054 - gnew = gradient(m -> (sum(norm, Flux.params(m))), m)[1] - @test gnew.x ≈ [0.4472135954999579, 0.8944271909999159] - @test gnew.y ≈ [1.0] - + @test_broken begin + # Explicit -- was broken by #2054 / then fixed / now broken again on julia v0.11 + gnew = gradient(m -> (sum(norm, Flux.params(m))), m)[1] + @test gnew.x ≈ [0.4472135954999579, 0.8944271909999159] + @test gnew.y ≈ [1.0] + true + end + # Implicit gold = gradient(() -> (sum(norm, Flux.params(m))), Flux.params(m)) @test gold[m.x] ≈ [0.4472135954999579, 0.8944271909999159] From 2d913782082b58229dd86067c6639a675c17f4f6 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 10 Oct 2024 10:40:51 +0200 Subject: [PATCH 07/21] fix docs --- docs/make.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 6c7b483caa..f0883b6ac8 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,11 +1,11 @@ using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers, OneHotArrays, Zygote, ChainRulesCore, Plots, MLDatasets, Statistics, - DataFrames, JLD2 + DataFrames, JLD2, MLDataDevices DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true) makedocs( - modules = [Flux, NNlib, Functors, MLUtils, Zygote, OneHotArrays, Optimisers, ChainRulesCore], + modules = [Flux, NNlib, Functors, MLUtils, Zygote, OneHotArrays, Optimisers, ChainRulesCore, MLDataDevices], sitename = "Flux", pages = [ "Welcome" => "index.md", From 1bfc0d332cfee275e059fa8c687b5afb9e792faf Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 10 Oct 2024 10:48:39 +0200 Subject: [PATCH 08/21] more enzyme fixes --- test/train.jl | 73 +++++++++++++++++++++++++++------------------------ 1 file changed, 39 insertions(+), 34 deletions(-) diff --git a/test/train.jl b/test/train.jl index 92b25bef5d..d2114d5a16 100644 --- a/test/train.jl +++ b/test/train.jl @@ -125,49 +125,54 @@ end end for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) -@testset "L2 regularisation with $name" begin - # New docs claim an exact equivalent. It's a bit long to put the example in there, - # but perhaps the tests should contain it. - - model = Dense(3 => 2, tanh); - init_weight = copy(model.weight); - data = [(randn(Float32, 3,5), randn(Float32, 2,5)) for _ in 1:10]; - - # Take 1: explicitly add a penalty in the loss function - opt = Flux.setup(Adam(0.1), model) - trainfn!(model, data, opt) do m, x, y - err = Flux.mse(m(x), y) - l2 = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2 - err + 0.33 * l2 + + if (name == "Enzyme" && get(ENV, "FLUX_TEST_ENZYME", "true") == "false") + continue end - diff1 = model.weight .- init_weight + + @testset "L2 regularisation with $name" begin + # New docs claim an exact equivalent. It's a bit long to put the example in there, + # but perhaps the tests should contain it. - # Take 2: the same, but with Flux.params. Was broken for a bit, no tests! - # skipping this test for Enzyme cause implicit params is unsupported - if name == "Zygote" - model.weight .= init_weight - model.bias .= 0 - pen2(x::AbstractArray) = sum(abs2, x)/2 + model = Dense(3 => 2, tanh); + init_weight = copy(model.weight); + data = [(randn(Float32, 3,5), randn(Float32, 2,5)) for _ in 1:10]; + + # Take 1: explicitly add a penalty in the loss function opt = Flux.setup(Adam(0.1), model) trainfn!(model, data, opt) do m, x, y err = Flux.mse(m(x), y) - l2 = sum(pen2, Flux.params(m)) + l2 = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2 err + 0.33 * l2 end - diff2 = model.weight .- init_weight - @test diff1 ≈ diff2 - end + diff1 = model.weight .- init_weight + + # Take 2: the same, but with Flux.params. Was broken for a bit, no tests! + # skipping this test for Enzyme cause implicit params is unsupported + if name == "Zygote" + model.weight .= init_weight + model.bias .= 0 + pen2(x::AbstractArray) = sum(abs2, x)/2 + opt = Flux.setup(Adam(0.1), model) + trainfn!(model, data, opt) do m, x, y + err = Flux.mse(m(x), y) + l2 = sum(pen2, Flux.params(m)) + err + 0.33 * l2 + end + diff2 = model.weight .- init_weight + @test diff1 ≈ diff2 + end - # Take 3: using WeightDecay instead. Need the /2 above, to match exactly. - model.weight .= init_weight - model.bias .= 0 - decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.33), Adam(0.1)), model); - trainfn!(model, data, decay_opt) do m, x, y - Flux.mse(m(x), y) + # Take 3: using WeightDecay instead. Need the /2 above, to match exactly. + model.weight .= init_weight + model.bias .= 0 + decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.33), Adam(0.1)), model); + trainfn!(model, data, decay_opt) do m, x, y + Flux.mse(m(x), y) + end + diff3 = model.weight .- init_weight + @test diff1 ≈ diff3 end - diff3 = model.weight .- init_weight - @test diff1 ≈ diff3 -end end @testset "Flux.setup bugs" begin From 2c27af3e14050a836e82753dcebddaafcc20c08b Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 10 Oct 2024 10:54:14 +0200 Subject: [PATCH 09/21] fix metal --- test/ext_metal/get_devices.jl | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/test/ext_metal/get_devices.jl b/test/ext_metal/get_devices.jl index 269f6a4b9c..cdde261fa4 100644 --- a/test/ext_metal/get_devices.jl +++ b/test/ext_metal/get_devices.jl @@ -1,9 +1,7 @@ @testset "get_device()" begin metal_device = Flux.get_device() - @test typeof(metal_device) <: Flux.FluxMetalDevice - @test typeof(metal_device.deviceID) <: Metal.MTLDevice - @test Flux._get_device_name(metal_device) in Flux.supported_devices() + @test typeof(metal_device) <: Flux.MetalDevice # correctness of data transfer x = randn(5, 5) @@ -12,17 +10,13 @@ @test Metal.device(cx).registryID == metal_device.deviceID.registryID end -@testset "get_device(Metal)" begin - metal_device = Flux.get_device("Metal") +@testset "gpu_device()" begin + metal_device = gpu_device() - @test typeof(metal_device) <: Flux.FluxMetalDevice - @test typeof(metal_device.deviceID) <: Metal.MTLDevice - @test Flux._get_device_name(metal_device) in Flux.supported_devices() + @test typeof(metal_device) <: Flux.MetalDevice - metal_device = Flux.get_device("Metal", 0) + metal_device = gpu_device(0) - @test typeof(metal_device) <: Flux.FluxMetalDevice - @test typeof(metal_device.deviceID) <: Metal.MTLDevice - @test Flux._get_device_name(metal_device) in Flux.supported_devices() + @test typeof(metal_device) <: Flux.MetalDevice end From f975c18348f91301d396f6ce4b43803d3f49ace1 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 10 Oct 2024 12:50:30 +0200 Subject: [PATCH 10/21] fix gpu --- src/Flux.jl | 1 + src/devices.jl | 3 ++- src/functor.jl | 2 +- test/ext_amdgpu/get_devices.jl | 7 ------- test/ext_cuda/get_devices.jl | 10 ++-------- test/ext_metal/runtests.jl | 24 +++++++++++++++++++++--- 6 files changed, 27 insertions(+), 20 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index c47dd12d8f..2763e26499 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -104,6 +104,7 @@ include("losses/Losses.jl") using .Losses include("devices.jl") +export get_device # Distributed Training include("distributed/backend.jl") diff --git a/src/devices.jl b/src/devices.jl index cd7a92c3ed..ff54472eb2 100644 --- a/src/devices.jl +++ b/src/devices.jl @@ -1,6 +1,7 @@ -# TODO get docstring from MLDataDevices.get_device get_device(x) = MLDataDevices.get_device(x) +@doc (@doc MLDataDevices.get_device) get_device + function (device::MLDataDevices.AbstractDevice)(d::MLUtils.DataLoader) MLUtils.DataLoader(MLUtils.mapobs(device, d.data), d.batchsize, diff --git a/src/functor.jl b/src/functor.jl index de7ba1d0f1..c76646729a 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -192,7 +192,7 @@ _isleaf(::AbstractRNG) = true # the order below is important const GPU_BACKENDS = ("CUDA", "AMDGPU", "Metal", "CPU") const GPU_BACKEND_ORDER = Dict(collect(zip(GPU_BACKENDS, 1:length(GPU_BACKENDS)))) -const GPU_BACKEND = @load_preference("gpu_backend", "CUDA") +const GPU_BACKEND = load_preference(MLDataDevices, "gpu_backend", "CUDA") """ diff --git a/test/ext_amdgpu/get_devices.jl b/test/ext_amdgpu/get_devices.jl index ecda4a9759..7f4d8ccd7a 100644 --- a/test/ext_amdgpu/get_devices.jl +++ b/test/ext_amdgpu/get_devices.jl @@ -3,8 +3,6 @@ amdgpu_device = gpu_device() # should pass, whether or not AMDGPU is functional @test typeof(amdgpu_device) <: Flux.AMDGPUDevice -@test typeof(amdgpu_device.deviceID) <: AMDGPU.HIPDevice - # testing get_device dense_model = Dense(2 => 3) # initially lives on CPU weight = copy(dense_model.weight) # store the weight @@ -13,20 +11,15 @@ bias = copy(dense_model.bias) # store the bias amdgpu_device = gpu_device() @test typeof(amdgpu_device) <: Flux.AMDGPUDevice -@test typeof(amdgpu_device.deviceID) <: AMDGPU.HIPDevice -@test Flux._get_device_name(amdgpu_device) in Flux.supported_devices() # correctness of data transfer x = randn(Float32, 5, 5) cx = x |> amdgpu_device @test cx isa AMDGPU.ROCArray -@test AMDGPU.device_id(AMDGPU.device(cx)) == AMDGPU.device_id(amdgpu_device.deviceID) # moving models to specific NVIDIA devices for id in 0:(length(AMDGPU.devices()) - 1) current_amdgpu_device = Flux.get_device("AMDGPU", id) - @test typeof(current_amdgpu_device.deviceID) <: AMDGPU.HIPDevice - @test AMDGPU.device_id(current_amdgpu_device.deviceID) == id + 1 global dense_model = dense_model |> current_amdgpu_device @test dense_model.weight isa AMDGPU.ROCArray diff --git a/test/ext_cuda/get_devices.jl b/test/ext_cuda/get_devices.jl index 4eb4b6905b..2f4ea3bd98 100644 --- a/test/ext_cuda/get_devices.jl +++ b/test/ext_cuda/get_devices.jl @@ -3,8 +3,6 @@ cuda_device = gpu_device() # should pass, whether or not CUDA is functional @test typeof(cuda_device) <: Flux.CUDADevice -@test typeof(cuda_device.deviceID) <: CUDA.CuDevice - # testing get_device dense_model = Dense(2 => 3) # initially lives on CPU weight = copy(dense_model.weight) # store the weight @@ -13,20 +11,16 @@ bias = copy(dense_model.bias) # store the bias cuda_device = Flux.get_device() @test typeof(cuda_device) <: Flux.CUDADevice -@test typeof(cuda_device.deviceID) <: CUDA.CuDevice -@test Flux._get_device_name(cuda_device) in Flux.supported_devices() # correctness of data transfer x = randn(5, 5) cx = x |> cuda_device @test cx isa CUDA.CuArray -@test CUDA.device(cx).handle == cuda_device.deviceID.handle # moving models to specific NVIDIA devices for id in 0:(length(CUDA.devices()) - 1) - current_cuda_device = Flux.get_device("CUDA", id) - @test typeof(current_cuda_device.deviceID) <: CUDA.CuDevice - @test current_cuda_device.deviceID.handle == id + current_cuda_device = gpu_device(id+1) + @test typeof(current_cuda_device) <: Flux.CUDADevice global dense_model = dense_model |> current_cuda_device @test dense_model.weight isa CUDA.CuArray diff --git a/test/ext_metal/runtests.jl b/test/ext_metal/runtests.jl index cb77bce3b8..8c8af7d896 100644 --- a/test/ext_metal/runtests.jl +++ b/test/ext_metal/runtests.jl @@ -5,11 +5,29 @@ using Random, Statistics using Zygote Flux.gpu_backend!("Metal") # needs a restart -# include("../test_utils.jl") include("test_utils.jl") -@testset "get_devices" begin - include("get_devices.jl") +@testset "data movement" begin + metal_device = Flux.gpu_device() + cdev = cpu_device() + + @test metal_device isa Flux.MetalDevice + + x = randn(Float32, 5, 5) + cx = x |> metal_device + @test cx isa Metal.MtlMatrix{Float32} + x2 = cx |> cdev + @test x2 isa Matrix{Float32} + @test x ≈ x2 + + metal_device = gpu_device(1) + @test metal_device isa Flux.MetalDevice + + @test cpu(cx) isa Matrix{Float32} + @test cpu(cx) ≈ x + + @test gpu(x) isa Metal.MtlMatrix{Float32} + @test cpu(gpu(x)) ≈ x end @testset "Basic" begin From 3b3e50e6e7b598845ff5f025d3a9486a1cada5e9 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 10 Oct 2024 13:13:07 +0200 Subject: [PATCH 11/21] doc project --- docs/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/Project.toml b/docs/Project.toml index 1990368231..731bc7e84a 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,6 +6,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" From f50c62cb13c68fd9b979fed427d29821ed3947d5 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 10 Oct 2024 13:15:46 +0200 Subject: [PATCH 12/21] fix buildkite preference --- .buildkite/pipeline.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index f55033e4cf..553fb318c5 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -46,7 +46,7 @@ steps: using Pkg Pkg.resolve()' commands: | - printf "[Flux]\ngpu_backend = \"Metal\"" > LocalPreferences.toml + printf "[MLDataDevices]\ngpu_backend = \"Metal\"" > LocalPreferences.toml if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 60 @@ -74,7 +74,7 @@ steps: rocm: "*" rocmgpu: "*" commands: | - printf "[Flux]\ngpu_backend = \"AMDGPU\"" > LocalPreferences.toml + printf "[MLDataDevices]\ngpu_backend = \"AMDGPU\"" > LocalPreferences.toml timeout_in_minutes: 60 env: JULIA_AMDGPU_CORE_MUST_LOAD: "1" From e5ab99b627c8778cd076a3cfa3be963ed88eb8db Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 10 Oct 2024 14:12:33 +0200 Subject: [PATCH 13/21] fix docs --- docs/src/guide/gpu.md | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/docs/src/guide/gpu.md b/docs/src/guide/gpu.md index 7a022f9fed..8a08b47986 100644 --- a/docs/src/guide/gpu.md +++ b/docs/src/guide/gpu.md @@ -235,7 +235,7 @@ More information for conditional use of GPUs in CUDA.jl can be found in its [doc As a more convenient syntax, Flux allows the usage of GPU `device` objects which can be used to easily transfer models to GPUs (and defaulting to using the CPU if no GPU backend is available). This syntax has a few advantages including automatic selection of the GPU backend and type stability of data movement. These features are provided by [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl) package, that Flux's uses internally and re-exports. -A `device` object can be created using the [`gpu_device`](@ref MLDataDevices.get_device) function. +A `device` object can be created using the [`gpu_device`](@ref MLDataDevices.gpu_device) function. `gpu_device` first checks for a GPU preference, and if possible returns a device for the preference backend. For instance, consider the following example, where we load the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) package to use an NVIDIA GPU (`"CUDA"` is the default preference): ```julia-repl @@ -360,12 +360,6 @@ MLDataDevices.gpu_backend! MLDataDevices.get_device_type MLDataDevices.reset_gpu_device! MLDataDevices.supported_gpu_backends -MLDataDevices.CPUDevice -MLDataDevices.CUDADevice -MLDataDevices.AMDGPUDevice -MLDataDevices.MetalDevice -MLDataDevices.oneAPIDevice -MLDataDevices.XLADevice MLDataDevices.DeviceIterator ``` From 5acb42a1b29cce96394b1b57bc874cc49c37e799 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 10 Oct 2024 14:36:51 +0200 Subject: [PATCH 14/21] fix docs --- docs/src/guide/models/recurrence.md | 2 +- docs/src/guide/saving.md | 4 ++-- src/layers/basic.jl | 2 +- src/layers/macro.jl | 2 +- src/layers/recurrent.jl | 8 ++++---- src/utils.jl | 4 ++-- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/src/guide/models/recurrence.md b/docs/src/guide/models/recurrence.md index a93b0dc258..6cadef0402 100644 --- a/docs/src/guide/models/recurrence.md +++ b/docs/src/guide/models/recurrence.md @@ -86,7 +86,7 @@ Chain( ), Dense(5 => 1), # 6 parameters ) # Total: 6 trainable arrays, 51 parameters, - # plus 1 non-trainable, 5 parameters, summarysize 580 bytes. + # plus 1 non-trainable, 5 parameters, summarysize 540 bytes. ``` In this example, each output has only one component. diff --git a/docs/src/guide/saving.md b/docs/src/guide/saving.md index 0b1e4fc91b..fb00454eec 100644 --- a/docs/src/guide/saving.md +++ b/docs/src/guide/saving.md @@ -62,7 +62,7 @@ julia> m = Chain(Dense(10 => 5, relu), Dense(5 => 2)) Chain( Dense(10 => 5, relu), # 55 parameters Dense(5 => 2), # 12 parameters -) # Total: 4 arrays, 67 parameters, 524 bytes. +) # Total: 4 arrays, 67 parameters, 476 bytes. julia> for epoch in 1:10 # ... train model ... @@ -131,7 +131,7 @@ julia> model Chain( Dense(10 => 5, relu), # 55 parameters Dense(5 => 2), # 12 parameters -) # Total: 4 arrays, 67 parameters, 524 bytes. +) # Total: 4 arrays, 67 parameters, 476 bytes. ``` !!! warning Saving models this way could lead to compatibility issues across julia versions diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 58a42f9b2b..43d995dca0 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -298,7 +298,7 @@ Maxout( Dense(5 => 7, tanh), # 42 parameters Dense(5 => 7, tanh), # 42 parameters Dense(5 => 7, tanh), # 42 parameters -) # Total: 6 arrays, 126 parameters, 888 bytes. +) # Total: 6 arrays, 126 parameters, 816 bytes. julia> Flux.outputsize(m3, (5, 11)) (7, 11) diff --git a/src/layers/macro.jl b/src/layers/macro.jl index dcebe551e3..065774602a 100644 --- a/src/layers/macro.jl +++ b/src/layers/macro.jl @@ -43,7 +43,7 @@ Trio( Dense(2 => 1, tanh), # 3 parameters Dense(1 => 1; bias=false), # 1 parameters Dropout(0.4), -) # Total: 3 arrays, 4 parameters, 224 bytes. +) # Total: 3 arrays, 4 parameters, 240 bytes. ``` """ diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index f55ebb1741..931eed65ca 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -232,7 +232,7 @@ julia> r = RNN(3 => 5) Recur( RNNCell(3 => 5, tanh), # 50 parameters ) # Total: 4 trainable arrays, 50 parameters, - # plus 1 non-trainable, 5 parameters, summarysize 432 bytes. + # plus 1 non-trainable, 5 parameters, summarysize 424 bytes. julia> r(rand(Float32, 3)) |> size (5,) @@ -341,7 +341,7 @@ julia> l = LSTM(3 => 5) Recur( LSTMCell(3 => 5), # 190 parameters ) # Total: 5 trainable arrays, 190 parameters, - # plus 2 non-trainable, 10 parameters, summarysize 1.062 KiB. + # plus 2 non-trainable, 10 parameters, summarysize 1.023 KiB. julia> l(rand(Float32, 3)) |> size (5,) @@ -415,7 +415,7 @@ julia> g = GRU(3 => 5) Recur( GRUCell(3 => 5), # 140 parameters ) # Total: 4 trainable arrays, 140 parameters, - # plus 1 non-trainable, 5 parameters, summarysize 792 bytes. + # plus 1 non-trainable, 5 parameters, summarysize 784 bytes. julia> g(rand(Float32, 3)) |> size (5,) @@ -485,7 +485,7 @@ julia> g = GRUv3(3 => 5) Recur( GRUv3Cell(3 => 5), # 140 parameters ) # Total: 5 trainable arrays, 140 parameters, - # plus 1 non-trainable, 5 parameters, summarysize 848 bytes. + # plus 1 non-trainable, 5 parameters, summarysize 840 bytes. julia> g(rand(Float32, 3)) |> size (5,) diff --git a/src/utils.jl b/src/utils.jl index 8fa3889a11..09f01fc715 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -186,7 +186,7 @@ julia> round(std(Flux.kaiming_normal(10, 1000)), digits=3) 0.044f0 julia> round(std(Flux.kaiming_normal(1000, 10)), digits=3) -0.449f0 +0.45f0 julia> round(std(Flux.kaiming_normal(1000, 1000)), digits=3) 0.045f0 @@ -590,7 +590,7 @@ Chain( ), Dense(64 => 10), # 650 parameters ) # Total: 6 trainable arrays, 51_018 parameters, - # plus 2 non-trainable, 128 parameters, summarysize 200.312 KiB. + # plus 2 non-trainable, 128 parameters, summarysize 200.211 KiB. julia> Flux.modules(m2) 7-element Vector{Any}: From 7e964ed5b13c75afde9170984100d1c3d2eff29a Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 11 Oct 2024 03:48:15 +0200 Subject: [PATCH 15/21] fix docs --- docs/src/guide/models/recurrence.md | 2 +- src/layers/basic.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/guide/models/recurrence.md b/docs/src/guide/models/recurrence.md index 6cadef0402..7827062f22 100644 --- a/docs/src/guide/models/recurrence.md +++ b/docs/src/guide/models/recurrence.md @@ -71,7 +71,7 @@ julia> RNN(2, 5) # or equivalently RNN(2 => 5) Recur( RNNCell(2 => 5, tanh), # 45 parameters ) # Total: 4 trainable arrays, 45 parameters, - # plus 1 non-trainable, 5 parameters, summarysize 412 bytes. + # plus 1 non-trainable, 5 parameters, summarysize 404 bytes. ``` Equivalent to the `RNN` stateful constructor, `LSTM` and `GRU` are also available. diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 43d995dca0..254f06db0c 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -499,7 +499,7 @@ Parallel( +, α = Dense(10 => 2, tanh), # 22 parameters β = Dense(5 => 2), # 12 parameters -) # Total: 4 arrays, 34 parameters, 392 bytes. +) # Total: 4 arrays, 34 parameters, 344 bytes. julia> model2(rand32(10), rand32(5)) |> size (2,) From 56c2485ea59288ce7b4d53f59b6d50c78999a300 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 11 Oct 2024 03:50:42 +0200 Subject: [PATCH 16/21] fix docs --- docs/src/reference/destructure.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/src/reference/destructure.md b/docs/src/reference/destructure.md index 2071b5466b..469a1465b1 100644 --- a/docs/src/reference/destructure.md +++ b/docs/src/reference/destructure.md @@ -94,4 +94,5 @@ Flux.loadmodel! Functors.KeyPath Functors.getkeypath Functors.haskeypath -``` \ No newline at end of file +Functors.setkeypath! +``` From 2fc826f9477b4d75ffc5803d74f485099171990a Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 11 Oct 2024 03:57:13 +0200 Subject: [PATCH 17/21] some tests are broken --- test/layers/normalisation.jl | 2 +- test/train.jl | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 35f11a4adc..6c1b78919f 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -446,7 +446,7 @@ end @test Zygote.hessian_reverse(sum∘m1, [1.0,2.0,3.0]) == zeros(3, 3) m2 = Chain(BatchNorm(3), sum) - @test Zygote.hessian_reverse(m2, Float32[1 2; 3 4; 5 6]) == zeros(Float32, 6, 6) + @test_broken Zygote.hessian_reverse(m2, Float32[1 2; 3 4; 5 6]) == zeros(Float32, 6, 6) end @testset "ForwardDiff" begin diff --git a/test/train.jl b/test/train.jl index d2114d5a16..436a4fca69 100644 --- a/test/train.jl +++ b/test/train.jl @@ -154,11 +154,15 @@ for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) model.bias .= 0 pen2(x::AbstractArray) = sum(abs2, x)/2 opt = Flux.setup(Adam(0.1), model) - trainfn!(model, data, opt) do m, x, y - err = Flux.mse(m(x), y) - l2 = sum(pen2, Flux.params(m)) - err + 0.33 * l2 + + @test_broken begin + trainfn!(model, data, opt) do m, x, y + err = Flux.mse(m(x), y) + l2 = sum(pen2, Flux.params(m)) + err + 0.33 * l2 + end end + diff2 = model.weight .- init_weight @test diff1 ≈ diff2 end From 1a1e4cb7f7d5e37b85634f9c047afb6e5c4f1b6f Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 11 Oct 2024 03:58:17 +0200 Subject: [PATCH 18/21] cleanup --- test/ext_metal/get_devices.jl | 22 ---------------------- 1 file changed, 22 deletions(-) delete mode 100644 test/ext_metal/get_devices.jl diff --git a/test/ext_metal/get_devices.jl b/test/ext_metal/get_devices.jl deleted file mode 100644 index cdde261fa4..0000000000 --- a/test/ext_metal/get_devices.jl +++ /dev/null @@ -1,22 +0,0 @@ -@testset "get_device()" begin - metal_device = Flux.get_device() - - @test typeof(metal_device) <: Flux.MetalDevice - - # correctness of data transfer - x = randn(5, 5) - cx = x |> metal_device - @test cx isa Metal.MtlArray - @test Metal.device(cx).registryID == metal_device.deviceID.registryID -end - -@testset "gpu_device()" begin - metal_device = gpu_device() - - @test typeof(metal_device) <: Flux.MetalDevice - - metal_device = gpu_device(0) - - @test typeof(metal_device) <: Flux.MetalDevice -end - From c62a0e821fbf1116daf27204b6189258ff92b0bf Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 11 Oct 2024 04:26:34 +0200 Subject: [PATCH 19/21] fix tests --- .buildkite/pipeline.yml | 4 ++-- test/train.jl | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 553fb318c5..961c056dfc 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -26,10 +26,10 @@ steps: # cuda: "*" # timeout_in_minutes: 60 - - label: "Metal with julia {{matrix.julia}}" + - label: "Metal with julia v1" plugins: - JuliaCI/julia#v1: - version: "{{matrix.julia}}" + version: "1" - JuliaCI/julia-test#v1: test_args: "--quickfail" - JuliaCI/julia-coverage#v1: diff --git a/test/train.jl b/test/train.jl index 436a4fca69..96bd0d22a4 100644 --- a/test/train.jl +++ b/test/train.jl @@ -161,10 +161,12 @@ for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) l2 = sum(pen2, Flux.params(m)) err + 0.33 * l2 end + + diff2 = model.weight .- init_weight + @test diff1 ≈ diff2 + + true end - - diff2 = model.weight .- init_weight - @test diff1 ≈ diff2 end # Take 3: using WeightDecay instead. Need the /2 above, to match exactly. From 29129eacd394c024cfdf0955ea46120dcaae48f8 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 11 Oct 2024 04:35:03 +0200 Subject: [PATCH 20/21] buildkite --- .buildkite/pipeline.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 961c056dfc..c9b4e450ac 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -46,7 +46,7 @@ steps: using Pkg Pkg.resolve()' commands: | - printf "[MLDataDevices]\ngpu_backend = \"Metal\"" > LocalPreferences.toml + printf "[MLDataDevices]\ngpu_backend = \"Metal\"\n" > LocalPreferences.toml if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 60 @@ -74,7 +74,7 @@ steps: rocm: "*" rocmgpu: "*" commands: | - printf "[MLDataDevices]\ngpu_backend = \"AMDGPU\"" > LocalPreferences.toml + printf "[MLDataDevices]\ngpu_backend = \"AMDGPU\"\n" > LocalPreferences.toml timeout_in_minutes: 60 env: JULIA_AMDGPU_CORE_MUST_LOAD: "1" From 05a0eb4e29352e3151c0c206e4345a2f4653cacc Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 11 Oct 2024 16:03:38 +0200 Subject: [PATCH 21/21] rework rng_from_array --- ext/FluxCUDAExt/FluxCUDAExt.jl | 1 - ext/FluxCUDAExt/utils.jl | 1 - src/utils.jl | 10 +++------- 3 files changed, 3 insertions(+), 9 deletions(-) delete mode 100644 ext/FluxCUDAExt/utils.jl diff --git a/ext/FluxCUDAExt/FluxCUDAExt.jl b/ext/FluxCUDAExt/FluxCUDAExt.jl index 0a0f67adc7..9f0dae1aa9 100644 --- a/ext/FluxCUDAExt/FluxCUDAExt.jl +++ b/ext/FluxCUDAExt/FluxCUDAExt.jl @@ -32,7 +32,6 @@ end ChainRulesCore.@non_differentiable check_use_cuda() include("functor.jl") -include("utils.jl") function __init__() Flux.CUDA_LOADED[] = true diff --git a/ext/FluxCUDAExt/utils.jl b/ext/FluxCUDAExt/utils.jl deleted file mode 100644 index 07500e9eb9..0000000000 --- a/ext/FluxCUDAExt/utils.jl +++ /dev/null @@ -1 +0,0 @@ -Flux.rng_from_array(::CuArray) = CUDA.default_rng() diff --git a/src/utils.jl b/src/utils.jl index 09f01fc715..6077544178 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -37,14 +37,10 @@ epseltype(x) = eps(float(eltype(x))) rng_from_array(x) Create an instance of the RNG most appropriate for `x`. -The current defaults are: -- `x isa CuArray`: `CUDA.default_rng()` -- `x isa AbstractArray`: `Random.default_rng() +As an example, if `x` is a`CuArray`, it will return a `CUDA.default_rng()`. +If `x` is an `Array` instead, it will return a `Random.default_rng()`. """ -rng_from_array(::AbstractArray) = Random.default_rng() - -@non_differentiable rng_from_array(::Any) - +rng_from_array(x) = MLDataDevices.default_device_rng(MLDataDevices.get_device(x)) """ glorot_uniform([rng], size...; gain = 1) -> Array