Skip to content

Commit

Permalink
feat: add fallbacks for unknown objects (#87)
Browse files Browse the repository at this point in the history
* feat: add fallbacks for unknown objects

* feat: handle RNGs and undef arrays gracefully

* test: RNG movement

* test: functions and closures
  • Loading branch information
avik-pal authored Oct 18, 2024
1 parent bf09413 commit a9871cb
Show file tree
Hide file tree
Showing 14 changed files with 215 additions and 21 deletions.
2 changes: 1 addition & 1 deletion lib/MLDataDevices/.buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
steps:
- label: "Triggering Pipelines (Pull Request)"
if: "build.pull_request.base_branch == 'main'"
if: build.branch != "main" && build.tag == null
agents:
queue: "juliagpu"
plugins:
Expand Down
2 changes: 1 addition & 1 deletion lib/MLDataDevices/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "1.2.1"
version = "1.3.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
2 changes: 2 additions & 0 deletions lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ function Internal.get_device(x::AMDGPU.AnyROCArray)
parent_x === x && return AMDGPUDevice(AMDGPU.device(x))
return Internal.get_device(parent_x)
end
Internal.get_device(::AMDGPU.rocRAND.RNG) = AMDGPUDevice(AMDGPU.device())

Internal.get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice
Internal.get_device_type(::AMDGPU.rocRAND.RNG) = AMDGPUDevice

# Set Device
function MLDataDevices.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice)
Expand Down
4 changes: 4 additions & 0 deletions lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@ function Internal.get_device(x::CUDA.AnyCuArray)
return MLDataDevices.get_device(parent_x)
end
Internal.get_device(x::AbstractCuSparseArray) = CUDADevice(CUDA.device(x.nzVal))
Internal.get_device(::CUDA.RNG) = CUDADevice(CUDA.device())
Internal.get_device(::CUDA.CURAND.RNG) = CUDADevice(CUDA.device())

Internal.get_device_type(::Union{<:CUDA.AnyCuArray, <:AbstractCuSparseArray}) = CUDADevice
Internal.get_device_type(::CUDA.RNG) = CUDADevice
Internal.get_device_type(::CUDA.CURAND.RNG) = CUDADevice

# Set Device
MLDataDevices.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) = CUDA.device!(dev)
Expand Down
11 changes: 8 additions & 3 deletions lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@ module MLDataDevicesChainRulesCoreExt
using Adapt: Adapt
using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable

using MLDataDevices: AbstractDevice, get_device, get_device_type
using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type

@non_differentiable get_device(::Any)
@non_differentiable get_device_type(::Any)

function ChainRulesCore.rrule(
::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray)
∇adapt_storage = let x = x
Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ))
∇adapt_storage = let dev = get_device(x)
if dev === nothing || dev isa UnknownDevice
@warn "`get_device(::$(typeof(x)))` returned `$(dev)`." maxlog=1
Δ -> (NoTangent(), NoTangent(), Δ)
else
Δ -> (NoTangent(), NoTangent(), dev(Δ))
end
end
return Adapt.adapt_storage(to, x), ∇adapt_storage
end
Expand Down
5 changes: 4 additions & 1 deletion lib/MLDataDevices/ext/MLDataDevicesGPUArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ module MLDataDevicesGPUArraysExt

using Adapt: Adapt
using GPUArrays: GPUArrays
using MLDataDevices: CPUDevice
using MLDataDevices: Internal, CPUDevice
using Random: Random

Adapt.adapt_storage(::CPUDevice, rng::GPUArrays.RNG) = Random.default_rng()

Internal.get_device(rng::GPUArrays.RNG) = Internal.get_device(rng.state)
Internal.get_device_type(rng::GPUArrays.RNG) = Internal.get_device_type(rng.state)

end
39 changes: 32 additions & 7 deletions lib/MLDataDevices/src/internal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ using Preferences: load_preference
using Random: AbstractRNG

