Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: const wrapper for __ldg #18

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/MixedModeBroadcastAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using FastSplat
############

include("gpu/CuArrays.jl")
include("gpu/Const.jl")
include("gpu/StructOfArrays.jl")
include("gpu/GPUArrays.jl")

Expand All @@ -27,10 +28,14 @@ include("ad/primitives.jl")
####################

Broadcast.BroadcastStyle(::Broadcast.ArrayStyle{CuArray}, s::Broadcast.ArrayStyle{<:StructOfArrays}) = s
Broadcast.BroadcastStyle(::Broadcast.ArrayStyle{CuArray}, s::Broadcast.ArrayStyle{<:Const}) = s
Broadcast.BroadcastStyle(::Broadcast.ArrayStyle{CuArray}, s::RecordOtherStyle) = s
Broadcast.BroadcastStyle(::Broadcast.ArrayStyle{CuArray}, s::RecordArrayStyle) = s
Broadcast.BroadcastStyle(::Broadcast.ArrayStyle{<:StructOfArrays}, s::Broadcast.ArrayStyle{<:Const}) = s
Broadcast.BroadcastStyle(::Broadcast.ArrayStyle{<:StructOfArrays}, s::RecordOtherStyle) = s
Broadcast.BroadcastStyle(::Broadcast.ArrayStyle{<:StructOfArrays}, s::RecordArrayStyle) = s
Broadcast.BroadcastStyle(::Broadcast.ArrayStyle{<:Const}, s::RecordOtherStyle) = s
Broadcast.BroadcastStyle(::Broadcast.ArrayStyle{<:Const}, s::RecordArrayStyle) = s

