Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cost Matrix function #104

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
30de49c
Initianning sikhorn divergence
davibarreira Jun 2, 2021
1a03325
Merge branch 'master' of https://github.com/JuliaOptimalTransport/Opt…
davibarreira Jun 2, 2021
4a1f380
Sinkhorn divergence implemented
davibarreira Jun 2, 2021
bdc1b5b
Added PyCall to test dependencies
davibarreira Jun 2, 2021
416dcb4
Added tests for sinkhorn divergence
davibarreira Jun 2, 2021
f593377
Added Sinkhorn Divergence to docs
davibarreira Jun 2, 2021
21d38a8
Creating FiniteDiscreteMeasure struct
davibarreira Jun 3, 2021
e17bba5
Modifications:
davibarreira Jun 3, 2021
10e8849
FixedDiscreteMeasure normalizes the weights to sum 1
davibarreira Jun 3, 2021
52b3c7a
FixedDiscreteMeasure checks if probabilities are positive
davibarreira Jun 3, 2021
7d2924d
Created tests for FiniteDiscreteMeasure
davibarreira Jun 3, 2021
7cf44a6
Added tests for sinkhorn divergence and finite discrete measure
davibarreira Jun 3, 2021
4764b00
Fixed the code for creating cost matrices in the sinkhorn_divergence
davibarreira Jun 3, 2021
98784c5
Added costmatrix.jl to tests
davibarreira Jun 3, 2021
1fb0fc1
Fixed docstring for costmatrix
davibarreira Jun 3, 2021
808d6ac
Fixed errors in the tests
davibarreira Jun 3, 2021
3415386
Minor fixes in the tests
davibarreira Jun 3, 2021
d373d52
Created auxiliary cost matrix function
davibarreira Jun 4, 2021
933c106
Formatted code
davibarreira Jun 4, 2021
63af17a
costmatrix implementation from sinkhorndiverngce PR
davibarreira Jun 12, 2021
f952be8
Formatted code
davibarreira Jun 12, 2021
6d812e3
Added costmatrix.jl to docs
davibarreira Jun 12, 2021
97f5d77
Update Project.toml
davibarreira Jun 12, 2021
fd54e9f
Update src/OptimalTransport.jl
davibarreira Jun 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,8 @@ sinkhorn_unbalanced2
```@docs
quadreg
```

## Utilities
```@docs
cost_matrix
```
2 changes: 2 additions & 0 deletions src/OptimalTransport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ export sinkhorn, sinkhorn2
export emd, emd2
export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter
export sinkhorn_unbalanced, sinkhorn_unbalanced2
export sinkhorn_divergence
export quadreg
export ot_cost, ot_plan, wasserstein, squared2wasserstein
export cost_matrix
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this should be exposed to users.

Copy link
Member Author

@davibarreira davibarreira Jun 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I remember you said this. Can you expand on why you think so? As a user, I'd like to have access to this function, since if I wanted to create the cost matrix, it would save me some time (I always misuse pairwise at first). I don't see the downside in making this available.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, do you think the cost_matrix function is useful (even if only for internal use)? Cause I remember you were not sure about it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that the best implementation would be to just forward everything to pairwise since the type of the support and the cost function should know best what to do. E.g., if you use ColVecs or RowVecs then there is no need to concatenate vectors, you can just call pairwise with the underlying matrix. This is also how it is implemented in KernelFunctions and it is much much more efficient than extracting and combining all columns or rows.

Copy link
Member Author

@davibarreira davibarreira Jun 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but if I construct a FiniteDiscreteMeasure without KernelFunctions, then there is no matrix attribute, this is why I wrote like that. Sorry, I don't understand your point that the best implementation would be to just forward everything to pairwise. I mean, the reason for cost_matrix is that I wouldn't have to deal with the variations. For example, if my cost function is a personalized function, then, for example, pairwise(sqeuclidean, mu.support, nu.support) behaves differently than pairwise(SqEuclidean(), mu.support, nu.support), which would require, for example, adding dims=1 and transforming the support to a matrix.
What I'd like to do is to write a function that deals with all these varying cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If instead the user passes sinkhorndivergence(SqEuclidean(), mu, nu), then I'd have to instead write C = pairwise(c, mu.support.X, nu.support.X, dims=1). But since I cannot guarantee that the user user KernelFunctions when creating the finite measures, I have to use a method that works for both cases, hence C = pairwise(c,reduce(hcat, mu.support), reduce(hcat, nu.support), dims=1).

