Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gradient issues with kernelmatrix_diag and use ChainRulesCore #208

Merged
merged 50 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
e525614
Use broadcasting instead of map for kerneldiagmatrix
theogf Dec 9, 2020
e56492a
Removed method for transformedkernel
theogf Dec 9, 2020
35a6306
Restored functions and applied suggestions
theogf Dec 14, 2020
25e5efd
Added tests for diagmatrix
theogf Dec 14, 2020
2f85ebc
Put changes to the right file and removed utils_AD.jl
theogf Dec 14, 2020
cae225f
Apply suggestions from code review
theogf Dec 14, 2020
3f16f07
Added colwise and fixed kerneldiagmatrix
theogf Dec 15, 2020
8c0d0a2
Added colwise for RowVecs and ColVecs
theogf Dec 16, 2020
13a10fd
Removed definition relying on Distances.colwise!
theogf Dec 21, 2020
78a2078
Merge branch 'master' into fix_diagmat
theogf Mar 16, 2021
5ca94e7
Readapt to kernelmatrix_diag
theogf Mar 16, 2021
2c60abd
Fixes for Zygote
theogf Mar 16, 2021
9214211
Remove type piracy
theogf Mar 16, 2021
87edbc8
Adding some adjoints (not everything fixed yet)
theogf Mar 17, 2021
f65556b
Fixed adjoint for polynomials
theogf Mar 17, 2021
48e2dcb
Add ChainRulesCore for defining rrule
theogf Mar 17, 2021
6cc803d
Replace broadcast by map
theogf Mar 17, 2021
0e30941
Missing return for style
theogf Mar 17, 2021
61869b1
Fixing ZygoteRules
theogf Mar 22, 2021
06bd4f0
Renamed zygote_adjoints to chainrules
theogf Mar 22, 2021
8e1e516
Apply formatting suggestions
theogf Mar 22, 2021
aaa16de
Added forward rule for Euclidean distance
theogf Mar 22, 2021
52b1ae5
Corrected rules for Row/ColVecs constructors
theogf Mar 22, 2021
4067a42
Added ZygoteRules back for the "map hack"
theogf Mar 22, 2021
641ebee
Corrected the rrules
theogf Mar 22, 2021
13d1e39
Type stable frule
theogf Mar 22, 2021
4675c2f
Corrected tests
theogf Mar 23, 2021
0b97c1a
Adapted the use of Distances.jl
theogf Mar 23, 2021
ad9838e
Added methods to make nn work
theogf Mar 23, 2021
650dc08
Missing kernelmatrix_diag
theogf Mar 23, 2021
1703db1
Formatting suggestions
theogf Mar 23, 2021
e2cd167
Added methods for FBM
theogf Mar 23, 2021
01ffac0
Last fix on Delta
theogf Mar 23, 2021
9bfb6eb
Potential fix for Euclidean
theogf Mar 23, 2021
f3fa4bc
Missing Distances.
theogf Mar 23, 2021
a0c2a64
Wrong file naming
theogf Mar 23, 2021
ff5a66b
Correct formatting
theogf Mar 23, 2021
8157b4c
Better error message
theogf Mar 23, 2021
e6bfdb1
Moar formatting
theogf Mar 23, 2021
db5e7b8
Applied suggestions
theogf Mar 24, 2021
a44a762
Fixed the dims issue with pairwise
theogf Mar 24, 2021
72889dd
Fixed formatting
theogf Mar 24, 2021
25549c1
Missing @thunk
theogf Mar 24, 2021
bbe5c7c
Putting back Composite to Any
theogf Mar 24, 2021
e08dbf4
add @thunk for -delta a
theogf Mar 24, 2021
48bd681
Update src/chainrules.jl
theogf Mar 25, 2021
3298d34
Update KernelFunctions.jl
theogf Mar 25, 2021
0b99771
Apply suggestions from code review
theogf Mar 25, 2021
c26edf3
Update Project.toml
theogf Mar 25, 2021
647862a
Merge branch 'master' into fix_diagmat
theogf Mar 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.8.24"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Expand All @@ -17,6 +18,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
ChainRulesCore = "0.9"
Compat = "3.7"
Distances = "0.9.1, 0.10"
Functors = "0.1"
Expand Down
5 changes: 3 additions & 2 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,12 @@ export IndependentMOKernel, LatentFactorMOKernel
export tensor, ⊗

using Compat
using ChainRulesCore
using Requires
using Distances, LinearAlgebra
using Functors
using SpecialFunctions: loggamma, besselk, polygamma
using ZygoteRules: @adjoint, pullback
# using ZygoteRules: @adjoint, pullback, ZygoteRules
using StatsFuns: logtwo
using StatsBase
using TensorCore
Expand Down Expand Up @@ -112,7 +113,7 @@ include(joinpath("mokernels", "moinput.jl"))
include(joinpath("mokernels", "independent.jl"))
include(joinpath("mokernels", "slfm.jl"))

