From 0db7257edd5bfdffb274bce2d797f9352aec86ac Mon Sep 17 00:00:00 2001 From: Davi Barreira Date: Sat, 15 Jan 2022 14:18:27 -0300 Subject: [PATCH 1/7] :sparkles: Greenkhorn implementation. --- src/OptimalTransport.jl | 2 + src/entropic/greenkhorn.jl | 114 +++++++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+) create mode 100644 src/entropic/greenkhorn.jl diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index 1653431e..83f53c74 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -17,6 +17,7 @@ using NNlib: NNlib export SinkhornGibbs, SinkhornStabilized, SinkhornEpsilonScaling export SinkhornBarycenterGibbs export QuadraticOTNewton +export Greenkhorn export sinkhorn, sinkhorn2 export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter @@ -36,6 +37,7 @@ include("entropic/sinkhorn_unbalanced.jl") include("entropic/sinkhorn_barycenter.jl") include("entropic/sinkhorn_barycenter_gibbs.jl") include("entropic/sinkhorn_solve.jl") +include("entropic/greenkhorn.jl") include("quadratic.jl") include("quadratic_newton.jl") diff --git a/src/entropic/greenkhorn.jl b/src/entropic/greenkhorn.jl new file mode 100644 index 00000000..dca44dbf --- /dev/null +++ b/src/entropic/greenkhorn.jl @@ -0,0 +1,114 @@ +# Greenkhorn is a greedy version of the Sinkhorn algorithm +# This method is from https://arxiv.org/pdf/1705.09634.pdf +# Code is based on implementation from package POT + + +""" + Greenkhorn() + +Greenkhorn is a greedy version of the Sinkhorn algorithm. +""" +struct Greenkhorn <: Sinkhorn end + +struct GreenkhornCache{U,V,KT} + u::U + v::V + K::KT + Kv::U + G::KT + du::U + dv::V +end + +Base.show(io::IO, ::Greenkhorn) = print(io, "Greenkhorn algorithm") + +function build_cache( + ::Type{T}, + ::Greenkhorn, + size2::Tuple, + μ::AbstractVecOrMat, + ν::AbstractVecOrMat, + C::AbstractMatrix, + ε::Real, +) where {T} + # compute Gibbs kernel (has to be mutable for ε-scaling algorithm) + K = similar(C, T) + @. K = exp(-C / ε) + + # create and initialize dual potentials + u = similar(μ, T, size(μ, 1), size2...) + v = similar(ν, T, size(ν, 1), size2...) + fill!(u, one(T)/size(μ, 1)) + fill!(v, one(T)/size(ν, 1)) + + # G = sinkhorn_plan(u, v, K) + G = diagm(u) * K * diagm(v) + + Kv = similar(u) + + # improve this! + # du = similar(u) + # fill!(du, 1.0) + # dv = similar(v) + du = sum(G', dims=1)[:] - μ + dv = sum(G', dims=2)[:] - ν + + return GreenkhornCache(u, v, K, Kv, G, du, dv) +end + +prestep!(::SinkhornSolver{Greenkhorn}, ::Int) = nothing + +function init_step!(solver::SinkhornSolver{<:Greenkhorn}) + return A_batched_mul_B!(solver.cache.Kv, solver.cache.K, solver.cache.v) +end + +function step!(solver::SinkhornSolver{<:Greenkhorn}, iter::Int) + μ = solver.source + ν = solver.target + cache = solver.cache + u = cache.u + v = cache.v + K = cache.K + G = cache.G + Δμ= cache.du + Δν= cache.dv + + ρμ = abs.(Δμ + μ .* log.(μ ./ sum(G', dims=1)[:])) + ρν = abs.(Δν + ν .* log.(ν ./ sum(G', dims=2)[:])) + + i₁ = argmax(ρμ) + i₂ = argmax(ρν) + + if ρμ[i₁]> ρν[i₂] + old_u = u[i₁] + u[i₁] = μ[i₁]/ (K[i₁,:] ⋅ v) + Δ = u[i₁] - old_u + G[i₁, :] = u[i₁] * K[i₁,:] .* v + Δμ[i₁] = u[i₁] * (K[i₁,:] ⋅ v) - μ[i₁] + Δν = Δν .+ Δ .* K[i₁,:] .* v + else + old_v = v[i₂] + v[i₂] = ν[i₂]/ (K[:,i₂] ⋅ u) + Δ = v[i₂] - old_v + G[:, i₂] = v[i₂] * K[:,i₂] .* u + Δν[i₂] = v[i₂] * (K[:,i₂] ⋅ u) - ν[i₂] + Δμ = Δμ .+ Δ .* K[:,i₂] .* u + end + cache.du .= Δμ + cache.dv .= Δν + A_batched_mul_B!(solver.cache.Kv, K, v) +end + +function sinkhorn_plan(solver::SinkhornSolver{Greenkhorn}) + cache = solver.cache + # println("OK") + # println('K',cache.K) + # println('u',cache.u) + # println('v',cache.v) + # + # println("μ",solver.source) + # println("ν",solver.target) + # return sinkhorn_plan(cache.u, cache.v, cache.K) + return cache.G +end + From 21514cff48b03a80f6a8312b183b1357bed0e9ea Mon Sep 17 00:00:00 2001 From: Davi Barreira Date: Sat, 15 Jan 2022 14:30:14 -0300 Subject: [PATCH 2/7] :sparkles: Greenkhorn implementation, improved code. --- src/entropic/greenkhorn.jl | 31 +++++++++++-------------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/src/entropic/greenkhorn.jl b/src/entropic/greenkhorn.jl index dca44dbf..480adc87 100644 --- a/src/entropic/greenkhorn.jl +++ b/src/entropic/greenkhorn.jl @@ -14,7 +14,7 @@ struct GreenkhornCache{U,V,KT} u::U v::V K::KT - Kv::U + Kv::U #placeholder G::KT du::U dv::V @@ -47,9 +47,6 @@ function build_cache( Kv = similar(u) # improve this! - # du = similar(u) - # fill!(du, 1.0) - # dv = similar(v) du = sum(G', dims=1)[:] - μ dv = sum(G', dims=2)[:] - ν @@ -58,9 +55,7 @@ end prestep!(::SinkhornSolver{Greenkhorn}, ::Int) = nothing -function init_step!(solver::SinkhornSolver{<:Greenkhorn}) - return A_batched_mul_B!(solver.cache.Kv, solver.cache.K, solver.cache.v) -end +init_step!(solver::SinkhornSolver{<:Greenkhorn}) = nothing function step!(solver::SinkhornSolver{<:Greenkhorn}, iter::Int) μ = solver.source @@ -73,6 +68,10 @@ function step!(solver::SinkhornSolver{<:Greenkhorn}, iter::Int) Δμ= cache.du Δν= cache.dv + # The implementation in POT does not compute `μ .* log.(μ ./ sum(G', dims=1)[:])` + # or `ν .* log.(ν ./ sum(G', dims=2)[:])`. Yet, this term is present in the original + # paper, where it uses ρ(a,b) = b - a + a log(a/b). + ρμ = abs.(Δμ + μ .* log.(μ ./ sum(G', dims=1)[:])) ρν = abs.(Δν + ν .* log.(ν ./ sum(G', dims=2)[:])) @@ -85,30 +84,22 @@ function step!(solver::SinkhornSolver{<:Greenkhorn}, iter::Int) Δ = u[i₁] - old_u G[i₁, :] = u[i₁] * K[i₁,:] .* v Δμ[i₁] = u[i₁] * (K[i₁,:] ⋅ v) - μ[i₁] - Δν = Δν .+ Δ .* K[i₁,:] .* v + # Δν = Δν .+ Δ .* K[i₁,:] .* v + @. Δν = Δν + Δ * K[i₁,:] * v else old_v = v[i₂] v[i₂] = ν[i₂]/ (K[:,i₂] ⋅ u) Δ = v[i₂] - old_v G[:, i₂] = v[i₂] * K[:,i₂] .* u Δν[i₂] = v[i₂] * (K[:,i₂] ⋅ u) - ν[i₂] - Δμ = Δμ .+ Δ .* K[:,i₂] .* u + @. Δμ = Δμ + Δ * K[:,i₂] * u end - cache.du .= Δμ - cache.dv .= Δν - A_batched_mul_B!(solver.cache.Kv, K, v) + + A_batched_mul_B!(solver.cache.Kv, K, v) # Compute to evaluate convergence end function sinkhorn_plan(solver::SinkhornSolver{Greenkhorn}) cache = solver.cache - # println("OK") - # println('K',cache.K) - # println('u',cache.u) - # println('v',cache.v) - # - # println("μ",solver.source) - # println("ν",solver.target) - # return sinkhorn_plan(cache.u, cache.v, cache.K) return cache.G end From 436fb16576cfe7c6068f840c0c387f9399e2160a Mon Sep 17 00:00:00 2001 From: Davi Barreira Date: Sat, 15 Jan 2022 15:02:25 -0300 Subject: [PATCH 3/7] :sparkles: Adding testset for greenkhorn. --- test/entropic/greenkhorn.jl | 231 ++++++++++++++++++++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 test/entropic/greenkhorn.jl diff --git a/test/entropic/greenkhorn.jl b/test/entropic/greenkhorn.jl new file mode 100644 index 00000000..07ec2f22 --- /dev/null +++ b/test/entropic/greenkhorn.jl @@ -0,0 +1,231 @@ +using OptimalTransport + +using Distances +using ForwardDiff +using ReverseDiff +using LogExpFunctions +using PythonOT: PythonOT + +using LinearAlgebra +using Random +using Test + +const POT = PythonOT + +Random.seed!(100) + +@testset "sinkhorn_gibbs.jl" begin + # size of source and target + M = 250 + N = 200 + + # create two random histograms + μ = normalize!(rand(M), 1) + ν = normalize!(rand(N), 1) + + # create random cost matrix + C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2) + + # regularization parameter + ε = 0.01 + + @testset "example" begin + # compute optimal transport plan and optimal transport cost + γ = sinkhorn(μ, ν, C, ε, Greeenkhorn(); maxiter=5_000, rtol=1e-9) + c = sinkhorn2(μ, ν, C, ε, Greeenkhorn(); maxiter=5_000, rtol=1e-9) + + # check that plan and cost are consistent + @test c ≈ dot(γ, C) + + # compare with default algorithm + γ_default = sinkhorn(μ, ν, C, ε; maxiter=5_000, rtol=1e-9) + c_default = sinkhorn2(μ, ν, C, ε; maxiter=5_000, rtol=1e-9) + @test γ_default == γ + @test c_default == c + + # compare with POT + γ_pot = POT.sinkhorn(μ, ν, C, ε; numItermax=5_000, stopThr=1e-9) + c_pot = POT.sinkhorn2(μ, ν, C, ε; numItermax=5_000, stopThr=1e-9)[1] + @test γ_pot ≈ γ rtol = 1e-6 + @test c_pot ≈ c rtol = 1e-7 + + # compute optimal transport cost with regularization term + c_w_regularization = sinkhorn2( + μ, ν, C, ε, Greeenkhorn(); maxiter=5_000, regularization=true + ) + @test c_w_regularization ≈ c + ε * sum(x -> iszero(x) ? x : x * log(x), γ) + @test c_w_regularization == + sinkhorn2(μ, ν, C, ε; maxiter=5_000, regularization=true) + + # ensure that provided plan is used and correct + c2 = sinkhorn2(similar(μ), similar(ν), C, rand(), Greeenkhorn(); plan=γ) + @test c2 ≈ c + @test c2 == sinkhorn2(similar(μ), similar(ν), C, rand(); plan=γ) + c2_w_regularization = sinkhorn2( + similar(μ), similar(ν), C, ε, Greeenkhorn(); plan=γ, regularization=true + ) + @test c2_w_regularization ≈ c_w_regularization + @test c2_w_regularization == + sinkhorn2(similar(μ), similar(ν), C, ε; plan=γ, regularization=true) + + # batches of histograms + d = 10 + for (size2_μ, size2_ν) in + (((), (d,)), ((1,), (d,)), ((d,), ()), ((d,), (1,)), ((d,), (d,))) + # generate batches of histograms + μ_batch = repeat(μ, 1, size2_μ...) + ν_batch = repeat(ν, 1, size2_ν...) + + # compute optimal transport plan and check that it is consistent with the + # plan for individual histograms + γ_all = sinkhorn( + μ_batch, ν_batch, C, ε, Greeenkhorn(); maxiter=5_000, rtol=1e-9 + ) + @test size(γ_all) == (M, N, d) + @test all(view(γ_all, :, :, i) ≈ γ for i in axes(γ_all, 3)) + @test γ_all == sinkhorn(μ_batch, ν_batch, C, ε; maxiter=5_000, rtol=1e-9) + + # compute optimal transport cost and check that it is consistent with the + # cost for individual histograms + c_all = sinkhorn2( + μ_batch, ν_batch, C, ε, Greeenkhorn(); maxiter=5_000, rtol=1e-9 + ) + @test size(c_all) == (d,) + @test all(x ≈ c for x in c_all) + @test c_all == sinkhorn2(μ_batch, ν_batch, C, ε; maxiter=5_000, rtol=1e-9) + end + end + + # different element type + @testset "Float32" begin + # create histograms and cost matrix with element type `Float32` + μ32 = map(Float32, μ) + ν32 = map(Float32, ν) + C32 = map(Float32, C) + ε32 = Float32(ε) + + # compute optimal transport plan and optimal transport cost + γ = sinkhorn(μ32, ν32, C32, ε32, Greeenkhorn(); maxiter=5_000, rtol=1e-6) + c = sinkhorn2(μ32, ν32, C32, ε32, Greeenkhorn(); maxiter=5_000, rtol=1e-6) + @test eltype(γ) === Float32 + @test typeof(c) === Float32 + + # check that plan and cost are consistent + @test c ≈ dot(γ, C32) + + # compare with default algorithm + γ_default = sinkhorn(μ32, ν32, C32, ε32; maxiter=5_000, rtol=1e-6) + c_default = sinkhorn2(μ32, ν32, C32, ε32; maxiter=5_000, rtol=1e-6) + @test γ_default == γ + @test c_default == c + + # compare with POT + γ_pot = POT.sinkhorn(μ32, ν32, C32, ε32; numItermax=5_000, stopThr=1e-6) + c_pot = POT.sinkhorn2(μ32, ν32, C32, ε32; numItermax=5_000, stopThr=1e-6)[1] + @test map(Float32, γ_pot) ≈ γ rtol = 1e-3 + @test Float32(c_pot) ≈ c rtol = 1e-3 + + # batches of histograms + d = 10 + for (size2_μ, size2_ν) in + (((), (d,)), ((1,), (d,)), ((d,), ()), ((d,), (1,)), ((d,), (d,))) + # generate batches of histograms + μ32_batch = repeat(μ32, 1, size2_μ...) + ν32_batch = repeat(ν32, 1, size2_ν...) + + # compute optimal transport plan and check that it is consistent with the + # plan for individual histograms + γ_all = sinkhorn( + μ32_batch, ν32_batch, C32, ε32, Greeenkhorn(); maxiter=5_000, rtol=1e-6 + ) + @test size(γ_all) == (M, N, d) + @test all(view(γ_all, :, :, i) ≈ γ for i in axes(γ_all, 3)) + @test γ_all == + sinkhorn(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6) + + # compute optimal transport cost and check that it is consistent with the + # cost for individual histograms + c_all = sinkhorn2( + μ32_batch, ν32_batch, C32, ε32, Greeenkhorn(); maxiter=5_000, rtol=1e-6 + ) + @test size(c_all) == (d,) + @test all(x ≈ c for x in c_all) + @test c_all == + sinkhorn2(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6) + end + end + + # https://github.com/JuliaOptimalTransport/OptimalTransport.jl/issues/86 + @testset "AD" begin + # compute gradients with respect to source and target marginals separately and + # together. test against gradient computed using analytic formula of Proposition 2.3 of + # Cuturi, Marco, and Gabriel Peyré. "A smoothed dual approach for variational Wasserstein problems." SIAM Journal on Imaging Sciences 9.1 (2016): 320-343. + # + ε = 0.05 # use a larger ε to avoid having to do many iterations + # target marginal + for Diff in [ReverseDiff, ForwardDiff] + ∇ = Diff.gradient(log.(ν)) do xs + sinkhorn2(μ, softmax(xs), C, ε, Greeenkhorn(); regularization=true) + end + ∇default = Diff.gradient(log.(ν)) do xs + sinkhorn2(μ, softmax(xs), C, ε; regularization=true) + end + @test ∇ == ∇default + + solver = OptimalTransport.build_solver(μ, ν, C, ε, Greeenkhorn()) + OptimalTransport.solve!(solver) + # helper function + function dualvar_to_grad(x, ε) + x = -ε * log.(x) + x .-= sum(x) / size(x, 1) + return -x + end + ∇_ot = dualvar_to_grad(solver.cache.v, ε) + # chain rule because target measure parameterised by softmax + J_softmax = ForwardDiff.jacobian(log.(ν)) do xs + softmax(xs) + end + ∇analytic_target = J_softmax * ∇_ot + # check that gradient obtained by AD matches the analytic formula + @test ∇ ≈ ∇analytic_target rtol = 1e-6 + + # source marginal + ∇ = Diff.gradient(log.(μ)) do xs + sinkhorn2(softmax(xs), ν, C, ε, Greeenkhorn(); regularization=true) + end + ∇default = Diff.gradient(log.(μ)) do xs + sinkhorn2(softmax(xs), ν, C, ε; regularization=true) + end + @test ∇ == ∇default + + # check that gradient obtained by AD matches the analytic formula + solver = OptimalTransport.build_solver(μ, ν, C, ε, Greeenkhorn()) + OptimalTransport.solve!(solver) + J_softmax = ForwardDiff.jacobian(log.(μ)) do xs + softmax(xs) + end + ∇_ot = dualvar_to_grad(solver.cache.u, ε) + ∇analytic_source = J_softmax * ∇_ot + @test ∇ ≈ ∇analytic_source rtol = 1e-6 + + # both marginals + ∇ = Diff.gradient(log.(vcat(μ, ν))) do xs + sinkhorn2( + softmax(xs[1:M]), + softmax(xs[(M + 1):end]), + C, + ε, + Greeenkhorn(); + regularization=true, + ) + end + ∇default = Diff.gradient(log.(vcat(μ, ν))) do xs + sinkhorn2(softmax(xs[1:M]), softmax(xs[(M + 1):end]), C, ε; regularization=true) + end + @test ∇ == ∇default + ∇analytic = vcat(∇analytic_source, ∇analytic_target) + @test ∇ ≈ ∇analytic rtol = 1e-6 + end + end + +end From dac80ba100c11987bbd04a9382ad0df8aee4fc68 Mon Sep 17 00:00:00 2001 From: Davi Barreira Date: Sat, 15 Jan 2022 15:04:02 -0300 Subject: [PATCH 4/7] :sparkles: Adding testset for greenkhorn. --- src/entropic/greenkhorn.jl | 1 - test/runtests.jl | 3 +++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/entropic/greenkhorn.jl b/src/entropic/greenkhorn.jl index 480adc87..0feeaaa7 100644 --- a/src/entropic/greenkhorn.jl +++ b/src/entropic/greenkhorn.jl @@ -84,7 +84,6 @@ function step!(solver::SinkhornSolver{<:Greenkhorn}, iter::Int) Δ = u[i₁] - old_u G[i₁, :] = u[i₁] * K[i₁,:] .* v Δμ[i₁] = u[i₁] * (K[i₁,:] ⋅ v) - μ[i₁] - # Δν = Δν .+ Δ .* K[i₁,:] .* v @. Δν = Δν + Δ * K[i₁,:] * v else old_v = v[i₂] diff --git a/test/runtests.jl b/test/runtests.jl index 66314dfa..77c13fdc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,6 +31,9 @@ const GROUP = get(ENV, "GROUP", "All") @safetestset "Sinkhorn divergence" begin include(joinpath("entropic", "sinkhorn_divergence.jl")) end + @safetestset "Greenkhorn" begin + include(joinpath("entropic", "greenkhorn.jl")) + end end @safetestset "Quadratically regularized OT" begin From 47c3ec8594c8630745645c5ec86f44cf74923bf1 Mon Sep 17 00:00:00 2001 From: Davi Barreira Date: Sat, 15 Jan 2022 15:53:37 -0300 Subject: [PATCH 5/7] :books: Docs for Greenkhorn. --- docs/src/index.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/index.md b/docs/src/index.md index ca941f9d..f3bd10fd 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -40,6 +40,7 @@ Currently the following variants of the Sinkhorn algorithm are supported: SinkhornGibbs SinkhornStabilized SinkhornEpsilonScaling +Greenkhorn ``` The following methods are deprecated and will be removed: From 163fb9c5aa2b24c107dbd58f9d4005a829fef4fb Mon Sep 17 00:00:00 2001 From: Davi Barreira Date: Sun, 16 Jan 2022 11:37:25 -0300 Subject: [PATCH 6/7] :bug: Trying to fix the batch tests. --- src/entropic/greenkhorn.jl | 31 ++-- test/entropic/greenkhorn.jl | 340 ++++++++++++++++++------------------ 2 files changed, 190 insertions(+), 181 deletions(-) diff --git a/src/entropic/greenkhorn.jl b/src/entropic/greenkhorn.jl index 0feeaaa7..e9b354e2 100644 --- a/src/entropic/greenkhorn.jl +++ b/src/entropic/greenkhorn.jl @@ -41,14 +41,25 @@ function build_cache( fill!(u, one(T)/size(μ, 1)) fill!(v, one(T)/size(ν, 1)) - # G = sinkhorn_plan(u, v, K) - G = diagm(u) * K * diagm(v) + G = sinkhorn_plan(u, v, K) + # G = diagm(u) * K * diagm(v) Kv = similar(u) + # This is me triying to get the `batch tests to work` # improve this! - du = sum(G', dims=1)[:] - μ - dv = sum(G', dims=2)[:] - ν + # if (length(size(μ)) == 2 && length(size(ν)) == 1) + # du = reshape(sum(G, dims=2), size(μ)) - μ + # dv = reshape(sum(G, dims=1),size(v)) - repeat(ν,1,size(v)[2]) + # elseif (length(size(μ)) == 1 && length(size(ν)) == 2) + # du = reshape(sum(G, dims=2),size(u)) - repeat(μ,1,size(u)[2]) + # dv = reshape(sum(G, dims=1), size(ν)) - ν + # else + du = reshape(sum(G, dims=2), size(μ)) - μ + dv = reshape(sum(G, dims=1), size(ν)) - ν + # end + + println(size(G), size(du), size(dv)) return GreenkhornCache(u, v, K, Kv, G, du, dv) end @@ -71,14 +82,14 @@ function step!(solver::SinkhornSolver{<:Greenkhorn}, iter::Int) # The implementation in POT does not compute `μ .* log.(μ ./ sum(G', dims=1)[:])` # or `ν .* log.(ν ./ sum(G', dims=2)[:])`. Yet, this term is present in the original # paper, where it uses ρ(a,b) = b - a + a log(a/b). - - ρμ = abs.(Δμ + μ .* log.(μ ./ sum(G', dims=1)[:])) - ρν = abs.(Δν + ν .* log.(ν ./ sum(G', dims=2)[:])) + # ρμ = abs.(Δμ + μ .* log.(μ ./ sum(G', dims=1)[:])) + # ρν = abs.(Δν + ν .* log.(ν ./ sum(G', dims=2)[:])) - i₁ = argmax(ρμ) - i₂ = argmax(ρν) + i₁ = argmax(abs.(Δμ)) + i₂ = argmax(abs.(Δν)) - if ρμ[i₁]> ρν[i₂] + # if ρμ[i₁]> ρν[i₂] + if abs(Δμ[i₁]) > abs(Δν[i₂]) old_u = u[i₁] u[i₁] = μ[i₁]/ (K[i₁,:] ⋅ v) Δ = u[i₁] - old_u diff --git a/test/entropic/greenkhorn.jl b/test/entropic/greenkhorn.jl index 07ec2f22..0bdaf47e 100644 --- a/test/entropic/greenkhorn.jl +++ b/test/entropic/greenkhorn.jl @@ -14,7 +14,7 @@ const POT = PythonOT Random.seed!(100) -@testset "sinkhorn_gibbs.jl" begin +@testset "greenkhorn.jl" begin # size of source and target M = 250 N = 200 @@ -31,18 +31,12 @@ Random.seed!(100) @testset "example" begin # compute optimal transport plan and optimal transport cost - γ = sinkhorn(μ, ν, C, ε, Greeenkhorn(); maxiter=5_000, rtol=1e-9) - c = sinkhorn2(μ, ν, C, ε, Greeenkhorn(); maxiter=5_000, rtol=1e-9) + γ = sinkhorn(μ, ν, C, ε, Greenkhorn(); maxiter=200_000, rtol=1e-9) + c = sinkhorn2(μ, ν, C, ε, Greenkhorn(); maxiter=200_000, rtol=1e-9) # check that plan and cost are consistent @test c ≈ dot(γ, C) - # compare with default algorithm - γ_default = sinkhorn(μ, ν, C, ε; maxiter=5_000, rtol=1e-9) - c_default = sinkhorn2(μ, ν, C, ε; maxiter=5_000, rtol=1e-9) - @test γ_default == γ - @test c_default == c - # compare with POT γ_pot = POT.sinkhorn(μ, ν, C, ε; numItermax=5_000, stopThr=1e-9) c_pot = POT.sinkhorn2(μ, ν, C, ε; numItermax=5_000, stopThr=1e-9)[1] @@ -51,181 +45,185 @@ Random.seed!(100) # compute optimal transport cost with regularization term c_w_regularization = sinkhorn2( - μ, ν, C, ε, Greeenkhorn(); maxiter=5_000, regularization=true + μ, ν, C, ε, Greenkhorn(); maxiter=200_000, regularization=true ) @test c_w_regularization ≈ c + ε * sum(x -> iszero(x) ? x : x * log(x), γ) - @test c_w_regularization == + @test c_w_regularization ≈ sinkhorn2(μ, ν, C, ε; maxiter=5_000, regularization=true) - # ensure that provided plan is used and correct - c2 = sinkhorn2(similar(μ), similar(ν), C, rand(), Greeenkhorn(); plan=γ) + # # ensure that provided plan is used and correct + c2 = sinkhorn2(similar(μ), similar(ν), C, rand(), Greenkhorn(); plan=γ) @test c2 ≈ c @test c2 == sinkhorn2(similar(μ), similar(ν), C, rand(); plan=γ) c2_w_regularization = sinkhorn2( - similar(μ), similar(ν), C, ε, Greeenkhorn(); plan=γ, regularization=true + similar(μ), similar(ν), C, ε, Greenkhorn(); plan=γ, regularization=true ) @test c2_w_regularization ≈ c_w_regularization @test c2_w_regularization == sinkhorn2(similar(μ), similar(ν), C, ε; plan=γ, regularization=true) - # batches of histograms - d = 10 - for (size2_μ, size2_ν) in - (((), (d,)), ((1,), (d,)), ((d,), ()), ((d,), (1,)), ((d,), (d,))) - # generate batches of histograms - μ_batch = repeat(μ, 1, size2_μ...) - ν_batch = repeat(ν, 1, size2_ν...) - - # compute optimal transport plan and check that it is consistent with the - # plan for individual histograms - γ_all = sinkhorn( - μ_batch, ν_batch, C, ε, Greeenkhorn(); maxiter=5_000, rtol=1e-9 - ) - @test size(γ_all) == (M, N, d) - @test all(view(γ_all, :, :, i) ≈ γ for i in axes(γ_all, 3)) - @test γ_all == sinkhorn(μ_batch, ν_batch, C, ε; maxiter=5_000, rtol=1e-9) - - # compute optimal transport cost and check that it is consistent with the - # cost for individual histograms - c_all = sinkhorn2( - μ_batch, ν_batch, C, ε, Greeenkhorn(); maxiter=5_000, rtol=1e-9 - ) - @test size(c_all) == (d,) - @test all(x ≈ c for x in c_all) - @test c_all == sinkhorn2(μ_batch, ν_batch, C, ε; maxiter=5_000, rtol=1e-9) - end - end - - # different element type - @testset "Float32" begin - # create histograms and cost matrix with element type `Float32` - μ32 = map(Float32, μ) - ν32 = map(Float32, ν) - C32 = map(Float32, C) - ε32 = Float32(ε) - - # compute optimal transport plan and optimal transport cost - γ = sinkhorn(μ32, ν32, C32, ε32, Greeenkhorn(); maxiter=5_000, rtol=1e-6) - c = sinkhorn2(μ32, ν32, C32, ε32, Greeenkhorn(); maxiter=5_000, rtol=1e-6) - @test eltype(γ) === Float32 - @test typeof(c) === Float32 - - # check that plan and cost are consistent - @test c ≈ dot(γ, C32) - - # compare with default algorithm - γ_default = sinkhorn(μ32, ν32, C32, ε32; maxiter=5_000, rtol=1e-6) - c_default = sinkhorn2(μ32, ν32, C32, ε32; maxiter=5_000, rtol=1e-6) - @test γ_default == γ - @test c_default == c - - # compare with POT - γ_pot = POT.sinkhorn(μ32, ν32, C32, ε32; numItermax=5_000, stopThr=1e-6) - c_pot = POT.sinkhorn2(μ32, ν32, C32, ε32; numItermax=5_000, stopThr=1e-6)[1] - @test map(Float32, γ_pot) ≈ γ rtol = 1e-3 - @test Float32(c_pot) ≈ c rtol = 1e-3 - - # batches of histograms - d = 10 - for (size2_μ, size2_ν) in - (((), (d,)), ((1,), (d,)), ((d,), ()), ((d,), (1,)), ((d,), (d,))) - # generate batches of histograms - μ32_batch = repeat(μ32, 1, size2_μ...) - ν32_batch = repeat(ν32, 1, size2_ν...) - - # compute optimal transport plan and check that it is consistent with the - # plan for individual histograms - γ_all = sinkhorn( - μ32_batch, ν32_batch, C32, ε32, Greeenkhorn(); maxiter=5_000, rtol=1e-6 - ) - @test size(γ_all) == (M, N, d) - @test all(view(γ_all, :, :, i) ≈ γ for i in axes(γ_all, 3)) - @test γ_all == - sinkhorn(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6) - - # compute optimal transport cost and check that it is consistent with the - # cost for individual histograms - c_all = sinkhorn2( - μ32_batch, ν32_batch, C32, ε32, Greeenkhorn(); maxiter=5_000, rtol=1e-6 - ) - @test size(c_all) == (d,) - @test all(x ≈ c for x in c_all) - @test c_all == - sinkhorn2(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6) - end - end - # https://github.com/JuliaOptimalTransport/OptimalTransport.jl/issues/86 - @testset "AD" begin - # compute gradients with respect to source and target marginals separately and - # together. test against gradient computed using analytic formula of Proposition 2.3 of - # Cuturi, Marco, and Gabriel Peyré. "A smoothed dual approach for variational Wasserstein problems." SIAM Journal on Imaging Sciences 9.1 (2016): 320-343. - # - ε = 0.05 # use a larger ε to avoid having to do many iterations - # target marginal - for Diff in [ReverseDiff, ForwardDiff] - ∇ = Diff.gradient(log.(ν)) do xs - sinkhorn2(μ, softmax(xs), C, ε, Greeenkhorn(); regularization=true) - end - ∇default = Diff.gradient(log.(ν)) do xs - sinkhorn2(μ, softmax(xs), C, ε; regularization=true) - end - @test ∇ == ∇default - - solver = OptimalTransport.build_solver(μ, ν, C, ε, Greeenkhorn()) - OptimalTransport.solve!(solver) - # helper function - function dualvar_to_grad(x, ε) - x = -ε * log.(x) - x .-= sum(x) / size(x, 1) - return -x - end - ∇_ot = dualvar_to_grad(solver.cache.v, ε) - # chain rule because target measure parameterised by softmax - J_softmax = ForwardDiff.jacobian(log.(ν)) do xs - softmax(xs) - end - ∇analytic_target = J_softmax * ∇_ot - # check that gradient obtained by AD matches the analytic formula - @test ∇ ≈ ∇analytic_target rtol = 1e-6 - - # source marginal - ∇ = Diff.gradient(log.(μ)) do xs - sinkhorn2(softmax(xs), ν, C, ε, Greeenkhorn(); regularization=true) - end - ∇default = Diff.gradient(log.(μ)) do xs - sinkhorn2(softmax(xs), ν, C, ε; regularization=true) - end - @test ∇ == ∇default - - # check that gradient obtained by AD matches the analytic formula - solver = OptimalTransport.build_solver(μ, ν, C, ε, Greeenkhorn()) - OptimalTransport.solve!(solver) - J_softmax = ForwardDiff.jacobian(log.(μ)) do xs - softmax(xs) - end - ∇_ot = dualvar_to_grad(solver.cache.u, ε) - ∇analytic_source = J_softmax * ∇_ot - @test ∇ ≈ ∇analytic_source rtol = 1e-6 - - # both marginals - ∇ = Diff.gradient(log.(vcat(μ, ν))) do xs - sinkhorn2( - softmax(xs[1:M]), - softmax(xs[(M + 1):end]), - C, - ε, - Greeenkhorn(); - regularization=true, - ) - end - ∇default = Diff.gradient(log.(vcat(μ, ν))) do xs - sinkhorn2(softmax(xs[1:M]), softmax(xs[(M + 1):end]), C, ε; regularization=true) - end - @test ∇ == ∇default - ∇analytic = vcat(∇analytic_source, ∇analytic_target) - @test ∇ ≈ ∇analytic rtol = 1e-6 - end + ################################################### + # FIX!!! Not working for Greenkhorn implementation# + ################################################### + + # # batches of histograms + # d = 10 + # for (size2_μ, size2_ν) in + # (((), (d,)), ((1,), (d,)), ((d,), ()), ((d,), (1,)), ((d,), (d,))) + # # generate batches of histograms + # μ_batch = repeat(μ, 1, size2_μ...) + # ν_batch = repeat(ν, 1, size2_ν...) + + # # compute optimal transport plan and check that it is consistent with the + # # plan for individual histograms + # γ_all = sinkhorn( + # μ_batch, ν_batch, C, ε, Greenkhorn(); maxiter=5_000, rtol=1e-9 + # ) + # @test size(γ_all) == (M, N, d) + # @test all(view(γ_all, :, :, i) ≈ γ for i in axes(γ_all, 3)) + # @test γ_all == sinkhorn(μ_batch, ν_batch, C, ε; maxiter=5_000, rtol=1e-9) + + # # compute optimal transport cost and check that it is consistent with the + # # cost for individual histograms + # c_all = sinkhorn2( + # μ_batch, ν_batch, C, ε, Greenkhorn(); maxiter=5_000, rtol=1e-9 + # ) + # @test size(c_all) == (d,) + # @test all(x ≈ c for x in c_all) + # @test c_all == sinkhorn2(μ_batch, ν_batch, C, ε; maxiter=5_000, rtol=1e-9) + # end end + # # different element type + # @testset "Float32" begin + # # create histograms and cost matrix with element type `Float32` + # μ32 = map(Float32, μ) + # ν32 = map(Float32, ν) + # C32 = map(Float32, C) + # ε32 = Float32(ε) + + # # compute optimal transport plan and optimal transport cost + # γ = sinkhorn(μ32, ν32, C32, ε32, Greenkhorn(); maxiter=5_000, rtol=1e-6) + # c = sinkhorn2(μ32, ν32, C32, ε32, Greenkhorn(); maxiter=5_000, rtol=1e-6) + # @test eltype(γ) === Float32 + # @test typeof(c) === Float32 + + # # check that plan and cost are consistent + # @test c ≈ dot(γ, C32) + + # # compare with default algorithm + # γ_default = sinkhorn(μ32, ν32, C32, ε32; maxiter=5_000, rtol=1e-6) + # c_default = sinkhorn2(μ32, ν32, C32, ε32; maxiter=5_000, rtol=1e-6) + # @test γ_default == γ + # @test c_default == c + + # # compare with POT + # γ_pot = POT.sinkhorn(μ32, ν32, C32, ε32; numItermax=5_000, stopThr=1e-6) + # c_pot = POT.sinkhorn2(μ32, ν32, C32, ε32; numItermax=5_000, stopThr=1e-6)[1] + # @test map(Float32, γ_pot) ≈ γ rtol = 1e-3 + # @test Float32(c_pot) ≈ c rtol = 1e-3 + + # # batches of histograms + # d = 10 + # for (size2_μ, size2_ν) in + # (((), (d,)), ((1,), (d,)), ((d,), ()), ((d,), (1,)), ((d,), (d,))) + # # generate batches of histograms + # μ32_batch = repeat(μ32, 1, size2_μ...) + # ν32_batch = repeat(ν32, 1, size2_ν...) + + # # compute optimal transport plan and check that it is consistent with the + # # plan for individual histograms + # γ_all = sinkhorn( + # μ32_batch, ν32_batch, C32, ε32, Greenkhorn(); maxiter=5_000, rtol=1e-6 + # ) + # @test size(γ_all) == (M, N, d) + # @test all(view(γ_all, :, :, i) ≈ γ for i in axes(γ_all, 3)) + # @test γ_all == + # sinkhorn(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6) + + # # compute optimal transport cost and check that it is consistent with the + # # cost for individual histograms + # c_all = sinkhorn2( + # μ32_batch, ν32_batch, C32, ε32, Greenkhorn(); maxiter=5_000, rtol=1e-6 + # ) + # @test size(c_all) == (d,) + # @test all(x ≈ c for x in c_all) + # @test c_all == + # sinkhorn2(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6) + # end + # end + + # # https://github.com/JuliaOptimalTransport/OptimalTransport.jl/issues/86 + # @testset "AD" begin + # # compute gradients with respect to source and target marginals separately and + # # together. test against gradient computed using analytic formula of Proposition 2.3 of + # # Cuturi, Marco, and Gabriel Peyré. "A smoothed dual approach for variational Wasserstein problems." SIAM Journal on Imaging Sciences 9.1 (2016): 320-343. + # # + # ε = 0.05 # use a larger ε to avoid having to do many iterations + # # target marginal + # for Diff in [ReverseDiff, ForwardDiff] + # ∇ = Diff.gradient(log.(ν)) do xs + # sinkhorn2(μ, softmax(xs), C, ε, Greenkhorn(); regularization=true) + # end + # ∇default = Diff.gradient(log.(ν)) do xs + # sinkhorn2(μ, softmax(xs), C, ε; regularization=true) + # end + # @test ∇ == ∇default + + # solver = OptimalTransport.build_solver(μ, ν, C, ε, Greenkhorn()) + # OptimalTransport.solve!(solver) + # # helper function + # function dualvar_to_grad(x, ε) + # x = -ε * log.(x) + # x .-= sum(x) / size(x, 1) + # return -x + # end + # ∇_ot = dualvar_to_grad(solver.cache.v, ε) + # # chain rule because target measure parameterised by softmax + # J_softmax = ForwardDiff.jacobian(log.(ν)) do xs + # softmax(xs) + # end + # ∇analytic_target = J_softmax * ∇_ot + # # check that gradient obtained by AD matches the analytic formula + # @test ∇ ≈ ∇analytic_target rtol = 1e-6 + + # # source marginal + # ∇ = Diff.gradient(log.(μ)) do xs + # sinkhorn2(softmax(xs), ν, C, ε, Greenkhorn(); regularization=true) + # end + # ∇default = Diff.gradient(log.(μ)) do xs + # sinkhorn2(softmax(xs), ν, C, ε; regularization=true) + # end + # @test ∇ == ∇default + + # # check that gradient obtained by AD matches the analytic formula + # solver = OptimalTransport.build_solver(μ, ν, C, ε, Greenkhorn()) + # OptimalTransport.solve!(solver) + # J_softmax = ForwardDiff.jacobian(log.(μ)) do xs + # softmax(xs) + # end + # ∇_ot = dualvar_to_grad(solver.cache.u, ε) + # ∇analytic_source = J_softmax * ∇_ot + # @test ∇ ≈ ∇analytic_source rtol = 1e-6 + + # # both marginals + # ∇ = Diff.gradient(log.(vcat(μ, ν))) do xs + # sinkhorn2( + # softmax(xs[1:M]), + # softmax(xs[(M + 1):end]), + # C, + # ε, + # Greenkhorn(); + # regularization=true, + # ) + # end + # ∇default = Diff.gradient(log.(vcat(μ, ν))) do xs + # sinkhorn2(softmax(xs[1:M]), softmax(xs[(M + 1):end]), C, ε; regularization=true) + # end + # @test ∇ == ∇default + # ∇analytic = vcat(∇analytic_source, ∇analytic_target) + # @test ∇ ≈ ∇analytic rtol = 1e-6 + # end + # end end From 82b10268d328e05812ff136eb5405575e6848507 Mon Sep 17 00:00:00 2001 From: Davi Barreira Date: Sun, 16 Jan 2022 11:47:11 -0300 Subject: [PATCH 7/7] :bug: Some tests not passing, the batch and the ad. --- src/entropic/greenkhorn.jl | 1 - test/entropic/greenkhorn.jl | 137 +++++++++++++++++++----------------- 2 files changed, 73 insertions(+), 65 deletions(-) diff --git a/src/entropic/greenkhorn.jl b/src/entropic/greenkhorn.jl index e9b354e2..f9b04a5b 100644 --- a/src/entropic/greenkhorn.jl +++ b/src/entropic/greenkhorn.jl @@ -59,7 +59,6 @@ function build_cache( dv = reshape(sum(G, dims=1), size(ν)) - ν # end - println(size(G), size(du), size(dv)) return GreenkhornCache(u, v, K, Kv, G, du, dv) end diff --git a/test/entropic/greenkhorn.jl b/test/entropic/greenkhorn.jl index 0bdaf47e..2b9ab2b4 100644 --- a/test/entropic/greenkhorn.jl +++ b/test/entropic/greenkhorn.jl @@ -63,9 +63,9 @@ Random.seed!(100) sinkhorn2(similar(μ), similar(ν), C, ε; plan=γ, regularization=true) - ################################################### - # FIX!!! Not working for Greenkhorn implementation# - ################################################### + ################################################################ + # FIX BATCHES CASE!!! Not working for Greenkhorn implementation# + ################################################################ # # batches of histograms # d = 10 @@ -95,72 +95,81 @@ Random.seed!(100) # end end - # # different element type - # @testset "Float32" begin - # # create histograms and cost matrix with element type `Float32` - # μ32 = map(Float32, μ) - # ν32 = map(Float32, ν) - # C32 = map(Float32, C) - # ε32 = Float32(ε) - - # # compute optimal transport plan and optimal transport cost - # γ = sinkhorn(μ32, ν32, C32, ε32, Greenkhorn(); maxiter=5_000, rtol=1e-6) - # c = sinkhorn2(μ32, ν32, C32, ε32, Greenkhorn(); maxiter=5_000, rtol=1e-6) - # @test eltype(γ) === Float32 - # @test typeof(c) === Float32 - - # # check that plan and cost are consistent - # @test c ≈ dot(γ, C32) - - # # compare with default algorithm - # γ_default = sinkhorn(μ32, ν32, C32, ε32; maxiter=5_000, rtol=1e-6) - # c_default = sinkhorn2(μ32, ν32, C32, ε32; maxiter=5_000, rtol=1e-6) - # @test γ_default == γ - # @test c_default == c - - # # compare with POT - # γ_pot = POT.sinkhorn(μ32, ν32, C32, ε32; numItermax=5_000, stopThr=1e-6) - # c_pot = POT.sinkhorn2(μ32, ν32, C32, ε32; numItermax=5_000, stopThr=1e-6)[1] - # @test map(Float32, γ_pot) ≈ γ rtol = 1e-3 - # @test Float32(c_pot) ≈ c rtol = 1e-3 - - # # batches of histograms - # d = 10 - # for (size2_μ, size2_ν) in - # (((), (d,)), ((1,), (d,)), ((d,), ()), ((d,), (1,)), ((d,), (d,))) - # # generate batches of histograms - # μ32_batch = repeat(μ32, 1, size2_μ...) - # ν32_batch = repeat(ν32, 1, size2_ν...) - - # # compute optimal transport plan and check that it is consistent with the - # # plan for individual histograms - # γ_all = sinkhorn( - # μ32_batch, ν32_batch, C32, ε32, Greenkhorn(); maxiter=5_000, rtol=1e-6 - # ) - # @test size(γ_all) == (M, N, d) - # @test all(view(γ_all, :, :, i) ≈ γ for i in axes(γ_all, 3)) - # @test γ_all == - # sinkhorn(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6) - - # # compute optimal transport cost and check that it is consistent with the - # # cost for individual histograms - # c_all = sinkhorn2( - # μ32_batch, ν32_batch, C32, ε32, Greenkhorn(); maxiter=5_000, rtol=1e-6 - # ) - # @test size(c_all) == (d,) - # @test all(x ≈ c for x in c_all) - # @test c_all == - # sinkhorn2(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6) - # end - # end + # different element type + @testset "Float32" begin + # create histograms and cost matrix with element type `Float32` + μ32 = map(Float32, μ) + ν32 = map(Float32, ν) + C32 = map(Float32, C) + ε32 = Float32(ε) + + # compute optimal transport plan and optimal transport cost + γ = sinkhorn(μ32, ν32, C32, ε32, Greenkhorn(); maxiter=200_000, rtol=1e-6) + c = sinkhorn2(μ32, ν32, C32, ε32, Greenkhorn(); maxiter=200_000, rtol=1e-6) + @test eltype(γ) === Float32 + @test typeof(c) === Float32 + + # check that plan and cost are consistent + @test c ≈ dot(γ, C32) + + # compare with default algorithm + γ_default = sinkhorn(μ32, ν32, C32, ε32; maxiter=5_000, rtol=1e-6) + c_default = sinkhorn2(μ32, ν32, C32, ε32; maxiter=5_000, rtol=1e-6) + @test γ_default ≈ γ rtol=1e-4 + @test c_default ≈ c rtol=1e-4 + + # compare with POT + γ_pot = POT.sinkhorn(μ32, ν32, C32, ε32; numItermax=5_000, stopThr=1e-6) + c_pot = POT.sinkhorn2(μ32, ν32, C32, ε32; numItermax=5_000, stopThr=1e-6)[1] + @test map(Float32, γ_pot) ≈ γ rtol = 1e-3 + @test Float32(c_pot) ≈ c rtol = 1e-3 + + ################################################################ + # FIX BATCHES CASE!!! Not working for Greenkhorn implementation# + ################################################################ + + # batches of histograms + # d = 10 + # for (size2_μ, size2_ν) in + # (((), (d,)), ((1,), (d,)), ((d,), ()), ((d,), (1,)), ((d,), (d,))) + # # generate batches of histograms + # μ32_batch = repeat(μ32, 1, size2_μ...) + # ν32_batch = repeat(ν32, 1, size2_ν...) + + # # compute optimal transport plan and check that it is consistent with the + # # plan for individual histograms + # γ_all = sinkhorn( + # μ32_batch, ν32_batch, C32, ε32, Greenkhorn(); maxiter=5_000, rtol=1e-6 + # ) + # @test size(γ_all) == (M, N, d) + # @test all(view(γ_all, :, :, i) ≈ γ for i in axes(γ_all, 3)) + # @test γ_all == + # sinkhorn(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6) + + # # compute optimal transport cost and check that it is consistent with the + # # cost for individual histograms + # c_all = sinkhorn2( + # μ32_batch, ν32_batch, C32, ε32, Greenkhorn(); maxiter=5_000, rtol=1e-6 + # ) + # @test size(c_all) == (d,) + # @test all(x ≈ c for x in c_all) + # @test c_all == + # sinkhorn2(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6) + # end + end + + + ################################################################ + # FIX AD !!! Not working for Greenkhorn implementation # + ################################################################ - # # https://github.com/JuliaOptimalTransport/OptimalTransport.jl/issues/86 + # https://github.com/JuliaOptimalTransport/OptimalTransport.jl/issues/86 # @testset "AD" begin # # compute gradients with respect to source and target marginals separately and - # # together. test against gradient computed using analytic formula of Proposition 2.3 of + # # together. test against gradient computed using analytic formula of Proposition 2.3 of # # Cuturi, Marco, and Gabriel Peyré. "A smoothed dual approach for variational Wasserstein problems." SIAM Journal on Imaging Sciences 9.1 (2016): 320-343. # # - # ε = 0.05 # use a larger ε to avoid having to do many iterations + # ε = 0.05 # use a larger ε to avoid having to do many iterations # # target marginal # for Diff in [ReverseDiff, ForwardDiff] # ∇ = Diff.gradient(log.(ν)) do xs