DiffRules.@define_diffrule CUDAnative.exp(x) = :(CUDAnative.exp($x))
DiffRules.@define_diffrule CUDAnative.tanh(x) = :(1 - CUDAnative.tanh($x)^2)
Expand Down
8 changes: 5 additions & 3 deletions src/ad/primitives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ forward!(i::BroadcastInstruction{<:Tuple{typeof(*),Any,Any}}) = invoke(forward!,
# multiple dispatch selects this implementation for fused benchmarks

function forward!(i::BroadcastInstruction)
f, input_values = first(i.input), value.(i.input[2:end])
f, _input_values = first(i.input), value.(i.input[2:end])
input_values = map(readonly, _input_values)
if isa(i.output, Tuple) # we have pre-cached memory we can reuse
output_variable, output_duals = i.output
dual_eval_broadcast!(f, output_duals, value(output_variable), input_values)
Expand Down Expand Up @@ -188,8 +189,9 @@ end

function backward!(i::BroadcastInstruction)
f, args = first(i.input), i.input[2:end]
output, output_duals = i.output
output_deriv = deriv(output)
output, _output_duals = i.output
output_duals = readonly(_output_duals)
output_deriv = readonly(deriv(output))
for (i, arg) in enumerate(args)
isa(arg, Variable) || continue
arg_deriv = deriv(arg)
Expand Down
34 changes: 34 additions & 0 deletions src/gpu/Const.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
## Const Arg
struct Const{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
data::A
Const{T,N,A}(data::A) where {T,N,A<:AbstractArray{T,N}} = new{T,N,A}(data)
end
Const{T,N,A}(::Uninitialized, shape::NTuple{N,Integer}) where {T,N,A} = Const{T,N,A}(A(uninitialized, shape))

Base.size(A::Const) = size(A.data)
Base.size(A::Const, i) = size(A.data, i)
Base.show(io::IO, a::Const{T,N,A}) where {T,N,A} = print(io, "$(length(a))-element Const{$T,$N,$A}")
Base.print_array(::IO, ::Const) = nothing

Base.getindex(A::Const, I...) = (Base.@_propagate_inbounds_meta; A.data[I...])
@inline function Base.getindex(A::Const{T,N,<:CuDeviceArray{T,N,AS}}, index::Integer) where {T,N,AS}
@boundscheck checkbounds(A, index)
align = CUDAnative.datatype_align(T)
CUDAnative.unsafe_cached_load(pointer(A.data), index, Val(align))::T
end

Base.BroadcastStyle(::Type{<:Const{T,N,A}}) where {T,N,A<:AbstractArray} = Broadcast.ArrayStyle{Const{T,N,A}}()

function Base.similar(A::Const{T1,N,AT}, ::Type{T}, dims::Dims) where {T1,N,AT,T}
similar(AT, T, dims)
end

function Base.broadcast_similar(f, ::Broadcast.ArrayStyle{Const{T1, N, A}}, ::Type{T}, inds, As...) where {T1,N,A,T}
Base.broadcast_similar(f, Base.BroadcastStyle(A), T, inds, As...)
end

Base.setindex!(::Const, x, I...) = error("setindex! is not allowed for Const array")

Base.IndexStyle(::Type{<:Const{T,N,A}}) where {T,N,A<:AbstractArray} = Base.IndexStyle(A)

readonly(a::A) where{T,N,A<:AbstractArray{T,N}} = Const{T,N,A}(a)
16 changes: 14 additions & 2 deletions src/gpu/GPUArrays.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
## broadcast
const GPUSoA = StructOfArrays{T,N,CuArray{T,N},U} where {T,N,U}
const GPUDeviceSoA = StructOfArrays{T,N,CuDeviceArray{T,N,AS},U} where {T,N,AS,U}
const GPUArrays = Union{<:CuArray{T,N}, <:GPUSoA{T,N}} where {T,N}
const GPUDeviceArrays = Union{<:CuDeviceArray{T,N,AS}, <:GPUDeviceSoA{T,N,AS}} where {T,N,AS}
const _GPUArrays = Union{<:CuArray{T,N}, <:GPUSoA{T,N}} where {T,N}
const _GPUDeviceArrays = Union{<:CuDeviceArray{T,N,AS}, <:GPUDeviceSoA{T,N,AS}} where {T,N,AS}
const GPUArrays = Union{<:_GPUArrays{T,N}, <:Const{T, N, <:_GPUArrays{T,N}}} where {T, N}
const GPUDeviceArrays = Union{<:_GPUDeviceArrays{T,N,AS}, <:Const{T, N, <:_GPUDeviceArrays{T,N,AS}}} where {T, N, AS}

function CUDAnative.cudaconvert(A::GPUSoA{T, N}) where {T, N}
arrays = map(CUDAnative.cudaconvert, A.arrays)
tt = typeof(arrays)
StructOfArrays{T, N, CuDeviceArray{T,N,AS.Global}, tt}(arrays)
end

function CUDAnative.cudaconvert(A::Const{T,N,CuArray{T,N}}) where {T, N}
Const{T,N,CuDeviceArray{T,N,AS.Global}}(CUDAnative.cudaconvert(A.data))
end

function CUDAnative.cudaconvert(A::StructOfArrays{T,N,Const{T,N,CuArray{T,N}}}) where {T,N}
arrays = map(CUDAnative.cudaconvert, A.arrays)
tt = typeof(arrays)
StructOfArrays{T, N, CuDeviceArray{T,N,AS.Global}, tt}(arrays)
end

### base interface

@inline function Base.broadcast!(f, dest::GPUArrays, ::Nothing, As::Vararg{Any, N}) where N
Expand Down
22 changes: 20 additions & 2 deletions src/gpu/StructOfArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ end
# Storage types of StructOfArrays need to implement this
_type_with_eltype(::Type{<:Array}, T, N) = Array{T, N}
_type_with_eltype(::Type{<:CuArray}, T, N) = CuArray{T, N}
_type_with_eltype(::Type{CuDeviceArray{_T,_N,AS}}, T, N) where{_T,_N,AS} = CuDeviceArray(T,N,AS)
_type_with_eltype(::Type{CuDeviceArray{_T,_N,AS}}, T, N) where{_T,_N,AS} = CuDeviceArray{T,N,AS}
_type_with_eltype(::Type{Const{_T,_N,AT}}, T, N) where{_T,_N,AT} = Const{T,N,_type_with_eltype(AT,T,N)}
_type(::Type{<:Array}) = Array
_type(::Type{<:CuArray}) = CuArray
_type(::Type{<:CuDeviceArray}) = CuDeviceArray
_type(::Type{<:Const}) = Const

function gather_eltypes(T, visited = Set{Type}())
(!isconcretetype(T) || T.mutable) && throw(ArgumentError("can only create an StructOfArrays of leaf type immutables"))
Expand Down Expand Up @@ -121,14 +123,30 @@ Base.IndexStyle(::Type{<:StructOfArrays{T,N,A}}) where {T,N,A<:AbstractArray} =
Base.BroadcastStyle(::Type{<:StructOfArrays{T,N,A}}) where {T,N,A<:AbstractArray} = Broadcast.ArrayStyle{StructOfArrays{T,N,A}}()

function Base.similar(A::StructOfArrays{T1,N,AT}, ::Type{T}, dims::Dims) where {T1,N,AT,T}
@assert !(AT<:Const)
StructOfArrays(T, AT, dims)
end

function Base.similar(A::StructOfArrays{T1,N,Const{T1,N,AT}}, ::Type{T}, dims::Dims) where {T1,N,AT,T}
StructOfArrays(T, AT, dims)
end

function Base.broadcast_similar(f, ::Broadcast.ArrayStyle{StructOfArrays{T1,N,A}}, ::Type{T}, inds, As...) where {T1,N,A,T}
@assert !(A<:Const)
StructOfArrays(T, A, Base.to_shape(inds))
end

function Base.convert(::Type{<:StructOfArrays{T,N,AT}}, A::StructOfArrays{T,N}) where {T,N,AT<:AbstractArray{T,N}}
function Base.broadcast_similar(f, ::Broadcast.ArrayStyle{StructOfArrays{T1,N,Const{T1,N,A}}}, ::Type{T}, inds, As...) where {T1,N,A,T}
StructOfArrays(T, A, Base.to_shape(inds))
end

function readonly(A::StructOfArrays{T,N,AT}) where {T,N,AT}
arrays = map(readonly, A.arrays)
tt = typeof(arrays)
StructOfArrays{T,N,Const{T,N,AT},tt}(arrays)
end

function Base.convert(::Type{<:StructOfArrays{T,N,AT}}, A::StructOfArrays{T, N}) where {T,N,AT<:AbstractArray{T,N}}
if AT <: StructOfArrays
error("Can't embed a SoA array in a SoA array")
end
Expand Down