From 50be8d2fad7638aed857da0a85d522fd6c2db33a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 29 Oct 2024 23:14:36 -0400 Subject: [PATCH 1/2] re-write the start of GPU docs --- docs/src/guide/gpu.md | 206 ++++++++++++++++++++++-------------------- 1 file changed, 108 insertions(+), 98 deletions(-) diff --git a/docs/src/guide/gpu.md b/docs/src/guide/gpu.md index 57d87c57b7..79e8c6a61f 100644 --- a/docs/src/guide/gpu.md +++ b/docs/src/guide/gpu.md @@ -1,140 +1,150 @@ # GPU Support -Starting with v0.14, Flux doesn't force a specific GPU backend and the corresponding package dependencies on the users. -Thanks to the [package extension mechanism](https://pkgdocs.julialang.org/v1/creating-packages/#Conditional-loading-of-code-in-packages-(Extensions)) introduced in julia v1.9, Flux conditionally loads GPU specific code once a GPU package is made available (e.g. through `using CUDA`). +Most work on neural networks involves the use of GPUs, as they can typically perform the required computation much faster. +This page describes how Flux co-operates with various other packages, which talk to GPU hardware. -NVIDIA GPU support requires the packages `CUDA.jl` and `cuDNN.jl` to be installed in the environment. In the julia REPL, type `] add CUDA, cuDNN` to install them. For more details see the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) readme. +## Basic GPU use: from `Array` to `CuArray` with `cu` -AMD GPU support is available since Julia 1.9 on systems with ROCm and MIOpen installed. For more details refer to the [AMDGPU.jl](https://github.com/JuliaGPU/AMDGPU.jl) repository. +Julia's GPU packages work with special array types, in place of the built-in `Array`. +The most used is `CuArray` provided by [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl), for GPUs made by NVIDIA. +That package provides a function `cu` which converts an ordinary `Array` (living in CPu memory) to a `CuArray` (living in GPU memory). +Functions like `*` and broadcasting specialise so that, when given `CuArray`s, all the computation happens on the GPU: -Metal GPU acceleration is available on Apple Silicon hardware. For more details refer to the [Metal.jl](https://github.com/JuliaGPU/Metal.jl) repository. Metal support in Flux is experimental and many features are not yet available. +```julia +W = randn(3, 4) # some weights, on CPU: 3×4 Array{Float64, 2} +x = randn(4) # fake data +y = tanh.(W * x) # computation on the CPU -In order to trigger GPU support in Flux, you need to call `using CUDA`, `using AMDGPU` or `using Metal` -in your code. Notice that for CUDA, explicitly loading also `cuDNN` is not required, but the package has to be installed in the environment. +using CUDA +cu(W) isa CuArray{Float32} +(cW, cx) = (W, x) |> cu # move both to GPU +cy = tanh.(cW * cx) # computation on the GPU +``` -!!! compat "Flux ≤ 0.13" - Old versions of Flux automatically installed CUDA.jl to provide GPU support. Starting from Flux v0.14, CUDA.jl is not a dependency anymore and has to be installed manually. +Notice that `cu` doesn't only move arrays, it also recurses into many structures, such as the tuple `(W, x)` above. +(Notice also that it converts Julia's default `Float64` numbers to `Float32`, as this is what most GPUs support efficiently -- it calls itself "opinionated". Flux defaults to `Float32` in all cases.) +To use CUDA with Flux, you can simply use `cu` to move both the model, and the data. +It will create a copy of the Flux model, with all of its parameter arrays moved to the GPU: -## Basic GPU Usage +```julia +using Pkg; Pkg.add(["CUDA", "cuDNN"]) # do this once -Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl), [AMDGPU.jl](https://github.com/JuliaGPU/AMDGPU.jl), and [Metal.jl](https://github.com/JuliaGPU/Metal.jl). -Flux is agnostic to array types, so we simply need to move model weights and data to the GPU and Flux will handle it. +using Flux, CUDA +CUDA.allowscalar(false) # recommended -For example, we can use `CUDA.CuArray` (with the `CUDA.cu` converter) to run our [basic example](@ref man-basics) on an NVIDIA GPU. +model = Dense(W, true, tanh) # wrap the same matrix W +model(x) ≈ y # same result, still on CPU -(Note that you need to have CUDA available to use CUDA.CuArray – please see the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) instructions for more details.) +c_model = cu(model) # move all the arrays within model to the GPU +c_model(cx) # computation on the GPU +``` + +Notice that you need `using CUDA` (every time) but also `] add cuDNN` (once, when installing packages). +This is a quirk of how these packages are set up. + +Flux's `gradient`, and training functions like `setup`, `update!`, and `train!`, are all equally happy to accept GPU arrays and GPU models, and then perform all computations on the GPU. +It is recommended that you move the model to the GPU before calling `setup`. ```julia -using CUDA +Flux.gradient((f,x) -> sum(abs2, f(x)), model, x) -W = cu(rand(2, 5)) # a 2×5 CuArray -b = cu(rand(2)) +c_grads = Flux.gradient((f,x) -> sum(abs2, f(x)), c_model, cx) # same result, all on GPU -predict(x) = W*x .+ b -loss(x, y) = sum((predict(x) .- y).^2) +c_opt = Flux.setup(Adam(), c_model) # setup optimiser after moving model to GPU -x, y = cu(rand(5)), cu(rand(2)) # Dummy data -loss(x, y) # ~ 3 +Flux.update!(c_opt, c_model, c_grads[1]) # mutates c_model but not model ``` -Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before. - -If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `fmap`, which allows you to alter all parameters of a model at once. +To move arrays and other objects back to the CPU, Flux provides a function `cpu`. +This is recommended when saving models, `Flux.state(c_model |> cpu)`, see below. ```julia -d = Dense(10 => 5, σ) -d = fmap(cu, d) -d.weight # CuArray -d(cu(rand(10))) # CuArray output - -m = Chain(Dense(10 => 5, σ), Dense(5 => 2), softmax) -m = fmap(cu, m) -m(cu(rand(10))) +cpu(cW) isa Array{Float32, 2} + +model2 = cpu(c_model) # copy model back to CPU +model2(x) ``` -As a convenience, Flux provides the `gpu` function to convert models and data to the GPU if one is available. By default, it'll do nothing. So, you can safely call `gpu` on some data or model (as shown below), and the code will not error, regardless of whether the GPU is available or not. If a GPU library (e.g. CUDA) loads successfully, `gpu` will move data from the CPU to the GPU. As is shown below, this will change the type of something like a regular array to a `CuArray`. +!!! compat "Flux ≤ 0.13" + Old versions of Flux automatically loaded CUDA.jl to provide GPU support. Starting from Flux v0.14, it has to be loaded separately. Julia's [package extensions](https://pkgdocs.julialang.org/v1/creating-packages/#Conditional-loading-of-code-in-packages-(Extensions)) allow Flux to automatically load some GPU-specific code when needed. + +## Other GPU packages for AMD & Apple + +Non-NVIDIA graphics cards are supported by other packages. Each provides its own function which behaves like `cu`. +AMD GPU support provided by [AMDGPU.jl](https://github.com/JuliaGPU/AMDGPU.jl), on systems with ROCm and MIOpen installed. +This package has a function `roc` which converts `Array` to `ROCArray`: ```julia -julia> using Flux, CUDA - -julia> m = Dense(10, 5) |> gpu -Dense(10 => 5) # 55 parameters - -julia> x = rand(10) |> gpu -10-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}: - 0.066846445 - ⋮ - 0.76706964 - -julia> m(x) -5-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}: - -0.99992573 - ⋮ - -0.547261 +using Flux, AMDGPU +AMDGPU.allowscalar(false) + +r_model = roc(model) +r_model(roc(x)) + +Flux.gradient((f,x) -> sum(abs2, f(x)), r_model, roc(x)) ``` -The analogue `cpu` is also available for moving models and data back off of the GPU. +Experimental support for apple devices with M-series chips is provided by [Metal.jl](https://github.com/JuliaGPU/Metal.jl). This has a function [`mtl`](https://metal.juliagpu.org/stable/api/array/#Metal.mtl) which works like `cu`, converting `Array` to `MtlArray`: ```julia -julia> x = rand(10) |> gpu -10-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}: - 0.8019236 - ⋮ - 0.7766742 - -julia> x |> cpu -10-element Vector{Float32}: - 0.8019236 - ⋮ - 0.7766742 -``` +using Flux, Metal +Metal.allowscalar(false) -## Using device objects +m_model = mtl(model) +m_y = m_model(mtl(x)) -In Flux, you can create `device` objects which can be used to easily transfer models and data to GPUs (and defaulting to using the CPU if no GPU backend is available). -These features are provided by [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl) package, that Flux uses internally and re-exports. +Flux.gradient((f,x) -> sum(abs2, f(x)), m_model, mtl(x)) +``` -Device objects can be automatically created using the [`cpu_device`](@ref MLDataDevices.cpu_device) and [`gpu_device`](@ref MLDataDevices.gpu_device) functions. For instance, the `gpu` and `cpu` functions are just convenience functions defined as +!!! warn "Experimental" + Metal support in Flux is experimental and many features are not yet available. + AMD support is improving, but likely to have more rough edges than CUDA. + +If you want your model to work with any brand of GPU, or none, then you may not wish to write `cu` everywhere. +One simple way to be generic is, at the top of the file, to un-comment one of several lines which import a package and assign its "adaptor" to the same name: ```julia -cpu(x) = cpu_device()(x) -gpu(x) = gpu_device()(x) +using CUDA: cu as device # after this, `device === cu` +# using AMDGPU: roc as device +# device = identity # do-nothing, for CPU + +using Flux +model = Chain(...) |> device ``` -`gpu_device` performs automatic GPU device selection and returns a device object: -- If no GPU is available, it returns a `CPUDevice` object. -- If a LocalPreferences file is present, then the backend specified in the file is used. To set a backend, use `Flux.gpu_backend!()`. If the trigger package corresponding to the device is not loaded (e.g. with `using CUDA`), then a warning is displayed. -- If no LocalPreferences option is present, then the first working GPU with loaded trigger package is used. +!!! note "Adapt.jl" + The functions `cu`, `mtl`, `roc` all use [Adapt.jl](https://github.com/JuliaGPU/Adapt.jl), to work within various wrappers. + The reason they work on Flux models is that `Flux.@layer Layer` defines methods of `Adapt.adapt_structure(to, lay::Layer)`. -Consider the following example, where we load the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) package to use an NVIDIA GPU (`"CUDA"` is the default preference): -```julia-repl -julia> using Flux, CUDA; - -julia> device = gpu_device() # returns handle to an NVIDIA GPU if available -(::CUDADevice{Nothing}) (generic function with 4 methods) +## Automatic GPU choice with `gpu` -julia> model = Dense(2 => 3); +Flux also provides a more automatic way of choosing which GPU (or none) to use. This is the function `gpu`: +* By default it does nothing. +* If the package CUDA is loaded, and `CUDA.functional() === true`, then it behaves like `cu`. +* If the package AMDGPU is loaded, and `AMDGPU.functional() === true`, then it behaves like `roc`. +* If the package Metal is loaded, and `Metal.functional() === true`, then it behaves like `mtl`. +* If two differnet GPU packages are loaded, the first one takes priority. -julia> model.weight # the model initially lives in CPU memory -3×2 Matrix{Float32}: - -0.984794 -0.904345 - 0.720379 -0.486398 - 0.851011 -0.586942 +For the most part, this means that a script which says `model |> gpu` and `data |> gpu` will just work. +It should always run, and if a GPU package is loaded (and finds the correct hardware) then that will be used. -julia> model = model |> device # transfer model to the GPU -Dense(2 => 3) # 9 parameters +The function `gpu` uses a lower-level function called `get_device()` from [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl), +which checks what to do & then returns some device object. In fact, the entire implementation is just this: -julia> model.weight -3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}: - -0.984794 -0.904345 - 0.720379 -0.486398 - 0.851011 -0.586942 +```julia +gpu(x) = gpu_device()(x) +cpu(x) = cpu_device()(x) ``` +## Manually selecting devices + +?? + + ## Transferring Training Data In order to train the model using the GPU both model and the training data have to be transferred to GPU memory. Moving the data can be done in two different ways: @@ -178,7 +188,7 @@ In order to train the model using the GPU both model and the training data have ## Saving GPU-Trained Models -After the training process is done, one must always transfer the trained model back to the `cpu` memory scope before serializing or saving to disk. This can be done, as described in the previous section, with: +After the training process is done, we must always transfer the trained model back to the CPU memory before serializing or saving to disk. This can be done with `cpu`: ```julia model = cpu(model) # or model = model |> cpu ``` @@ -280,11 +290,11 @@ Due to a limitation in `Metal.jl`, currently this kind of data movement across d ## Distributed data parallel training !!! danger "Experimental" - + Distributed support is experimental and could change in the future. -Flux supports now distributed data parallel training with `DistributedUtils` module. +Flux supports now distributed data parallel training with `DistributedUtils` module. If you want to run your code on multiple GPUs, you have to install `MPI.jl` (see [docs](https://juliaparallel.org/MPI.jl/stable/usage/) for more info). ```julia-repl @@ -352,7 +362,7 @@ DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam( julia> st_opt = Optimisers.setup(opt, model) (layers = ((weight = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999))), bias = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999))), σ = ()), (weight = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0], Float32[0.0], (0.9, 0.999))), σ = ())),) -julia> st_opt = DistributedUtils.synchronize!!(backend, st_opt; root=0) +julia> st_opt = DistributedUtils.synchronize!!(backend, st_opt; root=0) (layers = ((weight = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999))), bias = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999))), σ = ()), (weight = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0], Float32[0.0], (0.9, 0.999))), σ = ())),) ``` @@ -376,7 +386,7 @@ Epoch 3: Loss 0.012763695 Remember that in order to run it on multiple GPUs you have to run from CLI `mpiexecjl --project=. -n julia .jl`, where `` is the number of processes that you want to use. The number of processes usually corresponds to the number of gpus. -By default `MPI.jl` MPI installation is CUDA-unaware so if you want to run it in CUDA-aware mode, read more [here](https://juliaparallel.org/MPI.jl/stable/usage/#CUDA-aware-MPI-support) on custom installation and rebuilding `MPI.jl`. +By default `MPI.jl` MPI installation is CUDA-unaware so if you want to run it in CUDA-aware mode, read more [here](https://juliaparallel.org/MPI.jl/stable/usage/#CUDA-aware-MPI-support) on custom installation and rebuilding `MPI.jl`. Then test if your MPI is CUDA-aware by ```julia-repl julia> import Pkg @@ -390,7 +400,7 @@ julia> set_preferences!("Flux", "FluxDistributedMPICUDAAware" => true) ``` !!! warning "Known shortcomings" - + We don't run CUDA-aware tests so you're running it at own risk. @@ -424,4 +434,4 @@ julia> using Metal julia> Metal.functional() true -``` \ No newline at end of file +``` From d9c03b962f8024ddbb965c2d7437520dfa2bc7b9 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 29 Oct 2024 23:51:23 -0400 Subject: [PATCH 2/2] tweaks --- docs/src/guide/gpu.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/src/guide/gpu.md b/docs/src/guide/gpu.md index 79e8c6a61f..6e0ed95b3b 100644 --- a/docs/src/guide/gpu.md +++ b/docs/src/guide/gpu.md @@ -34,7 +34,7 @@ using Pkg; Pkg.add(["CUDA", "cuDNN"]) # do this once using Flux, CUDA CUDA.allowscalar(false) # recommended -model = Dense(W, true, tanh) # wrap the same matrix W +model = Dense(W, true, tanh) # wrap the same matrix W in a Flux layer model(x) ≈ y # same result, still on CPU c_model = cu(model) # move all the arrays within model to the GPU @@ -43,13 +43,13 @@ c_model(cx) # computation on the GPU Notice that you need `using CUDA` (every time) but also `] add cuDNN` (once, when installing packages). This is a quirk of how these packages are set up. +(The [`cuDNN.jl`](https://github.com/JuliaGPU/CUDA.jl/tree/master/lib/cudnn) sub-package handles operations such as convolutions, called by Flux via [NNlib.jl](https://github.com/FluxML/NNlib.jl).) Flux's `gradient`, and training functions like `setup`, `update!`, and `train!`, are all equally happy to accept GPU arrays and GPU models, and then perform all computations on the GPU. It is recommended that you move the model to the GPU before calling `setup`. ```julia -Flux.gradient((f,x) -> sum(abs2, f(x)), model, x) - +grads = Flux.gradient((f,x) -> sum(abs2, f(x)), model, x) # on CPU c_grads = Flux.gradient((f,x) -> sum(abs2, f(x)), c_model, cx) # same result, all on GPU c_opt = Flux.setup(Adam(), c_model) # setup optimiser after moving model to GPU @@ -86,7 +86,7 @@ r_model(roc(x)) Flux.gradient((f,x) -> sum(abs2, f(x)), r_model, roc(x)) ``` -Experimental support for apple devices with M-series chips is provided by [Metal.jl](https://github.com/JuliaGPU/Metal.jl). This has a function [`mtl`](https://metal.juliagpu.org/stable/api/array/#Metal.mtl) which works like `cu`, converting `Array` to `MtlArray`: +Experimental support for Apple devices with M-series chips is provided by [Metal.jl](https://github.com/JuliaGPU/Metal.jl). This has a function [`mtl`](https://metal.juliagpu.org/stable/api/array/#Metal.mtl) which works like `cu`, converting `Array` to `MtlArray`: ```julia using Flux, Metal @@ -98,7 +98,7 @@ m_y = m_model(mtl(x)) Flux.gradient((f,x) -> sum(abs2, f(x)), m_model, mtl(x)) ``` -!!! warn "Experimental" +!!! danger "Experimental" Metal support in Flux is experimental and many features are not yet available. AMD support is improving, but likely to have more rough edges than CUDA. @@ -142,7 +142,7 @@ cpu(x) = cpu_device()(x) ## Manually selecting devices -?? +I thought there was a whole `Flux.gpu_backend!` and Preferences.jl story we had to tell?? ## Transferring Training Data