Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend GPU support to Metal, ROCm, and oneAPI backends #405

Merged
merged 21 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
808d265
Initial Commit
rkierulf May 28, 2024
4c6c567
Initial implementation for supporting multiple GPU backends
rkierulf May 30, 2024
e701616
Fix printing issues and remove unecessary reclaim_gpu function and Pa…
rkierulf May 31, 2024
2549b38
Improve error message for float64 unsupported
rkierulf May 31, 2024
bc827f6
Fix some CPU / formatting issues
rkierulf Jun 3, 2024
dba1845
Add AbstractRange adapt_storage to match Flux and fix issues with sav…
rkierulf Jun 4, 2024
3da2697
Clean up GPU functions
rkierulf Jun 5, 2024
ab5f418
Fix issue with device name
rkierulf Jun 6, 2024
569c628
Remove accidental print statement and change Adapt Compatability to 3…
rkierulf Jun 6, 2024
9562294
Export print_devices() and change to also print CPU information
rkierulf Jun 7, 2024
aa9d391
Update compat versions
rkierulf Jun 7, 2024
adee5f2
Add temporary workaround for Metal cumsum and findall
rkierulf Jun 10, 2024
f729cc9
Add using statements to tests
rkierulf Jun 11, 2024
8b617b9
Allow switching backend between CPU / GPU
rkierulf Jun 11, 2024
04a0604
Merge branch 'master' of https://github.com/JuliaHealth/KomaMRI.jl in…
rkierulf Jun 11, 2024
db21aa5
Move using statements inside test items
rkierulf Jun 11, 2024
d50a108
Remove GPU functionality from main tests
rkierulf Jun 11, 2024
5411e0c
Remove GPU functionality from main tests (Project.toml)
rkierulf Jun 11, 2024
cba4a24
Reformat docstring
rkierulf Jun 12, 2024
b278734
Add print_devices and get_backend to docs
rkierulf Jun 12, 2024
64897c8
Add warning for Metal
rkierulf Jun 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion KomaMRICore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,35 @@ version = "0.8.3"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
KomaMRIBase = "d0bc0b20-b151-4d03-b2a4-6ca51751cb9c"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"

[extensions]
KomaAMDGPUExt = "AMDGPU"
KomaCUDAExt = "CUDA"
KomaMetalExt = "Metal"
KomaoneAPIExt = "oneAPI"

[compat]
Adapt = "3, 4"
AMDGPU = "0.9"
CUDA = "3, 4, 5"
Functors = "0.4"
KernelAbstractions = "0.9"
KomaMRIBase = "0.8"
Metal = "1"
oneAPI = "1"
Pkg = "1.4"
ProgressMeter = "1"
Reexport = "1"
Expand Down
24 changes: 24 additions & 0 deletions KomaMRICore/ext/KomaAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module KomaAMDGPUExt

using AMDGPU
import KomaMRICore

KomaMRICore.name(::ROCBackend) = "AMDGPU"
KomaMRICore.isfunctional(::ROCBackend) = AMDGPU.functional()
KomaMRICore.set_device!(::ROCBackend, dev_idx::Integer) = AMDGPU.device_id!(dev_idx)
KomaMRICore.set_device!(::ROCBackend, dev::AMDGPU.HIPDevice) = AMDGPU.device!(dev)
KomaMRICore.device_name(::ROCBackend) = AMDGPU.device().name

function KomaMRICore._print_devices(::ROCBackend)
devices = [
Symbol("($(i-1)$(i == 1 ? "*" : " "))") => d.name for
(i, d) in enumerate(AMDGPU.devices())
]
@info "$(length(AMDGPU.devices())) AMD capable device(s)." devices...
end

function __init__()
push!(KomaMRICore.LOADED_BACKENDS[], ROCBackend())
end

end
23 changes: 23 additions & 0 deletions KomaMRICore/ext/KomaCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module KomaCUDAExt

using CUDA
import KomaMRICore

KomaMRICore.name(::CUDABackend) = "CUDA"
KomaMRICore.isfunctional(::CUDABackend) = CUDA.functional()
KomaMRICore.set_device!(::CUDABackend, val) = CUDA.device!(val)
KomaMRICore.device_name(::CUDABackend) = CUDA.name(CUDA.device())

function KomaMRICore._print_devices(::CUDABackend)
devices = [
Symbol("($(i-1)$(i == 1 ? "*" : " "))") => CUDA.name(d) for
(i, d) in enumerate(CUDA.devices())
]
@info "$(length(CUDA.devices())) CUDA capable device(s)." devices...
end

function __init__()
push!(KomaMRICore.LOADED_BACKENDS[], CUDABackend())
end

end
28 changes: 28 additions & 0 deletions KomaMRICore/ext/KomaMetalExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
module KomaMetalExt

using Metal
import KomaMRICore

KomaMRICore.name(::MetalBackend) = "Metal"
KomaMRICore.isfunctional(::MetalBackend) = Metal.functional()
KomaMRICore.set_device!(::MetalBackend, device_index::Integer) = device_index == 1 || @warn "Metal does not support multiple gpu devices. Ignoring the device setting."
KomaMRICore.set_device!(::MetalBackend, dev::Metal.MTLDevice) = Metal.device!(dev)
KomaMRICore.device_name(::MetalBackend) = String(Metal.current_device().name)

function KomaMRICore._print_devices(::MetalBackend)
@info "Metal device type: $(KomaMRICore.device_name(MetalBackend()))"
end

