diff --git a/src/MixedModeBroadcastAD.jl b/src/MixedModeBroadcastAD.jl index db405ed..98d33ef 100644 --- a/src/MixedModeBroadcastAD.jl +++ b/src/MixedModeBroadcastAD.jl @@ -10,6 +10,7 @@ using FastSplat ############ include("gpu/CuArrays.jl") +include("gpu/Const.jl") include("gpu/StructOfArrays.jl") include("gpu/GPUArrays.jl") @@ -27,10 +28,14 @@ include("ad/primitives.jl") #################### Broadcast.BroadcastStyle(::Broadcast.ArrayStyle{CuArray}, s::Broadcast.ArrayStyle{<:StructOfArrays}) = s +Broadcast.BroadcastStyle(::Broadcast.ArrayStyle{CuArray}, s::Broadcast.ArrayStyle{<:Const}) = s Broadcast.BroadcastStyle(::Broadcast.ArrayStyle{CuArray}, s::RecordOtherStyle) = s Broadcast.BroadcastStyle(::Broadcast.ArrayStyle{CuArray}, s::RecordArrayStyle) = s +Broadcast.BroadcastStyle(::Broadcast.ArrayStyle{<:StructOfArrays}, s::Broadcast.ArrayStyle{<:Const}) = s Broadcast.BroadcastStyle(::Broadcast.ArrayStyle{<:StructOfArrays}, s::RecordOtherStyle) = s Broadcast.BroadcastStyle(::Broadcast.ArrayStyle{<:StructOfArrays}, s::RecordArrayStyle) = s +Broadcast.BroadcastStyle(::Broadcast.ArrayStyle{<:Const}, s::RecordOtherStyle) = s +Broadcast.BroadcastStyle(::Broadcast.ArrayStyle{<:Const}, s::RecordArrayStyle) = s DiffRules.@define_diffrule CUDAnative.exp(x) = :(CUDAnative.exp($x)) DiffRules.@define_diffrule CUDAnative.tanh(x) = :(1 - CUDAnative.tanh($x)^2) diff --git a/src/ad/primitives.jl b/src/ad/primitives.jl index 85d2e53..6432271 100644 --- a/src/ad/primitives.jl +++ b/src/ad/primitives.jl @@ -81,7 +81,8 @@ forward!(i::BroadcastInstruction{<:Tuple{typeof(*),Any,Any}}) = invoke(forward!, # multiple dispatch selects this implementation for fused benchmarks function forward!(i::BroadcastInstruction) - f, input_values = first(i.input), value.(i.input[2:end]) + f, _input_values = first(i.input), value.(i.input[2:end]) + input_values = map(readonly, _input_values) if isa(i.output, Tuple) # we have pre-cached memory we can reuse output_variable, output_duals = i.output dual_eval_broadcast!(f, output_duals, value(output_variable), input_values) @@ -188,8 +189,9 @@ end function backward!(i::BroadcastInstruction) f, args = first(i.input), i.input[2:end] - output, output_duals = i.output - output_deriv = deriv(output) + output, _output_duals = i.output + output_duals = readonly(_output_duals) + output_deriv = readonly(deriv(output)) for (i, arg) in enumerate(args) isa(arg, Variable) || continue arg_deriv = deriv(arg) diff --git a/src/gpu/Const.jl b/src/gpu/Const.jl new file mode 100644 index 0000000..c4cd239 --- /dev/null +++ b/src/gpu/Const.jl @@ -0,0 +1,34 @@ +## Const Arg +struct Const{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N} + data::A + Const{T,N,A}(data::A) where {T,N,A<:AbstractArray{T,N}} = new{T,N,A}(data) +end +Const{T,N,A}(::Uninitialized, shape::NTuple{N,Integer}) where {T,N,A} = Const{T,N,A}(A(uninitialized, shape)) + +Base.size(A::Const) = size(A.data) +Base.size(A::Const, i) = size(A.data, i) +Base.show(io::IO, a::Const{T,N,A}) where {T,N,A} = print(io, "$(length(a))-element Const{$T,$N,$A}") +Base.print_array(::IO, ::Const) = nothing + +Base.getindex(A::Const, I...) = (Base.@_propagate_inbounds_meta; A.data[I...]) +@inline function Base.getindex(A::Const{T,N,<:CuDeviceArray{T,N,AS}}, index::Integer) where {T,N,AS} + @boundscheck checkbounds(A, index) + align = CUDAnative.datatype_align(T) + CUDAnative.unsafe_cached_load(pointer(A.data), index, Val(align))::T +end + +Base.BroadcastStyle(::Type{<:Const{T,N,A}}) where {T,N,A<:AbstractArray} = Broadcast.ArrayStyle{Const{T,N,A}}() + +function Base.similar(A::Const{T1,N,AT}, ::Type{T}, dims::Dims) where {T1,N,AT,T} + similar(AT, T, dims) +end + +function Base.broadcast_similar(f, ::Broadcast.ArrayStyle{Const{T1, N, A}}, ::Type{T}, inds, As...) where {T1,N,A,T} + Base.broadcast_similar(f, Base.BroadcastStyle(A), T, inds, As...) +end + +Base.setindex!(::Const, x, I...) = error("setindex! is not allowed for Const array") + +Base.IndexStyle(::Type{<:Const{T,N,A}}) where {T,N,A<:AbstractArray} = Base.IndexStyle(A) + +readonly(a::A) where{T,N,A<:AbstractArray{T,N}} = Const{T,N,A}(a) diff --git a/src/gpu/GPUArrays.jl b/src/gpu/GPUArrays.jl index 133d66a..2af1bd3 100644 --- a/src/gpu/GPUArrays.jl +++ b/src/gpu/GPUArrays.jl @@ -1,8 +1,10 @@ ## broadcast const GPUSoA = StructOfArrays{T,N,CuArray{T,N},U} where {T,N,U} const GPUDeviceSoA = StructOfArrays{T,N,CuDeviceArray{T,N,AS},U} where {T,N,AS,U} -const GPUArrays = Union{<:CuArray{T,N}, <:GPUSoA{T,N}} where {T,N} -const GPUDeviceArrays = Union{<:CuDeviceArray{T,N,AS}, <:GPUDeviceSoA{T,N,AS}} where {T,N,AS} +const _GPUArrays = Union{<:CuArray{T,N}, <:GPUSoA{T,N}} where {T,N} +const _GPUDeviceArrays = Union{<:CuDeviceArray{T,N,AS}, <:GPUDeviceSoA{T,N,AS}} where {T,N,AS} +const GPUArrays = Union{<:_GPUArrays{T,N}, <:Const{T, N, <:_GPUArrays{T,N}}} where {T, N} +const GPUDeviceArrays = Union{<:_GPUDeviceArrays{T,N,AS}, <:Const{T, N, <:_GPUDeviceArrays{T,N,AS}}} where {T, N, AS} function CUDAnative.cudaconvert(A::GPUSoA{T, N}) where {T, N} arrays = map(CUDAnative.cudaconvert, A.arrays) @@ -10,6 +12,16 @@ function CUDAnative.cudaconvert(A::GPUSoA{T, N}) where {T, N} StructOfArrays{T, N, CuDeviceArray{T,N,AS.Global}, tt}(arrays) end +function CUDAnative.cudaconvert(A::Const{T,N,CuArray{T,N}}) where {T, N} + Const{T,N,CuDeviceArray{T,N,AS.Global}}(CUDAnative.cudaconvert(A.data)) +end + +function CUDAnative.cudaconvert(A::StructOfArrays{T,N,Const{T,N,CuArray{T,N}}}) where {T,N} + arrays = map(CUDAnative.cudaconvert, A.arrays) + tt = typeof(arrays) + StructOfArrays{T, N, CuDeviceArray{T,N,AS.Global}, tt}(arrays) +end + ### base interface @inline function Base.broadcast!(f, dest::GPUArrays, ::Nothing, As::Vararg{Any, N}) where N diff --git a/src/gpu/StructOfArrays.jl b/src/gpu/StructOfArrays.jl index 9dacea5..a9e9206 100644 --- a/src/gpu/StructOfArrays.jl +++ b/src/gpu/StructOfArrays.jl @@ -18,10 +18,12 @@ end # Storage types of StructOfArrays need to implement this _type_with_eltype(::Type{<:Array}, T, N) = Array{T, N} _type_with_eltype(::Type{<:CuArray}, T, N) = CuArray{T, N} -_type_with_eltype(::Type{CuDeviceArray{_T,_N,AS}}, T, N) where{_T,_N,AS} = CuDeviceArray(T,N,AS) +_type_with_eltype(::Type{CuDeviceArray{_T,_N,AS}}, T, N) where{_T,_N,AS} = CuDeviceArray{T,N,AS} +_type_with_eltype(::Type{Const{_T,_N,AT}}, T, N) where{_T,_N,AT} = Const{T,N,_type_with_eltype(AT,T,N)} _type(::Type{<:Array}) = Array _type(::Type{<:CuArray}) = CuArray _type(::Type{<:CuDeviceArray}) = CuDeviceArray +_type(::Type{<:Const}) = Const function gather_eltypes(T, visited = Set{Type}()) (!isconcretetype(T) || T.mutable) && throw(ArgumentError("can only create an StructOfArrays of leaf type immutables")) @@ -121,14 +123,30 @@ Base.IndexStyle(::Type{<:StructOfArrays{T,N,A}}) where {T,N,A<:AbstractArray} = Base.BroadcastStyle(::Type{<:StructOfArrays{T,N,A}}) where {T,N,A<:AbstractArray} = Broadcast.ArrayStyle{StructOfArrays{T,N,A}}() function Base.similar(A::StructOfArrays{T1,N,AT}, ::Type{T}, dims::Dims) where {T1,N,AT,T} + @assert !(AT<:Const) + StructOfArrays(T, AT, dims) +end + +function Base.similar(A::StructOfArrays{T1,N,Const{T1,N,AT}}, ::Type{T}, dims::Dims) where {T1,N,AT,T} StructOfArrays(T, AT, dims) end function Base.broadcast_similar(f, ::Broadcast.ArrayStyle{StructOfArrays{T1,N,A}}, ::Type{T}, inds, As...) where {T1,N,A,T} + @assert !(A<:Const) StructOfArrays(T, A, Base.to_shape(inds)) end -function Base.convert(::Type{<:StructOfArrays{T,N,AT}}, A::StructOfArrays{T,N}) where {T,N,AT<:AbstractArray{T,N}} +function Base.broadcast_similar(f, ::Broadcast.ArrayStyle{StructOfArrays{T1,N,Const{T1,N,A}}}, ::Type{T}, inds, As...) where {T1,N,A,T} + StructOfArrays(T, A, Base.to_shape(inds)) +end + +function readonly(A::StructOfArrays{T,N,AT}) where {T,N,AT} + arrays = map(readonly, A.arrays) + tt = typeof(arrays) + StructOfArrays{T,N,Const{T,N,AT},tt}(arrays) +end + +function Base.convert(::Type{<:StructOfArrays{T,N,AT}}, A::StructOfArrays{T, N}) where {T,N,AT<:AbstractArray{T,N}} if AT <: StructOfArrays error("Can't embed a SoA array in a SoA array") end