Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute Wasserstein distance between a density and a sum of Diracs #184

Closed
Vilin97 opened this issue May 3, 2024 · 6 comments
Closed

Compute Wasserstein distance between a density and a sum of Diracs #184

Vilin97 opened this issue May 3, 2024 · 6 comments

Comments

@Vilin97
Copy link

Vilin97 commented May 3, 2024

I have two distributions in d-dimensional space, between which I want to compute Wasserstein distance. One distribution is a sum of Dirac delta functions (i.e. an empirical distribution), and the other is given by a density (e.g. a Gaussian). Is my best option to compute histograms of both and compute the distance between the histograms? I don't like this approach because the result will depend on the bin width, and bin width choice is a hard problem. Is there a better way?

Here is what I have so far:

using Distributions, LinearAlgebra, StatsBase
σ = MvNormal(I(2))
μ = rand(σ, 1000)
μ_hist = fit(Histogram, (μ[1,:], μ[2,:])) # make histogram of the empirical distribution
μ_mass = reshape(μ_hist.weights ./ 1000, :)
support = (μ_hist.edges[1][1:end-1] .+  Float64(μ_hist.edges[1].step), μ_hist.edges[2][1:end-1] .+  Float64(μ_hist.edges[2].step))
σ_mass = reshape([pdf(σ, [x,y]) for x in support[1], y in support[2]], :)
σ_mass ./= sum(σ_mass) # normalize to 1
sum(μ_mass) ≈ sum(σ_mass) # make sure the OT problem is balanced
C = reshape([sum(abs2, [x1,y1] .- [x2,y2]) for x1 in support[1], y1 in support[2], x2 in support[1], y2 in support[2]], length(μ_mass), length(σ_mass)) # |x-y|²
transport_plan = sinkhorn(μ_mass, σ_mass, C, 0.01)

Questions:

  1. How to obtain the cost of the plan now?
  2. What if I want to do exact regularized OT, how can I do it?
  3. Can I circumvent making histograms and compute W₂(μ,σ) directly?
  4. Do I have to do the reshaping into vectors? Seems annoying but I was getting errors from the statement size(C) == (size(μ, 1), size(ν, 1)) in checksize. I don't quite understand what C should be when μ and ν are not vector-valued.
@davibarreira
Copy link
Member

Hey, @Vilin97.

You don't need to use histograms per se. You can sample the distributions and compute either the sinkhorn distance or the exact distance.

There has been some time since I last used the package. I'll recover some notebooks I have, and perhaps I can do a quick example for your case. It should be straightforward.

@Vilin97
Copy link
Author

Vilin97 commented May 6, 2024

Thank you for the answer, @davibarreira . Given X, Y, both d x n matrices (d is the dimension and n is the size of the sample), how can I compute the W2 distance between the empirical distributions given by X and Y? I did not understand how to do it from the documentation of sinkhorn.

@davibarreira
Copy link
Member

davibarreira commented May 6, 2024

Random.seed!(3)
σ1 = MvNormal(I(2))
N = 100
μ = fill(1 / N, N)
μsupport = rand(σ1,100)'

M = 50
σ2 = MvNormal([5,5],I(2))
ν = fill(1 / M, M)
νsupport = rand(σ2,M)';

C = pairwise(sqeuclidean, μsupport', νsupport'; dims=2);

# This is the exact total cost
γ = emd2(μ, ν, C, Tulip.Optimizer());

ε = .5

# This is the sinkhorn cost
s = sinkhorn2(μ, ν, C, ε);

@davibarreira
Copy link
Member

@Vilin97 , does the code above answer your questions? I'm sampling two multivariate normal distributions, and then constructing the dirac dist. Then, I compute the cost matrix C using the squared euclidean distance. I'm using Distances.jl for the sqeuclidean function, and Tulip.jl for the Tulip.Optimizer().

@Vilin97
Copy link
Author

Vilin97 commented May 6, 2024

Thank you so much for this snippet! I will play around with it when I get to my laptop but from the first glance it looks like exactly what I wanted. Thank you!

@Vilin97
Copy link
Author

Vilin97 commented May 7, 2024

The code you gave works. Thank you!

@Vilin97 Vilin97 closed this as completed May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants