Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to Ghost v0.2, remove implicit dependency on CUDA #97

Merged
merged 1 commit into from
Jul 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
name = "Yota"
uuid = "cd998857-8626-517d-b929-70ad188a48f0"
authors = ["Andrei Zhabinski <andrei.zhabinski@gmail.com>"]
version = "0.5.0"
version = "0.6.0"

[deps]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Ghost = "4f8f7498-1303-42e1-920c-5033445536df"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[compat]
ChainRulesCore = "0.10"
ChainRules = "0.8"
ChainRulesCore = "0.10"
FiniteDifferences = "0.12"
Ghost = "0.2"
OrderedCollections = "1.4"
Ghost = "0.1"
julia = "1.6"
6 changes: 4 additions & 2 deletions src/core.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import Statistics
using LinearAlgebra
using OrderedCollections
using ChainRulesCore
using ChainRules
using NNlib
using Ghost
using Ghost: Tape, Variable, V, Call, mkcall, Constant, inputs
using Ghost: bound, _getfield, compile, play!, isstruct, ungetfield, ungetindex, uncat
using Ghost: unbroadcast, unbroadcast_prod_x, unbroadcast_prod_y
using Ghost: bound, compile, play!, isstruct
using Ghost: remove_first_parameter, kwfunc_signature, call_signature


include("helpers.jl")
include("drules.jl")
include("chainrules.jl")
include("grad.jl")
Expand Down
148 changes: 148 additions & 0 deletions src/helpers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
function ungetfield(dy, s::Tuple, f::Int)
T = typeof(s)
return Tangent{T}([i == f ? dy : ZeroTangent() for i=1:length(s)]...)
end

function ungetindex!(dx::AbstractArray, x::AbstractArray, dy::AbstractArray, I...)
idx = CartesianIndices(size(x))[I...]
idx = idx isa CartesianIndex ? [idx] : idx
return NNlib.scatter!(+, dx, dy, idx)
end

_array_type_only(A::AT) where AT <: AbstractArray{T, N} where {T, N} = AT


function ungetindex!(dx::AbstractArray, x::AbstractArray, dy::Number, I...)
d_dy = similar(dx, (1,))
fill!(d_dy, dy)
ungetindex!(dx, x, d_dy, I...)
end


function ungetindex(x::AbstractArray, dy, I...)
dx = zero(x)
return ungetindex!(dx, x, dy, I...)
end

function ungetindex(x::Tuple, dy, I...)
dx = map(1:length(x)) do i
i in I ? dy : zero(x[i])
end
return dx
end

"""
_getfield(value, fld)
This function can be used instead of getfield() to bypass Yota rules
during backpropagation.
"""
_getfield(value, fld) = getfield(value, fld)

# function ∇sum(x::AbstractArray, dy)
# dx = similar(x)
# dx .= dy
# return dx
# end


# function ∇mean(x::AbstractArray, dy, dims=1:ndims(x))
# dx = similar(x)
# dx .= dy ./ prod(size(x, d) for d in dims)
# return dx
# end


# function sum_dropdims(x::AbstractArray, dims)
# return dropdims(sum(x; dims=dims); dims=dims)
# end


# unbroadcast from Flux
# in in-place version we can consider sum!(similar(x), ds),
# but we need to carefully measure performance in each case

# reshape Δ to be consistent with x
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))

function unbroadcast(x::AbstractArray, Δ)
if size(x) == size(Δ)
return Δ
elseif length(x) == length(Δ)
return trim(x, Δ)
else
sum_dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))
return trim(x, sum(Δ, dims=sum_dims))
end
end

unbroadcast(::Number, Δ) = sum(Δ)

function unbroadcast_prod_x(x::AbstractArray, y::AbstractArray, Δ)
if size(x) == size(Δ)
return Δ .* y
elseif length(x) == length(Δ)
return trim(x, Δ .* y)
else
sum_dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))
return trim(x, sum(Δ.* y, dims=sum_dims))
end
end
unbroadcast_prod_y(x::AbstractArray, y::AbstractArray, Δ) = unbroadcast_prod_x(y, x, Δ)

# device_like(example, a) = (device = guess_device([example]); device(a))

# unbroadcast_prod_x(x::Number, y::AbstractArray, Δ) = unbroadcast_prod_x(device_like(y, [x]), y, Δ)[1]
unbroadcast_prod_x(x::Number, y::AbstractArray, Δ) = unbroadcast_prod_x(to_same_device([x], y), y, Δ)[1]
unbroadcast_prod_x(x::AbstractArray, y::Number, Δ) = unbroadcast_prod_x(x, to_same_device([y], x), Δ)
unbroadcast_prod_y(x::AbstractArray, y::Number, Δ) = unbroadcast_prod_y(x, to_same_device([y], x), Δ)[1]
unbroadcast_prod_y(x::Number, y::AbstractArray, Δ) = unbroadcast_prod_y(to_same_device([x], y), y, Δ)


untranspose_vec(ds::Transpose{T, <:AbstractVector{T}}) where T = transpose(ds)
untranspose_vec(ds::Adjoint{T, <:AbstractVector{T}}) where T = adjoint(ds)
untranspose_vec(ds::AbstractMatrix) = dropdims(transpose(ds); dims=2)


# function unvcat(dy::AbstractArray, n::Int, arrs::AbstractArray...)
# a = arrs[n]
# from = n == 1 ? 1 : sum(size(arr, 1) for arr in arrs[1:n-1]) + 1
# to = from + size(a, 1) - 1
# return dy[from:to, [(:) for i=1:length(size(dy)) - 1]...]
# end


# function unhcat(dy::AbstractArray, n::Int, arrs::AbstractArray...)
# a = arrs[n]
# from = n == 1 ? 1 : sum(size(arr, 2) for arr in arrs[1:n-1]) + 1
# to = from + size(a, 2) - 1
# return dy[:, from:to, [(:) for i=1:length(size(dy)) - 2]...]
# end


function uncat(dy::AbstractArray, n::Int, arrs::AbstractArray...; dims)
@assert(dims isa Integer, "Can only undo cat() over a single dimension, " *
"but dimensions $dims were provided")
dim = dims
a = arrs[n]
from = n == 1 ? 1 : sum(size(arr, dim) for arr in arrs[1:n-1]) + 1
to = from + size(a, dim) - 1
return dy[[(:) for i=1:dim - 1]..., from:to, [(:) for i=1:length(size(dy)) - dim]...]
end


namedtuple(names, values) = NamedTuple{names}(values)
namedtuple(d::Dict) = NamedTuple{tuple(keys(d)...)}(values(d))


function rev_perm(perm::NTuple{N, Int}) where N
rperm = Vector{Int}(undef, length(perm))
for (i, j) in enumerate(perm)
rperm[j] = i
end
return tuple(rperm...)
end

function ∇permutedims(dy, perm)
rperm = rev_perm(perm)
return permutedims(dy, rperm)
end
10 changes: 10 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
CUDA = "3"
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ using Yota: gradtape, gradcheck, update_chainrules_primitives!
using Yota: trace, compile, play!
import ChainRulesCore: Tangent, ZeroTangent

# test-only dependencies
using CUDA


include("test_helpers.jl")
include("test_grad.jl")
include("test_update.jl")
include("test_examples.jl")
88 changes: 88 additions & 0 deletions test/test_helpers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import Yota: ungetindex, ungetindex!


@testset "helpers: getindex" begin
x = rand(4)
dx = zero(x)
dy = 1.0
@test ungetindex!(dx, x, dy, 2) == [0.0, 1, 0, 0]
@test ungetindex(x, dy, 2) == [0.0, 1, 0, 0]

x = rand(4)
dx = zero(x)
dy = [1, 1, 1]
@test ungetindex!(dx, x, dy, [1, 3, 1]) == [2.0, 0, 1, 0]
@test ungetindex(x, dy, [1, 3, 1]) == [2.0, 0, 1, 0]

x = rand(4, 5)
dx = zero(x)
dy = ones(4, 3)
expected = [2 0 1 0 0;
2 0 1 0 0;
2 0 1 0 0;
2 0 1 0 0.0]
@test ungetindex!(dx, x, dy, :, [1, 3, 1]) == expected
@test ungetindex(x, dy, :, [1, 3, 1]) == expected

## additional smoke tests

x = rand(3, 4)
I = (1:2, 2:3)
y = x[I...]
dy = ones(size(y))
dx = zero(x)
ungetindex!(dx, x, dy, I...)

x = rand(3, 4)
I = (1, 2:3)
y = x[I...]
dy = ones(size(y))
dx = zero(x)
ungetindex!(dx, x, dy, I...)

x = rand(3, 4)
I = (1, :)
y = x[I...]
dy = ones(size(y))
dx = zero(x)
ungetindex!(dx, x, dy, I...)

x = rand(3, 4)
I = (1, [1, 3])
y = x[I...]
dy = ones(size(y))
dx = zero(x)
ungetindex!(dx, x, dy, I...)

# single-element Cartesian index
x = rand(3, 4)
I = (1, 2)
y = x[I...]
dy = 1.0
dx = zero(x)
ungetindex!(dx, x, dy, I...)

# single-element linear index
x = rand(3, 4)
I = (5,)
y = x[I...]
dy = 1.0
dx = zero(x)
ungetindex!(dx, x, dy, I...)

# CUDA
if CUDA.functional()
x = rand(4) |> cu
dx = zero(x) |> cu
dy = 1.0f0
@test ungetindex!(dx, x, dy, 2) == cu([0.0, 1, 0, 0])
@test ungetindex(x, dy, 2) == cu([0.0, 1, 0, 0])

x = rand(4, 5) |> cu
dx = zero(x) |> cu
dy = ones(4, 3) |> cu
@test ungetindex!(dx, x, dy, :, [1, 3, 1]) == cu(expected)
@test ungetindex(x, dy, :, [1, 3, 1]) == cu(expected)
end

end