That's exactly my point, you don't know what type of vectors is used (e.g., whether it is a ColVecs or a RowVecs), so often reduce(hcat, mu.support) can be a very inefficient and suboptimal choice. If instead you would just use pairwise(c, mu.support, nu.support) then you could make use of the optimizations in packages such as KernelFunctions. So my main point is just that probably this is handled here on the wrong level - the packages that define e.g. ColVecs and RowVecs should define how pairwise is handled and make sure that it is efficient since we can't handle all possible types here.

In general though I am a bit surprised about the problems you mention with SqEuclidean. All the desired cases work automatically due to https://github.com/JuliaStats/Distances.jl/blob/b52f0a10017553b311a9c9eed6f96e34a5629c2f/src/generic.jl#L333-L351 (even though it is not optimized for ColVecs but IMO that's an issue of KernelFunctions or the separate package where they should be moved):

julia> pairwise(SqEuclidean(), rand(5), rand(5))
5x5 Matrix{Float64}:
...

julia> pairwise(SqEuclidean(), [rand(5), rand(5)], [rand(5), rand(5)])
2x2 Matrix{Float64}:
...

julia> pairwise(SqEuclidean(), ColVecs(rand(5, 2)), ColVecs(rand(5, 2)))
2x2 Matrix{Float64}:
...

Copy link
Member

@devmotion devmotion Jun 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(sqeuclidean works as well but uses the fallback in StatsBase - IMO this should be changed in Distances)

Copy link
Member Author

@davibarreira davibarreira Jun 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. I was getting an error when using mu.support without the splatter, but I was using the dims argument. So perhaps that was the issue. If that is so, then I agree with you that the cost_matrix function is not necessary.

Copy link
Member Author

@davibarreira davibarreira Jun 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll close the PR. Thanks for the inputs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW I made a PR to Distances that would fix the SqEuclidean/sqeuclidean discrepancy: JuliaStats/Distances.jl#224


const MOI = MathOptInterface

Expand Down
63 changes: 63 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,66 @@ end

Distributions.support(d::FiniteDiscreteMeasure) = d.support
Distributions.probs(d::FiniteDiscreteMeasure) = d.p

"""
cost_matrix(
c,
μ::Union{FiniteDiscreteMeasure, DiscreteNonParametric},
ν::Union{FiniteDiscreteMeasure, DiscreteNonParametric}
)

Compute cost matrix from Finite Discrete Measures `μ` and `ν` using cost function `c`.

Note that the use of functions such as `SqEuclidean()` from `Distances.jl` have
better performance than generic functions. Thus, it's prefered to use
`cost_matrix(SqEuclidean(), μ, ν)`, instead of `cost_matrix((x,y)->sum((x-y).^2), μ, ν)`
or even `cost_matrix(sqeuclidean, μ, ν)`.

For custom cost functions, it is necessary to guarantee that the function `c` works
on vectors, i.e., if one wants to compute the squared Euclidean distance,
the one must define `c(x,y) = sum((x - y).^2)`.

# Example
```julia
μ = discretemeasure(rand(10),normalize!(rand(10),1))
ν = discretemeasure(rand(8))
c = TotalVariation()
C = cost_matrix(c, μ, ν)
```
"""
function cost_matrix(
c,
μ::Union{FiniteDiscreteMeasure,DiscreteNonParametric},
ν::Union{FiniteDiscreteMeasure,DiscreteNonParametric},
)
if typeof(c) <: PreMetric && length(μ.support[1]) == 1
return pairwise(c, vcat(μ.support...), vcat(ν.support...))
elseif typeof(c) <: PreMetric && length(μ.support[1]) > 1
return pairwise(c, vcat(μ.support'...), vcat(ν.support'...); dims=1)
else
return pairwise(c, μ.support, ν.support)
end
end
Comment on lines +136 to +148
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we would want to restrict this to pairs of marginals with the same type of the support. Or at least the dimension should match in the case of arrays and scalars. So we would want a more fine-grained function signature.

In general, it would be better to avoid the type checks in the function definition since it makes it more difficult to extend and specialize the method. I think it would be better to just

  • define a separate fallback for arbitrary c (the last branch)
  • define a separate method for c.:PreMetric

Also one should avoid the splatting of support, it will lead to massive compile times and inference problems for larger arrays.

The same comments apply to the implementation below.

