Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix AD issues with various kernels #154

Merged
merged 24 commits into from
Sep 8, 2020
Merged

Conversation

sharanry
Copy link
Contributor

@sharanry
Copy link
Contributor Author

I think the reason for failure of Zygote with MahalanobisKernel is mutating arrays at
https://github.com/JuliaStats/Distances.jl/blob/44036b573ec85287022f4368c4e1e279698bd031/src/mahalanobis.jl#L75.

src/basekernels/nn.jl Outdated Show resolved Hide resolved
src/basekernels/nn.jl Show resolved Hide resolved
src/basekernels/nn.jl Outdated Show resolved Hide resolved
src/basekernels/nn.jl Show resolved Hide resolved
src/basekernels/nn.jl Outdated Show resolved Hide resolved
src/basekernels/nn.jl Outdated Show resolved Hide resolved
src/basekernels/nn.jl Show resolved Hide resolved
src/basekernels/nn.jl Outdated Show resolved Hide resolved
test/basekernels/fbm.jl Show resolved Hide resolved
test/basekernels/gabor.jl Show resolved Hide resolved
@sharanry
Copy link
Contributor Author

I think the reason for failure of Zygote with MahalanobisKernel is mutating arrays at
https://github.com/JuliaStats/Distances.jl/blob/44036b573ec85287022f4368c4e1e279698bd031/src/mahalanobis.jl#L75.

@devmotion Do you suggest I override their pairwise implementation for now or open an issue/PR to Distances.jl?

@devmotion
Copy link
Member

I guess that should/could be resolved by adding a custom ChainRules-based adjoint for pairwise(::Mahalanobis, ...), similar to https://github.com/FluxML/Zygote.jl/blob/956575ee2c732dee25324b59ba43fbb471a52d9a/src/lib/distances.jl#L19-L25. It seems it's not yet decided if Distances would accept ChainRules PRs (see JuliaStats/Distances.jl#172), so one could either make a PR to Zygote or add some piracy to KernelFunctions for now, I guess.

@sharanry
Copy link
Contributor Author

sharanry commented Aug 19, 2020

Sorry for the delay. I am having a hard time defining adjoints which aren't very computationally expensive for the Maha kernel.

@devmotion
Copy link
Member

devmotion commented Aug 19, 2020

I guess you shouldn't need anything in particular for the Mahalanobis kernel but rather (just) a custom adjoint for the distance computations in Distances? In this case the Matrix cookbook, and in particular equations 72 and 81 are helpful. They show you that d((x-y)'*Q*(x-y))/dx = (Q + Q') * (x - y), d((x-y)'*Q*(x-y))/dy = - (Q + Q') * (x - y), and d((x-y)'*Q*(x-y))/dQ = (x - y) * (x - y)' (you should recheck that I didn't make any stupid mistakes 😄). These expressions aren't too bad, I guess, but in general the Mahalanobis distance isn't computationally cheap so it's not too surprising that the adjoints aren't either.

BTW I just noticed that the docstring of the kernel is incorrect since the distance computation (

metric::MahalanobisKernel) = SqMahalanobis.P)
) uses P but not the inverse of P (Distances doesn't use the inverse either, according to their README).

@sharanry
Copy link
Contributor Author

sharanry commented Aug 20, 2020

@devmotion Thanks for pointing out the typo in the docstring!
The problem I am facing is for pairwise computation. Defining an efficient adjoint seems quite tricky based on the implementation here. https://github.com/JuliaStats/Distances.jl/blob/44036b573ec85287022f4368c4e1e279698bd031/src/mahalanobis.jl#L62

src/basekernels/maha.jl Outdated Show resolved Hide resolved
src/zygote_adjoints.jl Outdated Show resolved Hide resolved
src/zygote_adjoints.jl Outdated Show resolved Hide resolved
Comment on lines 108 to 116
a_b = map(
x -> (first(last(x)) - last(last(x)))*first(x),
zip(
Δ,
Iterators.product(eachslice(a, dims=dims), eachslice(b, dims=dims))
)
)
δa = reduce(hcat, sum(map(x -> B_B_t*x, a_b), dims=1))
δB = sum(map(x -> x*transpose(x), a_b))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would assume it should be possible to vectorize this code? What's the mathematical formula that you use here?

Copy link
Contributor Author

@sharanry sharanry Aug 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is the same equations you mentioned earlier.
d((x-y)'*Q*(x-y))/dx = (Q + Q') * (x - y), d((x-y)'*Q*(x-y))/dy = - (Q + Q') * (x - y), and d((x-y)'*Q*(x-y))/dQ = (x - y)' * (x - y) .
But this is being done for all pairwise combinations together using map. It later sums these differences to get \deltaB and others.
Please note that the current implementation is not correct. I am still debugging it. (it is only partially matching the intended result) If you happen to find any obvious mistakes please let me know. I am facing trouble in reducing the results of individual pairwise pullbacks to the final pullback. The way I am summing them is probably wrong.

Copy link
Contributor Author

@sharanry sharanry Aug 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

julia> using Distances, Random;

julia> rng = MersenneTwister(123);

julia> M1, M2 = rand(rng, 2,3), rand(rng, 2,3);

julia> dist = SqMahalanobis(rand(rng, 2,2))
SqMahalanobis{Float64}([0.8654121434083455 0.2856979003853177; 0.617491887982287 0.46384720826189474])

julia> pairwise(dist, M1, M2; dims=2)
3×3 Array{Float64,2}:
  0.371673   0.856348  0.742803
  0.0233992  0.274278  0.276694
 -0.036568   0.118487  0.0748149

julia> map(x -> evaluate(dist, first(x), last(x)), Iterators.product(eachslice(M1, dims=2), eachslice(M2, dims=2)))
3×3 Array{Float64,2}:
 0.541253   0.912421  0.673273
 0.0886328  0.285181  0.192394
 0.0868399  0.166227  0.0616321

@devmotion isn't this wrong or have I done something silly? They are equal in case of euclidean. I feel this is the root of the problem.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should work if dist.qmat is positive definite: JuliaStats/Distances.jl#174

Copy link
Contributor Author

@sharanry sharanry Aug 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still does not solve the differences in the computed adjoints for the covariance matrix Q. My current implementation matches the second adjoint.

julia> using Distances, LinearAlgebra, FiniteDifferences, Random

julia> FiniteDifferences.to_vec(dist::SqMahalanobis{Float64}) = vec(dist.qmat), x -> SqMahalanobis(reshape(x, size(dist.qmat)...))

julia> rng = MersenneTwister(123);

julia> M1, M2 = rand(rng,3,1), rand(rng,3,1)
([0.7684476751965699; 0.940515000715187; 0.6739586945680673], [0.3954531123351086; 0.3132439558075186; 0.6625548164736534])

julia> Q = Matrix(Cholesky(rand(rng, 3, 3), 'U', 0))
3×3 Array{Float64,2}:
 0.343422   0.0638007  0.507151
 0.0638007  0.0386393  0.19528
 0.507151   0.19528    1.21186

julia> isposdef(Q)
true

julia> dist = SqMahalanobis(Q);

julia> fdm=FiniteDifferences.Central(5, 1);

julia> j′vp(fdm, pairwise, ones(1,1), dist, M1, M2)[1].qmat #A
3×3 Array{Float64,2}:
 0.139125  0.365187  -0.238366
 0.102751  0.393469  -0.404876
 0.246873  0.419183   0.000130048

julia> j′vp(fdm, evaluate, 1, dist, M1[:, 1], M2[:, 1])[1].qmat #B
3×3 Array{Float64,2}:
 0.139125    0.233969    0.00425358
 0.233969    0.393469    0.00715332
 0.00425358  0.00715332  0.000130048

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it is best if (Sq)Mahalanobis distance is actually parameterized by the decomposition of Q, i.e, the upper or lower triangular matrix which is not constrained.

Yes, that would be the most natural way to ensure that it is always positive semi-definite (if the diagonal is non-negative) and optimization is performed in the correct space. So I guess users would want to use this parameterization even if it is not enforced by KernelFunctions and not directly supported by SqMahalanobis by using something like

function mykernel(L)
    idxs = diagind(L)
    @inbounds for i in idxs
        L[i] = softplus(L[i])
    end
    return MahalanobisKernel(Array(L * L'))
end

Of course, it would be nice if (Sq)Mahalanobis would support specifying e.g. a Cholesky decomposition or PDMat directly (it could even be used for simplifying the computations since x'*Q*x = (L'*x)'*(L'*x) in this case), but can't we work around this by checking gradients of the mykernel setup instead of computing Q -> MahalanobisKernel(Q) directly? That's at least how we do it in DistributionsAD, e.g. in https://github.com/TuringLang/DistributionsAD.jl/blob/a96b159ab25aab67d1a2076726e8b9c392eb6fc7/test/ad/distributions.jl#L18-L34.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but can't we work around this by checking gradients of the mykernel setup instead of computing Q -> MahalanobisKernel(Q) directly?

Yeah that should work. Will try that out.

Regarding the issue with pairwise implementation which messes up FiniteDifferences results, do you suggest I override the implementation for the time being?

Copy link
Member

@devmotion devmotion Aug 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you test the suggested parameterization the implementation of pairwise shouldn't matter (since we do not test the intermediate step which might be affected by it).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. Could we also change our side of the parametrization? i.e, the way it is stored in the struct. We could continue to allow initialization using a full matrix. This should allow for seamless AD regardless of how the user decides to initialize them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we want to do that, I think this deserves some discussion first (and then a separate PR possibly). Ideally, Distances would just support arbitrary matrices and contain optimized implementations for specific array types. We just forward P to SqMahalanobis, so ideally we wouldn't perform any transformations or computations. I'm also a bit worried that focusing on a specific parameterization might make it difficult for users who would like to use a different one (but still no dense matrix) or might lead to confusing behaviour.

Comment on lines 23 to 26
@test_broken j′vp(fdm, x -> MahalanobisKernel(Array(x[1]'*x[1]))(x[2], x[3]), 1, [U, v1, v2]) ≈
Zygote.pullback(x -> MahalanobisKernel(Array(x[1]'*x[1]))(x[2], x[3]), [U, v1, v2])[2](1)
@test all(j′vp(fdm, x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), 1, [U, v1, v2])[1][1] .≈
Zygote.pullback(x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), [U, v1, v2])[2](1)[1][1])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@devmotion I tried doing what you suggested. The tests still fail. This error probably propagates and causes even the first test to fail.

julia> j′vp(fdm, x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), 1, [U, v1, v2])[1][1]
3×3 UpperTriangular{Float64,Array{Float64,2}}:
 0.228808   0.00318764   -0.107503
          -0.000391803   0.0132135
                        0.0438772

julia> Zygote.pullback(x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), [U, v1, v2])[2](1)[1][1]
3×3 Array{Float64,2}:
  0.228808    0.00318764   -0.107503
 -0.0281234  -0.000391803   0.0132135
 -0.0933875  -0.00130103    0.0438772

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me your output indicates that it basically works apart from the fact that Zygote incorrectly returns a dense matrix instead of an upper triangular matrix. Since U was upper triangular, only the values above and on the diagonal should be returned.

Copy link
Contributor Author

@sharanry sharanry Aug 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FiniteDifferences if pretty good in matching the types. Zygote isn't. Do you suggest we manually check if the upper triangular part matches for now?

Edit: I don't we are addressing the major issue here. Our goal is to make the overall adjoint correct for kernelmatrix. So maybe defining a custom zygote adjoint for UpperTriangular which outputs a UpperTriangular might solve the problem.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Were the call to UpperTriangular inside the function, then the adjoint that you would get from Zygote would also be UpperTriangular. Maybe just do that?

fdm = FiniteDifferences.Central(5, 1);


FiniteDifferences.to_vec(dist::SqMahalanobis{Float64}) = vec(dist.qmat), x -> SqMahalanobis(reshape(x, size(dist.qmat)...))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed? If possible, we should avoid this type piracy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes j′vp only works when there is a to_vec function defined for each argument.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering since according to the docs to_vec is only needed for the inputs xs... but not the evaluated function f in j'vp(fdm, f, xs...).

Copy link
Contributor Author

@sharanry sharanry Aug 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I understand, it is also needed for objects like SqMahalanobis if they have parameters like qmat.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct, but actually for some reason we've not made FiniteDifferences handle functions-with-data properly yet, so you'll have to build the SqMaha object inside of the function that you're differentiating.

test/basekernels/maha.jl Outdated Show resolved Hide resolved
Copy link
Member

@theogf theogf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow the solution to have to define kernelmatrix again for the NeuralNetworkKernel seems very hacky, isn't there another solution?

src/zygote_adjoints.jl Outdated Show resolved Hide resolved
Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking good. Just some style things.

src/basekernels/nn.jl Outdated Show resolved Hide resolved
src/basekernels/nn.jl Outdated Show resolved Hide resolved
src/basekernels/nn.jl Outdated Show resolved Hide resolved
src/basekernels/nn.jl Outdated Show resolved Hide resolved
src/zygote_adjoints.jl Outdated Show resolved Hide resolved
src/zygote_adjoints.jl Outdated Show resolved Hide resolved
src/zygote_adjoints.jl Outdated Show resolved Hide resolved
src/zygote_adjoints.jl Outdated Show resolved Hide resolved
src/zygote_adjoints.jl Outdated Show resolved Hide resolved
@willtebbutt willtebbutt mentioned this pull request Aug 30, 2020
2 tasks
)
δa = reduce(hcat, sum(map(x -> B_Bᵀ*x, a_b), dims=2))
δB = sum(map(x -> x*transpose(x), a_b))
return (qmat=δB,), δa, -δa
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is som discrepancy between the simple case above and this pullback - intuitively, from the simple case above I would assume that δB = sum_{i, j} (a_i - b_j) * (a_i - b_j)^T * Δ_{i,j}. However, here you compute δB = sum_{i, j} (a_i - b_j) * (a_i - b_j)^T * Δ_{i,j}^2. Probably one of them is incorrect (table 7 in https://notendur.hi.is/jonasson/greinar/blas-rmd.pdf indicates that the pairwise one is incorrect). Can we add the derivation of the adjoints according to https://www.juliadiff.org/ChainRulesCore.jl/dev/arrays.html as docstrings or comments, or maybe even have a separate PR for the Mahalanobis fixes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. I think a separate PR for mahalanobis fixes makes more sense.

@devmotion
Copy link
Member

Somehow the solution to have to define kernelmatrix again for the NeuralNetworkKernel seems very hacky, isn't there another solution?

I guess one could define a "PreMetric" that evaluates dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y))) (or asin(dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y)))))), similar to DotProduct, and make NeuralNetworkKernel a SimpleKernel. But even in this case one might want to implement (a) specialized version(s) of pairwise, so I'm not sure how much one would gain.