include("zygote_adjoints.jl")
include("chainrules.jl")

include("test_utils.jl")

Expand Down
142 changes: 142 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
## Reverse Rules Delta

function rrule(::typeof(Distances.evaluate), s::Delta, x::AbstractVector, y::AbstractVector)
d = evaluate(s, x, y)
function evaluate_pullback(::Any)
theogf marked this conversation as resolved.
Show resolved Hide resolved
return NO_FIELDS, Zero(), Zero()
end
return d, evaluate_pullback
end

function rrule(
::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2
)
P = Distances.pairwise(d, X, Y; dims=dims)
function pairwise_pullback(::Any)
theogf marked this conversation as resolved.
Show resolved Hide resolved
return NO_FIELDS, Zero(), Zero()
end
return P, pairwise_pullback
end

function rrule(::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2)
P = Distances.pairwise(d, X; dims=dims)
function pairwise_pullback(::Any)
theogf marked this conversation as resolved.
Show resolved Hide resolved
return NO_FIELDS, Zero()
end
return P, pairwise_pullback
end

function rrule(::typeof(Distances.colwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix)
C = Distances.colwise(d, X, Y)
function colwise_pullback(::AbstractVector)
return NO_FIELDS, Zero(), Zero()
end
return C, colwise_pullback
end

## Reverse Rules DotProduct
function rrule(
::typeof(Distances.evaluate), s::DotProduct, x::AbstractVector, y::AbstractVector
)
d = dot(x, y)
function evaluate_pullback(Δ)
return NO_FIELDS, Δ .* y, Δ .* x
end
return d, evaluate_pullback
end

function rrule(
::typeof(Distances.pairwise),
d::DotProduct,
X::AbstractMatrix,
Y::AbstractMatrix;
dims=2,
)
P = Distances.pairwise(d, X, Y; dims=dims)
if dims == 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this check lead to any type inference problems?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the if dims == 1 ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Mainly because you return two different functions depending on the value of dim, so I wonder if it messes up type inference and if it would be better to move the check inside the pullback function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I don't think the dims argument passed to the pullback ?

function pairwise_pullback_cols(Δ)
return NO_FIELDS, Δ * Y, Δ' * X
end
return P, pairwise_pullback_cols
else
function pairwise_pullback_rows(Δ)
return NO_FIELDS, Y * Δ', X * Δ
end
return P, pairwise_pullback_rows
end
end

function rrule(::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2)
P = Distances.pairwise(d, X; dims=dims)
if dims == 1
function pairwise_pullback_cols(Δ)
return NO_FIELDS, 2 * Δ * X
end
return P, pairwise_pullback_cols
else
function pairwise_pullback_rows(Δ)
return NO_FIELDS, 2 * X * Δ
end
return P, pairwise_pullback_rows
end
end

function rrule(
::typeof(Distances.colwise), d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix
)
C = Distances.colwise(d, X, Y)
function colwise_pullback(Δ::AbstractVector)
return (nothing, Δ' .* Y, Δ' .* X)
theogf marked this conversation as resolved.
Show resolved Hide resolved
end
return C, colwise_pullback
end

## Reverse Rules Sinus
function rrule(::typeof(Distances.evaluate), s::Sinus, x::AbstractVector, y::AbstractVector)
d = (x - y)
theogf marked this conversation as resolved.
Show resolved Hide resolved
sind = sinpi.(d)
val = sum(abs2, sind ./ s.r)
gradx = 2π .* cospi.(d) .* sind ./ (s.r .^ 2)
theogf marked this conversation as resolved.
Show resolved Hide resolved
function evaluate_pullback(Δ)
return (r=-2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, -Δ * gradx
theogf marked this conversation as resolved.
Show resolved Hide resolved
end
return val, evaluate_pullback
end

## Reverse Rules for matrix wrappers

function rrule(::ColVecs, X::AbstractMatrix)
ColVecs_pullback(Δ::NamedTuple) = (Δ.X,)
ColVecs_pullback(Δ::AbstractMatrix) = (Δ,)
theogf marked this conversation as resolved.
Show resolved Hide resolved
function ColVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}})
return throw(error("In slow method"))
theogf marked this conversation as resolved.
Show resolved Hide resolved
end
return ColVecs(X), ColVecs_pullback
end

function rrule(::RowVecs, X::AbstractMatrix)
RowVecs_pullback(Δ::NamedTuple) = (Δ.X,)
RowVecs_pullback(Δ::AbstractMatrix) = (Δ,)
theogf marked this conversation as resolved.
Show resolved Hide resolved
function RowVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}})
return throw(error("In slow method"))
end
return RowVecs(X), RowVecs_pullback
end

