Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing the Greenkhorn #159

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Currently the following variants of the Sinkhorn algorithm are supported:
SinkhornGibbs
SinkhornStabilized
SinkhornEpsilonScaling
Greenkhorn
```

The following methods are deprecated and will be removed:
Expand Down
2 changes: 2 additions & 0 deletions src/OptimalTransport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using NNlib: NNlib
export SinkhornGibbs, SinkhornStabilized, SinkhornEpsilonScaling
export SinkhornBarycenterGibbs
export QuadraticOTNewton
export Greenkhorn

export sinkhorn, sinkhorn2
export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter
Expand All @@ -36,6 +37,7 @@ include("entropic/sinkhorn_unbalanced.jl")
include("entropic/sinkhorn_barycenter.jl")
include("entropic/sinkhorn_barycenter_gibbs.jl")
include("entropic/sinkhorn_solve.jl")
include("entropic/greenkhorn.jl")

include("quadratic.jl")
include("quadratic_newton.jl")
Expand Down
114 changes: 114 additions & 0 deletions src/entropic/greenkhorn.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Greenkhorn is a greedy version of the Sinkhorn algorithm
# This method is from https://arxiv.org/pdf/1705.09634.pdf
# Code is based on implementation from package POT
Comment on lines +1 to +3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to describe what the differences are (if there are any apart from implementation details).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, there are. The paper implementation is actually just a couple of lines commented out. I'll point out in the code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized that it's already in the code. Inside the step! function.



"""
Greenkhorn()

Greenkhorn is a greedy version of the Sinkhorn algorithm.
"""
struct Greenkhorn <: Sinkhorn end

struct GreenkhornCache{U,V,KT}
u::U
v::V
K::KT
Kv::U #placeholder
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why "placeholder"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a partial solution. Without this Kv, I got an error. It seems that the Sinkhorn structs further on require it. I could not find out how to get rid of it without changing the code for Sinkhorn.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean you use it explicitly below for checking convergence?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's been used to check the convergence. But I think the convergence check might be done another way more efficiently. But I was having trouble getting it to work with the existing "api" for Sinkhorn, so I updated Kv and used the convergence verification already in place.

You are right that, as is, Kv is not actually just a placeholder, since it's been used in the convergence.

G::KT
du::U
dv::V
end

Base.show(io::IO, ::Greenkhorn) = print(io, "Greenkhorn algorithm")

function build_cache(
::Type{T},
::Greenkhorn,
size2::Tuple,
μ::AbstractVecOrMat,
ν::AbstractVecOrMat,
C::AbstractMatrix,
ε::Real,
) where {T}
# compute Gibbs kernel (has to be mutable for ε-scaling algorithm)
K = similar(C, T)
@. K = exp(-C / ε)

# create and initialize dual potentials
u = similar(μ, T, size(μ, 1), size2...)
v = similar(ν, T, size(ν, 1), size2...)
fill!(u, one(T)/size(μ, 1))
fill!(v, one(T)/size(ν, 1))

G = sinkhorn_plan(u, v, K)
# G = diagm(u) * K * diagm(v)

Kv = similar(u)

# This is me triying to get the `batch tests to work`
# improve this!
# if (length(size(μ)) == 2 && length(size(ν)) == 1)
# du = reshape(sum(G, dims=2), size(μ)) - μ
# dv = reshape(sum(G, dims=1),size(v)) - repeat(ν,1,size(v)[2])
# elseif (length(size(μ)) == 1 && length(size(ν)) == 2)
# du = reshape(sum(G, dims=2),size(u)) - repeat(μ,1,size(u)[2])
# dv = reshape(sum(G, dims=1), size(ν)) - ν
# else
du = reshape(sum(G, dims=2), size(μ)) - μ
dv = reshape(sum(G, dims=1), size(ν)) - ν
# end


return GreenkhornCache(u, v, K, Kv, G, du, dv)
end

prestep!(::SinkhornSolver{Greenkhorn}, ::Int) = nothing

init_step!(solver::SinkhornSolver{<:Greenkhorn}) = nothing

function step!(solver::SinkhornSolver{<:Greenkhorn}, iter::Int)
μ = solver.source
ν = solver.target
cache = solver.cache
u = cache.u
v = cache.v
K = cache.K
G = cache.G
Δμ= cache.du
Δν= cache.dv

# The implementation in POT does not compute `μ .* log.(μ ./ sum(G', dims=1)[:])`
# or `ν .* log.(ν ./ sum(G', dims=2)[:])`. Yet, this term is present in the original
# paper, where it uses ρ(a,b) = b - a + a log(a/b).
# ρμ = abs.(Δμ + μ .* log.(μ ./ sum(G', dims=1)[:]))
# ρν = abs.(Δν + ν .* log.(ν ./ sum(G', dims=2)[:]))

i₁ = argmax(abs.(Δμ))
i₂ = argmax(abs.(Δν))
Comment on lines +87 to +88
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
i₁ = argmax(abs.(Δμ))
i₂ = argmax(abs.(Δν))
Δμ_max, Δμ_max_idx = findmax(abs, Δμ)
Δν_max, Δν_max_idx = findmax(abs, Δν)


# if ρμ[i₁]> ρν[i₂]
if abs(Δμ[i₁]) > abs(Δν[i₂])
Comment on lines +90 to +91
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# if ρμ[i₁]> ρν[i₂]
if abs(Δμ[i₁]) > abs(Δν[i₂])
if Δμ_max > Δν_max

old_u = u[i₁]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also has to be changed for batch support it seems.

Suggested change
old_u = u[i₁]
old_u = u[Δμ_max_idx]

u[i₁] = μ[i₁]/ (K[i₁,:] ⋅ v)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
u[i₁] = μ[i₁]/ (K[i₁,:] v)
u[Δμ_max_idx] = μ[Δμ_max_idx] / dot(K[Δμ_max_idx, :], v)

It would be better to select columns instead of rows and to use views. Julia uses column major order, so a column is close in memory and hence accessing columns is faster.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True! I'll try to come up with something.

Δ = u[i₁] - old_u
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Δ = u[i₁] - old_u
Δ = u[Δμ_max_idx] - old_u

G[i₁, :] = u[i₁] * K[i₁,:] .* v
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, better to work with columns than with rows. Also some unnecessary allocations here:

Suggested change
G[i₁, :] = u[i₁] * K[i₁,:] .* v
G[Δμ_max_idx, :] .= u[Δμ_max_idx] .* K[Δμ_max_idx, :] .* v

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the unnecessary allocation?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh there are multiple unnecessary allocations.

First of all, K[i₁, :] creates, i.e., allocates, a row vector. Then u[i₁] * K[i₁, :] scales it and allocates a new row vector. And finally u[i₁] * K[i₁,:] .* v multiplies the entries of u[i₁] * K[i₁,:] elementwise with v and allocates yet another row vector.

Whereas the alternative suggestion allocates only K[i₁, :] (could be avoided by using a view) and then fuses all multiplications and writes the result directly to G without allocating any other row vector.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the answer, and sorry for the bad code. I'm still very crude with code optimization.

Δμ[i₁] = u[i₁] * (K[i₁,:] ⋅ v) - μ[i₁]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

@. Δν = Δν + Δ * K[i₁,:] * v
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here.

else
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comments in the second branch.

old_v = v[i₂]
v[i₂] = ν[i₂]/ (K[:,i₂] ⋅ u)
Δ = v[i₂] - old_v
G[:, i₂] = v[i₂] * K[:,i₂] .* u
Δν[i₂] = v[i₂] * (K[:,i₂] ⋅ u) - ν[i₂]
@. Δμ = Δμ + Δ * K[:,i₂] * u
end

A_batched_mul_B!(solver.cache.Kv, K, v) # Compute to evaluate convergence
end

function sinkhorn_plan(solver::SinkhornSolver{Greenkhorn})
cache = solver.cache
return cache.G
end

238 changes: 238 additions & 0 deletions test/entropic/greenkhorn.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
using OptimalTransport

using Distances
using ForwardDiff
using ReverseDiff
using LogExpFunctions
using PythonOT: PythonOT

using LinearAlgebra
using Random
using Test

const POT = PythonOT

Random.seed!(100)

@testset "greenkhorn.jl" begin
# size of source and target
M = 250
N = 200

# create two random histograms
μ = normalize!(rand(M), 1)
ν = normalize!(rand(N), 1)

# create random cost matrix
C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2)

# regularization parameter
ε = 0.01

@testset "example" begin
# compute optimal transport plan and optimal transport cost
γ = sinkhorn(μ, ν, C, ε, Greenkhorn(); maxiter=200_000, rtol=1e-9)
c = sinkhorn2(μ, ν, C, ε, Greenkhorn(); maxiter=200_000, rtol=1e-9)

# check that plan and cost are consistent
@test c ≈ dot(γ, C)

# compare with POT
γ_pot = POT.sinkhorn(μ, ν, C, ε; numItermax=5_000, stopThr=1e-9)
c_pot = POT.sinkhorn2(μ, ν, C, ε; numItermax=5_000, stopThr=1e-9)[1]
@test γ_pot ≈ γ rtol = 1e-6
@test c_pot ≈ c rtol = 1e-7

# compute optimal transport cost with regularization term
c_w_regularization = sinkhorn2(
μ, ν, C, ε, Greenkhorn(); maxiter=200_000, regularization=true
)
@test c_w_regularization ≈ c + ε * sum(x -> iszero(x) ? x : x * log(x), γ)
@test c_w_regularization ≈
sinkhorn2(μ, ν, C, ε; maxiter=5_000, regularization=true)

# # ensure that provided plan is used and correct
c2 = sinkhorn2(similar(μ), similar(ν), C, rand(), Greenkhorn(); plan=γ)
@test c2 ≈ c
@test c2 == sinkhorn2(similar(μ), similar(ν), C, rand(); plan=γ)
c2_w_regularization = sinkhorn2(
similar(μ), similar(ν), C, ε, Greenkhorn(); plan=γ, regularization=true
)
@test c2_w_regularization ≈ c_w_regularization
@test c2_w_regularization ==
sinkhorn2(similar(μ), similar(ν), C, ε; plan=γ, regularization=true)


################################################################
# FIX BATCHES CASE!!! Not working for Greenkhorn implementation#
################################################################

# # batches of histograms
# d = 10
# for (size2_μ, size2_ν) in
# (((), (d,)), ((1,), (d,)), ((d,), ()), ((d,), (1,)), ((d,), (d,)))
# # generate batches of histograms
# μ_batch = repeat(μ, 1, size2_μ...)
# ν_batch = repeat(ν, 1, size2_ν...)

# # compute optimal transport plan and check that it is consistent with the
# # plan for individual histograms
# γ_all = sinkhorn(
# μ_batch, ν_batch, C, ε, Greenkhorn(); maxiter=5_000, rtol=1e-9
# )
# @test size(γ_all) == (M, N, d)
# @test all(view(γ_all, :, :, i) ≈ γ for i in axes(γ_all, 3))
# @test γ_all == sinkhorn(μ_batch, ν_batch, C, ε; maxiter=5_000, rtol=1e-9)

# # compute optimal transport cost and check that it is consistent with the
# # cost for individual histograms
# c_all = sinkhorn2(
# μ_batch, ν_batch, C, ε, Greenkhorn(); maxiter=5_000, rtol=1e-9
# )
# @test size(c_all) == (d,)
# @test all(x ≈ c for x in c_all)
# @test c_all == sinkhorn2(μ_batch, ν_batch, C, ε; maxiter=5_000, rtol=1e-9)
# end
end

# different element type
@testset "Float32" begin
# create histograms and cost matrix with element type `Float32`
μ32 = map(Float32, μ)
ν32 = map(Float32, ν)
C32 = map(Float32, C)
ε32 = Float32(ε)

# compute optimal transport plan and optimal transport cost
γ = sinkhorn(μ32, ν32, C32, ε32, Greenkhorn(); maxiter=200_000, rtol=1e-6)
c = sinkhorn2(μ32, ν32, C32, ε32, Greenkhorn(); maxiter=200_000, rtol=1e-6)
@test eltype(γ) === Float32
@test typeof(c) === Float32

# check that plan and cost are consistent
@test c ≈ dot(γ, C32)

# compare with default algorithm
γ_default = sinkhorn(μ32, ν32, C32, ε32; maxiter=5_000, rtol=1e-6)
c_default = sinkhorn2(μ32, ν32, C32, ε32; maxiter=5_000, rtol=1e-6)
@test γ_default ≈ γ rtol=1e-4
@test c_default ≈ c rtol=1e-4

# compare with POT
γ_pot = POT.sinkhorn(μ32, ν32, C32, ε32; numItermax=5_000, stopThr=1e-6)
c_pot = POT.sinkhorn2(μ32, ν32, C32, ε32; numItermax=5_000, stopThr=1e-6)[1]
@test map(Float32, γ_pot) ≈ γ rtol = 1e-3
@test Float32(c_pot) ≈ c rtol = 1e-3

################################################################
# FIX BATCHES CASE!!! Not working for Greenkhorn implementation#
################################################################

# batches of histograms
# d = 10
# for (size2_μ, size2_ν) in
# (((), (d,)), ((1,), (d,)), ((d,), ()), ((d,), (1,)), ((d,), (d,)))
# # generate batches of histograms
# μ32_batch = repeat(μ32, 1, size2_μ...)
# ν32_batch = repeat(ν32, 1, size2_ν...)

# # compute optimal transport plan and check that it is consistent with the
# # plan for individual histograms
# γ_all = sinkhorn(
# μ32_batch, ν32_batch, C32, ε32, Greenkhorn(); maxiter=5_000, rtol=1e-6
# )
# @test size(γ_all) == (M, N, d)
# @test all(view(γ_all, :, :, i) ≈ γ for i in axes(γ_all, 3))
# @test γ_all ==
# sinkhorn(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6)

# # compute optimal transport cost and check that it is consistent with the
# # cost for individual histograms
# c_all = sinkhorn2(
# μ32_batch, ν32_batch, C32, ε32, Greenkhorn(); maxiter=5_000, rtol=1e-6
# )
# @test size(c_all) == (d,)
# @test all(x ≈ c for x in c_all)
# @test c_all ==
# sinkhorn2(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6)
# end
end


################################################################
# FIX AD !!! Not working for Greenkhorn implementation #
################################################################

# https://github.com/JuliaOptimalTransport/OptimalTransport.jl/issues/86
# @testset "AD" begin
# # compute gradients with respect to source and target marginals separately and
# # together. test against gradient computed using analytic formula of Proposition 2.3 of
# # Cuturi, Marco, and Gabriel Peyré. "A smoothed dual approach for variational Wasserstein problems." SIAM Journal on Imaging Sciences 9.1 (2016): 320-343.
# #
# ε = 0.05 # use a larger ε to avoid having to do many iterations
# # target marginal
# for Diff in [ReverseDiff, ForwardDiff]
# ∇ = Diff.gradient(log.(ν)) do xs
# sinkhorn2(μ, softmax(xs), C, ε, Greenkhorn(); regularization=true)
# end
# ∇default = Diff.gradient(log.(ν)) do xs
# sinkhorn2(μ, softmax(xs), C, ε; regularization=true)
# end
# @test ∇ == ∇default

# solver = OptimalTransport.build_solver(μ, ν, C, ε, Greenkhorn())
# OptimalTransport.solve!(solver)
# # helper function
# function dualvar_to_grad(x, ε)
# x = -ε * log.(x)
# x .-= sum(x) / size(x, 1)
# return -x
# end
# ∇_ot = dualvar_to_grad(solver.cache.v, ε)
# # chain rule because target measure parameterised by softmax
# J_softmax = ForwardDiff.jacobian(log.(ν)) do xs
# softmax(xs)
# end
# ∇analytic_target = J_softmax * ∇_ot
# # check that gradient obtained by AD matches the analytic formula
# @test ∇ ≈ ∇analytic_target rtol = 1e-6

# # source marginal
# ∇ = Diff.gradient(log.(μ)) do xs
# sinkhorn2(softmax(xs), ν, C, ε, Greenkhorn(); regularization=true)
# end
# ∇default = Diff.gradient(log.(μ)) do xs
# sinkhorn2(softmax(xs), ν, C, ε; regularization=true)
# end
# @test ∇ == ∇default

# # check that gradient obtained by AD matches the analytic formula
# solver = OptimalTransport.build_solver(μ, ν, C, ε, Greenkhorn())
# OptimalTransport.solve!(solver)
# J_softmax = ForwardDiff.jacobian(log.(μ)) do xs
# softmax(xs)
# end
# ∇_ot = dualvar_to_grad(solver.cache.u, ε)
# ∇analytic_source = J_softmax * ∇_ot
# @test ∇ ≈ ∇analytic_source rtol = 1e-6

# # both marginals
# ∇ = Diff.gradient(log.(vcat(μ, ν))) do xs
# sinkhorn2(
# softmax(xs[1:M]),
# softmax(xs[(M + 1):end]),
# C,
# ε,
# Greenkhorn();
# regularization=true,
# )
# end
# ∇default = Diff.gradient(log.(vcat(μ, ν))) do xs
# sinkhorn2(softmax(xs[1:M]), softmax(xs[(M + 1):end]), C, ε; regularization=true)
# end
# @test ∇ == ∇default
# ∇analytic = vcat(∇analytic_source, ∇analytic_target)
# @test ∇ ≈ ∇analytic rtol = 1e-6
# end
# end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ const GROUP = get(ENV, "GROUP", "All")
@safetestset "Sinkhorn divergence" begin
include(joinpath("entropic", "sinkhorn_divergence.jl"))
end
@safetestset "Greenkhorn" begin
include(joinpath("entropic", "greenkhorn.jl"))
end
end

@safetestset "Quadratically regularized OT" begin
Expand Down