Skip to content

Commit

Permalink
Optimal Transport for Multivariate Gaussians (#85)
Browse files Browse the repository at this point in the history
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
  • Loading branch information
davibarreira and devmotion authored Jun 4, 2021
1 parent 4a26ec9 commit fb23ca3
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ docs/src/examples/

# Files generated by Jupyter Notebooks
*.ipynb_checkpoints
*.ipynb
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimalTransport"
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
authors = ["zsteve <stephenz@student.unimelb.edu.au>"]
version = "0.3.8"
version = "0.3.9"

[deps]
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Expand All @@ -10,6 +10,7 @@ IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -20,6 +21,7 @@ Distributions = "0.25"
IterativeSolvers = "0.8.4, 0.9"
LogExpFunctions = "0.2"
MathOptInterface = "0.9"
PDMats = "0.11"
QuadGK = "2"
StatsBase = "0.33.8"
julia = "1"
Expand All @@ -32,6 +34,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tulip = "6dd1b50a-3aae-11e9-10b5-ef983d2400fa"
HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49"

[targets]
test = ["ForwardDiff", "Pkg", "PythonOT", "Random", "SafeTestsets", "Test", "Tulip"]
test = ["ForwardDiff", "Pkg", "PythonOT", "Random", "SafeTestsets", "Test", "Tulip", "HCubature"]
2 changes: 2 additions & 0 deletions src/OptimalTransport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using IterativeSolvers, SparseArrays
using LogExpFunctions: LogExpFunctions
using MathOptInterface
using Distributions
using PDMats
using QuadGK
using StatsBase: StatsBase

Expand All @@ -22,6 +23,7 @@ export ot_cost, ot_plan, wasserstein, squared2wasserstein

const MOI = MathOptInterface

include("distances/bures.jl")
include("utils.jl")
include("exact.jl")
include("wasserstein.jl")
Expand Down
74 changes: 74 additions & 0 deletions src/distances/bures.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Code from @devmotion
# https://github.com/devmotion/\
# CalibrationErrorsDistributions.jl/blob/main/src/distances/bures.jl

"""
tr_sqrt(A::AbstractMatrix)
Compute ``\\operatorname{tr}\\big(A^{1/2}\\big)``.
"""
tr_sqrt(A::AbstractMatrix) = LinearAlgebra.tr(sqrt(A))
tr_sqrt(A::PDMats.PDMat) = tr_sqrt(A.mat)
tr_sqrt(A::PDMats.PDiagMat) = sum(sqrt, A.diag)
tr_sqrt(A::PDMats.ScalMat) = A.dim * sqrt(A.value)

"""
_gaussian_ot_A(A::AbstractMatrix, B::AbstractMatrix)
Compute
```math
A^{1/2} B A^{1/2}.
```
"""
function _gaussian_ot_A(A::AbstractMatrix, B::AbstractMatrix)
sqrt_A = sqrt(A)
return sqrt_A * B * sqrt_A
end
function _gaussian_ot_A(A::PDMats.PDiagMat, B::AbstractMatrix)
return sqrt.(A.diag) .* B .* sqrt.(A.diag')
end
function _gaussian_ot_A(A::StridedMatrix, B::PDMats.PDMat)
return PDMats.X_A_Xt(B, sqrt(A))
end
_gaussian_ot_A(A::PDMats.PDMat, B::PDMats.PDMat) = _gaussian_ot_A(A.mat, B)
_gaussian_ot_A(A::AbstractMatrix, B::PDMats.PDiagMat) = _gaussian_ot_A(B, A)
_gaussian_ot_A(A::PDMats.PDMat, B::StridedMatrix) = _gaussian_ot_A(B, A)

"""
sqbures(A::AbstractMatrix, B::AbstractMatrix)
Compute the squared Bures metric
```math
\\operatorname{tr}(A) + \\operatorname{tr}(B)
- \\operatorname{tr}\\Big({\\big(A^{1/2} B A^{1/2}\\big)}^{1/2}\\Big).
```
"""
function sqbures(A::AbstractMatrix, B::AbstractMatrix)
return LinearAlgebra.tr(A) + LinearAlgebra.tr(B) - 2 * tr_sqrt(_gaussian_ot_A(A, B))
end

# diagonal matrix
function sqbures(A::PDMats.PDiagMat, B::PDMats.PDiagMat)
if !(A.dim == B.dim)
throw(ArgumentError("matrices must have the same dimensions."))
end
return sum(zip(A.diag, B.diag)) do (x, y)
abs2(sqrt(x) - sqrt(y))
end
end

# scaled identity matrix
function sqbures(A::PDMats.ScalMat, B::AbstractMatrix)
return LinearAlgebra.tr(A) + LinearAlgebra.tr(B) - 2 * sqrt(A.value) * tr_sqrt(B)
end
sqbures(A::AbstractMatrix, B::PDMats.ScalMat) = sqbures(B, A)
sqbures(A::PDMats.ScalMat, B::PDMats.ScalMat) = A.dim * abs2(sqrt(A.value) - sqrt(B.value))

# combinations
function sqbures(A::PDMats.PDiagMat, B::PDMats.ScalMat)
sqrt_B = sqrt(B.value)
return sum(A.diag) do x
abs2(sqrt(x) - sqrt_B)
end
end
sqbures(A::PDMats.ScalMat, B::PDMats.PDiagMat) = sqbures(B, A)
79 changes: 79 additions & 0 deletions src/exact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,82 @@ end
function _ot_cost(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric, plan)
return dot(plan, StatsBase.pairwise(c, support(μ), support(ν)))
end

################
# OT Gaussians
################

"""
ot_cost(::SqEuclidean, μ::MvNormal, ν::MvNormal)
Compute the squared 2-Wasserstein distance between normal distributions `μ` and `ν` as
source and target marginals.
In this setting, the optimal transport cost can be computed as
```math
W_2^2(\\mu, \\nu) = \\|m_\\mu - m_\\nu \\|^2 + \\mathcal{B}(\\Sigma_\\mu, \\Sigma_\\nu)^2,
```
where ``\\mu = \\mathcal{N}(m_\\mu, \\Sigma_\\mu)``,
``\\nu = \\mathcal{N}(m_\\nu, \\Sigma_\\nu)``, and ``\\mathcal{B}`` is the Bures metric.
See also: [`ot_plan`](@ref), [`emd2`](@ref)
"""
function ot_cost(::SqEuclidean, μ::MvNormal, ν::MvNormal)
return sqeuclidean.μ, ν.μ) + sqbures.Σ, ν.Σ)
end