using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice,
MetalDevice, oneAPIDevice, XLADevice, supported_gpu_backends,
GPU_DEVICES, loaded, functional
MetalDevice, oneAPIDevice, XLADevice, UnknownDevice,
supported_gpu_backends, GPU_DEVICES, loaded, functional

for dev in (CPUDevice, MetalDevice, oneAPIDevice)
msg = "`device_id` is not applicable for `$dev`."
Expand Down Expand Up @@ -107,31 +107,38 @@ special_aos(::AbstractArray) = false
recursive_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number)

combine_devices(::Nothing, ::Nothing) = nothing
combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing
combine_devices(::Nothing, dev::AbstractDevice) = dev
combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T
combine_devices(dev::AbstractDevice, ::Nothing) = dev
combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T
function combine_devices(dev1::AbstractDevice, dev2::AbstractDevice)
dev1 == dev2 && return dev1
dev1 isa UnknownDevice && return dev2
dev2 isa UnknownDevice && return dev1
throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2)."))
end

combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing
combine_devices(::Type{T}, ::Type{T}) where {T <: AbstractDevice} = T
combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T
combine_devices(::Type{T}, ::Type{UnknownDevice}) where {T <: AbstractDevice} = T
combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T
combine_devices(::Type{UnknownDevice}, ::Type{T}) where {T <: AbstractDevice} = T
combine_devices(::Type{UnknownDevice}, ::Type{UnknownDevice}) = UnknownDevice
function combine_devices(T1::Type{<:AbstractDevice}, T2::Type{<:AbstractDevice})
throw(ArgumentError("Objects are on devices with different types: $(T1) and $(T2)."))
end

for op in (:get_device, :get_device_type)
cpu_ret_val = op == :get_device ? CPUDevice() : CPUDevice
unknown_ret_val = op == :get_device ? UnknownDevice() : UnknownDevice
not_assigned_msg = "AbstractArray has some undefined references. Giving up, returning \
$(cpu_ret_val)..."
$(unknown_ret_val)..."

@eval begin
function $(op)(x::AbstractArray{T}) where {T}
if recursive_array_eltype(T)
if any(!isassigned(x, i) for i in eachindex(x))
@warn $(not_assigned_msg)
return $(cpu_ret_val)
return $(unknown_ret_val)
end
return mapreduce(MLDataDevices.$(op), combine_devices, x)
end
Expand All @@ -147,13 +154,31 @@ for op in (:get_device, :get_device_type)
length(x) == 0 && return $(op == :get_device ? nothing : Nothing)
return unrolled_mapreduce(MLDataDevices.$(op), combine_devices, values(x))
end

function $(op)(f::F) where {F <: Function}
Base.issingletontype(F) &&
return $(op == :get_device ? UnknownDevice() : UnknownDevice)
return unrolled_mapreduce(MLDataDevices.$(op), combine_devices,
map(Base.Fix1(getfield, f), fieldnames(F)))
end
end

for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange)
@eval $(op)(::$(T)) = $(op == :get_device ? nothing : Nothing)
end
end

get_device(_) = UnknownDevice()
get_device_type(_) = UnknownDevice

fast_structure(::AbstractArray) = true
fast_structure(::Union{Tuple, NamedTuple}) = true
for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange)
@eval fast_structure(::$(T)) = true
end
fast_structure(::Function) = true
fast_structure(_) = false

function unrolled_mapreduce(f::F, op::O, itr) where {F, O}
return unrolled_mapreduce(f, op, itr, static_length(itr))
end
Expand Down
22 changes: 16 additions & 6 deletions lib/MLDataDevices/src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ struct oneAPIDevice <: AbstractGPUDevice end
# TODO: Later we might want to add the client field here?
struct XLADevice <: AbstractAcceleratorDevice end

# Fallback for when we don't know the device type
struct UnknownDevice <: AbstractDevice end

