Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix AD issues with various kernels #154

Merged
merged 24 commits into from
Sep 8, 2020
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/basekernels/maha.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

Mahalanobis distance-based kernel given by
```math
κ(x,y) = exp(-r^2), r^2 = maha(x,P,y) = (x-y)'*inv(P)*(x-y)
κ(x,y) = exp(-r^2), r^2 = maha(x,P,y) = (x-y)'* P *(x-y)
```
where the matrix P is the metric.

Expand Down
4 changes: 3 additions & 1 deletion src/basekernels/matern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ The matern kernel is a Mercer kernel given by the formula:
```
κ(x,y) = 2^{1-ν}/Γ(ν)*(√(2ν)‖x-y‖)^ν K_ν(√(2ν)‖x-y‖)
```
For `ν=n+1/2, n=0,1,2,...` it can be simplified and you should instead use [`ExponentialKernel`](@ref) for `n=0`, [`Matern32Kernel`](@ref), for `n=1`, [`Matern52Kernel`](@ref) for `n=2` and [`SqExponentialKernel`](@ref) for `n=∞`.
For `ν=n+1/2, n=0,1,2,...` it can be simplified and you should instead use
[`ExponentialKernel`](@ref) for `n=0`, [`Matern32Kernel`](@ref), for `n=1`,
[`Matern52Kernel`](@ref) for `n=2` and [`SqExponentialKernel`](@ref) for `n=∞`.
"""
struct MaternKernel{Tν<:Real} <: SimpleKernel
ν::Vector{Tν}
Expand Down
28 changes: 28 additions & 0 deletions src/basekernels/nn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,32 @@ function (κ::NeuralNetworkKernel)(x, y)
return asin(dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y))))
end

function kernelmatrix(::NeuralNetworkKernel, x::ColVecs, y::ColVecs)
sharanry marked this conversation as resolved.
Show resolved Hide resolved
validate_inputs(x, y)
X_2 = sum(x.X .* x.X; dims=1)
Y_2 = sum(y.X .* y.X; dims=1)
XY = x.X' * y.X
return asin.(XY ./ sqrt.((X_2 .+ 1)' * (Y_2 .+ 1)))
end

function kernelmatrix(::NeuralNetworkKernel, x::ColVecs)
sharanry marked this conversation as resolved.
Show resolved Hide resolved
X_2_1 = sum(x.X .* x.X; dims=1) .+ 1
XX = x.X' * x.X
return asin.(XX ./ sqrt.(X_2_1' * X_2_1))
end

function kernelmatrix(::NeuralNetworkKernel, x::RowVecs, y::RowVecs)
sharanry marked this conversation as resolved.
Show resolved Hide resolved
validate_inputs(x, y)
X_2 = sum(x.X .* x.X; dims=2)
Y_2 = sum(y.X .* y.X; dims=2)
XY = x.X * y.X'
return asin.(XY ./ sqrt.((X_2 .+ 1)' * (Y_2 .+ 1)))
end

function kernelmatrix(::NeuralNetworkKernel, x::RowVecs)
sharanry marked this conversation as resolved.
Show resolved Hide resolved
X_2_1 = sum(x.X .* x.X; dims=2) .+ 1
XX = x.X * x.X'
return asin.(XX ./ sqrt.(X_2_1' * X_2_1))
end

Base.show(io::IO, κ::NeuralNetworkKernel) = print(io, "Neural Network Kernel")
63 changes: 59 additions & 4 deletions src/zygote_adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,19 @@ end
@adjoint function ColVecs(X::AbstractMatrix)
back(Δ::NamedTuple) = (Δ.X,)
back(Δ::AbstractMatrix) = (Δ,)
function back(Δ::AbstractVector{<:AbstractVector{<:Real}})
function ColVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}})
throw(error("In slow method"))
end
return ColVecs(X), back
return ColVecs(X), ColVecs_pullback
end

@adjoint function RowVecs(X::AbstractMatrix)
back(Δ::NamedTuple) = (Δ.X,)
back(Δ::AbstractMatrix) = (Δ,)
function back(Δ::AbstractVector{<:AbstractVector{<:Real}})
function RowVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}})
throw(error("In slow method"))
end
return RowVecs(X), back
return RowVecs(X), RowVecs_pullback
end

@adjoint function Base.map(t::Transform, X::ColVecs)
Expand All @@ -84,3 +84,58 @@ end
@adjoint function Base.map(t::Transform, X::RowVecs)
pullback(_map, t, X)
end

@adjoint function (dist::Distances.SqMahalanobis)(a, b)
function SqMahalanobis_pullback(Δ::Real)
B_Bᵀ = dist.qmat + transpose(dist.qmat)
a_b = a - b
δa = (B_Bᵀ * a_b) * Δ
return (qmat = (a_b * a_b') * Δ,), δa, -δa
end
return evaluate(dist, a, b), SqMahalanobis_pullback
sharanry marked this conversation as resolved.
Show resolved Hide resolved
end


@adjoint function Distances.pairwise(
dist::SqMahalanobis,
a::AbstractMatrix,
b::AbstractMatrix;
dims::Union{Nothing,Integer}=nothing
)
function pairwise_pullback(Δ::AbstractMatrix)
B_Bᵀ = dist.qmat + transpose(dist.qmat)
a_b = map(
x -> (first(last(x)) - last(last(x)))*first(x),
zip(
Δ,
Iterators.product(eachslice(a, dims=dims), eachslice(b, dims=dims))
)
)
δa = reduce(hcat, sum(map(x -> B_Bᵀ*x, a_b), dims=2))
δB = sum(map(x -> x*transpose(x), a_b))
return (qmat=δB,), δa, -δa
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is som discrepancy between the simple case above and this pullback - intuitively, from the simple case above I would assume that δB = sum_{i, j} (a_i - b_j) * (a_i - b_j)^T * Δ_{i,j}. However, here you compute δB = sum_{i, j} (a_i - b_j) * (a_i - b_j)^T * Δ_{i,j}^2. Probably one of them is incorrect (table 7 in https://notendur.hi.is/jonasson/greinar/blas-rmd.pdf indicates that the pairwise one is incorrect). Can we add the derivation of the adjoints according to https://www.juliadiff.org/ChainRulesCore.jl/dev/arrays.html as docstrings or comments, or maybe even have a separate PR for the Mahalanobis fixes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. I think a separate PR for mahalanobis fixes makes more sense.

end
return Distances.pairwise(dist, a, b, dims=dims), pairwise_pullback
end

@adjoint function Distances.pairwise(
dist::SqMahalanobis,
a::AbstractMatrix;
dims::Union{Nothing,Integer}=nothing
)
function pairwise_pullback(Δ::AbstractMatrix)
B_Bᵀ = dist.qmat + transpose(dist.qmat)
a_a = map(
x -> (first(last(x)) - last(last(x)))*first(x),
zip(
Δ,
Iterators.product(eachslice(a, dims=dims), eachslice(a, dims=dims))
)
)
δa = reduce(hcat, sum(map(x -> B_Bᵀ*x, a_a), dims=2))
δB = sum(map(x -> x*transpose(x), a_a))
return (qmat=δB,), δa
end
return Distances.pairwise(dist, a, b, dims=dims), pairwise_pullback
end

3 changes: 1 addition & 2 deletions test/basekernels/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
@test metric(GammaExponentialKernel(γ=2.0)) == SqEuclidean()
@test repr(k) == "Gamma Exponential Kernel (γ = $(γ))"
@test KernelFunctions.iskroncompatible(k) == true
test_ADs(γ -> GammaExponentialKernel(gamma=first(γ)), [γ], ADs = [:ForwardDiff, :ReverseDiff])
@test_broken "Zygote gradient given γ"
test_ADs(γ -> GammaExponentialKernel(gamma=first(γ)), [γ])
test_params(k, ([γ],))
#Coherence :
@test GammaExponentialKernel(γ=1.0)(v1,v2) ≈ SqExponentialKernel()(v1,v2)
Expand Down
5 changes: 2 additions & 3 deletions test/basekernels/fbm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
@test kernelmatrix(k, x1*ones(1,1), x2*ones(1,1))[1] ≈ k(x1, x2) atol=1e-5

@test repr(k) == "Fractional Brownian Motion Kernel (h = $(h))"
test_ADs(FBMKernel, ADs = [:ReverseDiff])
@test_broken "Tests failing for kernelmatrix(k, x) for ForwardDiff and Zygote"

test_ADs(FBMKernel, ADs = [:ReverseDiff, :Zygote])
@test_broken "Tests failing for kernelmatrix(k, x) for ForwardDiff"
devmotion marked this conversation as resolved.
Show resolved Hide resolved
test_params(k, ([h],))
end
3 changes: 1 addition & 2 deletions test/basekernels/gabor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
@test k.ell ≈ 1.0 atol=1e-5
@test k.p ≈ 1.0 atol=1e-5
@test repr(k) == "Gabor Kernel (ell = 1.0, p = 1.0)"
#test_ADs(x -> GaborKernel(ell = x[1], p = x[2]), [ell, p])#, ADs = [:ForwardDiff, :ReverseDiff])
@test_broken "Tests failing for Zygote on differentiating through ell and p"
test_ADs(x -> GaborKernel(ell = x[1], p = x[2]), [ell, p], ADs = [:Zygote])
devmotion marked this conversation as resolved.
Show resolved Hide resolved
# Tests are also failing randomly for ForwardDiff and ReverseDiff but randomly
end
33 changes: 29 additions & 4 deletions test/basekernels/maha.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,39 @@
v1 = rand(rng, 3)
v2 = rand(rng, 3)

P = rand(rng, 3, 3)
k = MahalanobisKernel(P=P)

U = UpperTriangular(rand(rng, 3,3))
P = Matrix(Cholesky(U, 'U', 0))
@assert isposdef(P)
k = MahalanobisKernel(P)

sharanry marked this conversation as resolved.
Show resolved Hide resolved
@test kappa(k, x) == exp(-x)
@test k(v1, v2) ≈ exp(-sqmahalanobis(v1, v2, P))
@test kappa(ExponentialKernel(), x) == kappa(k, x)
@test repr(k) == "Mahalanobis Kernel (size(P) = $(size(P)))"
# test_ADs(P -> MahalanobisKernel(P=P), P)

M1, M2 = rand(rng,3,2), rand(rng,3,2)
fdm = FiniteDifferences.Central(5, 1);


function FiniteDifferences.to_vec(dist::SqMahalanobis{Float64})
return vec(dist.qmat), x -> SqMahalanobis(reshape(x, size(dist.qmat)...))
end

function test_mahakernel(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector)
return MahalanobisKernel(Array(U'*U))(v1, v2)
end

@test_broken all(j′vp(fdm, test_mahakernel, 1, U, v1, v2) .≈
Zygote.pullback(test_mahakernel, U, v1, v2)[2](1))

function test_sqmaha(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector)
return SqMahalanobis(Array(U'*U))(v1, v2)
end

@test_broken all(j′vp(fdm, test_sqmaha, 1, U, v1, v2) .≈
Zygote.pullback(test_sqmaha, U, v1, v2)[2](1))

# test_ADs(U -> MahalanobisKernel(Array(U' * U)), U, ADs=[:Zygote])
@test_broken "Nothing passes (problem with Mahalanobis distance in Distances)"

test_params(k, (P,))
Expand Down
7 changes: 3 additions & 4 deletions test/basekernels/nn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,10 @@
@test kerneldiagmatrix(k, m1) ≈ A4 atol=1e-5

A5 = ones(4,4)
@test_throws AssertionError kernelmatrix!(A5, k, m1, m2, obsdim=3)
@test_throws AssertionError kernelmatrix!(A5, k, m1, obsdim=3)
@test_throws AssertionError kernelmatrix!(A5, k, m1, m2; obsdim=3)
@test_throws AssertionError kernelmatrix!(A5, k, m1; obsdim=3)
@test_throws DimensionMismatch kernelmatrix!(A5, k, ones(4,3), ones(3,4))

@test k([x1], [x2]) ≈ k(x1, x2) atol=1e-5
test_ADs(NeuralNetworkKernel, ADs = [:ForwardDiff, :ReverseDiff])
@test_broken "Zygote uncompatible with BaseKernel"
test_ADs(NeuralNetworkKernel)
end
10 changes: 10 additions & 0 deletions test/zygote_adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
x = rand(rng, 5)
y = rand(rng, 5)
r = rand(rng, 5)
Q = Matrix(Cholesky(rand(rng, 5, 5), 'U', 0))
@assert isposdef(Q)


gzeucl = gradient(:Zygote, [x,y]) do xy
evaluate(Euclidean(), xy[1], xy[2])
Expand All @@ -20,6 +23,9 @@
gzsinus = gradient(:Zygote, [x,y]) do xy
evaluate(KernelFunctions.Sinus(r), xy[1], xy[2])
end
gzsqmaha = gradient(:Zygote, [Q,x,y]) do xy
sharanry marked this conversation as resolved.
Show resolved Hide resolved
evaluate(SqMahalanobis(xy[1]), xy[2], xy[3])
end

gfeucl = gradient(:FiniteDiff, [x,y]) do xy
evaluate(Euclidean(), xy[1], xy[2])
Expand All @@ -36,11 +42,15 @@
gfsinus = gradient(:FiniteDiff, [x,y]) do xy
evaluate(KernelFunctions.Sinus(r), xy[1], xy[2])
end
gfsqmaha = gradient(:FiniteDiff, [Q,x,y]) do xy
sharanry marked this conversation as resolved.
Show resolved Hide resolved
evaluate(SqMahalanobis(xy[1]), xy[2], xy[3])
end


@test all(gzeucl .≈ gfeucl)
@test all(gzsqeucl .≈ gfsqeucl)
@test all(gzdotprod .≈ gfdotprod)
@test all(gzdelta .≈ gfdelta)
@test all(gzsinus .≈ gfsinus)
@test all(gzsqmaha .≈ gfsqmaha)
end