"""
ot_cost(::SqEuclidean, μ::Normal, ν::Normal)
Compute the squared 2-Wasserstein distance between univariate normal distributions `μ` and
`ν` as source and target marginals.
See also: [`ot_plan`](@ref), [`emd2`](@ref)
"""
function ot_cost(::SqEuclidean, μ::Normal, ν::Normal)
return.μ - ν.μ)^2 +.σ - ν.σ)^2
end

"""
ot_plan(::SqEuclidean, μ::MvNormal, ν::MvNormal)
Compute the optimal transport plan for the Monge-Kantorovich problem with multivariate
normal distributions `μ` and `ν` as source and target marginals and cost function
``c(x, y) = \\|x - y\\|_2^2``.
In this setting, for ``\\mu = \\mathcal{N}(m_\\mu, \\Sigma_\\mu)`` and
``\\nu = \\mathcal{N}(m_\\nu, \\Sigma_\\nu)``, the optimal transport plan is the Monge
map
```math
T \\colon x \\mapsto m_\\nu
+ \\Sigma_\\mu^{-1/2}
{\\big(\\Sigma_\\mu^{1/2} \\Sigma_\\nu \\Sigma_\\mu^{1/2}\\big)}^{1/2}\\Sigma_\\mu^{-1/2}
(x - m_\\mu).
See also: [`ot_cost`](@ref), [`emd`](@ref)
"""
function ot_plan(::SqEuclidean, μ::MvNormal, ν::MvNormal)
Σμsqrt = μ.Σ^(-1 / 2)
A = Σμsqrt * sqrt(_gaussian_ot_A.Σ, ν.Σ)) * Σμsqrt
= μ.μ
= ν.μ
T(x) =+ A * (x - mμ)
return T
end

