diff --git a/src/basekernels/maha.jl b/src/basekernels/maha.jl index ca3d7436f..979bc0ed6 100644 --- a/src/basekernels/maha.jl +++ b/src/basekernels/maha.jl @@ -3,7 +3,7 @@ Mahalanobis distance-based kernel given by ```math - κ(x,y) = exp(-r^2), r^2 = maha(x,P,y) = (x-y)'*inv(P)*(x-y) + κ(x,y) = exp(-r^2), r^2 = maha(x,P,y) = (x-y)'* P *(x-y) ``` where the matrix P is the metric. diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index e02a9850a..e6b8ee8ea 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -5,7 +5,9 @@ The matern kernel is a Mercer kernel given by the formula: ``` κ(x,y) = 2^{1-ν}/Γ(ν)*(√(2ν)‖x-y‖)^ν K_ν(√(2ν)‖x-y‖) ``` -For `ν=n+1/2, n=0,1,2,...` it can be simplified and you should instead use [`ExponentialKernel`](@ref) for `n=0`, [`Matern32Kernel`](@ref), for `n=1`, [`Matern52Kernel`](@ref) for `n=2` and [`SqExponentialKernel`](@ref) for `n=∞`. +For `ν=n+1/2, n=0,1,2,...` it can be simplified and you should instead use +[`ExponentialKernel`](@ref) for `n=0`, [`Matern32Kernel`](@ref), for `n=1`, +[`Matern52Kernel`](@ref) for `n=2` and [`SqExponentialKernel`](@ref) for `n=∞`. """ struct MaternKernel{Tν<:Real} <: SimpleKernel ν::Vector{Tν} diff --git a/src/basekernels/nn.jl b/src/basekernels/nn.jl index c4727e991..be1d35361 100644 --- a/src/basekernels/nn.jl +++ b/src/basekernels/nn.jl @@ -23,4 +23,32 @@ function (κ::NeuralNetworkKernel)(x, y) return asin(dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y)))) end +function kernelmatrix(::NeuralNetworkKernel, x::ColVecs, y::ColVecs) + validate_inputs(x, y) + X_2 = sum(x.X .* x.X; dims=1) + Y_2 = sum(y.X .* y.X; dims=1) + XY = x.X' * y.X + return asin.(XY ./ sqrt.((X_2 .+ 1)' * (Y_2 .+ 1))) +end + +function kernelmatrix(::NeuralNetworkKernel, x::ColVecs) + X_2_1 = sum(x.X .* x.X; dims=1) .+ 1 + XX = x.X' * x.X + return asin.(XX ./ sqrt.(X_2_1' * X_2_1)) +end + +function kernelmatrix(::NeuralNetworkKernel, x::RowVecs, y::RowVecs) + validate_inputs(x, y) + X_2 = sum(x.X .* x.X; dims=2) + Y_2 = sum(y.X .* y.X; dims=2) + XY = x.X * y.X' + return asin.(XY ./ sqrt.((X_2 .+ 1)' * (Y_2 .+ 1))) +end + +function kernelmatrix(::NeuralNetworkKernel, x::RowVecs) + X_2_1 = sum(x.X .* x.X; dims=2) .+ 1 + XX = x.X * x.X' + return asin.(XX ./ sqrt.(X_2_1' * X_2_1)) +end + Base.show(io::IO, κ::NeuralNetworkKernel) = print(io, "Neural Network Kernel") diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index f51466fb6..e3b15c115 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -60,21 +60,21 @@ end end @adjoint function ColVecs(X::AbstractMatrix) - back(Δ::NamedTuple) = (Δ.X,) - back(Δ::AbstractMatrix) = (Δ,) - function back(Δ::AbstractVector{<:AbstractVector{<:Real}}) + ColVecs_pullback(Δ::NamedTuple) = (Δ.X,) + ColVecs_pullback(Δ::AbstractMatrix) = (Δ,) + function ColVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) throw(error("In slow method")) end - return ColVecs(X), back + return ColVecs(X), ColVecs_pullback end @adjoint function RowVecs(X::AbstractMatrix) - back(Δ::NamedTuple) = (Δ.X,) - back(Δ::AbstractMatrix) = (Δ,) - function back(Δ::AbstractVector{<:AbstractVector{<:Real}}) + RowVecs_pullback(Δ::NamedTuple) = (Δ.X,) + RowVecs_pullback(Δ::AbstractMatrix) = (Δ,) + function RowVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) throw(error("In slow method")) end - return RowVecs(X), back + return RowVecs(X), RowVecs_pullback end @adjoint function Base.map(t::Transform, X::ColVecs) @@ -84,3 +84,13 @@ end @adjoint function Base.map(t::Transform, X::RowVecs) pullback(_map, t, X) end + +@adjoint function (dist::Distances.SqMahalanobis)(a, b) + function SqMahalanobis_pullback(Δ::Real) + B_Bᵀ = dist.qmat + transpose(dist.qmat) + a_b = a - b + δa = (B_Bᵀ * a_b) * Δ + return (qmat = (a_b * a_b') * Δ,), δa, -δa + end + return evaluate(dist, a, b), SqMahalanobis_pullback +end diff --git a/test/basekernels/exponential.jl b/test/basekernels/exponential.jl index 692e0983c..389ab2be5 100644 --- a/test/basekernels/exponential.jl +++ b/test/basekernels/exponential.jl @@ -38,8 +38,7 @@ @test metric(GammaExponentialKernel(γ=2.0)) == SqEuclidean() @test repr(k) == "Gamma Exponential Kernel (γ = $(γ))" @test KernelFunctions.iskroncompatible(k) == true - test_ADs(γ -> GammaExponentialKernel(gamma=first(γ)), [γ], ADs = [:ForwardDiff, :ReverseDiff]) - @test_broken "Zygote gradient given γ" + test_ADs(γ -> GammaExponentialKernel(gamma=first(γ)), [γ]) test_params(k, ([γ],)) #Coherence : @test GammaExponentialKernel(γ=1.0)(v1,v2) ≈ SqExponentialKernel()(v1,v2) diff --git a/test/basekernels/fbm.jl b/test/basekernels/fbm.jl index 0428d8c80..c7b0eb620 100644 --- a/test/basekernels/fbm.jl +++ b/test/basekernels/fbm.jl @@ -22,8 +22,7 @@ @test kernelmatrix(k, x1*ones(1,1), x2*ones(1,1))[1] ≈ k(x1, x2) atol=1e-5 @test repr(k) == "Fractional Brownian Motion Kernel (h = $(h))" - test_ADs(FBMKernel, ADs = [:ReverseDiff]) - @test_broken "Tests failing for kernelmatrix(k, x) for ForwardDiff and Zygote" - + test_ADs(FBMKernel, ADs = [:ReverseDiff, :Zygote]) + @test_broken "Tests failing for kernelmatrix(k, x) for ForwardDiff" test_params(k, ([h],)) end diff --git a/test/basekernels/gabor.jl b/test/basekernels/gabor.jl index 052a53eac..6488010ee 100644 --- a/test/basekernels/gabor.jl +++ b/test/basekernels/gabor.jl @@ -17,7 +17,6 @@ @test k.ell ≈ 1.0 atol=1e-5 @test k.p ≈ 1.0 atol=1e-5 @test repr(k) == "Gabor Kernel (ell = 1.0, p = 1.0)" - #test_ADs(x -> GaborKernel(ell = x[1], p = x[2]), [ell, p])#, ADs = [:ForwardDiff, :ReverseDiff]) - @test_broken "Tests failing for Zygote on differentiating through ell and p" + test_ADs(x -> GaborKernel(ell = x[1], p = x[2]), [ell, p], ADs = [:Zygote]) # Tests are also failing randomly for ForwardDiff and ReverseDiff but randomly end diff --git a/test/basekernels/maha.jl b/test/basekernels/maha.jl index 898df7b6e..1daf3cd69 100644 --- a/test/basekernels/maha.jl +++ b/test/basekernels/maha.jl @@ -4,14 +4,40 @@ v1 = rand(rng, 3) v2 = rand(rng, 3) - P = rand(rng, 3, 3) + U = UpperTriangular(rand(rng, 3,3)) + P = Matrix(Cholesky(U, 'U', 0)) + @assert isposdef(P) k = MahalanobisKernel(P=P) @test kappa(k, x) == exp(-x) @test k(v1, v2) ≈ exp(-sqmahalanobis(v1, v2, P)) @test kappa(ExponentialKernel(), x) == kappa(k, x) @test repr(k) == "Mahalanobis Kernel (size(P) = $(size(P)))" - # test_ADs(P -> MahalanobisKernel(P=P), P) + + M1, M2 = rand(rng,3,2), rand(rng,3,2) + fdm = FiniteDifferences.Central(5, 1); + + + function FiniteDifferences.to_vec(dist::SqMahalanobis{Float64}) + return vec(dist.qmat), x -> SqMahalanobis(reshape(x, size(dist.qmat)...)) + end + a = rand() + + function test_mahakernel(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector) + return MahalanobisKernel(P=Array(U'*U))(v1, v2) + end + + @test all(FiniteDifferences.j′vp(fdm, test_mahakernel, a, U, v1, v2)[1] .≈ + UpperTriangular(Zygote.pullback(test_mahakernel, U, v1, v2)[2](a)[1])) + + function test_sqmaha(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector) + return SqMahalanobis(Array(U'*U))(v1, v2) + end + + @test all(FiniteDifferences.j′vp(fdm, test_sqmaha, a, U, v1, v2)[1] .≈ + UpperTriangular(Zygote.pullback(test_sqmaha, U, v1, v2)[2](a)[1])) + + # test_ADs(U -> MahalanobisKernel(P=Array(U' * U)), U, ADs=[:Zygote]) @test_broken "Nothing passes (problem with Mahalanobis distance in Distances)" test_params(k, (P,)) diff --git a/test/basekernels/nn.jl b/test/basekernels/nn.jl index 6d6bb272c..a46208505 100644 --- a/test/basekernels/nn.jl +++ b/test/basekernels/nn.jl @@ -38,11 +38,10 @@ @test kerneldiagmatrix(k, m1) ≈ A4 atol=1e-5 A5 = ones(4,4) - @test_throws AssertionError kernelmatrix!(A5, k, m1, m2, obsdim=3) - @test_throws AssertionError kernelmatrix!(A5, k, m1, obsdim=3) + @test_throws AssertionError kernelmatrix!(A5, k, m1, m2; obsdim=3) + @test_throws AssertionError kernelmatrix!(A5, k, m1; obsdim=3) @test_throws DimensionMismatch kernelmatrix!(A5, k, ones(4,3), ones(3,4)) @test k([x1], [x2]) ≈ k(x1, x2) atol=1e-5 - test_ADs(NeuralNetworkKernel, ADs = [:ForwardDiff, :ReverseDiff]) - @test_broken "Zygote uncompatible with BaseKernel" + test_ADs(NeuralNetworkKernel) end diff --git a/test/zygote_adjoints.jl b/test/zygote_adjoints.jl index 5e9447b37..b57750728 100644 --- a/test/zygote_adjoints.jl +++ b/test/zygote_adjoints.jl @@ -4,38 +4,47 @@ x = rand(rng, 5) y = rand(rng, 5) r = rand(rng, 5) + Q = Matrix(Cholesky(rand(rng, 5, 5), 'U', 0)) + @assert isposdef(Q) - gzeucl = gradient(:Zygote, [x,y]) do xy + + gzeucl = gradient(:Zygote, [x, y]) do xy evaluate(Euclidean(), xy[1], xy[2]) end - gzsqeucl = gradient(:Zygote, [x,y]) do xy + gzsqeucl = gradient(:Zygote, [x, y]) do xy evaluate(SqEuclidean(), xy[1], xy[2]) end - gzdotprod = gradient(:Zygote, [x,y]) do xy + gzdotprod = gradient(:Zygote, [x, y]) do xy evaluate(KernelFunctions.DotProduct(), xy[1], xy[2]) end - gzdelta = gradient(:Zygote, [x,y]) do xy + gzdelta = gradient(:Zygote, [x, y]) do xy evaluate(KernelFunctions.Delta(), xy[1], xy[2]) end - gzsinus = gradient(:Zygote, [x,y]) do xy + gzsinus = gradient(:Zygote, [x, y]) do xy evaluate(KernelFunctions.Sinus(r), xy[1], xy[2]) end + gzsqmaha = gradient(:Zygote, [Q, x, y]) do xy + evaluate(SqMahalanobis(xy[1]), xy[2], xy[3]) + end - gfeucl = gradient(:FiniteDiff, [x,y]) do xy + gfeucl = gradient(:FiniteDiff, [x, y]) do xy evaluate(Euclidean(), xy[1], xy[2]) end - gfsqeucl = gradient(:FiniteDiff, [x,y]) do xy + gfsqeucl = gradient(:FiniteDiff, [x, y]) do xy evaluate(SqEuclidean(), xy[1], xy[2]) end - gfdotprod = gradient(:FiniteDiff, [x,y]) do xy + gfdotprod = gradient(:FiniteDiff, [x, y]) do xy evaluate(KernelFunctions.DotProduct(), xy[1], xy[2]) end - gfdelta = gradient(:FiniteDiff, [x,y]) do xy + gfdelta = gradient(:FiniteDiff, [x, y]) do xy evaluate(KernelFunctions.Delta(), xy[1], xy[2]) end - gfsinus = gradient(:FiniteDiff, [x,y]) do xy + gfsinus = gradient(:FiniteDiff, [x, y]) do xy evaluate(KernelFunctions.Sinus(r), xy[1], xy[2]) end + gfsqmaha = gradient(:FiniteDiff, [Q, x, y]) do xy + evaluate(SqMahalanobis(xy[1]), xy[2], xy[3]) + end @test all(gzeucl .≈ gfeucl) @@ -43,4 +52,5 @@ @test all(gzdotprod .≈ gfdotprod) @test all(gzdelta .≈ gfdelta) @test all(gzsinus .≈ gfsinus) + @test all(gzsqmaha .≈ gfsqmaha) end