theogf marked this conversation as resolved.
Show resolved Hide resolved
# function rrule(::typeof(Base.map), t::Transform, X::ColVecs)
# return pullback(_map, t, X)
# end

# function rrule(::typeof(Base.map), t::Transform, X::RowVecs)
# return pullback(_map, t, X)
# end

# @adjoint function (dist::Distances.SqMahalanobis)(a, b)
# function SqMahalanobis_pullback(Δ::Real)
# B_Bᵀ = dist.qmat + transpose(dist.qmat)
# a_b = a - b
# δa = (B_Bᵀ * a_b) * Δ
# return (qmat=(a_b * a_b') * Δ,), δa, -δa
# end
# return evaluate(dist, a, b), SqMahalanobis_pullback
# end
26 changes: 26 additions & 0 deletions src/distances/pairwise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,29 @@ function pairwise!(
)
return Distances.pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
end

# Also defines the colwise method for abstractvectors

function colwise(d::PreMetric, x::ColVecs)
return Distances.colwise(d, x.X, x.X)
end

function colwise(d::PreMetric, x::RowVecs)
return Distances.colwise(d, x.X', x.X')
end

function colwise(d::PreMetric, x::AbstractVector)
return map(d, x, x)
end

function colwise(d::PreMetric, x::ColVecs, y::ColVecs)
return Distances.colwise(d, x.X, y.X)
end

function colwise(d::PreMetric, x::RowVecs, y::RowVecs)
return Distances.colwise(d, x.X', y.X')
end

function colwise(d::PreMetric, x::AbstractVector, y::AbstractVector)
return map(d, x, y)
end
theogf marked this conversation as resolved.
Show resolved Hide resolved
theogf marked this conversation as resolved.
Show resolved Hide resolved
10 changes: 10 additions & 0 deletions src/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ function kernelmatrix_diag!(K::AbstractVector, κ::TransformedKernel, x::Abstrac
return kernelmatrix_diag!(K, κ.kernel, _map(κ.transform, x))
end

function kernelmatrix_diag!(
K::AbstractVector, κ::TransformedKernel, x::AbstractVector, y::AbstractVector
)
return kernelmatrix_diag!(K, κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
end

function kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector)
return kernelmatrix!(K, kernel(κ), _map(κ.transform, x))
end
Expand All @@ -94,6 +100,10 @@ function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector)
return kernelmatrix_diag(κ.kernel, _map(κ.transform, x))
end

function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector, y::AbstractVector)
return kernelmatrix_diag(κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
end

function kernelmatrix(κ::TransformedKernel, x::AbstractVector)
return kernelmatrix(kernel(κ), _map(κ.transform, x))
end
Expand Down
16 changes: 12 additions & 4 deletions src/matrix/kernelmatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,24 +80,32 @@ kernelmatrix_diag(κ::Kernel, x::AbstractVector, y::AbstractVector) = map(κ, x,
function kernelmatrix!(K::AbstractMatrix, κ::SimpleKernel, x::AbstractVector)
validate_inplace_dims(K, x)
pairwise!(K, metric(κ), x)
return map!(d -> kappa(κ, d), K, K)
return map!(Base.Fix1(kappa, κ), K, K)
end

function kernelmatrix!(
K::AbstractMatrix, κ::SimpleKernel, x::AbstractVector, y::AbstractVector
)
validate_inplace_dims(K, x, y)
pairwise!(K, metric(κ), x, y)
return map!(d -> kappa(κ, d), K, K)
return map!(Base.Fix1(kappa, κ), K, K)
end

function kernelmatrix(κ::SimpleKernel, x::AbstractVector)
return map(d -> kappa(κ, d), pairwise(metric(κ), x))
return map(Base.Fix1(kappa, κ), pairwise(metric(κ), x))
end

function kernelmatrix(κ::SimpleKernel, x::AbstractVector, y::AbstractVector)
validate_inputs(x, y)
return map(d -> kappa(κ, d), pairwise(metric(κ), x, y))
return map(Base.Fix1(kappa, κ), pairwise(metric(κ), x, y))
end

function kernelmatrix_diag(κ::SimpleKernel, x::AbstractVector)
return map(Base.Fix1(kappa, κ), colwise(metric(κ), x))
end

function kernelmatrix_diag(κ::SimpleKernel, x::AbstractVector, y::AbstractVector)
return map(Base.Fix1(kappa, κ), colwise(metric(κ), x, y))
end

#
Expand Down
98 changes: 0 additions & 98 deletions src/zygote_adjoints.jl

This file was deleted.

Loading