"""
ot_plan(::SqEuclidean, μ::Normal, ν::Normal)
Compute the optimal transport plan for the Monge-Kantorovich problem with
normal distributions `μ` and `ν` as source and target marginals and cost function
``c(x, y) = \\|x - y\\|_2^2``.
See also: [`ot_cost`](@ref), [`emd`](@ref)
"""
function ot_plan(::SqEuclidean, μ::Normal, ν::Normal)
= μ.μ
= ν.μ
a = ν.σ / μ.σ
T(x) =+ a * (x - mμ)
return T
end
27 changes: 27 additions & 0 deletions test/bures.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Code from @devmotion
# https://github.com/devmotion/\
# CalibrationErrorsDistributions.jl/blob/main/src/distances/bures.jl
using OptimalTransport

using LinearAlgebra
using Random
using PDMats

@testset "bures.jl" begin
function _sqbures(A, B)
sqrt_A = sqrt(A)
return tr(A) + tr(B) - 2 * tr(sqrt(sqrt_A * B * sqrt_A'))
end

function rand_matrices(n)
A = randn(n, n)
B = A' * A + I
return B, PDMat(B), PDiagMat(diag(B)), ScalMat(n, B[1])
end

for (x, y) in Iterators.product(rand_matrices(10), rand_matrices(10))
xfull = Matrix(x)
yfull = Matrix(y)
@test OptimalTransport.sqbures(x, y) _sqbures(xfull, yfull)
end
end
52 changes: 52 additions & 0 deletions test/exact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using PythonOT: PythonOT
using Tulip
using MathOptInterface
using Distributions
using HCubature

using LinearAlgebra
using Random
Expand Down Expand Up @@ -164,4 +165,55 @@ Random.seed!(100)
@test c2 c
end
end

@testset "Multivariate Gaussians" begin
@testset "translation with constant covariance" begin
m = randn(100)
τ = rand(100)
Σ = Matrix(Hermitian(rand(100, 100) + 100I))
μ = MvNormal(m, Σ)
ν = MvNormal(m .+ τ, Σ)
@test ot_cost(SqEuclidean(), μ, ν) norm(τ)^2

x = rand(100, 10)
T = ot_plan(SqEuclidean(), μ, ν)
@test pdf(ν, mapslices(T, x; dims=1)) pdf(μ, x)
end

@testset "comparison to grid approximation" begin
μ = MvNormal([0, 0], [1 0; 0 2])
ν = MvNormal([10, 10], [2 0; 0 1])
# Constructing circular grid approximation
# Angular grid step
θ = collect(0:0.2:(2π))
θx = cos.(θ)
θy = sin.(θ)
# Radius grid step
δ = collect(0:0.2:1)
μsupp = [0.0 0.0]
νsupp = [10.0 10.0]
for i in δ[2:end]
a = [θx .* i θy .* i * 2]
b = [θx .* i * 2 θy .* i] .+ [10 10]
μsupp = vcat(μsupp, a)
νsupp = vcat(νsupp, b)
end

# Create discretized distribution
μprobs = pdf(μ, μsupp')
μprobs = μprobs ./ sum(μprobs)
νprobs = pdf(ν, νsupp')
νprobs = νprobs ./ sum(νprobs)
C = pairwise(SqEuclidean(), μsupp', νsupp')
@test emd2(μprobs, νprobs, C, Tulip.Optimizer()) ot_cost(SqEuclidean(), μ, ν) rtol =
1e-3

# Use hcubature integration to perform ``\\int c(x,T(x)) d\\mu``
T = ot_plan(SqEuclidean(), μ, ν)
c_hcubature, _ = hcubature([-10, -10], [10, 10]) do x
return sqeuclidean(x, T(x)) * pdf(μ, x)
end
@test ot_cost(SqEuclidean(), μ, ν) c_hcubature rtol = 1e-3
end
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ const GROUP = get(ENV, "GROUP", "All")
@safetestset "Wasserstein distance" begin
include("wasserstein.jl")
end
@safetestset "Bures distance" begin
include("bures.jl")
end
end

# CUDA requires Julia >= 1.6
Expand Down

2 comments on commit fb23ca3

@devmotion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/38197

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.9 -m "<description of version>" fb23ca3d8810b3992c7c10e97006cfb520cd5c9a
git push origin v0.3.9

Please sign in to comment.