#Temporary workaround for https://github.com/JuliaGPU/Metal.jl/issues/348
#Once run_spin_excitation! and run_spin_precession! are kernel-based, this code
#can be removed
Base.cumsum(x::MtlVector) = convert(MtlVector, cumsum(KomaMRICore.cpu(x)))
Base.cumsum(x::MtlArray{T}; dims) where T = convert(MtlArray{T}, cumsum(KomaMRICore.cpu(x), dims=dims))
Base.findall(x::MtlVector{Bool}) = convert(MtlVector, findall(KomaMRICore.cpu(x)))

function __init__()
push!(KomaMRICore.LOADED_BACKENDS[], MetalBackend())
@warn "Due to https://github.com/JuliaGPU/Metal.jl/issues/348, some functions may need to run on the CPU. Performance may be impacted as a result."
end

end
23 changes: 23 additions & 0 deletions KomaMRICore/ext/KomaoneAPIExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module KomaoneAPIExt

using oneAPI
import KomaMRICore

KomaMRICore.name(::oneAPIBackend) = "oneAPI"
KomaMRICore.isfunctional(::oneAPIBackend) = oneAPI.functional()
KomaMRICore.set_device!(::oneAPIBackend, val) = oneAPI.device!(val)
KomaMRICore.device_name(::oneAPIBackend) = oneAPI.properties(oneAPI.device()).name

function KomaMRICore._print_devices(::oneAPIBackend)
devices = [
Symbol("($(i-1)$(i == 1 ? "*" : " "))") => oneAPI.properties(d).name for
(i, d) in enumerate(oneAPI.devices())
]
@info "$(length(oneAPI.devices())) oneAPI capable device(s)." devices...
end

function __init__()
push!(KomaMRICore.LOADED_BACKENDS[], oneAPIBackend())
end

end
4 changes: 2 additions & 2 deletions KomaMRICore/src/KomaMRICore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ module KomaMRICore

# General
import Base.*, Base.abs
import KernelAbstractions as KA
using Reexport
using ThreadsX
# Printing
using ProgressMeter
# Simulation
using CUDA

# KomaMRIBase
@reexport using KomaMRIBase
Expand All @@ -18,6 +17,7 @@ include("rawdata/ISMRMRD.jl")
include("datatypes/Spinor.jl")
include("other/DiffusionModel.jl")
# Simulator
include("simulation/Functors.jl")
include("simulation/GPUFunctions.jl")
include("simulation/SimulatorCore.jl")

Expand Down
94 changes: 94 additions & 0 deletions KomaMRICore/src/simulation/Functors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import Adapt: adapt, adapt_storage
import Functors: @functor, functor, fmap, isleaf

#Aux. funcitons to check if the variable we want to convert to CuArray is numeric
_isleaf(x) = isleaf(x)
_isleaf(::AbstractArray{<:Number}) = true
_isleaf(::AbstractArray{T}) where T = isbitstype(T)
_isleaf(::AbstractRange) = true

Check warning on line 8 in KomaMRICore/src/simulation/Functors.jl

View check run for this annotation

Codecov / codecov/patch

KomaMRICore/src/simulation/Functors.jl#L7-L8

Added lines #L7 - L8 were not covered by tests

"""
gpu(x)

Tries to move `x` to the GPU backend specified in the 'backend' parameter.

This works for functions, and any struct marked with `@functor`.

Use [`cpu`](@ref) to copy back to ordinary `Array`s.

See also [`f32`](@ref) and [`f64`](@ref) to change element type only.

# Examples
```julia
x = gpu(x, CUDABackend())
```
"""
function gpu(x, backend::KA.GPU)
return fmap(x -> adapt(backend, x), x; exclude=_isleaf)

Check warning on line 27 in KomaMRICore/src/simulation/Functors.jl

View check run for this annotation

Codecov / codecov/patch

KomaMRICore/src/simulation/Functors.jl#L26-L27

Added lines #L26 - L27 were not covered by tests
end

# To CPU
"""
cpu(x)

Tries to move object to CPU. This works for functions, and any struct marked with `@functor`.

See also [`gpu`](@ref).

# Examples
```julia
x = x |> cpu
```
"""
cpu(x) = fmap(x -> adapt(KA.CPU(), x), x, exclude=_isleaf)

#Precision
paramtype(T::Type{<:Real}, m) = fmap(x -> adapt(T, x), m)
adapt_storage(T::Type{<:Real}, xs::Real) = convert(T, xs)
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Complex}) = convert.(Complex{T}, xs)
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Bool}) = xs
adapt_storage(T::Type{<:Real}, xs::SimpleMotion) = SimpleMotion(paramtype(T, xs.types))
adapt_storage(T::Type{<:Real}, xs::NoMotion) = NoMotion{T}()
function adapt_storage(T::Type{<:Real}, xs::ArbitraryMotion)
fields = []
for field in fieldnames(ArbitraryMotion)
push!(fields, paramtype(T, getfield(xs, field)))
end
return ArbitraryMotion(fields...)
end

"""
f32(m)

Converts the `eltype` of model's parameters to `Float32`
Recurses into structs marked with `@functor`.

See also [`f64`](@ref).
"""
f32(m) = paramtype(Float32, m)

"""
f64(m)

Converts the `eltype` of model's parameters to `Float64` (which is Koma's default)..
Recurses into structs marked with `@functor`.

See also [`f32`](@ref).
"""
f64(m) = paramtype(Float64, m)

Check warning on line 79 in KomaMRICore/src/simulation/Functors.jl

View check run for this annotation

Codecov / codecov/patch

KomaMRICore/src/simulation/Functors.jl#L79

Added line #L79 was not covered by tests

#The functor macro makes it easier to call a function in all the parameters
@functor Phantom

@functor Translation
@functor Rotation
@functor HeartBeat
@functor PeriodicTranslation
@functor PeriodicRotation
@functor PeriodicHeartBeat

@functor Spinor
@functor DiscreteSequence

export gpu, cpu, f32, f64
Loading