Copy link
Member Author

@davibarreira davibarreira Jun 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. So the splatter is inefficient. Now, how should one efficiently construct matrix C using Distances.pairwise ? I mean, the StatsBase.pairwise takes the vector of vectors and returns exactly what one wants. But the Distances.pairwise would require a matrix version, so how do I make a matrix from the vector of vectors?.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, if you work with ColVecs or RowVecs you actually don't want to construct a matrix at all. But if you deal with an actual vector of vectors, then usually you would use e.g. reduce(hcat, vectors_of_vectors) to avoid splatting.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. But since we didn't require KernelFunctions as a dependency, I could not guarantee that mu.support.X would work. I'll change to your suggestion with reduce.


"""
cost_matrix(
c,
μ::Union{FiniteDiscreteMeasure, DiscreteNonParametric},
symmetric = false
)

Compute cost matrix from Finite Discrete Measures `μ` to itself using cost function `c`.
If the cost function is symmetric, set the argument `symmetric` to `true` in order
to increase performance.
"""
function cost_matrix(
c, μ::Union{FiniteDiscreteMeasure,DiscreteNonParametric}; symmetric=false
)
if typeof(c) <: PreMetric && length(μ.support[1]) == 1
return pairwise(c, vcat(μ.support...))
elseif typeof(c) <: PreMetric && length(μ.support[1]) > 1
return pairwise(c, vcat(μ.support'...); dims=1)
else
return pairwise(c, μ.support; symmetric=symmetric)
end
end
1 change: 1 addition & 0 deletions test/entropic/sinkhorn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Distances
using ForwardDiff
using LogExpFunctions
using PythonOT: PythonOT
using Distributions

using LinearAlgebra
using Random
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using LinearAlgebra: symmetric
using OptimalTransport
using Pkg: Pkg
using SafeTestsets
Expand Down
54 changes: 54 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using OptimalTransport

using Distributions: DiscreteNonParametric
using Distances
using LinearAlgebra
using Random
using Test
Expand Down Expand Up @@ -144,4 +146,56 @@ Random.seed!(100)
@test νsupp == support(ν)
end
end
@testset "costmatrix.jl" begin
@testset "Creating cost matrices from vectors" begin
n = 100
m = 80
μsupp = rand(n)
νsupp = rand(m)
μprobs = normalize!(rand(n), 1)
μ = OptimalTransport.discretemeasure(μsupp, μprobs)
ν = OptimalTransport.discretemeasure(νsupp)
c(x, y) = sum((x - y) .^ 2)
C1 = cost_matrix(SqEuclidean(), μ, ν)
C2 = cost_matrix(sqeuclidean, μ, ν)
C3 = cost_matrix(c, μ, ν)
C = pairwise(SqEuclidean(), vcat(μ.support...), vcat(ν.support...))
@test C1 ≈ C
@test C2 ≈ C
@test C3 ≈ C
end

@testset "Creating cost matrices from matrices" begin
n = 10
m = 3
μsupp = [rand(m) for i in 1:n]
νsupp = [rand(m) for i in 1:n]
μprobs = normalize!(rand(n), 1)
μ = OptimalTransport.discretemeasure(μsupp, μprobs)
ν = OptimalTransport.discretemeasure(νsupp)
c(x, y) = sum((x - y) .^ 2)
C1 = cost_matrix(SqEuclidean(), μ, ν)
C2 = cost_matrix(sqeuclidean, μ, ν)
C3 = cost_matrix(c, μ, ν)
C = pairwise(SqEuclidean(), vcat(μ.support'...), vcat(ν.support'...); dims=1)
@test C1 ≈ C
@test C2 ≈ C
@test C3 ≈ C
end
@testset "Creating cost matrices from μ to itself" begin
n = 10
m = 3
μsupp = [rand(m) for i in 1:n]
μprobs = normalize!(rand(n), 1)
μ = OptimalTransport.discretemeasure(μsupp, μprobs)
c(x, y) = sqrt(sum((x - y) .^ 2))
C1 = cost_matrix(Euclidean(), μ; symmetric=true)
C2 = cost_matrix(euclidean, μ; symmetric=true)
C3 = cost_matrix(c, μ)
C = pairwise(Euclidean(), vcat(μ.support'...), vcat(μ.support'...); dims=1)
@test C1 ≈ C
@test C2 ≈ C
@test C3 ≈ C
end
end
end