-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementing the Greenkhorn #159
base: master
Are you sure you want to change the base?
Changes from all commits
0db7257
21514cf
436fb16
dac80ba
47c3ec8
163fb9c
82b1026
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,114 @@ | ||||||||||
# Greenkhorn is a greedy version of the Sinkhorn algorithm | ||||||||||
# This method is from https://arxiv.org/pdf/1705.09634.pdf | ||||||||||
# Code is based on implementation from package POT | ||||||||||
|
||||||||||
|
||||||||||
""" | ||||||||||
Greenkhorn() | ||||||||||
|
||||||||||
Greenkhorn is a greedy version of the Sinkhorn algorithm. | ||||||||||
""" | ||||||||||
struct Greenkhorn <: Sinkhorn end | ||||||||||
|
||||||||||
struct GreenkhornCache{U,V,KT} | ||||||||||
u::U | ||||||||||
v::V | ||||||||||
K::KT | ||||||||||
Kv::U #placeholder | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why "placeholder"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a partial solution. Without this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean you use it explicitly below for checking convergence? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, it's been used to check the convergence. But I think the convergence check might be done another way more efficiently. But I was having trouble getting it to work with the existing "api" for Sinkhorn, so I updated You are right that, as is, |
||||||||||
G::KT | ||||||||||
du::U | ||||||||||
dv::V | ||||||||||
end | ||||||||||
|
||||||||||
Base.show(io::IO, ::Greenkhorn) = print(io, "Greenkhorn algorithm") | ||||||||||
|
||||||||||
function build_cache( | ||||||||||
::Type{T}, | ||||||||||
::Greenkhorn, | ||||||||||
size2::Tuple, | ||||||||||
μ::AbstractVecOrMat, | ||||||||||
ν::AbstractVecOrMat, | ||||||||||
C::AbstractMatrix, | ||||||||||
ε::Real, | ||||||||||
) where {T} | ||||||||||
# compute Gibbs kernel (has to be mutable for ε-scaling algorithm) | ||||||||||
K = similar(C, T) | ||||||||||
@. K = exp(-C / ε) | ||||||||||
|
||||||||||
# create and initialize dual potentials | ||||||||||
u = similar(μ, T, size(μ, 1), size2...) | ||||||||||
v = similar(ν, T, size(ν, 1), size2...) | ||||||||||
fill!(u, one(T)/size(μ, 1)) | ||||||||||
fill!(v, one(T)/size(ν, 1)) | ||||||||||
|
||||||||||
G = sinkhorn_plan(u, v, K) | ||||||||||
# G = diagm(u) * K * diagm(v) | ||||||||||
|
||||||||||
Kv = similar(u) | ||||||||||
|
||||||||||
# This is me triying to get the `batch tests to work` | ||||||||||
# improve this! | ||||||||||
# if (length(size(μ)) == 2 && length(size(ν)) == 1) | ||||||||||
# du = reshape(sum(G, dims=2), size(μ)) - μ | ||||||||||
# dv = reshape(sum(G, dims=1),size(v)) - repeat(ν,1,size(v)[2]) | ||||||||||
# elseif (length(size(μ)) == 1 && length(size(ν)) == 2) | ||||||||||
# du = reshape(sum(G, dims=2),size(u)) - repeat(μ,1,size(u)[2]) | ||||||||||
# dv = reshape(sum(G, dims=1), size(ν)) - ν | ||||||||||
# else | ||||||||||
du = reshape(sum(G, dims=2), size(μ)) - μ | ||||||||||
dv = reshape(sum(G, dims=1), size(ν)) - ν | ||||||||||
# end | ||||||||||
|
||||||||||
|
||||||||||
return GreenkhornCache(u, v, K, Kv, G, du, dv) | ||||||||||
end | ||||||||||
|
||||||||||
prestep!(::SinkhornSolver{Greenkhorn}, ::Int) = nothing | ||||||||||
|
||||||||||
init_step!(solver::SinkhornSolver{<:Greenkhorn}) = nothing | ||||||||||
|
||||||||||
function step!(solver::SinkhornSolver{<:Greenkhorn}, iter::Int) | ||||||||||
μ = solver.source | ||||||||||
ν = solver.target | ||||||||||
cache = solver.cache | ||||||||||
u = cache.u | ||||||||||
v = cache.v | ||||||||||
K = cache.K | ||||||||||
G = cache.G | ||||||||||
Δμ= cache.du | ||||||||||
Δν= cache.dv | ||||||||||
|
||||||||||
# The implementation in POT does not compute `μ .* log.(μ ./ sum(G', dims=1)[:])` | ||||||||||
# or `ν .* log.(ν ./ sum(G', dims=2)[:])`. Yet, this term is present in the original | ||||||||||
# paper, where it uses ρ(a,b) = b - a + a log(a/b). | ||||||||||
# ρμ = abs.(Δμ + μ .* log.(μ ./ sum(G', dims=1)[:])) | ||||||||||
# ρν = abs.(Δν + ν .* log.(ν ./ sum(G', dims=2)[:])) | ||||||||||
|
||||||||||
i₁ = argmax(abs.(Δμ)) | ||||||||||
i₂ = argmax(abs.(Δν)) | ||||||||||
Comment on lines
+87
to
+88
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
||||||||||
# if ρμ[i₁]> ρν[i₂] | ||||||||||
if abs(Δμ[i₁]) > abs(Δν[i₂]) | ||||||||||
Comment on lines
+90
to
+91
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
old_u = u[i₁] | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also has to be changed for batch support it seems.
Suggested change
|
||||||||||
u[i₁] = μ[i₁]/ (K[i₁,:] ⋅ v) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
It would be better to select columns instead of rows and to use views. Julia uses column major order, so a column is close in memory and hence accessing columns is faster. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True! I'll try to come up with something. |
||||||||||
Δ = u[i₁] - old_u | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
G[i₁, :] = u[i₁] * K[i₁,:] .* v | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, better to work with columns than with rows. Also some unnecessary allocations here:
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the unnecessary allocation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh there are multiple unnecessary allocations. First of all, Whereas the alternative suggestion allocates only There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the answer, and sorry for the bad code. I'm still very crude with code optimization. |
||||||||||
Δμ[i₁] = u[i₁] * (K[i₁,:] ⋅ v) - μ[i₁] | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. |
||||||||||
@. Δν = Δν + Δ * K[i₁,:] * v | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And here. |
||||||||||
else | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comments in the second branch. |
||||||||||
old_v = v[i₂] | ||||||||||
v[i₂] = ν[i₂]/ (K[:,i₂] ⋅ u) | ||||||||||
Δ = v[i₂] - old_v | ||||||||||
G[:, i₂] = v[i₂] * K[:,i₂] .* u | ||||||||||
Δν[i₂] = v[i₂] * (K[:,i₂] ⋅ u) - ν[i₂] | ||||||||||
@. Δμ = Δμ + Δ * K[:,i₂] * u | ||||||||||
end | ||||||||||
|
||||||||||
A_batched_mul_B!(solver.cache.Kv, K, v) # Compute to evaluate convergence | ||||||||||
end | ||||||||||
|
||||||||||
function sinkhorn_plan(solver::SinkhornSolver{Greenkhorn}) | ||||||||||
cache = solver.cache | ||||||||||
return cache.G | ||||||||||
end | ||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,238 @@ | ||
using OptimalTransport | ||
|
||
using Distances | ||
using ForwardDiff | ||
using ReverseDiff | ||
using LogExpFunctions | ||
using PythonOT: PythonOT | ||
|
||
using LinearAlgebra | ||
using Random | ||
using Test | ||
|
||
const POT = PythonOT | ||
|
||
Random.seed!(100) | ||
|
||
@testset "greenkhorn.jl" begin | ||
# size of source and target | ||
M = 250 | ||
N = 200 | ||
|
||
# create two random histograms | ||
μ = normalize!(rand(M), 1) | ||
ν = normalize!(rand(N), 1) | ||
|
||
# create random cost matrix | ||
C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2) | ||
|
||
# regularization parameter | ||
ε = 0.01 | ||
|
||
@testset "example" begin | ||
# compute optimal transport plan and optimal transport cost | ||
γ = sinkhorn(μ, ν, C, ε, Greenkhorn(); maxiter=200_000, rtol=1e-9) | ||
c = sinkhorn2(μ, ν, C, ε, Greenkhorn(); maxiter=200_000, rtol=1e-9) | ||
|
||
# check that plan and cost are consistent | ||
@test c ≈ dot(γ, C) | ||
|
||
# compare with POT | ||
γ_pot = POT.sinkhorn(μ, ν, C, ε; numItermax=5_000, stopThr=1e-9) | ||
c_pot = POT.sinkhorn2(μ, ν, C, ε; numItermax=5_000, stopThr=1e-9)[1] | ||
@test γ_pot ≈ γ rtol = 1e-6 | ||
@test c_pot ≈ c rtol = 1e-7 | ||
|
||
# compute optimal transport cost with regularization term | ||
c_w_regularization = sinkhorn2( | ||
μ, ν, C, ε, Greenkhorn(); maxiter=200_000, regularization=true | ||
) | ||
@test c_w_regularization ≈ c + ε * sum(x -> iszero(x) ? x : x * log(x), γ) | ||
@test c_w_regularization ≈ | ||
sinkhorn2(μ, ν, C, ε; maxiter=5_000, regularization=true) | ||
|
||
# # ensure that provided plan is used and correct | ||
c2 = sinkhorn2(similar(μ), similar(ν), C, rand(), Greenkhorn(); plan=γ) | ||
@test c2 ≈ c | ||
@test c2 == sinkhorn2(similar(μ), similar(ν), C, rand(); plan=γ) | ||
c2_w_regularization = sinkhorn2( | ||
similar(μ), similar(ν), C, ε, Greenkhorn(); plan=γ, regularization=true | ||
) | ||
@test c2_w_regularization ≈ c_w_regularization | ||
@test c2_w_regularization == | ||
sinkhorn2(similar(μ), similar(ν), C, ε; plan=γ, regularization=true) | ||
|
||
|
||
################################################################ | ||
# FIX BATCHES CASE!!! Not working for Greenkhorn implementation# | ||
################################################################ | ||
|
||
# # batches of histograms | ||
# d = 10 | ||
# for (size2_μ, size2_ν) in | ||
# (((), (d,)), ((1,), (d,)), ((d,), ()), ((d,), (1,)), ((d,), (d,))) | ||
# # generate batches of histograms | ||
# μ_batch = repeat(μ, 1, size2_μ...) | ||
# ν_batch = repeat(ν, 1, size2_ν...) | ||
|
||
# # compute optimal transport plan and check that it is consistent with the | ||
# # plan for individual histograms | ||
# γ_all = sinkhorn( | ||
# μ_batch, ν_batch, C, ε, Greenkhorn(); maxiter=5_000, rtol=1e-9 | ||
# ) | ||
# @test size(γ_all) == (M, N, d) | ||
# @test all(view(γ_all, :, :, i) ≈ γ for i in axes(γ_all, 3)) | ||
# @test γ_all == sinkhorn(μ_batch, ν_batch, C, ε; maxiter=5_000, rtol=1e-9) | ||
|
||
# # compute optimal transport cost and check that it is consistent with the | ||
# # cost for individual histograms | ||
# c_all = sinkhorn2( | ||
# μ_batch, ν_batch, C, ε, Greenkhorn(); maxiter=5_000, rtol=1e-9 | ||
# ) | ||
# @test size(c_all) == (d,) | ||
# @test all(x ≈ c for x in c_all) | ||
# @test c_all == sinkhorn2(μ_batch, ν_batch, C, ε; maxiter=5_000, rtol=1e-9) | ||
# end | ||
end | ||
|
||
# different element type | ||
@testset "Float32" begin | ||
# create histograms and cost matrix with element type `Float32` | ||
μ32 = map(Float32, μ) | ||
ν32 = map(Float32, ν) | ||
C32 = map(Float32, C) | ||
ε32 = Float32(ε) | ||
|
||
# compute optimal transport plan and optimal transport cost | ||
γ = sinkhorn(μ32, ν32, C32, ε32, Greenkhorn(); maxiter=200_000, rtol=1e-6) | ||
c = sinkhorn2(μ32, ν32, C32, ε32, Greenkhorn(); maxiter=200_000, rtol=1e-6) | ||
@test eltype(γ) === Float32 | ||
@test typeof(c) === Float32 | ||
|
||
# check that plan and cost are consistent | ||
@test c ≈ dot(γ, C32) | ||
|
||
# compare with default algorithm | ||
γ_default = sinkhorn(μ32, ν32, C32, ε32; maxiter=5_000, rtol=1e-6) | ||
c_default = sinkhorn2(μ32, ν32, C32, ε32; maxiter=5_000, rtol=1e-6) | ||
@test γ_default ≈ γ rtol=1e-4 | ||
@test c_default ≈ c rtol=1e-4 | ||
|
||
# compare with POT | ||
γ_pot = POT.sinkhorn(μ32, ν32, C32, ε32; numItermax=5_000, stopThr=1e-6) | ||
c_pot = POT.sinkhorn2(μ32, ν32, C32, ε32; numItermax=5_000, stopThr=1e-6)[1] | ||
@test map(Float32, γ_pot) ≈ γ rtol = 1e-3 | ||
@test Float32(c_pot) ≈ c rtol = 1e-3 | ||
|
||
################################################################ | ||
# FIX BATCHES CASE!!! Not working for Greenkhorn implementation# | ||
################################################################ | ||
|
||
# batches of histograms | ||
# d = 10 | ||
# for (size2_μ, size2_ν) in | ||
# (((), (d,)), ((1,), (d,)), ((d,), ()), ((d,), (1,)), ((d,), (d,))) | ||
# # generate batches of histograms | ||
# μ32_batch = repeat(μ32, 1, size2_μ...) | ||
# ν32_batch = repeat(ν32, 1, size2_ν...) | ||
|
||
# # compute optimal transport plan and check that it is consistent with the | ||
# # plan for individual histograms | ||
# γ_all = sinkhorn( | ||
# μ32_batch, ν32_batch, C32, ε32, Greenkhorn(); maxiter=5_000, rtol=1e-6 | ||
# ) | ||
# @test size(γ_all) == (M, N, d) | ||
# @test all(view(γ_all, :, :, i) ≈ γ for i in axes(γ_all, 3)) | ||
# @test γ_all == | ||
# sinkhorn(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6) | ||
|
||
# # compute optimal transport cost and check that it is consistent with the | ||
# # cost for individual histograms | ||
# c_all = sinkhorn2( | ||
# μ32_batch, ν32_batch, C32, ε32, Greenkhorn(); maxiter=5_000, rtol=1e-6 | ||
# ) | ||
# @test size(c_all) == (d,) | ||
# @test all(x ≈ c for x in c_all) | ||
# @test c_all == | ||
# sinkhorn2(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6) | ||
# end | ||
end | ||
|
||
|
||
################################################################ | ||
# FIX AD !!! Not working for Greenkhorn implementation # | ||
################################################################ | ||
|
||
# https://github.com/JuliaOptimalTransport/OptimalTransport.jl/issues/86 | ||
# @testset "AD" begin | ||
# # compute gradients with respect to source and target marginals separately and | ||
# # together. test against gradient computed using analytic formula of Proposition 2.3 of | ||
# # Cuturi, Marco, and Gabriel Peyré. "A smoothed dual approach for variational Wasserstein problems." SIAM Journal on Imaging Sciences 9.1 (2016): 320-343. | ||
# # | ||
# ε = 0.05 # use a larger ε to avoid having to do many iterations | ||
# # target marginal | ||
# for Diff in [ReverseDiff, ForwardDiff] | ||
# ∇ = Diff.gradient(log.(ν)) do xs | ||
# sinkhorn2(μ, softmax(xs), C, ε, Greenkhorn(); regularization=true) | ||
# end | ||
# ∇default = Diff.gradient(log.(ν)) do xs | ||
# sinkhorn2(μ, softmax(xs), C, ε; regularization=true) | ||
# end | ||
# @test ∇ == ∇default | ||
|
||
# solver = OptimalTransport.build_solver(μ, ν, C, ε, Greenkhorn()) | ||
# OptimalTransport.solve!(solver) | ||
# # helper function | ||
# function dualvar_to_grad(x, ε) | ||
# x = -ε * log.(x) | ||
# x .-= sum(x) / size(x, 1) | ||
# return -x | ||
# end | ||
# ∇_ot = dualvar_to_grad(solver.cache.v, ε) | ||
# # chain rule because target measure parameterised by softmax | ||
# J_softmax = ForwardDiff.jacobian(log.(ν)) do xs | ||
# softmax(xs) | ||
# end | ||
# ∇analytic_target = J_softmax * ∇_ot | ||
# # check that gradient obtained by AD matches the analytic formula | ||
# @test ∇ ≈ ∇analytic_target rtol = 1e-6 | ||
|
||
# # source marginal | ||
# ∇ = Diff.gradient(log.(μ)) do xs | ||
# sinkhorn2(softmax(xs), ν, C, ε, Greenkhorn(); regularization=true) | ||
# end | ||
# ∇default = Diff.gradient(log.(μ)) do xs | ||
# sinkhorn2(softmax(xs), ν, C, ε; regularization=true) | ||
# end | ||
# @test ∇ == ∇default | ||
|
||
# # check that gradient obtained by AD matches the analytic formula | ||
# solver = OptimalTransport.build_solver(μ, ν, C, ε, Greenkhorn()) | ||
# OptimalTransport.solve!(solver) | ||
# J_softmax = ForwardDiff.jacobian(log.(μ)) do xs | ||
# softmax(xs) | ||
# end | ||
# ∇_ot = dualvar_to_grad(solver.cache.u, ε) | ||
# ∇analytic_source = J_softmax * ∇_ot | ||
# @test ∇ ≈ ∇analytic_source rtol = 1e-6 | ||
|
||
# # both marginals | ||
# ∇ = Diff.gradient(log.(vcat(μ, ν))) do xs | ||
# sinkhorn2( | ||
# softmax(xs[1:M]), | ||
# softmax(xs[(M + 1):end]), | ||
# C, | ||
# ε, | ||
# Greenkhorn(); | ||
# regularization=true, | ||
# ) | ||
# end | ||
# ∇default = Diff.gradient(log.(vcat(μ, ν))) do xs | ||
# sinkhorn2(softmax(xs[1:M]), softmax(xs[(M + 1):end]), C, ε; regularization=true) | ||
# end | ||
# @test ∇ == ∇default | ||
# ∇analytic = vcat(∇analytic_source, ∇analytic_target) | ||
# @test ∇ ≈ ∇analytic rtol = 1e-6 | ||
# end | ||
# end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be helpful to describe what the differences are (if there are any apart from implementation details).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, there are. The paper implementation is actually just a couple of lines commented out. I'll point out in the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just realized that it's already in the code. Inside the
step!
function.