From a6211d0adf109323b33635b1ecaee0ff800bf50a Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sun, 16 Aug 2020 14:11:32 +0530 Subject: [PATCH 01/23] Zygote passes for Exponential and FBM kernel --- test/basekernels/exponential.jl | 3 +-- test/basekernels/fbm.jl | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/test/basekernels/exponential.jl b/test/basekernels/exponential.jl index e890a3a15..b0b31480a 100644 --- a/test/basekernels/exponential.jl +++ b/test/basekernels/exponential.jl @@ -38,8 +38,7 @@ @test metric(GammaExponentialKernel(γ=2.0)) == SqEuclidean() @test repr(k) == "Gamma Exponential Kernel (γ = $(γ))" @test KernelFunctions.iskroncompatible(k) == true - test_ADs(γ -> GammaExponentialKernel(gamma=first(γ)), [γ], ADs = [:ForwardDiff, :ReverseDiff]) - @test_broken "Zygote gradient given γ" + test_ADs(γ -> GammaExponentialKernel(gamma=first(γ)), [γ]) #Coherence : @test GammaExponentialKernel(γ=1.0)(v1,v2) ≈ SqExponentialKernel()(v1,v2) @test GammaExponentialKernel(γ=0.5)(v1,v2) ≈ ExponentialKernel()(v1,v2) diff --git a/test/basekernels/fbm.jl b/test/basekernels/fbm.jl index 77ed3b537..70078cabc 100644 --- a/test/basekernels/fbm.jl +++ b/test/basekernels/fbm.jl @@ -22,6 +22,6 @@ @test kernelmatrix(k, x1*ones(1,1), x2*ones(1,1))[1] ≈ k(x1, x2) atol=1e-5 @test repr(k) == "Fractional Brownian Motion Kernel (h = $(h))" - test_ADs(FBMKernel, ADs = [:ReverseDiff]) - @test_broken "Tests failing for kernelmatrix(k, x) for ForwardDiff and Zygote" + test_ADs(FBMKernel, ADs = [:ReverseDiff, :Zygote]) + @test_broken "Tests failing for kernelmatrix(k, x) for ForwardDiff" end From 8704f181c64e0970498e2d43ea473682bc89eff6 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sun, 16 Aug 2020 14:40:35 +0530 Subject: [PATCH 02/23] Zygote passes NN kernel --- src/basekernels/nn.jl | 30 +++++++++++++++++++++++++++++- test/basekernels/nn.jl | 3 +-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/basekernels/nn.jl b/src/basekernels/nn.jl index c4727e991..441a7415b 100644 --- a/src/basekernels/nn.jl +++ b/src/basekernels/nn.jl @@ -20,7 +20,35 @@ Bayesian neural network with erf (Error Function) as activation function. struct NeuralNetworkKernel <: Kernel end function (κ::NeuralNetworkKernel)(x, y) - return asin(dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y)))) + return asin(dot(x, y) / sqrt((1 + sum(abs2.(x))) * (1 + sum(abs2.(y))))) +end + +function kernelmatrix(::NeuralNetworkKernel, x::ColVecs, y::ColVecs) + validate_inputs(x, y) + X_2 = sum(x.X .* x.X, dims=1) + Y_2 = sum(y.X .* y.X, dims=1) + XY = x.X' * y.X + return asin.(XY ./ sqrt.((X_2 .+ 1.0)' * (Y_2 .+ 1.0))) +end + +function kernelmatrix(::NeuralNetworkKernel, x::ColVecs) + X_2_1 = sum(x.X .* x.X, dims=1) .+ 1.0 + XX = x.X' * x.X + return asin.(XX ./ sqrt.(X_2_1' * X_2_1)) +end + +function kernelmatrix(::NeuralNetworkKernel, x::RowVecs, y::RowVecs) + validate_inputs(x, y) + X_2 = sum(x.X .* x.X, dims=2) + Y_2 = sum(y.X .* y.X, dims=2) + XY = x.X * y.X' + return asin.(XY ./ sqrt.((X_2 .+ 1.0)' * (Y_2 .+ 1.0))) +end + +function kernelmatrix(::NeuralNetworkKernel, x::RowVecs) + X_2_1 = sum(x.X .* x.X, dims=2) .+ 1.0 + XX = x.X * x.X' + return asin.(XX ./ sqrt.(X_2_1' * X_2_1)) end Base.show(io::IO, κ::NeuralNetworkKernel) = print(io, "Neural Network Kernel") diff --git a/test/basekernels/nn.jl b/test/basekernels/nn.jl index 6d6bb272c..b021055c1 100644 --- a/test/basekernels/nn.jl +++ b/test/basekernels/nn.jl @@ -43,6 +43,5 @@ @test_throws DimensionMismatch kernelmatrix!(A5, k, ones(4,3), ones(3,4)) @test k([x1], [x2]) ≈ k(x1, x2) atol=1e-5 - test_ADs(NeuralNetworkKernel, ADs = [:ForwardDiff, :ReverseDiff]) - @test_broken "Zygote uncompatible with BaseKernel" + test_ADs(NeuralNetworkKernel) end From 8f44c511372b66f82195ffaa011e2feb8c05477d Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sun, 16 Aug 2020 14:49:03 +0530 Subject: [PATCH 03/23] Zygote passes Gabor kernel --- test/basekernels/gabor.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/basekernels/gabor.jl b/test/basekernels/gabor.jl index 052a53eac..6488010ee 100644 --- a/test/basekernels/gabor.jl +++ b/test/basekernels/gabor.jl @@ -17,7 +17,6 @@ @test k.ell ≈ 1.0 atol=1e-5 @test k.p ≈ 1.0 atol=1e-5 @test repr(k) == "Gabor Kernel (ell = 1.0, p = 1.0)" - #test_ADs(x -> GaborKernel(ell = x[1], p = x[2]), [ell, p])#, ADs = [:ForwardDiff, :ReverseDiff]) - @test_broken "Tests failing for Zygote on differentiating through ell and p" + test_ADs(x -> GaborKernel(ell = x[1], p = x[2]), [ell, p], ADs = [:Zygote]) # Tests are also failing randomly for ForwardDiff and ReverseDiff but randomly end From 14db1f4840e7ef759972e91c425959b5c7558599 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sun, 16 Aug 2020 16:27:45 +0530 Subject: [PATCH 04/23] Address code review --- src/basekernels/nn.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/basekernels/nn.jl b/src/basekernels/nn.jl index 441a7415b..3f2ef194f 100644 --- a/src/basekernels/nn.jl +++ b/src/basekernels/nn.jl @@ -20,7 +20,7 @@ Bayesian neural network with erf (Error Function) as activation function. struct NeuralNetworkKernel <: Kernel end function (κ::NeuralNetworkKernel)(x, y) - return asin(dot(x, y) / sqrt((1 + sum(abs2.(x))) * (1 + sum(abs2.(y))))) + return asin(dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y)))) end function kernelmatrix(::NeuralNetworkKernel, x::ColVecs, y::ColVecs) @@ -28,11 +28,11 @@ function kernelmatrix(::NeuralNetworkKernel, x::ColVecs, y::ColVecs) X_2 = sum(x.X .* x.X, dims=1) Y_2 = sum(y.X .* y.X, dims=1) XY = x.X' * y.X - return asin.(XY ./ sqrt.((X_2 .+ 1.0)' * (Y_2 .+ 1.0))) + return asin.(XY ./ sqrt.((X_2 .+ 1)' * (Y_2 .+ 1))) end function kernelmatrix(::NeuralNetworkKernel, x::ColVecs) - X_2_1 = sum(x.X .* x.X, dims=1) .+ 1.0 + X_2_1 = sum(x.X .* x.X, dims=1) .+ 1 XX = x.X' * x.X return asin.(XX ./ sqrt.(X_2_1' * X_2_1)) end @@ -42,11 +42,11 @@ function kernelmatrix(::NeuralNetworkKernel, x::RowVecs, y::RowVecs) X_2 = sum(x.X .* x.X, dims=2) Y_2 = sum(y.X .* y.X, dims=2) XY = x.X * y.X' - return asin.(XY ./ sqrt.((X_2 .+ 1.0)' * (Y_2 .+ 1.0))) + return asin.(XY ./ sqrt.((X_2 .+ 1)' * (Y_2 .+ 1))) end function kernelmatrix(::NeuralNetworkKernel, x::RowVecs) - X_2_1 = sum(x.X .* x.X, dims=2) .+ 1.0 + X_2_1 = sum(x.X .* x.X, dims=2) .+ 1 XX = x.X * x.X' return asin.(XX ./ sqrt.(X_2_1' * X_2_1)) end From 90c1dff5ecf7884614a9258261376899cd3ffda7 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Thu, 20 Aug 2020 13:58:47 +0530 Subject: [PATCH 05/23] Fix mutating arrays problem for maha kernel --- src/basekernels/maha.jl | 38 +++++++++++++++++++++++++++++++++++++- test/basekernels/maha.jl | 2 +- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/src/basekernels/maha.jl b/src/basekernels/maha.jl index 5c06b5117..166521e9f 100644 --- a/src/basekernels/maha.jl +++ b/src/basekernels/maha.jl @@ -3,7 +3,7 @@ Mahalanobis distance-based kernel given by ```math - κ(x,y) = exp(-r^2), r^2 = maha(x,P,y) = (x-y)'*inv(P)*(x-y) + κ(x,y) = exp(-r^2), r^2 = maha(x,P,y) = (x-y)'* P *(x-y) ``` where the matrix P is the metric. @@ -20,4 +20,40 @@ kappa(κ::MahalanobisKernel, d::T) where {T<:Real} = exp(-d) metric(κ::MahalanobisKernel) = SqMahalanobis(κ.P) +function dot_perslice(A::AbstractMatrix, B::AbstractMatrix; dims=2) + return reshape(sum(A .* B, dims=3-dims), :) +end + +function Distances.pairwise( + metric::SqMahalanobis, + a::AbstractMatrix, + b::AbstractMatrix; + dims::Union{Nothing,Integer}=nothing + ) + dims = Distances.deprecated_dims(dims) + Q = metric.qmat + Qa = Q * a + Qb = Q * b + sa2 = dot_perslice(a, Qa; dims=dims) + sb2 = dot_perslice(b, Qb; dims=dims) + r = a' * Qb + r = sa2 .+ sb2' - 2 * r + return r +end + +function Distances.pairwise( + metric::SqMahalanobis, + a::AbstractMatrix, + dims::Union{Nothing,Integer}=nothing + ) + dims = Distances.deprecated_dims(dims) + Q = metric.qmat + Qa = Q * a + sa2 = dot_perslice(a, Qa; dims=dims) + r = a' * Qa + r = sa2 .+ sa2' - 2 * r + r = Symmetric(r') + return r + end + Base.show(io::IO, κ::MahalanobisKernel) = print(io, "Mahalanobis Kernel (size(P) = ", size(κ.P), ")") diff --git a/test/basekernels/maha.jl b/test/basekernels/maha.jl index e5ecba3d0..4baa8946e 100644 --- a/test/basekernels/maha.jl +++ b/test/basekernels/maha.jl @@ -11,6 +11,6 @@ @test k(v1, v2) ≈ exp(-sqmahalanobis(v1, v2, P)) @test kappa(ExponentialKernel(), x) == kappa(k, x) @test repr(k) == "Mahalanobis Kernel (size(P) = $(size(P)))" - # test_ADs(P -> MahalanobisKernel(P), P) + test_ADs(P -> MahalanobisKernel(P), P, ADs=[:Zygote]) @test_broken "Nothing passes (problem with Mahalanobis distance in Distances)" end From dcf1f6bad31253aac8080179fb6a944988bc40f6 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Fri, 21 Aug 2020 12:54:24 +0530 Subject: [PATCH 06/23] Add adjoint for maha distance metric --- src/basekernels/maha.jl | 32 -------------------------------- src/zygote_adjoints.jl | 10 ++++++++++ 2 files changed, 10 insertions(+), 32 deletions(-) diff --git a/src/basekernels/maha.jl b/src/basekernels/maha.jl index 166521e9f..d60ea2ac8 100644 --- a/src/basekernels/maha.jl +++ b/src/basekernels/maha.jl @@ -24,36 +24,4 @@ function dot_perslice(A::AbstractMatrix, B::AbstractMatrix; dims=2) return reshape(sum(A .* B, dims=3-dims), :) end -function Distances.pairwise( - metric::SqMahalanobis, - a::AbstractMatrix, - b::AbstractMatrix; - dims::Union{Nothing,Integer}=nothing - ) - dims = Distances.deprecated_dims(dims) - Q = metric.qmat - Qa = Q * a - Qb = Q * b - sa2 = dot_perslice(a, Qa; dims=dims) - sb2 = dot_perslice(b, Qb; dims=dims) - r = a' * Qb - r = sa2 .+ sb2' - 2 * r - return r -end - -function Distances.pairwise( - metric::SqMahalanobis, - a::AbstractMatrix, - dims::Union{Nothing,Integer}=nothing - ) - dims = Distances.deprecated_dims(dims) - Q = metric.qmat - Qa = Q * a - sa2 = dot_perslice(a, Qa; dims=dims) - r = a' * Qa - r = sa2 .+ sa2' - 2 * r - r = Symmetric(r') - return r - end - Base.show(io::IO, κ::MahalanobisKernel) = print(io, "Mahalanobis Kernel (size(P) = ", size(κ.P), ")") diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index f51466fb6..d44166e0a 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -84,3 +84,13 @@ end @adjoint function Base.map(t::Transform, X::RowVecs) pullback(_map, t, X) end + +@adjoint function Distances.evaluate(dist::SqMahalanobis, a, b) + function back(Δ::Real) + B_B_inv = dist.qmat + transpose(dist.qmat) + a_b = a - b + δa = B_B_inv * a_b + return (qmat = a_b * a_b',), δa, -δa + end + return evaluate(dist::SqMahalanobis, a, b), back +end From 16e8af65ce7861799693e235918bebf56773e2af Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Fri, 21 Aug 2020 13:02:50 +0530 Subject: [PATCH 07/23] Fix zygote adjoint --- src/zygote_adjoints.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index d44166e0a..be829a230 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -85,7 +85,7 @@ end pullback(_map, t, X) end -@adjoint function Distances.evaluate(dist::SqMahalanobis, a, b) +@adjoint function (dist::Distances.SqMahalanobis)(a, b) function back(Δ::Real) B_B_inv = dist.qmat + transpose(dist.qmat) a_b = a - b From ede5879898fc509b8323b21e51e54cac076e1f1f Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Fri, 21 Aug 2020 13:28:11 +0530 Subject: [PATCH 08/23] Fix adjoint typo --- src/zygote_adjoints.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index be829a230..d0f6ef4f7 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -87,10 +87,10 @@ end @adjoint function (dist::Distances.SqMahalanobis)(a, b) function back(Δ::Real) - B_B_inv = dist.qmat + transpose(dist.qmat) + B_Bᵀ = dist.qmat + transpose(dist.qmat) a_b = a - b - δa = B_B_inv * a_b + δa = B_Bᵀ * a_b return (qmat = a_b * a_b',), δa, -δa end - return evaluate(dist::SqMahalanobis, a, b), back + return evaluate(dist, a, b), back end From e8b76ec587c1f951264f22a1c5e622b964b323c7 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Fri, 21 Aug 2020 13:33:24 +0530 Subject: [PATCH 09/23] Fix buggy version of pairwise adjoint --- src/zygote_adjoints.jl | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index d0f6ef4f7..35dd142ca 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -94,3 +94,27 @@ end end return evaluate(dist, a, b), back end + + +# FIXME +function Distances.pairwise( + dist::SqMahalanobis, + a::AbstractMatrix, + b::AbstractMatrix; + dims::Union{Nothing,Integer}=nothing + ) + function back(Δ::AbstractMatrix) + B_B_t = dist.qmat + transpose(dist.qmat) + a_b = map( + x -> (first(last(x)) - last(last(x)))*first(x), + zip( + Δ, + Iterators.product(eachslice(a, dims=dims), eachslice(b, dims=dims)) + ) + ) + δa = reduce(hcat, sum(map(x -> B_B_t*x, a_b), dims=1)) + δB = sum(map(x -> x*transpose(x), a_b)) + return (qmat=δB,), δa, -δa + end + return Distances.pairwise(dist, a, b, dims=dims), back +end From e236aaf48cb6af7fd5ec92b89876cbfdfecf8a19 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Fri, 21 Aug 2020 13:34:18 +0530 Subject: [PATCH 10/23] Fix typo --- src/zygote_adjoints.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index 35dd142ca..e402e2825 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -104,7 +104,7 @@ function Distances.pairwise( dims::Union{Nothing,Integer}=nothing ) function back(Δ::AbstractMatrix) - B_B_t = dist.qmat + transpose(dist.qmat) + B_Bᵀ = dist.qmat + transpose(dist.qmat) a_b = map( x -> (first(last(x)) - last(last(x)))*first(x), zip( @@ -112,7 +112,7 @@ function Distances.pairwise( Iterators.product(eachslice(a, dims=dims), eachslice(b, dims=dims)) ) ) - δa = reduce(hcat, sum(map(x -> B_B_t*x, a_b), dims=1)) + δa = reduce(hcat, sum(map(x -> B_Bᵀ*x, a_b), dims=1)) δB = sum(map(x -> x*transpose(x), a_b)) return (qmat=δB,), δa, -δa end From d50c73fc7bd10520a51a8927ea1bfa9f291bae85 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Fri, 21 Aug 2020 13:45:50 +0530 Subject: [PATCH 11/23] Forgot to add adjoint macro --- src/zygote_adjoints.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index e402e2825..b4d7dce04 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -97,7 +97,7 @@ end # FIXME -function Distances.pairwise( +@adjoint function Distances.pairwise( dist::SqMahalanobis, a::AbstractMatrix, b::AbstractMatrix; From 090cc8a8d85e96c94a8a1a1be70215411694b0d0 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sat, 22 Aug 2020 13:38:36 +0530 Subject: [PATCH 12/23] Add pairwise sqmahalanobis adjoint and test of sqmahalanobis --- src/zygote_adjoints.jl | 3 +-- test/zygote_adjoints.jl | 10 ++++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index b4d7dce04..549180392 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -96,7 +96,6 @@ end end -# FIXME @adjoint function Distances.pairwise( dist::SqMahalanobis, a::AbstractMatrix, @@ -112,7 +111,7 @@ end Iterators.product(eachslice(a, dims=dims), eachslice(b, dims=dims)) ) ) - δa = reduce(hcat, sum(map(x -> B_Bᵀ*x, a_b), dims=1)) + δa = reduce(hcat, sum(map(x -> B_Bᵀ*x, a_b), dims=2)) δB = sum(map(x -> x*transpose(x), a_b)) return (qmat=δB,), δa, -δa end diff --git a/test/zygote_adjoints.jl b/test/zygote_adjoints.jl index 5e9447b37..d661e5721 100644 --- a/test/zygote_adjoints.jl +++ b/test/zygote_adjoints.jl @@ -4,6 +4,9 @@ x = rand(rng, 5) y = rand(rng, 5) r = rand(rng, 5) + Q = Matrix(Cholesky(rand(rng, 5, 5), 'U', 0)) + @assert isposdef(Q) + gzeucl = gradient(:Zygote, [x,y]) do xy evaluate(Euclidean(), xy[1], xy[2]) @@ -20,6 +23,9 @@ gzsinus = gradient(:Zygote, [x,y]) do xy evaluate(KernelFunctions.Sinus(r), xy[1], xy[2]) end + gzsqmaha = gradient(:Zygote, [Q,x,y]) do xy + evaluate(SqMahalanobis(xy[1]), xy[2], xy[3]) + end gfeucl = gradient(:FiniteDiff, [x,y]) do xy evaluate(Euclidean(), xy[1], xy[2]) @@ -36,6 +42,9 @@ gfsinus = gradient(:FiniteDiff, [x,y]) do xy evaluate(KernelFunctions.Sinus(r), xy[1], xy[2]) end + gfsqmaha = gradient(:FiniteDiff, [Q,x,y]) do xy + evaluate(SqMahalanobis(xy[1]), xy[2], xy[3]) + end @test all(gzeucl .≈ gfeucl) @@ -43,4 +52,5 @@ @test all(gzdotprod .≈ gfdotprod) @test all(gzdelta .≈ gfdelta) @test all(gzsinus .≈ gfsinus) + @test all(gzsqmaha .≈ gfsqmaha) end From 45c14d60dd0fe408a91680949000016e563bd973 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Mon, 24 Aug 2020 18:03:18 +0530 Subject: [PATCH 13/23] Maha kernel tests --- src/basekernels/maha.jl | 4 ---- test/basekernels/maha.jl | 18 ++++++++++++++++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/basekernels/maha.jl b/src/basekernels/maha.jl index d60ea2ac8..86639840a 100644 --- a/src/basekernels/maha.jl +++ b/src/basekernels/maha.jl @@ -20,8 +20,4 @@ kappa(κ::MahalanobisKernel, d::T) where {T<:Real} = exp(-d) metric(κ::MahalanobisKernel) = SqMahalanobis(κ.P) -function dot_perslice(A::AbstractMatrix, B::AbstractMatrix; dims=2) - return reshape(sum(A .* B, dims=3-dims), :) -end - Base.show(io::IO, κ::MahalanobisKernel) = print(io, "Mahalanobis Kernel (size(P) = ", size(κ.P), ")") diff --git a/test/basekernels/maha.jl b/test/basekernels/maha.jl index 4baa8946e..4ad54c848 100644 --- a/test/basekernels/maha.jl +++ b/test/basekernels/maha.jl @@ -4,13 +4,27 @@ v1 = rand(rng, 3) v2 = rand(rng, 3) - P = rand(rng, 3, 3) + U = UpperTriangular(rand(rng, 3,3)) + P = Matrix(Cholesky(U, 'U', 0)) + @assert isposdef(P) k = MahalanobisKernel(P) @test kappa(k, x) == exp(-x) @test k(v1, v2) ≈ exp(-sqmahalanobis(v1, v2, P)) @test kappa(ExponentialKernel(), x) == kappa(k, x) @test repr(k) == "Mahalanobis Kernel (size(P) = $(size(P)))" - test_ADs(P -> MahalanobisKernel(P), P, ADs=[:Zygote]) + + M1, M2 = rand(rng,3,2), rand(rng,3,2) + fdm = FiniteDifferences.Central(5, 1); + + + FiniteDifferences.to_vec(dist::SqMahalanobis{Float64}) = vec(dist.qmat), x -> SqMahalanobis(reshape(x, size(dist.qmat)...)) + + @test_broken j′vp(fdm, x -> MahalanobisKernel(Array(x[1]'*x[1]))(x[2], x[3]), 1, [U, v1, v2]) ≈ + Zygote.pullback(x -> MahalanobisKernel(Array(x[1]'*x[1]))(x[2], x[3]), [U, v1, v2])[2](1) + @test all(j′vp(fdm, x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), 1, [U, v1, v2])[1][1] .≈ + Zygote.pullback(x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), [U, v1, v2])[2](1)[1][1]) + + # test_ADs(U -> MahalanobisKernel(Array(U' * U)), U, ADs=[:Zygote]) @test_broken "Nothing passes (problem with Mahalanobis distance in Distances)" end From b920c196c250ba6a22ac7cfbf43672c868a6083e Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Wed, 26 Aug 2020 13:41:18 +0530 Subject: [PATCH 14/23] Fix zygote adjoint for mahalanobis --- src/zygote_adjoints.jl | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index 549180392..b18304f08 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -89,8 +89,8 @@ end function back(Δ::Real) B_Bᵀ = dist.qmat + transpose(dist.qmat) a_b = a - b - δa = B_Bᵀ * a_b - return (qmat = a_b * a_b',), δa, -δa + δa = (B_Bᵀ * a_b) * Δ + return (qmat = (a_b * a_b') * Δ,), δa, -δa end return evaluate(dist, a, b), back end @@ -117,3 +117,25 @@ end end return Distances.pairwise(dist, a, b, dims=dims), back end + +@adjoint function Distances.pairwise( + dist::SqMahalanobis, + a::AbstractMatrix; + dims::Union{Nothing,Integer}=nothing + ) + function back(Δ::AbstractMatrix) + B_Bᵀ = dist.qmat + transpose(dist.qmat) + a_a = map( + x -> (first(last(x)) - last(last(x)))*first(x), + zip( + Δ, + Iterators.product(eachslice(a, dims=dims), eachslice(a, dims=dims)) + ) + ) + δa = reduce(hcat, sum(map(x -> B_Bᵀ*x, a_a), dims=2)) + δB = sum(map(x -> x*transpose(x), a_a)) + return (qmat=δB,), δa + end + return Distances.pairwise(dist, a, b, dims=dims), back +end + From 2630adcf1bb1ba28bf1dbace855bb853361b57df Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Wed, 26 Aug 2020 13:41:54 +0530 Subject: [PATCH 15/23] Fix docs for matern --- src/basekernels/matern.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index 13fb455f6..6fc973853 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -5,7 +5,9 @@ The matern kernel is a Mercer kernel given by the formula: ``` κ(x,y) = 2^{1-ν}/Γ(ν)*(√(2ν)‖x-y‖)^ν K_ν(√(2ν)‖x-y‖) ``` -For `ν=n+1/2, n=0,1,2,...` it can be simplified and you should instead use [`ExponentialKernel`](@ref) for `n=0`, [`Matern32Kernel`](@ref), for `n=1`, [`Matern52Kernel`](@ref) for `n=2` and [`SqExponentialKernel`](@ref) for `n=∞`. +For `ν=n+1/2, n=0,1,2,...` it can be simplified and you should instead use +[`ExponentialKernel`](@ref) for `n=0`, [`Matern32Kernel`](@ref), for `n=1`, +[`Matern52Kernel`](@ref) for `n=2` and [`SqExponentialKernel`](@ref) for `n=∞`. """ struct MaternKernel{Tν<:Real} <: SimpleKernel ν::Vector{Tν} From e81cb0166904891957cb12803c71a623bac96fa4 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Wed, 26 Aug 2020 13:55:40 +0530 Subject: [PATCH 16/23] Make maha tests more readable --- test/basekernels/maha.jl | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/test/basekernels/maha.jl b/test/basekernels/maha.jl index 76a405a58..51d4f278a 100644 --- a/test/basekernels/maha.jl +++ b/test/basekernels/maha.jl @@ -18,12 +18,23 @@ fdm = FiniteDifferences.Central(5, 1); - FiniteDifferences.to_vec(dist::SqMahalanobis{Float64}) = vec(dist.qmat), x -> SqMahalanobis(reshape(x, size(dist.qmat)...)) - - @test_broken j′vp(fdm, x -> MahalanobisKernel(Array(x[1]'*x[1]))(x[2], x[3]), 1, [U, v1, v2]) ≈ - Zygote.pullback(x -> MahalanobisKernel(Array(x[1]'*x[1]))(x[2], x[3]), [U, v1, v2])[2](1) - @test all(j′vp(fdm, x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), 1, [U, v1, v2])[1][1] .≈ - Zygote.pullback(x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), [U, v1, v2])[2](1)[1][1]) + function FiniteDifferences.to_vec(dist::SqMahalanobis{Float64}) + return vec(dist.qmat), x -> SqMahalanobis(reshape(x, size(dist.qmat)...)) + end + + function test_mahakernel(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector) + return MahalanobisKernel(Array(U'*U))(v1, v2) + end + + @test_broken all(j′vp(fdm, test_mahakernel, 1, U, v1, v2) .≈ + Zygote.pullback(test_mahakernel, U, v1, v2)[2](1)) + + function test_sqmaha(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector) + return SqMahalanobis(Array(U'*U))(v1, v2) + end + + @test_broken all(j′vp(fdm, test_sqmaha, 1, U, v1, v2) .≈ + Zygote.pullback(test_sqmaha, U, v1, v2)[2](1)) # test_ADs(U -> MahalanobisKernel(Array(U' * U)), U, ADs=[:Zygote]) @test_broken "Nothing passes (problem with Mahalanobis distance in Distances)" From 4c2f2334dbcafb1829e8722465e9c0c97762c27e Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Mon, 31 Aug 2020 12:20:11 +0530 Subject: [PATCH 17/23] Address style issues --- src/basekernels/nn.jl | 12 ++++++------ src/zygote_adjoints.jl | 30 +++++++++++++++--------------- test/basekernels/nn.jl | 4 ++-- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/basekernels/nn.jl b/src/basekernels/nn.jl index 3f2ef194f..be1d35361 100644 --- a/src/basekernels/nn.jl +++ b/src/basekernels/nn.jl @@ -25,28 +25,28 @@ end function kernelmatrix(::NeuralNetworkKernel, x::ColVecs, y::ColVecs) validate_inputs(x, y) - X_2 = sum(x.X .* x.X, dims=1) - Y_2 = sum(y.X .* y.X, dims=1) + X_2 = sum(x.X .* x.X; dims=1) + Y_2 = sum(y.X .* y.X; dims=1) XY = x.X' * y.X return asin.(XY ./ sqrt.((X_2 .+ 1)' * (Y_2 .+ 1))) end function kernelmatrix(::NeuralNetworkKernel, x::ColVecs) - X_2_1 = sum(x.X .* x.X, dims=1) .+ 1 + X_2_1 = sum(x.X .* x.X; dims=1) .+ 1 XX = x.X' * x.X return asin.(XX ./ sqrt.(X_2_1' * X_2_1)) end function kernelmatrix(::NeuralNetworkKernel, x::RowVecs, y::RowVecs) validate_inputs(x, y) - X_2 = sum(x.X .* x.X, dims=2) - Y_2 = sum(y.X .* y.X, dims=2) + X_2 = sum(x.X .* x.X; dims=2) + Y_2 = sum(y.X .* y.X; dims=2) XY = x.X * y.X' return asin.(XY ./ sqrt.((X_2 .+ 1)' * (Y_2 .+ 1))) end function kernelmatrix(::NeuralNetworkKernel, x::RowVecs) - X_2_1 = sum(x.X .* x.X, dims=2) .+ 1 + X_2_1 = sum(x.X .* x.X; dims=2) .+ 1 XX = x.X * x.X' return asin.(XX ./ sqrt.(X_2_1' * X_2_1)) end diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index b18304f08..1bfda61c8 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -62,19 +62,19 @@ end @adjoint function ColVecs(X::AbstractMatrix) back(Δ::NamedTuple) = (Δ.X,) back(Δ::AbstractMatrix) = (Δ,) - function back(Δ::AbstractVector{<:AbstractVector{<:Real}}) + function ColVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) throw(error("In slow method")) end - return ColVecs(X), back + return ColVecs(X), ColVecs_pullback end @adjoint function RowVecs(X::AbstractMatrix) back(Δ::NamedTuple) = (Δ.X,) back(Δ::AbstractMatrix) = (Δ,) - function back(Δ::AbstractVector{<:AbstractVector{<:Real}}) + function RowVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) throw(error("In slow method")) end - return RowVecs(X), back + return RowVecs(X), RowVecs_pullback end @adjoint function Base.map(t::Transform, X::ColVecs) @@ -86,13 +86,13 @@ end end @adjoint function (dist::Distances.SqMahalanobis)(a, b) - function back(Δ::Real) + function SqMahalanobis_pullback(Δ::Real) B_Bᵀ = dist.qmat + transpose(dist.qmat) a_b = a - b δa = (B_Bᵀ * a_b) * Δ return (qmat = (a_b * a_b') * Δ,), δa, -δa end - return evaluate(dist, a, b), back + return evaluate(dist, a, b), SqMahalanobis_pullback end @@ -101,8 +101,8 @@ end a::AbstractMatrix, b::AbstractMatrix; dims::Union{Nothing,Integer}=nothing - ) - function back(Δ::AbstractMatrix) +) + function pairwise_pullback(Δ::AbstractMatrix) B_Bᵀ = dist.qmat + transpose(dist.qmat) a_b = map( x -> (first(last(x)) - last(last(x)))*first(x), @@ -115,15 +115,15 @@ end δB = sum(map(x -> x*transpose(x), a_b)) return (qmat=δB,), δa, -δa end - return Distances.pairwise(dist, a, b, dims=dims), back + return Distances.pairwise(dist, a, b, dims=dims), pairwise_pullback end @adjoint function Distances.pairwise( - dist::SqMahalanobis, - a::AbstractMatrix; - dims::Union{Nothing,Integer}=nothing - ) - function back(Δ::AbstractMatrix) + dist::SqMahalanobis, + a::AbstractMatrix; + dims::Union{Nothing,Integer}=nothing +) + function pairwise_pullback(Δ::AbstractMatrix) B_Bᵀ = dist.qmat + transpose(dist.qmat) a_a = map( x -> (first(last(x)) - last(last(x)))*first(x), @@ -136,6 +136,6 @@ end δB = sum(map(x -> x*transpose(x), a_a)) return (qmat=δB,), δa end - return Distances.pairwise(dist, a, b, dims=dims), back + return Distances.pairwise(dist, a, b, dims=dims), pairwise_pullback end diff --git a/test/basekernels/nn.jl b/test/basekernels/nn.jl index b021055c1..a46208505 100644 --- a/test/basekernels/nn.jl +++ b/test/basekernels/nn.jl @@ -38,8 +38,8 @@ @test kerneldiagmatrix(k, m1) ≈ A4 atol=1e-5 A5 = ones(4,4) - @test_throws AssertionError kernelmatrix!(A5, k, m1, m2, obsdim=3) - @test_throws AssertionError kernelmatrix!(A5, k, m1, obsdim=3) + @test_throws AssertionError kernelmatrix!(A5, k, m1, m2; obsdim=3) + @test_throws AssertionError kernelmatrix!(A5, k, m1; obsdim=3) @test_throws DimensionMismatch kernelmatrix!(A5, k, ones(4,3), ones(3,4)) @test k([x1], [x2]) ≈ k(x1, x2) atol=1e-5 From 0023292a4859942f0988c1f8cad9cd0a3f34fc2c Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Mon, 7 Sep 2020 14:18:26 +0530 Subject: [PATCH 18/23] Fix bugs in tests and adjoints --- src/zygote_adjoints.jl | 8 ++++---- test/basekernels/maha.jl | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index 1bfda61c8..7032a502b 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -60,8 +60,8 @@ end end @adjoint function ColVecs(X::AbstractMatrix) - back(Δ::NamedTuple) = (Δ.X,) - back(Δ::AbstractMatrix) = (Δ,) + ColVecs_pullback(Δ::NamedTuple) = (Δ.X,) + ColVecs_pullback(Δ::AbstractMatrix) = (Δ,) function ColVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) throw(error("In slow method")) end @@ -69,8 +69,8 @@ end end @adjoint function RowVecs(X::AbstractMatrix) - back(Δ::NamedTuple) = (Δ.X,) - back(Δ::AbstractMatrix) = (Δ,) + RowVecs_pullback(Δ::NamedTuple) = (Δ.X,) + RowVecs_pullback(Δ::AbstractMatrix) = (Δ,) function RowVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) throw(error("In slow method")) end diff --git a/test/basekernels/maha.jl b/test/basekernels/maha.jl index 51d4f278a..a54d59313 100644 --- a/test/basekernels/maha.jl +++ b/test/basekernels/maha.jl @@ -7,7 +7,7 @@ U = UpperTriangular(rand(rng, 3,3)) P = Matrix(Cholesky(U, 'U', 0)) @assert isposdef(P) - k = MahalanobisKernel(P) + k = MahalanobisKernel(P=P) @test kappa(k, x) == exp(-x) @test k(v1, v2) ≈ exp(-sqmahalanobis(v1, v2, P)) @@ -23,7 +23,7 @@ end function test_mahakernel(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector) - return MahalanobisKernel(Array(U'*U))(v1, v2) + return MahalanobisKernel(P=Array(U'*U))(v1, v2) end @test_broken all(j′vp(fdm, test_mahakernel, 1, U, v1, v2) .≈ @@ -36,7 +36,7 @@ @test_broken all(j′vp(fdm, test_sqmaha, 1, U, v1, v2) .≈ Zygote.pullback(test_sqmaha, U, v1, v2)[2](1)) - # test_ADs(U -> MahalanobisKernel(Array(U' * U)), U, ADs=[:Zygote]) + # test_ADs(U -> MahalanobisKernel(P=Array(U' * U)), U, ADs=[:Zygote]) @test_broken "Nothing passes (problem with Mahalanobis distance in Distances)" test_params(k, (P,)) From acdec1a688a5ce40c6259e7efbf9f5c2ced04539 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Mon, 7 Sep 2020 14:43:47 +0530 Subject: [PATCH 19/23] Fix maha tests --- test/basekernels/maha.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/basekernels/maha.jl b/test/basekernels/maha.jl index a54d59313..817d2a600 100644 --- a/test/basekernels/maha.jl +++ b/test/basekernels/maha.jl @@ -21,20 +21,21 @@ function FiniteDifferences.to_vec(dist::SqMahalanobis{Float64}) return vec(dist.qmat), x -> SqMahalanobis(reshape(x, size(dist.qmat)...)) end + a = rand() function test_mahakernel(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector) return MahalanobisKernel(P=Array(U'*U))(v1, v2) end - @test_broken all(j′vp(fdm, test_mahakernel, 1, U, v1, v2) .≈ - Zygote.pullback(test_mahakernel, U, v1, v2)[2](1)) + @test all(FiniteDifferences.j′vp(fdm, test_mahakernel, a, U, v1, v2)[1] .≈ + UpperTriangular(Zygote.pullback(test_mahakernel, U, v1, v2)[2](a)[1])) function test_sqmaha(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector) return SqMahalanobis(Array(U'*U))(v1, v2) end - @test_broken all(j′vp(fdm, test_sqmaha, 1, U, v1, v2) .≈ - Zygote.pullback(test_sqmaha, U, v1, v2)[2](1)) + @test all(FiniteDifferences.j′vp(fdm, test_sqmaha, a, U, v1, v2)[1] .≈ + UpperTriangular(Zygote.pullback(test_sqmaha, U, v1, v2)[2](a)[1])) # test_ADs(U -> MahalanobisKernel(P=Array(U' * U)), U, ADs=[:Zygote]) @test_broken "Nothing passes (problem with Mahalanobis distance in Distances)" From f467162390a10fe960ced0037ccda4ec58c03095 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Mon, 7 Sep 2020 14:55:44 +0530 Subject: [PATCH 20/23] Remove pairwise maha adjoints for now. --- src/zygote_adjoints.jl | 45 ---------------------------------------- test/basekernels/maha.jl | 2 +- 2 files changed, 1 insertion(+), 46 deletions(-) diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index 7032a502b..9dd56d3d3 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -94,48 +94,3 @@ end end return evaluate(dist, a, b), SqMahalanobis_pullback end - - -@adjoint function Distances.pairwise( - dist::SqMahalanobis, - a::AbstractMatrix, - b::AbstractMatrix; - dims::Union{Nothing,Integer}=nothing -) - function pairwise_pullback(Δ::AbstractMatrix) - B_Bᵀ = dist.qmat + transpose(dist.qmat) - a_b = map( - x -> (first(last(x)) - last(last(x)))*first(x), - zip( - Δ, - Iterators.product(eachslice(a, dims=dims), eachslice(b, dims=dims)) - ) - ) - δa = reduce(hcat, sum(map(x -> B_Bᵀ*x, a_b), dims=2)) - δB = sum(map(x -> x*transpose(x), a_b)) - return (qmat=δB,), δa, -δa - end - return Distances.pairwise(dist, a, b, dims=dims), pairwise_pullback -end - -@adjoint function Distances.pairwise( - dist::SqMahalanobis, - a::AbstractMatrix; - dims::Union{Nothing,Integer}=nothing -) - function pairwise_pullback(Δ::AbstractMatrix) - B_Bᵀ = dist.qmat + transpose(dist.qmat) - a_a = map( - x -> (first(last(x)) - last(last(x)))*first(x), - zip( - Δ, - Iterators.product(eachslice(a, dims=dims), eachslice(a, dims=dims)) - ) - ) - δa = reduce(hcat, sum(map(x -> B_Bᵀ*x, a_a), dims=2)) - δB = sum(map(x -> x*transpose(x), a_a)) - return (qmat=δB,), δa - end - return Distances.pairwise(dist, a, b, dims=dims), pairwise_pullback -end - diff --git a/test/basekernels/maha.jl b/test/basekernels/maha.jl index 817d2a600..43beeeb53 100644 --- a/test/basekernels/maha.jl +++ b/test/basekernels/maha.jl @@ -33,7 +33,7 @@ function test_sqmaha(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector) return SqMahalanobis(Array(U'*U))(v1, v2) end - + @test all(FiniteDifferences.j′vp(fdm, test_sqmaha, a, U, v1, v2)[1] .≈ UpperTriangular(Zygote.pullback(test_sqmaha, U, v1, v2)[2](a)[1])) From 651ae020cc828024afb55c56ca26b49dbedc765e Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Mon, 7 Sep 2020 15:38:27 +0530 Subject: [PATCH 21/23] Fix style issues --- src/zygote_adjoints.jl | 2 +- test/basekernels/maha.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index 9dd56d3d3..e3b15c115 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -92,5 +92,5 @@ end δa = (B_Bᵀ * a_b) * Δ return (qmat = (a_b * a_b') * Δ,), δa, -δa end - return evaluate(dist, a, b), SqMahalanobis_pullback + return evaluate(dist, a, b), SqMahalanobis_pullback end diff --git a/test/basekernels/maha.jl b/test/basekernels/maha.jl index 43beeeb53..4296bee83 100644 --- a/test/basekernels/maha.jl +++ b/test/basekernels/maha.jl @@ -8,7 +8,7 @@ P = Matrix(Cholesky(U, 'U', 0)) @assert isposdef(P) k = MahalanobisKernel(P=P) - + @test kappa(k, x) == exp(-x) @test k(v1, v2) ≈ exp(-sqmahalanobis(v1, v2, P)) @test kappa(ExponentialKernel(), x) == kappa(k, x) From 6b114d2f478261d174473aad5c53db7fae07bb07 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Mon, 7 Sep 2020 18:50:41 +0530 Subject: [PATCH 22/23] Update maha.jl --- test/basekernels/maha.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/basekernels/maha.jl b/test/basekernels/maha.jl index 4296bee83..1daf3cd69 100644 --- a/test/basekernels/maha.jl +++ b/test/basekernels/maha.jl @@ -8,7 +8,7 @@ P = Matrix(Cholesky(U, 'U', 0)) @assert isposdef(P) k = MahalanobisKernel(P=P) - + @test kappa(k, x) == exp(-x) @test k(v1, v2) ≈ exp(-sqmahalanobis(v1, v2, P)) @test kappa(ExponentialKernel(), x) == kappa(k, x) From 86559112385180e55425bba081a2964c4ac81cd2 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Tue, 8 Sep 2020 08:21:00 +0530 Subject: [PATCH 23/23] Fix style in zygote_adjoints.jl --- test/zygote_adjoints.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test/zygote_adjoints.jl b/test/zygote_adjoints.jl index d661e5721..b57750728 100644 --- a/test/zygote_adjoints.jl +++ b/test/zygote_adjoints.jl @@ -8,41 +8,41 @@ @assert isposdef(Q) - gzeucl = gradient(:Zygote, [x,y]) do xy + gzeucl = gradient(:Zygote, [x, y]) do xy evaluate(Euclidean(), xy[1], xy[2]) end - gzsqeucl = gradient(:Zygote, [x,y]) do xy + gzsqeucl = gradient(:Zygote, [x, y]) do xy evaluate(SqEuclidean(), xy[1], xy[2]) end - gzdotprod = gradient(:Zygote, [x,y]) do xy + gzdotprod = gradient(:Zygote, [x, y]) do xy evaluate(KernelFunctions.DotProduct(), xy[1], xy[2]) end - gzdelta = gradient(:Zygote, [x,y]) do xy + gzdelta = gradient(:Zygote, [x, y]) do xy evaluate(KernelFunctions.Delta(), xy[1], xy[2]) end - gzsinus = gradient(:Zygote, [x,y]) do xy + gzsinus = gradient(:Zygote, [x, y]) do xy evaluate(KernelFunctions.Sinus(r), xy[1], xy[2]) end - gzsqmaha = gradient(:Zygote, [Q,x,y]) do xy + gzsqmaha = gradient(:Zygote, [Q, x, y]) do xy evaluate(SqMahalanobis(xy[1]), xy[2], xy[3]) end - gfeucl = gradient(:FiniteDiff, [x,y]) do xy + gfeucl = gradient(:FiniteDiff, [x, y]) do xy evaluate(Euclidean(), xy[1], xy[2]) end - gfsqeucl = gradient(:FiniteDiff, [x,y]) do xy + gfsqeucl = gradient(:FiniteDiff, [x, y]) do xy evaluate(SqEuclidean(), xy[1], xy[2]) end - gfdotprod = gradient(:FiniteDiff, [x,y]) do xy + gfdotprod = gradient(:FiniteDiff, [x, y]) do xy evaluate(KernelFunctions.DotProduct(), xy[1], xy[2]) end - gfdelta = gradient(:FiniteDiff, [x,y]) do xy + gfdelta = gradient(:FiniteDiff, [x, y]) do xy evaluate(KernelFunctions.Delta(), xy[1], xy[2]) end - gfsinus = gradient(:FiniteDiff, [x,y]) do xy + gfsinus = gradient(:FiniteDiff, [x, y]) do xy evaluate(KernelFunctions.Sinus(r), xy[1], xy[2]) end - gfsqmaha = gradient(:FiniteDiff, [Q,x,y]) do xy + gfsqmaha = gradient(:FiniteDiff, [Q, x, y]) do xy evaluate(SqMahalanobis(xy[1]), xy[2], xy[3]) end