Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cost Matrix function #104

Closed

Conversation

davibarreira
Copy link
Member

This PR is the separate PR for the cost matrix function, which would be a helper function in order to make writing functions as sinkhorndivergence(c, mu, nu) where c is a function.

@coveralls
Copy link

coveralls commented Jun 12, 2021

Pull Request Test Coverage Report for Build 931149182

  • 11 of 12 (91.67%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-0.07%) to 94.6%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/utils.jl 11 12 91.67%
Totals Coverage Status
Change from base Build 928044156: -0.07%
Covered Lines: 473
Relevant Lines: 500

💛 - Coveralls

Project.toml Outdated Show resolved Hide resolved
src/OptimalTransport.jl Outdated Show resolved Hide resolved
export quadreg
export ot_cost, ot_plan, wasserstein, squared2wasserstein
export cost_matrix
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this should be exposed to users.

Copy link
Member Author

@davibarreira davibarreira Jun 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I remember you said this. Can you expand on why you think so? As a user, I'd like to have access to this function, since if I wanted to create the cost matrix, it would save me some time (I always misuse pairwise at first). I don't see the downside in making this available.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, do you think the cost_matrix function is useful (even if only for internal use)? Cause I remember you were not sure about it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that the best implementation would be to just forward everything to pairwise since the type of the support and the cost function should know best what to do. E.g., if you use ColVecs or RowVecs then there is no need to concatenate vectors, you can just call pairwise with the underlying matrix. This is also how it is implemented in KernelFunctions and it is much much more efficient than extracting and combining all columns or rows.

Copy link
Member Author

@davibarreira davibarreira Jun 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but if I construct a FiniteDiscreteMeasure without KernelFunctions, then there is no matrix attribute, this is why I wrote like that. Sorry, I don't understand your point that the best implementation would be to just forward everything to pairwise. I mean, the reason for cost_matrix is that I wouldn't have to deal with the variations. For example, if my cost function is a personalized function, then, for example, pairwise(sqeuclidean, mu.support, nu.support) behaves differently than pairwise(SqEuclidean(), mu.support, nu.support), which would require, for example, adding dims=1 and transforming the support to a matrix.
What I'd like to do is to write a function that deals with all these varying cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If instead the user passes sinkhorndivergence(SqEuclidean(), mu, nu), then I'd have to instead write C = pairwise(c, mu.support.X, nu.support.X, dims=1). But since I cannot guarantee that the user user KernelFunctions when creating the finite measures, I have to use a method that works for both cases, hence C = pairwise(c,reduce(hcat, mu.support), reduce(hcat, nu.support), dims=1).

That's exactly my point, you don't know what type of vectors is used (e.g., whether it is a ColVecs or a RowVecs), so often reduce(hcat, mu.support) can be a very inefficient and suboptimal choice. If instead you would just use pairwise(c, mu.support, nu.support) then you could make use of the optimizations in packages such as KernelFunctions. So my main point is just that probably this is handled here on the wrong level - the packages that define e.g. ColVecs and RowVecs should define how pairwise is handled and make sure that it is efficient since we can't handle all possible types here.

In general though I am a bit surprised about the problems you mention with SqEuclidean. All the desired cases work automatically due to https://github.com/JuliaStats/Distances.jl/blob/b52f0a10017553b311a9c9eed6f96e34a5629c2f/src/generic.jl#L333-L351 (even though it is not optimized for ColVecs but IMO that's an issue of KernelFunctions or the separate package where they should be moved):

julia> pairwise(SqEuclidean(), rand(5), rand(5))
5x5 Matrix{Float64}:
...

julia> pairwise(SqEuclidean(), [rand(5), rand(5)], [rand(5), rand(5)])
2x2 Matrix{Float64}:
...

julia> pairwise(SqEuclidean(), ColVecs(rand(5, 2)), ColVecs(rand(5, 2)))
2x2 Matrix{Float64}:
...

Copy link
Member

@devmotion devmotion Jun 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(sqeuclidean works as well but uses the fallback in StatsBase - IMO this should be changed in Distances)

Copy link
Member Author

@davibarreira davibarreira Jun 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. I was getting an error when using mu.support without the splatter, but I was using the dims argument. So perhaps that was the issue. If that is so, then I agree with you that the cost_matrix function is not necessary.

Copy link
Member Author

@davibarreira davibarreira Jun 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll close the PR. Thanks for the inputs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW I made a PR to Distances that would fix the SqEuclidean/sqeuclidean discrepancy: JuliaStats/Distances.jl#224

Comment on lines +136 to +148
function cost_matrix(
c,
μ::Union{FiniteDiscreteMeasure,DiscreteNonParametric},
ν::Union{FiniteDiscreteMeasure,DiscreteNonParametric},
)
if typeof(c) <: PreMetric && length(μ.support[1]) == 1
return pairwise(c, vcat(μ.support...), vcat(ν.support...))
elseif typeof(c) <: PreMetric && length(μ.support[1]) > 1
return pairwise(c, vcat(μ.support'...), vcat(ν.support'...); dims=1)
else
return pairwise(c, μ.support, ν.support)
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we would want to restrict this to pairs of marginals with the same type of the support. Or at least the dimension should match in the case of arrays and scalars. So we would want a more fine-grained function signature.

In general, it would be better to avoid the type checks in the function definition since it makes it more difficult to extend and specialize the method. I think it would be better to just

  • define a separate fallback for arbitrary c (the last branch)
  • define a separate method for c.:PreMetric

Also one should avoid the splatting of support, it will lead to massive compile times and inference problems for larger arrays.

The same comments apply to the implementation below.

Copy link
Member Author

@davibarreira davibarreira Jun 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. So the splatter is inefficient. Now, how should one efficiently construct matrix C using Distances.pairwise ? I mean, the StatsBase.pairwise takes the vector of vectors and returns exactly what one wants. But the Distances.pairwise would require a matrix version, so how do I make a matrix from the vector of vectors?.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, if you work with ColVecs or RowVecs you actually don't want to construct a matrix at all. But if you deal with an actual vector of vectors, then usually you would use e.g. reduce(hcat, vectors_of_vectors) to avoid splatting.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. But since we didn't require KernelFunctions as a dependency, I could not guarantee that mu.support.X would work. I'll change to your suggestion with reduce.

davibarreira and others added 2 commits June 12, 2021 08:48
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
@davibarreira
Copy link
Member Author

It seems that this helper function was not necessary, so I'm closing this PR.

@coveralls
Copy link

coveralls commented Oct 1, 2024

Pull Request Test Coverage Report for Build 931101555

Warning: This coverage report may be inaccurate.

This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.

Details

  • 11 of 12 (91.67%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-0.07%) to 94.6%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/utils.jl 11 12 91.67%
Totals Coverage Status
Change from base Build 928044156: -0.07%
Covered Lines: 473
Relevant Lines: 500

💛 - Coveralls

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants