-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cost Matrix function #104
Cost Matrix function #104
Conversation
…imalTransport.jl into sinkhorndivergence
- Created the struct FiniteDiscreteMeasure, - Implemented two versions of sinkhorn_divergence, - Disabled the use of regularization on sinkhorn_divergence, - Fixed docstring with suggestions.
Pull Request Test Coverage Report for Build 931149182
💛 - Coveralls |
export quadreg | ||
export ot_cost, ot_plan, wasserstein, squared2wasserstein | ||
export cost_matrix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if this should be exposed to users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I remember you said this. Can you expand on why you think so? As a user, I'd like to have access to this function, since if I wanted to create the cost matrix, it would save me some time (I always misuse pairwise
at first). I don't see the downside in making this available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, do you think the cost_matrix
function is useful (even if only for internal use)? Cause I remember you were not sure about it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is that the best implementation would be to just forward everything to pairwise
since the type of the support and the cost function should know best what to do. E.g., if you use ColVecs
or RowVecs
then there is no need to concatenate vectors, you can just call pairwise
with the underlying matrix. This is also how it is implemented in KernelFunctions and it is much much more efficient than extracting and combining all columns or rows.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, but if I construct a FiniteDiscreteMeasure
without KernelFunctions
, then there is no matrix attribute, this is why I wrote like that. Sorry, I don't understand your point that the best implementation would be to just forward everything to pairwise
. I mean, the reason for cost_matrix
is that I wouldn't have to deal with the variations. For example, if my cost function is a personalized function, then, for example, pairwise(sqeuclidean, mu.support, nu.support)
behaves differently than pairwise(SqEuclidean(), mu.support, nu.support)
, which would require, for example, adding dims=1
and transforming the support to a matrix.
What I'd like to do is to write a function that deals with all these varying cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If instead the user passes sinkhorndivergence(SqEuclidean(), mu, nu), then I'd have to instead write C = pairwise(c, mu.support.X, nu.support.X, dims=1). But since I cannot guarantee that the user user KernelFunctions when creating the finite measures, I have to use a method that works for both cases, hence C = pairwise(c,reduce(hcat, mu.support), reduce(hcat, nu.support), dims=1).
That's exactly my point, you don't know what type of vectors is used (e.g., whether it is a ColVecs
or a RowVecs
), so often reduce(hcat, mu.support)
can be a very inefficient and suboptimal choice. If instead you would just use pairwise(c, mu.support, nu.support)
then you could make use of the optimizations in packages such as KernelFunctions. So my main point is just that probably this is handled here on the wrong level - the packages that define e.g. ColVecs
and RowVecs
should define how pairwise
is handled and make sure that it is efficient since we can't handle all possible types here.
In general though I am a bit surprised about the problems you mention with SqEuclidean
. All the desired cases work automatically due to https://github.com/JuliaStats/Distances.jl/blob/b52f0a10017553b311a9c9eed6f96e34a5629c2f/src/generic.jl#L333-L351 (even though it is not optimized for ColVecs
but IMO that's an issue of KernelFunctions or the separate package where they should be moved):
julia> pairwise(SqEuclidean(), rand(5), rand(5))
5x5 Matrix{Float64}:
...
julia> pairwise(SqEuclidean(), [rand(5), rand(5)], [rand(5), rand(5)])
2x2 Matrix{Float64}:
...
julia> pairwise(SqEuclidean(), ColVecs(rand(5, 2)), ColVecs(rand(5, 2)))
2x2 Matrix{Float64}:
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(sqeuclidean
works as well but uses the fallback in StatsBase - IMO this should be changed in Distances)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. I was getting an error when using mu.support
without the splatter, but I was using the dims
argument. So perhaps that was the issue. If that is so, then I agree with you that the cost_matrix
function is not necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll close the PR. Thanks for the inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW I made a PR to Distances that would fix the SqEuclidean/sqeuclidean discrepancy: JuliaStats/Distances.jl#224
function cost_matrix( | ||
c, | ||
μ::Union{FiniteDiscreteMeasure,DiscreteNonParametric}, | ||
ν::Union{FiniteDiscreteMeasure,DiscreteNonParametric}, | ||
) | ||
if typeof(c) <: PreMetric && length(μ.support[1]) == 1 | ||
return pairwise(c, vcat(μ.support...), vcat(ν.support...)) | ||
elseif typeof(c) <: PreMetric && length(μ.support[1]) > 1 | ||
return pairwise(c, vcat(μ.support'...), vcat(ν.support'...); dims=1) | ||
else | ||
return pairwise(c, μ.support, ν.support) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we would want to restrict this to pairs of marginals with the same type of the support. Or at least the dimension should match in the case of arrays and scalars. So we would want a more fine-grained function signature.
In general, it would be better to avoid the type checks in the function definition since it makes it more difficult to extend and specialize the method. I think it would be better to just
- define a separate fallback for arbitrary
c
(the last branch) - define a separate method for
c.:PreMetric
Also one should avoid the splatting of support, it will lead to massive compile times and inference problems for larger arrays.
The same comments apply to the implementation below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. So the splatter is inefficient. Now, how should one efficiently construct matrix C
using Distances.pairwise
? I mean, the StatsBase.pairwise
takes the vector of vectors and returns exactly what one wants. But the Distances.pairwise
would require a matrix version, so how do I make a matrix from the vector of vectors?.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned above, if you work with ColVecs or RowVecs you actually don't want to construct a matrix at all. But if you deal with an actual vector of vectors, then usually you would use e.g. reduce(hcat, vectors_of_vectors)
to avoid splatting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. But since we didn't require KernelFunctions
as a dependency, I could not guarantee that mu.support.X
would work. I'll change to your suggestion with reduce
.
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
It seems that this helper function was not necessary, so I'm closing this PR. |
Pull Request Test Coverage Report for Build 931101555Warning: This coverage report may be inaccurate.This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.
Details
💛 - Coveralls |
This PR is the separate PR for the cost matrix function, which would be a helper function in order to make writing functions as
sinkhorndivergence(c, mu, nu)
wherec
is a function.