Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some test failures #406

Merged
merged 9 commits into from
Jun 14, 2024
7 changes: 7 additions & 0 deletions KomaMRICore/ext/KomaAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,20 @@ module KomaAMDGPUExt

using AMDGPU
import KomaMRICore
import Adapt

KomaMRICore.name(::ROCBackend) = "AMDGPU"
KomaMRICore.isfunctional(::ROCBackend) = AMDGPU.functional()
KomaMRICore.set_device!(::ROCBackend, dev_idx::Integer) = AMDGPU.device_id!(dev_idx)
KomaMRICore.set_device!(::ROCBackend, dev::AMDGPU.HIPDevice) = AMDGPU.device!(dev)
KomaMRICore.device_name(::ROCBackend) = AMDGPU.device().name

function Adapt.adapt_storage(
::ROCBackend, x::Vector{KomaMRICore.LinearInterpolator{T,V}}
) where {T<:Real,V<:AbstractVector{T}}
return AMDGPU.rocconvert.(x)
end

function KomaMRICore._print_devices(::ROCBackend)
devices = [
Symbol("($(i-1)$(i == 1 ? "*" : " "))") => d.name for
Expand Down
7 changes: 7 additions & 0 deletions KomaMRICore/ext/KomaCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@ module KomaCUDAExt

using CUDA
import KomaMRICore
import Adapt

KomaMRICore.name(::CUDABackend) = "CUDA"
KomaMRICore.isfunctional(::CUDABackend) = CUDA.functional()
KomaMRICore.set_device!(::CUDABackend, val) = CUDA.device!(val)
KomaMRICore.device_name(::CUDABackend) = CUDA.name(CUDA.device())

function Adapt.adapt_storage(
::CUDABackend, x::Vector{KomaMRICore.LinearInterpolator{T,V}}
) where {T<:Real,V<:AbstractVector{T}}
return Adapt.adapt.(CuArray, x)
end

function KomaMRICore._print_devices(::CUDABackend)
devices = [
Symbol("($(i-1)$(i == 1 ? "*" : " "))") => CUDA.name(d) for
Expand Down
7 changes: 7 additions & 0 deletions KomaMRICore/ext/KomaMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,20 @@ module KomaMetalExt

using Metal
import KomaMRICore
import Adapt

KomaMRICore.name(::MetalBackend) = "Metal"
KomaMRICore.isfunctional(::MetalBackend) = Metal.functional()
KomaMRICore.set_device!(::MetalBackend, device_index::Integer) = device_index == 1 || @warn "Metal does not support multiple gpu devices. Ignoring the device setting."
KomaMRICore.set_device!(::MetalBackend, dev::Metal.MTLDevice) = Metal.device!(dev)
KomaMRICore.device_name(::MetalBackend) = String(Metal.current_device().name)

function Adapt.adapt_storage(
::MetalBackend, x::Vector{KomaMRICore.LinearInterpolator{T,V}}
) where {T<:Real,V<:AbstractVector{T}}
return Metal.mtl.(x)
end

function KomaMRICore._print_devices(::MetalBackend)
@info "Metal device type: $(KomaMRICore.device_name(MetalBackend()))"
end
Expand Down
7 changes: 7 additions & 0 deletions KomaMRICore/ext/KomaoneAPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@ module KomaoneAPIExt

using oneAPI
import KomaMRICore
import Adapt

KomaMRICore.name(::oneAPIBackend) = "oneAPI"
KomaMRICore.isfunctional(::oneAPIBackend) = oneAPI.functional()
KomaMRICore.set_device!(::oneAPIBackend, val) = oneAPI.device!(val)
KomaMRICore.device_name(::oneAPIBackend) = oneAPI.properties(oneAPI.device()).name

function Adapt.adapt_storage(
::oneAPIBackend, x::Vector{KomaMRICore.LinearInterpolator{T,V}}
) where {T<:Real,V<:AbstractVector{T}}
return Adapt.adapt.(oneArray, a)
end

function KomaMRICore._print_devices(::oneAPIBackend)
devices = [
Symbol("($(i-1)$(i == 1 ? "*" : " "))") => oneAPI.properties(d).name for
Expand Down
19 changes: 17 additions & 2 deletions KomaMRICore/src/simulation/Functors.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import Adapt: adapt, adapt_storage
import Functors: @functor, functor, fmap, isleaf

#Aux. funcitons to check if the variable we want to convert to CuArray is numeric
#Aux. funcitons to check if the variable we want to move to the GPU is numeric
_isleaf(x) = isleaf(x)
_isleaf(::AbstractArray{<:Number}) = true
_isleaf(::AbstractArray{T}) where T = isbitstype(T)
_isleaf(::AbstractRange) = true

"""
gpu(x)
gpu(x, backend)

Tries to move `x` to the GPU backend specified in the 'backend' parameter.

Expand Down Expand Up @@ -42,6 +42,21 @@
"""
cpu(x) = fmap(x -> adapt(KA.CPU(), x), x, exclude=_isleaf)

#MotionModel structs
adapt_storage(::KA.GPU, x::NoMotion) = x
adapt_storage(::KA.GPU, x::SimpleMotion) = x
function adapt_storage(backend::KA.GPU, xs::ArbitraryMotion)
fields = []
for field in fieldnames(ArbitraryMotion)
if field in (:ux, :uy, :uz)
push!(fields, adapt(backend, getfield(xs, field)))

Check warning on line 52 in KomaMRICore/src/simulation/Functors.jl

View check run for this annotation

Codecov / codecov/patch

KomaMRICore/src/simulation/Functors.jl#L46-L52

Added lines #L46 - L52 were not covered by tests
else
push!(fields, getfield(xs, field))

Check warning on line 54 in KomaMRICore/src/simulation/Functors.jl

View check run for this annotation

Codecov / codecov/patch

KomaMRICore/src/simulation/Functors.jl#L54

Added line #L54 was not covered by tests
end
end
return KomaMRICore.ArbitraryMotion(fields...)

Check warning on line 57 in KomaMRICore/src/simulation/Functors.jl

View check run for this annotation

Codecov / codecov/patch

KomaMRICore/src/simulation/Functors.jl#L56-L57

Added lines #L56 - L57 were not covered by tests
end

#Precision
paramtype(T::Type{<:Real}, m) = fmap(x -> adapt(T, x), m)
adapt_storage(T::Type{<:Real}, xs::Real) = convert(T, xs)
Expand Down
2 changes: 1 addition & 1 deletion KomaMRICore/src/simulation/SimulatorCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@
backend = get_backend(sim_params["gpu"])
sim_params["gpu"] &= backend isa KA.GPU
if !KA.supports_float64(backend) && sim_params["precision"] == "f64"
sim_params[precision] = "f32"
sim_params["precision"] = "f32"

Check warning on line 338 in KomaMRICore/src/simulation/SimulatorCore.jl

View check run for this annotation

Codecov / codecov/patch

KomaMRICore/src/simulation/SimulatorCore.jl#L338

Added line #L338 was not covered by tests
@info """ Backend: '$(name(backend))' does not support 64-bit precision
floating point operations. Automatically converting to type Float32.
(set sim_param["precision"] = "f32" to avoid seeing this message).
Expand Down