"""
functional(x::AbstractDevice) -> Bool
functional(::Type{<:AbstractDevice}) -> Bool
Expand Down Expand Up @@ -229,11 +232,6 @@ const GET_DEVICE_ADMONITIONS = """
!!! note
Trigger Packages must be loaded for this to return the correct device.
!!! warning
RNG types currently don't participate in device determination. We will remove this
restriction in the future.
"""

# Query Device from Array
Expand All @@ -245,6 +243,12 @@ device. Otherwise, we throw an error. If the object is device agnostic, we retur
$(GET_DEVICE_ADMONITIONS)
## Special Retuened Values
- `nothing` -- denotes that the object is device agnostic. For example, scalar, abstract
range, etc.
- `UnknownDevice()` -- denotes that the device type is unknown
See also [`get_device_type`](@ref) for a faster alternative that can be used for dispatch
based on device type.
"""
Expand All @@ -258,6 +262,12 @@ itself. This value is often a compile time constant and is recommended to be use
of [`get_device`](@ref) where ever defining dispatches based on the device type.
$(GET_DEVICE_ADMONITIONS)
## Special Retuened Values
- `Nothing` -- denotes that the object is device agnostic. For example, scalar, abstract
range, etc.
- `UnknownDevice` -- denotes that the device type is unknown
"""
function get_device_type end

Expand Down Expand Up @@ -345,7 +355,7 @@ end

for op in (:get_device, :get_device_type)
@eval function $(op)(x)
hasmethod(Internal.$(op), Tuple{typeof(x)}) && return Internal.$(op)(x)
Internal.fast_structure(x) && return Internal.$(op)(x)
return mapreduce(Internal.$(op), Internal.combine_devices, fleaves(x))
end
end
Expand Down
29 changes: 29 additions & 0 deletions lib/MLDataDevices/test/amdgpu_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@test get_device(ps_xpu.rng_default) isa AMDGPUDevice
@test get_device_type(ps_xpu.rng_default) <: AMDGPUDevice
@test ps_xpu.rng == ps.rng
@test get_device(ps_xpu.rng) === nothing
@test get_device_type(ps_xpu.rng) <: Nothing

if MLDataDevices.functional(AMDGPUDevice)
@test ps_xpu.one_elem isa ROCArray
Expand All @@ -83,7 +87,11 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test get_device(ps_cpu.rng_default) === nothing
@test get_device_type(ps_cpu.rng_default) <: Nothing
@test ps_cpu.rng == ps.rng
@test get_device(ps_cpu.rng) === nothing
@test get_device_type(ps_cpu.rng) <: Nothing

if MLDataDevices.functional(AMDGPUDevice)
@test ps_cpu.one_elem isa Array
Expand Down Expand Up @@ -118,6 +126,27 @@ using FillArrays, Zygote # Extensions
end
end

@testset "Functions" begin
if MLDataDevices.functional(AMDGPUDevice)
@test get_device(tanh) isa MLDataDevices.UnknownDevice
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice

f(x, y) = () -> (x, x .^ 2, y)

ff = f([1, 2, 3], 1)
@test get_device(ff) isa CPUDevice
@test get_device_type(ff) <: CPUDevice

ff_xpu = ff |> AMDGPUDevice()
@test get_device(ff_xpu) isa AMDGPUDevice
@test get_device_type(ff_xpu) <: AMDGPUDevice

ff_cpu = ff_xpu |> cpu_device()
@test get_device(ff_cpu) isa CPUDevice
@test get_device_type(ff_cpu) <: CPUDevice
end
end

@testset "Wrapped Arrays" begin
if MLDataDevices.functional(AMDGPUDevice)
x = rand(10, 10) |> AMDGPUDevice()
Expand Down
29 changes: 29 additions & 0 deletions lib/MLDataDevices/test/cuda_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@test get_device(ps_xpu.rng_default) isa CUDADevice
@test get_device_type(ps_xpu.rng_default) <: CUDADevice
@test ps_xpu.rng == ps.rng
@test get_device(ps_xpu.rng) === nothing
@test get_device_type(ps_xpu.rng) <: Nothing

if MLDataDevices.functional(CUDADevice)
@test ps_xpu.one_elem isa CuArray
Expand All @@ -82,7 +86,11 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test get_device(ps_cpu.rng_default) === nothing
@test get_device_type(ps_cpu.rng_default) <: Nothing
@test ps_cpu.rng == ps.rng
@test get_device(ps_cpu.rng) === nothing
@test get_device_type(ps_cpu.rng) <: Nothing

if MLDataDevices.functional(CUDADevice)
@test ps_cpu.one_elem isa Array
Expand Down Expand Up @@ -143,6 +151,27 @@ using FillArrays, Zygote # Extensions
end
end

@testset "Functions" begin
if MLDataDevices.functional(CUDADevice)
@test get_device(tanh) isa MLDataDevices.UnknownDevice
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice

f(x, y) = () -> (x, x .^ 2, y)

ff = f([1, 2, 3], 1)
@test get_device(ff) isa CPUDevice
@test get_device_type(ff) <: CPUDevice

ff_xpu = ff |> CUDADevice()
@test get_device(ff_xpu) isa CUDADevice
@test get_device_type(ff_xpu) <: CUDADevice

ff_cpu = ff_xpu |> cpu_device()
@test get_device(ff_cpu) isa CPUDevice
@test get_device_type(ff_cpu) <: CPUDevice
end
end

@testset "Wrapped Arrays" begin
if MLDataDevices.functional(CUDADevice)
x = rand(10, 10) |> CUDADevice()
Expand Down
29 changes: 29 additions & 0 deletions lib/MLDataDevices/test/metal_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@test get_device(ps_xpu.rng_default) isa MetalDevice
@test get_device_type(ps_xpu.rng_default) <: MetalDevice
@test ps_xpu.rng == ps.rng
@test get_device(ps_xpu.rng) === nothing
@test get_device_type(ps_xpu.rng) <: Nothing

if MLDataDevices.functional(MetalDevice)
@test ps_xpu.one_elem isa MtlArray
Expand All @@ -81,7 +85,11 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test get_device(ps_cpu.rng_default) === nothing
@test get_device_type(ps_cpu.rng_default) <: Nothing
@test ps_cpu.rng == ps.rng
@test get_device(ps_cpu.rng) === nothing
@test get_device_type(ps_cpu.rng) <: Nothing

if MLDataDevices.functional(MetalDevice)
@test ps_cpu.one_elem isa Array
Expand All @@ -107,6 +115,27 @@ using FillArrays, Zygote # Extensions
end
end

@testset "Functions" begin
if MLDataDevices.functional(MetalDevice)
@test get_device(tanh) isa MLDataDevices.UnknownDevice
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice

f(x, y) = () -> (x, x .^ 2, y)

ff = f([1, 2, 3], 1)
@test get_device(ff) isa CPUDevice
@test get_device_type(ff) <: CPUDevice

ff_xpu = ff |> MetalDevice()
@test get_device(ff_xpu) isa MetalDevice
@test get_device_type(ff_xpu) <: MetalDevice

ff_cpu = ff_xpu |> cpu_device()
@test get_device(ff_cpu) isa CPUDevice
@test get_device_type(ff_cpu) <: CPUDevice
end
end

@testset "Wrapper Arrays" begin
if MLDataDevices.functional(MetalDevice)
x = rand(Float32, 10, 10) |> MetalDevice()
Expand Down
4 changes: 2 additions & 2 deletions lib/MLDataDevices/test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,6 @@ end
@testset "undefined references array" begin
x = Matrix{Any}(undef, 10, 10)

@test get_device(x) isa CPUDevice
@test get_device_type(x) <: CPUDevice
@test get_device(x) isa MLDataDevices.UnknownDevice
@test get_device_type(x) <: MLDataDevices.UnknownDevice
end
Loading

0 comments on commit a9871cb

Please sign in to comment.