@sharanry
Copy link
Contributor Author

sharanry commented Sep 7, 2020

Can we merge this and tackle each of the remaining AD issues in separate PRs? It is getting increasingly tricky to address multiple issues at once.

Currently this PR does the following:

  • Defines kernelmatrix function for NeuralNetworkKernel.
  • Defines Zygote adjoints for Mahalanobis distance metric.
  • Zygote tests pass for Exponential, FBM, NN and Gabor kernels.

@devmotion
Copy link
Member

devmotion commented Sep 7, 2020

IMO this PR contains already too many changes, we should just focus on one AD problem at a time.

Defines Zygote adjoints for Mahalanobis distance metric.

I thought the idea was not include these adjoints since they were missing a clean derivation/documentation and were incorrect? Or are you talking about the non-pairwise adjoints only?

@sharanry
Copy link
Contributor Author

sharanry commented Sep 7, 2020

I thought the idea was not include these adjoints since they were missing a clean derivation/documentation and were incorrect? Or are you talking about the non-pairwise adjoints only?

I meant only the non-pairwise adjoint . I will be removing the pairwise adjoints for now.

src/zygote_adjoints.jl Outdated Show resolved Hide resolved
test/basekernels/maha.jl Outdated Show resolved Hide resolved
@sharanry
Copy link
Contributor Author

sharanry commented Sep 7, 2020

Any objections to merging this?

Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no objections other than these tiny style-related things. This is a great PR.

test/zygote_adjoints.jl Outdated Show resolved Hide resolved
test/zygote_adjoints.jl Outdated Show resolved Hide resolved
@sharanry sharanry merged commit 5c24f1c into master Sep 8, 2020
@sharanry sharanry deleted the sharan/fix-AD-issues branch September 8, 2020 09:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants